374 lines
12 KiB
Python
374 lines
12 KiB
Python
"""Agent 管理 API 路由"""
|
||
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agent_framework.config_manager import AgentConfigManager
|
||
from app.agent_framework.dispatcher import TaskDispatcher
|
||
from app.agent_framework.exceptions import (
|
||
AgentNotFoundError,
|
||
NoAvailableAgentError,
|
||
TaskDispatchError,
|
||
TaskNotFoundError,
|
||
)
|
||
from app.agent_framework.protocol import TaskMessage, TaskStatus
|
||
from app.agent_framework.registry import AgentRegistry
|
||
from app.api.deps import get_current_user
|
||
from app.config import settings
|
||
from app.database import get_db
|
||
from app.models.agent import AgentTask as AgentTaskModel
|
||
from app.models.agent import AgentTaskLog as AgentTaskLogModel
|
||
from app.models.user import User
|
||
from app.schemas.common import ErrorCode, ErrorResponse
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Pydantic Schemas
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TaskCreateRequest(BaseModel):
|
||
agent_name: str
|
||
task_type: str
|
||
priority: int = 0
|
||
input_data: dict = {}
|
||
callback_url: str | None = None
|
||
timeout_seconds: int = 300
|
||
|
||
|
||
class TaskCreateResponse(BaseModel):
|
||
task_id: str
|
||
status: str
|
||
message: str
|
||
|
||
|
||
class ConfigUpdateRequest(BaseModel):
|
||
configs: dict[str, Any]
|
||
|
||
|
||
class ConfigUpdateResponse(BaseModel):
|
||
updated_keys: list[str]
|
||
message: str
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Agent 端点
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@router.get("/", summary="列出所有 Agent")
|
||
async def list_agents(
|
||
agent_type: str | None = Query(None, description="按 Agent 类型筛选"),
|
||
agent_status: str | None = Query(None, alias="status", description="按状态筛选"),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
registry = AgentRegistry()
|
||
agents = await registry.list_agents(agent_type=agent_type, status=agent_status)
|
||
return {"items": agents, "total": len(agents)}
|
||
|
||
|
||
@router.get("/{agent_name}", summary="获取 Agent 详情")
|
||
async def get_agent_detail(
|
||
agent_name: str,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
registry = AgentRegistry()
|
||
agent = await registry.get_agent(agent_name)
|
||
if not agent:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Agent '{agent_name}' not found",
|
||
)
|
||
return agent
|
||
|
||
|
||
@router.get("/{agent_name}/config", summary="获取 Agent 配置")
|
||
async def get_agent_config(
|
||
agent_name: str,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
config_mgr = AgentConfigManager()
|
||
try:
|
||
config = await config_mgr.get_config(agent_name)
|
||
return config
|
||
except AgentNotFoundError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Agent '{agent_name}' not found",
|
||
)
|
||
|
||
|
||
@router.put("/{agent_name}/config", summary="更新 Agent 配置", response_model=ConfigUpdateResponse)
|
||
async def update_agent_config(
|
||
agent_name: str,
|
||
body: ConfigUpdateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
config_mgr = AgentConfigManager()
|
||
try:
|
||
await config_mgr.bulk_update_config(
|
||
agent_name=agent_name,
|
||
configs=body.configs,
|
||
updated_by=str(current_user.id),
|
||
)
|
||
return ConfigUpdateResponse(
|
||
updated_keys=list(body.configs.keys()),
|
||
message=f"Updated {len(body.configs)} config(s) for agent '{agent_name}'",
|
||
)
|
||
except AgentNotFoundError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Agent '{agent_name}' not found",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Task 端点
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@router.get("/tasks/", summary="列出任务")
|
||
async def list_tasks(
|
||
agent_name: str | None = Query(None, description="按 Agent 名称筛选"),
|
||
task_status: str | None = Query(None, alias="status", description="按状态筛选"),
|
||
task_type: str | None = Query(None, description="按任务类型筛选"),
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(20, ge=1, le=100),
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.created_by == current_user.id)
|
||
|
||
if task_status:
|
||
stmt = stmt.where(AgentTaskModel.status == task_status)
|
||
if task_type:
|
||
stmt = stmt.where(AgentTaskModel.task_type == task_type)
|
||
if agent_name:
|
||
# 需要关联查询 agent_registry 表
|
||
from app.models.agent import AgentRegistry as AgentRegistryModel
|
||
|
||
sub_stmt = select(AgentRegistryModel.id).where(
|
||
AgentRegistryModel.name == agent_name
|
||
)
|
||
stmt = stmt.where(AgentTaskModel.agent_id.in_(sub_stmt))
|
||
|
||
stmt = stmt.order_by(AgentTaskModel.created_at.desc()).offset(skip).limit(limit)
|
||
|
||
result = await db.execute(stmt)
|
||
tasks = result.scalars().all()
|
||
|
||
items = []
|
||
for t in tasks:
|
||
items.append({
|
||
"id": str(t.id),
|
||
"agent_id": str(t.agent_id),
|
||
"task_type": t.task_type,
|
||
"status": t.status,
|
||
"priority": t.priority,
|
||
"input_data": t.input_data,
|
||
"output_data": t.output_data,
|
||
"error_message": t.error_message,
|
||
"started_at": t.started_at.isoformat() if t.started_at else None,
|
||
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
|
||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||
})
|
||
|
||
return {"items": items, "total": len(items)}
|
||
|
||
|
||
@router.post("/tasks/", summary="创建任务(分发给 Agent)", response_model=TaskCreateResponse)
|
||
async def create_task(
|
||
body: TaskCreateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
task_id = str(uuid.uuid4())
|
||
task = TaskMessage(
|
||
task_id=task_id,
|
||
agent_name=body.agent_name,
|
||
task_type=body.task_type,
|
||
priority=body.priority,
|
||
input_data=body.input_data,
|
||
callback_url=body.callback_url,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=body.timeout_seconds,
|
||
)
|
||
|
||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||
try:
|
||
# 从 current_user 获取 organization_id,优先使用用户的组织ID
|
||
org_id = str(current_user.organization_id) if current_user.organization_id else str(current_user.id)
|
||
await dispatcher.dispatch(
|
||
task=task,
|
||
organization_id=org_id,
|
||
created_by=str(current_user.id),
|
||
)
|
||
return TaskCreateResponse(
|
||
task_id=task_id,
|
||
status="pending",
|
||
message=f"Task dispatched to agent '{body.agent_name}'",
|
||
)
|
||
except (TaskDispatchError, NoAvailableAgentError) as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=str(e),
|
||
)
|
||
finally:
|
||
await dispatcher.close()
|
||
|
||
|
||
@router.get("/tasks/{task_id}", summary="获取任务状态")
|
||
async def get_task_status(
|
||
task_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
# 权限校验:验证任务归属当前用户
|
||
try:
|
||
task_uuid = uuid.UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid task_id format",
|
||
)
|
||
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
result = await db.execute(stmt)
|
||
task = result.scalar_one_or_none()
|
||
if task is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Task '{task_id}' not found",
|
||
)
|
||
if task.created_by != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail=ErrorResponse(
|
||
detail="无权访问此任务",
|
||
code=ErrorCode.FORBIDDEN,
|
||
).dict(),
|
||
)
|
||
|
||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||
try:
|
||
task_status_data = await dispatcher.get_task_status(task_id)
|
||
return task_status_data
|
||
except TaskNotFoundError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Task '{task_id}' not found",
|
||
)
|
||
finally:
|
||
await dispatcher.close()
|
||
|
||
|
||
@router.post("/tasks/{task_id}/cancel", summary="取消任务")
|
||
async def cancel_task(
|
||
task_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
# 权限校验:验证任务归属当前用户
|
||
try:
|
||
task_uuid = uuid.UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid task_id format",
|
||
)
|
||
|
||
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
result = await db.execute(stmt)
|
||
task = result.scalar_one_or_none()
|
||
if task is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Task '{task_id}' not found",
|
||
)
|
||
if task.created_by != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail=ErrorResponse(
|
||
detail="无权取消此任务",
|
||
code=ErrorCode.FORBIDDEN,
|
||
).dict(),
|
||
)
|
||
|
||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||
try:
|
||
await dispatcher.cancel_task(task_id)
|
||
return {"task_id": task_id, "status": "cancelled", "message": "Task cancelled"}
|
||
except TaskNotFoundError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Task '{task_id}' not found",
|
||
)
|
||
finally:
|
||
await dispatcher.close()
|
||
|
||
|
||
@router.get("/tasks/{task_id}/logs", summary="获取任务日志")
|
||
async def get_task_logs(
|
||
task_id: str,
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(50, ge=1, le=200),
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
try:
|
||
task_uuid = uuid.UUID(task_id)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="Invalid task_id format",
|
||
)
|
||
|
||
# 权限校验:验证任务归属当前用户
|
||
task_stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
|
||
task_result = await db.execute(task_stmt)
|
||
task = task_result.scalar_one_or_none()
|
||
if task is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"Task '{task_id}' not found",
|
||
)
|
||
if task.created_by != current_user.id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail=ErrorResponse(
|
||
detail="无权访问此任务日志",
|
||
code=ErrorCode.FORBIDDEN,
|
||
).dict(),
|
||
)
|
||
|
||
stmt = (
|
||
select(AgentTaskLogModel)
|
||
.where(AgentTaskLogModel.task_id == task_uuid)
|
||
.order_by(AgentTaskLogModel.created_at.desc())
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
logs = result.scalars().all()
|
||
|
||
items = []
|
||
for log in logs:
|
||
items.append({
|
||
"id": str(log.id),
|
||
"task_id": str(log.task_id),
|
||
"agent_id": str(log.agent_id),
|
||
"log_level": log.log_level,
|
||
"message": log.message,
|
||
"metadata": log.extra_metadata,
|
||
"created_at": log.created_at.isoformat() if log.created_at else None,
|
||
})
|
||
|
||
return {"items": items, "total": len(items)}
|