"""Agent 管理 API 路由""" import uuid from datetime import datetime, timezone from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agent_framework.config_manager import AgentConfigManager from app.agent_framework.dispatcher import TaskDispatcher from app.agent_framework.exceptions import ( AgentNotFoundError, NoAvailableAgentError, TaskDispatchError, TaskNotFoundError, ) from app.agent_framework.protocol import TaskMessage, TaskStatus from app.agent_framework.registry import AgentRegistry from app.api.deps import get_current_user from app.config import settings from app.database import get_db from app.models.agent import AgentTask as AgentTaskModel from app.models.agent import AgentTaskLog as AgentTaskLogModel from app.models.user import User from app.schemas.common import ErrorCode, ErrorResponse router = APIRouter() # --------------------------------------------------------------------------- # Pydantic Schemas # --------------------------------------------------------------------------- class TaskCreateRequest(BaseModel): agent_name: str task_type: str priority: int = 0 input_data: dict = {} callback_url: str | None = None timeout_seconds: int = 300 class TaskCreateResponse(BaseModel): task_id: str status: str message: str class ConfigUpdateRequest(BaseModel): configs: dict[str, Any] class ConfigUpdateResponse(BaseModel): updated_keys: list[str] message: str # --------------------------------------------------------------------------- # Agent 端点 # --------------------------------------------------------------------------- @router.get("/", summary="列出所有 Agent") async def list_agents( agent_type: str | None = Query(None, description="按 Agent 类型筛选"), agent_status: str | None = Query(None, alias="status", description="按状态筛选"), current_user: User = Depends(get_current_user), ): registry = AgentRegistry() agents = await registry.list_agents(agent_type=agent_type, status=agent_status) return {"items": agents, "total": len(agents)} @router.get("/{agent_name}", summary="获取 Agent 详情") async def get_agent_detail( agent_name: str, current_user: User = Depends(get_current_user), ): registry = AgentRegistry() agent = await registry.get_agent(agent_name) if not agent: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Agent '{agent_name}' not found", ) return agent @router.get("/{agent_name}/config", summary="获取 Agent 配置") async def get_agent_config( agent_name: str, current_user: User = Depends(get_current_user), ): config_mgr = AgentConfigManager() try: config = await config_mgr.get_config(agent_name) return config except AgentNotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Agent '{agent_name}' not found", ) @router.put("/{agent_name}/config", summary="更新 Agent 配置", response_model=ConfigUpdateResponse) async def update_agent_config( agent_name: str, body: ConfigUpdateRequest, current_user: User = Depends(get_current_user), ): config_mgr = AgentConfigManager() try: await config_mgr.bulk_update_config( agent_name=agent_name, configs=body.configs, updated_by=str(current_user.id), ) return ConfigUpdateResponse( updated_keys=list(body.configs.keys()), message=f"Updated {len(body.configs)} config(s) for agent '{agent_name}'", ) except AgentNotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Agent '{agent_name}' not found", ) # --------------------------------------------------------------------------- # Task 端点 # --------------------------------------------------------------------------- @router.get("/tasks/", summary="列出任务") async def list_tasks( agent_name: str | None = Query(None, description="按 Agent 名称筛选"), task_status: str | None = Query(None, alias="status", description="按状态筛选"), task_type: str | None = Query(None, description="按任务类型筛选"), skip: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): stmt = select(AgentTaskModel).where(AgentTaskModel.created_by == current_user.id) if task_status: stmt = stmt.where(AgentTaskModel.status == task_status) if task_type: stmt = stmt.where(AgentTaskModel.task_type == task_type) if agent_name: # 需要关联查询 agent_registry 表 from app.models.agent import AgentRegistry as AgentRegistryModel sub_stmt = select(AgentRegistryModel.id).where( AgentRegistryModel.name == agent_name ) stmt = stmt.where(AgentTaskModel.agent_id.in_(sub_stmt)) stmt = stmt.order_by(AgentTaskModel.created_at.desc()).offset(skip).limit(limit) result = await db.execute(stmt) tasks = result.scalars().all() items = [] for t in tasks: items.append({ "id": str(t.id), "agent_id": str(t.agent_id), "task_type": t.task_type, "status": t.status, "priority": t.priority, "input_data": t.input_data, "output_data": t.output_data, "error_message": t.error_message, "started_at": t.started_at.isoformat() if t.started_at else None, "completed_at": t.completed_at.isoformat() if t.completed_at else None, "created_at": t.created_at.isoformat() if t.created_at else None, }) return {"items": items, "total": len(items)} @router.post("/tasks/", summary="创建任务(分发给 Agent)", response_model=TaskCreateResponse) async def create_task( body: TaskCreateRequest, current_user: User = Depends(get_current_user), ): task_id = str(uuid.uuid4()) task = TaskMessage( task_id=task_id, agent_name=body.agent_name, task_type=body.task_type, priority=body.priority, input_data=body.input_data, callback_url=body.callback_url, created_at=datetime.now(timezone.utc), timeout_seconds=body.timeout_seconds, ) dispatcher = TaskDispatcher(settings.REDIS_URL) try: # 从 current_user 获取 organization_id,优先使用用户的组织ID org_id = str(current_user.organization_id) if current_user.organization_id else str(current_user.id) await dispatcher.dispatch( task=task, organization_id=org_id, created_by=str(current_user.id), ) return TaskCreateResponse( task_id=task_id, status="pending", message=f"Task dispatched to agent '{body.agent_name}'", ) except (TaskDispatchError, NoAvailableAgentError) as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) finally: await dispatcher.close() @router.get("/tasks/{task_id}", summary="获取任务状态") async def get_task_status( task_id: str, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): # 权限校验:验证任务归属当前用户 try: task_uuid = uuid.UUID(task_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid task_id format", ) stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) result = await db.execute(stmt) task = result.scalar_one_or_none() if task is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task '{task_id}' not found", ) if task.created_by != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ErrorResponse( detail="无权访问此任务", code=ErrorCode.FORBIDDEN, ).dict(), ) dispatcher = TaskDispatcher(settings.REDIS_URL) try: task_status_data = await dispatcher.get_task_status(task_id) return task_status_data except TaskNotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task '{task_id}' not found", ) finally: await dispatcher.close() @router.post("/tasks/{task_id}/cancel", summary="取消任务") async def cancel_task( task_id: str, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): # 权限校验:验证任务归属当前用户 try: task_uuid = uuid.UUID(task_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid task_id format", ) stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) result = await db.execute(stmt) task = result.scalar_one_or_none() if task is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task '{task_id}' not found", ) if task.created_by != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ErrorResponse( detail="无权取消此任务", code=ErrorCode.FORBIDDEN, ).dict(), ) dispatcher = TaskDispatcher(settings.REDIS_URL) try: await dispatcher.cancel_task(task_id) return {"task_id": task_id, "status": "cancelled", "message": "Task cancelled"} except TaskNotFoundError: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task '{task_id}' not found", ) finally: await dispatcher.close() @router.get("/tasks/{task_id}/logs", summary="获取任务日志") async def get_task_logs( task_id: str, skip: int = Query(0, ge=0), limit: int = Query(50, ge=1, le=200), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): try: task_uuid = uuid.UUID(task_id) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid task_id format", ) # 权限校验:验证任务归属当前用户 task_stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid) task_result = await db.execute(task_stmt) task = task_result.scalar_one_or_none() if task is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Task '{task_id}' not found", ) if task.created_by != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ErrorResponse( detail="无权访问此任务日志", code=ErrorCode.FORBIDDEN, ).dict(), ) stmt = ( select(AgentTaskLogModel) .where(AgentTaskLogModel.task_id == task_uuid) .order_by(AgentTaskLogModel.created_at.desc()) .offset(skip) .limit(limit) ) result = await db.execute(stmt) logs = result.scalars().all() items = [] for log in logs: items.append({ "id": str(log.id), "task_id": str(log.task_id), "agent_id": str(log.agent_id), "log_level": log.log_level, "message": log.message, "metadata": log.extra_metadata, "created_at": log.created_at.isoformat() if log.created_at else None, }) return {"items": items, "total": len(items)}