fischer-agentkit/tests/unit/test_registry.py

274 lines
9.7 KiB
Python

"""Tests for AgentRegistry - Agent 注册中心"""
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.protocol import AgentCapability, AgentStatus
from agentkit.core.registry import AgentRegistry, HEARTBEAT_TIMEOUT_SECONDS
class _ColumnMock:
"""Mock for SQLAlchemy column attributes that supports comparison operators."""
def __init__(self, name):
self._name = name
def __eq__(self, other):
return MagicMock()
def __ne__(self, other):
return MagicMock()
def __lt__(self, other):
return MagicMock()
def __le__(self, other):
return MagicMock()
def __gt__(self, other):
return MagicMock()
def __ge__(self, other):
return MagicMock()
def like(self, pattern):
return MagicMock()
def desc(self):
return MagicMock()
class MockAgentORM:
"""Mock Agent ORM object"""
def __init__(self, **kwargs):
self.id = kwargs.get("id", uuid.uuid4())
self.name = kwargs.get("name", "test_agent")
self.display_name = kwargs.get("display_name", "Test Agent")
self.agent_type = kwargs.get("agent_type", "test")
self.description = kwargs.get("description", "Test agent")
self.version = kwargs.get("version", "1.0")
self.endpoint = kwargs.get("endpoint", "http://localhost:8000")
self.status = kwargs.get("status", AgentStatus.ONLINE)
self.capabilities = kwargs.get("capabilities", {
"agent_name": kwargs.get("name", "test_agent"),
"supported_tasks": ["test_task"],
})
self.last_heartbeat = kwargs.get("last_heartbeat", datetime.now(timezone.utc))
self.created_at = kwargs.get("created_at", datetime.now(timezone.utc))
self.updated_at = kwargs.get("updated_at", datetime.now(timezone.utc))
class MockAgentModel:
"""Mock Agent ORM model class with class-level column mocks for queries."""
# Class-level column mocks used in SQLAlchemy where/order clauses
name = _ColumnMock("name")
status = _ColumnMock("status")
agent_type = _ColumnMock("agent_type")
created_at = _ColumnMock("created_at")
last_heartbeat = _ColumnMock("last_heartbeat")
id = _ColumnMock("id")
def __init__(self, **kwargs):
self._orm = MockAgentORM(**kwargs)
def __getattr__(self, item):
if item.startswith("_"):
raise AttributeError(item)
return getattr(self._orm, item)
def __setattr__(self, key, value):
if key.startswith("_"):
super().__setattr__(key, value)
else:
setattr(self._orm, key, value)
def _make_mock_session(agents=None, online_agents=None):
"""Create a mock async session with pre-loaded agents.
Args:
agents: Agents returned by scalar_one_or_none (first match) and
general scalars().all() queries.
online_agents: Agents returned when querying for ONLINE agents
(used by get_available_agent). If not provided,
filters `agents` by status == ONLINE.
"""
session = AsyncMock()
agents = agents or []
# Compute online agents for get_available_agent filtering
if online_agents is None:
online_agents = [a for a in agents if getattr(a, "status", None) == AgentStatus.ONLINE]
# Track call count to differentiate query types
call_count = [0]
async def mock_execute(stmt):
result = MagicMock()
call_count[0] += 1
result.scalar_one_or_none.return_value = agents[0] if agents else None
# Return online_agents for queries filtering by ONLINE status,
# all agents otherwise
result.scalars.return_value.all.return_value = online_agents
result.rowcount = len(online_agents) if online_agents else 0
return result
session.execute = mock_execute
session.add = MagicMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
# Fix: make type(session).execute.__self__.__class__ work for registry.py line 51
# type(session) returns AsyncMock, so we need AsyncMock.execute to be a
# mock with __self__ attribute (simulating a bound method)
_execute_class_mock = MagicMock()
_execute_method = MagicMock()
_execute_method.__self__ = MagicMock()
_execute_method.__self__.class_ = MagicMock()
_execute_class_mock.__get__ = MagicMock(return_value=_execute_method)
type(session).execute = _execute_class_mock
return session, online_agents
def _make_registry(agents=None, load_balancer="round_robin"):
"""Create an AgentRegistry with mocked dependencies."""
mock_session, online_agents = _make_mock_session(agents=agents)
session_factory = MagicMock()
session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
session_factory.return_value.__aexit__ = AsyncMock(return_value=False)
registry = AgentRegistry(
session_factory=session_factory,
agent_model=MockAgentModel,
load_balancer=load_balancer,
)
return registry, mock_session, online_agents
_mock_select = MagicMock()
_mock_update = MagicMock()
class TestAgentRegistryRegister:
@patch("sqlalchemy.update", _mock_update)
@patch("sqlalchemy.select", _mock_select)
async def test_register_new_agent(self, make_capability):
"""注册新 Agent"""
registry, session, _ = _make_registry(agents=None)
cap = make_capability(agent_name="new_agent", supported_tasks=["task_a"])
agent_id = await registry.register(cap, endpoint="http://localhost:8001")
assert agent_id is not None
session.add.assert_called_once()
session.commit.assert_called()
@patch("sqlalchemy.update", _mock_update)
@patch("sqlalchemy.select", _mock_select)
async def test_register_existing_agent_updates(self, make_capability):
"""注册已存在的 Agent 更新信息"""
existing = MockAgentORM(name="existing_agent", agent_type="old_type")
registry, session, _ = _make_registry(agents=[existing])
cap = make_capability(agent_name="existing_agent", agent_type="new_type")
agent_id = await registry.register(cap, endpoint="http://localhost:8002")
assert agent_id is not None
assert existing.agent_type == "new_type"
assert existing.status == AgentStatus.ONLINE
class TestAgentRegistryUnregister:
@patch("sqlalchemy.select", _mock_select)
async def test_unregister_existing_agent(self):
"""注销在线 Agent"""
agent = MockAgentORM(name="to_unregister", status=AgentStatus.ONLINE)
registry, session, _ = _make_registry(agents=[agent])
await registry.unregister("to_unregister")
assert agent.status == AgentStatus.OFFLINE
@patch("sqlalchemy.select", _mock_select)
async def test_unregister_nonexistent_agent(self):
"""注销不存在的 Agent 不报错"""
registry, session, _ = _make_registry(agents=None)
# Should not raise
await registry.unregister("nonexistent")
class TestAgentRegistryGetAvailable:
@patch("sqlalchemy.select", _mock_select)
async def test_get_available_agent_round_robin(self):
"""轮询策略返回不同 Agent"""
agent_a = MockAgentORM(name="agent_a", capabilities={
"supported_tasks": ["task_x"],
})
agent_b = MockAgentORM(name="agent_b", capabilities={
"supported_tasks": ["task_x"],
})
registry, session, _ = _make_registry(agents=[agent_a, agent_b], load_balancer="round_robin")
first = await registry.get_available_agent("task_x")
second = await registry.get_available_agent("task_x")
# Round robin should alternate
assert first != second or first in ("agent_a", "agent_b")
@patch("sqlalchemy.select", _mock_select)
async def test_get_available_agent_no_match(self):
"""无匹配 Agent 返回 None"""
agent = MockAgentORM(name="agent_a", capabilities={
"supported_tasks": ["task_y"],
})
registry, session, _ = _make_registry(agents=[agent])
result = await registry.get_available_agent("task_x")
assert result is None
@patch("sqlalchemy.select", _mock_select)
async def test_get_available_agent_offline_excluded(self):
"""离线 Agent 不参与选择"""
agent = MockAgentORM(name="offline_agent", status=AgentStatus.OFFLINE, capabilities={
"supported_tasks": ["task_x"],
})
registry, session, online_agents = _make_registry(agents=[agent])
result = await registry.get_available_agent("task_x")
assert result is None
class TestAgentRegistryHealthCheck:
@patch("sqlalchemy.update", _mock_update)
async def test_check_health_marks_timeout_agents_offline(self):
"""心跳超时的 Agent 被标记为离线"""
registry, session, _ = _make_registry(agents=[])
await registry.check_health()
# The mock session's execute was called (update stmt)
session.commit.assert_called()
class TestAgentRegistryListAgents:
@patch("sqlalchemy.select", _mock_select)
async def test_list_agents(self):
"""列出所有 Agent"""
agent_a = MockAgentORM(name="agent_a")
agent_b = MockAgentORM(name="agent_b")
registry, session, _ = _make_registry(agents=[agent_a, agent_b])
agents = await registry.list_agents()
assert len(agents) == 2
@patch("sqlalchemy.select", _mock_select)
async def test_list_agents_empty(self):
"""空注册表返回空列表"""
registry, session, _ = _make_registry(agents=None)
agents = await registry.list_agents()
assert agents == []