geo/backend/tests/test_agent_framework/test_agent_dispatcher.py

214 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""任务分发器测试"""
import pytest
import uuid
from datetime import datetime, timezone
from app.agent_framework.dispatcher import TaskDispatcher
from app.agent_framework.registry import AgentRegistry
from app.agent_framework.protocol import (
AgentCapability,
TaskMessage,
)
from app.config import settings
def is_database_available():
"""检查数据库是否可用(同步方式)"""
try:
from sqlalchemy import create_engine, text
from app.config import settings
# 从URL创建同步引擎进行测试
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
if 'sqlite' in sync_url:
engine = create_engine(sync_url)
else:
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
engine.dispose()
return True
except Exception:
return False
# 检查服务是否可用
_db_available = None
def check_db():
global _db_available
if _db_available is None:
try:
_db_available = is_database_available()
except Exception:
_db_available = False
return _db_available
def is_redis_available():
"""检查Redis是否可用"""
import redis
try:
r = redis.Redis.from_url(settings.REDIS_URL)
r.ping()
return True
except Exception:
return False
_redis_available = None
def check_redis():
global _redis_available
if _redis_available is None:
try:
_redis_available = is_redis_available()
except Exception:
_redis_available = False
return _redis_available
class TestTaskDispatcher:
"""任务分发器测试"""
@pytest.fixture
def dispatcher(self):
"""创建分发器实例"""
return TaskDispatcher(settings.REDIS_URL)
@pytest.mark.asyncio
async def test_dispatcher_initialization(self, dispatcher):
"""测试分发器初始化"""
assert dispatcher is not None
assert dispatcher._redis_url == settings.REDIS_URL
@pytest.mark.asyncio
async def test_get_task_status_not_found(self, dispatcher):
"""测试获取不存在的任务状态"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
non_existent_id = str(uuid.uuid4())
from app.agent_framework.exceptions import TaskNotFoundError
with pytest.raises(TaskNotFoundError):
await dispatcher.get_task_status(non_existent_id)
@pytest.mark.asyncio
async def test_dispatch_without_agent(self, dispatcher):
"""测试分发任务到不存在的Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="non_existent_agent",
task_type="test_task",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
from app.agent_framework.exceptions import TaskDispatchError
with pytest.raises(TaskDispatchError):
await dispatcher.dispatch(task)
@pytest.mark.asyncio
async def test_cancel_task_not_found(self, dispatcher):
"""测试取消不存在的任务"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
non_existent_id = str(uuid.uuid4())
from app.agent_framework.exceptions import TaskNotFoundError
with pytest.raises(TaskNotFoundError):
await dispatcher.cancel_task(non_existent_id)
@pytest.mark.asyncio
async def test_handle_progress(self, dispatcher):
"""测试处理进度上报"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
from app.agent_framework.protocol import TaskProgress
# 创建一个假的progress对象
progress = TaskProgress(
task_id=str(uuid.uuid4()),
agent_name="non_existent",
progress=0.5,
message="测试进度",
updated_at=datetime.now(timezone.utc),
)
# 不应抛出异常
await dispatcher.handle_progress(progress)
@pytest.mark.asyncio
async def test_close_dispatcher(self, dispatcher):
"""测试关闭分发器"""
if not check_redis():
pytest.skip("Redis不可用跳过此测试")
# 先获取redis连接
await dispatcher._get_redis()
assert dispatcher._redis is not None
# 关闭
await dispatcher.close()
assert dispatcher._redis is None
@pytest.mark.asyncio
async def test_dispatch_and_query_flow(self, dispatcher):
"""测试完整分发和查询流程"""
if not check_db() or not check_redis():
pytest.skip("数据库或Redis不可用跳过此测试")
# 1. 注册一个测试Agent
registry = AgentRegistry()
agent_name = f"test_dispatch_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="test_type",
version="1.0.0",
supported_tasks=["test_task"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
# 2. 尝试分发任务
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=agent_name,
task_type="test_task",
priority=5,
input_data={"test": "data"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
# Agent虽然注册了但可能不在线这里只验证方法能正常执行
try:
task_id = await dispatcher.dispatch(task)
assert task_id is not None
except Exception:
# Agent可能不在线这是预期行为
pass
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_retry_failed_tasks_empty(self, dispatcher):
"""测试重试失败任务(无失败任务)"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
result = await dispatcher.retry_failed_tasks(max_retries=3)
# 无失败任务时不应抛出异常
assert result is None or result == [] or isinstance(result, int)