feat: initial fischer-agentkit package with unified agent architecture

- BaseAgent with handle_task() pattern (execute template moved up)
- Protocol: TaskMessage, TaskResult, HandoffMessage, EvolutionEvent
- Tool system: FunctionTool, AgentTool, ToolRegistry with versioning
- Memory system: WorkingMemory (Redis), EpisodicMemory (pgvector), SemanticMemory (RAG adapter), MemoryRetriever (hybrid)
- Evolution engine: Reflector, PromptOptimizer (DSPy-style), StrategyTuner, ABTester, EvolutionStore
- Orchestrator: PipelineEngine (parallel DAG), PipelineLoader (YAML), HandoffManager, DynamicPipeline
- MCP: Server (FastAPI), Client (httpx), MCPTool
- Prompts: PromptTemplate, PromptSection
- Exceptions: full hierarchy including Tool, Schema, Handoff, Evolution errors
- Tests: unit tests for core, tools, protocol, evolution, pipeline
This commit is contained in:
chiguyong 2026-06-04 22:24:06 +08:00
commit 9a6d6fee4e
45 changed files with 4402 additions and 0 deletions

48
pyproject.toml Normal file
View File

@ -0,0 +1,48 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "fischer-agentkit"
version = "0.1.0"
description = "Unified Agent Framework with Tool/Skill plugins, Memory, Self-Evolution, MCP support, and Multi-Agent orchestration"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "MIT"}
authors = [
{name = "Fischer Team"},
]
dependencies = [
"pydantic>=2.0",
"redis[hiredis]>=5.0",
"sqlalchemy[asyncio]>=2.0",
"asyncpg>=0.29",
"httpx>=0.27",
"pyyaml>=6.0",
"jsonschema>=4.0",
]
[project.optional-dependencies]
mcp = [
"mcp>=1.0",
]
evolution = [
"scipy>=1.12",
]
dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"pytest-cov>=5.0",
"ruff>=0.4",
]
[tool.hatch.build.targets.wheel]
packages = ["src/agentkit"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.ruff]
target-version = "py311"
line-length = 100

25
src/agentkit/__init__.py Normal file
View File

@ -0,0 +1,25 @@
"""Fischer AgentKit - Unified Agent Framework"""
from agentkit.core.base import BaseAgent
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
HandoffMessage,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
__version__ = "0.1.0"
__all__ = [
"BaseAgent",
"AgentCapability",
"AgentStatus",
"HandoffMessage",
"TaskMessage",
"TaskProgress",
"TaskResult",
"TaskStatus",
]

View File

@ -0,0 +1,61 @@
"""AgentKit Core - 基础组件"""
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import (
AgentAlreadyRegisteredError,
AgentFrameworkError,
AgentNotFoundError,
AgentNotReadyError,
AgentUnavailableError,
ConfigValidationError,
EvolutionError,
HandoffError,
NoAvailableAgentError,
SchemaValidationError,
TaskCancelledError,
TaskDispatchError,
TaskExecutionError,
TaskNotFoundError,
TaskTimeoutError,
ToolExecutionError,
ToolNotFoundError,
)
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
EvolutionEvent,
HandoffMessage,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
__all__ = [
"BaseAgent",
"AgentCapability",
"AgentStatus",
"AgentFrameworkError",
"AgentNotFoundError",
"AgentAlreadyRegisteredError",
"AgentUnavailableError",
"AgentNotReadyError",
"TaskNotFoundError",
"TaskDispatchError",
"TaskExecutionError",
"TaskTimeoutError",
"TaskCancelledError",
"NoAvailableAgentError",
"ConfigValidationError",
"SchemaValidationError",
"HandoffError",
"EvolutionError",
"ToolNotFoundError",
"ToolExecutionError",
"HandoffMessage",
"EvolutionEvent",
"TaskMessage",
"TaskProgress",
"TaskResult",
"TaskStatus",
]

395
src/agentkit/core/base.py Normal file
View File

@ -0,0 +1,395 @@
"""BaseAgent 基类 - 统一 Agent 生命周期管理
核心设计
- execute() final 方法包含完整的计时try/exceptTaskResult 构建
- 子类只需实现 handle_task(task) -> dict 返回业务数据
- 生命周期钩子on_task_start / on_task_complete / on_task_failed
- 支持 Tool 插件Memory 系统可选注入
"""
import asyncio
import json
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
import redis.asyncio as aioredis
from agentkit.core.exceptions import AgentNotReadyError, SchemaValidationError
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
HandoffMessage,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
if TYPE_CHECKING:
from agentkit.memory.base import Memory
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期。
子类只需实现
- handle_task(task) -> dict: 业务逻辑返回 output_data
- get_capabilities() -> AgentCapability: 能力声明
可选覆写
- on_task_start(task): 任务开始前的钩子
- on_task_complete(task, output): 任务成功后的钩子
- on_task_failed(task, error): 任务失败后的钩子
"""
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
self.name = name
self.agent_type = agent_type
self.version = version
self._status: AgentStatus = AgentStatus.OFFLINE
self._redis: aioredis.Redis | None = None
self._redis_url: str = ""
self._running_tasks: set[str] = set()
self._listen_task: asyncio.Task | None = None
self._heartbeat_task: asyncio.Task | None = None
self._semaphore: asyncio.Semaphore | None = None
# 可插拔能力(由子类或配置注入)
self._tools: list["Tool"] = []
self._memory: "Memory | None" = None
# 外部依赖注入(由 start() 时设置)
self._registry = None
self._dispatcher = None
@property
def status(self) -> AgentStatus:
return self._status
@property
def is_distributed(self) -> bool:
return self._redis is not None
@property
def tools(self) -> list["Tool"]:
return self._tools
@property
def memory(self) -> "Memory | None":
return self._memory
# ── 抽象方法(子类必须实现) ──────────────────────────────
@abstractmethod
async def handle_task(self, task: TaskMessage) -> dict:
"""执行任务的核心业务逻辑,子类必须实现。
返回 output_data dict框架自动包装为 TaskResult
"""
...
@abstractmethod
def get_capabilities(self) -> AgentCapability:
"""返回 Agent 能力声明"""
...
# ── 生命周期钩子(可选覆写) ──────────────────────────────
async def on_task_start(self, task: TaskMessage) -> None:
"""任务开始前的钩子,可用于加载记忆、准备上下文等"""
pass
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
"""任务成功后的钩子,可用于存储记忆、触发反思等"""
pass
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
"""任务失败后的钩子,可用于记录失败模式等"""
pass
# ── 可插拔能力注入 ──────────────────────────────────────
def use_tool(self, tool: "Tool") -> "BaseAgent":
"""添加工具到 Agent"""
self._tools.append(tool)
return self
def use_memory(self, memory: "Memory") -> "BaseAgent":
"""设置记忆系统"""
self._memory = memory
return self
def set_registry(self, registry: Any) -> "BaseAgent":
"""注入注册中心"""
self._registry = registry
return self
def set_dispatcher(self, dispatcher: Any) -> "BaseAgent":
"""注入任务分发器"""
self._dispatcher = dispatcher
return self
# ── 核心生命周期 ──────────────────────────────────────────
async def start(self, redis_url: str = ""):
"""启动 Agent连接 Redis → 注册 → 心跳 → 监听"""
self._redis_url = redis_url
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
if redis_url:
try:
self._redis = aioredis.from_url(redis_url, decode_responses=True)
await self._redis.ping()
logger.info(f"Agent '{self.name}' connected to Redis")
except Exception as e:
self._redis = None
logger.warning(f"Agent '{self.name}' Redis unavailable: {e}, falling back to local mode")
# 注册到 Registry
if self._registry is not None:
capability = self.get_capabilities()
await self._registry.register(capability, endpoint=f"agent:{self.name}")
self._status = AgentStatus.ONLINE
# 设置并发控制
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
# 启动心跳和监听
if self._redis is not None:
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self._listen_task = asyncio.create_task(self._listen_for_tasks())
logger.info(f"Agent '{self.name}' started ({'distributed' if self._redis else 'local'} mode)")
async def stop(self):
"""停止 Agent"""
logger.info(f"Stopping agent '{self.name}'")
self._status = AgentStatus.OFFLINE
for task in [self._listen_task, self._heartbeat_task]:
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
if self._redis is not None:
if self._registry is not None:
await self._registry.unregister(self.name)
await self._redis.close()
self._redis = None
logger.info(f"Agent '{self.name}' stopped")
# ── execute 为 final 方法 ─────────────────────────────────
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务(框架方法,不可覆写)。
完整流程on_task_start handle_task on_task_complete/on_task_failed
自动处理计时TaskResult 构建错误捕获
"""
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
# 前置钩子
await self.on_task_start(task)
# Schema 校验(如果 Agent 声明了 input_schema
capability = self.get_capabilities()
if capability.input_schema:
self._validate_input(task.input_data, capability.input_schema)
# 执行业务逻辑
output = await self.handle_task(task)
# 后置钩子
await self.on_task_complete(task, output)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
# 失败钩子
try:
await self.on_task_failed(task, e)
except Exception as hook_err:
logger.error(f"on_task_failed hook error: {hook_err}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
"error_type": type(e).__name__,
},
)
# ── Handoff ───────────────────────────────────────────────
async def handoff(self, target_agent: str, task: TaskMessage, reason: str, context: dict[str, Any] | None = None):
"""将当前任务转交给另一个 Agent"""
if self._redis is None:
raise RuntimeError("Handoff requires Redis connection")
handoff_msg = HandoffMessage(
source_agent=self.name,
target_agent=target_agent,
task_id=task.task_id,
task_type=task.task_type,
context=context or task.input_data,
reason=reason,
)
# 发布到目标 Agent 的 handoff 频道
await self._redis.publish(
f"agent:{target_agent}:handoff",
json.dumps(handoff_msg.to_dict()),
)
logger.info(f"Agent '{self.name}' handed off task {task.task_id} to '{target_agent}': {reason}")
# ── 进度上报 ──────────────────────────────────────────────
async def report_progress(self, task_id: str, progress: float, message: str):
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
progress=progress,
message=message,
updated_at=datetime.now(timezone.utc),
)
if self._redis:
try:
await self._redis.publish(
f"agent:{self.name}:progress",
json.dumps(progress_obj.to_dict()),
)
except Exception as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
if self._dispatcher is not None:
try:
await self._dispatcher.handle_progress(progress_obj)
except Exception as e:
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
# ── 内部方法 ──────────────────────────────────────────────
async def heartbeat(self):
if self._registry is not None:
await self._registry.update_heartbeat(self.name)
async def _heartbeat_loop(self):
try:
while self._status == AgentStatus.ONLINE:
await self.heartbeat()
await asyncio.sleep(30)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
async def _listen_for_tasks(self):
try:
queue_key = f"agent:{self.name}:tasks"
while self._status == AgentStatus.ONLINE:
if not self._redis:
await asyncio.sleep(1)
continue
result = await self._redis.brpop(queue_key, timeout=1)
if result:
_, task_json = result
try:
task_data = json.loads(task_json)
task = TaskMessage.from_dict(task_data)
asyncio.create_task(self._execute_task_with_semaphore(task))
except Exception as e:
logger.error(f"Failed to parse task message: {e}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Task listener error for agent '{self.name}': {e}")
async def _execute_task_with_semaphore(self, task: TaskMessage):
if self._semaphore is None:
await self._execute_task(task)
return
async with self._semaphore:
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
try:
logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})")
result = await self.execute(task)
if self._redis is not None and self._dispatcher is not None:
await self._dispatcher.handle_result(result)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics=None,
)
if self._redis is not None and self._dispatcher is not None:
await self._dispatcher.handle_result(error_result)
finally:
self._running_tasks.discard(task.task_id)
if not self._running_tasks:
self._status = AgentStatus.ONLINE
def _validate_input(self, data: dict, schema: dict) -> None:
"""校验输入数据是否符合 JSON Schema"""
try:
import jsonschema
jsonschema.validate(data, schema)
except ImportError:
logger.warning("jsonschema not installed, skipping input validation")
except Exception as e:
raise SchemaValidationError(self.name, str(e))

View File

