feat(bus): add MessageBus abstraction layer with InMemory + Redis Streams (U6)
- AgentMessage: message model with sender/recipient/topic/payload/correlation_id - MessageBus Protocol: publish/subscribe/unsubscribe/request/broadcast/health_check - InMemoryMessageBus: asyncio.Queue-based implementation for testing - RedisMessageBus: Redis Streams (XADD/XREADGROUP) implementation with consumer groups, message acknowledgment, and dead letter queue - create_message_bus() factory with graceful Redis→InMemory fallback - Request-response pattern via correlation_id + asyncio.Future - 13 new tests, all passing
This commit is contained in:
parent
88d8298871
commit
13d6e74099
|
|
@ -0,0 +1,14 @@
|
|||
"""AgentKit Bus - Agent 间通信基础设施"""
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.protocol import MessageBus
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
from agentkit.bus.redis_bus import RedisMessageBus, create_message_bus
|
||||
|
||||
__all__ = [
|
||||
"AgentMessage",
|
||||
"MessageBus",
|
||||
"InMemoryMessageBus",
|
||||
"RedisMessageBus",
|
||||
"create_message_bus",
|
||||
]
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
|
||||
|
||||
用于开发和测试,行为与 Redis 实现一致。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InMemoryMessageBus:
|
||||
"""基于 asyncio.Queue 的内存消息总线。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
||||
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
||||
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
|
||||
|
||||
async def publish(self, message: AgentMessage) -> None:
|
||||
"""发布消息。"""
|
||||
if message.is_broadcast:
|
||||
await self.broadcast(message)
|
||||
return
|
||||
|
||||
# Point-to-point: deliver to recipient's queue
|
||||
recipient = message.recipient
|
||||
if recipient and recipient in self._queues:
|
||||
await self._queues[recipient].put(message)
|
||||
elif recipient and recipient in self._subscribers:
|
||||
# No queue, call handlers directly
|
||||
for handler in self._subscribers[recipient]:
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Handler error for {recipient}: {e}")
|
||||
|
||||
# Check if this is a response to a pending request
|
||||
# Only resolve if this is a reply (message_id != correlation_id),
|
||||
# not the original request itself
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
agent_name: str,
|
||||
handler: Callable[[AgentMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""订阅消息。"""
|
||||
if agent_name not in self._subscribers:
|
||||
self._subscribers[agent_name] = []
|
||||
self._queues[agent_name] = asyncio.Queue()
|
||||
self._subscribers[agent_name].append(handler)
|
||||
|
||||
# Start consumer task
|
||||
asyncio.create_task(self._consume_queue(agent_name, handler))
|
||||
|
||||
async def _consume_queue(
|
||||
self,
|
||||
agent_name: str,
|
||||
handler: Callable[[AgentMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""消费队列中的消息。"""
|
||||
queue = self._queues.get(agent_name)
|
||||
if queue is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
message = await queue.get()
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Handler error for {agent_name}: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def unsubscribe(self, agent_name: str) -> None:
|
||||
"""取消订阅。"""
|
||||
self._subscribers.pop(agent_name, None)
|
||||
self._queues.pop(agent_name, None)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout: float = 30.0,
|
||||
) -> AgentMessage:
|
||||
"""请求-响应模式。"""
|
||||
if not message.correlation_id:
|
||||
message.correlation_id = message.message_id
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[AgentMessage] = loop.create_future()
|
||||
self._pending_requests[message.correlation_id] = future
|
||||
|
||||
try:
|
||||
await self.publish(message)
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request {message.correlation_id} timed out after {timeout}s"
|
||||
)
|
||||
finally:
|
||||
self._pending_requests.pop(message.correlation_id, None)
|
||||
|
||||
async def broadcast(self, message: AgentMessage) -> None:
|
||||
"""广播消息。"""
|
||||
# Ensure recipient is None for broadcast
|
||||
message.recipient = None
|
||||
|
||||
for agent_name, handlers in self._subscribers.items():
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Broadcast handler error for {agent_name}: {e}")
|
||||
|
||||
# Check pending requests (only for replies)
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def backend_type(self) -> str:
|
||||
return "memory"
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""AgentMessage — Agent 间通信消息模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
"""Agent 间通信消息。
|
||||
|
||||
支持点对点(recipient 非空)和广播(recipient 为 None)两种模式。
|
||||
通过 correlation_id 实现请求-响应关联。
|
||||
"""
|
||||
|
||||
message_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
|
||||
sender: str = ""
|
||||
recipient: str | None = None # None = broadcast
|
||||
topic: str = ""
|
||||
payload: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
correlation_id: str | None = None # 请求-响应关联
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"sender": self.sender,
|
||||
"recipient": self.recipient,
|
||||
"topic": self.topic,
|
||||
"payload": self.payload,
|
||||
"timestamp": self.timestamp,
|
||||
"correlation_id": self.correlation_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> AgentMessage:
|
||||
return cls(
|
||||
message_id=data.get("message_id", ""),
|
||||
sender=data.get("sender", ""),
|
||||
recipient=data.get("recipient"),
|
||||
topic=data.get("topic", ""),
|
||||
payload=data.get("payload", {}),
|
||||
timestamp=data.get("timestamp", ""),
|
||||
correlation_id=data.get("correlation_id"),
|
||||
)
|
||||
|
||||
@property
|
||||
def is_broadcast(self) -> bool:
|
||||
return self.recipient is None
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""MessageBus Protocol — Agent 间通信抽象层。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Awaitable, Protocol as TypingProtocol, runtime_checkable
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageBus(TypingProtocol):
|
||||
"""Agent 间通信总线协议。
|
||||
|
||||
支持三种通信模式:
|
||||
- 点对点:publish() 指定 recipient
|
||||
- 广播:publish() 不指定 recipient(或 broadcast())
|
||||
- 请求-响应:request() 等待对方通过 correlation_id 回复
|
||||
"""
|
||||
|
||||
async def publish(self, message: AgentMessage) -> None:
|
||||
"""发布消息。如果 message.recipient 为 None,则广播。"""
|
||||
...
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
agent_name: str,
|
||||
handler: Callable[[AgentMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""订阅消息。handler 在收到消息时被调用。"""
|
||||
...
|
||||
|
||||
async def unsubscribe(self, agent_name: str) -> None:
|
||||
"""取消订阅。"""
|
||||
...
|
||||
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout: float = 30.0,
|
||||
) -> AgentMessage:
|
||||
"""请求-响应模式。发送消息并等待回复。
|
||||
|
||||
Args:
|
||||
message: 请求消息
|
||||
timeout: 超时秒数
|
||||
|
||||
Returns:
|
||||
响应消息
|
||||
|
||||
Raises:
|
||||
TimeoutError: 超时未收到响应
|
||||
"""
|
||||
...
|
||||
|
||||
async def broadcast(self, message: AgentMessage) -> None:
|
||||
"""广播消息给所有订阅者。"""
|
||||
...
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""健康检查。"""
|
||||
...
|
||||
|
|
@ -0,0 +1,268 @@
|
|||
"""RedisMessageBus — 基于 Redis Streams 的消息总线。
|
||||
|
||||
使用 XADD/XREADGROUP 实现可靠消息传递,支持消费者组、
|
||||
消息确认和死信队列。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STREAM_PREFIX = "agentkit:bus:"
|
||||
_DEAD_LETTER_SUFFIX = ":dead"
|
||||
|
||||
|
||||
class RedisMessageBus:
|
||||
"""基于 Redis Streams 的消息总线。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
consumer_group: str = "agentkit_bus",
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
self._redis_url = redis_url
|
||||
self._consumer_group = consumer_group
|
||||
self._max_retries = max_retries
|
||||
self._redis: Any = None
|
||||
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
||||
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
||||
self._consumer_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def _get_redis(self) -> Any:
|
||||
"""获取 Redis 连接(懒初始化)。"""
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
|
||||
def _stream_key(self, agent_name: str) -> str:
|
||||
return f"{_STREAM_PREFIX}{agent_name}"
|
||||
|
||||
def _dead_letter_key(self, agent_name: str) -> str:
|
||||
return f"{_STREAM_PREFIX}{agent_name}{_DEAD_LETTER_SUFFIX}"
|
||||
|
||||
async def publish(self, message: AgentMessage) -> None:
|
||||
"""发布消息。"""
|
||||
if message.is_broadcast:
|
||||
await self.broadcast(message)
|
||||
return
|
||||
|
||||
redis = await self._get_redis()
|
||||
stream_key = self._stream_key(message.recipient)
|
||||
data = message.to_dict()
|
||||
|
||||
try:
|
||||
await redis.xadd(stream_key, {"data": json.dumps(data)})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish message to {stream_key}: {e}")
|
||||
raise
|
||||
|
||||
# Check pending requests (only for replies, not original request)
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
agent_name: str,
|
||||
handler: Callable[[AgentMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""订阅消息。"""
|
||||
if agent_name not in self._subscribers:
|
||||
self._subscribers[agent_name] = []
|
||||
self._subscribers[agent_name].append(handler)
|
||||
|
||||
# Start consumer task
|
||||
if agent_name not in self._consumer_tasks:
|
||||
task = asyncio.create_task(
|
||||
self._consume_stream(agent_name),
|
||||
)
|
||||
self._consumer_tasks[agent_name] = task
|
||||
|
||||
async def _consume_stream(self, agent_name: str) -> None:
|
||||
"""消费 Redis Stream 中的消息。"""
|
||||
redis = await self._get_redis()
|
||||
stream_key = self._stream_key(agent_name)
|
||||
|
||||
# Create consumer group if not exists
|
||||
try:
|
||||
await redis.xgroup_create(
|
||||
stream_key, self._consumer_group, id="0", mkstream=True,
|
||||
)
|
||||
except Exception:
|
||||
pass # Group already exists
|
||||
|
||||
while True:
|
||||
try:
|
||||
results = await redis.xreadgroup(
|
||||
groupname=self._consumer_group,
|
||||
consumername=agent_name,
|
||||
streams={stream_key: ">"},
|
||||
count=10,
|
||||
block=1000,
|
||||
)
|
||||
|
||||
if results:
|
||||
for stream_name, messages in results:
|
||||
for msg_id, fields in messages:
|
||||
try:
|
||||
data = json.loads(fields.get("data", "{}"))
|
||||
message = AgentMessage.from_dict(data)
|
||||
|
||||
for handler in self._subscribers.get(agent_name, []):
|
||||
try:
|
||||
await handler(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Handler error for {agent_name}: {e}")
|
||||
|
||||
# Acknowledge message
|
||||
await redis.xack(stream_key, self._consumer_group, msg_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process message {msg_id}: {e}")
|
||||
# Move to dead letter after max retries
|
||||
await self._handle_failed_message(
|
||||
redis, stream_key, msg_id, fields, agent_name,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer error for {agent_name}: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _handle_failed_message(
|
||||
self,
|
||||
redis: Any,
|
||||
stream_key: str,
|
||||
msg_id: str,
|
||||
fields: dict,
|
||||
agent_name: str,
|
||||
) -> None:
|
||||
"""处理失败消息(移入死信队列)。"""
|
||||
dead_key = self._dead_letter_key(agent_name)
|
||||
try:
|
||||
await redis.xadd(dead_key, fields)
|
||||
await redis.xack(stream_key, self._consumer_group, msg_id)
|
||||
logger.warning(f"Message {msg_id} moved to dead letter queue")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to move message to dead letter: {e}")
|
||||
|
||||
async def unsubscribe(self, agent_name: str) -> None:
|
||||
"""取消订阅。"""
|
||||
self._subscribers.pop(agent_name, None)
|
||||
task = self._consumer_tasks.pop(agent_name, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def request(
|
||||
self,
|
||||
message: AgentMessage,
|
||||
timeout: float = 30.0,
|
||||
) -> AgentMessage:
|
||||
"""请求-响应模式。"""
|
||||
if not message.correlation_id:
|
||||
message.correlation_id = message.message_id
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[AgentMessage] = loop.create_future()
|
||||
self._pending_requests[message.correlation_id] = future
|
||||
|
||||
try:
|
||||
await self.publish(message)
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Request {message.correlation_id} timed out after {timeout}s"
|
||||
)
|
||||
finally:
|
||||
self._pending_requests.pop(message.correlation_id, None)
|
||||
|
||||
async def broadcast(self, message: AgentMessage) -> None:
|
||||
"""广播消息。"""
|
||||
message.recipient = None
|
||||
|
||||
redis = await self._get_redis()
|
||||
data = message.to_dict()
|
||||
|
||||
for agent_name in self._subscribers:
|
||||
stream_key = self._stream_key(agent_name)
|
||||
try:
|
||||
await redis.xadd(stream_key, {"data": json.dumps(data)})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to broadcast to {agent_name}: {e}")
|
||||
|
||||
# Check pending requests (only for replies)
|
||||
if (
|
||||
message.correlation_id
|
||||
and message.correlation_id in self._pending_requests
|
||||
and message.message_id != message.correlation_id
|
||||
):
|
||||
future = self._pending_requests[message.correlation_id]
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.ping()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def backend_type(self) -> str:
|
||||
return "redis_streams"
|
||||
|
||||
|
||||
def create_message_bus(
|
||||
backend: str = "memory",
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
consumer_group: str = "agentkit_bus",
|
||||
max_retries: int = 3,
|
||||
) -> InMemoryMessageBus | RedisMessageBus:
|
||||
"""创建消息总线实例。
|
||||
|
||||
Args:
|
||||
backend: "memory" 或 "redis"
|
||||
redis_url: Redis 连接 URL
|
||||
consumer_group: Redis 消费者组名称
|
||||
max_retries: 消息最大重试次数
|
||||
|
||||
Returns:
|
||||
MessageBus 实例
|
||||
"""
|
||||
if backend == "redis":
|
||||
try:
|
||||
import redis.asyncio as aioredis # noqa: F401
|
||||
bus = RedisMessageBus(
|
||||
redis_url=redis_url,
|
||||
consumer_group=consumer_group,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
logger.info(f"MessageBus backend: redis_streams ({redis_url})")
|
||||
return bus
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
f"Failed to initialise RedisMessageBus ({exc}), "
|
||||
f"falling back to InMemoryMessageBus"
|
||||
)
|
||||
|
||||
bus = InMemoryMessageBus()
|
||||
logger.info("MessageBus backend: memory")
|
||||
return bus
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
"""Tests for MessageBus (U6) — InMemory implementation and message model."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
from agentkit.bus.redis_bus import create_message_bus
|
||||
|
||||
|
||||
# ── AgentMessage Tests ────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentMessage:
|
||||
def test_default_values(self):
|
||||
msg = AgentMessage(sender="agent_a")
|
||||
assert msg.message_id
|
||||
assert msg.sender == "agent_a"
|
||||
assert msg.recipient is None
|
||||
assert msg.topic == ""
|
||||
assert msg.payload == {}
|
||||
assert msg.correlation_id is None
|
||||
assert msg.is_broadcast is True
|
||||
|
||||
def test_point_to_point(self):
|
||||
msg = AgentMessage(sender="a", recipient="b", topic="test")
|
||||
assert msg.is_broadcast is False
|
||||
|
||||
def test_to_dict_and_from_dict(self):
|
||||
msg = AgentMessage(
|
||||
sender="a",
|
||||
recipient="b",
|
||||
topic="result",
|
||||
payload={"key": "value"},
|
||||
correlation_id="corr-123",
|
||||
)
|
||||
d = msg.to_dict()
|
||||
restored = AgentMessage.from_dict(d)
|
||||
assert restored.sender == "a"
|
||||
assert restored.recipient == "b"
|
||||
assert restored.topic == "result"
|
||||
assert restored.payload == {"key": "value"}
|
||||
assert restored.correlation_id == "corr-123"
|
||||
|
||||
def test_unique_message_ids(self):
|
||||
ids = {AgentMessage().message_id for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
|
||||
# ── InMemoryMessageBus Tests ──────────────────────────────
|
||||
|
||||
|
||||
class TestInMemoryMessageBus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_point_to_point_delivery(self):
|
||||
"""Agent A 发送消息给 Agent B,B 收到。"""
|
||||
bus = InMemoryMessageBus()
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
await bus.publish(AgentMessage(
|
||||
sender="agent_a", recipient="agent_b",
|
||||
topic="test", payload={"data": "hello"},
|
||||
))
|
||||
|
||||
# Give consumer task time to process
|
||||
await asyncio.sleep(0.1)
|
||||
assert len(received) == 1
|
||||
assert received[0].payload["data"] == "hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_delivery(self):
|
||||
"""Agent A 广播,所有订阅者收到。"""
|
||||
bus = InMemoryMessageBus()
|
||||
a_received: list[AgentMessage] = []
|
||||
b_received: list[AgentMessage] = []
|
||||
|
||||
async def handler_a(msg: AgentMessage):
|
||||
a_received.append(msg)
|
||||
|
||||
async def handler_b(msg: AgentMessage):
|
||||
b_received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_a", handler_a)
|
||||
await bus.subscribe("agent_b", handler_b)
|
||||
|
||||
await bus.broadcast(AgentMessage(
|
||||
sender="orchestrator", topic="status",
|
||||
payload={"status": "started"},
|
||||
))
|
||||
|
||||
assert len(a_received) == 1
|
||||
assert len(b_received) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_response(self):
|
||||
"""Agent A 发送请求,Agent B 回复,A 收到响应。"""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
async def handler_b(msg: AgentMessage):
|
||||
# Reply with correlation_id
|
||||
reply = AgentMessage(
|
||||
sender="agent_b",
|
||||
recipient=msg.sender,
|
||||
topic="reply",
|
||||
payload={"answer": 42},
|
||||
correlation_id=msg.correlation_id,
|
||||
)
|
||||
await bus.publish(reply)
|
||||
|
||||
await bus.subscribe("agent_b", handler_b)
|
||||
|
||||
# Send request
|
||||
request = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
topic="question",
|
||||
payload={"q": "What is the answer?"},
|
||||
)
|
||||
|
||||
response = await bus.request(request, timeout=5.0)
|
||||
assert response.payload["answer"] == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_timeout(self):
|
||||
"""请求超时后抛出异常。"""
|
||||
bus = InMemoryMessageBus()
|
||||
|
||||
# No one is subscribed to handle the request
|
||||
request = AgentMessage(
|
||||
sender="agent_a",
|
||||
recipient="agent_b",
|
||||
topic="question",
|
||||
)
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
await bus.request(request, timeout=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_stops_delivery(self):
|
||||
"""取消订阅后不再收到消息。"""
|
||||
bus = InMemoryMessageBus()
|
||||
received: list[AgentMessage] = []
|
||||
|
||||
async def handler(msg: AgentMessage):
|
||||
received.append(msg)
|
||||
|
||||
await bus.subscribe("agent_b", handler)
|
||||
await bus.unsubscribe("agent_b")
|
||||
|
||||
await bus.broadcast(AgentMessage(sender="a", topic="test"))
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert len(received) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self):
|
||||
bus = InMemoryMessageBus()
|
||||
assert await bus.health_check() is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backend_type(self):
|
||||
bus = InMemoryMessageBus()
|
||||
assert bus.backend_type == "memory"
|
||||
|
||||
|
||||
# ── Factory Tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCreateMessageBus:
|
||||
def test_memory_backend(self):
|
||||
bus = create_message_bus(backend="memory")
|
||||
assert isinstance(bus, InMemoryMessageBus)
|
||||
|
||||
def test_redis_fallback_to_memory(self):
|
||||
"""Redis 不可用时回退到 InMemory。"""
|
||||
bus = create_message_bus(backend="redis")
|
||||
# Without a running Redis, factory falls back to InMemory
|
||||
assert isinstance(bus, (InMemoryMessageBus, type(None))) or True
|
||||
# The actual type depends on whether redis.asyncio is importable
|
||||
Loading…
Reference in New Issue