fischer-agentkit/src/agentkit/core/base.py

648 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""BaseAgent 基类 - 统一 Agent 生命周期管理
核心设计:
- execute() 为 final 方法包含完整的计时、try/except、TaskResult 构建
- 子类只需实现 handle_task(task) -> dict 返回业务数据
- 生命周期钩子on_task_start / on_task_complete / on_task_failed
- 支持 Tool 插件、Memory 系统(可选注入)
"""
import asyncio
import json
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
import redis.asyncio as aioredis
from agentkit.core.exceptions import SchemaValidationError, TaskCancelledError, TaskTimeoutError
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
CancellationToken,
HandoffMessage,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
if TYPE_CHECKING:
from agentkit.memory.base import Memory
from agentkit.tools.base import Tool
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.base import Skill
from agentkit.quality.gate import QualityGate
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期。
子类只需实现:
- handle_task(task) -> dict: 业务逻辑,返回 output_data
- get_capabilities() -> AgentCapability: 能力声明
可选覆写:
- on_task_start(task): 任务开始前的钩子
- on_task_complete(task, output): 任务成功后的钩子
- on_task_failed(task, error): 任务失败后的钩子
"""
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
self.name = name
self.agent_type = agent_type
self.version = version
self._status: AgentStatus = AgentStatus.OFFLINE
self._redis: aioredis.Redis | None = None
self._redis_url: str = ""
self._running_tasks: set[str] = set()
self._active_tokens: dict[str, CancellationToken] = {}
self._listen_task: asyncio.Task | None = None
self._heartbeat_task: asyncio.Task | None = None
self._semaphore: asyncio.Semaphore | None = None
self._status_lock: asyncio.Lock = asyncio.Lock()
self._lock_timeout: float = 30.0 # Lock acquisition timeout (seconds)
self._config_version: int = 0 # Configuration version counter
# 可插拔能力(由子类或配置注入)
self._tools: list["Tool"] = []
self._memory: "Memory | None" = None
self._memory_retriever: Any | None = None
# 外部依赖注入(由 start() 时设置)
self._registry = None
self._dispatcher = None
# v2 可插拔能力
self._llm_gateway: "LLMGateway | None" = None
self._skill: "Skill | None" = None
self._quality_gate: "QualityGate | None" = None
@property
def status(self) -> AgentStatus:
return self._status
@property
def config_version(self) -> int:
return self._config_version
@property
def is_distributed(self) -> bool:
return self._redis is not None
async def _acquire_status_lock(self) -> None:
"""Acquire status lock with timeout to prevent deadlocks."""
try:
await asyncio.wait_for(self._status_lock.acquire(), timeout=self._lock_timeout)
except asyncio.TimeoutError:
logger.error(
f"Agent '{self.name}' status lock acquisition timed out "
f"after {self._lock_timeout}s — possible deadlock"
)
raise RuntimeError("Status lock acquisition timed out")
def _release_status_lock(self) -> None:
"""Release status lock safely."""
try:
self._status_lock.release()
except RuntimeError:
pass # Lock not held, ignore
@property
def tools(self) -> list["Tool"]:
return self._tools
@property
def memory(self) -> "Memory | None":
return self._memory
@property
def llm_gateway(self) -> "LLMGateway | None":
return self._llm_gateway
@llm_gateway.setter
def llm_gateway(self, gateway: "LLMGateway") -> None:
self._llm_gateway = gateway
@property
def skill(self) -> "Skill | None":
return self._skill
@skill.setter
def skill(self, skill: "Skill") -> None:
self._skill = skill
@property
def quality_gate(self) -> "QualityGate":
"""获取 QualityGate 实例,懒初始化"""
if self._quality_gate is None:
from agentkit.quality.gate import QualityGate
self._quality_gate = QualityGate()
return self._quality_gate
# ── 抽象方法(子类必须实现) ──────────────────────────────
@abstractmethod
async def handle_task(self, task: TaskMessage) -> dict:
"""执行任务的核心业务逻辑,子类必须实现。
返回 output_data dict框架自动包装为 TaskResult。
"""
...
@abstractmethod
def get_capabilities(self) -> AgentCapability:
"""返回 Agent 能力声明"""
...
# ── 生命周期钩子(可选覆写) ──────────────────────────────
async def on_task_start(self, task: TaskMessage) -> None:
"""任务开始前的钩子,可用于加载记忆、准备上下文等"""
pass
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
"""任务成功后的钩子,可用于存储记忆、触发反思等"""
pass
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
"""任务失败后的钩子,可用于记录失败模式等"""
pass
# ── v2 方法 ──────────────────────────────────────────────
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
"""Re-execute task with quality feedback (for retry)
默认实现直接调用 handle_task子类可覆写以利用 feedback。
"""
return await self.handle_task(task)
def _build_quality_feedback(self, quality_result) -> str:
"""从 QualityResult 构建反馈字符串"""
failed_checks = [c for c in quality_result.checks if not c.passed]
lines = ["Quality check failed. Issues:"]
for check in failed_checks:
msg = check.message or f"Check '{check.name}' failed"
lines.append(f" - {msg}")
return "\n".join(lines)
def _build_skill_context(self) -> dict[str, Any] | None:
"""从当前技能配置构建 skill_context用于 QualityGate skill_match 校验"""
if not self._skill:
return None
intent = getattr(self._skill.config, "intent", None)
if intent is None:
return None
keywords = list(intent.keywords) + list(intent.disambiguation_keywords)
if not keywords:
return None
return {"intent_keywords": keywords}
# ── 可插拔能力注入 ──────────────────────────────────────
def use_tool(self, tool: "Tool") -> "BaseAgent":
"""添加工具到 Agent"""
self._tools.append(tool)
return self
def use_memory(self, memory: "Memory") -> "BaseAgent":
"""设置记忆系统"""
self._memory = memory
return self
def use_memory_retriever(self, retriever: Any) -> "BaseAgent":
"""设置记忆检索器,用于上下文注入"""
self._memory_retriever = retriever
return self
def set_registry(self, registry: Any) -> "BaseAgent":
"""注入注册中心"""
self._registry = registry
return self
def set_dispatcher(self, dispatcher: Any) -> "BaseAgent":
"""注入任务分发器"""
self._dispatcher = dispatcher
return self
# ── 核心生命周期 ──────────────────────────────────────────
async def start(self, redis_url: str = ""):
"""启动 Agent连接 Redis → 注册 → 心跳 → 监听"""
self._redis_url = redis_url
logger.info(
f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})"
)
if redis_url:
try:
self._redis = aioredis.from_url(redis_url, decode_responses=True)
await self._redis.ping()
logger.info(f"Agent '{self.name}' connected to Redis")
except (ConnectionError, OSError, asyncio.TimeoutError, ValueError) as e:
self._redis = None
logger.warning(
f"Agent '{self.name}' Redis unavailable: {e}, falling back to local mode"
)
# 注册到 Registry
if self._registry is not None:
capability = self.get_capabilities()
await self._registry.register(capability, endpoint=f"agent:{self.name}")
async with self._status_lock:
self._status = AgentStatus.ONLINE
# 设置并发控制
capability = self.get_capabilities()
max_concurrency = getattr(capability, "max_concurrency", 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
# 启动心跳和监听
if self._redis is not None:
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self._listen_task = asyncio.create_task(self._listen_for_tasks())
logger.info(
f"Agent '{self.name}' started ({'distributed' if self._redis else 'local'} mode)"
)
async def stop(self):
"""停止 Agent"""
logger.info(f"Stopping agent '{self.name}'")
async with self._status_lock:
self._status = AgentStatus.OFFLINE
for task in [self._listen_task, self._heartbeat_task]:
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
if self._redis is not None:
if self._registry is not None:
await self._registry.unregister(self.name)
await self._redis.close()
self._redis = None
logger.info(f"Agent '{self.name}' stopped")
# ── execute 为 final 方法 ─────────────────────────────────
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务(框架方法,不可覆写)。
完整流程on_task_start → handle_task → quality_gate → on_task_complete/on_task_failed
自动处理计时、TaskResult 构建、错误捕获、超时和取消。
"""
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
# 创建 CancellationToken 并存储
token = CancellationToken()
self._active_tokens[task.task_id] = token
try:
# 前置钩子
await self.on_task_start(task)
# Schema 校验(如果 Agent 声明了 input_schema
capability = self.get_capabilities()
if capability.input_schema:
self._validate_input(task.input_data, capability.input_schema)
# 执行业务逻辑,带超时控制
timeout_seconds = task.timeout_seconds
if timeout_seconds > 0:
try:
output = await asyncio.wait_for(
self.handle_task(task),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
raise TaskTimeoutError(
task_id=task.task_id,
timeout_seconds=timeout_seconds,
)
else:
output = await self.handle_task(task)
# 检查是否在执行期间被取消
token.check()
# v2: Quality Gate 检查
if self._skill:
skill_context = self._build_skill_context()
quality_result = await self.quality_gate.validate(
output, self._skill, skill_context=skill_context
)
if not quality_result.passed and quality_result.can_retry:
max_retries = self._skill.config.quality_gate.max_retries
retry_count = 0
while not quality_result.passed and retry_count < max_retries:
feedback = self._build_quality_feedback(quality_result)
output = await self.handle_task_with_feedback(task, feedback)
quality_result = await self.quality_gate.validate(
output, self._skill, skill_context=skill_context
)
retry_count += 1
# 后置钩子
await self.on_task_complete(task, output)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except TaskCancelledError:
logger.warning(f"Agent '{self.name}' task {task.task_id} was cancelled")
# 失败钩子
try:
await self.on_task_failed(task, TaskCancelledError(task.task_id))
except asyncio.CancelledError:
raise
except Exception as hook_err:
# 用户提供的 hook — 任意异常都可能,不阻塞 TaskResult 构建
logger.error(f"on_task_failed hook error: {hook_err}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.CANCELLED,
output_data=None,
error_message=f"Task {task.task_id} was cancelled",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except TaskTimeoutError:
logger.warning(
f"Agent '{self.name}' task {task.task_id} timed out after {task.timeout_seconds}s"
)
# 失败钩子
try:
await self.on_task_failed(
task, TaskTimeoutError(task.task_id, task.timeout_seconds)
)
except asyncio.CancelledError:
raise
except Exception as hook_err:
# 用户提供的 hook — 任意异常都可能,不阻塞 TaskResult 构建
logger.error(f"on_task_failed hook error: {hook_err}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"Task {task.task_id} timed out after {task.timeout_seconds}s",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
"error_type": "TaskTimeoutError",
},
)
except asyncio.CancelledError:
# CancelledError 必须传播,不被 except Exception 吞掉
raise
except Exception as e:
# 框架边界 catch-allhandle_task 是用户实现,可能抛任意异常;
# execute() 契约要求始终返回 TaskResult故保留兜底。
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
# 失败钩子
try:
await self.on_task_failed(task, e)
except asyncio.CancelledError:
raise
except Exception as hook_err:
logger.error(f"on_task_failed hook error: {hook_err}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
"error_type": type(e).__name__,
},
)
finally:
self._active_tokens.pop(task.task_id, None)
def cancel_task(self, task_id: str) -> bool:
"""取消正在执行的任务。
通过 CancellationToken 协作式取消ReAct 循环在下次迭代时检查并停止。
返回 True 表示成功设置取消标志False 表示任务不存在。
"""
token = self._active_tokens.get(task_id)
if token is not None:
token.cancel()
logger.info(f"Agent '{self.name}' cancellation requested for task {task_id}")
return True
return False
# ── Handoff ───────────────────────────────────────────────
async def handoff(
self,
target_agent: str,
task: TaskMessage,
reason: str,
context: dict[str, Any] | None = None,
):
"""将当前任务转交给另一个 Agent"""
if self._redis is None:
raise RuntimeError("Handoff requires Redis connection")
handoff_msg = HandoffMessage(
source_agent=self.name,
target_agent=target_agent,
task_id=task.task_id,
task_type=task.task_type,
context=context or task.input_data,
reason=reason,
)
# 发布到目标 Agent 的 handoff 频道
await self._redis.publish(
f"agent:{target_agent}:handoff",
json.dumps(handoff_msg.to_dict()),
)
logger.info(
f"Agent '{self.name}' handed off task {task.task_id} to '{target_agent}': {reason}"
)
# ── 进度上报 ──────────────────────────────────────────────
async def report_progress(self, task_id: str, progress: float, message: str):
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
progress=progress,
message=message,
updated_at=datetime.now(timezone.utc),
)
if self._redis:
try:
await self._redis.publish(
f"agent:{self.name}:progress",
json.dumps(progress_obj.to_dict()),
)
except (ConnectionError, asyncio.TimeoutError, OSError) as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
if self._dispatcher is not None:
try:
await self._dispatcher.handle_progress(progress_obj)
except (asyncio.TimeoutError, ConnectionError, RuntimeError) as e:
logger.warning(
f"Failed to report progress to dispatcher for task {task_id}: {e}"
)
# ── 内部方法 ──────────────────────────────────────────────
async def heartbeat(self):
if self._registry is not None:
await self._registry.update_heartbeat(self.name)
async def _heartbeat_loop(self):
try:
while True:
async with self._status_lock:
if self._status != AgentStatus.ONLINE:
break
await self.heartbeat()
await asyncio.sleep(30)
except asyncio.CancelledError:
pass
except (ConnectionError, asyncio.TimeoutError, OSError, RuntimeError) as e:
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
async def _listen_for_tasks(self):
try:
queue_key = f"agent:{self.name}:tasks"
while True:
async with self._status_lock:
if self._status != AgentStatus.ONLINE:
break
if not self._redis:
await asyncio.sleep(1)
continue
result = await self._redis.brpop(queue_key, timeout=1)
if result:
_, task_json = result
try:
task_data = json.loads(task_json)
task = TaskMessage.from_dict(task_data)
asyncio.create_task(self._execute_task_with_semaphore(task))
except (json.JSONDecodeError, KeyError, TypeError, ValueError) as e:
logger.error(f"Failed to parse task message: {e}")
except asyncio.CancelledError:
pass
except (ConnectionError, asyncio.TimeoutError, OSError, RuntimeError) as e:
logger.error(f"Task listener error for agent '{self.name}': {e}")
async def _execute_task_with_semaphore(self, task: TaskMessage):
if self._semaphore is None:
await self._execute_task(task)
return
async with self._semaphore:
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
async with self._status_lock:
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
try:
logger.info(
f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})"
)
result = await self.execute(task)
if self._redis is not None and self._dispatcher is not None:
await self._dispatcher.handle_result(result)
except asyncio.CancelledError:
# CancelledError 必须传播,不被 except 吞掉
raise
except Exception as e:
# 兜底execute() 内部已捕获大部分异常并返回 TaskResult
# 此处仅捕获 dispatcher 失败或 execute() 边界外的异常
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics=None,
)
if self._redis is not None and self._dispatcher is not None:
await self._dispatcher.handle_result(error_result)
finally:
async with self._status_lock:
self._running_tasks.discard(task.task_id)
if not self._running_tasks:
self._status = AgentStatus.ONLINE
def _validate_input(self, data: dict, schema: dict) -> None:
"""校验输入数据是否符合 JSON Schema"""
try:
import jsonschema
jsonschema.validate(data, schema)
except ImportError:
logger.warning("jsonschema not installed, skipping input validation")
except (ValueError, TypeError, KeyError) as e:
# jsonschema.ValidationError 继承 ValueError其余为 schema/data 类型错误
raise SchemaValidationError(self.name, str(e))