geo/backend/tests/test_agent_framework/test_agent_registry.py

229 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent注册表测试"""
import pytest
import uuid
from app.agent_framework.registry import AgentRegistry
from app.agent_framework.protocol import AgentCapability, AgentStatus
def is_database_available():
"""检查数据库是否可用(同步方式)"""
try:
# 直接尝试同步方式检查
from sqlalchemy import create_engine, text
from app.config import settings
# 从URL创建同步引擎进行测试
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
if 'sqlite' in sync_url:
engine = create_engine(sync_url)
else:
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
engine.dispose()
return True
except Exception:
return False
# 检查服务是否可用
_db_available = None
def check_db():
global _db_available
if _db_available is None:
try:
_db_available = is_database_available()
except Exception:
_db_available = False
return _db_available
class TestAgentRegistry:
"""Agent注册表测试"""
@pytest.mark.asyncio
async def test_register_agent(self):
"""测试Agent注册"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect", "citation_detect_single"],
max_concurrency=3,
description="测试Agent",
)
agent_id = await registry.register(capability, endpoint=f"agent:{agent_name}")
# 验证注册成功
assert agent_id is not None
assert len(agent_id) > 0
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_get_registered_agent(self):
"""测试获取已注册的Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
retrieved = await registry.get_agent(agent_name)
assert retrieved is not None
assert retrieved["name"] == agent_name
assert retrieved["agent_type"] == "citation_detector"
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_list_all_agents(self):
"""测试列出所有Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agents = await registry.list_agents()
assert isinstance(agents, list)
@pytest.mark.asyncio
async def test_unregister_agent(self):
"""测试取消注册"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
await registry.unregister(agent_name)
retrieved = await registry.get_agent(agent_name)
# 注销后状态应为OFFLINE或None
assert retrieved is None or retrieved["status"] == AgentStatus.OFFLINE
@pytest.mark.asyncio
async def test_update_heartbeat(self):
"""测试心跳更新"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
await registry.update_heartbeat(agent_name)
agent = await registry.get_agent(agent_name)
assert agent is not None
assert agent["last_heartbeat"] is not None
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_get_available_agent(self):
"""测试根据任务类型获取可用Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect", "citation_detect_single"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
available = await registry.get_available_agent("citation_detect")
# 可能返回None因为Agent状态不是ONLINE
# 这里只验证方法能正常执行
assert available is None or isinstance(available, str)
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_agent_reregistration(self):
"""测试Agent重复注册应该更新而非报错"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability1 = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试AgentV1",
)
capability2 = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="2.0.0",
supported_tasks=["citation_detect", "new_task"],
max_concurrency=5,
description="测试AgentV2",
)
await registry.register(capability1, endpoint=f"agent:{agent_name}")
# 验证第一次注册成功
agent_data = await registry.get_agent(agent_name)
assert agent_data is not None
# 重新注册同名Agent
agent_id2 = await registry.register(capability2, endpoint=f"agent:{agent_name}")
# 应该成功且不报错
assert agent_id2 is not None
# 清理
await registry.unregister(agent_name)