"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent""" import json import logging import uuid from datetime import datetime, timezone import redis.asyncio as aioredis from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.agent_framework.exceptions import ( NoAvailableAgentError, TaskDispatchError, TaskNotFoundError, ) from app.agent_framework.protocol import ( AgentStatus, TaskMessage, TaskProgress, TaskResult, TaskStatus, ) from app.database import AsyncSessionLocal from app.models.agent import AgentRegistry as AgentRegistryModel from app.models.agent import AgentTask as AgentTaskModel from app.models.agent import AgentTaskLog as AgentTaskLogModel logger = logging.getLogger(__name__) class TaskDispatcher: """任务分发器,通过 Redis Queue 将任务分发给 Agent""" def __init__(self, redis_url: str): self._redis_url = redis_url async def _get_redis(self): """获取全局 Redis 连接""" from app.core.redis import get_redis return await get_redis() async def close(self): """关闭 Redis 连接(由全局连接池管理,无需手动关闭)""" pass async def dispatch( self, task: TaskMessage, organization_id: str | None = None, created_by: str | None = None, ) -> str: """ 分发任务到对应 Agent 的队列,返回 task_id。 1. 写入 agent_tasks 表 (status=pending) 2. 推送到 Redis Queue: agent:{agent_name}:tasks """ async with AsyncSessionLocal() as db: try: # 查找目标 Agent stmt = select(AgentRegistryModel).where( AgentRegistryModel.name == task.agent_name ) result = await db.execute(stmt) agent = result.scalar_one_or_none() if not agent: raise TaskDispatchError( task.task_id, f"Agent '{task.agent_name}' not found" ) if agent.status != AgentStatus.ONLINE: raise TaskDispatchError( task.task_id, f"Agent '{task.agent_name}' is not online (status={agent.status})", ) # 写入 agent_tasks 表 task_id_uuid = uuid.UUID(task.task_id) agent_task = AgentTaskModel( id=task_id_uuid, agent_id=agent.id, task_type=task.task_type, status=TaskStatus.PENDING, priority=task.priority, input_data=task.input_data, organization_id=uuid.UUID(organization_id) if organization_id else agent.id, created_by=uuid.UUID(created_by) if created_by else None, ) db.add(agent_task) await db.commit() # 推送到 Redis Queue redis = await self._get_redis() queue_key = f"agent:{task.agent_name}:tasks" await redis.lpush(queue_key, json.dumps(task.to_dict())) logger.info( f"Task {task.task_id} dispatched to agent '{task.agent_name}' " f"(type={task.task_type}, priority={task.priority})" ) return task.task_id except TaskDispatchError: raise except Exception as e: await db.rollback() logger.error(f"Failed to dispatch task {task.task_id}: {e}") raise TaskDispatchError(task.task_id, str(e)) async def cancel_task(self, task_id: str): """取消任务""" async with AsyncSessionLocal() as db: try: task_uuid = uuid.UUID(task_id) stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: raise TaskNotFoundError(task_id) if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): logger.warning( f"Cannot cancel task {task_id} with status '{task.status}'" ) return task.status = TaskStatus.CANCELLED task.completed_at = datetime.now(timezone.utc) await db.commit() # 写入日志 await self._write_log( db, task_id=task_id, agent_id=str(task.agent_id), log_level="info", message="Task cancelled by user", ) await db.commit() logger.info(f"Task {task_id} cancelled") except TaskNotFoundError: raise except Exception as e: await db.rollback() logger.error(f"Failed to cancel task {task_id}: {e}") raise async def get_task_status(self, task_id: str) -> dict: """获取任务状态""" async with AsyncSessionLocal() as db: task_uuid = uuid.UUID(task_id) stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: raise TaskNotFoundError(task_id) return self._task_to_dict(task) async def handle_result(self, result: TaskResult): """处理 Agent 返回的结果""" async with AsyncSessionLocal() as db: try: task_uuid = uuid.UUID(result.task_id) stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) db_result = await db.execute(stmt) task = db_result.scalar_one_or_none() if not task: logger.error(f"Task {result.task_id} not found when handling result") return # 更新 agent_tasks 表 task.status = result.status task.output_data = result.output_data task.error_message = result.error_message task.started_at = result.started_at task.completed_at = result.completed_at await db.commit() # 写入日志 log_level = "info" if result.status == TaskStatus.COMPLETED else "error" log_message = ( f"Task {result.status}" if result.status == TaskStatus.COMPLETED else f"Task failed: {result.error_message}" ) await self._write_log( db, task_id=result.task_id, agent_id=str(task.agent_id), log_level=log_level, message=log_message, extra_metadata=result.metrics, ) await db.commit() # 触发回调(如果有 callback_url) if result.output_data and result.output_data.get("callback_url"): await self._trigger_callback(result.output_data["callback_url"], result) logger.info( f"Task {result.task_id} result handled (status={result.status})" ) except Exception as e: await db.rollback() logger.error(f"Failed to handle result for task {result.task_id}: {e}") async def handle_progress(self, progress: TaskProgress): """处理进度上报""" async with AsyncSessionLocal() as db: try: # 查找 agent_id stmt = select(AgentRegistryModel).where( AgentRegistryModel.name == progress.agent_name ) result = await db.execute(stmt) agent = result.scalar_one_or_none() if not agent: logger.warning(f"Agent '{progress.agent_name}' not found for progress report") return # 写入进度日志 await self._write_log( db, task_id=progress.task_id, agent_id=str(agent.id), log_level="info", message=f"Progress: {progress.progress:.0%} - {progress.message}", extra_metadata={ "progress": progress.progress, "updated_at": progress.updated_at.isoformat(), }, ) await db.commit() except Exception as e: await db.rollback() logger.error(f"Failed to handle progress for task {progress.task_id}: {e}") async def retry_failed_tasks(self, max_retries: int = 3): """重试失败的任务""" async with AsyncSessionLocal() as db: try: stmt = select(AgentTaskModel).where( AgentTaskModel.status == TaskStatus.FAILED, ) result = await db.execute(stmt) failed_tasks = result.scalars().all() retried = 0 for task in failed_tasks: # 检查重试次数(通过日志计算) log_stmt = select(AgentTaskLogModel).where( AgentTaskLogModel.task_id == task.id, AgentTaskLogModel.message.like("%retry%"), ) log_result = await db.execute(log_stmt) retry_count = len(log_result.scalars().all()) if retry_count < max_retries: # 重置任务状态 task.status = TaskStatus.PENDING task.error_message = None task.started_at = None task.completed_at = None # 重新推送到 Redis agent_stmt = select(AgentRegistryModel).where( AgentRegistryModel.id == task.agent_id ) agent_result = await db.execute(agent_stmt) agent = agent_result.scalar_one_or_none() if agent and agent.status == AgentStatus.ONLINE: task_msg = TaskMessage( task_id=str(task.id), agent_name=agent.name, task_type=task.task_type, priority=task.priority, input_data=task.input_data or {}, callback_url=None, created_at=datetime.now(timezone.utc), ) redis = await self._get_redis() queue_key = f"agent:{agent.name}:tasks" await redis.lpush(queue_key, json.dumps(task_msg.to_dict())) # 写入重试日志 await self._write_log( db, task_id=str(task.id), agent_id=str(agent.id), log_level="info", message=f"Task retry attempt {retry_count + 1}/{max_retries}", ) retried += 1 await db.commit() if retried > 0: logger.info(f"Retried {retried} failed tasks") except Exception as e: await db.rollback() logger.error(f"Failed to retry failed tasks: {e}") async def _write_log( self, db: AsyncSession, task_id: str, agent_id: str, log_level: str, message: str, extra_metadata: dict | None = None, ): """写入任务日志""" log_entry = AgentTaskLogModel( task_id=uuid.UUID(task_id), agent_id=uuid.UUID(agent_id), log_level=log_level, message=message, extra_metadata=extra_metadata, ) db.add(log_entry) async def _trigger_callback(self, callback_url: str, result: TaskResult): """触发回调(简单 HTTP POST)""" try: import httpx async with httpx.AsyncClient(timeout=10) as client: await client.post(callback_url, json=result.to_dict()) logger.info(f"Callback triggered for task {result.task_id}") except Exception as e: logger.warning(f"Callback failed for task {result.task_id}: {e}") def _task_to_dict(self, task: AgentTaskModel) -> dict: """将任务 ORM 对象转换为字典""" return { "id": str(task.id), "agent_id": str(task.agent_id), "task_type": task.task_type, "status": task.status, "priority": task.priority, "input_data": task.input_data, "output_data": task.output_data, "error_message": task.error_message, "created_by": str(task.created_by) if task.created_by else None, "organization_id": str(task.organization_id), "project_id": str(task.project_id) if task.project_id else None, "scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None, "started_at": task.started_at.isoformat() if task.started_at else None, "completed_at": task.completed_at.isoformat() if task.completed_at else None, "created_at": task.created_at.isoformat() if task.created_at else None, }