"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。 用于开发和测试,行为与 Redis 实现一致。 集成 CascadeDetector 和 AlignmentGuard 进行消息质量管控。 """ from __future__ import annotations import asyncio import logging from typing import Any, Callable, Awaitable from agentkit.bus.message import AgentMessage logger = logging.getLogger(__name__) class InMemoryMessageBus: """基于 asyncio.Queue 的内存消息总线。""" def __init__( self, cascade_detector: Any = None, alignment_guard: Any = None, ) -> None: self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {} self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {} self._queues: dict[str, asyncio.Queue[AgentMessage]] = {} self._consumer_tasks: dict[str, list[asyncio.Task]] = {} self._cascade_detector = cascade_detector self._alignment_guard = alignment_guard async def publish(self, message: AgentMessage) -> bool: """发布消息,返回是否成功。""" # TTL 过期检查 if message.is_expired(): logger.warning(f"Message {message.message_id} expired, dropping") return False # Cascade detection — 级联故障检测 if self._cascade_detector and message.sender: alert = self._cascade_detector.check_interaction( session_id=f"bus-{message.sender}-{message.recipient or 'broadcast'}" ) if alert: logger.warning(f"Cascade alert: {alert}") return False # Alignment check — 对齐守卫检查(仅对 request / negotiate 类型) if self._alignment_guard and message.msg_type in ("request", "negotiate"): check = await self._alignment_guard.check_output( output={"content": str(message.content), **message.payload}, ) if not check.passed: logger.warning( f"Message blocked by alignment guard: {check.violations}" ) return False if message.is_broadcast: await self.broadcast(message) return True # Point-to-point: deliver to recipient's queue recipient = message.recipient if recipient and recipient in self._queues: await self._queues[recipient].put(message) elif recipient and recipient in self._subscribers: # No queue, call handlers directly for handler in self._subscribers[recipient]: try: await handler(message) except Exception as e: logger.warning(f"Handler error for {recipient}: {e}") # Check if this is a response to a pending request self._try_resolve_pending(message) return True async def subscribe( self, agent_name: str, handler: Callable[[AgentMessage], Awaitable[None]], ) -> None: """订阅消息。""" if agent_name not in self._subscribers: self._subscribers[agent_name] = [] self._queues[agent_name] = asyncio.Queue() self._consumer_tasks[agent_name] = [] self._subscribers[agent_name].append(handler) # Start consumer task and track it task = asyncio.create_task(self._consume_queue(agent_name, handler)) self._consumer_tasks[agent_name].append(task) async def _consume_queue( self, agent_name: str, handler: Callable[[AgentMessage], Awaitable[None]], ) -> None: """消费队列中的消息。""" queue = self._queues.get(agent_name) if queue is None: return while True: try: message = await queue.get() try: await handler(message) except Exception as e: logger.warning(f"Handler error for {agent_name}: {e}") # Check pending requests after handler processes the message # (e.g., handler may publish a response that resolves a future) self._try_resolve_pending(message) except asyncio.CancelledError: break def _try_resolve_pending(self, message: AgentMessage) -> None: """Try to resolve a pending request future if this message is a response.""" if ( message.correlation_id and message.correlation_id in self._pending_requests and message.message_id != message.correlation_id ): future = self._pending_requests[message.correlation_id] if not future.done(): future.set_result(message) async def unsubscribe(self, agent_name: str) -> None: """取消订阅。""" self._subscribers.pop(agent_name, None) self._queues.pop(agent_name, None) # Cancel tracked consumer tasks tasks = self._consumer_tasks.pop(agent_name, []) for task in tasks: if not task.done(): task.cancel() async def request( self, message: AgentMessage, timeout_seconds: float = 30.0, ) -> AgentMessage | None: """请求-响应模式。超时返回 None。""" message.msg_type = "request" if not message.correlation_id: message.correlation_id = message.message_id loop = asyncio.get_running_loop() future: asyncio.Future[AgentMessage] = loop.create_future() self._pending_requests[message.correlation_id] = future published = await self.publish(message) if not published: self._pending_requests.pop(message.correlation_id, None) return None try: return await asyncio.wait_for(future, timeout=timeout_seconds) except asyncio.TimeoutError: self._pending_requests.pop(message.correlation_id, None) logger.warning(f"Request {message.correlation_id} timed out") return None finally: self._pending_requests.pop(message.correlation_id, None) async def broadcast(self, message: AgentMessage) -> None: """广播消息。""" # Ensure recipient is None for broadcast message.recipient = None for agent_name, handlers in self._subscribers.items(): for handler in handlers: try: await handler(message) except Exception as e: logger.warning(f"Broadcast handler error for {agent_name}: {e}") # Check pending requests (only for replies) if ( message.correlation_id and message.correlation_id in self._pending_requests and message.message_id != message.correlation_id ): future = self._pending_requests[message.correlation_id] if not future.done(): future.set_result(message) async def health_check(self) -> bool: return True @property def backend_type(self) -> str: return "memory"