fischer-agentkit/tests/unit/test_dispatcher.py

270 lines
9.6 KiB
Python

"""Tests for TaskDispatcher - 任务分发器"""
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.dispatcher import TaskDispatcher
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
class _ColumnMock:
"""Mock for SQLAlchemy column attributes that supports comparison operators."""
def __init__(self, name):
self._name = name
def __eq__(self, other):
return MagicMock()
def __ne__(self, other):
return MagicMock()
def __lt__(self, other):
return MagicMock()
def __le__(self, other):
return MagicMock()
def __gt__(self, other):
return MagicMock()
def __ge__(self, other):
return MagicMock()
def like(self, pattern):
return MagicMock()
def desc(self):
return MagicMock()
class MockAgentModel:
"""Mock Agent ORM model with class-level column mocks."""
name = _ColumnMock("name")
status = _ColumnMock("status")
agent_type = _ColumnMock("agent_type")
id = _ColumnMock("id")
def __init__(self, **kwargs):
self.id = kwargs.get("id", uuid.uuid4())
self.name = kwargs.get("name", "test_agent")
self.agent_type = kwargs.get("agent_type", "test")
self.status = kwargs.get("status", AgentStatus.ONLINE)
self.version = kwargs.get("version", "1.0")
self.endpoint = kwargs.get("endpoint", "http://localhost:8000")
self.description = kwargs.get("description", "Test agent")
class MockTaskModel:
"""Mock Task ORM model with class-level column mocks."""
id = _ColumnMock("id")
agent_id = _ColumnMock("agent_id")
task_type = _ColumnMock("task_type")
status = _ColumnMock("status")
priority = _ColumnMock("priority")
input_data = _ColumnMock("input_data")
output_data = _ColumnMock("output_data")
error_message = _ColumnMock("error_message")
started_at = _ColumnMock("started_at")
completed_at = _ColumnMock("completed_at")
organization_id = _ColumnMock("organization_id")
created_by = _ColumnMock("created_by")
project_id = _ColumnMock("project_id")
scheduled_at = _ColumnMock("scheduled_at")
created_at = _ColumnMock("created_at")
def __init__(self, **kwargs):
self.id = kwargs.get("id", uuid.uuid4())
self.agent_id = kwargs.get("agent_id", uuid.uuid4())
self.task_type = kwargs.get("task_type", "test_task")
self.status = kwargs.get("status", TaskStatus.PENDING)
self.priority = kwargs.get("priority", 1)
self.input_data = kwargs.get("input_data", {})
self.output_data = kwargs.get("output_data", None)
self.error_message = kwargs.get("error_message", None)
self.started_at = kwargs.get("started_at", None)
self.completed_at = kwargs.get("completed_at", None)
self.organization_id = kwargs.get("organization_id", uuid.uuid4())
self.created_by = kwargs.get("created_by", None)
self.project_id = kwargs.get("project_id", None)
self.scheduled_at = kwargs.get("scheduled_at", None)
self.created_at = kwargs.get("created_at", None)
class MockTaskLogModel:
"""Mock TaskLog ORM model with class-level column mocks."""
id = _ColumnMock("id")
task_id = _ColumnMock("task_id")
agent_id = _ColumnMock("agent_id")
log_level = _ColumnMock("log_level")
message = _ColumnMock("message")
def __init__(self, **kwargs):
self.id = kwargs.get("id", uuid.uuid4())
self.task_id = kwargs.get("task_id", uuid.uuid4())
self.agent_id = kwargs.get("agent_id", uuid.uuid4())
self.log_level = kwargs.get("log_level", "info")
self.message = kwargs.get("message", "")
def _make_mock_session(agent=None, task=None, log_entries=None):
"""Create a mock async session that simulates SQLAlchemy queries."""
session = AsyncMock()
async def mock_execute(stmt):
result = MagicMock()
if agent is not None:
result.scalar_one_or_none.return_value = agent
elif task is not None:
result.scalar_one_or_none.return_value = task
result.scalars.return_value.all.return_value = [task] if task else []
else:
result.scalar_one_or_none.return_value = None
result.scalars.return_value.all.return_value = log_entries or []
if log_entries is not None:
result.scalars.return_value.all.return_value = log_entries
return result
session.execute = mock_execute
session.add = MagicMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.refresh = AsyncMock()
return session
def _make_dispatcher(agent=None, task=None, log_entries=None):
"""Create a TaskDispatcher with mocked dependencies."""
mock_session = _make_mock_session(agent=agent, task=task, log_entries=log_entries)
session_factory = MagicMock()
session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
session_factory.return_value.__aexit__ = AsyncMock(return_value=False)
mock_redis = AsyncMock()
mock_redis.lpush = AsyncMock()
redis_factory = AsyncMock(return_value=mock_redis)
dispatcher = TaskDispatcher(
redis_factory=redis_factory,
session_factory=session_factory,
agent_model=MockAgentModel,
task_model=MockTaskModel,
task_log_model=MockTaskLogModel,
)
return dispatcher, mock_session, mock_redis
_mock_select = MagicMock()
class TestTaskDispatcherDispatch:
@patch("sqlalchemy.select", _mock_select)
async def test_dispatch_to_online_agent(self, make_task):
"""分发任务到在线 Agent"""
agent = MockAgentModel(name="test_agent", status=AgentStatus.ONLINE)
dispatcher, session, redis = _make_dispatcher(agent=agent)
task_id = str(uuid.uuid4())
task = make_task(task_id=task_id, agent_name="test_agent")
result_task_id = await dispatcher.dispatch(task)
assert result_task_id == task_id
redis.lpush.assert_called_once()
# Verify the queue key format
call_args = redis.lpush.call_args
assert call_args[0][0] == "agent:test_agent:tasks"
@patch("sqlalchemy.select", _mock_select)
async def test_dispatch_agent_not_found(self, make_task):
"""分发到不存在的 Agent 抛出异常"""
dispatcher, session, redis = _make_dispatcher(agent=None)
task_id = str(uuid.uuid4())
task = make_task(task_id=task_id, agent_name="nonexistent")
with pytest.raises(TaskDispatchError):
await dispatcher.dispatch(task)
@patch("sqlalchemy.select", _mock_select)
async def test_dispatch_agent_offline(self, make_task):
"""分发到离线 Agent 抛出异常"""
agent = MockAgentModel(name="offline_agent", status=AgentStatus.OFFLINE)
dispatcher, session, redis = _make_dispatcher(agent=agent)
task_id = str(uuid.uuid4())
task = make_task(task_id=task_id, agent_name="offline_agent")
with pytest.raises(TaskDispatchError):
await dispatcher.dispatch(task)
class TestTaskDispatcherCancel:
@patch("sqlalchemy.select", _mock_select)
async def test_cancel_pending_task(self, make_task):
"""取消待执行的任务"""
task_uuid = uuid.uuid4()
task = MockTaskModel(id=task_uuid, status=TaskStatus.PENDING)
dispatcher, session, redis = _make_dispatcher(task=task)
await dispatcher.cancel_task(str(task_uuid))
assert task.status == TaskStatus.CANCELLED
@patch("sqlalchemy.select", _mock_select)
async def test_cancel_completed_task(self, make_task):
"""取消已完成的任务不改变状态"""
task_uuid = uuid.uuid4()
task = MockTaskModel(id=task_uuid, status=TaskStatus.COMPLETED)
dispatcher, session, redis = _make_dispatcher(task=task)
await dispatcher.cancel_task(str(task_uuid))
# Status should remain COMPLETED (not changed to CANCELLED)
assert task.status == TaskStatus.COMPLETED
@patch("sqlalchemy.select", _mock_select)
async def test_cancel_nonexistent_task(self):
"""取消不存在的任务抛出异常"""
dispatcher, session, redis = _make_dispatcher(task=None)
with pytest.raises(TaskNotFoundError):
await dispatcher.cancel_task(str(uuid.uuid4()))
class TestTaskDispatcherHandleResult:
@patch("sqlalchemy.select", _mock_select)
async def test_handle_completed_result(self, make_task, make_result):
"""处理成功结果"""
task_uuid = uuid.uuid4()
task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING)
dispatcher, session, redis = _make_dispatcher(task=task)
result = make_result(task_id=str(task_uuid), status=TaskStatus.COMPLETED)
await dispatcher.handle_result(result)
assert task.status == TaskStatus.COMPLETED
assert task.output_data == result.output_data
@patch("sqlalchemy.select", _mock_select)
async def test_handle_failed_result(self, make_task, make_result):
"""处理失败结果"""
task_uuid = uuid.uuid4()
task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING)
dispatcher, session, redis = _make_dispatcher(task=task)
result = make_result(
task_id=str(task_uuid),
status=TaskStatus.FAILED,
error_message="Something went wrong",
)
await dispatcher.handle_result(result)
assert task.status == TaskStatus.FAILED
assert task.error_message == "Something went wrong"