215 lines
6.5 KiB
Python
215 lines
6.5 KiB
Python
"""任务分发器测试"""
|
||
import pytest
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from app.agent_framework.dispatcher import TaskDispatcher
|
||
from app.agent_framework.registry import AgentRegistry
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
TaskMessage,
|
||
)
|
||
from app.config import settings
|
||
|
||
|
||
def is_database_available():
|
||
"""检查数据库是否可用(同步方式)"""
|
||
try:
|
||
from sqlalchemy import create_engine, text
|
||
from app.config import settings
|
||
|
||
# 从URL创建同步引擎进行测试
|
||
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
|
||
if 'sqlite' in sync_url:
|
||
engine = create_engine(sync_url)
|
||
else:
|
||
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
|
||
|
||
with engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
engine.dispose()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# 检查服务是否可用
|
||
_db_available = None
|
||
|
||
def check_db():
|
||
global _db_available
|
||
if _db_available is None:
|
||
try:
|
||
_db_available = is_database_available()
|
||
except Exception:
|
||
_db_available = False
|
||
return _db_available
|
||
|
||
|
||
def is_redis_available():
|
||
"""检查Redis是否可用"""
|
||
import redis
|
||
try:
|
||
r = redis.Redis.from_url(settings.REDIS_URL)
|
||
r.ping()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
_redis_available = None
|
||
|
||
def check_redis():
|
||
global _redis_available
|
||
if _redis_available is None:
|
||
try:
|
||
_redis_available = is_redis_available()
|
||
except Exception:
|
||
_redis_available = False
|
||
return _redis_available
|
||
|
||
|
||
class TestTaskDispatcher:
|
||
"""任务分发器测试"""
|
||
|
||
@pytest.fixture
|
||
def dispatcher(self):
|
||
"""创建分发器实例"""
|
||
return TaskDispatcher(settings.REDIS_URL)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatcher_initialization(self, dispatcher):
|
||
"""测试分发器初始化"""
|
||
assert dispatcher is not None
|
||
assert dispatcher._redis_url == settings.REDIS_URL
|
||
assert dispatcher._redis is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_task_status_not_found(self, dispatcher):
|
||
"""测试获取不存在的任务状态"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
non_existent_id = str(uuid.uuid4())
|
||
|
||
from app.agent_framework.exceptions import TaskNotFoundError
|
||
with pytest.raises(TaskNotFoundError):
|
||
await dispatcher.get_task_status(non_existent_id)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatch_without_agent(self, dispatcher):
|
||
"""测试分发任务到不存在的Agent"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="non_existent_agent",
|
||
task_type="test_task",
|
||
priority=5,
|
||
input_data={},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
from app.agent_framework.exceptions import TaskDispatchError
|
||
with pytest.raises(TaskDispatchError):
|
||
await dispatcher.dispatch(task)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cancel_task_not_found(self, dispatcher):
|
||
"""测试取消不存在的任务"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
non_existent_id = str(uuid.uuid4())
|
||
|
||
from app.agent_framework.exceptions import TaskNotFoundError
|
||
with pytest.raises(TaskNotFoundError):
|
||
await dispatcher.cancel_task(non_existent_id)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_handle_progress(self, dispatcher):
|
||
"""测试处理进度上报"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
from app.agent_framework.protocol import TaskProgress
|
||
|
||
# 创建一个假的progress对象
|
||
progress = TaskProgress(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="non_existent",
|
||
progress=0.5,
|
||
message="测试进度",
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
# 不应抛出异常
|
||
await dispatcher.handle_progress(progress)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_close_dispatcher(self, dispatcher):
|
||
"""测试关闭分发器"""
|
||
if not check_redis():
|
||
pytest.skip("Redis不可用,跳过此测试")
|
||
|
||
# 先获取redis连接
|
||
await dispatcher._get_redis()
|
||
assert dispatcher._redis is not None
|
||
|
||
# 关闭
|
||
await dispatcher.close()
|
||
assert dispatcher._redis is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dispatch_and_query_flow(self, dispatcher):
|
||
"""测试完整分发和查询流程"""
|
||
if not check_db() or not check_redis():
|
||
pytest.skip("数据库或Redis不可用,跳过此测试")
|
||
|
||
# 1. 注册一个测试Agent
|
||
registry = AgentRegistry()
|
||
agent_name = f"test_dispatch_agent_{uuid.uuid4().hex[:8]}"
|
||
capability = AgentCapability(
|
||
agent_name=agent_name,
|
||
agent_type="test_type",
|
||
version="1.0.0",
|
||
supported_tasks=["test_task"],
|
||
max_concurrency=3,
|
||
description="测试Agent",
|
||
)
|
||
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||
|
||
# 2. 尝试分发任务
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name=agent_name,
|
||
task_type="test_task",
|
||
priority=5,
|
||
input_data={"test": "data"},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
# Agent虽然注册了但可能不在线,这里只验证方法能正常执行
|
||
try:
|
||
task_id = await dispatcher.dispatch(task)
|
||
assert task_id is not None
|
||
except Exception:
|
||
# Agent可能不在线,这是预期行为
|
||
pass
|
||
|
||
# 清理
|
||
await registry.unregister(agent_name)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retry_failed_tasks_empty(self, dispatcher):
|
||
"""测试重试失败任务(无失败任务)"""
|
||
if not check_db():
|
||
pytest.skip("数据库不可用,跳过此测试")
|
||
|
||
result = await dispatcher.retry_failed_tasks(max_retries=3)
|
||
# 无失败任务时不应抛出异常
|
||
assert result is None or result == [] or isinstance(result, int)
|