"""BaseAgent 基类 - 定义所有 Agent 的标准生命周期""" import asyncio import json import logging from abc import ABC, abstractmethod from datetime import datetime, timezone import redis.asyncio as aioredis from app.agent_framework.exceptions import AgentNotReadyError from app.agent_framework.protocol import ( AgentCapability, AgentStatus, TaskMessage, TaskProgress, TaskResult, TaskStatus, ) from app.config import settings logger = logging.getLogger(__name__) class BaseAgent(ABC): """所有 Agent 的基类,定义标准生命周期""" def __init__(self, name: str, agent_type: str, version: str = "1.0.0"): self.name = name self.agent_type = agent_type self.version = version self._status: AgentStatus = AgentStatus.OFFLINE self._redis: aioredis.Redis | None = None self._running_tasks: set[str] = set() self._listen_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None self._semaphore: asyncio.Semaphore | None = None @property def status(self) -> AgentStatus: return self._status @abstractmethod async def execute(self, task: TaskMessage) -> TaskResult: """执行任务的核心逻辑,子类必须实现""" ... @abstractmethod def get_capabilities(self) -> AgentCapability: """返回 Agent 能力声明""" ... async def start(self): """启动 Agent:注册到 Registry,开始监听任务队列""" logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})") # 初始化 Redis 连接 self._redis = aioredis.from_url( settings.REDIS_URL, decode_responses=True, ) # 注册到 Registry from app.agent_framework.registry import AgentRegistry registry = AgentRegistry() capability = self.get_capabilities() await registry.register(capability, endpoint=f"agent:{self.name}") # 更新状态 self._status = AgentStatus.ONLINE # 根据 capabilities 的 max_concurrency 初始化 Semaphore capability = self.get_capabilities() max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 self._semaphore = asyncio.Semaphore(max_concurrency) logger.info( f"Agent '{self.name}' concurrency limit set to {max_concurrency}" ) # 启动心跳 self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) # 开始监听任务队列 self._listen_task = asyncio.create_task(self._listen_for_tasks()) logger.info(f"Agent '{self.name}' started successfully") async def stop(self): """停止 Agent:注销,停止监听""" logger.info(f"Stopping agent '{self.name}'") self._status = AgentStatus.OFFLINE # 取消监听任务 if self._listen_task and not self._listen_task.done(): self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass # 取消心跳任务 if self._heartbeat_task and not self._heartbeat_task.done(): self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass # 注销 from app.agent_framework.registry import AgentRegistry registry = AgentRegistry() await registry.unregister(self.name) # 关闭 Redis 连接 if self._redis: await self._redis.close() self._redis = None logger.info(f"Agent '{self.name}' stopped") async def heartbeat(self): """定期心跳上报""" from app.agent_framework.registry import AgentRegistry registry = AgentRegistry() await registry.update_heartbeat(self.name) async def report_progress(self, task_id: str, progress: float, message: str): """上报任务进度""" progress_obj = TaskProgress( task_id=task_id, agent_name=self.name, progress=progress, message=message, updated_at=datetime.now(timezone.utc), ) # 通过 Redis Pub/Sub 发布进度 if self._redis: try: await self._redis.publish( f"agent:{self.name}:progress", json.dumps(progress_obj.to_dict()), ) except Exception as e: logger.warning(f"Failed to publish progress for task {task_id}: {e}") # 同时更新数据库 from app.agent_framework.dispatcher import TaskDispatcher dispatcher = TaskDispatcher(settings.REDIS_URL) await dispatcher.handle_progress(progress_obj) async def _heartbeat_loop(self): """心跳循环""" try: while self._status == AgentStatus.ONLINE: await self.heartbeat() await asyncio.sleep(30) # 每30秒一次心跳 except asyncio.CancelledError: pass except Exception as e: logger.error(f"Heartbeat error for agent '{self.name}': {e}") async def _listen_for_tasks(self): """监听 Redis 任务队列""" try: queue_key = f"agent:{self.name}:tasks" while self._status == AgentStatus.ONLINE: if not self._redis: await asyncio.sleep(1) continue # 阻塞式从队列中获取任务(超时1秒,以便定期检查状态) result = await self._redis.brpop(queue_key, timeout=1) if result: _, task_json = result try: task_data = json.loads(task_json) task = TaskMessage.from_dict(task_data) asyncio.create_task(self._execute_task_with_semaphore(task)) except Exception as e: logger.error(f"Failed to parse task message: {e}") except asyncio.CancelledError: pass except Exception as e: logger.error(f"Task listener error for agent '{self.name}': {e}") async def _execute_task_with_semaphore(self, task: TaskMessage): """通过 Semaphore 限制并发执行任务""" if self._semaphore is None: await self._execute_task(task) return async with self._semaphore: await self._execute_task(task) async def _execute_task(self, task: TaskMessage): """执行单个任务""" self._running_tasks.add(task.task_id) self._status = AgentStatus.BUSY started_at = datetime.now(timezone.utc) try: logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})") result = await self.execute(task) result.started_at = started_at result.completed_at = datetime.now(timezone.utc) # 处理结果 from app.agent_framework.dispatcher import TaskDispatcher dispatcher = TaskDispatcher(settings.REDIS_URL) await dispatcher.handle_result(result) except Exception as e: logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") # 构建失败结果 error_result = TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics=None, ) from app.agent_framework.dispatcher import TaskDispatcher dispatcher = TaskDispatcher(settings.REDIS_URL) await dispatcher.handle_result(error_result) finally: self._running_tasks.discard(task.task_id) if not self._running_tasks: self._status = AgentStatus.ONLINE