fischer-agentkit/tests/unit/test_async_tasks.py

513 lines
19 KiB
Python

"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API"""
import asyncio
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.server.task_store import TaskRecord, TaskStore
from agentkit.server.runner import BackgroundRunner
# ═══════════════════════════════════════════════════════════
# TaskStore Tests
# ═══════════════════════════════════════════════════════════
class TestTaskRecord:
"""TaskRecord dataclass tests"""
def test_to_dict_returns_complete_dict(self):
record = TaskRecord(
task_id="t1",
agent_name="agent_a",
skill_name="skill_x",
input_data={"query": "hello"},
)
d = record.to_dict()
assert d["task_id"] == "t1"
assert d["agent_name"] == "agent_a"
assert d["skill_name"] == "skill_x"
assert d["input_data"] == {"query": "hello"}
assert d["status"] == "pending"
assert d["output_data"] is None
assert d["error_message"] is None
assert d["progress"] == 0.0
assert d["created_at"] is not None
def test_to_dict_with_timestamps(self):
now = datetime.now(timezone.utc)
record = TaskRecord(
task_id="t2",
agent_name="agent_b",
skill_name=None,
input_data={},
started_at=now,
completed_at=now,
)
d = record.to_dict()
assert d["started_at"] == now.isoformat()
assert d["completed_at"] == now.isoformat()
class TestTaskStore:
"""TaskStore in-memory storage tests"""
def test_create_task_record_stored_correctly(self):
store = TaskStore()
record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
assert record.task_id == "t1"
assert record.agent_name == "agent_a"
assert record.skill_name == "skill_x"
assert record.input_data == {"q": "hello"}
assert record.status == TaskStatus.PENDING
def test_get_task_by_id_returns_record(self):
store = TaskStore()
store.create("t1", "agent_a", {})
record = store.get("t1")
assert record is not None
assert record.task_id == "t1"
def test_get_nonexistent_task_returns_none(self):
store = TaskStore()
assert store.get("nonexistent") is None
def test_update_status_fields_updated(self):
store = TaskStore()
store.create("t1", "agent_a", {})
now = datetime.now(timezone.utc)
record = store.update_status(
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway"
)
assert record.status == TaskStatus.RUNNING
assert record.started_at == now
assert record.progress == 0.5
assert record.progress_message == "Halfway"
def test_update_nonexistent_task_raises_keyerror(self):
store = TaskStore()
with pytest.raises(KeyError, match="not found"):
store.update_status("nonexistent", TaskStatus.RUNNING)
def test_list_tasks_returns_all_sorted_desc(self):
store = TaskStore()
store.create("t1", "agent_a", {})
store.create("t2", "agent_b", {})
tasks = store.list_tasks()
assert len(tasks) == 2
# Most recent first
assert tasks[0].task_id == "t2"
assert tasks[1].task_id == "t1"
def test_list_tasks_filtered_by_status(self):
store = TaskStore()
store.create("t1", "agent_a", {})
store.create("t2", "agent_b", {})
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
tasks = store.list_tasks(status=TaskStatus.COMPLETED)
assert len(tasks) == 1
assert tasks[0].task_id == "t1"
def test_max_records_limit_evicts_oldest_completed(self):
store = TaskStore(max_records=2)
store.create("t1", "agent_a", {})
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
store.create("t2", "agent_b", {})
# t3 should evict t1 (oldest completed)
store.create("t3", "agent_c", {})
assert store.get("t1") is None
assert store.get("t2") is not None
assert store.get("t3") is not None
def test_max_records_full_no_completed_raises(self):
store = TaskStore(max_records=1)
store.create("t1", "agent_a", {})
# All tasks are PENDING, no completed to evict
with pytest.raises(RuntimeError, match="full"):
store.create("t2", "agent_b", {})
def test_ttl_cleanup_removes_expired_completed(self):
store = TaskStore(ttl_seconds=0) # Immediate expiry
store.create("t1", "agent_a", {})
store.update_status(
"t1", TaskStatus.COMPLETED,
completed_at=datetime.now(timezone.utc) - timedelta(seconds=10),
)
store.create("t2", "agent_b", {})
# t2 is PENDING, should not be cleaned
store._cleanup_expired()
assert store.get("t1") is None # Expired completed
assert store.get("t2") is not None # Pending stays
def test_size_property_correct_count(self):
store = TaskStore()
assert store.size == 0
store.create("t1", "agent_a", {})
assert store.size == 1
store.create("t2", "agent_b", {})
assert store.size == 2
def test_list_tasks_respects_limit(self):
store = TaskStore()
for i in range(5):
store.create(f"t{i}", "agent_a", {})
tasks = store.list_tasks(limit=3)
assert len(tasks) == 3
# ═══════════════════════════════════════════════════════════
# BackgroundRunner Tests
# ═══════════════════════════════════════════════════════════
class TestBackgroundRunner:
"""BackgroundRunner async task execution tests"""
@pytest.fixture
def task_store(self):
return TaskStore()
@pytest.fixture
def runner(self, task_store):
return BackgroundRunner(task_store=task_store, max_concurrent=5)
def _make_mock_agent(self, name="test_agent", output=None, raise_error=None):
"""Create a mock agent for testing"""
agent = MagicMock()
agent.name = name
agent.agent_type = "test_type"
if raise_error:
agent.execute = AsyncMock(side_effect=raise_error)
else:
task_result = TaskResult(
task_id="mock",
agent_name=name,
status="completed",
output_data=output or {"result": "ok"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent.execute = AsyncMock(return_value=task_result)
return agent
@pytest.mark.asyncio
async def test_submit_returns_task_id_immediately(self, runner, task_store):
agent = self._make_mock_agent()
task_id = await runner.submit(agent, {"query": "test"})
assert task_id is not None
assert isinstance(task_id, str)
# Task record should exist in store
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
@pytest.mark.asyncio
async def test_submit_task_runs_to_completion(self, runner, task_store):
agent = self._make_mock_agent(output={"answer": "42"})
task_id = await runner.submit(agent, {"query": "meaning of life"})
# Wait for task to complete
await asyncio.sleep(0.1)
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.output_data == {"answer": "42"}
assert record.progress == 1.0
@pytest.mark.asyncio
async def test_submit_task_failure_recorded(self, runner, task_store):
agent = self._make_mock_agent(raise_error=RuntimeError("boom"))
task_id = await runner.submit(agent, {"query": "fail"})
# Wait for task to fail
await asyncio.sleep(0.1)
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
assert "boom" in record.error_message
@pytest.mark.asyncio
async def test_cancel_running_task(self, runner, task_store):
async def slow_execute(msg):
await asyncio.sleep(10) # Long running
return TaskResult(
task_id=msg.task_id,
agent_name="test_agent",
status="completed",
output_data={"result": "done"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent = MagicMock()
agent.name = "slow_agent"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=slow_execute)
task_id = await runner.submit(agent, {"query": "slow"})
# Give it a moment to start
await asyncio.sleep(0.05)
cancelled = await runner.cancel(task_id)
assert cancelled is True
record = task_store.get(task_id)
assert record.status == TaskStatus.CANCELLED
@pytest.mark.asyncio
async def test_cancel_non_running_task_returns_false(self, runner, task_store):
result = await runner.cancel("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_concurrent_tasks_respects_semaphore(self, task_store):
runner = BackgroundRunner(task_store=task_store, max_concurrent=2)
execution_order = []
async def tracked_execute(msg):
execution_order.append(f"start:{msg.task_id}")
await asyncio.sleep(0.1)
execution_order.append(f"end:{msg.task_id}")
return TaskResult(
task_id=msg.task_id,
agent_name="test",
status="completed",
output_data={},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agents = []
for i in range(4):
agent = MagicMock()
agent.name = f"agent_{i}"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=tracked_execute)
agents.append(agent)
# Submit all 4 tasks
task_ids = []
for agent in agents:
tid = await runner.submit(agent, {"idx": agents.index(agent)})
task_ids.append(tid)
# Wait for all to complete
await asyncio.sleep(0.5)
# All tasks should have completed
for tid in task_ids:
record = task_store.get(tid)
assert record.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_active_count_tracks_running(self, task_store):
runner = BackgroundRunner(task_store=task_store, max_concurrent=10)
async def slow_execute(msg):
await asyncio.sleep(0.2)
return TaskResult(
task_id=msg.task_id,
agent_name="test",
status="completed",
output_data={},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent = MagicMock()
agent.name = "slow_agent"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=slow_execute)
await runner.submit(agent, {})
await asyncio.sleep(0.05)
assert runner.active_count >= 1
await asyncio.sleep(0.3)
assert runner.active_count == 0
# ═══════════════════════════════════════════════════════════
# API Tests (using TestClient)
# ═══════════════════════════════════════════════════════════
class TestAsyncTaskAPI:
"""Async task API endpoint tests"""
@pytest.fixture
def mock_llm_gateway(self):
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = LLMGateway()
mock_provider = AsyncMock()
mock_provider.chat.return_value = LLMResponse(
content='{"result": "mocked"}',
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
return gateway
@pytest.fixture
def skill_registry(self):
from agentkit.skills.registry import SkillRegistry
return SkillRegistry()
@pytest.fixture
def tool_registry(self):
from agentkit.tools.registry import ToolRegistry
return ToolRegistry()
@pytest.fixture
def app(self, mock_llm_gateway, skill_registry, tool_registry):
from agentkit.server.app import create_app
return create_app(
llm_gateway=mock_llm_gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
@pytest.fixture
def client(self, app):
return TestClient(app)
def _register_skill_and_create_agent(self, client, skill_registry):
"""Helper: register a skill and create an agent for it"""
from agentkit.skills.base import Skill, SkillConfig
skill_config = SkillConfig(
name="async_skill",
agent_type="async_type",
task_mode="llm_generate",
prompt={"identity": "Async Skill", "instructions": "Handle async"},
intent={"keywords": ["async"], "description": "Async skill"},
)
skill = Skill(config=skill_config)
skill_registry.register(skill)
# Create agent
resp = client.post(
"/api/v1/agents",
json={"skill_name": "async_skill"},
)
assert resp.status_code == 201
return "async_skill"
def test_submit_task_async_returns_task_id(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
response = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "async test"},
"agent_name": agent_name,
"mode": "async",
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
assert data["status"] == "pending"
assert data["mode"] == "async"
def test_get_task_status_returns_record(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit async task
submit_resp = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "status test"},
"agent_name": agent_name,
"mode": "async",
},
)
task_id = submit_resp.json()["task_id"]
# Wait a bit for completion
import time
time.sleep(0.3)
# Get status
response = client.get(f"/api/v1/tasks/{task_id}")
assert response.status_code == 200
data = response.json()
assert data["task_id"] == task_id
assert data["status"] in ("completed", "running", "pending")
def test_get_task_status_not_found_404(self, client):
response = client.get("/api/v1/tasks/nonexistent-id")
assert response.status_code == 404
def test_cancel_task(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit async task
submit_resp = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "cancel test"},
"agent_name": agent_name,
"mode": "async",
},
)
task_id = submit_resp.json()["task_id"]
# Try to cancel (may or may not succeed depending on timing)
response = client.post(f"/api/v1/tasks/{task_id}/cancel")
# Either cancelled or 400 (already completed)
assert response.status_code in (200, 400)
def test_list_tasks(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit an async task to ensure at least one exists
client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "list test"},
"agent_name": agent_name,
"mode": "async",
},
)
response = client.get("/api/v1/tasks")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_list_tasks_filter_by_status(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit an async task
client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "filter test"},
"agent_name": agent_name,
"mode": "async",
},
)
response = client.get("/api/v1/tasks?status=completed")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# All returned tasks should be completed
for task in data:
assert task["status"] == "completed"
def test_sync_mode_still_works(self, client, skill_registry):
"""Ensure existing sync mode is not broken"""
agent_name = self._register_skill_and_create_agent(client, skill_registry)
response = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "sync test"},
"agent_name": agent_name,
},
)
assert response.status_code == 200
data = response.json()
# Sync mode returns task_id and output
assert "task_id" in data