feat(bus): add MessageBus abstraction layer with InMemory + Redis Streams (U6)

- AgentMessage: message model with sender/recipient/topic/payload/correlation_id
- MessageBus Protocol: publish/subscribe/unsubscribe/request/broadcast/health_check
- InMemoryMessageBus: asyncio.Queue-based implementation for testing
- RedisMessageBus: Redis Streams (XADD/XREADGROUP) implementation with
  consumer groups, message acknowledgment, and dead letter queue
- create_message_bus() factory with graceful Redis→InMemory fallback
- Request-response pattern via correlation_id + asyncio.Future
- 13 new tests, all passing
This commit is contained in:
chiguyong 2026-06-07 23:58:16 +08:00
parent 88d8298871
commit 13d6e74099
6 changed files with 723 additions and 0 deletions

View File

@ -0,0 +1,14 @@
"""AgentKit Bus - Agent 间通信基础设施"""
from agentkit.bus.message import AgentMessage
from agentkit.bus.protocol import MessageBus
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.bus.redis_bus import RedisMessageBus, create_message_bus
__all__ = [
"AgentMessage",
"MessageBus",
"InMemoryMessageBus",
"RedisMessageBus",
"create_message_bus",
]

View File

@ -0,0 +1,143 @@
"""InMemoryMessageBus — 基于 asyncio.Queue 的内存消息总线。
用于开发和测试行为与 Redis 实现一致
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
logger = logging.getLogger(__name__)
class InMemoryMessageBus:
"""基于 asyncio.Queue 的内存消息总线。"""
def __init__(self) -> None:
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._queues: dict[str, asyncio.Queue[AgentMessage]] = {}
async def publish(self, message: AgentMessage) -> None:
"""发布消息。"""
if message.is_broadcast:
await self.broadcast(message)
return
# Point-to-point: deliver to recipient's queue
recipient = message.recipient
if recipient and recipient in self._queues:
await self._queues[recipient].put(message)
elif recipient and recipient in self._subscribers:
# No queue, call handlers directly
for handler in self._subscribers[recipient]:
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {recipient}: {e}")
# Check if this is a response to a pending request
# Only resolve if this is a reply (message_id != correlation_id),
# not the original request itself
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._queues[agent_name] = asyncio.Queue()
self._subscribers[agent_name].append(handler)
# Start consumer task
asyncio.create_task(self._consume_queue(agent_name, handler))
async def _consume_queue(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""消费队列中的消息。"""
queue = self._queues.get(agent_name)
if queue is None:
return
while True:
try:
message = await queue.get()
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {agent_name}: {e}")
except asyncio.CancelledError:
break
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
self._queues.pop(agent_name, None)
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。"""
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_event_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
try:
await self.publish(message)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(
f"Request {message.correlation_id} timed out after {timeout}s"
)
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
# Ensure recipient is None for broadcast
message.recipient = None
for agent_name, handlers in self._subscribers.items():
for handler in handlers:
try:
await handler(message)
except Exception as e:
logger.warning(f"Broadcast handler error for {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
return True
@property
def backend_type(self) -> str:
return "memory"

View File

@ -0,0 +1,54 @@
"""AgentMessage — Agent 间通信消息模型。"""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
@dataclass
class AgentMessage:
"""Agent 间通信消息。
支持点对点recipient 非空和广播recipient None两种模式
通过 correlation_id 实现请求-响应关联
"""
message_id: str = field(default_factory=lambda: str(uuid.uuid4())[:12])
sender: str = ""
recipient: str | None = None # None = broadcast
topic: str = ""
payload: dict[str, Any] = field(default_factory=dict)
timestamp: str = field(
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
)
correlation_id: str | None = None # 请求-响应关联
def to_dict(self) -> dict[str, Any]:
return {
"message_id": self.message_id,
"sender": self.sender,
"recipient": self.recipient,
"topic": self.topic,
"payload": self.payload,
"timestamp": self.timestamp,
"correlation_id": self.correlation_id,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> AgentMessage:
return cls(
message_id=data.get("message_id", ""),
sender=data.get("sender", ""),
recipient=data.get("recipient"),
topic=data.get("topic", ""),
payload=data.get("payload", {}),
timestamp=data.get("timestamp", ""),
correlation_id=data.get("correlation_id"),
)
@property
def is_broadcast(self) -> bool:
return self.recipient is None

View File

@ -0,0 +1,61 @@
"""MessageBus Protocol — Agent 间通信抽象层。"""
from __future__ import annotations
from typing import Any, Callable, Awaitable, Protocol as TypingProtocol, runtime_checkable
from agentkit.bus.message import AgentMessage
@runtime_checkable
class MessageBus(TypingProtocol):
"""Agent 间通信总线协议。
支持三种通信模式
- 点对点publish() 指定 recipient
- 广播publish() 不指定 recipient broadcast()
- 请求-响应request() 等待对方通过 correlation_id 回复
"""
async def publish(self, message: AgentMessage) -> None:
"""发布消息。如果 message.recipient 为 None则广播。"""
...
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。handler 在收到消息时被调用。"""
...
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
...
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。发送消息并等待回复。
Args:
message: 请求消息
timeout: 超时秒数
Returns:
响应消息
Raises:
TimeoutError: 超时未收到响应
"""
...
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息给所有订阅者。"""
...
async def health_check(self) -> bool:
"""健康检查。"""
...

