feat(server): Phase D - async task system (TaskStore + BackgroundRunner + API)
U5: TaskStore - in-memory task state with TTL cleanup and max records U6: BackgroundRunner - async task execution with semaphore concurrency control U7: Task status/result API + cancel endpoint + async submit mode 45 tests passing (28 new + 17 existing, no regression).
This commit is contained in:
parent
5f1c51cf9a
commit
ec0e221beb
|
|
@ -1,6 +1,8 @@
|
||||||
"""FastAPI Application Factory"""
|
"""FastAPI Application Factory"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
@ -13,6 +15,18 @@ from agentkit.skills.registry import SkillRegistry
|
||||||
from agentkit.tools.registry import ToolRegistry
|
from agentkit.tools.registry import ToolRegistry
|
||||||
from agentkit.server.routes import agents, tasks, skills, llm, health
|
from agentkit.server.routes import agents, tasks, skills, llm, health
|
||||||
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||||
|
from agentkit.server.task_store import TaskStore
|
||||||
|
from agentkit.server.runner import BackgroundRunner
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Startup
|
||||||
|
task_store = app.state.task_store
|
||||||
|
await task_store.start_cleanup()
|
||||||
|
yield
|
||||||
|
# Shutdown
|
||||||
|
await task_store.stop_cleanup()
|
||||||
|
|
||||||
|
|
||||||
def create_app(
|
def create_app(
|
||||||
|
|
@ -23,7 +37,7 @@ def create_app(
|
||||||
rate_limit: int | None = None,
|
rate_limit: int | None = None,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Create and configure the FastAPI application"""
|
"""Create and configure the FastAPI application"""
|
||||||
app = FastAPI(title="AgentKit Server", version="2.0.0")
|
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
||||||
|
|
||||||
# CORS 配置
|
# CORS 配置
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
|
@ -55,6 +69,8 @@ def create_app(
|
||||||
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
||||||
app.state.quality_gate = QualityGate()
|
app.state.quality_gate = QualityGate()
|
||||||
app.state.output_standardizer = OutputStandardizer()
|
app.state.output_standardizer = OutputStandardizer()
|
||||||
|
app.state.task_store = TaskStore()
|
||||||
|
app.state.runner = BackgroundRunner(task_store=app.state.task_store)
|
||||||
|
|
||||||
# Include routes
|
# Include routes
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,45 @@ class AgentKitClient:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
async def submit_task_async(
|
||||||
|
self,
|
||||||
|
input_data: dict,
|
||||||
|
skill_name: str | None = None,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Submit a task in async mode"""
|
||||||
|
payload: dict[str, Any] = {"input_data": input_data, "mode": "async"}
|
||||||
|
if skill_name:
|
||||||
|
payload["skill_name"] = skill_name
|
||||||
|
if agent_name:
|
||||||
|
payload["agent_name"] = agent_name
|
||||||
|
response = await self._client.post("/api/v1/tasks", json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> dict:
|
||||||
|
"""Get task status"""
|
||||||
|
response = await self._client.get(f"/api/v1/tasks/{task_id}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: str) -> dict:
|
||||||
|
"""Cancel a running task"""
|
||||||
|
response = await self._client.post(f"/api/v1/tasks/{task_id}/cancel")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def list_tasks(
|
||||||
|
self, status: str | None = None, limit: int = 100
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List tasks"""
|
||||||
|
params: dict[str, Any] = {"limit": limit}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
response = await self._client.get("/api/v1/tasks", params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the HTTP client"""
|
"""Close the HTTP client"""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from fastapi import APIRouter, HTTPException, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.core.protocol import TaskMessage
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
|
||||||
router = APIRouter(tags=["tasks"])
|
router = APIRouter(tags=["tasks"])
|
||||||
|
|
||||||
|
|
@ -16,6 +16,7 @@ class SubmitTaskRequest(BaseModel):
|
||||||
input_data: dict[str, Any]
|
input_data: dict[str, Any]
|
||||||
skill_name: str | None = None
|
skill_name: str | None = None
|
||||||
agent_name: str | None = None
|
agent_name: str | None = None
|
||||||
|
mode: str = "sync" # "sync" or "async"
|
||||||
|
|
||||||
# 输入数据大小限制(防止 OOM)
|
# 输入数据大小限制(防止 OOM)
|
||||||
model_config = {"json_schema_extra": {"max_input_size_bytes": 1024 * 1024}} # 1MB
|
model_config = {"json_schema_extra": {"max_input_size_bytes": 1024 * 1024}} # 1MB
|
||||||
|
|
@ -39,6 +40,15 @@ def _validate_input_size(input_data: dict) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks")
|
||||||
|
async def list_tasks(status: str | None = None, limit: int = 100, req: Request = None):
|
||||||
|
"""List tasks"""
|
||||||
|
store = req.app.state.task_store
|
||||||
|
task_status = TaskStatus(status) if status else None
|
||||||
|
records = store.list_tasks(status=task_status, limit=limit)
|
||||||
|
return [r.to_dict() for r in records]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tasks")
|
@router.post("/tasks")
|
||||||
async def submit_task(request: SubmitTaskRequest, req: Request):
|
async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
"""Submit a task (Intent Router auto-routes to skill)"""
|
"""Submit a task (Intent Router auto-routes to skill)"""
|
||||||
|
|
@ -98,7 +108,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
except (ValueError, RuntimeError) as e:
|
except (ValueError, RuntimeError) as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
# 4. Execute task
|
# 4. Async mode: submit to background runner
|
||||||
|
if request.mode == "async":
|
||||||
|
runner = req.app.state.runner
|
||||||
|
task_id = await runner.submit(
|
||||||
|
agent=agent,
|
||||||
|
input_data=request.input_data,
|
||||||
|
skill_name=request.skill_name,
|
||||||
|
quality_gate=quality_gate,
|
||||||
|
output_standardizer=output_standardizer,
|
||||||
|
skill=skill,
|
||||||
|
)
|
||||||
|
return {"task_id": task_id, "status": "pending", "mode": "async"}
|
||||||
|
|
||||||
|
# 5. Sync mode: existing blocking execution
|
||||||
task = TaskMessage(
|
task = TaskMessage(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
|
|
@ -111,7 +134,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
|
|
||||||
task_result = await agent.execute(task)
|
task_result = await agent.execute(task)
|
||||||
|
|
||||||
# 5. Run quality gate if skill available
|
# 6. Run quality gate if skill available
|
||||||
quality_result = None
|
quality_result = None
|
||||||
if skill:
|
if skill:
|
||||||
try:
|
try:
|
||||||
|
|
@ -119,7 +142,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Quality gate failure shouldn't block the response
|
pass # Quality gate failure shouldn't block the response
|
||||||
|
|
||||||
# 6. Standardize output if skill available
|
# 7. Standardize output if skill available
|
||||||
if skill:
|
if skill:
|
||||||
try:
|
try:
|
||||||
standard_output = await output_standardizer.standardize(
|
standard_output = await output_standardizer.standardize(
|
||||||
|
|
@ -141,7 +164,7 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Fall through to raw output
|
pass # Fall through to raw output
|
||||||
|
|
||||||
# 7. Return raw result if no skill or standardization failed
|
# 8. Return raw result if no skill or standardization failed
|
||||||
return {
|
return {
|
||||||
"task_id": task.task_id,
|
"task_id": task.task_id,
|
||||||
"status": task_result.status,
|
"status": task_result.status,
|
||||||
|
|
@ -151,6 +174,20 @@ async def submit_task(request: SubmitTaskRequest, req: Request):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks/{task_id}")
|
@router.get("/tasks/{task_id}")
|
||||||
async def get_task_status(task_id: str):
|
async def get_task_status(task_id: str, req: Request):
|
||||||
"""Get task status (placeholder for async mode)"""
|
"""Get task status and result"""
|
||||||
return {"task_id": task_id, "status": "placeholder"}
|
store = req.app.state.task_store
|
||||||
|
record = store.get(task_id)
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Task '{task_id}' not found")
|
||||||
|
return record.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tasks/{task_id}/cancel")
|
||||||
|
async def cancel_task(task_id: str, req: Request):
|
||||||
|
"""Cancel a running task"""
|
||||||
|
runner = req.app.state.runner
|
||||||
|
cancelled = await runner.cancel(task_id)
|
||||||
|
if not cancelled:
|
||||||
|
raise HTTPException(status_code=400, detail="Task cannot be cancelled (not running or not found)")
|
||||||
|
return {"task_id": task_id, "status": "cancelled"}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""BackgroundRunner - Async task execution with lifecycle management"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
from agentkit.server.task_store import TaskStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BackgroundRunner:
|
||||||
|
"""Runs tasks in background asyncio tasks with lifecycle management.
|
||||||
|
|
||||||
|
Integrates with AgentPool for agent execution and TaskStore for state tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, task_store: TaskStore, max_concurrent: int = 10):
|
||||||
|
self._task_store = task_store
|
||||||
|
self._max_concurrent = max_concurrent
|
||||||
|
self._running_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self) -> int:
|
||||||
|
return len(self._running_tasks)
|
||||||
|
|
||||||
|
async def submit(
|
||||||
|
self,
|
||||||
|
agent, # ConfigDrivenAgent
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
skill_name: str | None = None,
|
||||||
|
quality_gate=None,
|
||||||
|
output_standardizer=None,
|
||||||
|
skill=None,
|
||||||
|
) -> str:
|
||||||
|
"""Submit a task for background execution.
|
||||||
|
|
||||||
|
Returns task_id immediately.
|
||||||
|
"""
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Create task record
|
||||||
|
self._task_store.create(
|
||||||
|
task_id=task_id,
|
||||||
|
agent_name=agent.name,
|
||||||
|
input_data=input_data,
|
||||||
|
skill_name=skill_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch background asyncio task
|
||||||
|
asyncio_task = asyncio.create_task(
|
||||||
|
self._run_task(
|
||||||
|
task_id=task_id,
|
||||||
|
agent=agent,
|
||||||
|
input_data=input_data,
|
||||||
|
quality_gate=quality_gate,
|
||||||
|
output_standardizer=output_standardizer,
|
||||||
|
skill=skill,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self._running_tasks[task_id] = asyncio_task
|
||||||
|
|
||||||
|
# Clean up reference when done
|
||||||
|
def _on_done(t: asyncio.Task):
|
||||||
|
self._running_tasks.pop(task_id, None)
|
||||||
|
if t.exception():
|
||||||
|
logger.error(f"Background task {task_id} failed: {t.exception()}")
|
||||||
|
|
||||||
|
asyncio_task.add_done_callback(_on_done)
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def _run_task(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
agent,
|
||||||
|
input_data: dict,
|
||||||
|
quality_gate=None,
|
||||||
|
output_standardizer=None,
|
||||||
|
skill=None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Execute task in background with semaphore control"""
|
||||||
|
async with self._semaphore:
|
||||||
|
# Update status to RUNNING
|
||||||
|
self._task_store.update_status(
|
||||||
|
task_id, TaskStatus.RUNNING,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create TaskMessage for agent
|
||||||
|
task_msg = TaskMessage(
|
||||||
|
task_id=task_id,
|
||||||
|
agent_name=agent.name,
|
||||||
|
task_type=agent.agent_type,
|
||||||
|
priority=0,
|
||||||
|
input_data=input_data,
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute agent
|
||||||
|
task_result = await agent.execute(task_msg)
|
||||||
|
|
||||||
|
# Run quality gate if available
|
||||||
|
quality_result = None
|
||||||
|
if skill and quality_gate:
|
||||||
|
try:
|
||||||
|
quality_result = await quality_gate.validate(
|
||||||
|
task_result.output_data or {}, skill
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Quality gate failed for {task_id}: {e}")
|
||||||
|
|
||||||
|
# Standardize output if available
|
||||||
|
final_output = task_result.output_data
|
||||||
|
if skill and output_standardizer:
|
||||||
|
try:
|
||||||
|
standard_output = await output_standardizer.standardize(
|
||||||
|
raw_output=task_result.output_data or {},
|
||||||
|
skill=skill,
|
||||||
|
quality_result=quality_result,
|
||||||
|
)
|
||||||
|
final_output = {
|
||||||
|
"skill_name": standard_output.skill_name,
|
||||||
|
"data": standard_output.data,
|
||||||
|
"metadata": {
|
||||||
|
"version": standard_output.metadata.version,
|
||||||
|
"produced_at": standard_output.metadata.produced_at.isoformat(),
|
||||||
|
"quality_score": standard_output.metadata.quality_score,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Output standardization failed for {task_id}: {e}")
|
||||||
|
|
||||||
|
# Update store
|
||||||
|
self._task_store.update_status(
|
||||||
|
task_id, TaskStatus.COMPLETED,
|
||||||
|
output_data=final_output,
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
progress=1.0,
|
||||||
|
progress_message="Completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_output or {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task {task_id} failed: {e}")
|
||||||
|
self._task_store.update_status(
|
||||||
|
task_id, TaskStatus.FAILED,
|
||||||
|
error_message=str(e),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cancel(self, task_id: str) -> bool:
|
||||||
|
"""Cancel a running task"""
|
||||||
|
asyncio_task = self._running_tasks.get(task_id)
|
||||||
|
if asyncio_task and not asyncio_task.done():
|
||||||
|
asyncio_task.cancel()
|
||||||
|
self._task_store.update_status(
|
||||||
|
task_id, TaskStatus.CANCELLED,
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
"""TaskStore - In-memory task state storage with TTL"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskRecord:
|
||||||
|
"""Stored task record with full lifecycle data"""
|
||||||
|
task_id: str
|
||||||
|
agent_name: str
|
||||||
|
skill_name: str | None
|
||||||
|
input_data: dict[str, Any]
|
||||||
|
status: TaskStatus = TaskStatus.PENDING
|
||||||
|
output_data: dict[str, Any] | None = None
|
||||||
|
error_message: str | None = None
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
started_at: datetime | None = None
|
||||||
|
completed_at: datetime | None = None
|
||||||
|
progress: float = 0.0
|
||||||
|
progress_message: str = ""
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"skill_name": self.skill_name,
|
||||||
|
"input_data": self.input_data,
|
||||||
|
"status": self.status.value,
|
||||||
|
"output_data": self.output_data,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
"progress": self.progress,
|
||||||
|
"progress_message": self.progress_message,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStore:
|
||||||
|
"""In-memory task state storage with automatic TTL cleanup.
|
||||||
|
|
||||||
|
Stores task records indexed by task_id. Automatically removes
|
||||||
|
completed tasks after a configurable TTL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000):
|
||||||
|
self._tasks: dict[str, TaskRecord] = {}
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
self._max_records = max_records
|
||||||
|
self._cleanup_task: asyncio.Task | None = None
|
||||||
|
|
||||||
|
async def start_cleanup(self) -> None:
|
||||||
|
"""Start background cleanup task"""
|
||||||
|
if self._cleanup_task is None:
|
||||||
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
|
||||||
|
async def stop_cleanup(self) -> None:
|
||||||
|
"""Stop background cleanup task"""
|
||||||
|
if self._cleanup_task:
|
||||||
|
self._cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._cleanup_task = None
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""Periodically remove expired task records"""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
self._cleanup_expired()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"TaskStore cleanup error: {e}")
|
||||||
|
|
||||||
|
def _cleanup_expired(self) -> None:
|
||||||
|
"""Remove expired records"""
|
||||||
|
expired = []
|
||||||
|
for task_id, record in self._tasks.items():
|
||||||
|
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||||
|
if record.completed_at:
|
||||||
|
age = (datetime.now(timezone.utc) - record.completed_at).total_seconds()
|
||||||
|
if age > self._ttl_seconds:
|
||||||
|
expired.append(task_id)
|
||||||
|
for task_id in expired:
|
||||||
|
del self._tasks[task_id]
|
||||||
|
if expired:
|
||||||
|
logger.info(f"TaskStore cleaned up {len(expired)} expired records")
|
||||||
|
|
||||||
|
def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
|
||||||
|
"""Create a new task record"""
|
||||||
|
if len(self._tasks) >= self._max_records:
|
||||||
|
# Remove oldest completed task
|
||||||
|
oldest = None
|
||||||
|
for tid, rec in self._tasks.items():
|
||||||
|
if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||||
|
if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)):
|
||||||
|
oldest = rec
|
||||||
|
if oldest:
|
||||||
|
del self._tasks[oldest.task_id]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("TaskStore is full and no completed tasks to evict")
|
||||||
|
|
||||||
|
record = TaskRecord(
|
||||||
|
task_id=task_id,
|
||||||
|
agent_name=agent_name,
|
||||||
|
skill_name=skill_name,
|
||||||
|
input_data=input_data,
|
||||||
|
)
|
||||||
|
self._tasks[task_id] = record
|
||||||
|
return record
|
||||||
|
|
||||||
|
def get(self, task_id: str) -> TaskRecord | None:
|
||||||
|
"""Get task record by ID"""
|
||||||
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
|
||||||
|
"""Update task status and optional fields"""
|
||||||
|
record = self._tasks.get(task_id)
|
||||||
|
if record is None:
|
||||||
|
raise KeyError(f"Task '{task_id}' not found")
|
||||||
|
record.status = status
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(record, key):
|
||||||
|
setattr(record, key, value)
|
||||||
|
return record
|
||||||
|
|
||||||
|
def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
|
||||||
|
"""List tasks, optionally filtered by status"""
|
||||||
|
tasks = list(self._tasks.values())
|
||||||
|
if status:
|
||||||
|
tasks = [t for t in tasks if t.status == status]
|
||||||
|
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
||||||
|
return tasks[:limit]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
return len(self._tasks)
|
||||||
|
|
@ -0,0 +1,512 @@
|
||||||
|
"""Async Task System 单元测试 - TaskStore + BackgroundRunner + API"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
|
from agentkit.server.task_store import TaskRecord, TaskStore
|
||||||
|
from agentkit.server.runner import BackgroundRunner
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
# TaskStore Tests
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskRecord:
|
||||||
|
"""TaskRecord dataclass tests"""
|
||||||
|
|
||||||
|
def test_to_dict_returns_complete_dict(self):
|
||||||
|
record = TaskRecord(
|
||||||
|
task_id="t1",
|
||||||
|
agent_name="agent_a",
|
||||||
|
skill_name="skill_x",
|
||||||
|
input_data={"query": "hello"},
|
||||||
|
)
|
||||||
|
d = record.to_dict()
|
||||||
|
assert d["task_id"] == "t1"
|
||||||
|
assert d["agent_name"] == "agent_a"
|
||||||
|
assert d["skill_name"] == "skill_x"
|
||||||
|
assert d["input_data"] == {"query": "hello"}
|
||||||
|
assert d["status"] == "pending"
|
||||||
|
assert d["output_data"] is None
|
||||||
|
assert d["error_message"] is None
|
||||||
|
assert d["progress"] == 0.0
|
||||||
|
assert d["created_at"] is not None
|
||||||
|
|
||||||
|
def test_to_dict_with_timestamps(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
record = TaskRecord(
|
||||||
|
task_id="t2",
|
||||||
|
agent_name="agent_b",
|
||||||
|
skill_name=None,
|
||||||
|
input_data={},
|
||||||
|
started_at=now,
|
||||||
|
completed_at=now,
|
||||||
|
)
|
||||||
|
d = record.to_dict()
|
||||||
|
assert d["started_at"] == now.isoformat()
|
||||||
|
assert d["completed_at"] == now.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskStore:
|
||||||
|
"""TaskStore in-memory storage tests"""
|
||||||
|
|
||||||
|
def test_create_task_record_stored_correctly(self):
|
||||||
|
store = TaskStore()
|
||||||
|
record = store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
|
||||||
|
assert record.task_id == "t1"
|
||||||
|
assert record.agent_name == "agent_a"
|
||||||
|
assert record.skill_name == "skill_x"
|
||||||
|
assert record.input_data == {"q": "hello"}
|
||||||
|
assert record.status == TaskStatus.PENDING
|
||||||
|
|
||||||
|
def test_get_task_by_id_returns_record(self):
|
||||||
|
store = TaskStore()
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
record = store.get("t1")
|
||||||
|
assert record is not None
|
||||||
|
assert record.task_id == "t1"
|
||||||
|
|
||||||
|
def test_get_nonexistent_task_returns_none(self):
|
||||||
|
store = TaskStore()
|
||||||
|
assert store.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_update_status_fields_updated(self):
|
||||||
|
store = TaskStore()
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
record = store.update_status(
|
||||||
|
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway"
|
||||||
|
)
|
||||||
|
assert record.status == TaskStatus.RUNNING
|
||||||
|
assert record.started_at == now
|
||||||
|
assert record.progress == 0.5
|
||||||
|
assert record.progress_message == "Halfway"
|
||||||
|
|
||||||
|
def test_update_nonexistent_task_raises_keyerror(self):
|
||||||
|
store = TaskStore()
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
store.update_status("nonexistent", TaskStatus.RUNNING)
|
||||||
|
|
||||||
|
def test_list_tasks_returns_all_sorted_desc(self):
|
||||||
|
store = TaskStore()
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
tasks = store.list_tasks()
|
||||||
|
assert len(tasks) == 2
|
||||||
|
# Most recent first
|
||||||
|
assert tasks[0].task_id == "t2"
|
||||||
|
assert tasks[1].task_id == "t1"
|
||||||
|
|
||||||
|
def test_list_tasks_filtered_by_status(self):
|
||||||
|
store = TaskStore()
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||||||
|
tasks = store.list_tasks(status=TaskStatus.COMPLETED)
|
||||||
|
assert len(tasks) == 1
|
||||||
|
assert tasks[0].task_id == "t1"
|
||||||
|
|
||||||
|
def test_max_records_limit_evicts_oldest_completed(self):
|
||||||
|
store = TaskStore(max_records=2)
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
# t3 should evict t1 (oldest completed)
|
||||||
|
store.create("t3", "agent_c", {})
|
||||||
|
assert store.get("t1") is None
|
||||||
|
assert store.get("t2") is not None
|
||||||
|
assert store.get("t3") is not None
|
||||||
|
|
||||||
|
def test_max_records_full_no_completed_raises(self):
|
||||||
|
store = TaskStore(max_records=1)
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
# All tasks are PENDING, no completed to evict
|
||||||
|
with pytest.raises(RuntimeError, match="full"):
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
|
||||||
|
def test_ttl_cleanup_removes_expired_completed(self):
|
||||||
|
store = TaskStore(ttl_seconds=0) # Immediate expiry
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
store.update_status(
|
||||||
|
"t1", TaskStatus.COMPLETED,
|
||||||
|
completed_at=datetime.now(timezone.utc) - timedelta(seconds=10),
|
||||||
|
)
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
# t2 is PENDING, should not be cleaned
|
||||||
|
store._cleanup_expired()
|
||||||
|
assert store.get("t1") is None # Expired completed
|
||||||
|
assert store.get("t2") is not None # Pending stays
|
||||||
|
|
||||||
|
def test_size_property_correct_count(self):
|
||||||
|
store = TaskStore()
|
||||||
|
assert store.size == 0
|
||||||
|
store.create("t1", "agent_a", {})
|
||||||
|
assert store.size == 1
|
||||||
|
store.create("t2", "agent_b", {})
|
||||||
|
assert store.size == 2
|
||||||
|
|
||||||
|
def test_list_tasks_respects_limit(self):
|
||||||
|
store = TaskStore()
|
||||||
|
for i in range(5):
|
||||||
|
store.create(f"t{i}", "agent_a", {})
|
||||||
|
tasks = store.list_tasks(limit=3)
|
||||||
|
assert len(tasks) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
# BackgroundRunner Tests
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackgroundRunner:
|
||||||
|
"""BackgroundRunner async task execution tests"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_store(self):
|
||||||
|
return TaskStore()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(self, task_store):
|
||||||
|
return BackgroundRunner(task_store=task_store, max_concurrent=5)
|
||||||
|
|
||||||
|
def _make_mock_agent(self, name="test_agent", output=None, raise_error=None):
|
||||||
|
"""Create a mock agent for testing"""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = name
|
||||||
|
agent.agent_type = "test_type"
|
||||||
|
if raise_error:
|
||||||
|
agent.execute = AsyncMock(side_effect=raise_error)
|
||||||
|
else:
|
||||||
|
task_result = TaskResult(
|
||||||
|
task_id="mock",
|
||||||
|
agent_name=name,
|
||||||
|
status="completed",
|
||||||
|
output_data=output or {"result": "ok"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
agent.execute = AsyncMock(return_value=task_result)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_returns_task_id_immediately(self, runner, task_store):
|
||||||
|
agent = self._make_mock_agent()
|
||||||
|
task_id = await runner.submit(agent, {"query": "test"})
|
||||||
|
assert task_id is not None
|
||||||
|
assert isinstance(task_id, str)
|
||||||
|
# Task record should exist in store
|
||||||
|
record = task_store.get(task_id)
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.PENDING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_task_runs_to_completion(self, runner, task_store):
|
||||||
|
agent = self._make_mock_agent(output={"answer": "42"})
|
||||||
|
task_id = await runner.submit(agent, {"query": "meaning of life"})
|
||||||
|
# Wait for task to complete
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
record = task_store.get(task_id)
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.COMPLETED
|
||||||
|
assert record.output_data == {"answer": "42"}
|
||||||
|
assert record.progress == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_task_failure_recorded(self, runner, task_store):
|
||||||
|
agent = self._make_mock_agent(raise_error=RuntimeError("boom"))
|
||||||
|
task_id = await runner.submit(agent, {"query": "fail"})
|
||||||
|
# Wait for task to fail
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
record = task_store.get(task_id)
|
||||||
|
assert record is not None
|
||||||
|
assert record.status == TaskStatus.FAILED
|
||||||
|
assert "boom" in record.error_message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_running_task(self, runner, task_store):
|
||||||
|
async def slow_execute(msg):
|
||||||
|
await asyncio.sleep(10) # Long running
|
||||||
|
return TaskResult(
|
||||||
|
task_id=msg.task_id,
|
||||||
|
agent_name="test_agent",
|
||||||
|
status="completed",
|
||||||
|
output_data={"result": "done"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = "slow_agent"
|
||||||
|
agent.agent_type = "test_type"
|
||||||
|
agent.execute = AsyncMock(side_effect=slow_execute)
|
||||||
|
|
||||||
|
task_id = await runner.submit(agent, {"query": "slow"})
|
||||||
|
# Give it a moment to start
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
cancelled = await runner.cancel(task_id)
|
||||||
|
assert cancelled is True
|
||||||
|
record = task_store.get(task_id)
|
||||||
|
assert record.status == TaskStatus.CANCELLED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_non_running_task_returns_false(self, runner, task_store):
|
||||||
|
result = await runner.cancel("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_tasks_respects_semaphore(self, task_store):
|
||||||
|
runner = BackgroundRunner(task_store=task_store, max_concurrent=2)
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
async def tracked_execute(msg):
|
||||||
|
execution_order.append(f"start:{msg.task_id}")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
execution_order.append(f"end:{msg.task_id}")
|
||||||
|
return TaskResult(
|
||||||
|
task_id=msg.task_id,
|
||||||
|
agent_name="test",
|
||||||
|
status="completed",
|
||||||
|
output_data={},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
agents = []
|
||||||
|
for i in range(4):
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = f"agent_{i}"
|
||||||
|
agent.agent_type = "test_type"
|
||||||
|
agent.execute = AsyncMock(side_effect=tracked_execute)
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
# Submit all 4 tasks
|
||||||
|
task_ids = []
|
||||||
|
for agent in agents:
|
||||||
|
tid = await runner.submit(agent, {"idx": agents.index(agent)})
|
||||||
|
task_ids.append(tid)
|
||||||
|
|
||||||
|
# Wait for all to complete
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# All tasks should have completed
|
||||||
|
for tid in task_ids:
|
||||||
|
record = task_store.get(tid)
|
||||||
|
assert record.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_count_tracks_running(self, task_store):
|
||||||
|
runner = BackgroundRunner(task_store=task_store, max_concurrent=10)
|
||||||
|
|
||||||
|
async def slow_execute(msg):
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
return TaskResult(
|
||||||
|
task_id=msg.task_id,
|
||||||
|
agent_name="test",
|
||||||
|
status="completed",
|
||||||
|
output_data={},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = "slow_agent"
|
||||||
|
agent.agent_type = "test_type"
|
||||||
|
agent.execute = AsyncMock(side_effect=slow_execute)
|
||||||
|
|
||||||
|
await runner.submit(agent, {})
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
assert runner.active_count >= 1
|
||||||
|
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
assert runner.active_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
# API Tests (using TestClient)
|
||||||
|
# ═══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTaskAPI:
|
||||||
|
"""Async task API endpoint tests"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_gateway(self):
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
|
|
||||||
|
gateway = LLMGateway()
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.chat.return_value = LLMResponse(
|
||||||
|
content='{"result": "mocked"}',
|
||||||
|
model="test-model",
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||||
|
)
|
||||||
|
gateway.register_provider("test", mock_provider)
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def skill_registry(self):
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
return SkillRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tool_registry(self):
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
return ToolRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(self, mock_llm_gateway, skill_registry, tool_registry):
|
||||||
|
from agentkit.server.app import create_app
|
||||||
|
return create_app(
|
||||||
|
llm_gateway=mock_llm_gateway,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
tool_registry=tool_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(self, app):
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
def _register_skill_and_create_agent(self, client, skill_registry):
|
||||||
|
"""Helper: register a skill and create an agent for it"""
|
||||||
|
from agentkit.skills.base import Skill, SkillConfig
|
||||||
|
|
||||||
|
skill_config = SkillConfig(
|
||||||
|
name="async_skill",
|
||||||
|
agent_type="async_type",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
prompt={"identity": "Async Skill", "instructions": "Handle async"},
|
||||||
|
intent={"keywords": ["async"], "description": "Async skill"},
|
||||||
|
)
|
||||||
|
skill = Skill(config=skill_config)
|
||||||
|
skill_registry.register(skill)
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents",
|
||||||
|
json={"skill_name": "async_skill"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
return "async_skill"
|
||||||
|
|
||||||
|
def test_submit_task_async_returns_task_id(self, client, skill_registry):
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "async test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"mode": "async",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "task_id" in data
|
||||||
|
assert data["status"] == "pending"
|
||||||
|
assert data["mode"] == "async"
|
||||||
|
|
||||||
|
def test_get_task_status_returns_record(self, client, skill_registry):
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
# Submit async task
|
||||||
|
submit_resp = client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "status test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"mode": "async",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
task_id = submit_resp.json()["task_id"]
|
||||||
|
|
||||||
|
# Wait a bit for completion
|
||||||
|
import time
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
# Get status
|
||||||
|
response = client.get(f"/api/v1/tasks/{task_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["task_id"] == task_id
|
||||||
|
assert data["status"] in ("completed", "running", "pending")
|
||||||
|
|
||||||
|
def test_get_task_status_not_found_404(self, client):
|
||||||
|
response = client.get("/api/v1/tasks/nonexistent-id")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
def test_cancel_task(self, client, skill_registry):
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
# Submit async task
|
||||||
|
submit_resp = client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "cancel test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"mode": "async",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
task_id = submit_resp.json()["task_id"]
|
||||||
|
|
||||||
|
# Try to cancel (may or may not succeed depending on timing)
|
||||||
|
response = client.post(f"/api/v1/tasks/{task_id}/cancel")
|
||||||
|
# Either cancelled or 400 (already completed)
|
||||||
|
assert response.status_code in (200, 400)
|
||||||
|
|
||||||
|
def test_list_tasks(self, client, skill_registry):
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
# Submit an async task to ensure at least one exists
|
||||||
|
client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "list test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"mode": "async",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = client.get("/api/v1/tasks")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
|
||||||
|
def test_list_tasks_filter_by_status(self, client, skill_registry):
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
# Submit an async task
|
||||||
|
client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "filter test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"mode": "async",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response = client.get("/api/v1/tasks?status=completed")
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
# All returned tasks should be completed
|
||||||
|
for task in data:
|
||||||
|
assert task["status"] == "completed"
|
||||||
|
|
||||||
|
def test_sync_mode_still_works(self, client, skill_registry):
|
||||||
|
"""Ensure existing sync mode is not broken"""
|
||||||
|
agent_name = self._register_skill_and_create_agent(client, skill_registry)
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/tasks",
|
||||||
|
json={
|
||||||
|
"input_data": {"query": "sync test"},
|
||||||
|
"agent_name": agent_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# Sync mode returns task_id and output
|
||||||
|
assert "task_id" in data
|
||||||
Loading…
Reference in New Issue