171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
"""BackgroundRunner - Async task execution with lifecycle management"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
|
from agentkit.server.task_store import TaskStore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BackgroundRunner:
|
|
"""Runs tasks in background asyncio tasks with lifecycle management.
|
|
|
|
Integrates with AgentPool for agent execution and TaskStore for state tracking.
|
|
"""
|
|
|
|
def __init__(self, task_store: TaskStore, max_concurrent: int = 10):
|
|
self._task_store = task_store
|
|
self._max_concurrent = max_concurrent
|
|
self._running_tasks: dict[str, asyncio.Task] = {}
|
|
self._semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
|
@property
|
|
def active_count(self) -> int:
|
|
return len(self._running_tasks)
|
|
|
|
async def submit(
|
|
self,
|
|
agent, # ConfigDrivenAgent
|
|
input_data: dict[str, Any],
|
|
skill_name: str | None = None,
|
|
quality_gate=None,
|
|
output_standardizer=None,
|
|
skill=None,
|
|
) -> str:
|
|
"""Submit a task for background execution.
|
|
|
|
Returns task_id immediately.
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# Create task record
|
|
self._task_store.create(
|
|
task_id=task_id,
|
|
agent_name=agent.name,
|
|
input_data=input_data,
|
|
skill_name=skill_name,
|
|
)
|
|
|
|
# Launch background asyncio task
|
|
asyncio_task = asyncio.create_task(
|
|
self._run_task(
|
|
task_id=task_id,
|
|
agent=agent,
|
|
input_data=input_data,
|
|
quality_gate=quality_gate,
|
|
output_standardizer=output_standardizer,
|
|
skill=skill,
|
|
)
|
|
)
|
|
self._running_tasks[task_id] = asyncio_task
|
|
|
|
# Clean up reference when done
|
|
def _on_done(t: asyncio.Task):
|
|
self._running_tasks.pop(task_id, None)
|
|
if t.exception():
|
|
logger.error(f"Background task {task_id} failed: {t.exception()}")
|
|
|
|
asyncio_task.add_done_callback(_on_done)
|
|
|
|
return task_id
|
|
|
|
async def _run_task(
|
|
self,
|
|
task_id: str,
|
|
agent,
|
|
input_data: dict,
|
|
quality_gate=None,
|
|
output_standardizer=None,
|
|
skill=None,
|
|
) -> dict[str, Any]:
|
|
"""Execute task in background with semaphore control"""
|
|
async with self._semaphore:
|
|
# Update status to RUNNING
|
|
self._task_store.update_status(
|
|
task_id, TaskStatus.RUNNING,
|
|
started_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
try:
|
|
# Create TaskMessage for agent
|
|
task_msg = TaskMessage(
|
|
task_id=task_id,
|
|
agent_name=agent.name,
|
|
task_type=agent.agent_type,
|
|
priority=0,
|
|
input_data=input_data,
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
# Execute agent
|
|
task_result = await agent.execute(task_msg)
|
|
|
|
# Run quality gate if available
|
|
quality_result = None
|
|
if skill and quality_gate:
|
|
try:
|
|
quality_result = await quality_gate.validate(
|
|
task_result.output_data or {}, skill
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Quality gate failed for {task_id}: {e}")
|
|
|
|
# Standardize output if available
|
|
final_output = task_result.output_data
|
|
if skill and output_standardizer:
|
|
try:
|
|
standard_output = await output_standardizer.standardize(
|
|
raw_output=task_result.output_data or {},
|
|
skill=skill,
|
|
quality_result=quality_result,
|
|
)
|
|
final_output = {
|
|
"skill_name": standard_output.skill_name,
|
|
"data": standard_output.data,
|
|
"metadata": {
|
|
"version": standard_output.metadata.version,
|
|
"produced_at": standard_output.metadata.produced_at.isoformat(),
|
|
"quality_score": standard_output.metadata.quality_score,
|
|
},
|
|
}
|
|
except Exception as e:
|
|
logger.warning(f"Output standardization failed for {task_id}: {e}")
|
|
|
|
# Update store
|
|
self._task_store.update_status(
|
|
task_id, TaskStatus.COMPLETED,
|
|
output_data=final_output,
|
|
completed_at=datetime.now(timezone.utc),
|
|
progress=1.0,
|
|
progress_message="Completed",
|
|
)
|
|
|
|
return final_output or {}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} failed: {e}")
|
|
self._task_store.update_status(
|
|
task_id, TaskStatus.FAILED,
|
|
error_message=str(e),
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
raise
|
|
|
|
async def cancel(self, task_id: str) -> bool:
|
|
"""Cancel a running task"""
|
|
asyncio_task = self._running_tasks.get(task_id)
|
|
if asyncio_task and not asyncio_task.done():
|
|
asyncio_task.cancel()
|
|
self._task_store.update_status(
|
|
task_id, TaskStatus.CANCELLED,
|
|
completed_at=datetime.now(timezone.utc),
|
|
)
|
|
return True
|
|
return False
|