@ -0,0 +1,361 @@
"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent
与业务系统解耦通过依赖注入获取 Redis 连接和数据库会话
"""
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Awaitable
from agentkit.core.exceptions import (
NoAvailableAgentError,
TaskDispatchError,
TaskNotFoundError,
)
from agentkit.core.protocol import (
AgentStatus,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
logger = logging.getLogger(__name__)
class TaskDispatcher:
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
def __init__(
self,
redis_factory: Callable[[], Awaitable[Any]],
session_factory: Callable[[], Any],
agent_model: Any,
task_model: Any,
task_log_model: Any,
):
"""
Args:
redis_factory: 返回 Redis 连接的异步工厂
session_factory: 返回 async context manager 的工厂用于获取数据库会话
agent_model: Agent ORM 模型类
task_model: AgentTask ORM 模型类
task_log_model: AgentTaskLog ORM 模型类
"""
self._redis_factory = redis_factory
self._session_factory = session_factory
self._agent_model = agent_model
self._task_model = task_model
self._task_log_model = task_log_model
async def _get_redis(self):
return await self._redis_factory()
async def dispatch(
self,
task: TaskMessage,
organization_id: str | None = None,
created_by: str | None = None,
) -> str:
"""分发任务到对应 Agent 的队列,返回 task_id"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
AgentModel = self._agent_model
TaskModel = self._task_model
# 查找目标 Agent
stmt = select(AgentModel).where(AgentModel.name == task.agent_name)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
raise TaskDispatchError(task.task_id, f"Agent '{task.agent_name}' not found")
if agent.status != AgentStatus.ONLINE:
raise TaskDispatchError(
task.task_id,
f"Agent '{task.agent_name}' is not online (status={agent.status})",
)
# 写入 agent_tasks 表
task_id_uuid = uuid.UUID(task.task_id)
agent_task = TaskModel(
id=task_id_uuid,
agent_id=agent.id,
task_type=task.task_type,
status=TaskStatus.PENDING,
priority=task.priority,
input_data=task.input_data,
organization_id=uuid.UUID(organization_id) if organization_id else agent.id,
created_by=uuid.UUID(created_by) if created_by else None,
)
db.add(agent_task)
await db.commit()
# 推送到 Redis Queue
redis = await self._get_redis()
queue_key = f"agent:{task.agent_name}:tasks"
await redis.lpush(queue_key, json.dumps(task.to_dict()))
logger.info(
f"Task {task.task_id} dispatched to agent '{task.agent_name}' "
f"(type={task.task_type}, priority={task.priority})"
)
return task.task_id
except TaskDispatchError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to dispatch task {task.task_id}: {e}")
raise TaskDispatchError(task.task_id, str(e))
async def cancel_task(self, task_id: str):
"""取消任务"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
TaskModel = self._task_model
task_uuid = uuid.UUID(task_id)
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
raise TaskNotFoundError(task_id)
if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
logger.warning(f"Cannot cancel task {task_id} with status '{task.status}'")
return
task.status = TaskStatus.CANCELLED
task.completed_at = datetime.now(timezone.utc)
await db.commit()
await self._write_log(
db, task_id=task_id, agent_id=str(task.agent_id),
log_level="info", message="Task cancelled by user",
)
await db.commit()
logger.info(f"Task {task_id} cancelled")
except TaskNotFoundError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to cancel task {task_id}: {e}")
raise
async def get_task_status(self, task_id: str) -> dict:
"""获取任务状态"""
async with self._session_factory() as db:
from sqlalchemy import select
TaskModel = self._task_model
task_uuid = uuid.UUID(task_id)
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
result = await db.execute(stmt)
task = result.scalar_one_or_none()
if not task:
raise TaskNotFoundError(task_id)
return self._task_to_dict(task)
async def handle_result(self, result: TaskResult):
"""处理 Agent 返回的结果"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
TaskModel = self._task_model
task_uuid = uuid.UUID(result.task_id)
stmt = select(TaskModel).where(TaskModel.id == task_uuid)
db_result = await db.execute(stmt)
task = db_result.scalar_one_or_none()
if not task:
logger.error(f"Task {result.task_id} not found when handling result")
return
task.status = result.status
task.output_data = result.output_data
task.error_message = result.error_message
task.started_at = result.started_at
task.completed_at = result.completed_at
await db.commit()
log_level = "info" if result.status == TaskStatus.COMPLETED else "error"
log_message = (
f"Task {result.status}"
if result.status == TaskStatus.COMPLETED
else f"Task failed: {result.error_message}"
)
await self._write_log(
db,
task_id=result.task_id,
agent_id=str(task.agent_id),
log_level=log_level,
message=log_message,
extra_metadata=result.metrics,
)
await db.commit()
# 触发回调
if result.output_data and result.output_data.get("callback_url"):
await self._trigger_callback(result.output_data["callback_url"], result)
logger.info(f"Task {result.task_id} result handled (status={result.status})")
except Exception as e:
await db.rollback()
logger.error(f"Failed to handle result for task {result.task_id}: {e}")
async def handle_progress(self, progress: TaskProgress):
"""处理进度上报"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
AgentModel = self._agent_model
stmt = select(AgentModel).where(AgentModel.name == progress.agent_name)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
logger.warning(f"Agent '{progress.agent_name}' not found for progress report")
return
await self._write_log(
db,
task_id=progress.task_id,
agent_id=str(agent.id),
log_level="info",
message=f"Progress: {progress.progress:.0%} - {progress.message}",
extra_metadata={
"progress": progress.progress,
"updated_at": progress.updated_at.isoformat(),
},
)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to handle progress for task {progress.task_id}: {e}")
async def retry_failed_tasks(self, max_retries: int = 3):
"""重试失败的任务"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
TaskModel = self._task_model
AgentModel = self._agent_model
LogModel = self._task_log_model
stmt = select(TaskModel).where(TaskModel.status == TaskStatus.FAILED)
result = await db.execute(stmt)
failed_tasks = result.scalars().all()
retried = 0
for task in failed_tasks:
log_stmt = select(LogModel).where(
LogModel.task_id == task.id,
LogModel.message.like("%retry%"),
)
log_result = await db.execute(log_stmt)
retry_count = len(log_result.scalars().all())
if retry_count < max_retries:
task.status = TaskStatus.PENDING
task.error_message = None
task.started_at = None
task.completed_at = None
agent_stmt = select(AgentModel).where(AgentModel.id == task.agent_id)
agent_result = await db.execute(agent_stmt)
agent = agent_result.scalar_one_or_none()
if agent and agent.status == AgentStatus.ONLINE:
task_msg = TaskMessage(
task_id=str(task.id),
agent_name=agent.name,
task_type=task.task_type,
priority=task.priority,
input_data=task.input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
redis = await self._get_redis()
queue_key = f"agent:{agent.name}:tasks"
await redis.lpush(queue_key, json.dumps(task_msg.to_dict()))
await self._write_log(
db,
task_id=str(task.id),
agent_id=str(agent.id),
log_level="info",
message=f"Task retry attempt {retry_count + 1}/{max_retries}",
)
retried += 1
await db.commit()
if retried > 0:
logger.info(f"Retried {retried} failed tasks")
except Exception as e:
await db.rollback()
logger.error(f"Failed to retry failed tasks: {e}")
async def _write_log(
self,
db: Any,
task_id: str,
agent_id: str,
log_level: str,
message: str,
extra_metadata: dict | None = None,
):
LogModel = self._task_log_model
log_entry = LogModel(
task_id=uuid.UUID(task_id),
agent_id=uuid.UUID(agent_id),
log_level=log_level,
message=message,
extra_metadata=extra_metadata,
)
db.add(log_entry)
async def _trigger_callback(self, callback_url: str, result: TaskResult):
try:
import httpx
async with httpx.AsyncClient(timeout=10) as client:
await client.post(callback_url, json=result.to_dict())
logger.info(f"Callback triggered for task {result.task_id}")
except Exception as e:
logger.warning(f"Callback failed for task {result.task_id}: {e}")
def _task_to_dict(self, task: Any) -> dict:
return {
"id": str(task.id),
"agent_id": str(task.agent_id),
"task_type": task.task_type,
"status": task.status,
"priority": task.priority,
"input_data": task.input_data,
"output_data": task.output_data,
"error_message": task.error_message,
"created_by": str(task.created_by) if task.created_by else None,
"organization_id": str(task.organization_id),
"project_id": str(task.project_id) if task.project_id else None,
"scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None,
"started_at": task.started_at.isoformat() if task.started_at else None,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
"created_at": task.created_at.isoformat() if task.created_at else None,
}

View File

@ -0,0 +1,110 @@
"""Agent 框架自定义异常"""
class AgentFrameworkError(Exception):
"""Agent 框架基础异常"""
def __init__(self, message: str = "Agent framework error"):
self.message = message
super().__init__(self.message)
class AgentNotFoundError(AgentFrameworkError):
def __init__(self, agent_name: str):
self.agent_name = agent_name
super().__init__(f"Agent not found: {agent_name}")
class AgentAlreadyRegisteredError(AgentFrameworkError):
def __init__(self, agent_name: str):
self.agent_name = agent_name
super().__init__(f"Agent already registered: {agent_name}")
class AgentUnavailableError(AgentFrameworkError):
def __init__(self, agent_name: str, status: str = "offline"):
self.agent_name = agent_name
self.status = status
super().__init__(f"Agent '{agent_name}' is unavailable (status: {status})")
class TaskNotFoundError(AgentFrameworkError):
def __init__(self, task_id: str):
self.task_id = task_id
super().__init__(f"Task not found: {task_id}")
class TaskDispatchError(AgentFrameworkError):
def __init__(self, task_id: str, reason: str = ""):
self.task_id = task_id
super().__init__(f"Task dispatch failed for {task_id}: {reason}")
class TaskExecutionError(AgentFrameworkError):
def __init__(self, task_id: str, agent_name: str, reason: str = ""):
self.task_id = task_id
self.agent_name = agent_name
super().__init__(f"Task {task_id} execution failed on agent '{agent_name}': {reason}")
class TaskTimeoutError(AgentFrameworkError):
def __init__(self, task_id: str, timeout_seconds: int):
self.task_id = task_id
self.timeout_seconds = timeout_seconds
super().__init__(f"Task {task_id} timed out after {timeout_seconds}s")
class TaskCancelledError(AgentFrameworkError):
def __init__(self, task_id: str):
self.task_id = task_id
super().__init__(f"Task {task_id} was cancelled")
class NoAvailableAgentError(AgentFrameworkError):
def __init__(self, task_type: str):
self.task_type = task_type
super().__init__(f"No available agent for task type: {task_type}")
class ConfigValidationError(AgentFrameworkError):
def __init__(self, agent_name: str, key: str, reason: str = ""):
self.agent_name = agent_name
self.key = key
super().__init__(f"Config validation failed for agent '{agent_name}' key '{key}': {reason}")
class AgentNotReadyError(AgentFrameworkError):
def __init__(self, agent_name: str):
self.agent_name = agent_name
super().__init__(f"Agent '{agent_name}' is not ready")
class ToolNotFoundError(AgentFrameworkError):
def __init__(self, tool_name: str):
self.tool_name = tool_name
super().__init__(f"Tool not found: {tool_name}")
class ToolExecutionError(AgentFrameworkError):
def __init__(self, tool_name: str, reason: str = ""):
self.tool_name = tool_name
super().__init__(f"Tool '{tool_name}' execution failed: {reason}")
class SchemaValidationError(AgentFrameworkError):
def __init__(self, agent_name: str, detail: str = ""):
self.agent_name = agent_name
super().__init__(f"Schema validation failed for agent '{agent_name}': {detail}")
class HandoffError(AgentFrameworkError):
def __init__(self, source: str, target: str, reason: str = ""):
self.source = source
self.target = target
super().__init__(f"Handoff from '{source}' to '{target}' failed: {reason}")
class EvolutionError(AgentFrameworkError):
def __init__(self, agent_name: str, reason: str = ""):
self.agent_name = agent_name
super().__init__(f"Evolution failed for agent '{agent_name}': {reason}")

View File

@ -0,0 +1,245 @@
"""Agent 通信协议定义 - 统一消息格式"""
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
class TaskStatus(str, Enum):
"""任务状态枚举"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
HANDOFF = "handoff"
class AgentStatus(str, Enum):
"""Agent 状态枚举"""
ONLINE = "online"
OFFLINE = "offline"
BUSY = "busy"
@dataclass
class AgentCapability:
"""Agent 能力声明"""
agent_name: str
agent_type: str
version: str
supported_tasks: list[str]
max_concurrency: int
description: str
input_schema: dict[str, Any] | None = None
output_schema: dict[str, Any] | None = None
def to_dict(self) -> dict:
d = {
"agent_name": self.agent_name,
"agent_type": self.agent_type,
"version": self.version,
"supported_tasks": self.supported_tasks,
"max_concurrency": self.max_concurrency,
"description": self.description,
}
if self.input_schema is not None:
d["input_schema"] = self.input_schema
if self.output_schema is not None:
d["output_schema"] = self.output_schema
return d
@classmethod
def from_dict(cls, data: dict) -> "AgentCapability":
return cls(
agent_name=data["agent_name"],
agent_type=data["agent_type"],
version=data["version"],
supported_tasks=data["supported_tasks"],
max_concurrency=data["max_concurrency"],
description=data["description"],
input_schema=data.get("input_schema"),
output_schema=data.get("output_schema"),
)
@dataclass
class TaskMessage:
"""任务消息 - 从调度器发往 Agent"""
task_id: str
agent_name: str
task_type: str
priority: int
input_data: dict
callback_url: str | None
created_at: datetime
timeout_seconds: int = 300
conversation_id: str | None = None
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"agent_name": self.agent_name,
"task_type": self.task_type,
"priority": self.priority,
"input_data": self.input_data,
"callback_url": self.callback_url,
"created_at": self.created_at.isoformat() if self.created_at else None,
"timeout_seconds": self.timeout_seconds,
"conversation_id": self.conversation_id,
}
@classmethod
def from_dict(cls, data: dict) -> "TaskMessage":
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
return cls(
task_id=data["task_id"],
agent_name=data["agent_name"],
task_type=data["task_type"],
priority=data.get("priority", 0),
input_data=data.get("input_data", {}),
callback_url=data.get("callback_url"),
created_at=created_at or datetime.utcnow(),
timeout_seconds=data.get("timeout_seconds", 300),
conversation_id=data.get("conversation_id"),
)
@dataclass
class TaskResult:
"""任务结果 - 从 Agent 返回"""
task_id: str
agent_name: str
status: str
output_data: dict | None
error_message: str | None
started_at: datetime
completed_at: datetime
metrics: dict | None = None
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"agent_name": self.agent_name,
"status": self.status,
"output_data": self.output_data,
"error_message": self.error_message,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"metrics": self.metrics,
}
@classmethod
def from_dict(cls, data: dict) -> "TaskResult":
started_at = data.get("started_at")
if isinstance(started_at, str):
started_at = datetime.fromisoformat(started_at)
completed_at = data.get("completed_at")
if isinstance(completed_at, str):
completed_at = datetime.fromisoformat(completed_at)
return cls(
task_id=data["task_id"],
agent_name=data["agent_name"],
status=data["status"],
output_data=data.get("output_data"),
error_message=data.get("error_message"),
started_at=started_at or datetime.utcnow(),
completed_at=completed_at or datetime.utcnow(),
metrics=data.get("metrics"),
)
@dataclass
class TaskProgress:
"""进度上报 - Agent 执行过程中上报"""
task_id: str
agent_name: str
progress: float
message: str
updated_at: datetime
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"agent_name": self.agent_name,
"progress": self.progress,
"message": self.message,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
}
@classmethod
def from_dict(cls, data: dict) -> "TaskProgress":
updated_at = data.get("updated_at")
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at)
return cls(
task_id=data["task_id"],
agent_name=data["agent_name"],
progress=data.get("progress", 0.0),
message=data.get("message", ""),
updated_at=updated_at or datetime.utcnow(),
)
@dataclass
class HandoffMessage:
"""任务转交消息 - Agent 间 Handoff"""
source_agent: str
target_agent: str
task_id: str
task_type: str
context: dict[str, Any]
reason: str
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
def to_dict(self) -> dict:
return {
"source_agent": self.source_agent,
"target_agent": self.target_agent,
"task_id": self.task_id,
"task_type": self.task_type,
"context": self.context,
"reason": self.reason,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict) -> "HandoffMessage":
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at)
return cls(
source_agent=data["source_agent"],
target_agent=data["target_agent"],
task_id=data["task_id"],
task_type=data["task_type"],
context=data.get("context", {}),
reason=data["reason"],
created_at=created_at or datetime.utcnow(),
)
@dataclass
class EvolutionEvent:
"""进化事件 - 记录 Agent 的自我进化变更"""
agent_name: str
change_type: str # prompt / strategy / pipeline
before: dict[str, Any]
after: dict[str, Any]
metrics: dict[str, Any] | None = None
event_id: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
def to_dict(self) -> dict:
return {
"agent_name": self.agent_name,
"change_type": self.change_type,
"before": self.before,
"after": self.after,
"metrics": self.metrics,
"event_id": self.event_id,
"created_at": self.created_at.isoformat(),
}

View File

@ -0,0 +1,250 @@
"""Agent 注册中心 - 管理 Agent 的注册、发现、状态
与业务系统解耦通过 session_factory 注入数据库会话
通过 agent_model_factory 注入 ORM 模型
"""
import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Awaitable
from agentkit.core.exceptions import (
AgentNotFoundError,
AgentUnavailableError,
NoAvailableAgentError,
)
from agentkit.core.protocol import AgentCapability, AgentStatus
logger = logging.getLogger(__name__)
HEARTBEAT_TIMEOUT_SECONDS = 90
class AgentRegistry:
"""Agent 注册中心,管理 Agent 的注册、发现、状态
使用依赖注入模式不依赖具体的 ORM 模型或数据库连接
"""
def __init__(
self,
session_factory: Callable[[], Any],
agent_model: Any,
load_balancer: str = "round_robin",
):
"""
Args:
session_factory: 返回 async context manager 的工厂用于获取数据库会话
agent_model: Agent ORM 模型类
load_balancer: 负载均衡策略 (round_robin / least_tasks / random)
"""
self._session_factory = session_factory
self._agent_model = agent_model
self._load_balancer = load_balancer
self._round_robin_index: dict[str, int] = {}
async def register(self, capability: AgentCapability, endpoint: str) -> str:
"""注册 Agent返回 agent_id。同名 Agent 已存在则更新。"""
async with self._session_factory() as db:
try:
Model = self._agent_model
stmt = type(db).execute.__self__.__class__ # placeholder
# 尝试查找已有记录
from sqlalchemy import select
stmt = select(Model).where(Model.name == capability.agent_name)
result = await db.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
existing.agent_type = capability.agent_type
existing.version = capability.version
existing.endpoint = endpoint
existing.description = capability.description
existing.capabilities = capability.to_dict()
existing.status = AgentStatus.ONLINE
existing.last_heartbeat = datetime.now(timezone.utc)
await db.commit()
await db.refresh(existing)
agent_id = existing.id
logger.info(f"Agent '{capability.agent_name}' re-registered (id={agent_id})")
else:
agent = Model(
name=capability.agent_name,
display_name=capability.agent_name.replace("_", " ").title(),
agent_type=capability.agent_type,
description=capability.description,
version=capability.version,
endpoint=endpoint,
status=AgentStatus.ONLINE,
capabilities=capability.to_dict(),
last_heartbeat=datetime.now(timezone.utc),
)
db.add(agent)
await db.commit()
await db.refresh(agent)
agent_id = agent.id
logger.info(f"Agent '{capability.agent_name}' registered (id={agent_id})")
return str(agent_id)
except Exception as e:
await db.rollback()
logger.error(f"Failed to register agent '{capability.agent_name}': {e}")
raise
async def unregister(self, agent_name: str):
"""注销 Agent设置状态为 offline"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
Model = self._agent_model
stmt = select(Model).where(Model.name == agent_name)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
logger.warning(f"Attempted to unregister non-existent agent '{agent_name}'")
return
agent.status = AgentStatus.OFFLINE
await db.commit()
logger.info(f"Agent '{agent_name}' unregistered")
except Exception as e:
await db.rollback()
logger.error(f"Failed to unregister agent '{agent_name}': {e}")
raise
async def update_heartbeat(self, agent_name: str):
"""更新心跳时间"""
async with self._session_factory() as db:
try:
from sqlalchemy import update
Model = self._agent_model
stmt = (
update(Model)
.where(Model.name == agent_name)
.values(
last_heartbeat=datetime.now(timezone.utc),
status=AgentStatus.ONLINE,
)
)
await db.execute(stmt)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to update heartbeat for agent '{agent_name}': {e}")
async def get_agent(self, agent_name: str) -> dict | None:
"""获取 Agent 信息"""
async with self._session_factory() as db:
from sqlalchemy import select
Model = self._agent_model
stmt = select(Model).where(Model.name == agent_name)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
return None
return self._agent_to_dict(agent)
async def list_agents(
self,
agent_type: str | None = None,
status: str | None = None,
) -> list[dict]:
"""列出 Agent支持按类型和状态筛选"""
async with self._session_factory() as db:
from sqlalchemy import select
Model = self._agent_model
stmt = select(Model)
if agent_type:
stmt = stmt.where(Model.agent_type == agent_type)
if status:
stmt = stmt.where(Model.status == status)
stmt = stmt.order_by(Model.created_at.desc())
result = await db.execute(stmt)
agents = result.scalars().all()
return [self._agent_to_dict(a) for a in agents]
async def get_available_agent(self, task_type: str) -> str | None:
"""根据任务类型找到可用 Agent支持负载均衡"""
async with self._session_factory() as db:
from sqlalchemy import select
Model = self._agent_model
stmt = select(Model).where(Model.status == AgentStatus.ONLINE)
result = await db.execute(stmt)
agents = result.scalars().all()
candidates = []
for agent in agents:
capabilities = agent.capabilities or {}
supported_tasks = capabilities.get("supported_tasks", [])
if task_type in supported_tasks:
candidates.append(agent)
if not candidates:
return None
# 负载均衡选择
if self._load_balancer == "round_robin":
idx = self._round_robin_index.get(task_type, 0)
selected = candidates[idx % len(candidates)]
self._round_robin_index[task_type] = idx + 1
return selected.name
elif self._load_balancer == "random":
import random
return random.choice(candidates).name
else: # least_tasks 或默认:返回第一个
return candidates[0].name
async def check_health(self):
"""检查所有 Agent 健康状态,超时标记为 offline"""
async with self._session_factory() as db:
try:
from sqlalchemy import update
Model = self._agent_model
timeout_threshold = datetime.now(timezone.utc) - timedelta(
seconds=HEARTBEAT_TIMEOUT_SECONDS
)
stmt = (
update(Model)
.where(
Model.status == AgentStatus.ONLINE,
Model.last_heartbeat < timeout_threshold,
)
.values(status=AgentStatus.OFFLINE)
)
result = await db.execute(stmt)
await db.commit()
if result.rowcount > 0:
logger.warning(
f"Marked {result.rowcount} agent(s) as offline due to heartbeat timeout"
)
except Exception as e:
await db.rollback()
logger.error(f"Failed to check agent health: {e}")
def _agent_to_dict(self, agent: Any) -> dict:
"""将 Agent ORM 对象转换为字典"""
return {
"id": str(agent.id),
"name": agent.name,
"display_name": agent.display_name,
"agent_type": agent.agent_type,
"description": agent.description,
"version": agent.version,
"endpoint": agent.endpoint,
"status": agent.status,
"capabilities": agent.capabilities,
"last_heartbeat": agent.last_heartbeat.isoformat() if agent.last_heartbeat else None,
"created_at": agent.created_at.isoformat() if agent.created_at else None,
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
}

View File

@ -0,0 +1,17 @@
"""AgentKit Evolution - 自我进化引擎"""
from agentkit.evolution.reflector import Reflector
from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module
from agentkit.evolution.strategy_tuner import StrategyTuner
from agentkit.evolution.ab_tester import ABTester
from agentkit.evolution.evolution_store import EvolutionStore
__all__ = [
"Reflector",
"PromptOptimizer",
"Signature",
"Module",
"StrategyTuner",
"ABTester",
"EvolutionStore",
]

View File

@ -0,0 +1,121 @@
"""ABTester - A/B 测试框架
支持配置分流比例自动收集效果指标统计显著性检验
"""
import logging
import math
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class ABTestConfig:
"""A/B 测试配置"""
test_id: str
agent_name: str
change_type: str # prompt / strategy / pipeline
control_ratio: float = 0.8 # 对照组比例
min_samples: int = 30 # 最小样本量
confidence_level: float = 0.95 # 置信度
status: str = "running" # running / completed / rolled_back
@dataclass
class ABTestResult:
"""A/B 测试结果"""
test_id: str
control_metric: float
experiment_metric: float
control_samples: int
experiment_samples: int
is_significant: bool
winner: str | None # control / experiment / None
p_value: float | None = None
class ABTester:
"""A/B 测试框架"""
def __init__(self):
self._tests: dict[str, ABTestConfig] = {}
self._results: dict[str, list[tuple[str, float]]] = {} # test_id -> [(group, metric)]
def create_test(self, config: ABTestConfig) -> None:
"""创建 A/B 测试"""
self._tests[config.test_id] = config
self._results[config.test_id] = []
logger.info(f"A/B test '{config.test_id}' created for agent '{config.agent_name}'")
def assign_group(self, test_id: str) -> str:
"""分配测试组"""
import random
config = self._tests.get(test_id)
if not config:
return "control"
return "control" if random.random() < config.control_ratio else "experiment"
def record_result(self, test_id: str, group: str, metric: float) -> None:
"""记录测试结果"""
if test_id not in self._results:
self._results[test_id] = []
self._results[test_id].append((group, metric))
async def evaluate(self, test_id: str) -> ABTestResult | None:
"""评估 A/B 测试结果"""
config = self._tests.get(test_id)
if not config:
return None
results = self._results.get(test_id, [])
control_metrics = [m for g, m in results if g == "control"]
experiment_metrics = [m for g, m in results if g == "experiment"]
if len(control_metrics) < config.min_samples or len(experiment_metrics) < config.min_samples:
return ABTestResult(
test_id=test_id,
control_metric=sum(control_metrics) / len(control_metrics) if control_metrics else 0,
experiment_metric=sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0,
control_samples=len(control_metrics),
experiment_samples=len(experiment_metrics),
is_significant=False,
winner=None,
)
# 简单 t-test
control_mean = sum(control_metrics) / len(control_metrics)
experiment_mean = sum(experiment_metrics) / len(experiment_metrics)
control_var = sum((m - control_mean) ** 2 for m in control_metrics) / (len(control_metrics) - 1)
experiment_var = sum((m - experiment_mean) ** 2 for m in experiment_metrics) / (len(experiment_metrics) - 1)
pooled_se = math.sqrt(control_var / len(control_metrics) + experiment_var / len(experiment_metrics))
t_stat = (experiment_mean - control_mean) / pooled_se if pooled_se > 0 else 0
# 近似 p-value (双侧)
p_value = 2 * (1 - self._normal_cdf(abs(t_stat)))
is_significant = p_value < (1 - config.confidence_level)
winner = None
if is_significant:
winner = "experiment" if experiment_mean > control_mean else "control"
return ABTestResult(
test_id=test_id,
control_metric=control_mean,
experiment_metric=experiment_mean,
control_samples=len(control_metrics),
experiment_samples=len(experiment_metrics),
is_significant=is_significant,
winner=winner,
p_value=p_value,
)
@staticmethod
def _normal_cdf(x: float) -> float:
"""标准正态分布 CDF 近似"""
return 0.5 * (1 + math.erf(x / math.sqrt(2)))

View File

@ -0,0 +1,113 @@
"""EvolutionStore - 进化日志存储"""
import logging
from datetime import datetime
from typing import Any
from agentkit.core.protocol import EvolutionEvent
logger = logging.getLogger(__name__)
class EvolutionStore:
"""进化日志存储
记录 Agent 的自我进化变更支持回滚
"""
def __init__(self, session_factory: Any, evolution_model: Any):
self._session_factory = session_factory
self._evolution_model = evolution_model
async def record(self, event: EvolutionEvent) -> str:
"""记录进化事件"""
async with self._session_factory() as db:
try:
import uuid
Model = self._evolution_model
entry = Model(
id=uuid.uuid4(),
agent_name=event.agent_name,
change_type=event.change_type,
before=event.before,
after=event.after,
metrics=event.metrics,
status="active",
)
db.add(entry)
await db.commit()
await db.refresh(entry)
event_id = str(entry.id)
event.event_id = event_id
logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'")
return event_id
except Exception as e:
await db.rollback()
logger.error(f"Failed to record evolution event: {e}")
raise
async def rollback(self, event_id: str) -> bool:
"""回滚进化事件"""
async with self._session_factory() as db:
try:
import uuid
from sqlalchemy import select
Model = self._evolution_model
stmt = select(Model).where(Model.id == uuid.UUID(event_id))
result = await db.execute(stmt)
entry = result.scalar_one_or_none()
if not entry:
logger.error(f"Evolution event {event_id} not found")
return False
entry.status = "rolled_back"
await db.commit()
logger.info(f"Evolution event {event_id} rolled back")
return True
except Exception as e:
await db.rollback()
logger.error(f"Failed to rollback evolution event {event_id}: {e}")
return False
async def list_events(
self,
agent_name: str | None = None,
change_type: str | None = None,
status: str | None = None,
) -> list[dict]:
"""列出进化事件"""
async with self._session_factory() as db:
try:
from sqlalchemy import select
Model = self._evolution_model
stmt = select(Model)
if agent_name:
stmt = stmt.where(Model.agent_name == agent_name)
if change_type:
stmt = stmt.where(Model.change_type == change_type)
if status:
stmt = stmt.where(Model.status == status)
stmt = stmt.order_by(Model.created_at.desc())
result = await db.execute(stmt)
entries = result.scalars().all()
return [
{
"id": str(e.id),
"agent_name": e.agent_name,
"change_type": e.change_type,
"before": e.before,
"after": e.after,
"metrics": e.metrics,
"status": e.status,
"created_at": e.created_at.isoformat() if e.created_at else None,
}
for e in entries
]
except Exception as e:
logger.error(f"Failed to list evolution events: {e}")
return []

View File

@ -0,0 +1,151 @@
"""PromptOptimizer - DSPy 风格的 Prompt 自动优化器
核心概念
- Signature: 定义输入/输出 schema
- Module: 可组合的 Prompt 策略
- Optimizer: 从任务结果中自动优化 Prompt
"""
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class Signature:
"""Prompt 签名 - 定义输入/输出字段"""
input_fields: dict[str, str] # name -> description
output_fields: dict[str, str] # name -> description
instruction: str = ""
def to_prompt_prefix(self) -> str:
parts = []
if self.instruction:
parts.append(self.instruction)
parts.append("Inputs:")
for name, desc in self.input_fields.items():
parts.append(f" - {name}: {desc}")
parts.append("Outputs:")
for name, desc in self.output_fields.items():
parts.append(f" - {name}: {desc}")
return "\n".join(parts)
@dataclass
class Module:
"""可组合的 Prompt 策略模块"""
name: str
signature: Signature
template: str = ""
demos: list[dict[str, Any]] = field(default_factory=list)
def render(self, **kwargs) -> str:
parts = []
parts.append(self.signature.to_prompt_prefix())
if self.demos:
parts.append("\nExamples:")
for demo in self.demos:
parts.append(f"\nInput: {demo.get('input', '')}")
parts.append(f"Output: {demo.get('output', '')}")
if self.template:
parts.append(f"\n{self.template.format(**kwargs)}")
return "\n".join(parts)
class PromptOptimizer:
"""DSPy 风格的 Prompt 自动优化器
从成功案例中自动构建 few-shot 示例优化 Prompt 指令
"""
def __init__(
self,
max_demos: int = 5,
min_examples_for_optimization: int = 3,
):
self._max_demos = max_demos
self._min_examples = min_examples_for_optimization
self._success_examples: list[dict[str, Any]] = []
self._failure_examples: list[dict[str, Any]] = []
def add_example(
self,
input_data: dict,
output_data: dict,
quality_score: float,
) -> None:
"""添加训练样本"""
example = {
"input": input_data,
"output": output_data,
"quality_score": quality_score,
}
if quality_score >= 0.7:
self._success_examples.append(example)
else:
self._failure_examples.append(example)
async def optimize(self, module: Module) -> Module:
"""优化 Module 的 Prompt
BootstrapFewShot: 从成功案例中自动构建 few-shot 示例
"""
if len(self._success_examples) < self._min_examples:
logger.info(
f"Not enough examples for optimization "
f"({len(self._success_examples)}/{self._min_examples})"
)
return module
# 选择质量最高的成功案例作为 demo
sorted_examples = sorted(
self._success_examples,
key=lambda x: x["quality_score"],
reverse=True,
)
best_demos = sorted_examples[:self._max_demos]
# 构建 few-shot 示例
demos = []
for example in best_demos:
demos.append({
"input": str(example["input"]),
"output": str(example["output"]),
})
# 优化指令(基于失败案例的反面教材)
optimized_instruction = module.signature.instruction
if self._failure_examples:
failure_patterns = set()
for ex in self._failure_examples[-3:]:
failure_patterns.add(str(ex["input"])[:100])
if failure_patterns:
optimized_instruction += (
f"\n\nAvoid these patterns:\n"
+ "\n".join(f"- {p}" for p in failure_patterns)
)
# 创建优化后的 Module
optimized = Module(
name=f"{module.name}_optimized",
signature=Signature(
input_fields=module.signature.input_fields,
output_fields=module.signature.output_fields,
instruction=optimized_instruction,
),
template=module.template,
demos=demos,
)
logger.info(
f"Optimized module '{module.name}': "
f"{len(demos)} demos, instruction length {len(optimized_instruction)}"
)
return optimized
@property
def example_count(self) -> tuple[int, int]:
return len(self._success_examples), len(self._failure_examples)

View File

@ -0,0 +1,147 @@
"""Reflector - 执行反思
每次任务完成后自动评估结果提取模式生成反思总结
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
logger = logging.getLogger(__name__)
@dataclass
class Reflection:
"""反思结果"""
task_id: str
agent_name: str
outcome: str # success / failure / partial
quality_score: float # 0.0 - 1.0
patterns: list[str] = field(default_factory=list)
insights: list[str] = field(default_factory=list)
suggestions: list[str] = field(default_factory=list)
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
class Reflector:
"""执行反思器
评估任务结果提取成功/失败模式生成改进建议
"""
def __init__(
self,
quality_scorer: Any | None = None,
pattern_extractor: Any | None = None,
):
self._quality_scorer = quality_scorer
self._pattern_extractor = pattern_extractor
async def reflect(self, task: TaskMessage, result: TaskResult) -> Reflection:
"""对任务执行结果进行反思"""
# 判断结果
outcome = "success" if result.status == TaskStatus.COMPLETED else "failure"
# 质量评分
quality_score = await self._score_quality(task, result)
# 提取模式
patterns = await self._extract_patterns(task, result, outcome)
# 生成洞察
insights = self._generate_insights(outcome, quality_score, patterns)
# 生成建议
suggestions = self._generate_suggestions(outcome, quality_score, patterns)
reflection = Reflection(
task_id=task.task_id,
agent_name=result.agent_name,
outcome=outcome,
quality_score=quality_score,
patterns=patterns,
insights=insights,
suggestions=suggestions,
)
logger.info(
f"Reflection for task {task.task_id}: outcome={outcome}, "
f"quality={quality_score:.2f}, patterns={len(patterns)}"
)
return reflection
async def _score_quality(self, task: TaskMessage, result: TaskResult) -> float:
"""评估任务质量"""
if result.status != TaskStatus.COMPLETED:
return 0.0
if self._quality_scorer:
return await self._quality_scorer(task, result)
# 默认评分逻辑
score = 0.5 # 基础分
# 有输出数据加分
if result.output_data:
score += 0.2
# 无错误加分
if not result.error_message:
score += 0.1
# 耗时合理加分
if result.metrics and result.metrics.get("elapsed_seconds", 0) < 30:
score += 0.2
return min(score, 1.0)
async def _extract_patterns(
self, task: TaskMessage, result: TaskResult, outcome: str
) -> list[str]:
"""提取模式"""
patterns = []
if outcome == "failure":
if result.error_message:
patterns.append(f"error_type:{type(result.error_message).__name__}")
if result.metrics and result.metrics.get("elapsed_seconds", 0) > 60:
patterns.append("slow_execution")
else:
if result.output_data and len(result.output_data) > 5:
patterns.append("rich_output")
if result.metrics and result.metrics.get("elapsed_seconds", 0) < 10:
patterns.append("fast_execution")
return patterns
def _generate_insights(
self, outcome: str, quality_score: float, patterns: list[str]
) -> list[str]:
"""生成洞察"""
insights = []
if quality_score < 0.3:
insights.append("Low quality score indicates potential issues with task execution")
if "slow_execution" in patterns:
insights.append("Slow execution may benefit from strategy optimization")
if "error_type:TimeoutError" in patterns:
insights.append("Timeout errors suggest need for longer timeout or task decomposition")
return insights
def _generate_suggestions(
self, outcome: str, quality_score: float, patterns: list[str]
) -> list[str]:
"""生成改进建议"""
suggestions = []
if outcome == "failure" and quality_score < 0.3:
suggestions.append("Consider prompt optimization for this task type")
if "slow_execution" in patterns:
suggestions.append("Consider adjusting strategy parameters for faster execution")
return suggestions

View File

@ -0,0 +1,81 @@
"""StrategyTuner - 策略调优
自动调整 Agent 参数temperature, tool 选择权重, Pipeline 路径
"""
import logging
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class StrategyConfig:
"""策略配置"""
temperature: float = 0.5
tool_weights: dict[str, float] = field(default_factory=dict)
max_iterations: int = 5
timeout_seconds: int = 300
class StrategyTuner:
"""策略调优器
基于历史效果数据自动调整 Agent 参数
"""
def __init__(self, param_ranges: dict[str, tuple[float, float]] | None = None):
self._param_ranges = param_ranges or {
"temperature": (0.0, 1.0),
"max_iterations": (1, 10),
}
self._history: list[dict[str, Any]] = []
def record(self, config: StrategyConfig, metric: float) -> None:
"""记录配置和对应的效果指标"""
self._history.append({
"config": config,
"metric": metric,
})
async def suggest(self, current: StrategyConfig) -> StrategyConfig:
"""基于历史数据建议新的策略配置"""
if len(self._history) < 3:
logger.info("Not enough history for strategy tuning")
return current
# 找到效果最好的配置
best = max(self._history, key=lambda x: x["metric"])
best_config = best["config"]
best_metric = best["metric"]
# 在最佳配置附近微调
suggested = StrategyConfig(
temperature=self._clamp(
best_config.temperature + self._small_perturbation(),
*self._param_ranges.get("temperature", (0.0, 1.0)),
),
tool_weights=dict(best_config.tool_weights),
max_iterations=int(self._clamp(
best_config.max_iterations + self._small_perturbation(),
*self._param_ranges.get("max_iterations", (1, 10)),
)),
timeout_seconds=current.timeout_seconds,
)
logger.info(
f"Strategy suggestion: temperature {current.temperature:.2f} -> {suggested.temperature:.2f}, "
f"max_iterations {current.max_iterations} -> {suggested.max_iterations}"
)
return suggested
@staticmethod
def _small_perturbation() -> float:
import random
return random.uniform(-0.1, 0.1)
@staticmethod
def _clamp(value: float, min_val: float, max_val: float) -> float:
return max(min_val, min(max_val, value))

View File

@ -0,0 +1,6 @@
"""AgentKit MCP - Model Context Protocol 支持"""
__all__ = [
"MCPServer",
"MCPClient",
]

View File

@ -0,0 +1,83 @@
"""MCP Client - 调用外部 MCP 工具服务器"""
import logging
from typing import Any
import httpx
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class MCPClient:
"""MCP Client - 连接外部 MCP Server 并调用工具"""
def __init__(self, server_url: str, timeout: int = 30):
self._server_url = server_url.rstrip("/")
self._timeout = timeout
self._tools_cache: list[dict] | None = None
async def list_tools(self) -> list[dict]:
"""列出远程 MCP Server 上的工具"""
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.get(f"{self._server_url}/tools/list")
response.raise_for_status()
data = response.json()
self._tools_cache = data.get("tools", [])
return self._tools_cache
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""调用远程 MCP 工具"""
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.post(
f"{self._server_url}/tools/call",
json={"name": tool_name, "arguments": arguments},
)
response.raise_for_status()
return response.json()
def as_tool(self, tool_name: str, description: str = "") -> "MCPTool":
"""将远程 MCP 工具包装为本地 Tool 对象"""
return MCPTool(
name=tool_name,
description=description,
client=self,
)
class MCPTool(Tool):
"""MCP 工具 - 通过 MCP Client 调用远程工具"""
def __init__(
self,
name: str,
description: str,
client: MCPClient,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema,
output_schema=output_schema,
version=version,
tags=tags or ["mcp"],
)
self._client = client
async def execute(self, **kwargs) -> dict:
result = await self._client.call_tool(self.name, kwargs)
# 解析 MCP 响应格式
if "content" in result:
for item in result["content"]:
if item.get("type") == "text":
import json
try:
return json.loads(item["text"])
except json.JSONDecodeError:
return {"result": item["text"]}
return result

View File

@ -0,0 +1,86 @@
"""MCP Server - 将 Agent 能力暴露为 MCP 工具
基于 FastAPI 实现支持 Streamable HTTP 传输
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class MCPServer:
"""MCP Server - 暴露 Agent 能力为 MCP 工具
自动将 ToolRegistry 中注册的工具暴露为 MCP 工具端点
"""
def __init__(self, tool_registry: Any = None, host: str = "0.0.0.0", port: int = 8080):
self._tool_registry = tool_registry
self._host = host
self._port = port
self._app = None
def _create_app(self):
"""创建 FastAPI 应用"""
try:
from fastapi import FastAPI
except ImportError:
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
app = FastAPI(title="Fischer AgentKit MCP Server")
@app.get("/tools/list")
async def list_tools():
if self._tool_registry is None:
return {"tools": []}
tools = self._tool_registry.list_tools()
return {
"tools": [
{
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema or {},
}
for t in tools
]
}
@app.post("/tools/call")
async def call_tool(request: dict):
tool_name = request.get("name")
arguments = request.get("arguments", {})
if not tool_name or self._tool_registry is None:
return {"error": "Tool not specified or registry not configured"}
try:
tool = self._tool_registry.get(tool_name)
result = await tool.safe_execute(**arguments)
return {"content": [{"type": "text", "text": str(result)}]}
except Exception as e:
return {"isError": True, "content": [{"type": "text", "text": str(e)}]}
@app.get("/health")
async def health():
return {"status": "ok"}
return app
async def start(self):
"""启动 MCP Server"""
self._app = self._create_app()
try:
import uvicorn
config = uvicorn.Config(self._app, host=self._host, port=self._port, log_level="info")
server = uvicorn.Server(config)
await server.serve()
except ImportError:
raise ImportError("MCP Server requires uvicorn: pip install uvicorn")
def get_app(self):
"""获取 FastAPI 应用实例(用于测试或嵌入)"""
if self._app is None:
self._app = self._create_app()
return self._app

View File

@ -0,0 +1,16 @@
"""AgentKit Memory - 记忆系统"""
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.retriever import MemoryRetriever
__all__ = [
"Memory",
"MemoryItem",
"WorkingMemory",
"EpisodicMemory",
"SemanticMemory",
"MemoryRetriever",
]

View File

@ -0,0 +1,74 @@
"""Memory 抽象基类 - 统一记忆接口"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class MemoryItem:
"""记忆条目"""
key: str
value: Any
metadata: dict[str, Any] = field(default_factory=dict)
score: float = 1.0
created_at: datetime = field(default_factory=lambda: datetime.utcnow())
def to_dict(self) -> dict:
return {
"key": self.key,
"value": self.value,
"metadata": self.metadata,
"score": self.score,
"created_at": self.created_at.isoformat(),
}
class Memory(ABC):
"""记忆抽象基类
三层记忆系统的统一接口
- WorkingMemory: 当前任务上下文Redis, 短生命周期
- EpisodicMemory: 任务经验pgvector+PG, 永久
- SemanticMemory: 知识库RAG+Graph, 永久
"""
@abstractmethod
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
"""存储记忆"""
...
@abstractmethod
async def retrieve(self, key: str) -> MemoryItem | None:
"""按 key 精确检索"""
...
@abstractmethod
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
"""语义检索"""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""删除记忆"""
...
async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None:
"""批量存储"""
for key, value, metadata in items:
await self.store(key, value, metadata)
async def get_context(self, query: str, token_budget: int = 3000) -> str:
"""获取格式化的上下文字符串(用于注入 Prompt"""
items = await self.search(query, top_k=10)
context_parts = []
total_tokens = 0
for item in items:
text = str(item.value)
estimated_tokens = len(text) // 4 # 粗略估算
if total_tokens + estimated_tokens > token_budget:
break
context_parts.append(text)
total_tokens += estimated_tokens
return "\n".join(context_parts)

View File

@ -0,0 +1,149 @@
"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆"""
import logging
import math
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
logger = logging.getLogger(__name__)
class EpisodicMemory(Memory):
"""Episodic Memory - 记录每次任务的输入/输出/效果/反思
基于 pgvector + PostgreSQL 实现支持语义检索和时间衰减
生命周期永久可配置衰减
"""
def __init__(
self,
session_factory: Any,
episodic_model: Any,
embedder: Any | None = None,
decay_rate: float = 0.01,
):
"""
Args:
session_factory: 返回 async context manager 的工厂
episodic_model: EpisodicMemory ORM 模型类
embedder: 嵌入器用于生成向量
decay_rate: 时间衰减率越大衰减越快
"""
self._session_factory = session_factory
self._episodic_model = episodic_model
self._embedder = embedder
self._decay_rate = decay_rate
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
"""存储任务经验"""
async with self._session_factory() as db:
try:
Model = self._episodic_model
meta = metadata or {}
# 生成 embedding
embedding = None
if self._embedder:
text = f"{key} {value}"
embedding = await self._embedder.embed(text)
entry = Model(
agent_name=meta.get("agent_name", ""),
task_type=meta.get("task_type", ""),
input_summary=str(value)[:500] if value else "",
output_summary=meta.get("output_summary", ""),
outcome=meta.get("outcome", "success"),
quality_score=meta.get("quality_score", 0.5),
reflection=meta.get("reflection", ""),
embedding=embedding,
)
db.add(entry)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to store episodic memory: {e}")
raise
async def retrieve(self, key: str) -> MemoryItem | None:
"""按 key 精确检索Episodic Memory 通常不按 key 检索)"""
return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
"""语义检索相似历史案例"""
async with self._session_factory() as db:
try:
Model = self._episodic_model
filters = filters or {}
# 构建查询
from sqlalchemy import select, text as sql_text
stmt = select(Model)
if filters.get("agent_name"):
stmt = stmt.where(Model.agent_name == filters["agent_name"])
if filters.get("task_type"):
stmt = stmt.where(Model.task_type == filters["task_type"])
if filters.get("outcome"):
stmt = stmt.where(Model.outcome == filters["outcome"])
stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 2)
result = await db.execute(stmt)
entries = result.scalars().all()
# 如果有 embedder进行向量相似度排序
if self._embedder and entries:
query_embedding = await self._embedder.embed(query)
# TODO: 使用 pgvector 的 cosine distance 排序
# 目前按时间衰减排序
# 时间衰减排序
items = []
for entry in entries:
age_hours = (datetime.utcnow() - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
decay = math.exp(-self._decay_rate * age_hours)
score = (entry.quality_score or 0.5) * decay
items.append(MemoryItem(
key=str(entry.id),
value={
"input_summary": entry.input_summary,
"output_summary": entry.output_summary,
"outcome": entry.outcome,
"quality_score": entry.quality_score,
"reflection": entry.reflection,
},
metadata={
"agent_name": entry.agent_name,
"task_type": entry.task_type,
"created_at": entry.created_at.isoformat() if entry.created_at else None,
},
score=score,
created_at=entry.created_at or datetime.utcnow(),
))
items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k]
except Exception as e:
logger.error(f"Failed to search episodic memory: {e}")
return []
async def delete(self, key: str) -> bool:
"""删除指定经验"""
async with self._session_factory() as db:
try:
from sqlalchemy import select, delete as sql_delete
import uuid
Model = self._episodic_model
stmt = sql_delete(Model).where(Model.id == uuid.UUID(key))
await db.execute(stmt)
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete episodic memory: {e}")
return False

View File

@ -0,0 +1,113 @@
"""MemoryRetriever - 混合检索器
并行查询三层记忆按权重融合排序
"""
import asyncio
import logging
import math
from datetime import datetime
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
logger = logging.getLogger(__name__)
class MemoryRetriever:
"""混合检索器 - 并行查询三层记忆,按权重融合排序
检索策略
1. 并行查询 Working/Episodic/Semantic 三层
2. 按权重融合排序默认 Working 0.2, Episodic 0.4, Semantic 0.4
3. 时间衰减越久远的记忆权重越低
4. 上下文窗口管理 token 不超过预算
"""
def __init__(
self,
working_memory: WorkingMemory | None = None,
episodic_memory: EpisodicMemory | None = None,
semantic_memory: SemanticMemory | None = None,
weights: dict[str, float] | None = None,
):
self._working = working_memory
self._episodic = episodic_memory
self._semantic = semantic_memory
self._weights = weights or {
"working": 0.2,
"episodic": 0.4,
"semantic": 0.4,
}
async def retrieve(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
) -> list[MemoryItem]:
"""混合检索三层记忆"""
tasks = []
layer_names = []
if self._working:
tasks.append(self._working.search(query, top_k=top_k, filters=filters))
layer_names.append("working")
if self._episodic:
tasks.append(self._episodic.search(query, top_k=top_k, filters=filters))
layer_names.append("episodic")
if self._semantic:
tasks.append(self._semantic.search(query, top_k=top_k, filters=filters))
layer_names.append("semantic")
if not tasks:
return []
# 并行查询
results = await asyncio.gather(*tasks, return_exceptions=True)
# 融合排序
all_items = []
for layer_name, result in zip(layer_names, results):
if isinstance(result, Exception):
logger.error(f"Memory search failed for {layer_name}: {result}")
continue
weight = self._weights.get(layer_name, 0.3)
for item in result:
item.score *= weight
all_items.append(item)
# 按分数排序
all_items.sort(key=lambda x: x.score, reverse=True)
# Token 预算管理
selected = []
total_tokens = 0
for item in all_items:
text = str(item.value)
estimated_tokens = len(text) // 4
if total_tokens + estimated_tokens > token_budget:
continue
selected.append(item)
total_tokens += estimated_tokens
if len(selected) >= top_k:
break
return selected
async def get_context_string(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
) -> str:
"""获取格式化的上下文字符串"""
items = await self.retrieve(query, top_k, token_budget)
parts = []
for item in items:
parts.append(str(item.value))
return "\n\n".join(parts)

View File

@ -0,0 +1,94 @@
"""Semantic Memory - 知识库适配器
适配器模式对接外部 RAG 服务和知识图谱
"""
import logging
from typing import Any
from agentkit.memory.base import Memory, MemoryItem
logger = logging.getLogger(__name__)
class SemanticMemory(Memory):
"""Semantic Memory - 知识库检索
通过适配器对接外部 RAG 服务不直接依赖具体实现
"""
def __init__(
self,
rag_service: Any = None,
graph_service: Any = None,
knowledge_base_ids: list[str] | None = None,
):
"""
Args:
rag_service: RAG 检索服务需提供 search 方法
graph_service: 知识图谱服务需提供 query 方法
knowledge_base_ids: 默认检索的知识库 ID 列表
"""
self._rag_service = rag_service
self._graph_service = graph_service
self._knowledge_base_ids = knowledge_base_ids or []
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
if self._rag_service and hasattr(self._rag_service, 'ingest'):
await self._rag_service.ingest(key, value, metadata)
else:
logger.warning("SemanticMemory.store: no RAG service configured for writing")
async def retrieve(self, key: str) -> MemoryItem | None:
"""按 key 精确检索Semantic Memory 通常不按 key 检索)"""
return None
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
"""语义检索知识库"""
items = []
# RAG 检索
if self._rag_service:
try:
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
for r in results:
items.append(MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
"source": r.get("source", "rag"),
"score": r.get("score", 0.0),
"document_id": r.get("document_id"),
},
score=r.get("score", 0.0),
))
except Exception as e:
logger.error(f"RAG search failed: {e}")
# 知识图谱检索
if self._graph_service:
try:
graph_results = await self._graph_service.query(query, depth=2)
for r in graph_results[:top_k]:
items.append(MemoryItem(
key=r.get("id", ""),
value=r.get("content", ""),
metadata={
"source": "graph",
"entities": r.get("entities", []),
"relations": r.get("relations", []),
},
score=r.get("score", 0.0),
))
except Exception as e:
logger.error(f"Graph search failed: {e}")
items.sort(key=lambda x: x.score, reverse=True)
return items[:top_k]
async def delete(self, key: str) -> bool:
"""Semantic Memory 通常只读"""
logger.warning("SemanticMemory.delete: read-only memory")
return False

