277 lines
10 KiB
Python
277 lines
10 KiB
Python
"""RedisMessageBus — 基于 Redis Streams 的消息总线。
|
|
|
|
使用 XADD/XREADGROUP 实现可靠消息传递,支持消费者组、
|
|
消息确认和死信队列。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, Callable, Awaitable
|
|
|
|
from agentkit.bus.message import AgentMessage
|
|
from agentkit.bus.memory_bus import InMemoryMessageBus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_STREAM_PREFIX = "agentkit:bus:"
|
|
_DEAD_LETTER_SUFFIX = ":dead"
|
|
|
|
|
|
class RedisMessageBus:
|
|
"""基于 Redis Streams 的消息总线。"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_url: str = "redis://localhost:6379/0",
|
|
consumer_group: str = "agentkit_bus",
|
|
max_retries: int = 3,
|
|
) -> None:
|
|
self._redis_url = redis_url
|
|
self._consumer_group = consumer_group
|
|
self._max_retries = max_retries
|
|
self._redis: Any = None
|
|
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
|
|
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
|
|
self._consumer_tasks: dict[str, asyncio.Task] = {}
|
|
|
|
async def _get_redis(self) -> Any:
|
|
"""获取 Redis 连接(懒初始化)。"""
|
|
if self._redis is None:
|
|
import redis.asyncio as aioredis
|
|
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
|
return self._redis
|
|
|
|
def _stream_key(self, agent_name: str) -> str:
|
|
return f"{_STREAM_PREFIX}{agent_name}"
|
|
|
|
def _dead_letter_key(self, agent_name: str) -> str:
|
|
return f"{_STREAM_PREFIX}{agent_name}{_DEAD_LETTER_SUFFIX}"
|
|
|
|
async def publish(self, message: AgentMessage) -> None:
|
|
"""发布消息。"""
|
|
if message.is_broadcast:
|
|
await self.broadcast(message)
|
|
return
|
|
|
|
redis = await self._get_redis()
|
|
stream_key = self._stream_key(message.recipient)
|
|
data = message.to_dict()
|
|
|
|
try:
|
|
await redis.xadd(stream_key, {"data": json.dumps(data)})
|
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
|
|
logger.error(f"Failed to publish message to {stream_key}: {e}")
|
|
raise
|
|
|
|
# Check pending requests (only for replies, not original request)
|
|
if (
|
|
message.correlation_id
|
|
and message.correlation_id in self._pending_requests
|
|
and message.message_id != message.correlation_id
|
|
):
|
|
future = self._pending_requests[message.correlation_id]
|
|
if not future.done():
|
|
future.set_result(message)
|
|
|
|
async def subscribe(
|
|
self,
|
|
agent_name: str,
|
|
handler: Callable[[AgentMessage], Awaitable[None]],
|
|
) -> None:
|
|
"""订阅消息。"""
|
|
if agent_name not in self._subscribers:
|
|
self._subscribers[agent_name] = []
|
|
self._subscribers[agent_name].append(handler)
|
|
|
|
# Start consumer task
|
|
if agent_name not in self._consumer_tasks:
|
|
task = asyncio.create_task(
|
|
self._consume_stream(agent_name),
|
|
)
|
|
self._consumer_tasks[agent_name] = task
|
|
|
|
async def _consume_stream(self, agent_name: str) -> None:
|
|
"""消费 Redis Stream 中的消息。"""
|
|
redis = await self._get_redis()
|
|
stream_key = self._stream_key(agent_name)
|
|
|
|
# Create consumer group if not exists
|
|
try:
|
|
await redis.xgroup_create(
|
|
stream_key, self._consumer_group, id="0", mkstream=True,
|
|
)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception: # noqa: BLE001 — xgroup_create raises ResponseError on BUSYGROUP; redis is optional dep
|
|
pass # Group already exists
|
|
|
|
while True:
|
|
try:
|
|
results = await redis.xreadgroup(
|
|
groupname=self._consumer_group,
|
|
consumername=agent_name,
|
|
streams={stream_key: ">"},
|
|
count=10,
|
|
block=1000,
|
|
)
|
|
|
|
if results:
|
|
for stream_name, messages in results:
|
|
for msg_id, fields in messages:
|
|
try:
|
|
data = json.loads(fields.get("data", "{}"))
|
|
message = AgentMessage.from_dict(data)
|
|
|
|
for handler in self._subscribers.get(agent_name, []):
|
|
try:
|
|
await handler(message)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e: # noqa: BLE001 — user-defined async handlers can throw arbitrary exceptions
|
|
logger.warning(f"Handler error for {agent_name}: {e}")
|
|
|
|
# Acknowledge message
|
|
await redis.xack(stream_key, self._consumer_group, msg_id)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e: # noqa: BLE001 — wraps json parse + handler + xack; multi-op block with diverse failure modes
|
|
logger.warning(f"Failed to process message {msg_id}: {e}")
|
|
# Move to dead letter after max retries
|
|
await self._handle_failed_message(
|
|
redis, stream_key, msg_id, fields, agent_name,
|
|
)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e: # noqa: BLE001 — consumer loop top-level fallback; must keep running on transient errors
|
|
logger.error(f"Consumer error for {agent_name}: {e}")
|
|
await asyncio.sleep(1)
|
|
|
|
async def _handle_failed_message(
|
|
self,
|
|
redis: Any,
|
|
stream_key: str,
|
|
msg_id: str,
|
|
fields: dict,
|
|
agent_name: str,
|
|
) -> None:
|
|
"""处理失败消息(移入死信队列)。"""
|
|
dead_key = self._dead_letter_key(agent_name)
|
|
try:
|
|
await redis.xadd(dead_key, fields)
|
|
await redis.xack(stream_key, self._consumer_group, msg_id)
|
|
logger.warning(f"Message {msg_id} moved to dead letter queue")
|
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
|
|
logger.error(f"Failed to move message to dead letter: {e}")
|
|
|
|
async def unsubscribe(self, agent_name: str) -> None:
|
|
"""取消订阅。"""
|
|
self._subscribers.pop(agent_name, None)
|
|
task = self._consumer_tasks.pop(agent_name, None)
|
|
if task:
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def request(
|
|
self,
|
|
message: AgentMessage,
|
|
timeout: float = 30.0,
|
|
) -> AgentMessage:
|
|
"""请求-响应模式。"""
|
|
if not message.correlation_id:
|
|
message.correlation_id = message.message_id
|
|
|
|
loop = asyncio.get_event_loop()
|
|
future: asyncio.Future[AgentMessage] = loop.create_future()
|
|
self._pending_requests[message.correlation_id] = future
|
|
|
|
try:
|
|
await self.publish(message)
|
|
return await asyncio.wait_for(future, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
raise TimeoutError(
|
|
f"Request {message.correlation_id} timed out after {timeout}s"
|
|
)
|
|
finally:
|
|
self._pending_requests.pop(message.correlation_id, None)
|
|
|
|
async def broadcast(self, message: AgentMessage) -> None:
|
|
"""广播消息。"""
|
|
message.recipient = None
|
|
|
|
redis = await self._get_redis()
|
|
data = message.to_dict()
|
|
|
|
for agent_name in self._subscribers:
|
|
stream_key = self._stream_key(agent_name)
|
|
try:
|
|
await redis.xadd(stream_key, {"data": json.dumps(data)})
|
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
|
|
logger.error(f"Failed to broadcast to {agent_name}: {e}")
|
|
|
|
# Check pending requests (only for replies)
|
|
if (
|
|
message.correlation_id
|
|
and message.correlation_id in self._pending_requests
|
|
and message.message_id != message.correlation_id
|
|
):
|
|
future = self._pending_requests[message.correlation_id]
|
|
if not future.done():
|
|
future.set_result(message)
|
|
|
|
async def health_check(self) -> bool:
|
|
try:
|
|
redis = await self._get_redis()
|
|
return await redis.ping()
|
|
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError):
|
|
return False
|
|
|
|
@property
|
|
def backend_type(self) -> str:
|
|
return "redis_streams"
|
|
|
|
|
|
def create_message_bus(
|
|
backend: str = "memory",
|
|
redis_url: str = "redis://localhost:6379/0",
|
|
consumer_group: str = "agentkit_bus",
|
|
max_retries: int = 3,
|
|
) -> InMemoryMessageBus | RedisMessageBus:
|
|
"""创建消息总线实例。
|
|
|
|
Args:
|
|
backend: "memory" 或 "redis"
|
|
redis_url: Redis 连接 URL
|
|
consumer_group: Redis 消费者组名称
|
|
max_retries: 消息最大重试次数
|
|
|
|
Returns:
|
|
MessageBus 实例
|
|
"""
|
|
if backend == "redis":
|
|
try:
|
|
import redis.asyncio as aioredis # noqa: F401
|
|
bus = RedisMessageBus(
|
|
redis_url=redis_url,
|
|
consumer_group=consumer_group,
|
|
max_retries=max_retries,
|
|
)
|
|
logger.info(f"MessageBus backend: redis_streams ({redis_url})")
|
|
return bus
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc: # noqa: BLE001 — factory fallback to InMemoryMessageBus; must catch import/init errors broadly
|
|
logger.warning(
|
|
f"Failed to initialise RedisMessageBus ({exc}), "
|
|
f"falling back to InMemoryMessageBus"
|
|
)
|
|
|
|
bus = InMemoryMessageBus()
|
|
logger.info("MessageBus backend: memory")
|
|
return bus
|