fischer-agentkit/src/agentkit/bus/memory_bus.py

197 lines
7.1 KiB
Python

"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
用于开发和测试,行为与 Redis 实现一致。
集成 CascadeDetector 和 AlignmentGuard 进行消息质量管控。
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
logger = logging.getLogger(__name__)
class InMemoryMessageBus:
"""基于 asyncio.Queue 的内存消息总线。"""
def __init__(
self,
cascade_detector: Any = None,
alignment_guard: Any = None,
) -> None:
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
self._consumer_tasks: dict[str, list[asyncio.Task]] = {}
self._cascade_detector = cascade_detector
self._alignment_guard = alignment_guard
async def publish(self, message: AgentMessage) -> bool:
"""发布消息,返回是否成功。"""
# TTL 过期检查
if message.is_expired():
logger.warning(f"Message {message.message_id} expired, dropping")
return False
# Cascade detection — 级联故障检测
if self._cascade_detector and message.sender:
alert = self._cascade_detector.check_interaction(
session_id=f"bus-{message.sender}-{message.recipient or 'broadcast'}"
)
if alert:
logger.warning(f"Cascade alert: {alert}")
return False
# Alignment check — 对齐守卫检查(仅对 request / negotiate 类型)
if self._alignment_guard and message.msg_type in ("request", "negotiate"):
check = await self._alignment_guard.check_output(
output={"content": str(message.content), **message.payload},
)
if not check.passed:
logger.warning(
f"Message blocked by alignment guard: {check.violations}"
)
return False
if message.is_broadcast:
await self.broadcast(message)
return True
# Point-to-point: deliver to recipient's queue
recipient = message.recipient
if recipient and recipient in self._queues:
await self._queues[recipient].put(message)
elif recipient and recipient in self._subscribers:
# No queue, call handlers directly
for handler in self._subscribers[recipient]:
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {recipient}: {e}")
# Check if this is a response to a pending request
self._try_resolve_pending(message)
return True
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._queues[agent_name] = asyncio.Queue()
self._consumer_tasks[agent_name] = []
self._subscribers[agent_name].append(handler)
# Start consumer task and track it
task = asyncio.create_task(self._consume_queue(agent_name, handler))
self._consumer_tasks[agent_name].append(task)
async def _consume_queue(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""消费队列中的消息。"""
queue = self._queues.get(agent_name)
if queue is None:
return
while True:
try:
message = await queue.get()
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {agent_name}: {e}")
# Check pending requests after handler processes the message
# (e.g., handler may publish a response that resolves a future)
self._try_resolve_pending(message)
except asyncio.CancelledError:
break
def _try_resolve_pending(self, message: AgentMessage) -> None:
"""Try to resolve a pending request future if this message is a response."""
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
self._queues.pop(agent_name, None)
# Cancel tracked consumer tasks
tasks = self._consumer_tasks.pop(agent_name, [])
for task in tasks:
if not task.done():
task.cancel()
async def request(
self,
message: AgentMessage,
timeout_seconds: float = 30.0,
) -> AgentMessage | None:
"""请求-响应模式。超时返回 None。"""
message.msg_type = "request"
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_running_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
published = await self.publish(message)
if not published:
self._pending_requests.pop(message.correlation_id, None)
return None
try:
return await asyncio.wait_for(future, timeout=timeout_seconds)
except asyncio.TimeoutError:
self._pending_requests.pop(message.correlation_id, None)
logger.warning(f"Request {message.correlation_id} timed out")
return None
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
# Ensure recipient is None for broadcast
message.recipient = None
for agent_name, handlers in self._subscribers.items():
for handler in handlers:
try:
await handler(message)
except Exception as e:
logger.warning(f"Broadcast handler error for {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
return True
@property
def backend_type(self) -> str:
return "memory"