View File

@ -0,0 +1,97 @@
"""Working Memory - 基于 Redis 的短期任务记忆"""
import json
import logging
from datetime import datetime
from typing import Any
import redis.asyncio as aioredis
from agentkit.memory.base import Memory, MemoryItem
logger = logging.getLogger(__name__)
class WorkingMemory(Memory):
"""Working Memory - 当前任务的上下文和中间状态
基于 Redis 实现支持自动过期TTL
生命周期单次任务
"""
def __init__(
self,
redis: aioredis.Redis,
key_prefix: str = "agentkit:working",
default_ttl: int = 3600,
):
self._redis = redis
self._key_prefix = key_prefix
self._default_ttl = default_ttl
def _make_key(self, key: str) -> str:
return f"{self._key_prefix}:{key}"
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
redis_key = self._make_key(key)
item = MemoryItem(
key=key,
value=value,
metadata=metadata or {},
created_at=datetime.utcnow(),
)
await self._redis.setex(
redis_key,
self._default_ttl,
json.dumps(item.to_dict(), default=str),
)
async def retrieve(self, key: str) -> MemoryItem | None:
redis_key = self._make_key(key)
data = await self._redis.get(redis_key)
if data is None:
return None
item_dict = json.loads(data)
return MemoryItem(
key=item_dict["key"],
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=item_dict.get("score", 1.0),
created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.utcnow(),
)
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
"""Working Memory 不支持语义检索,按 key 前缀匹配"""
pattern = self._make_key(f"{query}*")
keys = []
async for key in self._redis.scan_iter(match=pattern, count=top_k):
keys.append(key)
if len(keys) >= top_k:
break
items = []
for key in keys:
data = await self._redis.get(key)
if data:
item_dict = json.loads(data)
items.append(MemoryItem(
key=item_dict["key"],
value=item_dict["value"],
metadata=item_dict.get("metadata", {}),
score=1.0,
created_at=datetime.utcnow(),
))
return items
async def delete(self, key: str) -> bool:
redis_key = self._make_key(key)
return bool(await self._redis.delete(redis_key))
async def clear(self, prefix: str = "") -> int:
"""清除指定前缀的所有 Working Memory"""
pattern = self._make_key(f"{prefix}*")
count = 0
async for key in self._redis.scan_iter(match=pattern):
await self._redis.delete(key)
count += 1
return count

