513 lines
19 KiB
Python
513 lines
19 KiB
Python
"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API"""
|
|
|
|
import asyncio
|
|
from datetime import datetime, timezone, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
|
from agentkit.server.task_store import TaskRecord, TaskStore
|
|
from agentkit.server.runner import BackgroundRunner
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# TaskStore Tests
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestTaskRecord:
|
|
"""TaskRecord dataclass tests"""
|
|
|
|
def test_to_dict_returns_complete_dict(self):
|
|
record = TaskRecord(
|
|
task_id="t1",
|
|
agent_name="agent_a",
|
|
skill_name="skill_x",
|
|
input_data={"query": "hello"},
|
|
)
|
|
d = record.to_dict()
|
|
assert d["task_id"] == "t1"
|
|
assert d["agent_name"] == "agent_a"
|
|
assert d["skill_name"] == "skill_x"
|
|
assert d["input_data"] == {"query": "hello"}
|
|
assert d["status"] == "pending"
|
|
assert d["output_data"] is None
|
|
assert d["error_message"] is None
|
|
assert d["progress"] == 0.0
|
|
assert d["created_at"] is not None
|
|
|
|
def test_to_dict_with_timestamps(self):
|
|
now = datetime.now(timezone.utc)
|
|
record = TaskRecord(
|
|
task_id="t2",
|
|
agent_name="agent_b",
|
|
skill_name=None,
|
|
input_data={},
|
|
started_at=now,
|
|
completed_at=now,
|
|
)
|
|
d = record.to_dict()
|
|
assert d["started_at"] == now.isoformat()
|
|
assert d["completed_at"] == now.isoformat()
|
|
|
|
|
|
class TestTaskStore:
|
|
"""TaskStore in-memory storage tests"""
|
|
|
|
def test_create_task_record_stored_correctly(self):
|
|
store = TaskStore()
|
|
record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
|
|
assert record.task_id == "t1"
|
|
assert record.agent_name == "agent_a"
|
|
assert record.skill_name == "skill_x"
|
|
assert record.input_data == {"q": "hello"}
|
|
assert record.status == TaskStatus.PENDING
|
|
|
|
def test_get_task_by_id_returns_record(self):
|
|
store = TaskStore()
|
|
store.create("t1", "agent_a", {})
|
|
record = store.get("t1")
|
|
assert record is not None
|
|
assert record.task_id == "t1"
|
|
|
|
def test_get_nonexistent_task_returns_none(self):
|
|
store = TaskStore()
|
|
assert store.get("nonexistent") is None
|
|
|
|
def test_update_status_fields_updated(self):
|
|
store = TaskStore()
|
|
store.create("t1", "agent_a", {})
|
|
now = datetime.now(timezone.utc)
|
|
record = store.update_status(
|
|
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway"
|
|
)
|
|
assert record.status == TaskStatus.RUNNING
|
|
assert record.started_at == now
|
|
assert record.progress == 0.5
|
|
assert record.progress_message == "Halfway"
|
|
|
|
def test_update_nonexistent_task_raises_keyerror(self):
|
|
store = TaskStore()
|
|
with pytest.raises(KeyError, match="not found"):
|
|
store.update_status("nonexistent", TaskStatus.RUNNING)
|
|
|
|
def test_list_tasks_returns_all_sorted_desc(self):
|
|
store = TaskStore()
|
|
store.create("t1", "agent_a", {})
|
|
store.create("t2", "agent_b", {})
|
|
tasks = store.list_tasks()
|
|
assert len(tasks) == 2
|
|
# Most recent first
|
|
assert tasks[0].task_id == "t2"
|
|
assert tasks[1].task_id == "t1"
|
|
|
|
def test_list_tasks_filtered_by_status(self):
|
|
store = TaskStore()
|
|
store.create("t1", "agent_a", {})
|
|
store.create("t2", "agent_b", {})
|
|
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
|
tasks = store.list_tasks(status=TaskStatus.COMPLETED)
|
|
assert len(tasks) == 1
|
|
assert tasks[0].task_id == "t1"
|
|
|
|
def test_max_records_limit_evicts_oldest_completed(self):
|
|
store = TaskStore(max_records=2)
|
|
store.create("t1", "agent_a", {})
|
|
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
|
store.create("t2", "agent_b", {})
|
|
# t3 should evict t1 (oldest completed)
|
|
store.create("t3", "agent_c", {})
|
|
assert store.get("t1") is None
|
|
assert store.get("t2") is not None
|
|
assert store.get("t3") is not None
|
|
|
|
def test_max_records_full_no_completed_raises(self):
|
|
store = TaskStore(max_records=1)
|
|
store.create("t1", "agent_a", {})
|
|
# All tasks are PENDING, no completed to evict
|
|
with pytest.raises(RuntimeError, match="full"):
|
|
store.create("t2", "agent_b", {})
|
|
|
|
def test_ttl_cleanup_removes_expired_completed(self):
|
|
store = TaskStore(ttl_seconds=0) # Immediate expiry
|
|
store.create("t1", "agent_a", {})
|
|
store.update_status(
|
|
"t1", TaskStatus.COMPLETED,
|
|
completed_at=datetime.now(timezone.utc) - timedelta(seconds=10),
|
|
)
|
|
store.create("t2", "agent_b", {})
|
|
# t2 is PENDING, should not be cleaned
|
|
store._cleanup_expired()
|
|
assert store.get("t1") is None # Expired completed
|
|
assert store.get("t2") is not None # Pending stays
|
|
|
|
def test_size_property_correct_count(self):
|
|
store = TaskStore()
|
|
assert store.size == 0
|
|
store.create("t1", "agent_a", {})
|
|
assert store.size == 1
|
|
store.create("t2", "agent_b", {})
|
|
assert store.size == 2
|
|
|
|
def test_list_tasks_respects_limit(self):
|
|
store = TaskStore()
|
|
for i in range(5):
|
|
store.create(f"t{i}", "agent_a", {})
|
|
tasks = store.list_tasks(limit=3)
|
|
assert len(tasks) == 3
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# BackgroundRunner Tests
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestBackgroundRunner:
|
|
"""BackgroundRunner async task execution tests"""
|
|
|
|
@pytest.fixture
|
|
def task_store(self):
|
|
return TaskStore()
|
|
|
|
@pytest.fixture
|
|
def runner(self, task_store):
|
|
return BackgroundRunner(task_store=task_store, max_concurrent=5)
|
|
|
|
def _make_mock_agent(self, name="test_agent", output=None, raise_error=None):
|
|
"""Create a mock agent for testing"""
|
|
agent = MagicMock()
|
|
agent.name = name
|
|
agent.agent_type = "test_type"
|
|
if raise_error:
|
|
agent.execute = AsyncMock(side_effect=raise_error)
|
|
else:
|
|
task_result = TaskResult(
|
|
task_id="mock",
|
|
agent_name=name,
|
|
status="completed",
|
|
output_data=output or {"result": "ok"},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
agent.execute = AsyncMock(return_value=task_result)
|
|
return agent
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_returns_task_id_immediately(self, runner, task_store):
|
|
agent = self._make_mock_agent()
|
|
task_id = await runner.submit(agent, {"query": "test"})
|
|
assert task_id is not None
|
|
assert isinstance(task_id, str)
|
|
# Task record should exist in store
|
|
record = task_store.get(task_id)
|
|
assert record is not None
|
|
assert record.status == TaskStatus.PENDING
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_task_runs_to_completion(self, runner, task_store):
|
|
agent = self._make_mock_agent(output={"answer": "42"})
|
|
task_id = await runner.submit(agent, {"query": "meaning of life"})
|
|
# Wait for task to complete
|
|
await asyncio.sleep(0.1)
|
|
record = task_store.get(task_id)
|
|
assert record is not None
|
|
assert record.status == TaskStatus.COMPLETED
|
|
assert record.output_data == {"answer": "42"}
|
|
assert record.progress == 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_submit_task_failure_recorded(self, runner, task_store):
|
|
agent = self._make_mock_agent(raise_error=RuntimeError("boom"))
|
|
task_id = await runner.submit(agent, {"query": "fail"})
|
|
# Wait for task to fail
|
|
await asyncio.sleep(0.1)
|
|
record = task_store.get(task_id)
|
|
assert record is not None
|
|
assert record.status == TaskStatus.FAILED
|
|
assert "boom" in record.error_message
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_running_task(self, runner, task_store):
|
|
async def slow_execute(msg):
|
|
await asyncio.sleep(10) # Long running
|
|
return TaskResult(
|
|
task_id=msg.task_id,
|
|
agent_name="test_agent",
|
|
status="completed",
|
|
output_data={"result": "done"},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
agent = MagicMock()
|
|
agent.name = "slow_agent"
|
|
agent.agent_type = "test_type"
|
|
agent.execute = AsyncMock(side_effect=slow_execute)
|
|
|
|
task_id = await runner.submit(agent, {"query": "slow"})
|
|
# Give it a moment to start
|
|
await asyncio.sleep(0.05)
|
|
cancelled = await runner.cancel(task_id)
|
|
assert cancelled is True
|
|
record = task_store.get(task_id)
|
|
assert record.status == TaskStatus.CANCELLED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_non_running_task_returns_false(self, runner, task_store):
|
|
result = await runner.cancel("nonexistent")
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_tasks_respects_semaphore(self, task_store):
|
|
runner = BackgroundRunner(task_store=task_store, max_concurrent=2)
|
|
execution_order = []
|
|
|
|
async def tracked_execute(msg):
|
|
execution_order.append(f"start:{msg.task_id}")
|
|
await asyncio.sleep(0.1)
|
|
execution_order.append(f"end:{msg.task_id}")
|
|
return TaskResult(
|
|
task_id=msg.task_id,
|
|
agent_name="test",
|
|
status="completed",
|
|
output_data={},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
agents = []
|
|
for i in range(4):
|
|
agent = MagicMock()
|
|
agent.name = f"agent_{i}"
|
|
agent.agent_type = "test_type"
|
|
agent.execute = AsyncMock(side_effect=tracked_execute)
|
|
agents.append(agent)
|
|
|
|
# Submit all 4 tasks
|
|
task_ids = []
|
|
for agent in agents:
|
|
tid = await runner.submit(agent, {"idx": agents.index(agent)})
|
|
task_ids.append(tid)
|
|
|
|
# Wait for all to complete
|
|
await asyncio.sleep(0.5)
|
|
|
|
# All tasks should have completed
|
|
for tid in task_ids:
|
|
record = task_store.get(tid)
|
|
assert record.status == TaskStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_active_count_tracks_running(self, task_store):
|
|
runner = BackgroundRunner(task_store=task_store, max_concurrent=10)
|
|
|
|
async def slow_execute(msg):
|
|
await asyncio.sleep(0.2)
|
|
return TaskResult(
|
|
task_id=msg.task_id,
|
|
agent_name="test",
|
|
status="completed",
|
|
output_data={},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
agent = MagicMock()
|
|
agent.name = "slow_agent"
|
|
agent.agent_type = "test_type"
|
|
agent.execute = AsyncMock(side_effect=slow_execute)
|
|
|
|
await runner.submit(agent, {})
|
|
await asyncio.sleep(0.05)
|
|
assert runner.active_count >= 1
|
|
|
|
await asyncio.sleep(0.3)
|
|
assert runner.active_count == 0
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# API Tests (using TestClient)
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestAsyncTaskAPI:
|
|
"""Async task API endpoint tests"""
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway(self):
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
|
|
gateway = LLMGateway()
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.return_value = LLMResponse(
|
|
content='{"result": "mocked"}',
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
)
|
|
gateway.register_provider("test", mock_provider)
|
|
return gateway
|
|
|
|
@pytest.fixture
|
|
def skill_registry(self):
|
|
from agentkit.skills.registry import SkillRegistry
|
|
return SkillRegistry()
|
|
|
|
@pytest.fixture
|
|
def tool_registry(self):
|
|
from agentkit.tools.registry import ToolRegistry
|
|
return ToolRegistry()
|
|
|
|
@pytest.fixture
|
|
def app(self, mock_llm_gateway, skill_registry, tool_registry):
|
|
from agentkit.server.app import create_app
|
|
return create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
@pytest.fixture
|
|
def client(self, app):
|
|
return TestClient(app)
|
|
|
|
def _register_skill_and_create_agent(self, client, skill_registry):
|
|
"""Helper: register a skill and create an agent for it"""
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
|
|
skill_config = SkillConfig(
|
|
name="async_skill",
|
|
agent_type="async_type",
|
|
task_mode="llm_generate",
|
|
prompt={"identity": "Async Skill", "instructions": "Handle async"},
|
|
intent={"keywords": ["async"], "description": "Async skill"},
|
|
)
|
|
skill = Skill(config=skill_config)
|
|
skill_registry.register(skill)
|
|
|
|
# Create agent
|
|
resp = client.post(
|
|
"/api/v1/agents",
|
|
json={"skill_name": "async_skill"},
|
|
)
|
|
assert resp.status_code == 201
|
|
return "async_skill"
|
|
|
|
def test_submit_task_async_returns_task_id(self, client, skill_registry):
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
response = client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "async test"},
|
|
"agent_name": agent_name,
|
|
"mode": "async",
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "task_id" in data
|
|
assert data["status"] == "pending"
|
|
assert data["mode"] == "async"
|
|
|
|
def test_get_task_status_returns_record(self, client, skill_registry):
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
# Submit async task
|
|
submit_resp = client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "status test"},
|
|
"agent_name": agent_name,
|
|
"mode": "async",
|
|
},
|
|
)
|
|
task_id = submit_resp.json()["task_id"]
|
|
|
|
# Wait a bit for completion
|
|
import time
|
|
time.sleep(0.3)
|
|
|
|
# Get status
|
|
response = client.get(f"/api/v1/tasks/{task_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["task_id"] == task_id
|
|
assert data["status"] in ("completed", "running", "pending")
|
|
|
|
def test_get_task_status_not_found_404(self, client):
|
|
response = client.get("/api/v1/tasks/nonexistent-id")
|
|
assert response.status_code == 404
|
|
|
|
def test_cancel_task(self, client, skill_registry):
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
# Submit async task
|
|
submit_resp = client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "cancel test"},
|
|
"agent_name": agent_name,
|
|
"mode": "async",
|
|
},
|
|
)
|
|
task_id = submit_resp.json()["task_id"]
|
|
|
|
# Try to cancel (may or may not succeed depending on timing)
|
|
response = client.post(f"/api/v1/tasks/{task_id}/cancel")
|
|
# Either cancelled or 400 (already completed)
|
|
assert response.status_code in (200, 400)
|
|
|
|
def test_list_tasks(self, client, skill_registry):
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
# Submit an async task to ensure at least one exists
|
|
client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "list test"},
|
|
"agent_name": agent_name,
|
|
"mode": "async",
|
|
},
|
|
)
|
|
response = client.get("/api/v1/tasks")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
|
|
def test_list_tasks_filter_by_status(self, client, skill_registry):
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
# Submit an async task
|
|
client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "filter test"},
|
|
"agent_name": agent_name,
|
|
"mode": "async",
|
|
},
|
|
)
|
|
response = client.get("/api/v1/tasks?status=completed")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
# All returned tasks should be completed
|
|
for task in data:
|
|
assert task["status"] == "completed"
|
|
|
|
def test_sync_mode_still_works(self, client, skill_registry):
|
|
"""Ensure existing sync mode is not broken"""
|
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
|
response = client.post(
|
|
"/api/v1/tasks",
|
|
json={
|
|
"input_data": {"query": "sync test"},
|
|
"agent_name": agent_name,
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Sync mode returns task_id and output
|
|
assert "task_id" in data
|