From 9a6d6fee4e6f17619f25bde23cb94b103fbf8c10 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 4 Jun 2026 22:24:06 +0800 Subject: [PATCH] feat: initial fischer-agentkit package with unified agent architecture - BaseAgent with handle_task() pattern (execute template moved up) - Protocol: TaskMessage, TaskResult, HandoffMessage, EvolutionEvent - Tool system: FunctionTool, AgentTool, ToolRegistry with versioning - Memory system: WorkingMemory (Redis), EpisodicMemory (pgvector), SemanticMemory (RAG adapter), MemoryRetriever (hybrid) - Evolution engine: Reflector, PromptOptimizer (DSPy-style), StrategyTuner, ABTester, EvolutionStore - Orchestrator: PipelineEngine (parallel DAG), PipelineLoader (YAML), HandoffManager, DynamicPipeline - MCP: Server (FastAPI), Client (httpx), MCPTool - Prompts: PromptTemplate, PromptSection - Exceptions: full hierarchy including Tool, Schema, Handoff, Evolution errors - Tests: unit tests for core, tools, protocol, evolution, pipeline --- pyproject.toml | 48 +++ src/agentkit/__init__.py | 25 ++ src/agentkit/core/__init__.py | 61 +++ src/agentkit/core/base.py | 395 ++++++++++++++++++ src/agentkit/core/dispatcher.py | 361 ++++++++++++++++ src/agentkit/core/exceptions.py | 110 +++++ src/agentkit/core/protocol.py | 245 +++++++++++ src/agentkit/core/registry.py | 250 +++++++++++ src/agentkit/evolution/__init__.py | 17 + src/agentkit/evolution/ab_tester.py | 121 ++++++ src/agentkit/evolution/evolution_store.py | 113 +++++ src/agentkit/evolution/prompt_optimizer.py | 151 +++++++ src/agentkit/evolution/reflector.py | 147 +++++++ src/agentkit/evolution/strategy_tuner.py | 81 ++++ src/agentkit/mcp/__init__.py | 6 + src/agentkit/mcp/client.py | 83 ++++ src/agentkit/mcp/server.py | 86 ++++ src/agentkit/memory/__init__.py | 16 + src/agentkit/memory/base.py | 74 ++++ src/agentkit/memory/episodic.py | 149 +++++++ src/agentkit/memory/retriever.py | 113 +++++ src/agentkit/memory/semantic.py | 94 +++++ src/agentkit/memory/working.py | 97 +++++ src/agentkit/orchestrator/__init__.py | 17 + src/agentkit/orchestrator/dynamic_pipeline.py | 92 ++++ src/agentkit/orchestrator/handoff.py | 69 +++ src/agentkit/orchestrator/pipeline_engine.py | 241 +++++++++++ src/agentkit/orchestrator/pipeline_loader.py | 72 ++++ src/agentkit/orchestrator/pipeline_schema.py | 52 +++ src/agentkit/prompts/__init__.py | 9 + src/agentkit/prompts/section.py | 36 ++ src/agentkit/prompts/template.py | 71 ++++ src/agentkit/tools/__init__.py | 13 + src/agentkit/tools/agent_tool.py | 93 +++++ src/agentkit/tools/base.py | 68 +++ src/agentkit/tools/function_tool.py | 76 ++++ src/agentkit/tools/registry.py | 72 ++++ tests/__init__.py | 0 tests/integration/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_base_agent.py | 139 ++++++ tests/unit/test_evolution.py | 131 ++++++ tests/unit/test_pipeline.py | 109 +++++ tests/unit/test_protocol.py | 95 +++++ tests/unit/test_tool_registry.py | 104 +++++ 45 files changed, 4402 insertions(+) create mode 100644 pyproject.toml create mode 100644 src/agentkit/__init__.py create mode 100644 src/agentkit/core/__init__.py create mode 100644 src/agentkit/core/base.py create mode 100644 src/agentkit/core/dispatcher.py create mode 100644 src/agentkit/core/exceptions.py create mode 100644 src/agentkit/core/protocol.py create mode 100644 src/agentkit/core/registry.py create mode 100644 src/agentkit/evolution/__init__.py create mode 100644 src/agentkit/evolution/ab_tester.py create mode 100644 src/agentkit/evolution/evolution_store.py create mode 100644 src/agentkit/evolution/prompt_optimizer.py create mode 100644 src/agentkit/evolution/reflector.py create mode 100644 src/agentkit/evolution/strategy_tuner.py create mode 100644 src/agentkit/mcp/__init__.py create mode 100644 src/agentkit/mcp/client.py create mode 100644 src/agentkit/mcp/server.py create mode 100644 src/agentkit/memory/__init__.py create mode 100644 src/agentkit/memory/base.py create mode 100644 src/agentkit/memory/episodic.py create mode 100644 src/agentkit/memory/retriever.py create mode 100644 src/agentkit/memory/semantic.py create mode 100644 src/agentkit/memory/working.py create mode 100644 src/agentkit/orchestrator/__init__.py create mode 100644 src/agentkit/orchestrator/dynamic_pipeline.py create mode 100644 src/agentkit/orchestrator/handoff.py create mode 100644 src/agentkit/orchestrator/pipeline_engine.py create mode 100644 src/agentkit/orchestrator/pipeline_loader.py create mode 100644 src/agentkit/orchestrator/pipeline_schema.py create mode 100644 src/agentkit/prompts/__init__.py create mode 100644 src/agentkit/prompts/section.py create mode 100644 src/agentkit/prompts/template.py create mode 100644 src/agentkit/tools/__init__.py create mode 100644 src/agentkit/tools/agent_tool.py create mode 100644 src/agentkit/tools/base.py create mode 100644 src/agentkit/tools/function_tool.py create mode 100644 src/agentkit/tools/registry.py create mode 100644 tests/__init__.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_base_agent.py create mode 100644 tests/unit/test_evolution.py create mode 100644 tests/unit/test_pipeline.py create mode 100644 tests/unit/test_protocol.py create mode 100644 tests/unit/test_tool_registry.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2869d37 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fischer-agentkit" +version = "0.1.0" +description = "Unified Agent Framework with Tool/Skill plugins, Memory, Self-Evolution, MCP support, and Multi-Agent orchestration" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +authors = [ + {name = "Fischer Team"}, +] +dependencies = [ + "pydantic>=2.0", + "redis[hiredis]>=5.0", + "sqlalchemy[asyncio]>=2.0", + "asyncpg>=0.29", + "httpx>=0.27", + "pyyaml>=6.0", + "jsonschema>=4.0", +] + +[project.optional-dependencies] +mcp = [ + "mcp>=1.0", +] +evolution = [ + "scipy>=1.12", +] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "pytest-cov>=5.0", + "ruff>=0.4", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/agentkit"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +target-version = "py311" +line-length = 100 diff --git a/src/agentkit/__init__.py b/src/agentkit/__init__.py new file mode 100644 index 0000000..eda5b75 --- /dev/null +++ b/src/agentkit/__init__.py @@ -0,0 +1,25 @@ +"""Fischer AgentKit - Unified Agent Framework""" + +from agentkit.core.base import BaseAgent +from agentkit.core.protocol import ( + AgentCapability, + AgentStatus, + HandoffMessage, + TaskMessage, + TaskProgress, + TaskResult, + TaskStatus, +) + +__version__ = "0.1.0" + +__all__ = [ + "BaseAgent", + "AgentCapability", + "AgentStatus", + "HandoffMessage", + "TaskMessage", + "TaskProgress", + "TaskResult", + "TaskStatus", +] diff --git a/src/agentkit/core/__init__.py b/src/agentkit/core/__init__.py new file mode 100644 index 0000000..6617fdc --- /dev/null +++ b/src/agentkit/core/__init__.py @@ -0,0 +1,61 @@ +"""AgentKit Core - 基础组件""" + +from agentkit.core.base import BaseAgent +from agentkit.core.exceptions import ( + AgentAlreadyRegisteredError, + AgentFrameworkError, + AgentNotFoundError, + AgentNotReadyError, + AgentUnavailableError, + ConfigValidationError, + EvolutionError, + HandoffError, + NoAvailableAgentError, + SchemaValidationError, + TaskCancelledError, + TaskDispatchError, + TaskExecutionError, + TaskNotFoundError, + TaskTimeoutError, + ToolExecutionError, + ToolNotFoundError, +) +from agentkit.core.protocol import ( + AgentCapability, + AgentStatus, + EvolutionEvent, + HandoffMessage, + TaskMessage, + TaskProgress, + TaskResult, + TaskStatus, +) + +__all__ = [ + "BaseAgent", + "AgentCapability", + "AgentStatus", + "AgentFrameworkError", + "AgentNotFoundError", + "AgentAlreadyRegisteredError", + "AgentUnavailableError", + "AgentNotReadyError", + "TaskNotFoundError", + "TaskDispatchError", + "TaskExecutionError", + "TaskTimeoutError", + "TaskCancelledError", + "NoAvailableAgentError", + "ConfigValidationError", + "SchemaValidationError", + "HandoffError", + "EvolutionError", + "ToolNotFoundError", + "ToolExecutionError", + "HandoffMessage", + "EvolutionEvent", + "TaskMessage", + "TaskProgress", + "TaskResult", + "TaskStatus", +] diff --git a/src/agentkit/core/base.py b/src/agentkit/core/base.py new file mode 100644 index 0000000..135a8d9 --- /dev/null +++ b/src/agentkit/core/base.py @@ -0,0 +1,395 @@ +"""BaseAgent 基类 - 统一 Agent 生命周期管理 + +核心设计: +- execute() 为 final 方法,包含完整的计时、try/except、TaskResult 构建 +- 子类只需实现 handle_task(task) -> dict 返回业务数据 +- 生命周期钩子:on_task_start / on_task_complete / on_task_failed +- 支持 Tool 插件、Memory 系统(可选注入) +""" + +import asyncio +import json +import logging +import time +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +import redis.asyncio as aioredis + +from agentkit.core.exceptions import AgentNotReadyError, SchemaValidationError +from agentkit.core.protocol import ( + AgentCapability, + AgentStatus, + HandoffMessage, + TaskMessage, + TaskProgress, + TaskResult, + TaskStatus, +) + +if TYPE_CHECKING: + from agentkit.memory.base import Memory + from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class BaseAgent(ABC): + """所有 Agent 的基类,定义标准生命周期。 + + 子类只需实现: + - handle_task(task) -> dict: 业务逻辑,返回 output_data + - get_capabilities() -> AgentCapability: 能力声明 + + 可选覆写: + - on_task_start(task): 任务开始前的钩子 + - on_task_complete(task, output): 任务成功后的钩子 + - on_task_failed(task, error): 任务失败后的钩子 + """ + + def __init__(self, name: str, agent_type: str, version: str = "1.0.0"): + self.name = name + self.agent_type = agent_type + self.version = version + self._status: AgentStatus = AgentStatus.OFFLINE + self._redis: aioredis.Redis | None = None + self._redis_url: str = "" + self._running_tasks: set[str] = set() + self._listen_task: asyncio.Task | None = None + self._heartbeat_task: asyncio.Task | None = None + self._semaphore: asyncio.Semaphore | None = None + + # 可插拔能力(由子类或配置注入) + self._tools: list["Tool"] = [] + self._memory: "Memory | None" = None + + # 外部依赖注入(由 start() 时设置) + self._registry = None + self._dispatcher = None + + @property + def status(self) -> AgentStatus: + return self._status + + @property + def is_distributed(self) -> bool: + return self._redis is not None + + @property + def tools(self) -> list["Tool"]: + return self._tools + + @property + def memory(self) -> "Memory | None": + return self._memory + + # ── 抽象方法(子类必须实现) ────────────────────────────── + + @abstractmethod + async def handle_task(self, task: TaskMessage) -> dict: + """执行任务的核心业务逻辑,子类必须实现。 + + 返回 output_data dict,框架自动包装为 TaskResult。 + """ + ... + + @abstractmethod + def get_capabilities(self) -> AgentCapability: + """返回 Agent 能力声明""" + ... + + # ── 生命周期钩子(可选覆写) ────────────────────────────── + + async def on_task_start(self, task: TaskMessage) -> None: + """任务开始前的钩子,可用于加载记忆、准备上下文等""" + pass + + async def on_task_complete(self, task: TaskMessage, output: dict) -> None: + """任务成功后的钩子,可用于存储记忆、触发反思等""" + pass + + async def on_task_failed(self, task: TaskMessage, error: Exception) -> None: + """任务失败后的钩子,可用于记录失败模式等""" + pass + + # ── 可插拔能力注入 ────────────────────────────────────── + + def use_tool(self, tool: "Tool") -> "BaseAgent": + """添加工具到 Agent""" + self._tools.append(tool) + return self + + def use_memory(self, memory: "Memory") -> "BaseAgent": + """设置记忆系统""" + self._memory = memory + return self + + def set_registry(self, registry: Any) -> "BaseAgent": + """注入注册中心""" + self._registry = registry + return self + + def set_dispatcher(self, dispatcher: Any) -> "BaseAgent": + """注入任务分发器""" + self._dispatcher = dispatcher + return self + + # ── 核心生命周期 ────────────────────────────────────────── + + async def start(self, redis_url: str = ""): + """启动 Agent:连接 Redis → 注册 → 心跳 → 监听""" + self._redis_url = redis_url + + logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})") + + if redis_url: + try: + self._redis = aioredis.from_url(redis_url, decode_responses=True) + await self._redis.ping() + logger.info(f"Agent '{self.name}' connected to Redis") + except Exception as e: + self._redis = None + logger.warning(f"Agent '{self.name}' Redis unavailable: {e}, falling back to local mode") + + # 注册到 Registry + if self._registry is not None: + capability = self.get_capabilities() + await self._registry.register(capability, endpoint=f"agent:{self.name}") + + self._status = AgentStatus.ONLINE + + # 设置并发控制 + capability = self.get_capabilities() + max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 + self._semaphore = asyncio.Semaphore(max_concurrency) + + # 启动心跳和监听 + if self._redis is not None: + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + self._listen_task = asyncio.create_task(self._listen_for_tasks()) + + logger.info(f"Agent '{self.name}' started ({'distributed' if self._redis else 'local'} mode)") + + async def stop(self): + """停止 Agent""" + logger.info(f"Stopping agent '{self.name}'") + self._status = AgentStatus.OFFLINE + + for task in [self._listen_task, self._heartbeat_task]: + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if self._redis is not None: + if self._registry is not None: + await self._registry.unregister(self.name) + await self._redis.close() + self._redis = None + + logger.info(f"Agent '{self.name}' stopped") + + # ── execute 为 final 方法 ───────────────────────────────── + + async def execute(self, task: TaskMessage) -> TaskResult: + """执行任务(框架方法,不可覆写)。 + + 完整流程:on_task_start → handle_task → on_task_complete/on_task_failed + 自动处理计时、TaskResult 构建、错误捕获。 + """ + started_at = datetime.now(timezone.utc) + start_time = time.monotonic() + + try: + # 前置钩子 + await self.on_task_start(task) + + # Schema 校验(如果 Agent 声明了 input_schema) + capability = self.get_capabilities() + if capability.input_schema: + self._validate_input(task.input_data, capability.input_schema) + + # 执行业务逻辑 + output = await self.handle_task(task) + + # 后置钩子 + await self.on_task_complete(task, output) + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=output, + error_message=None, + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except Exception as e: + logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") + + # 失败钩子 + try: + await self.on_task_failed(task, e) + except Exception as hook_err: + logger.error(f"on_task_failed hook error: {hook_err}") + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + "error_type": type(e).__name__, + }, + ) + + # ── Handoff ─────────────────────────────────────────────── + + async def handoff(self, target_agent: str, task: TaskMessage, reason: str, context: dict[str, Any] | None = None): + """将当前任务转交给另一个 Agent""" + if self._redis is None: + raise RuntimeError("Handoff requires Redis connection") + + handoff_msg = HandoffMessage( + source_agent=self.name, + target_agent=target_agent, + task_id=task.task_id, + task_type=task.task_type, + context=context or task.input_data, + reason=reason, + ) + + # 发布到目标 Agent 的 handoff 频道 + await self._redis.publish( + f"agent:{target_agent}:handoff", + json.dumps(handoff_msg.to_dict()), + ) + + logger.info(f"Agent '{self.name}' handed off task {task.task_id} to '{target_agent}': {reason}") + + # ── 进度上报 ────────────────────────────────────────────── + + async def report_progress(self, task_id: str, progress: float, message: str): + progress_obj = TaskProgress( + task_id=task_id, + agent_name=self.name, + progress=progress, + message=message, + updated_at=datetime.now(timezone.utc), + ) + + if self._redis: + try: + await self._redis.publish( + f"agent:{self.name}:progress", + json.dumps(progress_obj.to_dict()), + ) + except Exception as e: + logger.warning(f"Failed to publish progress for task {task_id}: {e}") + + if self._dispatcher is not None: + try: + await self._dispatcher.handle_progress(progress_obj) + except Exception as e: + logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}") + + # ── 内部方法 ────────────────────────────────────────────── + + async def heartbeat(self): + if self._registry is not None: + await self._registry.update_heartbeat(self.name) + + async def _heartbeat_loop(self): + try: + while self._status == AgentStatus.ONLINE: + await self.heartbeat() + await asyncio.sleep(30) + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Heartbeat error for agent '{self.name}': {e}") + + async def _listen_for_tasks(self): + try: + queue_key = f"agent:{self.name}:tasks" + while self._status == AgentStatus.ONLINE: + if not self._redis: + await asyncio.sleep(1) + continue + + result = await self._redis.brpop(queue_key, timeout=1) + if result: + _, task_json = result + try: + task_data = json.loads(task_json) + task = TaskMessage.from_dict(task_data) + asyncio.create_task(self._execute_task_with_semaphore(task)) + except Exception as e: + logger.error(f"Failed to parse task message: {e}") + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Task listener error for agent '{self.name}': {e}") + + async def _execute_task_with_semaphore(self, task: TaskMessage): + if self._semaphore is None: + await self._execute_task(task) + return + async with self._semaphore: + await self._execute_task(task) + + async def _execute_task(self, task: TaskMessage): + self._running_tasks.add(task.task_id) + self._status = AgentStatus.BUSY + + try: + logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})") + result = await self.execute(task) + + if self._redis is not None and self._dispatcher is not None: + await self._dispatcher.handle_result(result) + + except Exception as e: + logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") + error_result = TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + metrics=None, + ) + if self._redis is not None and self._dispatcher is not None: + await self._dispatcher.handle_result(error_result) + + finally: + self._running_tasks.discard(task.task_id) + if not self._running_tasks: + self._status = AgentStatus.ONLINE + + def _validate_input(self, data: dict, schema: dict) -> None: + """校验输入数据是否符合 JSON Schema""" + try: + import jsonschema + jsonschema.validate(data, schema) + except ImportError: + logger.warning("jsonschema not installed, skipping input validation") + except Exception as e: + raise SchemaValidationError(self.name, str(e)) diff --git a/src/agentkit/core/dispatcher.py b/src/agentkit/core/dispatcher.py new file mode 100644 index 0000000..f96a5d0 --- /dev/null +++ b/src/agentkit/core/dispatcher.py @@ -0,0 +1,361 @@ +"""任务分发器 - 通过 Redis Queue 将任务分发给 Agent + +与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。 +""" + +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Callable, Awaitable + +from agentkit.core.exceptions import ( + NoAvailableAgentError, + TaskDispatchError, + TaskNotFoundError, +) +from agentkit.core.protocol import ( + AgentStatus, + TaskMessage, + TaskProgress, + TaskResult, + TaskStatus, +) + +logger = logging.getLogger(__name__) + + +class TaskDispatcher: + """任务分发器,通过 Redis Queue 将任务分发给 Agent""" + + def __init__( + self, + redis_factory: Callable[[], Awaitable[Any]], + session_factory: Callable[[], Any], + agent_model: Any, + task_model: Any, + task_log_model: Any, + ): + """ + Args: + redis_factory: 返回 Redis 连接的异步工厂 + session_factory: 返回 async context manager 的工厂,用于获取数据库会话 + agent_model: Agent ORM 模型类 + task_model: AgentTask ORM 模型类 + task_log_model: AgentTaskLog ORM 模型类 + """ + self._redis_factory = redis_factory + self._session_factory = session_factory + self._agent_model = agent_model + self._task_model = task_model + self._task_log_model = task_log_model + + async def _get_redis(self): + return await self._redis_factory() + + async def dispatch( + self, + task: TaskMessage, + organization_id: str | None = None, + created_by: str | None = None, + ) -> str: + """分发任务到对应 Agent 的队列,返回 task_id""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + + AgentModel = self._agent_model + TaskModel = self._task_model + + # 查找目标 Agent + stmt = select(AgentModel).where(AgentModel.name == task.agent_name) + result = await db.execute(stmt) + agent = result.scalar_one_or_none() + + if not agent: + raise TaskDispatchError(task.task_id, f"Agent '{task.agent_name}' not found") + + if agent.status != AgentStatus.ONLINE: + raise TaskDispatchError( + task.task_id, + f"Agent '{task.agent_name}' is not online (status={agent.status})", + ) + + # 写入 agent_tasks 表 + task_id_uuid = uuid.UUID(task.task_id) + agent_task = TaskModel( + id=task_id_uuid, + agent_id=agent.id, + task_type=task.task_type, + status=TaskStatus.PENDING, + priority=task.priority, + input_data=task.input_data, + organization_id=uuid.UUID(organization_id) if organization_id else agent.id, + created_by=uuid.UUID(created_by) if created_by else None, + ) + db.add(agent_task) + await db.commit() + + # 推送到 Redis Queue + redis = await self._get_redis() + queue_key = f"agent:{task.agent_name}:tasks" + await redis.lpush(queue_key, json.dumps(task.to_dict())) + + logger.info( + f"Task {task.task_id} dispatched to agent '{task.agent_name}' " + f"(type={task.task_type}, priority={task.priority})" + ) + return task.task_id + + except TaskDispatchError: + raise + except Exception as e: + await db.rollback() + logger.error(f"Failed to dispatch task {task.task_id}: {e}") + raise TaskDispatchError(task.task_id, str(e)) + + async def cancel_task(self, task_id: str): + """取消任务""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + + TaskModel = self._task_model + task_uuid = uuid.UUID(task_id) + stmt = select(TaskModel).where(TaskModel.id == task_uuid) + result = await db.execute(stmt) + task = result.scalar_one_or_none() + + if not task: + raise TaskNotFoundError(task_id) + + if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): + logger.warning(f"Cannot cancel task {task_id} with status '{task.status}'") + return + + task.status = TaskStatus.CANCELLED + task.completed_at = datetime.now(timezone.utc) + await db.commit() + + await self._write_log( + db, task_id=task_id, agent_id=str(task.agent_id), + log_level="info", message="Task cancelled by user", + ) + await db.commit() + + logger.info(f"Task {task_id} cancelled") + + except TaskNotFoundError: + raise + except Exception as e: + await db.rollback() + logger.error(f"Failed to cancel task {task_id}: {e}") + raise + + async def get_task_status(self, task_id: str) -> dict: + """获取任务状态""" + async with self._session_factory() as db: + from sqlalchemy import select + + TaskModel = self._task_model + task_uuid = uuid.UUID(task_id) + stmt = select(TaskModel).where(TaskModel.id == task_uuid) + result = await db.execute(stmt) + task = result.scalar_one_or_none() + + if not task: + raise TaskNotFoundError(task_id) + + return self._task_to_dict(task) + + async def handle_result(self, result: TaskResult): + """处理 Agent 返回的结果""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + + TaskModel = self._task_model + task_uuid = uuid.UUID(result.task_id) + stmt = select(TaskModel).where(TaskModel.id == task_uuid) + db_result = await db.execute(stmt) + task = db_result.scalar_one_or_none() + + if not task: + logger.error(f"Task {result.task_id} not found when handling result") + return + + task.status = result.status + task.output_data = result.output_data + task.error_message = result.error_message + task.started_at = result.started_at + task.completed_at = result.completed_at + await db.commit() + + log_level = "info" if result.status == TaskStatus.COMPLETED else "error" + log_message = ( + f"Task {result.status}" + if result.status == TaskStatus.COMPLETED + else f"Task failed: {result.error_message}" + ) + await self._write_log( + db, + task_id=result.task_id, + agent_id=str(task.agent_id), + log_level=log_level, + message=log_message, + extra_metadata=result.metrics, + ) + await db.commit() + + # 触发回调 + if result.output_data and result.output_data.get("callback_url"): + await self._trigger_callback(result.output_data["callback_url"], result) + + logger.info(f"Task {result.task_id} result handled (status={result.status})") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to handle result for task {result.task_id}: {e}") + + async def handle_progress(self, progress: TaskProgress): + """处理进度上报""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + + AgentModel = self._agent_model + stmt = select(AgentModel).where(AgentModel.name == progress.agent_name) + result = await db.execute(stmt) + agent = result.scalar_one_or_none() + + if not agent: + logger.warning(f"Agent '{progress.agent_name}' not found for progress report") + return + + await self._write_log( + db, + task_id=progress.task_id, + agent_id=str(agent.id), + log_level="info", + message=f"Progress: {progress.progress:.0%} - {progress.message}", + extra_metadata={ + "progress": progress.progress, + "updated_at": progress.updated_at.isoformat(), + }, + ) + await db.commit() + + except Exception as e: + await db.rollback() + logger.error(f"Failed to handle progress for task {progress.task_id}: {e}") + + async def retry_failed_tasks(self, max_retries: int = 3): + """重试失败的任务""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + + TaskModel = self._task_model + AgentModel = self._agent_model + LogModel = self._task_log_model + + stmt = select(TaskModel).where(TaskModel.status == TaskStatus.FAILED) + result = await db.execute(stmt) + failed_tasks = result.scalars().all() + + retried = 0 + for task in failed_tasks: + log_stmt = select(LogModel).where( + LogModel.task_id == task.id, + LogModel.message.like("%retry%"), + ) + log_result = await db.execute(log_stmt) + retry_count = len(log_result.scalars().all()) + + if retry_count < max_retries: + task.status = TaskStatus.PENDING + task.error_message = None + task.started_at = None + task.completed_at = None + + agent_stmt = select(AgentModel).where(AgentModel.id == task.agent_id) + agent_result = await db.execute(agent_stmt) + agent = agent_result.scalar_one_or_none() + + if agent and agent.status == AgentStatus.ONLINE: + task_msg = TaskMessage( + task_id=str(task.id), + agent_name=agent.name, + task_type=task.task_type, + priority=task.priority, + input_data=task.input_data or {}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + redis = await self._get_redis() + queue_key = f"agent:{agent.name}:tasks" + await redis.lpush(queue_key, json.dumps(task_msg.to_dict())) + + await self._write_log( + db, + task_id=str(task.id), + agent_id=str(agent.id), + log_level="info", + message=f"Task retry attempt {retry_count + 1}/{max_retries}", + ) + retried += 1 + + await db.commit() + if retried > 0: + logger.info(f"Retried {retried} failed tasks") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to retry failed tasks: {e}") + + async def _write_log( + self, + db: Any, + task_id: str, + agent_id: str, + log_level: str, + message: str, + extra_metadata: dict | None = None, + ): + LogModel = self._task_log_model + log_entry = LogModel( + task_id=uuid.UUID(task_id), + agent_id=uuid.UUID(agent_id), + log_level=log_level, + message=message, + extra_metadata=extra_metadata, + ) + db.add(log_entry) + + async def _trigger_callback(self, callback_url: str, result: TaskResult): + try: + import httpx + async with httpx.AsyncClient(timeout=10) as client: + await client.post(callback_url, json=result.to_dict()) + logger.info(f"Callback triggered for task {result.task_id}") + except Exception as e: + logger.warning(f"Callback failed for task {result.task_id}: {e}") + + def _task_to_dict(self, task: Any) -> dict: + return { + "id": str(task.id), + "agent_id": str(task.agent_id), + "task_type": task.task_type, + "status": task.status, + "priority": task.priority, + "input_data": task.input_data, + "output_data": task.output_data, + "error_message": task.error_message, + "created_by": str(task.created_by) if task.created_by else None, + "organization_id": str(task.organization_id), + "project_id": str(task.project_id) if task.project_id else None, + "scheduled_at": task.scheduled_at.isoformat() if task.scheduled_at else None, + "started_at": task.started_at.isoformat() if task.started_at else None, + "completed_at": task.completed_at.isoformat() if task.completed_at else None, + "created_at": task.created_at.isoformat() if task.created_at else None, + } diff --git a/src/agentkit/core/exceptions.py b/src/agentkit/core/exceptions.py new file mode 100644 index 0000000..4d417c6 --- /dev/null +++ b/src/agentkit/core/exceptions.py @@ -0,0 +1,110 @@ +"""Agent 框架自定义异常""" + + +class AgentFrameworkError(Exception): + """Agent 框架基础异常""" + + def __init__(self, message: str = "Agent framework error"): + self.message = message + super().__init__(self.message) + + +class AgentNotFoundError(AgentFrameworkError): + def __init__(self, agent_name: str): + self.agent_name = agent_name + super().__init__(f"Agent not found: {agent_name}") + + +class AgentAlreadyRegisteredError(AgentFrameworkError): + def __init__(self, agent_name: str): + self.agent_name = agent_name + super().__init__(f"Agent already registered: {agent_name}") + + +class AgentUnavailableError(AgentFrameworkError): + def __init__(self, agent_name: str, status: str = "offline"): + self.agent_name = agent_name + self.status = status + super().__init__(f"Agent '{agent_name}' is unavailable (status: {status})") + + +class TaskNotFoundError(AgentFrameworkError): + def __init__(self, task_id: str): + self.task_id = task_id + super().__init__(f"Task not found: {task_id}") + + +class TaskDispatchError(AgentFrameworkError): + def __init__(self, task_id: str, reason: str = ""): + self.task_id = task_id + super().__init__(f"Task dispatch failed for {task_id}: {reason}") + + +class TaskExecutionError(AgentFrameworkError): + def __init__(self, task_id: str, agent_name: str, reason: str = ""): + self.task_id = task_id + self.agent_name = agent_name + super().__init__(f"Task {task_id} execution failed on agent '{agent_name}': {reason}") + + +class TaskTimeoutError(AgentFrameworkError): + def __init__(self, task_id: str, timeout_seconds: int): + self.task_id = task_id + self.timeout_seconds = timeout_seconds + super().__init__(f"Task {task_id} timed out after {timeout_seconds}s") + + +class TaskCancelledError(AgentFrameworkError): + def __init__(self, task_id: str): + self.task_id = task_id + super().__init__(f"Task {task_id} was cancelled") + + +class NoAvailableAgentError(AgentFrameworkError): + def __init__(self, task_type: str): + self.task_type = task_type + super().__init__(f"No available agent for task type: {task_type}") + + +class ConfigValidationError(AgentFrameworkError): + def __init__(self, agent_name: str, key: str, reason: str = ""): + self.agent_name = agent_name + self.key = key + super().__init__(f"Config validation failed for agent '{agent_name}' key '{key}': {reason}") + + +class AgentNotReadyError(AgentFrameworkError): + def __init__(self, agent_name: str): + self.agent_name = agent_name + super().__init__(f"Agent '{agent_name}' is not ready") + + +class ToolNotFoundError(AgentFrameworkError): + def __init__(self, tool_name: str): + self.tool_name = tool_name + super().__init__(f"Tool not found: {tool_name}") + + +class ToolExecutionError(AgentFrameworkError): + def __init__(self, tool_name: str, reason: str = ""): + self.tool_name = tool_name + super().__init__(f"Tool '{tool_name}' execution failed: {reason}") + + +class SchemaValidationError(AgentFrameworkError): + def __init__(self, agent_name: str, detail: str = ""): + self.agent_name = agent_name + super().__init__(f"Schema validation failed for agent '{agent_name}': {detail}") + + +class HandoffError(AgentFrameworkError): + def __init__(self, source: str, target: str, reason: str = ""): + self.source = source + self.target = target + super().__init__(f"Handoff from '{source}' to '{target}' failed: {reason}") + + +class EvolutionError(AgentFrameworkError): + def __init__(self, agent_name: str, reason: str = ""): + self.agent_name = agent_name + super().__init__(f"Evolution failed for agent '{agent_name}': {reason}") diff --git a/src/agentkit/core/protocol.py b/src/agentkit/core/protocol.py new file mode 100644 index 0000000..8316e52 --- /dev/null +++ b/src/agentkit/core/protocol.py @@ -0,0 +1,245 @@ +"""Agent 通信协议定义 - 统一消息格式""" + +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + + +class TaskStatus(str, Enum): + """任务状态枚举""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + HANDOFF = "handoff" + + +class AgentStatus(str, Enum): + """Agent 状态枚举""" + ONLINE = "online" + OFFLINE = "offline" + BUSY = "busy" + + +@dataclass +class AgentCapability: + """Agent 能力声明""" + agent_name: str + agent_type: str + version: str + supported_tasks: list[str] + max_concurrency: int + description: str + input_schema: dict[str, Any] | None = None + output_schema: dict[str, Any] | None = None + + def to_dict(self) -> dict: + d = { + "agent_name": self.agent_name, + "agent_type": self.agent_type, + "version": self.version, + "supported_tasks": self.supported_tasks, + "max_concurrency": self.max_concurrency, + "description": self.description, + } + if self.input_schema is not None: + d["input_schema"] = self.input_schema + if self.output_schema is not None: + d["output_schema"] = self.output_schema + return d + + @classmethod + def from_dict(cls, data: dict) -> "AgentCapability": + return cls( + agent_name=data["agent_name"], + agent_type=data["agent_type"], + version=data["version"], + supported_tasks=data["supported_tasks"], + max_concurrency=data["max_concurrency"], + description=data["description"], + input_schema=data.get("input_schema"), + output_schema=data.get("output_schema"), + ) + + +@dataclass +class TaskMessage: + """任务消息 - 从调度器发往 Agent""" + task_id: str + agent_name: str + task_type: str + priority: int + input_data: dict + callback_url: str | None + created_at: datetime + timeout_seconds: int = 300 + conversation_id: str | None = None + + def to_dict(self) -> dict: + return { + "task_id": self.task_id, + "agent_name": self.agent_name, + "task_type": self.task_type, + "priority": self.priority, + "input_data": self.input_data, + "callback_url": self.callback_url, + "created_at": self.created_at.isoformat() if self.created_at else None, + "timeout_seconds": self.timeout_seconds, + "conversation_id": self.conversation_id, + } + + @classmethod + def from_dict(cls, data: dict) -> "TaskMessage": + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + return cls( + task_id=data["task_id"], + agent_name=data["agent_name"], + task_type=data["task_type"], + priority=data.get("priority", 0), + input_data=data.get("input_data", {}), + callback_url=data.get("callback_url"), + created_at=created_at or datetime.utcnow(), + timeout_seconds=data.get("timeout_seconds", 300), + conversation_id=data.get("conversation_id"), + ) + + +@dataclass +class TaskResult: + """任务结果 - 从 Agent 返回""" + task_id: str + agent_name: str + status: str + output_data: dict | None + error_message: str | None + started_at: datetime + completed_at: datetime + metrics: dict | None = None + + def to_dict(self) -> dict: + return { + "task_id": self.task_id, + "agent_name": self.agent_name, + "status": self.status, + "output_data": self.output_data, + "error_message": self.error_message, + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "metrics": self.metrics, + } + + @classmethod + def from_dict(cls, data: dict) -> "TaskResult": + started_at = data.get("started_at") + if isinstance(started_at, str): + started_at = datetime.fromisoformat(started_at) + completed_at = data.get("completed_at") + if isinstance(completed_at, str): + completed_at = datetime.fromisoformat(completed_at) + return cls( + task_id=data["task_id"], + agent_name=data["agent_name"], + status=data["status"], + output_data=data.get("output_data"), + error_message=data.get("error_message"), + started_at=started_at or datetime.utcnow(), + completed_at=completed_at or datetime.utcnow(), + metrics=data.get("metrics"), + ) + + +@dataclass +class TaskProgress: + """进度上报 - Agent 执行过程中上报""" + task_id: str + agent_name: str + progress: float + message: str + updated_at: datetime + + def to_dict(self) -> dict: + return { + "task_id": self.task_id, + "agent_name": self.agent_name, + "progress": self.progress, + "message": self.message, + "updated_at": self.updated_at.isoformat() if self.updated_at else None, + } + + @classmethod + def from_dict(cls, data: dict) -> "TaskProgress": + updated_at = data.get("updated_at") + if isinstance(updated_at, str): + updated_at = datetime.fromisoformat(updated_at) + return cls( + task_id=data["task_id"], + agent_name=data["agent_name"], + progress=data.get("progress", 0.0), + message=data.get("message", ""), + updated_at=updated_at or datetime.utcnow(), + ) + + +@dataclass +class HandoffMessage: + """任务转交消息 - Agent 间 Handoff""" + source_agent: str + target_agent: str + task_id: str + task_type: str + context: dict[str, Any] + reason: str + created_at: datetime = field(default_factory=lambda: datetime.utcnow()) + + def to_dict(self) -> dict: + return { + "source_agent": self.source_agent, + "target_agent": self.target_agent, + "task_id": self.task_id, + "task_type": self.task_type, + "context": self.context, + "reason": self.reason, + "created_at": self.created_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict) -> "HandoffMessage": + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = datetime.fromisoformat(created_at) + return cls( + source_agent=data["source_agent"], + target_agent=data["target_agent"], + task_id=data["task_id"], + task_type=data["task_type"], + context=data.get("context", {}), + reason=data["reason"], + created_at=created_at or datetime.utcnow(), + ) + + +@dataclass +class EvolutionEvent: + """进化事件 - 记录 Agent 的自我进化变更""" + agent_name: str + change_type: str # prompt / strategy / pipeline + before: dict[str, Any] + after: dict[str, Any] + metrics: dict[str, Any] | None = None + event_id: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.utcnow()) + + def to_dict(self) -> dict: + return { + "agent_name": self.agent_name, + "change_type": self.change_type, + "before": self.before, + "after": self.after, + "metrics": self.metrics, + "event_id": self.event_id, + "created_at": self.created_at.isoformat(), + } diff --git a/src/agentkit/core/registry.py b/src/agentkit/core/registry.py new file mode 100644 index 0000000..8bdf5ba --- /dev/null +++ b/src/agentkit/core/registry.py @@ -0,0 +1,250 @@ +"""Agent 注册中心 - 管理 Agent 的注册、发现、状态 + +与业务系统解耦:通过 session_factory 注入数据库会话, +通过 agent_model_factory 注入 ORM 模型。 +""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Awaitable + +from agentkit.core.exceptions import ( + AgentNotFoundError, + AgentUnavailableError, + NoAvailableAgentError, +) +from agentkit.core.protocol import AgentCapability, AgentStatus + +logger = logging.getLogger(__name__) + +HEARTBEAT_TIMEOUT_SECONDS = 90 + + +class AgentRegistry: + """Agent 注册中心,管理 Agent 的注册、发现、状态 + + 使用依赖注入模式,不依赖具体的 ORM 模型或数据库连接。 + """ + + def __init__( + self, + session_factory: Callable[[], Any], + agent_model: Any, + load_balancer: str = "round_robin", + ): + """ + Args: + session_factory: 返回 async context manager 的工厂,用于获取数据库会话 + agent_model: Agent ORM 模型类 + load_balancer: 负载均衡策略 (round_robin / least_tasks / random) + """ + self._session_factory = session_factory + self._agent_model = agent_model + self._load_balancer = load_balancer + self._round_robin_index: dict[str, int] = {} + + async def register(self, capability: AgentCapability, endpoint: str) -> str: + """注册 Agent,返回 agent_id。同名 Agent 已存在则更新。""" + async with self._session_factory() as db: + try: + Model = self._agent_model + stmt = type(db).execute.__self__.__class__ # placeholder + + # 尝试查找已有记录 + from sqlalchemy import select + stmt = select(Model).where(Model.name == capability.agent_name) + result = await db.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + existing.agent_type = capability.agent_type + existing.version = capability.version + existing.endpoint = endpoint + existing.description = capability.description + existing.capabilities = capability.to_dict() + existing.status = AgentStatus.ONLINE + existing.last_heartbeat = datetime.now(timezone.utc) + await db.commit() + await db.refresh(existing) + agent_id = existing.id + logger.info(f"Agent '{capability.agent_name}' re-registered (id={agent_id})") + else: + agent = Model( + name=capability.agent_name, + display_name=capability.agent_name.replace("_", " ").title(), + agent_type=capability.agent_type, + description=capability.description, + version=capability.version, + endpoint=endpoint, + status=AgentStatus.ONLINE, + capabilities=capability.to_dict(), + last_heartbeat=datetime.now(timezone.utc), + ) + db.add(agent) + await db.commit() + await db.refresh(agent) + agent_id = agent.id + logger.info(f"Agent '{capability.agent_name}' registered (id={agent_id})") + + return str(agent_id) + + except Exception as e: + await db.rollback() + logger.error(f"Failed to register agent '{capability.agent_name}': {e}") + raise + + async def unregister(self, agent_name: str): + """注销 Agent(设置状态为 offline)""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + Model = self._agent_model + stmt = select(Model).where(Model.name == agent_name) + result = await db.execute(stmt) + agent = result.scalar_one_or_none() + + if not agent: + logger.warning(f"Attempted to unregister non-existent agent '{agent_name}'") + return + + agent.status = AgentStatus.OFFLINE + await db.commit() + logger.info(f"Agent '{agent_name}' unregistered") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to unregister agent '{agent_name}': {e}") + raise + + async def update_heartbeat(self, agent_name: str): + """更新心跳时间""" + async with self._session_factory() as db: + try: + from sqlalchemy import update + Model = self._agent_model + stmt = ( + update(Model) + .where(Model.name == agent_name) + .values( + last_heartbeat=datetime.now(timezone.utc), + status=AgentStatus.ONLINE, + ) + ) + await db.execute(stmt) + await db.commit() + except Exception as e: + await db.rollback() + logger.error(f"Failed to update heartbeat for agent '{agent_name}': {e}") + + async def get_agent(self, agent_name: str) -> dict | None: + """获取 Agent 信息""" + async with self._session_factory() as db: + from sqlalchemy import select + Model = self._agent_model + stmt = select(Model).where(Model.name == agent_name) + result = await db.execute(stmt) + agent = result.scalar_one_or_none() + + if not agent: + return None + + return self._agent_to_dict(agent) + + async def list_agents( + self, + agent_type: str | None = None, + status: str | None = None, + ) -> list[dict]: + """列出 Agent,支持按类型和状态筛选""" + async with self._session_factory() as db: + from sqlalchemy import select + Model = self._agent_model + stmt = select(Model) + if agent_type: + stmt = stmt.where(Model.agent_type == agent_type) + if status: + stmt = stmt.where(Model.status == status) + stmt = stmt.order_by(Model.created_at.desc()) + + result = await db.execute(stmt) + agents = result.scalars().all() + + return [self._agent_to_dict(a) for a in agents] + + async def get_available_agent(self, task_type: str) -> str | None: + """根据任务类型找到可用 Agent(支持负载均衡)""" + async with self._session_factory() as db: + from sqlalchemy import select + Model = self._agent_model + stmt = select(Model).where(Model.status == AgentStatus.ONLINE) + result = await db.execute(stmt) + agents = result.scalars().all() + + candidates = [] + for agent in agents: + capabilities = agent.capabilities or {} + supported_tasks = capabilities.get("supported_tasks", []) + if task_type in supported_tasks: + candidates.append(agent) + + if not candidates: + return None + + # 负载均衡选择 + if self._load_balancer == "round_robin": + idx = self._round_robin_index.get(task_type, 0) + selected = candidates[idx % len(candidates)] + self._round_robin_index[task_type] = idx + 1 + return selected.name + elif self._load_balancer == "random": + import random + return random.choice(candidates).name + else: # least_tasks 或默认:返回第一个 + return candidates[0].name + + async def check_health(self): + """检查所有 Agent 健康状态,超时标记为 offline""" + async with self._session_factory() as db: + try: + from sqlalchemy import update + Model = self._agent_model + timeout_threshold = datetime.now(timezone.utc) - timedelta( + seconds=HEARTBEAT_TIMEOUT_SECONDS + ) + + stmt = ( + update(Model) + .where( + Model.status == AgentStatus.ONLINE, + Model.last_heartbeat < timeout_threshold, + ) + .values(status=AgentStatus.OFFLINE) + ) + result = await db.execute(stmt) + await db.commit() + + if result.rowcount > 0: + logger.warning( + f"Marked {result.rowcount} agent(s) as offline due to heartbeat timeout" + ) + + except Exception as e: + await db.rollback() + logger.error(f"Failed to check agent health: {e}") + + def _agent_to_dict(self, agent: Any) -> dict: + """将 Agent ORM 对象转换为字典""" + return { + "id": str(agent.id), + "name": agent.name, + "display_name": agent.display_name, + "agent_type": agent.agent_type, + "description": agent.description, + "version": agent.version, + "endpoint": agent.endpoint, + "status": agent.status, + "capabilities": agent.capabilities, + "last_heartbeat": agent.last_heartbeat.isoformat() if agent.last_heartbeat else None, + "created_at": agent.created_at.isoformat() if agent.created_at else None, + "updated_at": agent.updated_at.isoformat() if agent.updated_at else None, + } diff --git a/src/agentkit/evolution/__init__.py b/src/agentkit/evolution/__init__.py new file mode 100644 index 0000000..e1f3d15 --- /dev/null +++ b/src/agentkit/evolution/__init__.py @@ -0,0 +1,17 @@ +"""AgentKit Evolution - 自我进化引擎""" + +from agentkit.evolution.reflector import Reflector +from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module +from agentkit.evolution.strategy_tuner import StrategyTuner +from agentkit.evolution.ab_tester import ABTester +from agentkit.evolution.evolution_store import EvolutionStore + +__all__ = [ + "Reflector", + "PromptOptimizer", + "Signature", + "Module", + "StrategyTuner", + "ABTester", + "EvolutionStore", +] diff --git a/src/agentkit/evolution/ab_tester.py b/src/agentkit/evolution/ab_tester.py new file mode 100644 index 0000000..7616fe3 --- /dev/null +++ b/src/agentkit/evolution/ab_tester.py @@ -0,0 +1,121 @@ +"""ABTester - A/B 测试框架 + +支持配置分流比例,自动收集效果指标,统计显著性检验。 +""" + +import logging +import math +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ABTestConfig: + """A/B 测试配置""" + test_id: str + agent_name: str + change_type: str # prompt / strategy / pipeline + control_ratio: float = 0.8 # 对照组比例 + min_samples: int = 30 # 最小样本量 + confidence_level: float = 0.95 # 置信度 + status: str = "running" # running / completed / rolled_back + + +@dataclass +class ABTestResult: + """A/B 测试结果""" + test_id: str + control_metric: float + experiment_metric: float + control_samples: int + experiment_samples: int + is_significant: bool + winner: str | None # control / experiment / None + p_value: float | None = None + + +class ABTester: + """A/B 测试框架""" + + def __init__(self): + self._tests: dict[str, ABTestConfig] = {} + self._results: dict[str, list[tuple[str, float]]] = {} # test_id -> [(group, metric)] + + def create_test(self, config: ABTestConfig) -> None: + """创建 A/B 测试""" + self._tests[config.test_id] = config + self._results[config.test_id] = [] + logger.info(f"A/B test '{config.test_id}' created for agent '{config.agent_name}'") + + def assign_group(self, test_id: str) -> str: + """分配测试组""" + import random + config = self._tests.get(test_id) + if not config: + return "control" + + return "control" if random.random() < config.control_ratio else "experiment" + + def record_result(self, test_id: str, group: str, metric: float) -> None: + """记录测试结果""" + if test_id not in self._results: + self._results[test_id] = [] + self._results[test_id].append((group, metric)) + + async def evaluate(self, test_id: str) -> ABTestResult | None: + """评估 A/B 测试结果""" + config = self._tests.get(test_id) + if not config: + return None + + results = self._results.get(test_id, []) + control_metrics = [m for g, m in results if g == "control"] + experiment_metrics = [m for g, m in results if g == "experiment"] + + if len(control_metrics) < config.min_samples or len(experiment_metrics) < config.min_samples: + return ABTestResult( + test_id=test_id, + control_metric=sum(control_metrics) / len(control_metrics) if control_metrics else 0, + experiment_metric=sum(experiment_metrics) / len(experiment_metrics) if experiment_metrics else 0, + control_samples=len(control_metrics), + experiment_samples=len(experiment_metrics), + is_significant=False, + winner=None, + ) + + # 简单 t-test + control_mean = sum(control_metrics) / len(control_metrics) + experiment_mean = sum(experiment_metrics) / len(experiment_metrics) + + control_var = sum((m - control_mean) ** 2 for m in control_metrics) / (len(control_metrics) - 1) + experiment_var = sum((m - experiment_mean) ** 2 for m in experiment_metrics) / (len(experiment_metrics) - 1) + + pooled_se = math.sqrt(control_var / len(control_metrics) + experiment_var / len(experiment_metrics)) + t_stat = (experiment_mean - control_mean) / pooled_se if pooled_se > 0 else 0 + + # 近似 p-value (双侧) + p_value = 2 * (1 - self._normal_cdf(abs(t_stat))) + is_significant = p_value < (1 - config.confidence_level) + + winner = None + if is_significant: + winner = "experiment" if experiment_mean > control_mean else "control" + + return ABTestResult( + test_id=test_id, + control_metric=control_mean, + experiment_metric=experiment_mean, + control_samples=len(control_metrics), + experiment_samples=len(experiment_metrics), + is_significant=is_significant, + winner=winner, + p_value=p_value, + ) + + @staticmethod + def _normal_cdf(x: float) -> float: + """标准正态分布 CDF 近似""" + return 0.5 * (1 + math.erf(x / math.sqrt(2))) diff --git a/src/agentkit/evolution/evolution_store.py b/src/agentkit/evolution/evolution_store.py new file mode 100644 index 0000000..74ce22f --- /dev/null +++ b/src/agentkit/evolution/evolution_store.py @@ -0,0 +1,113 @@ +"""EvolutionStore - 进化日志存储""" + +import logging +from datetime import datetime +from typing import Any + +from agentkit.core.protocol import EvolutionEvent + +logger = logging.getLogger(__name__) + + +class EvolutionStore: + """进化日志存储 + + 记录 Agent 的自我进化变更,支持回滚。 + """ + + def __init__(self, session_factory: Any, evolution_model: Any): + self._session_factory = session_factory + self._evolution_model = evolution_model + + async def record(self, event: EvolutionEvent) -> str: + """记录进化事件""" + async with self._session_factory() as db: + try: + import uuid + Model = self._evolution_model + entry = Model( + id=uuid.uuid4(), + agent_name=event.agent_name, + change_type=event.change_type, + before=event.before, + after=event.after, + metrics=event.metrics, + status="active", + ) + db.add(entry) + await db.commit() + await db.refresh(entry) + event_id = str(entry.id) + event.event_id = event_id + logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'") + return event_id + except Exception as e: + await db.rollback() + logger.error(f"Failed to record evolution event: {e}") + raise + + async def rollback(self, event_id: str) -> bool: + """回滚进化事件""" + async with self._session_factory() as db: + try: + import uuid + from sqlalchemy import select + Model = self._evolution_model + + stmt = select(Model).where(Model.id == uuid.UUID(event_id)) + result = await db.execute(stmt) + entry = result.scalar_one_or_none() + + if not entry: + logger.error(f"Evolution event {event_id} not found") + return False + + entry.status = "rolled_back" + await db.commit() + logger.info(f"Evolution event {event_id} rolled back") + return True + except Exception as e: + await db.rollback() + logger.error(f"Failed to rollback evolution event {event_id}: {e}") + return False + + async def list_events( + self, + agent_name: str | None = None, + change_type: str | None = None, + status: str | None = None, + ) -> list[dict]: + """列出进化事件""" + async with self._session_factory() as db: + try: + from sqlalchemy import select + Model = self._evolution_model + stmt = select(Model) + + if agent_name: + stmt = stmt.where(Model.agent_name == agent_name) + if change_type: + stmt = stmt.where(Model.change_type == change_type) + if status: + stmt = stmt.where(Model.status == status) + + stmt = stmt.order_by(Model.created_at.desc()) + result = await db.execute(stmt) + entries = result.scalars().all() + + return [ + { + "id": str(e.id), + "agent_name": e.agent_name, + "change_type": e.change_type, + "before": e.before, + "after": e.after, + "metrics": e.metrics, + "status": e.status, + "created_at": e.created_at.isoformat() if e.created_at else None, + } + for e in entries + ] + except Exception as e: + logger.error(f"Failed to list evolution events: {e}") + return [] diff --git a/src/agentkit/evolution/prompt_optimizer.py b/src/agentkit/evolution/prompt_optimizer.py new file mode 100644 index 0000000..baf04f7 --- /dev/null +++ b/src/agentkit/evolution/prompt_optimizer.py @@ -0,0 +1,151 @@ +"""PromptOptimizer - DSPy 风格的 Prompt 自动优化器 + +核心概念: +- Signature: 定义输入/输出 schema +- Module: 可组合的 Prompt 策略 +- Optimizer: 从任务结果中自动优化 Prompt +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class Signature: + """Prompt 签名 - 定义输入/输出字段""" + input_fields: dict[str, str] # name -> description + output_fields: dict[str, str] # name -> description + instruction: str = "" + + def to_prompt_prefix(self) -> str: + parts = [] + if self.instruction: + parts.append(self.instruction) + parts.append("Inputs:") + for name, desc in self.input_fields.items(): + parts.append(f" - {name}: {desc}") + parts.append("Outputs:") + for name, desc in self.output_fields.items(): + parts.append(f" - {name}: {desc}") + return "\n".join(parts) + + +@dataclass +class Module: + """可组合的 Prompt 策略模块""" + name: str + signature: Signature + template: str = "" + demos: list[dict[str, Any]] = field(default_factory=list) + + def render(self, **kwargs) -> str: + parts = [] + parts.append(self.signature.to_prompt_prefix()) + if self.demos: + parts.append("\nExamples:") + for demo in self.demos: + parts.append(f"\nInput: {demo.get('input', '')}") + parts.append(f"Output: {demo.get('output', '')}") + if self.template: + parts.append(f"\n{self.template.format(**kwargs)}") + return "\n".join(parts) + + +class PromptOptimizer: + """DSPy 风格的 Prompt 自动优化器 + + 从成功案例中自动构建 few-shot 示例,优化 Prompt 指令。 + """ + + def __init__( + self, + max_demos: int = 5, + min_examples_for_optimization: int = 3, + ): + self._max_demos = max_demos + self._min_examples = min_examples_for_optimization + self._success_examples: list[dict[str, Any]] = [] + self._failure_examples: list[dict[str, Any]] = [] + + def add_example( + self, + input_data: dict, + output_data: dict, + quality_score: float, + ) -> None: + """添加训练样本""" + example = { + "input": input_data, + "output": output_data, + "quality_score": quality_score, + } + if quality_score >= 0.7: + self._success_examples.append(example) + else: + self._failure_examples.append(example) + + async def optimize(self, module: Module) -> Module: + """优化 Module 的 Prompt + + BootstrapFewShot: 从成功案例中自动构建 few-shot 示例 + """ + if len(self._success_examples) < self._min_examples: + logger.info( + f"Not enough examples for optimization " + f"({len(self._success_examples)}/{self._min_examples})" + ) + return module + + # 选择质量最高的成功案例作为 demo + sorted_examples = sorted( + self._success_examples, + key=lambda x: x["quality_score"], + reverse=True, + ) + best_demos = sorted_examples[:self._max_demos] + + # 构建 few-shot 示例 + demos = [] + for example in best_demos: + demos.append({ + "input": str(example["input"]), + "output": str(example["output"]), + }) + + # 优化指令(基于失败案例的反面教材) + optimized_instruction = module.signature.instruction + if self._failure_examples: + failure_patterns = set() + for ex in self._failure_examples[-3:]: + failure_patterns.add(str(ex["input"])[:100]) + if failure_patterns: + optimized_instruction += ( + f"\n\nAvoid these patterns:\n" + + "\n".join(f"- {p}" for p in failure_patterns) + ) + + # 创建优化后的 Module + optimized = Module( + name=f"{module.name}_optimized", + signature=Signature( + input_fields=module.signature.input_fields, + output_fields=module.signature.output_fields, + instruction=optimized_instruction, + ), + template=module.template, + demos=demos, + ) + + logger.info( + f"Optimized module '{module.name}': " + f"{len(demos)} demos, instruction length {len(optimized_instruction)}" + ) + + return optimized + + @property + def example_count(self) -> tuple[int, int]: + return len(self._success_examples), len(self._failure_examples) diff --git a/src/agentkit/evolution/reflector.py b/src/agentkit/evolution/reflector.py new file mode 100644 index 0000000..df03062 --- /dev/null +++ b/src/agentkit/evolution/reflector.py @@ -0,0 +1,147 @@ +"""Reflector - 执行反思 + +每次任务完成后自动评估结果,提取模式,生成反思总结。 +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus + +logger = logging.getLogger(__name__) + + +@dataclass +class Reflection: + """反思结果""" + task_id: str + agent_name: str + outcome: str # success / failure / partial + quality_score: float # 0.0 - 1.0 + patterns: list[str] = field(default_factory=list) + insights: list[str] = field(default_factory=list) + suggestions: list[str] = field(default_factory=list) + created_at: datetime = field(default_factory=lambda: datetime.utcnow()) + + +class Reflector: + """执行反思器 + + 评估任务结果,提取成功/失败模式,生成改进建议。 + """ + + def __init__( + self, + quality_scorer: Any | None = None, + pattern_extractor: Any | None = None, + ): + self._quality_scorer = quality_scorer + self._pattern_extractor = pattern_extractor + + async def reflect(self, task: TaskMessage, result: TaskResult) -> Reflection: + """对任务执行结果进行反思""" + # 判断结果 + outcome = "success" if result.status == TaskStatus.COMPLETED else "failure" + + # 质量评分 + quality_score = await self._score_quality(task, result) + + # 提取模式 + patterns = await self._extract_patterns(task, result, outcome) + + # 生成洞察 + insights = self._generate_insights(outcome, quality_score, patterns) + + # 生成建议 + suggestions = self._generate_suggestions(outcome, quality_score, patterns) + + reflection = Reflection( + task_id=task.task_id, + agent_name=result.agent_name, + outcome=outcome, + quality_score=quality_score, + patterns=patterns, + insights=insights, + suggestions=suggestions, + ) + + logger.info( + f"Reflection for task {task.task_id}: outcome={outcome}, " + f"quality={quality_score:.2f}, patterns={len(patterns)}" + ) + + return reflection + + async def _score_quality(self, task: TaskMessage, result: TaskResult) -> float: + """评估任务质量""" + if result.status != TaskStatus.COMPLETED: + return 0.0 + + if self._quality_scorer: + return await self._quality_scorer(task, result) + + # 默认评分逻辑 + score = 0.5 # 基础分 + + # 有输出数据加分 + if result.output_data: + score += 0.2 + + # 无错误加分 + if not result.error_message: + score += 0.1 + + # 耗时合理加分 + if result.metrics and result.metrics.get("elapsed_seconds", 0) < 30: + score += 0.2 + + return min(score, 1.0) + + async def _extract_patterns( + self, task: TaskMessage, result: TaskResult, outcome: str + ) -> list[str]: + """提取模式""" + patterns = [] + + if outcome == "failure": + if result.error_message: + patterns.append(f"error_type:{type(result.error_message).__name__}") + if result.metrics and result.metrics.get("elapsed_seconds", 0) > 60: + patterns.append("slow_execution") + else: + if result.output_data and len(result.output_data) > 5: + patterns.append("rich_output") + if result.metrics and result.metrics.get("elapsed_seconds", 0) < 10: + patterns.append("fast_execution") + + return patterns + + def _generate_insights( + self, outcome: str, quality_score: float, patterns: list[str] + ) -> list[str]: + """生成洞察""" + insights = [] + + if quality_score < 0.3: + insights.append("Low quality score indicates potential issues with task execution") + if "slow_execution" in patterns: + insights.append("Slow execution may benefit from strategy optimization") + if "error_type:TimeoutError" in patterns: + insights.append("Timeout errors suggest need for longer timeout or task decomposition") + + return insights + + def _generate_suggestions( + self, outcome: str, quality_score: float, patterns: list[str] + ) -> list[str]: + """生成改进建议""" + suggestions = [] + + if outcome == "failure" and quality_score < 0.3: + suggestions.append("Consider prompt optimization for this task type") + if "slow_execution" in patterns: + suggestions.append("Consider adjusting strategy parameters for faster execution") + + return suggestions diff --git a/src/agentkit/evolution/strategy_tuner.py b/src/agentkit/evolution/strategy_tuner.py new file mode 100644 index 0000000..d446f79 --- /dev/null +++ b/src/agentkit/evolution/strategy_tuner.py @@ -0,0 +1,81 @@ +"""StrategyTuner - 策略调优 + +自动调整 Agent 参数(temperature, tool 选择权重, Pipeline 路径)。 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class StrategyConfig: + """策略配置""" + temperature: float = 0.5 + tool_weights: dict[str, float] = field(default_factory=dict) + max_iterations: int = 5 + timeout_seconds: int = 300 + + +class StrategyTuner: + """策略调优器 + + 基于历史效果数据自动调整 Agent 参数。 + """ + + def __init__(self, param_ranges: dict[str, tuple[float, float]] | None = None): + self._param_ranges = param_ranges or { + "temperature": (0.0, 1.0), + "max_iterations": (1, 10), + } + self._history: list[dict[str, Any]] = [] + + def record(self, config: StrategyConfig, metric: float) -> None: + """记录配置和对应的效果指标""" + self._history.append({ + "config": config, + "metric": metric, + }) + + async def suggest(self, current: StrategyConfig) -> StrategyConfig: + """基于历史数据建议新的策略配置""" + if len(self._history) < 3: + logger.info("Not enough history for strategy tuning") + return current + + # 找到效果最好的配置 + best = max(self._history, key=lambda x: x["metric"]) + best_config = best["config"] + best_metric = best["metric"] + + # 在最佳配置附近微调 + suggested = StrategyConfig( + temperature=self._clamp( + best_config.temperature + self._small_perturbation(), + *self._param_ranges.get("temperature", (0.0, 1.0)), + ), + tool_weights=dict(best_config.tool_weights), + max_iterations=int(self._clamp( + best_config.max_iterations + self._small_perturbation(), + *self._param_ranges.get("max_iterations", (1, 10)), + )), + timeout_seconds=current.timeout_seconds, + ) + + logger.info( + f"Strategy suggestion: temperature {current.temperature:.2f} -> {suggested.temperature:.2f}, " + f"max_iterations {current.max_iterations} -> {suggested.max_iterations}" + ) + + return suggested + + @staticmethod + def _small_perturbation() -> float: + import random + return random.uniform(-0.1, 0.1) + + @staticmethod + def _clamp(value: float, min_val: float, max_val: float) -> float: + return max(min_val, min(max_val, value)) diff --git a/src/agentkit/mcp/__init__.py b/src/agentkit/mcp/__init__.py new file mode 100644 index 0000000..4ea9ba2 --- /dev/null +++ b/src/agentkit/mcp/__init__.py @@ -0,0 +1,6 @@ +"""AgentKit MCP - Model Context Protocol 支持""" + +__all__ = [ + "MCPServer", + "MCPClient", +] diff --git a/src/agentkit/mcp/client.py b/src/agentkit/mcp/client.py new file mode 100644 index 0000000..17b6169 --- /dev/null +++ b/src/agentkit/mcp/client.py @@ -0,0 +1,83 @@ +"""MCP Client - 调用外部 MCP 工具服务器""" + +import logging +from typing import Any + +import httpx + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class MCPClient: + """MCP Client - 连接外部 MCP Server 并调用工具""" + + def __init__(self, server_url: str, timeout: int = 30): + self._server_url = server_url.rstrip("/") + self._timeout = timeout + self._tools_cache: list[dict] | None = None + + async def list_tools(self) -> list[dict]: + """列出远程 MCP Server 上的工具""" + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.get(f"{self._server_url}/tools/list") + response.raise_for_status() + data = response.json() + self._tools_cache = data.get("tools", []) + return self._tools_cache + + async def call_tool(self, tool_name: str, arguments: dict) -> dict: + """调用远程 MCP 工具""" + async with httpx.AsyncClient(timeout=self._timeout) as client: + response = await client.post( + f"{self._server_url}/tools/call", + json={"name": tool_name, "arguments": arguments}, + ) + response.raise_for_status() + return response.json() + + def as_tool(self, tool_name: str, description: str = "") -> "MCPTool": + """将远程 MCP 工具包装为本地 Tool 对象""" + return MCPTool( + name=tool_name, + description=description, + client=self, + ) + + +class MCPTool(Tool): + """MCP 工具 - 通过 MCP Client 调用远程工具""" + + def __init__( + self, + name: str, + description: str, + client: MCPClient, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema, + output_schema=output_schema, + version=version, + tags=tags or ["mcp"], + ) + self._client = client + + async def execute(self, **kwargs) -> dict: + result = await self._client.call_tool(self.name, kwargs) + # 解析 MCP 响应格式 + if "content" in result: + for item in result["content"]: + if item.get("type") == "text": + import json + try: + return json.loads(item["text"]) + except json.JSONDecodeError: + return {"result": item["text"]} + return result diff --git a/src/agentkit/mcp/server.py b/src/agentkit/mcp/server.py new file mode 100644 index 0000000..502f28c --- /dev/null +++ b/src/agentkit/mcp/server.py @@ -0,0 +1,86 @@ +"""MCP Server - 将 Agent 能力暴露为 MCP 工具 + +基于 FastAPI 实现,支持 Streamable HTTP 传输。 +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class MCPServer: + """MCP Server - 暴露 Agent 能力为 MCP 工具 + + 自动将 ToolRegistry 中注册的工具暴露为 MCP 工具端点。 + """ + + def __init__(self, tool_registry: Any = None, host: str = "0.0.0.0", port: int = 8080): + self._tool_registry = tool_registry + self._host = host + self._port = port + self._app = None + + def _create_app(self): + """创建 FastAPI 应用""" + try: + from fastapi import FastAPI + except ImportError: + raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]") + + app = FastAPI(title="Fischer AgentKit MCP Server") + + @app.get("/tools/list") + async def list_tools(): + if self._tool_registry is None: + return {"tools": []} + tools = self._tool_registry.list_tools() + return { + "tools": [ + { + "name": t.name, + "description": t.description, + "inputSchema": t.input_schema or {}, + } + for t in tools + ] + } + + @app.post("/tools/call") + async def call_tool(request: dict): + tool_name = request.get("name") + arguments = request.get("arguments", {}) + + if not tool_name or self._tool_registry is None: + return {"error": "Tool not specified or registry not configured"} + + try: + tool = self._tool_registry.get(tool_name) + result = await tool.safe_execute(**arguments) + return {"content": [{"type": "text", "text": str(result)}]} + except Exception as e: + return {"isError": True, "content": [{"type": "text", "text": str(e)}]} + + @app.get("/health") + async def health(): + return {"status": "ok"} + + return app + + async def start(self): + """启动 MCP Server""" + self._app = self._create_app() + + try: + import uvicorn + config = uvicorn.Config(self._app, host=self._host, port=self._port, log_level="info") + server = uvicorn.Server(config) + await server.serve() + except ImportError: + raise ImportError("MCP Server requires uvicorn: pip install uvicorn") + + def get_app(self): + """获取 FastAPI 应用实例(用于测试或嵌入)""" + if self._app is None: + self._app = self._create_app() + return self._app diff --git a/src/agentkit/memory/__init__.py b/src/agentkit/memory/__init__.py new file mode 100644 index 0000000..bc3fcf1 --- /dev/null +++ b/src/agentkit/memory/__init__.py @@ -0,0 +1,16 @@ +"""AgentKit Memory - 记忆系统""" + +from agentkit.memory.base import Memory, MemoryItem +from agentkit.memory.working import WorkingMemory +from agentkit.memory.episodic import EpisodicMemory +from agentkit.memory.semantic import SemanticMemory +from agentkit.memory.retriever import MemoryRetriever + +__all__ = [ + "Memory", + "MemoryItem", + "WorkingMemory", + "EpisodicMemory", + "SemanticMemory", + "MemoryRetriever", +] diff --git a/src/agentkit/memory/base.py b/src/agentkit/memory/base.py new file mode 100644 index 0000000..953ae25 --- /dev/null +++ b/src/agentkit/memory/base.py @@ -0,0 +1,74 @@ +"""Memory 抽象基类 - 统一记忆接口""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + + +@dataclass +class MemoryItem: + """记忆条目""" + key: str + value: Any + metadata: dict[str, Any] = field(default_factory=dict) + score: float = 1.0 + created_at: datetime = field(default_factory=lambda: datetime.utcnow()) + + def to_dict(self) -> dict: + return { + "key": self.key, + "value": self.value, + "metadata": self.metadata, + "score": self.score, + "created_at": self.created_at.isoformat(), + } + + +class Memory(ABC): + """记忆抽象基类 + + 三层记忆系统的统一接口: + - WorkingMemory: 当前任务上下文(Redis, 短生命周期) + - EpisodicMemory: 任务经验(pgvector+PG, 永久) + - SemanticMemory: 知识库(RAG+Graph, 永久) + """ + + @abstractmethod + async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + """存储记忆""" + ... + + @abstractmethod + async def retrieve(self, key: str) -> MemoryItem | None: + """按 key 精确检索""" + ... + + @abstractmethod + async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + """语义检索""" + ... + + @abstractmethod + async def delete(self, key: str) -> bool: + """删除记忆""" + ... + + async def store_batch(self, items: list[tuple[str, Any, dict | None]]) -> None: + """批量存储""" + for key, value, metadata in items: + await self.store(key, value, metadata) + + async def get_context(self, query: str, token_budget: int = 3000) -> str: + """获取格式化的上下文字符串(用于注入 Prompt)""" + items = await self.search(query, top_k=10) + context_parts = [] + total_tokens = 0 + for item in items: + text = str(item.value) + estimated_tokens = len(text) // 4 # 粗略估算 + if total_tokens + estimated_tokens > token_budget: + break + context_parts.append(text) + total_tokens += estimated_tokens + return "\n".join(context_parts) diff --git a/src/agentkit/memory/episodic.py b/src/agentkit/memory/episodic.py new file mode 100644 index 0000000..856e927 --- /dev/null +++ b/src/agentkit/memory/episodic.py @@ -0,0 +1,149 @@ +"""Episodic Memory - 基于 pgvector + PostgreSQL 的任务经验记忆""" + +import logging +import math +from datetime import datetime +from typing import Any + +from agentkit.memory.base import Memory, MemoryItem + +logger = logging.getLogger(__name__) + + +class EpisodicMemory(Memory): + """Episodic Memory - 记录每次任务的输入/输出/效果/反思 + + 基于 pgvector + PostgreSQL 实现,支持语义检索和时间衰减。 + 生命周期:永久(可配置衰减)。 + """ + + def __init__( + self, + session_factory: Any, + episodic_model: Any, + embedder: Any | None = None, + decay_rate: float = 0.01, + ): + """ + Args: + session_factory: 返回 async context manager 的工厂 + episodic_model: EpisodicMemory ORM 模型类 + embedder: 嵌入器,用于生成向量 + decay_rate: 时间衰减率(越大衰减越快) + """ + self._session_factory = session_factory + self._episodic_model = episodic_model + self._embedder = embedder + self._decay_rate = decay_rate + + async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + """存储任务经验""" + async with self._session_factory() as db: + try: + Model = self._episodic_model + meta = metadata or {} + + # 生成 embedding + embedding = None + if self._embedder: + text = f"{key} {value}" + embedding = await self._embedder.embed(text) + + entry = Model( + agent_name=meta.get("agent_name", ""), + task_type=meta.get("task_type", ""), + input_summary=str(value)[:500] if value else "", + output_summary=meta.get("output_summary", ""), + outcome=meta.get("outcome", "success"), + quality_score=meta.get("quality_score", 0.5), + reflection=meta.get("reflection", ""), + embedding=embedding, + ) + db.add(entry) + await db.commit() + except Exception as e: + await db.rollback() + logger.error(f"Failed to store episodic memory: {e}") + raise + + async def retrieve(self, key: str) -> MemoryItem | None: + """按 key 精确检索(Episodic Memory 通常不按 key 检索)""" + return None + + async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + """语义检索相似历史案例""" + async with self._session_factory() as db: + try: + Model = self._episodic_model + filters = filters or {} + + # 构建查询 + from sqlalchemy import select, text as sql_text + stmt = select(Model) + + if filters.get("agent_name"): + stmt = stmt.where(Model.agent_name == filters["agent_name"]) + if filters.get("task_type"): + stmt = stmt.where(Model.task_type == filters["task_type"]) + if filters.get("outcome"): + stmt = stmt.where(Model.outcome == filters["outcome"]) + + stmt = stmt.order_by(Model.created_at.desc()).limit(top_k * 2) + + result = await db.execute(stmt) + entries = result.scalars().all() + + # 如果有 embedder,进行向量相似度排序 + if self._embedder and entries: + query_embedding = await self._embedder.embed(query) + # TODO: 使用 pgvector 的 cosine distance 排序 + # 目前按时间衰减排序 + + # 时间衰减排序 + items = [] + for entry in entries: + age_hours = (datetime.utcnow() - entry.created_at).total_seconds() / 3600 if entry.created_at else 0 + decay = math.exp(-self._decay_rate * age_hours) + score = (entry.quality_score or 0.5) * decay + + items.append(MemoryItem( + key=str(entry.id), + value={ + "input_summary": entry.input_summary, + "output_summary": entry.output_summary, + "outcome": entry.outcome, + "quality_score": entry.quality_score, + "reflection": entry.reflection, + }, + metadata={ + "agent_name": entry.agent_name, + "task_type": entry.task_type, + "created_at": entry.created_at.isoformat() if entry.created_at else None, + }, + score=score, + created_at=entry.created_at or datetime.utcnow(), + )) + + items.sort(key=lambda x: x.score, reverse=True) + return items[:top_k] + + except Exception as e: + logger.error(f"Failed to search episodic memory: {e}") + return [] + + async def delete(self, key: str) -> bool: + """删除指定经验""" + async with self._session_factory() as db: + try: + from sqlalchemy import select, delete as sql_delete + import uuid + Model = self._episodic_model + + stmt = sql_delete(Model).where(Model.id == uuid.UUID(key)) + await db.execute(stmt) + await db.commit() + return True + except Exception as e: + await db.rollback() + logger.error(f"Failed to delete episodic memory: {e}") + return False diff --git a/src/agentkit/memory/retriever.py b/src/agentkit/memory/retriever.py new file mode 100644 index 0000000..4dc6ec7 --- /dev/null +++ b/src/agentkit/memory/retriever.py @@ -0,0 +1,113 @@ +"""MemoryRetriever - 混合检索器 + +并行查询三层记忆,按权重融合排序。 +""" + +import asyncio +import logging +import math +from datetime import datetime +from typing import Any + +from agentkit.memory.base import Memory, MemoryItem +from agentkit.memory.working import WorkingMemory +from agentkit.memory.episodic import EpisodicMemory +from agentkit.memory.semantic import SemanticMemory + +logger = logging.getLogger(__name__) + + +class MemoryRetriever: + """混合检索器 - 并行查询三层记忆,按权重融合排序 + + 检索策略: + 1. 并行查询 Working/Episodic/Semantic 三层 + 2. 按权重融合排序(默认 Working 0.2, Episodic 0.4, Semantic 0.4) + 3. 时间衰减:越久远的记忆权重越低 + 4. 上下文窗口管理:总 token 不超过预算 + """ + + def __init__( + self, + working_memory: WorkingMemory | None = None, + episodic_memory: EpisodicMemory | None = None, + semantic_memory: SemanticMemory | None = None, + weights: dict[str, float] | None = None, + ): + self._working = working_memory + self._episodic = episodic_memory + self._semantic = semantic_memory + self._weights = weights or { + "working": 0.2, + "episodic": 0.4, + "semantic": 0.4, + } + + async def retrieve( + self, + query: str, + top_k: int = 5, + token_budget: int = 3000, + filters: dict[str, Any] | None = None, + ) -> list[MemoryItem]: + """混合检索三层记忆""" + tasks = [] + layer_names = [] + + if self._working: + tasks.append(self._working.search(query, top_k=top_k, filters=filters)) + layer_names.append("working") + if self._episodic: + tasks.append(self._episodic.search(query, top_k=top_k, filters=filters)) + layer_names.append("episodic") + if self._semantic: + tasks.append(self._semantic.search(query, top_k=top_k, filters=filters)) + layer_names.append("semantic") + + if not tasks: + return [] + + # 并行查询 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 融合排序 + all_items = [] + for layer_name, result in zip(layer_names, results): + if isinstance(result, Exception): + logger.error(f"Memory search failed for {layer_name}: {result}") + continue + weight = self._weights.get(layer_name, 0.3) + for item in result: + item.score *= weight + all_items.append(item) + + # 按分数排序 + all_items.sort(key=lambda x: x.score, reverse=True) + + # Token 预算管理 + selected = [] + total_tokens = 0 + for item in all_items: + text = str(item.value) + estimated_tokens = len(text) // 4 + if total_tokens + estimated_tokens > token_budget: + continue + selected.append(item) + total_tokens += estimated_tokens + if len(selected) >= top_k: + break + + return selected + + async def get_context_string( + self, + query: str, + top_k: int = 5, + token_budget: int = 3000, + ) -> str: + """获取格式化的上下文字符串""" + items = await self.retrieve(query, top_k, token_budget) + parts = [] + for item in items: + parts.append(str(item.value)) + return "\n\n".join(parts) diff --git a/src/agentkit/memory/semantic.py b/src/agentkit/memory/semantic.py new file mode 100644 index 0000000..5378ffd --- /dev/null +++ b/src/agentkit/memory/semantic.py @@ -0,0 +1,94 @@ +"""Semantic Memory - 知识库适配器 + +适配器模式,对接外部 RAG 服务和知识图谱。 +""" + +import logging +from typing import Any + +from agentkit.memory.base import Memory, MemoryItem + +logger = logging.getLogger(__name__) + + +class SemanticMemory(Memory): + """Semantic Memory - 知识库检索 + + 通过适配器对接外部 RAG 服务,不直接依赖具体实现。 + """ + + def __init__( + self, + rag_service: Any = None, + graph_service: Any = None, + knowledge_base_ids: list[str] | None = None, + ): + """ + Args: + rag_service: RAG 检索服务(需提供 search 方法) + graph_service: 知识图谱服务(需提供 query 方法) + knowledge_base_ids: 默认检索的知识库 ID 列表 + """ + self._rag_service = rag_service + self._graph_service = graph_service + self._knowledge_base_ids = knowledge_base_ids or [] + + async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + """Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法""" + if self._rag_service and hasattr(self._rag_service, 'ingest'): + await self._rag_service.ingest(key, value, metadata) + else: + logger.warning("SemanticMemory.store: no RAG service configured for writing") + + async def retrieve(self, key: str) -> MemoryItem | None: + """按 key 精确检索(Semantic Memory 通常不按 key 检索)""" + return None + + async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + """语义检索知识库""" + items = [] + + # RAG 检索 + if self._rag_service: + try: + kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids) + results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k) + for r in results: + items.append(MemoryItem( + key=r.get("id", ""), + value=r.get("content", ""), + metadata={ + "source": r.get("source", "rag"), + "score": r.get("score", 0.0), + "document_id": r.get("document_id"), + }, + score=r.get("score", 0.0), + )) + except Exception as e: + logger.error(f"RAG search failed: {e}") + + # 知识图谱检索 + if self._graph_service: + try: + graph_results = await self._graph_service.query(query, depth=2) + for r in graph_results[:top_k]: + items.append(MemoryItem( + key=r.get("id", ""), + value=r.get("content", ""), + metadata={ + "source": "graph", + "entities": r.get("entities", []), + "relations": r.get("relations", []), + }, + score=r.get("score", 0.0), + )) + except Exception as e: + logger.error(f"Graph search failed: {e}") + + items.sort(key=lambda x: x.score, reverse=True) + return items[:top_k] + + async def delete(self, key: str) -> bool: + """Semantic Memory 通常只读""" + logger.warning("SemanticMemory.delete: read-only memory") + return False diff --git a/src/agentkit/memory/working.py b/src/agentkit/memory/working.py new file mode 100644 index 0000000..9401328 --- /dev/null +++ b/src/agentkit/memory/working.py @@ -0,0 +1,97 @@ +"""Working Memory - 基于 Redis 的短期任务记忆""" + +import json +import logging +from datetime import datetime +from typing import Any + +import redis.asyncio as aioredis + +from agentkit.memory.base import Memory, MemoryItem + +logger = logging.getLogger(__name__) + + +class WorkingMemory(Memory): + """Working Memory - 当前任务的上下文和中间状态 + + 基于 Redis 实现,支持自动过期(TTL)。 + 生命周期:单次任务。 + """ + + def __init__( + self, + redis: aioredis.Redis, + key_prefix: str = "agentkit:working", + default_ttl: int = 3600, + ): + self._redis = redis + self._key_prefix = key_prefix + self._default_ttl = default_ttl + + def _make_key(self, key: str) -> str: + return f"{self._key_prefix}:{key}" + + async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None: + redis_key = self._make_key(key) + item = MemoryItem( + key=key, + value=value, + metadata=metadata or {}, + created_at=datetime.utcnow(), + ) + await self._redis.setex( + redis_key, + self._default_ttl, + json.dumps(item.to_dict(), default=str), + ) + + async def retrieve(self, key: str) -> MemoryItem | None: + redis_key = self._make_key(key) + data = await self._redis.get(redis_key) + if data is None: + return None + item_dict = json.loads(data) + return MemoryItem( + key=item_dict["key"], + value=item_dict["value"], + metadata=item_dict.get("metadata", {}), + score=item_dict.get("score", 1.0), + created_at=datetime.fromisoformat(item_dict["created_at"]) if item_dict.get("created_at") else datetime.utcnow(), + ) + + async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]: + """Working Memory 不支持语义检索,按 key 前缀匹配""" + pattern = self._make_key(f"{query}*") + keys = [] + async for key in self._redis.scan_iter(match=pattern, count=top_k): + keys.append(key) + if len(keys) >= top_k: + break + + items = [] + for key in keys: + data = await self._redis.get(key) + if data: + item_dict = json.loads(data) + items.append(MemoryItem( + key=item_dict["key"], + value=item_dict["value"], + metadata=item_dict.get("metadata", {}), + score=1.0, + created_at=datetime.utcnow(), + )) + return items + + async def delete(self, key: str) -> bool: + redis_key = self._make_key(key) + return bool(await self._redis.delete(redis_key)) + + async def clear(self, prefix: str = "") -> int: + """清除指定前缀的所有 Working Memory""" + pattern = self._make_key(f"{prefix}*") + count = 0 + async for key in self._redis.scan_iter(match=pattern): + await self._redis.delete(key) + count += 1 + return count diff --git a/src/agentkit/orchestrator/__init__.py b/src/agentkit/orchestrator/__init__.py new file mode 100644 index 0000000..0907993 --- /dev/null +++ b/src/agentkit/orchestrator/__init__.py @@ -0,0 +1,17 @@ +"""AgentKit Orchestrator - 多 Agent 协同编排""" + +from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus +from agentkit.orchestrator.pipeline_engine import PipelineEngine +from agentkit.orchestrator.pipeline_loader import PipelineLoader +from agentkit.orchestrator.handoff import HandoffManager +from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline + +__all__ = [ + "Pipeline", + "PipelineStage", + "StageStatus", + "PipelineEngine", + "PipelineLoader", + "HandoffManager", + "DynamicPipeline", +] diff --git a/src/agentkit/orchestrator/dynamic_pipeline.py b/src/agentkit/orchestrator/dynamic_pipeline.py new file mode 100644 index 0000000..a6b8e51 --- /dev/null +++ b/src/agentkit/orchestrator/dynamic_pipeline.py @@ -0,0 +1,92 @@ +"""DynamicPipeline - 动态 Pipeline 组合 + +支持运行时根据条件选择子流程、嵌套 Pipeline、循环 Pipeline。 +""" + +import logging +from typing import Any + +from agentkit.orchestrator.pipeline_engine import PipelineEngine +from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageStatus + +logger = logging.getLogger(__name__) + + +class DynamicPipeline: + """动态 Pipeline 组合器""" + + def __init__(self, engine: PipelineEngine, loader: Any = None): + self._engine = engine + self._loader = loader + + async def execute_conditional( + self, + pipelines: dict[str, Pipeline], + condition_key: str, + context: dict[str, Any] | None = None, + ) -> PipelineResult: + """根据条件选择子 Pipeline 执行""" + context = context or {} + condition_value = context.get(condition_key) + + if condition_value not in pipelines: + return PipelineResult( + pipeline_name=f"conditional_{condition_key}", + status=StageStatus.FAILED, + error_message=f"No pipeline for condition '{condition_key}={condition_value}'", + ) + + selected = pipelines[condition_value] + logger.info(f"DynamicPipeline selected '{selected.name}' for {condition_key}={condition_value}") + return await self._engine.execute(selected, context) + + async def execute_nested( + self, + parent: Pipeline, + sub_pipeline_map: dict[str, Pipeline], + context: dict[str, Any] | None = None, + ) -> PipelineResult: + """执行嵌套 Pipeline""" + # 先执行父 Pipeline + parent_result = await self._engine.execute(parent, context) + + # 根据父 Pipeline 结果选择子 Pipeline + for stage_name, stage_result in parent_result.stage_results.items(): + if hasattr(stage_result, 'output_data') and stage_result.output_data: + sub_pipeline_name = stage_result.output_data.get("sub_pipeline") + if sub_pipeline_name and sub_pipeline_name in sub_pipeline_map: + sub = sub_pipeline_map[sub_pipeline_name] + sub_result = await self._engine.execute(sub, parent_result.variables) + parent_result.variables.update(sub_result.variables) + + return parent_result + + async def execute_loop( + self, + pipeline: Pipeline, + max_iterations: int = 5, + exit_condition: str = "done", + context: dict[str, Any] | None = None, + ) -> PipelineResult: + """循环执行 Pipeline 直到条件满足""" + current_context = context or {} + last_result = None + + for i in range(max_iterations): + logger.info(f"DynamicPipeline loop iteration {i + 1}/{max_iterations}") + result = await self._engine.execute(pipeline, current_context) + last_result = result + + # 检查退出条件 + if exit_condition in result.variables and result.variables[exit_condition]: + logger.info(f"DynamicPipeline loop exited at iteration {i + 1}") + break + + # 将结果作为下一轮的输入 + current_context.update(result.variables) + + return last_result or PipelineResult( + pipeline_name=pipeline.name, + status=StageStatus.FAILED, + error_message="Loop completed without meeting exit condition", + ) diff --git a/src/agentkit/orchestrator/handoff.py b/src/agentkit/orchestrator/handoff.py new file mode 100644 index 0000000..cc13631 --- /dev/null +++ b/src/agentkit/orchestrator/handoff.py @@ -0,0 +1,69 @@ +"""HandoffManager - Agent 间任务转交""" + +import asyncio +import json +import logging +from typing import Any + +from agentkit.core.protocol import HandoffMessage + +logger = logging.getLogger(__name__) + + +class HandoffManager: + """Handoff 管理器 + + 通过 Redis Pub/Sub 管理 Agent 间的任务转交。 + """ + + def __init__(self, redis: Any = None, dispatcher: Any = None): + self._redis = redis + self._dispatcher = dispatcher + self._handlers: dict[str, list[Any]] = {} + + def register_handler(self, agent_name: str, handler: Any) -> None: + """注册 Handoff 处理器""" + if agent_name not in self._handlers: + self._handlers[agent_name] = [] + self._handlers[agent_name].append(handler) + + async def send_handoff(self, handoff: HandoffMessage) -> None: + """发送 Handoff 请求""" + if self._redis is None: + raise RuntimeError("Handoff requires Redis connection") + + channel = f"agent:{handoff.target_agent}:handoff" + await self._redis.publish(channel, json.dumps(handoff.to_dict())) + logger.info( + f"Handoff sent: {handoff.source_agent} -> {handoff.target_agent} " + f"(task={handoff.task_id}, reason={handoff.reason})" + ) + + async def listen_for_handoffs(self, agent_name: str) -> None: + """监听发往指定 Agent 的 Handoff 请求""" + if self._redis is None: + return + + channel = f"agent:{agent_name}:handoff" + pubsub = self._redis.pubsub() + await pubsub.subscribe(channel) + + try: + async for message in pubsub.listen(): + if message["type"] == "message": + data = json.loads(message["data"]) + handoff = HandoffMessage.from_dict(data) + await self._handle_handoff(handoff) + except asyncio.CancelledError: + pass + finally: + await pubsub.unsubscribe(channel) + + async def _handle_handoff(self, handoff: HandoffMessage) -> None: + """处理收到的 Handoff""" + handlers = self._handlers.get(handoff.target_agent, []) + for handler in handlers: + try: + await handler(handoff) + except Exception as e: + logger.error(f"Handoff handler error: {e}") diff --git a/src/agentkit/orchestrator/pipeline_engine.py b/src/agentkit/orchestrator/pipeline_engine.py new file mode 100644 index 0000000..26bca97 --- /dev/null +++ b/src/agentkit/orchestrator/pipeline_engine.py @@ -0,0 +1,241 @@ +"""Pipeline Engine - DAG + 并行执行""" + +import asyncio +import logging +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any + +from agentkit.orchestrator.pipeline_schema import ( + Pipeline, + PipelineResult, + PipelineStage, + StageResult, + StageStatus, +) + +logger = logging.getLogger(__name__) + + +class PipelineEngine: + """Pipeline 执行引擎 + + 支持: + - DAG 拓扑排序 + - 同层并行执行(asyncio.gather) + - 变量解析 + - 条件执行 + - 重试 + """ + + def __init__(self, dispatcher: Any = None): + self._dispatcher = dispatcher + + async def execute( + self, + pipeline: Pipeline, + context: dict[str, Any] | None = None, + ) -> PipelineResult: + """执行 Pipeline""" + result = PipelineResult(pipeline_name=pipeline.name) + result.variables = {**pipeline.variables, **(context or {})} + + # 拓扑排序 + 按依赖层级分组 + try: + level_groups = self._topological_group(pipeline.stages) + except ValueError as e: + result.status = StageStatus.FAILED + result.error_message = str(e) + return result + + # 逐层执行 + for level, stages in enumerate(level_groups): + logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") + + # 并行执行同层 stages + tasks = [] + for stage in stages: + tasks.append(self._execute_stage(stage, result)) + + stage_results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理结果 + for stage, sr in zip(stages, stage_results): + if isinstance(sr, Exception): + sr = StageResult( + stage_name=stage.name, + status=StageStatus.FAILED, + error_message=str(sr), + ) + result.stage_results[stage.name] = sr + + # 收集输出变量 + if sr.output_data and isinstance(sr, dict): + pass + elif hasattr(sr, 'output_data') and sr.output_data: + for output_key in stage.outputs: + if output_key in sr.output_data: + result.variables[output_key] = sr.output_data[output_key] + + # 检查是否需要中止 + if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: + if not stage.continue_on_failure: + result.status = StageStatus.FAILED + result.error_message = f"Stage '{stage.name}' failed" + return result + + result.status = StageStatus.COMPLETED + return result + + async def _execute_stage( + self, + stage: PipelineStage, + pipeline_result: PipelineResult, + ) -> StageResult: + """执行单个 stage""" + started_at = datetime.now(timezone.utc).isoformat() + + # 条件检查 + if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables): + return StageResult( + stage_name=stage.name, + status=StageStatus.SKIPPED, + started_at=started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + # 解析输入变量 + resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables) + + # 执行 + if self._dispatcher is None: + # Dry-run 模式 + return StageResult( + stage_name=stage.name, + status=StageStatus.COMPLETED, + output_data={"dry_run": True, "inputs": resolved_inputs}, + started_at=started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + # 通过 Dispatcher 分发任务 + from agentkit.core.protocol import TaskMessage + import uuid + + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name=stage.agent, + task_type=stage.action, + priority=0, + input_data=resolved_inputs, + callback_url=None, + created_at=datetime.now(timezone.utc), + timeout_seconds=stage.timeout_seconds, + ) + + try: + await self._dispatcher.dispatch(task) + + # 等待结果 + for _ in range(stage.timeout_seconds): + status = await self._dispatcher.get_task_status(task.task_id) + if status["status"] in ("completed", "failed", "cancelled"): + return StageResult( + stage_name=stage.name, + status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED, + output_data=status.get("output_data"), + error_message=status.get("error_message"), + started_at=started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + await asyncio.sleep(1) + + return StageResult( + stage_name=stage.name, + status=StageStatus.FAILED, + error_message=f"Timeout after {stage.timeout_seconds}s", + started_at=started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + except Exception as e: + return StageResult( + stage_name=stage.name, + status=StageStatus.FAILED, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc).isoformat(), + ) + + @staticmethod + def _topological_group(stages: list[PipelineStage]) -> list[list[PipelineStage]]: + """拓扑排序 + 按依赖层级分组""" + stage_map = {s.name: s for s in stages} + in_degree = defaultdict(int) + dependents = defaultdict(list) + + for s in stages: + if s.name not in in_degree: + in_degree[s.name] = 0 + for dep in s.depends_on: + if dep not in stage_map: + raise ValueError(f"Stage '{s.name}' depends on unknown stage '{dep}'") + in_degree[s.name] += 1 + dependents[dep].append(s.name) + + levels = [] + remaining = set(in_degree.keys()) + + while remaining: + # 找到入度为 0 的节点 + current_level = [name for name in remaining if in_degree[name] == 0] + if not current_level: + raise ValueError("Circular dependency detected in pipeline") + + levels.append([stage_map[name] for name in current_level]) + + for name in current_level: + remaining.remove(name) + for dep in dependents[name]: + in_degree[dep] -= 1 + + return levels + + @staticmethod + def _resolve_variables(template: dict, context: dict) -> dict: + """解析 ${var.path} 变量引用""" + resolved = {} + for key, value in template.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + var_path = value[2:-1] + resolved[key] = PipelineEngine._get_nested(context, var_path) + else: + resolved[key] = value + return resolved + + @staticmethod + def _get_nested(data: dict, path: str) -> Any: + keys = path.split(".") + current = data + for key in keys: + if isinstance(current, dict): + current = current.get(key) + else: + return None + return current + + @staticmethod + def _evaluate_condition(condition: str, variables: dict) -> bool: + """简单条件评估""" + if "==" in condition: + parts = condition.split("==", 1) + left = variables.get(parts[0].strip(), parts[0].strip()) + right = parts[1].strip().strip("'\"") + return str(left) == right + elif "!=" in condition: + parts = condition.split("!=", 1) + left = variables.get(parts[0].strip(), parts[0].strip()) + right = parts[1].strip().strip("'\"") + return str(left) != right + else: + return bool(variables.get(condition)) diff --git a/src/agentkit/orchestrator/pipeline_loader.py b/src/agentkit/orchestrator/pipeline_loader.py new file mode 100644 index 0000000..e22498a --- /dev/null +++ b/src/agentkit/orchestrator/pipeline_loader.py @@ -0,0 +1,72 @@ +"""Pipeline Loader - YAML 加载器""" + +import logging +from pathlib import Path +from typing import Any + +import yaml + +from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage + +logger = logging.getLogger(__name__) + + +class PipelineLoader: + """Pipeline YAML 加载器""" + + def __init__(self, pipelines_dir: str | Path | None = None): + self._pipelines_dir = Path(pipelines_dir) if pipelines_dir else Path("pipelines") + + def load(self, pipeline_name: str) -> Pipeline: + """从 YAML 文件加载 Pipeline""" + yaml_path = self._pipelines_dir / f"{pipeline_name}.yaml" + if not yaml_path.exists(): + yaml_path = self._pipelines_dir / f"{pipeline_name}.yml" + if not yaml_path.exists(): + raise FileNotFoundError(f"Pipeline '{pipeline_name}' not found in {self._pipelines_dir}") + + content = yaml_path.read_text(encoding="utf-8") + return self.load_from_yaml(content, pipeline_name) + + def load_from_yaml(self, yaml_content: str, pipeline_name: str | None = None) -> Pipeline: + """从 YAML 字符串加载 Pipeline""" + data = yaml.safe_load(yaml_content) + + stages = [] + for stage_data in data.get("stages", []): + stages.append(PipelineStage(**stage_data)) + + return Pipeline( + name=data.get("name", pipeline_name or "unnamed"), + version=data.get("version", "1.0.0"), + description=data.get("description", ""), + stages=stages, + variables=data.get("variables", {}), + ) + + @staticmethod + def validate_dag(stages: list[PipelineStage]) -> bool: + """验证 DAG 无环""" + stage_names = {s.name for s in stages} + visited = set() + path = set() + + def dfs(name: str) -> bool: + if name in path: + return False + if name in visited: + return True + + path.add(name) + stage = next((s for s in stages if s.name == name), None) + if stage: + for dep in stage.depends_on: + if dep not in stage_names: + return False + if not dfs(dep): + return False + path.remove(name) + visited.add(name) + return True + + return all(dfs(name) for name in stage_names) diff --git a/src/agentkit/orchestrator/pipeline_schema.py b/src/agentkit/orchestrator/pipeline_schema.py new file mode 100644 index 0000000..bef758b --- /dev/null +++ b/src/agentkit/orchestrator/pipeline_schema.py @@ -0,0 +1,52 @@ +"""Pipeline 数据模型""" + +from enum import Enum +from typing import Any + +from pydantic import BaseModel + + +class StageStatus(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + + +class PipelineStage(BaseModel): + name: str + agent: str + action: str + depends_on: list[str] = [] + inputs: dict[str, Any] = {} + outputs: list[str] = [] + timeout_seconds: int = 300 + retry_count: int = 0 + continue_on_failure: bool = False + condition: str | None = None + + +class Pipeline(BaseModel): + name: str + version: str + description: str + stages: list[PipelineStage] + variables: dict[str, Any] = {} + + +class StageResult(BaseModel): + stage_name: str + status: StageStatus = StageStatus.PENDING + output_data: dict[str, Any] | None = None + error_message: str | None = None + started_at: str | None = None + completed_at: str | None = None + + +class PipelineResult(BaseModel): + pipeline_name: str + status: StageStatus = StageStatus.PENDING + stage_results: dict[str, StageResult] = {} + variables: dict[str, Any] = {} + error_message: str | None = None diff --git a/src/agentkit/prompts/__init__.py b/src/agentkit/prompts/__init__.py new file mode 100644 index 0000000..82ff215 --- /dev/null +++ b/src/agentkit/prompts/__init__.py @@ -0,0 +1,9 @@ +"""AgentKit Prompts - Prompt 模板系统""" + +from agentkit.prompts.template import PromptTemplate +from agentkit.prompts.section import PromptSection + +__all__ = [ + "PromptTemplate", + "PromptSection", +] diff --git a/src/agentkit/prompts/section.py b/src/agentkit/prompts/section.py new file mode 100644 index 0000000..d4ef8d2 --- /dev/null +++ b/src/agentkit/prompts/section.py @@ -0,0 +1,36 @@ +"""PromptSection - 模块化 Prompt 段落""" + +from dataclasses import dataclass + + +@dataclass +class PromptSection: + """Prompt 段落定义 + + 将 Prompt 分为 5 个标准段落,支持变量注入和 Token 预算管理。 + """ + identity: str = "" + context: str = "" + instructions: str = "" + constraints: str = "" + output_format: str = "" + examples: str = "" + + def render(self, variables: dict | None = None) -> str: + """渲染段落,替换变量""" + text = "\n\n".join( + part for part in [ + self.identity, + self.context, + self.instructions, + self.constraints, + self.output_format, + self.examples, + ] if part + ) + + if variables: + for key, value in variables.items(): + text = text.replace(f"${{{key}}}", str(value)) + + return text diff --git a/src/agentkit/prompts/template.py b/src/agentkit/prompts/template.py new file mode 100644 index 0000000..dea242b --- /dev/null +++ b/src/agentkit/prompts/template.py @@ -0,0 +1,71 @@ +"""PromptTemplate - Prompt 模板渲染""" + +import logging +from typing import Any + +from agentkit.prompts.section import PromptSection + +logger = logging.getLogger(__name__) + + +class PromptTemplate: + """Prompt 模板 + + 支持变量注入、Token 预算管理、动态段落组合。 + """ + + def __init__( + self, + sections: PromptSection, + name: str = "", + version: str = "1.0.0", + ): + self._sections = sections + self.name = name + self.version = version + + def render( + self, + variables: dict[str, Any] | None = None, + context_budget: int = 3000, + ) -> list[dict[str, str]]: + """渲染 Prompt 为消息列表 + + Returns: + [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}] + """ + system_parts = [] + if self._sections.identity: + system_parts.append(self._sections.identity) + if self._sections.context: + context = self._sections.context + if variables: + for key, value in variables.items(): + context = context.replace(f"${{{key}}}", str(value)) + system_parts.append(context) + if self._sections.constraints: + system_parts.append(self._sections.constraints) + + user_parts = [] + if self._sections.instructions: + instructions = self._sections.instructions + if variables: + for key, value in variables.items(): + instructions = instructions.replace(f"${{{key}}}", str(value)) + user_parts.append(instructions) + if self._sections.output_format: + user_parts.append(self._sections.output_format) + if self._sections.examples: + user_parts.append(self._sections.examples) + + messages = [] + if system_parts: + messages.append({"role": "system", "content": "\n\n".join(system_parts)}) + if user_parts: + messages.append({"role": "user", "content": "\n\n".join(user_parts)}) + + return messages + + @property + def sections(self) -> PromptSection: + return self._sections diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py new file mode 100644 index 0000000..57ac4ac --- /dev/null +++ b/src/agentkit/tools/__init__.py @@ -0,0 +1,13 @@ +"""AgentKit Tools - 工具插件系统""" + +from agentkit.tools.base import Tool +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.agent_tool import AgentTool +from agentkit.tools.registry import ToolRegistry + +__all__ = [ + "Tool", + "FunctionTool", + "AgentTool", + "ToolRegistry", +] diff --git a/src/agentkit/tools/agent_tool.py b/src/agentkit/tools/agent_tool.py new file mode 100644 index 0000000..bacf7e0 --- /dev/null +++ b/src/agentkit/tools/agent_tool.py @@ -0,0 +1,93 @@ +"""AgentTool - 将 Agent 包装为 Tool""" + +from typing import Any + +from agentkit.tools.base import Tool + + +class AgentTool(Tool): + """将另一个 Agent 包装为 Tool + + 通过 Dispatcher 分发任务到目标 Agent,等待结果返回。 + """ + + def __init__( + self, + name: str, + description: str, + agent_name: str, + task_type: str, + input_mapping: dict[str, str] | None = None, + output_mapping: dict[str, str] | None = None, + timeout_seconds: int = 300, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + version=version, + tags=tags or ["agent"], + ) + self.agent_name = agent_name + self.task_type = task_type + self.input_mapping = input_mapping or {} + self.output_mapping = output_mapping or {} + self.timeout_seconds = timeout_seconds + self._dispatcher = None + + def set_dispatcher(self, dispatcher: Any) -> "AgentTool": + """注入 Dispatcher""" + self._dispatcher = dispatcher + return self + + async def execute(self, **kwargs) -> dict: + if self._dispatcher is None: + raise RuntimeError(f"AgentTool '{self.name}' has no dispatcher configured") + + from agentkit.core.protocol import TaskMessage + from datetime import datetime, timezone + import uuid + + # 映射输入 + mapped_input = {} + for target_key, source_key in self.input_mapping.items(): + if source_key in kwargs: + mapped_input[target_key] = kwargs[source_key] + if not mapped_input: + mapped_input = kwargs + + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name=self.agent_name, + task_type=self.task_type, + priority=0, + input_data=mapped_input, + callback_url=None, + created_at=datetime.now(timezone.utc), + timeout_seconds=self.timeout_seconds, + ) + + await self._dispatcher.dispatch(task) + + # 等待结果 + import asyncio + for _ in range(self.timeout_seconds): + status = await self._dispatcher.get_task_status(task.task_id) + if status["status"] in ("completed", "failed", "cancelled"): + if status["status"] == "completed" and status.get("output_data"): + output = status["output_data"] + # 映射输出 + if self.output_mapping: + mapped_output = {} + for target_key, source_key in self.output_mapping.items(): + if source_key in output: + mapped_output[target_key] = output[source_key] + return mapped_output + return output + elif status["status"] == "failed": + raise RuntimeError(f"Agent '{self.agent_name}' failed: {status.get('error_message')}") + return {} + await asyncio.sleep(1) + + raise TimeoutError(f"Agent '{self.agent_name}' timed out after {self.timeout_seconds}s") diff --git a/src/agentkit/tools/base.py b/src/agentkit/tools/base.py new file mode 100644 index 0000000..7642644 --- /dev/null +++ b/src/agentkit/tools/base.py @@ -0,0 +1,68 @@ +"""Tool 抽象基类 - 统一工具接口""" + +from abc import ABC, abstractmethod +from typing import Any + + +class Tool(ABC): + """工具抽象基类 + + 所有工具(FunctionTool, AgentTool, MCPTool)的统一接口。 + """ + + def __init__( + self, + name: str, + description: str, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + self.name = name + self.description = description + self.input_schema = input_schema + self.output_schema = output_schema + self.version = version + self.tags = tags or [] + + @abstractmethod + async def execute(self, **kwargs) -> dict: + """执行工具,返回结果 dict""" + ... + + async def before_execute(self, **kwargs) -> None: + """执行前钩子""" + pass + + async def after_execute(self, result: dict, **kwargs) -> None: + """执行后钩子""" + pass + + async def on_error(self, error: Exception, **kwargs) -> None: + """错误钩子""" + pass + + async def safe_execute(self, **kwargs) -> dict: + """带钩子的安全执行""" + try: + await self.before_execute(**kwargs) + result = await self.execute(**kwargs) + await self.after_execute(result, **kwargs) + return result + except Exception as e: + await self.on_error(e, **kwargs) + raise + + def to_dict(self) -> dict: + return { + "name": self.name, + "description": self.description, + "input_schema": self.input_schema, + "output_schema": self.output_schema, + "version": self.version, + "tags": self.tags, + } + + def __repr__(self) -> str: + return f"<{type(self).__name__} name={self.name!r} version={self.version}>" diff --git a/src/agentkit/tools/function_tool.py b/src/agentkit/tools/function_tool.py new file mode 100644 index 0000000..92570a9 --- /dev/null +++ b/src/agentkit/tools/function_tool.py @@ -0,0 +1,76 @@ +"""FunctionTool - 将普通 Python 函数包装为 Tool""" + +import inspect +from typing import Any, Callable, Awaitable + +from agentkit.tools.base import Tool + + +class FunctionTool(Tool): + """将普通 Python 函数包装为 Tool + + 自动从函数签名推断 input_schema。 + """ + + def __init__( + self, + name: str, + description: str, + func: Callable[..., Awaitable[dict]] | Callable[..., dict], + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._infer_schema(func), + output_schema=output_schema, + version=version, + tags=tags, + ) + self._func = func + + async def execute(self, **kwargs) -> dict: + result = self._func(**kwargs) + if inspect.isawaitable(result): + result = await result + if not isinstance(result, dict): + result = {"result": result} + return result + + @staticmethod + def _infer_schema(func: Callable) -> dict: + """从函数签名推断 JSON Schema""" + sig = inspect.signature(func) + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + if param_name in ("self", "cls"): + continue + + param_type = "string" + if param.annotation != inspect.Parameter.empty: + if param.annotation in (int, float): + param_type = "number" + elif param.annotation == bool: + param_type = "boolean" + elif param.annotation in (list, tuple): + param_type = "array" + elif param.annotation == dict: + param_type = "object" + + properties[param_name] = {"type": param_type} + + if param.default == inspect.Parameter.empty: + required.append(param_name) + + schema = { + "type": "object", + "properties": properties, + } + if required: + schema["required"] = required + return schema diff --git a/src/agentkit/tools/registry.py b/src/agentkit/tools/registry.py new file mode 100644 index 0000000..bc0369f --- /dev/null +++ b/src/agentkit/tools/registry.py @@ -0,0 +1,72 @@ +"""ToolRegistry - 工具注册中心""" + +import logging +from typing import Any + +from agentkit.core.exceptions import ToolNotFoundError +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class ToolRegistry: + """工具注册中心,管理工具的注册、发现、版本""" + + def __init__(self): + self._tools: dict[str, dict[str, Tool]] = {} # name -> {version -> tool} + + def register(self, tool: Tool) -> "ToolRegistry": + """注册工具""" + if tool.name not in self._tools: + self._tools[tool.name] = {} + self._tools[tool.name][tool.version] = tool + logger.info(f"Tool '{tool.name}' v{tool.version} registered") + return self + + def unregister(self, name: str, version: str | None = None) -> None: + """注销工具""" + if name not in self._tools: + return + if version: + self._tools[name].pop(version, None) + if not self._tools[name]: + del self._tools[name] + else: + del self._tools[name] + + def get(self, name: str, version: str | None = None) -> Tool: + """获取工具(默认返回最新版本)""" + if name not in self._tools: + raise ToolNotFoundError(name) + + versions = self._tools[name] + if version: + if version not in versions: + raise ToolNotFoundError(f"{name}@{version}") + return versions[version] + + # 返回最新版本 + latest = sorted(versions.keys())[-1] + return versions[latest] + + def list_tools(self, tag: str | None = None) -> list[Tool]: + """列出所有工具(最新版本),可按标签过滤""" + result = [] + for name, versions in self._tools.items(): + latest = sorted(versions.keys())[-1] + tool = versions[latest] + if tag is None or tag in tool.tags: + result.append(tool) + return result + + def list_all_versions(self, name: str) -> list[Tool]: + """列出指定工具的所有版本""" + if name not in self._tools: + return [] + return list(self._tools[name].values()) + + def has_tool(self, name: str) -> bool: + return name in self._tools + + def clear(self) -> None: + self._tools.clear() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_base_agent.py b/tests/unit/test_base_agent.py new file mode 100644 index 0000000..9795ca7 --- /dev/null +++ b/tests/unit/test_base_agent.py @@ -0,0 +1,139 @@ +"""Tests for BaseAgent - 统一生命周期""" + +import asyncio +import pytest + +from agentkit.core.base import BaseAgent +from agentkit.core.protocol import ( + AgentCapability, + AgentStatus, + TaskMessage, + TaskResult, + TaskStatus, +) +from datetime import datetime, timezone + + +class SimpleAgent(BaseAgent): + """测试用简单 Agent""" + + def __init__(self): + super().__init__(name="test_agent", agent_type="test", version="1.0.0") + self.task_started = False + self.task_completed = False + self.task_failed = False + + async def handle_task(self, task: TaskMessage) -> dict: + if task.task_type == "echo": + return {"echo": task.input_data} + elif task.task_type == "fail": + raise ValueError("intentional failure") + return {"status": "ok"} + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["echo", "fail"], + max_concurrency=2, + description="Test agent", + ) + + async def on_task_start(self, task): + self.task_started = True + + async def on_task_complete(self, task, output): + self.task_completed = True + + async def on_task_failed(self, task, error): + self.task_failed = True + + +def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage: + return TaskMessage( + task_id="test-001", + agent_name="test_agent", + task_type=task_type, + priority=0, + input_data=input_data or {}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + +@pytest.mark.asyncio +async def test_handle_task_returns_output(): + agent = SimpleAgent() + task = _make_task("echo", {"msg": "hello"}) + result = await agent.execute(task) + + assert result.status == TaskStatus.COMPLETED + assert result.output_data == {"echo": {"msg": "hello"}} + assert result.error_message is None + assert result.metrics["task_type"] == "echo" + + +@pytest.mark.asyncio +async def test_handle_task_failure(): + agent = SimpleAgent() + task = _make_task("fail") + result = await agent.execute(task) + + assert result.status == TaskStatus.FAILED + assert result.error_message == "intentional failure" + assert result.metrics["error_type"] == "ValueError" + + +@pytest.mark.asyncio +async def test_lifecycle_hooks(): + agent = SimpleAgent() + + # 成功任务 + task = _make_task("echo") + await agent.execute(task) + assert agent.task_started is True + assert agent.task_completed is True + + # 重置 + agent.task_started = False + agent.task_completed = False + + # 失败任务 + task = _make_task("fail") + await agent.execute(task) + assert agent.task_started is True + assert agent.task_failed is True + + +@pytest.mark.asyncio +async def test_execute_wraps_timing(): + agent = SimpleAgent() + task = _make_task("echo") + result = await agent.execute(task) + + assert result.started_at is not None + assert result.completed_at is not None + assert result.metrics["elapsed_seconds"] >= 0 + + +@pytest.mark.asyncio +async def test_agent_status(): + agent = SimpleAgent() + assert agent.status == AgentStatus.OFFLINE + assert agent.is_distributed is False + + +@pytest.mark.asyncio +async def test_tool_injection(): + from agentkit.tools.function_tool import FunctionTool + + async def my_tool(x: int) -> dict: + return {"doubled": x * 2} + + tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool) + agent = SimpleAgent() + agent.use_tool(tool) + + assert len(agent.tools) == 1 + assert agent.tools[0].name == "doubler" diff --git a/tests/unit/test_evolution.py b/tests/unit/test_evolution.py new file mode 100644 index 0000000..0271fa7 --- /dev/null +++ b/tests/unit/test_evolution.py @@ -0,0 +1,131 @@ +"""Tests for Evolution system""" + +import pytest + +from agentkit.evolution.reflector import Reflector, Reflection +from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module +from agentkit.evolution.strategy_tuner import StrategyTuner, StrategyConfig +from agentkit.evolution.ab_tester import ABTester, ABTestConfig +from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus +from datetime import datetime, timezone + + +def _make_task() -> TaskMessage: + return TaskMessage( + task_id="test-001", + agent_name="test", + task_type="echo", + priority=0, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + +def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult: + return TaskResult( + task_id="test-001", + agent_name="test", + status=status, + output_data={"key": "value"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + metrics={"elapsed_seconds": 5.0}, + ) + + +@pytest.mark.asyncio +async def test_reflector_success(): + reflector = Reflector() + task = _make_task() + result = _make_result() + + reflection = await reflector.reflect(task, result) + assert reflection.outcome == "success" + assert reflection.quality_score > 0 + + +@pytest.mark.asyncio +async def test_reflector_failure(): + reflector = Reflector() + task = _make_task() + result = _make_result(TaskStatus.FAILED) + result.error_message = "something went wrong" + + reflection = await reflector.reflect(task, result) + assert reflection.outcome == "failure" + assert reflection.quality_score == 0.0 + + +@pytest.mark.asyncio +async def test_prompt_optimizer(): + optimizer = PromptOptimizer(max_demos=3, min_examples_for_optimization=2) + + # Add examples + for i in range(5): + optimizer.add_example( + input_data={"query": f"query_{i}"}, + output_data={"result": f"result_{i}"}, + quality_score=0.8 + i * 0.02, + ) + + module = Module( + name="test_module", + signature=Signature( + input_fields={"query": "search query"}, + output_fields={"result": "search result"}, + instruction="Find the best result.", + ), + ) + + optimized = await optimizer.optimize(module) + assert optimized.name == "test_module_optimized" + assert len(optimized.demos) == 3 + + +@pytest.mark.asyncio +async def test_prompt_optimizer_not_enough_examples(): + optimizer = PromptOptimizer(min_examples_for_optimization=10) + module = Module( + name="test", + signature=Signature( + input_fields={"x": "input"}, + output_fields={"y": "output"}, + ), + ) + + optimized = await optimizer.optimize(module) + # Should return unchanged module + assert optimized.name == "test" + + +def test_strategy_tuner(): + tuner = StrategyTuner() + + config = StrategyConfig(temperature=0.5) + tuner.record(config, metric=0.6) + tuner.record(StrategyConfig(temperature=0.7), metric=0.8) + tuner.record(StrategyConfig(temperature=0.3), metric=0.4) + + +@pytest.mark.asyncio +async def test_ab_tester(): + tester = ABTester() + test_config = ABTestConfig( + test_id="test-1", + agent_name="test_agent", + change_type="prompt", + min_samples=5, + ) + tester.create_test(test_config) + + # Record results + for _ in range(10): + group = tester.assign_group("test-1") + metric = 0.7 if group == "experiment" else 0.5 + tester.record_result("test-1", group, metric) + + result = await tester.evaluate("test-1") + assert result is not None + assert result.control_samples + result.experiment_samples == 10 diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py new file mode 100644 index 0000000..f66b827 --- /dev/null +++ b/tests/unit/test_pipeline.py @@ -0,0 +1,109 @@ +"""Tests for Pipeline system""" + +import pytest + +from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage +from agentkit.orchestrator.pipeline_engine import PipelineEngine +from agentkit.orchestrator.pipeline_loader import PipelineLoader + + +def test_pipeline_schema(): + stage = PipelineStage( + name="step1", + agent="agent_a", + action="process", + depends_on=[], + inputs={"data": "${input}"}, + ) + pipeline = Pipeline( + name="test_pipeline", + version="1.0.0", + description="Test", + stages=[stage], + variables={"input": "hello"}, + ) + assert len(pipeline.stages) == 1 + assert pipeline.variables["input"] == "hello" + + +def test_topological_group(): + stages = [ + PipelineStage(name="a", agent="agent1", action="do_a", depends_on=[]), + PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]), + PipelineStage(name="c", agent="agent3", action="do_c", depends_on=["a"]), + PipelineStage(name="d", agent="agent4", action="do_d", depends_on=["b", "c"]), + ] + + groups = PipelineEngine._topological_group(stages) + assert len(groups) == 3 # [a], [b, c], [d] + assert groups[0][0].name == "a" + assert {s.name for s in groups[1]} == {"b", "c"} + assert groups[2][0].name == "d" + + +def test_topological_group_circular(): + stages = [ + PipelineStage(name="a", agent="agent1", action="do_a", depends_on=["b"]), + PipelineStage(name="b", agent="agent2", action="do_b", depends_on=["a"]), + ] + + with pytest.raises(ValueError, match="Circular dependency"): + PipelineEngine._topological_group(stages) + + +@pytest.mark.asyncio +async def test_pipeline_dry_run(): + pipeline = Pipeline( + name="test", + version="1.0.0", + description="Test", + stages=[ + PipelineStage(name="step1", agent="agent1", action="process", inputs={"x": "1"}), + PipelineStage(name="step2", agent="agent2", action="analyze", depends_on=["step1"], inputs={"y": "2"}), + ], + ) + + engine = PipelineEngine(dispatcher=None) # dry-run mode + result = await engine.execute(pipeline) + assert result.status.value == "completed" + + +def test_yaml_loader(): + yaml_content = """ +name: test_pipeline +version: "1.0" +description: Test pipeline +stages: + - name: step1 + agent: agent1 + action: process + inputs: + data: hello + - name: step2 + agent: agent2 + action: analyze + depends_on: [step1] + inputs: + result: ${{step1_output}} +variables: + input: world +""" + loader = PipelineLoader() + pipeline = loader.load_from_yaml(yaml_content, "test") + assert pipeline.name == "test_pipeline" + assert len(pipeline.stages) == 2 + assert pipeline.stages[1].depends_on == ["step1"] + + +def test_validate_dag(): + stages = [ + PipelineStage(name="a", agent="a1", action="do", depends_on=[]), + PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]), + ] + assert PipelineLoader.validate_dag(stages) is True + + circular = [ + PipelineStage(name="a", agent="a1", action="do", depends_on=["b"]), + PipelineStage(name="b", agent="a2", action="do", depends_on=["a"]), + ] + assert PipelineLoader.validate_dag(circular) is False diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py new file mode 100644 index 0000000..84f520e --- /dev/null +++ b/tests/unit/test_protocol.py @@ -0,0 +1,95 @@ +"""Tests for Protocol data structures""" + +import pytest +from datetime import datetime + +from agentkit.core.protocol import ( + AgentCapability, + HandoffMessage, + TaskMessage, + TaskResult, + TaskStatus, + EvolutionEvent, +) + + +def test_task_status_values(): + assert TaskStatus.PENDING == "pending" + assert TaskStatus.RUNNING == "running" + assert TaskStatus.COMPLETED == "completed" + assert TaskStatus.FAILED == "failed" + assert TaskStatus.CANCELLED == "cancelled" + assert TaskStatus.HANDOFF == "handoff" + + +def test_agent_capability_with_schema(): + cap = AgentCapability( + agent_name="test", + agent_type="test", + version="1.0.0", + supported_tasks=["echo"], + max_concurrency=2, + description="Test", + input_schema={"type": "object", "properties": {"x": {"type": "number"}}}, + output_schema={"type": "object"}, + ) + + d = cap.to_dict() + assert "input_schema" in d + assert "output_schema" in d + + restored = AgentCapability.from_dict(d) + assert restored.agent_name == "test" + assert restored.input_schema is not None + + +def test_task_message_roundtrip(): + msg = TaskMessage( + task_id="123", + agent_name="agent1", + task_type="echo", + priority=1, + input_data={"key": "value"}, + callback_url=None, + created_at=datetime.utcnow(), + conversation_id="conv-1", + ) + + d = msg.to_dict() + assert d["conversation_id"] == "conv-1" + + restored = TaskMessage.from_dict(d) + assert restored.task_id == "123" + assert restored.conversation_id == "conv-1" + + +def test_handoff_message(): + msg = HandoffMessage( + source_agent="agent_a", + target_agent="agent_b", + task_id="task-1", + task_type="analyze", + context={"data": "test"}, + reason="Need expert analysis", + ) + + d = msg.to_dict() + assert d["source_agent"] == "agent_a" + assert d["target_agent"] == "agent_b" + + restored = HandoffMessage.from_dict(d) + assert restored.reason == "Need expert analysis" + + +def test_evolution_event(): + event = EvolutionEvent( + agent_name="optimizer", + change_type="prompt", + before={"instruction": "old"}, + after={"instruction": "new"}, + metrics={"quality_delta": 0.15}, + ) + + d = event.to_dict() + assert d["change_type"] == "prompt" + assert d["metrics"]["quality_delta"] == 0.15 diff --git a/tests/unit/test_tool_registry.py b/tests/unit/test_tool_registry.py new file mode 100644 index 0000000..6c04cf6 --- /dev/null +++ b/tests/unit/test_tool_registry.py @@ -0,0 +1,104 @@ +"""Tests for Tool system""" + +import asyncio +import pytest + +from agentkit.tools.base import Tool +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.registry import ToolRegistry + + +async def add_numbers(a: int, b: int) -> dict: + return {"sum": a + b} + + +def sync_greet(name: str) -> dict: + return {"greeting": f"Hello, {name}!"} + + +@pytest.mark.asyncio +async def test_function_tool_async(): + tool = FunctionTool(name="add", description="Add numbers", func=add_numbers) + result = await tool.execute(a=1, b=2) + assert result == {"sum": 3} + + +@pytest.mark.asyncio +async def test_function_tool_sync(): + tool = FunctionTool(name="greet", description="Greet someone", func=sync_greet) + result = await tool.execute(name="World") + assert result == {"greeting": "Hello, World!"} + + +@pytest.mark.asyncio +async def test_function_tool_schema_inference(): + tool = FunctionTool(name="add", description="Add numbers", func=add_numbers) + assert tool.input_schema is not None + assert "a" in tool.input_schema.get("properties", {}) + assert "b" in tool.input_schema.get("properties", {}) + + +@pytest.mark.asyncio +async def test_tool_registry(): + registry = ToolRegistry() + tool = FunctionTool(name="add", description="Add numbers", func=add_numbers) + registry.register(tool) + + assert registry.has_tool("add") + retrieved = registry.get("add") + assert retrieved.name == "add" + + +@pytest.mark.asyncio +async def test_tool_registry_versioning(): + registry = ToolRegistry() + + v1 = FunctionTool(name="add", description="Add v1", func=add_numbers, version="1.0.0") + v2 = FunctionTool(name="add", description="Add v2", func=add_numbers, version="2.0.0") + + registry.register(v1) + registry.register(v2) + + # Default returns latest + latest = registry.get("add") + assert latest.version == "2.0.0" + + # Can request specific version + specific = registry.get("add", version="1.0.0") + assert specific.version == "1.0.0" + + +@pytest.mark.asyncio +async def test_tool_registry_list(): + registry = ToolRegistry() + + t1 = FunctionTool(name="add", description="Add", func=add_numbers, tags=["math"]) + t2 = FunctionTool(name="greet", description="Greet", func=sync_greet, tags=["text"]) + + registry.register(t1) + registry.register(t2) + + all_tools = registry.list_tools() + assert len(all_tools) == 2 + + math_tools = registry.list_tools(tag="math") + assert len(math_tools) == 1 + assert math_tools[0].name == "add" + + +@pytest.mark.asyncio +async def test_tool_safe_execute(): + async def failing_tool(): + raise RuntimeError("boom") + + tool = FunctionTool(name="fail", description="Always fails", func=failing_tool) + with pytest.raises(RuntimeError): + await tool.safe_execute() + + +@pytest.mark.asyncio +async def test_tool_not_found(): + registry = ToolRegistry() + from agentkit.core.exceptions import ToolNotFoundError + with pytest.raises(ToolNotFoundError): + registry.get("nonexistent")