fischer-agentkit/tests/unit/test_agent_bus.py

332 lines
11 KiB
Python

"""Tests for Agent inter-communication bus — new features.
Covers: msg_type, content, TTL, CascadeDetector integration,
AlignmentGuard integration, negotiate, request-timeout-returns-None.
"""
import asyncio
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.bus.message import AgentMessage
from agentkit.bus.memory_bus import InMemoryMessageBus
from agentkit.quality.cascade_detector import CascadeDetector
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
# ── AgentMessage new fields ───────────────────────────────
class TestAgentMessageNewFields:
def test_msg_type_default(self):
msg = AgentMessage(sender="a")
assert msg.msg_type == "notify"
def test_content_field(self):
msg = AgentMessage(sender="a", content="hello")
assert msg.content == "hello"
def test_ttl_default(self):
msg = AgentMessage(sender="a")
assert msg.ttl_seconds == 300
def test_is_expired_fresh(self):
msg = AgentMessage(sender="a")
assert msg.is_expired() is False
def test_is_expired_old(self):
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
msg = AgentMessage(sender="a", ttl_seconds=300, timestamp=old_ts)
assert msg.is_expired() is True
def test_to_dict_roundtrip_with_new_fields(self):
msg = AgentMessage(
sender="a",
content="data",
msg_type="negotiate",
ttl_seconds=60,
)
d = msg.to_dict()
restored = AgentMessage.from_dict(d)
assert restored.content == "data"
assert restored.msg_type == "negotiate"
assert restored.ttl_seconds == 60
# ── Request-Response with correlation_id ──────────────────
class TestRequestResponse:
@pytest.mark.asyncio
async def test_request_response_correlation(self):
"""Agent A sends request, Agent B responds, correlation_id matches."""
bus = InMemoryMessageBus()
async def handler_b(msg: AgentMessage):
reply = AgentMessage(
sender="agent_b",
recipient=msg.sender,
content="answer",
msg_type="response",
correlation_id=msg.correlation_id,
)
await bus.publish(reply)
await bus.subscribe("agent_b", handler_b)
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
content="question",
)
response = await bus.request(request, timeout_seconds=5.0)
assert response is not None
assert response.content == "answer"
assert response.correlation_id == request.correlation_id
# ── Broadcast ─────────────────────────────────────────────
class TestBroadcast:
@pytest.mark.asyncio
async def test_broadcast_all_subscribers_receive(self):
"""Agent A broadcasts, all subscribers receive."""
bus = InMemoryMessageBus()
received: dict[str, list[AgentMessage]] = {"b": [], "c": []}
async def handler_b(msg: AgentMessage):
received["b"].append(msg)
async def handler_c(msg: AgentMessage):
received["c"].append(msg)
await bus.subscribe("agent_b", handler_b)
await bus.subscribe("agent_c", handler_c)
result = await bus.publish(AgentMessage(
sender="agent_a",
content="hello everyone",
msg_type="notify",
))
assert result is True
assert len(received["b"]) == 1
assert len(received["c"]) == 1
assert received["b"][0].content == "hello everyone"
# ── Negotiate ─────────────────────────────────────────────
class TestNegotiate:
@pytest.mark.asyncio
async def test_negotiate_response(self):
"""Agent A sends negotiate, Agent B responds."""
bus = InMemoryMessageBus()
async def handler_b(msg: AgentMessage):
reply = AgentMessage(
sender="agent_b",
recipient=msg.sender,
content="deal accepted",
msg_type="response",
correlation_id=msg.correlation_id,
)
await bus.publish(reply)
await bus.subscribe("agent_b", handler_b)
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
content="propose deal",
msg_type="negotiate",
)
response = await bus.request(request, timeout_seconds=5.0)
assert response is not None
assert response.content == "deal accepted"
assert response.msg_type == "response"
# ── TTL Expired ───────────────────────────────────────────
class TestTTLExpired:
@pytest.mark.asyncio
async def test_expired_message_dropped(self):
"""Expired message is dropped by publish()."""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
msg = AgentMessage(
sender="agent_a",
recipient="agent_b",
content="old news",
ttl_seconds=300,
timestamp=old_ts,
)
result = await bus.publish(msg)
assert result is False
await asyncio.sleep(0.05)
assert len(received) == 0
# ── Unsubscribe ───────────────────────────────────────────
class TestUnsubscribe:
@pytest.mark.asyncio
async def test_unsubscribed_agent_no_receive(self):
"""Unsubscribed agent doesn't receive messages."""
bus = InMemoryMessageBus()
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
await bus.unsubscribe("agent_b")
await bus.publish(AgentMessage(
sender="agent_a",
content="still here?",
))
await asyncio.sleep(0.05)
assert len(received) == 0
# ── Request Timeout ───────────────────────────────────────
class TestRequestTimeout:
@pytest.mark.asyncio
async def test_request_timeout_returns_none(self):
"""No response within timeout returns None."""
bus = InMemoryMessageBus()
# No subscriber for agent_b
request = AgentMessage(
sender="agent_a",
recipient="agent_b",
content="anyone there?",
)
response = await bus.request(request, timeout_seconds=0.1)
assert response is None
# ── CascadeDetector Integration ───────────────────────────
class TestCascadeDetectorIntegration:
@pytest.mark.asyncio
async def test_cascade_alert_blocks_message(self):
"""Too many messages trigger cascade alert and block publishing."""
detector = CascadeDetector(max_interactions=3)
bus = InMemoryMessageBus(cascade_detector=detector)
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
# First 3 interactions should succeed (check_interaction increments before check)
# max_interactions=3 means count > 3 triggers alert, so 3 succeed, 4th fails
for i in range(3):
result = await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
content=f"msg {i}",
))
assert result is True
# 4th interaction should be blocked
result = await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
content="msg 4",
))
assert result is False
# ── AlignmentGuard Integration ────────────────────────────
class TestAlignmentGuardIntegration:
@pytest.mark.asyncio
async def test_violating_message_blocked(self):
"""Message violating alignment constraints is blocked."""
config = AlignmentConfig(constraints=["禁止暴力"])
guard = AlignmentGuard(config=config)
bus = InMemoryMessageBus(alignment_guard=guard)
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
# request with violating content should be blocked
result = await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
content="执行暴力行为",
msg_type="request",
))
assert result is False
@pytest.mark.asyncio
async def test_non_violating_request_passes(self):
"""Non-violating request message passes alignment check."""
config = AlignmentConfig(constraints=["禁止暴力"])
guard = AlignmentGuard(config=config)
bus = InMemoryMessageBus(alignment_guard=guard)
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
result = await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
content="和平交流",
msg_type="request",
))
assert result is True
@pytest.mark.asyncio
async def test_notify_not_checked_by_alignment(self):
"""notify type messages are not checked by alignment guard."""
config = AlignmentConfig(constraints=["禁止暴力"])
guard = AlignmentGuard(config=config)
bus = InMemoryMessageBus(alignment_guard=guard)
received: list[AgentMessage] = []
async def handler(msg: AgentMessage):
received.append(msg)
await bus.subscribe("agent_b", handler)
# notify with violating content should pass (not checked)
result = await bus.publish(AgentMessage(
sender="agent_a",
recipient="agent_b",
content="执行暴力行为",
msg_type="notify",
))
assert result is True