View File

@ -0,0 +1,17 @@
"""AgentKit Orchestrator - 多 Agent 协同编排"""
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_loader import PipelineLoader
from agentkit.orchestrator.handoff import HandoffManager
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
__all__ = [
"Pipeline",
"PipelineStage",
"StageStatus",
"PipelineEngine",
"PipelineLoader",
"HandoffManager",
"DynamicPipeline",
]

View File

@ -0,0 +1,92 @@
"""DynamicPipeline - 动态 Pipeline 组合
支持运行时根据条件选择子流程嵌套 Pipeline循环 Pipeline
"""
import logging
from typing import Any
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus
logger = logging.getLogger(__name__)
class DynamicPipeline:
"""动态 Pipeline 组合器"""
def __init__(self, engine: PipelineEngine, loader: Any = None):
self._engine = engine
self._loader = loader
async def execute_conditional(
self,
pipelines: dict[str, Pipeline],
condition_key: str,
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""根据条件选择子 Pipeline 执行"""
context = context or {}
condition_value = context.get(condition_key)
if condition_value not in pipelines:
return PipelineResult(
pipeline_name=f"conditional_{condition_key}",
status=StageStatus.FAILED,
error_message=f"No pipeline for condition '{condition_key}={condition_value}'",
)
selected = pipelines[condition_value]
logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}")
return await self._engine.execute(selected, context)
async def execute_nested(
self,
parent: Pipeline,
sub_pipeline_map: dict[str, Pipeline],
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""执行嵌套 Pipeline"""
# 先执行父 Pipeline
parent_result = await self._engine.execute(parent, context)
# 根据父 Pipeline 结果选择子 Pipeline
for stage_name, stage_result in parent_result.stage_results.items():
if hasattr(stage_result, 'output_data') and stage_result.output_data:
sub_pipeline_name = stage_result.output_data.get("sub_pipeline")
if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map:
sub = sub_pipeline_map[sub_pipeline_name]
sub_result = await self._engine.execute(sub, parent_result.variables)
parent_result.variables.update(sub_result.variables)
return parent_result
async def execute_loop(
self,
pipeline: Pipeline,
max_iterations: int = 5,
exit_condition: str = "done",
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""循环执行 Pipeline 直到条件满足"""
current_context = context or {}
last_result = None
for i in range(max_iterations):
logger.info(f"DynamicPipeline loop iteration {i + 1}/{max_iterations}")
result = await self._engine.execute(pipeline, current_context)
last_result = result
# 检查退出条件
if exit_condition in result.variables and result.variables[exit_condition]:
logger.info(f"DynamicPipeline loop exited at iteration {i + 1}")
break
# 将结果作为下一轮的输入
current_context.update(result.variables)
return last_result or PipelineResult(
pipeline_name=pipeline.name,
status=StageStatus.FAILED,
error_message="Loop completed without meeting exit condition",
)

View File

@ -0,0 +1,69 @@
"""HandoffManager - Agent 间任务转交"""
import asyncio
import json
import logging
from typing import Any
from agentkit.core.protocol import HandoffMessage
logger = logging.getLogger(__name__)
class HandoffManager:
"""Handoff 管理器
通过 Redis Pub/Sub 管理 Agent 间的任务转交
"""
def __init__(self, redis: Any = None, dispatcher: Any = None):
self._redis = redis
self._dispatcher = dispatcher
self._handlers: dict[str, list[Any]] = {}
def register_handler(self, agent_name: str, handler: Any) -> None:
"""注册 Handoff 处理器"""
if agent_name not in self._handlers:
self._handlers[agent_name] = []
self._handlers[agent_name].append(handler)
async def send_handoff(self, handoff: HandoffMessage) -> None:
"""发送 Handoff 请求"""
if self._redis is None:
raise RuntimeError("Handoff requires Redis connection")
channel = f"agent:{handoff.target_agent}:handoff"
await self._redis.publish(channel, json.dumps(handoff.to_dict()))
logger.info(
f"Handoff sent: {handoff.source_agent} -> {handoff.target_agent} "
f"(task={handoff.task_id}, reason={handoff.reason})"
)
async def listen_for_handoffs(self, agent_name: str) -> None:
"""监听发往指定 Agent 的 Handoff 请求"""
if self._redis is None:
return
channel = f"agent:{agent_name}:handoff"
pubsub = self._redis.pubsub()
await pubsub.subscribe(channel)
try:
async for message in pubsub.listen():
if message["type"] == "message":
data = json.loads(message["data"])
handoff = HandoffMessage.from_dict(data)
await self._handle_handoff(handoff)
except asyncio.CancelledError:
pass
finally:
await pubsub.unsubscribe(channel)
async def _handle_handoff(self, handoff: HandoffMessage) -> None:
"""处理收到的 Handoff"""
handlers = self._handlers.get(handoff.target_agent, [])
for handler in handlers:
try:
await handler(handoff)
except Exception as e:
logger.error(f"Handoff handler error: {e}")

View File

@ -0,0 +1,241 @@
"""Pipeline Engine - DAG + 并行执行"""
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineResult,
PipelineStage,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
class PipelineEngine:
"""Pipeline 执行引擎
支持
- DAG 拓扑排序
- 同层并行执行asyncio.gather
- 变量解析
- 条件执行
- 重试
"""
def __init__(self, dispatcher: Any = None):
self._dispatcher = dispatcher
async def execute(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""执行 Pipeline"""
result = PipelineResult(pipeline_name=pipeline.name)
result.variables = {**pipeline.variables, **(context or {})}
# 拓扑排序 + 按依赖层级分组
try:
level_groups = self._topological_group(pipeline.stages)
except ValueError as e:
result.status = StageStatus.FAILED
result.error_message = str(e)
return result
# 逐层执行
for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
# 并行执行同层 stages
tasks = []
for stage in stages:
tasks.append(self._execute_stage(stage, result))
stage_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for stage, sr in zip(stages, stage_results):
if isinstance(sr, Exception):
sr = StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=str(sr),
)
result.stage_results[stage.name] = sr
# 收集输出变量
if sr.output_data and isinstance(sr, dict):
pass
elif hasattr(sr, 'output_data') and sr.output_data:
for output_key in stage.outputs:
if output_key in sr.output_data:
result.variables[output_key] = sr.output_data[output_key]
# 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure:
result.status = StageStatus.FAILED
result.error_message = f"Stage '{stage.name}' failed"
return result
result.status = StageStatus.COMPLETED
return result
async def _execute_stage(
self,
stage: PipelineStage,
pipeline_result: PipelineResult,
) -> StageResult:
"""执行单个 stage"""
started_at = datetime.now(timezone.utc).isoformat()
# 条件检查
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
return StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# 解析输入变量
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
# 执行
if self._dispatcher is None:
# Dry-run 模式
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED,
output_data={"dry_run": True, "inputs": resolved_inputs},
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# 通过 Dispatcher 分发任务
from agentkit.core.protocol import TaskMessage
import uuid
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=stage.agent,
task_type=stage.action,
priority=0,
input_data=resolved_inputs,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=stage.timeout_seconds,
)
try:
await self._dispatcher.dispatch(task)
# 等待结果
for _ in range(stage.timeout_seconds):
status = await self._dispatcher.get_task_status(task.task_id)
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
await asyncio.sleep(1)
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=f"Timeout after {stage.timeout_seconds}s",
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
except Exception as e:
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
@staticmethod
def _topological_group(stages: list[PipelineStage]) -> list[list[PipelineStage]]:
"""拓扑排序 + 按依赖层级分组"""
stage_map = {s.name: s for s in stages}
in_degree = defaultdict(int)
dependents = defaultdict(list)
for s in stages:
if s.name not in in_degree:
in_degree[s.name] = 0
for dep in s.depends_on:
if dep not in stage_map:
raise ValueError(f"Stage '{s.name}' depends on unknown stage '{dep}'")
in_degree[s.name] += 1
dependents[dep].append(s.name)
levels = []
remaining = set(in_degree.keys())
while remaining:
# 找到入度为 0 的节点
current_level = [name for name in remaining if in_degree[name] == 0]
if not current_level:
raise ValueError("Circular dependency detected in pipeline")
levels.append([stage_map[name] for name in current_level])
for name in current_level:
remaining.remove(name)
for dep in dependents[name]:
in_degree[dep] -= 1
return levels
@staticmethod
def _resolve_variables(template: dict, context: dict) -> dict:
"""解析 ${var.path} 变量引用"""
resolved = {}
for key, value in template.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
var_path = value[2:-1]
resolved[key] = PipelineEngine._get_nested(context, var_path)
else:
resolved[key] = value
return resolved
@staticmethod
def _get_nested(data: dict, path: str) -> Any:
keys = path.split(".")
current = data
for key in keys:
if isinstance(current, dict):
current = current.get(key)
else:
return None
return current
@staticmethod
def _evaluate_condition(condition: str, variables: dict) -> bool:
"""简单条件评估"""
if "==" in condition:
parts = condition.split("==", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) == right
elif "!=" in condition:
parts = condition.split("!=", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) != right
else:
return bool(variables.get(condition))

