184 lines
5.8 KiB
Python
184 lines
5.8 KiB
Python
"""Tests for MessageBus (U6) — InMemory implementation and message model."""
|
||
|
||
import asyncio
|
||
import pytest
|
||
|
||
from agentkit.bus.message import AgentMessage
|
||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||
from agentkit.bus.redis_bus import create_message_bus
|
||
|
||
|
||
# ── AgentMessage Tests ────────────────────────────────────
|
||
|
||
|
||
class TestAgentMessage:
|
||
def test_default_values(self):
|
||
msg = AgentMessage(sender="agent_a")
|
||
assert msg.message_id
|
||
assert msg.sender == "agent_a"
|
||
assert msg.recipient is None
|
||
assert msg.topic == ""
|
||
assert msg.payload == {}
|
||
assert msg.correlation_id is None
|
||
assert msg.is_broadcast is True
|
||
|
||
def test_point_to_point(self):
|
||
msg = AgentMessage(sender="a", recipient="b", topic="test")
|
||
assert msg.is_broadcast is False
|
||
|
||
def test_to_dict_and_from_dict(self):
|
||
msg = AgentMessage(
|
||
sender="a",
|
||
recipient="b",
|
||
topic="result",
|
||
payload={"key": "value"},
|
||
correlation_id="corr-123",
|
||
)
|
||
d = msg.to_dict()
|
||
restored = AgentMessage.from_dict(d)
|
||
assert restored.sender == "a"
|
||
assert restored.recipient == "b"
|
||
assert restored.topic == "result"
|
||
assert restored.payload == {"key": "value"}
|
||
assert restored.correlation_id == "corr-123"
|
||
|
||
def test_unique_message_ids(self):
|
||
ids = {AgentMessage().message_id for _ in range(100)}
|
||
assert len(ids) == 100
|
||
|
||
|
||
# ── InMemoryMessageBus Tests ──────────────────────────────
|
||
|
||
|
||
class TestInMemoryMessageBus:
|
||
@pytest.mark.asyncio
|
||
async def test_point_to_point_delivery(self):
|
||
"""Agent A 发送消息给 Agent B,B 收到。"""
|
||
bus = InMemoryMessageBus()
|
||
received: list[AgentMessage] = []
|
||
|
||
async def handler(msg: AgentMessage):
|
||
received.append(msg)
|
||
|
||
await bus.subscribe("agent_b", handler)
|
||
await bus.publish(AgentMessage(
|
||
sender="agent_a", recipient="agent_b",
|
||
topic="test", payload={"data": "hello"},
|
||
))
|
||
|
||
# Give consumer task time to process
|
||
await asyncio.sleep(0.1)
|
||
assert len(received) == 1
|
||
assert received[0].payload["data"] == "hello"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_broadcast_delivery(self):
|
||
"""Agent A 广播,所有订阅者收到。"""
|
||
bus = InMemoryMessageBus()
|
||
a_received: list[AgentMessage] = []
|
||
b_received: list[AgentMessage] = []
|
||
|
||
async def handler_a(msg: AgentMessage):
|
||
a_received.append(msg)
|
||
|
||
async def handler_b(msg: AgentMessage):
|
||
b_received.append(msg)
|
||
|
||
await bus.subscribe("agent_a", handler_a)
|
||
await bus.subscribe("agent_b", handler_b)
|
||
|
||
await bus.broadcast(AgentMessage(
|
||
sender="orchestrator", topic="status",
|
||
payload={"status": "started"},
|
||
))
|
||
|
||
assert len(a_received) == 1
|
||
assert len(b_received) == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_request_response(self):
|
||
"""Agent A 发送请求,Agent B 回复,A 收到响应。"""
|
||
bus = InMemoryMessageBus()
|
||
|
||
async def handler_b(msg: AgentMessage):
|
||
# Reply with correlation_id
|
||
reply = AgentMessage(
|
||
sender="agent_b",
|
||
recipient=msg.sender,
|
||
topic="reply",
|
||
payload={"answer": 42},
|
||
correlation_id=msg.correlation_id,
|
||
)
|
||
await bus.publish(reply)
|
||
|
||
await bus.subscribe("agent_b", handler_b)
|
||
|
||
# Send request
|
||
request = AgentMessage(
|
||
sender="agent_a",
|
||
recipient="agent_b",
|
||
topic="question",
|
||
payload={"q": "What is the answer?"},
|
||
)
|
||
|
||
response = await bus.request(request, timeout=5.0)
|
||
assert response.payload["answer"] == 42
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_request_timeout(self):
|
||
"""请求超时后抛出异常。"""
|
||
bus = InMemoryMessageBus()
|
||
|
||
# No one is subscribed to handle the request
|
||
request = AgentMessage(
|
||
sender="agent_a",
|
||
recipient="agent_b",
|
||
topic="question",
|
||
)
|
||
|
||
with pytest.raises(TimeoutError):
|
||
await bus.request(request, timeout=0.1)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unsubscribe_stops_delivery(self):
|
||
"""取消订阅后不再收到消息。"""
|
||
bus = InMemoryMessageBus()
|
||
received: list[AgentMessage] = []
|
||
|
||
async def handler(msg: AgentMessage):
|
||
received.append(msg)
|
||
|
||
await bus.subscribe("agent_b", handler)
|
||
await bus.unsubscribe("agent_b")
|
||
|
||
await bus.broadcast(AgentMessage(sender="a", topic="test"))
|
||
await asyncio.sleep(0.1)
|
||
|
||
assert len(received) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check(self):
|
||
bus = InMemoryMessageBus()
|
||
assert await bus.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_backend_type(self):
|
||
bus = InMemoryMessageBus()
|
||
assert bus.backend_type == "memory"
|
||
|
||
|
||
# ── Factory Tests ─────────────────────────────────────────
|
||
|
||
|
||
class TestCreateMessageBus:
|
||
def test_memory_backend(self):
|
||
bus = create_message_bus(backend="memory")
|
||
assert isinstance(bus, InMemoryMessageBus)
|
||
|
||
def test_redis_fallback_to_memory(self):
|
||
"""Redis 不可用时回退到 InMemory。"""
|
||
bus = create_message_bus(backend="redis")
|
||
# Without a running Redis, factory falls back to InMemory
|
||
assert isinstance(bus, (InMemoryMessageBus, type(None))) or True
|
||
# The actual type depends on whether redis.asyncio is importable
|