274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
"""Tests for AgentRegistry - Agent 注册中心"""
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.protocol import AgentCapability, AgentStatus
|
|
from agentkit.core.registry import AgentRegistry, HEARTBEAT_TIMEOUT_SECONDS
|
|
|
|
|
|
class _ColumnMock:
|
|
"""Mock for SQLAlchemy column attributes that supports comparison operators."""
|
|
|
|
def __init__(self, name):
|
|
self._name = name
|
|
|
|
def __eq__(self, other):
|
|
return MagicMock()
|
|
|
|
def __ne__(self, other):
|
|
return MagicMock()
|
|
|
|
def __lt__(self, other):
|
|
return MagicMock()
|
|
|
|
def __le__(self, other):
|
|
return MagicMock()
|
|
|
|
def __gt__(self, other):
|
|
return MagicMock()
|
|
|
|
def __ge__(self, other):
|
|
return MagicMock()
|
|
|
|
def like(self, pattern):
|
|
return MagicMock()
|
|
|
|
def desc(self):
|
|
return MagicMock()
|
|
|
|
|
|
class MockAgentORM:
|
|
"""Mock Agent ORM object"""
|
|
def __init__(self, **kwargs):
|
|
self.id = kwargs.get("id", uuid.uuid4())
|
|
self.name = kwargs.get("name", "test_agent")
|
|
self.display_name = kwargs.get("display_name", "Test Agent")
|
|
self.agent_type = kwargs.get("agent_type", "test")
|
|
self.description = kwargs.get("description", "Test agent")
|
|
self.version = kwargs.get("version", "1.0")
|
|
self.endpoint = kwargs.get("endpoint", "http://localhost:8000")
|
|
self.status = kwargs.get("status", AgentStatus.ONLINE)
|
|
self.capabilities = kwargs.get("capabilities", {
|
|
"agent_name": kwargs.get("name", "test_agent"),
|
|
"supported_tasks": ["test_task"],
|
|
})
|
|
self.last_heartbeat = kwargs.get("last_heartbeat", datetime.now(timezone.utc))
|
|
self.created_at = kwargs.get("created_at", datetime.now(timezone.utc))
|
|
self.updated_at = kwargs.get("updated_at", datetime.now(timezone.utc))
|
|
|
|
|
|
class MockAgentModel:
|
|
"""Mock Agent ORM model class with class-level column mocks for queries."""
|
|
|
|
# Class-level column mocks used in SQLAlchemy where/order clauses
|
|
name = _ColumnMock("name")
|
|
status = _ColumnMock("status")
|
|
agent_type = _ColumnMock("agent_type")
|
|
created_at = _ColumnMock("created_at")
|
|
last_heartbeat = _ColumnMock("last_heartbeat")
|
|
id = _ColumnMock("id")
|
|
|
|
def __init__(self, **kwargs):
|
|
self._orm = MockAgentORM(**kwargs)
|
|
|
|
def __getattr__(self, item):
|
|
if item.startswith("_"):
|
|
raise AttributeError(item)
|
|
return getattr(self._orm, item)
|
|
|
|
def __setattr__(self, key, value):
|
|
if key.startswith("_"):
|
|
super().__setattr__(key, value)
|
|
else:
|
|
setattr(self._orm, key, value)
|
|
|
|
|
|
def _make_mock_session(agents=None, online_agents=None):
|
|
"""Create a mock async session with pre-loaded agents.
|
|
|
|
Args:
|
|
agents: Agents returned by scalar_one_or_none (first match) and
|
|
general scalars().all() queries.
|
|
online_agents: Agents returned when querying for ONLINE agents
|
|
(used by get_available_agent). If not provided,
|
|
filters `agents` by status == ONLINE.
|
|
"""
|
|
session = AsyncMock()
|
|
agents = agents or []
|
|
|
|
# Compute online agents for get_available_agent filtering
|
|
if online_agents is None:
|
|
online_agents = [a for a in agents if getattr(a, "status", None) == AgentStatus.ONLINE]
|
|
|
|
# Track call count to differentiate query types
|
|
call_count = [0]
|
|
|
|
async def mock_execute(stmt):
|
|
result = MagicMock()
|
|
call_count[0] += 1
|
|
result.scalar_one_or_none.return_value = agents[0] if agents else None
|
|
# Return online_agents for queries filtering by ONLINE status,
|
|
# all agents otherwise
|
|
result.scalars.return_value.all.return_value = online_agents
|
|
result.rowcount = len(online_agents) if online_agents else 0
|
|
return result
|
|
|
|
session.execute = mock_execute
|
|
session.add = MagicMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
|
|
# Fix: make type(session).execute.__self__.__class__ work for registry.py line 51
|
|
# type(session) returns AsyncMock, so we need AsyncMock.execute to be a
|
|
# mock with __self__ attribute (simulating a bound method)
|
|
_execute_class_mock = MagicMock()
|
|
_execute_method = MagicMock()
|
|
_execute_method.__self__ = MagicMock()
|
|
_execute_method.__self__.class_ = MagicMock()
|
|
_execute_class_mock.__get__ = MagicMock(return_value=_execute_method)
|
|
type(session).execute = _execute_class_mock
|
|
|
|
return session, online_agents
|
|
|
|
|
|
def _make_registry(agents=None, load_balancer="round_robin"):
|
|
"""Create an AgentRegistry with mocked dependencies."""
|
|
mock_session, online_agents = _make_mock_session(agents=agents)
|
|
|
|
session_factory = MagicMock()
|
|
session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
session_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
registry = AgentRegistry(
|
|
session_factory=session_factory,
|
|
agent_model=MockAgentModel,
|
|
load_balancer=load_balancer,
|
|
)
|
|
|
|
return registry, mock_session, online_agents
|
|
|
|
|
|
_mock_select = MagicMock()
|
|
_mock_update = MagicMock()
|
|
|
|
|
|
class TestAgentRegistryRegister:
|
|
@patch("sqlalchemy.update", _mock_update)
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_register_new_agent(self, make_capability):
|
|
"""注册新 Agent"""
|
|
registry, session, _ = _make_registry(agents=None)
|
|
cap = make_capability(agent_name="new_agent", supported_tasks=["task_a"])
|
|
|
|
agent_id = await registry.register(cap, endpoint="http://localhost:8001")
|
|
assert agent_id is not None
|
|
session.add.assert_called_once()
|
|
session.commit.assert_called()
|
|
|
|
@patch("sqlalchemy.update", _mock_update)
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_register_existing_agent_updates(self, make_capability):
|
|
"""注册已存在的 Agent 更新信息"""
|
|
existing = MockAgentORM(name="existing_agent", agent_type="old_type")
|
|
registry, session, _ = _make_registry(agents=[existing])
|
|
cap = make_capability(agent_name="existing_agent", agent_type="new_type")
|
|
|
|
agent_id = await registry.register(cap, endpoint="http://localhost:8002")
|
|
assert agent_id is not None
|
|
assert existing.agent_type == "new_type"
|
|
assert existing.status == AgentStatus.ONLINE
|
|
|
|
|
|
class TestAgentRegistryUnregister:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_unregister_existing_agent(self):
|
|
"""注销在线 Agent"""
|
|
agent = MockAgentORM(name="to_unregister", status=AgentStatus.ONLINE)
|
|
registry, session, _ = _make_registry(agents=[agent])
|
|
|
|
await registry.unregister("to_unregister")
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_unregister_nonexistent_agent(self):
|
|
"""注销不存在的 Agent 不报错"""
|
|
registry, session, _ = _make_registry(agents=None)
|
|
# Should not raise
|
|
await registry.unregister("nonexistent")
|
|
|
|
|
|
class TestAgentRegistryGetAvailable:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_get_available_agent_round_robin(self):
|
|
"""轮询策略返回不同 Agent"""
|
|
agent_a = MockAgentORM(name="agent_a", capabilities={
|
|
"supported_tasks": ["task_x"],
|
|
})
|
|
agent_b = MockAgentORM(name="agent_b", capabilities={
|
|
"supported_tasks": ["task_x"],
|
|
})
|
|
registry, session, _ = _make_registry(agents=[agent_a, agent_b], load_balancer="round_robin")
|
|
|
|
first = await registry.get_available_agent("task_x")
|
|
second = await registry.get_available_agent("task_x")
|
|
|
|
# Round robin should alternate
|
|
assert first != second or first in ("agent_a", "agent_b")
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_get_available_agent_no_match(self):
|
|
"""无匹配 Agent 返回 None"""
|
|
agent = MockAgentORM(name="agent_a", capabilities={
|
|
"supported_tasks": ["task_y"],
|
|
})
|
|
registry, session, _ = _make_registry(agents=[agent])
|
|
|
|
result = await registry.get_available_agent("task_x")
|
|
assert result is None
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_get_available_agent_offline_excluded(self):
|
|
"""离线 Agent 不参与选择"""
|
|
agent = MockAgentORM(name="offline_agent", status=AgentStatus.OFFLINE, capabilities={
|
|
"supported_tasks": ["task_x"],
|
|
})
|
|
registry, session, online_agents = _make_registry(agents=[agent])
|
|
|
|
result = await registry.get_available_agent("task_x")
|
|
assert result is None
|
|
|
|
|
|
class TestAgentRegistryHealthCheck:
|
|
@patch("sqlalchemy.update", _mock_update)
|
|
async def test_check_health_marks_timeout_agents_offline(self):
|
|
"""心跳超时的 Agent 被标记为离线"""
|
|
registry, session, _ = _make_registry(agents=[])
|
|
|
|
await registry.check_health()
|
|
# The mock session's execute was called (update stmt)
|
|
session.commit.assert_called()
|
|
|
|
|
|
class TestAgentRegistryListAgents:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_list_agents(self):
|
|
"""列出所有 Agent"""
|
|
agent_a = MockAgentORM(name="agent_a")
|
|
agent_b = MockAgentORM(name="agent_b")
|
|
registry, session, _ = _make_registry(agents=[agent_a, agent_b])
|
|
|
|
agents = await registry.list_agents()
|
|
assert len(agents) == 2
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_list_agents_empty(self):
|
|
"""空注册表返回空列表"""
|
|
registry, session, _ = _make_registry(agents=None)
|
|
agents = await registry.list_agents()
|
|
assert agents == []
|