"""Tests for AgentRegistry - Agent 注册中心""" import uuid from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.protocol import AgentCapability, AgentStatus from agentkit.core.registry import AgentRegistry, HEARTBEAT_TIMEOUT_SECONDS class _ColumnMock: """Mock for SQLAlchemy column attributes that supports comparison operators.""" def __init__(self, name): self._name = name def __eq__(self, other): return MagicMock() def __ne__(self, other): return MagicMock() def __lt__(self, other): return MagicMock() def __le__(self, other): return MagicMock() def __gt__(self, other): return MagicMock() def __ge__(self, other): return MagicMock() def like(self, pattern): return MagicMock() def desc(self): return MagicMock() class MockAgentORM: """Mock Agent ORM object""" def __init__(self, **kwargs): self.id = kwargs.get("id", uuid.uuid4()) self.name = kwargs.get("name", "test_agent") self.display_name = kwargs.get("display_name", "Test Agent") self.agent_type = kwargs.get("agent_type", "test") self.description = kwargs.get("description", "Test agent") self.version = kwargs.get("version", "1.0") self.endpoint = kwargs.get("endpoint", "http://localhost:8000") self.status = kwargs.get("status", AgentStatus.ONLINE) self.capabilities = kwargs.get("capabilities", { "agent_name": kwargs.get("name", "test_agent"), "supported_tasks": ["test_task"], }) self.last_heartbeat = kwargs.get("last_heartbeat", datetime.now(timezone.utc)) self.created_at = kwargs.get("created_at", datetime.now(timezone.utc)) self.updated_at = kwargs.get("updated_at", datetime.now(timezone.utc)) class MockAgentModel: """Mock Agent ORM model class with class-level column mocks for queries.""" # Class-level column mocks used in SQLAlchemy where/order clauses name = _ColumnMock("name") status = _ColumnMock("status") agent_type = _ColumnMock("agent_type") created_at = _ColumnMock("created_at") last_heartbeat = _ColumnMock("last_heartbeat") id = _ColumnMock("id") def __init__(self, **kwargs): self._orm = MockAgentORM(**kwargs) def __getattr__(self, item): if item.startswith("_"): raise AttributeError(item) return getattr(self._orm, item) def __setattr__(self, key, value): if key.startswith("_"): super().__setattr__(key, value) else: setattr(self._orm, key, value) def _make_mock_session(agents=None, online_agents=None): """Create a mock async session with pre-loaded agents. Args: agents: Agents returned by scalar_one_or_none (first match) and general scalars().all() queries. online_agents: Agents returned when querying for ONLINE agents (used by get_available_agent). If not provided, filters `agents` by status == ONLINE. """ session = AsyncMock() agents = agents or [] # Compute online agents for get_available_agent filtering if online_agents is None: online_agents = [a for a in agents if getattr(a, "status", None) == AgentStatus.ONLINE] # Track call count to differentiate query types call_count = [0] async def mock_execute(stmt): result = MagicMock() call_count[0] += 1 result.scalar_one_or_none.return_value = agents[0] if agents else None # Return online_agents for queries filtering by ONLINE status, # all agents otherwise result.scalars.return_value.all.return_value = online_agents result.rowcount = len(online_agents) if online_agents else 0 return result session.execute = mock_execute session.add = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() # Fix: make type(session).execute.__self__.__class__ work for registry.py line 51 # type(session) returns AsyncMock, so we need AsyncMock.execute to be a # mock with __self__ attribute (simulating a bound method) _execute_class_mock = MagicMock() _execute_method = MagicMock() _execute_method.__self__ = MagicMock() _execute_method.__self__.class_ = MagicMock() _execute_class_mock.__get__ = MagicMock(return_value=_execute_method) type(session).execute = _execute_class_mock return session, online_agents def _make_registry(agents=None, load_balancer="round_robin"): """Create an AgentRegistry with mocked dependencies.""" mock_session, online_agents = _make_mock_session(agents=agents) session_factory = MagicMock() session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) session_factory.return_value.__aexit__ = AsyncMock(return_value=False) registry = AgentRegistry( session_factory=session_factory, agent_model=MockAgentModel, load_balancer=load_balancer, ) return registry, mock_session, online_agents _mock_select = MagicMock() _mock_update = MagicMock() class TestAgentRegistryRegister: @patch("sqlalchemy.update", _mock_update) @patch("sqlalchemy.select", _mock_select) async def test_register_new_agent(self, make_capability): """注册新 Agent""" registry, session, _ = _make_registry(agents=None) cap = make_capability(agent_name="new_agent", supported_tasks=["task_a"]) agent_id = await registry.register(cap, endpoint="http://localhost:8001") assert agent_id is not None session.add.assert_called_once() session.commit.assert_called() @patch("sqlalchemy.update", _mock_update) @patch("sqlalchemy.select", _mock_select) async def test_register_existing_agent_updates(self, make_capability): """注册已存在的 Agent 更新信息""" existing = MockAgentORM(name="existing_agent", agent_type="old_type") registry, session, _ = _make_registry(agents=[existing]) cap = make_capability(agent_name="existing_agent", agent_type="new_type") agent_id = await registry.register(cap, endpoint="http://localhost:8002") assert agent_id is not None assert existing.agent_type == "new_type" assert existing.status == AgentStatus.ONLINE class TestAgentRegistryUnregister: @patch("sqlalchemy.select", _mock_select) async def test_unregister_existing_agent(self): """注销在线 Agent""" agent = MockAgentORM(name="to_unregister", status=AgentStatus.ONLINE) registry, session, _ = _make_registry(agents=[agent]) await registry.unregister("to_unregister") assert agent.status == AgentStatus.OFFLINE @patch("sqlalchemy.select", _mock_select) async def test_unregister_nonexistent_agent(self): """注销不存在的 Agent 不报错""" registry, session, _ = _make_registry(agents=None) # Should not raise await registry.unregister("nonexistent") class TestAgentRegistryGetAvailable: @patch("sqlalchemy.select", _mock_select) async def test_get_available_agent_round_robin(self): """轮询策略返回不同 Agent""" agent_a = MockAgentORM(name="agent_a", capabilities={ "supported_tasks": ["task_x"], }) agent_b = MockAgentORM(name="agent_b", capabilities={ "supported_tasks": ["task_x"], }) registry, session, _ = _make_registry(agents=[agent_a, agent_b], load_balancer="round_robin") first = await registry.get_available_agent("task_x") second = await registry.get_available_agent("task_x") # Round robin should alternate assert first != second or first in ("agent_a", "agent_b") @patch("sqlalchemy.select", _mock_select) async def test_get_available_agent_no_match(self): """无匹配 Agent 返回 None""" agent = MockAgentORM(name="agent_a", capabilities={ "supported_tasks": ["task_y"], }) registry, session, _ = _make_registry(agents=[agent]) result = await registry.get_available_agent("task_x") assert result is None @patch("sqlalchemy.select", _mock_select) async def test_get_available_agent_offline_excluded(self): """离线 Agent 不参与选择""" agent = MockAgentORM(name="offline_agent", status=AgentStatus.OFFLINE, capabilities={ "supported_tasks": ["task_x"], }) registry, session, online_agents = _make_registry(agents=[agent]) result = await registry.get_available_agent("task_x") assert result is None class TestAgentRegistryHealthCheck: @patch("sqlalchemy.update", _mock_update) async def test_check_health_marks_timeout_agents_offline(self): """心跳超时的 Agent 被标记为离线""" registry, session, _ = _make_registry(agents=[]) await registry.check_health() # The mock session's execute was called (update stmt) session.commit.assert_called() class TestAgentRegistryListAgents: @patch("sqlalchemy.select", _mock_select) async def test_list_agents(self): """列出所有 Agent""" agent_a = MockAgentORM(name="agent_a") agent_b = MockAgentORM(name="agent_b") registry, session, _ = _make_registry(agents=[agent_a, agent_b]) agents = await registry.list_agents() assert len(agents) == 2 @patch("sqlalchemy.select", _mock_select) async def test_list_agents_empty(self): """空注册表返回空列表""" registry, session, _ = _make_registry(agents=None) agents = await registry.list_agents() assert agents == []