332 lines
11 KiB
Python
332 lines
11 KiB
Python
"""Tests for Agent inter-communication bus — new features.
|
|
|
|
Covers: msg_type, content, TTL, CascadeDetector integration,
|
|
AlignmentGuard integration, negotiate, request-timeout-returns-None.
|
|
"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.bus.message import AgentMessage
|
|
from agentkit.bus.memory_bus import InMemoryMessageBus
|
|
from agentkit.quality.cascade_detector import CascadeDetector
|
|
from agentkit.quality.alignment import AlignmentGuard, AlignmentConfig
|
|
|
|
|
|
# ── AgentMessage new fields ───────────────────────────────
|
|
|
|
|
|
class TestAgentMessageNewFields:
|
|
def test_msg_type_default(self):
|
|
msg = AgentMessage(sender="a")
|
|
assert msg.msg_type == "notify"
|
|
|
|
def test_content_field(self):
|
|
msg = AgentMessage(sender="a", content="hello")
|
|
assert msg.content == "hello"
|
|
|
|
def test_ttl_default(self):
|
|
msg = AgentMessage(sender="a")
|
|
assert msg.ttl_seconds == 300
|
|
|
|
def test_is_expired_fresh(self):
|
|
msg = AgentMessage(sender="a")
|
|
assert msg.is_expired() is False
|
|
|
|
def test_is_expired_old(self):
|
|
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
|
|
msg = AgentMessage(sender="a", ttl_seconds=300, timestamp=old_ts)
|
|
assert msg.is_expired() is True
|
|
|
|
def test_to_dict_roundtrip_with_new_fields(self):
|
|
msg = AgentMessage(
|
|
sender="a",
|
|
content="data",
|
|
msg_type="negotiate",
|
|
ttl_seconds=60,
|
|
)
|
|
d = msg.to_dict()
|
|
restored = AgentMessage.from_dict(d)
|
|
assert restored.content == "data"
|
|
assert restored.msg_type == "negotiate"
|
|
assert restored.ttl_seconds == 60
|
|
|
|
|
|
# ── Request-Response with correlation_id ──────────────────
|
|
|
|
|
|
class TestRequestResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_request_response_correlation(self):
|
|
"""Agent A sends request, Agent B responds, correlation_id matches."""
|
|
bus = InMemoryMessageBus()
|
|
|
|
async def handler_b(msg: AgentMessage):
|
|
reply = AgentMessage(
|
|
sender="agent_b",
|
|
recipient=msg.sender,
|
|
content="answer",
|
|
msg_type="response",
|
|
correlation_id=msg.correlation_id,
|
|
)
|
|
await bus.publish(reply)
|
|
|
|
await bus.subscribe("agent_b", handler_b)
|
|
|
|
request = AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="question",
|
|
)
|
|
response = await bus.request(request, timeout_seconds=5.0)
|
|
assert response is not None
|
|
assert response.content == "answer"
|
|
assert response.correlation_id == request.correlation_id
|
|
|
|
|
|
# ── Broadcast ─────────────────────────────────────────────
|
|
|
|
|
|
class TestBroadcast:
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_all_subscribers_receive(self):
|
|
"""Agent A broadcasts, all subscribers receive."""
|
|
bus = InMemoryMessageBus()
|
|
received: dict[str, list[AgentMessage]] = {"b": [], "c": []}
|
|
|
|
async def handler_b(msg: AgentMessage):
|
|
received["b"].append(msg)
|
|
|
|
async def handler_c(msg: AgentMessage):
|
|
received["c"].append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler_b)
|
|
await bus.subscribe("agent_c", handler_c)
|
|
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
content="hello everyone",
|
|
msg_type="notify",
|
|
))
|
|
assert result is True
|
|
assert len(received["b"]) == 1
|
|
assert len(received["c"]) == 1
|
|
assert received["b"][0].content == "hello everyone"
|
|
|
|
|
|
# ── Negotiate ─────────────────────────────────────────────
|
|
|
|
|
|
class TestNegotiate:
|
|
@pytest.mark.asyncio
|
|
async def test_negotiate_response(self):
|
|
"""Agent A sends negotiate, Agent B responds."""
|
|
bus = InMemoryMessageBus()
|
|
|
|
async def handler_b(msg: AgentMessage):
|
|
reply = AgentMessage(
|
|
sender="agent_b",
|
|
recipient=msg.sender,
|
|
content="deal accepted",
|
|
msg_type="response",
|
|
correlation_id=msg.correlation_id,
|
|
)
|
|
await bus.publish(reply)
|
|
|
|
await bus.subscribe("agent_b", handler_b)
|
|
|
|
request = AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="propose deal",
|
|
msg_type="negotiate",
|
|
)
|
|
response = await bus.request(request, timeout_seconds=5.0)
|
|
assert response is not None
|
|
assert response.content == "deal accepted"
|
|
assert response.msg_type == "response"
|
|
|
|
|
|
# ── TTL Expired ───────────────────────────────────────────
|
|
|
|
|
|
class TestTTLExpired:
|
|
@pytest.mark.asyncio
|
|
async def test_expired_message_dropped(self):
|
|
"""Expired message is dropped by publish()."""
|
|
bus = InMemoryMessageBus()
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
|
|
old_ts = datetime.now(timezone.utc) - timedelta(seconds=600)
|
|
msg = AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="old news",
|
|
ttl_seconds=300,
|
|
timestamp=old_ts,
|
|
)
|
|
result = await bus.publish(msg)
|
|
assert result is False
|
|
await asyncio.sleep(0.05)
|
|
assert len(received) == 0
|
|
|
|
|
|
# ── Unsubscribe ───────────────────────────────────────────
|
|
|
|
|
|
class TestUnsubscribe:
|
|
@pytest.mark.asyncio
|
|
async def test_unsubscribed_agent_no_receive(self):
|
|
"""Unsubscribed agent doesn't receive messages."""
|
|
bus = InMemoryMessageBus()
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
await bus.unsubscribe("agent_b")
|
|
|
|
await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
content="still here?",
|
|
))
|
|
await asyncio.sleep(0.05)
|
|
assert len(received) == 0
|
|
|
|
|
|
# ── Request Timeout ───────────────────────────────────────
|
|
|
|
|
|
class TestRequestTimeout:
|
|
@pytest.mark.asyncio
|
|
async def test_request_timeout_returns_none(self):
|
|
"""No response within timeout returns None."""
|
|
bus = InMemoryMessageBus()
|
|
|
|
# No subscriber for agent_b
|
|
request = AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="anyone there?",
|
|
)
|
|
response = await bus.request(request, timeout_seconds=0.1)
|
|
assert response is None
|
|
|
|
|
|
# ── CascadeDetector Integration ───────────────────────────
|
|
|
|
|
|
class TestCascadeDetectorIntegration:
|
|
@pytest.mark.asyncio
|
|
async def test_cascade_alert_blocks_message(self):
|
|
"""Too many messages trigger cascade alert and block publishing."""
|
|
detector = CascadeDetector(max_interactions=3)
|
|
bus = InMemoryMessageBus(cascade_detector=detector)
|
|
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
|
|
# First 3 interactions should succeed (check_interaction increments before check)
|
|
# max_interactions=3 means count > 3 triggers alert, so 3 succeed, 4th fails
|
|
for i in range(3):
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content=f"msg {i}",
|
|
))
|
|
assert result is True
|
|
|
|
# 4th interaction should be blocked
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="msg 4",
|
|
))
|
|
assert result is False
|
|
|
|
|
|
# ── AlignmentGuard Integration ────────────────────────────
|
|
|
|
|
|
class TestAlignmentGuardIntegration:
|
|
@pytest.mark.asyncio
|
|
async def test_violating_message_blocked(self):
|
|
"""Message violating alignment constraints is blocked."""
|
|
config = AlignmentConfig(constraints=["禁止暴力"])
|
|
guard = AlignmentGuard(config=config)
|
|
bus = InMemoryMessageBus(alignment_guard=guard)
|
|
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
|
|
# request with violating content should be blocked
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="执行暴力行为",
|
|
msg_type="request",
|
|
))
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_violating_request_passes(self):
|
|
"""Non-violating request message passes alignment check."""
|
|
config = AlignmentConfig(constraints=["禁止暴力"])
|
|
guard = AlignmentGuard(config=config)
|
|
bus = InMemoryMessageBus(alignment_guard=guard)
|
|
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="和平交流",
|
|
msg_type="request",
|
|
))
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_notify_not_checked_by_alignment(self):
|
|
"""notify type messages are not checked by alignment guard."""
|
|
config = AlignmentConfig(constraints=["禁止暴力"])
|
|
guard = AlignmentGuard(config=config)
|
|
bus = InMemoryMessageBus(alignment_guard=guard)
|
|
|
|
received: list[AgentMessage] = []
|
|
|
|
async def handler(msg: AgentMessage):
|
|
received.append(msg)
|
|
|
|
await bus.subscribe("agent_b", handler)
|
|
|
|
# notify with violating content should pass (not checked)
|
|
result = await bus.publish(AgentMessage(
|
|
sender="agent_a",
|
|
recipient="agent_b",
|
|
content="执行暴力行为",
|
|
msg_type="notify",
|
|
))
|
|
assert result is True
|