View File

@ -0,0 +1,72 @@
"""Pipeline Loader - YAML 加载器"""
import logging
from pathlib import Path
from typing import Any
import yaml
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
logger = logging.getLogger(__name__)
class PipelineLoader:
"""Pipeline YAML 加载器"""
def __init__(self, pipelines_dir: str | Path | None = None):
self._pipelines_dir = Path(pipelines_dir) if pipelines_dir else Path("pipelines")
def load(self, pipeline_name: str) -> Pipeline:
"""从 YAML 文件加载 Pipeline"""
yaml_path = self._pipelines_dir / f"{pipeline_name}.yaml"
if not yaml_path.exists():
yaml_path = self._pipelines_dir / f"{pipeline_name}.yml"
if not yaml_path.exists():
raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}")
content = yaml_path.read_text(encoding="utf-8")
return self.load_from_yaml(content, pipeline_name)
def load_from_yaml(self, yaml_content: str, pipeline_name: str | None = None) -> Pipeline:
"""从 YAML 字符串加载 Pipeline"""
data = yaml.safe_load(yaml_content)
stages = []
for stage_data in data.get("stages", []):
stages.append(PipelineStage(**stage_data))
return Pipeline(
name=data.get("name", pipeline_name or "unnamed"),
version=data.get("version", "1.0.0"),
description=data.get("description", ""),
stages=stages,
variables=data.get("variables", {}),
)
@staticmethod
def validate_dag(stages: list[PipelineStage]) -> bool:
"""验证 DAG 无环"""
stage_names = {s.name for s in stages}
visited = set()
path = set()
def dfs(name: str) -> bool:
if name in path:
return False
if name in visited:
return True
path.add(name)
stage = next((s for s in stages if s.name == name), None)
if stage:
for dep in stage.depends_on:
if dep not in stage_names:
return False
if not dfs(dep):
return False
path.remove(name)
visited.add(name)
return True
return all(dfs(name) for name in stage_names)

