229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
"""Agent注册表测试"""
|
||
import pytest
|
||
import uuid
|
||
|
||
from app.agent_framework.registry import AgentRegistry
|
||
from app.agent_framework.protocol import AgentCapability, AgentStatus
|
||
|
||
|
||
def is_database_available():
|
||
"""检查数据库是否可用(同步方式)"""
|
||
try:
|
||
# 直接尝试同步方式检查
|
||
from sqlalchemy import create_engine, text
|
||
from app.config import settings
|
||
|
||
# 从URL创建同步引擎进行测试
|
||
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
|
||
if 'sqlite' in sync_url:
|
||
engine = create_engine(sync_url)
|
||
else:
|
||
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
|
||
|
||
with engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
engine.dispose()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# 检查服务是否可用
|
||
_db_available = None
|
||
|
||
def check_db():
|
||
global _db_available
|
||
if _db_available is None:
|
||
try:
|
||
_db_available = is_database_available()
|
||
except Exception:
|
||
_db_available = False
|
||
return _db_available
|
||
|
||
|
||
class TestAgentRegistry:
|
||
"""Agent注册表测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_agent(self):
|
||
"""测试Agent注册"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect", "citation_detect_single"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
|
||
agent_id = await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
|
||
# 验证注册成功
|
||
assert agent_id is not None
|
||
assert len(agent_id) > 0
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_registered_agent(self):
|
||
"""测试获取已注册的Agent"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
|
||
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
|
||
retrieved = await registry.get_agent(agent_name)
|
||
|
||
assert retrieved is not None
|
||
assert retrieved["name"] == agent_name
|
||
assert retrieved["agent_type"] == "citation_detector"
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_all_agents(self):
|
||
"""测试列出所有Agent"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agents = await registry.list_agents()
|
||
|
||
assert isinstance(agents, list)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unregister_agent(self):
|
||
"""测试取消注册"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
|
||
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
await registry.unregister(agent_name)
|
||
|
||
retrieved = await registry.get_agent(agent_name)
|
||
# 注销后状态应为OFFLINE或None
|
||
assert retrieved is None or retrieved["status"] == AgentStatus.OFFLINE
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_heartbeat(self):
|
||
"""测试心跳更新"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
|
||
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
await registry.update_heartbeat(agent_name)
|
||
|
||
agent = await registry.get_agent(agent_name)
|
||
assert agent is not None
|
||
assert agent["last_heartbeat"] is not None
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_available_agent(self):
|
||
"""测试根据任务类型获取可用Agent"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect", "citation_detect_single"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
|
||
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
|
||
available = await registry.get_available_agent("citation_detect")
|
||
# 可能返回None因为Agent状态不是ONLINE
|
||
# 这里只验证方法能正常执行
|
||
assert available is None or isinstance(available, str)
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_agent_reregistration(self):
|
||
"""测试Agent重复注册(应该更新而非报错)"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||
capability1 = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="1.0.0",
|
||
supported_tasks=["citation_detect"],
|
||
max_concurrency=3,
|
||
description="测试AgentV1",
|
||
)
|
||
capability2 = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="citation_detector",
|
||
version="2.0.0",
|
||
supported_tasks=["citation_detect", "new_task"],
|
||
max_concurrency=5,
|
||
description="测试AgentV2",
|
||
)
|
||
|
||
await registry.register(capability1, endpoint=f"agent:{agent_name}")
|
||
|
||
# 验证第一次注册成功
|
||
agent_data = await registry.get_agent(agent_name)
|
||
assert agent_data is not None
|
||
|
||
# 重新注册同名Agent
|
||
agent_id2 = await registry.register(capability2, endpoint=f"agent:{agent_name}")
|
||
|
||
# 应该成功且不报错
|
||
assert agent_id2 is not None
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|