"""Tests for Agent inter-communication bus — new features. Covers: msg_type, content, TTL, CascadeDetector integration, AlignmentGuard integration, negotiate, request-timeout-returns-None. """ import asyncio from datetime import datetime, timezone, timedelta from unittest.mock import AsyncMock, MagicMock import pytest from agentkit.bus.message import AgentMessage from agentkit.bus.memory_bus import InMemoryMessageBus from agentkit.quality.cascade_detector import CascadeDetector from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig # ── AgentMessage new fields ─────────────────────────────── class TestAgentMessageNewFields: def test_msg_type_default(self): msg = AgentMessage(sender="a") assert msg.msg_type == "notify" def test_content_field(self): msg = AgentMessage(sender="a", content="hello") assert msg.content == "hello" def test_ttl_default(self): msg = AgentMessage(sender="a") assert msg.ttl_seconds == 300 def test_is_expired_fresh(self): msg = AgentMessage(sender="a") assert msg.is_expired() is False def test_is_expired_old(self): old_ts = datetime.now(timezone.utc) - timedelta(seconds=600) msg = AgentMessage(sender="a", ttl_seconds=300, timestamp=old_ts) assert msg.is_expired() is True def test_to_dict_roundtrip_with_new_fields(self): msg = AgentMessage( sender="a", content="data", msg_type="negotiate", ttl_seconds=60, ) d = msg.to_dict() restored = AgentMessage.from_dict(d) assert restored.content == "data" assert restored.msg_type == "negotiate" assert restored.ttl_seconds == 60 # ── Request-Response with correlation_id ────────────────── class TestRequestResponse: @pytest.mark.asyncio async def test_request_response_correlation(self): """Agent A sends request, Agent B responds, correlation_id matches.""" bus = InMemoryMessageBus() async def handler_b(msg: AgentMessage): reply = AgentMessage( sender="agent_b", recipient=msg.sender, content="answer", msg_type="response", correlation_id=msg.correlation_id, ) await bus.publish(reply) await bus.subscribe("agent_b", handler_b) request = AgentMessage( sender="agent_a", recipient="agent_b", content="question", ) response = await bus.request(request, timeout_seconds=5.0) assert response is not None assert response.content == "answer" assert response.correlation_id == request.correlation_id # ── Broadcast ───────────────────────────────────────────── class TestBroadcast: @pytest.mark.asyncio async def test_broadcast_all_subscribers_receive(self): """Agent A broadcasts, all subscribers receive.""" bus = InMemoryMessageBus() received: dict[str, list[AgentMessage]] = {"b": [], "c": []} async def handler_b(msg: AgentMessage): received["b"].append(msg) async def handler_c(msg: AgentMessage): received["c"].append(msg) await bus.subscribe("agent_b", handler_b) await bus.subscribe("agent_c", handler_c) result = await bus.publish(AgentMessage( sender="agent_a", content="hello everyone", msg_type="notify", )) assert result is True assert len(received["b"]) == 1 assert len(received["c"]) == 1 assert received["b"][0].content == "hello everyone" # ── Negotiate ───────────────────────────────────────────── class TestNegotiate: @pytest.mark.asyncio async def test_negotiate_response(self): """Agent A sends negotiate, Agent B responds.""" bus = InMemoryMessageBus() async def handler_b(msg: AgentMessage): reply = AgentMessage( sender="agent_b", recipient=msg.sender, content="deal accepted", msg_type="response", correlation_id=msg.correlation_id, ) await bus.publish(reply) await bus.subscribe("agent_b", handler_b) request = AgentMessage( sender="agent_a", recipient="agent_b", content="propose deal", msg_type="negotiate", ) response = await bus.request(request, timeout_seconds=5.0) assert response is not None assert response.content == "deal accepted" assert response.msg_type == "response" # ── TTL Expired ─────────────────────────────────────────── class TestTTLExpired: @pytest.mark.asyncio async def test_expired_message_dropped(self): """Expired message is dropped by publish().""" bus = InMemoryMessageBus() received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) old_ts = datetime.now(timezone.utc) - timedelta(seconds=600) msg = AgentMessage( sender="agent_a", recipient="agent_b", content="old news", ttl_seconds=300, timestamp=old_ts, ) result = await bus.publish(msg) assert result is False await asyncio.sleep(0.05) assert len(received) == 0 # ── Unsubscribe ─────────────────────────────────────────── class TestUnsubscribe: @pytest.mark.asyncio async def test_unsubscribed_agent_no_receive(self): """Unsubscribed agent doesn't receive messages.""" bus = InMemoryMessageBus() received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) await bus.unsubscribe("agent_b") await bus.publish(AgentMessage( sender="agent_a", content="still here?", )) await asyncio.sleep(0.05) assert len(received) == 0 # ── Request Timeout ─────────────────────────────────────── class TestRequestTimeout: @pytest.mark.asyncio async def test_request_timeout_returns_none(self): """No response within timeout returns None.""" bus = InMemoryMessageBus() # No subscriber for agent_b request = AgentMessage( sender="agent_a", recipient="agent_b", content="anyone there?", ) response = await bus.request(request, timeout_seconds=0.1) assert response is None # ── CascadeDetector Integration ─────────────────────────── class TestCascadeDetectorIntegration: @pytest.mark.asyncio async def test_cascade_alert_blocks_message(self): """Too many messages trigger cascade alert and block publishing.""" detector = CascadeDetector(max_interactions=3) bus = InMemoryMessageBus(cascade_detector=detector) received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) # First 3 interactions should succeed (check_interaction increments before check) # max_interactions=3 means count > 3 triggers alert, so 3 succeed, 4th fails for i in range(3): result = await bus.publish(AgentMessage( sender="agent_a", recipient="agent_b", content=f"msg {i}", )) assert result is True # 4th interaction should be blocked result = await bus.publish(AgentMessage( sender="agent_a", recipient="agent_b", content="msg 4", )) assert result is False # ── AlignmentGuard Integration ──────────────────────────── class TestAlignmentGuardIntegration: @pytest.mark.asyncio async def test_violating_message_blocked(self): """Message violating alignment constraints is blocked.""" config = AlignmentConfig(constraints=["禁止暴力"]) guard = AlignmentGuard(config=config) bus = InMemoryMessageBus(alignment_guard=guard) received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) # request with violating content should be blocked result = await bus.publish(AgentMessage( sender="agent_a", recipient="agent_b", content="执行暴力行为", msg_type="request", )) assert result is False @pytest.mark.asyncio async def test_non_violating_request_passes(self): """Non-violating request message passes alignment check.""" config = AlignmentConfig(constraints=["禁止暴力"]) guard = AlignmentGuard(config=config) bus = InMemoryMessageBus(alignment_guard=guard) received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) result = await bus.publish(AgentMessage( sender="agent_a", recipient="agent_b", content="和平交流", msg_type="request", )) assert result is True @pytest.mark.asyncio async def test_notify_not_checked_by_alignment(self): """notify type messages are not checked by alignment guard.""" config = AlignmentConfig(constraints=["禁止暴力"]) guard = AlignmentGuard(config=config) bus = InMemoryMessageBus(alignment_guard=guard) received: list[AgentMessage] = [] async def handler(msg: AgentMessage): received.append(msg) await bus.subscribe("agent_b", handler) # notify with violating content should pass (not checked) result = await bus.publish(AgentMessage( sender="agent_a", recipient="agent_b", content="执行暴力行为", msg_type="notify", )) assert result is True