View File

@ -0,0 +1,52 @@
"""Pipeline 数据模型"""
from enum import Enum
from typing import Any
from pydantic import BaseModel
class StageStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
class PipelineStage(BaseModel):
name: str
agent: str
action: str
depends_on: list[str] = []
inputs: dict[str, Any] = {}
outputs: list[str] = []
timeout_seconds: int = 300
retry_count: int = 0
continue_on_failure: bool = False
condition: str | None = None
class Pipeline(BaseModel):
name: str
version: str
description: str
stages: list[PipelineStage]
variables: dict[str, Any] = {}
class StageResult(BaseModel):
stage_name: str
status: StageStatus = StageStatus.PENDING
output_data: dict[str, Any] | None = None
error_message: str | None = None
started_at: str | None = None
completed_at: str | None = None
class PipelineResult(BaseModel):
pipeline_name: str
status: StageStatus = StageStatus.PENDING
stage_results: dict[str, StageResult] = {}
variables: dict[str, Any] = {}
error_message: str | None = None

View File

@ -0,0 +1,9 @@
"""AgentKit Prompts - Prompt 模板系统"""
from agentkit.prompts.template import PromptTemplate
from agentkit.prompts.section import PromptSection
__all__ = [
"PromptTemplate",
"PromptSection",
]