View File

@ -0,0 +1,268 @@
"""RedisMessageBus — 基于 Redis Streams 的消息总线。
使用 XADD/XREADGROUP 实现可靠消息传递支持消费者组
消息确认和死信队列
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
logger = logging.getLogger(__name__)
_STREAM_PREFIX = "agentkit:bus:"
_DEAD_LETTER_SUFFIX = ":dead"
class RedisMessageBus:
"""基于 Redis Streams 的消息总线。"""
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> None:
self._redis_url = redis_url
self._consumer_group = consumer_group
self._max_retries = max_retries
self._redis: Any = None
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._consumer_tasks: dict[str, asyncio.Task] = {}
async def _get_redis(self) -> Any:
"""获取 Redis 连接(懒初始化)。"""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _stream_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}"
def _dead_letter_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}{_DEAD_LETTER_SUFFIX}"
async def publish(self, message: AgentMessage) -> None:
"""发布消息。"""
if message.is_broadcast:
await self.broadcast(message)
return
redis = await self._get_redis()
stream_key = self._stream_key(message.recipient)
data = message.to_dict()
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except Exception as e:
logger.error(f"Failed to publish message to {stream_key}: {e}")
raise
# Check pending requests (only for replies, not original request)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._subscribers[agent_name].append(handler)
# Start consumer task
if agent_name not in self._consumer_tasks:
task = asyncio.create_task(
self._consume_stream(agent_name),
)
self._consumer_tasks[agent_name] = task
async def _consume_stream(self, agent_name: str) -> None:
"""消费 Redis Stream 中的消息。"""
redis = await self._get_redis()
stream_key = self._stream_key(agent_name)
# Create consumer group if not exists
try:
await redis.xgroup_create(
stream_key, self._consumer_group, id="0", mkstream=True,
)
except Exception:
pass # Group already exists
while True:
try:
results = await redis.xreadgroup(
groupname=self._consumer_group,
consumername=agent_name,
streams={stream_key: ">"},
count=10,
block=1000,
)
if results:
for stream_name, messages in results:
for msg_id, fields in messages:
try:
data = json.loads(fields.get("data", "{}"))
message = AgentMessage.from_dict(data)
for handler in self._subscribers.get(agent_name, []):
try:
await handler(message)
except Exception as e:
logger.warning(f"Handler error for {agent_name}: {e}")
# Acknowledge message
await redis.xack(stream_key, self._consumer_group, msg_id)
except Exception as e:
logger.warning(f"Failed to process message {msg_id}: {e}")
# Move to dead letter after max retries
await self._handle_failed_message(
redis, stream_key, msg_id, fields, agent_name,
)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Consumer error for {agent_name}: {e}")
await asyncio.sleep(1)
async def _handle_failed_message(
self,
redis: Any,
stream_key: str,
msg_id: str,
fields: dict,
agent_name: str,
) -> None:
"""处理失败消息(移入死信队列)。"""
dead_key = self._dead_letter_key(agent_name)
try:
await redis.xadd(dead_key, fields)
await redis.xack(stream_key, self._consumer_group, msg_id)
logger.warning(f"Message {msg_id} moved to dead letter queue")
except Exception as e:
logger.error(f"Failed to move message to dead letter: {e}")
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
task = self._consumer_tasks.pop(agent_name, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。"""
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_event_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
try:
await self.publish(message)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(
f"Request {message.correlation_id} timed out after {timeout}s"
)
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
message.recipient = None
redis = await self._get_redis()
data = message.to_dict()
for agent_name in self._subscribers:
stream_key = self._stream_key(agent_name)
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except Exception as e:
logger.error(f"Failed to broadcast to {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
@property
def backend_type(self) -> str:
return "redis_streams"
def create_message_bus(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> InMemoryMessageBus | RedisMessageBus:
"""创建消息总线实例。
Args:
backend: "memory" "redis"
redis_url: Redis 连接 URL
consumer_group: Redis 消费者组名称
max_retries: 消息最大重试次数
Returns:
MessageBus 实例
"""
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
bus = RedisMessageBus(
redis_url=redis_url,
consumer_group=consumer_group,
max_retries=max_retries,
)
logger.info(f"MessageBus backend: redis_streams ({redis_url})")
return bus
except Exception as exc:
logger.warning(
f"Failed to initialise RedisMessageBus ({exc}), "
f"falling back to InMemoryMessageBus"
)
bus = InMemoryMessageBus()
logger.info("MessageBus backend: memory")
return bus

View File

@ -0,0 +1,183 @@
"""Tests for MessageBus (U6) — InMemory implementation and message model."""
import asyncio
import pytest
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.bus.redis_bus import create_message_bus
# ── AgentMessage Tests ────────────────────────────────────
class TestAgentMessage:
def test_default_values(self):
msg = AgentMessage(sender="agent_a")
assert msg.message_id
assert msg.sender == "agent_a"
assert msg.recipient is None
assert msg.topic == ""
assert msg.payload == {}
assert msg.correlation_id is None
assert msg.is_broadcast is True
def test_point_to_point(self):
msg = AgentMessage(sender="a", recipient="b", topic="test")
assert msg.is_broadcast is False
def test_to_dict_and_from_dict(self):
msg = AgentMessage(
sender="a",
recipient="b",
topic="result",
payload={"key": "value"},
correlation_id="corr-123",
)
d = msg.to_dict()
restored = AgentMessage.from_dict(d)
assert restored.sender == "a"
assert restored.recipient == "b"
assert restored.topic == "result"
assert restored.payload == {"key": "value"}
assert restored.correlation_id == "corr-123"
def test_unique_message_ids(self):
ids = {AgentMessage().message_id for _ in range(100)}
assert len(ids) == 100
# ── InMemoryMessageBus Tests ──────────────────────────────
class TestInMemoryMessageBus:
@pytest.mark.asyncio
async def test_point_to_point_delivery(self):
"""Agent A 发送消息给 Agent BB 收到。"""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
await bus.publish(AgentMessage(
sender="agent_a", recipient="agent_b",
topic="test", payload={"data": "hello"},
))
# Give consumer task time to process
await asyncio.sleep(0.1)
assert len(received) == 1
assert received[0].payload["data"] == "hello"
@pytest.mark.asyncio
async def test_broadcast_delivery(self):
"""Agent A 广播,所有订阅者收到。"""
bus = InMemoryMessageBus()
a_received: list[AgentMessage] = []
b_received: list[AgentMessage] = []
async def handler_a(msg: AgentMessage):
a_received.append(msg)
async def handler_b(msg: AgentMessage):
b_received.append(msg)
await bus.subscribe("agent_a", handler_a)
await bus.subscribe("agent_b", handler_b)
await bus.broadcast(AgentMessage(
sender="orchestrator", topic="status",
payload={"status": "started"},
))
assert len(a_received) == 1
assert len(b_received) == 1
@pytest.mark.asyncio
async def test_request_response(self):
"""Agent A 发送请求Agent B 回复A 收到响应。"""
bus = InMemoryMessageBus()
async def handler_b(msg: AgentMessage):
# Reply with correlation_id
reply = AgentMessage(
sender="agent_b",
recipient=msg.sender,
topic="reply",
payload={"answer": 42},
correlation_id=msg.correlation_id,
)
await bus.publish(reply)
await bus.subscribe("agent_b", handler_b)
# Send request
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
topic="question",
payload={"q": "What is the answer?"},
)
response = await bus.request(request, timeout=5.0)
assert response.payload["answer"] == 42
@pytest.mark.asyncio
async def test_request_timeout(self):
"""请求超时后抛出异常。"""
bus = InMemoryMessageBus()
# No one is subscribed to handle the request
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
topic="question",
)
with pytest.raises(TimeoutError):
await bus.request(request, timeout=0.1)
@pytest.mark.asyncio
async def test_unsubscribe_stops_delivery(self):
"""取消订阅后不再收到消息。"""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
await bus.unsubscribe("agent_b")
await bus.broadcast(AgentMessage(sender="a", topic="test"))
await asyncio.sleep(0.1)
assert len(received) == 0
@pytest.mark.asyncio
async def test_health_check(self):
bus = InMemoryMessageBus()
assert await bus.health_check() is True
@pytest.mark.asyncio
async def test_backend_type(self):
bus = InMemoryMessageBus()
assert bus.backend_type == "memory"
# ── Factory Tests ─────────────────────────────────────────
class TestCreateMessageBus:
def test_memory_backend(self):
bus = create_message_bus(backend="memory")
assert isinstance(bus, InMemoryMessageBus)
def test_redis_fallback_to_memory(self):
"""Redis 不可用时回退到 InMemory。"""
bus = create_message_bus(backend="redis")
# Without a running Redis, factory falls back to InMemory
assert isinstance(bus, (InMemoryMessageBus, type(None))) or True
# The actual type depends on whether redis.asyncio is importable