"""HandoffManager 单元测试""" import asyncio import json import pytest from agentkit.core.protocol import HandoffMessage from agentkit.orchestrator.handoff import HandoffManager # ── HandoffMessage 创建与序列化测试 ───────────────────────────── class TestHandoffMessage: """HandoffMessage 创建与序列化测试""" def test_creation_with_required_fields(self): msg = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="task-001", task_type="analysis", context={"key": "value"}, reason="needs expertise", ) assert msg.source_agent == "agent_a" assert msg.target_agent == "agent_b" assert msg.task_id == "task-001" assert msg.task_type == "analysis" assert msg.context == {"key": "value"} assert msg.reason == "needs expertise" assert msg.created_at is not None def test_to_dict_roundtrip(self): msg = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="task-001", task_type="analysis", context={"data": [1, 2, 3]}, reason="specialization", ) d = msg.to_dict() restored = HandoffMessage.from_dict(d) assert restored.source_agent == msg.source_agent assert restored.target_agent == msg.target_agent assert restored.task_id == msg.task_id assert restored.task_type == msg.task_type assert restored.context == msg.context assert restored.reason == msg.reason def test_to_dict_contains_all_fields(self): msg = HandoffMessage( source_agent="a", target_agent="b", task_id="t1", task_type="search", context={"q": "test"}, reason="handoff", ) d = msg.to_dict() assert "source_agent" in d assert "target_agent" in d assert "task_id" in d assert "task_type" in d assert "context" in d assert "reason" in d assert "created_at" in d def test_from_dict_defaults_context(self): data = { "source_agent": "a", "target_agent": "b", "task_id": "t1", "task_type": "search", "reason": "test", } msg = HandoffMessage.from_dict(data) assert msg.context == {} def test_from_dict_parses_created_at_string(self): data = { "source_agent": "a", "target_agent": "b", "task_id": "t1", "task_type": "search", "context": {}, "reason": "test", "created_at": "2025-01-15T10:30:00+00:00", } msg = HandoffMessage.from_dict(data) assert msg.created_at.year == 2025 assert msg.created_at.month == 1 assert msg.created_at.day == 15 def test_json_serializable(self): msg = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="task-001", task_type="analysis", context={"key": "value"}, reason="needs expertise", ) serialized = json.dumps(msg.to_dict()) deserialized = json.loads(serialized) restored = HandoffMessage.from_dict(deserialized) assert restored.source_agent == msg.source_agent assert restored.target_agent == msg.target_agent assert restored.task_id == msg.task_id # ── HandoffManager 无 Redis(本地模式)测试 ────────────────────── class TestHandoffManagerLocalMode: """HandoffManager 无 Redis(本地模式)测试""" def test_construction_without_redis(self): manager = HandoffManager() assert manager._redis is None assert manager._handlers == {} def test_construction_with_dispatcher(self): manager = HandoffManager(dispatcher="mock_dispatcher") assert manager._dispatcher == "mock_dispatcher" async def test_send_handoff_without_redis_raises(self): manager = HandoffManager() handoff = HandoffMessage( source_agent="a", target_agent="b", task_id="t1", task_type="search", context={}, reason="test", ) with pytest.raises(RuntimeError, match="Redis connection"): await manager.send_handoff(handoff) async def test_listen_for_handoffs_without_redis_returns(self): manager = HandoffManager() # 无 Redis 时应直接返回,不报错 await manager.listen_for_handoffs("agent_a") def test_register_handler(self): manager = HandoffManager() async def handler(msg): pass manager.register_handler("agent_a", handler) assert "agent_a" in manager._handlers assert handler in manager._handlers["agent_a"] def test_register_multiple_handlers_for_same_agent(self): manager = HandoffManager() async def handler1(msg): pass async def handler2(msg): pass manager.register_handler("agent_a", handler1) manager.register_handler("agent_a", handler2) assert len(manager._handlers["agent_a"]) == 2 def test_register_handlers_for_different_agents(self): manager = HandoffManager() async def handler_a(msg): pass async def handler_b(msg): pass manager.register_handler("agent_a", handler_a) manager.register_handler("agent_b", handler_b) assert "agent_a" in manager._handlers assert "agent_b" in manager._handlers assert len(manager._handlers) == 2 # ── HandoffManager _handle_handoff 测试 ───────────────────────── class TestHandoffManagerHandleHandoff: """HandoffManager 内部 _handle_handoff 测试""" async def test_handle_handoff_calls_registered_handlers(self): manager = HandoffManager() received = [] async def handler(msg): received.append(msg) manager.register_handler("agent_b", handler) handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t1", task_type="search", context={"q": "test"}, reason="delegation", ) await manager._handle_handoff(handoff) assert len(received) == 1 assert received[0].task_id == "t1" assert received[0].source_agent == "agent_a" async def test_handle_handoff_no_handler_does_nothing(self): manager = HandoffManager() handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t1", task_type="search", context={}, reason="test", ) # 不应报错 await manager._handle_handoff(handoff) async def test_handle_handoff_handler_error_is_caught(self): manager = HandoffManager() async def bad_handler(msg): raise ValueError("handler error") manager.register_handler("agent_b", bad_handler) handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t1", task_type="search", context={}, reason="test", ) # 不应抛出异常 await manager._handle_handoff(handoff) async def test_handle_handoff_multiple_handlers(self): manager = HandoffManager() results = [] async def handler1(msg): results.append("handler1") async def handler2(msg): results.append("handler2") manager.register_handler("agent_b", handler1) manager.register_handler("agent_b", handler2) handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t1", task_type="search", context={}, reason="test", ) await manager._handle_handoff(handoff) assert len(results) == 2 assert "handler1" in results assert "handler2" in results # ── HandoffManager Redis Pub/Sub 测试 ─────────────────────────── def _redis_available(): """检查 Redis 是否可用""" import os import redis url = os.environ.get("REDIS_URL", "redis://localhost:6381/0") try: r = redis.from_url(url) r.ping() r.close() return True except Exception: return False redis_available = _redis_available() @pytest.mark.redis class TestHandoffManagerRedisMode: """HandoffManager Redis Pub/Sub 测试(需要 Redis)""" @pytest.mark.skipif(not redis_available, reason="Redis not available") async def test_send_handoff_publishes_to_channel(self, redis_client, clean_redis): manager = HandoffManager(redis=redis_client) handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t1", task_type="search", context={"q": "hello"}, reason="delegation", ) await manager.send_handoff(handoff) # 验证消息发布到了正确的频道 pubsub = redis_client.pubsub() await pubsub.subscribe("agent:agent_b:handoff") # 等待订阅确认消息 msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) # 第一条消息是订阅确认,跳过 # 由于 publish 是 fire-and-forget,消息可能已经发送了 # 我们通过另一种方式验证:重新发送并监听 await manager.send_handoff(handoff) # 读取发布的消息 while True: msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) if msg and msg.get("type") == "message": data = json.loads(msg["data"]) assert data["source_agent"] == "agent_a" assert data["target_agent"] == "agent_b" assert data["task_id"] == "t1" assert data["reason"] == "delegation" break await pubsub.unsubscribe("agent:agent_b:handoff") @pytest.mark.skipif(not redis_available, reason="Redis not available") async def test_send_handoff_channel_format(self, redis_client, clean_redis): """验证 handoff 消息发送到 agent:{target_agent}:handoff 频道""" manager = HandoffManager(redis=redis_client) handoff = HandoffMessage( source_agent="planner", target_agent="executor", task_id="t2", task_type="execute", context={"plan": "step1"}, reason="execute plan", ) await manager.send_handoff(handoff) # 验证频道名格式 pubsub = redis_client.pubsub() await pubsub.subscribe("agent:executor:handoff") # 等待订阅确认 await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) await manager.send_handoff(handoff) while True: msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) if msg and msg.get("type") == "message": data = json.loads(msg["data"]) assert data["target_agent"] == "executor" break await pubsub.unsubscribe("agent:executor:handoff") @pytest.mark.skipif(not redis_available, reason="Redis not available") async def test_different_agents_different_channels(self, redis_client, clean_redis): """不同 Agent 监听不同频道""" manager = HandoffManager(redis=redis_client) handoff_b = HandoffMessage( source_agent="a", target_agent="b", task_id="t3", task_type="search", context={}, reason="to b", ) handoff_c = HandoffMessage( source_agent="a", target_agent="c", task_id="t4", task_type="search", context={}, reason="to c", ) # 订阅 agent_b 的频道 pubsub_b = redis_client.pubsub() await pubsub_b.subscribe("agent:b:handoff") # 订阅 agent_c 的频道 pubsub_c = redis_client.pubsub() await pubsub_c.subscribe("agent:c:handoff") # 等待订阅确认 await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0) await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0) # 发送 handoff await manager.send_handoff(handoff_b) await manager.send_handoff(handoff_c) # 验证 b 收到自己的消息 while True: msg = await asyncio.wait_for(pubsub_b.get_message(timeout=2.0), timeout=3.0) if msg and msg.get("type") == "message": data = json.loads(msg["data"]) assert data["target_agent"] == "b" break # 验证 c 收到自己的消息 while True: msg = await asyncio.wait_for(pubsub_c.get_message(timeout=2.0), timeout=3.0) if msg and msg.get("type") == "message": data = json.loads(msg["data"]) assert data["target_agent"] == "c" break await pubsub_b.unsubscribe("agent:b:handoff") await pubsub_c.unsubscribe("agent:c:handoff") @pytest.mark.skipif(not redis_available, reason="Redis not available") async def test_listen_for_handoffs_receives_and_handles(self, redis_client, clean_redis): """listen_for_handoffs 接收消息并调用 handler""" manager = HandoffManager(redis=redis_client) received = [] async def handler(msg): received.append(msg) manager.register_handler("agent_b", handler) # 启动监听任务 listen_task = asyncio.create_task( manager.listen_for_handoffs("agent_b") ) # 等待订阅建立 await asyncio.sleep(0.5) # 发送 handoff handoff = HandoffMessage( source_agent="agent_a", target_agent="agent_b", task_id="t5", task_type="search", context={"q": "test"}, reason="delegation", ) await manager.send_handoff(handoff) # 等待处理 await asyncio.sleep(1.0) # 取消监听任务 listen_task.cancel() try: await listen_task except asyncio.CancelledError: pass assert len(received) == 1 assert received[0].task_id == "t5" assert received[0].source_agent == "agent_a" assert received[0].target_agent == "agent_b" assert received[0].context == {"q": "test"} assert received[0].reason == "delegation" @pytest.mark.skipif(not redis_available, reason="Redis not available") async def test_handoff_message_contains_all_fields(self, redis_client, clean_redis): """验证 handoff 消息包含 source_agent, target_agent, context, reason""" manager = HandoffManager(redis=redis_client) handoff = HandoffMessage( source_agent="researcher", target_agent="writer", task_id="t6", task_type="compose", context={"research": "findings", "style": "formal"}, reason="needs writing expertise", ) await manager.send_handoff(handoff) pubsub = redis_client.pubsub() await pubsub.subscribe("agent:writer:handoff") # 等待订阅确认 await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) await manager.send_handoff(handoff) while True: msg = await asyncio.wait_for(pubsub.get_message(timeout=2.0), timeout=3.0) if msg and msg.get("type") == "message": data = json.loads(msg["data"]) assert data["source_agent"] == "researcher" assert data["target_agent"] == "writer" assert data["context"] == {"research": "findings", "style": "formal"} assert data["reason"] == "needs writing expertise" assert data["task_id"] == "t6" assert data["task_type"] == "compose" assert "created_at" in data break await pubsub.unsubscribe("agent:writer:handoff")