geo/backend/app/api/agents.py

374 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent 管理 API 路由"""
import uuid
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agent_framework.config_manager import AgentConfigManager
from app.agent_framework.dispatcher import TaskDispatcher
from app.agent_framework.exceptions import (
AgentNotFoundError,
NoAvailableAgentError,
TaskDispatchError,
TaskNotFoundError,
)
from app.agent_framework.protocol import TaskMessage, TaskStatus
from app.agent_framework.registry import AgentRegistry
from app.api.deps import get_current_user
from app.config import settings
from app.database import get_db
from app.models.agent import AgentTask as AgentTaskModel
from app.models.agent import AgentTaskLog as AgentTaskLogModel
from app.models.user import User
from app.schemas.common import ErrorCode, ErrorResponse
router = APIRouter()
# ---------------------------------------------------------------------------
# Pydantic Schemas
# ---------------------------------------------------------------------------
class TaskCreateRequest(BaseModel):
agent_name: str
task_type: str
priority: int = 0
input_data: dict = {}
callback_url: str | None = None
timeout_seconds: int = 300
class TaskCreateResponse(BaseModel):
task_id: str
status: str
message: str
class ConfigUpdateRequest(BaseModel):
configs: dict[str, Any]
class ConfigUpdateResponse(BaseModel):
updated_keys: list[str]
message: str
# ---------------------------------------------------------------------------
# Agent 端点
# ---------------------------------------------------------------------------
@router.get("/", summary="列出所有 Agent")
async def list_agents(
agent_type: str | None = Query(None, description="按 Agent 类型筛选"),
agent_status: str | None = Query(None, alias="status", description="按状态筛选"),
current_user: User = Depends(get_current_user),
):
registry = AgentRegistry()
agents = await registry.list_agents(agent_type=agent_type, status=agent_status)
return {"items": agents, "total": len(agents)}
@router.get("/{agent_name}", summary="获取 Agent 详情")
async def get_agent_detail(
agent_name: str,
current_user: User = Depends(get_current_user),
):
registry = AgentRegistry()
agent = await registry.get_agent(agent_name)
if not agent:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Agent '{agent_name}' not found",
)
return agent
@router.get("/{agent_name}/config", summary="获取 Agent 配置")
async def get_agent_config(
agent_name: str,
current_user: User = Depends(get_current_user),
):
config_mgr = AgentConfigManager()
try:
config = await config_mgr.get_config(agent_name)
return config
except AgentNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Agent '{agent_name}' not found",
)
@router.put("/{agent_name}/config", summary="更新 Agent 配置", response_model=ConfigUpdateResponse)
async def update_agent_config(
agent_name: str,
body: ConfigUpdateRequest,
current_user: User = Depends(get_current_user),
):
config_mgr = AgentConfigManager()
try:
await config_mgr.bulk_update_config(
agent_name=agent_name,
configs=body.configs,
updated_by=str(current_user.id),
)
return ConfigUpdateResponse(
updated_keys=list(body.configs.keys()),
message=f"Updated {len(body.configs)} config(s) for agent '{agent_name}'",
)
except AgentNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Agent '{agent_name}' not found",
)
# ---------------------------------------------------------------------------
# Task 端点
# ---------------------------------------------------------------------------
@router.get("/tasks/", summary="列出任务")
async def list_tasks(
agent_name: str | None = Query(None, description="按 Agent 名称筛选"),
task_status: str | None = Query(None, alias="status", description="按状态筛选"),
task_type: str | None = Query(None, description="按任务类型筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
stmt = select(AgentTaskModel).where(AgentTaskModel.created_by == current_user.id)
if task_status:
stmt = stmt.where(AgentTaskModel.status == task_status)
if task_type:
stmt = stmt.where(AgentTaskModel.task_type == task_type)
if agent_name:
# 需要关联查询 agent_registry 表
from app.models.agent import AgentRegistry as AgentRegistryModel
sub_stmt = select(AgentRegistryModel.id).where(
AgentRegistryModel.name == agent_name
)
stmt = stmt.where(AgentTaskModel.agent_id.in_(sub_stmt))
stmt = stmt.order_by(AgentTaskModel.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(stmt)
tasks = result.scalars().all()
items = []
for t in tasks:
items.append({
"id": str(t.id),
"agent_id": str(t.agent_id),
"task_type": t.task_type,
"status": t.status,
"priority": t.priority,
"input_data": t.input_data,
"output_data": t.output_data,
"error_message": t.error_message,
"started_at": t.started_at.isoformat() if t.started_at else None,
"completed_at": t.completed_at.isoformat() if t.completed_at else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
})
return {"items": items, "total": len(items)}
@router.post("/tasks/", summary="创建任务(分发给 Agent", response_model=TaskCreateResponse)
async def create_task(
body: TaskCreateRequest,
current_user: User = Depends(get_current_user),
):
task_id = str(uuid.uuid4())
task = TaskMessage(
task_id=task_id,
agent_name=body.agent_name,
task_type=body.task_type,
priority=body.priority,
input_data=body.input_data,
callback_url=body.callback_url,
created_at=datetime.now(timezone.utc),
timeout_seconds=body.timeout_seconds,
)
dispatcher = TaskDispatcher(settings.REDIS_URL)
try:
# 从 current_user 获取 organization_id优先使用用户的组织ID
org_id = str(current_user.organization_id) if current_user.organization_id else str(current_user.id)
await dispatcher.dispatch(
task=task,
organization_id=org_id,
created_by=str(current_user.id),
)
return TaskCreateResponse(
task_id=task_id,
status="pending",
message=f"Task dispatched to agent '{body.agent_name}'",
)
except (TaskDispatchError, NoAvailableAgentError) as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
finally:
await dispatcher.close()
@router.get("/tasks/{task_id}", summary="获取任务状态")
async def get_task_status(
task_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# 权限校验:验证任务归属当前用户
try:
task_uuid = uuid.UUID(task_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid task_id format",
)
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if task is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task '{task_id}' not found",
)
if task.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ErrorResponse(
detail="无权访问此任务",
code=ErrorCode.FORBIDDEN,
).dict(),
)
dispatcher = TaskDispatcher(settings.REDIS_URL)
try:
task_status_data = await dispatcher.get_task_status(task_id)
return task_status_data
except TaskNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task '{task_id}' not found",
)
finally:
await dispatcher.close()
@router.post("/tasks/{task_id}/cancel", summary="取消任务")
async def cancel_task(
task_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
# 权限校验:验证任务归属当前用户
try:
task_uuid = uuid.UUID(task_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid task_id format",
)
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if task is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task '{task_id}' not found",
)
if task.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ErrorResponse(
detail="无权取消此任务",
code=ErrorCode.FORBIDDEN,
).dict(),
)
dispatcher = TaskDispatcher(settings.REDIS_URL)
try:
await dispatcher.cancel_task(task_id)
return {"task_id": task_id, "status": "cancelled", "message": "Task cancelled"}
except TaskNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task '{task_id}' not found",
)
finally:
await dispatcher.close()
@router.get("/tasks/{task_id}/logs", summary="获取任务日志")
async def get_task_logs(
task_id: str,
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=200),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
task_uuid = uuid.UUID(task_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid task_id format",
)
# 权限校验:验证任务归属当前用户
task_stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
task_result = await db.execute(task_stmt)
task = task_result.scalar_one_or_none()
if task is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Task '{task_id}' not found",
)
if task.created_by != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ErrorResponse(
detail="无权访问此任务日志",
code=ErrorCode.FORBIDDEN,
).dict(),
)
stmt = (
select(AgentTaskLogModel)
.where(AgentTaskLogModel.task_id == task_uuid)
.order_by(AgentTaskLogModel.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
logs = result.scalars().all()
items = []
for log in logs:
items.append({
"id": str(log.id),
"task_id": str(log.task_id),
"agent_id": str(log.agent_id),
"log_level": log.log_level,
"message": log.message,
"metadata": log.extra_metadata,
"created_at": log.created_at.isoformat() if log.created_at else None,
})
return {"items": items, "total": len(items)}