fischer-agentkit/tests/unit/test_handoff.py

517 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""HandoffManager 单元测试"""
import asyncio
import json
import pytest
from agentkit.core.protocol import HandoffMessage
from agentkit.orchestrator.handoff import HandoffManager
# ── HandoffMessage 创建与序列化测试 ─────────────────────────────
class TestHandoffMessage:
"""HandoffMessage 创建与序列化测试"""
def test_creation_with_required_fields(self):
msg = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="task-001",
task_type="analysis",
context={"key": "value"},
reason="needs expertise",
)
assert msg.source_agent == "agent_a"
assert msg.target_agent == "agent_b"
assert msg.task_id == "task-001"
assert msg.task_type == "analysis"
assert msg.context == {"key": "value"}
assert msg.reason == "needs expertise"
assert msg.created_at is not None
def test_to_dict_roundtrip(self):
msg = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="task-001",
task_type="analysis",
context={"data": [1, 2, 3]},
reason="specialization",
)
d = msg.to_dict()
restored = HandoffMessage.from_dict(d)
assert restored.source_agent == msg.source_agent
assert restored.target_agent == msg.target_agent
assert restored.task_id == msg.task_id
assert restored.task_type == msg.task_type
assert restored.context == msg.context
assert restored.reason == msg.reason
def test_to_dict_contains_all_fields(self):
msg = HandoffMessage(
source_agent="a",
target_agent="b",
task_id="t1",
task_type="search",
context={"q": "test"},
reason="handoff",
)
d = msg.to_dict()
assert "source_agent" in d
assert "target_agent" in d
assert "task_id" in d
assert "task_type" in d
assert "context" in d
assert "reason" in d
assert "created_at" in d
def test_from_dict_defaults_context(self):
data = {
"source_agent": "a",
"target_agent": "b",
"task_id": "t1",
"task_type": "search",
"reason": "test",
}
msg = HandoffMessage.from_dict(data)
assert msg.context == {}
def test_from_dict_parses_created_at_string(self):
data = {
"source_agent": "a",
"target_agent": "b",
"task_id": "t1",
"task_type": "search",
"context": {},
"reason": "test",
"created_at": "2025-01-15T10:30:00+00:00",
}
msg = HandoffMessage.from_dict(data)
assert msg.created_at.year == 2025
assert msg.created_at.month == 1
assert msg.created_at.day == 15
def test_json_serializable(self):
msg = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="task-001",
task_type="analysis",
context={"key": "value"},
reason="needs expertise",
)
serialized = json.dumps(msg.to_dict())
deserialized = json.loads(serialized)
restored = HandoffMessage.from_dict(deserialized)
assert restored.source_agent == msg.source_agent
assert restored.target_agent == msg.target_agent
assert restored.task_id == msg.task_id
# ── HandoffManager 无 Redis本地模式测试 ──────────────────────
class TestHandoffManagerLocalMode:
"""HandoffManager 无 Redis本地模式测试"""
def test_construction_without_redis(self):
manager = HandoffManager()
assert manager._redis is None
assert manager._handlers == {}
def test_construction_with_dispatcher(self):
manager = HandoffManager(dispatcher="mock_dispatcher")
assert manager._dispatcher == "mock_dispatcher"
async def test_send_handoff_without_redis_raises(self):
manager = HandoffManager()
handoff = HandoffMessage(
source_agent="a",
target_agent="b",
task_id="t1",
task_type="search",
context={},
reason="test",
)
with pytest.raises(RuntimeError, match="Redis connection"):
await manager.send_handoff(handoff)
async def test_listen_for_handoffs_without_redis_returns(self):
manager = HandoffManager()
# 无 Redis 时应直接返回,不报错
await manager.listen_for_handoffs("agent_a")
def test_register_handler(self):
manager = HandoffManager()
async def handler(msg):
pass
manager.register_handler("agent_a", handler)
assert "agent_a" in manager._handlers
assert handler in manager._handlers["agent_a"]
def test_register_multiple_handlers_for_same_agent(self):
manager = HandoffManager()
async def handler1(msg):
pass
async def handler2(msg):
pass
manager.register_handler("agent_a", handler1)
manager.register_handler("agent_a", handler2)
assert len(manager._handlers["agent_a"]) == 2
def test_register_handlers_for_different_agents(self):
manager = HandoffManager()
async def handler_a(msg):
pass
async def handler_b(msg):
pass
manager.register_handler("agent_a", handler_a)
manager.register_handler("agent_b", handler_b)
assert "agent_a" in manager._handlers
assert "agent_b" in manager._handlers
assert len(manager._handlers) == 2
# ── HandoffManager _handle_handoff 测试 ─────────────────────────
class TestHandoffManagerHandleHandoff:
"""HandoffManager 内部 _handle_handoff 测试"""
async def test_handle_handoff_calls_registered_handlers(self):
manager = HandoffManager()
received = []
async def handler(msg):
received.append(msg)
manager.register_handler("agent_b", handler)
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t1",
task_type="search",
context={"q": "test"},
reason="delegation",
)
await manager._handle_handoff(handoff)
assert len(received) == 1
assert received[0].task_id == "t1"
assert received[0].source_agent == "agent_a"
async def test_handle_handoff_no_handler_does_nothing(self):
manager = HandoffManager()
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t1",
task_type="search",
context={},
reason="test",
)
# 不应报错
await manager._handle_handoff(handoff)
async def test_handle_handoff_handler_error_is_caught(self):
manager = HandoffManager()
async def bad_handler(msg):
raise ValueError("handler error")
manager.register_handler("agent_b", bad_handler)
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t1",
task_type="search",
context={},
reason="test",
)
# 不应抛出异常
await manager._handle_handoff(handoff)
async def test_handle_handoff_multiple_handlers(self):
manager = HandoffManager()
results = []
async def handler1(msg):
results.append("handler1")
async def handler2(msg):
results.append("handler2")
manager.register_handler("agent_b", handler1)
manager.register_handler("agent_b", handler2)
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t1",
task_type="search",
context={},
reason="test",
)
await manager._handle_handoff(handoff)
assert len(results) == 2
assert "handler1" in results
assert "handler2" in results
# ── HandoffManager Redis Pub/Sub 测试 ───────────────────────────
def _redis_available():
"""检查 Redis 是否可用"""
import os
import redis
url = os.environ.get("REDIS_URL", "redis://localhost:6381/0")
try:
r = redis.from_url(url)
r.ping()
r.close()
return True
except Exception:
return False
redis_available = _redis_available()
@pytest.mark.redis
class TestHandoffManagerRedisMode:
"""HandoffManager Redis Pub/Sub 测试(需要 Redis"""
@pytest.mark.skipif(not redis_available, reason="Redis not available")
async def test_send_handoff_publishes_to_channel(self, redis_client, clean_redis):
manager = HandoffManager(redis=redis_client)
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t1",
task_type="search",
context={"q": "hello"},
reason="delegation",
)
await manager.send_handoff(handoff)
# 验证消息发布到了正确的频道
pubsub = redis_client.pubsub()
await pubsub.subscribe("agent:agent_b:handoff")
# 等待订阅确认消息
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
# 第一条消息是订阅确认,跳过
# 由于 publish 是 fire-and-forget消息可能已经发送了
# 我们通过另一种方式验证:重新发送并监听
await manager.send_handoff(handoff)
# 读取发布的消息
while True:
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
if msg and msg.get("type") == "message":
data = json.loads(msg["data"])
assert data["source_agent"] == "agent_a"
assert data["target_agent"] == "agent_b"
assert data["task_id"] == "t1"
assert data["reason"] == "delegation"
break
await pubsub.unsubscribe("agent:agent_b:handoff")
@pytest.mark.skipif(not redis_available, reason="Redis not available")
async def test_send_handoff_channel_format(self, redis_client, clean_redis):
"""验证 handoff 消息发送到 agent:{target_agent}:handoff 频道"""
manager = HandoffManager(redis=redis_client)
handoff = HandoffMessage(
source_agent="planner",
target_agent="executor",
task_id="t2",
task_type="execute",
context={"plan": "step1"},
reason="execute plan",
)
await manager.send_handoff(handoff)
# 验证频道名格式
pubsub = redis_client.pubsub()
await pubsub.subscribe("agent:executor:handoff")
# 等待订阅确认
await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
await manager.send_handoff(handoff)
while True:
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
if msg and msg.get("type") == "message":
data = json.loads(msg["data"])
assert data["target_agent"] == "executor"
break
await pubsub.unsubscribe("agent:executor:handoff")
@pytest.mark.skipif(not redis_available, reason="Redis not available")
async def test_different_agents_different_channels(self, redis_client, clean_redis):
"""不同 Agent 监听不同频道"""
manager = HandoffManager(redis=redis_client)
handoff_b = HandoffMessage(
source_agent="a",
target_agent="b",
task_id="t3",
task_type="search",
context={},
reason="to b",
)
handoff_c = HandoffMessage(
source_agent="a",
target_agent="c",
task_id="t4",
task_type="search",
context={},
reason="to c",
)
# 订阅 agent_b 的频道
pubsub_b = redis_client.pubsub()
await pubsub_b.subscribe("agent:b:handoff")
# 订阅 agent_c 的频道
pubsub_c = redis_client.pubsub()
await pubsub_c.subscribe("agent:c:handoff")
# 等待订阅确认
await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0)
await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0)
# 发送 handoff
await manager.send_handoff(handoff_b)
await manager.send_handoff(handoff_c)
# 验证 b 收到自己的消息
while True:
msg = await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0)
if msg and msg.get("type") == "message":
data = json.loads(msg["data"])
assert data["target_agent"] == "b"
break
# 验证 c 收到自己的消息
while True:
msg = await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0)
if msg and msg.get("type") == "message":
data = json.loads(msg["data"])
assert data["target_agent"] == "c"
break
await pubsub_b.unsubscribe("agent:b:handoff")
await pubsub_c.unsubscribe("agent:c:handoff")
@pytest.mark.skipif(not redis_available, reason="Redis not available")
async def test_listen_for_handoffs_receives_and_handles(self, redis_client, clean_redis):
"""listen_for_handoffs 接收消息并调用 handler"""
manager = HandoffManager(redis=redis_client)
received = []
async def handler(msg):
received.append(msg)
manager.register_handler("agent_b", handler)
# 启动监听任务
listen_task = asyncio.create_task(
manager.listen_for_handoffs("agent_b")
)
# 等待订阅建立
await asyncio.sleep(0.5)
# 发送 handoff
handoff = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t5",
task_type="search",
context={"q": "test"},
reason="delegation",
)
await manager.send_handoff(handoff)
# 等待处理
await asyncio.sleep(1.0)
# 取消监听任务
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
assert len(received) == 1
assert received[0].task_id == "t5"
assert received[0].source_agent == "agent_a"
assert received[0].target_agent == "agent_b"
assert received[0].context == {"q": "test"}
assert received[0].reason == "delegation"
@pytest.mark.skipif(not redis_available, reason="Redis not available")
async def test_handoff_message_contains_all_fields(self, redis_client, clean_redis):
"""验证 handoff 消息包含 source_agent, target_agent, context, reason"""
manager = HandoffManager(redis=redis_client)
handoff = HandoffMessage(
source_agent="researcher",
target_agent="writer",
task_id="t6",
task_type="compose",
context={"research": "findings", "style": "formal"},
reason="needs writing expertise",
)
await manager.send_handoff(handoff)
pubsub = redis_client.pubsub()
await pubsub.subscribe("agent:writer:handoff")
# 等待订阅确认
await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
await manager.send_handoff(handoff)
while True:
msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0)
if msg and msg.get("type") == "message":
data = json.loads(msg["data"])
assert data["source_agent"] == "researcher"
assert data["target_agent"] == "writer"
assert data["context"] == {"research": "findings", "style": "formal"}
assert data["reason"] == "needs writing expertise"
assert data["task_id"] == "t6"
assert data["task_type"] == "compose"
assert "created_at" in data
break
await pubsub.unsubscribe("agent:writer:handoff")