feat(server): Phase D - async task system (TaskStore + BackgroundRunner + API)

U5: TaskStore - in-memory task state with TTL cleanup and max records
U6: BackgroundRunner - async task execution with semaphore concurrency control
U7: Task status/result API + cancel endpoint + async submit mode

45 tests passing (28 new + 17 existing, no regression).
This commit is contained in:
chiguyong 2026-06-06 11:39:41 +08:00
parent 5f1c51cf9a
commit ec0e221beb
6 changed files with 934 additions and 9 deletions

View File

@ -1,6 +1,8 @@
"""FastAPI Application Factory"""
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
@ -13,6 +15,18 @@ from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
from agentkit.server.routes import agents, tasks, skills, llm, health
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
from agentkit.server.task_store import TaskStore
from agentkit.server.runner import BackgroundRunner
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
task_store = app.state.task_store
await task_store.start_cleanup()
yield
# Shutdown
await task_store.stop_cleanup()
def create_app(
@ -23,7 +37,7 @@ def create_app(
rate_limit: int | None = None,
) -> FastAPI:
"""Create and configure the FastAPI application"""
app = FastAPI(title="AgentKit Server", version="2.0.0")
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
# CORS 配置
app.add_middleware(
@ -55,6 +69,8 @@ def create_app(
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
app.state.quality_gate = QualityGate()
app.state.output_standardizer = OutputStandardizer()
app.state.task_store = TaskStore()
app.state.runner = BackgroundRunner(task_store=app.state.task_store)
# Include routes
app.include_router(agents.router, prefix="/api/v1")

View File

@ -87,6 +87,45 @@ class AgentKitClient:
response.raise_for_status()
return response.json()
async def submit_task_async(
self,
input_data: dict,
skill_name: str | None = None,
agent_name: str | None = None,
) -> dict:
"""Submit a task in async mode"""
payload: dict[str, Any] = {"input_data": input_data, "mode": "async"}
if skill_name:
payload["skill_name"] = skill_name
if agent_name:
payload["agent_name"] = agent_name
response = await self._client.post("/api/v1/tasks", json=payload)
response.raise_for_status()
return response.json()
async def get_task_status(self, task_id: str) -> dict:
"""Get task status"""
response = await self._client.get(f"/api/v1/tasks/{task_id}")
response.raise_for_status()
return response.json()
async def cancel_task(self, task_id: str) -> dict:
"""Cancel a running task"""
response = await self._client.post(f"/api/v1/tasks/{task_id}/cancel")
response.raise_for_status()
return response.json()
async def list_tasks(
self, status: str | None = None, limit: int = 100
) -> list[dict]:
"""List tasks"""
params: dict[str, Any] = {"limit": limit}
if status:
params["status"] = status
response = await self._client.get("/api/v1/tasks", params=params)
response.raise_for_status()
return response.json()
async def close(self) -> None:
"""Close the HTTP client"""
await self._client.aclose()

View File

@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from typing import Any
from agentkit.core.protocol import TaskMessage
from agentkit.core.protocol import TaskMessage, TaskStatus
router = APIRouter(tags=["tasks"])
@ -16,6 +16,7 @@ class SubmitTaskRequest(BaseModel):
input_data: dict[str, Any]
skill_name: str | None = None
agent_name: str | None = None
mode: str = "sync" # "sync" or "async"
# 输入数据大小限制(防止 OOM
model_config = {"json_schema_extra": {"max_input_size_bytes": 1024 * 1024}} # 1MB
@ -39,6 +40,15 @@ def _validate_input_size(input_data: dict) -> None:
)
@router.get("/tasks")
async def list_tasks(status: str | None = None, limit: int = 100, req: Request = None):
"""List tasks"""
store = req.app.state.task_store
task_status = TaskStatus(status) if status else None
records = store.list_tasks(status=task_status, limit=limit)
return [r.to_dict() for r in records]
@router.post("/tasks")
async def submit_task(request: SubmitTaskRequest, req: Request):
"""Submit a task (Intent Router auto-routes to skill)"""
@ -98,7 +108,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
except (ValueError, RuntimeError) as e:
raise HTTPException(status_code=400, detail=str(e))
# 4. Execute task
# 4. Async mode: submit to background runner
if request.mode == "async":
runner = req.app.state.runner
task_id = await runner.submit(
agent=agent,
input_data=request.input_data,
skill_name=request.skill_name,
quality_gate=quality_gate,
output_standardizer=output_standardizer,
skill=skill,
)
return {"task_id": task_id, "status": "pending", "mode": "async"}
# 5. Sync mode: existing blocking execution
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=agent.name,
@ -111,7 +134,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
task_result = await agent.execute(task)
# 5. Run quality gate if skill available
# 6. Run quality gate if skill available
quality_result = None
if skill:
try:
@ -119,7 +142,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
except Exception:
pass # Quality gate failure shouldn't block the response
# 6. Standardize output if skill available
# 7. Standardize output if skill available
if skill:
try:
standard_output = await output_standardizer.standardize(
@ -141,7 +164,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
except Exception:
pass # Fall through to raw output
# 7. Return raw result if no skill or standardization failed
# 8. Return raw result if no skill or standardization failed
return {
"task_id": task.task_id,
"status": task_result.status,
@ -151,6 +174,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
"""Get task status (placeholder for async mode)"""
return {"task_id": task_id, "status": "placeholder"}
async def get_task_status(task_id: str, req: Request):
"""Get task status and result"""
store = req.app.state.task_store
record = store.get(task_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
return record.to_dict()
@router.post("/tasks/{task_id}/cancel")
async def cancel_task(task_id: str, req: Request):
"""Cancel a running task"""
runner = req.app.state.runner
cancelled = await runner.cancel(task_id)
if not cancelled:
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
return {"task_id": task_id, "status": "cancelled"}

View File

@ -0,0 +1,170 @@
"""BackgroundRunner - Async task execution with lifecycle management"""
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.server.task_store import TaskStore
logger = logging.getLogger(__name__)
class BackgroundRunner:
"""Runs tasks in background asyncio tasks with lifecycle management.
Integrates with AgentPool for agent execution and TaskStore for state tracking.
"""
def __init__(self, task_store: TaskStore, max_concurrent: int = 10):
self._task_store = task_store
self._max_concurrent = max_concurrent
self._running_tasks: dict[str, asyncio.Task] = {}
self._semaphore = asyncio.Semaphore(max_concurrent)
@property
def active_count(self) -> int:
return len(self._running_tasks)
async def submit(
self,
agent, # ConfigDrivenAgent
input_data: dict[str, Any],
skill_name: str | None = None,
quality_gate=None,
output_standardizer=None,
skill=None,
) -> str:
"""Submit a task for background execution.
Returns task_id immediately.
"""
task_id = str(uuid.uuid4())
# Create task record
self._task_store.create(
task_id=task_id,
agent_name=agent.name,
input_data=input_data,
skill_name=skill_name,
)
# Launch background asyncio task
asyncio_task = asyncio.create_task(
self._run_task(
task_id=task_id,
agent=agent,
input_data=input_data,
quality_gate=quality_gate,
output_standardizer=output_standardizer,
skill=skill,
)
)
self._running_tasks[task_id] = asyncio_task
# Clean up reference when done
def _on_done(t: asyncio.Task):
self._running_tasks.pop(task_id, None)
if t.exception():
logger.error(f"Background task {task_id} failed: {t.exception()}")
asyncio_task.add_done_callback(_on_done)
return task_id
async def _run_task(
self,
task_id: str,
agent,
input_data: dict,
quality_gate=None,
output_standardizer=None,
skill=None,
) -> dict[str, Any]:
"""Execute task in background with semaphore control"""
async with self._semaphore:
# Update status to RUNNING
self._task_store.update_status(
task_id, TaskStatus.RUNNING,
started_at=datetime.now(timezone.utc),
)
try:
# Create TaskMessage for agent
task_msg = TaskMessage(
task_id=task_id,
agent_name=agent.name,
task_type=agent.agent_type,
priority=0,
input_data=input_data,
callback_url=None,
created_at=datetime.now(timezone.utc),
)
# Execute agent
task_result = await agent.execute(task_msg)
# Run quality gate if available
quality_result = None
if skill and quality_gate:
try:
quality_result = await quality_gate.validate(
task_result.output_data or {}, skill
)
except Exception as e:
logger.warning(f"Quality gate failed for {task_id}: {e}")
# Standardize output if available
final_output = task_result.output_data
if skill and output_standardizer:
try:
standard_output = await output_standardizer.standardize(
raw_output=task_result.output_data or {},
skill=skill,
quality_result=quality_result,
)
final_output = {
"skill_name": standard_output.skill_name,
"data": standard_output.data,
"metadata": {
"version": standard_output.metadata.version,
"produced_at": standard_output.metadata.produced_at.isoformat(),
"quality_score": standard_output.metadata.quality_score,
},
}
except Exception as e:
logger.warning(f"Output standardization failed for {task_id}: {e}")
# Update store
self._task_store.update_status(
task_id, TaskStatus.COMPLETED,
output_data=final_output,
completed_at=datetime.now(timezone.utc),
progress=1.0,
progress_message="Completed",
)
return final_output or {}
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
self._task_store.update_status(
task_id, TaskStatus.FAILED,
error_message=str(e),
completed_at=datetime.now(timezone.utc),
)
raise
async def cancel(self, task_id: str) -> bool:
"""Cancel a running task"""
asyncio_task = self._running_tasks.get(task_id)
if asyncio_task and not asyncio_task.done():
asyncio_task.cancel()
self._task_store.update_status(
task_id, TaskStatus.CANCELLED,
completed_at=datetime.now(timezone.utc),
)
return True
return False

View File

@ -0,0 +1,151 @@
"""TaskStore - In-memory task state storage with TTL"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from agentkit.core.protocol import TaskStatus
logger = logging.getLogger(__name__)
@dataclass
class TaskRecord:
"""Stored task record with full lifecycle data"""
task_id: str
agent_name: str
skill_name: str | None
input_data: dict[str, Any]
status: TaskStatus = TaskStatus.PENDING
output_data: dict[str, Any] | None = None
error_message: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
started_at: datetime | None = None
completed_at: datetime | None = None
progress: float = 0.0
progress_message: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"agent_name": self.agent_name,
"skill_name": self.skill_name,
"input_data": self.input_data,
"status": self.status.value,
"output_data": self.output_data,
"error_message": self.error_message,
"created_at": self.created_at.isoformat(),
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"progress": self.progress,
"progress_message": self.progress_message,
"metadata": self.metadata,
}
class TaskStore:
"""In-memory task state storage with automatic TTL cleanup.
Stores task records indexed by task_id. Automatically removes
completed tasks after a configurable TTL.
"""
def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000):
self._tasks: dict[str, TaskRecord] = {}
self._ttl_seconds = ttl_seconds
self._max_records = max_records
self._cleanup_task: asyncio.Task | None = None
async def start_cleanup(self) -> None:
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop_cleanup(self) -> None:
"""Stop background cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self) -> None:
"""Periodically remove expired task records"""
while True:
try:
await asyncio.sleep(60)
self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"TaskStore cleanup error: {e}")
def _cleanup_expired(self) -> None:
"""Remove expired records"""
expired = []
for task_id, record in self._tasks.items():
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
if record.completed_at:
age = (datetime.now(timezone.utc) - record.completed_at).total_seconds()
if age > self._ttl_seconds:
expired.append(task_id)
for task_id in expired:
del self._tasks[task_id]
if expired:
logger.info(f"TaskStore cleaned up {len(expired)} expired records")
def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
"""Create a new task record"""
if len(self._tasks) >= self._max_records:
# Remove oldest completed task
oldest = None
for tid, rec in self._tasks.items():
if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)):
oldest = rec
if oldest:
del self._tasks[oldest.task_id]
else:
raise RuntimeError("TaskStore is full and no completed tasks to evict")
record = TaskRecord(
task_id=task_id,
agent_name=agent_name,
skill_name=skill_name,
input_data=input_data,
)
self._tasks[task_id] = record
return record
def get(self, task_id: str) -> TaskRecord | None:
"""Get task record by ID"""
return self._tasks.get(task_id)
def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
"""Update task status and optional fields"""
record = self._tasks.get(task_id)
if record is None:
raise KeyError(f"Task '{task_id}' not found")
record.status = status
for key, value in kwargs.items():
if hasattr(record, key):
setattr(record, key, value)
return record
def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
"""List tasks, optionally filtered by status"""
tasks = list(self._tasks.values())
if status:
tasks = [t for t in tasks if t.status == status]
tasks.sort(key=lambda t: t.created_at, reverse=True)
return tasks[:limit]
@property
def size(self) -> int:
return len(self._tasks)

View File

@ -0,0 +1,512 @@
"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API"""
import asyncio
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.server.task_store import TaskRecord, TaskStore
from agentkit.server.runner import BackgroundRunner
# ═══════════════════════════════════════════════════════════
# TaskStore Tests
# ═══════════════════════════════════════════════════════════
class TestTaskRecord:
"""TaskRecord dataclass tests"""
def test_to_dict_returns_complete_dict(self):
record = TaskRecord(
task_id="t1",
agent_name="agent_a",
skill_name="skill_x",
input_data={"query": "hello"},
)
d = record.to_dict()
assert d["task_id"] == "t1"
assert d["agent_name"] == "agent_a"
assert d["skill_name"] == "skill_x"
assert d["input_data"] == {"query": "hello"}
assert d["status"] == "pending"
assert d["output_data"] is None
assert d["error_message"] is None
assert d["progress"] == 0.0
assert d["created_at"] is not None
def test_to_dict_with_timestamps(self):
now = datetime.now(timezone.utc)
record = TaskRecord(
task_id="t2",
agent_name="agent_b",
skill_name=None,
input_data={},
started_at=now,
completed_at=now,
)
d = record.to_dict()
assert d["started_at"] == now.isoformat()
assert d["completed_at"] == now.isoformat()
class TestTaskStore:
"""TaskStore in-memory storage tests"""
def test_create_task_record_stored_correctly(self):
store = TaskStore()
record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
assert record.task_id == "t1"
assert record.agent_name == "agent_a"
assert record.skill_name == "skill_x"
assert record.input_data == {"q": "hello"}
assert record.status == TaskStatus.PENDING
def test_get_task_by_id_returns_record(self):
store = TaskStore()
store.create("t1", "agent_a", {})
record = store.get("t1")
assert record is not None
assert record.task_id == "t1"
def test_get_nonexistent_task_returns_none(self):
store = TaskStore()
assert store.get("nonexistent") is None
def test_update_status_fields_updated(self):
store = TaskStore()
store.create("t1", "agent_a", {})
now = datetime.now(timezone.utc)
record = store.update_status(
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway"
)
assert record.status == TaskStatus.RUNNING
assert record.started_at == now
assert record.progress == 0.5
assert record.progress_message == "Halfway"
def test_update_nonexistent_task_raises_keyerror(self):
store = TaskStore()
with pytest.raises(KeyError, match="not found"):
store.update_status("nonexistent", TaskStatus.RUNNING)
def test_list_tasks_returns_all_sorted_desc(self):
store = TaskStore()
store.create("t1", "agent_a", {})
store.create("t2", "agent_b", {})
tasks = store.list_tasks()
assert len(tasks) == 2
# Most recent first
assert tasks[0].task_id == "t2"
assert tasks[1].task_id == "t1"
def test_list_tasks_filtered_by_status(self):
store = TaskStore()
store.create("t1", "agent_a", {})
store.create("t2", "agent_b", {})
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
tasks = store.list_tasks(status=TaskStatus.COMPLETED)
assert len(tasks) == 1
assert tasks[0].task_id == "t1"
def test_max_records_limit_evicts_oldest_completed(self):
store = TaskStore(max_records=2)
store.create("t1", "agent_a", {})
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
store.create("t2", "agent_b", {})
# t3 should evict t1 (oldest completed)
store.create("t3", "agent_c", {})
assert store.get("t1") is None
assert store.get("t2") is not None
assert store.get("t3") is not None
def test_max_records_full_no_completed_raises(self):
store = TaskStore(max_records=1)
store.create("t1", "agent_a", {})
# All tasks are PENDING, no completed to evict
with pytest.raises(RuntimeError, match="full"):
store.create("t2", "agent_b", {})
def test_ttl_cleanup_removes_expired_completed(self):
store = TaskStore(ttl_seconds=0) # Immediate expiry
store.create("t1", "agent_a", {})
store.update_status(
"t1", TaskStatus.COMPLETED,
completed_at=datetime.now(timezone.utc) - timedelta(seconds=10),
)
store.create("t2", "agent_b", {})
# t2 is PENDING, should not be cleaned
store._cleanup_expired()
assert store.get("t1") is None # Expired completed
assert store.get("t2") is not None # Pending stays
def test_size_property_correct_count(self):
store = TaskStore()
assert store.size == 0
store.create("t1", "agent_a", {})
assert store.size == 1
store.create("t2", "agent_b", {})
assert store.size == 2
def test_list_tasks_respects_limit(self):
store = TaskStore()
for i in range(5):
store.create(f"t{i}", "agent_a", {})
tasks = store.list_tasks(limit=3)
assert len(tasks) == 3
# ═══════════════════════════════════════════════════════════
# BackgroundRunner Tests
# ═══════════════════════════════════════════════════════════
class TestBackgroundRunner:
"""BackgroundRunner async task execution tests"""
@pytest.fixture
def task_store(self):
return TaskStore()
@pytest.fixture
def runner(self, task_store):
return BackgroundRunner(task_store=task_store, max_concurrent=5)
def _make_mock_agent(self, name="test_agent", output=None, raise_error=None):
"""Create a mock agent for testing"""
agent = MagicMock()
agent.name = name
agent.agent_type = "test_type"
if raise_error:
agent.execute = AsyncMock(side_effect=raise_error)
else:
task_result = TaskResult(
task_id="mock",
agent_name=name,
status="completed",
output_data=output or {"result": "ok"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent.execute = AsyncMock(return_value=task_result)
return agent
@pytest.mark.asyncio
async def test_submit_returns_task_id_immediately(self, runner, task_store):
agent = self._make_mock_agent()
task_id = await runner.submit(agent, {"query": "test"})
assert task_id is not None
assert isinstance(task_id, str)
# Task record should exist in store
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.PENDING
@pytest.mark.asyncio
async def test_submit_task_runs_to_completion(self, runner, task_store):
agent = self._make_mock_agent(output={"answer": "42"})
task_id = await runner.submit(agent, {"query": "meaning of life"})
# Wait for task to complete
await asyncio.sleep(0.1)
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.COMPLETED
assert record.output_data == {"answer": "42"}
assert record.progress == 1.0
@pytest.mark.asyncio
async def test_submit_task_failure_recorded(self, runner, task_store):
agent = self._make_mock_agent(raise_error=RuntimeError("boom"))
task_id = await runner.submit(agent, {"query": "fail"})
# Wait for task to fail
await asyncio.sleep(0.1)
record = task_store.get(task_id)
assert record is not None
assert record.status == TaskStatus.FAILED
assert "boom" in record.error_message
@pytest.mark.asyncio
async def test_cancel_running_task(self, runner, task_store):
async def slow_execute(msg):
await asyncio.sleep(10) # Long running
return TaskResult(
task_id=msg.task_id,
agent_name="test_agent",
status="completed",
output_data={"result": "done"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent = MagicMock()
agent.name = "slow_agent"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=slow_execute)
task_id = await runner.submit(agent, {"query": "slow"})
# Give it a moment to start
await asyncio.sleep(0.05)
cancelled = await runner.cancel(task_id)
assert cancelled is True
record = task_store.get(task_id)
assert record.status == TaskStatus.CANCELLED
@pytest.mark.asyncio
async def test_cancel_non_running_task_returns_false(self, runner, task_store):
result = await runner.cancel("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_concurrent_tasks_respects_semaphore(self, task_store):
runner = BackgroundRunner(task_store=task_store, max_concurrent=2)
execution_order = []
async def tracked_execute(msg):
execution_order.append(f"start:{msg.task_id}")
await asyncio.sleep(0.1)
execution_order.append(f"end:{msg.task_id}")
return TaskResult(
task_id=msg.task_id,
agent_name="test",
status="completed",
output_data={},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agents = []
for i in range(4):
agent = MagicMock()
agent.name = f"agent_{i}"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=tracked_execute)
agents.append(agent)
# Submit all 4 tasks
task_ids = []
for agent in agents:
tid = await runner.submit(agent, {"idx": agents.index(agent)})
task_ids.append(tid)
# Wait for all to complete
await asyncio.sleep(0.5)
# All tasks should have completed
for tid in task_ids:
record = task_store.get(tid)
assert record.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_active_count_tracks_running(self, task_store):
runner = BackgroundRunner(task_store=task_store, max_concurrent=10)
async def slow_execute(msg):
await asyncio.sleep(0.2)
return TaskResult(
task_id=msg.task_id,
agent_name="test",
status="completed",
output_data={},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
agent = MagicMock()
agent.name = "slow_agent"
agent.agent_type = "test_type"
agent.execute = AsyncMock(side_effect=slow_execute)
await runner.submit(agent, {})
await asyncio.sleep(0.05)
assert runner.active_count >= 1
await asyncio.sleep(0.3)
assert runner.active_count == 0
# ═══════════════════════════════════════════════════════════
# API Tests (using TestClient)
# ═══════════════════════════════════════════════════════════
class TestAsyncTaskAPI:
"""Async task API endpoint tests"""
@pytest.fixture
def mock_llm_gateway(self):
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = LLMGateway()
mock_provider = AsyncMock()
mock_provider.chat.return_value = LLMResponse(
content='{"result": "mocked"}',
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
return gateway
@pytest.fixture
def skill_registry(self):
from agentkit.skills.registry import SkillRegistry
return SkillRegistry()
@pytest.fixture
def tool_registry(self):
from agentkit.tools.registry import ToolRegistry
return ToolRegistry()
@pytest.fixture
def app(self, mock_llm_gateway, skill_registry, tool_registry):
from agentkit.server.app import create_app
return create_app(
llm_gateway=mock_llm_gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
@pytest.fixture
def client(self, app):
return TestClient(app)
def _register_skill_and_create_agent(self, client, skill_registry):
"""Helper: register a skill and create an agent for it"""
from agentkit.skills.base import Skill, SkillConfig
skill_config = SkillConfig(
name="async_skill",
agent_type="async_type",
task_mode="llm_generate",
prompt={"identity": "Async Skill", "instructions": "Handle async"},
intent={"keywords": ["async"], "description": "Async skill"},
)
skill = Skill(config=skill_config)
skill_registry.register(skill)
# Create agent
resp = client.post(
"/api/v1/agents",
json={"skill_name": "async_skill"},
)
assert resp.status_code == 201
return "async_skill"
def test_submit_task_async_returns_task_id(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
response = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "async test"},
"agent_name": agent_name,
"mode": "async",
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
assert data["status"] == "pending"
assert data["mode"] == "async"
def test_get_task_status_returns_record(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit async task
submit_resp = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "status test"},
"agent_name": agent_name,
"mode": "async",
},
)
task_id = submit_resp.json()["task_id"]
# Wait a bit for completion
import time
time.sleep(0.3)
# Get status
response = client.get(f"/api/v1/tasks/{task_id}")
assert response.status_code == 200
data = response.json()
assert data["task_id"] == task_id
assert data["status"] in ("completed", "running", "pending")
def test_get_task_status_not_found_404(self, client):
response = client.get("/api/v1/tasks/nonexistent-id")
assert response.status_code == 404
def test_cancel_task(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit async task
submit_resp = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "cancel test"},
"agent_name": agent_name,
"mode": "async",
},
)
task_id = submit_resp.json()["task_id"]
# Try to cancel (may or may not succeed depending on timing)
response = client.post(f"/api/v1/tasks/{task_id}/cancel")
# Either cancelled or 400 (already completed)
assert response.status_code in (200, 400)
def test_list_tasks(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit an async task to ensure at least one exists
client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "list test"},
"agent_name": agent_name,
"mode": "async",
},
)
response = client.get("/api/v1/tasks")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
def test_list_tasks_filter_by_status(self, client, skill_registry):
agent_name = self._register_skill_and_create_agent(client, skill_registry)
# Submit an async task
client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "filter test"},
"agent_name": agent_name,
"mode": "async",
},
)
response = client.get("/api/v1/tasks?status=completed")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# All returned tasks should be completed
for task in data:
assert task["status"] == "completed"
def test_sync_mode_still_works(self, client, skill_registry):
"""Ensure existing sync mode is not broken"""
agent_name = self._register_skill_and_create_agent(client, skill_registry)
response = client.post(
"/api/v1/tasks",
json={
"input_data": {"query": "sync test"},
"agent_name": agent_name,
},
)
assert response.status_code == 200
data = response.json()
# Sync mode returns task_id and output
assert "task_id" in data