"""Tests for TaskDispatcher - 任务分发器""" import json import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.dispatcher import TaskDispatcher from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus class _ColumnMock: """Mock for SQLAlchemy column attributes that supports comparison operators.""" def __init__(self, name): self._name = name def __eq__(self, other): return MagicMock() def __ne__(self, other): return MagicMock() def __lt__(self, other): return MagicMock() def __le__(self, other): return MagicMock() def __gt__(self, other): return MagicMock() def __ge__(self, other): return MagicMock() def like(self, pattern): return MagicMock() def desc(self): return MagicMock() class MockAgentModel: """Mock Agent ORM model with class-level column mocks.""" name = _ColumnMock("name") status = _ColumnMock("status") agent_type = _ColumnMock("agent_type") id = _ColumnMock("id") def __init__(self, **kwargs): self.id = kwargs.get("id", uuid.uuid4()) self.name = kwargs.get("name", "test_agent") self.agent_type = kwargs.get("agent_type", "test") self.status = kwargs.get("status", AgentStatus.ONLINE) self.version = kwargs.get("version", "1.0") self.endpoint = kwargs.get("endpoint", "http://localhost:8000") self.description = kwargs.get("description", "Test agent") class MockTaskModel: """Mock Task ORM model with class-level column mocks.""" id = _ColumnMock("id") agent_id = _ColumnMock("agent_id") task_type = _ColumnMock("task_type") status = _ColumnMock("status") priority = _ColumnMock("priority") input_data = _ColumnMock("input_data") output_data = _ColumnMock("output_data") error_message = _ColumnMock("error_message") started_at = _ColumnMock("started_at") completed_at = _ColumnMock("completed_at") organization_id = _ColumnMock("organization_id") created_by = _ColumnMock("created_by") project_id = _ColumnMock("project_id") scheduled_at = _ColumnMock("scheduled_at") created_at = _ColumnMock("created_at") def __init__(self, **kwargs): self.id = kwargs.get("id", uuid.uuid4()) self.agent_id = kwargs.get("agent_id", uuid.uuid4()) self.task_type = kwargs.get("task_type", "test_task") self.status = kwargs.get("status", TaskStatus.PENDING) self.priority = kwargs.get("priority", 1) self.input_data = kwargs.get("input_data", {}) self.output_data = kwargs.get("output_data", None) self.error_message = kwargs.get("error_message", None) self.started_at = kwargs.get("started_at", None) self.completed_at = kwargs.get("completed_at", None) self.organization_id = kwargs.get("organization_id", uuid.uuid4()) self.created_by = kwargs.get("created_by", None) self.project_id = kwargs.get("project_id", None) self.scheduled_at = kwargs.get("scheduled_at", None) self.created_at = kwargs.get("created_at", None) class MockTaskLogModel: """Mock TaskLog ORM model with class-level column mocks.""" id = _ColumnMock("id") task_id = _ColumnMock("task_id") agent_id = _ColumnMock("agent_id") log_level = _ColumnMock("log_level") message = _ColumnMock("message") def __init__(self, **kwargs): self.id = kwargs.get("id", uuid.uuid4()) self.task_id = kwargs.get("task_id", uuid.uuid4()) self.agent_id = kwargs.get("agent_id", uuid.uuid4()) self.log_level = kwargs.get("log_level", "info") self.message = kwargs.get("message", "") def _make_mock_session(agent=None, task=None, log_entries=None): """Create a mock async session that simulates SQLAlchemy queries.""" session = AsyncMock() async def mock_execute(stmt): result = MagicMock() if agent is not None: result.scalar_one_or_none.return_value = agent elif task is not None: result.scalar_one_or_none.return_value = task result.scalars.return_value.all.return_value = [task] if task else [] else: result.scalar_one_or_none.return_value = None result.scalars.return_value.all.return_value = log_entries or [] if log_entries is not None: result.scalars.return_value.all.return_value = log_entries return result session.execute = mock_execute session.add = MagicMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() return session def _make_dispatcher(agent=None, task=None, log_entries=None): """Create a TaskDispatcher with mocked dependencies.""" mock_session = _make_mock_session(agent=agent, task=task, log_entries=log_entries) session_factory = MagicMock() session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) session_factory.return_value.__aexit__ = AsyncMock(return_value=False) mock_redis = AsyncMock() mock_redis.lpush = AsyncMock() redis_factory = AsyncMock(return_value=mock_redis) dispatcher = TaskDispatcher( redis_factory=redis_factory, session_factory=session_factory, agent_model=MockAgentModel, task_model=MockTaskModel, task_log_model=MockTaskLogModel, ) return dispatcher, mock_session, mock_redis _mock_select = MagicMock() class TestTaskDispatcherDispatch: @patch("sqlalchemy.select", _mock_select) async def test_dispatch_to_online_agent(self, make_task): """分发任务到在线 Agent""" agent = MockAgentModel(name="test_agent", status=AgentStatus.ONLINE) dispatcher, session, redis = _make_dispatcher(agent=agent) task_id = str(uuid.uuid4()) task = make_task(task_id=task_id, agent_name="test_agent") result_task_id = await dispatcher.dispatch(task) assert result_task_id == task_id redis.lpush.assert_called_once() # Verify the queue key format call_args = redis.lpush.call_args assert call_args[0][0] == "agent:test_agent:tasks" @patch("sqlalchemy.select", _mock_select) async def test_dispatch_agent_not_found(self, make_task): """分发到不存在的 Agent 抛出异常""" dispatcher, session, redis = _make_dispatcher(agent=None) task_id = str(uuid.uuid4()) task = make_task(task_id=task_id, agent_name="nonexistent") with pytest.raises(TaskDispatchError): await dispatcher.dispatch(task) @patch("sqlalchemy.select", _mock_select) async def test_dispatch_agent_offline(self, make_task): """分发到离线 Agent 抛出异常""" agent = MockAgentModel(name="offline_agent", status=AgentStatus.OFFLINE) dispatcher, session, redis = _make_dispatcher(agent=agent) task_id = str(uuid.uuid4()) task = make_task(task_id=task_id, agent_name="offline_agent") with pytest.raises(TaskDispatchError): await dispatcher.dispatch(task) class TestTaskDispatcherCancel: @patch("sqlalchemy.select", _mock_select) async def test_cancel_pending_task(self, make_task): """取消待执行的任务""" task_uuid = uuid.uuid4() task = MockTaskModel(id=task_uuid, status=TaskStatus.PENDING) dispatcher, session, redis = _make_dispatcher(task=task) await dispatcher.cancel_task(str(task_uuid)) assert task.status == TaskStatus.CANCELLED @patch("sqlalchemy.select", _mock_select) async def test_cancel_completed_task(self, make_task): """取消已完成的任务不改变状态""" task_uuid = uuid.uuid4() task = MockTaskModel(id=task_uuid, status=TaskStatus.COMPLETED) dispatcher, session, redis = _make_dispatcher(task=task) await dispatcher.cancel_task(str(task_uuid)) # Status should remain COMPLETED (not changed to CANCELLED) assert task.status == TaskStatus.COMPLETED @patch("sqlalchemy.select", _mock_select) async def test_cancel_nonexistent_task(self): """取消不存在的任务抛出异常""" dispatcher, session, redis = _make_dispatcher(task=None) with pytest.raises(TaskNotFoundError): await dispatcher.cancel_task(str(uuid.uuid4())) class TestTaskDispatcherHandleResult: @patch("sqlalchemy.select", _mock_select) async def test_handle_completed_result(self, make_task, make_result): """处理成功结果""" task_uuid = uuid.uuid4() task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING) dispatcher, session, redis = _make_dispatcher(task=task) result = make_result(task_id=str(task_uuid), status=TaskStatus.COMPLETED) await dispatcher.handle_result(result) assert task.status == TaskStatus.COMPLETED assert task.output_data == result.output_data @patch("sqlalchemy.select", _mock_select) async def test_handle_failed_result(self, make_task, make_result): """处理失败结果""" task_uuid = uuid.uuid4() task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING) dispatcher, session, redis = _make_dispatcher(task=task) result = make_result( task_id=str(task_uuid), status=TaskStatus.FAILED, error_message="Something went wrong", ) await dispatcher.handle_result(result) assert task.status == TaskStatus.FAILED assert task.error_message == "Something went wrong"