View File

@ -0,0 +1,36 @@
"""PromptSection - 模块化 Prompt 段落"""
from dataclasses import dataclass
@dataclass
class PromptSection:
"""Prompt 段落定义
Prompt 分为 5 个标准段落支持变量注入和 Token 预算管理
"""
identity: str = ""
context: str = ""
instructions: str = ""
constraints: str = ""
output_format: str = ""
examples: str = ""
def render(self, variables: dict | None = None) -> str:
"""渲染段落,替换变量"""
text = "\n\n".join(
part for part in [
self.identity,
self.context,
self.instructions,
self.constraints,
self.output_format,
self.examples,
] if part
)
if variables:
for key, value in variables.items():
text = text.replace(f"${{{key}}}", str(value))
return text

View File

@ -0,0 +1,71 @@
"""PromptTemplate - Prompt 模板渲染"""
import logging
from typing import Any
from agentkit.prompts.section import PromptSection
logger = logging.getLogger(__name__)
class PromptTemplate:
"""Prompt 模板
支持变量注入Token 预算管理动态段落组合
"""
def __init__(
self,
sections: PromptSection,
name: str = "",
version: str = "1.0.0",
):
self._sections = sections
self.name = name
self.version = version
def render(
self,
variables: dict[str, Any] | None = None,
context_budget: int = 3000,
) -> list[dict[str, str]]:
"""渲染 Prompt 为消息列表
Returns:
[{"role": "system", "content": "..."}, {"role": "user", "content": "..."}]
"""
system_parts = []
if self._sections.identity:
system_parts.append(self._sections.identity)
if self._sections.context:
context = self._sections.context
if variables:
for key, value in variables.items():
context = context.replace(f"${{{key}}}", str(value))
system_parts.append(context)
if self._sections.constraints:
system_parts.append(self._sections.constraints)
user_parts = []
if self._sections.instructions:
instructions = self._sections.instructions
if variables:
for key, value in variables.items():
instructions = instructions.replace(f"${{{key}}}", str(value))
user_parts.append(instructions)
if self._sections.output_format:
user_parts.append(self._sections.output_format)
if self._sections.examples:
user_parts.append(self._sections.examples)
messages = []
if system_parts:
messages.append({"role": "system", "content": "\n\n".join(system_parts)})
if user_parts:
messages.append({"role": "user", "content": "\n\n".join(user_parts)})
return messages
@property
def sections(self) -> PromptSection:
return self._sections

View File

@ -0,0 +1,13 @@
"""AgentKit Tools - 工具插件系统"""
from agentkit.tools.base import Tool
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.agent_tool import AgentTool
from agentkit.tools.registry import ToolRegistry
__all__ = [
"Tool",
"FunctionTool",
"AgentTool",
"ToolRegistry",
]

View File

@ -0,0 +1,93 @@
"""AgentTool - 将 Agent 包装为 Tool"""
from typing import Any
from agentkit.tools.base import Tool
class AgentTool(Tool):
"""将另一个 Agent 包装为 Tool
通过 Dispatcher 分发任务到目标 Agent等待结果返回
"""
def __init__(
self,
name: str,
description: str,
agent_name: str,
task_type: str,
input_mapping: dict[str, str] | None = None,
output_mapping: dict[str, str] | None = None,
timeout_seconds: int = 300,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
version=version,
tags=tags or ["agent"],
)
self.agent_name = agent_name
self.task_type = task_type
self.input_mapping = input_mapping or {}
self.output_mapping = output_mapping or {}
self.timeout_seconds = timeout_seconds
self._dispatcher = None
def set_dispatcher(self, dispatcher: Any) -> "AgentTool":
"""注入 Dispatcher"""
self._dispatcher = dispatcher
return self
async def execute(self, **kwargs) -> dict:
if self._dispatcher is None:
raise RuntimeError(f"AgentTool '{self.name}' has no dispatcher configured")
from agentkit.core.protocol import TaskMessage
from datetime import datetime, timezone
import uuid
# 映射输入
mapped_input = {}
for target_key, source_key in self.input_mapping.items():
if source_key in kwargs:
mapped_input[target_key] = kwargs[source_key]
if not mapped_input:
mapped_input = kwargs
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=self.agent_name,
task_type=self.task_type,
priority=0,
input_data=mapped_input,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=self.timeout_seconds,
)
await self._dispatcher.dispatch(task)
# 等待结果
import asyncio
for _ in range(self.timeout_seconds):
status = await self._dispatcher.get_task_status(task.task_id)
if status["status"] in ("completed", "failed", "cancelled"):
if status["status"] == "completed" and status.get("output_data"):
output = status["output_data"]
# 映射输出
if self.output_mapping:
mapped_output = {}
for target_key, source_key in self.output_mapping.items():
if source_key in output:
mapped_output[target_key] = output[source_key]
return mapped_output
return output
elif status["status"] == "failed":
raise RuntimeError(f"Agent '{self.agent_name}' failed: {status.get('error_message')}")
return {}
await asyncio.sleep(1)
raise TimeoutError(f"Agent '{self.agent_name}' timed out after {self.timeout_seconds}s")

View File

@ -0,0 +1,68 @@
"""Tool 抽象基类 - 统一工具接口"""
from abc import ABC, abstractmethod
from typing import Any
class Tool(ABC):
"""工具抽象基类
所有工具FunctionTool, AgentTool, MCPTool的统一接口
"""
def __init__(
self,
name: str,
description: str,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
self.name = name
self.description = description
self.input_schema = input_schema
self.output_schema = output_schema
self.version = version
self.tags = tags or []
@abstractmethod
async def execute(self, **kwargs) -> dict:
"""执行工具,返回结果 dict"""
...
async def before_execute(self, **kwargs) -> None:
"""执行前钩子"""
pass
async def after_execute(self, result: dict, **kwargs) -> None:
"""执行后钩子"""
pass
async def on_error(self, error: Exception, **kwargs) -> None:
"""错误钩子"""
pass
async def safe_execute(self, **kwargs) -> dict:
"""带钩子的安全执行"""
try:
await self.before_execute(**kwargs)
result = await self.execute(**kwargs)
await self.after_execute(result, **kwargs)
return result
except Exception as e:
await self.on_error(e, **kwargs)
raise
def to_dict(self) -> dict:
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema,
"output_schema": self.output_schema,
"version": self.version,
"tags": self.tags,
}
def __repr__(self) -> str:
return f"<{type(self).__name__} name={self.name!r} version={self.version}>"

View File

@ -0,0 +1,76 @@
"""FunctionTool - 将普通 Python 函数包装为 Tool"""
import inspect
from typing import Any, Callable, Awaitable
from agentkit.tools.base import Tool
class FunctionTool(Tool):
"""将普通 Python 函数包装为 Tool
自动从函数签名推断 input_schema
"""
def __init__(
self,
name: str,
description: str,
func: Callable[..., Awaitable[dict]] | Callable[..., dict],
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._infer_schema(func),
output_schema=output_schema,
version=version,
tags=tags,
)
self._func = func
async def execute(self, **kwargs) -> dict:
result = self._func(**kwargs)
if inspect.isawaitable(result):
result = await result
if not isinstance(result, dict):
result = {"result": result}
return result
@staticmethod
def _infer_schema(func: Callable) -> dict:
"""从函数签名推断 JSON Schema"""
sig = inspect.signature(func)
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param_name in ("self", "cls"):
continue
param_type = "string"
if param.annotation != inspect.Parameter.empty:
if param.annotation in (int, float):
param_type = "number"
elif param.annotation == bool:
param_type = "boolean"
elif param.annotation in (list, tuple):
param_type = "array"
elif param.annotation == dict:
param_type = "object"
properties[param_name] = {"type": param_type}
if param.default == inspect.Parameter.empty:
required.append(param_name)
schema = {
"type": "object",
"properties": properties,
}
if required:
schema["required"] = required
return schema

View File

@ -0,0 +1,72 @@
"""ToolRegistry - 工具注册中心"""
import logging
from typing import Any
from agentkit.core.exceptions import ToolNotFoundError
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class ToolRegistry:
"""工具注册中心,管理工具的注册、发现、版本"""
def __init__(self):
self._tools: dict[str, dict[str, Tool]] = {} # name -> {version -> tool}
def register(self, tool: Tool) -> "ToolRegistry":
"""注册工具"""
if tool.name not in self._tools:
self._tools[tool.name] = {}
self._tools[tool.name][tool.version] = tool
logger.info(f"Tool '{tool.name}' v{tool.version} registered")
return self
def unregister(self, name: str, version: str | None = None) -> None:
"""注销工具"""
if name not in self._tools:
return
if version:
self._tools[name].pop(version, None)
if not self._tools[name]:
del self._tools[name]
else:
del self._tools[name]
def get(self, name: str, version: str | None = None) -> Tool:
"""获取工具(默认返回最新版本)"""
if name not in self._tools:
raise ToolNotFoundError(name)
versions = self._tools[name]
if version:
if version not in versions:
raise ToolNotFoundError(f"{name}@{version}")
return versions[version]
# 返回最新版本
latest = sorted(versions.keys())[-1]
return versions[latest]
def list_tools(self, tag: str | None = None) -> list[Tool]:
"""列出所有工具(最新版本),可按标签过滤"""
result = []
for name, versions in self._tools.items():
latest = sorted(versions.keys())[-1]
tool = versions[latest]
if tag is None or tag in tool.tags:
result.append(tool)
return result
def list_all_versions(self, name: str) -> list[Tool]:
"""列出指定工具的所有版本"""
if name not in self._tools:
return []
return list(self._tools[name].values())
def has_tool(self, name: str) -> bool:
return name in self._tools
def clear(self) -> None:
self._tools.clear()

0
tests/__init__.py Normal file
View File

