diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 3e08ee3..1c5b543 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -1,6 +1,8 @@ """FastAPI Application Factory""" import os +from contextlib import asynccontextmanager + from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -13,6 +15,18 @@ from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry from agentkit.server.routes import agents, tasks, skills, llm, health from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware +from agentkit.server.task_store import TaskStore +from agentkit.server.runner import BackgroundRunner + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + task_store = app.state.task_store + await task_store.start_cleanup() + yield + # Shutdown + await task_store.stop_cleanup() def create_app( @@ -23,7 +37,7 @@ def create_app( rate_limit: int | None = None, ) -> FastAPI: """Create and configure the FastAPI application""" - app = FastAPI(title="AgentKit Server", version="2.0.0") + app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan) # CORS 配置 app.add_middleware( @@ -55,6 +69,8 @@ def create_app( app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway) app.state.quality_gate = QualityGate() app.state.output_standardizer = OutputStandardizer() + app.state.task_store = TaskStore() + app.state.runner = BackgroundRunner(task_store=app.state.task_store) # Include routes app.include_router(agents.router, prefix="/api/v1") diff --git a/src/agentkit/server/client.py b/src/agentkit/server/client.py index 26f38a5..f850a35 100644 --- a/src/agentkit/server/client.py +++ b/src/agentkit/server/client.py @@ -87,6 +87,45 @@ class AgentKitClient: response.raise_for_status() return response.json() + async def submit_task_async( + self, + input_data: dict, + skill_name: str | None = None, + agent_name: str | None = None, + ) -> dict: + """Submit a task in async mode""" + payload: dict[str, Any] = {"input_data": input_data, "mode": "async"} + if skill_name: + payload["skill_name"] = skill_name + if agent_name: + payload["agent_name"] = agent_name + response = await self._client.post("/api/v1/tasks", json=payload) + response.raise_for_status() + return response.json() + + async def get_task_status(self, task_id: str) -> dict: + """Get task status""" + response = await self._client.get(f"/api/v1/tasks/{task_id}") + response.raise_for_status() + return response.json() + + async def cancel_task(self, task_id: str) -> dict: + """Cancel a running task""" + response = await self._client.post(f"/api/v1/tasks/{task_id}/cancel") + response.raise_for_status() + return response.json() + + async def list_tasks( + self, status: str | None = None, limit: int = 100 + ) -> list[dict]: + """List tasks""" + params: dict[str, Any] = {"limit": limit} + if status: + params["status"] = status + response = await self._client.get("/api/v1/tasks", params=params) + response.raise_for_status() + return response.json() + async def close(self) -> None: """Close the HTTP client""" await self._client.aclose() diff --git a/src/agentkit/server/routes/tasks.py b/src/agentkit/server/routes/tasks.py index 418019b..52d70e9 100644 --- a/src/agentkit/server/routes/tasks.py +++ b/src/agentkit/server/routes/tasks.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel from typing import Any -from agentkit.core.protocol import TaskMessage +from agentkit.core.protocol import TaskMessage, TaskStatus router = APIRouter(tags=["tasks"]) @@ -16,6 +16,7 @@ class SubmitTaskRequest(BaseModel): input_data: dict[str, Any] skill_name: str | None = None agent_name: str | None = None + mode: str = "sync" # "sync" or "async" # 输入数据大小限制(防止 OOM) model_config = {"json_schema_extra": {"max_input_size_bytes": 1024 * 1024}} # 1MB @@ -39,6 +40,15 @@ def _validate_input_size(input_data: dict) -> None: ) +@router.get("/tasks") +async def list_tasks(status: str | None = None, limit: int = 100, req: Request = None): + """List tasks""" + store = req.app.state.task_store + task_status = TaskStatus(status) if status else None + records = store.list_tasks(status=task_status, limit=limit) + return [r.to_dict() for r in records] + + @router.post("/tasks") async def submit_task(request: SubmitTaskRequest, req: Request): """Submit a task (Intent Router auto-routes to skill)""" @@ -98,7 +108,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request): except (ValueError, RuntimeError) as e: raise HTTPException(status_code=400, detail=str(e)) - # 4. Execute task + # 4. Async mode: submit to background runner + if request.mode == "async": + runner = req.app.state.runner + task_id = await runner.submit( + agent=agent, + input_data=request.input_data, + skill_name=request.skill_name, + quality_gate=quality_gate, + output_standardizer=output_standardizer, + skill=skill, + ) + return {"task_id": task_id, "status": "pending", "mode": "async"} + + # 5. Sync mode: existing blocking execution task = TaskMessage( task_id=str(uuid.uuid4()), agent_name=agent.name, @@ -111,7 +134,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): task_result = await agent.execute(task) - # 5. Run quality gate if skill available + # 6. Run quality gate if skill available quality_result = None if skill: try: @@ -119,7 +142,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): except Exception: pass # Quality gate failure shouldn't block the response - # 6. Standardize output if skill available + # 7. Standardize output if skill available if skill: try: standard_output = await output_standardizer.standardize( @@ -141,7 +164,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request): except Exception: pass # Fall through to raw output - # 7. Return raw result if no skill or standardization failed + # 8. Return raw result if no skill or standardization failed return { "task_id": task.task_id, "status": task_result.status, @@ -151,6 +174,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request): @router.get("/tasks/{task_id}") -async def get_task_status(task_id: str): - """Get task status (placeholder for async mode)""" - return {"task_id": task_id, "status": "placeholder"} +async def get_task_status(task_id: str, req: Request): + """Get task status and result""" + store = req.app.state.task_store + record = store.get(task_id) + if record is None: + raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found") + return record.to_dict() + + +@router.post("/tasks/{task_id}/cancel") +async def cancel_task(task_id: str, req: Request): + """Cancel a running task""" + runner = req.app.state.runner + cancelled = await runner.cancel(task_id) + if not cancelled: + raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)") + return {"task_id": task_id, "status": "cancelled"} diff --git a/src/agentkit/server/runner.py b/src/agentkit/server/runner.py new file mode 100644 index 0000000..e5d1ce9 --- /dev/null +++ b/src/agentkit/server/runner.py @@ -0,0 +1,170 @@ +"""BackgroundRunner - Async task execution with lifecycle management""" + +import asyncio +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +from agentkit.core.protocol import TaskMessage, TaskStatus +from agentkit.server.task_store import TaskStore + +logger = logging.getLogger(__name__) + + +class BackgroundRunner: + """Runs tasks in background asyncio tasks with lifecycle management. + + Integrates with AgentPool for agent execution and TaskStore for state tracking. + """ + + def __init__(self, task_store: TaskStore, max_concurrent: int = 10): + self._task_store = task_store + self._max_concurrent = max_concurrent + self._running_tasks: dict[str, asyncio.Task] = {} + self._semaphore = asyncio.Semaphore(max_concurrent) + + @property + def active_count(self) -> int: + return len(self._running_tasks) + + async def submit( + self, + agent, # ConfigDrivenAgent + input_data: dict[str, Any], + skill_name: str | None = None, + quality_gate=None, + output_standardizer=None, + skill=None, + ) -> str: + """Submit a task for background execution. + + Returns task_id immediately. + """ + task_id = str(uuid.uuid4()) + + # Create task record + self._task_store.create( + task_id=task_id, + agent_name=agent.name, + input_data=input_data, + skill_name=skill_name, + ) + + # Launch background asyncio task + asyncio_task = asyncio.create_task( + self._run_task( + task_id=task_id, + agent=agent, + input_data=input_data, + quality_gate=quality_gate, + output_standardizer=output_standardizer, + skill=skill, + ) + ) + self._running_tasks[task_id] = asyncio_task + + # Clean up reference when done + def _on_done(t: asyncio.Task): + self._running_tasks.pop(task_id, None) + if t.exception(): + logger.error(f"Background task {task_id} failed: {t.exception()}") + + asyncio_task.add_done_callback(_on_done) + + return task_id + + async def _run_task( + self, + task_id: str, + agent, + input_data: dict, + quality_gate=None, + output_standardizer=None, + skill=None, + ) -> dict[str, Any]: + """Execute task in background with semaphore control""" + async with self._semaphore: + # Update status to RUNNING + self._task_store.update_status( + task_id, TaskStatus.RUNNING, + started_at=datetime.now(timezone.utc), + ) + + try: + # Create TaskMessage for agent + task_msg = TaskMessage( + task_id=task_id, + agent_name=agent.name, + task_type=agent.agent_type, + priority=0, + input_data=input_data, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + # Execute agent + task_result = await agent.execute(task_msg) + + # Run quality gate if available + quality_result = None + if skill and quality_gate: + try: + quality_result = await quality_gate.validate( + task_result.output_data or {}, skill + ) + except Exception as e: + logger.warning(f"Quality gate failed for {task_id}: {e}") + + # Standardize output if available + final_output = task_result.output_data + if skill and output_standardizer: + try: + standard_output = await output_standardizer.standardize( + raw_output=task_result.output_data or {}, + skill=skill, + quality_result=quality_result, + ) + final_output = { + "skill_name": standard_output.skill_name, + "data": standard_output.data, + "metadata": { + "version": standard_output.metadata.version, + "produced_at": standard_output.metadata.produced_at.isoformat(), + "quality_score": standard_output.metadata.quality_score, + }, + } + except Exception as e: + logger.warning(f"Output standardization failed for {task_id}: {e}") + + # Update store + self._task_store.update_status( + task_id, TaskStatus.COMPLETED, + output_data=final_output, + completed_at=datetime.now(timezone.utc), + progress=1.0, + progress_message="Completed", + ) + + return final_output or {} + + except Exception as e: + logger.error(f"Task {task_id} failed: {e}") + self._task_store.update_status( + task_id, TaskStatus.FAILED, + error_message=str(e), + completed_at=datetime.now(timezone.utc), + ) + raise + + async def cancel(self, task_id: str) -> bool: + """Cancel a running task""" + asyncio_task = self._running_tasks.get(task_id) + if asyncio_task and not asyncio_task.done(): + asyncio_task.cancel() + self._task_store.update_status( + task_id, TaskStatus.CANCELLED, + completed_at=datetime.now(timezone.utc), + ) + return True + return False diff --git a/src/agentkit/server/task_store.py b/src/agentkit/server/task_store.py new file mode 100644 index 0000000..9976fc3 --- /dev/null +++ b/src/agentkit/server/task_store.py @@ -0,0 +1,151 @@ +"""TaskStore - In-memory task state storage with TTL""" + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +from agentkit.core.protocol import TaskStatus + +logger = logging.getLogger(__name__) + + +@dataclass +class TaskRecord: + """Stored task record with full lifecycle data""" + task_id: str + agent_name: str + skill_name: str | None + input_data: dict[str, Any] + status: TaskStatus = TaskStatus.PENDING + output_data: dict[str, Any] | None = None + error_message: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + started_at: datetime | None = None + completed_at: datetime | None = None + progress: float = 0.0 + progress_message: str = "" + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "task_id": self.task_id, + "agent_name": self.agent_name, + "skill_name": self.skill_name, + "input_data": self.input_data, + "status": self.status.value, + "output_data": self.output_data, + "error_message": self.error_message, + "created_at": self.created_at.isoformat(), + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "progress": self.progress, + "progress_message": self.progress_message, + "metadata": self.metadata, + } + + +class TaskStore: + """In-memory task state storage with automatic TTL cleanup. + + Stores task records indexed by task_id. Automatically removes + completed tasks after a configurable TTL. + """ + + def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000): + self._tasks: dict[str, TaskRecord] = {} + self._ttl_seconds = ttl_seconds + self._max_records = max_records + self._cleanup_task: asyncio.Task | None = None + + async def start_cleanup(self) -> None: + """Start background cleanup task""" + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + + async def stop_cleanup(self) -> None: + """Stop background cleanup task""" + if self._cleanup_task: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + async def _cleanup_loop(self) -> None: + """Periodically remove expired task records""" + while True: + try: + await asyncio.sleep(60) + self._cleanup_expired() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"TaskStore cleanup error: {e}") + + def _cleanup_expired(self) -> None: + """Remove expired records""" + expired = [] + for task_id, record in self._tasks.items(): + if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): + if record.completed_at: + age = (datetime.now(timezone.utc) - record.completed_at).total_seconds() + if age > self._ttl_seconds: + expired.append(task_id) + for task_id in expired: + del self._tasks[task_id] + if expired: + logger.info(f"TaskStore cleaned up {len(expired)} expired records") + + def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord: + """Create a new task record""" + if len(self._tasks) >= self._max_records: + # Remove oldest completed task + oldest = None + for tid, rec in self._tasks.items(): + if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): + if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)): + oldest = rec + if oldest: + del self._tasks[oldest.task_id] + else: + raise RuntimeError("TaskStore is full and no completed tasks to evict") + + record = TaskRecord( + task_id=task_id, + agent_name=agent_name, + skill_name=skill_name, + input_data=input_data, + ) + self._tasks[task_id] = record + return record + + def get(self, task_id: str) -> TaskRecord | None: + """Get task record by ID""" + return self._tasks.get(task_id) + + def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord: + """Update task status and optional fields""" + record = self._tasks.get(task_id) + if record is None: + raise KeyError(f"Task '{task_id}' not found") + record.status = status + for key, value in kwargs.items(): + if hasattr(record, key): + setattr(record, key, value) + return record + + def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]: + """List tasks, optionally filtered by status""" + tasks = list(self._tasks.values()) + if status: + tasks = [t for t in tasks if t.status == status] + tasks.sort(key=lambda t: t.created_at, reverse=True) + return tasks[:limit] + + @property + def size(self) -> int: + return len(self._tasks) diff --git a/tests/unit/test_async_tasks.py b/tests/unit/test_async_tasks.py new file mode 100644 index 0000000..fd67a64 --- /dev/null +++ b/tests/unit/test_async_tasks.py @@ -0,0 +1,512 @@ +"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API""" + +import asyncio +from datetime import datetime, timezone, timedelta +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus +from agentkit.server.task_store import TaskRecord, TaskStore +from agentkit.server.runner import BackgroundRunner + + +# ═══════════════════════════════════════════════════════════ +# TaskStore Tests +# ═══════════════════════════════════════════════════════════ + + +class TestTaskRecord: + """TaskRecord dataclass tests""" + + def test_to_dict_returns_complete_dict(self): + record = TaskRecord( + task_id="t1", + agent_name="agent_a", + skill_name="skill_x", + input_data={"query": "hello"}, + ) + d = record.to_dict() + assert d["task_id"] == "t1" + assert d["agent_name"] == "agent_a" + assert d["skill_name"] == "skill_x" + assert d["input_data"] == {"query": "hello"} + assert d["status"] == "pending" + assert d["output_data"] is None + assert d["error_message"] is None + assert d["progress"] == 0.0 + assert d["created_at"] is not None + + def test_to_dict_with_timestamps(self): + now = datetime.now(timezone.utc) + record = TaskRecord( + task_id="t2", + agent_name="agent_b", + skill_name=None, + input_data={}, + started_at=now, + completed_at=now, + ) + d = record.to_dict() + assert d["started_at"] == now.isoformat() + assert d["completed_at"] == now.isoformat() + + +class TestTaskStore: + """TaskStore in-memory storage tests""" + + def test_create_task_record_stored_correctly(self): + store = TaskStore() + record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x") + assert record.task_id == "t1" + assert record.agent_name == "agent_a" + assert record.skill_name == "skill_x" + assert record.input_data == {"q": "hello"} + assert record.status == TaskStatus.PENDING + + def test_get_task_by_id_returns_record(self): + store = TaskStore() + store.create("t1", "agent_a", {}) + record = store.get("t1") + assert record is not None + assert record.task_id == "t1" + + def test_get_nonexistent_task_returns_none(self): + store = TaskStore() + assert store.get("nonexistent") is None + + def test_update_status_fields_updated(self): + store = TaskStore() + store.create("t1", "agent_a", {}) + now = datetime.now(timezone.utc) + record = store.update_status( + "t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway" + ) + assert record.status == TaskStatus.RUNNING + assert record.started_at == now + assert record.progress == 0.5 + assert record.progress_message == "Halfway" + + def test_update_nonexistent_task_raises_keyerror(self): + store = TaskStore() + with pytest.raises(KeyError, match="not found"): + store.update_status("nonexistent", TaskStatus.RUNNING) + + def test_list_tasks_returns_all_sorted_desc(self): + store = TaskStore() + store.create("t1", "agent_a", {}) + store.create("t2", "agent_b", {}) + tasks = store.list_tasks() + assert len(tasks) == 2 + # Most recent first + assert tasks[0].task_id == "t2" + assert tasks[1].task_id == "t1" + + def test_list_tasks_filtered_by_status(self): + store = TaskStore() + store.create("t1", "agent_a", {}) + store.create("t2", "agent_b", {}) + store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) + tasks = store.list_tasks(status=TaskStatus.COMPLETED) + assert len(tasks) == 1 + assert tasks[0].task_id == "t1" + + def test_max_records_limit_evicts_oldest_completed(self): + store = TaskStore(max_records=2) + store.create("t1", "agent_a", {}) + store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) + store.create("t2", "agent_b", {}) + # t3 should evict t1 (oldest completed) + store.create("t3", "agent_c", {}) + assert store.get("t1") is None + assert store.get("t2") is not None + assert store.get("t3") is not None + + def test_max_records_full_no_completed_raises(self): + store = TaskStore(max_records=1) + store.create("t1", "agent_a", {}) + # All tasks are PENDING, no completed to evict + with pytest.raises(RuntimeError, match="full"): + store.create("t2", "agent_b", {}) + + def test_ttl_cleanup_removes_expired_completed(self): + store = TaskStore(ttl_seconds=0) # Immediate expiry + store.create("t1", "agent_a", {}) + store.update_status( + "t1", TaskStatus.COMPLETED, + completed_at=datetime.now(timezone.utc) - timedelta(seconds=10), + ) + store.create("t2", "agent_b", {}) + # t2 is PENDING, should not be cleaned + store._cleanup_expired() + assert store.get("t1") is None # Expired completed + assert store.get("t2") is not None # Pending stays + + def test_size_property_correct_count(self): + store = TaskStore() + assert store.size == 0 + store.create("t1", "agent_a", {}) + assert store.size == 1 + store.create("t2", "agent_b", {}) + assert store.size == 2 + + def test_list_tasks_respects_limit(self): + store = TaskStore() + for i in range(5): + store.create(f"t{i}", "agent_a", {}) + tasks = store.list_tasks(limit=3) + assert len(tasks) == 3 + + +# ═══════════════════════════════════════════════════════════ +# BackgroundRunner Tests +# ═══════════════════════════════════════════════════════════ + + +class TestBackgroundRunner: + """BackgroundRunner async task execution tests""" + + @pytest.fixture + def task_store(self): + return TaskStore() + + @pytest.fixture + def runner(self, task_store): + return BackgroundRunner(task_store=task_store, max_concurrent=5) + + def _make_mock_agent(self, name="test_agent", output=None, raise_error=None): + """Create a mock agent for testing""" + agent = MagicMock() + agent.name = name + agent.agent_type = "test_type" + if raise_error: + agent.execute = AsyncMock(side_effect=raise_error) + else: + task_result = TaskResult( + task_id="mock", + agent_name=name, + status="completed", + output_data=output or {"result": "ok"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + agent.execute = AsyncMock(return_value=task_result) + return agent + + @pytest.mark.asyncio + async def test_submit_returns_task_id_immediately(self, runner, task_store): + agent = self._make_mock_agent() + task_id = await runner.submit(agent, {"query": "test"}) + assert task_id is not None + assert isinstance(task_id, str) + # Task record should exist in store + record = task_store.get(task_id) + assert record is not None + assert record.status == TaskStatus.PENDING + + @pytest.mark.asyncio + async def test_submit_task_runs_to_completion(self, runner, task_store): + agent = self._make_mock_agent(output={"answer": "42"}) + task_id = await runner.submit(agent, {"query": "meaning of life"}) + # Wait for task to complete + await asyncio.sleep(0.1) + record = task_store.get(task_id) + assert record is not None + assert record.status == TaskStatus.COMPLETED + assert record.output_data == {"answer": "42"} + assert record.progress == 1.0 + + @pytest.mark.asyncio + async def test_submit_task_failure_recorded(self, runner, task_store): + agent = self._make_mock_agent(raise_error=RuntimeError("boom")) + task_id = await runner.submit(agent, {"query": "fail"}) + # Wait for task to fail + await asyncio.sleep(0.1) + record = task_store.get(task_id) + assert record is not None + assert record.status == TaskStatus.FAILED + assert "boom" in record.error_message + + @pytest.mark.asyncio + async def test_cancel_running_task(self, runner, task_store): + async def slow_execute(msg): + await asyncio.sleep(10) # Long running + return TaskResult( + task_id=msg.task_id, + agent_name="test_agent", + status="completed", + output_data={"result": "done"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + agent = MagicMock() + agent.name = "slow_agent" + agent.agent_type = "test_type" + agent.execute = AsyncMock(side_effect=slow_execute) + + task_id = await runner.submit(agent, {"query": "slow"}) + # Give it a moment to start + await asyncio.sleep(0.05) + cancelled = await runner.cancel(task_id) + assert cancelled is True + record = task_store.get(task_id) + assert record.status == TaskStatus.CANCELLED + + @pytest.mark.asyncio + async def test_cancel_non_running_task_returns_false(self, runner, task_store): + result = await runner.cancel("nonexistent") + assert result is False + + @pytest.mark.asyncio + async def test_concurrent_tasks_respects_semaphore(self, task_store): + runner = BackgroundRunner(task_store=task_store, max_concurrent=2) + execution_order = [] + + async def tracked_execute(msg): + execution_order.append(f"start:{msg.task_id}") + await asyncio.sleep(0.1) + execution_order.append(f"end:{msg.task_id}") + return TaskResult( + task_id=msg.task_id, + agent_name="test", + status="completed", + output_data={}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + agents = [] + for i in range(4): + agent = MagicMock() + agent.name = f"agent_{i}" + agent.agent_type = "test_type" + agent.execute = AsyncMock(side_effect=tracked_execute) + agents.append(agent) + + # Submit all 4 tasks + task_ids = [] + for agent in agents: + tid = await runner.submit(agent, {"idx": agents.index(agent)}) + task_ids.append(tid) + + # Wait for all to complete + await asyncio.sleep(0.5) + + # All tasks should have completed + for tid in task_ids: + record = task_store.get(tid) + assert record.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_active_count_tracks_running(self, task_store): + runner = BackgroundRunner(task_store=task_store, max_concurrent=10) + + async def slow_execute(msg): + await asyncio.sleep(0.2) + return TaskResult( + task_id=msg.task_id, + agent_name="test", + status="completed", + output_data={}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + ) + + agent = MagicMock() + agent.name = "slow_agent" + agent.agent_type = "test_type" + agent.execute = AsyncMock(side_effect=slow_execute) + + await runner.submit(agent, {}) + await asyncio.sleep(0.05) + assert runner.active_count >= 1 + + await asyncio.sleep(0.3) + assert runner.active_count == 0 + + +# ═══════════════════════════════════════════════════════════ +# API Tests (using TestClient) +# ═══════════════════════════════════════════════════════════ + + +class TestAsyncTaskAPI: + """Async task API endpoint tests""" + + @pytest.fixture + def mock_llm_gateway(self): + from agentkit.llm.gateway import LLMGateway + from agentkit.llm.protocol import LLMResponse, TokenUsage + + gateway = LLMGateway() + mock_provider = AsyncMock() + mock_provider.chat.return_value = LLMResponse( + content='{"result": "mocked"}', + model="test-model", + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + ) + gateway.register_provider("test", mock_provider) + return gateway + + @pytest.fixture + def skill_registry(self): + from agentkit.skills.registry import SkillRegistry + return SkillRegistry() + + @pytest.fixture + def tool_registry(self): + from agentkit.tools.registry import ToolRegistry + return ToolRegistry() + + @pytest.fixture + def app(self, mock_llm_gateway, skill_registry, tool_registry): + from agentkit.server.app import create_app + return create_app( + llm_gateway=mock_llm_gateway, + skill_registry=skill_registry, + tool_registry=tool_registry, + ) + + @pytest.fixture + def client(self, app): + return TestClient(app) + + def _register_skill_and_create_agent(self, client, skill_registry): + """Helper: register a skill and create an agent for it""" + from agentkit.skills.base import Skill, SkillConfig + + skill_config = SkillConfig( + name="async_skill", + agent_type="async_type", + task_mode="llm_generate", + prompt={"identity": "Async Skill", "instructions": "Handle async"}, + intent={"keywords": ["async"], "description": "Async skill"}, + ) + skill = Skill(config=skill_config) + skill_registry.register(skill) + + # Create agent + resp = client.post( + "/api/v1/agents", + json={"skill_name": "async_skill"}, + ) + assert resp.status_code == 201 + return "async_skill" + + def test_submit_task_async_returns_task_id(self, client, skill_registry): + agent_name = self._register_skill_and_create_agent(client, skill_registry) + response = client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "async test"}, + "agent_name": agent_name, + "mode": "async", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "task_id" in data + assert data["status"] == "pending" + assert data["mode"] == "async" + + def test_get_task_status_returns_record(self, client, skill_registry): + agent_name = self._register_skill_and_create_agent(client, skill_registry) + # Submit async task + submit_resp = client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "status test"}, + "agent_name": agent_name, + "mode": "async", + }, + ) + task_id = submit_resp.json()["task_id"] + + # Wait a bit for completion + import time + time.sleep(0.3) + + # Get status + response = client.get(f"/api/v1/tasks/{task_id}") + assert response.status_code == 200 + data = response.json() + assert data["task_id"] == task_id + assert data["status"] in ("completed", "running", "pending") + + def test_get_task_status_not_found_404(self, client): + response = client.get("/api/v1/tasks/nonexistent-id") + assert response.status_code == 404 + + def test_cancel_task(self, client, skill_registry): + agent_name = self._register_skill_and_create_agent(client, skill_registry) + # Submit async task + submit_resp = client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "cancel test"}, + "agent_name": agent_name, + "mode": "async", + }, + ) + task_id = submit_resp.json()["task_id"] + + # Try to cancel (may or may not succeed depending on timing) + response = client.post(f"/api/v1/tasks/{task_id}/cancel") + # Either cancelled or 400 (already completed) + assert response.status_code in (200, 400) + + def test_list_tasks(self, client, skill_registry): + agent_name = self._register_skill_and_create_agent(client, skill_registry) + # Submit an async task to ensure at least one exists + client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "list test"}, + "agent_name": agent_name, + "mode": "async", + }, + ) + response = client.get("/api/v1/tasks") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + def test_list_tasks_filter_by_status(self, client, skill_registry): + agent_name = self._register_skill_and_create_agent(client, skill_registry) + # Submit an async task + client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "filter test"}, + "agent_name": agent_name, + "mode": "async", + }, + ) + response = client.get("/api/v1/tasks?status=completed") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + # All returned tasks should be completed + for task in data: + assert task["status"] == "completed" + + def test_sync_mode_still_works(self, client, skill_registry): + """Ensure existing sync mode is not broken""" + agent_name = self._register_skill_and_create_agent(client, skill_registry) + response = client.post( + "/api/v1/tasks", + json={ + "input_data": {"query": "sync test"}, + "agent_name": agent_name, + }, + ) + assert response.status_code == 200 + data = response.json() + # Sync mode returns task_id and output + assert "task_id" in data