fischer-agentkit/src/agentkit/bus/redis_bus.py

277 lines
10 KiB
Python

"""RedisMessageBus — 基于 Redis Streams 的消息总线。
使用 XADD/XREADGROUP 实现可靠消息传递,支持消费者组、
消息确认和死信队列。
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Callable, Awaitable
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
logger = logging.getLogger(__name__)
_STREAM_PREFIX = "agentkit:bus:"
_DEAD_LETTER_SUFFIX = ":dead"
class RedisMessageBus:
"""基于 Redis Streams 的消息总线。"""
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> None:
self._redis_url = redis_url
self._consumer_group = consumer_group
self._max_retries = max_retries
self._redis: Any = None
self._subscribers: dict[str, list[Callable[[AgentMessage], Awaitable[None]]]] = {}
self._pending_requests: dict[str, asyncio.Future[AgentMessage]] = {}
self._consumer_tasks: dict[str, asyncio.Task] = {}
async def _get_redis(self) -> Any:
"""获取 Redis 连接(懒初始化)。"""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _stream_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}"
def _dead_letter_key(self, agent_name: str) -> str:
return f"{_STREAM_PREFIX}{agent_name}{_DEAD_LETTER_SUFFIX}"
async def publish(self, message: AgentMessage) -> None:
"""发布消息。"""
if message.is_broadcast:
await self.broadcast(message)
return
redis = await self._get_redis()
stream_key = self._stream_key(message.recipient)
data = message.to_dict()
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
logger.error(f"Failed to publish message to {stream_key}: {e}")
raise
# Check pending requests (only for replies, not original request)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def subscribe(
self,
agent_name: str,
handler: Callable[[AgentMessage], Awaitable[None]],
) -> None:
"""订阅消息。"""
if agent_name not in self._subscribers:
self._subscribers[agent_name] = []
self._subscribers[agent_name].append(handler)
# Start consumer task
if agent_name not in self._consumer_tasks:
task = asyncio.create_task(
self._consume_stream(agent_name),
)
self._consumer_tasks[agent_name] = task
async def _consume_stream(self, agent_name: str) -> None:
"""消费 Redis Stream 中的消息。"""
redis = await self._get_redis()
stream_key = self._stream_key(agent_name)
# Create consumer group if not exists
try:
await redis.xgroup_create(
stream_key, self._consumer_group, id="0", mkstream=True,
)
except asyncio.CancelledError:
raise
except Exception: # noqa: BLE001 — xgroup_create raises ResponseError on BUSYGROUP; redis is optional dep
pass # Group already exists
while True:
try:
results = await redis.xreadgroup(
groupname=self._consumer_group,
consumername=agent_name,
streams={stream_key: ">"},
count=10,
block=1000,
)
if results:
for stream_name, messages in results:
for msg_id, fields in messages:
try:
data = json.loads(fields.get("data", "{}"))
message = AgentMessage.from_dict(data)
for handler in self._subscribers.get(agent_name, []):
try:
await handler(message)
except asyncio.CancelledError:
raise
except Exception as e: # noqa: BLE001 — user-defined async handlers can throw arbitrary exceptions
logger.warning(f"Handler error for {agent_name}: {e}")
# Acknowledge message
await redis.xack(stream_key, self._consumer_group, msg_id)
except asyncio.CancelledError:
raise
except Exception as e: # noqa: BLE001 — wraps json parse + handler + xack; multi-op block with diverse failure modes
logger.warning(f"Failed to process message {msg_id}: {e}")
# Move to dead letter after max retries
await self._handle_failed_message(
redis, stream_key, msg_id, fields, agent_name,
)
except asyncio.CancelledError:
break
except Exception as e: # noqa: BLE001 — consumer loop top-level fallback; must keep running on transient errors
logger.error(f"Consumer error for {agent_name}: {e}")
await asyncio.sleep(1)
async def _handle_failed_message(
self,
redis: Any,
stream_key: str,
msg_id: str,
fields: dict,
agent_name: str,
) -> None:
"""处理失败消息(移入死信队列)。"""
dead_key = self._dead_letter_key(agent_name)
try:
await redis.xadd(dead_key, fields)
await redis.xack(stream_key, self._consumer_group, msg_id)
logger.warning(f"Message {msg_id} moved to dead letter queue")
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
logger.error(f"Failed to move message to dead letter: {e}")
async def unsubscribe(self, agent_name: str) -> None:
"""取消订阅。"""
self._subscribers.pop(agent_name, None)
task = self._consumer_tasks.pop(agent_name, None)
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def request(
self,
message: AgentMessage,
timeout: float = 30.0,
) -> AgentMessage:
"""请求-响应模式。"""
if not message.correlation_id:
message.correlation_id = message.message_id
loop = asyncio.get_event_loop()
future: asyncio.Future[AgentMessage] = loop.create_future()
self._pending_requests[message.correlation_id] = future
try:
await self.publish(message)
return await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(
f"Request {message.correlation_id} timed out after {timeout}s"
)
finally:
self._pending_requests.pop(message.correlation_id, None)
async def broadcast(self, message: AgentMessage) -> None:
"""广播消息。"""
message.recipient = None
redis = await self._get_redis()
data = message.to_dict()
for agent_name in self._subscribers:
stream_key = self._stream_key(agent_name)
try:
await redis.xadd(stream_key, {"data": json.dumps(data)})
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError) as e:
logger.error(f"Failed to broadcast to {agent_name}: {e}")
# Check pending requests (only for replies)
if (
message.correlation_id
and message.correlation_id in self._pending_requests
and message.message_id != message.correlation_id
):
future = self._pending_requests[message.correlation_id]
if not future.done():
future.set_result(message)
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except (ConnectionError, OSError, asyncio.TimeoutError, RuntimeError):
return False
@property
def backend_type(self) -> str:
return "redis_streams"
def create_message_bus(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
consumer_group: str = "agentkit_bus",
max_retries: int = 3,
) -> InMemoryMessageBus | RedisMessageBus:
"""创建消息总线实例。
Args:
backend: "memory""redis"
redis_url: Redis 连接 URL
consumer_group: Redis 消费者组名称
max_retries: 消息最大重试次数
Returns:
MessageBus 实例
"""
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
bus = RedisMessageBus(
redis_url=redis_url,
consumer_group=consumer_group,
max_retries=max_retries,
)
logger.info(f"MessageBus backend: redis_streams ({redis_url})")
return bus
except asyncio.CancelledError:
raise
except Exception as exc: # noqa: BLE001 — factory fallback to InMemoryMessageBus; must catch import/init errors broadly
logger.warning(
f"Failed to initialise RedisMessageBus ({exc}), "
f"falling back to InMemoryMessageBus"
)
bus = InMemoryMessageBus()
logger.info("MessageBus backend: memory")
return bus