367 lines
14 KiB
Python
367 lines
14 KiB
Python
"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent"""
|
||
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
import redis.asyncio as aioredis
|
||
from sqlalchemy import select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agent_framework.exceptions import (
|
||
NoAvailableAgentError,
|
||
TaskDispatchError,
|
||
TaskNotFoundError,
|
||
)
|
||
from app.agent_framework.protocol import (
|
||
AgentStatus,
|
||
TaskMessage,
|
||
TaskProgress,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.agent import AgentRegistry as AgentRegistryModel
|
||
from app.models.agent import AgentTask as AgentTaskModel
|
||
from app.models.agent import AgentTaskLog as AgentTaskLogModel
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TaskDispatcher:
|
||
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
|
||
|
||
def __init__(self, redis_url: str):
|
||
self._redis_url = redis_url
|
||
self._redis: aioredis.Redis | None = None
|
||
|
||
async def _get_redis(self) -> aioredis.Redis:
|
||
"""获取 Redis 连接"""
|
||
if self._redis is None:
|
||
self._redis = aioredis.from_url(
|
||
self._redis_url,
|
||
decode_responses=True,
|
||
)
|
||
return self._redis
|
||
|
||
async def close(self):
|
||
"""关闭 Redis 连接"""
|
||
if self._redis:
|
||
await self._redis.close()
|
||
self._redis = None
|
||
|
||
async def dispatch(
|
||
self,
|
||
task: TaskMessage,
|
||
organization_id: str | None = None,
|
||
created_by: str | None = None,
|
||
) -> str:
|
||
"""
|
||
分发任务到对应 Agent 的队列,返回 task_id。
|
||
1. 写入 agent_tasks 表 (status=pending)
|
||
2. 推送到 Redis Queue: agent:{agent_name}:tasks
|
||
"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
# 查找目标 Agent
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.name == task.agent_name
|
||
)
|
||
result = await db.execute(stmt)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if not agent:
|
||
raise TaskDispatchError(
|
||
task.task_id, f"Agent '{task.agent_name}' not found"
|
||
)
|
||
|
||
if agent.status != AgentStatus.ONLINE:
|
||
raise TaskDispatchError(
|
||
task.task_id,
|
||
f"Agent '{task.agent_name}' is not online (status={agent.status})",
|
||
)
|
||
|
||
# 写入 agent_tasks 表
|
||
task_id_uuid = uuid.UUID(task.task_id)
|
||
agent_task = AgentTaskModel(
|
||
id=task_id_uuid,
|
||
agent_id=agent.id,
|
||
task_type=task.task_type,
|
||
status=TaskStatus.PENDING,
|
||
priority=task.priority,
|
||
input_data=task.input_data,
|
||
organization_id=uuid.UUID(organization_id) if organization_id else agent.id,
|
||
created_by=uuid.UUID(created_by) if created_by else None,
|
||
)
|
||
db.add(agent_task)
|
||
await db.commit()
|
||
|
||
# 推送到 Redis Queue
|
||
redis = await self._get_redis()
|
||
queue_key = f"agent:{task.agent_name}:tasks"
|
||
await redis.lpush(queue_key, json.dumps(task.to_dict()))
|
||
|
||
logger.info(
|
||
f"Task {task.task_id} dispatched to agent '{task.agent_name}' "
|
||
f"(type={task.task_type}, priority={task.priority})"
|
||
)
|
||
return task.task_id
|
||
|
||
except TaskDispatchError:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to dispatch task {task.task_id}: {e}")
|
||
raise TaskDispatchError(task.task_id, str(e))
|
||
|
||
async def cancel_task(self, task_id: str):
|
||
"""取消任务"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
task_uuid = uuid.UUID(task_id)
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
result = await db.execute(stmt)
|
||
task = result.scalar_one_or_none()
|
||
|
||
if not task:
|
||
raise TaskNotFoundError(task_id)
|
||
|
||
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||
logger.warning(
|
||
f"Cannot cancel task {task_id} with status '{task.status}'"
|
||
)
|
||
return
|
||
|
||
task.status = TaskStatus.CANCELLED
|
||
task.completed_at = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
|
||
# 写入日志
|
||
await self._write_log(
|
||
db, task_id=task_id, agent_id=str(task.agent_id),
|
||
log_level="info", message="Task cancelled by user",
|
||
)
|
||
await db.commit()
|
||
|
||
logger.info(f"Task {task_id} cancelled")
|
||
|
||
except TaskNotFoundError:
|
||
raise
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to cancel task {task_id}: {e}")
|
||
raise
|
||
|
||
async def get_task_status(self, task_id: str) -> dict:
|
||
"""获取任务状态"""
|
||
async with AsyncSessionLocal() as db:
|
||
task_uuid = uuid.UUID(task_id)
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
result = await db.execute(stmt)
|
||
task = result.scalar_one_or_none()
|
||
|
||
if not task:
|
||
raise TaskNotFoundError(task_id)
|
||
|
||
return self._task_to_dict(task)
|
||
|
||
async def handle_result(self, result: TaskResult):
|
||
"""处理 Agent 返回的结果"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
task_uuid = uuid.UUID(result.task_id)
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
db_result = await db.execute(stmt)
|
||
task = db_result.scalar_one_or_none()
|
||
|
||
if not task:
|
||
logger.error(f"Task {result.task_id} not found when handling result")
|
||
return
|
||
|
||
# 更新 agent_tasks 表
|
||
task.status = result.status
|
||
task.output_data = result.output_data
|
||
task.error_message = result.error_message
|
||
task.started_at = result.started_at
|
||
task.completed_at = result.completed_at
|
||
await db.commit()
|
||
|
||
# 写入日志
|
||
log_level = "info" if result.status == TaskStatus.COMPLETED else "error"
|
||
log_message = (
|
||
f"Task {result.status}"
|
||
if result.status == TaskStatus.COMPLETED
|
||
else f"Task failed: {result.error_message}"
|
||
)
|
||
await self._write_log(
|
||
db,
|
||
task_id=result.task_id,
|
||
agent_id=str(task.agent_id),
|
||
log_level=log_level,
|
||
message=log_message,
|
||
extra_metadata=result.metrics,
|
||
)
|
||
await db.commit()
|
||
|
||
# 触发回调(如果有 callback_url)
|
||
if result.output_data and result.output_data.get("callback_url"):
|
||
await self._trigger_callback(result.output_data["callback_url"], result)
|
||
|
||
logger.info(
|
||
f"Task {result.task_id} result handled (status={result.status})"
|
||
)
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to handle result for task {result.task_id}: {e}")
|
||
|
||
async def handle_progress(self, progress: TaskProgress):
|
||
"""处理进度上报"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
# 查找 agent_id
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.name == progress.agent_name
|
||
)
|
||
result = await db.execute(stmt)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if not agent:
|
||
logger.warning(f"Agent '{progress.agent_name}' not found for progress report")
|
||
return
|
||
|
||
# 写入进度日志
|
||
await self._write_log(
|
||
db,
|
||
task_id=progress.task_id,
|
||
agent_id=str(agent.id),
|
||
log_level="info",
|
||
message=f"Progress: {progress.progress:.0%} - {progress.message}",
|
||
extra_metadata={
|
||
"progress": progress.progress,
|
||
"updated_at": progress.updated_at.isoformat(),
|
||
},
|
||
)
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to handle progress for task {progress.task_id}: {e}")
|
||
|
||
async def retry_failed_tasks(self, max_retries: int = 3):
|
||
"""重试失败的任务"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
stmt = select(AgentTaskModel).where(
|
||
AgentTaskModel.status == TaskStatus.FAILED,
|
||
)
|
||
result = await db.execute(stmt)
|
||
failed_tasks = result.scalars().all()
|
||
|
||
retried = 0
|
||
for task in failed_tasks:
|
||
# 检查重试次数(通过日志计算)
|
||
log_stmt = select(AgentTaskLogModel).where(
|
||
AgentTaskLogModel.task_id == task.id,
|
||
AgentTaskLogModel.message.like("%retry%"),
|
||
)
|
||
log_result = await db.execute(log_stmt)
|
||
retry_count = len(log_result.scalars().all())
|
||
|
||
if retry_count < max_retries:
|
||
# 重置任务状态
|
||
task.status = TaskStatus.PENDING
|
||
task.error_message = None
|
||
task.started_at = None
|
||
task.completed_at = None
|
||
|
||
# 重新推送到 Redis
|
||
agent_stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.id == task.agent_id
|
||
)
|
||
agent_result = await db.execute(agent_stmt)
|
||
agent = agent_result.scalar_one_or_none()
|
||
|
||
if agent and agent.status == AgentStatus.ONLINE:
|
||
task_msg = TaskMessage(
|
||
task_id=str(task.id),
|
||
agent_name=agent.name,
|
||
task_type=task.task_type,
|
||
priority=task.priority,
|
||
input_data=task.input_data or {},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
redis = await self._get_redis()
|
||
queue_key = f"agent:{agent.name}:tasks"
|
||
await redis.lpush(queue_key, json.dumps(task_msg.to_dict()))
|
||
|
||
# 写入重试日志
|
||
await self._write_log(
|
||
db,
|
||
task_id=str(task.id),
|
||
agent_id=str(agent.id),
|
||
log_level="info",
|
||
message=f"Task retry attempt {retry_count + 1}/{max_retries}",
|
||
)
|
||
retried += 1
|
||
|
||
await db.commit()
|
||
if retried > 0:
|
||
logger.info(f"Retried {retried} failed tasks")
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to retry failed tasks: {e}")
|
||
|
||
async def _write_log(
|
||
self,
|
||
db: AsyncSession,
|
||
task_id: str,
|
||
agent_id: str,
|
||
log_level: str,
|
||
message: str,
|
||
extra_metadata: dict | None = None,
|
||
):
|
||
"""写入任务日志"""
|
||
log_entry = AgentTaskLogModel(
|
||
task_id=uuid.UUID(task_id),
|
||
agent_id=uuid.UUID(agent_id),
|
||
log_level=log_level,
|
||
message=message,
|
||
extra_metadata=extra_metadata,
|
||
)
|
||
db.add(log_entry)
|
||
|
||
async def _trigger_callback(self, callback_url: str, result: TaskResult):
|
||
"""触发回调(简单 HTTP POST)"""
|
||
try:
|
||
import httpx
|
||
|
||
async with httpx.AsyncClient(timeout=10) as client:
|
||
await client.post(callback_url, json=result.to_dict())
|
||
logger.info(f"Callback triggered for task {result.task_id}")
|
||
except Exception as e:
|
||
logger.warning(f"Callback failed for task {result.task_id}: {e}")
|
||
|
||
def _task_to_dict(self, task: AgentTaskModel) -> dict:
|
||
"""将任务 ORM 对象转换为字典"""
|
||
return {
|
||
"id": str(task.id),
|
||
"agent_id": str(task.agent_id),
|
||
"task_type": task.task_type,
|
||
"status": task.status,
|
||
"priority": task.priority,
|
||
"input_data": task.input_data,
|
||
"output_data": task.output_data,
|
||
"error_message": task.error_message,
|
||
"created_by": str(task.created_by) if task.created_by else None,
|
||
"organization_id": str(task.organization_id),
|
||
"project_id": str(task.project_id) if task.project_id else None,
|
||
"scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None,
|
||
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||
}
|