feat(agentkit): Phase 3 upgrade - persistence, memory, evolution, observability
10 Implementation Units across 3 phases: Phase A - Infrastructure: - U1: RedisTaskStore with Redis/memory backend + factory function - U2: TraceRecorder for execution trace recording - U3: PersistentEvolutionStore with SQLite backend Phase B - Core Capabilities: - U4: MemoryRetriever integration into ReAct engine - U5: Embedder abstraction + EpisodicMemory vector search - U6: LLMReflector for LLM-in-the-loop reflection - U7: SkillPipeline for multi-skill orchestration Phase C - Enhancement: - U8: SKILL.md format + progressive disclosure levels - U9: ContextCompressor + prompt cache rendering - U10: Structured logging + metrics endpoint + enhanced health check Tests: 924 passed, 18 skipped, 0 failed
This commit is contained in:
parent
74e2223153
commit
f858d279f3
|
|
@ -18,6 +18,7 @@ COPY --from=builder /install /usr/local
|
|||
|
||||
COPY pyproject.toml README.md ./
|
||||
COPY src/ ./src/
|
||||
COPY configs/ ./configs/
|
||||
|
||||
RUN addgroup --system --gid 1001 appuser \
|
||||
&& adduser --system --uid 1001 appuser \
|
||||
|
|
@ -30,5 +31,4 @@ EXPOSE 8001
|
|||
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8001/api/v1/health')"
|
||||
|
||||
ENTRYPOINT ["agentkit"]
|
||||
CMD ["serve", "--host", "0.0.0.0", "--port", "8001"]
|
||||
CMD ["uvicorn", "configs.geo_server:create_geo_app", "--factory", "--host", "0.0.0.0", "--port", "8001"]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,9 @@ from agentkit.core.protocol import TaskMessage
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEO_BACKEND_URL = os.getenv("GEO_BACKEND_URL", "http://localhost:8000")
|
||||
INTERNAL_API_TOKEN = os.getenv("INTERNAL_API_TOKEN", "")
|
||||
INTERNAL_API_TOKEN = os.getenv("INTERNAL_API_TOKEN")
|
||||
if not INTERNAL_API_TOKEN:
|
||||
logger.warning("INTERNAL_API_TOKEN not set — callbacks to GEO Backend will fail")
|
||||
|
||||
|
||||
def _internal_headers() -> dict:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,379 @@
|
|||
# GEO 系统与 AgentKit 联通指南
|
||||
|
||||
## 一、AgentKit 是什么
|
||||
|
||||
AgentKit 是一个**统一 Agent 开发框架**,核心能力:
|
||||
|
||||
| 能力 | 说明 |
|
||||
|------|------|
|
||||
| **ReAct 推理引擎** | Think → Act → Observe 循环,LLM 自主选择工具、决定何时输出 |
|
||||
| **LLM Gateway** | 统一 LLM 调用入口,管理 API Key、模型路由、降级策略、用量统计 |
|
||||
| **Skill 系统** | YAML 配置定义技能(Prompt + Tool + 质量门禁),无需写代码 |
|
||||
| **意图路由** | 关键词匹配(零成本)+ LLM 分类(兜底),自动路由到最佳 Skill |
|
||||
| **产出质量管理** | 必填字段、最低字数、Schema 校验、自定义验证器,不通过自动重试 |
|
||||
| **标准化输出** | Schema 验证 + 类型归一化 + 元数据附加,所有 Skill 产出格式统一 |
|
||||
| **记忆系统** | 语义记忆(pgvector)+ 情景记忆(Redis)+ 工作记忆 |
|
||||
| **MCP 协议** | 支持 Model Context Protocol,可连接外部工具服务器 |
|
||||
| **CLI 工具** | `agentkit` 命令行,支持 init/serve/task/skill/pair/doctor/usage |
|
||||
| **独立部署** | FastAPI Server + Docker,业务系统通过 HTTP API 调用 |
|
||||
|
||||
**一句话总结**:AgentKit 让你从写 150 行 Agent 代码降为 10-20 行 YAML 配置。
|
||||
|
||||
---
|
||||
|
||||
## 二、架构关系
|
||||
|
||||
```
|
||||
┌──────────────────────┐ HTTP API ┌──────────────────────────┐
|
||||
│ GEO Backend │ ───────────────→ │ AgentKit Server │
|
||||
│ (FastAPI :8000) │ │ (FastAPI :8001) │
|
||||
│ │ POST /tasks │ │
|
||||
│ 不再 import │ GET /tasks/{id} │ Intent Router │
|
||||
│ agentkit 内部类 │ GET /skills │ ReAct Engine │
|
||||
│ │ GET /llm/usage │ LLM Gateway │
|
||||
│ 只用 AgentKitClient │ │ Quality Gate │
|
||||
│ │ ←── callback ─── │ Output Standardizer │
|
||||
│ /internal/* API │ (custom_handler) │ AgentPool + SkillRegistry│
|
||||
└──────────────────────┘ └──────────────────────────┘
|
||||
│
|
||||
┌─────┴─────┐
|
||||
│ LLM APIs │
|
||||
│ (DeepSeek │
|
||||
│ OpenAI…) │
|
||||
└───────────┘
|
||||
```
|
||||
|
||||
**关键原则**:
|
||||
- GEO Backend **不 import agentkit 内部类**,只通过 HTTP API 调用
|
||||
- AgentKit Server **不直接访问 GEO 数据库**,需要 DB 时回调 GEO 的内部 API
|
||||
- LLM API Key **只在 AgentKit Server 中配置**,GEO 不需要
|
||||
|
||||
---
|
||||
|
||||
## 三、联通步骤
|
||||
|
||||
### Step 1:部署 AgentKit Server
|
||||
|
||||
```bash
|
||||
cd fischer-agentkit
|
||||
|
||||
# 初始化配置
|
||||
agentkit init
|
||||
|
||||
# 编辑 .env,填入 LLM API Key
|
||||
cp .env.example .env
|
||||
# DEEPSEEK_API_KEY=sk-xxx
|
||||
# OPENAI_API_KEY=sk-xxx
|
||||
|
||||
# 配对 GEO 业务系统
|
||||
agentkit pair --name geo-backend --skills-dir ./configs/skills
|
||||
# 输出: API Key = ak_live_xxxxxxxxxxxx
|
||||
|
||||
# 启动 Server
|
||||
agentkit serve --host 0.0.0.0 --port 8001
|
||||
|
||||
# 验证
|
||||
agentkit doctor
|
||||
```
|
||||
|
||||
### Step 2:GEO Backend 配置环境变量
|
||||
|
||||
在 GEO 的 `.env` 中添加:
|
||||
|
||||
```bash
|
||||
# AgentKit Server 连接
|
||||
AGENTKIT_SERVER_URL=http://localhost:8001
|
||||
AGENTKIT_API_KEY=ak_live_xxxxxxxxxxxx # Step 1 中 pair 生成的 key
|
||||
```
|
||||
|
||||
### Step 3:改造 GEO 的 agent_framework 适配层
|
||||
|
||||
将 `app/agent_framework/adapter.py` 从 import 模式改为 HTTP API 模式:
|
||||
|
||||
```python
|
||||
# app/agent_framework/adapter.py — Mode A 版本
|
||||
import os
|
||||
import logging
|
||||
from agentkit.server.client import AgentKitClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_CLIENT: AgentKitClient | None = None
|
||||
|
||||
def get_agentkit_client() -> AgentKitClient:
|
||||
"""获取 AgentKit Server HTTP 客户端"""
|
||||
global _CLIENT
|
||||
if _CLIENT is None:
|
||||
base_url = os.getenv("AGENTKIT_SERVER_URL", "http://localhost:8001")
|
||||
api_key = os.getenv("AGENTKIT_API_KEY")
|
||||
_CLIENT = AgentKitClient(base_url=base_url, api_key=api_key)
|
||||
return _CLIENT
|
||||
|
||||
async def submit_task(input_data: dict, skill_name: str | None = None) -> dict:
|
||||
"""提交任务到 AgentKit Server"""
|
||||
client = get_agentkit_client()
|
||||
return await client.submit_task(input_data=input_data, skill_name=skill_name)
|
||||
|
||||
async def get_task_status(task_id: str) -> dict:
|
||||
"""查询任务状态"""
|
||||
client = get_agentkit_client()
|
||||
return await client.get_task_status(task_id)
|
||||
|
||||
async def get_llm_usage(agent_name: str | None = None) -> dict:
|
||||
"""查询 LLM 用量"""
|
||||
client = get_agentkit_client()
|
||||
return await client.get_usage(agent_name=agent_name)
|
||||
```
|
||||
|
||||
### Step 4:改造业务调用
|
||||
|
||||
**内容生成**(原来 3 次 dispatch → 1 次 submit_task):
|
||||
|
||||
```python
|
||||
# 改造前
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||||
task = TaskMessage(agent_name="content_generator", ...)
|
||||
result = await dispatcher.dispatch(task, ...)
|
||||
|
||||
# 改造后
|
||||
from app.agent_framework.adapter import submit_task
|
||||
result = await submit_task(
|
||||
input_data={"target_keyword": keyword, "brand_name": brand, ...},
|
||||
skill_name="content_generator",
|
||||
)
|
||||
content = result["data"]["content"]
|
||||
```
|
||||
|
||||
**引用检测**:
|
||||
|
||||
```python
|
||||
# 改造前
|
||||
from app.agent_framework.agents import CitationDetectorAgent
|
||||
agent = CitationDetectorAgent()
|
||||
result = await agent.execute(task)
|
||||
|
||||
# 改造后
|
||||
from app.agent_framework.adapter import submit_task
|
||||
result = await submit_task(
|
||||
input_data={"keyword": keyword, "platform": platform, ...},
|
||||
skill_name="citation_detector",
|
||||
)
|
||||
```
|
||||
|
||||
### Step 5:新增内部 API(供 AgentKit Server 回调)
|
||||
|
||||
custom_handler 需要 DB 访问时,AgentKit Server 通过 HTTP 回调 GEO:
|
||||
|
||||
```python
|
||||
# app/api/internal.py
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/internal", tags=["internal"])
|
||||
|
||||
@router.post("/citation/detect")
|
||||
async def citation_detect(input_data: dict, db: AsyncSession = Depends(get_db)):
|
||||
"""供 AgentKit Server 的 citation_handler 回调"""
|
||||
from app.services.citation.citation import CitationService
|
||||
service = CitationService()
|
||||
return await service.detect_full(input_data, db=db)
|
||||
|
||||
@router.post("/knowledge/search")
|
||||
async def knowledge_search(input_data: dict, db: AsyncSession = Depends(get_db)):
|
||||
"""供 AgentKit Server 的 retrieve_knowledge Tool 回调"""
|
||||
from app.services.knowledge.rag_service import RAGService
|
||||
service = RAGService()
|
||||
results = await service.search(session=db, query=input_data["query"])
|
||||
return {"results": results}
|
||||
```
|
||||
|
||||
### Step 6:Docker Compose 联合部署
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
version: "3.8"
|
||||
services:
|
||||
geo-backend:
|
||||
build: ./geo/backend
|
||||
ports: ["8000:8000"]
|
||||
environment:
|
||||
- AGENTKIT_SERVER_URL=http://agentkit-server:8001
|
||||
- AGENTKIT_API_KEY=${AGENTKIT_API_KEY}
|
||||
depends_on:
|
||||
- agentkit-server
|
||||
|
||||
agentkit-server:
|
||||
build: ./fischer-agentkit
|
||||
command: serve --host 0.0.0.0 --port 8001
|
||||
ports: ["8001:8001"]
|
||||
env_file: ./fischer-agentkit/.env
|
||||
environment:
|
||||
- GEO_BACKEND_URL=http://geo-backend:8000
|
||||
depends_on:
|
||||
- redis
|
||||
- postgres
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg15
|
||||
environment:
|
||||
POSTGRES_USER: agentkit
|
||||
POSTGRES_PASSWORD: agentkit
|
||||
POSTGRES_DB: agentkit
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 四、GEO 当前 8 个 Skill 映射
|
||||
|
||||
| 原 Agent 名 | Skill 名 | 模式 | 改造要点 |
|
||||
|-------------|---------|------|---------|
|
||||
| citation_detector | citation_detector | custom | handler 回调 GEO `/internal/citation/detect` |
|
||||
| monitor | monitor | custom | handler 回调 GEO `/internal/monitor/check` |
|
||||
| schema_advisor | schema_advisor | custom | handler 回调 GEO `/internal/schema/advise` |
|
||||
| content_generator | content_generator | llm_generate | 直接迁移 YAML,添加 intent + quality_gate |
|
||||
| deai_agent | deai_agent | llm_generate | 直接迁移 YAML |
|
||||
| geo_optimizer | geo_optimizer | llm_generate | 直接迁移 YAML |
|
||||
| competitor_analyzer | competitor_analyzer | tool_call | Tool 迁移到 AgentKit Server |
|
||||
| trend_agent | trend_agent | tool_call | Tool 迁移到 AgentKit Server |
|
||||
|
||||
**YAML 零修改**:现有 8 个 YAML 配置无需修改即可被 AgentKit 加载(SkillConfig 向后兼容 AgentConfig)。建议为 llm_generate 模式的 Skill 添加 `intent` 和 `quality_gate` 字段以启用新能力。
|
||||
|
||||
---
|
||||
|
||||
## 五、API 参考
|
||||
|
||||
### AgentKit Server REST API
|
||||
|
||||
| 路径 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `POST /api/v1/tasks` | POST | 提交任务(支持意图路由自动匹配 Skill) |
|
||||
| `GET /api/v1/tasks/{id}` | GET | 查询任务状态和结果 |
|
||||
| `GET /api/v1/tasks` | GET | 列出任务 |
|
||||
| `DELETE /api/v1/tasks/{id}` | DELETE | 取消任务 |
|
||||
| `POST /api/v1/agents` | POST | 创建 Agent 实例 |
|
||||
| `GET /api/v1/agents` | GET | 列出 Agent 实例 |
|
||||
| `POST /api/v1/skills` | POST | 注册 Skill |
|
||||
| `GET /api/v1/skills` | GET | 列出已注册 Skill |
|
||||
| `GET /api/v1/llm/usage` | GET | 查询 LLM 用量统计 |
|
||||
| `GET /api/v1/health` | GET | 健康检查 |
|
||||
|
||||
### 认证
|
||||
|
||||
所有 API 请求需携带 Header:
|
||||
|
||||
```
|
||||
X-API-Key: ak_live_xxxxxxxxxxxx
|
||||
```
|
||||
|
||||
### 提交任务示例
|
||||
|
||||
```bash
|
||||
# 指定 Skill
|
||||
curl -X POST http://localhost:8001/api/v1/tasks \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-API-Key: ak_live_xxxxxxxxxxxx" \
|
||||
-d '{
|
||||
"skill_name": "content_generator",
|
||||
"input_data": {"target_keyword": "AI", "brand_name": "BrandX"}
|
||||
}'
|
||||
|
||||
# 意图路由自动匹配
|
||||
curl -X POST http://localhost:8001/api/v1/tasks \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "X-API-Key: ak_live_xxxxxxxxxxxx" \
|
||||
-d '{
|
||||
"input_data": {"query": "帮我生成一篇关于AI的文章"}
|
||||
}'
|
||||
```
|
||||
|
||||
### Python SDK
|
||||
|
||||
```python
|
||||
from agentkit.server.client import AgentKitClient
|
||||
|
||||
client = AgentKitClient(
|
||||
base_url="http://localhost:8001",
|
||||
api_key="ak_live_xxxxxxxxxxxx",
|
||||
)
|
||||
|
||||
# 提交任务
|
||||
result = await client.submit_task(
|
||||
skill_name="content_generator",
|
||||
input_data={"target_keyword": "AI", "brand_name": "BrandX"},
|
||||
)
|
||||
|
||||
# 查询用量
|
||||
usage = await client.get_usage()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 六、CLI 速查
|
||||
|
||||
```bash
|
||||
agentkit init # 初始化项目配置
|
||||
agentkit serve --port 8001 # 启动 Server
|
||||
agentkit doctor # 诊断健康状态
|
||||
agentkit version # 查看版本
|
||||
|
||||
agentkit pair --name geo-backend # 配对业务系统,生成 API Key
|
||||
agentkit pair --list # 查看已配对客户端
|
||||
agentkit pair --revoke geo-backend # 撤销配对
|
||||
|
||||
agentkit task submit --skill content_generator --input '{"topic":"AI"}' --server-url http://localhost:8001
|
||||
agentkit task status <task_id> --server-url http://localhost:8001
|
||||
agentkit task list --server-url http://localhost:8001
|
||||
|
||||
agentkit skill list --server-url http://localhost:8001
|
||||
agentkit skill load ./my_skill.yaml
|
||||
agentkit skill info content_generator --server-url http://localhost:8001
|
||||
|
||||
agentkit usage --server-url http://localhost:8001
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 七、迁移检查清单
|
||||
|
||||
### Phase 1:AgentKit Server 部署
|
||||
- [ ] `agentkit init` 生成配置
|
||||
- [ ] `.env` 填入 LLM API Key
|
||||
- [ ] `agentkit pair --name geo-backend` 生成 API Key
|
||||
- [ ] 8 个 YAML 配置复制到 `configs/skills/`
|
||||
- [ ] 14 个 FunctionTool 迁移到 `configs/geo_tools.py`
|
||||
- [ ] 3 个 custom_handler 迁移到 `configs/geo_handlers.py`
|
||||
- [ ] `agentkit serve` 启动成功
|
||||
- [ ] `agentkit doctor` 健康检查通过
|
||||
|
||||
### Phase 2:GEO Backend 改造
|
||||
- [ ] `.env` 添加 `AGENTKIT_SERVER_URL` + `AGENTKIT_API_KEY`
|
||||
- [ ] `adapter.py` 改为 HTTP API 模式
|
||||
- [ ] `content_generation_service.py` 改用 `submit_task()`
|
||||
- [ ] `citation.py` 改用 `submit_task()`
|
||||
- [ ] `scheduler.py` 改用 `submit_task()`
|
||||
- [ ] 新增 `/internal/*` API 路由
|
||||
- [ ] 端到端测试通过
|
||||
|
||||
### Phase 3:清理
|
||||
- [ ] 删除旧框架文件(base.py, dispatcher.py, registry.py 等)
|
||||
- [ ] 删除旧 Agent 类
|
||||
- [ ] 更新 `__init__.py` 导出
|
||||
- [ ] 全量回归测试
|
||||
|
||||
---
|
||||
|
||||
## 八、配置优先级
|
||||
|
||||
```
|
||||
客户端自定义配置(pair 时 --skills-dir 指定)
|
||||
↓ 覆盖
|
||||
init 默认配置(agentkit.yaml)
|
||||
↓ 覆盖
|
||||
硬编码默认值
|
||||
```
|
||||
|
||||
业务系统可以通过 `agentkit pair --name geo-backend --skills-dir ./custom_skills` 指定自己的 Skill 目录,优先级高于 AgentKit Server 的默认配置。
|
||||
|
|
@ -0,0 +1,625 @@
|
|||
---
|
||||
title: "feat: AgentKit Phase 3 — 持久化·记忆·进化·技能·可观测性升级"
|
||||
status: active
|
||||
created: 2026-06-06
|
||||
plan_type: feat
|
||||
depth: deep
|
||||
origin: Hermes Agent 对比分析 + 5 大问题评估
|
||||
branch: feat/agentkit-phase3-upgrade
|
||||
---
|
||||
|
||||
# AgentKit Phase 3 升级计划
|
||||
|
||||
## Summary
|
||||
|
||||
基于 Hermes Agent 对标分析和 AgentKit 现状评估,本计划解决 5 个核心问题:无法持久运行、记忆系统未接入、进化架构断层、技能能力不足、缺乏可观测性。覆盖 P0+P1+P2 共 10 项升级,分 3 个交付阶段实施,保持主干代码不变,在 `feat/agentkit-phase3-upgrade` 分支开发。
|
||||
|
||||
## Problem Frame
|
||||
|
||||
AgentKit 当前是一个"有框架但未接入"的状态:
|
||||
|
||||
- **持久化断层**:docker-compose 配置了 Redis + PostgreSQL,但 TaskStore 纯内存,进程重启丢失所有状态
|
||||
- **记忆断层**:三层记忆架构设计完整,但 Agent 循环中零记忆调用,ReActEngine 不读写记忆
|
||||
- **进化断层**:EvolutionConfig 定义了配置但 EvolutionMixin 不读取,Reflector 基于硬编码规则,A/B 测试数据伪造
|
||||
- **技能断层**:Skill 是纯数据容器,无自动创建/编排/策展能力,不支持 SKILL.md 开放标准
|
||||
- **可观测性断层**:无结构化日志、无 metrics、无执行轨迹导出
|
||||
|
||||
Hermes Agent 的核心创新是"执行轨迹 → LLM 反思 → 技能沉淀 → 复用加速"的闭环飞轮。AgentKit 需要建立类似但适配企业场景的进化能力。
|
||||
|
||||
## Requirements
|
||||
|
||||
| ID | 需求 | 优先级 | 来源 |
|
||||
|----|------|--------|------|
|
||||
| R1 | TaskStore 持久化到 Redis/PG,进程重启不丢状态 | P0 | 持久运行评估 |
|
||||
| R2 | 记忆系统接入 Agent 循环,执行前检索上下文,执行后写入轨迹 | P0 | 记忆架构评估 |
|
||||
| R3 | LLM 驱动反思器替换硬编码 Reflector | P0 | 进化架构评估 |
|
||||
| R4 | EpisodicMemory 实现 pgvector 向量检索 | P1 | 记忆架构评估 |
|
||||
| R5 | 执行轨迹记录器,为反思和可观测性提供数据 | P1 | 进化+可观测性 |
|
||||
| R6 | 技能编排/Pipeline 能力 | P1 | 技能完备性评估 |
|
||||
| R7 | EvolutionStore 持久化 | P1 | 进化架构评估 |
|
||||
| R8 | SKILL.md 格式 + 渐进式分层 | P2 | 技能完备性评估 |
|
||||
| R9 | 上下文压缩与 Prompt 缓存 | P2 | Token 成本优化 |
|
||||
| R10 | 可观测性(结构化日志 + metrics + 健康检查增强) | P2 | 生产运维 |
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
|
||||
- 10 项升级(R1-R10),分 3 个交付阶段
|
||||
- 保持现有 API 向后兼容
|
||||
- 分支开发模式,不修改主干
|
||||
|
||||
### Out of Scope
|
||||
|
||||
- 多平台消息网关(Telegram/Discord/Slack 等)——定位差异,AgentKit 是 AI 引擎而非个人 Agent
|
||||
- 子代理并行执行——需要更复杂的调度架构,留待 Phase 4
|
||||
- 技能自动创建 + Curator——依赖 LLM 反思器和执行轨迹,留待 Phase 4
|
||||
- agentskills.io 技能市场——需要社区基础设施,留待 Phase 4
|
||||
- SemanticMemory 的 RAG/知识图谱后端实现——依赖外部服务,当前保持适配器模式
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
|
||||
- RateLimiter 迁移到 Redis 分布式限流
|
||||
- 多 worker 模式下的状态共享
|
||||
- 优雅关闭(SIGTERM 信号处理)
|
||||
- 用户建模(user_id + 偏好跟踪)
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: TaskStore 持久化策略 — Redis 优先
|
||||
|
||||
**决策**:TaskStore 默认使用 Redis 后端,InMemoryTaskStore 仅用于开发/测试。
|
||||
|
||||
**理由**:
|
||||
- docker-compose 已配置 Redis,基础设施就绪
|
||||
- TaskStore 已有 `RedisTaskStore` 实现(`server/task_store.py`),只需设为默认
|
||||
- Redis 天然支持 TTL,与任务过期清理需求一致
|
||||
- 避免引入新的存储依赖
|
||||
|
||||
**替代方案**:PostgreSQL 后端——更持久但延迟更高,适合归档而非活跃任务状态。
|
||||
|
||||
### KTD2: 记忆集成方式 — MemoryRetriever 注入 ReActEngine
|
||||
|
||||
**决策**:在 ReActEngine.execute() 中注入 `MemoryRetriever | None` 参数,执行前检索相关上下文注入 system_prompt,执行后写入轨迹到 EpisodicMemory。
|
||||
|
||||
**理由**:
|
||||
- ReActEngine 是所有执行模式的底层引擎,在此层集成覆盖面最广
|
||||
- MemoryRetriever 已实现三层并行检索 + 权重融合,无需重写
|
||||
- 注入方式而非继承方式,保持 ReActEngine 的独立性
|
||||
|
||||
**替代方案**:在 ConfigDrivenAgent 层集成——更简单但只覆盖 ConfigDrivenAgent,不覆盖直接使用 ReActEngine 的场景。
|
||||
|
||||
### KTD3: 反思器策略 — LLM-in-the-loop + 规则降级
|
||||
|
||||
**决策**:新增 `LLMReflector`,通过 LLM 分析执行轨迹生成反思。保留 `RuleBasedReflector`(当前实现)作为降级方案,LLM 不可用时自动切换。
|
||||
|
||||
**理由**:
|
||||
- GEPA 的核心洞见是"自然语言反思比数值奖励更有效",这需要 LLM 级别的反思
|
||||
- 企业场景需要降级策略,LLM 不可用时不能完全失去反思能力
|
||||
- 不直接使用 DSPy/GEPA 框架——AgentKit 已有 LLMGateway,无需引入新依赖
|
||||
|
||||
**替代方案**:集成 DSPy + GEPA——更强大但引入重依赖,且 AgentKit 的定位不需要 GEPA 的完整进化流水线。
|
||||
|
||||
### KTD4: 执行轨迹存储 — SQLite 本地 + 可选 PG
|
||||
|
||||
**决策**:执行轨迹默认存储在本地 SQLite(`~/.agentkit/traces/`),可选配置 PostgreSQL 后端用于大规模部署。
|
||||
|
||||
**理由**:
|
||||
- 与 Hermes Agent 一致(SQLite FTS5),轻量级
|
||||
- 单机部署无需 PG,降低使用门槛
|
||||
- PG 后端用于多实例部署场景
|
||||
|
||||
### KTD5: 技能编排 — 复用现有 PipelineEngine
|
||||
|
||||
**决策**:技能编排复用 `orchestrator/pipeline_engine.py` 的 PipelineEngine,新增 `SkillPipeline` 适配层将 Skill 包装为 Pipeline Step。
|
||||
|
||||
**理由**:
|
||||
- PipelineEngine 已实现顺序/并行/条件执行,功能完整
|
||||
- 避免重复造轮子,只需一个适配层
|
||||
- Pipeline YAML 格式已定义,用户可声明式编排技能
|
||||
|
||||
### KTD6: SKILL.md 格式 — YAML 元数据 + Markdown 正文
|
||||
|
||||
**决策**:SKILL.md 采用 YAML frontmatter + Markdown 正文的混合格式,兼容 agentskills.io 标准。
|
||||
|
||||
**理由**:
|
||||
- YAML frontmatter 机器可读(解析元数据),Markdown 正文人机可读(描述技能步骤)
|
||||
- 与现有 YAML 配置格式兼容,迁移成本低
|
||||
- agentskills.io 标准使用纯 Markdown,YAML frontmatter 是其超集
|
||||
|
||||
---
|
||||
|
||||
## High-Level Technical Design
|
||||
|
||||
### 进化飞轮架构
|
||||
|
||||
```mermaid
|
||||
graph LR
|
||||
A[任务执行] --> B[执行轨迹记录]
|
||||
B --> C[LLM 反思分析]
|
||||
C --> D{质量达标?}
|
||||
D -->|否| E[Prompt 优化]
|
||||
D -->|是| F[技能沉淀]
|
||||
E --> G[A/B 测试]
|
||||
G --> H{统计显著?}
|
||||
H -->|是| I[应用/回滚]
|
||||
H -->|否| J[继续收集样本]
|
||||
F --> K[技能库]
|
||||
K -->|复用| A
|
||||
I --> K
|
||||
```
|
||||
|
||||
### 记忆集成数据流
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant Agent as ConfigDrivenAgent
|
||||
participant Engine as ReActEngine
|
||||
participant Retriever as MemoryRetriever
|
||||
participant Episodic as EpisodicMemory
|
||||
|
||||
Client->>Agent: handle_task(task)
|
||||
Agent->>Retriever: get_context(task.input_data)
|
||||
Retriever->>Episodic: search(similar tasks)
|
||||
Episodic-->>Retriever: relevant memories
|
||||
Retriever-->>Agent: context string
|
||||
Agent->>Engine: execute(messages + context)
|
||||
Engine-->>Agent: result + trace
|
||||
Agent->>Episodic: store(trace summary)
|
||||
Agent-->>Client: TaskResult
|
||||
```
|
||||
|
||||
### 三阶段交付依赖
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph Phase A - 基础设施
|
||||
U1[U1: TaskStore 持久化]
|
||||
U2[U2: 执行轨迹记录器]
|
||||
U3[U3: EvolutionStore 持久化]
|
||||
end
|
||||
subgraph Phase B - 核心能力
|
||||
U4[U4: 记忆接入 Agent 循环]
|
||||
U5[U5: Episodic 向量检索]
|
||||
U6[U6: LLM 反思器]
|
||||
U7[U7: 技能编排]
|
||||
end
|
||||
subgraph Phase C - 增强
|
||||
U8[U8: SKILL.md 格式]
|
||||
U9[U9: 上下文压缩与缓存]
|
||||
U10[U10: 可观测性]
|
||||
end
|
||||
U1 --> U4
|
||||
U2 --> U4
|
||||
U2 --> U6
|
||||
U3 --> U6
|
||||
U4 --> U5
|
||||
U6 --> U8
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. TaskStore 持久化到 Redis
|
||||
|
||||
**Goal**: 将 TaskStore 默认后端从内存切换到 Redis,确保进程重启后任务状态不丢失。
|
||||
|
||||
**Requirements**: R1
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- Modify: `src/agentkit/server/task_store.py` — 将 `create_task_store()` 默认使用 Redis 后端
|
||||
- Modify: `src/agentkit/server/app.py` — `create_app()` 中根据配置选择 TaskStore 后端
|
||||
- Modify: `src/agentkit/server/config.py` — 新增 `task_store_backend` 配置项
|
||||
- Modify: `src/agentkit/cli/main.py` — serve 命令传递 task_store 配置
|
||||
- Test: `tests/unit/test_task_store_redis.py`
|
||||
|
||||
**Approach**:
|
||||
1. `RedisTaskStore` 已存在于 `task_store.py`,验证其功能完整性
|
||||
2. `create_task_store()` 工厂函数增加 `backend` 参数,默认 `redis`
|
||||
3. `ServerConfig` 新增 `task_store` 配置块(backend/redis_url/ttl/max_records)
|
||||
4. `create_app()` 从 `ServerConfig` 读取配置,创建对应 TaskStore
|
||||
5. InMemoryTaskStore 保留用于测试,通过 `backend: memory` 显式启用
|
||||
|
||||
**Patterns to follow**: `src/agentkit/server/task_store.py` 中 `RedisTaskStore` 的现有实现
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 创建任务 → 重启模拟(关闭 Redis 连接再重连)→ 查询任务仍存在
|
||||
- Edge case: Redis 不可用时降级到 InMemoryTaskStore 并打 warning 日志
|
||||
- Edge case: TTL 过期后任务自动清理
|
||||
- Error path: Redis 连接失败时的错误处理和降级
|
||||
- Integration: serve 命令启动后提交任务,查询任务状态
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_task_store_redis.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U2. 执行轨迹记录器
|
||||
|
||||
**Goal**: 在 ReActEngine 执行过程中记录完整的执行轨迹(每步动作、输入输出、耗时、Token 用量),为反思和可观测性提供数据。
|
||||
|
||||
**Requirements**: R5
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/core/trace.py` — TraceStep + ExecutionTrace 数据类 + TraceRecorder
|
||||
- Modify: `src/agentkit/core/react.py` — execute() 中注入 TraceRecorder,记录每步
|
||||
- Modify: `src/agentkit/core/protocol.py` — TaskResult 新增 `trace` 字段
|
||||
- Test: `tests/unit/test_trace_recorder.py`
|
||||
|
||||
**Approach**:
|
||||
1. 定义 `TraceStep`(step/action/tool_name/input/output/duration_ms/tokens_used/error)和 `ExecutionTrace`(task_id/agent_name/skill_name/steps/total_duration/total_tokens/outcome/quality_score)
|
||||
2. `TraceRecorder` 类:`start_trace()`、`record_step()`、`end_trace()`、`get_trace()`
|
||||
3. `ReActEngine.execute()` 新增 `trace_recorder: TraceRecorder | None = None` 参数
|
||||
4. 每次工具调用和 LLM 调用后调用 `record_step()`
|
||||
5. `TaskResult` 新增可选 `trace: ExecutionTrace | None` 字段
|
||||
6. 轨迹默认存储在内存中(单次请求生命周期),后续 U3 持久化
|
||||
|
||||
**Patterns to follow**: `src/agentkit/core/react.py` 中 `ReActStep` 和 `ReActResult` 的现有数据结构
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 执行 3 步 ReAct 循环,验证轨迹包含 3 个 TraceStep
|
||||
- Happy path: 工具调用记录 tool_name/input/output/duration
|
||||
- Edge case: 无工具调用的纯 LLM 响应,轨迹只有 1 步
|
||||
- Error path: 工具调用失败,TraceStep.error 非空
|
||||
- Integration: ConfigDrivenAgent 通过 ReActEngine 执行任务,TaskResult 包含 trace
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_trace_recorder.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U3. EvolutionStore 持久化
|
||||
|
||||
**Goal**: 将进化事件从内存迁移到 SQLite 持久化存储,支持进化历史查询和回滚。
|
||||
|
||||
**Requirements**: R7
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- Modify: `src/agentkit/evolution/evolution_store.py` — 新增 SQLite 后端,替换内存存储
|
||||
- Create: `src/agentkit/evolution/models.py` — SQLAlchemy ORM 模型(EvolutionEvent/SkillVersion/ABTestResult)
|
||||
- Test: `tests/unit/test_evolution_store_persistent.py`
|
||||
|
||||
**Approach**:
|
||||
1. 定义 SQLAlchemy ORM 模型:`EvolutionEvent`(id/agent_name/event_type/trace_id/reflection_id/proposal_id/status/created_at)、`SkillVersion`(id/skill_name/version/content/parent_version/created_at)、`ABTestResult`(id/test_id/variant/score/sample_count/created_at)
|
||||
2. `EvolutionStore` 新增 `backend` 参数,默认 `sqlite`(路径 `~/.agentkit/evolution.db`)
|
||||
3. `record()`/`query()`/`rollback()` 方法操作 SQLite
|
||||
4. 保留内存后端用于测试
|
||||
5. 首次运行自动创建表结构
|
||||
|
||||
**Patterns to follow**: `src/agentkit/evolution/evolution_store.py` 的现有接口
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 记录进化事件 → 关闭连接 → 重新打开 → 查询到事件
|
||||
- Happy path: 记录技能版本 → 查询版本历史
|
||||
- Edge case: 空数据库首次查询返回空列表
|
||||
- Error path: SQLite 文件不可写时的错误处理
|
||||
- Integration: EvolutionMixin.evolve_after_task() 写入 EvolutionStore
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_evolution_store_persistent.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U4. 记忆接入 Agent 循环
|
||||
|
||||
**Goal**: 将 MemoryRetriever 注入 ReActEngine,执行前检索相关上下文注入 system_prompt,执行后写入轨迹摘要到 EpisodicMemory。
|
||||
|
||||
**Requirements**: R2
|
||||
|
||||
**Dependencies**: U1, U2
|
||||
|
||||
**Files**:
|
||||
- Modify: `src/agentkit/core/react.py` — execute() 新增 `memory_retriever` 参数,执行前检索上下文
|
||||
- Modify: `src/agentkit/core/config_driven.py` — 根据 config.memory 自动实例化三层记忆,注入 ReActEngine
|
||||
- Modify: `src/agentkit/core/base.py` — BaseAgent 新增 `use_memory_retriever()` 方法
|
||||
- Modify: `src/agentkit/server/app.py` — create_app() 中初始化 Memory 组件
|
||||
- Test: `tests/unit/test_memory_integration.py`
|
||||
|
||||
**Approach**:
|
||||
1. `ReActEngine.__init__` 新增 `memory_retriever: MemoryRetriever | None = None`
|
||||
2. `execute()` 开始前:调用 `memory_retriever.get_context_string(task_input)` 获取相关记忆
|
||||
3. 将记忆上下文追加到 system_prompt 的末尾(`## Relevant Past Experience` 段落)
|
||||
4. `execute()` 结束后:将执行轨迹摘要写入 EpisodicMemory
|
||||
5. `ConfigDrivenAgent.__init__` 根据 `config.memory` 配置自动创建 WorkingMemory/EpisodicMemory/MemoryRetriever
|
||||
6. `create_app()` 中从 ServerConfig 读取 memory 配置,初始化 Memory 组件
|
||||
|
||||
**Patterns to follow**: `src/agentkit/memory/retriever.py` 的 `MemoryRetriever` 接口
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 执行任务时检索到相关历史记忆,注入 system_prompt
|
||||
- Happy path: 任务完成后轨迹摘要写入 EpisodicMemory
|
||||
- Edge case: 无记忆时正常执行(memory_retriever=None)
|
||||
- Edge case: 记忆检索失败时不影响任务执行
|
||||
- Integration: 连续执行两个相似任务,第二个任务能检索到第一个的记忆
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_memory_integration.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U5. EpisodicMemory 向量检索实现
|
||||
|
||||
**Goal**: 实现 EpisodicMemory 的 pgvector cosine distance 排序,替代当前的时间衰减排序,支持语义相似度检索。
|
||||
|
||||
**Requirements**: R4
|
||||
|
||||
**Dependencies**: U4
|
||||
|
||||
**Files**:
|
||||
- Modify: `src/agentkit/memory/episodic.py` — 实现 pgvector 向量检索
|
||||
- Create: `src/agentkit/memory/embedder.py` — Embedder 接口 + OpenAIEmbedder 实现
|
||||
- Test: `tests/unit/test_episodic_vector_search.py`
|
||||
|
||||
**Approach**:
|
||||
1. 新增 `Embedder` 抽象基类:`embed(text: str) -> list[float]`
|
||||
2. 新增 `OpenAIEmbedder`:调用 OpenAI Embeddings API(text-embedding-3-small)
|
||||
3. `EpisodicMemory.store()` 中调用 embedder 生成 embedding,存入 pgvector Vector 列
|
||||
4. `EpisodicMemory.search()` 中实现 cosine distance 排序,与时间衰减混合:`score = alpha * cosine_similarity + (1-alpha) * time_decay`
|
||||
5. 默认 `alpha=0.7`(语义相似度权重更高),可通过配置调整
|
||||
6. `retrieve(key)` 方法实现:先 embed query,再按 cosine distance 排序
|
||||
|
||||
**Patterns to follow**: `src/agentkit/memory/episodic.py` 的现有接口
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 存入 3 条记忆,用语义相似查询检索到最相关的
|
||||
- Happy path: 时间衰减 + 语义相似度混合排序
|
||||
- Edge case: embedder 不可用时降级到纯时间衰减排序
|
||||
- Edge case: 空查询返回空结果
|
||||
- Error path: pgvector 扩展未安装时的错误提示
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_episodic_vector_search.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U6. LLM 反思器
|
||||
|
||||
**Goal**: 新增 LLMReflector,通过 LLM 分析执行轨迹生成结构化反思。保留 RuleBasedReflector 作为降级方案。
|
||||
|
||||
**Requirements**: R3
|
||||
|
||||
**Dependencies**: U2, U3
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/evolution/llm_reflector.py` — LLMReflector 类
|
||||
- Modify: `src/agentkit/evolution/reflector.py` — 重命名为 RuleBasedReflector,保持接口兼容
|
||||
- Modify: `src/agentkit/evolution/lifecycle.py` — EvolutionMixin 支持 reflector 类型选择
|
||||
- Modify: `src/agentkit/skills/base.py` — EvolutionConfig 新增 `reflector_type` 字段
|
||||
- Test: `tests/unit/test_llm_reflector.py`
|
||||
|
||||
**Approach**:
|
||||
1. `LLMReflector` 接收 `ExecutionTrace`,构建反思 Prompt(包含轨迹详情 + 质量评分)
|
||||
2. 调用 LLM Gateway 生成结构化反思(失败根因/成功模式/改进建议)
|
||||
3. 输出与 `Reflection` 数据类兼容(outcome/quality_score/patterns/insights/suggestions)
|
||||
4. `EvolutionMixin` 新增 `reflector_type` 配置:`llm`(默认)/ `rule` / `auto`(LLM 优先,失败降级到 rule)
|
||||
5. LLM 反思使用辅助模型(非主模型),降低成本
|
||||
6. `EvolutionConfig` 新增 `reflector_type` 和 `auxiliary_model` 字段,与 EvolutionMixin 对齐
|
||||
|
||||
**Patterns to follow**: `src/agentkit/evolution/reflector.py` 的 `Reflector` 接口和 `Reflection` 数据类
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: LLM 分析执行轨迹,生成包含 insights 和 suggestions 的 Reflection
|
||||
- Happy path: auto 模式下 LLM 失败时降级到 RuleBasedReflector
|
||||
- Edge case: 执行轨迹为空时返回默认 Reflection
|
||||
- Edge case: LLM 返回非结构化文本时的解析容错
|
||||
- Integration: EvolutionMixin 使用 LLMReflector 完成完整进化流程
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_llm_reflector.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U7. 技能编排
|
||||
|
||||
**Goal**: 复用 PipelineEngine 实现 Skill 编排,支持将多个 Skill 串联为 Pipeline 执行。
|
||||
|
||||
**Requirements**: R6
|
||||
|
||||
**Dependencies**: U4
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/skills/pipeline.py` — SkillPipeline 适配层
|
||||
- Modify: `src/agentkit/skills/registry.py` — 新增 pipeline 注册和查询
|
||||
- Modify: `src/agentkit/server/routes/skills.py` — 新增 pipeline API 端点
|
||||
- Test: `tests/unit/test_skill_pipeline.py`
|
||||
|
||||
**Approach**:
|
||||
1. `SkillPipeline` 类:封装 PipelineEngine,将 Skill 包装为 Pipeline Step
|
||||
2. 每个 Skill 在 Pipeline 中作为一个 Step,输入为上一步的输出
|
||||
3. 支持顺序执行、条件分支(根据 Skill 输出决定下一步)、并行执行
|
||||
4. Pipeline 定义格式复用 `orchestrator/pipeline_schema.py` 的 PipelineConfig
|
||||
5. SkillPipeline 可通过 YAML 定义或编程式构建
|
||||
6. SkillRegistry 新增 `register_pipeline()` 和 `get_pipeline()` 方法
|
||||
|
||||
**Patterns to follow**: `src/agentkit/orchestrator/pipeline_engine.py` 的 PipelineEngine 接口
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 3 个 Skill 顺序执行,输出正确传递
|
||||
- Happy path: 条件分支 — 根据 Skill A 的输出决定执行 Skill B 还是 Skill C
|
||||
- Edge case: Pipeline 中某个 Skill 失败时,后续 Skill 不执行
|
||||
- Edge case: 空 Pipeline(0 个 Skill)直接返回空结果
|
||||
- Integration: 通过 API 提交 Pipeline 任务,查询执行状态
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_skill_pipeline.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U8. SKILL.md 格式 + 渐进式分层
|
||||
|
||||
**Goal**: 支持 SKILL.md 格式的技能定义,实现渐进式分层加载(Level 0 概要 / Level 1 完整 / Level 2 参考)。
|
||||
|
||||
**Requirements**: R8
|
||||
|
||||
**Dependencies**: U6
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/skills/skill_md.py` — SKILL.md 解析器
|
||||
- Modify: `src/agentkit/skills/loader.py` — 新增 `load_from_skill_md()` 方法
|
||||
- Modify: `src/agentkit/skills/base.py` — SkillConfig 新增 `skill_md_path` 和 `disclosure_level` 字段
|
||||
- Modify: `src/agentkit/cli/skill.py` — 新增 `skill create` 命令生成 SKILL.md 模板
|
||||
- Test: `tests/unit/test_skill_md.py`
|
||||
|
||||
**Approach**:
|
||||
1. SKILL.md 格式:YAML frontmatter(name/description/intent/quality_gate/execution_mode)+ Markdown 正文(trigger/steps/pitfalls/verification)
|
||||
2. 解析器提取 frontmatter 生成 SkillConfig,正文按标题分段存储
|
||||
3. 渐进式分层:
|
||||
- Level 0:frontmatter 中的 name + description(~50 tokens,常驻加载)
|
||||
- Level 1:完整正文(按需加载,当 IntentRouter 匹配到该技能时)
|
||||
- Level 2:references/ 和 templates/ 目录(深度加载,技能执行时)
|
||||
4. SkillLoader 新增 `load_from_skill_md(path)` 方法
|
||||
5. CLI `skill create` 生成 SKILL.md 模板文件
|
||||
|
||||
**Patterns to follow**: `src/agentkit/skills/loader.py` 的 `load_from_file()` 方法
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 解析 SKILL.md 文件,生成正确的 SkillConfig
|
||||
- Happy path: Level 0 只加载 name + description
|
||||
- Happy path: Level 1 加载完整步骤
|
||||
- Edge case: frontmatter 缺失时使用默认值
|
||||
- Edge case: Markdown 正文缺少标准段落时的容错处理
|
||||
- Integration: SkillLoader 从 SKILL.md 加载技能,注册到 SkillRegistry
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_skill_md.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U9. 上下文压缩与 Prompt 缓存
|
||||
|
||||
**Goal**: 实现上下文压缩(长会话自动压缩历史消息)和 Prompt 缓存(会话内 Prompt 不重复渲染)。
|
||||
|
||||
**Requirements**: R9
|
||||
|
||||
**Dependencies**: U4
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/core/compressor.py` — ContextCompressor 类
|
||||
- Modify: `src/agentkit/prompts/template.py` — 新增 `render_cached()` 方法和缓存机制
|
||||
- Modify: `src/agentkit/core/react.py` — execute() 中注入压缩逻辑
|
||||
- Test: `tests/unit/test_context_compressor.py`
|
||||
|
||||
**Approach**:
|
||||
1. `ContextCompressor`:当消息总 Token 数超过阈值(默认 4000)时,调用 LLM 将历史消息压缩为摘要
|
||||
2. 压缩策略:保留最近 N 条消息 + 早期消息的 LLM 摘要
|
||||
3. `PromptTemplate.render_cached()`:对相同变量输入返回缓存结果,变量变化时重新渲染
|
||||
4. 缓存 key 基于 variables 的 hash,缓存存储在 PromptTemplate 实例上
|
||||
5. ReActEngine.execute() 中在每次 LLM 调用前检查消息长度,超阈值则压缩
|
||||
|
||||
**Patterns to follow**: Hermes Agent 的上下文压缩机制(LLM 摘要 + 缓存快照)
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 10 条历史消息压缩为摘要 + 最近 3 条
|
||||
- Happy path: 压缩后 Token 数低于阈值
|
||||
- Happy path: 相同变量输入命中 PromptTemplate 缓存
|
||||
- Edge case: 压缩后仍超阈值时递归压缩
|
||||
- Edge case: LLM 压缩调用失败时保留原始消息
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_context_compressor.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
### U10. 可观测性
|
||||
|
||||
**Goal**: 实现结构化日志、metrics 端点和增强健康检查。
|
||||
|
||||
**Requirements**: R10
|
||||
|
||||
**Dependencies**: U2
|
||||
|
||||
**Files**:
|
||||
- Create: `src/agentkit/core/logging.py` — 结构化日志配置
|
||||
- Create: `src/agentkit/server/routes/metrics.py` — /api/v1/metrics 端点
|
||||
- Modify: `src/agentkit/server/routes/health.py` — 增强健康检查(Redis/PG/LLM/AgentPool 状态)
|
||||
- Modify: `src/agentkit/server/app.py` — 注册 metrics 路由,初始化结构化日志
|
||||
- Test: `tests/unit/test_observability.py`
|
||||
|
||||
**Approach**:
|
||||
1. 结构化日志:使用 Python `structlog`,JSON 格式输出,包含 trace_id/agent_name/skill_name
|
||||
2. Metrics 端点:`GET /api/v1/metrics` 返回任务计数/成功率/平均耗时/Token 用量/Agent 池状态
|
||||
3. 增强健康检查:`GET /api/v1/health` 返回 Redis 连通性/PG 连通性/LLM Provider 可用性/AgentPool 大小
|
||||
4. Metrics 数据从 TaskStore(Redis)和 EvolutionStore(SQLite)聚合
|
||||
5. 健康检查中 LLM 可用性通过轻量级 ping(发送空请求验证 API Key 有效)
|
||||
|
||||
**Patterns to follow**: `src/agentkit/server/routes/health.py` 的现有健康检查接口
|
||||
|
||||
**Test scenarios**:
|
||||
- Happy path: 结构化日志输出 JSON 格式,包含 trace_id
|
||||
- Happy path: /api/v1/metrics 返回正确的任务计数和成功率
|
||||
- Happy path: /api/v1/health 检查 Redis/PG/LLM 状态
|
||||
- Edge case: Redis 不可用时健康检查返回 degraded 状态
|
||||
- Edge case: 无任务数据时 metrics 返回零值
|
||||
|
||||
**Verification**: `PYTHONPATH=src pytest tests/unit/test_observability.py -v` 全部通过
|
||||
|
||||
---
|
||||
|
||||
## Phased Delivery
|
||||
|
||||
### Phase A: 基础设施(U1, U2, U3)
|
||||
|
||||
无外部依赖的底层能力,为后续所有单元提供基础。
|
||||
|
||||
- U1: TaskStore 持久化 → 进程重启不丢状态
|
||||
- U2: 执行轨迹记录器 → 为反思和可观测性提供数据
|
||||
- U3: EvolutionStore 持久化 → 进化可追溯
|
||||
|
||||
### Phase B: 核心能力(U4, U5, U6, U7)
|
||||
|
||||
依赖 Phase A 的核心升级,建立飞轮闭环。
|
||||
|
||||
- U4: 记忆接入 Agent 循环 → 跨会话上下文延续
|
||||
- U5: Episodic 向量检索 → 语义记忆召回
|
||||
- U6: LLM 反思器 → 真正的反思能力
|
||||
- U7: 技能编排 → 多技能 Pipeline
|
||||
|
||||
### Phase C: 增强(U8, U9, U10)
|
||||
|
||||
提升用户体验和生产就绪度。
|
||||
|
||||
- U8: SKILL.md 格式 → 开放标准兼容
|
||||
- U9: 上下文压缩与缓存 → Token 成本优化
|
||||
- U10: 可观测性 → 生产运维
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| 风险 | 影响 | 缓解措施 |
|
||||
|------|------|---------|
|
||||
| LLM 反思器增加 API 调用成本 | 中 | 使用辅助模型(更便宜),auto 模式降级到规则 |
|
||||
| pgvector 向量检索延迟 | 中 | 混合排序(语义+时间衰减),限制返回数量 |
|
||||
| 记忆注入增加 Prompt Token | 中 | Token 预算管理,超预算时截断 |
|
||||
| 技能编排增加复杂度 | 低 | 复用现有 PipelineEngine,渐进式引入 |
|
||||
| SQLite EvolutionStore 并发写入 | 低 | 单写多读模式,写操作加锁 |
|
||||
| 向后兼容性破坏 | 高 | 所有新参数默认 None,不改变现有行为 |
|
||||
|
||||
---
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **API 兼容性**:所有新增参数默认 None,现有 API 调用无需修改
|
||||
- **配置变更**:`agentkit.yaml` 新增 `task_store`/`memory`/`evolution` 配置块,均为可选
|
||||
- **部署变更**:Redis 从可选变为推荐(TaskStore 默认后端),已在 docker-compose 中配置
|
||||
- **依赖变更**:新增 `structlog`(可观测性),`pgvector` 向量检索需要 pgvector 扩展
|
||||
- **测试变更**:新增 10 个测试文件,约 50+ 测试用例
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
1. **Embedder 选型**:OpenAI Embeddings vs 本地模型(如 sentence-transformers)?建议默认 OpenAI,可选本地
|
||||
2. **LLM 反思的辅助模型**:使用主模型还是更便宜的模型?建议默认使用主模型,可通过 `auxiliary_model` 配置
|
||||
3. **SKILL.md 与现有 YAML 的共存策略**:是否需要迁移工具?建议双格式共存,SkillLoader 自动识别
|
||||
|
||||
---
|
||||
|
||||
## Sources & Research
|
||||
|
||||
- Hermes Agent 官方文档: https://hermes-agent.nousresearch.com/docs/developer-guide/architecture
|
||||
- GEPA 论文: ICLR 2026 Oral "Reflective Prompt Evolution Can Outperform Reinforcement Learning"
|
||||
- Hermes Agent 记忆系统: https://hermes-agent.ai/blog/hermes-agent-memory-system
|
||||
- Hermes Curator: https://hermes-agent.nousresearch.com/docs/user-guide/features/curator
|
||||
- AgentKit 现有计划: `docs/plans/006-refactor-agentkit-v2-phase2-plan.md`
|
||||
|
|
@ -34,17 +34,75 @@ def serve(
|
|||
workers: int = typer.Option(1, "--workers", help="Number of workers"),
|
||||
reload: bool = typer.Option(False, "--reload", help="Enable auto-reload"),
|
||||
config: Optional[str] = typer.Option(None, "--config", help="Path to agentkit.yaml"),
|
||||
task_store_backend: Optional[str] = typer.Option(None, "--task-store-backend", help="Task store backend: memory or redis"),
|
||||
task_store_redis_url: Optional[str] = typer.Option(None, "--task-store-redis-url", help="Redis URL for task store (only used when backend=redis)"),
|
||||
):
|
||||
"""Start the AgentKit server"""
|
||||
import uvicorn
|
||||
|
||||
rprint(f"[green]Starting AgentKit Server on {host}:{port}[/green]")
|
||||
from agentkit.server.config import ServerConfig, find_config_path
|
||||
|
||||
# Load .env file if present
|
||||
config_path = find_config_path(config)
|
||||
|
||||
if config_path:
|
||||
rprint(f"[green]Loading config from {config_path}[/green]")
|
||||
server_config = ServerConfig.from_yaml(config_path)
|
||||
|
||||
# Load .env file for env var resolution
|
||||
from pathlib import Path
|
||||
dotenv = Path(config_path).parent / ".env"
|
||||
server_config.load_dotenv(str(dotenv))
|
||||
|
||||
# Re-load config after .env is loaded (env vars now available)
|
||||
server_config = ServerConfig.from_yaml(config_path)
|
||||
|
||||
# CLI args override config file for task_store
|
||||
if task_store_backend is not None:
|
||||
server_config.task_store["backend"] = task_store_backend
|
||||
if task_store_redis_url is not None:
|
||||
server_config.task_store["redis_url"] = task_store_redis_url
|
||||
|
||||
# CLI args override config file
|
||||
effective_host = host if host != "0.0.0.0" else server_config.host
|
||||
effective_port = port if port != 8001 else server_config.port
|
||||
effective_workers = workers if workers != 1 else server_config.workers
|
||||
|
||||
# Store config for app factory
|
||||
import os
|
||||
import json as _json
|
||||
os.environ["AGENTKIT_CONFIG_PATH"] = config_path
|
||||
# Pass task_store overrides via env var so create_app can read them
|
||||
if server_config.task_store:
|
||||
os.environ["AGENTKIT_TASK_STORE"] = _json.dumps(server_config.task_store)
|
||||
|
||||
rprint(f"[green]LLM providers: {list(server_config.llm_config.providers.keys())}[/green]")
|
||||
rprint(f"[green]Skill paths: {server_config.skill_paths}[/green]")
|
||||
ts_backend = server_config.task_store.get("backend", "memory")
|
||||
rprint(f"[green]Task store backend: {ts_backend}[/green]")
|
||||
else:
|
||||
rprint("[yellow]No agentkit.yaml found, using defaults[/yellow]")
|
||||
effective_host = host
|
||||
effective_port = port
|
||||
effective_workers = workers
|
||||
# Apply CLI task_store overrides even without config file
|
||||
import os
|
||||
import json as _json
|
||||
ts_override: dict = {}
|
||||
if task_store_backend is not None:
|
||||
ts_override["backend"] = task_store_backend
|
||||
if task_store_redis_url is not None:
|
||||
ts_override["redis_url"] = task_store_redis_url
|
||||
if ts_override:
|
||||
os.environ["AGENTKIT_TASK_STORE"] = _json.dumps(ts_override)
|
||||
|
||||
rprint(f"[green]Starting AgentKit Server on {effective_host}:{effective_port}[/green]")
|
||||
|
||||
uvicorn.run(
|
||||
"agentkit.server.app:create_app",
|
||||
host=host,
|
||||
port=port,
|
||||
workers=workers,
|
||||
host=effective_host,
|
||||
port=effective_port,
|
||||
workers=effective_workers,
|
||||
reload=reload,
|
||||
factory=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -82,6 +82,46 @@ def load_skill(
|
|||
raise typer.Exit(code=1)
|
||||
|
||||
|
||||
@skill_app.command("create")
|
||||
def skill_create(
|
||||
name: str = typer.Argument(..., help="Skill name"),
|
||||
output_dir: Optional[str] = typer.Option(".", "--output-dir", "-o", help="Output directory"),
|
||||
):
|
||||
"""Create a new SKILL.md template"""
|
||||
template = f'''---
|
||||
name: {name}
|
||||
description: "Description of {name}"
|
||||
agent_type: {name}
|
||||
execution_mode: react
|
||||
intent:
|
||||
keywords: ["{name}"]
|
||||
description: "Tasks related to {name}"
|
||||
quality_gate:
|
||||
required_fields: []
|
||||
min_word_count: 0
|
||||
---
|
||||
|
||||
# Trigger
|
||||
- When to use this skill
|
||||
|
||||
# Steps
|
||||
1. Step one
|
||||
2. Step two
|
||||
3. Step three
|
||||
|
||||
# Pitfalls
|
||||
- Common mistakes to avoid
|
||||
|
||||
# Verification
|
||||
- How to verify the output
|
||||
'''
|
||||
output_path = os.path.join(output_dir, f"{name}.md")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(template)
|
||||
rprint(f"[green]Created SKILL.md template:[/green] {output_path}")
|
||||
|
||||
|
||||
@skill_app.command("info")
|
||||
def skill_info(
|
||||
name: str = typer.Argument(help="Skill name"),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def submit(
|
|||
agent: Optional[str] = typer.Option(None, "--agent", "-a", help="Agent name"),
|
||||
mode: str = typer.Option("sync", "--mode", "-m", help="Execution mode: sync or async"),
|
||||
server_url: Optional[str] = typer.Option(None, "--server-url", help="AgentKit server URL"),
|
||||
config: Optional[str] = typer.Option(None, "--config", help="Path to agentkit.yaml (local mode)"),
|
||||
):
|
||||
"""Submit a task for execution"""
|
||||
# Parse input data
|
||||
|
|
@ -31,11 +32,16 @@ def submit(
|
|||
rprint("[red]Error: Provide --input or --input-file[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not server_url:
|
||||
rprint("[red]Error: --server-url is required (local mode not yet supported)[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
if server_url:
|
||||
# Remote mode: use AgentKitClient
|
||||
_submit_remote(input_data, skill, agent, mode, server_url)
|
||||
else:
|
||||
# Local mode: execute directly
|
||||
_submit_local(input_data, skill, agent, mode, config)
|
||||
|
||||
# Use AgentKitClient for remote mode
|
||||
|
||||
def _submit_remote(input_data, skill, agent, mode, server_url):
|
||||
"""Submit task to a remote AgentKit server."""
|
||||
from agentkit.server.client import AgentKitClient
|
||||
client = AgentKitClient(base_url=server_url)
|
||||
|
||||
|
|
@ -59,6 +65,64 @@ def submit(
|
|||
rprint(json.dumps(result["output_data"], indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def _submit_local(input_data, skill, agent, mode, config_path):
|
||||
"""Submit task locally without a running server."""
|
||||
from agentkit.server.config import ServerConfig, find_config_path
|
||||
|
||||
# Load config
|
||||
resolved_path = find_config_path(config_path)
|
||||
if resolved_path:
|
||||
server_config = ServerConfig.from_yaml(resolved_path)
|
||||
server_config.load_dotenv()
|
||||
server_config = ServerConfig.from_yaml(resolved_path)
|
||||
else:
|
||||
server_config = None
|
||||
|
||||
# Build app components
|
||||
from agentkit.server.app import create_app
|
||||
app = create_app(server_config=server_config)
|
||||
|
||||
# Execute task through the app's agent pool
|
||||
async def _execute():
|
||||
agent_pool = app.state.agent_pool
|
||||
skill_registry = app.state.skill_registry
|
||||
|
||||
# Determine which skill/agent to use
|
||||
if skill:
|
||||
if not skill_registry.has_skill(skill):
|
||||
rprint(f"[red]Skill '{skill}' not found. Available: {[s.name for s in skill_registry.list_skills()]}[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
skill_obj = skill_registry.get(skill)
|
||||
agent_name = skill_obj.name
|
||||
elif agent:
|
||||
agent_name = agent
|
||||
else:
|
||||
rprint("[red]Error: Provide --skill or --agent[/red]")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Create agent and execute
|
||||
agent_instance = agent_pool.get_or_create(agent_name)
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
task = TaskMessage(
|
||||
task_id=str(uuid.uuid4()),
|
||||
agent_name=agent_name,
|
||||
task_type="cli_submit",
|
||||
priority=0,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
result = await agent_instance.execute(task)
|
||||
return result
|
||||
|
||||
result = asyncio.run(_execute())
|
||||
rprint("[green]Task completed[/green]")
|
||||
if result.output_data:
|
||||
rprint(json.dumps(result.output_data, indent=2, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
@task_app.command("status")
|
||||
def status(
|
||||
task_id: str = typer.Argument(help="Task ID"),
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class BaseAgent(ABC):
|
|||
# 可插拔能力(由子类或配置注入)
|
||||
self._tools: list["Tool"] = []
|
||||
self._memory: "Memory | None" = None
|
||||
self._memory_retriever: Any | None = None
|
||||
|
||||
# 外部依赖注入(由 start() 时设置)
|
||||
self._registry = None
|
||||
|
|
@ -175,6 +176,11 @@ class BaseAgent(ABC):
|
|||
self._memory = memory
|
||||
return self
|
||||
|
||||
def use_memory_retriever(self, retriever: Any) -> "BaseAgent":
|
||||
"""设置记忆检索器,用于上下文注入"""
|
||||
self._memory_retriever = retriever
|
||||
return self
|
||||
|
||||
def set_registry(self, registry: Any) -> "BaseAgent":
|
||||
"""注入注册中心"""
|
||||
self._registry = registry
|
||||
|
|
|
|||
|
|
@ -0,0 +1,171 @@
|
|||
"""ContextCompressor - 上下文压缩与 Prompt 缓存
|
||||
|
||||
长会话自动压缩历史消息,保持 Token 在预算内;
|
||||
会话内 Prompt 不重复渲染。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compress long conversation histories to stay within token budgets"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: Any = None,
|
||||
max_tokens: int = 4000,
|
||||
keep_recent: int = 3,
|
||||
model: str = "default",
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_tokens = max_tokens
|
||||
self._keep_recent = keep_recent
|
||||
self._model = model
|
||||
|
||||
def estimate_tokens(self, messages: list[dict]) -> int:
|
||||
"""Estimate total tokens in message list (rough: 4 chars = 1 token)"""
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
total += len(str(content)) // 4
|
||||
return total
|
||||
|
||||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||||
"""Compress messages if they exceed token budget
|
||||
|
||||
Strategy:
|
||||
1. Keep system messages unchanged
|
||||
2. Keep the most recent N messages unchanged
|
||||
3. Compress older messages into a summary using LLM
|
||||
"""
|
||||
if self.estimate_tokens(messages) <= self._max_tokens:
|
||||
return messages
|
||||
|
||||
# Separate system messages, old messages, and recent messages
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
if len(non_system) <= self._keep_recent:
|
||||
return messages # Not enough messages to compress
|
||||
|
||||
old_msgs = non_system[:-self._keep_recent]
|
||||
recent_msgs = non_system[-self._keep_recent:]
|
||||
|
||||
# Compress old messages
|
||||
summary = await self._summarize(old_msgs)
|
||||
|
||||
# Build compressed message list
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
compressed.extend(recent_msgs)
|
||||
|
||||
# Recursive check: if still over budget, compress again
|
||||
if self.estimate_tokens(compressed) > self._max_tokens:
|
||||
if len(recent_msgs) > 1:
|
||||
# Try keeping fewer recent messages
|
||||
return await self._compress_aggressive(messages)
|
||||
# Last resort: truncate
|
||||
return self._truncate(compressed)
|
||||
|
||||
return compressed
|
||||
|
||||
async def _summarize(self, messages: list[dict]) -> str:
|
||||
"""Summarize a list of messages using LLM"""
|
||||
if not self._llm_gateway:
|
||||
# No LLM available, do simple truncation
|
||||
return self._simple_summary(messages)
|
||||
|
||||
# Build summary prompt
|
||||
conversation_text = "\n".join(
|
||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
|
||||
for m in messages
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"Summarize the following conversation history concisely, "
|
||||
"preserving key facts, decisions, and context. "
|
||||
"Focus on information that would be needed for continuing the conversation.\n\n"
|
||||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._model,
|
||||
agent_name="compressor",
|
||||
task_type="summarization",
|
||||
)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM summarization failed, using simple summary: {e}")
|
||||
return self._simple_summary(messages)
|
||||
|
||||
def _simple_summary(self, messages: list[dict]) -> str:
|
||||
"""Simple truncation-based summary when LLM is unavailable"""
|
||||
parts = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
content = str(msg.get("content", ""))[:200]
|
||||
parts.append(f"[{role}]: {content}...")
|
||||
return "\n".join(parts)
|
||||
|
||||
async def _compress_aggressive(self, messages: list[dict]) -> list[dict]:
|
||||
"""More aggressive compression when standard compression isn't enough"""
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
# Keep only the last message
|
||||
if non_system:
|
||||
summary = await self._summarize(non_system[:-1])
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
compressed.append(non_system[-1])
|
||||
return compressed
|
||||
|
||||
return messages
|
||||
|
||||
def _truncate(self, messages: list[dict]) -> list[dict]:
|
||||
"""Last resort: truncate long messages"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
if len(content) > self._max_tokens * 2:
|
||||
msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"}
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||||
cache_key = hashlib.md5(
|
||||
json.dumps(variables or {}, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
if not hasattr(template, '_render_cache'):
|
||||
template._render_cache = {}
|
||||
|
||||
if cache_key in template._render_cache:
|
||||
return template._render_cache[cache_key]
|
||||
|
||||
result = template.render(variables=variables)
|
||||
template._render_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
def clear_cache(template) -> None:
|
||||
"""Clear the render cache on a PromptTemplate instance"""
|
||||
if hasattr(template, '_render_cache'):
|
||||
template._render_cache.clear()
|
||||
|
|
@ -199,6 +199,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
llm_client: Any = None,
|
||||
custom_handlers: dict[str, Callable[..., Coroutine]] | None = None,
|
||||
llm_gateway: Any = None, # NEW v2 param: LLMGateway
|
||||
mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs
|
||||
):
|
||||
# v2: If SkillConfig, extract skill info
|
||||
from agentkit.skills.base import SkillConfig, Skill
|
||||
|
|
@ -294,6 +295,52 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
# 从配置绑定 Tool
|
||||
self._bind_tools()
|
||||
|
||||
# v2: Merge Skill-bound tools into Agent's tool list
|
||||
if self._skill_instance and self._skill_instance.tools:
|
||||
for tool in self._skill_instance.tools:
|
||||
if not any(t.name == tool.name for t in self._tools):
|
||||
self.use_tool(tool)
|
||||
logger.info(f"Merged skill tool '{tool.name}' into agent '{self.name}'")
|
||||
|
||||
# v2: Register MCP tools if mcp_servers provided
|
||||
self._mcp_clients: list[Any] = []
|
||||
self._mcp_servers: dict[str, str] = mcp_servers or {}
|
||||
self._mcp_tools_registered = False
|
||||
|
||||
# Memory integration: 从 config.memory 自动实例化 MemoryRetriever
|
||||
self._memory_retriever: Any | None = None
|
||||
if config.memory:
|
||||
try:
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
from agentkit.memory.working import WorkingMemory
|
||||
|
||||
working = None
|
||||
episodic = None
|
||||
|
||||
if config.memory.get("working", {}).get("enabled"):
|
||||
import redis.asyncio as aioredis
|
||||
redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||
working = WorkingMemory(redis=redis_client)
|
||||
|
||||
if config.memory.get("episodic", {}).get("enabled"):
|
||||
# EpisodicMemory needs session_factory and model - requires PostgreSQL setup
|
||||
# Will be initialized externally when DB is available
|
||||
pass
|
||||
|
||||
self._memory_retriever = MemoryRetriever(
|
||||
working_memory=working,
|
||||
episodic_memory=episodic,
|
||||
)
|
||||
|
||||
# Inject into BaseAgent
|
||||
self._memory_retriever_ref = self._memory_retriever
|
||||
|
||||
logger.info(f"ConfigDrivenAgent '{self.name}' initialized memory system")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize memory system: {e}")
|
||||
self._memory_retriever = None
|
||||
|
||||
@property
|
||||
def config(self) -> AgentConfig:
|
||||
return self._config
|
||||
|
|
@ -352,6 +399,43 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}"
|
||||
)
|
||||
|
||||
async def _register_mcp_tools(self) -> None:
|
||||
"""Lazily register tools from MCP servers as agent tools.
|
||||
|
||||
Called on first task execution to allow async MCP client operations.
|
||||
"""
|
||||
if self._mcp_tools_registered or not self._mcp_servers:
|
||||
return
|
||||
|
||||
self._mcp_tools_registered = True
|
||||
from agentkit.mcp.client import MCPClient
|
||||
|
||||
for server_name, base_url in self._mcp_servers.items():
|
||||
try:
|
||||
client = MCPClient(server_url=base_url)
|
||||
self._mcp_clients.append(client)
|
||||
|
||||
# List available tools from the MCP server
|
||||
tools = await client.list_tools()
|
||||
for tool_info in tools:
|
||||
tool_name = tool_info.get("name", "")
|
||||
tool_desc = tool_info.get("description", "")
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
# Create MCPTool and register it
|
||||
mcp_tool = client.as_tool(tool_name, tool_desc)
|
||||
self.use_tool(mcp_tool)
|
||||
logger.info(
|
||||
f"Agent '{self.name}' registered MCP tool '{tool_name}' "
|
||||
f"from server '{server_name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Agent '{self.name}' failed to connect to MCP server "
|
||||
f"'{server_name}' at {base_url}: {e}"
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
|
|
@ -365,20 +449,30 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
)
|
||||
|
||||
async def handle_task(self, task: TaskMessage) -> dict:
|
||||
"""根据 task_mode 执行任务
|
||||
"""根据 execution_mode 和 task_mode 执行任务
|
||||
|
||||
v2: 如果 SkillConfig 且 execution_mode=react 且 ReAct engine 可用,
|
||||
则使用 ReAct 引擎执行;否则回退到传统模式。
|
||||
v2 execution_mode 优先级:
|
||||
- react: 使用 ReAct 引擎自主推理
|
||||
- direct: 直接调用 LLM(不经过 ReAct 循环)
|
||||
- custom: 使用自定义 handler
|
||||
|
||||
如果没有 SkillConfig,回退到传统 task_mode 分支。
|
||||
"""
|
||||
# v2: ReAct mode
|
||||
if (
|
||||
self._skill_config
|
||||
and self._skill_config.execution_mode == "react"
|
||||
and self._react_engine
|
||||
):
|
||||
return await self._handle_react(task)
|
||||
# Lazy-register MCP tools on first task execution
|
||||
await self._register_mcp_tools()
|
||||
|
||||
# Fall back to existing modes
|
||||
# v2: execution_mode routing (when SkillConfig is present)
|
||||
if self._skill_config:
|
||||
execution_mode = self._skill_config.execution_mode
|
||||
|
||||
if execution_mode == "react" and self._react_engine:
|
||||
return await self._handle_react(task)
|
||||
elif execution_mode == "direct":
|
||||
return await self._handle_direct(task)
|
||||
elif execution_mode == "custom":
|
||||
return await self._handle_custom(task)
|
||||
|
||||
# Fall back to existing task_mode modes
|
||||
if self._config.task_mode == "llm_generate":
|
||||
return await self._handle_llm_generate(task)
|
||||
elif self._config.task_mode == "tool_call":
|
||||
|
|
@ -394,33 +488,75 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
|||
|
||||
async def _handle_react(self, task: TaskMessage) -> dict:
|
||||
"""ReAct mode: use ReAct engine for autonomous reasoning"""
|
||||
# Build messages from prompt template
|
||||
# Build variables for prompt rendering
|
||||
variables = task.input_data.copy()
|
||||
variables["task_type"] = task.task_type
|
||||
|
||||
# Use PromptTemplate.render() to get full messages (system + user)
|
||||
if self._prompt_template:
|
||||
messages = self._prompt_template.render(variables=variables)
|
||||
rendered_messages = self._prompt_template.render(variables=variables)
|
||||
else:
|
||||
messages = [{"role": "user", "content": str(task.input_data)}]
|
||||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||||
|
||||
# Get system prompt from skill config
|
||||
# Separate system_prompt from user messages
|
||||
# PromptTemplate.render() returns [system_msg, user_msg] or [user_msg]
|
||||
system_prompt = None
|
||||
if self._skill_config and self._skill_config.prompt:
|
||||
system_prompt = self._skill_config.prompt.get("identity", "")
|
||||
user_messages = []
|
||||
for msg in rendered_messages:
|
||||
if msg["role"] == "system":
|
||||
system_prompt = msg["content"]
|
||||
else:
|
||||
user_messages.append(msg)
|
||||
|
||||
# If no user messages, add a default one
|
||||
if not user_messages:
|
||||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||||
|
||||
# Execute ReAct loop
|
||||
result = await self._react_engine.execute(
|
||||
messages=messages,
|
||||
messages=user_messages,
|
||||
tools=self._tools if self._tools else None,
|
||||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
system_prompt=system_prompt,
|
||||
memory_retriever=self._memory_retriever,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Parse result
|
||||
return self._parse_llm_response(result.output)
|
||||
|
||||
async def _handle_direct(self, task: TaskMessage) -> dict:
|
||||
"""Direct mode: single LLM call without ReAct loop.
|
||||
|
||||
Renders the full prompt template and makes one LLM call via LLMGateway.
|
||||
Falls back to _handle_llm_generate if no LLMGateway is available.
|
||||
"""
|
||||
if not self._llm_gateway:
|
||||
return await self._handle_llm_generate(task)
|
||||
|
||||
# Build variables for prompt rendering
|
||||
variables = task.input_data.copy()
|
||||
variables["task_type"] = task.task_type
|
||||
|
||||
# Use PromptTemplate.render() to get full messages
|
||||
if self._prompt_template:
|
||||
rendered_messages = self._prompt_template.render(variables=variables)
|
||||
else:
|
||||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||||
|
||||
# Make a single LLM call
|
||||
model = self._config.llm.get("model", "default") if self._config.llm else "default"
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=rendered_messages,
|
||||
model=model,
|
||||
agent_name=self.name,
|
||||
task_type=task.task_type,
|
||||
)
|
||||
|
||||
return self._parse_llm_response(response.content)
|
||||
|
||||
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
|
||||
"""Re-execute task with quality feedback"""
|
||||
enhanced_input = task.input_data.copy()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,66 @@
|
|||
"""Structured logging configuration for AgentKit.
|
||||
|
||||
Provides JSON-formatted structured logs using Python's built-in logging module.
|
||||
No external dependencies required.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
class StructuredFormatter(logging.Formatter):
|
||||
"""JSON structured log formatter.
|
||||
|
||||
Outputs each log record as a single-line JSON object with standard fields
|
||||
(timestamp, level, logger, message) plus optional structured fields
|
||||
(trace_id, agent_name, skill_name, task_id).
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
log_entry: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
|
||||
# Add optional structured fields from LogRecord extras
|
||||
for key in ("trace_id", "agent_name", "skill_name", "task_id"):
|
||||
value = getattr(record, key, None)
|
||||
if value:
|
||||
log_entry[key] = value
|
||||
|
||||
# Add exception info
|
||||
if record.exc_info and record.exc_info[1]:
|
||||
log_entry["exception"] = self.formatException(record.exc_info)
|
||||
|
||||
return json.dumps(log_entry, ensure_ascii=False)
|
||||
|
||||
|
||||
def setup_structured_logging(level: int = logging.INFO) -> None:
|
||||
"""Configure structured JSON logging for the agentkit namespace.
|
||||
|
||||
Replaces all existing handlers on the ``agentkit`` logger with a single
|
||||
:class:`StructuredFormatter`-backed stream handler.
|
||||
"""
|
||||
root_logger = logging.getLogger("agentkit")
|
||||
root_logger.setLevel(level)
|
||||
|
||||
# Remove existing handlers to avoid duplicate output
|
||||
root_logger.handlers.clear()
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(StructuredFormatter())
|
||||
root_logger.addHandler(handler)
|
||||
|
||||
|
||||
def get_logger(name: str, **extra: Any) -> logging.LoggerAdapter:
|
||||
"""Get a logger with extra structured fields.
|
||||
|
||||
The returned ``LoggerAdapter`` automatically injects *extra* keyword
|
||||
arguments into every log record so they appear in the JSON output.
|
||||
"""
|
||||
logger = logging.getLogger(f"agentkit.{name}")
|
||||
return logging.LoggerAdapter(logger, extra)
|
||||
|
|
@ -119,9 +119,10 @@ class TaskResult:
|
|||
started_at: datetime
|
||||
completed_at: datetime
|
||||
metrics: dict | None = None
|
||||
trace: Any | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
d = {
|
||||
"task_id": self.task_id,
|
||||
"agent_name": self.agent_name,
|
||||
"status": self.status,
|
||||
|
|
@ -131,6 +132,9 @@ class TaskResult:
|
|||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"metrics": self.metrics,
|
||||
}
|
||||
if self.trace is not None:
|
||||
d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "TaskResult":
|
||||
|
|
@ -149,6 +153,7 @@ class TaskResult:
|
|||
started_at=started_at or datetime.now(timezone.utc),
|
||||
completed_at=completed_at or datetime.now(timezone.utc),
|
||||
metrics=data.get("metrics"),
|
||||
trace=data.get("trace"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,19 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
from agentkit.core.trace import TraceRecorder
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -71,6 +77,10 @@ class ReActEngine:
|
|||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "ContextCompressor | None" = None,
|
||||
) -> ReActResult:
|
||||
"""执行 ReAct 循环
|
||||
|
||||
|
|
@ -82,21 +92,55 @@ class ReActEngine:
|
|||
tools = tools or []
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
|
||||
# 启动轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.start_trace(
|
||||
task_id="",
|
||||
agent_name=agent_name,
|
||||
skill_name=task_type or None,
|
||||
)
|
||||
|
||||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||||
if memory_retriever:
|
||||
try:
|
||||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=5,
|
||||
token_budget=2000,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
system_prompt += f"\n\n## Relevant Past Experience\n{memory_context}"
|
||||
else:
|
||||
system_prompt = f"## Relevant Past Experience\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
|
||||
# 构建初始消息
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
conversation.extend(messages)
|
||||
|
||||
# Context compression: 压缩超长对话历史
|
||||
if compressor:
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = 0
|
||||
step = 0
|
||||
output = ""
|
||||
trace_outcome = "success"
|
||||
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
||||
# Think: 调用 LLM
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
|
|
@ -104,12 +148,22 @@ class ReActEngine:
|
|||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||
|
||||
step_tokens = response.usage.total_tokens
|
||||
total_tokens += step_tokens
|
||||
|
||||
# 检查是否有 Function Calling 的 tool_calls
|
||||
if response.has_tool_calls:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# Act: 执行工具调用
|
||||
# 先记录 assistant 消息(含 tool_calls)到对话历史
|
||||
assistant_msg: dict[str, Any] = {
|
||||
|
|
@ -131,7 +185,10 @@ class ReActEngine:
|
|||
|
||||
# 执行每个工具调用
|
||||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
|
|
@ -142,6 +199,22 @@ class ReActEngine:
|
|||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = self._build_tool_result_message(tc.id, tool_result)
|
||||
conversation.append(tool_msg)
|
||||
|
|
@ -150,11 +223,23 @@ class ReActEngine:
|
|||
# 检查文本解析模式
|
||||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||
if parsed_calls and tools:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# 文本解析模式执行工具
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
|
||||
for pc in parsed_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
|
|
@ -165,6 +250,22 @@ class ReActEngine:
|
|||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
input_data=pc["arguments"],
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# 将工具结果添加到对话历史
|
||||
tool_msg = self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result)
|
||||
conversation.append(tool_msg)
|
||||
|
|
@ -178,10 +279,21 @@ class ReActEngine:
|
|||
)
|
||||
trajectory.append(react_step)
|
||||
output = response.content or ""
|
||||
|
||||
# 记录最终答案步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
output_data={"content": response.content},
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
break
|
||||
|
||||
# 达到 max_steps 时,返回当前最佳输出
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
# 使用最后一步的内容作为输出
|
||||
if trajectory and trajectory[-1].content:
|
||||
output = trajectory[-1].content
|
||||
|
|
@ -190,6 +302,22 @@ class ReActEngine:
|
|||
else:
|
||||
output = response.content or ""
|
||||
|
||||
# 结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||||
if memory_retriever and hasattr(memory_retriever, "_episodic") and memory_retriever._episodic:
|
||||
try:
|
||||
summary = output[:500] if output else ""
|
||||
await memory_retriever._episodic.store(
|
||||
key=f"task:{task_id or 'unknown'}",
|
||||
value={"output_summary": summary, "agent_name": agent_name},
|
||||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||||
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=trajectory,
|
||||
|
|
@ -205,6 +333,10 @@ class ReActEngine:
|
|||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "ContextCompressor | None" = None,
|
||||
):
|
||||
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||
|
||||
|
|
@ -214,15 +346,48 @@ class ReActEngine:
|
|||
tools = tools or []
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
|
||||
# 启动轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.start_trace(
|
||||
task_id="",
|
||||
agent_name=agent_name,
|
||||
skill_name=task_type or None,
|
||||
)
|
||||
|
||||
# Memory retrieval: 执行前检索相关上下文注入 system_prompt
|
||||
if memory_retriever:
|
||||
try:
|
||||
query = str(messages[-1].get("content", "")) if messages else ""
|
||||
memory_context = await memory_retriever.get_context_string(
|
||||
query=query,
|
||||
top_k=5,
|
||||
token_budget=2000,
|
||||
)
|
||||
if memory_context:
|
||||
if system_prompt:
|
||||
system_prompt += f"\n\n## Relevant Past Experience\n{memory_context}"
|
||||
else:
|
||||
system_prompt = f"## Relevant Past Experience\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if system_prompt:
|
||||
conversation.append({"role": "system", "content": system_prompt})
|
||||
conversation.extend(messages)
|
||||
|
||||
# Context compression: 压缩超长对话历史
|
||||
if compressor:
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = 0
|
||||
step = 0
|
||||
output = ""
|
||||
trace_outcome = "success"
|
||||
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
|
@ -235,6 +400,7 @@ class ReActEngine:
|
|||
)
|
||||
|
||||
# Think: call LLM
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=model,
|
||||
|
|
@ -242,11 +408,21 @@ class ReActEngine:
|
|||
task_type=task_type,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
llm_duration_ms = int((time.monotonic() - llm_start) * 1000)
|
||||
|
||||
step_tokens = response.usage.total_tokens
|
||||
total_tokens += step_tokens
|
||||
|
||||
if response.has_tool_calls:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# Record assistant message
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
|
|
@ -273,7 +449,10 @@ class ReActEngine:
|
|||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
|
|
@ -284,6 +463,22 @@ class ReActEngine:
|
|||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# Yield tool_result event
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
|
|
@ -298,6 +493,15 @@ class ReActEngine:
|
|||
# Check text parsing mode
|
||||
parsed_calls = self._parse_text_tool_calls(response.content or "")
|
||||
if parsed_calls and tools:
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
conversation.append({"role": "assistant", "content": response.content})
|
||||
|
||||
for pc in parsed_calls:
|
||||
|
|
@ -306,7 +510,9 @@ class ReActEngine:
|
|||
step=step,
|
||||
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||||
)
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
trajectory.append(ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
|
|
@ -315,6 +521,21 @@ class ReActEngine:
|
|||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
))
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
input_data=pc["arguments"],
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
|
|
@ -334,6 +555,17 @@ class ReActEngine:
|
|||
)
|
||||
trajectory.append(react_step)
|
||||
output = response.content or ""
|
||||
|
||||
# 记录最终答案步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="final_answer",
|
||||
output_data={"content": response.content},
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
|
|
@ -346,12 +578,18 @@ class ReActEngine:
|
|||
break
|
||||
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
if trajectory and trajectory[-1].content:
|
||||
output = trajectory[-1].content
|
||||
elif trajectory and trajectory[-1].result is not None:
|
||||
output = str(trajectory[-1].result)
|
||||
else:
|
||||
output = response.content or ""
|
||||
|
||||
# 结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=step,
|
||||
|
|
@ -362,6 +600,22 @@ class ReActEngine:
|
|||
"max_steps_reached": True,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 正常结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||||
if memory_retriever and hasattr(memory_retriever, "_episodic") and memory_retriever._episodic:
|
||||
try:
|
||||
summary = output[:500] if output else ""
|
||||
await memory_retriever._episodic.store(
|
||||
key=f"task:{task_id or 'unknown'}",
|
||||
value={"output_summary": summary, "agent_name": agent_name},
|
||||
metadata={"task_type": task_type, "outcome": trace_outcome},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||||
|
||||
def _build_tool_schemas(self, tools: list[Tool]) -> list[dict]:
|
||||
"""将 Tool 对象转换为 OpenAI Function Calling schema 格式"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
"""执行轨迹记录器
|
||||
|
||||
在 ReActEngine 执行过程中记录完整的执行轨迹(每步动作、输入输出、耗时、Token 用量),
|
||||
为反思和可观测性提供数据。
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceStep:
|
||||
"""单步执行轨迹"""
|
||||
|
||||
step: int
|
||||
action: str # "tool_call" | "llm_call" | "final_answer"
|
||||
tool_name: str | None = None
|
||||
input_data: dict | None = None
|
||||
output_data: Any = None
|
||||
duration_ms: int = 0
|
||||
tokens_used: int = 0
|
||||
error: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
"step": self.step,
|
||||
"action": self.action,
|
||||
"duration_ms": self.duration_ms,
|
||||
"tokens_used": self.tokens_used,
|
||||
}
|
||||
if self.tool_name is not None:
|
||||
d["tool_name"] = self.tool_name
|
||||
if self.input_data is not None:
|
||||
d["input_data"] = self.input_data
|
||||
if self.output_data is not None:
|
||||
d["output_data"] = self.output_data
|
||||
if self.error is not None:
|
||||
d["error"] = self.error
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionTrace:
|
||||
"""完整执行轨迹"""
|
||||
|
||||
task_id: str
|
||||
agent_name: str
|
||||
skill_name: str | None = None
|
||||
steps: list[TraceStep] = field(default_factory=list)
|
||||
total_duration_ms: int = 0
|
||||
total_tokens: int = 0
|
||||
outcome: str = "success" # "success" | "failure" | "partial"
|
||||
quality_score: float = 1.0 # 0.0 - 1.0
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"agent_name": self.agent_name,
|
||||
"skill_name": self.skill_name,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"total_duration_ms": self.total_duration_ms,
|
||||
"total_tokens": self.total_tokens,
|
||||
"outcome": self.outcome,
|
||||
"quality_score": self.quality_score,
|
||||
}
|
||||
|
||||
|
||||
class TraceRecorder:
|
||||
"""执行轨迹记录器
|
||||
|
||||
用法:
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(task_id="t1", agent_name="agent1")
|
||||
recorder.record_step(step=1, action="llm_call", ...)
|
||||
recorder.record_step(step=2, action="tool_call", tool_name="search", ...)
|
||||
trace = recorder.end_trace(outcome="success")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task_id: str = "",
|
||||
agent_name: str = "",
|
||||
skill_name: str | None = None,
|
||||
):
|
||||
self._trace: ExecutionTrace | None = None
|
||||
self._step_start_time: float = 0
|
||||
self._trace_start_time: float = 0
|
||||
# 如果构造时提供了参数,自动 start_trace
|
||||
if task_id:
|
||||
self.start_trace(task_id=task_id, agent_name=agent_name, skill_name=skill_name)
|
||||
|
||||
def start_trace(
|
||||
self,
|
||||
task_id: str = "",
|
||||
agent_name: str = "",
|
||||
skill_name: str | None = None,
|
||||
) -> None:
|
||||
"""开始记录执行轨迹"""
|
||||
tid = task_id or str(uuid.uuid4())
|
||||
self._trace = ExecutionTrace(
|
||||
task_id=tid,
|
||||
agent_name=agent_name,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
self._trace_start_time = time.monotonic()
|
||||
|
||||
def record_step(
|
||||
self,
|
||||
step: int,
|
||||
action: str,
|
||||
tool_name: str | None = None,
|
||||
input_data: dict | None = None,
|
||||
output_data: Any = None,
|
||||
duration_ms: int = 0,
|
||||
tokens_used: int = 0,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""记录一个执行步骤"""
|
||||
if self._trace is None:
|
||||
return
|
||||
|
||||
trace_step = TraceStep(
|
||||
step=step,
|
||||
action=action,
|
||||
tool_name=tool_name,
|
||||
input_data=input_data,
|
||||
output_data=output_data,
|
||||
duration_ms=duration_ms,
|
||||
tokens_used=tokens_used,
|
||||
error=error,
|
||||
)
|
||||
self._trace.steps.append(trace_step)
|
||||
|
||||
def end_trace(
|
||||
self,
|
||||
outcome: str = "success",
|
||||
quality_score: float = 1.0,
|
||||
) -> ExecutionTrace:
|
||||
"""结束执行轨迹记录并返回 ExecutionTrace"""
|
||||
if self._trace is None:
|
||||
# 未 start_trace 就 end_trace,返回一个空的默认轨迹
|
||||
self._trace = ExecutionTrace(
|
||||
task_id="unknown",
|
||||
agent_name="",
|
||||
)
|
||||
|
||||
self._trace.outcome = outcome
|
||||
self._trace.quality_score = quality_score
|
||||
|
||||
# 计算总耗时
|
||||
if self._trace_start_time > 0:
|
||||
self._trace.total_duration_ms = int(
|
||||
(time.monotonic() - self._trace_start_time) * 1000
|
||||
)
|
||||
|
||||
# 计算总 token
|
||||
self._trace.total_tokens = sum(s.tokens_used for s in self._trace.steps)
|
||||
|
||||
return self._trace
|
||||
|
||||
def get_trace(self) -> ExecutionTrace | None:
|
||||
"""获取当前执行轨迹(未 end_trace 前返回 None)"""
|
||||
# 如果已经 end_trace,_trace 仍然存在,但语义上 end_trace 后才算完成
|
||||
# 这里返回 _trace 本身,让调用者可以判断
|
||||
return self._trace
|
||||
|
||||
def start_step_timer(self) -> None:
|
||||
"""开始计时当前步骤"""
|
||||
self._step_start_time = time.monotonic()
|
||||
|
||||
def elapsed_ms(self) -> int:
|
||||
"""获取自 start_step_timer 以来的毫秒数"""
|
||||
if self._step_start_time == 0:
|
||||
return 0
|
||||
return int((time.monotonic() - self._step_start_time) * 1000)
|
||||
|
|
@ -4,7 +4,12 @@ from agentkit.evolution.reflector import Reflector
|
|||
from agentkit.evolution.prompt_optimizer import PromptOptimizer, Signature, Module
|
||||
from agentkit.evolution.strategy_tuner import StrategyTuner
|
||||
from agentkit.evolution.ab_tester import ABTester
|
||||
from agentkit.evolution.evolution_store import EvolutionStore
|
||||
from agentkit.evolution.evolution_store import (
|
||||
EvolutionStore,
|
||||
InMemoryEvolutionStore,
|
||||
PersistentEvolutionStore,
|
||||
create_evolution_store,
|
||||
)
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin, EvolutionLogEntry
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -15,6 +20,9 @@ __all__ = [
|
|||
"StrategyTuner",
|
||||
"ABTester",
|
||||
"EvolutionStore",
|
||||
"PersistentEvolutionStore",
|
||||
"InMemoryEvolutionStore",
|
||||
"create_evolution_store",
|
||||
"EvolutionMixin",
|
||||
"EvolutionLogEntry",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,29 @@
|
|||
"""EvolutionStore - 进化日志存储"""
|
||||
"""EvolutionStore - 进化日志存储
|
||||
|
||||
提供三种后端实现:
|
||||
- EvolutionStore: 基于外部注入的异步 SQLAlchemy session(原有实现)
|
||||
- PersistentEvolutionStore: 基于 SQLite 的持久化存储
|
||||
- InMemoryEvolutionStore: 基于内存字典的轻量存储(用于测试)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent
|
||||
from agentkit.evolution.models import (
|
||||
ABTestResultModel,
|
||||
Base,
|
||||
EvolutionEventModel,
|
||||
SkillVersionModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -111,3 +130,320 @@ class EvolutionStore:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to list evolution events: {e}")
|
||||
return []
|
||||
|
||||
|
||||
class PersistentEvolutionStore:
|
||||
"""SQLite 持久化进化存储
|
||||
|
||||
使用同步 SQLAlchemy + SQLite 实现持久化,通过 run_in_executor
|
||||
提供异步接口兼容性。
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "~/.agentkit/evolution.db"):
|
||||
self._db_path = os.path.expanduser(db_path)
|
||||
os.makedirs(os.path.dirname(self._db_path), exist_ok=True)
|
||||
self._engine = create_engine(f"sqlite:///{self._db_path}", echo=False)
|
||||
Base.metadata.create_all(self._engine)
|
||||
self._Session = sessionmaker(bind=self._engine)
|
||||
|
||||
# ── 内部辅助 ──────────────────────────────────────────
|
||||
|
||||
def _run_sync(self, func: Any) -> Any:
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_in_executor(None, func)
|
||||
|
||||
# ── 进化事件 ──────────────────────────────────────────
|
||||
|
||||
def _record_sync(self, event: EvolutionEvent) -> str:
|
||||
with self._Session() as session:
|
||||
event_id = str(_uuid.uuid4())
|
||||
entry = EvolutionEventModel(
|
||||
id=event_id,
|
||||
agent_name=event.agent_name,
|
||||
change_type=event.change_type,
|
||||
before=json.dumps(event.before, ensure_ascii=False),
|
||||
after=json.dumps(event.after, ensure_ascii=False),
|
||||
metrics=json.dumps(event.metrics, ensure_ascii=False) if event.metrics else None,
|
||||
status="active",
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
event.event_id = event_id
|
||||
logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'")
|
||||
return event_id
|
||||
|
||||
async def record(self, event: EvolutionEvent) -> str:
|
||||
"""记录进化事件"""
|
||||
return await self._run_sync(lambda: self._record_sync(event))
|
||||
|
||||
def _rollback_sync(self, event_id: str) -> bool:
|
||||
with self._Session() as session:
|
||||
stmt = select(EvolutionEventModel).where(EvolutionEventModel.id == event_id)
|
||||
entry = session.execute(stmt).scalar_one_or_none()
|
||||
if not entry:
|
||||
logger.error(f"Evolution event {event_id} not found")
|
||||
return False
|
||||
entry.status = "rolled_back"
|
||||
session.commit()
|
||||
logger.info(f"Evolution event {event_id} rolled back")
|
||||
return True
|
||||
|
||||
async def rollback(self, event_id: str) -> bool:
|
||||
"""回滚进化事件"""
|
||||
return await self._run_sync(lambda: self._rollback_sync(event_id))
|
||||
|
||||
def _list_events_sync(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
change_type: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[dict]:
|
||||
with self._Session() as session:
|
||||
stmt = select(EvolutionEventModel)
|
||||
if agent_name:
|
||||
stmt = stmt.where(EvolutionEventModel.agent_name == agent_name)
|
||||
if change_type:
|
||||
stmt = stmt.where(EvolutionEventModel.change_type == change_type)
|
||||
if status:
|
||||
stmt = stmt.where(EvolutionEventModel.status == status)
|
||||
stmt = stmt.order_by(EvolutionEventModel.created_at.desc())
|
||||
entries = session.execute(stmt).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"agent_name": e.agent_name,
|
||||
"event_type": e.event_type,
|
||||
"change_type": e.change_type,
|
||||
"before": json.loads(e.before) if e.before else None,
|
||||
"after": json.loads(e.after) if e.after else None,
|
||||
"metrics": json.loads(e.metrics) if e.metrics else None,
|
||||
"status": e.status,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
change_type: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""列出进化事件"""
|
||||
return await self._run_sync(lambda: self._list_events_sync(agent_name, change_type, status))
|
||||
|
||||
# ── 技能版本 ──────────────────────────────────────────
|
||||
|
||||
def _record_skill_version_sync(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = None
|
||||
) -> str:
|
||||
with self._Session() as session:
|
||||
vid = str(_uuid.uuid4())
|
||||
entry = SkillVersionModel(
|
||||
id=vid,
|
||||
skill_name=skill_name,
|
||||
version=version,
|
||||
content=content,
|
||||
parent_version=parent_version,
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
return vid
|
||||
|
||||
async def record_skill_version(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = None
|
||||
) -> str:
|
||||
"""记录技能版本"""
|
||||
return await self._run_sync(
|
||||
lambda: self._record_skill_version_sync(skill_name, version, content, parent_version)
|
||||
)
|
||||
|
||||
def _list_skill_versions_sync(self, skill_name: str) -> list[dict]:
|
||||
with self._Session() as session:
|
||||
stmt = (
|
||||
select(SkillVersionModel)
|
||||
.where(SkillVersionModel.skill_name == skill_name)
|
||||
.order_by(SkillVersionModel.created_at.desc())
|
||||
)
|
||||
entries = session.execute(stmt).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"skill_name": e.skill_name,
|
||||
"version": e.version,
|
||||
"content": e.content,
|
||||
"parent_version": e.parent_version,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
async def list_skill_versions(self, skill_name: str) -> list[dict]:
|
||||
"""列出技能版本历史"""
|
||||
return await self._run_sync(lambda: self._list_skill_versions_sync(skill_name))
|
||||
|
||||
# ── A/B 测试结果 ──────────────────────────────────────
|
||||
|
||||
def _record_ab_test_result_sync(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = 0
|
||||
) -> str:
|
||||
with self._Session() as session:
|
||||
rid = str(_uuid.uuid4())
|
||||
entry = ABTestResultModel(
|
||||
id=rid,
|
||||
test_id=test_id,
|
||||
variant=variant,
|
||||
score=score,
|
||||
sample_count=sample_count,
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
return rid
|
||||
|
||||
async def record_ab_test_result(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = 0
|
||||
) -> str:
|
||||
"""记录 A/B 测试结果"""
|
||||
return await self._run_sync(
|
||||
lambda: self._record_ab_test_result_sync(test_id, variant, score, sample_count)
|
||||
)
|
||||
|
||||
def _get_ab_test_results_sync(self, test_id: str) -> list[dict]:
|
||||
with self._Session() as session:
|
||||
stmt = select(ABTestResultModel).where(ABTestResultModel.test_id == test_id)
|
||||
entries = session.execute(stmt).scalars().all()
|
||||
return [
|
||||
{
|
||||
"id": e.id,
|
||||
"test_id": e.test_id,
|
||||
"variant": e.variant,
|
||||
"score": e.score,
|
||||
"sample_count": e.sample_count,
|
||||
"created_at": e.created_at.isoformat() if e.created_at else None,
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
async def get_ab_test_results(self, test_id: str) -> list[dict]:
|
||||
"""获取 A/B 测试结果"""
|
||||
return await self._run_sync(lambda: self._get_ab_test_results_sync(test_id))
|
||||
|
||||
|
||||
class InMemoryEvolutionStore:
|
||||
"""基于内存字典的进化存储(用于测试和轻量场景)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, dict] = {}
|
||||
self._skill_versions: dict[str, list[dict]] = {}
|
||||
self._ab_results: dict[str, list[dict]] = {}
|
||||
|
||||
async def record(self, event: EvolutionEvent) -> str:
|
||||
"""记录进化事件"""
|
||||
event_id = str(_uuid.uuid4())
|
||||
event.event_id = event_id
|
||||
self._events[event_id] = {
|
||||
"id": event_id,
|
||||
"agent_name": event.agent_name,
|
||||
"change_type": event.change_type,
|
||||
"before": event.before,
|
||||
"after": event.after,
|
||||
"metrics": event.metrics,
|
||||
"status": "active",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
logger.info(f"Evolution event recorded: {event_id} for agent '{event.agent_name}'")
|
||||
return event_id
|
||||
|
||||
async def rollback(self, event_id: str) -> bool:
|
||||
"""回滚进化事件"""
|
||||
if event_id not in self._events:
|
||||
logger.error(f"Evolution event {event_id} not found")
|
||||
return False
|
||||
self._events[event_id]["status"] = "rolled_back"
|
||||
logger.info(f"Evolution event {event_id} rolled back")
|
||||
return True
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
change_type: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""列出进化事件"""
|
||||
results = []
|
||||
for e in self._events.values():
|
||||
if agent_name and e["agent_name"] != agent_name:
|
||||
continue
|
||||
if change_type and e["change_type"] != change_type:
|
||||
continue
|
||||
if status and e["status"] != status:
|
||||
continue
|
||||
results.append(e)
|
||||
results.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
return results
|
||||
|
||||
async def record_skill_version(
|
||||
self, skill_name: str, version: str, content: str, parent_version: str | None = None
|
||||
) -> str:
|
||||
"""记录技能版本"""
|
||||
vid = str(_uuid.uuid4())
|
||||
entry = {
|
||||
"id": vid,
|
||||
"skill_name": skill_name,
|
||||
"version": version,
|
||||
"content": content,
|
||||
"parent_version": parent_version,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
self._skill_versions.setdefault(skill_name, []).append(entry)
|
||||
return vid
|
||||
|
||||
async def list_skill_versions(self, skill_name: str) -> list[dict]:
|
||||
"""列出技能版本历史"""
|
||||
versions = self._skill_versions.get(skill_name, [])
|
||||
return sorted(versions, key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
async def record_ab_test_result(
|
||||
self, test_id: str, variant: str, score: float, sample_count: int = 0
|
||||
) -> str:
|
||||
"""记录 A/B 测试结果"""
|
||||
rid = str(_uuid.uuid4())
|
||||
entry = {
|
||||
"id": rid,
|
||||
"test_id": test_id,
|
||||
"variant": variant,
|
||||
"score": score,
|
||||
"sample_count": sample_count,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
self._ab_results.setdefault(test_id, []).append(entry)
|
||||
return rid
|
||||
|
||||
async def get_ab_test_results(self, test_id: str) -> list[dict]:
|
||||
"""获取 A/B 测试结果"""
|
||||
return self._ab_results.get(test_id, [])
|
||||
|
||||
|
||||
def create_evolution_store(
|
||||
backend: str = "memory",
|
||||
db_path: str = "~/.agentkit/evolution.db",
|
||||
session_factory: Any = None,
|
||||
evolution_model: Any = None,
|
||||
) -> EvolutionStore | PersistentEvolutionStore | InMemoryEvolutionStore:
|
||||
"""工厂函数:创建进化存储实例
|
||||
|
||||
Args:
|
||||
backend: 存储后端类型 - "memory" | "sqlite" | "sql"
|
||||
db_path: SQLite 数据库路径(仅 backend="sqlite" 时使用)
|
||||
session_factory: 异步 SQLAlchemy session 工厂(仅 backend="sql" 时使用)
|
||||
evolution_model: SQLAlchemy ORM 模型类(仅 backend="sql" 时使用)
|
||||
|
||||
Returns:
|
||||
对应后端的进化存储实例
|
||||
"""
|
||||
if backend == "sqlite":
|
||||
return PersistentEvolutionStore(db_path=db_path)
|
||||
elif backend == "sql" and session_factory and evolution_model:
|
||||
return EvolutionStore(session_factory=session_factory, evolution_model=evolution_model)
|
||||
else:
|
||||
return InMemoryEvolutionStore()
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ from typing import Any
|
|||
from agentkit.core.protocol import EvolutionEvent, TaskMessage, TaskResult
|
||||
from agentkit.evolution.ab_tester import ABTestConfig, ABTestResult, ABTester
|
||||
from agentkit.evolution.evolution_store import EvolutionStore
|
||||
from agentkit.evolution.llm_reflector import LLMReflector
|
||||
from agentkit.evolution.prompt_optimizer import Module, PromptOptimizer
|
||||
from agentkit.evolution.reflector import Reflection, Reflector
|
||||
from agentkit.evolution.reflector import Reflection, Reflector, RuleBasedReflector
|
||||
from agentkit.evolution.strategy_tuner import StrategyConfig, StrategyTuner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -41,15 +42,30 @@ class EvolutionMixin:
|
|||
EvolutionMixin.__init__(self, reflector=..., ...)
|
||||
"""
|
||||
|
||||
_UNSET = object() # 用于区分"未传入"和"显式传入 None"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reflector: Reflector | None = None,
|
||||
reflector: Any = _UNSET,
|
||||
prompt_optimizer: PromptOptimizer | None = None,
|
||||
strategy_tuner: StrategyTuner | None = None,
|
||||
ab_tester: ABTester | None = None,
|
||||
evolution_store: EvolutionStore | None = None,
|
||||
reflector_type: str | None = None,
|
||||
llm_gateway: Any | None = None,
|
||||
auxiliary_model: str | None = None,
|
||||
):
|
||||
self._reflector = reflector
|
||||
if reflector is not EvolutionMixin._UNSET:
|
||||
# 显式传入了 reflector 参数(包括 None)
|
||||
self._reflector = reflector
|
||||
elif reflector_type is not None:
|
||||
# 未传入 reflector,但指定了 reflector_type → 自动创建
|
||||
self._reflector = self._create_reflector(
|
||||
reflector_type, llm_gateway, auxiliary_model
|
||||
)
|
||||
else:
|
||||
# 都未指定:保持向后兼容,reflector 为 None
|
||||
self._reflector = None
|
||||
self._prompt_optimizer = prompt_optimizer
|
||||
self._strategy_tuner = strategy_tuner
|
||||
self._ab_tester = ab_tester
|
||||
|
|
@ -57,6 +73,39 @@ class EvolutionMixin:
|
|||
self._evolution_log: list[EvolutionLogEntry] = []
|
||||
self._current_module: Module | None = None
|
||||
|
||||
@staticmethod
|
||||
def _create_reflector(
|
||||
reflector_type: str,
|
||||
llm_gateway: Any | None = None,
|
||||
auxiliary_model: str | None = None,
|
||||
) -> Reflector | None:
|
||||
"""根据 reflector_type 创建对应的反思器
|
||||
|
||||
Args:
|
||||
reflector_type: "llm" / "rule" / "auto"
|
||||
llm_gateway: LLMGateway 实例,llm/auto 模式需要
|
||||
auxiliary_model: LLM 反思使用的模型名称
|
||||
"""
|
||||
if reflector_type == "llm":
|
||||
if llm_gateway is None:
|
||||
logger.warning(
|
||||
"reflector_type='llm' but no llm_gateway provided, "
|
||||
"falling back to RuleBasedReflector"
|
||||
)
|
||||
return RuleBasedReflector()
|
||||
model = auxiliary_model or "default"
|
||||
return LLMReflector(llm_gateway=llm_gateway, model=model)
|
||||
|
||||
if reflector_type == "rule":
|
||||
return RuleBasedReflector()
|
||||
|
||||
# "auto" 模式:优先 LLM,降级到规则
|
||||
if llm_gateway is not None:
|
||||
model = auxiliary_model or "default"
|
||||
return LLMReflector(llm_gateway=llm_gateway, model=model)
|
||||
|
||||
return RuleBasedReflector()
|
||||
|
||||
async def evolve_after_task(self, task: TaskMessage, result: TaskResult) -> EvolutionLogEntry:
|
||||
"""任务完成后执行进化流程。
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
"""LLMReflector - LLM 驱动的执行反思器
|
||||
|
||||
通过 LLM 分析执行轨迹生成结构化反思,比 RuleBasedReflector 提供更深入的洞察。
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.trace import ExecutionTrace
|
||||
from agentkit.evolution.reflector import Reflection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMReflector:
|
||||
"""LLM 驱动的反思器,通过 LLM 分析执行轨迹生成结构化反思"""
|
||||
|
||||
def __init__(self, llm_gateway: Any, model: str = "default"):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._model = model
|
||||
|
||||
async def reflect(
|
||||
self, task: Any, result: Any, trace: ExecutionTrace | None = None
|
||||
) -> Reflection:
|
||||
"""通过 LLM 分析执行轨迹生成结构化反思"""
|
||||
prompt = self._build_reflection_prompt(task, result, trace)
|
||||
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._model,
|
||||
agent_name="reflector",
|
||||
task_type="reflection",
|
||||
)
|
||||
return self._parse_reflection_response(response.content, task, result)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM reflection failed, returning default: {e}")
|
||||
return Reflection(
|
||||
task_id=getattr(task, "task_id", "unknown"),
|
||||
agent_name=getattr(task, "agent_name", "unknown"),
|
||||
outcome="failure",
|
||||
quality_score=0.0,
|
||||
patterns=[],
|
||||
insights=[f"LLM reflection failed: {str(e)}"],
|
||||
suggestions=["Consider using rule-based reflector as fallback"],
|
||||
)
|
||||
|
||||
def _build_reflection_prompt(
|
||||
self, task: Any, result: Any, trace: ExecutionTrace | None
|
||||
) -> str:
|
||||
"""构建 LLM 反思提示"""
|
||||
parts = [
|
||||
"Analyze the following task execution and provide a structured reflection.",
|
||||
"",
|
||||
"## Task Information",
|
||||
f"- Task ID: {getattr(task, 'task_id', 'unknown')}",
|
||||
f"- Task Type: {getattr(task, 'task_type', 'unknown')}",
|
||||
f"- Agent: {getattr(task, 'agent_name', 'unknown')}",
|
||||
]
|
||||
|
||||
if trace:
|
||||
parts.append("")
|
||||
parts.append("## Execution Trace")
|
||||
parts.append(f"- Total Steps: {len(trace.steps)}")
|
||||
parts.append(f"- Total Duration: {trace.total_duration_ms}ms")
|
||||
parts.append(f"- Total Tokens: {trace.total_tokens}")
|
||||
parts.append(f"- Outcome: {trace.outcome}")
|
||||
for step in trace.steps:
|
||||
parts.append(f" Step {step.step}: {step.action}")
|
||||
if step.tool_name:
|
||||
parts.append(f" Tool: {step.tool_name}")
|
||||
if step.error:
|
||||
parts.append(f" Error: {step.error}")
|
||||
|
||||
result_status = getattr(result, "status", None)
|
||||
if result_status:
|
||||
parts.append("")
|
||||
parts.append("## Result")
|
||||
parts.append(f"- Status: {result_status}")
|
||||
error = getattr(result, "error_message", None)
|
||||
if error:
|
||||
parts.append(f"- Error: {error}")
|
||||
|
||||
parts.append("")
|
||||
parts.append("## Required Output Format")
|
||||
parts.append("Provide your analysis in the following JSON format:")
|
||||
parts.append(
|
||||
"""```json
|
||||
{
|
||||
"outcome": "success|failure|partial",
|
||||
"quality_score": 0.0-1.0,
|
||||
"patterns": ["pattern1", "pattern2"],
|
||||
"insights": ["insight1", "insight2"],
|
||||
"suggestions": ["suggestion1", "suggestion2"]
|
||||
}
|
||||
```"""
|
||||
)
|
||||
return "\n".join(parts)
|
||||
|
||||
def _parse_reflection_response(
|
||||
self, response_content: str, task: Any, result: Any
|
||||
) -> Reflection:
|
||||
"""将 LLM 响应解析为 Reflection 数据类"""
|
||||
# 尝试从代码块中提取 JSON
|
||||
json_match = re.search(
|
||||
r"```(?:json)?\s*\n?(.*?)\n?```", response_content, re.DOTALL
|
||||
)
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group(1))
|
||||
return self._build_reflection_from_data(data, task)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 尝试直接解析 JSON
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
return self._build_reflection_from_data(data, task)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 降级:返回基本反思
|
||||
return Reflection(
|
||||
task_id=getattr(task, "task_id", "unknown"),
|
||||
agent_name=getattr(task, "agent_name", "unknown"),
|
||||
outcome="partial",
|
||||
quality_score=0.5,
|
||||
patterns=[],
|
||||
insights=["LLM response could not be parsed as structured reflection"],
|
||||
suggestions=["Review LLM output format"],
|
||||
)
|
||||
|
||||
def _build_reflection_from_data(self, data: dict, task: Any) -> Reflection:
|
||||
"""从解析后的字典构建 Reflection"""
|
||||
return Reflection(
|
||||
task_id=getattr(task, "task_id", "unknown"),
|
||||
agent_name=getattr(task, "agent_name", "unknown"),
|
||||
outcome=data.get("outcome", "partial"),
|
||||
quality_score=float(data.get("quality_score", 0.5)),
|
||||
patterns=data.get("patterns", []),
|
||||
insights=data.get("insights", []),
|
||||
suggestions=data.get("suggestions", []),
|
||||
)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""SQLAlchemy ORM models for evolution persistence (SQLite-backed)."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, Float, Integer, String, Text, create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class EvolutionEventModel(Base):
|
||||
"""进化事件 ORM 模型"""
|
||||
|
||||
__tablename__ = "evolution_events"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
agent_name = Column(String, index=True)
|
||||
event_type = Column(String) # "reflection", "optimization", "ab_test", "apply", "rollback"
|
||||
trace_id = Column(String, nullable=True)
|
||||
reflection_id = Column(String, nullable=True)
|
||||
proposal_id = Column(String, nullable=True)
|
||||
change_type = Column(String, nullable=True)
|
||||
before = Column(Text, nullable=True) # JSON string
|
||||
after = Column(Text, nullable=True) # JSON string
|
||||
metrics = Column(Text, nullable=True) # JSON string
|
||||
status = Column(String, default="active") # "active", "rolled_back"
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class SkillVersionModel(Base):
|
||||
"""技能版本 ORM 模型"""
|
||||
|
||||
__tablename__ = "skill_versions"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
skill_name = Column(String, index=True)
|
||||
version = Column(String)
|
||||
content = Column(Text) # JSON string of skill config
|
||||
parent_version = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class ABTestResultModel(Base):
|
||||
"""A/B 测试结果 ORM 模型"""
|
||||
|
||||
__tablename__ = "ab_test_results"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
test_id = Column(String, index=True)
|
||||
variant = Column(String) # "control" or "experiment"
|
||||
score = Column(Float)
|
||||
sample_count = Column(Integer, default=0)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
|
@ -26,8 +26,8 @@ class Reflection:
|
|||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class Reflector:
|
||||
"""执行反思器
|
||||
class RuleBasedReflector:
|
||||
"""基于规则的执行反思器
|
||||
|
||||
评估任务结果,提取成功/失败模式,生成改进建议。
|
||||
"""
|
||||
|
|
@ -145,3 +145,7 @@ class Reflector:
|
|||
suggestions.append("Consider adjusting strategy parameters for faster execution")
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
# 向后兼容别名
|
||||
Reflector = RuleBasedReflector
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ class MCPServer:
|
|||
"""创建 FastAPI 应用"""
|
||||
try:
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
except ImportError:
|
||||
raise ImportError("MCP Server requires fastapi: pip install fischer-agentkit[mcp]")
|
||||
|
||||
|
|
@ -65,6 +66,67 @@ class MCPServer:
|
|||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.post("/")
|
||||
async def jsonrpc_endpoint(request: Request):
|
||||
"""JSON-RPC 2.0 endpoint for MCP protocol compatibility.
|
||||
|
||||
Handles requests from HTTPTransport which sends JSON-RPC format.
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}
|
||||
|
||||
method = body.get("method", "")
|
||||
params = body.get("params", {})
|
||||
req_id = body.get("id")
|
||||
|
||||
if method == "initialize":
|
||||
result = {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "agentkit-mcp-server", "version": "2.0.0"},
|
||||
}
|
||||
elif method == "tools/list":
|
||||
if self._tool_registry is None:
|
||||
result = {"tools": []}
|
||||
else:
|
||||
tools = self._tool_registry.list_tools()
|
||||
result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"inputSchema": t.input_schema or {},
|
||||
}
|
||||
for t in tools
|
||||
]
|
||||
}
|
||||
elif method == "tools/call":
|
||||
tool_name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name or self._tool_registry is None:
|
||||
result = {"isError": True, "content": [{"type": "text", "text": "Tool not found"}]}
|
||||
else:
|
||||
try:
|
||||
tool = self._tool_registry.get(tool_name)
|
||||
tool_result = await tool.safe_execute(**arguments)
|
||||
result = {"content": [{"type": "text", "text": str(tool_result)}]}
|
||||
except Exception as e:
|
||||
result = {"isError": True, "content": [{"type": "text", "text": str(e)}]}
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||
"id": req_id,
|
||||
}
|
||||
|
||||
response = {"jsonrpc": "2.0", "result": result, "id": req_id}
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
async def start(self):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
"""Embedder 接口与实现 - 文本向量化"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Embedder(ABC):
|
||||
"""文本嵌入抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""生成文本的嵌入向量"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_dimension(self) -> int:
|
||||
"""返回嵌入向量的维度"""
|
||||
...
|
||||
|
||||
|
||||
class OpenAIEmbedder(Embedder):
|
||||
"""OpenAI Embeddings API 实现"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model: str = "text-embedding-3-small",
|
||||
base_url: str | None = None,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._base_url = base_url
|
||||
self._dimension = 1536 # text-embedding-3-small 默认维度
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""使用 OpenAI API 生成嵌入向量"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
api_key = self._api_key or os.environ.get("OPENAI_API_KEY", "")
|
||||
base_url = self._base_url or "https://api.openai.com/v1"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{base_url}/embeddings",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
json={"input": text, "model": self._model},
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
embedding = data["data"][0]["embedding"]
|
||||
self._dimension = len(embedding)
|
||||
return embedding
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI embedding failed: {e}")
|
||||
raise
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
return self._dimension
|
||||
|
||||
|
||||
class MockEmbedder(Embedder):
|
||||
"""Mock Embedder - 生成确定性伪嵌入向量,用于测试"""
|
||||
|
||||
def __init__(self, dimension: int = 128):
|
||||
self._dimension = dimension
|
||||
|
||||
async def embed(self, text: str) -> list[float]:
|
||||
"""基于文本哈希生成确定性伪嵌入向量"""
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
vector = []
|
||||
for i in range(self._dimension):
|
||||
byte_idx = i % len(hash_bytes)
|
||||
vector.append(hash_bytes[byte_idx] / 255.0)
|
||||
# 归一化为单位向量
|
||||
magnitude = sum(x**2 for x in vector) ** 0.5
|
||||
if magnitude > 0:
|
||||
vector = [x / magnitude for x in vector]
|
||||
return vector
|
||||
|
||||
def get_dimension(self) -> int:
|
||||
return self._dimension
|
||||
|
|
@ -6,6 +6,7 @@ from datetime import datetime, timezone
|
|||
from typing import Any
|
||||
|
||||
from agentkit.memory.base import Memory, MemoryItem
|
||||
from agentkit.memory.embedder import Embedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,8 +22,9 @@ class EpisodicMemory(Memory):
|
|||
self,
|
||||
session_factory: Any,
|
||||
episodic_model: Any,
|
||||
embedder: Any | None = None,
|
||||
embedder: Embedder | None = None,
|
||||
decay_rate: float = 0.01,
|
||||
alpha: float = 0.7,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -30,11 +32,13 @@ class EpisodicMemory(Memory):
|
|||
episodic_model: EpisodicMemory ORM 模型类
|
||||
embedder: 嵌入器,用于生成向量
|
||||
decay_rate: 时间衰减率(越大衰减越快)
|
||||
alpha: 混合评分权重,alpha * cosine + (1-alpha) * time_decay
|
||||
"""
|
||||
self._session_factory = session_factory
|
||||
self._episodic_model = episodic_model
|
||||
self._embedder = embedder
|
||||
self._decay_rate = decay_rate
|
||||
self._alpha = alpha
|
||||
|
||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||
"""存储任务经验"""
|
||||
|
|
@ -67,8 +71,60 @@ class EpisodicMemory(Memory):
|
|||
raise
|
||||
|
||||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||
"""按 key 精确检索(Episodic Memory 通常不按 key 检索)"""
|
||||
return None
|
||||
"""按 key 语义检索(使用 embedding 相似度)"""
|
||||
if not self._embedder:
|
||||
return None
|
||||
|
||||
async with self._session_factory() as db:
|
||||
try:
|
||||
Model = self._episodic_model
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(Model).order_by(Model.created_at.desc()).limit(50)
|
||||
result = await db.execute(stmt)
|
||||
entries = result.scalars().all()
|
||||
|
||||
if not entries:
|
||||
return None
|
||||
|
||||
query_embedding = await self._embedder.embed(key)
|
||||
best_item = None
|
||||
best_score = -1.0
|
||||
|
||||
for entry in entries:
|
||||
entry_embedding = entry.embedding
|
||||
if entry_embedding is None:
|
||||
continue
|
||||
cosine = self._compute_cosine_similarity(query_embedding, entry_embedding)
|
||||
if cosine > best_score:
|
||||
best_score = cosine
|
||||
best_item = entry
|
||||
|
||||
if best_item is None or best_score < 0.1:
|
||||
return None
|
||||
|
||||
return MemoryItem(
|
||||
key=str(best_item.id),
|
||||
value={
|
||||
"input_summary": best_item.input_summary,
|
||||
"output_summary": best_item.output_summary,
|
||||
"outcome": best_item.outcome,
|
||||
"quality_score": best_item.quality_score,
|
||||
"reflection": best_item.reflection,
|
||||
},
|
||||
metadata={
|
||||
"agent_name": best_item.agent_name,
|
||||
"task_type": best_item.task_type,
|
||||
"created_at": best_item.created_at.isoformat() if best_item.created_at else None,
|
||||
"cosine_similarity": best_score,
|
||||
},
|
||||
score=best_score,
|
||||
created_at=best_item.created_at or datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve episodic memory: {e}")
|
||||
return None
|
||||
|
||||
async def search(self, query: str, top_k: int = 5, filters: dict[str, Any] | None = None) -> list[MemoryItem]:
|
||||
"""语义检索相似历史案例"""
|
||||
|
|
@ -78,7 +134,7 @@ class EpisodicMemory(Memory):
|
|||
filters = filters or {}
|
||||
|
||||
# 构建查询
|
||||
from sqlalchemy import select, text as sql_text
|
||||
from sqlalchemy import select
|
||||
stmt = select(Model)
|
||||
|
||||
if filters.get("agent_name"):
|
||||
|
|
@ -93,18 +149,24 @@ class EpisodicMemory(Memory):
|
|||
result = await db.execute(stmt)
|
||||
entries = result.scalars().all()
|
||||
|
||||
# 如果有 embedder,进行向量相似度排序
|
||||
# 如果有 embedder,生成 query embedding
|
||||
query_embedding = None
|
||||
if self._embedder and entries:
|
||||
query_embedding = await self._embedder.embed(query)
|
||||
# TODO: 使用 pgvector 的 cosine distance 排序
|
||||
# 目前按时间衰减排序
|
||||
|
||||
# 时间衰减排序
|
||||
# 计算得分并构建 MemoryItem
|
||||
items = []
|
||||
for entry in entries:
|
||||
age_hours = (datetime.now(timezone.utc) - entry.created_at).total_seconds() / 3600 if entry.created_at else 0
|
||||
decay = math.exp(-self._decay_rate * age_hours)
|
||||
score = (entry.quality_score or 0.5) * decay
|
||||
time_decay_score = (entry.quality_score or 0.5) * decay
|
||||
|
||||
# 混合评分:alpha * cosine + (1 - alpha) * time_decay
|
||||
if self._embedder and query_embedding is not None and entry.embedding is not None:
|
||||
cosine_sim = self._compute_cosine_similarity(query_embedding, entry.embedding)
|
||||
score = self._alpha * cosine_sim + (1 - self._alpha) * time_decay_score
|
||||
else:
|
||||
score = time_decay_score
|
||||
|
||||
items.append(MemoryItem(
|
||||
key=str(entry.id),
|
||||
|
|
@ -147,3 +209,20 @@ class EpisodicMemory(Memory):
|
|||
await db.rollback()
|
||||
logger.error(f"Failed to delete episodic memory: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _compute_cosine_similarity(vec_a: list[float], vec_b: list[float]) -> float:
|
||||
"""计算两个向量的余弦相似度"""
|
||||
if len(vec_a) != len(vec_b):
|
||||
logger.warning(
|
||||
f"Vector dimension mismatch: {len(vec_a)} vs {len(vec_b)}"
|
||||
)
|
||||
return 0.0
|
||||
if not vec_a:
|
||||
return 0.0
|
||||
dot_product = sum(a * b for a, b in zip(vec_a, vec_b))
|
||||
magnitude_a = sum(a**2 for a in vec_a) ** 0.5
|
||||
magnitude_b = sum(b**2 for b in vec_b) ** 0.5
|
||||
if magnitude_a == 0.0 or magnitude_b == 0.0:
|
||||
return 0.0
|
||||
return dot_product / (magnitude_a * magnitude_b)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""PromptTemplate - Prompt 模板渲染"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -69,3 +71,31 @@ class PromptTemplate:
|
|||
@property
|
||||
def sections(self) -> PromptSection:
|
||||
return self._sections
|
||||
|
||||
def render_cached(
|
||||
self,
|
||||
variables: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Render with caching - returns cached result for same variables
|
||||
|
||||
Uses MD5 hash of the variables dict as cache key.
|
||||
Same variables will return the previously rendered result.
|
||||
"""
|
||||
cache_key = hashlib.md5(
|
||||
json.dumps(variables or {}, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
if not hasattr(self, "_render_cache"):
|
||||
self._render_cache = {}
|
||||
|
||||
if cache_key in self._render_cache:
|
||||
return self._render_cache[cache_key]
|
||||
|
||||
result = self.render(variables=variables)
|
||||
self._render_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the render cache"""
|
||||
if hasattr(self, "_render_cache"):
|
||||
self._render_cache.clear()
|
||||
|
|
|
|||
|
|
@ -8,15 +8,49 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from agentkit.core.agent_pool import AgentPool
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||
from agentkit.quality.gate import QualityGate
|
||||
from agentkit.quality.output import OutputStandardizer
|
||||
from agentkit.router.intent import IntentRouter
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics
|
||||
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||
from agentkit.server.task_store import TaskStore
|
||||
from agentkit.server.task_store import create_task_store
|
||||
from agentkit.server.runner import BackgroundRunner
|
||||
from agentkit.core.logging import setup_structured_logging
|
||||
|
||||
|
||||
def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||
"""Build LLMGateway from ServerConfig, registering all providers."""
|
||||
gateway = LLMGateway(config=config.llm_config)
|
||||
|
||||
for name, pconf in config.llm_config.providers.items():
|
||||
if not pconf.api_key:
|
||||
continue # Skip providers without API keys
|
||||
try:
|
||||
provider = OpenAICompatibleProvider(
|
||||
api_key=pconf.api_key,
|
||||
base_url=pconf.base_url,
|
||||
)
|
||||
gateway.register_provider(name, provider)
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to register LLM provider '{name}': {e}")
|
||||
|
||||
return gateway
|
||||
|
||||
|
||||
def _build_skill_registry(config: ServerConfig) -> SkillRegistry:
|
||||
"""Build SkillRegistry from ServerConfig, loading all skill configs."""
|
||||
registry = SkillRegistry()
|
||||
skill_configs = config.load_skill_configs()
|
||||
for skill_config in skill_configs:
|
||||
skill = Skill(config=skill_config)
|
||||
registry.register(skill)
|
||||
return registry
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -35,10 +69,32 @@ def create_app(
|
|||
tool_registry: ToolRegistry | None = None,
|
||||
api_key: str | None = None,
|
||||
rate_limit: int | None = None,
|
||||
server_config: ServerConfig | None = None,
|
||||
) -> FastAPI:
|
||||
"""Create and configure the FastAPI application"""
|
||||
"""Create and configure the FastAPI application
|
||||
|
||||
When called by uvicorn (factory=True), automatically loads ServerConfig
|
||||
from AGENTKIT_CONFIG_PATH env var if server_config is not provided.
|
||||
"""
|
||||
# Auto-load config from env var if not provided (uvicorn factory mode)
|
||||
if server_config is None:
|
||||
config_path = os.environ.get("AGENTKIT_CONFIG_PATH")
|
||||
if config_path and os.path.exists(config_path):
|
||||
server_config = ServerConfig.from_yaml(config_path)
|
||||
app = FastAPI(title="AgentKit Server", version="2.0.0", lifespan=lifespan)
|
||||
|
||||
# Initialize structured logging
|
||||
setup_structured_logging()
|
||||
|
||||
# Resolve effective API key and rate limit
|
||||
effective_api_key = api_key
|
||||
effective_rate_limit = rate_limit
|
||||
if server_config:
|
||||
if effective_api_key is None:
|
||||
effective_api_key = server_config.api_key
|
||||
if effective_rate_limit is None:
|
||||
effective_rate_limit = server_config.rate_limit
|
||||
|
||||
# CORS 配置
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
|
|
@ -48,15 +104,23 @@ def create_app(
|
|||
)
|
||||
|
||||
# Auth middleware
|
||||
if api_key:
|
||||
os.environ["AGENTKIT_API_KEY"] = api_key
|
||||
if effective_api_key:
|
||||
os.environ["AGENTKIT_API_KEY"] = effective_api_key
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
|
||||
# Rate limiting middleware
|
||||
if rate_limit is not None:
|
||||
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(rate_limit)
|
||||
if effective_rate_limit is not None:
|
||||
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(effective_rate_limit)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# Build LLM Gateway from config if not provided
|
||||
if llm_gateway is None and server_config:
|
||||
llm_gateway = _build_llm_gateway(server_config)
|
||||
|
||||
# Build Skill Registry from config if not provided
|
||||
if skill_registry is None and server_config:
|
||||
skill_registry = _build_skill_registry(server_config)
|
||||
|
||||
# Initialize shared state
|
||||
app.state.llm_gateway = llm_gateway or LLMGateway()
|
||||
app.state.skill_registry = skill_registry or SkillRegistry()
|
||||
|
|
@ -69,8 +133,45 @@ def create_app(
|
|||
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
||||
app.state.quality_gate = QualityGate()
|
||||
app.state.output_standardizer = OutputStandardizer()
|
||||
app.state.task_store = TaskStore()
|
||||
# Initialize task store from config
|
||||
ts_config = server_config.task_store if server_config else {}
|
||||
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
||||
ts_env = os.environ.get("AGENTKIT_TASK_STORE")
|
||||
if ts_env:
|
||||
import json as _json
|
||||
try:
|
||||
ts_config = {**ts_config, **_json.loads(ts_env)}
|
||||
except Exception:
|
||||
pass
|
||||
task_store = create_task_store(
|
||||
backend=ts_config.get("backend", "memory"),
|
||||
redis_url=ts_config.get("redis_url", "redis://localhost:6379/0"),
|
||||
ttl_seconds=ts_config.get("ttl_seconds", 3600),
|
||||
max_records=ts_config.get("max_records", 10000),
|
||||
)
|
||||
app.state.task_store = task_store
|
||||
app.state.runner = BackgroundRunner(task_store=app.state.task_store)
|
||||
app.state.server_config = server_config
|
||||
|
||||
# Initialize memory components if configured
|
||||
if server_config and hasattr(server_config, 'memory') and server_config.memory:
|
||||
try:
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
from agentkit.memory.working import WorkingMemory
|
||||
|
||||
working = None
|
||||
if server_config.memory.get("working", {}).get("enabled"):
|
||||
import redis.asyncio as aioredis
|
||||
redis_url = server_config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||||
working = WorkingMemory(redis=redis_client)
|
||||
|
||||
memory_retriever = MemoryRetriever(working_memory=working)
|
||||
app.state.memory_retriever = memory_retriever
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(f"Failed to initialize memory components: {e}")
|
||||
app.state.memory_retriever = None
|
||||
|
||||
# Include routes
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
|
|
@ -78,5 +179,6 @@ def create_app(
|
|||
app.include_router(skills.router, prefix="/api/v1")
|
||||
app.include_router(llm.router, prefix="/api/v1")
|
||||
app.include_router(health.router, prefix="/api/v1")
|
||||
app.include_router(metrics.router, prefix="/api/v1")
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -0,0 +1,220 @@
|
|||
"""Server configuration loader - loads agentkit.yaml and .env"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from agentkit.llm.config import LLMConfig, ProviderConfig
|
||||
from agentkit.skills.base import SkillConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default config file name
|
||||
DEFAULT_CONFIG_FILE = "agentkit.yaml"
|
||||
|
||||
|
||||
def _resolve_env_vars(value: Any) -> Any:
|
||||
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
pattern = re.compile(r"\$\{([^}]+)\}")
|
||||
|
||||
def replacer(match):
|
||||
expr = match.group(1)
|
||||
if ":-" in expr:
|
||||
var_name, default = expr.split(":-", 1)
|
||||
return os.environ.get(var_name, default)
|
||||
return os.environ.get(expr, match.group(0))
|
||||
|
||||
return pattern.sub(replacer, value)
|
||||
|
||||
|
||||
def _deep_resolve(data: Any) -> Any:
|
||||
"""Recursively resolve env vars in nested dicts/lists."""
|
||||
if isinstance(data, dict):
|
||||
return {k: _deep_resolve(v) for k, v in data.items()}
|
||||
if isinstance(data, list):
|
||||
return [_deep_resolve(item) for item in data]
|
||||
if isinstance(data, str):
|
||||
return _resolve_env_vars(data)
|
||||
return data
|
||||
|
||||
|
||||
class ServerConfig:
|
||||
"""Server configuration loaded from agentkit.yaml"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8001,
|
||||
workers: int = 1,
|
||||
api_key: str | None = None,
|
||||
rate_limit: int = 60,
|
||||
llm_config: LLMConfig | None = None,
|
||||
skill_paths: list[str] | None = None,
|
||||
auto_discover_skills: bool = True,
|
||||
log_level: str = "INFO",
|
||||
log_format: str = "text",
|
||||
task_store: dict[str, Any] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.workers = workers
|
||||
self.api_key = api_key
|
||||
self.rate_limit = rate_limit
|
||||
self.llm_config = llm_config or LLMConfig()
|
||||
self.skill_paths = skill_paths or []
|
||||
self.auto_discover_skills = auto_discover_skills
|
||||
self.log_level = log_level
|
||||
self.log_format = log_format
|
||||
self.task_store = task_store or {}
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: str) -> "ServerConfig":
|
||||
"""Load configuration from a YAML file."""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
# Resolve environment variables
|
||||
data = _deep_resolve(data)
|
||||
|
||||
return cls.from_dict(data)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ServerConfig":
|
||||
"""Create ServerConfig from a dictionary."""
|
||||
server = data.get("server", {})
|
||||
llm_data = data.get("llm", {})
|
||||
skills_data = data.get("skills", {})
|
||||
logging_data = data.get("logging", {})
|
||||
task_store_data = data.get("task_store", {})
|
||||
|
||||
# Build LLMConfig
|
||||
llm_config = cls._build_llm_config(llm_data)
|
||||
|
||||
# Build skill paths
|
||||
skill_paths = skills_data.get("paths", [])
|
||||
auto_discover = skills_data.get("auto_discover", True)
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
port=server.get("port", 8001),
|
||||
workers=server.get("workers", 1),
|
||||
api_key=server.get("api_key"),
|
||||
rate_limit=server.get("rate_limit", 60),
|
||||
llm_config=llm_config,
|
||||
skill_paths=skill_paths,
|
||||
auto_discover_skills=auto_discover,
|
||||
log_level=logging_data.get("level", "INFO"),
|
||||
log_format=logging_data.get("format", "text"),
|
||||
task_store=task_store_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_llm_config(data: dict) -> LLMConfig:
|
||||
"""Build LLMConfig from the llm section of agentkit.yaml."""
|
||||
providers = {}
|
||||
model_aliases = {}
|
||||
|
||||
for name, pconf in data.get("providers", {}).items():
|
||||
api_key = pconf.get("api_key", "")
|
||||
base_url = pconf.get("base_url", "")
|
||||
models = pconf.get("models", {})
|
||||
|
||||
# Build model aliases from alias fields
|
||||
for model_name, model_conf in models.items():
|
||||
alias = model_conf.get("alias") if isinstance(model_conf, dict) else None
|
||||
if alias:
|
||||
model_aliases[alias] = f"{name}/{model_name}"
|
||||
|
||||
providers[name] = ProviderConfig(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
models=models,
|
||||
)
|
||||
|
||||
return LLMConfig(
|
||||
providers=providers,
|
||||
model_aliases=model_aliases,
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
)
|
||||
|
||||
def load_skill_configs(self) -> list[SkillConfig]:
|
||||
"""Load all SkillConfig from configured skill paths."""
|
||||
configs = []
|
||||
for skill_path in self.skill_paths:
|
||||
path = Path(skill_path)
|
||||
if path.is_file() and path.suffix in (".yaml", ".yml"):
|
||||
try:
|
||||
config = SkillConfig.from_yaml(str(path))
|
||||
configs.append(config)
|
||||
logger.info(f"Loaded skill config: {config.name} from {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load skill config from {path}: {e}")
|
||||
elif path.is_dir():
|
||||
for yaml_file in sorted(path.glob("*.yaml")):
|
||||
try:
|
||||
config = SkillConfig.from_yaml(str(yaml_file))
|
||||
configs.append(config)
|
||||
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
|
||||
for yaml_file in sorted(path.glob("*.yml")):
|
||||
try:
|
||||
config = SkillConfig.from_yaml(str(yaml_file))
|
||||
configs.append(config)
|
||||
logger.info(f"Loaded skill config: {config.name} from {yaml_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load skill config from {yaml_file}: {e}")
|
||||
return configs
|
||||
|
||||
def load_dotenv(self, dotenv_path: str = ".env") -> None:
|
||||
"""Load environment variables from a .env file (simple key=value format)."""
|
||||
path = Path(dotenv_path)
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip().strip("\"'")
|
||||
if key and key not in os.environ:
|
||||
os.environ[key] = value
|
||||
|
||||
|
||||
def find_config_path(config_arg: str | None = None) -> str | None:
|
||||
"""Find the agentkit.yaml config file.
|
||||
|
||||
Priority:
|
||||
1. Explicit --config argument
|
||||
2. ./agentkit.yaml in current directory
|
||||
3. ~/.agentkit/agentkit.yaml in home directory
|
||||
"""
|
||||
if config_arg:
|
||||
if Path(config_arg).exists():
|
||||
return config_arg
|
||||
logger.warning(f"Config file not found: {config_arg}")
|
||||
return None
|
||||
|
||||
# Check current directory
|
||||
cwd_config = Path.cwd() / DEFAULT_CONFIG_FILE
|
||||
if cwd_config.exists():
|
||||
return str(cwd_config)
|
||||
|
||||
# Check home directory
|
||||
home_config = Path.home() / ".agentkit" / DEFAULT_CONFIG_FILE
|
||||
if home_config.exists():
|
||||
return str(home_config)
|
||||
|
||||
return None
|
||||
|
|
@ -3,16 +3,47 @@
|
|||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
def _load_client_keys(config_dir: str | None = None) -> dict[str, str]:
|
||||
"""Load client API keys from clients.yaml.
|
||||
|
||||
Returns a dict mapping client_name -> api_key.
|
||||
"""
|
||||
if config_dir is None:
|
||||
# Try current directory and home directory
|
||||
for candidate in [Path.cwd(), Path.home() / ".agentkit"]:
|
||||
clients_path = candidate / "clients.yaml"
|
||||
if clients_path.exists():
|
||||
config_dir = str(candidate)
|
||||
break
|
||||
else:
|
||||
return {}
|
||||
|
||||
import yaml
|
||||
clients_path = Path(config_dir) / "clients.yaml"
|
||||
if not clients_path.exists():
|
||||
return {}
|
||||
|
||||
with open(clients_path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
|
||||
# data is {client_name: {api_key: "...", ...}}
|
||||
return {name: info["api_key"] for name, info in data.items() if "api_key" in info}
|
||||
|
||||
|
||||
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""API Key authentication middleware.
|
||||
|
||||
Validates X-API-Key header against AGENTKIT_API_KEY env var.
|
||||
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
|
||||
Validates X-API-Key header against:
|
||||
1. AGENTKIT_API_KEY env var (global key)
|
||||
2. Client keys from clients.yaml (generated by `agentkit pair`)
|
||||
|
||||
Skips validation if no keys are configured (dev mode).
|
||||
Whitelisted paths (no auth required): /api/v1/health
|
||||
"""
|
||||
|
||||
|
|
@ -23,14 +54,25 @@ class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
|||
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = os.environ.get("AGENTKIT_API_KEY")
|
||||
if not api_key:
|
||||
# Dev mode: skip auth if no API key configured
|
||||
# Collect all valid keys
|
||||
valid_keys = set()
|
||||
|
||||
# Global key from env var
|
||||
global_key = os.environ.get("AGENTKIT_API_KEY")
|
||||
if global_key:
|
||||
valid_keys.add(global_key)
|
||||
|
||||
# Client keys from clients.yaml
|
||||
client_keys = _load_client_keys()
|
||||
valid_keys.update(client_keys.values())
|
||||
|
||||
# No keys configured = dev mode
|
||||
if not valid_keys:
|
||||
return await call_next(request)
|
||||
|
||||
# Check API key from header
|
||||
provided_key = request.headers.get("X-API-Key")
|
||||
if not provided_key or provided_key != api_key:
|
||||
if not provided_key or provided_key not in valid_keys:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""Server route modules"""
|
||||
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics
|
||||
|
||||
__all__ = ["agents", "tasks", "skills", "llm", "health"]
|
||||
__all__ = ["agents", "tasks", "skills", "llm", "health", "metrics"]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,72 @@
|
|||
"""Health check route"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "ok", "version": "2.0.0"}
|
||||
async def health_check(request: Request):
|
||||
"""Enhanced health check with dependency status"""
|
||||
app = request.app
|
||||
checks: dict = {}
|
||||
overall_status = "healthy"
|
||||
|
||||
# Check Redis / TaskStore backend
|
||||
redis_status = "not_configured"
|
||||
try:
|
||||
task_store = getattr(app.state, "task_store", None)
|
||||
if task_store:
|
||||
redis_status = "available" if hasattr(task_store, "_redis") else "not_configured"
|
||||
else:
|
||||
redis_status = "not_configured"
|
||||
except Exception as exc:
|
||||
redis_status = f"error: {str(exc)[:100]}"
|
||||
overall_status = "degraded"
|
||||
checks["redis"] = redis_status
|
||||
|
||||
# Check AgentPool
|
||||
agent_pool = getattr(app.state, "agent_pool", None)
|
||||
pool_size = 0
|
||||
if agent_pool:
|
||||
try:
|
||||
agents = agent_pool.list_agents()
|
||||
pool_size = len(agents)
|
||||
except Exception:
|
||||
pass
|
||||
checks["agent_pool"] = {"status": "available", "size": pool_size}
|
||||
|
||||
# Check LLM Gateway
|
||||
llm_gateway = getattr(app.state, "llm_gateway", None)
|
||||
llm_status = "not_configured"
|
||||
if llm_gateway:
|
||||
llm_status = "configured"
|
||||
try:
|
||||
if hasattr(llm_gateway, "_providers") and llm_gateway._providers:
|
||||
llm_status = "available"
|
||||
else:
|
||||
llm_status = "no_providers"
|
||||
overall_status = "degraded"
|
||||
except Exception:
|
||||
llm_status = "error"
|
||||
overall_status = "degraded"
|
||||
checks["llm_gateway"] = llm_status
|
||||
|
||||
# Check Skill Registry
|
||||
skill_registry = getattr(app.state, "skill_registry", None)
|
||||
skill_count = 0
|
||||
if skill_registry:
|
||||
try:
|
||||
skill_count = len(skill_registry.list_skills())
|
||||
except Exception:
|
||||
pass
|
||||
checks["skill_registry"] = {
|
||||
"status": "available" if skill_registry else "not_configured",
|
||||
"count": skill_count,
|
||||
}
|
||||
|
||||
return {
|
||||
"status": overall_status,
|
||||
"version": "2.0.0",
|
||||
"checks": checks,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
"""Metrics route — /api/v1/metrics"""
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
router = APIRouter(tags=["metrics"])
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(request: Request):
|
||||
"""Get application metrics"""
|
||||
app = request.app
|
||||
|
||||
# Task metrics from TaskStore
|
||||
task_store = getattr(app.state, "task_store", None)
|
||||
task_metrics = {
|
||||
"total_tasks": 0,
|
||||
"completed_tasks": 0,
|
||||
"failed_tasks": 0,
|
||||
"pending_tasks": 0,
|
||||
}
|
||||
if task_store:
|
||||
try:
|
||||
all_tasks = task_store.list_tasks(limit=10000)
|
||||
task_metrics["total_tasks"] = len(all_tasks)
|
||||
task_metrics["completed_tasks"] = len(
|
||||
[t for t in all_tasks if t.status.value == "completed"]
|
||||
)
|
||||
task_metrics["failed_tasks"] = len(
|
||||
[t for t in all_tasks if t.status.value == "failed"]
|
||||
)
|
||||
task_metrics["pending_tasks"] = len(
|
||||
[t for t in all_tasks if t.status.value == "pending"]
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Agent pool metrics
|
||||
agent_pool = getattr(app.state, "agent_pool", None)
|
||||
agent_metrics: dict = {
|
||||
"total_agents": 0,
|
||||
"agent_names": [],
|
||||
}
|
||||
if agent_pool:
|
||||
try:
|
||||
agents = agent_pool.list_agents()
|
||||
agent_metrics["total_agents"] = len(agents)
|
||||
agent_metrics["agent_names"] = [a.get("name", "") for a in agents]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Skill registry metrics
|
||||
skill_registry = getattr(app.state, "skill_registry", None)
|
||||
skill_metrics: dict = {
|
||||
"total_skills": 0,
|
||||
"skill_names": [],
|
||||
}
|
||||
if skill_registry:
|
||||
try:
|
||||
skills = skill_registry.list_skills()
|
||||
skill_metrics["total_skills"] = len(skills)
|
||||
skill_metrics["skill_names"] = [s.name for s in skills]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"tasks": task_metrics,
|
||||
"agents": agent_metrics,
|
||||
"skills": skill_metrics,
|
||||
"version": "2.0.0",
|
||||
}
|
||||
|
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
|||
from typing import Any
|
||||
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.pipeline import SkillPipeline
|
||||
|
||||
router = APIRouter(tags=["skills"])
|
||||
|
||||
|
|
@ -13,6 +14,15 @@ class RegisterSkillRequest(BaseModel):
|
|||
config: dict[str, Any]
|
||||
|
||||
|
||||
class CreatePipelineRequest(BaseModel):
|
||||
name: str
|
||||
steps: list[dict[str, Any]]
|
||||
|
||||
|
||||
class ExecutePipelineRequest(BaseModel):
|
||||
input_data: dict[str, Any]
|
||||
|
||||
|
||||
@router.post("/skills", status_code=201)
|
||||
async def register_skill(request: RegisterSkillRequest, req: Request):
|
||||
"""Register a Skill"""
|
||||
|
|
@ -48,3 +58,59 @@ async def list_skills(req: Request):
|
|||
}
|
||||
for s in skills
|
||||
]
|
||||
|
||||
|
||||
# ---- Pipeline endpoints ----
|
||||
|
||||
|
||||
@router.post("/skills/pipelines", status_code=201)
|
||||
async def create_pipeline(request: CreatePipelineRequest, req: Request):
|
||||
"""Create and register a SkillPipeline"""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
|
||||
# Validate step definitions
|
||||
for i, step in enumerate(request.steps):
|
||||
if "skill_name" not in step:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Step {i} missing required field 'skill_name'",
|
||||
)
|
||||
|
||||
pipeline = SkillPipeline(
|
||||
name=request.name,
|
||||
steps=request.steps,
|
||||
skill_registry=skill_registry,
|
||||
)
|
||||
skill_registry.register_pipeline(pipeline)
|
||||
|
||||
return {
|
||||
"name": pipeline.name,
|
||||
"steps": [
|
||||
{"skill_name": s["skill_name"], "step_index": i}
|
||||
for i, s in enumerate(request.steps)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/skills/pipelines")
|
||||
async def list_pipelines(req: Request):
|
||||
"""List all registered pipelines"""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
return skill_registry.list_pipelines()
|
||||
|
||||
|
||||
@router.post("/skills/pipelines/{name}/execute")
|
||||
async def execute_pipeline(name: str, request: ExecutePipelineRequest, req: Request):
|
||||
"""Execute a registered pipeline"""
|
||||
skill_registry = req.app.state.skill_registry
|
||||
pipeline = skill_registry.get_pipeline(name)
|
||||
|
||||
if pipeline is None:
|
||||
raise HTTPException(status_code=404, detail=f"Pipeline '{name}' not found")
|
||||
|
||||
try:
|
||||
result = await pipeline.execute(input_data=request.input_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Pipeline execution failed: {e}")
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""TaskStore - In-memory task state storage with TTL"""
|
||||
"""TaskStore - Task state storage with TTL (InMemory / Redis backends)"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
|
@ -46,8 +46,27 @@ class TaskRecord:
|
|||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "TaskRecord":
|
||||
"""Reconstruct a TaskRecord from a dict (e.g. deserialized from Redis)."""
|
||||
return cls(
|
||||
task_id=data["task_id"],
|
||||
agent_name=data["agent_name"],
|
||||
skill_name=data.get("skill_name"),
|
||||
input_data=data.get("input_data", {}),
|
||||
status=TaskStatus(data.get("status", "pending")),
|
||||
output_data=data.get("output_data"),
|
||||
error_message=data.get("error_message"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
|
||||
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
|
||||
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
|
||||
progress=data.get("progress", 0.0),
|
||||
progress_message=data.get("progress_message", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
class TaskStore:
|
||||
|
||||
class InMemoryTaskStore:
|
||||
"""In-memory task state storage with automatic TTL cleanup.
|
||||
|
||||
Stores task records indexed by task_id. Automatically removes
|
||||
|
|
@ -105,7 +124,7 @@ class TaskStore:
|
|||
if len(self._tasks) >= self._max_records:
|
||||
# Remove oldest completed task
|
||||
oldest = None
|
||||
for tid, rec in self._tasks.items():
|
||||
for rec in self._tasks.values():
|
||||
if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||
if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)):
|
||||
oldest = rec
|
||||
|
|
@ -149,3 +168,206 @@ class TaskStore:
|
|||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._tasks)
|
||||
|
||||
|
||||
# Backward-compatible alias
|
||||
TaskStore = InMemoryTaskStore
|
||||
|
||||
|
||||
class RedisTaskStore:
|
||||
"""Redis-backed task state storage with TTL.
|
||||
|
||||
Stores each task as a JSON string in Redis with key pattern
|
||||
``agentkit:task:{task_id}``. Redis TTL handles automatic cleanup,
|
||||
so start_cleanup / stop_cleanup are no-ops.
|
||||
"""
|
||||
|
||||
KEY_PREFIX = "agentkit:task:"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
ttl_seconds: int = 3600,
|
||||
max_records: int = 10000,
|
||||
):
|
||||
self._redis_url = redis_url
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._max_records = max_records
|
||||
self._redis: Any = None # redis.asyncio.Redis, lazy init
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Lazy-initialise the async Redis client."""
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
def _key(self, task_id: str) -> str:
|
||||
return f"{self.KEY_PREFIX}{task_id}"
|
||||
|
||||
# ── lifecycle (no-ops, Redis TTL handles cleanup) ──────────
|
||||
|
||||
async def start_cleanup(self) -> None:
|
||||
"""No-op – Redis TTL handles expiry automatically."""
|
||||
|
||||
async def stop_cleanup(self) -> None:
|
||||
"""Close the Redis connection pool on shutdown."""
|
||||
if self._redis is not None:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
# ── CRUD ───────────────────────────────────────────────────
|
||||
|
||||
async def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
|
||||
"""Create a new task record in Redis."""
|
||||
redis = await self._get_redis()
|
||||
|
||||
# Enforce max_records by counting existing keys
|
||||
current_size = await self._count_keys(redis)
|
||||
if current_size >= self._max_records:
|
||||
# Try to evict the oldest completed task
|
||||
evicted = await self._evict_oldest_completed(redis)
|
||||
if not evicted:
|
||||
raise RuntimeError("TaskStore is full and no completed tasks to evict")
|
||||
|
||||
record = TaskRecord(
|
||||
task_id=task_id,
|
||||
agent_name=agent_name,
|
||||
skill_name=skill_name,
|
||||
input_data=input_data,
|
||||
)
|
||||
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
|
||||
return record
|
||||
|
||||
async def get(self, task_id: str) -> TaskRecord | None:
|
||||
"""Get task record by ID."""
|
||||
redis = await self._get_redis()
|
||||
raw = await redis.get(self._key(task_id))
|
||||
if raw is None:
|
||||
return None
|
||||
return TaskRecord.from_dict(json.loads(raw))
|
||||
|
||||
async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
|
||||
"""Update task status and optional fields."""
|
||||
redis = await self._get_redis()
|
||||
raw = await redis.get(self._key(task_id))
|
||||
if raw is None:
|
||||
raise KeyError(f"Task '{task_id}' not found")
|
||||
data = json.loads(raw)
|
||||
data["status"] = status.value
|
||||
for key, value in kwargs.items():
|
||||
if key in data or key in ("started_at", "completed_at", "output_data", "error_message", "progress", "progress_message", "metadata"):
|
||||
# Serialise datetime fields
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
else:
|
||||
data[key] = value
|
||||
await redis.set(self._key(task_id), json.dumps(data), ex=self._ttl_seconds)
|
||||
return TaskRecord.from_dict(data)
|
||||
|
||||
async def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
|
||||
"""List tasks, optionally filtered by status, sorted by created_at desc."""
|
||||
redis = await self._get_redis()
|
||||
tasks: list[TaskRecord] = []
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||||
if keys:
|
||||
values = await redis.mget(keys)
|
||||
for raw in values:
|
||||
if raw is None:
|
||||
continue
|
||||
record = TaskRecord.from_dict(json.loads(raw))
|
||||
if status is None or record.status == status:
|
||||
tasks.append(record)
|
||||
if cursor == 0:
|
||||
break
|
||||
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
||||
return tasks[:limit]
|
||||
|
||||
@property
|
||||
async def size(self) -> int:
|
||||
"""Number of task keys currently stored."""
|
||||
redis = await self._get_redis()
|
||||
return await self._count_keys(redis)
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────
|
||||
|
||||
async def _count_keys(self, redis) -> int:
|
||||
"""Count task keys using SCAN (avoid KEYS on large datasets)."""
|
||||
count = 0
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||||
count += len(keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
return count
|
||||
|
||||
async def _evict_oldest_completed(self, redis) -> bool:
|
||||
"""Find and delete the oldest completed/failed/cancelled task.
|
||||
Returns True if a record was evicted, False otherwise.
|
||||
"""
|
||||
tasks: list[TaskRecord] = []
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||||
if keys:
|
||||
values = await redis.mget(keys)
|
||||
for raw in values:
|
||||
if raw is None:
|
||||
continue
|
||||
record = TaskRecord.from_dict(json.loads(raw))
|
||||
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||||
tasks.append(record)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if not tasks:
|
||||
return False
|
||||
|
||||
# Pick the one with the earliest completed_at
|
||||
oldest = min(
|
||||
(t for t in tasks if t.completed_at is not None),
|
||||
key=lambda t: t.completed_at, # type: ignore[arg-type]
|
||||
default=None,
|
||||
)
|
||||
if oldest is None:
|
||||
return False
|
||||
|
||||
await redis.delete(self._key(oldest.task_id))
|
||||
return True
|
||||
|
||||
|
||||
def create_task_store(
|
||||
backend: str = "memory",
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
ttl_seconds: int = 3600,
|
||||
max_records: int = 10000,
|
||||
) -> InMemoryTaskStore | RedisTaskStore:
|
||||
"""Factory: create a TaskStore backed by memory or Redis.
|
||||
|
||||
If ``backend="redis"`` and the Redis connection cannot be established,
|
||||
falls back to :class:`InMemoryTaskStore` with a warning.
|
||||
"""
|
||||
if backend == "redis":
|
||||
try:
|
||||
import redis.asyncio as aioredis # noqa: F401
|
||||
|
||||
store = RedisTaskStore(
|
||||
redis_url=redis_url,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_records=max_records,
|
||||
)
|
||||
logger.info(f"TaskStore backend: redis ({redis_url})")
|
||||
return store
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to initialise RedisTaskStore ({exc}), falling back to InMemoryTaskStore")
|
||||
|
||||
store = InMemoryTaskStore(ttl_seconds=ttl_seconds, max_records=max_records)
|
||||
logger.info("TaskStore backend: memory")
|
||||
return store
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from agentkit.skills.base import IntentConfig, QualityGateConfig, Skill, SkillConfig
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
from agentkit.skills.pipeline import SkillPipeline
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -9,6 +10,7 @@ __all__ = [
|
|||
"QualityGateConfig",
|
||||
"SkillConfig",
|
||||
"Skill",
|
||||
"SkillPipeline",
|
||||
"SkillRegistry",
|
||||
"SkillLoader",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ class EvolutionConfig:
|
|||
reflect_on_failure: bool = True # Whether to reflect on failed tasks
|
||||
auto_apply: bool = False # Whether to auto-apply optimizations (without AB test)
|
||||
min_quality_threshold: float = 0.5 # Minimum quality score to trigger optimization
|
||||
reflector_type: str = "auto" # "llm" / "rule" / "auto"
|
||||
auxiliary_model: str | None = None # Model name for LLM reflection
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -70,6 +72,9 @@ class SkillConfig(AgentConfig):
|
|||
execution_mode: str = "react",
|
||||
max_steps: int = 5,
|
||||
evolution: dict[str, Any] | None = None,
|
||||
# v3 新增字段:SKILL.md 支持
|
||||
skill_md_path: str | None = None,
|
||||
disclosure_level: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
|
|
@ -92,6 +97,8 @@ class SkillConfig(AgentConfig):
|
|||
self.execution_mode = execution_mode
|
||||
self.max_steps = max_steps
|
||||
self.evolution = EvolutionConfig(**(evolution or {}))
|
||||
self.skill_md_path = skill_md_path
|
||||
self.disclosure_level = disclosure_level
|
||||
self._validate_v2()
|
||||
|
||||
def _validate_v2(self) -> None:
|
||||
|
|
@ -129,6 +136,8 @@ class SkillConfig(AgentConfig):
|
|||
execution_mode=data.get("execution_mode", "react"),
|
||||
max_steps=data.get("max_steps", 5),
|
||||
evolution=data.get("evolution"),
|
||||
skill_md_path=data.get("skill_md_path"),
|
||||
disclosure_level=data.get("disclosure_level", 0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -167,7 +176,11 @@ class SkillConfig(AgentConfig):
|
|||
"reflect_on_failure": self.evolution.reflect_on_failure,
|
||||
"auto_apply": self.evolution.auto_apply,
|
||||
"min_quality_threshold": self.evolution.min_quality_threshold,
|
||||
"reflector_type": self.evolution.reflector_type,
|
||||
"auxiliary_model": self.evolution.auxiliary_model,
|
||||
}
|
||||
d["skill_md_path"] = self.skill_md_path
|
||||
d["disclosure_level"] = self.disclosure_level
|
||||
return d
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""SkillLoader - 从 YAML 目录批量加载 Skill"""
|
||||
"""SkillLoader - 从 YAML/SKILL.md 目录批量加载 Skill"""
|
||||
|
||||
import glob
|
||||
import logging
|
||||
|
|
@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class SkillLoader:
|
||||
"""从 YAML 目录批量加载 Skill 并注册到 SkillRegistry"""
|
||||
"""从 YAML/SKILL.md 目录批量加载 Skill 并注册到 SkillRegistry"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -23,14 +23,15 @@ class SkillLoader:
|
|||
self._tool_registry = tool_registry
|
||||
|
||||
def load_from_directory(self, directory: str) -> list[Skill]:
|
||||
"""加载目录下所有 YAML 文件为 Skill,并注册到 SkillRegistry
|
||||
"""加载目录下所有 YAML 和 SKILL.md 文件为 Skill,并注册到 SkillRegistry
|
||||
|
||||
无效的 YAML 文件会被跳过并记录警告。
|
||||
无效的文件会被跳过并记录警告。
|
||||
"""
|
||||
skills: list[Skill] = []
|
||||
pattern = os.path.join(directory, "*.yaml")
|
||||
yaml_files = sorted(glob.glob(pattern))
|
||||
|
||||
# 加载 YAML 文件
|
||||
yaml_pattern = os.path.join(directory, "*.yaml")
|
||||
yaml_files = sorted(glob.glob(yaml_pattern))
|
||||
for yaml_path in yaml_files:
|
||||
try:
|
||||
skill = self._load_skill_from_file(yaml_path)
|
||||
|
|
@ -38,6 +39,16 @@ class SkillLoader:
|
|||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid YAML file '{yaml_path}': {e}")
|
||||
|
||||
# 加载 SKILL.md 文件
|
||||
md_pattern = os.path.join(directory, "*.md")
|
||||
md_files = sorted(glob.glob(md_pattern))
|
||||
for md_path in md_files:
|
||||
try:
|
||||
skill = self.load_from_skill_md(md_path)
|
||||
skills.append(skill)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid SKILL.md file '{md_path}': {e}")
|
||||
|
||||
return skills
|
||||
|
||||
def load_from_file(self, path: str) -> Skill:
|
||||
|
|
@ -54,6 +65,28 @@ class SkillLoader:
|
|||
logger.info(f"Loaded skill '{skill.name}' from '{path}'")
|
||||
return skill
|
||||
|
||||
def load_from_skill_md(self, path: str, disclosure_level: int = 1) -> Skill:
|
||||
"""加载 SKILL.md 文件为 Skill,并注册到 SkillRegistry
|
||||
|
||||
Args:
|
||||
path: SKILL.md 文件路径
|
||||
disclosure_level: 渐进式加载层级(0=概要, 1=完整, 2=参考)
|
||||
|
||||
Returns:
|
||||
加载的 Skill 实例
|
||||
"""
|
||||
from agentkit.skills.skill_md import SkillMdParser
|
||||
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(
|
||||
frontmatter, sections, path, disclosure_level=disclosure_level,
|
||||
)
|
||||
tools = self._bind_tools(config)
|
||||
skill = Skill(config, tools=tools)
|
||||
self._skill_registry.register(skill)
|
||||
logger.info(f"Loaded skill '{skill.name}' from SKILL.md '{path}' (level={disclosure_level})")
|
||||
return skill
|
||||
|
||||
def _bind_tools(self, config: SkillConfig) -> list:
|
||||
"""根据配置中的 tools 列表绑定工具"""
|
||||
if not self._tool_registry or not config.tools:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
"""SkillPipeline - 技能编排,将多个 Skill 串联为 Pipeline 执行
|
||||
|
||||
复用 PipelineEngine 的设计理念,支持:
|
||||
- 顺序执行(skill A → skill B → skill C)
|
||||
- 条件分支(if skill A output contains X, run skill B, else skip)
|
||||
- 输出映射(将上一步输出字段映射到下一步输入字段)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillPipeline:
|
||||
"""将多个 Skill 串联为 Pipeline 执行
|
||||
|
||||
每个步骤定义包含:
|
||||
- skill_name: str (必需) — 要执行的 Skill 名称
|
||||
- input_mapping: dict | None — 将上一步输出映射到当前步骤输入
|
||||
- condition: str | None — 条件表达式,不满足则跳过
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
steps: list[dict[str, Any]],
|
||||
skill_registry: SkillRegistry | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
name: Pipeline 名称
|
||||
steps: 步骤定义列表,每项包含 skill_name、input_mapping、condition
|
||||
skill_registry: 用于查找 Skill 的注册中心
|
||||
"""
|
||||
self.name = name
|
||||
self._steps = steps
|
||||
self._skill_registry = skill_registry
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any],
|
||||
agent_factory: Callable[..., Coroutine] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""顺序执行 Pipeline 中所有步骤
|
||||
|
||||
Args:
|
||||
input_data: 初始输入数据
|
||||
agent_factory: 可选的 Agent 工厂函数,签名为
|
||||
async (skill_name: str, input_data: dict) -> dict
|
||||
|
||||
Returns:
|
||||
包含 pipeline 名称、各步骤结果和最终输出的字典
|
||||
"""
|
||||
current_input: dict[str, Any] = input_data
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for i, step_def in enumerate(self._steps):
|
||||
skill_name = step_def["skill_name"]
|
||||
|
||||
# 条件检查
|
||||
condition = step_def.get("condition")
|
||||
if condition and not self._evaluate_condition(condition, current_input, results):
|
||||
results.append({
|
||||
"step": i,
|
||||
"skill": skill_name,
|
||||
"status": "skipped",
|
||||
"reason": f"Condition not met: {condition}",
|
||||
})
|
||||
continue
|
||||
|
||||
# 输入映射
|
||||
input_mapping = step_def.get("input_mapping")
|
||||
step_input = (
|
||||
self._map_input(current_input, input_mapping, results)
|
||||
if input_mapping
|
||||
else current_input
|
||||
)
|
||||
|
||||
# 执行 Skill
|
||||
try:
|
||||
step_result = await self._execute_skill(skill_name, step_input, agent_factory)
|
||||
results.append({
|
||||
"step": i,
|
||||
"skill": skill_name,
|
||||
"output": step_result,
|
||||
"status": "success",
|
||||
})
|
||||
current_input = step_result
|
||||
except Exception as e:
|
||||
results.append({
|
||||
"step": i,
|
||||
"skill": skill_name,
|
||||
"error": str(e),
|
||||
"status": "failed",
|
||||
})
|
||||
break
|
||||
|
||||
return {
|
||||
"pipeline": self.name,
|
||||
"steps": results,
|
||||
"final_output": current_input,
|
||||
}
|
||||
|
||||
async def _execute_skill(
|
||||
self,
|
||||
skill_name: str,
|
||||
input_data: dict[str, Any],
|
||||
agent_factory: Callable[..., Coroutine] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""执行单个 Skill
|
||||
|
||||
优先使用 agent_factory,其次通过 SkillRegistry 查找 Skill 并创建 Agent 执行。
|
||||
"""
|
||||
if agent_factory:
|
||||
return await agent_factory(skill_name, input_data)
|
||||
|
||||
if self._skill_registry:
|
||||
try:
|
||||
skill = self._skill_registry.get(skill_name)
|
||||
except Exception:
|
||||
raise ValueError(f"Skill '{skill_name}' not found in registry")
|
||||
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
from datetime import datetime, timezone
|
||||
|
||||
agent = ConfigDrivenAgent(config=skill.config)
|
||||
task = TaskMessage(
|
||||
task_id=f"pipeline-{skill_name}",
|
||||
agent_name=skill_name,
|
||||
task_type=skill.config.agent_type,
|
||||
priority=0,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
return await agent.handle_task(task)
|
||||
|
||||
raise ValueError(
|
||||
f"Cannot execute skill '{skill_name}': "
|
||||
"no agent_factory or skill_registry provided"
|
||||
)
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str,
|
||||
current_input: dict[str, Any],
|
||||
results: list[dict[str, Any]],
|
||||
) -> bool:
|
||||
"""评估简单条件表达式
|
||||
|
||||
支持格式:
|
||||
- "key.path == 'value'" — 字符串相等
|
||||
- "key.path > 0.5" — 数值大于
|
||||
"""
|
||||
try:
|
||||
if "==" in condition:
|
||||
path, value = condition.split("==", 1)
|
||||
path = path.strip()
|
||||
value = value.strip().strip("'\"")
|
||||
actual = self._resolve_path(path, current_input)
|
||||
return str(actual) == value
|
||||
elif ">" in condition:
|
||||
path, value = condition.split(">", 1)
|
||||
path = path.strip()
|
||||
value = float(value.strip())
|
||||
actual = float(self._resolve_path(path, current_input))
|
||||
return actual > value
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _resolve_path(path: str, data: dict[str, Any]) -> Any:
|
||||
"""解析点号路径,如 'output.score'"""
|
||||
parts = path.split(".")
|
||||
obj: Any = data
|
||||
for part in parts:
|
||||
if isinstance(obj, dict):
|
||||
obj = obj.get(part)
|
||||
else:
|
||||
return None
|
||||
return obj
|
||||
|
||||
def _map_input(
|
||||
self,
|
||||
current_input: dict[str, Any],
|
||||
mapping: dict[str, str],
|
||||
results: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""根据映射规则将上一步输出映射到当前步骤输入
|
||||
|
||||
mapping 格式: {"target_key": "source.path"}
|
||||
"""
|
||||
mapped: dict[str, Any] = {}
|
||||
for target_key, source_path in mapping.items():
|
||||
value = self._resolve_path(source_path, current_input)
|
||||
if value is not None:
|
||||
mapped[target_key] = value
|
||||
return mapped
|
||||
|
|
@ -1,10 +1,16 @@
|
|||
"""SkillRegistry - Skill 注册中心"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from agentkit.core.exceptions import SkillNotFoundError
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.skills.pipeline import SkillPipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -13,6 +19,7 @@ class SkillRegistry:
|
|||
|
||||
def __init__(self):
|
||||
self._skills: dict[str, Skill] = {}
|
||||
self._pipelines: dict[str, SkillPipeline] = {}
|
||||
|
||||
def register(self, skill: Skill) -> None:
|
||||
"""注册 Skill,同名覆盖"""
|
||||
|
|
@ -48,3 +55,24 @@ class SkillRegistry:
|
|||
def has_skill(self, name: str) -> bool:
|
||||
"""检查 Skill 是否已注册"""
|
||||
return name in self._skills
|
||||
|
||||
# ---- Pipeline 管理 ----
|
||||
|
||||
def register_pipeline(self, pipeline: SkillPipeline) -> None:
|
||||
"""注册 SkillPipeline,同名覆盖"""
|
||||
self._pipelines[pipeline.name] = pipeline
|
||||
logger.info(f"SkillPipeline '{pipeline.name}' registered")
|
||||
|
||||
def get_pipeline(self, name: str) -> SkillPipeline | None:
|
||||
"""获取已注册的 SkillPipeline,不存在返回 None"""
|
||||
return self._pipelines.get(name)
|
||||
|
||||
def list_pipelines(self) -> list[str]:
|
||||
"""列出所有已注册的 Pipeline 名称"""
|
||||
return list(self._pipelines.keys())
|
||||
|
||||
def unregister_pipeline(self, name: str) -> None:
|
||||
"""注销 SkillPipeline"""
|
||||
if name in self._pipelines:
|
||||
del self._pipelines[name]
|
||||
logger.info(f"SkillPipeline '{name}' unregistered")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
"""SKILL.md 解析器 - 从 Markdown 文件解析技能定义
|
||||
|
||||
支持渐进式分层加载:
|
||||
- Level 0: 概要(name + description)
|
||||
- Level 1: 完整内容(所有 sections)
|
||||
- Level 2: 参考信息(含外部链接等)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from agentkit.core.exceptions import ConfigValidationError
|
||||
from agentkit.skills.base import SkillConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillMdParser:
|
||||
"""解析 SKILL.md 文件为 SkillConfig
|
||||
|
||||
SKILL.md 格式:
|
||||
1. YAML frontmatter(--- 包裹):包含元数据
|
||||
2. Markdown body:包含 # Trigger / # Steps / # Pitfalls / # Verification 等 section
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse(file_path: str) -> tuple[dict[str, Any], dict[str, str], str]:
|
||||
"""解析 SKILL.md 文件
|
||||
|
||||
Args:
|
||||
file_path: SKILL.md 文件路径
|
||||
|
||||
Returns:
|
||||
- frontmatter: YAML 元数据字典
|
||||
- sections: section 标题 → 内容的映射
|
||||
- raw_markdown: 去掉 frontmatter 后的完整 Markdown 内容
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 提取 YAML frontmatter(--- 标记之间)
|
||||
frontmatter: dict[str, Any] = {}
|
||||
body = content
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
frontmatter = yaml.safe_load(parts[1]) or {}
|
||||
body = parts[2].strip()
|
||||
|
||||
# 按 # 标题解析 sections
|
||||
sections: dict[str, str] = {}
|
||||
current_section: str | None = None
|
||||
current_lines: list[str] = []
|
||||
|
||||
for line in body.split("\n"):
|
||||
# 匹配 H1 标题(# 开头但不是 ##)
|
||||
h1_match = re.match(r"^# (.+)$", line)
|
||||
if h1_match:
|
||||
# 保存前一个 section
|
||||
if current_section is not None:
|
||||
sections[current_section] = "\n".join(current_lines).strip()
|
||||
current_section = h1_match.group(1).strip().lower()
|
||||
current_lines = []
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
# 保存最后一个 section
|
||||
if current_section is not None:
|
||||
sections[current_section] = "\n".join(current_lines).strip()
|
||||
|
||||
return frontmatter, sections, body
|
||||
|
||||
@staticmethod
|
||||
def to_skill_config(
|
||||
frontmatter: dict[str, Any],
|
||||
sections: dict[str, str],
|
||||
file_path: str,
|
||||
disclosure_level: int = 1,
|
||||
) -> SkillConfig:
|
||||
"""将解析后的 SKILL.md 数据转换为 SkillConfig
|
||||
|
||||
Args:
|
||||
frontmatter: YAML 元数据
|
||||
sections: Markdown sections
|
||||
file_path: 原始文件路径
|
||||
disclosure_level: 渐进式加载层级(0=概要, 1=完整, 2=参考)
|
||||
|
||||
Returns:
|
||||
SkillConfig 实例
|
||||
"""
|
||||
# 构建 IntentConfig
|
||||
intent_data = frontmatter.get("intent") or {}
|
||||
intent_config_data: dict[str, Any] = {
|
||||
"keywords": intent_data.get("keywords", []),
|
||||
"description": intent_data.get("description", ""),
|
||||
"examples": intent_data.get("examples", []),
|
||||
}
|
||||
|
||||
# 构建 QualityGateConfig
|
||||
qg_data = frontmatter.get("quality_gate") or {}
|
||||
quality_gate_config_data: dict[str, Any] = {
|
||||
"required_fields": qg_data.get("required_fields", []),
|
||||
"min_word_count": qg_data.get("min_word_count", 0),
|
||||
"max_retries": qg_data.get("max_retries", 0),
|
||||
"custom_validator": qg_data.get("custom_validator"),
|
||||
}
|
||||
|
||||
# 从 sections 构建 prompt
|
||||
prompt: dict[str, str] = {}
|
||||
if sections.get("steps"):
|
||||
prompt["instructions"] = sections["steps"]
|
||||
if sections.get("pitfalls"):
|
||||
prompt["constraints"] = sections["pitfalls"]
|
||||
if sections.get("verification"):
|
||||
prompt["output_format"] = sections["verification"]
|
||||
if sections.get("trigger"):
|
||||
prompt["context"] = sections["trigger"]
|
||||
|
||||
# Level 0: 仅保留 name + description,prompt 仅含 identity
|
||||
if disclosure_level == 0:
|
||||
prompt = {"identity": frontmatter.get("description", frontmatter.get("name", ""))}
|
||||
|
||||
# 确保 prompt 非空(llm_generate 模式要求 prompt 配置)
|
||||
if not prompt:
|
||||
prompt = {"identity": frontmatter.get("description", frontmatter.get("name", ""))}
|
||||
|
||||
# 校验必要字段
|
||||
name = frontmatter.get("name", "")
|
||||
if not name:
|
||||
raise ConfigValidationError(
|
||||
agent_name="unknown",
|
||||
key="name",
|
||||
reason="SKILL.md frontmatter must contain a non-empty 'name' field",
|
||||
)
|
||||
|
||||
return SkillConfig(
|
||||
name=name,
|
||||
agent_type=frontmatter.get("agent_type", frontmatter.get("name", "")),
|
||||
description=frontmatter.get("description", ""),
|
||||
task_mode="llm_generate",
|
||||
prompt=prompt,
|
||||
execution_mode=frontmatter.get("execution_mode", "react"),
|
||||
intent=intent_config_data,
|
||||
quality_gate=quality_gate_config_data,
|
||||
skill_md_path=file_path,
|
||||
disclosure_level=disclosure_level,
|
||||
)
|
||||
|
|
@ -0,0 +1,434 @@
|
|||
"""Tests for ContextCompressor and PromptTemplate cache"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.prompts.section import PromptSection
|
||||
from agentkit.prompts.template import PromptTemplate
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
def make_mock_gateway(summary_content: str = "Summary of conversation") -> MagicMock:
|
||||
"""创建一个 mock LLMGateway,返回摘要响应"""
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
response = LLMResponse(
|
||||
content=summary_content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||
)
|
||||
gateway.chat = AsyncMock(return_value=response)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
|
||||
"""生成长消息列表用于测试压缩"""
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
for i in range(count):
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "x" * content_length + f" message {i}",
|
||||
})
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "y" * content_length + f" reply {i}",
|
||||
})
|
||||
return messages
|
||||
|
||||
|
||||
# ── ContextCompressor Tests ──────────────────────────
|
||||
|
||||
|
||||
class TestEstimateTokens:
|
||||
"""estimate_tokens 基础测试"""
|
||||
|
||||
def test_empty_messages(self):
|
||||
compressor = ContextCompressor()
|
||||
assert compressor.estimate_tokens([]) == 0
|
||||
|
||||
def test_single_message(self):
|
||||
compressor = ContextCompressor()
|
||||
messages = [{"role": "user", "content": "a" * 40}]
|
||||
# 40 chars / 4 = 10 tokens
|
||||
assert compressor.estimate_tokens(messages) == 10
|
||||
|
||||
def test_multiple_messages(self):
|
||||
compressor = ContextCompressor()
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 40},
|
||||
{"role": "assistant", "content": "b" * 80},
|
||||
]
|
||||
# 40/4 + 80/4 = 10 + 20 = 30
|
||||
assert compressor.estimate_tokens(messages) == 30
|
||||
|
||||
def test_missing_content_key(self):
|
||||
compressor = ContextCompressor()
|
||||
messages = [{"role": "user"}]
|
||||
assert compressor.estimate_tokens(messages) == 0
|
||||
|
||||
|
||||
class TestNoCompressionWhenUnderBudget:
|
||||
"""Token 预算内不压缩"""
|
||||
|
||||
async def test_short_messages_not_compressed(self):
|
||||
compressor = ContextCompressor(max_tokens=10000)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
assert result == messages
|
||||
|
||||
async def test_exactly_at_budget_not_compressed(self):
|
||||
# 40 chars = 10 tokens, budget = 10
|
||||
compressor = ContextCompressor(max_tokens=10)
|
||||
messages = [{"role": "user", "content": "a" * 40}]
|
||||
result = await compressor.compress(messages)
|
||||
assert result == messages
|
||||
|
||||
|
||||
class TestCompressionTriggersWhenOverBudget:
|
||||
"""超出预算时触发压缩"""
|
||||
|
||||
async def test_long_messages_get_compressed(self):
|
||||
gateway = make_mock_gateway("Compressed summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = make_long_messages(count=5, content_length=500)
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 结果应该比原始消息少
|
||||
assert len(result) < len(messages)
|
||||
# 应该包含系统消息
|
||||
system_msgs = [m for m in result if m.get("role") == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
# 应该保留最近的消息
|
||||
assert result[-1]["role"] != "system"
|
||||
|
||||
async def test_compression_preserves_system_messages(self):
|
||||
gateway = make_mock_gateway("Summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "a" * 2000},
|
||||
{"role": "assistant", "content": "b" * 2000},
|
||||
{"role": "user", "content": "c" * 2000},
|
||||
{"role": "assistant", "content": "d" * 2000},
|
||||
{"role": "user", "content": "Recent question"},
|
||||
{"role": "assistant", "content": "Recent answer"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 第一个消息应该是原始 system 消息
|
||||
assert result[0]["content"] == "System prompt"
|
||||
assert result[0]["role"] == "system"
|
||||
|
||||
async def test_compression_keeps_recent_messages(self):
|
||||
gateway = make_mock_gateway("Summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "a" * 2000},
|
||||
{"role": "assistant", "content": "b" * 2000},
|
||||
{"role": "user", "content": "Recent question"},
|
||||
{"role": "assistant", "content": "Recent answer"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 最后两条非系统消息应该是原始的最近消息
|
||||
non_system = [m for m in result if m.get("role") != "system"]
|
||||
assert non_system[-2]["content"] == "Recent question"
|
||||
assert non_system[-1]["content"] == "Recent answer"
|
||||
|
||||
|
||||
class TestSummaryGenerationWithLLM:
|
||||
"""LLM 摘要生成"""
|
||||
|
||||
async def test_llm_summarization_called(self):
|
||||
gateway = make_mock_gateway("LLM generated summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 2000},
|
||||
{"role": "assistant", "content": "b" * 2000},
|
||||
{"role": "user", "content": "Recent"},
|
||||
{"role": "assistant", "content": "Reply"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# LLM 应该被调用
|
||||
gateway.chat.assert_called_once()
|
||||
# 摘要应出现在结果中
|
||||
summary_msgs = [
|
||||
m for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "LLM generated summary" in summary_msgs[0]["content"]
|
||||
|
||||
|
||||
class TestFallbackToSimpleSummary:
|
||||
"""LLM 不可用时回退到简单摘要"""
|
||||
|
||||
async def test_no_llm_uses_simple_summary(self):
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=None,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 2000},
|
||||
{"role": "assistant", "content": "b" * 2000},
|
||||
{"role": "user", "content": "Recent"},
|
||||
{"role": "assistant", "content": "Reply"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 应该有摘要消息(简单截断模式)
|
||||
summary_msgs = [
|
||||
m for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
# 简单摘要应包含截断标记
|
||||
assert "..." in summary_msgs[0]["content"]
|
||||
|
||||
async def test_llm_failure_uses_simple_summary(self):
|
||||
gateway = make_mock_gateway()
|
||||
gateway.chat = AsyncMock(side_effect=Exception("LLM error"))
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 2000},
|
||||
{"role": "assistant", "content": "b" * 2000},
|
||||
{"role": "user", "content": "Recent"},
|
||||
{"role": "assistant", "content": "Reply"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 应该有摘要消息(回退到简单摘要)
|
||||
summary_msgs = [
|
||||
m for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
|
||||
|
||||
class TestAggressiveCompression:
|
||||
"""标准压缩后仍超预算时的激进压缩"""
|
||||
|
||||
async def test_aggressive_compression_when_still_over_budget(self):
|
||||
# 极小的预算,即使压缩后也超
|
||||
gateway = make_mock_gateway("x" * 5000) # 摘要本身也很长
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=10,
|
||||
keep_recent=2,
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 5000},
|
||||
{"role": "assistant", "content": "b" * 5000},
|
||||
{"role": "user", "content": "c" * 5000},
|
||||
{"role": "assistant", "content": "d" * 5000},
|
||||
{"role": "user", "content": "Recent"},
|
||||
{"role": "assistant", "content": "Reply"},
|
||||
]
|
||||
result = await compressor.compress(messages)
|
||||
|
||||
# 激进压缩应只保留最后一条非系统消息
|
||||
non_system = [m for m in result if m.get("role") != "system"]
|
||||
# 激进压缩后最多保留 1 条非系统消息
|
||||
assert len(non_system) <= 1
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
"""截断作为最后手段"""
|
||||
|
||||
def test_truncate_long_messages(self):
|
||||
compressor = ContextCompressor(max_tokens=50)
|
||||
messages = [
|
||||
{"role": "system", "content": "a" * 500},
|
||||
{"role": "user", "content": "b" * 500},
|
||||
]
|
||||
result = compressor._truncate(messages)
|
||||
|
||||
# 长消息应该被截断
|
||||
for msg in result:
|
||||
content = msg.get("content", "")
|
||||
if len(content) > 100 + len("...[truncated]"):
|
||||
# 只有超长消息才截断
|
||||
assert content.endswith("...[truncated]")
|
||||
|
||||
def test_truncate_preserves_short_messages(self):
|
||||
compressor = ContextCompressor(max_tokens=50)
|
||||
messages = [
|
||||
{"role": "user", "content": "Short message"},
|
||||
]
|
||||
result = compressor._truncate(messages)
|
||||
assert result[0]["content"] == "Short message"
|
||||
|
||||
|
||||
class TestNotEnoughMessagesToCompress:
|
||||
"""消息数量不足时跳过压缩"""
|
||||
|
||||
async def test_fewer_than_keep_recent_messages(self):
|
||||
compressor = ContextCompressor(
|
||||
max_tokens=10,
|
||||
keep_recent=5,
|
||||
)
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 200},
|
||||
{"role": "assistant", "content": "b" * 200},
|
||||
]
|
||||
# 非系统消息只有 2 条,keep_recent=5,不压缩
|
||||
result = await compressor.compress(messages)
|
||||
assert result == messages
|
||||
|
||||
|
||||
# ── PromptTemplate Cache Tests ───────────────────────
|
||||
|
||||
|
||||
class TestPromptTemplateRenderCached:
|
||||
"""render_cached() 缓存测试"""
|
||||
|
||||
def test_same_variables_returns_cached_result(self):
|
||||
section = PromptSection(
|
||||
identity="Bot",
|
||||
context="Hello ${name}",
|
||||
)
|
||||
tpl = PromptTemplate(sections=section)
|
||||
|
||||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||||
result2 = tpl.render_cached(variables={"name": "Alice"})
|
||||
|
||||
assert result1 == result2
|
||||
# 应该是同一个对象(缓存命中)
|
||||
assert result1 is result2
|
||||
|
||||
def test_different_variables_re_renders(self):
|
||||
section = PromptSection(
|
||||
context="Hello ${name}",
|
||||
)
|
||||
tpl = PromptTemplate(sections=section)
|
||||
|
||||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||||
result2 = tpl.render_cached(variables={"name": "Bob"})
|
||||
|
||||
assert result1 != result2
|
||||
assert "Alice" in result1[0]["content"]
|
||||
assert "Bob" in result2[0]["content"]
|
||||
|
||||
def test_no_variables_cached(self):
|
||||
section = PromptSection(identity="Bot")
|
||||
tpl = PromptTemplate(sections=section)
|
||||
|
||||
result1 = tpl.render_cached()
|
||||
result2 = tpl.render_cached()
|
||||
|
||||
assert result1 is result2
|
||||
|
||||
def test_render_cached_matches_render(self):
|
||||
section = PromptSection(
|
||||
identity="Bot",
|
||||
context="Hello ${name}",
|
||||
)
|
||||
tpl = PromptTemplate(sections=section)
|
||||
|
||||
cached = tpl.render_cached(variables={"name": "Alice"})
|
||||
direct = tpl.render(variables={"name": "Alice"})
|
||||
|
||||
assert cached == direct
|
||||
|
||||
|
||||
class TestPromptTemplateClearCache:
|
||||
"""clear_cache() 测试"""
|
||||
|
||||
def test_clear_cache_works(self):
|
||||
section = PromptSection(
|
||||
context="Hello ${name}",
|
||||
)
|
||||
tpl = PromptTemplate(sections=section)
|
||||
|
||||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||||
tpl.clear_cache()
|
||||
result2 = tpl.render_cached(variables={"name": "Alice"})
|
||||
|
||||
# 清除缓存后应该重新渲染,不再是同一对象
|
||||
assert result1 == result2
|
||||
assert result1 is not result2
|
||||
|
||||
def test_clear_cache_on_fresh_template(self):
|
||||
"""对没有缓存的新模板调用 clear_cache 不报错"""
|
||||
section = PromptSection(identity="Bot")
|
||||
tpl = PromptTemplate(sections=section)
|
||||
tpl.clear_cache() # 应该不抛异常
|
||||
|
||||
|
||||
class TestReActEngineWithCompressor:
|
||||
"""ReActEngine 集成 ContextCompressor 测试"""
|
||||
|
||||
async def test_execute_with_compressor(self):
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(return_value=LLMResponse(
|
||||
content="Final answer",
|
||||
model="test",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||
))
|
||||
|
||||
compressor = ContextCompressor(max_tokens=10000)
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
compressor=compressor,
|
||||
)
|
||||
|
||||
assert result.output == "Final answer"
|
||||
|
||||
async def test_execute_without_compressor_backward_compatible(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(return_value=LLMResponse(
|
||||
content="Answer",
|
||||
model="test",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||
))
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
# 不传 compressor 应该正常工作
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
assert result.output == "Answer"
|
||||
|
|
@ -0,0 +1,562 @@
|
|||
"""EpisodicMemory 向量检索单元测试 - cosine similarity + hybrid scoring"""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, DateTime, Float, String
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from agentkit.memory.episodic import EpisodicMemory
|
||||
from agentkit.memory.base import MemoryItem
|
||||
from agentkit.memory.embedder import MockEmbedder
|
||||
|
||||
|
||||
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class MockEpisodicModel(Base):
|
||||
"""模拟 EpisodicMemory ORM 模型"""
|
||||
|
||||
__tablename__ = "test_episodic_vector_search"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
agent_name = Column(String, default="")
|
||||
task_type = Column(String, default="")
|
||||
input_summary = Column(String, default="")
|
||||
output_summary = Column(String, default="")
|
||||
outcome = Column(String, default="success")
|
||||
quality_score = Column(Float, default=0.5)
|
||||
reflection = Column(String, default="")
|
||||
embedding = Column(String, nullable=True)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ── Mock 辅助工具 ────────────────────────────────────────
|
||||
|
||||
|
||||
def make_mock_entry(
|
||||
id: uuid.UUID | None = None,
|
||||
agent_name: str = "test_agent",
|
||||
task_type: str = "analysis",
|
||||
input_summary: str = "test input",
|
||||
output_summary: str = "test output",
|
||||
outcome: str = "success",
|
||||
quality_score: float = 0.8,
|
||||
reflection: str = "",
|
||||
embedding: list[float] | None = None,
|
||||
created_at: datetime | None = None,
|
||||
):
|
||||
"""创建一个模拟的 ORM entry 对象"""
|
||||
entry = MockEpisodicModel(
|
||||
id=str(id or uuid.uuid4()),
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
input_summary=input_summary,
|
||||
output_summary=output_summary,
|
||||
outcome=outcome,
|
||||
quality_score=quality_score,
|
||||
reflection=reflection,
|
||||
created_at=created_at or datetime.now(timezone.utc),
|
||||
)
|
||||
# 直接设置 embedding 属性(绕过 Column 限制)
|
||||
entry.embedding = embedding
|
||||
return entry
|
||||
|
||||
|
||||
def make_mock_session_factory(entries: list | None = None):
|
||||
"""创建一个 mock session_factory"""
|
||||
entries = entries or []
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = entries
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
@asynccontextmanager
|
||||
async def factory():
|
||||
yield mock_session
|
||||
|
||||
return factory, mock_session
|
||||
|
||||
|
||||
# ── Cosine Similarity 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
"""_compute_cosine_similarity 测试"""
|
||||
|
||||
def test_identical_vectors_return_one(self):
|
||||
"""相同向量余弦相似度为 1"""
|
||||
vec = [1.0, 0.0, 0.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_vectors_return_zero(self):
|
||||
"""正交向量余弦相似度为 0"""
|
||||
vec_a = [1.0, 0.0]
|
||||
vec_b = [0.0, 1.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
||||
|
||||
def test_opposite_vectors_return_minus_one(self):
|
||||
"""相反向量余弦相似度为 -1"""
|
||||
vec_a = [1.0, 0.0]
|
||||
vec_b = [-1.0, 0.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
||||
|
||||
def test_dimension_mismatch_returns_zero(self):
|
||||
"""维度不匹配返回 0"""
|
||||
vec_a = [1.0, 2.0]
|
||||
vec_b = [1.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
|
||||
def test_empty_vectors_return_zero(self):
|
||||
"""空向量返回 0"""
|
||||
assert EpisodicMemory._compute_cosine_similarity([], []) == 0.0
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
"""零向量返回 0"""
|
||||
vec_a = [0.0, 0.0]
|
||||
vec_b = [1.0, 2.0]
|
||||
assert EpisodicMemory._compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||||
|
||||
|
||||
# ── MockEmbedder 测试 ───────────────────────────────────
|
||||
|
||||
|
||||
class TestMockEmbedder:
|
||||
"""MockEmbedder 测试"""
|
||||
|
||||
async def test_embed_returns_correct_dimension(self):
|
||||
"""embed 返回指定维度的向量"""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
vec = await embedder.embed("test text")
|
||||
assert len(vec) == 64
|
||||
|
||||
async def test_embed_is_deterministic(self):
|
||||
"""相同文本生成相同向量"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
vec1 = await embedder.embed("hello world")
|
||||
vec2 = await embedder.embed("hello world")
|
||||
assert vec1 == vec2
|
||||
|
||||
async def test_embed_different_text_different_vector(self):
|
||||
"""不同文本生成不同向量"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
vec1 = await embedder.embed("hello")
|
||||
vec2 = await embedder.embed("world")
|
||||
assert vec1 != vec2
|
||||
|
||||
async def test_embed_produces_unit_vector(self):
|
||||
"""embed 生成单位向量"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
vec = await embedder.embed("test")
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
assert magnitude == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
def test_get_dimension(self):
|
||||
"""get_dimension 返回正确维度"""
|
||||
embedder = MockEmbedder(dimension=256)
|
||||
assert embedder.get_dimension() == 256
|
||||
|
||||
|
||||
# ── Store 测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStoreWithEmbedder:
|
||||
"""store() 带 embedder 的测试"""
|
||||
|
||||
async def test_store_generates_embedding_when_embedder_provided(self):
|
||||
"""有 embedder 时 store 生成 embedding"""
|
||||
factory, mock_session = make_mock_session_factory()
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
await mem.store("key1", "some value", {"agent_name": "test"})
|
||||
|
||||
entry_arg = mock_session.add.call_args[0][0]
|
||||
assert entry_arg.embedding is not None
|
||||
assert len(entry_arg.embedding) == 32
|
||||
|
||||
async def test_store_no_embedding_without_embedder(self):
|
||||
"""无 embedder 时 store 不生成 embedding"""
|
||||
factory, mock_session = make_mock_session_factory()
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
)
|
||||
|
||||
await mem.store("key1", "some value")
|
||||
|
||||
entry_arg = mock_session.add.call_args[0][0]
|
||||
assert entry_arg.embedding is None
|
||||
|
||||
|
||||
# ── Search 向量检索测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestSearchVectorSearch:
|
||||
"""search() 向量检索测试"""
|
||||
|
||||
async def test_search_with_embedder_uses_cosine_similarity(self):
|
||||
"""有 embedder 时 search 使用 cosine similarity 排序"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
# 生成 embedding
|
||||
vec_similar = await embedder.embed("financial analysis")
|
||||
vec_different = await embedder.embed("completely unrelated topic xyz")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
similar_entry = make_mock_entry(
|
||||
input_summary="financial analysis report",
|
||||
quality_score=0.5,
|
||||
embedding=vec_similar,
|
||||
created_at=now,
|
||||
)
|
||||
different_entry = make_mock_entry(
|
||||
input_summary="unrelated task",
|
||||
quality_score=0.5,
|
||||
embedding=vec_different,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=1.0, # 纯 cosine 排序
|
||||
)
|
||||
|
||||
results = await mem.search("financial analysis")
|
||||
assert len(results) == 2
|
||||
# 相似条目应排在前面
|
||||
assert results[0].value["input_summary"] == "financial analysis report"
|
||||
|
||||
async def test_search_fallback_to_time_decay_without_embedder(self):
|
||||
"""无 embedder 时 search 回退到时间衰减排序"""
|
||||
now = datetime.now(timezone.utc)
|
||||
recent_entry = make_mock_entry(
|
||||
quality_score=0.8,
|
||||
created_at=now - timedelta(hours=1),
|
||||
)
|
||||
old_entry = make_mock_entry(
|
||||
quality_score=0.8,
|
||||
created_at=now - timedelta(hours=100),
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([recent_entry, old_entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
)
|
||||
|
||||
results = await mem.search("test query")
|
||||
assert len(results) == 2
|
||||
# 近期条目应排在前面(纯时间衰减)
|
||||
assert results[0].score > results[1].score
|
||||
|
||||
async def test_search_hybrid_scoring_formula(self):
|
||||
"""混合评分公式:alpha * cosine + (1-alpha) * time_decay"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
vec_similar = await embedder.embed("query text")
|
||||
vec_different = await embedder.embed("something else entirely")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# 相似条目但质量低
|
||||
similar_entry = make_mock_entry(
|
||||
quality_score=0.5,
|
||||
embedding=vec_similar,
|
||||
created_at=now,
|
||||
)
|
||||
# 不相似条目但质量高
|
||||
different_entry = make_mock_entry(
|
||||
quality_score=0.9,
|
||||
embedding=vec_different,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
|
||||
# alpha=1.0 → 纯 cosine 排序
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=1.0,
|
||||
)
|
||||
|
||||
results = await mem.search("query text")
|
||||
# alpha=1.0 时,cosine 主导,相似条目排前面
|
||||
assert results[0].value["input_summary"] == similar_entry.input_summary
|
||||
|
||||
async def test_search_alpha_zero_pure_time_decay(self):
|
||||
"""alpha=0 时完全使用时间衰减排序"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
vec_similar = await embedder.embed("query text")
|
||||
vec_different = await embedder.embed("something else")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# 相似但质量低
|
||||
similar_entry = make_mock_entry(
|
||||
quality_score=0.3,
|
||||
embedding=vec_similar,
|
||||
created_at=now,
|
||||
)
|
||||
# 不相似但质量高
|
||||
different_entry = make_mock_entry(
|
||||
quality_score=0.9,
|
||||
embedding=vec_different,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=0.0, # 纯时间衰减
|
||||
)
|
||||
|
||||
results = await mem.search("query text")
|
||||
# alpha=0 时,time_decay 主导,高质量条目排前面
|
||||
assert results[0].value["quality_score"] == 0.9
|
||||
|
||||
async def test_search_entry_without_embedding_uses_time_decay(self):
|
||||
"""有 embedder 但 entry 没有 embedding 时使用时间衰减"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
entry_with_embedding = make_mock_entry(
|
||||
quality_score=0.5,
|
||||
embedding=await embedder.embed("test"),
|
||||
created_at=now - timedelta(hours=10),
|
||||
)
|
||||
entry_without_embedding = make_mock_entry(
|
||||
quality_score=0.9,
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([entry_with_embedding, entry_without_embedding])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
results = await mem.search("test query")
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_search_empty_store_returns_empty(self):
|
||||
"""空存储 search 返回空列表"""
|
||||
factory, _ = make_mock_session_factory([])
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
results = await mem.search("anything")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ── Retrieve 向量检索测试 ───────────────────────────────
|
||||
|
||||
|
||||
class TestRetrieveVectorSearch:
|
||||
"""retrieve() 向量检索测试"""
|
||||
|
||||
async def test_retrieve_with_embedder_returns_best_match(self):
|
||||
"""有 embedder 时 retrieve 返回最相似条目"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
vec_similar = await embedder.embed("financial report")
|
||||
vec_different = await embedder.embed("weather forecast")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
similar_entry = make_mock_entry(
|
||||
input_summary="financial report Q4",
|
||||
embedding=vec_similar,
|
||||
created_at=now,
|
||||
)
|
||||
different_entry = make_mock_entry(
|
||||
input_summary="weather forecast today",
|
||||
embedding=vec_different,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
result = await mem.retrieve("financial report")
|
||||
assert result is not None
|
||||
assert result.value["input_summary"] == "financial report Q4"
|
||||
assert result.metadata["cosine_similarity"] > 0.0
|
||||
|
||||
async def test_retrieve_without_embedder_returns_none(self):
|
||||
"""无 embedder 时 retrieve 返回 None"""
|
||||
factory, _ = make_mock_session_factory([])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
)
|
||||
|
||||
result = await mem.retrieve("any key")
|
||||
assert result is None
|
||||
|
||||
async def test_retrieve_empty_store_returns_none(self):
|
||||
"""空存储 retrieve 返回 None"""
|
||||
factory, _ = make_mock_session_factory([])
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
result = await mem.retrieve("any key")
|
||||
assert result is None
|
||||
|
||||
async def test_retrieve_no_entries_with_embedding_returns_none(self):
|
||||
"""所有 entry 都没有 embedding 时 retrieve 返回 None"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = make_mock_entry(
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
result = await mem.retrieve("any key")
|
||||
assert result is None
|
||||
|
||||
async def test_retrieve_returns_memory_item(self):
|
||||
"""retrieve 返回 MemoryItem 实例"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
vec = await embedder.embed("test query")
|
||||
now = datetime.now(timezone.utc)
|
||||
entry = make_mock_entry(
|
||||
input_summary="test input",
|
||||
output_summary="test output",
|
||||
outcome="success",
|
||||
quality_score=0.9,
|
||||
embedding=vec,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
factory, _ = make_mock_session_factory([entry])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
)
|
||||
|
||||
result = await mem.retrieve("test query")
|
||||
assert isinstance(result, MemoryItem)
|
||||
assert result.value["input_summary"] == "test input"
|
||||
assert result.value["output_summary"] == "test output"
|
||||
assert result.value["outcome"] == "success"
|
||||
assert result.score > 0.0
|
||||
|
||||
|
||||
# ── Alpha 参数测试 ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestAlphaParameter:
|
||||
"""alpha 参数控制混合评分平衡"""
|
||||
|
||||
async def test_alpha_controls_hybrid_balance(self):
|
||||
"""alpha 控制语义相似度和时间衰减的平衡"""
|
||||
embedder = MockEmbedder(dimension=32)
|
||||
|
||||
vec_similar = await embedder.embed("machine learning")
|
||||
vec_different = await embedder.embed("cooking recipes")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
similar_entry = make_mock_entry(
|
||||
quality_score=0.3,
|
||||
embedding=vec_similar,
|
||||
created_at=now,
|
||||
)
|
||||
different_entry = make_mock_entry(
|
||||
quality_score=0.9,
|
||||
embedding=vec_different,
|
||||
created_at=now,
|
||||
)
|
||||
|
||||
# alpha=1.0: 纯 cosine → 相似条目排前面
|
||||
factory1, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
mem_high_alpha = EpisodicMemory(
|
||||
session_factory=factory1,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=1.0,
|
||||
)
|
||||
results_high = await mem_high_alpha.search("machine learning")
|
||||
assert results_high[0].value["quality_score"] == 0.3 # 相似条目
|
||||
|
||||
# alpha=0.0: 纯 time_decay → 高质量条目排前面
|
||||
factory2, _ = make_mock_session_factory([similar_entry, different_entry])
|
||||
mem_low_alpha = EpisodicMemory(
|
||||
session_factory=factory2,
|
||||
episodic_model=MockEpisodicModel,
|
||||
embedder=embedder,
|
||||
alpha=0.0,
|
||||
)
|
||||
results_low = await mem_low_alpha.search("machine learning")
|
||||
assert results_low[0].value["quality_score"] == 0.9 # 高质量条目
|
||||
|
||||
async def test_default_alpha_is_0_7(self):
|
||||
"""默认 alpha 值为 0.7"""
|
||||
factory, _ = make_mock_session_factory([])
|
||||
|
||||
mem = EpisodicMemory(
|
||||
session_factory=factory,
|
||||
episodic_model=MockEpisodicModel,
|
||||
)
|
||||
|
||||
assert mem._alpha == 0.7
|
||||
|
|
@ -0,0 +1,374 @@
|
|||
"""Tests for PersistentEvolutionStore - SQLite-backed evolution persistence"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import EvolutionEvent
|
||||
from agentkit.evolution.evolution_store import (
|
||||
InMemoryEvolutionStore,
|
||||
PersistentEvolutionStore,
|
||||
create_evolution_store,
|
||||
)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_path(tmp_path):
|
||||
"""Provide a temporary SQLite database path."""
|
||||
return str(tmp_path / "test_evolution.db")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(db_path):
|
||||
"""Create a PersistentEvolutionStore with a temporary database."""
|
||||
return PersistentEvolutionStore(db_path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""A sample EvolutionEvent."""
|
||||
return EvolutionEvent(
|
||||
agent_name="test_agent",
|
||||
change_type="prompt",
|
||||
before={"prompt": "old prompt"},
|
||||
after={"prompt": "new prompt"},
|
||||
metrics={"accuracy": 0.9},
|
||||
)
|
||||
|
||||
|
||||
# ── record() + persistence tests ─────────────────────────
|
||||
|
||||
|
||||
class TestRecordAndPersistence:
|
||||
async def test_record_returns_event_id(self, store, sample_event):
|
||||
event_id = await store.record(sample_event)
|
||||
assert event_id is not None
|
||||
assert isinstance(event_id, str)
|
||||
assert len(event_id) > 0
|
||||
|
||||
async def test_record_sets_event_id_on_event(self, store, sample_event):
|
||||
assert sample_event.event_id is None
|
||||
await store.record(sample_event)
|
||||
assert sample_event.event_id is not None
|
||||
|
||||
async def test_record_and_reopen_returns_event(self, db_path, sample_event):
|
||||
"""Persistence test: record → close → reopen → list_events returns the event."""
|
||||
store1 = PersistentEvolutionStore(db_path=db_path)
|
||||
await store1.record(sample_event)
|
||||
event_id = sample_event.event_id
|
||||
del store1 # close
|
||||
|
||||
store2 = PersistentEvolutionStore(db_path=db_path)
|
||||
events = await store2.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["id"] == event_id
|
||||
assert events[0]["agent_name"] == "test_agent"
|
||||
assert events[0]["change_type"] == "prompt"
|
||||
|
||||
async def test_record_event_data_roundtrip(self, store, sample_event):
|
||||
"""Verify before/after/metrics are stored and retrieved correctly."""
|
||||
await store.record(sample_event)
|
||||
events = await store.list_events()
|
||||
assert len(events) == 1
|
||||
e = events[0]
|
||||
assert e["before"] == {"prompt": "old prompt"}
|
||||
assert e["after"] == {"prompt": "new prompt"}
|
||||
assert e["metrics"] == {"accuracy": 0.9}
|
||||
assert e["status"] == "active"
|
||||
assert e["created_at"] is not None
|
||||
|
||||
|
||||
# ── rollback() tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestRollback:
|
||||
async def test_rollback_success(self, store, sample_event):
|
||||
event_id = await store.record(sample_event)
|
||||
result = await store.rollback(event_id)
|
||||
assert result is True
|
||||
|
||||
events = await store.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["status"] == "rolled_back"
|
||||
|
||||
async def test_rollback_nonexistent_returns_false(self, store):
|
||||
result = await store.rollback("nonexistent-id")
|
||||
assert result is False
|
||||
|
||||
async def test_rollback_persists_across_reopen(self, db_path, sample_event):
|
||||
"""Rollback status persists after reopening the database."""
|
||||
store1 = PersistentEvolutionStore(db_path=db_path)
|
||||
event_id = await store1.record(sample_event)
|
||||
await store1.rollback(event_id)
|
||||
del store1
|
||||
|
||||
store2 = PersistentEvolutionStore(db_path=db_path)
|
||||
events = await store2.list_events()
|
||||
assert events[0]["status"] == "rolled_back"
|
||||
|
||||
|
||||
# ── list_events() tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestListEvents:
|
||||
async def test_list_events_empty(self, store):
|
||||
events = await store.list_events()
|
||||
assert events == []
|
||||
|
||||
async def test_list_events_filter_by_agent_name(self, store):
|
||||
event_a = EvolutionEvent(
|
||||
agent_name="agent_a", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_b = EvolutionEvent(
|
||||
agent_name="agent_b", change_type="prompt", before={}, after={}
|
||||
)
|
||||
await store.record(event_a)
|
||||
await store.record(event_b)
|
||||
|
||||
events = await store.list_events(agent_name="agent_a")
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "agent_a"
|
||||
|
||||
async def test_list_events_filter_by_change_type(self, store):
|
||||
event_prompt = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_strategy = EvolutionEvent(
|
||||
agent_name="test", change_type="strategy", before={}, after={}
|
||||
)
|
||||
await store.record(event_prompt)
|
||||
await store.record(event_strategy)
|
||||
|
||||
events = await store.list_events(change_type="strategy")
|
||||
assert len(events) == 1
|
||||
assert events[0]["change_type"] == "strategy"
|
||||
|
||||
async def test_list_events_filter_by_status(self, store):
|
||||
event = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_id = await store.record(event)
|
||||
await store.rollback(event_id)
|
||||
|
||||
active_events = await store.list_events(status="active")
|
||||
assert len(active_events) == 0
|
||||
|
||||
rolled_back_events = await store.list_events(status="rolled_back")
|
||||
assert len(rolled_back_events) == 1
|
||||
assert rolled_back_events[0]["status"] == "rolled_back"
|
||||
|
||||
async def test_list_events_multiple_with_combined_filters(self, store):
|
||||
"""Integration: record multiple events, list with filters."""
|
||||
for i in range(3):
|
||||
event = EvolutionEvent(
|
||||
agent_name="agent_a" if i < 2 else "agent_b",
|
||||
change_type="prompt" if i % 2 == 0 else "strategy",
|
||||
before={},
|
||||
after={},
|
||||
)
|
||||
await store.record(event)
|
||||
|
||||
# Filter by agent_name
|
||||
events = await store.list_events(agent_name="agent_a")
|
||||
assert len(events) == 2
|
||||
|
||||
# Filter by change_type
|
||||
events = await store.list_events(change_type="strategy")
|
||||
assert len(events) == 1
|
||||
|
||||
# Combined filter
|
||||
events = await store.list_events(agent_name="agent_a", change_type="prompt")
|
||||
assert len(events) == 1
|
||||
|
||||
async def test_list_events_ordered_by_created_at_desc(self, store):
|
||||
"""Events are returned newest first."""
|
||||
import asyncio
|
||||
|
||||
event1 = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={"v": 1}, after={}
|
||||
)
|
||||
await store.record(event1)
|
||||
await asyncio.sleep(0.01) # ensure different timestamps
|
||||
event2 = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={"v": 2}, after={}
|
||||
)
|
||||
await store.record(event2)
|
||||
|
||||
events = await store.list_events()
|
||||
assert len(events) == 2
|
||||
# Newest first
|
||||
assert events[0]["before"]["v"] == 2
|
||||
assert events[1]["before"]["v"] == 1
|
||||
|
||||
|
||||
# ── Skill version tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestSkillVersions:
|
||||
async def test_record_and_list_skill_version(self, store):
|
||||
vid = await store.record_skill_version(
|
||||
skill_name="search",
|
||||
version="v1",
|
||||
content='{"prompt": "search for X"}',
|
||||
)
|
||||
assert vid is not None
|
||||
|
||||
versions = await store.list_skill_versions("search")
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["skill_name"] == "search"
|
||||
assert versions[0]["version"] == "v1"
|
||||
assert versions[0]["content"] == '{"prompt": "search for X"}'
|
||||
|
||||
async def test_skill_version_with_parent(self, store):
|
||||
await store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
await store.record_skill_version(
|
||||
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
|
||||
)
|
||||
|
||||
versions = await store.list_skill_versions("search")
|
||||
assert len(versions) == 2
|
||||
# Newest first
|
||||
assert versions[0]["version"] == "v2"
|
||||
assert versions[0]["parent_version"] == "v1"
|
||||
assert versions[1]["version"] == "v1"
|
||||
assert versions[1]["parent_version"] is None
|
||||
|
||||
async def test_skill_versions_persist_across_reopen(self, db_path):
|
||||
store1 = PersistentEvolutionStore(db_path=db_path)
|
||||
await store1.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
||||
del store1
|
||||
|
||||
store2 = PersistentEvolutionStore(db_path=db_path)
|
||||
versions = await store2.list_skill_versions("search")
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["version"] == "v1"
|
||||
|
||||
async def test_list_skill_versions_empty(self, store):
|
||||
versions = await store.list_skill_versions("nonexistent")
|
||||
assert versions == []
|
||||
|
||||
|
||||
# ── A/B test result tests ────────────────────────────────
|
||||
|
||||
|
||||
class TestABTestResults:
|
||||
async def test_record_and_get_ab_test_result(self, store):
|
||||
rid = await store.record_ab_test_result(
|
||||
test_id="test_001", variant="control", score=0.85, sample_count=10
|
||||
)
|
||||
assert rid is not None
|
||||
|
||||
results = await store.get_ab_test_results("test_001")
|
||||
assert len(results) == 1
|
||||
assert results[0]["test_id"] == "test_001"
|
||||
assert results[0]["variant"] == "control"
|
||||
assert results[0]["score"] == 0.85
|
||||
assert results[0]["sample_count"] == 10
|
||||
|
||||
async def test_ab_test_multiple_variants(self, store):
|
||||
await store.record_ab_test_result("test_001", "control", 0.8, 10)
|
||||
await store.record_ab_test_result("test_001", "experiment", 0.9, 10)
|
||||
|
||||
results = await store.get_ab_test_results("test_001")
|
||||
assert len(results) == 2
|
||||
|
||||
async def test_ab_test_results_persist_across_reopen(self, db_path):
|
||||
store1 = PersistentEvolutionStore(db_path=db_path)
|
||||
await store1.record_ab_test_result("test_001", "control", 0.8, 5)
|
||||
del store1
|
||||
|
||||
store2 = PersistentEvolutionStore(db_path=db_path)
|
||||
results = await store2.get_ab_test_results("test_001")
|
||||
assert len(results) == 1
|
||||
assert results[0]["variant"] == "control"
|
||||
|
||||
async def test_get_ab_test_results_empty(self, store):
|
||||
results = await store.get_ab_test_results("nonexistent")
|
||||
assert results == []
|
||||
|
||||
|
||||
# ── InMemoryEvolutionStore tests ─────────────────────────
|
||||
|
||||
|
||||
class TestInMemoryEvolutionStore:
|
||||
async def test_record_and_list(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
event = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_id = await store.record(event)
|
||||
assert event_id is not None
|
||||
|
||||
events = await store.list_events()
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_name"] == "test"
|
||||
|
||||
async def test_rollback(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
event = EvolutionEvent(
|
||||
agent_name="test", change_type="prompt", before={}, after={}
|
||||
)
|
||||
event_id = await store.record(event)
|
||||
result = await store.rollback(event_id)
|
||||
assert result is True
|
||||
|
||||
events = await store.list_events()
|
||||
assert events[0]["status"] == "rolled_back"
|
||||
|
||||
async def test_rollback_nonexistent(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
result = await store.rollback("nonexistent")
|
||||
assert result is False
|
||||
|
||||
async def test_list_events_with_filters(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
await store.record(
|
||||
EvolutionEvent(agent_name="a", change_type="prompt", before={}, after={})
|
||||
)
|
||||
await store.record(
|
||||
EvolutionEvent(agent_name="b", change_type="strategy", before={}, after={})
|
||||
)
|
||||
|
||||
events = await store.list_events(agent_name="a")
|
||||
assert len(events) == 1
|
||||
|
||||
async def test_skill_versions(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
await store.record_skill_version("skill1", "v1", '{"data": 1}')
|
||||
versions = await store.list_skill_versions("skill1")
|
||||
assert len(versions) == 1
|
||||
assert versions[0]["version"] == "v1"
|
||||
|
||||
async def test_ab_test_results(self):
|
||||
store = InMemoryEvolutionStore()
|
||||
await store.record_ab_test_result("t1", "control", 0.8, 5)
|
||||
results = await store.get_ab_test_results("t1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["variant"] == "control"
|
||||
|
||||
|
||||
# ── create_evolution_store factory tests ──────────────────
|
||||
|
||||
|
||||
class TestCreateEvolutionStore:
|
||||
def test_create_memory_backend(self):
|
||||
store = create_evolution_store(backend="memory")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_create_sqlite_backend(self, tmp_path):
|
||||
db_path = str(tmp_path / "factory_test.db")
|
||||
store = create_evolution_store(backend="sqlite", db_path=db_path)
|
||||
assert isinstance(store, PersistentEvolutionStore)
|
||||
|
||||
def test_create_default_backend(self):
|
||||
store = create_evolution_store()
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
||||
def test_create_sql_backend_without_params_falls_back(self):
|
||||
"""sql backend without session_factory/evolution_model falls back to memory."""
|
||||
store = create_evolution_store(backend="sql")
|
||||
assert isinstance(store, InMemoryEvolutionStore)
|
||||
|
|
@ -0,0 +1,295 @@
|
|||
"""Tests for LLMReflector - LLM 驱动的执行反思器"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.core.trace import ExecutionTrace, TraceStep
|
||||
from agentkit.evolution.llm_reflector import LLMReflector
|
||||
from agentkit.evolution.reflector import Reflection, Reflector, RuleBasedReflector
|
||||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||||
from agentkit.skills.base import EvolutionConfig
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_task() -> TaskMessage:
|
||||
return TaskMessage(
|
||||
task_id="test-001",
|
||||
agent_name="test_agent",
|
||||
task_type="echo",
|
||||
priority=0,
|
||||
input_data={"query": "hello"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||||
return TaskResult(
|
||||
task_id="test-001",
|
||||
agent_name="test_agent",
|
||||
status=status,
|
||||
output_data={"key": "value"},
|
||||
error_message=None,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={"elapsed_seconds": 5.0},
|
||||
)
|
||||
|
||||
|
||||
def _make_trace() -> ExecutionTrace:
|
||||
return ExecutionTrace(
|
||||
task_id="test-001",
|
||||
agent_name="test_agent",
|
||||
steps=[
|
||||
TraceStep(step=1, action="llm_call", tokens_used=100),
|
||||
TraceStep(
|
||||
step=2,
|
||||
action="tool_call",
|
||||
tool_name="search",
|
||||
duration_ms=200,
|
||||
tokens_used=50,
|
||||
),
|
||||
TraceStep(step=3, action="final_answer", tokens_used=80),
|
||||
],
|
||||
total_duration_ms=500,
|
||||
total_tokens=230,
|
||||
outcome="success",
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_gateway(response_content: str) -> MagicMock:
|
||||
"""创建返回指定内容的 mock LLMGateway"""
|
||||
gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = response_content
|
||||
gateway.chat = AsyncMock(return_value=mock_response)
|
||||
return gateway
|
||||
|
||||
|
||||
# ── LLMReflector 基础功能 ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_reflector_parses_json_in_code_block():
|
||||
"""LLMReflector 从代码块中的 JSON 生成 Reflection"""
|
||||
json_data = {
|
||||
"outcome": "success",
|
||||
"quality_score": 0.85,
|
||||
"patterns": ["fast_execution"],
|
||||
"insights": ["Task completed efficiently"],
|
||||
"suggestions": ["Consider caching results"],
|
||||
}
|
||||
response = f"```json\n{json.dumps(json_data)}\n```"
|
||||
gateway = _make_mock_gateway(response)
|
||||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
reflection = await reflector.reflect(task, result)
|
||||
|
||||
assert isinstance(reflection, Reflection)
|
||||
assert reflection.outcome == "success"
|
||||
assert reflection.quality_score == 0.85
|
||||
assert reflection.patterns == ["fast_execution"]
|
||||
assert reflection.insights == ["Task completed efficiently"]
|
||||
assert reflection.suggestions == ["Consider caching results"]
|
||||
assert reflection.task_id == "test-001"
|
||||
assert reflection.agent_name == "test_agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_reflector_parses_raw_json():
|
||||
"""LLMReflector 从原始 JSON 响应生成 Reflection"""
|
||||
json_data = {
|
||||
"outcome": "failure",
|
||||
"quality_score": 0.2,
|
||||
"patterns": ["slow_execution", "error_type:TimeoutError"],
|
||||
"insights": ["Timeout occurred"],
|
||||
"suggestions": ["Increase timeout"],
|
||||
}
|
||||
gateway = _make_mock_gateway(json.dumps(json_data))
|
||||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result(status=TaskStatus.FAILED)
|
||||
reflection = await reflector.reflect(task, result)
|
||||
|
||||
assert reflection.outcome == "failure"
|
||||
assert reflection.quality_score == 0.2
|
||||
assert "slow_execution" in reflection.patterns
|
||||
assert "Increase timeout" in reflection.suggestions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_reflector_handles_unparseable_response():
|
||||
"""LLMReflector 处理无法解析的 LLM 响应(降级反思)"""
|
||||
gateway = _make_mock_gateway("This is not JSON at all, just plain text.")
|
||||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
reflection = await reflector.reflect(task, result)
|
||||
|
||||
assert isinstance(reflection, Reflection)
|
||||
assert reflection.outcome == "partial"
|
||||
assert reflection.quality_score == 0.5
|
||||
assert "LLM response could not be parsed as structured reflection" in reflection.insights
|
||||
assert "Review LLM output format" in reflection.suggestions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_reflector_handles_llm_call_failure():
|
||||
"""LLMReflector 处理 LLM 调用失败(返回失败反思)"""
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(side_effect=Exception("LLM service unavailable"))
|
||||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
reflection = await reflector.reflect(task, result)
|
||||
|
||||
assert isinstance(reflection, Reflection)
|
||||
assert reflection.outcome == "failure"
|
||||
assert reflection.quality_score == 0.0
|
||||
assert any("LLM reflection failed" in i for i in reflection.insights)
|
||||
assert "Consider using rule-based reflector as fallback" in reflection.suggestions
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_reflector_uses_execution_trace():
|
||||
"""LLMReflector 使用 ExecutionTrace 信息"""
|
||||
gateway = _make_mock_gateway('{"outcome": "success", "quality_score": 0.9}')
|
||||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||||
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
trace = _make_trace()
|
||||
reflection = await reflector.reflect(task, result, trace=trace)
|
||||
|
||||
# 验证 LLM 被调用,且 prompt 中包含 trace 信息
|
||||
call_args = gateway.chat.call_args
|
||||
prompt = call_args.kwargs["messages"][0]["content"]
|
||||
assert "Total Steps: 3" in prompt
|
||||
assert "Total Duration: 500ms" in prompt
|
||||
assert "Total Tokens: 230" in prompt
|
||||
assert "Tool: search" in prompt
|
||||
assert reflection.outcome == "success"
|
||||
|
||||
|
||||
# ── Auto 模式 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auto_mode_with_llm_available():
|
||||
"""Auto 模式:LLM 可用时使用 LLMReflector"""
|
||||
gateway = MagicMock()
|
||||
mixin = EvolutionMixin(reflector_type="auto", llm_gateway=gateway)
|
||||
assert isinstance(mixin._reflector, LLMReflector)
|
||||
|
||||
|
||||
def test_auto_mode_without_llm_falls_back():
|
||||
"""Auto 模式:LLM 不可用时降级到 RuleBasedReflector"""
|
||||
mixin = EvolutionMixin(reflector_type="auto", llm_gateway=None)
|
||||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||||
|
||||
|
||||
def test_rule_mode_always_uses_rule_based():
|
||||
"""Rule 模式:始终使用 RuleBasedReflector"""
|
||||
gateway = MagicMock()
|
||||
mixin = EvolutionMixin(reflector_type="rule", llm_gateway=gateway)
|
||||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||||
|
||||
|
||||
def test_llm_mode_without_gateway_falls_back():
|
||||
"""LLM 模式:无 gateway 时降级到 RuleBasedReflector"""
|
||||
mixin = EvolutionMixin(reflector_type="llm", llm_gateway=None)
|
||||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||||
|
||||
|
||||
def test_llm_mode_with_gateway():
|
||||
"""LLM 模式:有 gateway 时使用 LLMReflector"""
|
||||
gateway = MagicMock()
|
||||
mixin = EvolutionMixin(reflector_type="llm", llm_gateway=gateway)
|
||||
assert isinstance(mixin._reflector, LLMReflector)
|
||||
|
||||
|
||||
def test_explicit_reflector_overrides_type():
|
||||
"""显式传入 reflector 时覆盖 reflector_type"""
|
||||
gateway = MagicMock()
|
||||
rule_reflector = RuleBasedReflector()
|
||||
mixin = EvolutionMixin(
|
||||
reflector=rule_reflector,
|
||||
reflector_type="llm",
|
||||
llm_gateway=gateway,
|
||||
)
|
||||
assert mixin._reflector is rule_reflector
|
||||
|
||||
|
||||
def test_auxiliary_model_passed_to_llm_reflector():
|
||||
"""auxiliary_model 正确传递给 LLMReflector"""
|
||||
gateway = MagicMock()
|
||||
mixin = EvolutionMixin(
|
||||
reflector_type="llm",
|
||||
llm_gateway=gateway,
|
||||
auxiliary_model="gpt-4o-mini",
|
||||
)
|
||||
assert isinstance(mixin._reflector, LLMReflector)
|
||||
assert mixin._reflector._model == "gpt-4o-mini"
|
||||
|
||||
|
||||
def test_no_reflector_type_defaults_to_none():
|
||||
"""不指定 reflector_type 时,reflector 为 None(向后兼容)"""
|
||||
mixin = EvolutionMixin()
|
||||
assert mixin._reflector is None
|
||||
|
||||
|
||||
# ── EvolutionConfig 新字段 ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_evolution_config_default_values():
|
||||
"""EvolutionConfig 默认值"""
|
||||
config = EvolutionConfig()
|
||||
assert config.reflector_type == "auto"
|
||||
assert config.auxiliary_model is None
|
||||
|
||||
|
||||
def test_evolution_config_custom_values():
|
||||
"""EvolutionConfig 自定义值"""
|
||||
config = EvolutionConfig(
|
||||
enabled=True,
|
||||
reflector_type="llm",
|
||||
auxiliary_model="gpt-4o-mini",
|
||||
)
|
||||
assert config.reflector_type == "llm"
|
||||
assert config.auxiliary_model == "gpt-4o-mini"
|
||||
|
||||
|
||||
# ── 向后兼容 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_reflector_alias_still_works():
|
||||
"""Reflector 别名仍然可用"""
|
||||
assert Reflector is RuleBasedReflector
|
||||
reflector = Reflector()
|
||||
assert isinstance(reflector, RuleBasedReflector)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflector_alias_produces_same_reflection():
|
||||
"""Reflector 别名产生与 RuleBasedReflector 相同的结果"""
|
||||
task = _make_task()
|
||||
result = _make_result()
|
||||
|
||||
r1 = Reflector()
|
||||
r2 = RuleBasedReflector()
|
||||
|
||||
reflection1 = await r1.reflect(task, result)
|
||||
reflection2 = await r2.reflect(task, result)
|
||||
|
||||
assert reflection1.outcome == reflection2.outcome
|
||||
assert reflection1.quality_score == reflection2.quality_score
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
"""U4: 记忆接入 Agent 循环 - 集成测试
|
||||
|
||||
测试 MemoryRetriever 注入 ReActEngine 的完整流程:
|
||||
1. 执行前检索相关上下文注入 system_prompt
|
||||
2. 执行后写入轨迹摘要到 EpisodicMemory
|
||||
3. Memory 检索失败不中断任务执行
|
||||
4. ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever
|
||||
5. BaseAgent.use_memory_retriever() 方法
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.react import ReActEngine, ReActResult
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
|
||||
# ── Test Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||||
"""创建一个 mock LLMGateway,按顺序返回给定响应"""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
"""快速构造 LLMResponse"""
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
tool_calls=[],
|
||||
)
|
||||
|
||||
|
||||
def make_mock_memory_retriever(context_string: str = "past experience data"):
|
||||
"""创建一个 mock MemoryRetriever"""
|
||||
retriever = MagicMock()
|
||||
retriever.get_context_string = AsyncMock(return_value=context_string)
|
||||
retriever._episodic = None
|
||||
return retriever
|
||||
|
||||
|
||||
def make_mock_episodic_memory():
|
||||
"""创建一个 mock EpisodicMemory"""
|
||||
episodic = MagicMock()
|
||||
episodic.store = AsyncMock()
|
||||
return episodic
|
||||
|
||||
|
||||
# ── Test: Memory context injected into system_prompt ──────────
|
||||
|
||||
|
||||
class TestMemoryContextInjection:
|
||||
"""Memory 上下文注入 system_prompt 测试"""
|
||||
|
||||
async def test_memory_context_appended_to_existing_system_prompt(self):
|
||||
"""当有 system_prompt 时,memory context 追加到末尾"""
|
||||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever("Previous task result: success")
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Do something"}],
|
||||
system_prompt="You are a helpful assistant.",
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
retriever.get_context_string.assert_awaited_once_with(
|
||||
query="Do something",
|
||||
top_k=5,
|
||||
token_budget=2000,
|
||||
)
|
||||
|
||||
# Verify system_prompt was augmented with memory context
|
||||
call_args = gateway.chat.call_args
|
||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
# The first message should be system with appended context
|
||||
system_msg = messages_sent[0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert "You are a helpful assistant." in system_msg["content"]
|
||||
assert "Relevant Past Experience" in system_msg["content"]
|
||||
assert "Previous task result: success" in system_msg["content"]
|
||||
|
||||
async def test_memory_context_used_as_system_prompt_when_none(self):
|
||||
"""当没有 system_prompt 时,memory context 作为 system_prompt"""
|
||||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever("Past context only")
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
call_args = gateway.chat.call_args
|
||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
system_msg = messages_sent[0]
|
||||
assert system_msg["role"] == "system"
|
||||
assert "Relevant Past Experience" in system_msg["content"]
|
||||
assert "Past context only" in system_msg["content"]
|
||||
|
||||
async def test_no_memory_context_when_retriever_is_none(self):
|
||||
"""当 memory_retriever 为 None 时,不注入 memory context"""
|
||||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system_prompt="You are a helper.",
|
||||
memory_retriever=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
call_args = gateway.chat.call_args
|
||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
system_msg = messages_sent[0]
|
||||
assert system_msg["content"] == "You are a helper."
|
||||
assert "Relevant Past Experience" not in system_msg["content"]
|
||||
|
||||
async def test_empty_memory_context_not_injected(self):
|
||||
"""当 memory context 为空字符串时,不注入"""
|
||||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever(context_string="")
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system_prompt="You are a helper.",
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
call_args = gateway.chat.call_args
|
||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
system_msg = messages_sent[0]
|
||||
assert system_msg["content"] == "You are a helper."
|
||||
assert "Relevant Past Experience" not in system_msg["content"]
|
||||
|
||||
|
||||
# ── Test: Memory retrieval failure doesn't break execution ──────────
|
||||
|
||||
|
||||
class TestMemoryRetrievalFailure:
|
||||
"""Memory 检索失败不中断任务执行"""
|
||||
|
||||
async def test_retrieval_failure_continues_without_context(self):
|
||||
"""Memory 检索异常时,任务正常执行"""
|
||||
gateway = make_mock_gateway([make_response(content="still works")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever()
|
||||
retriever.get_context_string = AsyncMock(side_effect=RuntimeError("Redis down"))
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system_prompt="You are a helper.",
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
# Task should still complete
|
||||
assert isinstance(result, ReActResult)
|
||||
assert result.output == "still works"
|
||||
|
||||
# system_prompt should NOT have memory context
|
||||
call_args = gateway.chat.call_args
|
||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
system_msg = messages_sent[0]
|
||||
assert "Relevant Past Experience" not in system_msg["content"]
|
||||
|
||||
|
||||
# ── Test: Task result stored in episodic memory ──────────
|
||||
|
||||
|
||||
class TestEpisodicMemoryStorage:
|
||||
"""执行后写入轨迹摘要到 EpisodicMemory"""
|
||||
|
||||
async def test_result_stored_in_episodic_memory(self):
|
||||
"""任务完成后,结果摘要存储到 EpisodicMemory"""
|
||||
gateway = make_mock_gateway([make_response(content="The answer is 42")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
episodic = make_mock_episodic_memory()
|
||||
retriever = make_mock_memory_retriever(context_string="")
|
||||
retriever._episodic = episodic
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||||
memory_retriever=retriever,
|
||||
task_id="task-123",
|
||||
agent_name="test-agent",
|
||||
task_type="qa",
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
episodic.store.assert_awaited_once()
|
||||
call_kwargs = episodic.store.call_args
|
||||
assert call_kwargs.kwargs.get("key") == "task:task-123" or call_kwargs[1].get("key") == "task:task-123"
|
||||
# Verify metadata
|
||||
metadata = call_kwargs.kwargs.get("metadata") or call_kwargs[1].get("metadata")
|
||||
assert metadata["task_type"] == "qa"
|
||||
assert metadata["outcome"] == "success"
|
||||
|
||||
async def test_no_storage_when_no_episodic_memory(self):
|
||||
"""没有 EpisodicMemory 时不尝试存储"""
|
||||
gateway = make_mock_gateway([make_response(content="done")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever(context_string="")
|
||||
retriever._episodic = None
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
# No exception raised, no store called
|
||||
|
||||
async def test_storage_failure_doesnt_break_execution(self):
|
||||
"""EpisodicMemory 存储失败不中断任务"""
|
||||
gateway = make_mock_gateway([make_response(content="done")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
episodic = make_mock_episodic_memory()
|
||||
episodic.store = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
|
||||
retriever = make_mock_memory_retriever(context_string="")
|
||||
retriever._episodic = episodic
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
memory_retriever=retriever,
|
||||
)
|
||||
|
||||
# Task should still complete
|
||||
assert isinstance(result, ReActResult)
|
||||
assert result.output == "done"
|
||||
|
||||
|
||||
# ── Test: execute_stream with memory ──────────
|
||||
|
||||
|
||||
class TestMemoryInStreamMode:
|
||||
"""execute_stream 模式下的 Memory 集成"""
|
||||
|
||||
async def test_stream_injects_memory_context(self):
|
||||
"""execute_stream 也注入 memory context"""
|
||||
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
retriever = make_mock_memory_retriever("Stream context")
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
system_prompt="You are a helper.",
|
||||
memory_retriever=retriever,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should have events
|
||||
assert len(events) > 0
|
||||
retriever.get_context_string.assert_awaited_once()
|
||||
|
||||
async def test_stream_stores_to_episodic(self):
|
||||
"""execute_stream 完成后也存储到 EpisodicMemory"""
|
||||
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
episodic = make_mock_episodic_memory()
|
||||
retriever = make_mock_memory_retriever(context_string="")
|
||||
retriever._episodic = episodic
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
memory_retriever=retriever,
|
||||
task_id="stream-task-1",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
episodic.store.assert_awaited_once()
|
||||
|
||||
|
||||
# ── Test: BaseAgent.use_memory_retriever() ──────────
|
||||
|
||||
|
||||
class TestBaseAgentMemoryRetriever:
|
||||
"""BaseAgent.use_memory_retriever() 方法测试"""
|
||||
|
||||
def test_use_memory_retriever_sets_field(self):
|
||||
"""use_memory_retriever() 正确设置 _memory_retriever"""
|
||||
from agentkit.core.base import BaseAgent
|
||||
|
||||
# Create a concrete subclass for testing
|
||||
class TestAgent(BaseAgent):
|
||||
async def handle_task(self, task):
|
||||
return {}
|
||||
|
||||
def get_capabilities(self):
|
||||
from agentkit.core.protocol import AgentCapability
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
)
|
||||
|
||||
agent = TestAgent(name="test", agent_type="test")
|
||||
mock_retriever = MagicMock()
|
||||
|
||||
result = agent.use_memory_retriever(mock_retriever)
|
||||
|
||||
# Should return self for chaining
|
||||
assert result is agent
|
||||
assert agent._memory_retriever is mock_retriever
|
||||
|
||||
def test_memory_retriever_default_is_none(self):
|
||||
"""_memory_retriever 默认为 None"""
|
||||
from agentkit.core.base import BaseAgent
|
||||
|
||||
class TestAgent(BaseAgent):
|
||||
async def handle_task(self, task):
|
||||
return {}
|
||||
|
||||
def get_capabilities(self):
|
||||
from agentkit.core.protocol import AgentCapability
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
)
|
||||
|
||||
agent = TestAgent(name="test", agent_type="test")
|
||||
assert agent._memory_retriever is None
|
||||
|
||||
|
||||
# ── Test: ConfigDrivenAgent memory integration ──────────
|
||||
|
||||
|
||||
class TestConfigDrivenAgentMemory:
|
||||
"""ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever"""
|
||||
|
||||
def test_memory_retriever_created_from_config(self):
|
||||
"""config.memory 配置时自动创建 MemoryRetriever"""
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||||
|
||||
config = AgentConfig(
|
||||
name="test-agent",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "Test agent"},
|
||||
memory={
|
||||
"working": {"enabled": False},
|
||||
"episodic": {"enabled": False},
|
||||
},
|
||||
)
|
||||
|
||||
with patch("agentkit.core.config_driven.MemoryRetriever", create=True) or \
|
||||
self._patch_memory_imports():
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
# MemoryRetriever should have been created (with no backends since both disabled)
|
||||
assert agent._memory_retriever is not None
|
||||
|
||||
@staticmethod
|
||||
def _patch_memory_imports():
|
||||
"""Helper to handle import patching"""
|
||||
from unittest.mock import patch
|
||||
return patch("agentkit.memory.retriever.MemoryRetriever")
|
||||
|
||||
def test_no_memory_retriever_when_no_config(self):
|
||||
"""没有 config.memory 时不创建 MemoryRetriever"""
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||||
|
||||
config = AgentConfig(
|
||||
name="test-agent",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "Test agent"},
|
||||
)
|
||||
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
assert agent._memory_retriever is None
|
||||
|
||||
def test_memory_retriever_created_with_empty_memory_dict(self):
|
||||
"""config.memory 为空 dict 时创建 MemoryRetriever(无后端)"""
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||||
|
||||
config = AgentConfig(
|
||||
name="test-agent",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "Test agent"},
|
||||
memory={},
|
||||
)
|
||||
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
# Empty dict is falsy, so no retriever
|
||||
assert agent._memory_retriever is None
|
||||
|
||||
def test_memory_retriever_failure_graceful(self):
|
||||
"""Memory 初始化失败时优雅降级"""
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||||
|
||||
config = AgentConfig(
|
||||
name="test-agent",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "Test agent"},
|
||||
memory={"working": {"enabled": True, "redis_url": "redis://nonexistent:6379"}},
|
||||
)
|
||||
|
||||
# Should not raise, just log warning and set _memory_retriever to None
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
# Either retriever was created or gracefully failed
|
||||
# The key is that no exception is raised
|
||||
|
|
@ -0,0 +1,308 @@
|
|||
"""Unit tests for observability features: structured logging, metrics, health check"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.logging import StructuredFormatter, setup_structured_logging, get_logger
|
||||
from agentkit.core.protocol import TaskStatus
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.server.app import create_app
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
# ── Structured Logging Tests ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStructuredFormatter:
|
||||
"""StructuredFormatter outputs valid JSON with required fields"""
|
||||
|
||||
def test_outputs_valid_json(self):
|
||||
formatter = StructuredFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="agentkit.test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="hello world",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "timestamp" in data
|
||||
assert data["level"] == "INFO"
|
||||
assert data["logger"] == "agentkit.test"
|
||||
assert data["message"] == "hello world"
|
||||
|
||||
def test_includes_extra_fields(self):
|
||||
formatter = StructuredFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="agentkit.test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="with extras",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
record.trace_id = "abc-123"
|
||||
record.agent_name = "my_agent"
|
||||
record.skill_name = "my_skill"
|
||||
record.task_id = "task-456"
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["trace_id"] == "abc-123"
|
||||
assert data["agent_name"] == "my_agent"
|
||||
assert data["skill_name"] == "my_skill"
|
||||
assert data["task_id"] == "task-456"
|
||||
|
||||
def test_omits_empty_extra_fields(self):
|
||||
formatter = StructuredFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="agentkit.test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="no extras",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "trace_id" not in data
|
||||
assert "agent_name" not in data
|
||||
|
||||
def test_includes_exception_info(self):
|
||||
formatter = StructuredFormatter()
|
||||
try:
|
||||
raise ValueError("test error")
|
||||
except ValueError:
|
||||
import sys
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
record = logging.LogRecord(
|
||||
name="agentkit.test",
|
||||
level=logging.ERROR,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="error occurred",
|
||||
args=(),
|
||||
exc_info=exc_info,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert "exception" in data
|
||||
assert "ValueError" in data["exception"]
|
||||
assert "test error" in data["exception"]
|
||||
|
||||
def test_unicode_message(self):
|
||||
formatter = StructuredFormatter()
|
||||
record = logging.LogRecord(
|
||||
name="agentkit.test",
|
||||
level=logging.INFO,
|
||||
pathname="test.py",
|
||||
lineno=1,
|
||||
msg="中文日志消息",
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
output = formatter.format(record)
|
||||
data = json.loads(output)
|
||||
assert data["message"] == "中文日志消息"
|
||||
|
||||
|
||||
class TestSetupStructuredLogging:
|
||||
"""setup_structured_logging() configures agentkit logger"""
|
||||
|
||||
def test_configures_agentkit_logger(self):
|
||||
setup_structured_logging(level=logging.DEBUG)
|
||||
logger = logging.getLogger("agentkit")
|
||||
assert logger.level == logging.DEBUG
|
||||
assert len(logger.handlers) == 1
|
||||
handler = logger.handlers[0]
|
||||
assert isinstance(handler.formatter, StructuredFormatter)
|
||||
|
||||
def test_clears_existing_handlers(self):
|
||||
logger = logging.getLogger("agentkit")
|
||||
logger.addHandler(logging.StreamHandler())
|
||||
initial_count = len(logger.handlers)
|
||||
|
||||
setup_structured_logging()
|
||||
assert len(logger.handlers) == 1
|
||||
assert len(logger.handlers) < initial_count + 1
|
||||
|
||||
|
||||
class TestGetLogger:
|
||||
"""get_logger() creates logger with extra fields"""
|
||||
|
||||
def test_returns_logger_adapter(self):
|
||||
adapter = get_logger("my_module")
|
||||
assert isinstance(adapter, logging.LoggerAdapter)
|
||||
assert adapter.logger.name == "agentkit.my_module"
|
||||
|
||||
def test_extra_fields_in_adapter(self):
|
||||
adapter = get_logger("test", trace_id="t-1", agent_name="a-1")
|
||||
assert adapter.extra["trace_id"] == "t-1"
|
||||
assert adapter.extra["agent_name"] == "a-1"
|
||||
|
||||
|
||||
# ── Metrics Endpoint Tests ─────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_gateway():
|
||||
gateway = LLMGateway()
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.chat.return_value = LLMResponse(
|
||||
content='{"result": "mocked"}',
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
gateway.register_provider("test", mock_provider)
|
||||
return gateway
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skill_registry():
|
||||
return SkillRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_registry():
|
||||
return ToolRegistry()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_llm_gateway, skill_registry, tool_registry):
|
||||
return create_app(
|
||||
llm_gateway=mock_llm_gateway,
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestMetricsEndpoint:
|
||||
"""GET /api/v1/metrics"""
|
||||
|
||||
def test_metrics_returns_200(self, client):
|
||||
response = client.get("/api/v1/metrics")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_metrics_has_required_sections(self, client):
|
||||
response = client.get("/api/v1/metrics")
|
||||
data = response.json()
|
||||
assert "tasks" in data
|
||||
assert "agents" in data
|
||||
assert "skills" in data
|
||||
assert "version" in data
|
||||
|
||||
def test_metrics_zero_values_when_empty(self, client):
|
||||
response = client.get("/api/v1/metrics")
|
||||
data = response.json()
|
||||
assert data["tasks"]["total_tasks"] == 0
|
||||
assert data["tasks"]["completed_tasks"] == 0
|
||||
assert data["tasks"]["failed_tasks"] == 0
|
||||
assert data["tasks"]["pending_tasks"] == 0
|
||||
assert data["agents"]["total_agents"] == 0
|
||||
assert data["agents"]["agent_names"] == []
|
||||
assert data["skills"]["total_skills"] == 0
|
||||
assert data["skills"]["skill_names"] == []
|
||||
|
||||
def test_metrics_with_registered_skill(self, client, skill_registry):
|
||||
skill_config = SkillConfig(
|
||||
name="metrics_skill",
|
||||
agent_type="test_type",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "Metrics Skill"},
|
||||
intent={"keywords": ["metrics"], "description": "A metrics skill"},
|
||||
)
|
||||
skill = Skill(config=skill_config)
|
||||
skill_registry.register(skill)
|
||||
|
||||
response = client.get("/api/v1/metrics")
|
||||
data = response.json()
|
||||
assert data["skills"]["total_skills"] == 1
|
||||
assert "metrics_skill" in data["skills"]["skill_names"]
|
||||
|
||||
def test_metrics_version(self, client):
|
||||
response = client.get("/api/v1/metrics")
|
||||
data = response.json()
|
||||
assert data["version"] == "2.0.0"
|
||||
|
||||
|
||||
# ── Enhanced Health Check Tests ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestEnhancedHealthCheck:
|
||||
"""GET /api/v1/health — enhanced with dependency checks"""
|
||||
|
||||
def test_health_returns_200(self, client):
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_includes_checks(self, client):
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert "checks" in data
|
||||
assert "redis" in data["checks"]
|
||||
assert "agent_pool" in data["checks"]
|
||||
assert "llm_gateway" in data["checks"]
|
||||
assert "skill_registry" in data["checks"]
|
||||
|
||||
def test_health_healthy_with_provider(self, client):
|
||||
"""With a registered LLM provider, status should be healthy"""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["version"] == "2.0.0"
|
||||
|
||||
def test_health_agent_pool_info(self, client):
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
pool_check = data["checks"]["agent_pool"]
|
||||
assert pool_check["status"] == "available"
|
||||
assert pool_check["size"] == 0
|
||||
|
||||
def test_health_skill_registry_info(self, client):
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
registry_check = data["checks"]["skill_registry"]
|
||||
assert registry_check["status"] == "available"
|
||||
assert registry_check["count"] == 0
|
||||
|
||||
def test_health_degraded_without_providers(self, skill_registry, tool_registry):
|
||||
"""Without LLM providers, status should be degraded"""
|
||||
gateway = LLMGateway() # No providers registered
|
||||
app = create_app(
|
||||
llm_gateway=gateway,
|
||||
skill_registry=skill_registry,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
client = TestClient(app)
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
assert data["checks"]["llm_gateway"] == "no_providers"
|
||||
|
||||
def test_health_redis_not_configured_for_memory_store(self, client):
|
||||
"""In-memory task store should report redis as not_configured"""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["checks"]["redis"] == "not_configured"
|
||||
|
||||
def test_health_llm_gateway_available_with_provider(self, client):
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["checks"]["llm_gateway"] == "available"
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
"""Tests for ReAct Prompt, Skill/Agent tool sync, MCP bridge, and execution modes"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
def _make_skill_config(execution_mode="react", **kwargs) -> SkillConfig:
|
||||
"""Helper to create a SkillConfig for testing."""
|
||||
defaults = {
|
||||
"name": "test_skill",
|
||||
"agent_type": "test",
|
||||
"task_mode": "llm_generate",
|
||||
"supported_tasks": ["test_task"],
|
||||
"prompt": {
|
||||
"identity": "You are a test assistant.",
|
||||
"context": "Context: ${topic}",
|
||||
"instructions": "Please help with: ${query}",
|
||||
"constraints": "Be concise.",
|
||||
"output_format": "Return JSON.",
|
||||
"examples": "Example: input -> output",
|
||||
},
|
||||
"execution_mode": execution_mode,
|
||||
"max_steps": 3,
|
||||
"intent": {
|
||||
"keywords": ["test", "demo"],
|
||||
"description": "A test skill",
|
||||
},
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SkillConfig.from_dict(defaults)
|
||||
|
||||
|
||||
def _make_task(**kwargs) -> TaskMessage:
|
||||
"""Helper to create a TaskMessage for testing."""
|
||||
defaults = {
|
||||
"task_id": "test-task-1",
|
||||
"agent_name": "test_skill",
|
||||
"task_type": "test_task",
|
||||
"priority": 0,
|
||||
"input_data": {"topic": "AI", "query": "What is AI?"},
|
||||
"callback_url": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return TaskMessage(**defaults)
|
||||
|
||||
|
||||
class TestReActPromptFullRendering:
|
||||
"""Test that ReAct mode uses full PromptTemplate.render() output."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_uses_full_prompt_template(self):
|
||||
"""ReAct mode should use PromptTemplate.render() to get all prompt sections,
|
||||
not just identity.
|
||||
|
||||
In ReAct mode, _handle_react() passes system_prompt to ReActEngine.execute(),
|
||||
which prepends it as a system message in the conversation passed to gateway.chat().
|
||||
So we check the 'messages' kwarg for a system message containing all sections.
|
||||
"""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# Mock LLMGateway to capture what messages are sent
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps({"answer": "test"})
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 10
|
||||
mock_response.has_tool_calls = False
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
await agent.handle_task(task)
|
||||
|
||||
# Verify the gateway was called
|
||||
mock_gateway.chat.assert_called_once()
|
||||
call_kwargs = mock_gateway.chat.call_args
|
||||
|
||||
# ReActEngine.execute() puts system_prompt as the first message in conversation
|
||||
messages = call_kwargs.kwargs.get("messages", [])
|
||||
assert len(messages) > 0, "No messages sent to gateway"
|
||||
|
||||
# First message should be the system message with all prompt sections
|
||||
system_msg = messages[0]
|
||||
assert system_msg["role"] == "system", f"First message is not system: {system_msg['role']}"
|
||||
system_content = system_msg["content"]
|
||||
assert "test assistant" in system_content, f"Identity missing from system message: {system_content}"
|
||||
assert "AI" in system_content, f"Context variable not resolved in system message: {system_content}"
|
||||
assert "concise" in system_content, f"Constraints missing from system message: {system_content}"
|
||||
|
||||
# Check that user messages contain instructions + output_format + examples
|
||||
user_content = " ".join(m.get("content", "") for m in messages if m["role"] != "system")
|
||||
assert "What is AI?" in user_content, f"Instructions variable not resolved: {user_content}"
|
||||
assert "JSON" in user_content, f"Output format missing: {user_content}"
|
||||
assert "Example" in user_content, f"Examples missing: {user_content}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_without_prompt_template(self):
|
||||
"""ReAct mode without prompt template should use input_data as fallback."""
|
||||
config = SkillConfig(
|
||||
name="no_prompt_skill",
|
||||
agent_type="test",
|
||||
task_mode="tool_call",
|
||||
supported_tasks=["test"],
|
||||
execution_mode="react",
|
||||
tools=["mock_tool"],
|
||||
)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
async def mock_func(**kwargs):
|
||||
return {"mock": True}
|
||||
|
||||
tool_registry.register(FunctionTool(name="mock_tool", description="Mock", func=mock_func))
|
||||
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"result": "ok"}'
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 5
|
||||
mock_response.has_tool_calls = False
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
)
|
||||
|
||||
task = _make_task(input_data={"message": "hello"})
|
||||
result = await agent.handle_task(task)
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
class TestSkillAgentToolSync:
|
||||
"""Test that Skill-bound tools are merged into Agent._tools."""
|
||||
|
||||
def test_skill_tools_merged_into_agent(self):
|
||||
"""When ConfigDrivenAgent receives a SkillConfig with tools,
|
||||
the Skill's bound tools should be merged into Agent._tools."""
|
||||
config = _make_skill_config(
|
||||
execution_mode="react",
|
||||
tools=["tool_a", "tool_b"],
|
||||
)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
async def mock_func(**kwargs):
|
||||
return {"mock": True}
|
||||
|
||||
tool_registry.register(FunctionTool(name="tool_a", description="Tool A", func=mock_func))
|
||||
tool_registry.register(FunctionTool(name="tool_b", description="Tool B", func=mock_func))
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
|
||||
# Agent should have both tools from the config
|
||||
tool_names = [t.name for t in agent._tools]
|
||||
assert "tool_a" in tool_names, f"tool_a not found in agent tools: {tool_names}"
|
||||
assert "tool_b" in tool_names, f"tool_b not found in agent tools: {tool_names}"
|
||||
|
||||
def test_skill_instance_tools_merged(self):
|
||||
"""When a Skill instance has tools bound via bind_tool(),
|
||||
those tools should be merged into Agent._tools."""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
|
||||
# Manually bind a tool to the skill instance
|
||||
async def extra_func(**kwargs):
|
||||
return {"extra": True}
|
||||
|
||||
extra_tool = FunctionTool(name="extra_tool", description="Extra", func=extra_func)
|
||||
agent._skill_instance.bind_tool(extra_tool)
|
||||
|
||||
# Simulate re-creating agent (in real flow, tools are merged during __init__)
|
||||
# For this test, verify the merge logic works
|
||||
initial_count = len(agent._tools)
|
||||
for tool in agent._skill_instance.tools:
|
||||
if not any(t.name == tool.name for t in agent._tools):
|
||||
agent.use_tool(tool)
|
||||
|
||||
tool_names = [t.name for t in agent._tools]
|
||||
assert "extra_tool" in tool_names
|
||||
assert len(agent._tools) == initial_count + 1
|
||||
|
||||
|
||||
class TestMCPBridge:
|
||||
"""Test MCP → ReAct bridge."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_servers_parameter_accepted(self):
|
||||
"""ConfigDrivenAgent should accept mcp_servers parameter."""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
mcp_servers={"test_server": "http://localhost:8080"},
|
||||
)
|
||||
|
||||
assert agent._mcp_servers == {"test_server": "http://localhost:8080"}
|
||||
assert agent._mcp_tools_registered is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_lazy_registration_on_task(self):
|
||||
"""MCP tools should be lazily registered on first task execution."""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"result": "ok"}'
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 5
|
||||
mock_response.has_tool_calls = False
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
mcp_servers={"test_server": "http://localhost:8080"},
|
||||
)
|
||||
|
||||
# Mock MCPClient to avoid real HTTP calls
|
||||
with patch("agentkit.mcp.client.MCPClient") as MockMCPClient:
|
||||
mock_client_instance = MagicMock()
|
||||
mock_client_instance.list_tools = AsyncMock(return_value=[
|
||||
{"name": "remote_tool", "description": "A remote tool"}
|
||||
])
|
||||
mock_mcp_tool = MagicMock()
|
||||
mock_mcp_tool.name = "remote_tool"
|
||||
mock_client_instance.as_tool = MagicMock(return_value=mock_mcp_tool)
|
||||
MockMCPClient.return_value = mock_client_instance
|
||||
|
||||
task = _make_task()
|
||||
await agent.handle_task(task)
|
||||
|
||||
# MCP tools should now be registered
|
||||
assert agent._mcp_tools_registered is True
|
||||
mock_client_instance.list_tools.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_registration_failure_graceful(self):
|
||||
"""MCP registration failure should not prevent task execution."""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = '{"result": "ok"}'
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 5
|
||||
mock_response.has_tool_calls = False
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
mcp_servers={"bad_server": "http://nonexistent:9999"},
|
||||
)
|
||||
|
||||
with patch("agentkit.mcp.client.MCPClient") as MockMCPClient:
|
||||
MockMCPClient.return_value.list_tools = AsyncMock(
|
||||
side_effect=Exception("Connection refused")
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
result = await agent.handle_task(task)
|
||||
# Should still complete despite MCP failure
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
class TestExecutionModes:
|
||||
"""Test execution_mode=react/direct/custom."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_single_llm_call(self):
|
||||
"""execution_mode=direct should make a single LLM call without ReAct loop."""
|
||||
config = _make_skill_config(execution_mode="direct")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps({"answer": "direct result"})
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 15
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
result = await agent.handle_task(task)
|
||||
|
||||
# Should call gateway.chat directly (not ReAct engine)
|
||||
mock_gateway.chat.assert_called_once()
|
||||
assert result == {"answer": "direct result"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_mode_with_skill_config(self):
|
||||
"""execution_mode=custom should use custom handler."""
|
||||
config = _make_skill_config(
|
||||
execution_mode="custom",
|
||||
custom_handler="test.handlers.mock_handler",
|
||||
)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
async def mock_handler(task):
|
||||
return {"custom": True, "task_id": task.task_id}
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
custom_handlers={"test.handlers.mock_handler": mock_handler},
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
result = await agent.handle_task(task)
|
||||
|
||||
assert result["custom"] is True
|
||||
assert result["task_id"] == "test-task-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_react_mode_uses_react_engine(self):
|
||||
"""execution_mode=react should use ReAct engine."""
|
||||
config = _make_skill_config(execution_mode="react")
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
mock_gateway = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps({"answer": "react result"})
|
||||
mock_response.usage = MagicMock()
|
||||
mock_response.usage.total_tokens = 20
|
||||
mock_response.has_tool_calls = False
|
||||
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_gateway=mock_gateway,
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
result = await agent.handle_task(task)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_to_task_mode_without_skill_config(self):
|
||||
"""Without SkillConfig, should fall back to task_mode."""
|
||||
config = AgentConfig(
|
||||
name="legacy_agent",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
supported_tasks=["test"],
|
||||
prompt={"identity": "Legacy agent"},
|
||||
)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
)
|
||||
|
||||
task = _make_task()
|
||||
result = await agent.handle_task(task)
|
||||
|
||||
# Should return rendered prompt (no LLM client)
|
||||
assert "messages" in result or isinstance(result, dict)
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""Tests for ServerConfig - configuration loading"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.server.config import ServerConfig, find_config_path, _resolve_env_vars, _deep_resolve
|
||||
|
||||
|
||||
class TestEnvVarResolution:
|
||||
"""Test ${VAR:-default} pattern resolution"""
|
||||
|
||||
def test_resolve_simple_var(self):
|
||||
os.environ["TEST_AK_KEY"] = "sk-123"
|
||||
assert _resolve_env_vars("${TEST_AK_KEY}") == "sk-123"
|
||||
del os.environ["TEST_AK_KEY"]
|
||||
|
||||
def test_resolve_var_with_default(self):
|
||||
# Var not set -> use default
|
||||
assert _resolve_env_vars("${TEST_MISSING_VAR:-fallback}") == "fallback"
|
||||
|
||||
def test_resolve_var_with_default_and_env_set(self):
|
||||
os.environ["TEST_AK_KEY"] = "sk-456"
|
||||
assert _resolve_env_vars("${TEST_AK_KEY:-fallback}") == "sk-456"
|
||||
del os.environ["TEST_AK_KEY"]
|
||||
|
||||
def test_resolve_non_string(self):
|
||||
assert _resolve_env_vars(42) == 42
|
||||
assert _resolve_env_vars(None) is None
|
||||
|
||||
def test_deep_resolve_dict(self):
|
||||
os.environ["TEST_AK_KEY"] = "sk-789"
|
||||
data = {"api_key": "${TEST_AK_KEY}", "port": 8001}
|
||||
result = _deep_resolve(data)
|
||||
assert result["api_key"] == "sk-789"
|
||||
assert result["port"] == 8001
|
||||
del os.environ["TEST_AK_KEY"]
|
||||
|
||||
def test_deep_resolve_nested(self):
|
||||
os.environ["TEST_AK_KEY"] = "sk-nested"
|
||||
data = {"llm": {"providers": {"openai": {"api_key": "${TEST_AK_KEY}"}}}}
|
||||
result = _deep_resolve(data)
|
||||
assert result["llm"]["providers"]["openai"]["api_key"] == "sk-nested"
|
||||
del os.environ["TEST_AK_KEY"]
|
||||
|
||||
|
||||
class TestServerConfigFromYaml:
|
||||
"""Test loading ServerConfig from YAML"""
|
||||
|
||||
def test_load_minimal_config(self):
|
||||
yaml_content = """
|
||||
server:
|
||||
host: "127.0.0.1"
|
||||
port: 9000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
assert config.host == "127.0.0.1"
|
||||
assert config.port == 9000
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_full_config(self):
|
||||
yaml_content = """
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8001
|
||||
workers: 4
|
||||
api_key: "test-key-123"
|
||||
rate_limit: 120
|
||||
|
||||
llm:
|
||||
default_provider: "openai"
|
||||
providers:
|
||||
openai:
|
||||
api_key: "sk-test"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
models:
|
||||
gpt-4o:
|
||||
alias: "default"
|
||||
gpt-4o-mini:
|
||||
alias: "fast"
|
||||
deepseek:
|
||||
api_key: "sk-deepseek"
|
||||
base_url: "https://api.deepseek.com/v1"
|
||||
models:
|
||||
deepseek-chat:
|
||||
alias: "deepseek"
|
||||
|
||||
skills:
|
||||
auto_discover: true
|
||||
paths:
|
||||
- "./skills"
|
||||
|
||||
logging:
|
||||
level: "DEBUG"
|
||||
format: "json"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
assert config.host == "0.0.0.0"
|
||||
assert config.port == 8001
|
||||
assert config.workers == 4
|
||||
assert config.api_key == "test-key-123"
|
||||
assert config.rate_limit == 120
|
||||
assert "openai" in config.llm_config.providers
|
||||
assert "deepseek" in config.llm_config.providers
|
||||
assert config.llm_config.providers["openai"].api_key == "sk-test"
|
||||
assert config.llm_config.model_aliases["default"] == "openai/gpt-4o"
|
||||
assert config.llm_config.model_aliases["fast"] == "openai/gpt-4o-mini"
|
||||
assert config.skill_paths == ["./skills"]
|
||||
assert config.auto_discover_skills is True
|
||||
assert config.log_level == "DEBUG"
|
||||
assert config.log_format == "json"
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_config_with_env_vars(self):
|
||||
os.environ["TEST_AK_OPENAI_KEY"] = "sk-from-env"
|
||||
yaml_content = """
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8001
|
||||
|
||||
llm:
|
||||
providers:
|
||||
openai:
|
||||
api_key: "${TEST_AK_OPENAI_KEY}"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
models:
|
||||
gpt-4o:
|
||||
alias: "default"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
assert config.llm_config.providers["openai"].api_key == "sk-from-env"
|
||||
del os.environ["TEST_AK_OPENAI_KEY"]
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
class TestServerConfigLoadSkillConfigs:
|
||||
"""Test loading skill configs from skill paths"""
|
||||
|
||||
def test_load_skills_from_directory(self):
|
||||
yaml_content = """
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8001
|
||||
|
||||
skills:
|
||||
paths:
|
||||
- "./skills"
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skills_dir = Path(tmpdir) / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
# Create a test skill YAML
|
||||
skill_yaml = skills_dir / "test_skill.yaml"
|
||||
skill_yaml.write_text("""
|
||||
name: test_skill
|
||||
agent_type: test
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- test_task
|
||||
prompt:
|
||||
identity: "Test skill"
|
||||
""")
|
||||
# Update yaml_content with absolute path
|
||||
yaml_content_updated = yaml_content.replace("./skills", str(skills_dir))
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, dir=tmpdir) as f:
|
||||
f.write(yaml_content_updated)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
configs = config.load_skill_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0].name == "test_skill"
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_skills_from_single_file(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skill_yaml = Path(tmpdir) / "my_skill.yaml"
|
||||
skill_yaml.write_text("""
|
||||
name: my_skill
|
||||
agent_type: test
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- test_task
|
||||
prompt:
|
||||
identity: "My skill"
|
||||
""")
|
||||
yaml_content = f"""
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8001
|
||||
|
||||
skills:
|
||||
paths:
|
||||
- "{skill_yaml}"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, dir=tmpdir) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
configs = config.load_skill_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0].name == "my_skill"
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_load_skills_skips_invalid(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
skills_dir = Path(tmpdir) / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
# Valid skill
|
||||
(skills_dir / "valid.yaml").write_text("""
|
||||
name: valid_skill
|
||||
agent_type: test
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- test
|
||||
prompt:
|
||||
identity: "Valid skill"
|
||||
""")
|
||||
# Invalid skill (missing required fields)
|
||||
(skills_dir / "invalid.yaml").write_text("not_a_valid: yaml")
|
||||
|
||||
yaml_content = f"""
|
||||
server:
|
||||
host: "0.0.0.0"
|
||||
port: 8001
|
||||
|
||||
skills:
|
||||
paths:
|
||||
- "{skills_dir}"
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False, dir=tmpdir) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
config = ServerConfig.from_yaml(f.name)
|
||||
|
||||
configs = config.load_skill_configs()
|
||||
assert len(configs) == 1
|
||||
assert configs[0].name == "valid_skill"
|
||||
os.unlink(f.name)
|
||||
|
||||
|
||||
class TestServerConfigLoadDotenv:
|
||||
"""Test loading .env file"""
|
||||
|
||||
def test_load_dotenv(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_file = Path(tmpdir) / ".env"
|
||||
env_file.write_text("MY_TEST_VAR=hello_world\n# comment\nEMPTY_VAR=\n")
|
||||
|
||||
config = ServerConfig()
|
||||
config.load_dotenv(str(env_file))
|
||||
|
||||
assert os.environ.get("MY_TEST_VAR") == "hello_world"
|
||||
# Cleanup
|
||||
del os.environ["MY_TEST_VAR"]
|
||||
|
||||
def test_load_dotenv_no_overwrite(self):
|
||||
os.environ["EXISTING_VAR"] = "original"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
env_file = Path(tmpdir) / ".env"
|
||||
env_file.write_text("EXISTING_VAR=should_not_overwrite\n")
|
||||
|
||||
config = ServerConfig()
|
||||
config.load_dotenv(str(env_file))
|
||||
|
||||
assert os.environ["EXISTING_VAR"] == "original"
|
||||
del os.environ["EXISTING_VAR"]
|
||||
|
||||
def test_load_dotenv_missing_file(self):
|
||||
config = ServerConfig()
|
||||
config.load_dotenv("/nonexistent/.env") # Should not raise
|
||||
|
||||
|
||||
class TestFindConfigPath:
|
||||
"""Test config file discovery"""
|
||||
|
||||
def test_explicit_path_exists(self):
|
||||
with tempfile.NamedTemporaryFile(suffix=".yaml", delete=False) as f:
|
||||
f.write(b"test: true")
|
||||
f.flush()
|
||||
result = find_config_path(f.name)
|
||||
assert result == f.name
|
||||
os.unlink(f.name)
|
||||
|
||||
def test_explicit_path_not_exists(self):
|
||||
result = find_config_path("/nonexistent/agentkit.yaml")
|
||||
assert result is None
|
||||
|
||||
def test_find_in_cwd(self):
|
||||
original_cwd = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
config_file = Path(tmpdir) / "agentkit.yaml"
|
||||
config_file.write_text("test: true")
|
||||
result = find_config_path()
|
||||
assert result is not None
|
||||
os.chdir(original_cwd)
|
||||
|
||||
def test_no_config_found(self):
|
||||
original_cwd = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
os.chdir(tmpdir)
|
||||
result = find_config_path()
|
||||
# May find home dir config, so just check it doesn't crash
|
||||
assert result is None or result.endswith("agentkit.yaml")
|
||||
os.chdir(original_cwd)
|
||||
|
|
@ -60,8 +60,9 @@ class TestHealthRoute:
|
|||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["status"] in ("ok", "healthy", "degraded")
|
||||
assert data["version"] == "2.0.0"
|
||||
assert "checks" in data
|
||||
|
||||
|
||||
class TestAgentRoutes:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,474 @@
|
|||
"""SKILL.md 解析器单元测试"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.skills.skill_md import SkillMdParser
|
||||
|
||||
|
||||
# ── 测试用 SKILL.md 内容 ──────────────────────────────────
|
||||
|
||||
FULL_SKILL_MD = '''\
|
||||
---
|
||||
name: content-generator
|
||||
description: "Generate high-quality content based on requirements"
|
||||
agent_type: content_generation
|
||||
execution_mode: react
|
||||
intent:
|
||||
keywords: ["generate", "write", "content"]
|
||||
description: "Content generation tasks"
|
||||
examples: ["Write a blog post", "Generate marketing copy"]
|
||||
quality_gate:
|
||||
required_fields: ["content"]
|
||||
min_word_count: 100
|
||||
max_retries: 3
|
||||
custom_validator: "validators.check_quality"
|
||||
---
|
||||
|
||||
# Trigger
|
||||
- User asks to generate content
|
||||
- Keywords: generate, write, create content
|
||||
|
||||
# Steps
|
||||
1. Analyze the user's requirements and target audience
|
||||
2. Research relevant topics and gather information
|
||||
3. Draft the content following best practices
|
||||
4. Review and refine the output
|
||||
|
||||
# Pitfalls
|
||||
- Don't generate overly generic content
|
||||
- Avoid plagiarism by always creating original content
|
||||
- Don't ignore the target audience's preferences
|
||||
|
||||
# Verification
|
||||
- Content meets minimum word count
|
||||
- Content is relevant to the user's request
|
||||
- Output format matches expectations
|
||||
'''
|
||||
|
||||
MINIMAL_SKILL_MD = '''\
|
||||
---
|
||||
name: minimal-skill
|
||||
description: "A minimal skill"
|
||||
agent_type: minimal
|
||||
---
|
||||
|
||||
# Steps
|
||||
1. Do something
|
||||
'''
|
||||
|
||||
NO_FRONTMATTER_MD = '''\
|
||||
# Steps
|
||||
1. Step one
|
||||
2. Step two
|
||||
'''
|
||||
|
||||
EMPTY_FRONTMATTER_MD = '''\
|
||||
---
|
||||
---
|
||||
|
||||
# Steps
|
||||
1. Step one
|
||||
'''
|
||||
|
||||
|
||||
def _write_skill_md(directory: str, filename: str, content: str) -> str:
|
||||
path = os.path.join(directory, filename)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
# ── SkillMdParser.parse 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestSkillMdParserParse:
|
||||
"""SkillMdParser.parse() 解析测试"""
|
||||
|
||||
def test_parse_full_skill_md(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "content.md", FULL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
|
||||
assert frontmatter["name"] == "content-generator"
|
||||
assert frontmatter["description"] == "Generate high-quality content based on requirements"
|
||||
assert frontmatter["agent_type"] == "content_generation"
|
||||
assert frontmatter["execution_mode"] == "react"
|
||||
assert frontmatter["intent"]["keywords"] == ["generate", "write", "content"]
|
||||
assert frontmatter["quality_gate"]["required_fields"] == ["content"]
|
||||
assert frontmatter["quality_gate"]["min_word_count"] == 100
|
||||
|
||||
assert "trigger" in sections
|
||||
assert "steps" in sections
|
||||
assert "pitfalls" in sections
|
||||
assert "verification" in sections
|
||||
assert "Analyze the user's requirements" in sections["steps"]
|
||||
assert "Don't generate overly generic content" in sections["pitfalls"]
|
||||
|
||||
assert "Trigger" not in body or "# Trigger" in body
|
||||
|
||||
def test_parse_minimal_skill_md(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "minimal.md", MINIMAL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
|
||||
assert frontmatter["name"] == "minimal-skill"
|
||||
assert frontmatter["description"] == "A minimal skill"
|
||||
assert "steps" in sections
|
||||
assert "Do something" in sections["steps"]
|
||||
|
||||
def test_parse_no_frontmatter(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "no_fm.md", NO_FRONTMATTER_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
|
||||
assert frontmatter == {}
|
||||
assert "steps" in sections
|
||||
|
||||
def test_parse_empty_frontmatter(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "empty_fm.md", EMPTY_FRONTMATTER_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
|
||||
assert frontmatter == {}
|
||||
assert "steps" in sections
|
||||
|
||||
def test_parse_missing_sections_graceful(self):
|
||||
content = """\
|
||||
---
|
||||
name: no-sections
|
||||
description: "No body sections"
|
||||
agent_type: test
|
||||
---
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "nosec.md", content)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
|
||||
assert frontmatter["name"] == "no-sections"
|
||||
assert sections == {}
|
||||
|
||||
|
||||
# ── SkillMdParser.to_skill_config 测试 ────────────────────
|
||||
|
||||
|
||||
class TestSkillMdToSkillConfig:
|
||||
"""SkillMdParser.to_skill_config() 转换测试"""
|
||||
|
||||
def test_to_skill_config_full(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "full.md", FULL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(frontmatter, sections, path)
|
||||
|
||||
assert config.name == "content-generator"
|
||||
assert config.agent_type == "content_generation"
|
||||
assert config.description == "Generate high-quality content based on requirements"
|
||||
assert config.execution_mode == "react"
|
||||
assert config.intent.keywords == ["generate", "write", "content"]
|
||||
assert config.intent.description == "Content generation tasks"
|
||||
assert config.intent.examples == ["Write a blog post", "Generate marketing copy"]
|
||||
assert config.quality_gate.required_fields == ["content"]
|
||||
assert config.quality_gate.min_word_count == 100
|
||||
assert config.quality_gate.max_retries == 3
|
||||
assert config.quality_gate.custom_validator == "validators.check_quality"
|
||||
assert config.prompt is not None
|
||||
assert "instructions" in config.prompt
|
||||
assert "constraints" in config.prompt
|
||||
assert "output_format" in config.prompt
|
||||
assert "context" in config.prompt
|
||||
|
||||
def test_to_skill_config_level_0_summary_only(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "level0.md", FULL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(
|
||||
frontmatter, sections, path, disclosure_level=0,
|
||||
)
|
||||
|
||||
assert config.name == "content-generator"
|
||||
assert config.description != ""
|
||||
assert config.disclosure_level == 0
|
||||
# Level 0: prompt 仅含 identity(概要信息)
|
||||
assert config.prompt is not None
|
||||
assert "identity" in config.prompt
|
||||
assert "instructions" not in config.prompt
|
||||
|
||||
def test_to_skill_config_level_1_full(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "level1.md", FULL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(
|
||||
frontmatter, sections, path, disclosure_level=1,
|
||||
)
|
||||
|
||||
assert config.name == "content-generator"
|
||||
assert config.disclosure_level == 1
|
||||
assert config.prompt is not None
|
||||
assert "instructions" in config.prompt
|
||||
|
||||
def test_to_skill_config_minimal(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "minimal.md", MINIMAL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(frontmatter, sections, path)
|
||||
|
||||
assert config.name == "minimal-skill"
|
||||
assert config.agent_type == "minimal"
|
||||
assert config.execution_mode == "react" # 默认值
|
||||
assert config.intent.keywords == []
|
||||
assert config.quality_gate.required_fields == []
|
||||
|
||||
def test_to_skill_config_no_frontmatter(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "no_fm.md", NO_FRONTMATTER_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
# 无 frontmatter 时 name 为空,无法创建有效的 SkillConfig
|
||||
# 验证解析结果正确即可
|
||||
assert frontmatter == {}
|
||||
assert "steps" in sections
|
||||
|
||||
def test_skill_md_path_stored(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "path_test.md", FULL_SKILL_MD)
|
||||
frontmatter, sections, body = SkillMdParser.parse(path)
|
||||
config = SkillMdParser.to_skill_config(frontmatter, sections, path)
|
||||
|
||||
assert config.skill_md_path == path
|
||||
|
||||
|
||||
# ── SkillConfig 新字段测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestSkillConfigNewFields:
|
||||
"""SkillConfig 新增 skill_md_path 和 disclosure_level 字段测试"""
|
||||
|
||||
def test_default_skill_md_path_is_none(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
)
|
||||
assert config.skill_md_path is None
|
||||
|
||||
def test_default_disclosure_level_is_zero(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
)
|
||||
assert config.disclosure_level == 0
|
||||
|
||||
def test_skill_md_path_set(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
skill_md_path="/path/to/skill.md",
|
||||
)
|
||||
assert config.skill_md_path == "/path/to/skill.md"
|
||||
|
||||
def test_disclosure_level_set(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
disclosure_level=2,
|
||||
)
|
||||
assert config.disclosure_level == 2
|
||||
|
||||
def test_to_dict_includes_new_fields(self):
|
||||
config = SkillConfig(
|
||||
name="test",
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": "test"},
|
||||
skill_md_path="/path/to/skill.md",
|
||||
disclosure_level=1,
|
||||
)
|
||||
d = config.to_dict()
|
||||
assert d["skill_md_path"] == "/path/to/skill.md"
|
||||
assert d["disclosure_level"] == 1
|
||||
|
||||
def test_from_dict_includes_new_fields(self):
|
||||
data = {
|
||||
"name": "test",
|
||||
"agent_type": "test",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "test"},
|
||||
"skill_md_path": "/path/to/skill.md",
|
||||
"disclosure_level": 2,
|
||||
}
|
||||
config = SkillConfig.from_dict(data)
|
||||
assert config.skill_md_path == "/path/to/skill.md"
|
||||
assert config.disclosure_level == 2
|
||||
|
||||
def test_from_dict_defaults_new_fields(self):
|
||||
data = {
|
||||
"name": "test",
|
||||
"agent_type": "test",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "test"},
|
||||
}
|
||||
config = SkillConfig.from_dict(data)
|
||||
assert config.skill_md_path is None
|
||||
assert config.disclosure_level == 0
|
||||
|
||||
|
||||
# ── SkillLoader.load_from_skill_md 测试 ───────────────────
|
||||
|
||||
|
||||
class TestSkillLoaderFromSkillMd:
|
||||
"""SkillLoader.load_from_skill_md() 加载测试"""
|
||||
|
||||
def test_load_from_skill_md_creates_skill(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "content.md", FULL_SKILL_MD)
|
||||
skill = loader.load_from_skill_md(path)
|
||||
|
||||
assert isinstance(skill, Skill)
|
||||
assert skill.name == "content-generator"
|
||||
assert skill.config.agent_type == "content_generation"
|
||||
assert skill.config.skill_md_path == path
|
||||
assert skill.config.disclosure_level == 1 # 默认 level=1
|
||||
|
||||
def test_load_from_skill_md_registers_in_registry(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "content.md", FULL_SKILL_MD)
|
||||
loader.load_from_skill_md(path)
|
||||
|
||||
assert registry.has_skill("content-generator")
|
||||
|
||||
def test_load_from_skill_md_level_0(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "content.md", FULL_SKILL_MD)
|
||||
skill = loader.load_from_skill_md(path, disclosure_level=0)
|
||||
|
||||
assert skill.config.disclosure_level == 0
|
||||
# Level 0: prompt 仅含 identity,不含 instructions
|
||||
assert skill.config.prompt is not None
|
||||
assert "identity" in skill.config.prompt
|
||||
assert "instructions" not in skill.config.prompt
|
||||
|
||||
def test_load_from_skill_md_level_1(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = _write_skill_md(tmpdir, "content.md", FULL_SKILL_MD)
|
||||
skill = loader.load_from_skill_md(path, disclosure_level=1)
|
||||
|
||||
assert skill.config.disclosure_level == 1
|
||||
assert skill.config.prompt is not None
|
||||
assert "instructions" in skill.config.prompt
|
||||
|
||||
def test_load_from_directory_includes_md_files(self):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
_write_skill_md(tmpdir, "skill.md", FULL_SKILL_MD)
|
||||
skills = loader.load_from_directory(tmpdir)
|
||||
|
||||
assert len(skills) == 1
|
||||
assert skills[0].name == "content-generator"
|
||||
|
||||
def test_load_from_directory_mixed_yaml_and_md(self):
|
||||
import yaml
|
||||
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# YAML 文件
|
||||
yaml_path = os.path.join(tmpdir, "yaml_skill.yaml")
|
||||
with open(yaml_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump({
|
||||
"name": "yaml_skill",
|
||||
"agent_type": "yaml",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"identity": "YAML 技能"},
|
||||
}, f)
|
||||
|
||||
# SKILL.md 文件
|
||||
_write_skill_md(tmpdir, "md_skill.md", FULL_SKILL_MD)
|
||||
|
||||
skills = loader.load_from_directory(tmpdir)
|
||||
assert len(skills) == 2
|
||||
names = [s.name for s in skills]
|
||||
assert "yaml_skill" in names
|
||||
assert "content-generator" in names
|
||||
|
||||
def test_load_from_directory_skips_invalid_md(self, caplog):
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# 无效的 MD(不是合法的 SKILL.md 格式,YAML 解析后缺少必要字段)
|
||||
invalid_md = "This is just plain text, not a valid SKILL.md at all."
|
||||
_write_skill_md(tmpdir, "invalid.md", invalid_md)
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
skills = loader.load_from_directory(tmpdir)
|
||||
|
||||
# 无效文件应被跳过(纯文本无 frontmatter,name 为空)
|
||||
assert len(skills) == 0
|
||||
|
||||
|
||||
# ── CLI skill create 测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestCliSkillCreate:
|
||||
"""CLI skill create 命令测试"""
|
||||
|
||||
def test_create_generates_valid_skill_md(self):
|
||||
from typer.testing import CliRunner
|
||||
from agentkit.cli.skill import skill_app
|
||||
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = runner.invoke(skill_app, ["create", "my-skill", "--output-dir", tmpdir])
|
||||
assert result.exit_code == 0
|
||||
|
||||
output_path = os.path.join(tmpdir, "my-skill.md")
|
||||
assert os.path.exists(output_path)
|
||||
|
||||
# 验证生成的文件可以被解析
|
||||
frontmatter, sections, body = SkillMdParser.parse(output_path)
|
||||
assert frontmatter["name"] == "my-skill"
|
||||
assert "steps" in sections
|
||||
assert "pitfalls" in sections
|
||||
assert "verification" in sections
|
||||
|
||||
def test_create_template_is_loadable(self):
|
||||
from typer.testing import CliRunner
|
||||
from agentkit.cli.skill import skill_app
|
||||
|
||||
runner = CliRunner()
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
runner.invoke(skill_app, ["create", "loadable-skill", "--output-dir", tmpdir])
|
||||
|
||||
output_path = os.path.join(tmpdir, "loadable-skill.md")
|
||||
registry = SkillRegistry()
|
||||
loader = SkillLoader(skill_registry=registry)
|
||||
skill = loader.load_from_skill_md(output_path)
|
||||
|
||||
assert skill.name == "loadable-skill"
|
||||
|
|
@ -0,0 +1,450 @@
|
|||
"""SkillPipeline 单元测试"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.skills.pipeline import SkillPipeline
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
|
||||
async def _mock_agent_factory(skill_name: str, input_data: dict) -> dict:
|
||||
"""Mock agent factory: 返回包含 skill_name 和输入数据的字典"""
|
||||
return {"skill": skill_name, "processed": True, **input_data}
|
||||
|
||||
|
||||
async def _failing_agent_factory(skill_name: str, input_data: dict) -> dict:
|
||||
"""Mock agent factory: 特定 skill 抛出异常"""
|
||||
if skill_name == "failing_skill":
|
||||
raise RuntimeError("Skill execution failed")
|
||||
return {"skill": skill_name, "processed": True, **input_data}
|
||||
|
||||
|
||||
async def _transform_agent_factory(skill_name: str, input_data: dict) -> dict:
|
||||
"""Mock agent factory: 根据技能名做不同转换"""
|
||||
if skill_name == "extract":
|
||||
return {"title": input_data.get("raw_text", "").split()[0], "score": 0.9}
|
||||
if skill_name == "enrich":
|
||||
return {"title": input_data.get("title", ""), "enriched": True}
|
||||
if skill_name == "format":
|
||||
return {"result": f"Formatted: {input_data.get('title', '')}", "enriched": input_data.get("enriched", False)}
|
||||
return {"skill": skill_name, **input_data}
|
||||
|
||||
|
||||
# ---- SkillPipeline 核心测试 ----
|
||||
|
||||
|
||||
class TestSkillPipelineSequential:
|
||||
"""顺序执行测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_three_skills(self):
|
||||
"""3 个 Skill 顺序执行,输出在步骤间传递"""
|
||||
pipeline = SkillPipeline(
|
||||
name="seq_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b"},
|
||||
{"skill_name": "skill_c"},
|
||||
],
|
||||
)
|
||||
|
||||
result = await pipeline.execute(
|
||||
input_data={"query": "hello"},
|
||||
agent_factory=_mock_agent_factory,
|
||||
)
|
||||
|
||||
assert result["pipeline"] == "seq_pipeline"
|
||||
assert len(result["steps"]) == 3
|
||||
assert result["steps"][0]["status"] == "success"
|
||||
assert result["steps"][0]["skill"] == "skill_a"
|
||||
assert result["steps"][1]["status"] == "success"
|
||||
assert result["steps"][1]["skill"] == "skill_b"
|
||||
assert result["steps"][2]["status"] == "success"
|
||||
assert result["steps"][2]["skill"] == "skill_c"
|
||||
|
||||
# 验证输出传递:第二步输入包含第一步输出
|
||||
assert result["steps"][1]["output"]["query"] == "hello"
|
||||
assert result["steps"][1]["output"]["processed"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_passes_between_steps(self):
|
||||
"""输出在步骤间正确传递"""
|
||||
pipeline = SkillPipeline(
|
||||
name="transform_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "extract"},
|
||||
{"skill_name": "enrich"},
|
||||
{"skill_name": "format"},
|
||||
],
|
||||
)
|
||||
|
||||
result = await pipeline.execute(
|
||||
input_data={"raw_text": "Hello World"},
|
||||
agent_factory=_transform_agent_factory,
|
||||
)
|
||||
|
||||
# 第一步: extract → {"title": "Hello", "score": 0.9}
|
||||
assert result["steps"][0]["output"]["title"] == "Hello"
|
||||
assert result["steps"][0]["output"]["score"] == 0.9
|
||||
|
||||
# 第二步: enrich → {"title": "Hello", "enriched": True}
|
||||
assert result["steps"][1]["output"]["title"] == "Hello"
|
||||
assert result["steps"][1]["output"]["enriched"] is True
|
||||
|
||||
# 第三步: format → {"result": "Formatted: Hello", "enriched": True}
|
||||
assert result["steps"][2]["output"]["result"] == "Formatted: Hello"
|
||||
assert result["steps"][2]["output"]["enriched"] is True
|
||||
|
||||
# final_output 是最后一步的输出
|
||||
assert result["final_output"]["result"] == "Formatted: Hello"
|
||||
|
||||
|
||||
class TestSkillPipelineConditional:
|
||||
"""条件分支测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condition_met_executes_step(self):
|
||||
"""条件满足时执行步骤"""
|
||||
pipeline = SkillPipeline(
|
||||
name="cond_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b", "condition": "status == 'ok'"},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"status": "ok", "data": "test"}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
|
||||
assert result["steps"][0]["status"] == "success"
|
||||
assert result["steps"][1]["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_condition_not_met_skips_step(self):
|
||||
"""条件不满足时跳过步骤"""
|
||||
pipeline = SkillPipeline(
|
||||
name="cond_pipeline_skip",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b", "condition": "status == 'ok'"},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"status": "error", "data": "test"}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
|
||||
assert result["steps"][0]["status"] == "success"
|
||||
assert result["steps"][1]["status"] == "skipped"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_condition(self):
|
||||
"""数值条件判断"""
|
||||
pipeline = SkillPipeline(
|
||||
name="num_cond_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b", "condition": "score > 0.5"},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"score": 0.9}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
assert result["steps"][1]["status"] == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_numeric_condition_not_met(self):
|
||||
"""数值条件不满足时跳过"""
|
||||
pipeline = SkillPipeline(
|
||||
name="num_cond_pipeline_fail",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b", "condition": "score > 0.5"},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"score": 0.3}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
assert result["steps"][1]["status"] == "skipped"
|
||||
|
||||
|
||||
class TestSkillPipelineFailure:
|
||||
"""Pipeline 失败测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_failure_stops_pipeline(self):
|
||||
"""步骤失败时中止 Pipeline"""
|
||||
pipeline = SkillPipeline(
|
||||
name="fail_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "failing_skill"},
|
||||
{"skill_name": "skill_c"},
|
||||
],
|
||||
)
|
||||
|
||||
result = await pipeline.execute(
|
||||
input_data={"query": "test"},
|
||||
agent_factory=_failing_agent_factory,
|
||||
)
|
||||
|
||||
assert len(result["steps"]) == 2
|
||||
assert result["steps"][0]["status"] == "success"
|
||||
assert result["steps"][1]["status"] == "failed"
|
||||
assert result["steps"][1]["skill"] == "failing_skill"
|
||||
assert "Skill execution failed" in result["steps"][1]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_registry_no_factory_marks_step_failed(self):
|
||||
"""无 registry 也无 factory 时步骤标记为 failed"""
|
||||
pipeline = SkillPipeline(
|
||||
name="no_exec_pipeline",
|
||||
steps=[{"skill_name": "skill_a"}],
|
||||
)
|
||||
|
||||
result = await pipeline.execute(input_data={})
|
||||
|
||||
assert len(result["steps"]) == 1
|
||||
assert result["steps"][0]["status"] == "failed"
|
||||
assert "no agent_factory or skill_registry" in result["steps"][0]["error"]
|
||||
|
||||
|
||||
class TestSkillPipelineEmpty:
|
||||
"""空 Pipeline 测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_pipeline(self):
|
||||
"""空步骤列表返回空结果"""
|
||||
pipeline = SkillPipeline(name="empty_pipeline", steps=[])
|
||||
|
||||
result = await pipeline.execute(input_data={"key": "value"})
|
||||
|
||||
assert result["pipeline"] == "empty_pipeline"
|
||||
assert result["steps"] == []
|
||||
assert result["final_output"] == {"key": "value"}
|
||||
|
||||
|
||||
class TestSkillPipelineInputMapping:
|
||||
"""输入映射测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_mapping(self):
|
||||
"""将上一步输出字段映射到下一步输入字段"""
|
||||
pipeline = SkillPipeline(
|
||||
name="mapping_pipeline",
|
||||
steps=[
|
||||
{"skill_name": "extract"},
|
||||
{
|
||||
"skill_name": "enrich",
|
||||
"input_mapping": {"title": "title"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result = await pipeline.execute(
|
||||
input_data={"raw_text": "Hello World"},
|
||||
agent_factory=_transform_agent_factory,
|
||||
)
|
||||
|
||||
# 第一步输出 {"title": "Hello", "score": 0.9}
|
||||
# 映射后第二步输入 {"title": "Hello"}
|
||||
assert result["steps"][1]["output"]["title"] == "Hello"
|
||||
assert result["steps"][1]["output"]["enriched"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_path_mapping(self):
|
||||
"""嵌套路径映射"""
|
||||
pipeline = SkillPipeline(
|
||||
name="nested_mapping",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{
|
||||
"skill_name": "skill_b",
|
||||
"input_mapping": {"name": "user.name"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"user": {"name": "Alice"}, "age": 30}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
|
||||
# 第二步输入应为 {"name": "Alice"}
|
||||
assert result["steps"][1]["output"]["name"] == "Alice"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mapping_missing_field_omitted(self):
|
||||
"""映射字段不存在时省略该字段"""
|
||||
pipeline = SkillPipeline(
|
||||
name="missing_mapping",
|
||||
steps=[
|
||||
{"skill_name": "skill_a"},
|
||||
{
|
||||
"skill_name": "skill_b",
|
||||
"input_mapping": {"title": "nonexistent.field"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
async def factory(name, data):
|
||||
if name == "skill_a":
|
||||
return {"other": "data"}
|
||||
return {"skill": name, **data}
|
||||
|
||||
result = await pipeline.execute(input_data={}, agent_factory=factory)
|
||||
|
||||
# 映射字段不存在,第二步输入为空字典
|
||||
assert result["steps"][1]["status"] == "success"
|
||||
|
||||
|
||||
class TestSkillPipelineRegistry:
|
||||
"""SkillPipeline 在 SkillRegistry 中的注册与查询"""
|
||||
|
||||
def test_register_pipeline(self):
|
||||
registry = SkillRegistry()
|
||||
pipeline = SkillPipeline(name="test_pipe", steps=[{"skill_name": "a"}])
|
||||
registry.register_pipeline(pipeline)
|
||||
assert registry.get_pipeline("test_pipe") is pipeline
|
||||
|
||||
def test_get_pipeline_not_found(self):
|
||||
registry = SkillRegistry()
|
||||
assert registry.get_pipeline("nonexistent") is None
|
||||
|
||||
def test_list_pipelines(self):
|
||||
registry = SkillRegistry()
|
||||
registry.register_pipeline(SkillPipeline(name="p1", steps=[]))
|
||||
registry.register_pipeline(SkillPipeline(name="p2", steps=[]))
|
||||
names = registry.list_pipelines()
|
||||
assert "p1" in names
|
||||
assert "p2" in names
|
||||
|
||||
def test_list_pipelines_empty(self):
|
||||
registry = SkillRegistry()
|
||||
assert registry.list_pipelines() == []
|
||||
|
||||
def test_unregister_pipeline(self):
|
||||
registry = SkillRegistry()
|
||||
registry.register_pipeline(SkillPipeline(name="p1", steps=[]))
|
||||
registry.unregister_pipeline("p1")
|
||||
assert registry.get_pipeline("p1") is None
|
||||
|
||||
def test_unregister_pipeline_nonexistent(self):
|
||||
"""注销不存在的 Pipeline 不抛异常"""
|
||||
registry = SkillRegistry()
|
||||
registry.unregister_pipeline("nonexistent")
|
||||
|
||||
def test_register_pipeline_overwrites(self):
|
||||
"""同名 Pipeline 覆盖注册"""
|
||||
registry = SkillRegistry()
|
||||
p1 = SkillPipeline(name="dup", steps=[{"skill_name": "a"}])
|
||||
p2 = SkillPipeline(name="dup", steps=[{"skill_name": "b"}])
|
||||
registry.register_pipeline(p1)
|
||||
registry.register_pipeline(p2)
|
||||
assert registry.get_pipeline("dup") is p2
|
||||
|
||||
|
||||
class TestSkillPipelineAPI:
|
||||
"""Pipeline API 端点测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
from agentkit.server.app import create_app
|
||||
|
||||
application = create_app()
|
||||
return application
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, app):
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pipeline(self, client):
|
||||
response = await client.post(
|
||||
"/api/v1/skills/pipelines",
|
||||
json={
|
||||
"name": "test_pipe",
|
||||
"steps": [
|
||||
{"skill_name": "skill_a"},
|
||||
{"skill_name": "skill_b"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "test_pipe"
|
||||
assert len(data["steps"]) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pipeline_missing_skill_name(self, client):
|
||||
response = await client.post(
|
||||
"/api/v1/skills/pipelines",
|
||||
json={
|
||||
"name": "bad_pipe",
|
||||
"steps": [{"no_skill_name": "oops"}],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pipelines_empty(self, client):
|
||||
response = await client.get("/api/v1/skills/pipelines")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pipelines_after_create(self, client):
|
||||
await client.post(
|
||||
"/api/v1/skills/pipelines",
|
||||
json={"name": "pipe1", "steps": [{"skill_name": "a"}]},
|
||||
)
|
||||
response = await client.get("/api/v1/skills/pipelines")
|
||||
assert response.status_code == 200
|
||||
assert "pipe1" in response.json()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_not_found(self, client):
|
||||
response = await client.post(
|
||||
"/api/v1/skills/pipelines/nonexistent/execute",
|
||||
json={"input_data": {}},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_pipeline_no_executor(self, client):
|
||||
"""Pipeline 存在但 registry 中无 Skill 时步骤标记为 failed"""
|
||||
await client.post(
|
||||
"/api/v1/skills/pipelines",
|
||||
json={"name": "exec_pipe", "steps": [{"skill_name": "missing_skill"}]},
|
||||
)
|
||||
response = await client.post(
|
||||
"/api/v1/skills/pipelines/exec_pipe/execute",
|
||||
json={"input_data": {"query": "test"}},
|
||||
)
|
||||
# Pipeline 执行返回 200,但步骤标记为 failed
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["steps"][0]["status"] == "failed"
|
||||
|
|
@ -0,0 +1,315 @@
|
|||
"""RedisTaskStore unit tests - uses mock Redis (no real Redis required)"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.protocol import TaskStatus
|
||||
from agentkit.server.task_store import (
|
||||
InMemoryTaskStore,
|
||||
RedisTaskStore,
|
||||
TaskRecord,
|
||||
TaskStore,
|
||||
create_task_store,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# Helpers – lightweight fake Redis for unit tests
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class FakeRedis:
|
||||
"""Minimal in-memory fake that satisfies the RedisTaskStore interface."""
|
||||
|
||||
def __init__(self):
|
||||
self._data: dict[str, str] = {}
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url, **kwargs):
|
||||
return cls()
|
||||
|
||||
async def get(self, key):
|
||||
return self._data.get(key)
|
||||
|
||||
async def set(self, key, value, ex=None, **kwargs):
|
||||
self._data[key] = value
|
||||
|
||||
async def delete(self, key):
|
||||
self._data.pop(key, None)
|
||||
|
||||
async def mget(self, keys):
|
||||
return [self._data.get(k) for k in keys]
|
||||
|
||||
async def scan(self, cursor=0, match=None, count=200):
|
||||
"""Simplified SCAN – returns all matching keys in one batch."""
|
||||
import fnmatch
|
||||
|
||||
pattern = match or "*"
|
||||
matched = [k for k in self._data if fnmatch.fnmatch(k, pattern)]
|
||||
# cursor=0 means "done"
|
||||
return (0, matched)
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def _make_redis_store(fake_redis: FakeRedis | None = None) -> RedisTaskStore:
|
||||
"""Build a RedisTaskStore with a FakeRedis injected."""
|
||||
store = RedisTaskStore(redis_url="redis://fake/0")
|
||||
if fake_redis is None:
|
||||
fake_redis = FakeRedis()
|
||||
store._redis = fake_redis
|
||||
return store
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# TaskRecord.from_dict round-trip
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestTaskRecordRoundTrip:
|
||||
"""Verify TaskRecord serialisation / deserialisation."""
|
||||
|
||||
def test_to_dict_from_dict_round_trip(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
record = TaskRecord(
|
||||
task_id="t1",
|
||||
agent_name="agent_a",
|
||||
skill_name="skill_x",
|
||||
input_data={"query": "hello"},
|
||||
status=TaskStatus.RUNNING,
|
||||
output_data={"result": "world"},
|
||||
error_message=None,
|
||||
created_at=now,
|
||||
started_at=now,
|
||||
completed_at=None,
|
||||
progress=0.5,
|
||||
progress_message="Halfway",
|
||||
metadata={"key": "val"},
|
||||
)
|
||||
restored = TaskRecord.from_dict(record.to_dict())
|
||||
assert restored.task_id == record.task_id
|
||||
assert restored.agent_name == record.agent_name
|
||||
assert restored.skill_name == record.skill_name
|
||||
assert restored.input_data == record.input_data
|
||||
assert restored.status == record.status
|
||||
assert restored.output_data == record.output_data
|
||||
assert restored.progress == record.progress
|
||||
assert restored.progress_message == record.progress_message
|
||||
assert restored.metadata == record.metadata
|
||||
|
||||
def test_from_dict_with_none_fields(self):
|
||||
data = {
|
||||
"task_id": "t2",
|
||||
"agent_name": "b",
|
||||
"skill_name": None,
|
||||
"input_data": {},
|
||||
"status": "pending",
|
||||
"output_data": None,
|
||||
"error_message": None,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"progress": 0.0,
|
||||
"progress_message": "",
|
||||
"metadata": {},
|
||||
}
|
||||
record = TaskRecord.from_dict(data)
|
||||
assert record.skill_name is None
|
||||
assert record.started_at is None
|
||||
assert record.completed_at is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# RedisTaskStore – happy path
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRedisTaskStoreHappyPath:
|
||||
"""Core CRUD operations on RedisTaskStore with mock Redis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get(self):
|
||||
store = _make_redis_store()
|
||||
record = await store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
|
||||
assert record.task_id == "t1"
|
||||
assert record.agent_name == "agent_a"
|
||||
assert record.skill_name == "skill_x"
|
||||
assert record.input_data == {"q": "hello"}
|
||||
assert record.status == TaskStatus.PENDING
|
||||
|
||||
fetched = await store.get("t1")
|
||||
assert fetched is not None
|
||||
assert fetched.task_id == "t1"
|
||||
assert fetched.agent_name == "agent_a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_changes_fields(self):
|
||||
store = _make_redis_store()
|
||||
await store.create("t1", "agent_a", {})
|
||||
now = datetime.now(timezone.utc)
|
||||
updated = await store.update_status(
|
||||
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway",
|
||||
)
|
||||
assert updated.status == TaskStatus.RUNNING
|
||||
assert updated.progress == 0.5
|
||||
assert updated.progress_message == "Halfway"
|
||||
|
||||
# Verify persistence
|
||||
fetched = await store.get("t1")
|
||||
assert fetched is not None
|
||||
assert fetched.status == TaskStatus.RUNNING
|
||||
assert fetched.progress == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_sorted_by_created_at_desc(self):
|
||||
store = _make_redis_store()
|
||||
await store.create("t1", "agent_a", {})
|
||||
await store.create("t2", "agent_b", {})
|
||||
tasks = await store.list_tasks()
|
||||
assert len(tasks) == 2
|
||||
# Most recent first (t2 created after t1)
|
||||
assert tasks[0].task_id == "t2"
|
||||
assert tasks[1].task_id == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_filtered_by_status(self):
|
||||
store = _make_redis_store()
|
||||
await store.create("t1", "agent_a", {})
|
||||
await store.create("t2", "agent_b", {})
|
||||
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||||
tasks = await store.list_tasks(status=TaskStatus.COMPLETED)
|
||||
assert len(tasks) == 1
|
||||
assert tasks[0].task_id == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_respects_limit(self):
|
||||
store = _make_redis_store()
|
||||
for i in range(5):
|
||||
await store.create(f"t{i}", "agent_a", {})
|
||||
tasks = await store.list_tasks(limit=3)
|
||||
assert len(tasks) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_size_returns_count(self):
|
||||
store = _make_redis_store()
|
||||
assert await store.size == 0
|
||||
await store.create("t1", "agent_a", {})
|
||||
assert await store.size == 1
|
||||
await store.create("t2", "agent_b", {})
|
||||
assert await store.size == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_cleanup_is_noop(self):
|
||||
store = _make_redis_store()
|
||||
# Should not raise
|
||||
await store.start_cleanup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cleanup_closes_redis(self):
|
||||
fake = FakeRedis()
|
||||
store = _make_redis_store(fake)
|
||||
await store.stop_cleanup()
|
||||
assert store._redis is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# RedisTaskStore – error / edge cases
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRedisTaskStoreErrors:
|
||||
"""Error and edge-case handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_returns_none(self):
|
||||
store = _make_redis_store()
|
||||
result = await store.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_nonexistent_raises_keyerror(self):
|
||||
store = _make_redis_store()
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
await store.update_status("nonexistent", TaskStatus.RUNNING)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_records_evicts_oldest_completed(self):
|
||||
fake = FakeRedis()
|
||||
store = _make_redis_store(fake)
|
||||
store._max_records = 2
|
||||
|
||||
await store.create("t1", "agent_a", {})
|
||||
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||||
await store.create("t2", "agent_b", {})
|
||||
# t3 should evict t1 (oldest completed)
|
||||
await store.create("t3", "agent_c", {})
|
||||
assert await store.get("t1") is None
|
||||
assert await store.get("t2") is not None
|
||||
assert await store.get("t3") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_records_full_no_completed_raises(self):
|
||||
fake = FakeRedis()
|
||||
store = _make_redis_store(fake)
|
||||
store._max_records = 1
|
||||
|
||||
await store.create("t1", "agent_a", {})
|
||||
# All tasks are PENDING, no completed to evict
|
||||
with pytest.raises(RuntimeError, match="full"):
|
||||
await store.create("t2", "agent_b", {})
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# TTL expiry (simulated by removing key from fake Redis)
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestRedisTaskStoreTTL:
|
||||
"""Simulate TTL expiry by manually removing keys from FakeRedis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_key_returns_none(self):
|
||||
fake = FakeRedis()
|
||||
store = _make_redis_store(fake)
|
||||
await store.create("t1", "agent_a", {})
|
||||
# Simulate TTL expiry: remove key from fake Redis
|
||||
fake._data.pop(store._key("t1"))
|
||||
result = await store.get("t1")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
# create_task_store factory
|
||||
# ═══════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestCreateTaskStore:
|
||||
"""Factory function tests."""
|
||||
|
||||
def test_default_backend_is_memory(self):
|
||||
store = create_task_store()
|
||||
assert isinstance(store, InMemoryTaskStore)
|
||||
|
||||
def test_explicit_memory_backend(self):
|
||||
store = create_task_store(backend="memory")
|
||||
assert isinstance(store, InMemoryTaskStore)
|
||||
|
||||
def test_redis_backend_returns_redis_task_store(self):
|
||||
store = create_task_store(backend="redis", redis_url="redis://localhost:6379/0")
|
||||
assert isinstance(store, RedisTaskStore)
|
||||
|
||||
def test_redis_unavailable_falls_back_to_memory(self):
|
||||
"""If redis.asyncio import fails, factory falls back to InMemoryTaskStore."""
|
||||
with patch.dict("sys.modules", {"redis.asyncio": None}):
|
||||
# Force import failure
|
||||
with patch("builtins.__import__", side_effect=ImportError("no redis")):
|
||||
store = create_task_store(backend="redis")
|
||||
assert isinstance(store, InMemoryTaskStore)
|
||||
|
||||
def test_backward_compat_alias(self):
|
||||
"""TaskStore is an alias for InMemoryTaskStore."""
|
||||
assert TaskStore is InMemoryTaskStore
|
||||
|
|
@ -0,0 +1,482 @@
|
|||
"""TraceRecorder 单元测试"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.trace import ExecutionTrace, TraceRecorder, TraceStep
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
# ── Test Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
class FakeTool(Tool):
|
||||
"""用于测试的 Fake Tool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "fake_tool",
|
||||
description: str = "A fake tool for testing",
|
||||
result: dict | None = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._result = result or {"status": "ok"}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
return self._result
|
||||
|
||||
|
||||
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||||
"""创建一个 mock LLMGateway,按顺序返回给定响应"""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=responses)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_response(
|
||||
content: str = "",
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 20,
|
||||
) -> LLMResponse:
|
||||
"""快速构造 LLMResponse"""
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
),
|
||||
tool_calls=tool_calls or [],
|
||||
)
|
||||
|
||||
|
||||
# ── TraceStep Tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestTraceStep:
|
||||
"""TraceStep 数据类测试"""
|
||||
|
||||
def test_to_dict_with_all_fields(self):
|
||||
step = TraceStep(
|
||||
step=1,
|
||||
action="tool_call",
|
||||
tool_name="search",
|
||||
input_data={"query": "test"},
|
||||
output_data={"results": ["found"]},
|
||||
duration_ms=100,
|
||||
tokens_used=50,
|
||||
error=None,
|
||||
)
|
||||
d = step.to_dict()
|
||||
assert d["step"] == 1
|
||||
assert d["action"] == "tool_call"
|
||||
assert d["tool_name"] == "search"
|
||||
assert d["input_data"] == {"query": "test"}
|
||||
assert d["output_data"] == {"results": ["found"]}
|
||||
assert d["duration_ms"] == 100
|
||||
assert d["tokens_used"] == 50
|
||||
assert "error" not in d
|
||||
|
||||
def test_to_dict_omits_none_fields(self):
|
||||
step = TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30)
|
||||
d = step.to_dict()
|
||||
assert "tool_name" not in d
|
||||
assert "input_data" not in d
|
||||
assert "output_data" not in d
|
||||
assert "error" not in d
|
||||
|
||||
def test_to_dict_includes_error_when_present(self):
|
||||
step = TraceStep(step=1, action="tool_call", error="Tool not found")
|
||||
d = step.to_dict()
|
||||
assert d["error"] == "Tool not found"
|
||||
|
||||
|
||||
# ── ExecutionTrace Tests ─────────────────────────────────
|
||||
|
||||
|
||||
class TestExecutionTrace:
|
||||
"""ExecutionTrace 数据类测试"""
|
||||
|
||||
def test_to_dict(self):
|
||||
trace = ExecutionTrace(
|
||||
task_id="t1",
|
||||
agent_name="agent1",
|
||||
skill_name="search_skill",
|
||||
steps=[
|
||||
TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30),
|
||||
TraceStep(step=1, action="tool_call", tool_name="search", duration_ms=100, tokens_used=0),
|
||||
],
|
||||
total_duration_ms=150,
|
||||
total_tokens=30,
|
||||
outcome="success",
|
||||
quality_score=0.9,
|
||||
)
|
||||
d = trace.to_dict()
|
||||
assert d["task_id"] == "t1"
|
||||
assert d["agent_name"] == "agent1"
|
||||
assert d["skill_name"] == "search_skill"
|
||||
assert len(d["steps"]) == 2
|
||||
assert d["total_duration_ms"] == 150
|
||||
assert d["total_tokens"] == 30
|
||||
assert d["outcome"] == "success"
|
||||
assert d["quality_score"] == 0.9
|
||||
|
||||
|
||||
# ── TraceRecorder Happy Path Tests ───────────────────────
|
||||
|
||||
|
||||
class TestTraceRecorderHappyPath:
|
||||
"""TraceRecorder 正常流程测试"""
|
||||
|
||||
def test_start_record_end_returns_trace(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(task_id="t1", agent_name="agent1")
|
||||
recorder.record_step(
|
||||
step=1,
|
||||
action="llm_call",
|
||||
duration_ms=50,
|
||||
tokens_used=30,
|
||||
)
|
||||
recorder.record_step(
|
||||
step=1,
|
||||
action="tool_call",
|
||||
tool_name="search",
|
||||
input_data={"query": "test"},
|
||||
output_data={"results": ["found"]},
|
||||
duration_ms=100,
|
||||
)
|
||||
trace = recorder.end_trace(outcome="success", quality_score=0.9)
|
||||
|
||||
assert isinstance(trace, ExecutionTrace)
|
||||
assert trace.task_id == "t1"
|
||||
assert trace.agent_name == "agent1"
|
||||
assert trace.outcome == "success"
|
||||
assert trace.quality_score == 0.9
|
||||
assert len(trace.steps) == 2
|
||||
assert trace.steps[0].action == "llm_call"
|
||||
assert trace.steps[1].action == "tool_call"
|
||||
assert trace.steps[1].tool_name == "search"
|
||||
|
||||
def test_multiple_steps_recorded_in_order(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(task_id="t2", agent_name="agent2")
|
||||
recorder.record_step(step=1, action="llm_call", tokens_used=100)
|
||||
recorder.record_step(step=1, action="tool_call", tool_name="calc", tokens_used=0)
|
||||
recorder.record_step(step=2, action="llm_call", tokens_used=80)
|
||||
recorder.record_step(step=2, action="final_answer", tokens_used=0)
|
||||
trace = recorder.end_trace()
|
||||
|
||||
assert len(trace.steps) == 4
|
||||
assert trace.steps[0].action == "llm_call"
|
||||
assert trace.steps[1].action == "tool_call"
|
||||
assert trace.steps[2].action == "llm_call"
|
||||
assert trace.steps[3].action == "final_answer"
|
||||
assert trace.total_tokens == 180 # 100 + 0 + 80 + 0
|
||||
|
||||
def test_total_duration_calculated(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(task_id="t3", agent_name="agent3")
|
||||
recorder.record_step(step=1, action="llm_call", duration_ms=50)
|
||||
recorder.record_step(step=1, action="tool_call", duration_ms=100)
|
||||
trace = recorder.end_trace()
|
||||
|
||||
# total_duration_ms 应该基于实际经过的时间(>=0)
|
||||
assert trace.total_duration_ms >= 0
|
||||
|
||||
def test_constructor_with_params_auto_starts(self):
|
||||
recorder = TraceRecorder(task_id="t4", agent_name="agent4", skill_name="skill1")
|
||||
recorder.record_step(step=1, action="llm_call", duration_ms=10)
|
||||
trace = recorder.end_trace()
|
||||
|
||||
assert trace.task_id == "t4"
|
||||
assert trace.agent_name == "agent4"
|
||||
assert trace.skill_name == "skill1"
|
||||
assert len(trace.steps) == 1
|
||||
|
||||
def test_start_trace_generates_uuid_when_no_task_id(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(agent_name="agent5")
|
||||
trace = recorder.end_trace()
|
||||
|
||||
assert trace.task_id # 应该有值(UUID)
|
||||
assert len(trace.task_id) > 0
|
||||
|
||||
|
||||
# ── TraceRecorder Edge Case Tests ────────────────────────
|
||||
|
||||
|
||||
class TestTraceRecorderEdgeCases:
|
||||
"""TraceRecorder 边界情况测试"""
|
||||
|
||||
def test_end_trace_without_start_returns_default(self):
|
||||
recorder = TraceRecorder()
|
||||
trace = recorder.end_trace(outcome="failure")
|
||||
|
||||
assert isinstance(trace, ExecutionTrace)
|
||||
assert trace.task_id == "unknown"
|
||||
assert trace.agent_name == ""
|
||||
assert trace.outcome == "failure"
|
||||
assert len(trace.steps) == 0
|
||||
|
||||
def test_get_trace_returns_trace_after_start(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_trace(task_id="t1", agent_name="a1")
|
||||
trace = recorder.get_trace()
|
||||
|
||||
assert trace is not None
|
||||
assert trace.task_id == "t1"
|
||||
|
||||
def test_get_trace_returns_none_before_start(self):
|
||||
recorder = TraceRecorder()
|
||||
trace = recorder.get_trace()
|
||||
|
||||
assert trace is None
|
||||
|
||||
def test_record_step_without_start_does_nothing(self):
|
||||
recorder = TraceRecorder()
|
||||
# 不应抛异常
|
||||
recorder.record_step(step=1, action="llm_call")
|
||||
trace = recorder.end_trace()
|
||||
assert len(trace.steps) == 0
|
||||
|
||||
def test_elapsed_ms_without_timer_returns_zero(self):
|
||||
recorder = TraceRecorder()
|
||||
assert recorder.elapsed_ms() == 0
|
||||
|
||||
def test_start_step_timer_and_elapsed_ms(self):
|
||||
recorder = TraceRecorder()
|
||||
recorder.start_step_timer()
|
||||
time.sleep(0.01) # 10ms
|
||||
elapsed = recorder.elapsed_ms()
|
||||
assert elapsed >= 8 # 至少 8ms(考虑精度)
|
||||
|
||||
|
||||
# ── Integration: TraceRecorder with ReActEngine ──────────
|
||||
|
||||
|
||||
class TestTraceRecorderWithReActEngine:
|
||||
"""TraceRecorder 与 ReActEngine 集成测试"""
|
||||
|
||||
async def test_single_step_with_recorder(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="The answer is 42"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
assert trace.outcome == "success"
|
||||
assert len(trace.steps) == 1
|
||||
assert trace.steps[0].action == "final_answer"
|
||||
assert trace.steps[0].tokens_used > 0
|
||||
|
||||
async def test_two_step_with_recorder(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = FakeTool(name="calculator", result={"value": 42})
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
|
||||
),
|
||||
make_response(content="The result is 42"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Calculate 6*7"}],
|
||||
tools=[tool],
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
assert trace.outcome == "success"
|
||||
# 应记录: llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
|
||||
# 注意: final_answer 分支中 LLM 调用和最终答案合并为一个 trace step
|
||||
assert len(trace.steps) == 3
|
||||
assert trace.steps[0].action == "llm_call"
|
||||
assert trace.steps[1].action == "tool_call"
|
||||
assert trace.steps[1].tool_name == "calculator"
|
||||
assert trace.steps[1].input_data == {"expr": "6*7"}
|
||||
assert trace.steps[1].output_data == {"value": 42}
|
||||
assert trace.steps[2].action == "final_answer"
|
||||
|
||||
async def test_max_steps_outcome_is_partial(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||
always_tool_response = make_response(
|
||||
content="Thinking...",
|
||||
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
||||
)
|
||||
gateway = make_mock_gateway([always_tool_response] * 20)
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Keep searching"}],
|
||||
tools=[tool],
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
assert trace.outcome == "partial"
|
||||
|
||||
async def test_without_recorder_backward_compatible(self):
|
||||
"""不传 trace_recorder 时行为不变"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Direct answer"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
)
|
||||
|
||||
assert result.output == "Direct answer"
|
||||
assert result.total_steps == 1
|
||||
|
||||
async def test_tool_error_recorded_in_trace(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="nonexistent_tool", arguments={})],
|
||||
),
|
||||
make_response(content="Tool not found, here is my answer"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Use unknown tool"}],
|
||||
tools=[],
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
# 找到 tool_call 步骤
|
||||
tool_steps = [s for s in trace.steps if s.action == "tool_call"]
|
||||
assert len(tool_steps) == 1
|
||||
assert tool_steps[0].error is not None
|
||||
|
||||
async def test_trace_total_tokens(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
),
|
||||
make_response(
|
||||
content="Final answer",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=30,
|
||||
),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search"}],
|
||||
tools=[tool],
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
assert trace.total_tokens == 380 # 150 + 230
|
||||
|
||||
async def test_agent_name_and_skill_name_in_trace(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Done"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
agent_name="test_agent",
|
||||
task_type="search_task",
|
||||
trace_recorder=recorder,
|
||||
)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace.agent_name == "test_agent"
|
||||
assert trace.skill_name == "search_task"
|
||||
|
||||
|
||||
# ── Integration: TraceRecorder with execute_stream ───────
|
||||
|
||||
|
||||
class TestTraceRecorderWithExecuteStream:
|
||||
"""TraceRecorder 与 execute_stream 集成测试"""
|
||||
|
||||
async def test_stream_with_recorder(self):
|
||||
from agentkit.core.react import ReActEngine, ReActEvent
|
||||
|
||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
content="",
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(content="Final answer"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
recorder = TraceRecorder()
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Search"}],
|
||||
tools=[tool],
|
||||
trace_recorder=recorder,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
trace = recorder.get_trace()
|
||||
assert trace is not None
|
||||
assert trace.outcome == "success"
|
||||
# llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
|
||||
assert len(trace.steps) == 3
|
||||
|
||||
async def test_stream_without_recorder_backward_compatible(self):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([
|
||||
make_response(content="Direct answer"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert any(e.event_type == "final_answer" for e in events)
|
||||
|
|
@ -1,148 +1,156 @@
|
|||
"""U8 GEO 适配层集成测试
|
||||
|
||||
验证 YAML 配置文件、ConfigDrivenAgent 创建、Custom Handler 路由等。
|
||||
测试在 fischer-agentkit 环境中运行,不依赖 GEO 业务代码。
|
||||
验证 YAML 配置文件加载、ConfigDrivenAgent 创建、Custom Handler 路由等。
|
||||
使用 agentkit 自带的 example_skill.yaml 和内联配置,不依赖 GEO 项目路径。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
from agentkit.skills.base import SkillConfig
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
CONFIGS_DIR = Path(__file__).parent.parent.parent.parent / "geo" / "backend" / "app" / "agent_framework" / "agents" / "configs"
|
||||
# Use agentkit's own skills directory
|
||||
SKILLS_DIR = Path(__file__).parent.parent.parent / "configs" / "skills"
|
||||
|
||||
|
||||
def _make_llm_generate_config() -> dict:
|
||||
"""Inline config for llm_generate mode agent."""
|
||||
return {
|
||||
"name": "test_llm_agent",
|
||||
"agent_type": "test_llm",
|
||||
"task_mode": "llm_generate",
|
||||
"supported_tasks": ["test_llm_task"],
|
||||
"prompt": {
|
||||
"identity": "You are a test assistant.",
|
||||
"instruction": "Respond to the user's request.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_tool_call_config() -> dict:
|
||||
"""Inline config for tool_call mode agent."""
|
||||
return {
|
||||
"name": "test_tool_agent",
|
||||
"agent_type": "test_tool",
|
||||
"task_mode": "tool_call",
|
||||
"supported_tasks": ["test_tool_task"],
|
||||
"tools": ["mock_tool_a", "mock_tool_b"],
|
||||
}
|
||||
|
||||
|
||||
def _make_custom_config() -> dict:
|
||||
"""Inline config for custom mode agent."""
|
||||
return {
|
||||
"name": "test_custom_agent",
|
||||
"agent_type": "test_custom",
|
||||
"task_mode": "custom",
|
||||
"supported_tasks": ["test_custom_task"],
|
||||
"custom_handler": "test.handlers.mock_handler",
|
||||
}
|
||||
|
||||
|
||||
class TestYAMLConfigLoading:
|
||||
"""测试 YAML 配置文件加载"""
|
||||
"""测试 YAML 配置文件加载(使用内联配置,不依赖 GEO)"""
|
||||
|
||||
@pytest.mark.parametrize("yaml_file", [
|
||||
"citation_detector.yaml",
|
||||
"content_generator.yaml",
|
||||
"deai_agent.yaml",
|
||||
"geo_optimizer.yaml",
|
||||
"monitor.yaml",
|
||||
"schema_advisor.yaml",
|
||||
"competitor_analyzer.yaml",
|
||||
"trend_agent.yaml",
|
||||
])
|
||||
def test_yaml_file_exists(self, yaml_file):
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
assert path.exists(), f"Config file {yaml_file} not found at {path}"
|
||||
|
||||
@pytest.mark.parametrize("yaml_file", [
|
||||
"citation_detector.yaml",
|
||||
"content_generator.yaml",
|
||||
"deai_agent.yaml",
|
||||
"geo_optimizer.yaml",
|
||||
"monitor.yaml",
|
||||
"schema_advisor.yaml",
|
||||
"competitor_analyzer.yaml",
|
||||
"trend_agent.yaml",
|
||||
])
|
||||
def test_yaml_valid_structure(self, yaml_file):
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
with open(path) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert isinstance(data, dict)
|
||||
assert "name" in data
|
||||
assert "agent_type" in data
|
||||
assert "task_mode" in data
|
||||
assert "supported_tasks" in data
|
||||
assert data["task_mode"] in {"llm_generate", "tool_call", "custom"}
|
||||
|
||||
@pytest.mark.parametrize("yaml_file", [
|
||||
"citation_detector.yaml",
|
||||
"content_generator.yaml",
|
||||
"deai_agent.yaml",
|
||||
"geo_optimizer.yaml",
|
||||
"monitor.yaml",
|
||||
"schema_advisor.yaml",
|
||||
"competitor_analyzer.yaml",
|
||||
"trend_agent.yaml",
|
||||
])
|
||||
def test_yaml_to_agent_config(self, yaml_file):
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
config = AgentConfig.from_yaml(str(path))
|
||||
assert config.name
|
||||
assert config.agent_type
|
||||
assert config.task_mode
|
||||
def test_llm_generate_config_structure(self):
|
||||
config = AgentConfig.from_dict(_make_llm_generate_config())
|
||||
assert config.name == "test_llm_agent"
|
||||
assert config.agent_type == "test_llm"
|
||||
assert config.task_mode == "llm_generate"
|
||||
assert len(config.supported_tasks) > 0
|
||||
assert config.prompt is not None
|
||||
|
||||
def test_tool_call_config_structure(self):
|
||||
config = AgentConfig.from_dict(_make_tool_call_config())
|
||||
assert config.name == "test_tool_agent"
|
||||
assert config.task_mode == "tool_call"
|
||||
assert len(config.tools) == 2
|
||||
|
||||
def test_custom_config_structure(self):
|
||||
config = AgentConfig.from_dict(_make_custom_config())
|
||||
assert config.name == "test_custom_agent"
|
||||
assert config.task_mode == "custom"
|
||||
assert config.custom_handler == "test.handlers.mock_handler"
|
||||
|
||||
def test_llm_generate_agents_have_prompt(self):
|
||||
llm_agents = ["content_generator.yaml", "deai_agent.yaml", "geo_optimizer.yaml"]
|
||||
for yaml_file in llm_agents:
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
config = AgentConfig.from_yaml(str(path))
|
||||
assert config.prompt, f"{yaml_file}: llm_generate mode requires prompt"
|
||||
assert "identity" in config.prompt
|
||||
config = AgentConfig.from_dict(_make_llm_generate_config())
|
||||
assert config.prompt, "llm_generate mode requires prompt"
|
||||
assert "identity" in config.prompt
|
||||
|
||||
def test_custom_agents_have_handler(self):
|
||||
custom_agents = ["citation_detector.yaml", "monitor.yaml", "schema_advisor.yaml"]
|
||||
for yaml_file in custom_agents:
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
config = AgentConfig.from_yaml(str(path))
|
||||
assert config.custom_handler, f"{yaml_file}: custom mode requires custom_handler"
|
||||
config = AgentConfig.from_dict(_make_custom_config())
|
||||
assert config.custom_handler, "custom mode requires custom_handler"
|
||||
|
||||
def test_tool_call_agents_have_tools(self):
|
||||
tool_agents = ["competitor_analyzer.yaml", "trend_agent.yaml"]
|
||||
for yaml_file in tool_agents:
|
||||
path = CONFIGS_DIR / yaml_file
|
||||
config = AgentConfig.from_yaml(str(path))
|
||||
assert config.tools, f"{yaml_file}: tool_call mode requires tools list"
|
||||
config = AgentConfig.from_dict(_make_tool_call_config())
|
||||
assert config.tools, "tool_call mode requires tools list"
|
||||
|
||||
def test_example_skill_yaml_if_exists(self):
|
||||
"""Test loading example_skill.yaml if it exists in configs/skills/."""
|
||||
example_path = SKILLS_DIR / "example_skill.yaml"
|
||||
if not example_path.exists():
|
||||
pytest.skip("example_skill.yaml not found in configs/skills/")
|
||||
config = SkillConfig.from_yaml(str(example_path))
|
||||
assert config.name
|
||||
assert config.agent_type
|
||||
|
||||
|
||||
class TestConfigDrivenAgentCreation:
|
||||
"""测试从 YAML 创建 ConfigDrivenAgent"""
|
||||
"""测试从配置创建 ConfigDrivenAgent"""
|
||||
|
||||
def test_create_llm_generate_agent(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "content_generator.yaml"))
|
||||
config = AgentConfig.from_dict(_make_llm_generate_config())
|
||||
tool_registry = ToolRegistry()
|
||||
agent = ConfigDrivenAgent(config=config, tool_registry=tool_registry)
|
||||
assert agent.name == "content_generator"
|
||||
assert agent.agent_type == "content_generation"
|
||||
assert agent.name == "test_llm_agent"
|
||||
assert agent.agent_type == "test_llm"
|
||||
assert agent.prompt_template is not None
|
||||
|
||||
def test_create_tool_call_agent(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "competitor_analyzer.yaml"))
|
||||
config = AgentConfig.from_dict(_make_tool_call_config())
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
async def mock_analyze(**kwargs):
|
||||
async def mock_func(**kwargs):
|
||||
return {"result": "mock"}
|
||||
|
||||
tool_registry.register(
|
||||
FunctionTool(name="competitor_analyze", description="mock", func=mock_analyze)
|
||||
FunctionTool(name="mock_tool_a", description="mock", func=mock_func)
|
||||
)
|
||||
tool_registry.register(
|
||||
FunctionTool(name="competitor_gap_analysis", description="mock", func=mock_analyze)
|
||||
FunctionTool(name="mock_tool_b", description="mock", func=mock_func)
|
||||
)
|
||||
|
||||
agent = ConfigDrivenAgent(config=config, tool_registry=tool_registry)
|
||||
assert agent.name == "competitor_analyzer"
|
||||
assert agent.name == "test_tool_agent"
|
||||
assert len(agent._tools) == 2
|
||||
|
||||
def test_create_custom_agent(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "citation_detector.yaml"))
|
||||
config = AgentConfig.from_dict(_make_custom_config())
|
||||
|
||||
async def mock_handler(task):
|
||||
return {"mock": True}
|
||||
|
||||
custom_handlers = {
|
||||
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": mock_handler,
|
||||
"test.handlers.mock_handler": mock_handler,
|
||||
}
|
||||
|
||||
agent = ConfigDrivenAgent(config=config, custom_handlers=custom_handlers)
|
||||
assert agent.name == "citation_detector"
|
||||
assert agent.name == "test_custom_agent"
|
||||
|
||||
def test_create_all_8_agents(self):
|
||||
"""验证所有 8 个 Agent 都能成功创建"""
|
||||
for yaml_file in CONFIGS_DIR.glob("*.yaml"):
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
def test_create_all_mode_agents(self):
|
||||
"""验证三种模式的 Agent 都能成功创建"""
|
||||
configs = [_make_llm_generate_config(), _make_tool_call_config(), _make_custom_config()]
|
||||
|
||||
for cfg_dict in configs:
|
||||
config = AgentConfig.from_dict(cfg_dict)
|
||||
tool_registry = ToolRegistry()
|
||||
|
||||
# 为 tool_call 模式注册 mock 工具
|
||||
|
|
@ -172,8 +180,8 @@ class TestCustomHandlerRouting:
|
|||
"""测试 Custom Handler 路由"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_citation_handler_routing(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "citation_detector.yaml"))
|
||||
async def test_custom_handler_routing(self):
|
||||
config = AgentConfig.from_dict(_make_custom_config())
|
||||
|
||||
call_log = []
|
||||
|
||||
|
|
@ -182,14 +190,14 @@ class TestCustomHandlerRouting:
|
|||
return {"mock": True, "task_type": task.task_type}
|
||||
|
||||
custom_handlers = {
|
||||
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": mock_handler,
|
||||
"test.handlers.mock_handler": mock_handler,
|
||||
}
|
||||
|
||||
agent = ConfigDrivenAgent(config=config, custom_handlers=custom_handlers)
|
||||
task = TaskMessage(
|
||||
task_id="test-1",
|
||||
agent_name="citation_detector",
|
||||
task_type="citation_detect",
|
||||
agent_name="test_custom_agent",
|
||||
task_type="test_custom_task",
|
||||
priority=0,
|
||||
input_data={"query_id": "test-qid"},
|
||||
callback_url=None,
|
||||
|
|
@ -197,66 +205,65 @@ class TestCustomHandlerRouting:
|
|||
)
|
||||
result = await agent.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
assert "citation_detect" in call_log
|
||||
assert "test_custom_task" in call_log
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_handler_routing(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "monitor.yaml"))
|
||||
|
||||
async def mock_handler(task):
|
||||
return {"brand_id": task.input_data.get("brand_id"), "reports": []}
|
||||
class TestSkillConfigV2:
|
||||
"""测试 SkillConfig v2 字段"""
|
||||
|
||||
custom_handlers = {
|
||||
"app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": mock_handler,
|
||||
def test_skill_config_from_dict(self):
|
||||
data = {
|
||||
"name": "test_skill",
|
||||
"agent_type": "test",
|
||||
"task_mode": "llm_generate",
|
||||
"supported_tasks": ["test"],
|
||||
"prompt": {"identity": "test"},
|
||||
"intent": {
|
||||
"keywords": ["test", "demo"],
|
||||
"description": "A test skill",
|
||||
},
|
||||
"quality_gate": {
|
||||
"required_fields": ["output"],
|
||||
"min_word_count": 10,
|
||||
},
|
||||
"execution_mode": "react",
|
||||
"max_steps": 3,
|
||||
"evolution": {
|
||||
"enabled": True,
|
||||
"reflect_on_failure": False,
|
||||
},
|
||||
}
|
||||
config = SkillConfig.from_dict(data)
|
||||
assert config.name == "test_skill"
|
||||
assert config.intent.keywords == ["test", "demo"]
|
||||
assert config.quality_gate.required_fields == ["output"]
|
||||
assert config.execution_mode == "react"
|
||||
assert config.max_steps == 3
|
||||
assert config.evolution.enabled is True
|
||||
|
||||
agent = ConfigDrivenAgent(config=config, custom_handlers=custom_handlers)
|
||||
task = TaskMessage(
|
||||
task_id="test-2",
|
||||
agent_name="monitor",
|
||||
task_type="monitor_track",
|
||||
priority=0,
|
||||
input_data={"brand_id": "test-brand-id"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
result = await agent.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_handler_routing(self):
|
||||
config = AgentConfig.from_yaml(str(CONFIGS_DIR / "schema_advisor.yaml"))
|
||||
|
||||
async def mock_handler(task):
|
||||
return {"brand_id": task.input_data.get("brand_id"), "suggestions": [], "total": 0}
|
||||
|
||||
custom_handlers = {
|
||||
"app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": mock_handler,
|
||||
def test_skill_config_backward_compatible(self):
|
||||
"""旧 YAML 无 v2 字段时自动填充默认值"""
|
||||
data = {
|
||||
"name": "legacy_skill",
|
||||
"agent_type": "legacy",
|
||||
"task_mode": "llm_generate",
|
||||
"supported_tasks": ["legacy"],
|
||||
"prompt": {"identity": "Legacy skill"},
|
||||
}
|
||||
|
||||
agent = ConfigDrivenAgent(config=config, custom_handlers=custom_handlers)
|
||||
task = TaskMessage(
|
||||
task_id="test-3",
|
||||
agent_name="schema_advisor",
|
||||
task_type="schema_advise",
|
||||
priority=0,
|
||||
input_data={"brand_id": "test-brand-id"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
result = await agent.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
config = SkillConfig.from_dict(data)
|
||||
assert config.name == "legacy_skill"
|
||||
assert config.intent.keywords == []
|
||||
assert config.quality_gate.required_fields == []
|
||||
assert config.execution_mode == "react" # default
|
||||
assert config.evolution.enabled is False # default
|
||||
|
||||
|
||||
class TestToolRegistration:
|
||||
"""测试 Tool 注册完整性"""
|
||||
|
||||
def test_all_yaml_referenced_tools_registered(self):
|
||||
def test_all_referenced_tools_registered(self):
|
||||
registry = ToolRegistry()
|
||||
all_tool_names = set()
|
||||
for yaml_file in CONFIGS_DIR.glob("*.yaml"):
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
all_tool_names.update(config.tools)
|
||||
all_tool_names = {"mock_tool_a", "mock_tool_b", "mock_tool_c"}
|
||||
|
||||
for tool_name in all_tool_names:
|
||||
async def mock_func(**kwargs):
|
||||
|
|
@ -272,43 +279,22 @@ class TestToolRegistration:
|
|||
class TestAdapterCompatibility:
|
||||
"""测试适配层兼容性"""
|
||||
|
||||
def test_yaml_configs_count(self):
|
||||
yaml_files = list(CONFIGS_DIR.glob("*.yaml"))
|
||||
assert len(yaml_files) == 8, f"Expected 8 YAML configs, found {len(yaml_files)}"
|
||||
|
||||
def test_all_agent_names_unique(self):
|
||||
names = []
|
||||
for yaml_file in CONFIGS_DIR.glob("*.yaml"):
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
names.append(config.name)
|
||||
configs = [_make_llm_generate_config(), _make_tool_call_config(), _make_custom_config()]
|
||||
names = [AgentConfig.from_dict(c).name for c in configs]
|
||||
assert len(names) == len(set(names)), f"Duplicate agent names: {names}"
|
||||
|
||||
def test_all_agent_types_unique(self):
|
||||
types = []
|
||||
for yaml_file in CONFIGS_DIR.glob("*.yaml"):
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
types.append(config.agent_type)
|
||||
configs = [_make_llm_generate_config(), _make_tool_call_config(), _make_custom_config()]
|
||||
types = [AgentConfig.from_dict(c).agent_type for c in configs]
|
||||
assert len(types) == len(set(types)), f"Duplicate agent types: {types}"
|
||||
|
||||
def test_supported_tasks_no_overlap(self):
|
||||
configs = [_make_llm_generate_config(), _make_tool_call_config(), _make_custom_config()]
|
||||
all_tasks = {}
|
||||
for yaml_file in CONFIGS_DIR.glob("*.yaml"):
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
for cfg_dict in configs:
|
||||
config = AgentConfig.from_dict(cfg_dict)
|
||||
for task in config.supported_tasks:
|
||||
if task in all_tasks:
|
||||
assert False, f"Task '{task}' defined in both '{all_tasks[task]}' and '{config.name}'"
|
||||
all_tasks[task] = config.name
|
||||
|
||||
def test_migration_script_exists(self):
|
||||
migration_path = (
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "geo" / "backend" / "alembic" / "versions" / "b001_agentkit_extension.py"
|
||||
)
|
||||
assert migration_path.exists(), "Migration script not found"
|
||||
|
||||
def test_adapter_module_exists(self):
|
||||
adapter_path = (
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "geo" / "backend" / "app" / "agent_framework" / "adapter.py"
|
||||
)
|
||||
assert adapter_path.exists(), "adapter.py not found"
|
||||
|
|
|
|||
Loading…
Reference in New Issue