"""BaseAgent 基类 - 定义所有 Agent 的标准生命周期""" import asyncio import json import logging from abc import ABC, abstractmethod from datetime import datetime, timezone import redis.asyncio as aioredis from app.agent_framework.exceptions import AgentNotReadyError from app.agent_framework.protocol import ( AgentCapability, AgentStatus, TaskMessage, TaskProgress, TaskResult, TaskStatus, ) from app.config import settings logger = logging.getLogger(__name__) # Module-level lazy singleton for TaskDispatcher — avoids creating a new # dispatcher on every method call while still deferring the import to # prevent circular-dependency issues at module-load time. _dispatcher_instance = None def _get_dispatcher(): """Return a cached TaskDispatcher instance (lazy singleton).""" global _dispatcher_instance if _dispatcher_instance is None: from app.agent_framework.dispatcher import TaskDispatcher _dispatcher_instance = TaskDispatcher(settings.REDIS_URL) return _dispatcher_instance # Module-level lazy singleton for AgentRegistry — same rationale. _registry_instance = None def _get_registry(): """Return a cached AgentRegistry instance (lazy singleton).""" global _registry_instance if _registry_instance is None: from app.agent_framework.registry import AgentRegistry _registry_instance = AgentRegistry() return _registry_instance class BaseAgent(ABC): """所有 Agent 的基类,定义标准生命周期""" def __init__(self, name: str, agent_type: str, version: str = "1.0.0"): self.name = name self.agent_type = agent_type self.version = version self._status: AgentStatus = AgentStatus.OFFLINE self._redis: aioredis.Redis | None = None self._running_tasks: set[str] = set() self._listen_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None self._semaphore: asyncio.Semaphore | None = None @property def status(self) -> AgentStatus: return self._status @property def is_distributed(self) -> bool: return self._redis is not None @abstractmethod async def execute(self, task: TaskMessage) -> TaskResult: """执行任务的核心逻辑,子类必须实现""" ... @abstractmethod def get_capabilities(self) -> AgentCapability: """返回 Agent 能力声明""" ... async def start(self): logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})") try: self._redis = aioredis.from_url( settings.REDIS_URL, decode_responses=True, ) await self._redis.ping() registry = _get_registry() capability = self.get_capabilities() await registry.register(capability, endpoint=f"agent:{self.name}") self._status = AgentStatus.ONLINE capability = self.get_capabilities() max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 self._semaphore = asyncio.Semaphore(max_concurrency) logger.info( f"Agent '{self.name}' concurrency limit set to {max_concurrency}" ) self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) self._listen_task = asyncio.create_task(self._listen_for_tasks()) logger.info(f"Agent '{self.name}' started in distributed mode") except Exception as e: self._redis = None self._status = AgentStatus.ONLINE capability = self.get_capabilities() max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 self._semaphore = asyncio.Semaphore(max_concurrency) logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})") async def stop(self): logger.info(f"Stopping agent '{self.name}'") self._status = AgentStatus.OFFLINE if self._listen_task and not self._listen_task.done(): self._listen_task.cancel() try: await self._listen_task except asyncio.CancelledError: pass if self._heartbeat_task and not self._heartbeat_task.done(): self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass if self._redis is not None: registry = _get_registry() await registry.unregister(self.name) await self._redis.close() self._redis = None logger.info(f"Agent '{self.name}' stopped") async def heartbeat(self): """定期心跳上报""" registry = _get_registry() await registry.update_heartbeat(self.name) async def report_progress(self, task_id: str, progress: float, message: str): progress_obj = TaskProgress( task_id=task_id, agent_name=self.name, progress=progress, message=message, updated_at=datetime.now(timezone.utc), ) if self._redis: try: await self._redis.publish( f"agent:{self.name}:progress", json.dumps(progress_obj.to_dict()), ) except Exception as e: logger.warning(f"Failed to publish progress for task {task_id}: {e}") try: dispatcher = _get_dispatcher() await dispatcher.handle_progress(progress_obj) except Exception as e: logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}") async def _heartbeat_loop(self): """心跳循环""" try: while self._status == AgentStatus.ONLINE: await self.heartbeat() await asyncio.sleep(30) # 每30秒一次心跳 except asyncio.CancelledError: pass except Exception as e: logger.error(f"Heartbeat error for agent '{self.name}': {e}") async def _listen_for_tasks(self): """监听 Redis 任务队列""" try: queue_key = f"agent:{self.name}:tasks" while self._status == AgentStatus.ONLINE: if not self._redis: await asyncio.sleep(1) continue # 阻塞式从队列中获取任务(超时1秒,以便定期检查状态) result = await self._redis.brpop(queue_key, timeout=1) if result: _, task_json = result try: task_data = json.loads(task_json) task = TaskMessage.from_dict(task_data) asyncio.create_task(self._execute_task_with_semaphore(task)) except Exception as e: logger.error(f"Failed to parse task message: {e}") except asyncio.CancelledError: pass except Exception as e: logger.error(f"Task listener error for agent '{self.name}': {e}") async def _execute_task_with_semaphore(self, task: TaskMessage): """通过 Semaphore 限制并发执行任务""" if self._semaphore is None: await self._execute_task(task) return async with self._semaphore: await self._execute_task(task) async def _execute_task(self, task: TaskMessage): self._running_tasks.add(task.task_id) self._status = AgentStatus.BUSY started_at = datetime.now(timezone.utc) try: logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})") result = await self.execute(task) result.started_at = started_at result.completed_at = datetime.now(timezone.utc) if self._redis is not None: dispatcher = _get_dispatcher() await dispatcher.handle_result(result) except Exception as e: logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") error_result = TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics=None, ) if self._redis is not None: dispatcher = _get_dispatcher() await dispatcher.handle_result(error_result) finally: self._running_tasks.discard(task.task_id) if not self._running_tasks: self._status = AgentStatus.ONLINE