322 lines
12 KiB
Python
322 lines
12 KiB
Python
"""Tests for TaskDispatcher - 任务分发器"""
|
|
|
|
import json
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.dispatcher import TaskDispatcher, _validate_callback_url
|
|
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
|
|
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
|
|
|
|
|
|
class _ColumnMock:
|
|
"""Mock for SQLAlchemy column attributes that supports comparison operators."""
|
|
|
|
def __init__(self, name):
|
|
self._name = name
|
|
|
|
def __eq__(self, other):
|
|
return MagicMock()
|
|
|
|
def __ne__(self, other):
|
|
return MagicMock()
|
|
|
|
def __lt__(self, other):
|
|
return MagicMock()
|
|
|
|
def __le__(self, other):
|
|
return MagicMock()
|
|
|
|
def __gt__(self, other):
|
|
return MagicMock()
|
|
|
|
def __ge__(self, other):
|
|
return MagicMock()
|
|
|
|
def like(self, pattern):
|
|
return MagicMock()
|
|
|
|
def desc(self):
|
|
return MagicMock()
|
|
|
|
|
|
class MockAgentModel:
|
|
"""Mock Agent ORM model with class-level column mocks."""
|
|
name = _ColumnMock("name")
|
|
status = _ColumnMock("status")
|
|
agent_type = _ColumnMock("agent_type")
|
|
id = _ColumnMock("id")
|
|
|
|
def __init__(self, **kwargs):
|
|
self.id = kwargs.get("id", uuid.uuid4())
|
|
self.name = kwargs.get("name", "test_agent")
|
|
self.agent_type = kwargs.get("agent_type", "test")
|
|
self.status = kwargs.get("status", AgentStatus.ONLINE)
|
|
self.version = kwargs.get("version", "1.0")
|
|
self.endpoint = kwargs.get("endpoint", "http://localhost:8000")
|
|
self.description = kwargs.get("description", "Test agent")
|
|
|
|
|
|
class MockTaskModel:
|
|
"""Mock Task ORM model with class-level column mocks."""
|
|
id = _ColumnMock("id")
|
|
agent_id = _ColumnMock("agent_id")
|
|
task_type = _ColumnMock("task_type")
|
|
status = _ColumnMock("status")
|
|
priority = _ColumnMock("priority")
|
|
input_data = _ColumnMock("input_data")
|
|
output_data = _ColumnMock("output_data")
|
|
error_message = _ColumnMock("error_message")
|
|
started_at = _ColumnMock("started_at")
|
|
completed_at = _ColumnMock("completed_at")
|
|
organization_id = _ColumnMock("organization_id")
|
|
created_by = _ColumnMock("created_by")
|
|
project_id = _ColumnMock("project_id")
|
|
scheduled_at = _ColumnMock("scheduled_at")
|
|
created_at = _ColumnMock("created_at")
|
|
|
|
def __init__(self, **kwargs):
|
|
self.id = kwargs.get("id", uuid.uuid4())
|
|
self.agent_id = kwargs.get("agent_id", uuid.uuid4())
|
|
self.task_type = kwargs.get("task_type", "test_task")
|
|
self.status = kwargs.get("status", TaskStatus.PENDING)
|
|
self.priority = kwargs.get("priority", 1)
|
|
self.input_data = kwargs.get("input_data", {})
|
|
self.output_data = kwargs.get("output_data", None)
|
|
self.error_message = kwargs.get("error_message", None)
|
|
self.started_at = kwargs.get("started_at", None)
|
|
self.completed_at = kwargs.get("completed_at", None)
|
|
self.organization_id = kwargs.get("organization_id", uuid.uuid4())
|
|
self.created_by = kwargs.get("created_by", None)
|
|
self.project_id = kwargs.get("project_id", None)
|
|
self.scheduled_at = kwargs.get("scheduled_at", None)
|
|
self.created_at = kwargs.get("created_at", None)
|
|
|
|
|
|
class MockTaskLogModel:
|
|
"""Mock TaskLog ORM model with class-level column mocks."""
|
|
id = _ColumnMock("id")
|
|
task_id = _ColumnMock("task_id")
|
|
agent_id = _ColumnMock("agent_id")
|
|
log_level = _ColumnMock("log_level")
|
|
message = _ColumnMock("message")
|
|
|
|
def __init__(self, **kwargs):
|
|
self.id = kwargs.get("id", uuid.uuid4())
|
|
self.task_id = kwargs.get("task_id", uuid.uuid4())
|
|
self.agent_id = kwargs.get("agent_id", uuid.uuid4())
|
|
self.log_level = kwargs.get("log_level", "info")
|
|
self.message = kwargs.get("message", "")
|
|
|
|
|
|
def _make_mock_session(agent=None, task=None, log_entries=None):
|
|
"""Create a mock async session that simulates SQLAlchemy queries."""
|
|
session = AsyncMock()
|
|
|
|
async def mock_execute(stmt):
|
|
result = MagicMock()
|
|
|
|
if agent is not None:
|
|
result.scalar_one_or_none.return_value = agent
|
|
elif task is not None:
|
|
result.scalar_one_or_none.return_value = task
|
|
result.scalars.return_value.all.return_value = [task] if task else []
|
|
else:
|
|
result.scalar_one_or_none.return_value = None
|
|
result.scalars.return_value.all.return_value = log_entries or []
|
|
|
|
if log_entries is not None:
|
|
result.scalars.return_value.all.return_value = log_entries
|
|
|
|
return result
|
|
|
|
session.execute = mock_execute
|
|
session.add = MagicMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
|
|
return session
|
|
|
|
|
|
def _make_dispatcher(agent=None, task=None, log_entries=None):
|
|
"""Create a TaskDispatcher with mocked dependencies."""
|
|
mock_session = _make_mock_session(agent=agent, task=task, log_entries=log_entries)
|
|
|
|
session_factory = MagicMock()
|
|
session_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
session_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.lpush = AsyncMock()
|
|
redis_factory = AsyncMock(return_value=mock_redis)
|
|
|
|
dispatcher = TaskDispatcher(
|
|
redis_factory=redis_factory,
|
|
session_factory=session_factory,
|
|
agent_model=MockAgentModel,
|
|
task_model=MockTaskModel,
|
|
task_log_model=MockTaskLogModel,
|
|
)
|
|
|
|
return dispatcher, mock_session, mock_redis
|
|
|
|
|
|
_mock_select = MagicMock()
|
|
|
|
|
|
class TestTaskDispatcherDispatch:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_dispatch_to_online_agent(self, make_task):
|
|
"""分发任务到在线 Agent"""
|
|
agent = MockAgentModel(name="test_agent", status=AgentStatus.ONLINE)
|
|
dispatcher, session, redis = _make_dispatcher(agent=agent)
|
|
task_id = str(uuid.uuid4())
|
|
task = make_task(task_id=task_id, agent_name="test_agent")
|
|
|
|
result_task_id = await dispatcher.dispatch(task)
|
|
assert result_task_id == task_id
|
|
redis.lpush.assert_called_once()
|
|
|
|
# Verify the queue key format
|
|
call_args = redis.lpush.call_args
|
|
assert call_args[0][0] == "agent:test_agent:tasks"
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_dispatch_agent_not_found(self, make_task):
|
|
"""分发到不存在的 Agent 抛出异常"""
|
|
dispatcher, session, redis = _make_dispatcher(agent=None)
|
|
task_id = str(uuid.uuid4())
|
|
task = make_task(task_id=task_id, agent_name="nonexistent")
|
|
|
|
with pytest.raises(TaskDispatchError):
|
|
await dispatcher.dispatch(task)
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_dispatch_agent_offline(self, make_task):
|
|
"""分发到离线 Agent 抛出异常"""
|
|
agent = MockAgentModel(name="offline_agent", status=AgentStatus.OFFLINE)
|
|
dispatcher, session, redis = _make_dispatcher(agent=agent)
|
|
task_id = str(uuid.uuid4())
|
|
task = make_task(task_id=task_id, agent_name="offline_agent")
|
|
|
|
with pytest.raises(TaskDispatchError):
|
|
await dispatcher.dispatch(task)
|
|
|
|
|
|
class TestTaskDispatcherCancel:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_cancel_pending_task(self, make_task):
|
|
"""取消待执行的任务"""
|
|
task_uuid = uuid.uuid4()
|
|
task = MockTaskModel(id=task_uuid, status=TaskStatus.PENDING)
|
|
dispatcher, session, redis = _make_dispatcher(task=task)
|
|
|
|
await dispatcher.cancel_task(str(task_uuid))
|
|
assert task.status == TaskStatus.CANCELLED
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_cancel_completed_task(self, make_task):
|
|
"""取消已完成的任务不改变状态"""
|
|
task_uuid = uuid.uuid4()
|
|
task = MockTaskModel(id=task_uuid, status=TaskStatus.COMPLETED)
|
|
dispatcher, session, redis = _make_dispatcher(task=task)
|
|
|
|
await dispatcher.cancel_task(str(task_uuid))
|
|
# Status should remain COMPLETED (not changed to CANCELLED)
|
|
assert task.status == TaskStatus.COMPLETED
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_cancel_nonexistent_task(self):
|
|
"""取消不存在的任务抛出异常"""
|
|
dispatcher, session, redis = _make_dispatcher(task=None)
|
|
|
|
with pytest.raises(TaskNotFoundError):
|
|
await dispatcher.cancel_task(str(uuid.uuid4()))
|
|
|
|
|
|
class TestTaskDispatcherHandleResult:
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_handle_completed_result(self, make_task, make_result):
|
|
"""处理成功结果"""
|
|
task_uuid = uuid.uuid4()
|
|
task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING)
|
|
dispatcher, session, redis = _make_dispatcher(task=task)
|
|
|
|
result = make_result(task_id=str(task_uuid), status=TaskStatus.COMPLETED)
|
|
await dispatcher.handle_result(result)
|
|
|
|
assert task.status == TaskStatus.COMPLETED
|
|
assert task.output_data == result.output_data
|
|
|
|
@patch("sqlalchemy.select", _mock_select)
|
|
async def test_handle_failed_result(self, make_task, make_result):
|
|
"""处理失败结果"""
|
|
task_uuid = uuid.uuid4()
|
|
task = MockTaskModel(id=task_uuid, status=TaskStatus.RUNNING)
|
|
dispatcher, session, redis = _make_dispatcher(task=task)
|
|
|
|
result = make_result(
|
|
task_id=str(task_uuid),
|
|
status=TaskStatus.FAILED,
|
|
error_message="Something went wrong",
|
|
)
|
|
await dispatcher.handle_result(result)
|
|
|
|
assert task.status == TaskStatus.FAILED
|
|
assert task.error_message == "Something went wrong"
|
|
|
|
|
|
class TestValidateCallbackUrl:
|
|
"""SSRF protection tests for _validate_callback_url."""
|
|
|
|
def test_valid_public_https_url(self):
|
|
"""Valid public HTTPS URL should be allowed."""
|
|
assert _validate_callback_url("https://example.com/callback") is True
|
|
|
|
def test_valid_public_http_url(self):
|
|
"""Valid public HTTP URL should be allowed."""
|
|
assert _validate_callback_url("http://example.com/callback") is True
|
|
|
|
def test_localhost_blocked(self):
|
|
"""localhost should be blocked."""
|
|
assert _validate_callback_url("http://localhost:8080/callback") is False
|
|
|
|
def test_loopback_ip_blocked(self):
|
|
"""127.0.0.1 should be blocked."""
|
|
assert _validate_callback_url("http://127.0.0.1:8080/callback") is False
|
|
|
|
def test_private_10_range_blocked(self):
|
|
"""10.0.0.0/8 range should be blocked."""
|
|
assert _validate_callback_url("http://10.0.0.1/internal") is False
|
|
|
|
def test_private_192_range_blocked(self):
|
|
"""192.168.0.0/16 range should be blocked."""
|
|
assert _validate_callback_url("http://192.168.1.1/admin") is False
|
|
|
|
def test_private_172_range_blocked(self):
|
|
"""172.16.0.0/12 range should be blocked."""
|
|
assert _validate_callback_url("http://172.16.0.1/internal") is False
|
|
|
|
def test_ftp_protocol_blocked(self):
|
|
"""FTP protocol should be blocked."""
|
|
assert _validate_callback_url("ftp://example.com/file") is False
|
|
|
|
def test_file_protocol_blocked(self):
|
|
"""file:// protocol should be blocked."""
|
|
assert _validate_callback_url("file:///etc/passwd") is False
|
|
|
|
def test_javascript_protocol_blocked(self):
|
|
"""javascript: protocol should be blocked."""
|
|
assert _validate_callback_url("javascript:alert(1)") is False
|
|
|
|
def test_empty_url_blocked(self):
|
|
"""Empty URL should be blocked."""
|
|
assert _validate_callback_url("") is False
|
|
|
|
def test_malformed_url_blocked(self):
|
|
"""Malformed URL should be blocked."""
|
|
assert _validate_callback_url("not-a-valid-url") is False
|