"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API""" import asyncio from datetime import datetime, timezone, timedelta from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus from agentkit.server.task_store import TaskRecord, TaskStore from agentkit.server.runner import BackgroundRunner # ═══════════════════════════════════════════════════════════ # TaskStore Tests # ═══════════════════════════════════════════════════════════ class TestTaskRecord: """TaskRecord dataclass tests""" def test_to_dict_returns_complete_dict(self): record = TaskRecord( task_id="t1", agent_name="agent_a", skill_name="skill_x", input_data={"query": "hello"}, ) d = record.to_dict() assert d["task_id"] == "t1" assert d["agent_name"] == "agent_a" assert d["skill_name"] == "skill_x" assert d["input_data"] == {"query": "hello"} assert d["status"] == "pending" assert d["output_data"] is None assert d["error_message"] is None assert d["progress"] == 0.0 assert d["created_at"] is not None def test_to_dict_with_timestamps(self): now = datetime.now(timezone.utc) record = TaskRecord( task_id="t2", agent_name="agent_b", skill_name=None, input_data={}, started_at=now, completed_at=now, ) d = record.to_dict() assert d["started_at"] == now.isoformat() assert d["completed_at"] == now.isoformat() class TestTaskStore: """TaskStore in-memory storage tests""" def test_create_task_record_stored_correctly(self): store = TaskStore() record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x") assert record.task_id == "t1" assert record.agent_name == "agent_a" assert record.skill_name == "skill_x" assert record.input_data == {"q": "hello"} assert record.status == TaskStatus.PENDING def test_get_task_by_id_returns_record(self): store = TaskStore() store.create("t1", "agent_a", {}) record = store.get("t1") assert record is not None assert record.task_id == "t1" def test_get_nonexistent_task_returns_none(self): store = TaskStore() assert store.get("nonexistent") is None def test_update_status_fields_updated(self): store = TaskStore() store.create("t1", "agent_a", {}) now = datetime.now(timezone.utc) record = store.update_status( "t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway" ) assert record.status == TaskStatus.RUNNING assert record.started_at == now assert record.progress == 0.5 assert record.progress_message == "Halfway" def test_update_nonexistent_task_raises_keyerror(self): store = TaskStore() with pytest.raises(KeyError, match="not found"): store.update_status("nonexistent", TaskStatus.RUNNING) def test_list_tasks_returns_all_sorted_desc(self): store = TaskStore() store.create("t1", "agent_a", {}) store.create("t2", "agent_b", {}) tasks = store.list_tasks() assert len(tasks) == 2 # Most recent first assert tasks[0].task_id == "t2" assert tasks[1].task_id == "t1" def test_list_tasks_filtered_by_status(self): store = TaskStore() store.create("t1", "agent_a", {}) store.create("t2", "agent_b", {}) store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) tasks = store.list_tasks(status=TaskStatus.COMPLETED) assert len(tasks) == 1 assert tasks[0].task_id == "t1" def test_max_records_limit_evicts_oldest_completed(self): store = TaskStore(max_records=2) store.create("t1", "agent_a", {}) store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) store.create("t2", "agent_b", {}) # t3 should evict t1 (oldest completed) store.create("t3", "agent_c", {}) assert store.get("t1") is None assert store.get("t2") is not None assert store.get("t3") is not None def test_max_records_full_no_completed_raises(self): store = TaskStore(max_records=1) store.create("t1", "agent_a", {}) # All tasks are PENDING, no completed to evict with pytest.raises(RuntimeError, match="full"): store.create("t2", "agent_b", {}) def test_ttl_cleanup_removes_expired_completed(self): store = TaskStore(ttl_seconds=0) # Immediate expiry store.create("t1", "agent_a", {}) store.update_status( "t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc) - timedelta(seconds=10), ) store.create("t2", "agent_b", {}) # t2 is PENDING, should not be cleaned store._cleanup_expired() assert store.get("t1") is None # Expired completed assert store.get("t2") is not None # Pending stays def test_size_property_correct_count(self): store = TaskStore() assert store.size == 0 store.create("t1", "agent_a", {}) assert store.size == 1 store.create("t2", "agent_b", {}) assert store.size == 2 def test_list_tasks_respects_limit(self): store = TaskStore() for i in range(5): store.create(f"t{i}", "agent_a", {}) tasks = store.list_tasks(limit=3) assert len(tasks) == 3 # ═══════════════════════════════════════════════════════════ # BackgroundRunner Tests # ═══════════════════════════════════════════════════════════ class TestBackgroundRunner: """BackgroundRunner async task execution tests""" @pytest.fixture def task_store(self): return TaskStore() @pytest.fixture def runner(self, task_store): return BackgroundRunner(task_store=task_store, max_concurrent=5) def _make_mock_agent(self, name="test_agent", output=None, raise_error=None): """Create a mock agent for testing""" agent = MagicMock() agent.name = name agent.agent_type = "test_type" if raise_error: agent.execute = AsyncMock(side_effect=raise_error) else: task_result = TaskResult( task_id="mock", agent_name=name, status="completed", output_data=output or {"result": "ok"}, error_message=None, started_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc), ) agent.execute = AsyncMock(return_value=task_result) return agent @pytest.mark.asyncio async def test_submit_returns_task_id_immediately(self, runner, task_store): agent = self._make_mock_agent() task_id = await runner.submit(agent, {"query": "test"}) assert task_id is not None assert isinstance(task_id, str) # Task record should exist in store record = task_store.get(task_id) assert record is not None assert record.status == TaskStatus.PENDING @pytest.mark.asyncio async def test_submit_task_runs_to_completion(self, runner, task_store): agent = self._make_mock_agent(output={"answer": "42"}) task_id = await runner.submit(agent, {"query": "meaning of life"}) # Wait for task to complete await asyncio.sleep(0.1) record = task_store.get(task_id) assert record is not None assert record.status == TaskStatus.COMPLETED assert record.output_data == {"answer": "42"} assert record.progress == 1.0 @pytest.mark.asyncio async def test_submit_task_failure_recorded(self, runner, task_store): agent = self._make_mock_agent(raise_error=RuntimeError("boom")) task_id = await runner.submit(agent, {"query": "fail"}) # Wait for task to fail await asyncio.sleep(0.1) record = task_store.get(task_id) assert record is not None assert record.status == TaskStatus.FAILED assert "boom" in record.error_message @pytest.mark.asyncio async def test_cancel_running_task(self, runner, task_store): async def slow_execute(msg): await asyncio.sleep(10) # Long running return TaskResult( task_id=msg.task_id, agent_name="test_agent", status="completed", output_data={"result": "done"}, error_message=None, started_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc), ) agent = MagicMock() agent.name = "slow_agent" agent.agent_type = "test_type" agent.execute = AsyncMock(side_effect=slow_execute) task_id = await runner.submit(agent, {"query": "slow"}) # Give it a moment to start await asyncio.sleep(0.05) cancelled = await runner.cancel(task_id) assert cancelled is True record = task_store.get(task_id) assert record.status == TaskStatus.CANCELLED @pytest.mark.asyncio async def test_cancel_non_running_task_returns_false(self, runner, task_store): result = await runner.cancel("nonexistent") assert result is False @pytest.mark.asyncio async def test_concurrent_tasks_respects_semaphore(self, task_store): runner = BackgroundRunner(task_store=task_store, max_concurrent=2) execution_order = [] async def tracked_execute(msg): execution_order.append(f"start:{msg.task_id}") await asyncio.sleep(0.1) execution_order.append(f"end:{msg.task_id}") return TaskResult( task_id=msg.task_id, agent_name="test", status="completed", output_data={}, error_message=None, started_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc), ) agents = [] for i in range(4): agent = MagicMock() agent.name = f"agent_{i}" agent.agent_type = "test_type" agent.execute = AsyncMock(side_effect=tracked_execute) agents.append(agent) # Submit all 4 tasks task_ids = [] for agent in agents: tid = await runner.submit(agent, {"idx": agents.index(agent)}) task_ids.append(tid) # Wait for all to complete await asyncio.sleep(0.5) # All tasks should have completed for tid in task_ids: record = task_store.get(tid) assert record.status == TaskStatus.COMPLETED @pytest.mark.asyncio async def test_active_count_tracks_running(self, task_store): runner = BackgroundRunner(task_store=task_store, max_concurrent=10) async def slow_execute(msg): await asyncio.sleep(0.2) return TaskResult( task_id=msg.task_id, agent_name="test", status="completed", output_data={}, error_message=None, started_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc), ) agent = MagicMock() agent.name = "slow_agent" agent.agent_type = "test_type" agent.execute = AsyncMock(side_effect=slow_execute) await runner.submit(agent, {}) await asyncio.sleep(0.05) assert runner.active_count >= 1 await asyncio.sleep(0.3) assert runner.active_count == 0 # ═══════════════════════════════════════════════════════════ # API Tests (using TestClient) # ═══════════════════════════════════════════════════════════ class TestAsyncTaskAPI: """Async task API endpoint tests""" @pytest.fixture def mock_llm_gateway(self): from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMResponse, TokenUsage gateway = LLMGateway() mock_provider = AsyncMock() mock_provider.chat.return_value = LLMResponse( content='{"result": "mocked"}', model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=20), ) gateway.register_provider("test", mock_provider) return gateway @pytest.fixture def skill_registry(self): from agentkit.skills.registry import SkillRegistry return SkillRegistry() @pytest.fixture def tool_registry(self): from agentkit.tools.registry import ToolRegistry return ToolRegistry() @pytest.fixture def app(self, mock_llm_gateway, skill_registry, tool_registry): from agentkit.server.app import create_app return create_app( llm_gateway=mock_llm_gateway, skill_registry=skill_registry, tool_registry=tool_registry, ) @pytest.fixture def client(self, app): return TestClient(app) def _register_skill_and_create_agent(self, client, skill_registry): """Helper: register a skill and create an agent for it""" from agentkit.skills.base import Skill, SkillConfig skill_config = SkillConfig( name="async_skill", agent_type="async_type", task_mode="llm_generate", prompt={"identity": "Async Skill", "instructions": "Handle async"}, intent={"keywords": ["async"], "description": "Async skill"}, ) skill = Skill(config=skill_config) skill_registry.register(skill) # Create agent resp = client.post( "/api/v1/agents", json={"skill_name": "async_skill"}, ) assert resp.status_code == 201 return "async_skill" def test_submit_task_async_returns_task_id(self, client, skill_registry): agent_name = self._register_skill_and_create_agent(client, skill_registry) response = client.post( "/api/v1/tasks", json={ "input_data": {"query": "async test"}, "agent_name": agent_name, "mode": "async", }, ) assert response.status_code == 200 data = response.json() assert "task_id" in data assert data["status"] == "pending" assert data["mode"] == "async" def test_get_task_status_returns_record(self, client, skill_registry): agent_name = self._register_skill_and_create_agent(client, skill_registry) # Submit async task submit_resp = client.post( "/api/v1/tasks", json={ "input_data": {"query": "status test"}, "agent_name": agent_name, "mode": "async", }, ) task_id = submit_resp.json()["task_id"] # Wait a bit for completion import time time.sleep(0.3) # Get status response = client.get(f"/api/v1/tasks/{task_id}") assert response.status_code == 200 data = response.json() assert data["task_id"] == task_id assert data["status"] in ("completed", "running", "pending") def test_get_task_status_not_found_404(self, client): response = client.get("/api/v1/tasks/nonexistent-id") assert response.status_code == 404 def test_cancel_task(self, client, skill_registry): agent_name = self._register_skill_and_create_agent(client, skill_registry) # Submit async task submit_resp = client.post( "/api/v1/tasks", json={ "input_data": {"query": "cancel test"}, "agent_name": agent_name, "mode": "async", }, ) task_id = submit_resp.json()["task_id"] # Try to cancel (may or may not succeed depending on timing) response = client.post(f"/api/v1/tasks/{task_id}/cancel") # Either cancelled or 400 (already completed) assert response.status_code in (200, 400) def test_list_tasks(self, client, skill_registry): agent_name = self._register_skill_and_create_agent(client, skill_registry) # Submit an async task to ensure at least one exists client.post( "/api/v1/tasks", json={ "input_data": {"query": "list test"}, "agent_name": agent_name, "mode": "async", }, ) response = client.get("/api/v1/tasks") assert response.status_code == 200 data = response.json() assert isinstance(data, list) def test_list_tasks_filter_by_status(self, client, skill_registry): agent_name = self._register_skill_and_create_agent(client, skill_registry) # Submit an async task client.post( "/api/v1/tasks", json={ "input_data": {"query": "filter test"}, "agent_name": agent_name, "mode": "async", }, ) response = client.get("/api/v1/tasks?status=completed") assert response.status_code == 200 data = response.json() assert isinstance(data, list) # All returned tasks should be completed for task in data: assert task["status"] == "completed" def test_sync_mode_still_works(self, client, skill_registry): """Ensure existing sync mode is not broken""" agent_name = self._register_skill_and_create_agent(client, skill_registry) response = client.post( "/api/v1/tasks", json={ "input_data": {"query": "sync test"}, "agent_name": agent_name, }, ) assert response.status_code == 200 data = response.json() # Sync mode returns task_id and output assert "task_id" in data