517 lines
16 KiB
Python
517 lines
16 KiB
Python
"""HandoffManager 单元测试"""
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.protocol import HandoffMessage
|
||
from agentkit.orchestrator.handoff import HandoffManager
|
||
|
||
|
||
# ── HandoffMessage 创建与序列化测试 ─────────────────────────────
|
||
|
||
|
||
class TestHandoffMessage:
|
||
"""HandoffMessage 创建与序列化测试"""
|
||
|
||
def test_creation_with_required_fields(self):
|
||
msg = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="task-001",
|
||
task_type="analysis",
|
||
context={"key": "value"},
|
||
reason="needs expertise",
|
||
)
|
||
assert msg.source_agent == "agent_a"
|
||
assert msg.target_agent == "agent_b"
|
||
assert msg.task_id == "task-001"
|
||
assert msg.task_type == "analysis"
|
||
assert msg.context == {"key": "value"}
|
||
assert msg.reason == "needs expertise"
|
||
assert msg.created_at is not None
|
||
|
||
def test_to_dict_roundtrip(self):
|
||
msg = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="task-001",
|
||
task_type="analysis",
|
||
context={"data": [1, 2, 3]},
|
||
reason="specialization",
|
||
)
|
||
d = msg.to_dict()
|
||
restored = HandoffMessage.from_dict(d)
|
||
|
||
assert restored.source_agent == msg.source_agent
|
||
assert restored.target_agent == msg.target_agent
|
||
assert restored.task_id == msg.task_id
|
||
assert restored.task_type == msg.task_type
|
||
assert restored.context == msg.context
|
||
assert restored.reason == msg.reason
|
||
|
||
def test_to_dict_contains_all_fields(self):
|
||
msg = HandoffMessage(
|
||
source_agent="a",
|
||
target_agent="b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={"q": "test"},
|
||
reason="handoff",
|
||
)
|
||
d = msg.to_dict()
|
||
|
||
assert "source_agent" in d
|
||
assert "target_agent" in d
|
||
assert "task_id" in d
|
||
assert "task_type" in d
|
||
assert "context" in d
|
||
assert "reason" in d
|
||
assert "created_at" in d
|
||
|
||
def test_from_dict_defaults_context(self):
|
||
data = {
|
||
"source_agent": "a",
|
||
"target_agent": "b",
|
||
"task_id": "t1",
|
||
"task_type": "search",
|
||
"reason": "test",
|
||
}
|
||
msg = HandoffMessage.from_dict(data)
|
||
assert msg.context == {}
|
||
|
||
def test_from_dict_parses_created_at_string(self):
|
||
data = {
|
||
"source_agent": "a",
|
||
"target_agent": "b",
|
||
"task_id": "t1",
|
||
"task_type": "search",
|
||
"context": {},
|
||
"reason": "test",
|
||
"created_at": "2025-01-15T10:30:00+00:00",
|
||
}
|
||
msg = HandoffMessage.from_dict(data)
|
||
assert msg.created_at.year == 2025
|
||
assert msg.created_at.month == 1
|
||
assert msg.created_at.day == 15
|
||
|
||
def test_json_serializable(self):
|
||
msg = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="task-001",
|
||
task_type="analysis",
|
||
context={"key": "value"},
|
||
reason="needs expertise",
|
||
)
|
||
serialized = json.dumps(msg.to_dict())
|
||
deserialized = json.loads(serialized)
|
||
restored = HandoffMessage.from_dict(deserialized)
|
||
|
||
assert restored.source_agent == msg.source_agent
|
||
assert restored.target_agent == msg.target_agent
|
||
assert restored.task_id == msg.task_id
|
||
|
||
|
||
# ── HandoffManager 无 Redis(本地模式)测试 ──────────────────────
|
||
|
||
|
||
class TestHandoffManagerLocalMode:
|
||
"""HandoffManager 无 Redis(本地模式)测试"""
|
||
|
||
def test_construction_without_redis(self):
|
||
manager = HandoffManager()
|
||
assert manager._redis is None
|
||
assert manager._handlers == {}
|
||
|
||
def test_construction_with_dispatcher(self):
|
||
manager = HandoffManager(dispatcher="mock_dispatcher")
|
||
assert manager._dispatcher == "mock_dispatcher"
|
||
|
||
async def test_send_handoff_without_redis_raises(self):
|
||
manager = HandoffManager()
|
||
handoff = HandoffMessage(
|
||
source_agent="a",
|
||
target_agent="b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={},
|
||
reason="test",
|
||
)
|
||
with pytest.raises(RuntimeError, match="Redis connection"):
|
||
await manager.send_handoff(handoff)
|
||
|
||
async def test_listen_for_handoffs_without_redis_returns(self):
|
||
manager = HandoffManager()
|
||
# 无 Redis 时应直接返回,不报错
|
||
await manager.listen_for_handoffs("agent_a")
|
||
|
||
def test_register_handler(self):
|
||
manager = HandoffManager()
|
||
|
||
async def handler(msg):
|
||
pass
|
||
|
||
manager.register_handler("agent_a", handler)
|
||
assert "agent_a" in manager._handlers
|
||
assert handler in manager._handlers["agent_a"]
|
||
|
||
def test_register_multiple_handlers_for_same_agent(self):
|
||
manager = HandoffManager()
|
||
|
||
async def handler1(msg):
|
||
pass
|
||
|
||
async def handler2(msg):
|
||
pass
|
||
|
||
manager.register_handler("agent_a", handler1)
|
||
manager.register_handler("agent_a", handler2)
|
||
assert len(manager._handlers["agent_a"]) == 2
|
||
|
||
def test_register_handlers_for_different_agents(self):
|
||
manager = HandoffManager()
|
||
|
||
async def handler_a(msg):
|
||
pass
|
||
|
||
async def handler_b(msg):
|
||
pass
|
||
|
||
manager.register_handler("agent_a", handler_a)
|
||
manager.register_handler("agent_b", handler_b)
|
||
assert "agent_a" in manager._handlers
|
||
assert "agent_b" in manager._handlers
|
||
assert len(manager._handlers) == 2
|
||
|
||
|
||
# ── HandoffManager _handle_handoff 测试 ─────────────────────────
|
||
|
||
|
||
class TestHandoffManagerHandleHandoff:
|
||
"""HandoffManager 内部 _handle_handoff 测试"""
|
||
|
||
async def test_handle_handoff_calls_registered_handlers(self):
|
||
manager = HandoffManager()
|
||
received = []
|
||
|
||
async def handler(msg):
|
||
received.append(msg)
|
||
|
||
manager.register_handler("agent_b", handler)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={"q": "test"},
|
||
reason="delegation",
|
||
)
|
||
await manager._handle_handoff(handoff)
|
||
|
||
assert len(received) == 1
|
||
assert received[0].task_id == "t1"
|
||
assert received[0].source_agent == "agent_a"
|
||
|
||
async def test_handle_handoff_no_handler_does_nothing(self):
|
||
manager = HandoffManager()
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={},
|
||
reason="test",
|
||
)
|
||
# 不应报错
|
||
await manager._handle_handoff(handoff)
|
||
|
||
async def test_handle_handoff_handler_error_is_caught(self):
|
||
manager = HandoffManager()
|
||
|
||
async def bad_handler(msg):
|
||
raise ValueError("handler error")
|
||
|
||
manager.register_handler("agent_b", bad_handler)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={},
|
||
reason="test",
|
||
)
|
||
# 不应抛出异常
|
||
await manager._handle_handoff(handoff)
|
||
|
||
async def test_handle_handoff_multiple_handlers(self):
|
||
manager = HandoffManager()
|
||
results = []
|
||
|
||
async def handler1(msg):
|
||
results.append("handler1")
|
||
|
||
async def handler2(msg):
|
||
results.append("handler2")
|
||
|
||
manager.register_handler("agent_b", handler1)
|
||
manager.register_handler("agent_b", handler2)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={},
|
||
reason="test",
|
||
)
|
||
await manager._handle_handoff(handoff)
|
||
|
||
assert len(results) == 2
|
||
assert "handler1" in results
|
||
assert "handler2" in results
|
||
|
||
|
||
# ── HandoffManager Redis Pub/Sub 测试 ───────────────────────────
|
||
|
||
|
||
def _redis_available():
|
||
"""检查 Redis 是否可用"""
|
||
import os
|
||
|
||
import redis
|
||
|
||
url = os.environ.get("REDIS_URL", "redis://localhost:6381/0")
|
||
try:
|
||
r = redis.from_url(url)
|
||
r.ping()
|
||
r.close()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
redis_available = _redis_available()
|
||
|
||
|
||
@pytest.mark.redis
|
||
class TestHandoffManagerRedisMode:
|
||
"""HandoffManager Redis Pub/Sub 测试(需要 Redis)"""
|
||
|
||
@pytest.mark.skipif(not redis_available, reason="Redis not available")
|
||
async def test_send_handoff_publishes_to_channel(self, redis_client, clean_redis):
|
||
manager = HandoffManager(redis=redis_client)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t1",
|
||
task_type="search",
|
||
context={"q": "hello"},
|
||
reason="delegation",
|
||
)
|
||
await manager.send_handoff(handoff)
|
||
|
||
# 验证消息发布到了正确的频道
|
||
pubsub = redis_client.pubsub()
|
||
await pubsub.subscribe("agent:agent_b:handoff")
|
||
|
||
# 等待订阅确认消息
|
||
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
# 第一条消息是订阅确认,跳过
|
||
|
||
# 由于 publish 是 fire-and-forget,消息可能已经发送了
|
||
# 我们通过另一种方式验证:重新发送并监听
|
||
await manager.send_handoff(handoff)
|
||
|
||
# 读取发布的消息
|
||
while True:
|
||
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
if msg and msg.get("type") == "message":
|
||
data = json.loads(msg["data"])
|
||
assert data["source_agent"] == "agent_a"
|
||
assert data["target_agent"] == "agent_b"
|
||
assert data["task_id"] == "t1"
|
||
assert data["reason"] == "delegation"
|
||
break
|
||
|
||
await pubsub.unsubscribe("agent:agent_b:handoff")
|
||
|
||
@pytest.mark.skipif(not redis_available, reason="Redis not available")
|
||
async def test_send_handoff_channel_format(self, redis_client, clean_redis):
|
||
"""验证 handoff 消息发送到 agent:{target_agent}:handoff 频道"""
|
||
manager = HandoffManager(redis=redis_client)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="planner",
|
||
target_agent="executor",
|
||
task_id="t2",
|
||
task_type="execute",
|
||
context={"plan": "step1"},
|
||
reason="execute plan",
|
||
)
|
||
await manager.send_handoff(handoff)
|
||
|
||
# 验证频道名格式
|
||
pubsub = redis_client.pubsub()
|
||
await pubsub.subscribe("agent:executor:handoff")
|
||
|
||
# 等待订阅确认
|
||
await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
|
||
await manager.send_handoff(handoff)
|
||
|
||
while True:
|
||
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
if msg and msg.get("type") == "message":
|
||
data = json.loads(msg["data"])
|
||
assert data["target_agent"] == "executor"
|
||
break
|
||
|
||
await pubsub.unsubscribe("agent:executor:handoff")
|
||
|
||
@pytest.mark.skipif(not redis_available, reason="Redis not available")
|
||
async def test_different_agents_different_channels(self, redis_client, clean_redis):
|
||
"""不同 Agent 监听不同频道"""
|
||
manager = HandoffManager(redis=redis_client)
|
||
|
||
handoff_b = HandoffMessage(
|
||
source_agent="a",
|
||
target_agent="b",
|
||
task_id="t3",
|
||
task_type="search",
|
||
context={},
|
||
reason="to b",
|
||
)
|
||
handoff_c = HandoffMessage(
|
||
source_agent="a",
|
||
target_agent="c",
|
||
task_id="t4",
|
||
task_type="search",
|
||
context={},
|
||
reason="to c",
|
||
)
|
||
|
||
# 订阅 agent_b 的频道
|
||
pubsub_b = redis_client.pubsub()
|
||
await pubsub_b.subscribe("agent:b:handoff")
|
||
|
||
# 订阅 agent_c 的频道
|
||
pubsub_c = redis_client.pubsub()
|
||
await pubsub_c.subscribe("agent:c:handoff")
|
||
|
||
# 等待订阅确认
|
||
await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0)
|
||
await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0)
|
||
|
||
# 发送 handoff
|
||
await manager.send_handoff(handoff_b)
|
||
await manager.send_handoff(handoff_c)
|
||
|
||
# 验证 b 收到自己的消息
|
||
while True:
|
||
msg = await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0)
|
||
if msg and msg.get("type") == "message":
|
||
data = json.loads(msg["data"])
|
||
assert data["target_agent"] == "b"
|
||
break
|
||
|
||
# 验证 c 收到自己的消息
|
||
while True:
|
||
msg = await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0)
|
||
if msg and msg.get("type") == "message":
|
||
data = json.loads(msg["data"])
|
||
assert data["target_agent"] == "c"
|
||
break
|
||
|
||
await pubsub_b.unsubscribe("agent:b:handoff")
|
||
await pubsub_c.unsubscribe("agent:c:handoff")
|
||
|
||
@pytest.mark.skipif(not redis_available, reason="Redis not available")
|
||
async def test_listen_for_handoffs_receives_and_handles(self, redis_client, clean_redis):
|
||
"""listen_for_handoffs 接收消息并调用 handler"""
|
||
manager = HandoffManager(redis=redis_client)
|
||
received = []
|
||
|
||
async def handler(msg):
|
||
received.append(msg)
|
||
|
||
manager.register_handler("agent_b", handler)
|
||
|
||
# 启动监听任务
|
||
listen_task = asyncio.create_task(
|
||
manager.listen_for_handoffs("agent_b")
|
||
)
|
||
|
||
# 等待订阅建立
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 发送 handoff
|
||
handoff = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t5",
|
||
task_type="search",
|
||
context={"q": "test"},
|
||
reason="delegation",
|
||
)
|
||
await manager.send_handoff(handoff)
|
||
|
||
# 等待处理
|
||
await asyncio.sleep(1.0)
|
||
|
||
# 取消监听任务
|
||
listen_task.cancel()
|
||
try:
|
||
await listen_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
assert len(received) == 1
|
||
assert received[0].task_id == "t5"
|
||
assert received[0].source_agent == "agent_a"
|
||
assert received[0].target_agent == "agent_b"
|
||
assert received[0].context == {"q": "test"}
|
||
assert received[0].reason == "delegation"
|
||
|
||
@pytest.mark.skipif(not redis_available, reason="Redis not available")
|
||
async def test_handoff_message_contains_all_fields(self, redis_client, clean_redis):
|
||
"""验证 handoff 消息包含 source_agent, target_agent, context, reason"""
|
||
manager = HandoffManager(redis=redis_client)
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="researcher",
|
||
target_agent="writer",
|
||
task_id="t6",
|
||
task_type="compose",
|
||
context={"research": "findings", "style": "formal"},
|
||
reason="needs writing expertise",
|
||
)
|
||
await manager.send_handoff(handoff)
|
||
|
||
pubsub = redis_client.pubsub()
|
||
await pubsub.subscribe("agent:writer:handoff")
|
||
|
||
# 等待订阅确认
|
||
await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
|
||
await manager.send_handoff(handoff)
|
||
|
||
while True:
|
||
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
|
||
if msg and msg.get("type") == "message":
|
||
data = json.loads(msg["data"])
|
||
assert data["source_agent"] == "researcher"
|
||
assert data["target_agent"] == "writer"
|
||
assert data["context"] == {"research": "findings", "style": "formal"}
|
||
assert data["reason"] == "needs writing expertise"
|
||
assert data["task_id"] == "t6"
|
||
assert data["task_type"] == "compose"
|
||
assert "created_at" in data
|
||
break
|
||
|
||
await pubsub.unsubscribe("agent:writer:handoff")
|