View File

0
tests/unit/__init__.py Normal file
View File

View File

@ -0,0 +1,139 @@
"""Tests for BaseAgent - 统一生命周期"""
import asyncio
import pytest
from agentkit.core.base import BaseAgent
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskResult,
TaskStatus,
)
from datetime import datetime, timezone
class SimpleAgent(BaseAgent):
"""测试用简单 Agent"""
def __init__(self):
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
self.task_started = False
self.task_completed = False
self.task_failed = False
async def handle_task(self, task: TaskMessage) -> dict:
if task.task_type == "echo":
return {"echo": task.input_data}
elif task.task_type == "fail":
raise ValueError("intentional failure")
return {"status": "ok"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["echo", "fail"],
max_concurrency=2,
description="Test agent",
)
async def on_task_start(self, task):
self.task_started = True
async def on_task_complete(self, task, output):
self.task_completed = True
async def on_task_failed(self, task, error):
self.task_failed = True
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="test_agent",
task_type=task_type,
priority=0,
input_data=input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
@pytest.mark.asyncio
async def test_handle_task_returns_output():
agent = SimpleAgent()
task = _make_task("echo", {"msg": "hello"})
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"echo": {"msg": "hello"}}
assert result.error_message is None
assert result.metrics["task_type"] == "echo"
@pytest.mark.asyncio
async def test_handle_task_failure():
agent = SimpleAgent()
task = _make_task("fail")
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message == "intentional failure"
assert result.metrics["error_type"] == "ValueError"
@pytest.mark.asyncio
async def test_lifecycle_hooks():
agent = SimpleAgent()
# 成功任务
task = _make_task("echo")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_completed is True
# 重置
agent.task_started = False
agent.task_completed = False
# 失败任务
task = _make_task("fail")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_failed is True
@pytest.mark.asyncio
async def test_execute_wraps_timing():
agent = SimpleAgent()
task = _make_task("echo")
result = await agent.execute(task)
assert result.started_at is not None
assert result.completed_at is not None
assert result.metrics["elapsed_seconds"] >= 0
@pytest.mark.asyncio
async def test_agent_status():
agent = SimpleAgent()
assert agent.status == AgentStatus.OFFLINE
assert agent.is_distributed is False
@pytest.mark.asyncio
async def test_tool_injection():
from agentkit.tools.function_tool import FunctionTool
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
agent = SimpleAgent()
agent.use_tool(tool)
assert len(agent.tools) == 1
assert agent.tools[0].name == "doubler"

View File

@ -0,0 +1,131 @@
"""Tests for Evolution system"""
import pytest
from agentkit.evolution.reflector import Reflector, Reflection
from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module
from agentkit.evolution.strategy_tuner import StrategyTuner, StrategyConfig
from agentkit.evolution.ab_tester import ABTester, ABTestConfig
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from datetime import datetime, timezone
def _make_task() -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="test",
task_type="echo",
priority=0,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
return TaskResult(
task_id="test-001",
agent_name="test",
status=status,
output_data={"key": "value"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed_seconds": 5.0},
)
@pytest.mark.asyncio
async def test_reflector_success():
reflector = Reflector()
task = _make_task()
result = _make_result()
reflection = await reflector.reflect(task, result)
assert reflection.outcome == "success"
assert reflection.quality_score > 0
@pytest.mark.asyncio
async def test_reflector_failure():
reflector = Reflector()
task = _make_task()
result = _make_result(TaskStatus.FAILED)
result.error_message = "something went wrong"
reflection = await reflector.reflect(task, result)
assert reflection.outcome == "failure"
assert reflection.quality_score == 0.0
@pytest.mark.asyncio
async def test_prompt_optimizer():
optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=2)
# Add examples
for i in range(5):
optimizer.add_example(
input_data={"query": f"query_{i}"},
output_data={"result": f"result_{i}"},
quality_score=0.8 + i * 0.02,
)
module = Module(
name="test_module",
signature=Signature(
input_fields={"query": "search query"},
output_fields={"result": "search result"},
instruction="Find the best result.",
),
)
optimized = await optimizer.optimize(module)
assert optimized.name == "test_module_optimized"
assert len(optimized.demos) == 3
@pytest.mark.asyncio
async def test_prompt_optimizer_not_enough_examples():
optimizer = PromptOptimizer(min_examples_for_optimization=10)
module = Module(
name="test",
signature=Signature(
input_fields={"x": "input"},
output_fields={"y": "output"},
),
)
optimized = await optimizer.optimize(module)
# Should return unchanged module
assert optimized.name == "test"
def test_strategy_tuner():
tuner = StrategyTuner()
config = StrategyConfig(temperature=0.5)
tuner.record(config, metric=0.6)
tuner.record(StrategyConfig(temperature=0.7), metric=0.8)
tuner.record(StrategyConfig(temperature=0.3), metric=0.4)
@pytest.mark.asyncio
async def test_ab_tester():
tester = ABTester()
test_config = ABTestConfig(
test_id="test-1",
agent_name="test_agent",
change_type="prompt",
min_samples=5,
)
tester.create_test(test_config)
# Record results
for _ in range(10):
group = tester.assign_group("test-1")
metric = 0.7 if group == "experiment" else 0.5
tester.record_result("test-1", group, metric)
result = await tester.evaluate("test-1")
assert result is not None
assert result.control_samples + result.experiment_samples == 10

109
tests/unit/test_pipeline.py Normal file
View File

@ -0,0 +1,109 @@
"""Tests for Pipeline system"""
import pytest
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_loader import PipelineLoader
def test_pipeline_schema():
stage = PipelineStage(
name="step1",
agent="agent_a",
action="process",
depends_on=[],
inputs={"data": "${input}"},
)
pipeline = Pipeline(
name="test_pipeline",
version="1.0.0",
description="Test",
stages=[stage],
variables={"input": "hello"},
)
assert len(pipeline.stages) == 1
assert pipeline.variables["input"] == "hello"
def test_topological_group():
stages = [
PipelineStage(name="a", agent="agent1", action="do_a", depends_on=[]),
PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]),
PipelineStage(name="c", agent="agent3", action="do_c", depends_on=["a"]),
PipelineStage(name="d", agent="agent4", action="do_d", depends_on=["b", "c"]),
]
groups = PipelineEngine._topological_group(stages)
assert len(groups) == 3 # [a], [b, c], [d]
assert groups[0][0].name == "a"
assert {s.name for s in groups[1]} == {"b", "c"}
assert groups[2][0].name == "d"
def test_topological_group_circular():
stages = [
PipelineStage(name="a", agent="agent1", action="do_a", depends_on=["b"]),
PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]),
]
with pytest.raises(ValueError, match="Circular dependency"):
PipelineEngine._topological_group(stages)
@pytest.mark.asyncio
async def test_pipeline_dry_run():
pipeline = Pipeline(
name="test",
version="1.0.0",
description="Test",
stages=[
PipelineStage(name="step1", agent="agent1", action="process", inputs={"x": "1"}),
PipelineStage(name="step2", agent="agent2", action="analyze", depends_on=["step1"], inputs={"y": "2"}),
],
)
engine = PipelineEngine(dispatcher=None) # dry-run mode
result = await engine.execute(pipeline)
assert result.status.value == "completed"
def test_yaml_loader():
yaml_content = """
name: test_pipeline
version: "1.0"
description: Test pipeline
stages:
- name: step1
agent: agent1
action: process
inputs:
data: hello
- name: step2
agent: agent2
action: analyze
depends_on: [step1]
inputs:
result: ${{step1_output}}
variables:
input: world
"""
loader = PipelineLoader()
pipeline = loader.load_from_yaml(yaml_content, "test")
assert pipeline.name == "test_pipeline"
assert len(pipeline.stages) == 2
assert pipeline.stages[1].depends_on == ["step1"]
def test_validate_dag():
stages = [
PipelineStage(name="a", agent="a1", action="do", depends_on=[]),
PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]),
]
assert PipelineLoader.validate_dag(stages) is True
circular = [
PipelineStage(name="a", agent="a1", action="do", depends_on=["b"]),
PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]),
]
assert PipelineLoader.validate_dag(circular) is False

View File

@ -0,0 +1,95 @@
"""Tests for Protocol data structures"""
import pytest
from datetime import datetime
from agentkit.core.protocol import (
AgentCapability,
HandoffMessage,
TaskMessage,
TaskResult,
TaskStatus,
EvolutionEvent,
)
def test_task_status_values():
assert TaskStatus.PENDING == "pending"
assert TaskStatus.RUNNING == "running"
assert TaskStatus.COMPLETED == "completed"
assert TaskStatus.FAILED == "failed"
assert TaskStatus.CANCELLED == "cancelled"
assert TaskStatus.HANDOFF == "handoff"
def test_agent_capability_with_schema():
cap = AgentCapability(
agent_name="test",
agent_type="test",
version="1.0.0",
supported_tasks=["echo"],
max_concurrency=2,
description="Test",
input_schema={"type": "object", "properties": {"x": {"type": "number"}}},
output_schema={"type": "object"},
)
d = cap.to_dict()
assert "input_schema" in d
assert "output_schema" in d
restored = AgentCapability.from_dict(d)
assert restored.agent_name == "test"
assert restored.input_schema is not None
def test_task_message_roundtrip():
msg = TaskMessage(
task_id="123",
agent_name="agent1",
task_type="echo",
priority=1,
input_data={"key": "value"},
callback_url=None,
created_at=datetime.utcnow(),
conversation_id="conv-1",
)
d = msg.to_dict()
assert d["conversation_id"] == "conv-1"
restored = TaskMessage.from_dict(d)
assert restored.task_id == "123"
assert restored.conversation_id == "conv-1"
def test_handoff_message():
msg = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="task-1",
task_type="analyze",
context={"data": "test"},
reason="Need expert analysis",
)
d = msg.to_dict()
assert d["source_agent"] == "agent_a"
assert d["target_agent"] == "agent_b"
restored = HandoffMessage.from_dict(d)
assert restored.reason == "Need expert analysis"
def test_evolution_event():
event = EvolutionEvent(
agent_name="optimizer",
change_type="prompt",
before={"instruction": "old"},
after={"instruction": "new"},
metrics={"quality_delta": 0.15},
)
d = event.to_dict()
assert d["change_type"] == "prompt"
assert d["metrics"]["quality_delta"] == 0.15

View File

@ -0,0 +1,104 @@
"""Tests for Tool system"""
import asyncio
import pytest
from agentkit.tools.base import Tool
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
async def add_numbers(a: int, b: int) -> dict:
return {"sum": a + b}
def sync_greet(name: str) -> dict:
return {"greeting": f"Hello, {name}!"}
@pytest.mark.asyncio
async def test_function_tool_async():
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
result = await tool.execute(a=1, b=2)
assert result == {"sum": 3}
@pytest.mark.asyncio
async def test_function_tool_sync():
tool = FunctionTool(name="greet", description="Greet someone", func=sync_greet)
result = await tool.execute(name="World")
assert result == {"greeting": "Hello, World!"}
@pytest.mark.asyncio
async def test_function_tool_schema_inference():
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
assert tool.input_schema is not None
assert "a" in tool.input_schema.get("properties", {})
assert "b" in tool.input_schema.get("properties", {})
@pytest.mark.asyncio
async def test_tool_registry():
registry = ToolRegistry()
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
registry.register(tool)
assert registry.has_tool("add")
retrieved = registry.get("add")
assert retrieved.name == "add"
@pytest.mark.asyncio
async def test_tool_registry_versioning():
registry = ToolRegistry()
v1 = FunctionTool(name="add", description="Add v1", func=add_numbers, version="1.0.0")
v2 = FunctionTool(name="add", description="Add v2", func=add_numbers, version="2.0.0")
registry.register(v1)
registry.register(v2)
# Default returns latest
latest = registry.get("add")
assert latest.version == "2.0.0"
# Can request specific version
specific = registry.get("add", version="1.0.0")
assert specific.version == "1.0.0"
@pytest.mark.asyncio
async def test_tool_registry_list():
registry = ToolRegistry()
t1 = FunctionTool(name="add", description="Add", func=add_numbers, tags=["math"])
t2 = FunctionTool(name="greet", description="Greet", func=sync_greet, tags=["text"])
registry.register(t1)
registry.register(t2)
all_tools = registry.list_tools()
assert len(all_tools) == 2
math_tools = registry.list_tools(tag="math")
assert len(math_tools) == 1
assert math_tools[0].name == "add"
@pytest.mark.asyncio
async def test_tool_safe_execute():
async def failing_tool():
raise RuntimeError("boom")
tool = FunctionTool(name="fail", description="Always fails", func=failing_tool)
with pytest.raises(RuntimeError):
await tool.safe_execute()
@pytest.mark.asyncio
async def test_tool_not_found():
registry = ToolRegistry()
from agentkit.core.exceptions import ToolNotFoundError
with pytest.raises(ToolNotFoundError):
registry.get("nonexistent")