geo/backend/app/agent_framework/base.py

240 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""BaseAgent 基类 - 定义所有 Agent 的标准生命周期"""
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
import redis.asyncio as aioredis
from app.agent_framework.exceptions import AgentNotReadyError
from app.agent_framework.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
from app.config import settings
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期"""
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
self.name = name
self.agent_type = agent_type
self.version = version
self._status: AgentStatus = AgentStatus.OFFLINE
self._redis: aioredis.Redis | None = None
self._running_tasks: set[str] = set()
self._listen_task: asyncio.Task | None = None
self._heartbeat_task: asyncio.Task | None = None
self._semaphore: asyncio.Semaphore | None = None
@property
def status(self) -> AgentStatus:
return self._status
@abstractmethod
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务的核心逻辑,子类必须实现"""
...
@abstractmethod
def get_capabilities(self) -> AgentCapability:
"""返回 Agent 能力声明"""
...
async def start(self):
"""启动 Agent注册到 Registry开始监听任务队列"""
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
# 初始化 Redis 连接
self._redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
# 注册到 Registry
from app.agent_framework.registry import AgentRegistry
registry = AgentRegistry()
capability = self.get_capabilities()
await registry.register(capability, endpoint=f"agent:{self.name}")
# 更新状态
self._status = AgentStatus.ONLINE
# 根据 capabilities 的 max_concurrency 初始化 Semaphore
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.info(
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
)
# 启动心跳
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
# 开始监听任务队列
self._listen_task = asyncio.create_task(self._listen_for_tasks())
logger.info(f"Agent '{self.name}' started successfully")
async def stop(self):
"""停止 Agent注销停止监听"""
logger.info(f"Stopping agent '{self.name}'")
self._status = AgentStatus.OFFLINE
# 取消监听任务
if self._listen_task and not self._listen_task.done():
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
# 取消心跳任务
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
# 注销
from app.agent_framework.registry import AgentRegistry
registry = AgentRegistry()
await registry.unregister(self.name)
# 关闭 Redis 连接
if self._redis:
await self._redis.close()
self._redis = None
logger.info(f"Agent '{self.name}' stopped")
async def heartbeat(self):
"""定期心跳上报"""
from app.agent_framework.registry import AgentRegistry
registry = AgentRegistry()
await registry.update_heartbeat(self.name)
async def report_progress(self, task_id: str, progress: float, message: str):
"""上报任务进度"""
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
progress=progress,
message=message,
updated_at=datetime.now(timezone.utc),
)
# 通过 Redis Pub/Sub 发布进度
if self._redis:
try:
await self._redis.publish(
f"agent:{self.name}:progress",
json.dumps(progress_obj.to_dict()),
)
except Exception as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
# 同时更新数据库
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_progress(progress_obj)
async def _heartbeat_loop(self):
"""心跳循环"""
try:
while self._status == AgentStatus.ONLINE:
await self.heartbeat()
await asyncio.sleep(30) # 每30秒一次心跳
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
async def _listen_for_tasks(self):
"""监听 Redis 任务队列"""
try:
queue_key = f"agent:{self.name}:tasks"
while self._status == AgentStatus.ONLINE:
if not self._redis:
await asyncio.sleep(1)
continue
# 阻塞式从队列中获取任务超时1秒以便定期检查状态
result = await self._redis.brpop(queue_key, timeout=1)
if result:
_, task_json = result
try:
task_data = json.loads(task_json)
task = TaskMessage.from_dict(task_data)
asyncio.create_task(self._execute_task_with_semaphore(task))
except Exception as e:
logger.error(f"Failed to parse task message: {e}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Task listener error for agent '{self.name}': {e}")
async def _execute_task_with_semaphore(self, task: TaskMessage):
"""通过 Semaphore 限制并发执行任务"""
if self._semaphore is None:
await self._execute_task(task)
return
async with self._semaphore:
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
"""执行单个任务"""
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
started_at = datetime.now(timezone.utc)
try:
logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})")
result = await self.execute(task)
result.started_at = started_at
result.completed_at = datetime.now(timezone.utc)
# 处理结果
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_result(result)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
# 构建失败结果
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics=None,
)
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_result(error_result)
finally:
self._running_tasks.discard(task.task_id)
if not self._running_tasks:
self._status = AgentStatus.ONLINE