feat: initial fischer-agentkit package with unified agent architecture
- BaseAgent with handle_task() pattern (execute template moved up) - Protocol: TaskMessage, TaskResult, HandoffMessage, EvolutionEvent - Tool system: FunctionTool, AgentTool, ToolRegistry with versioning - Memory system: WorkingMemory (Redis), EpisodicMemory (pgvector), SemanticMemory (RAG adapter), MemoryRetriever (hybrid) - Evolution engine: Reflector, PromptOptimizer (DSPy-style), StrategyTuner, ABTester, EvolutionStore - Orchestrator: PipelineEngine (parallel DAG), PipelineLoader (YAML), HandoffManager, DynamicPipeline - MCP: Server (FastAPI), Client (httpx), MCPTool - Prompts: PromptTemplate, PromptSection - Exceptions: full hierarchy including Tool, Schema, Handoff, Evolution errors - Tests: unit tests for core, tools, protocol, evolution, pipeline
This commit is contained in:
commit
9a6d6fee4e
|
|
@ -0,0 +1,48 @@
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "fischer-agentkit"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Unified Agent Framework with Tool/Skill plugins, Memory, Self-Evolution, MCP support, and Multi-Agent orchestration"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Fischer Team"},
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"pydantic>=2.0",
|
||||||
|
"redis[hiredis]>=5.0",
|
||||||
|
"sqlalchemy[asyncio]>=2.0",
|
||||||
|
"asyncpg>=0.29",
|
||||||
|
"httpx>=0.27",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"jsonschema>=4.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
mcp = [
|
||||||
|
"mcp>=1.0",
|
||||||
|
]
|
||||||
|
evolution = [
|
||||||
|
"scipy>=1.12",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0",
|
||||||
|
"pytest-asyncio>=0.23",
|
||||||
|
"pytest-cov>=5.0",
|
||||||
|
"ruff>=0.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/agentkit"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
"""Fischer AgentKit - Unified Agent Framework"""
|
||||||
|
|
||||||
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
AgentStatus,
|
||||||
|
HandoffMessage,
|
||||||
|
TaskMessage,
|
||||||
|
TaskProgress,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAgent",
|
||||||
|
"AgentCapability",
|
||||||
|
"AgentStatus",
|
||||||
|
"HandoffMessage",
|
||||||
|
"TaskMessage",
|
||||||
|
"TaskProgress",
|
||||||
|
"TaskResult",
|
||||||
|
"TaskStatus",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
"""AgentKit Core - 基础组件"""
|
||||||
|
|
||||||
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.exceptions import (
|
||||||
|
AgentAlreadyRegisteredError,
|
||||||
|
AgentFrameworkError,
|
||||||
|
AgentNotFoundError,
|
||||||
|
AgentNotReadyError,
|
||||||
|
AgentUnavailableError,
|
||||||
|
ConfigValidationError,
|
||||||
|
EvolutionError,
|
||||||
|
HandoffError,
|
||||||
|
NoAvailableAgentError,
|
||||||
|
SchemaValidationError,
|
||||||
|
TaskCancelledError,
|
||||||
|
TaskDispatchError,
|
||||||
|
TaskExecutionError,
|
||||||
|
TaskNotFoundError,
|
||||||
|
TaskTimeoutError,
|
||||||
|
ToolExecutionError,
|
||||||
|
ToolNotFoundError,
|
||||||
|
)
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
AgentStatus,
|
||||||
|
EvolutionEvent,
|
||||||
|
HandoffMessage,
|
||||||
|
TaskMessage,
|
||||||
|
TaskProgress,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAgent",
|
||||||
|
"AgentCapability",
|
||||||
|
"AgentStatus",
|
||||||
|
"AgentFrameworkError",
|
||||||
|
"AgentNotFoundError",
|
||||||
|
"AgentAlreadyRegisteredError",
|
||||||
|
"AgentUnavailableError",
|
||||||
|
"AgentNotReadyError",
|
||||||
|
"TaskNotFoundError",
|
||||||
|
"TaskDispatchError",
|
||||||
|
"TaskExecutionError",
|
||||||
|
"TaskTimeoutError",
|
||||||
|
"TaskCancelledError",
|
||||||
|
"NoAvailableAgentError",
|
||||||
|
"ConfigValidationError",
|
||||||
|
"SchemaValidationError",
|
||||||
|
"HandoffError",
|
||||||
|
"EvolutionError",
|
||||||
|
"ToolNotFoundError",
|
||||||
|
"ToolExecutionError",
|
||||||
|
"HandoffMessage",
|
||||||
|
"EvolutionEvent",
|
||||||
|
"TaskMessage",
|
||||||
|
"TaskProgress",
|
||||||
|
"TaskResult",
|
||||||
|
"TaskStatus",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,395 @@
|
||||||
|
"""BaseAgent 基类 - 统一 Agent 生命周期管理
|
||||||
|
|
||||||
|
核心设计:
|
||||||
|
- execute() 为 final 方法,包含完整的计时、try/except、TaskResult 构建
|
||||||
|
- 子类只需实现 handle_task(task) -> dict 返回业务数据
|
||||||
|
- 生命周期钩子:on_task_start / on_task_complete / on_task_failed
|
||||||
|
- 支持 Tool 插件、Memory 系统(可选注入)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import AgentNotReadyError, SchemaValidationError
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
AgentStatus,
|
||||||
|
HandoffMessage,
|
||||||
|
TaskMessage,
|
||||||
|
TaskProgress,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.memory.base import Memory
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""所有 Agent 的基类,定义标准生命周期。
|
||||||
|
|
||||||
|
子类只需实现:
|
||||||
|
- handle_task(task) -> dict: 业务逻辑,返回 output_data
|
||||||
|
- get_capabilities() -> AgentCapability: 能力声明
|
||||||
|
|
||||||
|
可选覆写:
|
||||||
|
- on_task_start(task): 任务开始前的钩子
|
||||||
|
- on_task_complete(task, output): 任务成功后的钩子
|
||||||
|
- on_task_failed(task, error): 任务失败后的钩子
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
|
||||||
|
self.name = name
|
||||||
|
self.agent_type = agent_type
|
||||||
|
self.version = version
|
||||||
|
self._status: AgentStatus = AgentStatus.OFFLINE
|
||||||
|
self._redis: aioredis.Redis | None = None
|
||||||
|
self._redis_url: str = ""
|
||||||
|
self._running_tasks: set[str] = set()
|
||||||
|
self._listen_task: asyncio.Task | None = None
|
||||||
|
self._heartbeat_task: asyncio.Task | None = None
|
||||||
|
self._semaphore: asyncio.Semaphore | None = None
|
||||||
|
|
||||||
|
# 可插拔能力(由子类或配置注入)
|
||||||
|
self._tools: list["Tool"] = []
|
||||||
|
self._memory: "Memory | None" = None
|
||||||
|
|
||||||
|
# 外部依赖注入(由 start() 时设置)
|
||||||
|
self._registry = None
|
||||||
|
self._dispatcher = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> AgentStatus:
|
||||||
|
return self._status
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_distributed(self) -> bool:
|
||||||
|
return self._redis is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools(self) -> list["Tool"]:
|
||||||
|
return self._tools
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory(self) -> "Memory | None":
|
||||||
|
return self._memory
|
||||||
|
|
||||||
|
# ── 抽象方法(子类必须实现) ──────────────────────────────
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle_task(self, task: TaskMessage) -> dict:
|
||||||
|
"""执行任务的核心业务逻辑,子类必须实现。
|
||||||
|
|
||||||
|
返回 output_data dict,框架自动包装为 TaskResult。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_capabilities(self) -> AgentCapability:
|
||||||
|
"""返回 Agent 能力声明"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# ── 生命周期钩子(可选覆写) ──────────────────────────────
|
||||||
|
|
||||||
|
async def on_task_start(self, task: TaskMessage) -> None:
|
||||||
|
"""任务开始前的钩子,可用于加载记忆、准备上下文等"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
||||||
|
"""任务成功后的钩子,可用于存储记忆、触发反思等"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
|
||||||
|
"""任务失败后的钩子,可用于记录失败模式等"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── 可插拔能力注入 ──────────────────────────────────────
|
||||||
|
|
||||||
|
def use_tool(self, tool: "Tool") -> "BaseAgent":
|
||||||
|
"""添加工具到 Agent"""
|
||||||
|
self._tools.append(tool)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def use_memory(self, memory: "Memory") -> "BaseAgent":
|
||||||
|
"""设置记忆系统"""
|
||||||
|
self._memory = memory
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_registry(self, registry: Any) -> "BaseAgent":
|
||||||
|
"""注入注册中心"""
|
||||||
|
self._registry = registry
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_dispatcher(self, dispatcher: Any) -> "BaseAgent":
|
||||||
|
"""注入任务分发器"""
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
return self
|
||||||
|
|
||||||
|
# ── 核心生命周期 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
async def start(self, redis_url: str = ""):
|
||||||
|
"""启动 Agent:连接 Redis → 注册 → 心跳 → 监听"""
|
||||||
|
self._redis_url = redis_url
|
||||||
|
|
||||||
|
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
|
||||||
|
|
||||||
|
if redis_url:
|
||||||
|
try:
|
||||||
|
self._redis = aioredis.from_url(redis_url, decode_responses=True)
|
||||||
|
await self._redis.ping()
|
||||||
|
logger.info(f"Agent '{self.name}' connected to Redis")
|
||||||
|
except Exception as e:
|
||||||
|
self._redis = None
|
||||||
|
logger.warning(f"Agent '{self.name}' Redis unavailable: {e}, falling back to local mode")
|
||||||
|
|
||||||
|
# 注册到 Registry
|
||||||
|
if self._registry is not None:
|
||||||
|
capability = self.get_capabilities()
|
||||||
|
await self._registry.register(capability, endpoint=f"agent:{self.name}")
|
||||||
|
|
||||||
|
self._status = AgentStatus.ONLINE
|
||||||
|
|
||||||
|
# 设置并发控制
|
||||||
|
capability = self.get_capabilities()
|
||||||
|
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||||||
|
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
# 启动心跳和监听
|
||||||
|
if self._redis is not None:
|
||||||
|
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||||
|
self._listen_task = asyncio.create_task(self._listen_for_tasks())
|
||||||
|
|
||||||
|
logger.info(f"Agent '{self.name}' started ({'distributed' if self._redis else 'local'} mode)")
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止 Agent"""
|
||||||
|
logger.info(f"Stopping agent '{self.name}'")
|
||||||
|
self._status = AgentStatus.OFFLINE
|
||||||
|
|
||||||
|
for task in [self._listen_task, self._heartbeat_task]:
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._redis is not None:
|
||||||
|
if self._registry is not None:
|
||||||
|
await self._registry.unregister(self.name)
|
||||||
|
await self._redis.close()
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
logger.info(f"Agent '{self.name}' stopped")
|
||||||
|
|
||||||
|
# ── execute 为 final 方法 ─────────────────────────────────
|
||||||
|
|
||||||
|
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||||
|
"""执行任务(框架方法,不可覆写)。
|
||||||
|
|
||||||
|
完整流程:on_task_start → handle_task → on_task_complete/on_task_failed
|
||||||
|
自动处理计时、TaskResult 构建、错误捕获。
|
||||||
|
"""
|
||||||
|
started_at = datetime.now(timezone.utc)
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 前置钩子
|
||||||
|
await self.on_task_start(task)
|
||||||
|
|
||||||
|
# Schema 校验(如果 Agent 声明了 input_schema)
|
||||||
|
capability = self.get_capabilities()
|
||||||
|
if capability.input_schema:
|
||||||
|
self._validate_input(task.input_data, capability.input_schema)
|
||||||
|
|
||||||
|
# 执行业务逻辑
|
||||||
|
output = await self.handle_task(task)
|
||||||
|
|
||||||
|
# 后置钩子
|
||||||
|
await self.on_task_complete(task, output)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data=output,
|
||||||
|
error_message=None,
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={
|
||||||
|
"elapsed_seconds": round(elapsed, 2),
|
||||||
|
"task_type": task.task_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
|
||||||
|
|
||||||
|
# 失败钩子
|
||||||
|
try:
|
||||||
|
await self.on_task_failed(task, e)
|
||||||
|
except Exception as hook_err:
|
||||||
|
logger.error(f"on_task_failed hook error: {hook_err}")
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
status=TaskStatus.FAILED,
|
||||||
|
output_data=None,
|
||||||
|
error_message=str(e),
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={
|
||||||
|
"elapsed_seconds": round(elapsed, 2),
|
||||||
|
"task_type": task.task_type,
|
||||||
|
"error_type": type(e).__name__,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Handoff ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def handoff(self, target_agent: str, task: TaskMessage, reason: str, context: dict[str, Any] | None = None):
|
||||||
|
"""将当前任务转交给另一个 Agent"""
|
||||||
|
if self._redis is None:
|
||||||
|
raise RuntimeError("Handoff requires Redis connection")
|
||||||
|
|
||||||
|
handoff_msg = HandoffMessage(
|
||||||
|
source_agent=self.name,
|
||||||
|
target_agent=target_agent,
|
||||||
|
task_id=task.task_id,
|
||||||
|
task_type=task.task_type,
|
||||||
|
context=context or task.input_data,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发布到目标 Agent 的 handoff 频道
|
||||||
|
await self._redis.publish(
|
||||||
|
f"agent:{target_agent}:handoff",
|
||||||
|
json.dumps(handoff_msg.to_dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Agent '{self.name}' handed off task {task.task_id} to '{target_agent}': {reason}")
|
||||||
|
|
||||||
|
# ── 进度上报 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def report_progress(self, task_id: str, progress: float, message: str):
|
||||||
|
progress_obj = TaskProgress(
|
||||||
|
task_id=task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
progress=progress,
|
||||||
|
message=message,
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._redis:
|
||||||
|
try:
|
||||||
|
await self._redis.publish(
|
||||||
|
f"agent:{self.name}:progress",
|
||||||
|
json.dumps(progress_obj.to_dict()),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
|
||||||
|
|
||||||
|
if self._dispatcher is not None:
|
||||||
|
try:
|
||||||
|
await self._dispatcher.handle_progress(progress_obj)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
|
||||||
|
|
||||||
|
# ── 内部方法 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def heartbeat(self):
|
||||||
|
if self._registry is not None:
|
||||||
|
await self._registry.update_heartbeat(self.name)
|
||||||
|
|
||||||
|
async def _heartbeat_loop(self):
|
||||||
|
try:
|
||||||
|
while self._status == AgentStatus.ONLINE:
|
||||||
|
await self.heartbeat()
|
||||||
|
await asyncio.sleep(30)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
|
||||||
|
|
||||||
|
async def _listen_for_tasks(self):
|
||||||
|
try:
|
||||||
|
queue_key = f"agent:{self.name}:tasks"
|
||||||
|
while self._status == AgentStatus.ONLINE:
|
||||||
|
if not self._redis:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = await self._redis.brpop(queue_key, timeout=1)
|
||||||
|
if result:
|
||||||
|
_, task_json = result
|
||||||
|
try:
|
||||||
|
task_data = json.loads(task_json)
|
||||||
|
task = TaskMessage.from_dict(task_data)
|
||||||
|
asyncio.create_task(self._execute_task_with_semaphore(task))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse task message: {e}")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task listener error for agent '{self.name}': {e}")
|
||||||
|
|
||||||
|
async def _execute_task_with_semaphore(self, task: TaskMessage):
|
||||||
|
if self._semaphore is None:
|
||||||
|
await self._execute_task(task)
|
||||||
|
return
|
||||||
|
async with self._semaphore:
|
||||||
|
await self._execute_task(task)
|
||||||
|
|
||||||
|
async def _execute_task(self, task: TaskMessage):
|
||||||
|
self._running_tasks.add(task.task_id)
|
||||||
|
self._status = AgentStatus.BUSY
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})")
|
||||||
|
result = await self.execute(task)
|
||||||
|
|
||||||
|
if self._redis is not None and self._dispatcher is not None:
|
||||||
|
await self._dispatcher.handle_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
|
||||||
|
error_result = TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
status=TaskStatus.FAILED,
|
||||||
|
output_data=None,
|
||||||
|
error_message=str(e),
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics=None,
|
||||||
|
)
|
||||||
|
if self._redis is not None and self._dispatcher is not None:
|
||||||
|
await self._dispatcher.handle_result(error_result)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._running_tasks.discard(task.task_id)
|
||||||
|
if not self._running_tasks:
|
||||||
|
self._status = AgentStatus.ONLINE
|
||||||
|
|
||||||
|
def _validate_input(self, data: dict, schema: dict) -> None:
|
||||||
|
"""校验输入数据是否符合 JSON Schema"""
|
||||||
|
try:
|
||||||
|
import jsonschema
|
||||||
|
jsonschema.validate(data, schema)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("jsonschema not installed, skipping input validation")
|
||||||
|
except Exception as e:
|
||||||
|
raise SchemaValidationError(self.name, str(e))
|
||||||
|
|
@ -0,0 +1,361 @@
|
||||||
|
"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent
|
||||||
|
|
||||||
|
与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import (
|
||||||
|
NoAvailableAgentError,
|
||||||
|
TaskDispatchError,
|
||||||
|
TaskNotFoundError,
|
||||||
|
)
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentStatus,
|
||||||
|
TaskMessage,
|
||||||
|
TaskProgress,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDispatcher:
|
||||||
|
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_factory: Callable[[], Awaitable[Any]],
|
||||||
|
session_factory: Callable[[], Any],
|
||||||
|
agent_model: Any,
|
||||||
|
task_model: Any,
|
||||||
|
task_log_model: Any,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
redis_factory: 返回 Redis 连接的异步工厂
|
||||||
|
session_factory: 返回 async context manager 的工厂,用于获取数据库会话
|
||||||
|
agent_model: Agent ORM 模型类
|
||||||
|
task_model: AgentTask ORM 模型类
|
||||||
|
task_log_model: AgentTaskLog ORM 模型类
|
||||||
|
"""
|
||||||
|
self._redis_factory = redis_factory
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._agent_model = agent_model
|
||||||
|
self._task_model = task_model
|
||||||
|
self._task_log_model = task_log_model
|
||||||
|
|
||||||
|
async def _get_redis(self):
|
||||||
|
return await self._redis_factory()
|
||||||
|
|
||||||
|
async def dispatch(
|
||||||
|
self,
|
||||||
|
task: TaskMessage,
|
||||||
|
organization_id: str | None = None,
|
||||||
|
created_by: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""分发任务到对应 Agent 的队列,返回 task_id"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
AgentModel = self._agent_model
|
||||||
|
TaskModel = self._task_model
|
||||||
|
|
||||||
|
# 查找目标 Agent
|
||||||
|
stmt = select(AgentModel).where(AgentModel.name == task.agent_name)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agent = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
raise TaskDispatchError(task.task_id, f"Agent '{task.agent_name}' not found")
|
||||||
|
|
||||||
|
if agent.status != AgentStatus.ONLINE:
|
||||||
|
raise TaskDispatchError(
|
||||||
|
task.task_id,
|
||||||
|
f"Agent '{task.agent_name}' is not online (status={agent.status})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 写入 agent_tasks 表
|
||||||
|
task_id_uuid = uuid.UUID(task.task_id)
|
||||||
|
agent_task = TaskModel(
|
||||||
|
id=task_id_uuid,
|
||||||
|
agent_id=agent.id,
|
||||||
|
task_type=task.task_type,
|
||||||
|
status=TaskStatus.PENDING,
|
||||||
|
priority=task.priority,
|
||||||
|
input_data=task.input_data,
|
||||||
|
organization_id=uuid.UUID(organization_id) if organization_id else agent.id,
|
||||||
|
created_by=uuid.UUID(created_by) if created_by else None,
|
||||||
|
)
|
||||||
|
db.add(agent_task)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 推送到 Redis Queue
|
||||||
|
redis = await self._get_redis()
|
||||||
|
queue_key = f"agent:{task.agent_name}:tasks"
|
||||||
|
await redis.lpush(queue_key, json.dumps(task.to_dict()))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Task {task.task_id} dispatched to agent '{task.agent_name}' "
|
||||||
|
f"(type={task.task_type}, priority={task.priority})"
|
||||||
|
)
|
||||||
|
return task.task_id
|
||||||
|
|
||||||
|
except TaskDispatchError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to dispatch task {task.task_id}: {e}")
|
||||||
|
raise TaskDispatchError(task.task_id, str(e))
|
||||||
|
|
||||||
|
async def cancel_task(self, task_id: str):
|
||||||
|
"""取消任务"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
TaskModel = self._task_model
|
||||||
|
task_uuid = uuid.UUID(task_id)
|
||||||
|
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
task = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise TaskNotFoundError(task_id)
|
||||||
|
|
||||||
|
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||||
|
logger.warning(f"Cannot cancel task {task_id} with status '{task.status}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
task.status = TaskStatus.CANCELLED
|
||||||
|
task.completed_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await self._write_log(
|
||||||
|
db, task_id=task_id, agent_id=str(task.agent_id),
|
||||||
|
log_level="info", message="Task cancelled by user",
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info(f"Task {task_id} cancelled")
|
||||||
|
|
||||||
|
except TaskNotFoundError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to cancel task {task_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> dict:
|
||||||
|
"""获取任务状态"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
TaskModel = self._task_model
|
||||||
|
task_uuid = uuid.UUID(task_id)
|
||||||
|
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
task = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise TaskNotFoundError(task_id)
|
||||||
|
|
||||||
|
return self._task_to_dict(task)
|
||||||
|
|
||||||
|
async def handle_result(self, result: TaskResult):
|
||||||
|
"""处理 Agent 返回的结果"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
TaskModel = self._task_model
|
||||||
|
task_uuid = uuid.UUID(result.task_id)
|
||||||
|
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
|
||||||
|
db_result = await db.execute(stmt)
|
||||||
|
task = db_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
logger.error(f"Task {result.task_id} not found when handling result")
|
||||||
|
return
|
||||||
|
|
||||||
|
task.status = result.status
|
||||||
|
task.output_data = result.output_data
|
||||||
|
task.error_message = result.error_message
|
||||||
|
task.started_at = result.started_at
|
||||||
|
task.completed_at = result.completed_at
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
log_level = "info" if result.status == TaskStatus.COMPLETED else "error"
|
||||||
|
log_message = (
|
||||||
|
f"Task {result.status}"
|
||||||
|
if result.status == TaskStatus.COMPLETED
|
||||||
|
else f"Task failed: {result.error_message}"
|
||||||
|
)
|
||||||
|
await self._write_log(
|
||||||
|
db,
|
||||||
|
task_id=result.task_id,
|
||||||
|
agent_id=str(task.agent_id),
|
||||||
|
log_level=log_level,
|
||||||
|
message=log_message,
|
||||||
|
extra_metadata=result.metrics,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 触发回调
|
||||||
|
if result.output_data and result.output_data.get("callback_url"):
|
||||||
|
await self._trigger_callback(result.output_data["callback_url"], result)
|
||||||
|
|
||||||
|
logger.info(f"Task {result.task_id} result handled (status={result.status})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to handle result for task {result.task_id}: {e}")
|
||||||
|
|
||||||
|
async def handle_progress(self, progress: TaskProgress):
|
||||||
|
"""处理进度上报"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
AgentModel = self._agent_model
|
||||||
|
stmt = select(AgentModel).where(AgentModel.name == progress.agent_name)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agent = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
logger.warning(f"Agent '{progress.agent_name}' not found for progress report")
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._write_log(
|
||||||
|
db,
|
||||||
|
task_id=progress.task_id,
|
||||||
|
agent_id=str(agent.id),
|
||||||
|
log_level="info",
|
||||||
|
message=f"Progress: {progress.progress:.0%} - {progress.message}",
|
||||||
|
extra_metadata={
|
||||||
|
"progress": progress.progress,
|
||||||
|
"updated_at": progress.updated_at.isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to handle progress for task {progress.task_id}: {e}")
|
||||||
|
|
||||||
|
async def retry_failed_tasks(self, max_retries: int = 3):
|
||||||
|
"""重试失败的任务"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
TaskModel = self._task_model
|
||||||
|
AgentModel = self._agent_model
|
||||||
|
LogModel = self._task_log_model
|
||||||
|
|
||||||
|
stmt = select(TaskModel).where(TaskModel.status == TaskStatus.FAILED)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
failed_tasks = result.scalars().all()
|
||||||
|
|
||||||
|
retried = 0
|
||||||
|
for task in failed_tasks:
|
||||||
|
log_stmt = select(LogModel).where(
|
||||||
|
LogModel.task_id == task.id,
|
||||||
|
LogModel.message.like("%retry%"),
|
||||||
|
)
|
||||||
|
log_result = await db.execute(log_stmt)
|
||||||
|
retry_count = len(log_result.scalars().all())
|
||||||
|
|
||||||
|
if retry_count < max_retries:
|
||||||
|
task.status = TaskStatus.PENDING
|
||||||
|
task.error_message = None
|
||||||
|
task.started_at = None
|
||||||
|
task.completed_at = None
|
||||||
|
|
||||||
|
agent_stmt = select(AgentModel).where(AgentModel.id == task.agent_id)
|
||||||
|
agent_result = await db.execute(agent_stmt)
|
||||||
|
agent = agent_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if agent and agent.status == AgentStatus.ONLINE:
|
||||||
|
task_msg = TaskMessage(
|
||||||
|
task_id=str(task.id),
|
||||||
|
agent_name=agent.name,
|
||||||
|
task_type=task.task_type,
|
||||||
|
priority=task.priority,
|
||||||
|
input_data=task.input_data or {},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
redis = await self._get_redis()
|
||||||
|
queue_key = f"agent:{agent.name}:tasks"
|
||||||
|
await redis.lpush(queue_key, json.dumps(task_msg.to_dict()))
|
||||||
|
|
||||||
|
await self._write_log(
|
||||||
|
db,
|
||||||
|
task_id=str(task.id),
|
||||||
|
agent_id=str(agent.id),
|
||||||
|
log_level="info",
|
||||||
|
message=f"Task retry attempt {retry_count + 1}/{max_retries}",
|
||||||
|
)
|
||||||
|
retried += 1
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
if retried > 0:
|
||||||
|
logger.info(f"Retried {retried} failed tasks")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to retry failed tasks: {e}")
|
||||||
|
|
||||||
|
async def _write_log(
|
||||||
|
self,
|
||||||
|
db: Any,
|
||||||
|
task_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
log_level: str,
|
||||||
|
message: str,
|
||||||
|
extra_metadata: dict | None = None,
|
||||||
|
):
|
||||||
|
LogModel = self._task_log_model
|
||||||
|
log_entry = LogModel(
|
||||||
|
task_id=uuid.UUID(task_id),
|
||||||
|
agent_id=uuid.UUID(agent_id),
|
||||||
|
log_level=log_level,
|
||||||
|
message=message,
|
||||||
|
extra_metadata=extra_metadata,
|
||||||
|
)
|
||||||
|
db.add(log_entry)
|
||||||
|
|
||||||
|
async def _trigger_callback(self, callback_url: str, result: TaskResult):
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
await client.post(callback_url, json=result.to_dict())
|
||||||
|
logger.info(f"Callback triggered for task {result.task_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Callback failed for task {result.task_id}: {e}")
|
||||||
|
|
||||||
|
def _task_to_dict(self, task: Any) -> dict:
|
||||||
|
return {
|
||||||
|
"id": str(task.id),
|
||||||
|
"agent_id": str(task.agent_id),
|
||||||
|
"task_type": task.task_type,
|
||||||
|
"status": task.status,
|
||||||
|
"priority": task.priority,
|
||||||
|
"input_data": task.input_data,
|
||||||
|
"output_data": task.output_data,
|
||||||
|
"error_message": task.error_message,
|
||||||
|
"created_by": str(task.created_by) if task.created_by else None,
|
||||||
|
"organization_id": str(task.organization_id),
|
||||||
|
"project_id": str(task.project_id) if task.project_id else None,
|
||||||
|
"scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None,
|
||||||
|
"started_at": task.started_at.isoformat() if task.started_at else None,
|
||||||
|
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
|
||||||
|
"created_at": task.created_at.isoformat() if task.created_at else None,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""Agent 框架自定义异常"""
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFrameworkError(Exception):
|
||||||
|
"""Agent 框架基础异常"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Agent framework error"):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNotFoundError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Agent not found: {agent_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentAlreadyRegisteredError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Agent already registered: {agent_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentUnavailableError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str, status: str = "offline"):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
self.status = status
|
||||||
|
super().__init__(f"Agent '{agent_name}' is unavailable (status: {status})")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskNotFoundError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_id: str):
|
||||||
|
self.task_id = task_id
|
||||||
|
super().__init__(f"Task not found: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDispatchError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_id: str, reason: str = ""):
|
||||||
|
self.task_id = task_id
|
||||||
|
super().__init__(f"Task dispatch failed for {task_id}: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskExecutionError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_id: str, agent_name: str, reason: str = ""):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Task {task_id} execution failed on agent '{agent_name}': {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTimeoutError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_id: str, timeout_seconds: int):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
super().__init__(f"Task {task_id} timed out after {timeout_seconds}s")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCancelledError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_id: str):
|
||||||
|
self.task_id = task_id
|
||||||
|
super().__init__(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
|
|
||||||
|
class NoAvailableAgentError(AgentFrameworkError):
|
||||||
|
def __init__(self, task_type: str):
|
||||||
|
self.task_type = task_type
|
||||||
|
super().__init__(f"No available agent for task type: {task_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigValidationError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str, key: str, reason: str = ""):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
self.key = key
|
||||||
|
super().__init__(f"Config validation failed for agent '{agent_name}' key '{key}': {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentNotReadyError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Agent '{agent_name}' is not ready")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolNotFoundError(AgentFrameworkError):
|
||||||
|
def __init__(self, tool_name: str):
|
||||||
|
self.tool_name = tool_name
|
||||||
|
super().__init__(f"Tool not found: {tool_name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutionError(AgentFrameworkError):
|
||||||
|
def __init__(self, tool_name: str, reason: str = ""):
|
||||||
|
self.tool_name = tool_name
|
||||||
|
super().__init__(f"Tool '{tool_name}' execution failed: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaValidationError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str, detail: str = ""):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Schema validation failed for agent '{agent_name}': {detail}")
|
||||||
|
|
||||||
|
|
||||||
|
class HandoffError(AgentFrameworkError):
|
||||||
|
def __init__(self, source: str, target: str, reason: str = ""):
|
||||||
|
self.source = source
|
||||||
|
self.target = target
|
||||||
|
super().__init__(f"Handoff from '{source}' to '{target}' failed: {reason}")
|
||||||
|
|
||||||
|
|
||||||
|
class EvolutionError(AgentFrameworkError):
|
||||||
|
def __init__(self, agent_name: str, reason: str = ""):
|
||||||
|
self.agent_name = agent_name
|
||||||
|
super().__init__(f"Evolution failed for agent '{agent_name}': {reason}")
|
||||||
|
|
@ -0,0 +1,245 @@
|
||||||
|
"""Agent 通信协议定义 - 统一消息格式"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""任务状态枚举"""
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
HANDOFF = "handoff"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatus(str, Enum):
|
||||||
|
"""Agent 状态枚举"""
|
||||||
|
ONLINE = "online"
|
||||||
|
OFFLINE = "offline"
|
||||||
|
BUSY = "busy"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentCapability:
|
||||||
|
"""Agent 能力声明"""
|
||||||
|
agent_name: str
|
||||||
|
agent_type: str
|
||||||
|
version: str
|
||||||
|
supported_tasks: list[str]
|
||||||
|
max_concurrency: int
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any] | None = None
|
||||||
|
output_schema: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = {
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"agent_type": self.agent_type,
|
||||||
|
"version": self.version,
|
||||||
|
"supported_tasks": self.supported_tasks,
|
||||||
|
"max_concurrency": self.max_concurrency,
|
||||||
|
"description": self.description,
|
||||||
|
}
|
||||||
|
if self.input_schema is not None:
|
||||||
|
d["input_schema"] = self.input_schema
|
||||||
|
if self.output_schema is not None:
|
||||||
|
d["output_schema"] = self.output_schema
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "AgentCapability":
|
||||||
|
return cls(
|
||||||
|
agent_name=data["agent_name"],
|
||||||
|
agent_type=data["agent_type"],
|
||||||
|
version=data["version"],
|
||||||
|
supported_tasks=data["supported_tasks"],
|
||||||
|
max_concurrency=data["max_concurrency"],
|
||||||
|
description=data["description"],
|
||||||
|
input_schema=data.get("input_schema"),
|
||||||
|
output_schema=data.get("output_schema"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskMessage:
|
||||||
|
"""任务消息 - 从调度器发往 Agent"""
|
||||||
|
task_id: str
|
||||||
|
agent_name: str
|
||||||
|
task_type: str
|
||||||
|
priority: int
|
||||||
|
input_data: dict
|
||||||
|
callback_url: str | None
|
||||||
|
created_at: datetime
|
||||||
|
timeout_seconds: int = 300
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"task_type": self.task_type,
|
||||||
|
"priority": self.priority,
|
||||||
|
"input_data": self.input_data,
|
||||||
|
"callback_url": self.callback_url,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"timeout_seconds": self.timeout_seconds,
|
||||||
|
"conversation_id": self.conversation_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "TaskMessage":
|
||||||
|
created_at = data.get("created_at")
|
||||||
|
if isinstance(created_at, str):
|
||||||
|
created_at = datetime.fromisoformat(created_at)
|
||||||
|
return cls(
|
||||||
|
task_id=data["task_id"],
|
||||||
|
agent_name=data["agent_name"],
|
||||||
|
task_type=data["task_type"],
|
||||||
|
priority=data.get("priority", 0),
|
||||||
|
input_data=data.get("input_data", {}),
|
||||||
|
callback_url=data.get("callback_url"),
|
||||||
|
created_at=created_at or datetime.utcnow(),
|
||||||
|
timeout_seconds=data.get("timeout_seconds", 300),
|
||||||
|
conversation_id=data.get("conversation_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskResult:
|
||||||
|
"""任务结果 - 从 Agent 返回"""
|
||||||
|
task_id: str
|
||||||
|
agent_name: str
|
||||||
|
status: str
|
||||||
|
output_data: dict | None
|
||||||
|
error_message: str | None
|
||||||
|
started_at: datetime
|
||||||
|
completed_at: datetime
|
||||||
|
metrics: dict | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"status": self.status,
|
||||||
|
"output_data": self.output_data,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
"metrics": self.metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "TaskResult":
|
||||||
|
started_at = data.get("started_at")
|
||||||
|
if isinstance(started_at, str):
|
||||||
|
started_at = datetime.fromisoformat(started_at)
|
||||||
|
completed_at = data.get("completed_at")
|
||||||
|
if isinstance(completed_at, str):
|
||||||
|
completed_at = datetime.fromisoformat(completed_at)
|
||||||
|
return cls(
|
||||||
|
task_id=data["task_id"],
|
||||||
|
agent_name=data["agent_name"],
|
||||||
|
status=data["status"],
|
||||||
|
output_data=data.get("output_data"),
|
||||||
|
error_message=data.get("error_message"),
|
||||||
|
started_at=started_at or datetime.utcnow(),
|
||||||
|
completed_at=completed_at or datetime.utcnow(),
|
||||||
|
metrics=data.get("metrics"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TaskProgress:
|
||||||
|
"""进度上报 - Agent 执行过程中上报"""
|
||||||
|
task_id: str
|
||||||
|
agent_name: str
|
||||||
|
progress: float
|
||||||
|
message: str
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"progress": self.progress,
|
||||||
|
"message": self.message,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "TaskProgress":
|
||||||
|
updated_at = data.get("updated_at")
|
||||||
|
if isinstance(updated_at, str):
|
||||||
|
updated_at = datetime.fromisoformat(updated_at)
|
||||||
|
return cls(
|
||||||
|
task_id=data["task_id"],
|
||||||
|
agent_name=data["agent_name"],
|
||||||
|
progress=data.get("progress", 0.0),
|
||||||
|
message=data.get("message", ""),
|
||||||
|
updated_at=updated_at or datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HandoffMessage:
|
||||||
|
"""任务转交消息 - Agent 间 Handoff"""
|
||||||
|
source_agent: str
|
||||||
|
target_agent: str
|
||||||
|
task_id: str
|
||||||
|
task_type: str
|
||||||
|
context: dict[str, Any]
|
||||||
|
reason: str
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"source_agent": self.source_agent,
|
||||||
|
"target_agent": self.target_agent,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"task_type": self.task_type,
|
||||||
|
"context": self.context,
|
||||||
|
"reason": self.reason,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "HandoffMessage":
|
||||||
|
created_at = data.get("created_at")
|
||||||
|
if isinstance(created_at, str):
|
||||||
|
created_at = datetime.fromisoformat(created_at)
|
||||||
|
return cls(
|
||||||
|
source_agent=data["source_agent"],
|
||||||
|
target_agent=data["target_agent"],
|
||||||
|
task_id=data["task_id"],
|
||||||
|
task_type=data["task_type"],
|
||||||
|
context=data.get("context", {}),
|
||||||
|
reason=data["reason"],
|
||||||
|
created_at=created_at or datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvolutionEvent:
|
||||||
|
"""进化事件 - 记录 Agent 的自我进化变更"""
|
||||||
|
agent_name: str
|
||||||
|
change_type: str # prompt / strategy / pipeline
|
||||||
|
before: dict[str, Any]
|
||||||
|
after: dict[str, Any]
|
||||||
|
metrics: dict[str, Any] | None = None
|
||||||
|
event_id: str | None = None
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"agent_name": self.agent_name,
|
||||||
|
"change_type": self.change_type,
|
||||||
|
"before": self.before,
|
||||||
|
"after": self.after,
|
||||||
|
"metrics": self.metrics,
|
||||||
|
"event_id": self.event_id,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,250 @@
|
||||||
|
"""Agent 注册中心 - 管理 Agent 的注册、发现、状态
|
||||||
|
|
||||||
|
与业务系统解耦:通过 session_factory 注入数据库会话,
|
||||||
|
通过 agent_model_factory 注入 ORM 模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import (
|
||||||
|
AgentNotFoundError,
|
||||||
|
AgentUnavailableError,
|
||||||
|
NoAvailableAgentError,
|
||||||
|
)
|
||||||
|
from agentkit.core.protocol import AgentCapability, AgentStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HEARTBEAT_TIMEOUT_SECONDS = 90
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Agent 注册中心,管理 Agent 的注册、发现、状态
|
||||||
|
|
||||||
|
使用依赖注入模式,不依赖具体的 ORM 模型或数据库连接。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: Callable[[], Any],
|
||||||
|
agent_model: Any,
|
||||||
|
load_balancer: str = "round_robin",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
session_factory: 返回 async context manager 的工厂,用于获取数据库会话
|
||||||
|
agent_model: Agent ORM 模型类
|
||||||
|
load_balancer: 负载均衡策略 (round_robin / least_tasks / random)
|
||||||
|
"""
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._agent_model = agent_model
|
||||||
|
self._load_balancer = load_balancer
|
||||||
|
self._round_robin_index: dict[str, int] = {}
|
||||||
|
|
||||||
|
async def register(self, capability: AgentCapability, endpoint: str) -> str:
|
||||||
|
"""注册 Agent,返回 agent_id。同名 Agent 已存在则更新。"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = type(db).execute.__self__.__class__ # placeholder
|
||||||
|
|
||||||
|
# 尝试查找已有记录
|
||||||
|
from sqlalchemy import select
|
||||||
|
stmt = select(Model).where(Model.name == capability.agent_name)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
existing.agent_type = capability.agent_type
|
||||||
|
existing.version = capability.version
|
||||||
|
existing.endpoint = endpoint
|
||||||
|
existing.description = capability.description
|
||||||
|
existing.capabilities = capability.to_dict()
|
||||||
|
existing.status = AgentStatus.ONLINE
|
||||||
|
existing.last_heartbeat = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(existing)
|
||||||
|
agent_id = existing.id
|
||||||
|
logger.info(f"Agent '{capability.agent_name}' re-registered (id={agent_id})")
|
||||||
|
else:
|
||||||
|
agent = Model(
|
||||||
|
name=capability.agent_name,
|
||||||
|
display_name=capability.agent_name.replace("_", " ").title(),
|
||||||
|
agent_type=capability.agent_type,
|
||||||
|
description=capability.description,
|
||||||
|
version=capability.version,
|
||||||
|
endpoint=endpoint,
|
||||||
|
status=AgentStatus.ONLINE,
|
||||||
|
capabilities=capability.to_dict(),
|
||||||
|
last_heartbeat=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
agent_id = agent.id
|
||||||
|
logger.info(f"Agent '{capability.agent_name}' registered (id={agent_id})")
|
||||||
|
|
||||||
|
return str(agent_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to register agent '{capability.agent_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def unregister(self, agent_name: str):
|
||||||
|
"""注销 Agent(设置状态为 offline)"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = select(Model).where(Model.name == agent_name)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agent = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
logger.warning(f"Attempted to unregister non-existent agent '{agent_name}'")
|
||||||
|
return
|
||||||
|
|
||||||
|
agent.status = AgentStatus.OFFLINE
|
||||||
|
await db.commit()
|
||||||
|
logger.info(f"Agent '{agent_name}' unregistered")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to unregister agent '{agent_name}': {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def update_heartbeat(self, agent_name: str):
|
||||||
|
"""更新心跳时间"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import update
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = (
|
||||||
|
update(Model)
|
||||||
|
.where(Model.name == agent_name)
|
||||||
|
.values(
|
||||||
|
last_heartbeat=datetime.now(timezone.utc),
|
||||||
|
status=AgentStatus.ONLINE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to update heartbeat for agent '{agent_name}': {e}")
|
||||||
|
|
||||||
|
async def get_agent(self, agent_name: str) -> dict | None:
|
||||||
|
"""获取 Agent 信息"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = select(Model).where(Model.name == agent_name)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agent = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not agent:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._agent_to_dict(agent)
|
||||||
|
|
||||||
|
async def list_agents(
|
||||||
|
self,
|
||||||
|
agent_type: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""列出 Agent,支持按类型和状态筛选"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = select(Model)
|
||||||
|
if agent_type:
|
||||||
|
stmt = stmt.where(Model.agent_type == agent_type)
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(Model.status == status)
|
||||||
|
stmt = stmt.order_by(Model.created_at.desc())
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agents = result.scalars().all()
|
||||||
|
|
||||||
|
return [self._agent_to_dict(a) for a in agents]
|
||||||
|
|
||||||
|
async def get_available_agent(self, task_type: str) -> str | None:
|
||||||
|
"""根据任务类型找到可用 Agent(支持负载均衡)"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._agent_model
|
||||||
|
stmt = select(Model).where(Model.status == AgentStatus.ONLINE)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
agents = result.scalars().all()
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for agent in agents:
|
||||||
|
capabilities = agent.capabilities or {}
|
||||||
|
supported_tasks = capabilities.get("supported_tasks", [])
|
||||||
|
if task_type in supported_tasks:
|
||||||
|
candidates.append(agent)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 负载均衡选择
|
||||||
|
if self._load_balancer == "round_robin":
|
||||||
|
idx = self._round_robin_index.get(task_type, 0)
|
||||||
|
selected = candidates[idx % len(candidates)]
|
||||||
|
self._round_robin_index[task_type] = idx + 1
|
||||||
|
return selected.name
|
||||||
|
elif self._load_balancer == "random":
|
||||||
|
import random
|
||||||
|
return random.choice(candidates).name
|
||||||
|
else: # least_tasks 或默认:返回第一个
|
||||||
|
return candidates[0].name
|
||||||
|
|
||||||
|
async def check_health(self):
|
||||||
|
"""检查所有 Agent 健康状态,超时标记为 offline"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import update
|
||||||
|
Model = self._agent_model
|
||||||
|
timeout_threshold = datetime.now(timezone.utc) - timedelta(
|
||||||
|
seconds=HEARTBEAT_TIMEOUT_SECONDS
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(Model)
|
||||||
|
.where(
|
||||||
|
Model.status == AgentStatus.ONLINE,
|
||||||
|
Model.last_heartbeat < timeout_threshold,
|
||||||
|
)
|
||||||
|
.values(status=AgentStatus.OFFLINE)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
if result.rowcount > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Marked {result.rowcount} agent(s) as offline due to heartbeat timeout"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to check agent health: {e}")
|
||||||
|
|
||||||
|
def _agent_to_dict(self, agent: Any) -> dict:
|
||||||
|
"""将 Agent ORM 对象转换为字典"""
|
||||||
|
return {
|
||||||
|
"id": str(agent.id),
|
||||||
|
"name": agent.name,
|
||||||
|
"display_name": agent.display_name,
|
||||||
|
"agent_type": agent.agent_type,
|
||||||
|
"description": agent.description,
|
||||||
|
"version": agent.version,
|
||||||
|
"endpoint": agent.endpoint,
|
||||||
|
"status": agent.status,
|
||||||
|
"capabilities": agent.capabilities,
|
||||||
|
"last_heartbeat": agent.last_heartbeat.isoformat() if agent.last_heartbeat else None,
|
||||||
|
"created_at": agent.created_at.isoformat() if agent.created_at else None,
|
||||||
|
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""AgentKit Evolution - 自我进化引擎"""
|
||||||
|
|
||||||
|
from agentkit.evolution.reflector import Reflector
|
||||||
|
from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module
|
||||||
|
from agentkit.evolution.strategy_tuner import StrategyTuner
|
||||||
|
from agentkit.evolution.ab_tester import ABTester
|
||||||
|
from agentkit.evolution.evolution_store import EvolutionStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Reflector",
|
||||||
|
"PromptOptimizer",
|
||||||
|
"Signature",
|
||||||
|
"Module",
|
||||||
|
"StrategyTuner",
|
||||||
|
"ABTester",
|
||||||
|
"EvolutionStore",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""ABTester - A/B 测试框架
|
||||||
|
|
||||||
|
支持配置分流比例,自动收集效果指标,统计显著性检验。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestConfig:
|
||||||
|
"""A/B 测试配置"""
|
||||||
|
test_id: str
|
||||||
|
agent_name: str
|
||||||
|
change_type: str # prompt / strategy / pipeline
|
||||||
|
control_ratio: float = 0.8 # 对照组比例
|
||||||
|
min_samples: int = 30 # 最小样本量
|
||||||
|
confidence_level: float = 0.95 # 置信度
|
||||||
|
status: str = "running" # running / completed / rolled_back
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ABTestResult:
|
||||||
|
"""A/B 测试结果"""
|
||||||
|
test_id: str
|
||||||
|
control_metric: float
|
||||||
|
experiment_metric: float
|
||||||
|
control_samples: int
|
||||||
|
experiment_samples: int
|
||||||
|
is_significant: bool
|
||||||
|
winner: str | None # control / experiment / None
|
||||||
|
p_value: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ABTester:
|
||||||
|
"""A/B 测试框架"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tests: dict[str, ABTestConfig] = {}
|
||||||
|
self._results: dict[str, list[tuple[str, float]]] = {} # test_id -> [(group, metric)]
|
||||||
|
|
||||||
|
def create_test(self, config: ABTestConfig) -> None:
|
||||||
|
"""创建 A/B 测试"""
|
||||||
|
self._tests[config.test_id] = config
|
||||||
|
self._results[config.test_id] = []
|
||||||
|
logger.info(f"A/B test '{config.test_id}' created for agent '{config.agent_name}'")
|
||||||
|
|
||||||
|
def assign_group(self, test_id: str) -> str:
|
||||||
|
"""分配测试组"""
|
||||||
|
import random
|
||||||
|
config = self._tests.get(test_id)
|
||||||
|
if not config:
|
||||||
|
return "control"
|
||||||
|
|
||||||
|
return "control" if random.random() < config.control_ratio else "experiment"
|
||||||
|
|
||||||
|
def record_result(self, test_id: str, group: str, metric: float) -> None:
|
||||||
|
"""记录测试结果"""
|
||||||
|
if test_id not in self._results:
|
||||||
|
self._results[test_id] = []
|
||||||
|
self._results[test_id].append((group, metric))
|
||||||
|
|
||||||
|
async def evaluate(self, test_id: str) -> ABTestResult | None:
|
||||||
|
"""评估 A/B 测试结果"""
|
||||||
|
config = self._tests.get(test_id)
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
results = self._results.get(test_id, [])
|
||||||
|
control_metrics = [m for g, m in results if g == "control"]
|
||||||
|
experiment_metrics = [m for g, m in results if g == "experiment"]
|
||||||
|
|
||||||
|
if len(control_metrics) < config.min_samples or len(experiment_metrics) < config.min_samples:
|
||||||
|
return ABTestResult(
|
||||||
|
test_id=test_id,
|
||||||
|
control_metric=sum(control_metrics) / len(control_metrics) if control_metrics else 0,
|
||||||
|
experiment_metric=sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0,
|
||||||
|
control_samples=len(control_metrics),
|
||||||
|
experiment_samples=len(experiment_metrics),
|
||||||
|
is_significant=False,
|
||||||
|
winner=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 简单 t-test
|
||||||
|
control_mean = sum(control_metrics) / len(control_metrics)
|
||||||
|
experiment_mean = sum(experiment_metrics) / len(experiment_metrics)
|
||||||
|
|
||||||
|
control_var = sum((m - control_mean) ** 2 for m in control_metrics) / (len(control_metrics) - 1)
|
||||||
|
experiment_var = sum((m - experiment_mean) ** 2 for m in experiment_metrics) / (len(experiment_metrics) - 1)
|
||||||
|
|
||||||
|
pooled_se = math.sqrt(control_var / len(control_metrics) + experiment_var / len(experiment_metrics))
|
||||||
|
t_stat = (experiment_mean - control_mean) / pooled_se if pooled_se > 0 else 0
|
||||||
|
|
||||||
|
# 近似 p-value (双侧)
|
||||||
|
p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))
|
||||||
|
is_significant = p_value < (1 - config.confidence_level)
|
||||||
|
|
||||||
|
winner = None
|
||||||
|
if is_significant:
|
||||||
|
winner = "experiment" if experiment_mean > control_mean else "control"
|
||||||
|
|
||||||
|
return ABTestResult(
|
||||||
|
test_id=test_id,
|
||||||
|
control_metric=control_mean,
|
||||||
|
experiment_metric=experiment_mean,
|
||||||
|
control_samples=len(control_metrics),
|
||||||
|
experiment_samples=len(experiment_metrics),
|
||||||
|
is_significant=is_significant,
|
||||||
|
winner=winner,
|
||||||
|
p_value=p_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normal_cdf(x: float) -> float:
|
||||||
|
"""标准正态分布 CDF 近似"""
|
||||||
|
return 0.5 * (1 + math.erf(x / math.sqrt(2)))
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""EvolutionStore - 进化日志存储"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import EvolutionEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EvolutionStore:
|
||||||
|
"""进化日志存储
|
||||||
|
|
||||||
|
记录 Agent 的自我进化变更,支持回滚。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: Any, evolution_model: Any):
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._evolution_model = evolution_model
|
||||||
|
|
||||||
|
async def record(self, event: EvolutionEvent) -> str:
|
||||||
|
"""记录进化事件"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
import uuid
|
||||||
|
Model = self._evolution_model
|
||||||
|
entry = Model(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
agent_name=event.agent_name,
|
||||||
|
change_type=event.change_type,
|
||||||
|
before=event.before,
|
||||||
|
after=event.after,
|
||||||
|
metrics=event.metrics,
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db.add(entry)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(entry)
|
||||||
|
event_id = str(entry.id)
|
||||||
|
event.event_id = event_id
|
||||||
|
logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'")
|
||||||
|
return event_id
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to record evolution event: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def rollback(self, event_id: str) -> bool:
|
||||||
|
"""回滚进化事件"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._evolution_model
|
||||||
|
|
||||||
|
stmt = select(Model).where(Model.id == uuid.UUID(event_id))
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
entry = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not entry:
|
||||||
|
logger.error(f"Evolution event {event_id} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
entry.status = "rolled_back"
|
||||||
|
await db.commit()
|
||||||
|
logger.info(f"Evolution event {event_id} rolled back")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to rollback evolution event {event_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
agent_name: str | None = None,
|
||||||
|
change_type: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""列出进化事件"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select
|
||||||
|
Model = self._evolution_model
|
||||||
|
stmt = select(Model)
|
||||||
|
|
||||||
|
if agent_name:
|
||||||
|
stmt = stmt.where(Model.agent_name == agent_name)
|
||||||
|
if change_type:
|
||||||
|
stmt = stmt.where(Model.change_type == change_type)
|
||||||
|
if status:
|
||||||
|
stmt = stmt.where(Model.status == status)
|
||||||
|
|
||||||
|
stmt = stmt.order_by(Model.created_at.desc())
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
entries = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": str(e.id),
|
||||||
|
"agent_name": e.agent_name,
|
||||||
|
"change_type": e.change_type,
|
||||||
|
"before": e.before,
|
||||||
|
"after": e.after,
|
||||||
|
"metrics": e.metrics,
|
||||||
|
"status": e.status,
|
||||||
|
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list evolution events: {e}")
|
||||||
|
return []
|
||||||
|
|
@ -0,0 +1,151 @@
|
||||||
|
"""PromptOptimizer - DSPy 风格的 Prompt 自动优化器
|
||||||
|
|
||||||
|
核心概念:
|
||||||
|
- Signature: 定义输入/输出 schema
|
||||||
|
- Module: 可组合的 Prompt 策略
|
||||||
|
- Optimizer: 从任务结果中自动优化 Prompt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Signature:
|
||||||
|
"""Prompt 签名 - 定义输入/输出字段"""
|
||||||
|
input_fields: dict[str, str] # name -> description
|
||||||
|
output_fields: dict[str, str] # name -> description
|
||||||
|
instruction: str = ""
|
||||||
|
|
||||||
|
def to_prompt_prefix(self) -> str:
|
||||||
|
parts = []
|
||||||
|
if self.instruction:
|
||||||
|
parts.append(self.instruction)
|
||||||
|
parts.append("Inputs:")
|
||||||
|
for name, desc in self.input_fields.items():
|
||||||
|
parts.append(f" - {name}: {desc}")
|
||||||
|
parts.append("Outputs:")
|
||||||
|
for name, desc in self.output_fields.items():
|
||||||
|
parts.append(f" - {name}: {desc}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Module:
|
||||||
|
"""可组合的 Prompt 策略模块"""
|
||||||
|
name: str
|
||||||
|
signature: Signature
|
||||||
|
template: str = ""
|
||||||
|
demos: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def render(self, **kwargs) -> str:
|
||||||
|
parts = []
|
||||||
|
parts.append(self.signature.to_prompt_prefix())
|
||||||
|
if self.demos:
|
||||||
|
parts.append("\nExamples:")
|
||||||
|
for demo in self.demos:
|
||||||
|
parts.append(f"\nInput: {demo.get('input', '')}")
|
||||||
|
parts.append(f"Output: {demo.get('output', '')}")
|
||||||
|
if self.template:
|
||||||
|
parts.append(f"\n{self.template.format(**kwargs)}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptOptimizer:
|
||||||
|
"""DSPy 风格的 Prompt 自动优化器
|
||||||
|
|
||||||
|
从成功案例中自动构建 few-shot 示例,优化 Prompt 指令。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_demos: int = 5,
|
||||||
|
min_examples_for_optimization: int = 3,
|
||||||
|
):
|
||||||
|
self._max_demos = max_demos
|
||||||
|
self._min_examples = min_examples_for_optimization
|
||||||
|
self._success_examples: list[dict[str, Any]] = []
|
||||||
|
self._failure_examples: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def add_example(
|
||||||
|
self,
|
||||||
|
input_data: dict,
|
||||||
|
output_data: dict,
|
||||||
|
quality_score: float,
|
||||||
|
) -> None:
|
||||||
|
"""添加训练样本"""
|
||||||
|
example = {
|
||||||
|
"input": input_data,
|
||||||
|
"output": output_data,
|
||||||
|
"quality_score": quality_score,
|
||||||
|
}
|
||||||
|
if quality_score >= 0.7:
|
||||||
|
self._success_examples.append(example)
|
||||||
|
else:
|
||||||
|
self._failure_examples.append(example)
|
||||||
|
|
||||||
|
async def optimize(self, module: Module) -> Module:
|
||||||
|
"""优化 Module 的 Prompt
|
||||||
|
|
||||||
|
BootstrapFewShot: 从成功案例中自动构建 few-shot 示例
|
||||||
|
"""
|
||||||
|
if len(self._success_examples) < self._min_examples:
|
||||||
|
logger.info(
|
||||||
|
f"Not enough examples for optimization "
|
||||||
|
f"({len(self._success_examples)}/{self._min_examples})"
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
|
||||||
|
# 选择质量最高的成功案例作为 demo
|
||||||
|
sorted_examples = sorted(
|
||||||
|
self._success_examples,
|
||||||
|
key=lambda x: x["quality_score"],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
best_demos = sorted_examples[:self._max_demos]
|
||||||
|
|
||||||
|
# 构建 few-shot 示例
|
||||||
|
demos = []
|
||||||
|
for example in best_demos:
|
||||||
|
demos.append({
|
||||||
|
"input": str(example["input"]),
|
||||||
|
"output": str(example["output"]),
|
||||||
|
})
|
||||||
|
|
||||||
|
# 优化指令(基于失败案例的反面教材)
|
||||||
|
optimized_instruction = module.signature.instruction
|
||||||
|
if self._failure_examples:
|
||||||
|
failure_patterns = set()
|
||||||
|
for ex in self._failure_examples[-3:]:
|
||||||
|
failure_patterns.add(str(ex["input"])[:100])
|
||||||
|
if failure_patterns:
|
||||||
|
optimized_instruction += (
|
||||||
|
f"\n\nAvoid these patterns:\n"
|
||||||
|
+ "\n".join(f"- {p}" for p in failure_patterns)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建优化后的 Module
|
||||||
|
optimized = Module(
|
||||||
|
name=f"{module.name}_optimized",
|
||||||
|
signature=Signature(
|
||||||
|
input_fields=module.signature.input_fields,
|
||||||
|
output_fields=module.signature.output_fields,
|
||||||
|
instruction=optimized_instruction,
|
||||||
|
),
|
||||||
|
template=module.template,
|
||||||
|
demos=demos,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Optimized module '{module.name}': "
|
||||||
|
f"{len(demos)} demos, instruction length {len(optimized_instruction)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimized
|
||||||
|
|
||||||
|
@property
|
||||||
|
def example_count(self) -> tuple[int, int]:
|
||||||
|
return len(self._success_examples), len(self._failure_examples)
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""Reflector - 执行反思
|
||||||
|
|
||||||
|
每次任务完成后自动评估结果,提取模式,生成反思总结。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Reflection:
|
||||||
|
"""反思结果"""
|
||||||
|
task_id: str
|
||||||
|
agent_name: str
|
||||||
|
outcome: str # success / failure / partial
|
||||||
|
quality_score: float # 0.0 - 1.0
|
||||||
|
patterns: list[str] = field(default_factory=list)
|
||||||
|
insights: list[str] = field(default_factory=list)
|
||||||
|
suggestions: list[str] = field(default_factory=list)
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||||
|
|
||||||
|
|
||||||
|
class Reflector:
|
||||||
|
"""执行反思器
|
||||||
|
|
||||||
|
评估任务结果,提取成功/失败模式,生成改进建议。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quality_scorer: Any | None = None,
|
||||||
|
pattern_extractor: Any | None = None,
|
||||||
|
):
|
||||||
|
self._quality_scorer = quality_scorer
|
||||||
|
self._pattern_extractor = pattern_extractor
|
||||||
|
|
||||||
|
async def reflect(self, task: TaskMessage, result: TaskResult) -> Reflection:
|
||||||
|
"""对任务执行结果进行反思"""
|
||||||
|
# 判断结果
|
||||||
|
outcome = "success" if result.status == TaskStatus.COMPLETED else "failure"
|
||||||
|
|
||||||
|
# 质量评分
|
||||||
|
quality_score = await self._score_quality(task, result)
|
||||||
|
|
||||||
|
# 提取模式
|
||||||
|
patterns = await self._extract_patterns(task, result, outcome)
|
||||||
|
|
||||||
|
# 生成洞察
|
||||||
|
insights = self._generate_insights(outcome, quality_score, patterns)
|
||||||
|
|
||||||
|
# 生成建议
|
||||||
|
suggestions = self._generate_suggestions(outcome, quality_score, patterns)
|
||||||
|
|
||||||
|
reflection = Reflection(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=result.agent_name,
|
||||||
|
outcome=outcome,
|
||||||
|
quality_score=quality_score,
|
||||||
|
patterns=patterns,
|
||||||
|
insights=insights,
|
||||||
|
suggestions=suggestions,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Reflection for task {task.task_id}: outcome={outcome}, "
|
||||||
|
f"quality={quality_score:.2f}, patterns={len(patterns)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return reflection
|
||||||
|
|
||||||
|
async def _score_quality(self, task: TaskMessage, result: TaskResult) -> float:
|
||||||
|
"""评估任务质量"""
|
||||||
|
if result.status != TaskStatus.COMPLETED:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
if self._quality_scorer:
|
||||||
|
return await self._quality_scorer(task, result)
|
||||||
|
|
||||||
|
# 默认评分逻辑
|
||||||
|
score = 0.5 # 基础分
|
||||||
|
|
||||||
|
# 有输出数据加分
|
||||||
|
if result.output_data:
|
||||||
|
score += 0.2
|
||||||
|
|
||||||
|
# 无错误加分
|
||||||
|
if not result.error_message:
|
||||||
|
score += 0.1
|
||||||
|
|
||||||
|
# 耗时合理加分
|
||||||
|
if result.metrics and result.metrics.get("elapsed_seconds", 0) < 30:
|
||||||
|
score += 0.2
|
||||||
|
|
||||||
|
return min(score, 1.0)
|
||||||
|
|
||||||
|
async def _extract_patterns(
|
||||||
|
self, task: TaskMessage, result: TaskResult, outcome: str
|
||||||
|
) -> list[str]:
|
||||||
|
"""提取模式"""
|
||||||
|
patterns = []
|
||||||
|
|
||||||
|
if outcome == "failure":
|
||||||
|
if result.error_message:
|
||||||
|
patterns.append(f"error_type:{type(result.error_message).__name__}")
|
||||||
|
if result.metrics and result.metrics.get("elapsed_seconds", 0) > 60:
|
||||||
|
patterns.append("slow_execution")
|
||||||
|
else:
|
||||||
|
if result.output_data and len(result.output_data) > 5:
|
||||||
|
patterns.append("rich_output")
|
||||||
|
if result.metrics and result.metrics.get("elapsed_seconds", 0) < 10:
|
||||||
|
patterns.append("fast_execution")
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
|
||||||
|
def _generate_insights(
|
||||||
|
self, outcome: str, quality_score: float, patterns: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""生成洞察"""
|
||||||
|
insights = []
|
||||||
|
|
||||||
|
if quality_score < 0.3:
|
||||||
|
insights.append("Low quality score indicates potential issues with task execution")
|
||||||
|
if "slow_execution" in patterns:
|
||||||
|
insights.append("Slow execution may benefit from strategy optimization")
|
||||||
|
if "error_type:TimeoutError" in patterns:
|
||||||
|
insights.append("Timeout errors suggest need for longer timeout or task decomposition")
|
||||||
|
|
||||||
|
return insights
|
||||||
|
|
||||||
|
def _generate_suggestions(
|
||||||
|
self, outcome: str, quality_score: float, patterns: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""生成改进建议"""
|
||||||
|
suggestions = []
|
||||||
|
|
||||||
|
if outcome == "failure" and quality_score < 0.3:
|
||||||
|
suggestions.append("Consider prompt optimization for this task type")
|
||||||
|
if "slow_execution" in patterns:
|
||||||
|
suggestions.append("Consider adjusting strategy parameters for faster execution")
|
||||||
|
|
||||||
|
return suggestions
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
"""StrategyTuner - 策略调优
|
||||||
|
|
||||||
|
自动调整 Agent 参数(temperature, tool 选择权重, Pipeline 路径)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyConfig:
|
||||||
|
"""策略配置"""
|
||||||
|
temperature: float = 0.5
|
||||||
|
tool_weights: dict[str, float] = field(default_factory=dict)
|
||||||
|
max_iterations: int = 5
|
||||||
|
timeout_seconds: int = 300
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyTuner:
|
||||||
|
"""策略调优器
|
||||||
|
|
||||||
|
基于历史效果数据自动调整 Agent 参数。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, param_ranges: dict[str, tuple[float, float]] | None = None):
|
||||||
|
self._param_ranges = param_ranges or {
|
||||||
|
"temperature": (0.0, 1.0),
|
||||||
|
"max_iterations": (1, 10),
|
||||||
|
}
|
||||||
|
self._history: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def record(self, config: StrategyConfig, metric: float) -> None:
|
||||||
|
"""记录配置和对应的效果指标"""
|
||||||
|
self._history.append({
|
||||||
|
"config": config,
|
||||||
|
"metric": metric,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def suggest(self, current: StrategyConfig) -> StrategyConfig:
|
||||||
|
"""基于历史数据建议新的策略配置"""
|
||||||
|
if len(self._history) < 3:
|
||||||
|
logger.info("Not enough history for strategy tuning")
|
||||||
|
return current
|
||||||
|
|
||||||
|
# 找到效果最好的配置
|
||||||
|
best = max(self._history, key=lambda x: x["metric"])
|
||||||
|
best_config = best["config"]
|
||||||
|
best_metric = best["metric"]
|
||||||
|
|
||||||
|
# 在最佳配置附近微调
|
||||||
|
suggested = StrategyConfig(
|
||||||
|
temperature=self._clamp(
|
||||||
|
best_config.temperature + self._small_perturbation(),
|
||||||
|
*self._param_ranges.get("temperature", (0.0, 1.0)),
|
||||||
|
),
|
||||||
|
tool_weights=dict(best_config.tool_weights),
|
||||||
|
max_iterations=int(self._clamp(
|
||||||
|
best_config.max_iterations + self._small_perturbation(),
|
||||||
|
*self._param_ranges.get("max_iterations", (1, 10)),
|
||||||
|
)),
|
||||||
|
timeout_seconds=current.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Strategy suggestion: temperature {current.temperature:.2f} -> {suggested.temperature:.2f}, "
|
||||||
|
f"max_iterations {current.max_iterations} -> {suggested.max_iterations}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return suggested
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _small_perturbation() -> float:
|
||||||
|
import random
|
||||||
|
return random.uniform(-0.1, 0.1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _clamp(value: float, min_val: float, max_val: float) -> float:
|
||||||
|
return max(min_val, min(max_val, value))
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""AgentKit MCP - Model Context Protocol 支持"""
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPServer",
|
||||||
|
"MCPClient",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""MCP Client - 调用外部 MCP 工具服务器"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""MCP Client - 连接外部 MCP Server 并调用工具"""
|
||||||
|
|
||||||
|
def __init__(self, server_url: str, timeout: int = 30):
|
||||||
|
self._server_url = server_url.rstrip("/")
|
||||||
|
self._timeout = timeout
|
||||||
|
self._tools_cache: list[dict] | None = None
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[dict]:
|
||||||
|
"""列出远程 MCP Server 上的工具"""
|
||||||
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||||
|
response = await client.get(f"{self._server_url}/tools/list")
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
self._tools_cache = data.get("tools", [])
|
||||||
|
return self._tools_cache
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
|
"""调用远程 MCP 工具"""
|
||||||
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self._server_url}/tools/call",
|
||||||
|
json={"name": tool_name, "arguments": arguments},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def as_tool(self, tool_name: str, description: str = "") -> "MCPTool":
|
||||||
|
"""将远程 MCP 工具包装为本地 Tool 对象"""
|
||||||
|
return MCPTool(
|
||||||
|
name=tool_name,
|
||||||
|
description=description,
|
||||||
|
client=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTool(Tool):
|
||||||
|
"""MCP 工具 - 通过 MCP Client 调用远程工具"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
client: MCPClient,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema,
|
||||||
|
output_schema=output_schema,
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["mcp"],
|
||||||
|
)
|
||||||
|
self._client = client
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
result = await self._client.call_tool(self.name, kwargs)
|
||||||
|
# 解析 MCP 响应格式
|
||||||
|
if "content" in result:
|
||||||
|
for item in result["content"]:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
return json.loads(item["text"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {"result": item["text"]}
|
||||||
|
return result
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""MCP Server - 将 Agent 能力暴露为 MCP 工具
|
||||||
|
|
||||||
|
基于 FastAPI 实现,支持 Streamable HTTP 传输。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPServer:
|
||||||
|
"""MCP Server - 暴露 Agent 能力为 MCP 工具
|
||||||
|
|
||||||
|
自动将 ToolRegistry 中注册的工具暴露为 MCP 工具端点。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_registry: Any = None, host: str = "0.0.0.0", port: int = 8080):
|
||||||
|
self._tool_registry = tool_registry
|
||||||
|
self._host = host
|
||||||
|
self._port = port
|
||||||
|
self._app = None
|
||||||
|
|
||||||
|
def _create_app(self):
|
||||||
|
"""创建 FastAPI 应用"""
|
||||||
|
try:
|
||||||
|
from fastapi import FastAPI
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
|
||||||
|
|
||||||
|
app = FastAPI(title="Fischer AgentKit MCP Server")
|
||||||
|
|
||||||
|
@app.get("/tools/list")
|
||||||
|
async def list_tools():
|
||||||
|
if self._tool_registry is None:
|
||||||
|
return {"tools": []}
|
||||||
|
tools = self._tool_registry.list_tools()
|
||||||
|
return {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"inputSchema": t.input_schema or {},
|
||||||
|
}
|
||||||
|
for t in tools
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/tools/call")
|
||||||
|
async def call_tool(request: dict):
|
||||||
|
tool_name = request.get("name")
|
||||||
|
arguments = request.get("arguments", {})
|
||||||
|
|
||||||
|
if not tool_name or self._tool_registry is None:
|
||||||
|
return {"error": "Tool not specified or registry not configured"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool = self._tool_registry.get(tool_name)
|
||||||
|
result = await tool.safe_execute(**arguments)
|
||||||
|
return {"content": [{"type": "text", "text": str(result)}]}
|
||||||
|
except Exception as e:
|
||||||
|
return {"isError": True, "content": [{"type": "text", "text": str(e)}]}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""启动 MCP Server"""
|
||||||
|
self._app = self._create_app()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import uvicorn
|
||||||
|
config = uvicorn.Config(self._app, host=self._host, port=self._port, log_level="info")
|
||||||
|
server = uvicorn.Server(config)
|
||||||
|
await server.serve()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("MCP Server requires uvicorn: pip install uvicorn")
|
||||||
|
|
||||||
|
def get_app(self):
|
||||||
|
"""获取 FastAPI 应用实例(用于测试或嵌入)"""
|
||||||
|
if self._app is None:
|
||||||
|
self._app = self._create_app()
|
||||||
|
return self._app
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""AgentKit Memory - 记忆系统"""
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
from agentkit.memory.working import WorkingMemory
|
||||||
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Memory",
|
||||||
|
"MemoryItem",
|
||||||
|
"WorkingMemory",
|
||||||
|
"EpisodicMemory",
|
||||||
|
"SemanticMemory",
|
||||||
|
"MemoryRetriever",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
"""Memory 抽象基类 - 统一记忆接口"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryItem:
|
||||||
|
"""记忆条目"""
|
||||||
|
key: str
|
||||||
|
value: Any
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
score: float = 1.0
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"key": self.key,
|
||||||
|
"value": self.value,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"score": self.score,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(ABC):
|
||||||
|
"""记忆抽象基类
|
||||||
|
|
||||||
|
三层记忆系统的统一接口:
|
||||||
|
- WorkingMemory: 当前任务上下文(Redis, 短生命周期)
|
||||||
|
- EpisodicMemory: 任务经验(pgvector+PG, 永久)
|
||||||
|
- SemanticMemory: 知识库(RAG+Graph, 永久)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
"""存储记忆"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
"""按 key 精确检索"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||||
|
"""语义检索"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""删除记忆"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
|
||||||
|
"""批量存储"""
|
||||||
|
for key, value, metadata in items:
|
||||||
|
await self.store(key, value, metadata)
|
||||||
|
|
||||||
|
async def get_context(self, query: str, token_budget: int = 3000) -> str:
|
||||||
|
"""获取格式化的上下文字符串(用于注入 Prompt)"""
|
||||||
|
items = await self.search(query, top_k=10)
|
||||||
|
context_parts = []
|
||||||
|
total_tokens = 0
|
||||||
|
for item in items:
|
||||||
|
text = str(item.value)
|
||||||
|
estimated_tokens = len(text) // 4 # 粗略估算
|
||||||
|
if total_tokens + estimated_tokens > token_budget:
|
||||||
|
break
|
||||||
|
context_parts.append(text)
|
||||||
|
total_tokens += estimated_tokens
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicMemory(Memory):
|
||||||
|
"""Episodic Memory - 记录每次任务的输入/输出/效果/反思
|
||||||
|
|
||||||
|
基于 pgvector + PostgreSQL 实现,支持语义检索和时间衰减。
|
||||||
|
生命周期:永久(可配置衰减)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: Any,
|
||||||
|
episodic_model: Any,
|
||||||
|
embedder: Any | None = None,
|
||||||
|
decay_rate: float = 0.01,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
session_factory: 返回 async context manager 的工厂
|
||||||
|
episodic_model: EpisodicMemory ORM 模型类
|
||||||
|
embedder: 嵌入器,用于生成向量
|
||||||
|
decay_rate: 时间衰减率(越大衰减越快)
|
||||||
|
"""
|
||||||
|
self._session_factory = session_factory
|
||||||
|
self._episodic_model = episodic_model
|
||||||
|
self._embedder = embedder
|
||||||
|
self._decay_rate = decay_rate
|
||||||
|
|
||||||
|
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
"""存储任务经验"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
Model = self._episodic_model
|
||||||
|
meta = metadata or {}
|
||||||
|
|
||||||
|
# 生成 embedding
|
||||||
|
embedding = None
|
||||||
|
if self._embedder:
|
||||||
|
text = f"{key} {value}"
|
||||||
|
embedding = await self._embedder.embed(text)
|
||||||
|
|
||||||
|
entry = Model(
|
||||||
|
agent_name=meta.get("agent_name", ""),
|
||||||
|
task_type=meta.get("task_type", ""),
|
||||||
|
input_summary=str(value)[:500] if value else "",
|
||||||
|
output_summary=meta.get("output_summary", ""),
|
||||||
|
outcome=meta.get("outcome", "success"),
|
||||||
|
quality_score=meta.get("quality_score", 0.5),
|
||||||
|
reflection=meta.get("reflection", ""),
|
||||||
|
embedding=embedding,
|
||||||
|
)
|
||||||
|
db.add(entry)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to store episodic memory: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
"""按 key 精确检索(Episodic Memory 通常不按 key 检索)"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||||
|
"""语义检索相似历史案例"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
Model = self._episodic_model
|
||||||
|
filters = filters or {}
|
||||||
|
|
||||||
|
# 构建查询
|
||||||
|
from sqlalchemy import select, text as sql_text
|
||||||
|
stmt = select(Model)
|
||||||
|
|
||||||
|
if filters.get("agent_name"):
|
||||||
|
stmt = stmt.where(Model.agent_name == filters["agent_name"])
|
||||||
|
if filters.get("task_type"):
|
||||||
|
stmt = stmt.where(Model.task_type == filters["task_type"])
|
||||||
|
if filters.get("outcome"):
|
||||||
|
stmt = stmt.where(Model.outcome == filters["outcome"])
|
||||||
|
|
||||||
|
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 2)
|
||||||
|
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
entries = result.scalars().all()
|
||||||
|
|
||||||
|
# 如果有 embedder,进行向量相似度排序
|
||||||
|
if self._embedder and entries:
|
||||||
|
query_embedding = await self._embedder.embed(query)
|
||||||
|
# TODO: 使用 pgvector 的 cosine distance 排序
|
||||||
|
# 目前按时间衰减排序
|
||||||
|
|
||||||
|
# 时间衰减排序
|
||||||
|
items = []
|
||||||
|
for entry in entries:
|
||||||
|
age_hours = (datetime.utcnow() - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
|
||||||
|
decay = math.exp(-self._decay_rate * age_hours)
|
||||||
|
score = (entry.quality_score or 0.5) * decay
|
||||||
|
|
||||||
|
items.append(MemoryItem(
|
||||||
|
key=str(entry.id),
|
||||||
|
value={
|
||||||
|
"input_summary": entry.input_summary,
|
||||||
|
"output_summary": entry.output_summary,
|
||||||
|
"outcome": entry.outcome,
|
||||||
|
"quality_score": entry.quality_score,
|
||||||
|
"reflection": entry.reflection,
|
||||||
|
},
|
||||||
|
metadata={
|
||||||
|
"agent_name": entry.agent_name,
|
||||||
|
"task_type": entry.task_type,
|
||||||
|
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
||||||
|
},
|
||||||
|
score=score,
|
||||||
|
created_at=entry.created_at or datetime.utcnow(),
|
||||||
|
))
|
||||||
|
|
||||||
|
items.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return items[:top_k]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to search episodic memory: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""删除指定经验"""
|
||||||
|
async with self._session_factory() as db:
|
||||||
|
try:
|
||||||
|
from sqlalchemy import select, delete as sql_delete
|
||||||
|
import uuid
|
||||||
|
Model = self._episodic_model
|
||||||
|
|
||||||
|
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))
|
||||||
|
await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
logger.error(f"Failed to delete episodic memory: {e}")
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""MemoryRetriever - 混合检索器
|
||||||
|
|
||||||
|
并行查询三层记忆,按权重融合排序。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
from agentkit.memory.working import WorkingMemory
|
||||||
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRetriever:
|
||||||
|
"""混合检索器 - 并行查询三层记忆,按权重融合排序
|
||||||
|
|
||||||
|
检索策略:
|
||||||
|
1. 并行查询 Working/Episodic/Semantic 三层
|
||||||
|
2. 按权重融合排序(默认 Working 0.2, Episodic 0.4, Semantic 0.4)
|
||||||
|
3. 时间衰减:越久远的记忆权重越低
|
||||||
|
4. 上下文窗口管理:总 token 不超过预算
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
working_memory: WorkingMemory | None = None,
|
||||||
|
episodic_memory: EpisodicMemory | None = None,
|
||||||
|
semantic_memory: SemanticMemory | None = None,
|
||||||
|
weights: dict[str, float] | None = None,
|
||||||
|
):
|
||||||
|
self._working = working_memory
|
||||||
|
self._episodic = episodic_memory
|
||||||
|
self._semantic = semantic_memory
|
||||||
|
self._weights = weights or {
|
||||||
|
"working": 0.2,
|
||||||
|
"episodic": 0.4,
|
||||||
|
"semantic": 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
token_budget: int = 3000,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
) -> list[MemoryItem]:
|
||||||
|
"""混合检索三层记忆"""
|
||||||
|
tasks = []
|
||||||
|
layer_names = []
|
||||||
|
|
||||||
|
if self._working:
|
||||||
|
tasks.append(self._working.search(query, top_k=top_k, filters=filters))
|
||||||
|
layer_names.append("working")
|
||||||
|
if self._episodic:
|
||||||
|
tasks.append(self._episodic.search(query, top_k=top_k, filters=filters))
|
||||||
|
layer_names.append("episodic")
|
||||||
|
if self._semantic:
|
||||||
|
tasks.append(self._semantic.search(query, top_k=top_k, filters=filters))
|
||||||
|
layer_names.append("semantic")
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 并行查询
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 融合排序
|
||||||
|
all_items = []
|
||||||
|
for layer_name, result in zip(layer_names, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"Memory search failed for {layer_name}: {result}")
|
||||||
|
continue
|
||||||
|
weight = self._weights.get(layer_name, 0.3)
|
||||||
|
for item in result:
|
||||||
|
item.score *= weight
|
||||||
|
all_items.append(item)
|
||||||
|
|
||||||
|
# 按分数排序
|
||||||
|
all_items.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
# Token 预算管理
|
||||||
|
selected = []
|
||||||
|
total_tokens = 0
|
||||||
|
for item in all_items:
|
||||||
|
text = str(item.value)
|
||||||
|
estimated_tokens = len(text) // 4
|
||||||
|
if total_tokens + estimated_tokens > token_budget:
|
||||||
|
continue
|
||||||
|
selected.append(item)
|
||||||
|
total_tokens += estimated_tokens
|
||||||
|
if len(selected) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
async def get_context_string(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
token_budget: int = 3000,
|
||||||
|
) -> str:
|
||||||
|
"""获取格式化的上下文字符串"""
|
||||||
|
items = await self.retrieve(query, top_k, token_budget)
|
||||||
|
parts = []
|
||||||
|
for item in items:
|
||||||
|
parts.append(str(item.value))
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Semantic Memory - 知识库适配器
|
||||||
|
|
||||||
|
适配器模式,对接外部 RAG 服务和知识图谱。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticMemory(Memory):
|
||||||
|
"""Semantic Memory - 知识库检索
|
||||||
|
|
||||||
|
通过适配器对接外部 RAG 服务,不直接依赖具体实现。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rag_service: Any = None,
|
||||||
|
graph_service: Any = None,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rag_service: RAG 检索服务(需提供 search 方法)
|
||||||
|
graph_service: 知识图谱服务(需提供 query 方法)
|
||||||
|
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||||||
|
"""
|
||||||
|
self._rag_service = rag_service
|
||||||
|
self._graph_service = graph_service
|
||||||
|
self._knowledge_base_ids = knowledge_base_ids or []
|
||||||
|
|
||||||
|
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
||||||
|
if self._rag_service and hasattr(self._rag_service, 'ingest'):
|
||||||
|
await self._rag_service.ingest(key, value, metadata)
|
||||||
|
else:
|
||||||
|
logger.warning("SemanticMemory.store: no RAG service configured for writing")
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
"""按 key 精确检索(Semantic Memory 通常不按 key 检索)"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||||
|
"""语义检索知识库"""
|
||||||
|
items = []
|
||||||
|
|
||||||
|
# RAG 检索
|
||||||
|
if self._rag_service:
|
||||||
|
try:
|
||||||
|
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
||||||
|
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
||||||
|
for r in results:
|
||||||
|
items.append(MemoryItem(
|
||||||
|
key=r.get("id", ""),
|
||||||
|
value=r.get("content", ""),
|
||||||
|
metadata={
|
||||||
|
"source": r.get("source", "rag"),
|
||||||
|
"score": r.get("score", 0.0),
|
||||||
|
"document_id": r.get("document_id"),
|
||||||
|
},
|
||||||
|
score=r.get("score", 0.0),
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAG search failed: {e}")
|
||||||
|
|
||||||
|
# 知识图谱检索
|
||||||
|
if self._graph_service:
|
||||||
|
try:
|
||||||
|
graph_results = await self._graph_service.query(query, depth=2)
|
||||||
|
for r in graph_results[:top_k]:
|
||||||
|
items.append(MemoryItem(
|
||||||
|
key=r.get("id", ""),
|
||||||
|
value=r.get("content", ""),
|
||||||
|
metadata={
|
||||||
|
"source": "graph",
|
||||||
|
"entities": r.get("entities", []),
|
||||||
|
"relations": r.get("relations", []),
|
||||||
|
},
|
||||||
|
score=r.get("score", 0.0),
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Graph search failed: {e}")
|
||||||
|
|
||||||
|
items.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return items[:top_k]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Semantic Memory 通常只读"""
|
||||||
|
logger.warning("SemanticMemory.delete: read-only memory")
|
||||||
|
return False
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Working Memory - 基于 Redis 的短期任务记忆"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkingMemory(Memory):
|
||||||
|
"""Working Memory - 当前任务的上下文和中间状态
|
||||||
|
|
||||||
|
基于 Redis 实现,支持自动过期(TTL)。
|
||||||
|
生命周期:单次任务。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis: aioredis.Redis,
|
||||||
|
key_prefix: str = "agentkit:working",
|
||||||
|
default_ttl: int = 3600,
|
||||||
|
):
|
||||||
|
self._redis = redis
|
||||||
|
self._key_prefix = key_prefix
|
||||||
|
self._default_ttl = default_ttl
|
||||||
|
|
||||||
|
def _make_key(self, key: str) -> str:
|
||||||
|
return f"{self._key_prefix}:{key}"
|
||||||
|
|
||||||
|
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
item = MemoryItem(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
metadata=metadata or {},
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
await self._redis.setex(
|
||||||
|
redis_key,
|
||||||
|
self._default_ttl,
|
||||||
|
json.dumps(item.to_dict(), default=str),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
data = await self._redis.get(redis_key)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
item_dict = json.loads(data)
|
||||||
|
return MemoryItem(
|
||||||
|
key=item_dict["key"],
|
||||||
|
value=item_dict["value"],
|
||||||
|
metadata=item_dict.get("metadata", {}),
|
||||||
|
score=item_dict.get("score", 1.0),
|
||||||
|
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||||
|
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
|
||||||
|
pattern = self._make_key(f"{query}*")
|
||||||
|
keys = []
|
||||||
|
async for key in self._redis.scan_iter(match=pattern, count=top_k):
|
||||||
|
keys.append(key)
|
||||||
|
if len(keys) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for key in keys:
|
||||||
|
data = await self._redis.get(key)
|
||||||
|
if data:
|
||||||
|
item_dict = json.loads(data)
|
||||||
|
items.append(MemoryItem(
|
||||||
|
key=item_dict["key"],
|
||||||
|
value=item_dict["value"],
|
||||||
|
metadata=item_dict.get("metadata", {}),
|
||||||
|
score=1.0,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
))
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
return bool(await self._redis.delete(redis_key))
|
||||||
|
|
||||||
|
async def clear(self, prefix: str = "") -> int:
|
||||||
|
"""清除指定前缀的所有 Working Memory"""
|
||||||
|
pattern = self._make_key(f"{prefix}*")
|
||||||
|
count = 0
|
||||||
|
async for key in self._redis.scan_iter(match=pattern):
|
||||||
|
await self._redis.delete(key)
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
"""AgentKit Orchestrator - 多 Agent 协同编排"""
|
||||||
|
|
||||||
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
|
||||||
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
|
from agentkit.orchestrator.pipeline_loader import PipelineLoader
|
||||||
|
from agentkit.orchestrator.handoff import HandoffManager
|
||||||
|
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Pipeline",
|
||||||
|
"PipelineStage",
|
||||||
|
"StageStatus",
|
||||||
|
"PipelineEngine",
|
||||||
|
"PipelineLoader",
|
||||||
|
"HandoffManager",
|
||||||
|
"DynamicPipeline",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
"""DynamicPipeline - 动态 Pipeline 组合
|
||||||
|
|
||||||
|
支持运行时根据条件选择子流程、嵌套 Pipeline、循环 Pipeline。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicPipeline:
|
||||||
|
"""动态 Pipeline 组合器"""
|
||||||
|
|
||||||
|
def __init__(self, engine: PipelineEngine, loader: Any = None):
|
||||||
|
self._engine = engine
|
||||||
|
self._loader = loader
|
||||||
|
|
||||||
|
async def execute_conditional(
|
||||||
|
self,
|
||||||
|
pipelines: dict[str, Pipeline],
|
||||||
|
condition_key: str,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""根据条件选择子 Pipeline 执行"""
|
||||||
|
context = context or {}
|
||||||
|
condition_value = context.get(condition_key)
|
||||||
|
|
||||||
|
if condition_value not in pipelines:
|
||||||
|
return PipelineResult(
|
||||||
|
pipeline_name=f"conditional_{condition_key}",
|
||||||
|
status=StageStatus.FAILED,
|
||||||
|
error_message=f"No pipeline for condition '{condition_key}={condition_value}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = pipelines[condition_value]
|
||||||
|
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
|
||||||
|
return await self._engine.execute(selected, context)
|
||||||
|
|
||||||
|
async def execute_nested(
|
||||||
|
self,
|
||||||
|
parent: Pipeline,
|
||||||
|
sub_pipeline_map: dict[str, Pipeline],
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""执行嵌套 Pipeline"""
|
||||||
|
# 先执行父 Pipeline
|
||||||
|
parent_result = await self._engine.execute(parent, context)
|
||||||
|
|
||||||
|
# 根据父 Pipeline 结果选择子 Pipeline
|
||||||
|
for stage_name, stage_result in parent_result.stage_results.items():
|
||||||
|
if hasattr(stage_result, 'output_data') and stage_result.output_data:
|
||||||
|
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
|
||||||
|
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
|
||||||
|
sub = sub_pipeline_map[sub_pipeline_name]
|
||||||
|
sub_result = await self._engine.execute(sub, parent_result.variables)
|
||||||
|
parent_result.variables.update(sub_result.variables)
|
||||||
|
|
||||||
|
return parent_result
|
||||||
|
|
||||||
|
async def execute_loop(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
max_iterations: int = 5,
|
||||||
|
exit_condition: str = "done",
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""循环执行 Pipeline 直到条件满足"""
|
||||||
|
current_context = context or {}
|
||||||
|
last_result = None
|
||||||
|
|
||||||
|
for i in range(max_iterations):
|
||||||
|
logger.info(f"DynamicPipeline loop iteration {i + 1}/{max_iterations}")
|
||||||
|
result = await self._engine.execute(pipeline, current_context)
|
||||||
|
last_result = result
|
||||||
|
|
||||||
|
# 检查退出条件
|
||||||
|
if exit_condition in result.variables and result.variables[exit_condition]:
|
||||||
|
logger.info(f"DynamicPipeline loop exited at iteration {i + 1}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 将结果作为下一轮的输入
|
||||||
|
current_context.update(result.variables)
|
||||||
|
|
||||||
|
return last_result or PipelineResult(
|
||||||
|
pipeline_name=pipeline.name,
|
||||||
|
status=StageStatus.FAILED,
|
||||||
|
error_message="Loop completed without meeting exit condition",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""HandoffManager - Agent 间任务转交"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import HandoffMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HandoffManager:
|
||||||
|
"""Handoff 管理器
|
||||||
|
|
||||||
|
通过 Redis Pub/Sub 管理 Agent 间的任务转交。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis: Any = None, dispatcher: Any = None):
|
||||||
|
self._redis = redis
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
self._handlers: dict[str, list[Any]] = {}
|
||||||
|
|
||||||
|
def register_handler(self, agent_name: str, handler: Any) -> None:
|
||||||
|
"""注册 Handoff 处理器"""
|
||||||
|
if agent_name not in self._handlers:
|
||||||
|
self._handlers[agent_name] = []
|
||||||
|
self._handlers[agent_name].append(handler)
|
||||||
|
|
||||||
|
async def send_handoff(self, handoff: HandoffMessage) -> None:
|
||||||
|
"""发送 Handoff 请求"""
|
||||||
|
if self._redis is None:
|
||||||
|
raise RuntimeError("Handoff requires Redis connection")
|
||||||
|
|
||||||
|
channel = f"agent:{handoff.target_agent}:handoff"
|
||||||
|
await self._redis.publish(channel, json.dumps(handoff.to_dict()))
|
||||||
|
logger.info(
|
||||||
|
f"Handoff sent: {handoff.source_agent} -> {handoff.target_agent} "
|
||||||
|
f"(task={handoff.task_id}, reason={handoff.reason})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def listen_for_handoffs(self, agent_name: str) -> None:
|
||||||
|
"""监听发往指定 Agent 的 Handoff 请求"""
|
||||||
|
if self._redis is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
channel = f"agent:{agent_name}:handoff"
|
||||||
|
pubsub = self._redis.pubsub()
|
||||||
|
await pubsub.subscribe(channel)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] == "message":
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
handoff = HandoffMessage.from_dict(data)
|
||||||
|
await self._handle_handoff(handoff)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
await pubsub.unsubscribe(channel)
|
||||||
|
|
||||||
|
async def _handle_handoff(self, handoff: HandoffMessage) -> None:
|
||||||
|
"""处理收到的 Handoff"""
|
||||||
|
handlers = self._handlers.get(handoff.target_agent, [])
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
await handler(handoff)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Handoff handler error: {e}")
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
"""Pipeline Engine - DAG + 并行执行"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
|
Pipeline,
|
||||||
|
PipelineResult,
|
||||||
|
PipelineStage,
|
||||||
|
StageResult,
|
||||||
|
StageStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineEngine:
|
||||||
|
"""Pipeline 执行引擎
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- DAG 拓扑排序
|
||||||
|
- 同层并行执行(asyncio.gather)
|
||||||
|
- 变量解析
|
||||||
|
- 条件执行
|
||||||
|
- 重试
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dispatcher: Any = None):
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""执行 Pipeline"""
|
||||||
|
result = PipelineResult(pipeline_name=pipeline.name)
|
||||||
|
result.variables = {**pipeline.variables, **(context or {})}
|
||||||
|
|
||||||
|
# 拓扑排序 + 按依赖层级分组
|
||||||
|
try:
|
||||||
|
level_groups = self._topological_group(pipeline.stages)
|
||||||
|
except ValueError as e:
|
||||||
|
result.status = StageStatus.FAILED
|
||||||
|
result.error_message = str(e)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 逐层执行
|
||||||
|
for level, stages in enumerate(level_groups):
|
||||||
|
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
|
||||||
|
|
||||||
|
# 并行执行同层 stages
|
||||||
|
tasks = []
|
||||||
|
for stage in stages:
|
||||||
|
tasks.append(self._execute_stage(stage, result))
|
||||||
|
|
||||||
|
stage_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 处理结果
|
||||||
|
for stage, sr in zip(stages, stage_results):
|
||||||
|
if isinstance(sr, Exception):
|
||||||
|
sr = StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.FAILED,
|
||||||
|
error_message=str(sr),
|
||||||
|
)
|
||||||
|
result.stage_results[stage.name] = sr
|
||||||
|
|
||||||
|
# 收集输出变量
|
||||||
|
if sr.output_data and isinstance(sr, dict):
|
||||||
|
pass
|
||||||
|
elif hasattr(sr, 'output_data') and sr.output_data:
|
||||||
|
for output_key in stage.outputs:
|
||||||
|
if output_key in sr.output_data:
|
||||||
|
result.variables[output_key] = sr.output_data[output_key]
|
||||||
|
|
||||||
|
# 检查是否需要中止
|
||||||
|
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
|
||||||
|
if not stage.continue_on_failure:
|
||||||
|
result.status = StageStatus.FAILED
|
||||||
|
result.error_message = f"Stage '{stage.name}' failed"
|
||||||
|
return result
|
||||||
|
|
||||||
|
result.status = StageStatus.COMPLETED
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _execute_stage(
|
||||||
|
self,
|
||||||
|
stage: PipelineStage,
|
||||||
|
pipeline_result: PipelineResult,
|
||||||
|
) -> StageResult:
|
||||||
|
"""执行单个 stage"""
|
||||||
|
started_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# 条件检查
|
||||||
|
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
|
||||||
|
return StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.SKIPPED,
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析输入变量
|
||||||
|
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
|
||||||
|
|
||||||
|
# 执行
|
||||||
|
if self._dispatcher is None:
|
||||||
|
# Dry-run 模式
|
||||||
|
return StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.COMPLETED,
|
||||||
|
output_data={"dry_run": True, "inputs": resolved_inputs},
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 通过 Dispatcher 分发任务
|
||||||
|
from agentkit.core.protocol import TaskMessage
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name=stage.agent,
|
||||||
|
task_type=stage.action,
|
||||||
|
priority=0,
|
||||||
|
input_data=resolved_inputs,
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
timeout_seconds=stage.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._dispatcher.dispatch(task)
|
||||||
|
|
||||||
|
# 等待结果
|
||||||
|
for _ in range(stage.timeout_seconds):
|
||||||
|
status = await self._dispatcher.get_task_status(task.task_id)
|
||||||
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
|
return StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
||||||
|
output_data=status.get("output_data"),
|
||||||
|
error_message=status.get("error_message"),
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
return StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.FAILED,
|
||||||
|
error_message=f"Timeout after {stage.timeout_seconds}s",
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return StageResult(
|
||||||
|
stage_name=stage.name,
|
||||||
|
status=StageStatus.FAILED,
|
||||||
|
error_message=str(e),
|
||||||
|
started_at=started_at,
|
||||||
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _topological_group(stages: list[PipelineStage]) -> list[list[PipelineStage]]:
|
||||||
|
"""拓扑排序 + 按依赖层级分组"""
|
||||||
|
stage_map = {s.name: s for s in stages}
|
||||||
|
in_degree = defaultdict(int)
|
||||||
|
dependents = defaultdict(list)
|
||||||
|
|
||||||
|
for s in stages:
|
||||||
|
if s.name not in in_degree:
|
||||||
|
in_degree[s.name] = 0
|
||||||
|
for dep in s.depends_on:
|
||||||
|
if dep not in stage_map:
|
||||||
|
raise ValueError(f"Stage '{s.name}' depends on unknown stage '{dep}'")
|
||||||
|
in_degree[s.name] += 1
|
||||||
|
dependents[dep].append(s.name)
|
||||||
|
|
||||||
|
levels = []
|
||||||
|
remaining = set(in_degree.keys())
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
# 找到入度为 0 的节点
|
||||||
|
current_level = [name for name in remaining if in_degree[name] == 0]
|
||||||
|
if not current_level:
|
||||||
|
raise ValueError("Circular dependency detected in pipeline")
|
||||||
|
|
||||||
|
levels.append([stage_map[name] for name in current_level])
|
||||||
|
|
||||||
|
for name in current_level:
|
||||||
|
remaining.remove(name)
|
||||||
|
for dep in dependents[name]:
|
||||||
|
in_degree[dep] -= 1
|
||||||
|
|
||||||
|
return levels
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_variables(template: dict, context: dict) -> dict:
|
||||||
|
"""解析 ${var.path} 变量引用"""
|
||||||
|
resolved = {}
|
||||||
|
for key, value in template.items():
|
||||||
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||||
|
var_path = value[2:-1]
|
||||||
|
resolved[key] = PipelineEngine._get_nested(context, var_path)
|
||||||
|
else:
|
||||||
|
resolved[key] = value
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_nested(data: dict, path: str) -> Any:
|
||||||
|
keys = path.split(".")
|
||||||
|
current = data
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(current, dict):
|
||||||
|
current = current.get(key)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return current
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _evaluate_condition(condition: str, variables: dict) -> bool:
|
||||||
|
"""简单条件评估"""
|
||||||
|
if "==" in condition:
|
||||||
|
parts = condition.split("==", 1)
|
||||||
|
left = variables.get(parts[0].strip(), parts[0].strip())
|
||||||
|
right = parts[1].strip().strip("'\"")
|
||||||
|
return str(left) == right
|
||||||
|
elif "!=" in condition:
|
||||||
|
parts = condition.split("!=", 1)
|
||||||
|
left = variables.get(parts[0].strip(), parts[0].strip())
|
||||||
|
right = parts[1].strip().strip("'\"")
|
||||||
|
return str(left) != right
|
||||||
|
else:
|
||||||
|
return bool(variables.get(condition))
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""Pipeline Loader - YAML 加载器"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineLoader:
|
||||||
|
"""Pipeline YAML 加载器"""
|
||||||
|
|
||||||
|
def __init__(self, pipelines_dir: str | Path | None = None):
|
||||||
|
self._pipelines_dir = Path(pipelines_dir) if pipelines_dir else Path("pipelines")
|
||||||
|
|
||||||
|
def load(self, pipeline_name: str) -> Pipeline:
|
||||||
|
"""从 YAML 文件加载 Pipeline"""
|
||||||
|
yaml_path = self._pipelines_dir / f"{pipeline_name}.yaml"
|
||||||
|
if not yaml_path.exists():
|
||||||
|
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
|
||||||
|
if not yaml_path.exists():
|
||||||
|
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
|
||||||
|
|
||||||
|
content = yaml_path.read_text(encoding="utf-8")
|
||||||
|
return self.load_from_yaml(content, pipeline_name)
|
||||||
|
|
||||||
|
def load_from_yaml(self, yaml_content: str, pipeline_name: str | None = None) -> Pipeline:
|
||||||
|
"""从 YAML 字符串加载 Pipeline"""
|
||||||
|
data = yaml.safe_load(yaml_content)
|
||||||
|
|
||||||
|
stages = []
|
||||||
|
for stage_data in data.get("stages", []):
|
||||||
|
stages.append(PipelineStage(**stage_data))
|
||||||
|
|
||||||
|
return Pipeline(
|
||||||
|
name=data.get("name", pipeline_name or "unnamed"),
|
||||||
|
version=data.get("version", "1.0.0"),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
stages=stages,
|
||||||
|
variables=data.get("variables", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_dag(stages: list[PipelineStage]) -> bool:
|
||||||
|
"""验证 DAG 无环"""
|
||||||
|
stage_names = {s.name for s in stages}
|
||||||
|
visited = set()
|
||||||
|
path = set()
|
||||||
|
|
||||||
|
def dfs(name: str) -> bool:
|
||||||
|
if name in path:
|
||||||
|
return False
|
||||||
|
if name in visited:
|
||||||
|
return True
|
||||||
|
|
||||||
|
path.add(name)
|
||||||
|
stage = next((s for s in stages if s.name == name), None)
|
||||||
|
if stage:
|
||||||
|
for dep in stage.depends_on:
|
||||||
|
if dep not in stage_names:
|
||||||
|
return False
|
||||||
|
if not dfs(dep):
|
||||||
|
return False
|
||||||
|
path.remove(name)
|
||||||
|
visited.add(name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return all(dfs(name) for name in stage_names)
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""Pipeline 数据模型"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class StageStatus(str, Enum):
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
SKIPPED = "skipped"
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStage(BaseModel):
|
||||||
|
name: str
|
||||||
|
agent: str
|
||||||
|
action: str
|
||||||
|
depends_on: list[str] = []
|
||||||
|
inputs: dict[str, Any] = {}
|
||||||
|
outputs: list[str] = []
|
||||||
|
timeout_seconds: int = 300
|
||||||
|
retry_count: int = 0
|
||||||
|
continue_on_failure: bool = False
|
||||||
|
condition: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline(BaseModel):
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
description: str
|
||||||
|
stages: list[PipelineStage]
|
||||||
|
variables: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class StageResult(BaseModel):
|
||||||
|
stage_name: str
|
||||||
|
status: StageStatus = StageStatus.PENDING
|
||||||
|
output_data: dict[str, Any] | None = None
|
||||||
|
error_message: str | None = None
|
||||||
|
started_at: str | None = None
|
||||||
|
completed_at: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineResult(BaseModel):
|
||||||
|
pipeline_name: str
|
||||||
|
status: StageStatus = StageStatus.PENDING
|
||||||
|
stage_results: dict[str, StageResult] = {}
|
||||||
|
variables: dict[str, Any] = {}
|
||||||
|
error_message: str | None = None
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
"""AgentKit Prompts - Prompt 模板系统"""
|
||||||
|
|
||||||
|
from agentkit.prompts.template import PromptTemplate
|
||||||
|
from agentkit.prompts.section import PromptSection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PromptTemplate",
|
||||||
|
"PromptSection",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""PromptSection - 模块化 Prompt 段落"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptSection:
|
||||||
|
"""Prompt 段落定义
|
||||||
|
|
||||||
|
将 Prompt 分为 5 个标准段落,支持变量注入和 Token 预算管理。
|
||||||
|
"""
|
||||||
|
identity: str = ""
|
||||||
|
context: str = ""
|
||||||
|
instructions: str = ""
|
||||||
|
constraints: str = ""
|
||||||
|
output_format: str = ""
|
||||||
|
examples: str = ""
|
||||||
|
|
||||||
|
def render(self, variables: dict | None = None) -> str:
|
||||||
|
"""渲染段落,替换变量"""
|
||||||
|
text = "\n\n".join(
|
||||||
|
part for part in [
|
||||||
|
self.identity,
|
||||||
|
self.context,
|
||||||
|
self.instructions,
|
||||||
|
self.constraints,
|
||||||
|
self.output_format,
|
||||||
|
self.examples,
|
||||||
|
] if part
|
||||||
|
)
|
||||||
|
|
||||||
|
if variables:
|
||||||
|
for key, value in variables.items():
|
||||||
|
text = text.replace(f"${{{key}}}", str(value))
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
"""PromptTemplate - Prompt 模板渲染"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.prompts.section import PromptSection
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplate:
|
||||||
|
"""Prompt 模板
|
||||||
|
|
||||||
|
支持变量注入、Token 预算管理、动态段落组合。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sections: PromptSection,
|
||||||
|
name: str = "",
|
||||||
|
version: str = "1.0.0",
|
||||||
|
):
|
||||||
|
self._sections = sections
|
||||||
|
self.name = name
|
||||||
|
self.version = version
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
variables: dict[str, Any] | None = None,
|
||||||
|
context_budget: int = 3000,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""渲染 Prompt 为消息列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
|
||||||
|
"""
|
||||||
|
system_parts = []
|
||||||
|
if self._sections.identity:
|
||||||
|
system_parts.append(self._sections.identity)
|
||||||
|
if self._sections.context:
|
||||||
|
context = self._sections.context
|
||||||
|
if variables:
|
||||||
|
for key, value in variables.items():
|
||||||
|
context = context.replace(f"${{{key}}}", str(value))
|
||||||
|
system_parts.append(context)
|
||||||
|
if self._sections.constraints:
|
||||||
|
system_parts.append(self._sections.constraints)
|
||||||
|
|
||||||
|
user_parts = []
|
||||||
|
if self._sections.instructions:
|
||||||
|
instructions = self._sections.instructions
|
||||||
|
if variables:
|
||||||
|
for key, value in variables.items():
|
||||||
|
instructions = instructions.replace(f"${{{key}}}", str(value))
|
||||||
|
user_parts.append(instructions)
|
||||||
|
if self._sections.output_format:
|
||||||
|
user_parts.append(self._sections.output_format)
|
||||||
|
if self._sections.examples:
|
||||||
|
user_parts.append(self._sections.examples)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_parts:
|
||||||
|
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
|
||||||
|
if user_parts:
|
||||||
|
messages.append({"role": "user", "content": "\n\n".join(user_parts)})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sections(self) -> PromptSection:
|
||||||
|
return self._sections
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""AgentKit Tools - 工具插件系统"""
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
from agentkit.tools.agent_tool import AgentTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Tool",
|
||||||
|
"FunctionTool",
|
||||||
|
"AgentTool",
|
||||||
|
"ToolRegistry",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""AgentTool - 将 Agent 包装为 Tool"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTool(Tool):
|
||||||
|
"""将另一个 Agent 包装为 Tool
|
||||||
|
|
||||||
|
通过 Dispatcher 分发任务到目标 Agent,等待结果返回。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
agent_name: str,
|
||||||
|
task_type: str,
|
||||||
|
input_mapping: dict[str, str] | None = None,
|
||||||
|
output_mapping: dict[str, str] | None = None,
|
||||||
|
timeout_seconds: int = 300,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["agent"],
|
||||||
|
)
|
||||||
|
self.agent_name = agent_name
|
||||||
|
self.task_type = task_type
|
||||||
|
self.input_mapping = input_mapping or {}
|
||||||
|
self.output_mapping = output_mapping or {}
|
||||||
|
self.timeout_seconds = timeout_seconds
|
||||||
|
self._dispatcher = None
|
||||||
|
|
||||||
|
def set_dispatcher(self, dispatcher: Any) -> "AgentTool":
|
||||||
|
"""注入 Dispatcher"""
|
||||||
|
self._dispatcher = dispatcher
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
if self._dispatcher is None:
|
||||||
|
raise RuntimeError(f"AgentTool '{self.name}' has no dispatcher configured")
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
# 映射输入
|
||||||
|
mapped_input = {}
|
||||||
|
for target_key, source_key in self.input_mapping.items():
|
||||||
|
if source_key in kwargs:
|
||||||
|
mapped_input[target_key] = kwargs[source_key]
|
||||||
|
if not mapped_input:
|
||||||
|
mapped_input = kwargs
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name=self.agent_name,
|
||||||
|
task_type=self.task_type,
|
||||||
|
priority=0,
|
||||||
|
input_data=mapped_input,
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
timeout_seconds=self.timeout_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._dispatcher.dispatch(task)
|
||||||
|
|
||||||
|
# 等待结果
|
||||||
|
import asyncio
|
||||||
|
for _ in range(self.timeout_seconds):
|
||||||
|
status = await self._dispatcher.get_task_status(task.task_id)
|
||||||
|
if status["status"] in ("completed", "failed", "cancelled"):
|
||||||
|
if status["status"] == "completed" and status.get("output_data"):
|
||||||
|
output = status["output_data"]
|
||||||
|
# 映射输出
|
||||||
|
if self.output_mapping:
|
||||||
|
mapped_output = {}
|
||||||
|
for target_key, source_key in self.output_mapping.items():
|
||||||
|
if source_key in output:
|
||||||
|
mapped_output[target_key] = output[source_key]
|
||||||
|
return mapped_output
|
||||||
|
return output
|
||||||
|
elif status["status"] == "failed":
|
||||||
|
raise RuntimeError(f"Agent '{self.agent_name}' failed: {status.get('error_message')}")
|
||||||
|
return {}
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError(f"Agent '{self.agent_name}' timed out after {self.timeout_seconds}s")
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""Tool 抽象基类 - 统一工具接口"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(ABC):
|
||||||
|
"""工具抽象基类
|
||||||
|
|
||||||
|
所有工具(FunctionTool, AgentTool, MCPTool)的统一接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.version = version
|
||||||
|
self.tags = tags or []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行工具,返回结果 dict"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def before_execute(self, **kwargs) -> None:
|
||||||
|
"""执行前钩子"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def after_execute(self, result: dict, **kwargs) -> None:
|
||||||
|
"""执行后钩子"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||||
|
"""错误钩子"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def safe_execute(self, **kwargs) -> dict:
|
||||||
|
"""带钩子的安全执行"""
|
||||||
|
try:
|
||||||
|
await self.before_execute(**kwargs)
|
||||||
|
result = await self.execute(**kwargs)
|
||||||
|
await self.after_execute(result, **kwargs)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
await self.on_error(e, **kwargs)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"input_schema": self.input_schema,
|
||||||
|
"output_schema": self.output_schema,
|
||||||
|
"version": self.version,
|
||||||
|
"tags": self.tags,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{type(self).__name__} name={self.name!r} version={self.version}>"
|
||||||
|
|
@ -0,0 +1,76 @@
|
||||||
|
"""FunctionTool - 将普通 Python 函数包装为 Tool"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionTool(Tool):
|
||||||
|
"""将普通 Python 函数包装为 Tool
|
||||||
|
|
||||||
|
自动从函数签名推断 input_schema。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
func: Callable[..., Awaitable[dict]] | Callable[..., dict],
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._infer_schema(func),
|
||||||
|
output_schema=output_schema,
|
||||||
|
version=version,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
self._func = func
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
result = self._func(**kwargs)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
result = {"result": result}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infer_schema(func: Callable) -> dict:
|
||||||
|
"""从函数签名推断 JSON Schema"""
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
properties = {}
|
||||||
|
required = []
|
||||||
|
|
||||||
|
for param_name, param in sig.parameters.items():
|
||||||
|
if param_name in ("self", "cls"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_type = "string"
|
||||||
|
if param.annotation != inspect.Parameter.empty:
|
||||||
|
if param.annotation in (int, float):
|
||||||
|
param_type = "number"
|
||||||
|
elif param.annotation == bool:
|
||||||
|
param_type = "boolean"
|
||||||
|
elif param.annotation in (list, tuple):
|
||||||
|
param_type = "array"
|
||||||
|
elif param.annotation == dict:
|
||||||
|
param_type = "object"
|
||||||
|
|
||||||
|
properties[param_name] = {"type": param_type}
|
||||||
|
|
||||||
|
if param.default == inspect.Parameter.empty:
|
||||||
|
required.append(param_name)
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
}
|
||||||
|
if required:
|
||||||
|
schema["required"] = required
|
||||||
|
return schema
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
"""ToolRegistry - 工具注册中心"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.exceptions import ToolNotFoundError
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册中心,管理工具的注册、发现、版本"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._tools: dict[str, dict[str, Tool]] = {} # name -> {version -> tool}
|
||||||
|
|
||||||
|
def register(self, tool: Tool) -> "ToolRegistry":
|
||||||
|
"""注册工具"""
|
||||||
|
if tool.name not in self._tools:
|
||||||
|
self._tools[tool.name] = {}
|
||||||
|
self._tools[tool.name][tool.version] = tool
|
||||||
|
logger.info(f"Tool '{tool.name}' v{tool.version} registered")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def unregister(self, name: str, version: str | None = None) -> None:
|
||||||
|
"""注销工具"""
|
||||||
|
if name not in self._tools:
|
||||||
|
return
|
||||||
|
if version:
|
||||||
|
self._tools[name].pop(version, None)
|
||||||
|
if not self._tools[name]:
|
||||||
|
del self._tools[name]
|
||||||
|
else:
|
||||||
|
del self._tools[name]
|
||||||
|
|
||||||
|
def get(self, name: str, version: str | None = None) -> Tool:
|
||||||
|
"""获取工具(默认返回最新版本)"""
|
||||||
|
if name not in self._tools:
|
||||||
|
raise ToolNotFoundError(name)
|
||||||
|
|
||||||
|
versions = self._tools[name]
|
||||||
|
if version:
|
||||||
|
if version not in versions:
|
||||||
|
raise ToolNotFoundError(f"{name}@{version}")
|
||||||
|
return versions[version]
|
||||||
|
|
||||||
|
# 返回最新版本
|
||||||
|
latest = sorted(versions.keys())[-1]
|
||||||
|
return versions[latest]
|
||||||
|
|
||||||
|
def list_tools(self, tag: str | None = None) -> list[Tool]:
|
||||||
|
"""列出所有工具(最新版本),可按标签过滤"""
|
||||||
|
result = []
|
||||||
|
for name, versions in self._tools.items():
|
||||||
|
latest = sorted(versions.keys())[-1]
|
||||||
|
tool = versions[latest]
|
||||||
|
if tag is None or tag in tool.tags:
|
||||||
|
result.append(tool)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def list_all_versions(self, name: str) -> list[Tool]:
|
||||||
|
"""列出指定工具的所有版本"""
|
||||||
|
if name not in self._tools:
|
||||||
|
return []
|
||||||
|
return list(self._tools[name].values())
|
||||||
|
|
||||||
|
def has_tool(self, name: str) -> bool:
|
||||||
|
return name in self._tools
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._tools.clear()
|
||||||
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""Tests for BaseAgent - 统一生命周期"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
AgentStatus,
|
||||||
|
TaskMessage,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleAgent(BaseAgent):
|
||||||
|
"""测试用简单 Agent"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
|
||||||
|
self.task_started = False
|
||||||
|
self.task_completed = False
|
||||||
|
self.task_failed = False
|
||||||
|
|
||||||
|
async def handle_task(self, task: TaskMessage) -> dict:
|
||||||
|
if task.task_type == "echo":
|
||||||
|
return {"echo": task.input_data}
|
||||||
|
elif task.task_type == "fail":
|
||||||
|
raise ValueError("intentional failure")
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
def get_capabilities(self) -> AgentCapability:
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
version=self.version,
|
||||||
|
supported_tasks=["echo", "fail"],
|
||||||
|
max_concurrency=2,
|
||||||
|
description="Test agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_task_start(self, task):
|
||||||
|
self.task_started = True
|
||||||
|
|
||||||
|
async def on_task_complete(self, task, output):
|
||||||
|
self.task_completed = True
|
||||||
|
|
||||||
|
async def on_task_failed(self, task, error):
|
||||||
|
self.task_failed = True
|
||||||
|
|
||||||
|
|
||||||
|
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
|
||||||
|
return TaskMessage(
|
||||||
|
task_id="test-001",
|
||||||
|
agent_name="test_agent",
|
||||||
|
task_type=task_type,
|
||||||
|
priority=0,
|
||||||
|
input_data=input_data or {},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_task_returns_output():
|
||||||
|
agent = SimpleAgent()
|
||||||
|
task = _make_task("echo", {"msg": "hello"})
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
assert result.output_data == {"echo": {"msg": "hello"}}
|
||||||
|
assert result.error_message is None
|
||||||
|
assert result.metrics["task_type"] == "echo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_task_failure():
|
||||||
|
agent = SimpleAgent()
|
||||||
|
task = _make_task("fail")
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert result.error_message == "intentional failure"
|
||||||
|
assert result.metrics["error_type"] == "ValueError"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifecycle_hooks():
|
||||||
|
agent = SimpleAgent()
|
||||||
|
|
||||||
|
# 成功任务
|
||||||
|
task = _make_task("echo")
|
||||||
|
await agent.execute(task)
|
||||||
|
assert agent.task_started is True
|
||||||
|
assert agent.task_completed is True
|
||||||
|
|
||||||
|
# 重置
|
||||||
|
agent.task_started = False
|
||||||
|
agent.task_completed = False
|
||||||
|
|
||||||
|
# 失败任务
|
||||||
|
task = _make_task("fail")
|
||||||
|
await agent.execute(task)
|
||||||
|
assert agent.task_started is True
|
||||||
|
assert agent.task_failed is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_wraps_timing():
|
||||||
|
agent = SimpleAgent()
|
||||||
|
task = _make_task("echo")
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.started_at is not None
|
||||||
|
assert result.completed_at is not None
|
||||||
|
assert result.metrics["elapsed_seconds"] >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_status():
|
||||||
|
agent = SimpleAgent()
|
||||||
|
assert agent.status == AgentStatus.OFFLINE
|
||||||
|
assert agent.is_distributed is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_injection():
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
|
||||||
|
async def my_tool(x: int) -> dict:
|
||||||
|
return {"doubled": x * 2}
|
||||||
|
|
||||||
|
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
|
||||||
|
agent = SimpleAgent()
|
||||||
|
agent.use_tool(tool)
|
||||||
|
|
||||||
|
assert len(agent.tools) == 1
|
||||||
|
assert agent.tools[0].name == "doubler"
|
||||||
|
|
@ -0,0 +1,131 @@
|
||||||
|
"""Tests for Evolution system"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.evolution.reflector import Reflector, Reflection
|
||||||
|
from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module
|
||||||
|
from agentkit.evolution.strategy_tuner import StrategyTuner, StrategyConfig
|
||||||
|
from agentkit.evolution.ab_tester import ABTester, ABTestConfig
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
def _make_task() -> TaskMessage:
|
||||||
|
return TaskMessage(
|
||||||
|
task_id="test-001",
|
||||||
|
agent_name="test",
|
||||||
|
task_type="echo",
|
||||||
|
priority=0,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||||||
|
return TaskResult(
|
||||||
|
task_id="test-001",
|
||||||
|
agent_name="test",
|
||||||
|
status=status,
|
||||||
|
output_data={"key": "value"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"elapsed_seconds": 5.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reflector_success():
|
||||||
|
reflector = Reflector()
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result()
|
||||||
|
|
||||||
|
reflection = await reflector.reflect(task, result)
|
||||||
|
assert reflection.outcome == "success"
|
||||||
|
assert reflection.quality_score > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reflector_failure():
|
||||||
|
reflector = Reflector()
|
||||||
|
task = _make_task()
|
||||||
|
result = _make_result(TaskStatus.FAILED)
|
||||||
|
result.error_message = "something went wrong"
|
||||||
|
|
||||||
|
reflection = await reflector.reflect(task, result)
|
||||||
|
assert reflection.outcome == "failure"
|
||||||
|
assert reflection.quality_score == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_optimizer():
|
||||||
|
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=2)
|
||||||
|
|
||||||
|
# Add examples
|
||||||
|
for i in range(5):
|
||||||
|
optimizer.add_example(
|
||||||
|
input_data={"query": f"query_{i}"},
|
||||||
|
output_data={"result": f"result_{i}"},
|
||||||
|
quality_score=0.8 + i * 0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
module = Module(
|
||||||
|
name="test_module",
|
||||||
|
signature=Signature(
|
||||||
|
input_fields={"query": "search query"},
|
||||||
|
output_fields={"result": "search result"},
|
||||||
|
instruction="Find the best result.",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
optimized = await optimizer.optimize(module)
|
||||||
|
assert optimized.name == "test_module_optimized"
|
||||||
|
assert len(optimized.demos) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_optimizer_not_enough_examples():
|
||||||
|
optimizer = PromptOptimizer(min_examples_for_optimization=10)
|
||||||
|
module = Module(
|
||||||
|
name="test",
|
||||||
|
signature=Signature(
|
||||||
|
input_fields={"x": "input"},
|
||||||
|
output_fields={"y": "output"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
optimized = await optimizer.optimize(module)
|
||||||
|
# Should return unchanged module
|
||||||
|
assert optimized.name == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_strategy_tuner():
|
||||||
|
tuner = StrategyTuner()
|
||||||
|
|
||||||
|
config = StrategyConfig(temperature=0.5)
|
||||||
|
tuner.record(config, metric=0.6)
|
||||||
|
tuner.record(StrategyConfig(temperature=0.7), metric=0.8)
|
||||||
|
tuner.record(StrategyConfig(temperature=0.3), metric=0.4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ab_tester():
|
||||||
|
tester = ABTester()
|
||||||
|
test_config = ABTestConfig(
|
||||||
|
test_id="test-1",
|
||||||
|
agent_name="test_agent",
|
||||||
|
change_type="prompt",
|
||||||
|
min_samples=5,
|
||||||
|
)
|
||||||
|
tester.create_test(test_config)
|
||||||
|
|
||||||
|
# Record results
|
||||||
|
for _ in range(10):
|
||||||
|
group = tester.assign_group("test-1")
|
||||||
|
metric = 0.7 if group == "experiment" else 0.5
|
||||||
|
tester.record_result("test-1", group, metric)
|
||||||
|
|
||||||
|
result = await tester.evaluate("test-1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.control_samples + result.experiment_samples == 10
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""Tests for Pipeline system"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
|
||||||
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
|
from agentkit.orchestrator.pipeline_loader import PipelineLoader
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_schema():
|
||||||
|
stage = PipelineStage(
|
||||||
|
name="step1",
|
||||||
|
agent="agent_a",
|
||||||
|
action="process",
|
||||||
|
depends_on=[],
|
||||||
|
inputs={"data": "${input}"},
|
||||||
|
)
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name="test_pipeline",
|
||||||
|
version="1.0.0",
|
||||||
|
description="Test",
|
||||||
|
stages=[stage],
|
||||||
|
variables={"input": "hello"},
|
||||||
|
)
|
||||||
|
assert len(pipeline.stages) == 1
|
||||||
|
assert pipeline.variables["input"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_topological_group():
|
||||||
|
stages = [
|
||||||
|
PipelineStage(name="a", agent="agent1", action="do_a", depends_on=[]),
|
||||||
|
PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]),
|
||||||
|
PipelineStage(name="c", agent="agent3", action="do_c", depends_on=["a"]),
|
||||||
|
PipelineStage(name="d", agent="agent4", action="do_d", depends_on=["b", "c"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
groups = PipelineEngine._topological_group(stages)
|
||||||
|
assert len(groups) == 3 # [a], [b, c], [d]
|
||||||
|
assert groups[0][0].name == "a"
|
||||||
|
assert {s.name for s in groups[1]} == {"b", "c"}
|
||||||
|
assert groups[2][0].name == "d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_topological_group_circular():
|
||||||
|
stages = [
|
||||||
|
PipelineStage(name="a", agent="agent1", action="do_a", depends_on=["b"]),
|
||||||
|
PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Circular dependency"):
|
||||||
|
PipelineEngine._topological_group(stages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_dry_run():
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name="test",
|
||||||
|
version="1.0.0",
|
||||||
|
description="Test",
|
||||||
|
stages=[
|
||||||
|
PipelineStage(name="step1", agent="agent1", action="process", inputs={"x": "1"}),
|
||||||
|
PipelineStage(name="step2", agent="agent2", action="analyze", depends_on=["step1"], inputs={"y": "2"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = PipelineEngine(dispatcher=None) # dry-run mode
|
||||||
|
result = await engine.execute(pipeline)
|
||||||
|
assert result.status.value == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_yaml_loader():
|
||||||
|
yaml_content = """
|
||||||
|
name: test_pipeline
|
||||||
|
version: "1.0"
|
||||||
|
description: Test pipeline
|
||||||
|
stages:
|
||||||
|
- name: step1
|
||||||
|
agent: agent1
|
||||||
|
action: process
|
||||||
|
inputs:
|
||||||
|
data: hello
|
||||||
|
- name: step2
|
||||||
|
agent: agent2
|
||||||
|
action: analyze
|
||||||
|
depends_on: [step1]
|
||||||
|
inputs:
|
||||||
|
result: ${{step1_output}}
|
||||||
|
variables:
|
||||||
|
input: world
|
||||||
|
"""
|
||||||
|
loader = PipelineLoader()
|
||||||
|
pipeline = loader.load_from_yaml(yaml_content, "test")
|
||||||
|
assert pipeline.name == "test_pipeline"
|
||||||
|
assert len(pipeline.stages) == 2
|
||||||
|
assert pipeline.stages[1].depends_on == ["step1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_dag():
|
||||||
|
stages = [
|
||||||
|
PipelineStage(name="a", agent="a1", action="do", depends_on=[]),
|
||||||
|
PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]),
|
||||||
|
]
|
||||||
|
assert PipelineLoader.validate_dag(stages) is True
|
||||||
|
|
||||||
|
circular = [
|
||||||
|
PipelineStage(name="a", agent="a1", action="do", depends_on=["b"]),
|
||||||
|
PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]),
|
||||||
|
]
|
||||||
|
assert PipelineLoader.validate_dag(circular) is False
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""Tests for Protocol data structures"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
HandoffMessage,
|
||||||
|
TaskMessage,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
EvolutionEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_status_values():
|
||||||
|
assert TaskStatus.PENDING == "pending"
|
||||||
|
assert TaskStatus.RUNNING == "running"
|
||||||
|
assert TaskStatus.COMPLETED == "completed"
|
||||||
|
assert TaskStatus.FAILED == "failed"
|
||||||
|
assert TaskStatus.CANCELLED == "cancelled"
|
||||||
|
assert TaskStatus.HANDOFF == "handoff"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_capability_with_schema():
|
||||||
|
cap = AgentCapability(
|
||||||
|
agent_name="test",
|
||||||
|
agent_type="test",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["echo"],
|
||||||
|
max_concurrency=2,
|
||||||
|
description="Test",
|
||||||
|
input_schema={"type": "object", "properties": {"x": {"type": "number"}}},
|
||||||
|
output_schema={"type": "object"},
|
||||||
|
)
|
||||||
|
|
||||||
|
d = cap.to_dict()
|
||||||
|
assert "input_schema" in d
|
||||||
|
assert "output_schema" in d
|
||||||
|
|
||||||
|
restored = AgentCapability.from_dict(d)
|
||||||
|
assert restored.agent_name == "test"
|
||||||
|
assert restored.input_schema is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_message_roundtrip():
|
||||||
|
msg = TaskMessage(
|
||||||
|
task_id="123",
|
||||||
|
agent_name="agent1",
|
||||||
|
task_type="echo",
|
||||||
|
priority=1,
|
||||||
|
input_data={"key": "value"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
conversation_id="conv-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
d = msg.to_dict()
|
||||||
|
assert d["conversation_id"] == "conv-1"
|
||||||
|
|
||||||
|
restored = TaskMessage.from_dict(d)
|
||||||
|
assert restored.task_id == "123"
|
||||||
|
assert restored.conversation_id == "conv-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handoff_message():
|
||||||
|
msg = HandoffMessage(
|
||||||
|
source_agent="agent_a",
|
||||||
|
target_agent="agent_b",
|
||||||
|
task_id="task-1",
|
||||||
|
task_type="analyze",
|
||||||
|
context={"data": "test"},
|
||||||
|
reason="Need expert analysis",
|
||||||
|
)
|
||||||
|
|
||||||
|
d = msg.to_dict()
|
||||||
|
assert d["source_agent"] == "agent_a"
|
||||||
|
assert d["target_agent"] == "agent_b"
|
||||||
|
|
||||||
|
restored = HandoffMessage.from_dict(d)
|
||||||
|
assert restored.reason == "Need expert analysis"
|
||||||
|
|
||||||
|
|
||||||
|
def test_evolution_event():
|
||||||
|
event = EvolutionEvent(
|
||||||
|
agent_name="optimizer",
|
||||||
|
change_type="prompt",
|
||||||
|
before={"instruction": "old"},
|
||||||
|
after={"instruction": "new"},
|
||||||
|
metrics={"quality_delta": 0.15},
|
||||||
|
)
|
||||||
|
|
||||||
|
d = event.to_dict()
|
||||||
|
assert d["change_type"] == "prompt"
|
||||||
|
assert d["metrics"]["quality_delta"] == 0.15
|
||||||
|
|
@ -0,0 +1,104 @@
|
||||||
|
"""Tests for Tool system"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
async def add_numbers(a: int, b: int) -> dict:
|
||||||
|
return {"sum": a + b}
|
||||||
|
|
||||||
|
|
||||||
|
def sync_greet(name: str) -> dict:
|
||||||
|
return {"greeting": f"Hello, {name}!"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_function_tool_async():
|
||||||
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
||||||
|
result = await tool.execute(a=1, b=2)
|
||||||
|
assert result == {"sum": 3}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_function_tool_sync():
|
||||||
|
tool = FunctionTool(name="greet", description="Greet someone", func=sync_greet)
|
||||||
|
result = await tool.execute(name="World")
|
||||||
|
assert result == {"greeting": "Hello, World!"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_function_tool_schema_inference():
|
||||||
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
||||||
|
assert tool.input_schema is not None
|
||||||
|
assert "a" in tool.input_schema.get("properties", {})
|
||||||
|
assert "b" in tool.input_schema.get("properties", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_registry():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
assert registry.has_tool("add")
|
||||||
|
retrieved = registry.get("add")
|
||||||
|
assert retrieved.name == "add"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_registry_versioning():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
v1 = FunctionTool(name="add", description="Add v1", func=add_numbers, version="1.0.0")
|
||||||
|
v2 = FunctionTool(name="add", description="Add v2", func=add_numbers, version="2.0.0")
|
||||||
|
|
||||||
|
registry.register(v1)
|
||||||
|
registry.register(v2)
|
||||||
|
|
||||||
|
# Default returns latest
|
||||||
|
latest = registry.get("add")
|
||||||
|
assert latest.version == "2.0.0"
|
||||||
|
|
||||||
|
# Can request specific version
|
||||||
|
specific = registry.get("add", version="1.0.0")
|
||||||
|
assert specific.version == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_registry_list():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
t1 = FunctionTool(name="add", description="Add", func=add_numbers, tags=["math"])
|
||||||
|
t2 = FunctionTool(name="greet", description="Greet", func=sync_greet, tags=["text"])
|
||||||
|
|
||||||
|
registry.register(t1)
|
||||||
|
registry.register(t2)
|
||||||
|
|
||||||
|
all_tools = registry.list_tools()
|
||||||
|
assert len(all_tools) == 2
|
||||||
|
|
||||||
|
math_tools = registry.list_tools(tag="math")
|
||||||
|
assert len(math_tools) == 1
|
||||||
|
assert math_tools[0].name == "add"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_safe_execute():
|
||||||
|
async def failing_tool():
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
tool = FunctionTool(name="fail", description="Always fails", func=failing_tool)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await tool.safe_execute()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_not_found():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
from agentkit.core.exceptions import ToolNotFoundError
|
||||||
|
with pytest.raises(ToolNotFoundError):
|
||||||
|
registry.get("nonexistent")
|
||||||
Loading…
Reference in New Issue