"""BaseAgent 基类 - 统一 Agent 生命周期管理 核心设计: - execute() 为 final 方法,包含完整的计时、try/except、TaskResult 构建 - 子类只需实现 handle_task(task) -> dict 返回业务数据 - 生命周期钩子:on_task_start / on_task_complete / on_task_failed - 支持 Tool 插件、Memory 系统(可选注入) """ import asyncio import json import logging import time from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import TYPE_CHECKING import redis.asyncio as aioredis from agentkit.core.exceptions import SchemaValidationError, TaskCancelledError, TaskTimeoutError from agentkit.core.protocol import ( AgentCapability, AgentStatus, CancellationToken, HandoffMessage, TaskMessage, TaskProgress, TaskResult, TaskStatus, ) if TYPE_CHECKING: from agentkit.core.react import ReActEvent from agentkit.memory.base import Memory from agentkit.tools.base import Tool from agentkit.llm.gateway import LLMGateway from agentkit.skills.base import Skill from agentkit.quality.gate import QualityGate logger = logging.getLogger(__name__) class BaseAgent(ABC): """所有 Agent 的基类,定义标准生命周期。 子类只需实现: - handle_task(task) -> dict: 业务逻辑,返回 output_data - get_capabilities() -> AgentCapability: 能力声明 可选覆写: - on_task_start(task): 任务开始前的钩子 - on_task_complete(task, output): 任务成功后的钩子 - on_task_failed(task, error): 任务失败后的钩子 """ def __init__(self, name: str, agent_type: str, version: str = "1.0.0"): self.name = name self.agent_type = agent_type self.version = version self._status: AgentStatus = AgentStatus.OFFLINE self._redis: aioredis.Redis | None = None self._redis_url: str = "" self._running_tasks: set[str] = set() self._active_tokens: dict[str, CancellationToken] = {} self._listen_task: asyncio.Task | None = None self._heartbeat_task: asyncio.Task | None = None self._semaphore: asyncio.Semaphore | None = None self._status_lock: asyncio.Lock = asyncio.Lock() self._lock_timeout: float = 30.0 # Lock acquisition timeout (seconds) self._config_version: int = 0 # Configuration version counter # 可插拔能力(由子类或配置注入) self._tools: list["Tool"] = [] self._memory: "Memory | None" = None self._memory_retriever: object | None = None # 外部依赖注入(由 start() 时设置) self._registry = None self._dispatcher = None # v2 可插拔能力 self._llm_gateway: "LLMGateway | None" = None self._skill: "Skill | None" = None self._quality_gate: "QualityGate | None" = None @property def status(self) -> AgentStatus: return self._status @property def config_version(self) -> int: return self._config_version @property def is_distributed(self) -> bool: return self._redis is not None async def _acquire_status_lock(self) -> None: """Acquire status lock with timeout to prevent deadlocks.""" try: await asyncio.wait_for(self._status_lock.acquire(), timeout=self._lock_timeout) except asyncio.TimeoutError: logger.error( f"Agent '{self.name}' status lock acquisition timed out " f"after {self._lock_timeout}s — possible deadlock" ) raise RuntimeError("Status lock acquisition timed out") def _release_status_lock(self) -> None: """Release status lock safely.""" try: self._status_lock.release() except RuntimeError: pass # Lock not held, ignore @property def tools(self) -> list["Tool"]: return self._tools @property def memory(self) -> "Memory | None": return self._memory @property def llm_gateway(self) -> "LLMGateway | None": return self._llm_gateway @llm_gateway.setter def llm_gateway(self, gateway: "LLMGateway") -> None: self._llm_gateway = gateway @property def skill(self) -> "Skill | None": return self._skill @skill.setter def skill(self, skill: "Skill") -> None: self._skill = skill @property def quality_gate(self) -> "QualityGate": """获取 QualityGate 实例,懒初始化""" if self._quality_gate is None: from agentkit.quality.gate import QualityGate self._quality_gate = QualityGate() return self._quality_gate # ── 抽象方法(子类必须实现) ────────────────────────────── @abstractmethod async def handle_task(self, task: TaskMessage) -> dict: """执行任务的核心业务逻辑,子类必须实现。 返回 output_data dict,框架自动包装为 TaskResult。 """ ... @abstractmethod def get_capabilities(self) -> AgentCapability: """返回 Agent 能力声明""" ... # ── 流式执行(U3) ──────────────────────────────────────── async def execute_stream(self, task: TaskMessage) -> AsyncGenerator["ReActEvent", None]: """流式执行任务,yield ReActEvent 事件。 与 execute() 不同,此方法不包装 TaskResult — 直接 yield 事件, 由调用方(如 PhaseExecutorMixin._run_agent_steps)负责转发和 最终结果累积。默认实现回退到 handle_task 并包装为单个 final_answer 事件;子类应覆写以提供真正的流式输出。 """ # Default fallback: run sync handle_task, wrap as single final_answer. # Subclasses (e.g. ConfigDrivenAgent) override for real streaming. from agentkit.core.react import ReActEvent output = await self.handle_task(task) yield ReActEvent( event_type="final_answer", step=0, data={"output": output.get("content", str(output))}, ) # ── 生命周期钩子(可选覆写) ────────────────────────────── async def on_task_start(self, task: TaskMessage) -> None: """任务开始前的钩子,可用于加载记忆、准备上下文等""" pass async def on_task_complete(self, task: TaskMessage, output: dict) -> None: """任务成功后的钩子,可用于存储记忆、触发反思等""" pass async def on_task_failed(self, task: TaskMessage, error: Exception) -> None: """任务失败后的钩子,可用于记录失败模式等""" pass # ── v2 方法 ────────────────────────────────────────────── async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict: """Re-execute task with quality feedback (for retry) 默认实现直接调用 handle_task,子类可覆写以利用 feedback。 """ return await self.handle_task(task) def _build_quality_feedback(self, quality_result) -> str: """从 QualityResult 构建反馈字符串""" failed_checks = [c for c in quality_result.checks if not c.passed] lines = ["Quality check failed. Issues:"] for check in failed_checks: msg = check.message or f"Check '{check.name}' failed" lines.append(f" - {msg}") return "\n".join(lines) def _build_skill_context(self) -> dict[str, object] | None: """从当前技能配置构建 skill_context,用于 QualityGate skill_match 校验""" if not self._skill: return None intent = getattr(self._skill.config, "intent", None) if intent is None: return None keywords = list(intent.keywords) + list(intent.disambiguation_keywords) if not keywords: return None return {"intent_keywords": keywords} # ── 可插拔能力注入 ────────────────────────────────────── def use_tool(self, tool: "Tool") -> "BaseAgent": """添加工具到 Agent""" self._tools.append(tool) return self def use_memory(self, memory: "Memory") -> "BaseAgent": """设置记忆系统""" self._memory = memory return self def use_memory_retriever(self, retriever: object) -> "BaseAgent": """设置记忆检索器,用于上下文注入""" self._memory_retriever = retriever return self def set_registry(self, registry: object) -> "BaseAgent": """注入注册中心""" self._registry = registry return self def set_dispatcher(self, dispatcher: object) -> "BaseAgent": """注入任务分发器""" self._dispatcher = dispatcher return self # ── 核心生命周期 ────────────────────────────────────────── async def start(self, redis_url: str = ""): """启动 Agent:连接 Redis → 注册 → 心跳 → 监听""" self._redis_url = redis_url logger.info( f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})" ) if redis_url: try: self._redis = aioredis.from_url(redis_url, decode_responses=True) await self._redis.ping() logger.info(f"Agent '{self.name}' connected to Redis") except (ConnectionError, OSError, asyncio.TimeoutError, ValueError) as e: self._redis = None logger.warning( f"Agent '{self.name}' Redis unavailable: {e}, falling back to local mode" ) # 注册到 Registry if self._registry is not None: capability = self.get_capabilities() await self._registry.register(capability, endpoint=f"agent:{self.name}") async with self._status_lock: self._status = AgentStatus.ONLINE # 设置并发控制 capability = self.get_capabilities() max_concurrency = getattr(capability, "max_concurrency", 1) or 1 self._semaphore = asyncio.Semaphore(max_concurrency) # 启动心跳和监听 if self._redis is not None: self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) self._listen_task = asyncio.create_task(self._listen_for_tasks()) logger.info( f"Agent '{self.name}' started ({'distributed' if self._redis else 'local'} mode)" ) async def stop(self): """停止 Agent""" logger.info(f"Stopping agent '{self.name}'") async with self._status_lock: self._status = AgentStatus.OFFLINE for task in [self._listen_task, self._heartbeat_task]: if task and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass if self._redis is not None: if self._registry is not None: await self._registry.unregister(self.name) await self._redis.close() self._redis = None logger.info(f"Agent '{self.name}' stopped") # ── execute 为 final 方法 ───────────────────────────────── async def execute(self, task: TaskMessage) -> TaskResult: """执行任务(框架方法,不可覆写)。 完整流程:on_task_start → handle_task → quality_gate → on_task_complete/on_task_failed 自动处理计时、TaskResult 构建、错误捕获、超时和取消。 """ started_at = datetime.now(timezone.utc) start_time = time.monotonic() # 创建 CancellationToken 并存储 token = CancellationToken() self._active_tokens[task.task_id] = token try: # 前置钩子 await self.on_task_start(task) # Schema 校验(如果 Agent 声明了 input_schema) capability = self.get_capabilities() if capability.input_schema: self._validate_input(task.input_data, capability.input_schema) # 执行业务逻辑,带超时控制 timeout_seconds = task.timeout_seconds if timeout_seconds > 0: try: output = await asyncio.wait_for( self.handle_task(task), timeout=timeout_seconds, ) except asyncio.TimeoutError: raise TaskTimeoutError( task_id=task.task_id, timeout_seconds=timeout_seconds, ) else: output = await self.handle_task(task) # 检查是否在执行期间被取消 token.check() # v2: Quality Gate 检查 if self._skill: skill_context = self._build_skill_context() quality_result = await self.quality_gate.validate( output, self._skill, skill_context=skill_context ) if not quality_result.passed and quality_result.can_retry: max_retries = self._skill.config.quality_gate.max_retries retry_count = 0 while not quality_result.passed and retry_count < max_retries: feedback = self._build_quality_feedback(quality_result) output = await self.handle_task_with_feedback(task, feedback) quality_result = await self.quality_gate.validate( output, self._skill, skill_context=skill_context ) retry_count += 1 # 后置钩子 await self.on_task_complete(task, output) elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.COMPLETED, output_data=output, error_message=None, started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) except TaskCancelledError: logger.warning(f"Agent '{self.name}' task {task.task_id} was cancelled") # 失败钩子 try: await self.on_task_failed(task, TaskCancelledError(task.task_id)) except asyncio.CancelledError: raise except Exception as hook_err: # 用户提供的 hook — 任意异常都可能,不阻塞 TaskResult 构建 logger.error(f"on_task_failed hook error: {hook_err}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.CANCELLED, output_data=None, error_message=f"Task {task.task_id} was cancelled", started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) except TaskTimeoutError: logger.warning( f"Agent '{self.name}' task {task.task_id} timed out after {task.timeout_seconds}s" ) # 失败钩子 try: await self.on_task_failed( task, TaskTimeoutError(task.task_id, task.timeout_seconds) ) except asyncio.CancelledError: raise except Exception as hook_err: # 用户提供的 hook — 任意异常都可能,不阻塞 TaskResult 构建 logger.error(f"on_task_failed hook error: {hook_err}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=f"Task {task.task_id} timed out after {task.timeout_seconds}s", started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, "error_type": "TaskTimeoutError", }, ) except asyncio.CancelledError: # CancelledError 必须传播,不被 except Exception 吞掉 raise except Exception as e: # 框架边界 catch-all:handle_task 是用户实现,可能抛任意异常; # execute() 契约要求始终返回 TaskResult,故保留兜底。 logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") # 失败钩子 try: await self.on_task_failed(task, e) except asyncio.CancelledError: raise except Exception as hook_err: logger.error(f"on_task_failed hook error: {hook_err}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, "error_type": type(e).__name__, }, ) finally: self._active_tokens.pop(task.task_id, None) def cancel_task(self, task_id: str) -> bool: """取消正在执行的任务。 通过 CancellationToken 协作式取消,ReAct 循环在下次迭代时检查并停止。 返回 True 表示成功设置取消标志,False 表示任务不存在。 """ token = self._active_tokens.get(task_id) if token is not None: token.cancel() logger.info(f"Agent '{self.name}' cancellation requested for task {task_id}") return True return False # ── Handoff ─────────────────────────────────────────────── async def handoff( self, target_agent: str, task: TaskMessage, reason: str, context: dict[str, object] | None = None, ): """将当前任务转交给另一个 Agent""" if self._redis is None: raise RuntimeError("Handoff requires Redis connection") handoff_msg = HandoffMessage( source_agent=self.name, target_agent=target_agent, task_id=task.task_id, task_type=task.task_type, context=context or task.input_data, reason=reason, ) # 发布到目标 Agent 的 handoff 频道 await self._redis.publish( f"agent:{target_agent}:handoff", json.dumps(handoff_msg.to_dict()), ) logger.info( f"Agent '{self.name}' handed off task {task.task_id} to '{target_agent}': {reason}" ) # ── 进度上报 ────────────────────────────────────────────── async def report_progress(self, task_id: str, progress: float, message: str): progress_obj = TaskProgress( task_id=task_id, agent_name=self.name, progress=progress, message=message, updated_at=datetime.now(timezone.utc), ) if self._redis: try: await self._redis.publish( f"agent:{self.name}:progress", json.dumps(progress_obj.to_dict()), ) except (ConnectionError, asyncio.TimeoutError, OSError) as e: logger.warning(f"Failed to publish progress for task {task_id}: {e}") if self._dispatcher is not None: try: await self._dispatcher.handle_progress(progress_obj) except (asyncio.TimeoutError, ConnectionError, RuntimeError) as e: logger.warning( f"Failed to report progress to dispatcher for task {task_id}: {e}" ) # ── 内部方法 ────────────────────────────────────────────── async def heartbeat(self): if self._registry is not None: await self._registry.update_heartbeat(self.name) async def _heartbeat_loop(self): try: while True: async with self._status_lock: if self._status != AgentStatus.ONLINE: break await self.heartbeat() await asyncio.sleep(30) except asyncio.CancelledError: pass except (ConnectionError, asyncio.TimeoutError, OSError, RuntimeError) as e: logger.error(f"Heartbeat error for agent '{self.name}': {e}") async def _listen_for_tasks(self): try: queue_key = f"agent:{self.name}:tasks" while True: async with self._status_lock: if self._status != AgentStatus.ONLINE: break if not self._redis: await asyncio.sleep(1) continue result = await self._redis.brpop(queue_key, timeout=1) if result: _, task_json = result try: task_data = json.loads(task_json) task = TaskMessage.from_dict(task_data) asyncio.create_task(self._execute_task_with_semaphore(task)) except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e: logger.error(f"Failed to parse task message: {e}") except asyncio.CancelledError: pass except (ConnectionError, asyncio.TimeoutError, OSError, RuntimeError) as e: logger.error(f"Task listener error for agent '{self.name}': {e}") async def _execute_task_with_semaphore(self, task: TaskMessage): if self._semaphore is None: await self._execute_task(task) return async with self._semaphore: await self._execute_task(task) async def _execute_task(self, task: TaskMessage): async with self._status_lock: self._running_tasks.add(task.task_id) self._status = AgentStatus.BUSY try: logger.info( f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})" ) result = await self.execute(task) if self._redis is not None and self._dispatcher is not None: await self._dispatcher.handle_result(result) except asyncio.CancelledError: # CancelledError 必须传播,不被 except 吞掉 raise except Exception as e: # 兜底:execute() 内部已捕获大部分异常并返回 TaskResult, # 此处仅捕获 dispatcher 失败或 execute() 边界外的异常 logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") error_result = TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=datetime.now(timezone.utc), completed_at=datetime.now(timezone.utc), metrics=None, ) if self._redis is not None and self._dispatcher is not None: await self._dispatcher.handle_result(error_result) finally: async with self._status_lock: self._running_tasks.discard(task.task_id) if not self._running_tasks: self._status = AgentStatus.ONLINE def _validate_input(self, data: dict, schema: dict) -> None: """校验输入数据是否符合 JSON Schema""" try: import jsonschema jsonschema.validate(data, schema) except ImportError: logger.warning("jsonschema not installed, skipping input validation") except (ValueError, TypeError, KeyError) as e: # jsonschema.ValidationError 继承 ValueError;其余为 schema/data 类型错误 raise SchemaValidationError(self.name, str(e))