geo/backend/app/agent_framework/dispatcher.py

360 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent"""
import json
import logging
import uuid
from datetime import datetime, timezone
import redis.asyncio as aioredis
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.agent_framework.exceptions import (
NoAvailableAgentError,
TaskDispatchError,
TaskNotFoundError,
)
from app.agent_framework.protocol import (
AgentStatus,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
from app.database import AsyncSessionLocal
from app.models.agent import AgentRegistry as AgentRegistryModel
from app.models.agent import AgentTask as AgentTaskModel
from app.models.agent import AgentTaskLog as AgentTaskLogModel
logger = logging.getLogger(__name__)
class TaskDispatcher:
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
def __init__(self, redis_url: str):
self._redis_url = redis_url
async def _get_redis(self):
"""获取全局 Redis 连接"""
from app.core.redis import get_redis
return await get_redis()
async def close(self):
"""关闭 Redis 连接(由全局连接池管理,无需手动关闭)"""
pass
async def dispatch(
self,
task: TaskMessage,
organization_id: str | None = None,
created_by: str | None = None,
) -> str:
"""
分发任务到对应 Agent 的队列,返回 task_id。
1. 写入 agent_tasks 表 (status=pending)
2. 推送到 Redis Queue: agent:{agent_name}:tasks
"""
async with AsyncSessionLocal() as db:
try:
# 查找目标 Agent
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == task.agent_name
)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
raise TaskDispatchError(
task.task_id, f"Agent '{task.agent_name}' not found"
)
if agent.status != AgentStatus.ONLINE:
raise TaskDispatchError(
task.task_id,
f"Agent '{task.agent_name}' is not online (status={agent.status})",
)
# 写入 agent_tasks 表
task_id_uuid = uuid.UUID(task.task_id)
agent_task = AgentTaskModel(
id=task_id_uuid,
agent_id=agent.id,
task_type=task.task_type,
status=TaskStatus.PENDING,
priority=task.priority,
input_data=task.input_data,
organization_id=uuid.UUID(organization_id) if organization_id else agent.id,
created_by=uuid.UUID(created_by) if created_by else None,
)
db.add(agent_task)
await db.commit()
# 推送到 Redis Queue
redis = await self._get_redis()
queue_key = f"agent:{task.agent_name}:tasks"
await redis.lpush(queue_key, json.dumps(task.to_dict()))
logger.info(
f"Task {task.task_id} dispatched to agent '{task.agent_name}' "
f"(type={task.task_type}, priority={task.priority})"
)
return task.task_id
except TaskDispatchError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to dispatch task {task.task_id}: {e}")
raise TaskDispatchError(task.task_id, str(e))
async def cancel_task(self, task_id: str):
"""取消任务"""
async with AsyncSessionLocal() as db:
try:
task_uuid = uuid.UUID(task_id)
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
raise TaskNotFoundError(task_id)
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
logger.warning(
f"Cannot cancel task {task_id} with status '{task.status}'"
)
return
task.status = TaskStatus.CANCELLED
task.completed_at = datetime.now(timezone.utc)
await db.commit()
# 写入日志
await self._write_log(
db, task_id=task_id, agent_id=str(task.agent_id),
log_level="info", message="Task cancelled by user",
)
await db.commit()
logger.info(f"Task {task_id} cancelled")
except TaskNotFoundError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to cancel task {task_id}: {e}")
raise
async def get_task_status(self, task_id: str) -> dict:
"""获取任务状态"""
async with AsyncSessionLocal() as db:
task_uuid = uuid.UUID(task_id)
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
raise TaskNotFoundError(task_id)
return self._task_to_dict(task)
async def handle_result(self, result: TaskResult):
"""处理 Agent 返回的结果"""
async with AsyncSessionLocal() as db:
try:
task_uuid = uuid.UUID(result.task_id)
stmt = select(AgentTaskModel).where(AgentTaskModel.id == task_uuid)
db_result = await db.execute(stmt)
task = db_result.scalar_one_or_none()
if not task:
logger.error(f"Task {result.task_id} not found when handling result")
return
# 更新 agent_tasks 表
task.status = result.status
task.output_data = result.output_data
task.error_message = result.error_message
task.started_at = result.started_at
task.completed_at = result.completed_at
await db.commit()
# 写入日志
log_level = "info" if result.status == TaskStatus.COMPLETED else "error"
log_message = (
f"Task {result.status}"
if result.status == TaskStatus.COMPLETED
else f"Task failed: {result.error_message}"
)
await self._write_log(
db,
task_id=result.task_id,
agent_id=str(task.agent_id),
log_level=log_level,
message=log_message,
extra_metadata=result.metrics,
)
await db.commit()
# 触发回调(如果有 callback_url
if result.output_data and result.output_data.get("callback_url"):
await self._trigger_callback(result.output_data["callback_url"], result)
logger.info(
f"Task {result.task_id} result handled (status={result.status})"
)
except Exception as e:
await db.rollback()
logger.error(f"Failed to handle result for task {result.task_id}: {e}")
async def handle_progress(self, progress: TaskProgress):
"""处理进度上报"""
async with AsyncSessionLocal() as db:
try:
# 查找 agent_id
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == progress.agent_name
)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
logger.warning(f"Agent '{progress.agent_name}' not found for progress report")
return
# 写入进度日志
await self._write_log(
db,
task_id=progress.task_id,
agent_id=str(agent.id),
log_level="info",
message=f"Progress: {progress.progress:.0%} - {progress.message}",
extra_metadata={
"progress": progress.progress,
"updated_at": progress.updated_at.isoformat(),
},
)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to handle progress for task {progress.task_id}: {e}")
async def retry_failed_tasks(self, max_retries: int = 3):
"""重试失败的任务"""
async with AsyncSessionLocal() as db:
try:
stmt = select(AgentTaskModel).where(
AgentTaskModel.status == TaskStatus.FAILED,
)
result = await db.execute(stmt)
failed_tasks = result.scalars().all()
retried = 0
for task in failed_tasks:
# 检查重试次数(通过日志计算)
log_stmt = select(AgentTaskLogModel).where(
AgentTaskLogModel.task_id == task.id,
AgentTaskLogModel.message.like("%retry%"),
)
log_result = await db.execute(log_stmt)
retry_count = len(log_result.scalars().all())
if retry_count < max_retries:
# 重置任务状态
task.status = TaskStatus.PENDING
task.error_message = None
task.started_at = None
task.completed_at = None
# 重新推送到 Redis
agent_stmt = select(AgentRegistryModel).where(
AgentRegistryModel.id == task.agent_id
)
agent_result = await db.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
if agent and agent.status == AgentStatus.ONLINE:
task_msg = TaskMessage(
task_id=str(task.id),
agent_name=agent.name,
task_type=task.task_type,
priority=task.priority,
input_data=task.input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
redis = await self._get_redis()
queue_key = f"agent:{agent.name}:tasks"
await redis.lpush(queue_key, json.dumps(task_msg.to_dict()))
# 写入重试日志
await self._write_log(
db,
task_id=str(task.id),
agent_id=str(agent.id),
log_level="info",
message=f"Task retry attempt {retry_count + 1}/{max_retries}",
)
retried += 1
await db.commit()
if retried > 0:
logger.info(f"Retried {retried} failed tasks")
except Exception as e:
await db.rollback()
logger.error(f"Failed to retry failed tasks: {e}")
async def _write_log(
self,
db: AsyncSession,
task_id: str,
agent_id: str,
log_level: str,
message: str,
extra_metadata: dict | None = None,
):
"""写入任务日志"""
log_entry = AgentTaskLogModel(
task_id=uuid.UUID(task_id),
agent_id=uuid.UUID(agent_id),
log_level=log_level,
message=message,
extra_metadata=extra_metadata,
)
db.add(log_entry)
async def _trigger_callback(self, callback_url: str, result: TaskResult):
"""触发回调(简单 HTTP POST"""
try:
import httpx
async with httpx.AsyncClient(timeout=10) as client:
await client.post(callback_url, json=result.to_dict())
logger.info(f"Callback triggered for task {result.task_id}")
except Exception as e:
logger.warning(f"Callback failed for task {result.task_id}: {e}")
def _task_to_dict(self, task: AgentTaskModel) -> dict:
"""将任务 ORM 对象转换为字典"""
return {
"id": str(task.id),
"agent_id": str(task.agent_id),
"task_type": task.task_type,
"status": task.status,
"priority": task.priority,
"input_data": task.input_data,
"output_data": task.output_data,
"error_message": task.error_message,
"created_by": str(task.created_by) if task.created_by else None,
"organization_id": str(task.organization_id),
"project_id": str(task.project_id) if task.project_id else None,
"scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"created_at": task.created_at.isoformat() if task.created_at else None,
}