197 lines
7.1 KiB
Python
197 lines
7.1 KiB
Python
"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
|
|
|
|
用于开发和测试,行为与 Redis 实现一致。
|
|
集成 CascadeDetector 和 AlignmentGuard 进行消息质量管控。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Callable, Awaitable
|
|
|
|
from agentkit.bus.message import AgentMessage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InMemoryMessageBus:
|
|
"""基于 asyncio.Queue 的内存消息总线。"""
|
|
|
|
def __init__(
|
|
self,
|
|
cascade_detector: Any = None,
|
|
alignment_guard: Any = None,
|
|
) -> None:
|
|
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
|
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
|
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
|
|
self._consumer_tasks: dict[str, list[asyncio.Task]] = {}
|
|
self._cascade_detector = cascade_detector
|
|
self._alignment_guard = alignment_guard
|
|
|
|
async def publish(self, message: AgentMessage) -> bool:
|
|
"""发布消息,返回是否成功。"""
|
|
# TTL 过期检查
|
|
if message.is_expired():
|
|
logger.warning(f"Message {message.message_id} expired, dropping")
|
|
return False
|
|
|
|
# Cascade detection — 级联故障检测
|
|
if self._cascade_detector and message.sender:
|
|
alert = self._cascade_detector.check_interaction(
|
|
session_id=f"bus-{message.sender}-{message.recipient or 'broadcast'}"
|
|
)
|
|
if alert:
|
|
logger.warning(f"Cascade alert: {alert}")
|
|
return False
|
|
|
|
# Alignment check — 对齐守卫检查(仅对 request / negotiate 类型)
|
|
if self._alignment_guard and message.msg_type in ("request", "negotiate"):
|
|
check = await self._alignment_guard.check_output(
|
|
output={"content": str(message.content), **message.payload},
|
|
)
|
|
if not check.passed:
|
|
logger.warning(
|
|
f"Message blocked by alignment guard: {check.violations}"
|
|
)
|
|
return False
|
|
|
|
if message.is_broadcast:
|
|
await self.broadcast(message)
|
|
return True
|
|
|
|
# Point-to-point: deliver to recipient's queue
|
|
recipient = message.recipient
|
|
if recipient and recipient in self._queues:
|
|
await self._queues[recipient].put(message)
|
|
elif recipient and recipient in self._subscribers:
|
|
# No queue, call handlers directly
|
|
for handler in self._subscribers[recipient]:
|
|
try:
|
|
await handler(message)
|
|
except Exception as e:
|
|
logger.warning(f"Handler error for {recipient}: {e}")
|
|
|
|
# Check if this is a response to a pending request
|
|
self._try_resolve_pending(message)
|
|
|
|
return True
|
|
|
|
async def subscribe(
|
|
self,
|
|
agent_name: str,
|
|
handler: Callable[[AgentMessage], Awaitable[None]],
|
|
) -> None:
|
|
"""订阅消息。"""
|
|
if agent_name not in self._subscribers:
|
|
self._subscribers[agent_name] = []
|
|
self._queues[agent_name] = asyncio.Queue()
|
|
self._consumer_tasks[agent_name] = []
|
|
self._subscribers[agent_name].append(handler)
|
|
|
|
# Start consumer task and track it
|
|
task = asyncio.create_task(self._consume_queue(agent_name, handler))
|
|
self._consumer_tasks[agent_name].append(task)
|
|
|
|
async def _consume_queue(
|
|
self,
|
|
agent_name: str,
|
|
handler: Callable[[AgentMessage], Awaitable[None]],
|
|
) -> None:
|
|
"""消费队列中的消息。"""
|
|
queue = self._queues.get(agent_name)
|
|
if queue is None:
|
|
return
|
|
while True:
|
|
try:
|
|
message = await queue.get()
|
|
try:
|
|
await handler(message)
|
|
except Exception as e:
|
|
logger.warning(f"Handler error for {agent_name}: {e}")
|
|
|
|
# Check pending requests after handler processes the message
|
|
# (e.g., handler may publish a response that resolves a future)
|
|
self._try_resolve_pending(message)
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
def _try_resolve_pending(self, message: AgentMessage) -> None:
|
|
"""Try to resolve a pending request future if this message is a response."""
|
|
if (
|
|
message.correlation_id
|
|
and message.correlation_id in self._pending_requests
|
|
and message.message_id != message.correlation_id
|
|
):
|
|
future = self._pending_requests[message.correlation_id]
|
|
if not future.done():
|
|
future.set_result(message)
|
|
|
|
async def unsubscribe(self, agent_name: str) -> None:
|
|
"""取消订阅。"""
|
|
self._subscribers.pop(agent_name, None)
|
|
self._queues.pop(agent_name, None)
|
|
# Cancel tracked consumer tasks
|
|
tasks = self._consumer_tasks.pop(agent_name, [])
|
|
for task in tasks:
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
async def request(
|
|
self,
|
|
message: AgentMessage,
|
|
timeout_seconds: float = 30.0,
|
|
) -> AgentMessage | None:
|
|
"""请求-响应模式。超时返回 None。"""
|
|
message.msg_type = "request"
|
|
if not message.correlation_id:
|
|
message.correlation_id = message.message_id
|
|
|
|
loop = asyncio.get_running_loop()
|
|
future: asyncio.Future[AgentMessage] = loop.create_future()
|
|
self._pending_requests[message.correlation_id] = future
|
|
|
|
published = await self.publish(message)
|
|
if not published:
|
|
self._pending_requests.pop(message.correlation_id, None)
|
|
return None
|
|
|
|
try:
|
|
return await asyncio.wait_for(future, timeout=timeout_seconds)
|
|
except asyncio.TimeoutError:
|
|
self._pending_requests.pop(message.correlation_id, None)
|
|
logger.warning(f"Request {message.correlation_id} timed out")
|
|
return None
|
|
finally:
|
|
self._pending_requests.pop(message.correlation_id, None)
|
|
|
|
async def broadcast(self, message: AgentMessage) -> None:
|
|
"""广播消息。"""
|
|
# Ensure recipient is None for broadcast
|
|
message.recipient = None
|
|
|
|
for agent_name, handlers in self._subscribers.items():
|
|
for handler in handlers:
|
|
try:
|
|
await handler(message)
|
|
except Exception as e:
|
|
logger.warning(f"Broadcast handler error for {agent_name}: {e}")
|
|
|
|
# Check pending requests (only for replies)
|
|
if (
|
|
message.correlation_id
|
|
and message.correlation_id in self._pending_requests
|
|
and message.message_id != message.correlation_id
|
|
):
|
|
future = self._pending_requests[message.correlation_id]
|
|
if not future.done():
|
|
future.set_result(message)
|
|
|
|
async def health_check(self) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def backend_type(self) -> str:
|
|
return "memory"
|