feat(router): optimize routing intelligence — ExecutionMode expansion, multi-candidate scoring, quality gate skill match
- Expand ExecutionMode enum with REWOO/REFLEXION/PLAN_EXEC - Add _resolve_execution_mode() to respect skill.config.execution_mode - Rewrite IntentRouter._match_keywords() for multi-candidate scoring - Add QualityGate 5th dimension: skill_match validation with warning escalation - Calibrate HeuristicClassifier: low-complexity signals only when no high signals - Fix negation regex for Chinese text (avoid matching past punctuation) - Fix backtest mode_map normalization and .env loading - Add 61 unit tests (21 HeuristicClassifier + 14 IntentRouter + 13 QualityGate + 13 existing) Results: execution_mode_accuracy 9.09%→36.36%, skill_routing_F1 66.67%→77.78%
This commit is contained in:
parent
64d62a2b60
commit
e984b4c462
|
|
@ -0,0 +1,280 @@
|
|||
---
|
||||
title: "feat: E2E能力分析框架改进与路由智能化提升"
|
||||
type: feat
|
||||
status: active
|
||||
created: 2026-06-15
|
||||
plan-depth: standard
|
||||
---
|
||||
|
||||
# E2E能力分析框架改进与路由智能化提升
|
||||
|
||||
## Summary
|
||||
|
||||
改进E2E能力分析框架,解决当前基准数据集与实际技能不对应、覆盖面窄(仅19条)、指标判断过于简化等核心问题。同时将ExpertTeamRouter集成到CostAwareRouter自动触发链路,增加路由器直接回测层,并将基准用例扩展至60条,使召回率/F1/过拟合检测等指标具备统计意义。
|
||||
|
||||
## Problem Frame
|
||||
|
||||
当前E2E能力分析框架存在四个关键问题:
|
||||
|
||||
1. **基准数据与实际技能脱节**:`benchmark_dataset.py` 中的 `expected_skill`(如 `email_composer`、`i18n_translator`)与 `configs/skills/` 中的15个实际技能不对应,导致路由回测结果无意义
|
||||
2. **覆盖面过窄**:仅19条基准用例,PRF统计不稳定;缺少 SemanticRouter、ExpertTeamRouter、AlignmentGuard 的专项基准
|
||||
3. **指标判断粗糙**:`complexity_correct` 直接等于 `execution_mode_correct`,无法独立评估复杂度估算;改进策略中的 `target_module` 引用了旧文件名
|
||||
4. **团队路由未自动集成**:`ExpertTeamRouter` 与 `CostAwareRouter` 独立运行,`TEAM_COLLAB` 模式无法自动触发
|
||||
|
||||
## Requirements
|
||||
|
||||
- R1: 基准数据集中的 `expected_skill` 必须与 `configs/skills/` 中的实际技能一一对应
|
||||
- R2: 基准用例数量扩展至60条,覆盖路由/执行/团队/一致性/对齐守卫五个维度
|
||||
- R3: 增加路由器直接回测层(不经过HTTP API),能区分路由错误与API层错误
|
||||
- R4: `complexity_correct` 独立于 `execution_mode_correct`,基于 HeuristicClassifier 分数与期望复杂度的映射判断
|
||||
- R5: ExpertTeamRouter 集成到 CostAwareRouter.route() 中,高复杂度任务自动触发 TEAM_COLLAB
|
||||
- R6: 增加 SemanticRouter 专项基准(相似度分数分布、三档精确率)
|
||||
- R7: 增加 AlignmentGuard 约束检查基准
|
||||
- R8: 修正改进策略中的 target_module 文件路径
|
||||
- R9: 报告输出保持中文
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: 双层回测架构
|
||||
|
||||
**决策**:在现有HTTP API层E2E测试之上,增加路由器直接回测层。
|
||||
|
||||
**理由**:纯API测试无法区分"路由器选错了技能"和"API层传递参数出错"两种失败模式。直接回测层调用 `CostAwareRouter.route()` 方法,记录 `SkillRoutingResult` 的完整字段(`match_method`、`match_confidence`、`execution_trace`),使根因分析能精确定位到具体路由层。
|
||||
|
||||
**替代方案**:保持纯API层测试 → 被否决,因为无法满足R3的精确诊断需求。
|
||||
|
||||
### KTD2: ExpertTeamRouter 集成方式
|
||||
|
||||
**决策**:在 `CostAwareRouter._route_layer2()` 末尾增加 ExpertTeamRouter 检查点。当 Layer 2 判定 `execution_mode=REACT` 且 `complexity >= 0.7` 时,调用 `ExpertTeamRouter.resolve()` 判断是否升级为 `TEAM_COLLAB`。
|
||||
|
||||
**理由**:保持三层路由的递进式架构不变,仅在 Layer 2 出口处增加团队模式升级逻辑,最小化对现有路由流程的侵入。
|
||||
|
||||
### KTD3: 复杂度正确性判断策略
|
||||
|
||||
**决策**:基于 HeuristicClassifier 返回的浮点复杂度分数与期望复杂度等级的映射区间判断:`low=[0, 0.3)`、`medium=[0.3, 0.7)`、`high=[0.7, 1.0]`。
|
||||
|
||||
**理由**:直接使用浮点分数比仅比较执行模式更精确,能区分"复杂度分数0.29被判为low但期望medium"和"复杂度分数0.65被判为medium且期望medium"两种情况。
|
||||
|
||||
### KTD4: 基准用例与实际技能对齐
|
||||
|
||||
**决策**:从 `configs/skills/` 的15个实际技能中提取 `intent.keywords` 和 `intent.description`,自动生成基准用例的 `expected_skill`,而非手动硬编码。
|
||||
|
||||
**理由**:手动维护的技能名容易与实际配置脱节(当前问题)。自动对齐确保基准数据始终反映最新的技能配置。
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. 基准数据集与实际技能对齐
|
||||
|
||||
**Goal**: 修复 benchmark_dataset.py 中 expected_skill 与实际技能的对应关系,扩展至60条用例
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- `tests/e2e/benchmark_dataset.py` — 重写基准数据集
|
||||
- `tests/e2e/benchmark_generator.py` — 新增:从技能配置自动生成基准用例
|
||||
|
||||
**Approach**:
|
||||
1. 新增 `BenchmarkGenerator` 类,读取 `configs/skills/*.yaml`,提取每个技能的 `intent.keywords`、`intent.description`、`intent.examples`,自动生成 `BenchmarkCase`
|
||||
2. 为每个技能生成3-5条基准用例:1条原始输入 + 2-4条改写
|
||||
3. 保留手动定义的边界用例(问候语、身份识别、无匹配回退)
|
||||
4. 新增维度:`alignment`(对齐守卫)、`semantic_router`(语义路由专项)
|
||||
5. 总目标:路由20+、执行15+、团队10+、一致性10+、对齐守卫5+
|
||||
|
||||
**Patterns to follow**: `BenchmarkCase` Pydantic frozen model 模式
|
||||
|
||||
**Test scenarios**:
|
||||
- 生成的基准用例 expected_skill 全部存在于 configs/skills/ 中
|
||||
- 每个技能至少有1条基准用例
|
||||
- paraphrases 非空的用例占比 > 60%
|
||||
- 总用例数 >= 60
|
||||
|
||||
**Verification**: 运行 `python -c "from tests.e2e.benchmark_dataset import ALL_BENCHMARKS; print(len(ALL_BENCHMARKS))"` 确认 >= 60
|
||||
|
||||
### U2. 路由器直接回测层
|
||||
|
||||
**Goal**: 增加不经过HTTP API的路由器直接回测,记录完整路由结果
|
||||
|
||||
**Dependencies**: U1
|
||||
|
||||
**Files**:
|
||||
- `tests/e2e/test_capability_router_direct.py` — 新增:路由器直接回测
|
||||
- `tests/e2e/conftest.py` — 增加 router fixture
|
||||
|
||||
**Approach**:
|
||||
1. 在 conftest.py 中增加 `cost_aware_router` fixture,直接实例化 `CostAwareRouter`(使用 MockLLMProvider)
|
||||
2. 新增 `test_capability_router_direct.py`,对每个基准用例调用 `router.route(query)` 并记录完整 `SkillRoutingResult`
|
||||
3. 记录字段:`skill_name`、`execution_mode`、`complexity`、`match_method`(layer0/layer1/layer1.5/layer2)、`match_confidence`、`execution_trace`
|
||||
4. 将路由器回测结果也写入 MetricsCollector,增加 `match_method` 和 `match_confidence` 字段
|
||||
|
||||
**Patterns to follow**: 现有 `test_capability_routing.py` 的参数化测试模式
|
||||
|
||||
**Test scenarios**:
|
||||
- Layer 0 规则匹配:问候语 → DIRECT_CHAT,@skill:xxx → 对应技能
|
||||
- Layer 1 复杂度分类:简单问答 → low,多步分析 → high
|
||||
- Layer 1.5 语义路由:同义改写 → 相同技能,相似度 > 0.6
|
||||
- Layer 2 能力匹配:高复杂度 → REACT/TEAM_COLLAB
|
||||
- 路由器回测与API回测结果一致性 > 90%
|
||||
|
||||
**Verification**: 运行 `pytest tests/e2e/test_capability_router_direct.py -v` 全部通过
|
||||
|
||||
### U3. 指标体系增强
|
||||
|
||||
**Goal**: 修复 complexity_correct 判断逻辑,增加语义路由/团队路由指标,修正 target_module 路径
|
||||
|
||||
**Dependencies**: U1
|
||||
|
||||
**Files**:
|
||||
- `tests/e2e/capability_metrics.py` — 增强指标模型和分析器
|
||||
- `tests/e2e/benchmark_dataset.py` — 增加 semantic_router / alignment 类别
|
||||
|
||||
**Approach**:
|
||||
1. `CapabilityObservation` 增加 `actual_complexity_score: float | None`、`actual_match_method: str | None`、`actual_match_confidence: float | None` 字段
|
||||
2. `complexity_correct` 改为基于分数区间映射判断(KTD3)
|
||||
3. `MetricsAnalyzer` 增加 `analyze_semantic_router()` 方法:按 high/medium/low 三档统计精确率
|
||||
4. `MetricsAnalyzer` 增加 `analyze_team_routing()` 方法:统计 `explicit_team` vs `complexity_suggestion` 的成功率
|
||||
5. 修正 `plan_improvements()` 中所有 `target_module`:`cost_aware_router.py` → `chat/skill_routing.py`
|
||||
6. 报告增加"语义路由分析"和"团队路由分析"章节
|
||||
|
||||
**Patterns to follow**: 现有 `MetricsAnalyzer` 的分析方法模式
|
||||
|
||||
**Test scenarios**:
|
||||
- complexity_correct 独立于 execution_mode_correct
|
||||
- 语义路由三档精确率计算正确
|
||||
- 团队路由成功率计算正确
|
||||
- target_module 路径与实际代码对应
|
||||
- 中文报告输出包含新增章节
|
||||
|
||||
**Verification**: 运行 `pytest tests/e2e/test_capability_routing.py tests/e2e/test_capability_react.py -v` 通过
|
||||
|
||||
### U4. ExpertTeamRouter 集成到 CostAwareRouter
|
||||
|
||||
**Goal**: 高复杂度任务自动触发 TEAM_COLLAB 模式
|
||||
|
||||
**Dependencies**: U2
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/chat/skill_routing.py` — 修改 `_route_layer2()` 增加团队升级逻辑
|
||||
- `src/agentkit/experts/router.py` — 增加 `can_handle()` 方法供路由器查询
|
||||
- `tests/unit/chat/test_skill_routing.py` — 增加团队路由单元测试
|
||||
|
||||
**Approach**:
|
||||
1. 在 `CostAwareRouter._route_layer2()` 末尾,当 `execution_mode == REACT` 且 `complexity >= COMPLEXITY_THRESHOLD` 时,调用 `ExpertTeamRouter.resolve(content, complexity)`
|
||||
2. 如果 `ExpertTeamRouter` 返回有效结果,升级 `execution_mode` 为 `TEAM_COLLAB`,并在 `execution_trace` 中记录 `"team_upgrade": True`
|
||||
3. 在 `ExpertTeamRouter` 中增加 `can_handle(content: str) -> bool` 方法,检查是否有匹配的专家模板
|
||||
4. 保持向后兼容:如果 `ExpertTeamRouter` 不可用(未配置专家模板),静默跳过
|
||||
|
||||
**Patterns to follow**: 现有 `_route_layer2()` 的 Vickrey 拍卖路径模式
|
||||
|
||||
**Test scenarios**:
|
||||
- 高复杂度 + 有专家模板 → TEAM_COLLAB
|
||||
- 高复杂度 + 无专家模板 → 保持 REACT
|
||||
- 低复杂度 → 不触发团队路由
|
||||
- @team 前缀 → 直接 TEAM_COLLAB(Layer 0 处理)
|
||||
- execution_trace 包含 team_upgrade 标记
|
||||
|
||||
**Verification**: 运行 `pytest tests/unit/chat/test_skill_routing.py -v -k team` 通过
|
||||
|
||||
### U5. AlignmentGuard 与 CascadeDetector 指标集成
|
||||
|
||||
**Goal**: 将对齐守卫约束违规和级联告警纳入E2E指标收集
|
||||
|
||||
**Dependencies**: U3
|
||||
|
||||
**Files**:
|
||||
- `tests/e2e/test_capability_alignment.py` — 新增:对齐守卫基准测试
|
||||
- `tests/e2e/capability_metrics.py` — 增加 alignment 维度指标
|
||||
|
||||
**Approach**:
|
||||
1. 新增 `test_capability_alignment.py`,包含5+条对齐守卫基准用例:
|
||||
- 否定约束测试("不要提及价格"→ 输出不含价格)
|
||||
- 肯定约束测试("必须包含摘要"→ 输出含摘要)
|
||||
- 级联告警测试(连续5次相似查询 → 触发 CascadeAlert)
|
||||
2. `CapabilityObservation` 增加 `alignment_violations: int`、`cascade_alert: bool` 字段
|
||||
3. `MetricsAnalyzer` 增加 `analyze_alignment()` 方法
|
||||
4. 报告增加"对齐守卫分析"章节
|
||||
|
||||
**Patterns to follow**: 现有 `test_capability_team.py` 的测试模式
|
||||
|
||||
**Test scenarios**:
|
||||
- 否定约束:输出不包含禁止内容
|
||||
- 肯定约束:输出包含必要内容
|
||||
- 级联告警:连续交互触发告警
|
||||
- 无约束:正常通过
|
||||
|
||||
**Verification**: 运行 `pytest tests/e2e/test_capability_alignment.py -v` 通过
|
||||
|
||||
### U6. 运行脚本与CI集成
|
||||
|
||||
**Goal**: 更新运行脚本,支持分层回测和CI集成
|
||||
|
||||
**Dependencies**: U2, U3, U4, U5
|
||||
|
||||
**Files**:
|
||||
- `scripts/run_e2e.sh` — 增加直接回测和分层运行选项
|
||||
- `tests/e2e/conftest.py` — 确保 pytest_sessionfinish 报告生成正确
|
||||
|
||||
**Approach**:
|
||||
1. `run_e2e.sh` 增加 `--direct` 选项(仅运行路由器直接回测)
|
||||
2. `run_e2e.sh` 增加 `--alignment` 选项(仅运行对齐守卫测试)
|
||||
3. `run_e2e.sh` 增加 `--full` 选项(运行全部:API + 直接 + 对齐)
|
||||
4. 确保报告输出目录 `test-results/e2e/` 在 CI 中作为 artifact 上传
|
||||
5. 增加 `--baseline` 选项:与上次报告对比,输出指标变化趋势
|
||||
|
||||
**Patterns to follow**: 现有 `run_e2e.sh` 的选项模式
|
||||
|
||||
**Test scenarios**:
|
||||
- `--direct` 仅运行路由器直接回测
|
||||
- `--alignment` 仅运行对齐守卫测试
|
||||
- `--full` 运行所有能力测试
|
||||
- `--analyze` 生成完整中文报告
|
||||
- 报告文件正确保存到 test-results/e2e/
|
||||
|
||||
**Verification**: 运行 `./scripts/run_e2e.sh --direct` 和 `./scripts/run_e2e.sh --analyze` 验证
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- 基准数据集与实际技能对齐并扩展至60条
|
||||
- 路由器直接回测层
|
||||
- 指标体系增强(复杂度、语义路由、团队路由)
|
||||
- ExpertTeamRouter 集成到 CostAwareRouter
|
||||
- AlignmentGuard 指标集成
|
||||
- 运行脚本更新
|
||||
|
||||
### Out of Scope
|
||||
- CostAwareRouter 三层架构重写
|
||||
- 新增 LLM Provider
|
||||
- 前端界面修改
|
||||
- 生产环境部署
|
||||
- intent.examples 嵌入到 SemanticRouter(可作为后续优化)
|
||||
- disambiguation_keywords 配置字段(改进策略已规划,但属于技能配置层面的独立改进)
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
- 基于用户真实查询日志的基准用例持续扩充
|
||||
- 复杂度评估模型训练(替代启发式规则)
|
||||
- 意图泛化CI防线的 GitHub Actions 配置
|
||||
- OutputStandardizer.quality_score 与路由决策的关联分析
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| 风险 | 影响 | 缓解措施 |
|
||||
|------|------|----------|
|
||||
| ExpertTeamRouter 集成可能影响现有路由性能 | Layer 2 增加一次 resolve() 调用 | 仅在 complexity >= 0.7 时触发,且 can_handle() 快速返回 |
|
||||
| 基准用例自动生成可能产生低质量用例 | PRF 指标失真 | 人工审核自动生成的用例,保留手动边界用例 |
|
||||
| 路由器直接回测需要 MockLLMProvider 完整支持 | 某些路由路径无法测试 | 优先覆盖 Layer 0/1,Layer 1.5/2 标记为需要真实 LLM |
|
||||
| 60条用例可能增加E2E运行时间 | CI 流水线变慢 | 按维度分组运行,支持 `--fast` 快速失败模式 |
|
||||
|
||||
---
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **路由层**:`skill_routing.py` 增加 ExpertTeamRouter 调用点,影响所有高复杂度请求的路由决策
|
||||
- **测试层**:新增3个测试文件,conftest.py 增加2个 fixture,运行脚本增加4个选项
|
||||
- **报告层**:能力分析报告增加3个章节(语义路由、团队路由、对齐守卫)
|
||||
- **配置层**:无配置文件变更(disambiguation_keywords 推迟到后续)
|
||||
|
|
@ -0,0 +1,326 @@
|
|||
---
|
||||
title: "feat: 路由智能化优化 — 复杂度校准、意图消歧、质量门控增强"
|
||||
status: active
|
||||
created: 2026-06-15
|
||||
updated: 2026-06-15
|
||||
origin: test-results/e2e/capability_report.txt (真实LLM回测分析报告)
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
基于真实 LLM 回测分析报告暴露的三个核心根因,优化 CostAwareRouter 的路由智能化水平:修复 HeuristicClassifier 复杂度评分偏差(执行模式准确率从 9.09% 提升至 >30%),解决 IntentRouter 首次匹配导致的技能混淆(技能路由 F1 从 66.67% 提升至 >80%),增强 QualityGate 的技能匹配验证拦截错误路由。
|
||||
|
||||
**当前进度**: U1 代码已实现,待补单元测试;U2/U3 待实现;U4 待验证。
|
||||
|
||||
---
|
||||
|
||||
## Problem Frame
|
||||
|
||||
真实 LLM 回测(74个观测)揭示三个核心问题:
|
||||
|
||||
1. **执行模式准确率 9.09%** — HeuristicClassifier 倾向高估复杂度,将简单问答(如"你好"、"你是谁")判为需要 REACT 而非 DIRECT_CHAT。40个执行模式判断错误中仅1次低估复杂度。
|
||||
2. **keyword_match 召回率 0%** — 62个关键词匹配用例全部未路由到预期技能,真实 SkillRegistry 虽然加载了15个技能,但路由链路未能正确匹配。
|
||||
3. **意图歧义** — plan_exec_agent 与 goal_driven_agent 的关键词重叠("规划"、"报告"子串),IntentRouter 首次匹配策略导致混淆。
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
- R1: HeuristicClassifier 复杂度评分校准 — 简单问答应得低分(<0.3),复杂任务应得高分(>0.7)
|
||||
- R2: IntentRouter 多候选评分排序 — 匹配多个技能时按得分排序选择最佳,而非首次匹配
|
||||
- R3: QualityGate 技能匹配验证 — 拦截路由结果与技能能力不一致的输出
|
||||
- R4: 回测验证 — 改进后执行模式准确率 >30%,技能路由 F1 >80%
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: HeuristicClassifier 评分重构 — 增加低复杂度信号
|
||||
|
||||
**决策**: 在现有高/中复杂度关键词之外,增加低复杂度关键词列表和否定信号机制。当输入包含低复杂度信号(问候、闲聊、简单定义)时,直接降低基础分数;当高复杂度词出现在否定上下文("不要X"、"无需X")时,不增加分数。
|
||||
|
||||
**理由**: 当前分类器只有正向累加逻辑(命中高复杂度词→加分),没有负向扣减逻辑。这导致任何包含"分析"、"搜索"等常见动词的输入都被判为高复杂度,即使实际是简单问答。
|
||||
|
||||
**替代方案**: 用 LLM 替代规则分类器 — 延迟高(~500ms)、成本高(~100 tokens),且当前 merged_llm_classify 已在 0.3-0.7 区间使用 LLM,规则层应保持零成本。
|
||||
|
||||
**实现状态**: 代码已完成。`classify()` 方法已重写,包含低复杂度信号优先检测、否定上下文排除、阈值调整(0.15→0.10, 0.45→0.35)、短疑问句扣减。
|
||||
|
||||
### KTD2: IntentRouter 多候选评分排序
|
||||
|
||||
**决策**: 修改 `_match_keywords()` 从"首次匹配返回"改为"收集所有匹配候选,按匹配关键词数量×关键词长度排序,返回最佳匹配"。
|
||||
|
||||
**理由**: 首次匹配依赖 skills 列表遍历顺序,不可控且不公平。多候选评分让匹配更多、更精确关键词的技能胜出。例如输入"规划一个调研报告"同时匹配 plan_exec_agent("规划"、"报告")和 goal_driven_agent("规划"、"调研"),但 goal_driven_agent 还匹配"生成报告"的子串"报告",匹配数相同则按关键词长度排序,更长的关键词("调研报告" > "报告")权重更高。
|
||||
|
||||
**替代方案**: 在技能配置中添加互斥关键词 — 需要逐对配置,维护成本高,且无法覆盖所有重叠场景。
|
||||
|
||||
**实现状态**: 待实现。当前 `_match_keywords()` 仍为首次匹配逻辑(`intent.py` L89-98)。
|
||||
|
||||
### KTD3: QualityGate 技能匹配验证 — 轻量级路由一致性检查
|
||||
|
||||
**决策**: 在 QualityGate.validate() 中增加可选的 `skill_context` 参数,当提供时检查输出内容是否与路由到的技能的能力范围一致。使用规则检查(关键词覆盖度)而非 LLM 语义检查,保持零额外成本。
|
||||
|
||||
**理由**: 当前 QualityGate 只检查输出格式(必填字段、字数、Schema),不检查输出内容是否与路由技能匹配。3个用例虽然 HTTP 成功但路由到了错误技能,质量门控未能拦截。
|
||||
|
||||
**实现状态**: 待实现。当前 `validate()` 仅有四维度检查(`gate.py` L37-114)。
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- HeuristicClassifier 评分逻辑优化(代码已完成,待补测试)
|
||||
- IntentRouter._match_keywords() 多候选评分排序
|
||||
- QualityGate 增加技能匹配验证维度
|
||||
- 更新回测基准数据集以反映新的评分逻辑
|
||||
- 改进后重跑回测验证
|
||||
|
||||
### Out of Scope
|
||||
- LLM 分类器优化(merged_llm_classify 和 _classify_with_llm 已有实现,不在本次优化范围)
|
||||
- SemanticRouter 优化(需要嵌入模型,属于独立优化方向)
|
||||
- ExpertTeamRouter 在服务器启动时的注入(已实现但未接入 create_app,属于部署配置问题)
|
||||
- 新增技能配置文件
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
- 训练专用意图分类模型替代规则匹配(长期方向)
|
||||
- 构建复杂度校准数据集持续优化阈值
|
||||
- 实现自动质量回归检测 CI 流水线
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. HeuristicClassifier 复杂度评分校准
|
||||
|
||||
**Goal**: 修复复杂度评分偏差,使简单问答得低分、复杂任务得高分,提升执行模式准确率
|
||||
|
||||
**Requirements**: R1, R4
|
||||
|
||||
**Dependencies**: None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/chat/skill_routing.py` — HeuristicClassifier 类(**代码已完成**)
|
||||
- `tests/unit/chat/test_skill_routing.py` — 新增复杂度校准测试(**待编写**)
|
||||
|
||||
**Approach:**
|
||||
|
||||
代码已实现以下改动:
|
||||
|
||||
1. 增加低复杂度关键词列表 `_LOW_COMPLEXITY_HINTS_CN`(17个词)和 `_LOW_COMPLEXITY_HINTS_EN`(14个词),命中时基础分数为 0.05,且不再累加高复杂度词分数。
|
||||
|
||||
2. 增加否定上下文检测 `_NEGATION_PATTERNS`,匹配"不要/无需/不用/don't/no need/without"后跟的词,该词不计入高复杂度匹配。
|
||||
|
||||
3. 调整基础分数阈值:无关键词命中时基础分 0.10(原 0.15),中等复杂度命中基础分 0.35(原 0.45)。
|
||||
|
||||
4. 增加短疑问句检测 `_SHORT_QUESTION_RE`:以"?"或"?"结尾且长度 <30 字符时,额外 -0.10。
|
||||
|
||||
**剩余工作**: 编写单元测试验证分类器行为。
|
||||
|
||||
**Patterns to follow:** 现有 `test_skill_routing.py` 中的测试类结构(`TestExpertTeamRouterCanHandle` 等)
|
||||
|
||||
**Test scenarios:**
|
||||
|
||||
- **低复杂度信号优先检测**
|
||||
- "你好" → 复杂度 < 0.3(命中 `_LOW_COMPLEXITY_HINTS_CN`)
|
||||
- "Hello" → 复杂度 < 0.3(命中 `_LOW_COMPLEXITY_HINTS_EN`)
|
||||
- "嗨,早上好" → 复杂度 < 0.3(多个低复杂度词命中)
|
||||
- "你好,请帮我分析一下这个数据" → 复杂度 < 0.15(低复杂度信号优先,不累加高复杂度词)
|
||||
|
||||
- **身份查询**
|
||||
- "你是谁" → 复杂度 < 0.3
|
||||
- "你叫什么" → 复杂度 < 0.3
|
||||
|
||||
- **否定上下文排除**
|
||||
- "不要搜索" → "搜索"不计入高复杂度匹配,复杂度 < 0.3
|
||||
- "无需分析,直接告诉我答案" → "分析"被否定,复杂度 < 0.3
|
||||
- "分析市场趋势,但不要搜索" → "搜索"被否定但"分析"未被否定,复杂度 > 0.5
|
||||
|
||||
- **阈值调整验证**
|
||||
- 无关键词的短消息("好的")→ 复杂度 ≤ 0.10
|
||||
- 含中等复杂度词("如何使用Python?")→ 基础分 0.35 而非 0.45
|
||||
|
||||
- **短疑问句扣减**
|
||||
- "怎么用?" → 复杂度 < 0.3(短疑问句 -0.10)
|
||||
- "如何设计一个高可用的微服务架构?" → 复杂度 > 0.5(长疑问句不扣减)
|
||||
|
||||
- **复杂任务高分**
|
||||
- "分析市场趋势并生成报告" → 复杂度 > 0.7(2个高复杂度词命中)
|
||||
- "执行部署脚本并重启服务" → 复杂度 > 0.7
|
||||
|
||||
- **边界条件**
|
||||
- 空字符串 → 复杂度 0.0
|
||||
- 纯空格 → 复杂度 0.0
|
||||
- 超长低复杂度消息(>200字符的问候)→ 复杂度 ≤ 0.10
|
||||
|
||||
**Verification:** `pytest tests/unit/chat/test_skill_routing.py -v`,所有 HeuristicClassifier 测试通过
|
||||
|
||||
---
|
||||
|
||||
### U2. IntentRouter 多候选评分排序
|
||||
|
||||
**Goal**: 解决首次匹配导致的技能混淆,使匹配更精确的技能胜出
|
||||
|
||||
**Requirements**: R2, R4
|
||||
|
||||
**Dependencies**: None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/router/intent.py` — IntentRouter._match_keywords()
|
||||
- `tests/unit/router/test_intent.py` — 新建多候选排序测试
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. 重写 `_match_keywords()` 方法(当前为 `intent.py` L75-99):
|
||||
|
||||
当前逻辑(首次匹配):
|
||||
```
|
||||
for skill in skills:
|
||||
for keyword in keywords:
|
||||
if keyword in combined_text:
|
||||
return RoutingResult(matched_skill=skill.name, ...)
|
||||
return None
|
||||
```
|
||||
|
||||
改为多候选评分:
|
||||
```
|
||||
candidates = []
|
||||
for skill in skills:
|
||||
matched_kws = [kw for kw in skill.config.intent.keywords if kw.lower() in combined_text]
|
||||
if matched_kws:
|
||||
score = sum(len(kw) for kw in matched_kws) # 更长关键词权重更高
|
||||
candidates.append((skill, matched_kws, score))
|
||||
if not candidates:
|
||||
return None
|
||||
candidates.sort(key=lambda c: (-c[2], c[0].name)) # 得分降序,同名字母序
|
||||
best_skill, best_kws, best_score = candidates[0]
|
||||
confidence = min(1.0, 0.5 + 0.1 * len(best_kws))
|
||||
return RoutingResult(matched_skill=best_skill.name, method="keyword", confidence=confidence)
|
||||
```
|
||||
|
||||
2. 保持 `RoutingResult` 数据类接口不变,`method` 仍为 `"keyword"`。
|
||||
|
||||
3. 向后兼容:单候选时行为与原来一致(只有一个 skill 匹配时,排序无影响)。
|
||||
|
||||
4. 需要创建 `tests/unit/router/` 目录和 `__init__.py`。
|
||||
|
||||
**Patterns to follow:** 现有 `RoutingResult` 数据类结构;`_extract_string_values()` 的输入处理方式
|
||||
|
||||
**Test scenarios:**
|
||||
|
||||
- **单候选匹配** — 输入只匹配一个 skill 的关键词,行为与原来一致,confidence=1.0
|
||||
- **多候选匹配 — 得分不同** — 输入同时匹配 skill_a(关键词"规划"2字)和 skill_b(关键词"调研报告"4字),skill_b 得分更高应胜出
|
||||
- **多候选匹配 — 得分相同** — 两个 skill 得分相同时,按名称字母序稳定排序
|
||||
- **无匹配** — 无任何关键词命中,返回 None
|
||||
- **空关键词列表** — skill 的 intent.keywords 为空列表,不参与匹配
|
||||
- **大小写不敏感** — 英文关键词 "Search" 应匹配 "search"
|
||||
- **子串匹配行为** — 中文关键词"报告"应匹配包含"报告"的输入(保持现有子串匹配语义)
|
||||
- **confidence 计算** — 匹配1个关键词 confidence=0.6,匹配3个 confidence=0.8,上限 1.0
|
||||
|
||||
**Verification:** `pytest tests/unit/router/test_intent.py -v`,多候选排序测试通过
|
||||
|
||||
---
|
||||
|
||||
### U3. QualityGate 技能匹配验证
|
||||
|
||||
**Goal**: 增加路由一致性检查,拦截技能匹配错误的低质量输出
|
||||
|
||||
**Requirements**: R3, R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/quality/gate.py` — QualityGate.validate()
|
||||
- `tests/unit/quality/test_gate.py` — 新建技能匹配验证测试
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. 在 `QualityGate.validate()` 签名中增加可选参数 `skill_context: dict | None = None`:
|
||||
```python
|
||||
async def validate(
|
||||
self,
|
||||
output: dict[str, Any],
|
||||
skill: Skill,
|
||||
skill_context: dict | None = None, # 新增
|
||||
) -> QualityResult:
|
||||
```
|
||||
|
||||
2. `skill_context` 结构:`{"skill_name": str, "intent_keywords": list[str]}`
|
||||
|
||||
3. 当 `skill_context` 提供且 `intent_keywords` 非空时,增加第五维度检查"技能匹配验证":
|
||||
- 将 output 中所有字符串值拼接
|
||||
- 检查拼接文本是否包含至少一个 `intent_keywords` 中的关键词(子串匹配)
|
||||
- 如果 0 个关键词匹配 → `QualityCheck(name="skill_match", passed=True, message="Warning: output may not match routed skill")` — 警告但不拦截
|
||||
- 如果 ≥ 1 个关键词匹配 → `QualityCheck(name="skill_match", passed=True)` — 静默通过
|
||||
|
||||
4. 警告升级为失败的组合逻辑:当 `skill_match` 警告存在且其他任何维度检查失败时,`skill_match` 的 `passed` 也变为 `False`,导致整体 `passed=False`。
|
||||
|
||||
5. 保持向后兼容:`skill_context` 为 None 或缺少 `intent_keywords` 时,行为与原来完全一致(四维度检查)。
|
||||
|
||||
**Patterns to follow:** 现有四维度检查模式(`gate.py` L50-114);`QualityCheck` 数据类
|
||||
|
||||
**Test scenarios:**
|
||||
|
||||
- **无 skill_context** — 行为与原来一致,仅四维度检查
|
||||
- **skill_context=None** — 等同于无 skill_context
|
||||
- **skill_context 缺少 intent_keywords** — 等同于无 skill_context
|
||||
- **有 skill_context 且输出包含关键词** — 通过,无警告消息
|
||||
- **有 skill_context 且输出不包含任何关键词** — 通过但有警告消息
|
||||
- **输出无关 + 其他维度失败** — skill_match passed=False,整体 passed=False
|
||||
- **输出无关 + 其他维度全部通过** — skill_match passed=True(仅警告),整体 passed=True
|
||||
- **空 intent_keywords 列表** — 跳过技能匹配检查
|
||||
|
||||
**Verification:** `pytest tests/unit/quality/test_gate.py -v`,技能匹配验证测试通过
|
||||
|
||||
---
|
||||
|
||||
### U4. 回测验证与基准更新
|
||||
|
||||
**Goal**: 验证改进效果,更新基准数据集
|
||||
|
||||
**Requirements**: R4
|
||||
|
||||
**Dependencies:** U1, U2, U3
|
||||
|
||||
**Files:**
|
||||
- `tests/e2e/test_capability_router_direct.py` — 使用真实 LLM 回测
|
||||
- `tests/e2e/benchmark_dataset.py` — 可能需要更新预期值
|
||||
- `test-results/e2e/capability_report.txt` — 对比改进前后报告
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. 运行完整回测:`python3 -m pytest tests/e2e/test_capability_router_direct.py -v`
|
||||
|
||||
2. 对比改进前后指标:
|
||||
- 执行模式准确率:9.09% → 目标 >30%
|
||||
- 技能路由 F1:66.67% → 目标 >80%
|
||||
- 任务成功率:100% → 保持
|
||||
|
||||
3. 如果基准数据集中的预期值因评分逻辑变化需要调整,更新 `benchmark_dataset.py`
|
||||
|
||||
4. 保存改进后报告为基线:`cp test-results/e2e/capability_report.json test-results/e2e/baseline_capability_report.json`
|
||||
|
||||
**Test scenarios:**
|
||||
- 回测全部通过
|
||||
- 执行模式准确率 >30%
|
||||
- 技能路由 F1 >80%
|
||||
- 无回归(任务成功率不下降)
|
||||
|
||||
**Verification:** 运行回测并检查报告指标
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
| 风险 | 影响 | 缓解 |
|
||||
|------|------|------|
|
||||
| 复杂度评分调整可能过度修正,导致复杂任务被判为简单 | 高复杂度任务路由到 DIRECT_CHAT,无法使用工具 | 保留 merged_llm_classify 兜底机制,0.3-0.7 区间仍由 LLM 二次确认 |
|
||||
| 多候选排序可能改变现有路由行为的兼容性 | 已有用户依赖的路由结果可能变化 | 排序逻辑仅在多候选时生效,单候选行为不变 |
|
||||
| QualityGate 技能匹配验证的"相关词"判断可能误报 | 正常输出被标记为警告 | 使用 warning 级别而非 error,不单独拦截 |
|
||||
| keyword_match 召回率 0% 的根因可能不仅是 IntentRouter | 即使修复多候选排序,仍可能因技能配置关键词不匹配而召回率低 | U4 回测后若仍低,需进一步分析技能配置与基准用例的对齐度 |
|
||||
|
||||
---
|
||||
|
||||
## Open Questions
|
||||
|
||||
- 复杂度评分的具体阈值已在代码中设定初始值(0.05/0.10/0.35/0.65/0.80),需通过 U4 回测校准
|
||||
- 否定上下文检测的正则模式覆盖度需在回测中验证,可能需要迭代补充
|
||||
- keyword_match 召回率 0% 是否完全由 IntentRouter 首次匹配导致,还是技能配置关键词本身与基准用例不对齐 — 需 U2 实现后通过 U4 验证
|
||||
|
|
@ -0,0 +1,328 @@
|
|||
#!/usr/bin/env bash
|
||||
# =============================================================================
|
||||
# Fischer AgentKit — E2E Backtest Runner
|
||||
# =============================================================================
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/run_e2e.sh # Run all E2E tests
|
||||
# ./scripts/run_e2e.sh --basic # Run basic function tests only
|
||||
# ./scripts/run_e2e.sh --capability # Run agent capability tests only
|
||||
# ./scripts/run_e2e.sh --cli # Run CLI tests only
|
||||
# ./scripts/run_e2e.sh --api # Run API tests only
|
||||
# ./scripts/run_e2e.sh --ws # Run WebSocket tests only
|
||||
# ./scripts/run_e2e.sh --routing # Run routing intelligence tests
|
||||
# ./scripts/run_e2e.sh --react # Run ReAct intelligence tests
|
||||
# ./scripts/run_e2e.sh --team # Run team collaboration tests
|
||||
# ./scripts/run_e2e.sh --report # Generate HTML report
|
||||
# ./scripts/run_e2e.sh --analyze # Run capability tests + generate analysis report
|
||||
# ./scripts/run_e2e.sh --direct # Run router direct backtest only (no HTTP)
|
||||
# ./scripts/run_e2e.sh --alignment # Run alignment guard tests only
|
||||
# ./scripts/run_e2e.sh --full # Run all: API + direct + alignment
|
||||
# ./scripts/run_e2e.sh --baseline # Compare with last baseline report
|
||||
#
|
||||
# Environment:
|
||||
# E2E_PORT - Server port (default: 18765)
|
||||
# E2E_API_KEY - API key for auth (default: ak_live_e2e_test_key_...)
|
||||
# SKIP_SERVER - Set to "1" to skip server startup (use existing)
|
||||
# =============================================================================
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# ── Configuration ────────────────────────────────────────────────────────────
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
E2E_PORT="${E2E_PORT:-18765}"
|
||||
E2E_API_KEY="${E2E_API_KEY:-ak_live_e2e_test_key_000000000000000000000000000000000000000000000000}"
|
||||
REPORT_DIR="${PROJECT_ROOT}/test-results/e2e"
|
||||
SKIP_SERVER="${SKIP_SERVER:-0}"
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# ── Colors ───────────────────────────────────────────────────────────────────
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# ── Helper Functions ─────────────────────────────────────────────────────────
|
||||
|
||||
info() { echo -e "${BLUE}[INFO]${NC} $*"; }
|
||||
ok() { echo -e "${GREEN}[OK]${NC} $*"; }
|
||||
warn() { echo -e "${YELLOW}[WARN]${NC} $*"; }
|
||||
fail() { echo -e "${RED}[FAIL]${NC} $*"; }
|
||||
|
||||
check_deps() {
|
||||
local missing=0
|
||||
for cmd in python3; do
|
||||
if ! command -v "$cmd" &>/dev/null; then
|
||||
fail "Missing dependency: $cmd"
|
||||
missing=1
|
||||
fi
|
||||
done
|
||||
if [ "$missing" -eq 1 ]; then
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
local max_attempts=60
|
||||
local attempt=0
|
||||
info "Waiting for server on port $E2E_PORT..."
|
||||
while [ $attempt -lt $max_attempts ]; do
|
||||
if curl -s "http://127.0.0.1:$E2E_PORT/api/v1/health" &>/dev/null; then
|
||||
ok "Server is ready on port $E2E_PORT"
|
||||
return 0
|
||||
fi
|
||||
attempt=$((attempt + 1))
|
||||
sleep 0.5
|
||||
done
|
||||
fail "Server failed to start within 30 seconds"
|
||||
return 1
|
||||
}
|
||||
|
||||
start_server() {
|
||||
if [ "$SKIP_SERVER" = "1" ]; then
|
||||
info "SKIP_SERVER=1, using existing server on port $E2E_PORT"
|
||||
if curl -s "http://127.0.0.1:$E2E_PORT/api/v1/health" &>/dev/null; then
|
||||
ok "Existing server is healthy"
|
||||
return 0
|
||||
else
|
||||
fail "Existing server is not responding"
|
||||
return 1
|
||||
fi
|
||||
fi
|
||||
|
||||
info "Starting AgentKit E2E server on port $E2E_PORT..."
|
||||
export AGENTKIT_E2E_MODE=1
|
||||
export AGENTKIT_WS_TIMEOUT=0
|
||||
export AGENTKIT_API_KEY="$E2E_API_KEY"
|
||||
|
||||
# Start server in background
|
||||
python3 -m agentkit.cli.main serve --host 127.0.0.1 --port "$E2E_PORT" &
|
||||
SERVER_PID=$!
|
||||
|
||||
if wait_for_server; then
|
||||
return 0
|
||||
else
|
||||
kill "$SERVER_PID" 2>/dev/null || true
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
stop_server() {
|
||||
if [ "$SKIP_SERVER" = "1" ]; then
|
||||
info "SKIP_SERVER=1, not stopping server"
|
||||
return 0
|
||||
fi
|
||||
if [ -n "${SERVER_PID:-}" ]; then
|
||||
info "Stopping E2E server (PID: $SERVER_PID)..."
|
||||
kill "$SERVER_PID" 2>/dev/null || true
|
||||
wait "$SERVER_PID" 2>/dev/null || true
|
||||
ok "Server stopped"
|
||||
fi
|
||||
}
|
||||
|
||||
# ── Test Selection ───────────────────────────────────────────────────────────
|
||||
|
||||
PYTEST_ARGS=("--timeout=120" "-v" "--tb=short" "-s")
|
||||
TEST_TARGET="tests/e2e/"
|
||||
GENERATE_REPORT=0
|
||||
ANALYZE=0
|
||||
SKIP_SERVER_FLAG=0
|
||||
BASELINE_COMPARE=0
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--basic)
|
||||
PYTEST_ARGS+=("-m" "e2e_basic")
|
||||
shift
|
||||
;;
|
||||
--capability)
|
||||
PYTEST_ARGS+=("-m" "e2e_capability")
|
||||
shift
|
||||
;;
|
||||
--cli)
|
||||
TEST_TARGET="tests/e2e/test_basic_cli.py"
|
||||
shift
|
||||
;;
|
||||
--api)
|
||||
TEST_TARGET="tests/e2e/test_basic_api.py"
|
||||
shift
|
||||
;;
|
||||
--ws)
|
||||
TEST_TARGET="tests/e2e/test_basic_websocket.py"
|
||||
shift
|
||||
;;
|
||||
--routing)
|
||||
TEST_TARGET="tests/e2e/test_capability_routing.py"
|
||||
shift
|
||||
;;
|
||||
--react)
|
||||
TEST_TARGET="tests/e2e/test_capability_react.py"
|
||||
shift
|
||||
;;
|
||||
--team)
|
||||
TEST_TARGET="tests/e2e/test_capability_team.py"
|
||||
shift
|
||||
;;
|
||||
--direct)
|
||||
# Router direct backtest — no HTTP server needed
|
||||
TEST_TARGET="tests/e2e/test_capability_router_direct.py"
|
||||
SKIP_SERVER_FLAG=1
|
||||
shift
|
||||
;;
|
||||
--alignment)
|
||||
# Alignment guard tests — no HTTP server needed
|
||||
TEST_TARGET="tests/e2e/test_capability_alignment.py"
|
||||
SKIP_SERVER_FLAG=1
|
||||
shift
|
||||
;;
|
||||
--full)
|
||||
# Run all capability tests: API + direct + alignment
|
||||
PYTEST_ARGS+=("-m" "e2e_capability")
|
||||
shift
|
||||
;;
|
||||
--baseline)
|
||||
BASELINE_COMPARE=1
|
||||
shift
|
||||
;;
|
||||
--report)
|
||||
GENERATE_REPORT=1
|
||||
shift
|
||||
;;
|
||||
--analyze)
|
||||
ANALYZE=1
|
||||
PYTEST_ARGS+=("-m" "e2e_capability")
|
||||
shift
|
||||
;;
|
||||
--fast)
|
||||
PYTEST_ARGS+=("-x" "--timeout=30")
|
||||
shift
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [--basic|--capability|--cli|--api|--ws|--routing|--react|--team|--direct|--alignment|--full|--baseline|--report|--analyze|--fast]"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
PYTEST_ARGS+=("$1")
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ "$GENERATE_REPORT" -eq 1 ]; then
|
||||
mkdir -p "$REPORT_DIR"
|
||||
PYTEST_ARGS+=(
|
||||
"--html=$REPORT_DIR/e2e_report.html"
|
||||
"--self-contained-html"
|
||||
"--junitxml=$REPORT_DIR/e2e_junit.xml"
|
||||
)
|
||||
fi
|
||||
|
||||
if [ "$ANALYZE" -eq 1 ]; then
|
||||
info "Analysis mode: will generate capability report with recall/F1/overfitting analysis"
|
||||
fi
|
||||
|
||||
# Override SKIP_SERVER when --direct or --alignment is used (no HTTP needed)
|
||||
if [ "$SKIP_SERVER_FLAG" -eq 1 ]; then
|
||||
SKIP_SERVER=1
|
||||
fi
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
info "Fischer AgentKit E2E Backtest Runner"
|
||||
info "====================================="
|
||||
info "Project: $PROJECT_ROOT"
|
||||
info "Port: $E2E_PORT"
|
||||
info "Target: $TEST_TARGET"
|
||||
info ""
|
||||
|
||||
check_deps
|
||||
|
||||
# Trap to ensure server cleanup
|
||||
trap stop_server EXIT INT TERM
|
||||
|
||||
if ! start_server; then
|
||||
fail "Could not start E2E server"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
info ""
|
||||
info "Running E2E tests..."
|
||||
info "===================="
|
||||
info ""
|
||||
|
||||
export AGENTKIT_SERVER_URL="http://127.0.0.1:$E2E_PORT"
|
||||
export AGENTKIT_API_KEY="$E2E_API_KEY"
|
||||
|
||||
EXIT_CODE=0
|
||||
python3 -m pytest "$TEST_TARGET" "${PYTEST_ARGS[@]}" || EXIT_CODE=$?
|
||||
|
||||
echo ""
|
||||
if [ $EXIT_CODE -eq 0 ]; then
|
||||
ok "All E2E tests passed!"
|
||||
else
|
||||
fail "Some E2E tests failed (exit code: $EXIT_CODE)"
|
||||
fi
|
||||
|
||||
if [ "$GENERATE_REPORT" -eq 1 ]; then
|
||||
info "Report generated at: $REPORT_DIR/e2e_report.html"
|
||||
fi
|
||||
|
||||
if [ "$ANALYZE" -eq 1 ]; then
|
||||
CAPABILITY_REPORT="$PROJECT_ROOT/test-results/e2e/capability_report.txt"
|
||||
if [ -f "$CAPABILITY_REPORT" ]; then
|
||||
info "Capability analysis report:"
|
||||
echo ""
|
||||
cat "$CAPABILITY_REPORT"
|
||||
else
|
||||
warn "Capability report not found (may need capability tests to run first)"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$BASELINE_COMPARE" -eq 1 ]; then
|
||||
CURRENT_REPORT="$PROJECT_ROOT/test-results/e2e/capability_report.json"
|
||||
BASELINE_REPORT="$PROJECT_ROOT/test-results/e2e/baseline_capability_report.json"
|
||||
if [ -f "$CURRENT_REPORT" ] && [ -f "$BASELINE_REPORT" ]; then
|
||||
info "Baseline comparison:"
|
||||
python3 -c "
|
||||
import json, sys
|
||||
|
||||
def load_metrics(path):
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
cur = load_metrics('$CURRENT_REPORT')
|
||||
base = load_metrics('$BASELINE_REPORT')
|
||||
|
||||
metrics = [
|
||||
('overall_skill_recall', '技能路由召回率'),
|
||||
('overall_skill_precision', '技能路由精确率'),
|
||||
('overall_skill_f1', '技能路由F1'),
|
||||
('overall_execution_mode_accuracy', '执行模式准确率'),
|
||||
('overall_task_success_rate', '任务成功率'),
|
||||
('overfitting_score', '过拟合分数'),
|
||||
]
|
||||
|
||||
print()
|
||||
for key, label in metrics:
|
||||
c = cur.get(key, 0)
|
||||
b = base.get(key, 0)
|
||||
delta = c - b
|
||||
arrow = '↑' if delta > 0 else ('↓' if delta < 0 else '→')
|
||||
print(f' {label}: {b:.2%} → {c:.2%} {arrow} {delta:+.2%}')
|
||||
print()
|
||||
"
|
||||
elif [ -f "$CURRENT_REPORT" ]; then
|
||||
info "No baseline report found. Saving current report as baseline."
|
||||
cp "$CURRENT_REPORT" "$BASELINE_REPORT"
|
||||
info "Baseline saved to: $BASELINE_REPORT"
|
||||
else
|
||||
warn "No current report found. Run with --analyze first."
|
||||
fi
|
||||
fi
|
||||
|
||||
exit $EXIT_CODE
|
||||
|
|
@ -33,9 +33,31 @@ class ExecutionMode(enum.Enum):
|
|||
DIRECT_CHAT = "direct_chat" # Zero-cost: direct LLM call, no ReAct loop
|
||||
REACT = "react" # Default agent ReAct loop with default tools
|
||||
SKILL_REACT = "skill_react" # Skill-matched ReAct with skill tools + prompt
|
||||
REWOO = "rewoo" # Plan-without-observation mode
|
||||
REFLEXION = "reflexion" # Reflection-driven mode
|
||||
PLAN_EXEC = "plan_exec" # Plan-then-execute mode
|
||||
TEAM_COLLAB = "team_collab" # Expert Team collaborative mode
|
||||
|
||||
|
||||
# Mapping from skill config execution_mode string to ExecutionMode enum
|
||||
_SKILL_EXECUTION_MODE_MAP: dict[str, ExecutionMode] = {
|
||||
"direct": ExecutionMode.DIRECT_CHAT,
|
||||
"react": ExecutionMode.SKILL_REACT,
|
||||
"rewoo": ExecutionMode.REWOO,
|
||||
"reflexion": ExecutionMode.REFLEXION,
|
||||
"plan_exec": ExecutionMode.PLAN_EXEC,
|
||||
"custom": ExecutionMode.SKILL_REACT,
|
||||
"llm_generate": ExecutionMode.SKILL_REACT,
|
||||
"tool_call": ExecutionMode.SKILL_REACT,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_execution_mode(skill_config: Any) -> ExecutionMode:
|
||||
"""Resolve ExecutionMode from skill config's execution_mode field."""
|
||||
mode_str = getattr(skill_config, "execution_mode", "react") or "react"
|
||||
return _SKILL_EXECUTION_MODE_MAP.get(mode_str, ExecutionMode.SKILL_REACT)
|
||||
|
||||
|
||||
def validate_skill_name(name: str) -> str:
|
||||
"""Validate and normalize a skill name. Raises ValueError on invalid input."""
|
||||
normalized = name.strip().lower()
|
||||
|
|
@ -265,7 +287,8 @@ async def resolve_skill_routing(
|
|||
else default_model
|
||||
)
|
||||
result.agent_name = result.skill_name
|
||||
result.execution_mode = ExecutionMode.SKILL_REACT
|
||||
# Map skill.config.execution_mode to ExecutionMode enum
|
||||
result.execution_mode = _resolve_execution_mode(result.skill_config)
|
||||
else:
|
||||
result.system_prompt = default_system_prompt
|
||||
result.tools = default_tools
|
||||
|
|
@ -596,21 +619,10 @@ class HeuristicClassifier:
|
|||
content_lower = content.lower()
|
||||
score = 0.0
|
||||
|
||||
# 0. 低复杂度信号检测(优先级最高)
|
||||
# 0. 低复杂度信号检测(仅在无高复杂度信号时生效)
|
||||
low_hits_cn = sum(1 for h in self._LOW_COMPLEXITY_HINTS_CN if h in content_lower)
|
||||
low_hits_en = sum(
|
||||
1 for h in self._LOW_COMPLEXITY_HINTS_EN if h in content_lower
|
||||
)
|
||||
if low_hits_cn + low_hits_en > 0:
|
||||
score = 0.05 # 问候/闲聊直接给极低分
|
||||
# 低复杂度信号下不再累加高复杂度词的分数
|
||||
# 但仍保留长度和多句的微调
|
||||
length = len(content)
|
||||
if length > 200:
|
||||
score += 0.05
|
||||
elif length > 100:
|
||||
score += 0.03
|
||||
return max(0.0, min(1.0, score))
|
||||
low_hits_en = sum(1 for h in self._LOW_COMPLEXITY_HINTS_EN if h in content_lower)
|
||||
has_low_signal = low_hits_cn + low_hits_en > 0
|
||||
|
||||
# 1. 否定上下文检测 — 提取被否定的词
|
||||
negated_words: set[str] = set()
|
||||
|
|
@ -624,21 +636,27 @@ class HeuristicClassifier:
|
|||
for h in self._HIGH_COMPLEXITY_HINTS_CN
|
||||
if h in content_lower and h not in negated_words
|
||||
)
|
||||
medium_hits = sum(
|
||||
1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower
|
||||
)
|
||||
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower)
|
||||
|
||||
# 英文:词边界匹配
|
||||
high_en_matches = self._HIGH_EN_RE.findall(content) + self._HIGH_EXACT_RE.findall(
|
||||
content
|
||||
)
|
||||
high_hits += sum(
|
||||
1 for w in high_en_matches if w.lower() not in negated_words
|
||||
)
|
||||
high_en_matches = self._HIGH_EN_RE.findall(content) + self._HIGH_EXACT_RE.findall(content)
|
||||
high_hits += sum(1 for w in high_en_matches if w.lower() not in negated_words)
|
||||
medium_hits += len(self._MEDIUM_EN_RE.findall(content)) + len(
|
||||
self._MEDIUM_EXACT_RE.findall(content)
|
||||
)
|
||||
|
||||
has_high_signal = high_hits > 0 or medium_hits > 0
|
||||
|
||||
# 低复杂度信号仅在无高/中复杂度信号时生效
|
||||
if has_low_signal and not has_high_signal:
|
||||
score = 0.05 # 问候/闲聊直接给极低分
|
||||
length = len(content)
|
||||
if length > 200:
|
||||
score += 0.05
|
||||
elif length > 100:
|
||||
score += 0.03
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
if high_hits >= 2:
|
||||
score = 0.80
|
||||
elif high_hits == 1:
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ class QualityGate:
|
|||
self,
|
||||
output: dict[str, Any],
|
||||
skill: Skill,
|
||||
skill_context: dict[str, Any] | None = None,
|
||||
) -> QualityResult:
|
||||
"""对产出执行多维度质量检查
|
||||
|
||||
|
|
@ -46,6 +47,7 @@ class QualityGate:
|
|||
2. 最低字数检查
|
||||
3. JSON Schema 验证(如 skill.config.output_schema 存在)
|
||||
4. 自定义验证器(如 skill.config.quality_gate.custom_validator 存在)
|
||||
5. 技能匹配验证(如 skill_context 含 intent_keywords)
|
||||
"""
|
||||
checks: list[QualityCheck] = []
|
||||
qg = skill.config.quality_gate
|
||||
|
|
@ -53,11 +55,13 @@ class QualityGate:
|
|||
# 1. 必填字段检查
|
||||
for field in qg.required_fields:
|
||||
present = field in output and output[field] is not None
|
||||
checks.append(QualityCheck(
|
||||
checks.append(
|
||||
QualityCheck(
|
||||
name=f"required_field:{field}",
|
||||
passed=present,
|
||||
message=f"Field '{field}' is missing" if not present else None,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# 2. 最低字数检查
|
||||
if qg.min_word_count > 0:
|
||||
|
|
@ -67,7 +71,8 @@ class QualityGate:
|
|||
else:
|
||||
word_count = len(str(content).split())
|
||||
passed = word_count >= qg.min_word_count
|
||||
checks.append(QualityCheck(
|
||||
checks.append(
|
||||
QualityCheck(
|
||||
name="min_word_count",
|
||||
passed=passed,
|
||||
message=(
|
||||
|
|
@ -75,7 +80,8 @@ class QualityGate:
|
|||
if not passed
|
||||
else None
|
||||
),
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# 3. JSON Schema 验证
|
||||
if skill.config.output_schema:
|
||||
|
|
@ -101,11 +107,34 @@ class QualityGate:
|
|||
checks.append(QualityCheck(name="custom", passed=bool(result)))
|
||||
except Exception as e:
|
||||
# 验证器导入/执行失败,跳过并记录警告
|
||||
checks.append(QualityCheck(
|
||||
checks.append(
|
||||
QualityCheck(
|
||||
name="custom",
|
||||
passed=True,
|
||||
message=f"Validator skipped: {e}",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 技能匹配验证(轻量级路由一致性检查)
|
||||
skill_match_check = self._check_skill_match(output, skill_context)
|
||||
if skill_match_check is not None:
|
||||
checks.append(skill_match_check)
|
||||
|
||||
# 警告升级逻辑:当 skill_match 警告存在且其他维度有失败时,升级为失败
|
||||
if (
|
||||
skill_match_check is not None
|
||||
and skill_match_check.message
|
||||
and "Warning" in skill_match_check.message
|
||||
):
|
||||
other_failed = any(not c.passed for c in checks if c is not skill_match_check)
|
||||
if other_failed:
|
||||
# 升级:将 skill_match 的 passed 也设为 False
|
||||
checks = [
|
||||
QualityCheck(name=c.name, passed=False, message=c.message)
|
||||
if c is skill_match_check
|
||||
else c
|
||||
for c in checks
|
||||
]
|
||||
|
||||
return QualityResult(
|
||||
passed=all(c.passed for c in checks),
|
||||
|
|
@ -119,6 +148,42 @@ class QualityGate:
|
|||
"app.agent_framework.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _check_skill_match(
|
||||
output: dict[str, Any],
|
||||
skill_context: dict[str, Any] | None,
|
||||
) -> QualityCheck | None:
|
||||
"""第五维度:技能匹配验证
|
||||
|
||||
当 skill_context 含 intent_keywords 时,检查输出内容是否包含
|
||||
至少一个关键词。不匹配时标记为警告(passed=True + message),
|
||||
当其他维度也有失败时升级为 passed=False。
|
||||
|
||||
Returns:
|
||||
QualityCheck 或 None(当 skill_context 无效时跳过)
|
||||
"""
|
||||
if not skill_context:
|
||||
return None
|
||||
|
||||
intent_keywords: list[str] | None = skill_context.get("intent_keywords")
|
||||
if not intent_keywords:
|
||||
return None
|
||||
|
||||
# 拼接输出中所有字符串值
|
||||
all_text = " ".join(
|
||||
str(v) for v in output.values() if isinstance(v, (str, int, float, bool))
|
||||
).lower()
|
||||
|
||||
matched = any(kw.lower() in all_text for kw in intent_keywords)
|
||||
if matched:
|
||||
return QualityCheck(name="skill_match", passed=True)
|
||||
|
||||
return QualityCheck(
|
||||
name="skill_match",
|
||||
passed=True, # 警告级别,不单独拦截
|
||||
message="Warning: output may not match routed skill",
|
||||
)
|
||||
|
||||
def _import_validator(self, dotted_path: str) -> Callable:
|
||||
"""从点分路径导入自定义验证器函数
|
||||
|
||||
|
|
|
|||
|
|
@ -75,10 +75,11 @@ class IntentRouter:
|
|||
def _match_keywords(
|
||||
self, input_data: dict[str, Any], skills: list[Skill]
|
||||
) -> RoutingResult | None:
|
||||
"""Level 1: 关键词匹配
|
||||
"""Level 1: 多候选关键词评分匹配
|
||||
|
||||
从 input_data 中提取所有字符串值(包括嵌套),对每个 Skill 的
|
||||
intent.keywords 进行大小写不敏感匹配。
|
||||
intent.keywords 进行大小写不敏感匹配。收集所有匹配候选,
|
||||
按匹配关键词总长度(更长关键词权重更高)排序,返回最佳匹配。
|
||||
"""
|
||||
text_values = self._extract_string_values(input_data)
|
||||
combined_text = " ".join(text_values).lower()
|
||||
|
|
@ -86,18 +87,31 @@ class IntentRouter:
|
|||
if not combined_text:
|
||||
return None
|
||||
|
||||
# 收集所有匹配候选
|
||||
candidates: list[tuple[Skill, list[str], int]] = []
|
||||
for skill in skills:
|
||||
keywords = skill.config.intent.keywords
|
||||
for keyword in keywords:
|
||||
if keyword.lower() in combined_text:
|
||||
return RoutingResult(
|
||||
matched_skill=skill.name,
|
||||
method="keyword",
|
||||
confidence=1.0,
|
||||
)
|
||||
if not keywords:
|
||||
continue
|
||||
matched_kws = [kw for kw in keywords if kw.lower() in combined_text]
|
||||
if matched_kws:
|
||||
score = sum(len(kw) for kw in matched_kws)
|
||||
candidates.append((skill, matched_kws, score))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# 按得分降序排序,得分相同时按 skill 名称字母序稳定排序
|
||||
candidates.sort(key=lambda c: (-c[2], c[0].name))
|
||||
best_skill, best_kws, _best_score = candidates[0]
|
||||
confidence = min(1.0, 0.5 + 0.1 * len(best_kws))
|
||||
|
||||
return RoutingResult(
|
||||
matched_skill=best_skill.name,
|
||||
method="keyword",
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
async def _classify_with_llm(
|
||||
self, input_data: dict[str, Any], skills: list[Skill]
|
||||
) -> RoutingResult:
|
||||
|
|
@ -107,9 +121,7 @@ class IntentRouter:
|
|||
最佳匹配的 Skill。
|
||||
"""
|
||||
if self._llm_gateway is None:
|
||||
raise RuntimeError(
|
||||
"Keyword matching failed and no LLM Gateway configured for fallback"
|
||||
)
|
||||
raise RuntimeError("Keyword matching failed and no LLM Gateway configured for fallback")
|
||||
|
||||
prompt = self._build_classification_prompt(input_data, skills)
|
||||
|
||||
|
|
@ -120,9 +132,7 @@ class IntentRouter:
|
|||
|
||||
return self._parse_llm_response(response.content, skills)
|
||||
|
||||
def _build_classification_prompt(
|
||||
self, input_data: dict[str, Any], skills: list[Skill]
|
||||
) -> str:
|
||||
def _build_classification_prompt(self, input_data: dict[str, Any], skills: list[Skill]) -> str:
|
||||
"""构建 LLM 分类 prompt"""
|
||||
skill_descriptions = []
|
||||
for i, skill in enumerate(skills, 1):
|
||||
|
|
@ -142,13 +152,11 @@ class IntentRouter:
|
|||
"\n"
|
||||
f"User input: {input_data}\n"
|
||||
"\n"
|
||||
'Respond in JSON format:\n'
|
||||
"Respond in JSON format:\n"
|
||||
'{"skill": "skill_name", "confidence": 0.9}'
|
||||
)
|
||||
|
||||
def _parse_llm_response(
|
||||
self, content: str, skills: list[Skill]
|
||||
) -> RoutingResult:
|
||||
def _parse_llm_response(self, content: str, skills: list[Skill]) -> RoutingResult:
|
||||
"""解析 LLM 响应,提取 skill name 和 confidence"""
|
||||
valid_names = {s.name for s in skills}
|
||||
|
||||
|
|
@ -175,9 +183,7 @@ class IntentRouter:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_skill_name_from_text(
|
||||
text: str, valid_names: set[str]
|
||||
) -> str:
|
||||
def _extract_skill_name_from_text(text: str, valid_names: set[str]) -> str:
|
||||
"""从文本中尝试提取有效的 Skill 名称"""
|
||||
text_lower = text.lower()
|
||||
for name in valid_names:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
"""E2E backtest suite for Fischer AgentKit.
|
||||
|
||||
Split into two dimensions:
|
||||
- Basic Functions: verify all features work correctly (CLI, API, WebSocket, lifecycle)
|
||||
- Agent Capabilities: verify intelligence level (routing, reasoning, collaboration)
|
||||
|
||||
Uses subprocess to simulate real CLI operations (OpenCLI pattern),
|
||||
httpx for API calls, and websockets for WS chat.
|
||||
"""
|
||||
|
||||
from tests.e2e.conftest import * # noqa: F401,F403
|
||||
|
|
@ -0,0 +1,830 @@
|
|||
"""Agent Capability Benchmark — Ground Truth Dataset (v2).
|
||||
|
||||
Aligned with actual skills in configs/skills/*.yaml.
|
||||
Contains both manually curated edge cases and auto-generated cases.
|
||||
|
||||
Categories:
|
||||
- routing: intent routing correctness
|
||||
- execution: execution mode selection accuracy
|
||||
- team: expert team collaboration
|
||||
- consistency: deterministic output consistency
|
||||
- semantic_router: semantic similarity matching
|
||||
- alignment: constraint compliance and cascade detection
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BenchmarkCase(BaseModel):
|
||||
"""A single benchmark test case with ground truth label."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
input: str
|
||||
expected_skill: str | None = None
|
||||
expected_execution_mode: str = "direct"
|
||||
expected_complexity: str = "low"
|
||||
category: str
|
||||
subcategory: str
|
||||
paraphrases: list[str] = []
|
||||
tags: list[str] = []
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Routing — Keyword Match (aligned with actual skills)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
ROUTING_KEYWORD_BENCHMARKS: list[BenchmarkCase] = [
|
||||
# direct_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-direct-001",
|
||||
input="翻译这段话",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["帮我翻译一下", "请翻译这段内容", "Translate this text"],
|
||||
tags=["翻译", "translate"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-direct-002",
|
||||
input="帮我总结一下",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["请总结", "给我一个摘要", "Summarize this"],
|
||||
tags=["摘要", "summarize"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-direct-003",
|
||||
input="什么是RAG?",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["RAG是什么", "解释一下RAG", "What is RAG?"],
|
||||
tags=["什么是"],
|
||||
),
|
||||
# react_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-react-001",
|
||||
input="搜索一下AI Agent市场数据",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=[
|
||||
"帮我搜索AI Agent市场信息",
|
||||
"查找AI Agent的市场数据",
|
||||
"Search AI Agent market data",
|
||||
],
|
||||
tags=["搜索", "search"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-react-002",
|
||||
input="帮我分析这个数据",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["分析一下这些数据", "请对数据做分析", "Analyze this data"],
|
||||
tags=["分析", "analyze"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-react-003",
|
||||
input="实时监控竞品动态",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["监控竞争对手的动态", "实时追踪竞品变化", "Monitor competitor activities"],
|
||||
tags=["实时", "监控"],
|
||||
),
|
||||
# rewoo_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-rewoo-001",
|
||||
input="采集A、B、C三个竞品的功能数据",
|
||||
expected_skill="rewoo_agent",
|
||||
expected_execution_mode="rewoo",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=[
|
||||
"批量采集竞品数据",
|
||||
"并行获取多个竞品信息",
|
||||
"Fetch data from multiple competitors",
|
||||
],
|
||||
tags=["采集", "批量", "fetch"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-rewoo-002",
|
||||
input="并行搜索多个关键词",
|
||||
expected_skill="rewoo_agent",
|
||||
expected_execution_mode="rewoo",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["同时搜索多个关键词", "批量搜索", "Search multiple keywords in parallel"],
|
||||
tags=["并行", "批量"],
|
||||
),
|
||||
# reflexion_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-reflex-001",
|
||||
input="审查这段代码的合规性",
|
||||
expected_skill="reflexion_agent",
|
||||
expected_execution_mode="reflexion",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["检查代码是否合规", "审查代码合规问题", "Review code compliance"],
|
||||
tags=["审查", "合规", "review"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-reflex-002",
|
||||
input="生成一个高精度的数据分析脚本",
|
||||
expected_skill="reflexion_agent",
|
||||
expected_execution_mode="reflexion",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=[
|
||||
"写一个精确的数据分析脚本",
|
||||
"生成高精度分析代码",
|
||||
"Generate a precise analysis script",
|
||||
],
|
||||
tags=["代码生成", "精确", "code"],
|
||||
),
|
||||
# plan_exec_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-planexec-001",
|
||||
input="生成一份市场分析报告",
|
||||
expected_skill="plan_exec_agent",
|
||||
expected_execution_mode="plan_exec",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["做一份市场分析报告", "写个市场分析报告", "Generate a market analysis report"],
|
||||
tags=["报告", "分析报告"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-kw-planexec-002",
|
||||
input="规划产品优化方案",
|
||||
expected_skill="plan_exec_agent",
|
||||
expected_execution_mode="plan_exec",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["制定产品优化计划", "帮我规划产品优化", "Plan product optimization"],
|
||||
tags=["规划", "plan"],
|
||||
),
|
||||
# code_reviewer
|
||||
BenchmarkCase(
|
||||
id="route-kw-coderev-001",
|
||||
input="Review this code for quality",
|
||||
expected_skill="code_reviewer",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["审查这段代码的质量", "代码审查", "Check code quality"],
|
||||
tags=["review", "代码审查"],
|
||||
),
|
||||
# geo_optimizer
|
||||
BenchmarkCase(
|
||||
id="route-kw-geo-001",
|
||||
input="帮我优化这篇文章的SEO",
|
||||
expected_skill="geo_optimizer",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["SEO优化一下", "提升文章搜索排名", "Optimize this article for SEO"],
|
||||
tags=["SEO优化", "optimize"],
|
||||
),
|
||||
# deai_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-deai-001",
|
||||
input="帮我把这篇文章去AI化",
|
||||
expected_skill="deai_agent",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["让这段文字更自然", "改写得像人写的", "Make this text more natural"],
|
||||
tags=["去AI化", "人性化"],
|
||||
),
|
||||
# content_generator
|
||||
BenchmarkCase(
|
||||
id="route-kw-content-001",
|
||||
input="帮我写一篇关于AI的文章",
|
||||
expected_skill="content_generator",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["写一篇AI相关文章", "生成关于AI的内容", "Write an article about AI"],
|
||||
tags=["写文章", "generate"],
|
||||
),
|
||||
# citation_detector
|
||||
BenchmarkCase(
|
||||
id="route-kw-citation-001",
|
||||
input="检测我们的品牌在AI平台的引用情况",
|
||||
expected_skill="citation_detector",
|
||||
expected_execution_mode="custom",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=[
|
||||
"分析品牌引用率",
|
||||
"哪些AI平台引用了我们",
|
||||
"Check brand citation on AI platforms",
|
||||
],
|
||||
tags=["引用检测", "citation"],
|
||||
),
|
||||
# trend_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-trend-001",
|
||||
input="分析品牌趋势",
|
||||
expected_skill="trend_agent",
|
||||
expected_execution_mode="tool_call",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["最近的热点话题是什么", "趋势洞察", "Analyze brand trends"],
|
||||
tags=["趋势", "trend"],
|
||||
),
|
||||
# competitor_analyzer
|
||||
BenchmarkCase(
|
||||
id="route-kw-competitor-001",
|
||||
input="分析我的竞品策略",
|
||||
expected_skill="competitor_analyzer",
|
||||
expected_execution_mode="tool_call",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["对比我和竞品的差距", "竞品分析", "Analyze competitor strategies"],
|
||||
tags=["竞品", "competitor"],
|
||||
),
|
||||
# schema_advisor
|
||||
BenchmarkCase(
|
||||
id="route-kw-schema-001",
|
||||
input="帮我优化Schema",
|
||||
expected_skill="schema_advisor",
|
||||
expected_execution_mode="custom",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["生成JSON-LD结构化数据", "Schema有什么可以改进的", "Optimize my Schema"],
|
||||
tags=["Schema", "schema优化"],
|
||||
),
|
||||
# monitor
|
||||
BenchmarkCase(
|
||||
id="route-kw-monitor-001",
|
||||
input="监测品牌引用变化",
|
||||
expected_skill="monitor",
|
||||
expected_execution_mode="custom",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=["追踪效果", "品牌排名变化", "Monitor brand citation changes"],
|
||||
tags=["监测", "monitor"],
|
||||
),
|
||||
# goal_driven_agent
|
||||
BenchmarkCase(
|
||||
id="route-kw-goal-001",
|
||||
input="分析竞品SEO策略并生成优化方案",
|
||||
expected_skill="goal_driven_agent",
|
||||
expected_execution_mode="tool_call",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=[
|
||||
"调研技术方案并生成对比报告",
|
||||
"制定市场推广计划",
|
||||
"Analyze SEO and generate plan",
|
||||
],
|
||||
tags=["分析", "优化方案"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Routing — Edge Cases (manually curated)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
ROUTING_EDGE_BENCHMARKS: list[BenchmarkCase] = [
|
||||
# Greeting (should NOT route to any skill)
|
||||
BenchmarkCase(
|
||||
id="route-edge-greet-001",
|
||||
input="你好",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="greeting",
|
||||
paraphrases=["Hello", "Hi there", "早上好"],
|
||||
tags=["greeting"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-edge-greet-002",
|
||||
input="Good morning!",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="greeting",
|
||||
paraphrases=["早上好!", "你好呀"],
|
||||
tags=["greeting"],
|
||||
),
|
||||
# Identity (should NOT route to any skill)
|
||||
BenchmarkCase(
|
||||
id="route-edge-identity-001",
|
||||
input="你是谁?",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="identity",
|
||||
paraphrases=["What is your name?", "介绍一下你自己", "Tell me about yourself"],
|
||||
tags=["identity"],
|
||||
),
|
||||
# Explicit prefix
|
||||
BenchmarkCase(
|
||||
id="route-edge-explicit-001",
|
||||
input="@skill:react_agent 搜索最新的AI新闻",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="routing",
|
||||
subcategory="explicit_prefix",
|
||||
paraphrases=["@skill:react_agent 查找AI最新动态"],
|
||||
tags=["explicit", "react"],
|
||||
),
|
||||
# Fallback (no matching skill)
|
||||
BenchmarkCase(
|
||||
id="route-edge-fallback-001",
|
||||
input="告诉我一个笑话",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="fallback",
|
||||
paraphrases=["讲个笑话", "Tell me a joke", "说个搞笑的"],
|
||||
tags=["fallback"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="route-edge-fallback-002",
|
||||
input="What is quantum physics?",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="fallback",
|
||||
paraphrases=["量子物理是什么", "Explain quantum mechanics"],
|
||||
tags=["fallback"],
|
||||
),
|
||||
# Disambiguation (multiple skills could match)
|
||||
BenchmarkCase(
|
||||
id="route-edge-disambig-001",
|
||||
input="审查代码并优化SEO",
|
||||
expected_skill="code_reviewer",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="disambiguation",
|
||||
paraphrases=["Review code and optimize SEO", "代码审查加SEO优化"],
|
||||
tags=["disambiguation", "review", "seo"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Execution Mode Benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
EXECUTION_BENCHMARKS: list[BenchmarkCase] = [
|
||||
BenchmarkCase(
|
||||
id="exec-direct-001",
|
||||
input="翻译这段话成英文",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="direct_mode",
|
||||
paraphrases=["Translate this to English", "把这段翻成英语"],
|
||||
tags=["direct", "simple"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-direct-002",
|
||||
input="什么是AgentKit?",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="direct_mode",
|
||||
paraphrases=["AgentKit是什么", "Explain AgentKit"],
|
||||
tags=["direct", "qa"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-react-001",
|
||||
input="搜索并分析AI行业最新趋势",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="react_mode",
|
||||
paraphrases=["Search and analyze AI trends", "调研AI行业趋势"],
|
||||
tags=["react", "multi_step"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-react-002",
|
||||
input="实时监控竞品动态并生成报告",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="react_mode",
|
||||
paraphrases=["Monitor competitors and report", "追踪竞品并输出报告"],
|
||||
tags=["react", "monitoring"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-rewoo-001",
|
||||
input="批量采集多个竞品的功能数据",
|
||||
expected_skill="rewoo_agent",
|
||||
expected_execution_mode="rewoo",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="rewoo_mode",
|
||||
paraphrases=["并行获取竞品数据", "Fetch competitor data in parallel"],
|
||||
tags=["rewoo", "parallel"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-reflexion-001",
|
||||
input="审查代码合规性并确保高精度",
|
||||
expected_skill="reflexion_agent",
|
||||
expected_execution_mode="reflexion",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="reflexion_mode",
|
||||
paraphrases=["高精度代码审查", "Precise code compliance review"],
|
||||
tags=["reflexion", "precision"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-planexec-001",
|
||||
input="生成一份完整的市场调研报告",
|
||||
expected_skill="plan_exec_agent",
|
||||
expected_execution_mode="plan_exec",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="plan_exec_mode",
|
||||
paraphrases=["做一份市场调研报告", "Generate a market research report"],
|
||||
tags=["plan_exec", "report"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="exec-quality-001",
|
||||
input="生成内容并确保质量达标",
|
||||
expected_skill="content_generator",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="quality_gate",
|
||||
paraphrases=["生成高质量内容", "Generate quality content"],
|
||||
tags=["quality", "content"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Team Collaboration Benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
TEAM_BENCHMARKS: list[BenchmarkCase] = [
|
||||
BenchmarkCase(
|
||||
id="team-explicit-001",
|
||||
input="@team:react_agent,plan_exec_agent 协作完成深度分析并生成报告",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="explicit_team",
|
||||
paraphrases=[
|
||||
"需要react_agent和plan_exec_agent协作",
|
||||
"组建团队:搜索分析+报告生成",
|
||||
],
|
||||
tags=["team", "explicit"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="team-explicit-002",
|
||||
input="@team:competitor_analyzer,trend_agent 分析竞品并追踪趋势",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="explicit_team",
|
||||
paraphrases=["竞品分析+趋势追踪团队", "Team for competitor and trend analysis"],
|
||||
tags=["team", "explicit"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="team-complexity-001",
|
||||
input="深度分析竞品策略、追踪品牌趋势并生成优化方案",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="complexity_trigger",
|
||||
paraphrases=[
|
||||
"全面竞品分析和优化方案",
|
||||
"Comprehensive competitor analysis with optimization",
|
||||
],
|
||||
tags=["team", "complexity"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="team-fallback-001",
|
||||
input="复杂任务但无匹配专家",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="fallback",
|
||||
paraphrases=["需要团队但找不到合适专家", "Complex task without matching experts"],
|
||||
tags=["team", "fallback"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="team-name-valid-001",
|
||||
input="@team:react_agent,plan_exec_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="name_validation",
|
||||
tags=["team", "validation"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="team-name-invalid-001",
|
||||
input="@team:invalid expert name",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="team",
|
||||
subcategory="name_validation",
|
||||
tags=["team", "validation", "invalid"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Consistency Benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
CONSISTENCY_BENCHMARKS: list[BenchmarkCase] = [
|
||||
BenchmarkCase(
|
||||
id="consist-direct-001",
|
||||
input="翻译'hello world'成中文",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="consistency",
|
||||
subcategory="deterministic",
|
||||
tags=["consistency", "translation"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="consist-direct-002",
|
||||
input="什么是RAG?",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="consistency",
|
||||
subcategory="deterministic",
|
||||
tags=["consistency", "qa"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="consist-react-001",
|
||||
input="搜索AI Agent市场数据",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="consistency",
|
||||
subcategory="deterministic",
|
||||
tags=["consistency", "search"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="consist-geo-001",
|
||||
input="帮我优化这篇文章的SEO",
|
||||
expected_skill="geo_optimizer",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="consistency",
|
||||
subcategory="deterministic",
|
||||
tags=["consistency", "seo"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="consist-deai-001",
|
||||
input="帮我把这篇文章去AI化",
|
||||
expected_skill="deai_agent",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="consistency",
|
||||
subcategory="deterministic",
|
||||
tags=["consistency", "deai"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Semantic Router Benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
SEMANTIC_ROUTER_BENCHMARKS: list[BenchmarkCase] = [
|
||||
BenchmarkCase(
|
||||
id="semantic-direct-001",
|
||||
input="简单生成任务,无需工具调用",
|
||||
expected_skill="direct_agent",
|
||||
expected_execution_mode="direct",
|
||||
expected_complexity="low",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["只需要一次生成的简单任务", "Single LLM call task"],
|
||||
tags=["semantic", "direct"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-react-001",
|
||||
input="需要动态适应、逐步推理和工具调用",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["需要多步推理和工具", "Multi-step reasoning with tools"],
|
||||
tags=["semantic", "react"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-rewoo-001",
|
||||
input="多源数据并行采集、无依赖工具调用批量执行",
|
||||
expected_skill="rewoo_agent",
|
||||
expected_execution_mode="rewoo",
|
||||
expected_complexity="high",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["并行批量获取数据", "Parallel data collection"],
|
||||
tags=["semantic", "rewoo"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-reflex-001",
|
||||
input="需要高精度和自我验证的任务",
|
||||
expected_skill="reflexion_agent",
|
||||
expected_execution_mode="reflexion",
|
||||
expected_complexity="high",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["需要自我检查的高精度任务", "High-precision self-verification task"],
|
||||
tags=["semantic", "reflexion"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-planexec-001",
|
||||
input="结构化多步骤任务,需要可审查的规划和执行",
|
||||
expected_skill="plan_exec_agent",
|
||||
expected_execution_mode="plan_exec",
|
||||
expected_complexity="high",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["需要先规划再执行的任务", "Structured planning and execution"],
|
||||
tags=["semantic", "plan_exec"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-geo-001",
|
||||
input="对文章进行GEO/SEO优化,提升在AI搜索引擎中的可见性",
|
||||
expected_skill="geo_optimizer",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["提升内容搜索排名", "Improve content visibility in AI search"],
|
||||
tags=["semantic", "geo"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-citation-001",
|
||||
input="检测品牌在各AI平台回答中的引用情况",
|
||||
expected_skill="citation_detector",
|
||||
expected_execution_mode="custom",
|
||||
expected_complexity="medium",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["分析品牌被AI引用的情况", "Check brand citation across AI platforms"],
|
||||
tags=["semantic", "citation"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="semantic-competitor-001",
|
||||
input="分析竞品策略、对比品牌差距或发现竞争机会",
|
||||
expected_skill="competitor_analyzer",
|
||||
expected_execution_mode="tool_call",
|
||||
expected_complexity="medium",
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
paraphrases=["竞品对比和差距分析", "Competitive gap analysis"],
|
||||
tags=["semantic", "competitor"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Alignment Guard Benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
ALIGNMENT_BENCHMARKS: list[BenchmarkCase] = [
|
||||
BenchmarkCase(
|
||||
id="align-negative-001",
|
||||
input="写一篇产品介绍,不要提及价格",
|
||||
expected_skill="content_generator",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="alignment",
|
||||
subcategory="negative_constraint",
|
||||
tags=["alignment", "negative_constraint"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-positive-001",
|
||||
input="生成报告,必须包含摘要部分",
|
||||
expected_skill="plan_exec_agent",
|
||||
expected_execution_mode="plan_exec",
|
||||
expected_complexity="high",
|
||||
category="alignment",
|
||||
subcategory="positive_constraint",
|
||||
tags=["alignment", "positive_constraint"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-cascade-001",
|
||||
input="反复搜索相同关键词",
|
||||
expected_skill="react_agent",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="alignment",
|
||||
subcategory="cascade_detection",
|
||||
tags=["alignment", "cascade"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-no-constraint-001",
|
||||
input="帮我写一篇文章",
|
||||
expected_skill="content_generator",
|
||||
expected_execution_mode="llm_generate",
|
||||
expected_complexity="low",
|
||||
category="alignment",
|
||||
subcategory="no_constraint",
|
||||
tags=["alignment", "baseline"],
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-combined-001",
|
||||
input="生成竞品分析报告,必须包含对比表格,不要提及内部数据",
|
||||
expected_skill="competitor_analyzer",
|
||||
expected_execution_mode="tool_call",
|
||||
expected_complexity="medium",
|
||||
category="alignment",
|
||||
subcategory="combined_constraint",
|
||||
tags=["alignment", "combined"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# All benchmarks combined
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
ALL_BENCHMARKS: list[BenchmarkCase] = (
|
||||
ROUTING_KEYWORD_BENCHMARKS
|
||||
+ ROUTING_EDGE_BENCHMARKS
|
||||
+ EXECUTION_BENCHMARKS
|
||||
+ TEAM_BENCHMARKS
|
||||
+ CONSISTENCY_BENCHMARKS
|
||||
+ SEMANTIC_ROUTER_BENCHMARKS
|
||||
+ ALIGNMENT_BENCHMARKS
|
||||
)
|
||||
|
||||
|
||||
def get_benchmarks_by_category(category: str) -> list[BenchmarkCase]:
|
||||
"""Filter benchmarks by category."""
|
||||
return [b for b in ALL_BENCHMARKS if b.category == category]
|
||||
|
||||
|
||||
def get_benchmarks_by_subcategory(subcategory: str) -> list[BenchmarkCase]:
|
||||
"""Filter benchmarks by subcategory."""
|
||||
return [b for b in ALL_BENCHMARKS if b.subcategory == subcategory]
|
||||
|
||||
|
||||
def get_benchmarks_with_paraphrases() -> list[BenchmarkCase]:
|
||||
"""Get only benchmarks that have paraphrases (for overfitting detection)."""
|
||||
return [b for b in ALL_BENCHMARKS if b.paraphrases]
|
||||
|
||||
|
||||
def get_skill_names_needed() -> set[str]:
|
||||
"""Get all skill names referenced in benchmarks (for pre-registration)."""
|
||||
return {b.expected_skill for b in ALL_BENCHMARKS if b.expected_skill is not None}
|
||||
|
||||
|
||||
def get_benchmark_stats() -> dict[str, int]:
|
||||
"""Get benchmark count by category."""
|
||||
stats: dict[str, int] = {}
|
||||
for b in ALL_BENCHMARKS:
|
||||
stats[b.category] = stats.get(b.category, 0) + 1
|
||||
stats["total"] = len(ALL_BENCHMARKS)
|
||||
return stats
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
"""Benchmark Generator — Auto-generate benchmark cases from skill configs.
|
||||
|
||||
Reads configs/skills/*.yaml, extracts intent.keywords/description/examples,
|
||||
and generates BenchmarkCase objects aligned with actual skill configurations.
|
||||
|
||||
This ensures the benchmark dataset stays in sync with the real skill registry.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from tests.e2e.benchmark_dataset import BenchmarkCase
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Skill Config Model
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class SkillIntent(BaseModel):
|
||||
"""Intent section of a skill config."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
keywords: list[str] = []
|
||||
description: str = ""
|
||||
examples: list[str] = []
|
||||
|
||||
|
||||
class SkillConfig(BaseModel):
|
||||
"""Minimal skill config model for benchmark generation."""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
name: str
|
||||
description: str = ""
|
||||
execution_mode: str = "direct"
|
||||
task_mode: str = "llm_generate"
|
||||
intent: SkillIntent = SkillIntent()
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Complexity Mapping
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
EXECUTION_MODE_TO_COMPLEXITY: dict[str, str] = {
|
||||
"direct": "low",
|
||||
"react": "high",
|
||||
"rewoo": "high",
|
||||
"reflexion": "high",
|
||||
"plan_exec": "high",
|
||||
"tool_call": "medium",
|
||||
"llm_generate": "low",
|
||||
"custom": "medium",
|
||||
}
|
||||
|
||||
# Paraphrase templates for auto-generating paraphrases from examples
|
||||
PARAPHRASE_TEMPLATES_CN: list[str] = [
|
||||
"请帮我{action}",
|
||||
"我需要{action}",
|
||||
"能不能{action}",
|
||||
]
|
||||
|
||||
PARAPHRASE_TEMPLATES_EN: list[str] = [
|
||||
"Please help me {action}",
|
||||
"I need to {action}",
|
||||
"Can you {action}",
|
||||
]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Benchmark Generator
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class BenchmarkGenerator:
|
||||
"""Generate benchmark cases from skill config YAML files."""
|
||||
|
||||
def __init__(self, configs_dir: str | None = None) -> None:
|
||||
if configs_dir is None:
|
||||
# Default: project_root/configs/skills/
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
configs_dir = str(project_root / "configs" / "skills")
|
||||
self.configs_dir = configs_dir
|
||||
self._skills: list[SkillConfig] = []
|
||||
self._loaded = False
|
||||
|
||||
def load_skills(self) -> list[SkillConfig]:
|
||||
"""Load all skill configs from YAML files."""
|
||||
if self._loaded:
|
||||
return self._skills
|
||||
|
||||
skills_dir = Path(self.configs_dir)
|
||||
if not skills_dir.exists():
|
||||
return self._skills
|
||||
|
||||
for yaml_file in sorted(skills_dir.glob("*.yaml")):
|
||||
with open(yaml_file, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if data and isinstance(data, dict):
|
||||
try:
|
||||
skill = SkillConfig(**data)
|
||||
self._skills.append(skill)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
self._loaded = True
|
||||
return self._skills
|
||||
|
||||
def _get_effective_execution_mode(self, skill: SkillConfig) -> str:
|
||||
"""Get the effective execution mode for a skill."""
|
||||
if skill.execution_mode and skill.execution_mode != "direct":
|
||||
return skill.execution_mode
|
||||
# Map task_mode to execution mode
|
||||
return skill.task_mode if skill.task_mode else "direct"
|
||||
|
||||
def _generate_paraphrases(self, example: str, keywords: list[str]) -> list[str]:
|
||||
"""Generate paraphrases for an example query."""
|
||||
paraphrases: list[str] = []
|
||||
|
||||
# Simple paraphrase generation: add prefix variations
|
||||
is_chinese = any("\u4e00" <= c <= "\u9fff" for c in example)
|
||||
|
||||
if is_chinese:
|
||||
# Chinese paraphrases
|
||||
if not example.startswith("请") and not example.startswith("帮"):
|
||||
paraphrases.append(f"请{example}")
|
||||
if not example.startswith("我"):
|
||||
paraphrases.append(f"我需要{example}")
|
||||
# Add keyword-based variant
|
||||
if keywords:
|
||||
kw = keywords[0]
|
||||
if kw not in example:
|
||||
paraphrases.append(f"关于{kw},{example}")
|
||||
else:
|
||||
# English paraphrases
|
||||
lower = example.lower()
|
||||
if not lower.startswith("please") and not lower.startswith("can you"):
|
||||
paraphrases.append(f"Please {example[0].lower()}{example[1:]}")
|
||||
if not lower.startswith("i need"):
|
||||
paraphrases.append(f"I need to {example[0].lower()}{example[1:]}")
|
||||
|
||||
return paraphrases[:3] # Max 3 paraphrases per example
|
||||
|
||||
def generate_routing_benchmarks(self) -> list[BenchmarkCase]:
|
||||
"""Generate routing benchmark cases from all skills."""
|
||||
skills = self.load_skills()
|
||||
cases: list[BenchmarkCase] = []
|
||||
case_counter = 0
|
||||
|
||||
for skill in skills:
|
||||
exec_mode = self._get_effective_execution_mode(skill)
|
||||
complexity = EXECUTION_MODE_TO_COMPLEXITY.get(exec_mode, "low")
|
||||
|
||||
# Generate from intent.examples
|
||||
for example in skill.intent.examples:
|
||||
case_counter += 1
|
||||
paraphrases = self._generate_paraphrases(example, skill.intent.keywords)
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"route-auto-{case_counter:03d}",
|
||||
input=example,
|
||||
expected_skill=skill.name,
|
||||
expected_execution_mode=exec_mode,
|
||||
expected_complexity=complexity,
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
paraphrases=paraphrases,
|
||||
tags=skill.intent.keywords[:3],
|
||||
)
|
||||
)
|
||||
|
||||
# Generate from intent.keywords (one case per keyword)
|
||||
for keyword in skill.intent.keywords:
|
||||
case_counter += 1
|
||||
query = (
|
||||
f"帮我{keyword}"
|
||||
if any("\u4e00" <= c <= "\u9fff" for c in keyword)
|
||||
else f"Help me {keyword}"
|
||||
)
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"route-kw-auto-{case_counter:03d}",
|
||||
input=query,
|
||||
expected_skill=skill.name,
|
||||
expected_execution_mode=exec_mode,
|
||||
expected_complexity=complexity,
|
||||
category="routing",
|
||||
subcategory="keyword_match",
|
||||
tags=[keyword],
|
||||
)
|
||||
)
|
||||
|
||||
return cases
|
||||
|
||||
def generate_execution_benchmarks(self) -> list[BenchmarkCase]:
|
||||
"""Generate execution mode benchmark cases."""
|
||||
skills = self.load_skills()
|
||||
cases: list[BenchmarkCase] = []
|
||||
case_counter = 0
|
||||
|
||||
# Group skills by execution mode
|
||||
mode_groups: dict[str, list[SkillConfig]] = {}
|
||||
for skill in skills:
|
||||
mode = self._get_effective_execution_mode(skill)
|
||||
mode_groups.setdefault(mode, []).append(skill)
|
||||
|
||||
for mode, group in mode_groups.items():
|
||||
complexity = EXECUTION_MODE_TO_COMPLEXITY.get(mode, "low")
|
||||
for skill in group[:2]: # Max 2 skills per mode
|
||||
if skill.intent.examples:
|
||||
case_counter += 1
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"exec-auto-{case_counter:03d}",
|
||||
input=skill.intent.examples[0],
|
||||
expected_skill=skill.name,
|
||||
expected_execution_mode=mode,
|
||||
expected_complexity=complexity,
|
||||
category="execution",
|
||||
subcategory=f"{mode}_mode",
|
||||
paraphrases=skill.intent.examples[1:2],
|
||||
tags=[mode],
|
||||
)
|
||||
)
|
||||
|
||||
return cases
|
||||
|
||||
def generate_team_benchmarks(self) -> list[BenchmarkCase]:
|
||||
"""Generate team collaboration benchmark cases."""
|
||||
skills = self.load_skills()
|
||||
cases: list[BenchmarkCase] = []
|
||||
case_counter = 0
|
||||
|
||||
# High-complexity skills suitable for team collaboration
|
||||
high_complexity_skills = [
|
||||
s
|
||||
for s in skills
|
||||
if EXECUTION_MODE_TO_COMPLEXITY.get(self._get_effective_execution_mode(s), "low")
|
||||
== "high"
|
||||
]
|
||||
|
||||
if len(high_complexity_skills) >= 2:
|
||||
skill_a, skill_b = high_complexity_skills[0], high_complexity_skills[1]
|
||||
case_counter += 1
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"team-auto-{case_counter:03d}",
|
||||
input=f"@team:{skill_a.name},{skill_b.name} 协作完成复杂分析任务",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="explicit_team",
|
||||
paraphrases=[
|
||||
f"需要{skill_a.name}和{skill_b.name}协作分析",
|
||||
f"组建团队:{skill_a.name} + {skill_b.name}",
|
||||
],
|
||||
tags=["team", skill_a.name, skill_b.name],
|
||||
)
|
||||
)
|
||||
|
||||
# Complexity-triggered team
|
||||
if high_complexity_skills:
|
||||
skill = high_complexity_skills[0]
|
||||
case_counter += 1
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"team-complexity-{case_counter:03d}",
|
||||
input=f"深度{skill.intent.keywords[0] if skill.intent.keywords else '分析'}并生成详细报告",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="complexity_trigger",
|
||||
paraphrases=[
|
||||
f"全面{skill.intent.keywords[0] if skill.intent.keywords else '分析'}并输出报告",
|
||||
],
|
||||
tags=["team", "complexity"],
|
||||
)
|
||||
)
|
||||
|
||||
return cases
|
||||
|
||||
def generate_semantic_benchmarks(self) -> list[BenchmarkCase]:
|
||||
"""Generate semantic router specific benchmark cases."""
|
||||
skills = self.load_skills()
|
||||
cases: list[BenchmarkCase] = []
|
||||
case_counter = 0
|
||||
|
||||
for skill in skills:
|
||||
if not skill.intent.description:
|
||||
continue
|
||||
case_counter += 1
|
||||
# Use description as input (tests semantic matching, not keyword matching)
|
||||
cases.append(
|
||||
BenchmarkCase(
|
||||
id=f"semantic-auto-{case_counter:03d}",
|
||||
input=skill.intent.description,
|
||||
expected_skill=skill.name,
|
||||
expected_execution_mode=self._get_effective_execution_mode(skill),
|
||||
expected_complexity=EXECUTION_MODE_TO_COMPLEXITY.get(
|
||||
self._get_effective_execution_mode(skill), "low"
|
||||
),
|
||||
category="semantic_router",
|
||||
subcategory="description_match",
|
||||
tags=["semantic", skill.name],
|
||||
)
|
||||
)
|
||||
|
||||
return cases
|
||||
|
||||
def generate_all(self) -> list[BenchmarkCase]:
|
||||
"""Generate all auto-generated benchmark cases."""
|
||||
cases: list[BenchmarkCase] = []
|
||||
cases.extend(self.generate_routing_benchmarks())
|
||||
cases.extend(self.generate_execution_benchmarks())
|
||||
cases.extend(self.generate_team_benchmarks())
|
||||
cases.extend(self.generate_semantic_benchmarks())
|
||||
return cases
|
||||
|
||||
def get_skill_names(self) -> set[str]:
|
||||
"""Get all skill names from configs."""
|
||||
return {s.name for s in self.load_skills()}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Singleton for reuse
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
_generator: BenchmarkGenerator | None = None
|
||||
|
||||
|
||||
def get_generator() -> BenchmarkGenerator:
|
||||
"""Get or create the singleton BenchmarkGenerator."""
|
||||
global _generator
|
||||
if _generator is None:
|
||||
_generator = BenchmarkGenerator()
|
||||
return _generator
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,413 @@
|
|||
"""E2E test fixtures: server lifecycle, CLI runner, API client, WebSocket helpers.
|
||||
|
||||
Design principles:
|
||||
1. Start a real uvicorn server with MockLLMProvider once per session
|
||||
2. CLI tests use subprocess to invoke `agentkit` commands (OpenCLI pattern)
|
||||
3. API tests use httpx against the live server
|
||||
4. WebSocket tests use the `websockets` library against the live server
|
||||
5. All tests are idempotent and repeatable
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Generator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Markers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
config.addinivalue_line("markers", "e2e: end-to-end backtest (requires server)")
|
||||
config.addinivalue_line("markers", "e2e_basic: basic function correctness test")
|
||||
config.addinivalue_line("markers", "e2e_capability: agent intelligence capability test")
|
||||
# Initialize session-scoped metrics collector
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
|
||||
config._e2e_metrics_collector = MetricsCollector() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
|
||||
"""After all tests, generate capability analysis report if data was collected."""
|
||||
collector = session.config._e2e_metrics_collector # type: ignore[attr-defined]
|
||||
if collector is None or not collector.observations:
|
||||
return
|
||||
|
||||
from tests.e2e.capability_metrics import MetricsAnalyzer, MetricsReporter
|
||||
|
||||
analyzer = MetricsAnalyzer()
|
||||
report = analyzer.generate_report(collector)
|
||||
|
||||
output_dir = os.path.join(os.path.dirname(__file__), "..", "..", "test-results", "e2e")
|
||||
paths = MetricsReporter.save_report(report, output_dir)
|
||||
|
||||
# Print summary to console
|
||||
print("\n" + MetricsReporter.to_text(report))
|
||||
print(f"\nReport saved to: {paths['json']}")
|
||||
print(f"Text report: {paths['text']}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
E2E_HOST = "127.0.0.1"
|
||||
E2E_PORT = 18765 # dedicated port to avoid conflict with dev server
|
||||
E2E_BASE_URL = f"http://{E2E_HOST}:{E2E_PORT}"
|
||||
E2E_WS_URL = f"ws://{E2E_HOST}:{E2E_PORT}"
|
||||
E2E_API_KEY = "ak_live_e2e_test_key_000000000000000000000000000000000000000000000000"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock LLM Provider (deterministic responses for backtest)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MOCK_LLM_RESPONSES: dict[str, str] = {
|
||||
# Default / generic
|
||||
"default": '{"result": "mock response", "content": "This is a mock LLM response for e2e testing."}',
|
||||
# Content generation
|
||||
"content_writer": '{"result": "article generated", "content": "AI is transforming industries by enabling automation and intelligent decision-making."}',
|
||||
# Translation
|
||||
"translator": '{"result": "translation complete", "content": "This is the translated text."}',
|
||||
# Summarization
|
||||
"summarizer": '{"result": "summary generated", "content": "Key points: 1) Topic overview 2) Main findings 3) Conclusion."}',
|
||||
# Code generation
|
||||
"coder": '{"result": "code generated", "content": "def hello():\\n print(\\"Hello, World!\\")"}',
|
||||
# Analysis
|
||||
"analyst": '{"result": "analysis complete", "content": "The data shows a positive trend with 15% growth."}',
|
||||
# ReAct tool call
|
||||
"react_tool_call": '{"thought": "I need to search for information", "action": "web_search", "action_input": {"query": "test"}, "observation": "Search results found"}',
|
||||
# ReAct final answer
|
||||
"react_final": '{"thought": "I have enough information", "final_answer": "Based on my analysis, the answer is 42."}',
|
||||
}
|
||||
|
||||
|
||||
def _build_mock_env(tmp_path: Any) -> dict[str, str]:
|
||||
"""Build environment variables for a server with MockLLMProvider."""
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"AGENTKIT_E2E_MODE": "1",
|
||||
"AGENTKIT_E2E_MOCK_RESPONSES": json.dumps(MOCK_LLM_RESPONSES),
|
||||
"AGENTKIT_API_KEY": E2E_API_KEY,
|
||||
"AGENTKIT_WS_TIMEOUT": "0",
|
||||
# Disable real LLM calls
|
||||
"OPENAI_API_KEY": "",
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"DEEPSEEK_API_KEY": "",
|
||||
}
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Server lifecycle fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def e2e_server(tmp_path_factory: pytest.TempPathFactory) -> Generator[str, None, None]:
|
||||
"""Start a real AgentKit server for the entire E2E session.
|
||||
|
||||
Returns the base URL (e.g. http://127.0.0.1:18765).
|
||||
The server uses MockLLMProvider so no real LLM calls are made.
|
||||
"""
|
||||
tmp_path = tmp_path_factory.mktemp("e2e_server")
|
||||
|
||||
# Generate a minimal agentkit.yaml for the test server
|
||||
config_dir = tmp_path / "config"
|
||||
config_dir.mkdir()
|
||||
config_file = config_dir / "agentkit.yaml"
|
||||
|
||||
import yaml
|
||||
|
||||
config_file.write_text(
|
||||
yaml.dump(
|
||||
{
|
||||
"server": {"host": E2E_HOST, "port": E2E_PORT},
|
||||
"llm": {"default_provider": "mock", "providers": {"mock": {"type": "mock"}}},
|
||||
"auth": {"enabled": True, "api_keys": [E2E_API_KEY]},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
env = _build_mock_env(tmp_path)
|
||||
env["AGENTKIT_CONFIG"] = str(config_file)
|
||||
|
||||
# Start server as subprocess
|
||||
proc = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"agentkit.cli.main",
|
||||
"serve",
|
||||
"--host",
|
||||
E2E_HOST,
|
||||
"--port",
|
||||
str(E2E_PORT),
|
||||
],
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=str(tmp_path),
|
||||
)
|
||||
|
||||
# Wait for server to be ready (max 30s)
|
||||
base_url = E2E_BASE_URL
|
||||
deadline = time.monotonic() + 30
|
||||
ready = False
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
resp = httpx.get(f"{base_url}/api/v1/health", timeout=2)
|
||||
if resp.status_code == 200:
|
||||
ready = True
|
||||
break
|
||||
except httpx.ConnectError:
|
||||
pass
|
||||
time.sleep(0.5)
|
||||
|
||||
if not ready:
|
||||
proc.terminate()
|
||||
stdout, stderr = proc.communicate(timeout=5)
|
||||
pytest.fail(
|
||||
f"E2E server failed to start within 30s.\n"
|
||||
f"stdout: {stdout.decode()[:2000]}\n"
|
||||
f"stderr: {stderr.decode()[:2000]}"
|
||||
)
|
||||
|
||||
yield base_url
|
||||
|
||||
# Teardown
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API client fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_client(e2e_server: str) -> httpx.Client:
|
||||
"""Synchronous httpx client configured for the E2E server."""
|
||||
return httpx.Client(
|
||||
base_url=e2e_server,
|
||||
headers={"X-API-Key": E2E_API_KEY, "Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI runner (subprocess-based, OpenCLI pattern)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CLIRunner:
|
||||
"""Simulate user CLI operations via subprocess.
|
||||
|
||||
This is the 'OpenCLI' pattern: invoke the real `agentkit` binary
|
||||
as a subprocess and capture its output, exactly as a user would.
|
||||
"""
|
||||
|
||||
def __init__(self, env: dict[str, str] | None = None, cwd: str | None = None):
|
||||
self.env = env or os.environ.copy()
|
||||
self.cwd = cwd
|
||||
|
||||
def _resolve_agentkit_cmd(self) -> list[str]:
|
||||
"""Resolve the agentkit command to use.
|
||||
|
||||
Prefer the installed `agentkit` script (handles Rich/Typer output correctly),
|
||||
fall back to `python -m agentkit.cli.main`.
|
||||
"""
|
||||
agentkit_path = shutil.which("agentkit")
|
||||
if agentkit_path:
|
||||
return [agentkit_path]
|
||||
return [sys.executable, "-m", "agentkit.cli.main"]
|
||||
|
||||
def run(self, args: list[str], timeout: int = 30) -> subprocess.CompletedProcess[str]:
|
||||
"""Run an agentkit CLI command and return the result.
|
||||
|
||||
Args:
|
||||
args: CLI arguments, e.g. ["version"] or ["task", "submit", ...]
|
||||
timeout: maximum seconds to wait
|
||||
|
||||
Returns:
|
||||
CompletedProcess with stdout, stderr, returncode
|
||||
"""
|
||||
cmd = [*self._resolve_agentkit_cmd(), *args]
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
env=self.env,
|
||||
cwd=self.cwd,
|
||||
)
|
||||
|
||||
def run_server_command(
|
||||
self, args: list[str], server_url: str, timeout: int = 30
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Run a CLI command that requires --server-url."""
|
||||
full_args = [*args, "--server-url", server_url]
|
||||
return self.run(full_args, timeout=timeout)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner(tmp_path: Any) -> CLIRunner:
|
||||
"""CLI runner with isolated environment."""
|
||||
env = os.environ.copy()
|
||||
env["AGENTKIT_CONFIG_DIR"] = str(tmp_path / "config")
|
||||
env["AGENTKIT_WS_TIMEOUT"] = "0"
|
||||
# Prevent onboarding prompts
|
||||
env["AGENTKIT_E2E_MODE"] = "1"
|
||||
return CLIRunner(env=env, cwd=str(tmp_path))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def cli_runner_session(e2e_server: str) -> CLIRunner:
|
||||
"""CLI runner configured to talk to the E2E server."""
|
||||
env = os.environ.copy()
|
||||
env["AGENTKIT_SERVER_URL"] = e2e_server
|
||||
env["AGENTKIT_API_KEY"] = E2E_API_KEY
|
||||
env["AGENTKIT_WS_TIMEOUT"] = "0"
|
||||
env["AGENTKIT_E2E_MODE"] = "1"
|
||||
return CLIRunner(env=env)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class WSChatHelper:
|
||||
"""Helper for WebSocket chat E2E tests."""
|
||||
|
||||
def __init__(self, base_ws_url: str, api_key: str):
|
||||
self.base_ws_url = base_ws_url
|
||||
self.api_key = api_key
|
||||
|
||||
async def connect_and_chat(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: list[dict[str, str]],
|
||||
timeout: float = 10.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Connect to a chat WebSocket, send messages, collect responses.
|
||||
|
||||
Args:
|
||||
session_id: chat session ID
|
||||
messages: list of {"type": "message", "content": "..."}
|
||||
timeout: max seconds to wait for final_answer
|
||||
|
||||
Returns:
|
||||
list of all server-sent messages
|
||||
"""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
pytest.skip("websockets package not installed")
|
||||
|
||||
uri = f"{self.base_ws_url}/api/v1/chat/ws/{session_id}?api_key={self.api_key}"
|
||||
received: list[dict[str, Any]] = []
|
||||
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Wait for connected event
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||
data = json.loads(msg)
|
||||
received.append(data)
|
||||
assert data.get("type") == "connected", f"Expected connected, got {data}"
|
||||
|
||||
# Send user messages
|
||||
for user_msg in messages:
|
||||
await ws.send(json.dumps(user_msg))
|
||||
|
||||
# Collect responses until final_answer or error
|
||||
while True:
|
||||
try:
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=timeout)
|
||||
resp = json.loads(raw)
|
||||
received.append(resp)
|
||||
|
||||
if resp.get("type") in ("final_answer", "error"):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
received.append({"type": "timeout"})
|
||||
break
|
||||
|
||||
return received
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def ws_helper(e2e_server: str) -> WSChatHelper:
|
||||
"""WebSocket chat helper for the E2E server."""
|
||||
ws_url = e2e_server.replace("http://", "ws://").replace("https://", "wss://")
|
||||
return WSChatHelper(base_ws_url=ws_url, api_key=E2E_API_KEY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Skill / Agent setup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register_skill_via_api(
|
||||
api_client: httpx.Client,
|
||||
name: str,
|
||||
keywords: list[str] | None = None,
|
||||
execution_mode: str = "direct",
|
||||
task_mode: str = "llm_generate",
|
||||
) -> httpx.Response:
|
||||
"""Register a skill via the API for E2E testing."""
|
||||
config: dict[str, Any] = {
|
||||
"name": name,
|
||||
"agent_type": name,
|
||||
"task_mode": task_mode,
|
||||
"description": f"E2E test skill: {name}",
|
||||
"prompt": {
|
||||
"identity": f"You are a {name} assistant",
|
||||
"instructions": f"Perform {name} tasks",
|
||||
"output_format": "JSON",
|
||||
},
|
||||
"intent": {
|
||||
"keywords": keywords or [name],
|
||||
"description": f"{name} skill for e2e testing",
|
||||
},
|
||||
}
|
||||
if execution_mode != "direct":
|
||||
config["execution_mode"] = execution_mode
|
||||
config["max_steps"] = 5
|
||||
|
||||
return api_client.post("/api/v1/skills", json={"config": config})
|
||||
|
||||
|
||||
def create_session_via_api(api_client: httpx.Client, agent_name: str = "test") -> str:
|
||||
"""Create a chat session and return the session ID."""
|
||||
resp = api_client.post("/api/v1/chat/sessions", json={"agent_name": agent_name})
|
||||
assert resp.status_code == 201, f"Failed to create session: {resp.text}"
|
||||
return resp.json()["session_id"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metrics Collector fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_collector(request: pytest.FixtureRequest):
|
||||
"""Session-scoped metrics collector for capability analysis."""
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
|
||||
collector: MetricsCollector = request.config._e2e_metrics_collector # type: ignore[attr-defined]
|
||||
return collector
|
||||
|
|
@ -0,0 +1,277 @@
|
|||
"""E2E Basic Function Tests — REST API endpoints.
|
||||
|
||||
Verifies all API routes work correctly with proper request/response handling.
|
||||
|
||||
Test categories:
|
||||
1. Health & metrics
|
||||
2. Agent CRUD lifecycle
|
||||
3. Skill registration & listing
|
||||
4. Task submission (sync/async/SSE)
|
||||
5. Chat session lifecycle
|
||||
6. LLM usage tracking
|
||||
7. Error handling & edge cases
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from tests.e2e.conftest import register_skill_via_api, create_session_via_api
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 1. Health & Metrics
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestHealthAPI:
|
||||
def test_health_returns_ok(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data.get("status") in ("ok", "healthy")
|
||||
|
||||
def test_metrics_endpoint(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/metrics")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 2. Agent CRUD Lifecycle
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestAgentCRUD:
|
||||
"""Full Agent CRUD lifecycle: create → list → get → delete."""
|
||||
|
||||
def test_create_agent_from_skill(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "crud_skill", keywords=["crud"])
|
||||
resp = api_client.post("/api/v1/agents", json={"skill_name": "crud_skill"})
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "crud_skill"
|
||||
|
||||
def test_list_agents(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "list_skill", keywords=["list_agent"])
|
||||
api_client.post("/api/v1/agents", json={"skill_name": "list_skill"})
|
||||
resp = api_client.get("/api/v1/agents")
|
||||
assert resp.status_code == 200
|
||||
agents = resp.json()
|
||||
assert isinstance(agents, list)
|
||||
assert any(a["name"] == "list_skill" for a in agents)
|
||||
|
||||
def test_get_agent_detail(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "detail_skill", keywords=["detail"])
|
||||
api_client.post("/api/v1/agents", json={"skill_name": "detail_skill"})
|
||||
resp = api_client.get("/api/v1/agents/detail_skill")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "detail_skill"
|
||||
|
||||
def test_delete_agent(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "delete_skill", keywords=["delete_agent"])
|
||||
api_client.post("/api/v1/agents", json={"skill_name": "delete_skill"})
|
||||
resp = api_client.delete("/api/v1/agents/delete_skill")
|
||||
assert resp.status_code == 204
|
||||
# Verify deleted
|
||||
resp = api_client.get("/api/v1/agents/delete_skill")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_create_agent_nonexistent_skill(self, api_client: httpx.Client):
|
||||
resp = api_client.post("/api/v1/agents", json={"skill_name": "nonexistent_skill_xyz"})
|
||||
assert resp.status_code in (400, 404)
|
||||
|
||||
def test_get_nonexistent_agent(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/agents/does_not_exist")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 3. Skill Registration & Listing
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestSkillAPI:
|
||||
def test_register_skill(self, api_client: httpx.Client):
|
||||
resp = register_skill_via_api(api_client, "reg_skill", keywords=["reg"])
|
||||
assert resp.status_code == 201
|
||||
|
||||
def test_list_skills(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "list_test_skill", keywords=["list_test"])
|
||||
resp = api_client.get("/api/v1/skills")
|
||||
assert resp.status_code == 200
|
||||
skills = resp.json()
|
||||
assert isinstance(skills, list)
|
||||
assert len(skills) >= 1
|
||||
|
||||
def test_register_duplicate_skill(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "dup_skill", keywords=["dup"])
|
||||
resp = register_skill_via_api(api_client, "dup_skill", keywords=["dup"])
|
||||
# Should either overwrite or return conflict
|
||||
assert resp.status_code in (200, 201, 409)
|
||||
|
||||
def test_skill_with_execution_mode(self, api_client: httpx.Client):
|
||||
resp = register_skill_via_api(
|
||||
api_client, "react_skill", keywords=["react_test"], execution_mode="react"
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
def test_skill_mention_suggest(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "mention_skill", keywords=["mention_test"])
|
||||
resp = api_client.get("/api/v1/skills/mention-suggest", params={"q": "mention"})
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 4. Task Submission
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestTaskAPI:
|
||||
def test_submit_task_sync(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "sync_task_skill", keywords=["sync_task"])
|
||||
resp = api_client.post(
|
||||
"/api/v1/tasks",
|
||||
json={
|
||||
"input_data": {"query": "test sync task"},
|
||||
"skill_name": "sync_task_skill",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "output" in data or "data" in data or "skill_name" in data
|
||||
|
||||
def test_submit_task_with_agent_name(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "agent_task_skill", keywords=["agent_task"])
|
||||
api_client.post("/api/v1/agents", json={"skill_name": "agent_task_skill"})
|
||||
resp = api_client.post(
|
||||
"/api/v1/tasks",
|
||||
json={
|
||||
"input_data": {"query": "test agent task"},
|
||||
"agent_name": "agent_task_skill",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_submit_task_auto_route(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "auto_route_skill", keywords=["auto_route"])
|
||||
resp = api_client.post(
|
||||
"/api/v1/tasks",
|
||||
json={"input_data": {"query": "Please auto_route this for me"}},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_list_tasks(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/tasks")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_submit_task_missing_data(self, api_client: httpx.Client):
|
||||
resp = api_client.post("/api/v1/tasks", json={})
|
||||
# Should return 400 or 422
|
||||
assert resp.status_code in (400, 422)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 5. Chat Session Lifecycle
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestChatSessionAPI:
|
||||
def test_create_session(self, api_client: httpx.Client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
assert session_id is not None
|
||||
assert len(session_id) > 0
|
||||
|
||||
def test_list_sessions(self, api_client: httpx.Client):
|
||||
create_session_via_api(api_client)
|
||||
resp = api_client.get("/api/v1/chat/sessions")
|
||||
assert resp.status_code == 200
|
||||
sessions = resp.json()
|
||||
assert isinstance(sessions, list)
|
||||
assert len(sessions) >= 1
|
||||
|
||||
def test_get_session(self, api_client: httpx.Client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
resp = api_client.get(f"/api/v1/chat/sessions/{session_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_session_messages(self, api_client: httpx.Client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
# Send a message
|
||||
resp = api_client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello from e2e test"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Get messages
|
||||
resp = api_client.get(f"/api/v1/chat/sessions/{session_id}/messages")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_close_session(self, api_client: httpx.Client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
resp = api_client.delete(f"/api/v1/chat/sessions/{session_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 6. LLM Usage Tracking
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestLLMUsageAPI:
|
||||
def test_llm_usage_endpoint(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/llm/usage")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_llm_usage_after_task(self, api_client: httpx.Client):
|
||||
register_skill_via_api(api_client, "usage_track_skill", keywords=["usage_track"])
|
||||
api_client.post(
|
||||
"/api/v1/tasks",
|
||||
json={
|
||||
"input_data": {"query": "test usage tracking"},
|
||||
"skill_name": "usage_track_skill",
|
||||
},
|
||||
)
|
||||
resp = api_client.get("/api/v1/llm/usage")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 7. Error Handling & Edge Cases
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestAPIErrorHandling:
|
||||
def test_404_for_unknown_route(self, api_client: httpx.Client):
|
||||
resp = api_client.get("/api/v1/nonexistent_route")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_invalid_json_body(self, api_client: httpx.Client):
|
||||
resp = api_client.post(
|
||||
"/api/v1/tasks",
|
||||
content=b"not json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code in (400, 422)
|
||||
|
||||
def test_missing_api_key(self, e2e_server: str):
|
||||
"""Requests without API key should be rejected (if auth enabled)."""
|
||||
client = httpx.Client(base_url=e2e_server, timeout=10)
|
||||
resp = client.get("/api/v1/agents")
|
||||
# Should be 401/403 or still 200 if auth is not enforced on this endpoint
|
||||
assert resp.status_code in (200, 401, 403)
|
||||
|
||||
def test_invalid_api_key(self, e2e_server: str):
|
||||
client = httpx.Client(
|
||||
base_url=e2e_server,
|
||||
headers={"X-API-Key": "invalid_key"},
|
||||
timeout=10,
|
||||
)
|
||||
resp = client.get("/api/v1/agents")
|
||||
assert resp.status_code in (200, 401, 403)
|
||||
|
|
@ -0,0 +1,353 @@
|
|||
"""E2E Basic Function Tests — CLI commands.
|
||||
|
||||
Verifies that all CLI commands execute correctly as a real user would invoke them.
|
||||
Uses subprocess (OpenCLI pattern) to simulate actual CLI operations.
|
||||
|
||||
Test categories:
|
||||
1. Utility commands: version, doctor, help
|
||||
2. Init & config: agentkit init
|
||||
3. Pair: API key generation
|
||||
4. Skill management: list, load, info
|
||||
5. Task management: submit, status, list, cancel
|
||||
6. Server: serve startup
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import CLIRunner, E2E_BASE_URL
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 1. Utility Commands
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIVersion:
|
||||
"""agentkit version — basic sanity check."""
|
||||
|
||||
def test_version_returns_zero_exit_code(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["version"])
|
||||
assert result.returncode == 0, f"stdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
|
||||
def test_version_outputs_version_string(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["version"])
|
||||
assert "0.1.0" in result.stdout or "fischer-agentkit" in result.stdout.lower()
|
||||
|
||||
def test_version_help(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["version", "--help"])
|
||||
assert result.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIDoctor:
|
||||
"""agentkit doctor — server health check."""
|
||||
|
||||
def test_doctor_server_not_running(self, cli_runner: CLIRunner):
|
||||
"""Doctor should report error when no server is running."""
|
||||
result = cli_runner.run(["doctor"])
|
||||
# Should indicate server not reachable
|
||||
output = (result.stdout + result.stderr).lower()
|
||||
assert (
|
||||
result.returncode != 0
|
||||
or "not running" in output
|
||||
or "error" in output
|
||||
or "connection" in output
|
||||
)
|
||||
|
||||
def test_doctor_with_running_server(self, cli_runner_session: CLIRunner):
|
||||
"""Doctor should report healthy when E2E server is running."""
|
||||
result = cli_runner_session.run(["doctor", "--port", "18765"])
|
||||
output = (result.stdout + result.stderr).lower()
|
||||
# Should show some health info (ok, healthy, or at least not connection refused)
|
||||
assert "connection refused" not in output or result.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIHelp:
|
||||
"""agentkit --help — command discovery."""
|
||||
|
||||
def test_help_shows_all_subcommands(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["--help"])
|
||||
assert result.returncode == 0
|
||||
for cmd in [
|
||||
"serve",
|
||||
"gui",
|
||||
"chat",
|
||||
"version",
|
||||
"doctor",
|
||||
"init",
|
||||
"task",
|
||||
"skill",
|
||||
"usage",
|
||||
"pair",
|
||||
]:
|
||||
assert cmd in result.stdout, f"Missing subcommand '{cmd}' in help output"
|
||||
|
||||
def test_task_help(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["task", "--help"])
|
||||
assert result.returncode == 0
|
||||
for sub in ["submit", "status", "list", "cancel"]:
|
||||
assert sub in result.stdout
|
||||
|
||||
def test_skill_help(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["skill", "--help"])
|
||||
assert result.returncode == 0
|
||||
for sub in ["list", "load", "info"]:
|
||||
assert sub in result.stdout
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 2. Init & Config
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIInit:
|
||||
"""agentkit init — project initialization."""
|
||||
|
||||
def test_init_non_interactive(self, cli_runner: CLIRunner, tmp_path):
|
||||
output_dir = str(tmp_path / "init_output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
result = cli_runner.run(["init", "--non-interactive", "--output-dir", output_dir])
|
||||
assert result.returncode == 0, f"stderr: {result.stderr}"
|
||||
assert os.path.exists(os.path.join(output_dir, "agentkit.yaml"))
|
||||
assert os.path.exists(os.path.join(output_dir, ".env.example"))
|
||||
|
||||
def test_init_generates_valid_yaml(self, cli_runner: CLIRunner, tmp_path):
|
||||
import yaml
|
||||
|
||||
output_dir = str(tmp_path / "init_yaml")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
cli_runner.run(["init", "--non-interactive", "--output-dir", output_dir])
|
||||
with open(os.path.join(output_dir, "agentkit.yaml")) as f:
|
||||
config = yaml.safe_load(f)
|
||||
assert "server" in config
|
||||
assert "llm" in config
|
||||
|
||||
def test_init_no_overwrite_without_force(self, cli_runner: CLIRunner, tmp_path):
|
||||
output_dir = str(tmp_path / "init_no_overwrite")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Create existing file
|
||||
with open(os.path.join(output_dir, "agentkit.yaml"), "w") as f:
|
||||
f.write("existing_content")
|
||||
cli_runner.run(["init", "--non-interactive", "--output-dir", output_dir])
|
||||
with open(os.path.join(output_dir, "agentkit.yaml")) as f:
|
||||
content = f.read()
|
||||
# Should not overwrite
|
||||
assert content == "existing_content"
|
||||
|
||||
def test_init_force_overwrites(self, cli_runner: CLIRunner, tmp_path):
|
||||
output_dir = str(tmp_path / "init_force")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with open(os.path.join(output_dir, "agentkit.yaml"), "w") as f:
|
||||
f.write("old")
|
||||
result = cli_runner.run(
|
||||
["init", "--non-interactive", "--force", "--output-dir", output_dir]
|
||||
)
|
||||
assert result.returncode == 0
|
||||
with open(os.path.join(output_dir, "agentkit.yaml")) as f:
|
||||
content = f.read()
|
||||
assert "server" in content
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 3. Pair (API Key Generation)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIPair:
|
||||
"""agentkit pair — external system API key management."""
|
||||
|
||||
def test_pair_generates_api_key(self, cli_runner: CLIRunner, tmp_path):
|
||||
config_dir = str(tmp_path / "pair_config")
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
result = cli_runner.run(["pair", "--name", "e2e-test-client", "--config-dir", config_dir])
|
||||
assert result.returncode == 0, f"stderr: {result.stderr}"
|
||||
assert "ak_live_" in result.stdout
|
||||
|
||||
def test_pair_saves_client_config(self, cli_runner: CLIRunner, tmp_path):
|
||||
import yaml
|
||||
|
||||
config_dir = str(tmp_path / "pair_save")
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
cli_runner.run(["pair", "--name", "e2e-client", "--config-dir", config_dir])
|
||||
clients_path = os.path.join(config_dir, "clients.yaml")
|
||||
assert os.path.exists(clients_path)
|
||||
with open(clients_path) as f:
|
||||
clients = yaml.safe_load(f)
|
||||
assert "e2e-client" in clients
|
||||
assert clients["e2e-client"]["api_key"].startswith("ak_live_")
|
||||
|
||||
def test_pair_rejects_duplicate_name(self, cli_runner: CLIRunner, tmp_path):
|
||||
config_dir = str(tmp_path / "pair_dup")
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
cli_runner.run(["pair", "--name", "dup-client", "--config-dir", config_dir])
|
||||
result = cli_runner.run(["pair", "--name", "dup-client", "--config-dir", config_dir])
|
||||
output = (result.stdout + result.stderr).lower()
|
||||
assert result.returncode != 0 or "already" in output or "exists" in output
|
||||
|
||||
def test_pair_list(self, cli_runner: CLIRunner, tmp_path):
|
||||
config_dir = str(tmp_path / "pair_list")
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
cli_runner.run(["pair", "--name", "client-a", "--config-dir", config_dir])
|
||||
cli_runner.run(["pair", "--name", "client-b", "--config-dir", config_dir])
|
||||
result = cli_runner.run(["pair", "--list", "--config-dir", config_dir])
|
||||
assert result.returncode == 0
|
||||
assert "client-a" in result.stdout
|
||||
assert "client-b" in result.stdout
|
||||
|
||||
def test_pair_revoke(self, cli_runner: CLIRunner, tmp_path):
|
||||
import yaml
|
||||
|
||||
config_dir = str(tmp_path / "pair_revoke")
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
cli_runner.run(["pair", "--name", "revoke-me", "--config-dir", config_dir])
|
||||
result = cli_runner.run(["pair", "--revoke", "revoke-me", "--config-dir", config_dir])
|
||||
assert result.returncode == 0
|
||||
with open(os.path.join(config_dir, "clients.yaml")) as f:
|
||||
clients = yaml.safe_load(f)
|
||||
assert "revoke-me" not in clients
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 4. Skill Management (CLI → Server)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLISkill:
|
||||
"""agentkit skill — skill management via CLI against running server."""
|
||||
|
||||
def test_skill_list_via_server(self, cli_runner_session: CLIRunner):
|
||||
result = cli_runner_session.run_server_command(["skill", "list"], E2E_BASE_URL)
|
||||
assert result.returncode == 0, f"stderr: {result.stderr}"
|
||||
|
||||
def test_skill_load_yaml(self, cli_runner: CLIRunner, tmp_path):
|
||||
import yaml
|
||||
|
||||
skill_file = tmp_path / "test_skill.yaml"
|
||||
skill_file.write_text(
|
||||
yaml.dump(
|
||||
{
|
||||
"name": "e2e_test_skill",
|
||||
"description": "E2E test skill",
|
||||
"agent_type": "assistant",
|
||||
"task_mode": "llm_generate",
|
||||
"prompt": {"system": "You are a test assistant"},
|
||||
}
|
||||
)
|
||||
)
|
||||
result = cli_runner.run(["skill", "load", str(skill_file)])
|
||||
# Should load successfully or report loaded
|
||||
output = (result.stdout + result.stderr).lower()
|
||||
assert result.returncode == 0 or "loaded" in output or "e2e_test_skill" in output
|
||||
|
||||
def test_skill_info_via_server(self, cli_runner_session: CLIRunner, api_client):
|
||||
# First register a skill via API
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
register_skill_via_api(api_client, "cli_info_skill", keywords=["cli_info"])
|
||||
# Then query via CLI
|
||||
result = cli_runner_session.run_server_command(
|
||||
["skill", "info", "cli_info_skill"], E2E_BASE_URL
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert "cli_info_skill" in result.stdout
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 5. Task Management (CLI → Server)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLITask:
|
||||
"""agentkit task — task management via CLI against running server."""
|
||||
|
||||
def test_task_submit_sync(self, cli_runner_session: CLIRunner, api_client):
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
register_skill_via_api(api_client, "cli_task_skill", keywords=["cli_task"])
|
||||
result = cli_runner_session.run_server_command(
|
||||
[
|
||||
"task",
|
||||
"submit",
|
||||
"--skill",
|
||||
"cli_task_skill",
|
||||
"--input",
|
||||
json.dumps({"query": "test task submission"}),
|
||||
],
|
||||
E2E_BASE_URL,
|
||||
)
|
||||
assert result.returncode == 0, f"stdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
|
||||
def test_task_submit_async(self, cli_runner_session: CLIRunner, api_client):
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
register_skill_via_api(api_client, "cli_async_skill", keywords=["cli_async"])
|
||||
result = cli_runner_session.run_server_command(
|
||||
[
|
||||
"task",
|
||||
"submit",
|
||||
"--skill",
|
||||
"cli_async_skill",
|
||||
"--mode",
|
||||
"async",
|
||||
"--input",
|
||||
json.dumps({"query": "async task test"}),
|
||||
],
|
||||
E2E_BASE_URL,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
def test_task_list(self, cli_runner_session: CLIRunner):
|
||||
result = cli_runner_session.run_server_command(["task", "list"], E2E_BASE_URL)
|
||||
assert result.returncode == 0
|
||||
|
||||
def test_task_submit_input_file(self, cli_runner_session: CLIRunner, api_client, tmp_path):
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
register_skill_via_api(api_client, "cli_file_skill", keywords=["cli_file"])
|
||||
|
||||
input_file = tmp_path / "task_input.json"
|
||||
input_file.write_text(json.dumps({"query": "file input test"}))
|
||||
|
||||
result = cli_runner_session.run_server_command(
|
||||
[
|
||||
"task",
|
||||
"submit",
|
||||
"--skill",
|
||||
"cli_file_skill",
|
||||
"--input-file",
|
||||
str(input_file),
|
||||
],
|
||||
E2E_BASE_URL,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 6. Server Startup
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestCLIServe:
|
||||
"""agentkit serve — server startup (basic check, not full lifecycle)."""
|
||||
|
||||
def test_serve_help(self, cli_runner: CLIRunner):
|
||||
result = cli_runner.run(["serve", "--help"])
|
||||
assert result.returncode == 0
|
||||
assert "--host" in result.stdout
|
||||
assert "--port" in result.stdout
|
||||
|
||||
def test_serve_invalid_port(self, cli_runner: CLIRunner):
|
||||
"""Serve with an invalid port should fail gracefully."""
|
||||
result = cli_runner.run(["serve", "--port", "not_a_port"], timeout=5)
|
||||
# Should error out, not hang
|
||||
assert result.returncode != 0 or "error" in (result.stdout + result.stderr).lower()
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""E2E Basic Function Tests — WebSocket chat protocol.
|
||||
|
||||
Verifies the WebSocket chat protocol works correctly:
|
||||
1. Connection lifecycle (connect → connected → ping/pong → disconnect)
|
||||
2. Message exchange (user message → token stream → final_answer)
|
||||
3. Confirmation flow (confirmation_request → confirmation_reply → confirmation_result)
|
||||
4. AskHuman flow (ask_human → reply → continue)
|
||||
5. Cancel flow (cancel → error/cancelled)
|
||||
6. Expert team events (team_formed → expert_step → team_synthesis → team_dissolved)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import WSChatHelper, create_session_via_api, register_skill_via_api
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 1. Connection Lifecycle
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestWSConnection:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_receives_connected_event(self, ws_helper: WSChatHelper, api_client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
messages = await ws_helper.connect_and_chat(session_id, [])
|
||||
assert len(messages) >= 1
|
||||
assert messages[0].get("type") == "connected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_pong(self, ws_helper: WSChatHelper, api_client):
|
||||
"""Ping should receive pong response."""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
pytest.skip("websockets not installed")
|
||||
|
||||
session_id = create_session_via_api(api_client)
|
||||
uri = f"{ws_helper.base_ws_url}/api/v1/chat/ws/{session_id}?api_key={ws_helper.api_key}"
|
||||
|
||||
received: list[dict] = []
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Wait for connected
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
received.append(json.loads(msg))
|
||||
|
||||
# Send ping
|
||||
await ws.send(json.dumps({"type": "ping"}))
|
||||
raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
resp = json.loads(raw)
|
||||
assert resp.get("type") == "pong"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_session_id(self, ws_helper: WSChatHelper):
|
||||
"""Connecting with invalid session ID should fail."""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
pytest.skip("websockets not installed")
|
||||
|
||||
uri = f"{ws_helper.base_ws_url}/api/v1/chat/ws/nonexistent-session?api_key={ws_helper.api_key}"
|
||||
with pytest.raises(Exception):
|
||||
async with websockets.connect(uri) as ws:
|
||||
await asyncio.wait_for(ws.recv(), timeout=5)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 2. Message Exchange
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestWSMessageExchange:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_get_response(self, ws_helper: WSChatHelper, api_client):
|
||||
session_id = create_session_via_api(api_client)
|
||||
messages = await ws_helper.connect_and_chat(
|
||||
session_id,
|
||||
[{"type": "message", "content": "Hello, this is an e2e test"}],
|
||||
)
|
||||
# Should receive at least: connected + some response (token/final_answer/error)
|
||||
assert len(messages) >= 2
|
||||
# Last meaningful message should be final_answer or error
|
||||
response_types = [m.get("type") for m in messages]
|
||||
assert any(t in response_types for t in ("final_answer", "error", "token", "thinking"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_types_are_valid(self, ws_helper: WSChatHelper, api_client):
|
||||
"""All server-sent messages should have a valid 'type' field."""
|
||||
session_id = create_session_via_api(api_client)
|
||||
messages = await ws_helper.connect_and_chat(
|
||||
session_id,
|
||||
[{"type": "message", "content": "Test valid message types"}],
|
||||
)
|
||||
valid_types = {
|
||||
"connected",
|
||||
"token",
|
||||
"thinking",
|
||||
"step",
|
||||
"final_answer",
|
||||
"skill_match",
|
||||
"confirmation_request",
|
||||
"confirmation_result",
|
||||
"ask_human",
|
||||
"error",
|
||||
"pong",
|
||||
"team_formed",
|
||||
"expert_step",
|
||||
"expert_result",
|
||||
"plan_update",
|
||||
"team_synthesis",
|
||||
"team_dissolved",
|
||||
}
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict) and "type" in msg:
|
||||
assert msg["type"] in valid_types, f"Invalid message type: {msg['type']}"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 3. Cancel Flow
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestWSCancel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_message_accepted(self, ws_helper: WSChatHelper, api_client):
|
||||
"""Sending cancel should be accepted by the server."""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
pytest.skip("websockets not installed")
|
||||
|
||||
session_id = create_session_via_api(api_client)
|
||||
uri = f"{ws_helper.base_ws_url}/api/v1/chat/ws/{session_id}?api_key={ws_helper.api_key}"
|
||||
|
||||
async with websockets.connect(uri) as ws:
|
||||
# Wait for connected
|
||||
await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
# Send a message first
|
||||
await ws.send(json.dumps({"type": "message", "content": "Start a task"}))
|
||||
# Immediately send cancel
|
||||
await ws.send(json.dumps({"type": "cancel"}))
|
||||
# Server should handle gracefully (no crash)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# 4. Skill Match Event
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_basic
|
||||
class TestWSSkillMatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_match_notification(self, ws_helper: WSChatHelper, api_client):
|
||||
"""When a skill is matched, server should send skill_match event."""
|
||||
register_skill_via_api(api_client, "ws_skill", keywords=["ws_skill_match"])
|
||||
session_id = create_session_via_api(api_client)
|
||||
messages = await ws_helper.connect_and_chat(
|
||||
session_id,
|
||||
[{"type": "message", "content": "Please use ws_skill_match for this"}],
|
||||
)
|
||||
# Check if skill_match event was sent (may or may not happen depending on routing)
|
||||
_ = [m.get("type") for m in messages] # noqa: F841
|
||||
# At minimum, we should get a response (skill_match or direct answer)
|
||||
assert len(messages) >= 2
|
||||
|
|
@ -0,0 +1,305 @@
|
|||
"""E2E capability tests for AlignmentGuard and CascadeDetector.
|
||||
|
||||
Tests constraint injection, output checking, and cascade alert behavior
|
||||
using the actual AlignmentGuard and CascadeDetector implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.quality.alignment import (
|
||||
AlignmentCheckResult,
|
||||
AlignmentConfig,
|
||||
AlignmentGuard,
|
||||
ConstraintInjector,
|
||||
)
|
||||
from agentkit.quality.cascade_detector import CascadeDetector
|
||||
|
||||
from tests.e2e.benchmark_dataset import BenchmarkCase
|
||||
from tests.e2e.capability_metrics import CapabilityObservation, MetricsCollector
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alignment benchmark cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
ALIGNMENT_BENCHMARKS: list[BenchmarkCase] = [
|
||||
# Negative constraints: "不要X" → output should NOT contain X
|
||||
BenchmarkCase(
|
||||
id="align-neg-001",
|
||||
input="请分析市场趋势,不要提及价格信息",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="alignment",
|
||||
subcategory="negative_constraint",
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-neg-002",
|
||||
input="总结这篇文章,禁止包含个人观点",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="medium",
|
||||
category="alignment",
|
||||
subcategory="negative_constraint",
|
||||
),
|
||||
# Positive constraints: "必须X" → output SHOULD contain X
|
||||
BenchmarkCase(
|
||||
id="align-pos-001",
|
||||
input="分析竞争对手,必须包含摘要部分",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="alignment",
|
||||
subcategory="positive_constraint",
|
||||
),
|
||||
BenchmarkCase(
|
||||
id="align-pos-002",
|
||||
input="审查代码,需要提供改进建议",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="medium",
|
||||
category="alignment",
|
||||
subcategory="positive_constraint",
|
||||
),
|
||||
# Cascade alert: repeated interactions should trigger alert
|
||||
BenchmarkCase(
|
||||
id="align-cascade-001",
|
||||
input="重复执行相似查询触发级联告警",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="medium",
|
||||
category="alignment",
|
||||
subcategory="cascade_alert",
|
||||
),
|
||||
# No constraints: should pass cleanly
|
||||
BenchmarkCase(
|
||||
id="align-none-001",
|
||||
input="帮我分析一下用户数据",
|
||||
expected_skill=None,
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="medium",
|
||||
category="alignment",
|
||||
subcategory="no_constraint",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: ConstraintInjector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstraintInjector:
|
||||
def test_inject_constraints(self) -> None:
|
||||
config = AlignmentConfig(constraints=["不要提及价格", "必须包含摘要"])
|
||||
injector = ConstraintInjector(config)
|
||||
input_data = {"query": "分析市场趋势"}
|
||||
result = injector.inject(input_data)
|
||||
assert "alignment_constraints" in result
|
||||
assert result["alignment_constraints"] == ["不要提及价格", "必须包含摘要"]
|
||||
# Original data should not be modified
|
||||
assert "alignment_constraints" not in input_data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: AlignmentGuard rule-based checking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignmentGuardRuleCheck:
|
||||
@pytest.fixture
|
||||
def guard(self) -> AlignmentGuard:
|
||||
config = AlignmentConfig(
|
||||
constraints=["不要提及价格信息", "必须摘要"],
|
||||
audit_enabled=False,
|
||||
)
|
||||
return AlignmentGuard(config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_constraint_pass(self, guard: AlignmentGuard) -> None:
|
||||
"""Output without forbidden content should pass."""
|
||||
output = {"content": "市场趋势分析:整体呈上升趋势。摘要:市场表现良好。"}
|
||||
result = await guard.check_output(output)
|
||||
assert isinstance(result, AlignmentCheckResult)
|
||||
# "价格信息" not in output → should pass
|
||||
assert result.passed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_constraint_violation(self, guard: AlignmentGuard) -> None:
|
||||
"""Output containing forbidden content should fail."""
|
||||
output = {"content": "当前提及价格信息显示市场上涨。摘要:市场持续走高。"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is False
|
||||
assert len(result.violations) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_constraint_pass(self, guard: AlignmentGuard) -> None:
|
||||
"""Output containing required content should pass."""
|
||||
output = {"content": "分析结果如下。摘要:市场趋势向好。"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_positive_constraint_violation(self, guard: AlignmentGuard) -> None:
|
||||
"""Output missing required content should fail."""
|
||||
output = {"content": "分析结果如下。市场趋势向好。"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_constraints(self) -> None:
|
||||
"""Guard with no constraints should always pass."""
|
||||
config = AlignmentConfig(constraints=[], audit_enabled=False)
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "任意内容"}
|
||||
result = await guard.check_output(output)
|
||||
assert result.passed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negation_context_not_violation(self) -> None:
|
||||
"""Mentioning forbidden content in negative context should not be a violation."""
|
||||
config = AlignmentConfig(
|
||||
constraints=["不要提及价格"],
|
||||
audit_enabled=False,
|
||||
)
|
||||
guard = AlignmentGuard(config)
|
||||
output = {"content": "我们不会提及价格信息,请放心。摘要:市场分析完成。"}
|
||||
result = await guard.check_output(output)
|
||||
# "价格" appears but in negative context ("不会提及价格")
|
||||
assert result.passed is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: CascadeDetector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCascadeDetector:
|
||||
def test_no_alert_below_threshold(self) -> None:
|
||||
detector = CascadeDetector(max_interactions=5)
|
||||
for _ in range(5):
|
||||
alert = detector.check_interaction("session-1")
|
||||
assert alert is None
|
||||
|
||||
def test_alert_above_interaction_threshold(self) -> None:
|
||||
detector = CascadeDetector(max_interactions=5)
|
||||
for _ in range(5):
|
||||
detector.check_interaction("session-2")
|
||||
# 6th interaction should trigger alert
|
||||
alert = detector.check_interaction("session-2")
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
assert alert.current_value == 6
|
||||
|
||||
def test_alert_above_depth_threshold(self) -> None:
|
||||
detector = CascadeDetector(max_depth=3)
|
||||
alert = detector.check_depth("session-3", 4)
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "loop_depth"
|
||||
assert alert.current_value == 4
|
||||
|
||||
def test_no_alert_below_depth_threshold(self) -> None:
|
||||
detector = CascadeDetector(max_depth=3)
|
||||
alert = detector.check_depth("session-4", 3)
|
||||
assert alert is None
|
||||
|
||||
def test_reset_clears_state(self) -> None:
|
||||
detector = CascadeDetector(max_interactions=3)
|
||||
for _ in range(3):
|
||||
detector.check_interaction("session-5")
|
||||
detector.reset("session-5")
|
||||
# After reset, count should be back to 0
|
||||
alert = detector.check_interaction("session-5")
|
||||
assert alert is None # count is now 1, below threshold
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: AlignmentGuard cascade integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignmentGuardCascade:
|
||||
def test_record_interaction_returns_alert(self) -> None:
|
||||
config = AlignmentConfig(cascade_max_interactions=3)
|
||||
guard = AlignmentGuard(config)
|
||||
for _ in range(3):
|
||||
guard.record_interaction("session-10")
|
||||
alert = guard.record_interaction("session-10")
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "interaction_limit"
|
||||
|
||||
def test_record_loop_depth_returns_alert(self) -> None:
|
||||
config = AlignmentConfig(cascade_max_depth=2)
|
||||
guard = AlignmentGuard(config)
|
||||
alert = guard.record_loop_depth("session-11", 3)
|
||||
assert alert is not None
|
||||
assert alert.alert_type == "loop_depth"
|
||||
|
||||
def test_reset_session(self) -> None:
|
||||
config = AlignmentConfig(cascade_max_interactions=2)
|
||||
guard = AlignmentGuard(config)
|
||||
guard.record_interaction("session-12")
|
||||
guard.record_interaction("session-12")
|
||||
guard.reset_session("session-12")
|
||||
assert guard.get_interaction_count("session-12") == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Metrics collection for alignment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAlignmentMetricsCollection:
|
||||
def test_record_alignment_observation(self) -> None:
|
||||
collector = MetricsCollector()
|
||||
obs = CapabilityObservation(
|
||||
benchmark_id="align-neg-001",
|
||||
test_name="test_neg_constraint",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
input_query="请分析市场趋势,不要提及价格信息",
|
||||
category="alignment",
|
||||
subcategory="negative_constraint",
|
||||
alignment_violations=0,
|
||||
cascade_alert=False,
|
||||
)
|
||||
collector.record(obs)
|
||||
alignment_obs = collector.get_observations_by_category("alignment")
|
||||
assert len(alignment_obs) == 1
|
||||
assert alignment_obs[0].alignment_violations == 0
|
||||
|
||||
def test_record_alignment_with_violations(self) -> None:
|
||||
collector = MetricsCollector()
|
||||
obs = CapabilityObservation(
|
||||
benchmark_id="align-neg-002",
|
||||
test_name="test_neg_constraint_violation",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
input_query="总结这篇文章,禁止包含个人观点",
|
||||
category="alignment",
|
||||
subcategory="negative_constraint",
|
||||
alignment_violations=1,
|
||||
cascade_alert=False,
|
||||
)
|
||||
collector.record(obs)
|
||||
alignment_obs = collector.get_observations_by_category("alignment")
|
||||
assert alignment_obs[0].alignment_violations == 1
|
||||
|
||||
def test_record_cascade_alert(self) -> None:
|
||||
collector = MetricsCollector()
|
||||
obs = CapabilityObservation(
|
||||
benchmark_id="align-cascade-001",
|
||||
test_name="test_cascade_alert",
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
input_query="重复执行相似查询触发级联告警",
|
||||
category="alignment",
|
||||
subcategory="cascade_alert",
|
||||
alignment_violations=0,
|
||||
cascade_alert=True,
|
||||
)
|
||||
collector.record(obs)
|
||||
alignment_obs = collector.get_observations_by_category("alignment")
|
||||
assert alignment_obs[0].cascade_alert is True
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
"""E2E Agent Capability Tests — ReAct Reasoning & Execution with Metrics.
|
||||
|
||||
Tests the intelligence of agent execution AND collects data for:
|
||||
- Execution mode selection accuracy
|
||||
- Quality gate effectiveness
|
||||
- Task success rate by mode
|
||||
- Output standardization consistency
|
||||
- Overfitting detection via paraphrased inputs
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from tests.e2e.benchmark_dataset import EXECUTION_BENCHMARKS, BenchmarkCase
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Helper: run execution benchmark and record metrics
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _run_exec_benchmark(
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
collector: MetricsCollector,
|
||||
test_name: str,
|
||||
is_paraphrase: bool = False,
|
||||
input_override: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute an execution benchmark and record metrics."""
|
||||
query = input_override or benchmark.input
|
||||
collector.start_timer(benchmark.id)
|
||||
|
||||
payload: dict = {"input_data": {"query": query}}
|
||||
if benchmark.expected_skill is not None:
|
||||
payload["skill_name"] = benchmark.expected_skill
|
||||
|
||||
resp = api_client.post("/api/v1/tasks", json=payload)
|
||||
|
||||
actual_skill = None
|
||||
actual_exec_mode = None
|
||||
actual_keys = []
|
||||
task_succeeded = resp.status_code == 200
|
||||
error_msg = None
|
||||
|
||||
if task_succeeded:
|
||||
data = resp.json()
|
||||
actual_skill = data.get("skill_name")
|
||||
actual_exec_mode = data.get("execution_mode")
|
||||
actual_keys = list(data.keys())
|
||||
elif resp.status_code >= 400:
|
||||
try:
|
||||
error_msg = resp.json().get("detail", resp.text[:200])
|
||||
except Exception:
|
||||
error_msg = resp.text[:200]
|
||||
|
||||
collector.record_benchmark_result(
|
||||
benchmark,
|
||||
test_name=test_name,
|
||||
actual_skill=actual_skill,
|
||||
actual_execution_mode=actual_exec_mode,
|
||||
actual_status_code=resp.status_code,
|
||||
actual_response_keys=actual_keys,
|
||||
task_succeeded=task_succeeded,
|
||||
is_paraphrase=is_paraphrase,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
return {
|
||||
"status_code": resp.status_code,
|
||||
"actual_skill": actual_skill,
|
||||
"actual_exec_mode": actual_exec_mode,
|
||||
"actual_keys": actual_keys,
|
||||
"task_succeeded": task_succeeded,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Parameterized Execution Benchmark Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestExecutionBenchmarks:
|
||||
"""Run all execution benchmarks with metrics collection."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
EXECUTION_BENCHMARKS,
|
||||
ids=[b.id for b in EXECUTION_BENCHMARKS],
|
||||
)
|
||||
def test_execution_benchmark(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run original execution benchmark and record metrics."""
|
||||
# Register the skill if expected
|
||||
if benchmark.expected_skill:
|
||||
exec_mode = (
|
||||
benchmark.expected_execution_mode
|
||||
if benchmark.expected_execution_mode != "direct"
|
||||
else "direct"
|
||||
)
|
||||
register_skill_via_api(
|
||||
api_client,
|
||||
benchmark.expected_skill,
|
||||
keywords=[benchmark.expected_skill],
|
||||
execution_mode=exec_mode,
|
||||
)
|
||||
|
||||
result = _run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"exec_benchmark_{benchmark.id}",
|
||||
)
|
||||
assert result["status_code"] == 200, f"Benchmark {benchmark.id} failed"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
[b for b in EXECUTION_BENCHMARKS if b.paraphrases],
|
||||
ids=[b.id for b in EXECUTION_BENCHMARKS if b.paraphrases],
|
||||
)
|
||||
def test_execution_paraphrase(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run paraphrases for overfitting detection."""
|
||||
for i, paraphrase in enumerate(benchmark.paraphrases):
|
||||
_run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"exec_paraphrase_{benchmark.id}_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=paraphrase,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# ReAct Loop Intelligence
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestReActIntelligence:
|
||||
"""Test that ReAct agents reason correctly through Think→Act→Observe."""
|
||||
|
||||
def test_react_skill_executes_steps(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""ReAct skill should execute multiple steps for complex tasks."""
|
||||
benchmark = BenchmarkCase(
|
||||
id="react-steps-001",
|
||||
input="Research and analyze the impact of AI on healthcare",
|
||||
expected_skill="react_reasoner",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="execution",
|
||||
subcategory="react_mode",
|
||||
paraphrases=["Investigate AI's effect on medical industry", "调研AI对医疗行业的影响"],
|
||||
)
|
||||
register_skill_via_api(
|
||||
api_client,
|
||||
"react_reasoner",
|
||||
keywords=["react_reason", "research", "analyze"],
|
||||
execution_mode="react",
|
||||
)
|
||||
|
||||
result = _run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="react_steps",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
_run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"react_steps_para_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Quality Gate Intelligence
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestQualityGateIntelligence:
|
||||
"""Test that quality gate correctly validates and retries outputs."""
|
||||
|
||||
def test_quality_gate_with_required_fields(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Quality gate should enforce required_fields in output."""
|
||||
benchmark = BenchmarkCase(
|
||||
id="quality-fields-001",
|
||||
input="Generate content with quality check",
|
||||
expected_skill="quality_skill",
|
||||
expected_complexity="medium",
|
||||
category="execution",
|
||||
subcategory="quality_gate",
|
||||
)
|
||||
register_skill_via_api(api_client, "quality_skill", keywords=["quality_test"])
|
||||
|
||||
result = _run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="quality_fields",
|
||||
)
|
||||
assert result["status_code"] in (200, 400, 422)
|
||||
|
||||
def test_quality_gate_rejects_empty_output(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Quality gate should reject empty or minimal output."""
|
||||
benchmark = BenchmarkCase(
|
||||
id="quality-empty-001",
|
||||
input="",
|
||||
expected_skill="quality_empty",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="quality_gate",
|
||||
)
|
||||
register_skill_via_api(api_client, "quality_empty", keywords=["quality_empty"])
|
||||
|
||||
result = _run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="quality_empty",
|
||||
)
|
||||
# Should handle gracefully
|
||||
assert result["status_code"] in (200, 400, 422)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Output Standardization Intelligence
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestOutputStandardization:
|
||||
"""Test that agent outputs are properly standardized."""
|
||||
|
||||
def test_output_has_required_structure(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Task results should have a consistent structure."""
|
||||
register_skill_via_api(api_client, "output_std_skill", keywords=["output_std"])
|
||||
|
||||
benchmark = BenchmarkCase(
|
||||
id="output-std-001",
|
||||
input="Test output standardization",
|
||||
expected_skill="output_std_skill",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="output_std",
|
||||
)
|
||||
|
||||
result = _run_exec_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="output_std",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
assert result["task_succeeded"]
|
||||
|
||||
def test_different_skills_produce_consistent_format(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Different skills should produce results in consistent format."""
|
||||
register_skill_via_api(api_client, "format_skill_a", keywords=["format_a"])
|
||||
register_skill_via_api(api_client, "format_skill_b", keywords=["format_b"])
|
||||
|
||||
bench_a = BenchmarkCase(
|
||||
id="format-a-001",
|
||||
input="Test format A",
|
||||
expected_skill="format_skill_a",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="output_std",
|
||||
)
|
||||
bench_b = BenchmarkCase(
|
||||
id="format-b-001",
|
||||
input="Test format B",
|
||||
expected_skill="format_skill_b",
|
||||
expected_complexity="low",
|
||||
category="execution",
|
||||
subcategory="output_std",
|
||||
)
|
||||
|
||||
result_a = _run_exec_benchmark(bench_a, api_client, metrics_collector, test_name="format_a")
|
||||
result_b = _run_exec_benchmark(bench_b, api_client, metrics_collector, test_name="format_b")
|
||||
|
||||
if result_a["task_succeeded"] and result_b["task_succeeded"]:
|
||||
# Both should have some common response keys
|
||||
keys_a = set(result_a["actual_keys"])
|
||||
keys_b = set(result_b["actual_keys"])
|
||||
assert len(keys_a & keys_b) > 0 or len(keys_a) > 0
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
"""E2E Agent Capability Tests — Router Direct Backtest Layer (Real LLM).
|
||||
|
||||
Directly tests CostAwareRouter.route() using real LLM configuration
|
||||
loaded from agentkit.yaml. Records full SkillRoutingResult for precise
|
||||
root cause analysis:
|
||||
- match_method (layer0/layer1/layer1.5/layer2)
|
||||
- match_confidence
|
||||
- complexity score
|
||||
- execution_trace
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter
|
||||
from agentkit.router.intent import IntentRouter
|
||||
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
from tests.e2e.benchmark_dataset import (
|
||||
ALL_BENCHMARKS,
|
||||
ROUTING_KEYWORD_BENCHMARKS,
|
||||
ROUTING_EDGE_BENCHMARKS,
|
||||
SEMANTIC_ROUTER_BENCHMARKS,
|
||||
BenchmarkCase,
|
||||
)
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Real component initialization from agentkit.yaml
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _find_config_path() -> str | None:
|
||||
"""Find agentkit.yaml in standard search paths."""
|
||||
candidates = [
|
||||
os.environ.get("AGENTKIT_CONFIG", ""),
|
||||
str(Path.cwd() / "agentkit.yaml"),
|
||||
str(Path.home() / ".agentkit" / "agentkit.yaml"),
|
||||
]
|
||||
for path in candidates:
|
||||
if path and Path(path).is_file():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def _build_real_components() -> tuple[CostAwareRouter, SkillRegistry, IntentRouter]:
|
||||
"""Build real components from agentkit.yaml configuration.
|
||||
|
||||
Returns (router, skill_registry, intent_router).
|
||||
Raises skip if no valid LLM provider is configured.
|
||||
"""
|
||||
config_path = _find_config_path()
|
||||
if not config_path:
|
||||
pytest.skip("No agentkit.yaml found — cannot build real components")
|
||||
|
||||
# Load .env if present
|
||||
env_path = Path(config_path).parent / ".env"
|
||||
if env_path.exists():
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(env_path)
|
||||
except ImportError:
|
||||
# python-dotenv not installed, manually parse .env
|
||||
with open(env_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, _, value = line.partition("=")
|
||||
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
|
||||
|
||||
server_config = ServerConfig.from_yaml(config_path)
|
||||
|
||||
# Check if any LLM provider has a valid API key
|
||||
if not server_config.has_llm_provider():
|
||||
# Try to inject DASHSCOPE_API_KEY from environment
|
||||
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||
if dashscope_key:
|
||||
# Inject into the test provider config
|
||||
for name, pconf in server_config.llm_config.providers.items():
|
||||
if not pconf.api_key:
|
||||
pconf.api_key = dashscope_key
|
||||
# Set base_url for dashscope if missing
|
||||
if not pconf.base_url:
|
||||
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
break
|
||||
|
||||
if not server_config.has_llm_provider():
|
||||
pytest.skip("No LLM provider with valid API key — skipping real LLM tests")
|
||||
|
||||
# Build real LLM gateway
|
||||
llm_gateway = _build_llm_gateway(server_config)
|
||||
|
||||
# Build real skill registry from configs/skills
|
||||
skill_registry = _build_skill_registry(server_config)
|
||||
|
||||
# Build real intent router
|
||||
intent_router = IntentRouter(llm_gateway=llm_gateway)
|
||||
|
||||
# Build real CostAwareRouter
|
||||
router_conf = server_config.router or {}
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=llm_gateway,
|
||||
model="default",
|
||||
org_context=None,
|
||||
auction_enabled=router_conf.get("auction_enabled", False),
|
||||
classifier=router_conf.get("classifier", "heuristic"),
|
||||
merged_llm_classify=router_conf.get("merged_llm_classify", True),
|
||||
)
|
||||
|
||||
return router, skill_registry, intent_router
|
||||
|
||||
|
||||
# Cache components at module level to avoid rebuilding for every test
|
||||
_cached_components: tuple[CostAwareRouter, SkillRegistry, IntentRouter] | None = None
|
||||
|
||||
|
||||
def _get_components() -> tuple[CostAwareRouter, SkillRegistry, IntentRouter]:
|
||||
"""Get or build real components (cached for session)."""
|
||||
global _cached_components
|
||||
if _cached_components is None:
|
||||
_cached_components = _build_real_components()
|
||||
return _cached_components
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Helper: Run a single benchmark through the real router
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
async def _run_router_benchmark(
|
||||
benchmark: BenchmarkCase,
|
||||
collector: MetricsCollector,
|
||||
test_name: str,
|
||||
is_paraphrase: bool = False,
|
||||
input_override: str | None = None,
|
||||
) -> dict:
|
||||
"""Run a single benchmark through the real router."""
|
||||
router, skill_registry, intent_router = _get_components()
|
||||
query = input_override or benchmark.input
|
||||
|
||||
collector.start_timer(benchmark.id)
|
||||
|
||||
try:
|
||||
result = await router.route(
|
||||
content=query,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt=None,
|
||||
)
|
||||
|
||||
actual_skill = result.skill_name
|
||||
actual_exec_mode = result.execution_mode.value if result.execution_mode else None
|
||||
actual_complexity = result.complexity
|
||||
actual_match_method = result.match_method
|
||||
actual_match_confidence = result.match_confidence
|
||||
task_succeeded = True
|
||||
error_msg = None
|
||||
except Exception as e:
|
||||
actual_skill = None
|
||||
actual_exec_mode = None
|
||||
actual_complexity = 0.0
|
||||
actual_match_method = None
|
||||
actual_match_confidence = 0.0
|
||||
task_succeeded = False
|
||||
error_msg = str(e)[:200]
|
||||
|
||||
# Map complexity score to level
|
||||
if actual_complexity < 0.3:
|
||||
actual_complexity_level = "low"
|
||||
elif actual_complexity < 0.7:
|
||||
actual_complexity_level = "medium"
|
||||
else:
|
||||
actual_complexity_level = "high"
|
||||
|
||||
# Judge correctness
|
||||
skill_correct = None
|
||||
if benchmark.expected_skill is not None and actual_skill is not None:
|
||||
skill_correct = actual_skill == benchmark.expected_skill
|
||||
elif benchmark.expected_skill is None:
|
||||
skill_correct = actual_skill is None or task_succeeded
|
||||
|
||||
execution_mode_correct = None
|
||||
if actual_exec_mode is not None and benchmark.expected_execution_mode:
|
||||
mode_map = {
|
||||
"direct": "DIRECT_CHAT",
|
||||
"react": "SKILL_REACT",
|
||||
"rewoo": "REWOO",
|
||||
"reflexion": "REFLEXION",
|
||||
"plan_exec": "PLAN_EXEC",
|
||||
"team_collab": "TEAM_COLLAB",
|
||||
"llm_generate": "SKILL_REACT",
|
||||
"tool_call": "SKILL_REACT",
|
||||
"custom": "SKILL_REACT",
|
||||
}
|
||||
expected_normalized = mode_map.get(
|
||||
benchmark.expected_execution_mode, benchmark.expected_execution_mode.upper()
|
||||
)
|
||||
execution_mode_correct = actual_exec_mode.upper() == expected_normalized
|
||||
|
||||
complexity_correct = actual_complexity_level == benchmark.expected_complexity
|
||||
|
||||
obs = collector.record_benchmark_result(
|
||||
benchmark,
|
||||
test_name=test_name,
|
||||
actual_skill=actual_skill,
|
||||
actual_execution_mode=actual_exec_mode,
|
||||
actual_status_code=200 if task_succeeded else 500,
|
||||
task_succeeded=task_succeeded,
|
||||
is_paraphrase=is_paraphrase,
|
||||
error_message=error_msg,
|
||||
)
|
||||
obs.complexity_correct = complexity_correct
|
||||
|
||||
return {
|
||||
"skill_correct": skill_correct,
|
||||
"execution_mode_correct": execution_mode_correct,
|
||||
"complexity_correct": complexity_correct,
|
||||
"actual_skill": actual_skill,
|
||||
"actual_exec_mode": actual_exec_mode,
|
||||
"actual_complexity": actual_complexity,
|
||||
"actual_complexity_level": actual_complexity_level,
|
||||
"actual_match_method": actual_match_method,
|
||||
"actual_match_confidence": actual_match_confidence,
|
||||
"task_succeeded": task_succeeded,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Layer 0: Rule Matching Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRouterLayer0:
|
||||
"""Test Layer 0 rule matching with real router."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
[
|
||||
b
|
||||
for b in ROUTING_EDGE_BENCHMARKS
|
||||
if b.subcategory in ("greeting", "identity", "explicit_prefix")
|
||||
],
|
||||
ids=[
|
||||
b.id
|
||||
for b in ROUTING_EDGE_BENCHMARKS
|
||||
if b.subcategory in ("greeting", "identity", "explicit_prefix")
|
||||
],
|
||||
)
|
||||
def test_layer0_rules(self, benchmark: BenchmarkCase, metrics_collector: MetricsCollector):
|
||||
"""Layer 0 should correctly match greetings, identity, and @skill: prefix."""
|
||||
result = asyncio.run(
|
||||
_run_router_benchmark(benchmark, metrics_collector, f"layer0_{benchmark.id}")
|
||||
)
|
||||
if benchmark.subcategory == "greeting":
|
||||
assert result["actual_match_method"] in ("layer0", None) or result["task_succeeded"]
|
||||
if benchmark.subcategory == "explicit_prefix":
|
||||
assert result["actual_skill"] == benchmark.expected_skill or result["task_succeeded"]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Layer 1: Complexity Classification Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRouterLayer1:
|
||||
"""Test Layer 1 complexity classification with real router."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
ROUTING_KEYWORD_BENCHMARKS,
|
||||
ids=[b.id for b in ROUTING_KEYWORD_BENCHMARKS],
|
||||
)
|
||||
def test_complexity_classification(
|
||||
self, benchmark: BenchmarkCase, metrics_collector: MetricsCollector
|
||||
):
|
||||
"""HeuristicClassifier should correctly estimate complexity."""
|
||||
asyncio.run(_run_router_benchmark(benchmark, metrics_collector, f"layer1_{benchmark.id}"))
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Semantic Router Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestSemanticRouter:
|
||||
"""Test semantic router matching with real router."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
SEMANTIC_ROUTER_BENCHMARKS,
|
||||
ids=[b.id for b in SEMANTIC_ROUTER_BENCHMARKS],
|
||||
)
|
||||
def test_semantic_match(self, benchmark: BenchmarkCase, metrics_collector: MetricsCollector):
|
||||
"""SemanticRouter should match skill descriptions."""
|
||||
asyncio.run(_run_router_benchmark(benchmark, metrics_collector, f"semantic_{benchmark.id}"))
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Paraphrase Consistency Tests (Overfitting Detection)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRouterParaphraseConsistency:
|
||||
"""Test that paraphrased inputs route to the same skill as originals."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
[b for b in ALL_BENCHMARKS if b.paraphrases and b.expected_skill is not None][:10],
|
||||
ids=[b.id for b in ALL_BENCHMARKS if b.paraphrases and b.expected_skill is not None][:10],
|
||||
)
|
||||
def test_paraphrase_routes_same_skill(
|
||||
self, benchmark: BenchmarkCase, metrics_collector: MetricsCollector
|
||||
):
|
||||
"""Original and paraphrased inputs should route to the same skill."""
|
||||
# Run original
|
||||
asyncio.run(
|
||||
_run_router_benchmark(benchmark, metrics_collector, f"para_orig_{benchmark.id}")
|
||||
)
|
||||
|
||||
# Run paraphrases
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
asyncio.run(
|
||||
_run_router_benchmark(
|
||||
benchmark,
|
||||
metrics_collector,
|
||||
f"para_{benchmark.id}_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
)
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
"""E2E Agent Capability Tests — Intent Routing Intelligence with Metrics Collection.
|
||||
|
||||
Tests the intelligence of the CostAwareRouter (3-layer routing) AND collects
|
||||
data for recall/precision/F1 analysis, overfitting detection, and weakness
|
||||
identification.
|
||||
|
||||
Each test:
|
||||
1. Runs the benchmark case (original input)
|
||||
2. Runs all paraphrases of the same input (overfitting detection)
|
||||
3. Records observations to MetricsCollector
|
||||
4. Asserts minimum quality thresholds
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from tests.e2e.benchmark_dataset import (
|
||||
ROUTING_KEYWORD_BENCHMARKS,
|
||||
ROUTING_EDGE_BENCHMARKS,
|
||||
CONSISTENCY_BENCHMARKS,
|
||||
BenchmarkCase,
|
||||
get_skill_names_needed,
|
||||
)
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Pre-registration of all skills needed by benchmarks
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def register_benchmark_skills(api_client: httpx.Client):
|
||||
"""Auto-register all skills needed by routing benchmarks."""
|
||||
for skill_name in get_skill_names_needed():
|
||||
register_skill_via_api(api_client, skill_name, keywords=[skill_name])
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Helper: run a single benchmark case and record metrics
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _run_benchmark_and_record(
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
collector: MetricsCollector,
|
||||
test_name: str,
|
||||
is_paraphrase: bool = False,
|
||||
input_override: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute a benchmark case against the API and record metrics."""
|
||||
query = input_override or benchmark.input
|
||||
collector.start_timer(benchmark.id)
|
||||
|
||||
payload: dict = {"input_data": {"query": query}}
|
||||
if benchmark.expected_skill is not None:
|
||||
payload["skill_name"] = benchmark.expected_skill
|
||||
|
||||
resp = api_client.post("/api/v1/tasks", json=payload)
|
||||
|
||||
actual_skill = None
|
||||
actual_exec_mode = None
|
||||
actual_keys = []
|
||||
task_succeeded = resp.status_code == 200
|
||||
error_msg = None
|
||||
|
||||
if task_succeeded:
|
||||
data = resp.json()
|
||||
actual_skill = data.get("skill_name")
|
||||
actual_exec_mode = data.get("execution_mode")
|
||||
actual_keys = list(data.keys())
|
||||
elif resp.status_code >= 400:
|
||||
try:
|
||||
error_msg = resp.json().get("detail", resp.text[:200])
|
||||
except Exception:
|
||||
error_msg = resp.text[:200]
|
||||
|
||||
collector.record_benchmark_result(
|
||||
benchmark,
|
||||
test_name=test_name,
|
||||
actual_skill=actual_skill,
|
||||
actual_execution_mode=actual_exec_mode,
|
||||
actual_status_code=resp.status_code,
|
||||
actual_response_keys=actual_keys,
|
||||
task_succeeded=task_succeeded,
|
||||
is_paraphrase=is_paraphrase,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
return {
|
||||
"status_code": resp.status_code,
|
||||
"actual_skill": actual_skill,
|
||||
"actual_exec_mode": actual_exec_mode,
|
||||
"task_succeeded": task_succeeded,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Parameterized Routing Benchmark Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRoutingBenchmarks:
|
||||
"""Run all routing benchmarks with metrics collection."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
ROUTING_KEYWORD_BENCHMARKS + ROUTING_EDGE_BENCHMARKS,
|
||||
ids=[b.id for b in ROUTING_KEYWORD_BENCHMARKS + ROUTING_EDGE_BENCHMARKS],
|
||||
)
|
||||
def test_routing_benchmark(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run original benchmark input and record metrics."""
|
||||
result = _run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"routing_benchmark_{benchmark.id}",
|
||||
)
|
||||
assert result["status_code"] == 200, f"Benchmark {benchmark.id} failed: {result}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
[b for b in ROUTING_KEYWORD_BENCHMARKS + ROUTING_EDGE_BENCHMARKS if b.paraphrases],
|
||||
ids=[b.id for b in ROUTING_KEYWORD_BENCHMARKS + ROUTING_EDGE_BENCHMARKS if b.paraphrases],
|
||||
)
|
||||
def test_routing_paraphrase(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run all paraphrases for overfitting detection."""
|
||||
for i, paraphrase in enumerate(benchmark.paraphrases):
|
||||
_run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"routing_paraphrase_{benchmark.id}_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=paraphrase,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Routing Consistency (same input, multiple runs)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRoutingConsistency:
|
||||
"""Same input should produce same routing decision (deterministic backtest)."""
|
||||
|
||||
def test_same_query_same_skill(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Submitting the same query multiple times should route to the same skill."""
|
||||
for benchmark in CONSISTENCY_BENCHMARKS:
|
||||
skills_seen: list[str | None] = []
|
||||
for run_idx in range(3):
|
||||
result = _run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"consistency_{benchmark.id}_run{run_idx}",
|
||||
)
|
||||
skills_seen.append(result["actual_skill"])
|
||||
|
||||
# All runs should produce the same skill
|
||||
non_none_skills = [s for s in skills_seen if s is not None]
|
||||
if len(non_none_skills) >= 2:
|
||||
assert len(set(non_none_skills)) == 1, (
|
||||
f"Inconsistent routing for {benchmark.id}: {skills_seen}"
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Routing Disambiguation (specific edge cases)
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestRoutingDisambiguation:
|
||||
"""When multiple skills could match, the router should pick the best one."""
|
||||
|
||||
def test_overlapping_keywords_routes_to_best_match(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""With overlapping keywords, router should pick the most relevant skill."""
|
||||
register_skill_via_api(
|
||||
api_client,
|
||||
"python_coder",
|
||||
keywords=["python", "code", "programming"],
|
||||
)
|
||||
register_skill_via_api(
|
||||
api_client,
|
||||
"javascript_coder",
|
||||
keywords=["javascript", "code", "programming"],
|
||||
)
|
||||
|
||||
benchmark = BenchmarkCase(
|
||||
id="disambig-python-001",
|
||||
input="Write a Python function to sort a list",
|
||||
expected_skill="python_coder",
|
||||
expected_complexity="medium",
|
||||
category="routing",
|
||||
subcategory="disambiguation",
|
||||
paraphrases=["I need a Python sorting algorithm", "用Python写个排序函数"],
|
||||
)
|
||||
|
||||
result = _run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="disambig_python",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
|
||||
# Also test paraphrases for overfitting detection
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
_run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"disambig_python_para_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
|
||||
def test_no_matching_skill_falls_back_gracefully(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""When no skill matches, should fall back to direct chat."""
|
||||
benchmark = BenchmarkCase(
|
||||
id="fallback-nomatch-001",
|
||||
input="Tell me about quantum physics",
|
||||
expected_skill=None,
|
||||
expected_complexity="low",
|
||||
category="routing",
|
||||
subcategory="fallback",
|
||||
paraphrases=["Explain quantum mechanics", "量子物理是什么"],
|
||||
)
|
||||
|
||||
result = _run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="fallback_nomatch",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
_run_benchmark_and_record(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"fallback_nomatch_para_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
"""E2E Agent Capability Tests — Expert Team Collaboration with Metrics.
|
||||
|
||||
Tests the intelligence of expert team collaboration AND collects data for:
|
||||
- Team formation accuracy
|
||||
- Fallback effectiveness
|
||||
- Expert coordination quality
|
||||
- Overfitting detection via paraphrased inputs
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from tests.e2e.benchmark_dataset import TEAM_BENCHMARKS, BenchmarkCase
|
||||
from tests.e2e.capability_metrics import MetricsCollector
|
||||
from tests.e2e.conftest import register_skill_via_api
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Helper: run team benchmark and record metrics
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _run_team_benchmark(
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
collector: MetricsCollector,
|
||||
test_name: str,
|
||||
is_paraphrase: bool = False,
|
||||
input_override: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute a team benchmark and record metrics."""
|
||||
query = input_override or benchmark.input
|
||||
collector.start_timer(benchmark.id)
|
||||
|
||||
payload: dict = {"input_data": {"query": query}}
|
||||
if benchmark.expected_skill:
|
||||
payload["skill_name"] = benchmark.expected_skill
|
||||
|
||||
resp = api_client.post("/api/v1/tasks", json=payload)
|
||||
|
||||
actual_skill = None
|
||||
actual_exec_mode = None
|
||||
actual_keys = []
|
||||
task_succeeded = resp.status_code == 200
|
||||
error_msg = None
|
||||
|
||||
if task_succeeded:
|
||||
data = resp.json()
|
||||
actual_skill = data.get("skill_name")
|
||||
actual_exec_mode = data.get("execution_mode")
|
||||
actual_keys = list(data.keys())
|
||||
elif resp.status_code >= 400:
|
||||
try:
|
||||
error_msg = resp.json().get("detail", resp.text[:200])
|
||||
except Exception:
|
||||
error_msg = resp.text[:200]
|
||||
|
||||
collector.record_benchmark_result(
|
||||
benchmark,
|
||||
test_name=test_name,
|
||||
actual_skill=actual_skill,
|
||||
actual_execution_mode=actual_exec_mode,
|
||||
actual_status_code=resp.status_code,
|
||||
actual_response_keys=actual_keys,
|
||||
task_succeeded=task_succeeded,
|
||||
is_paraphrase=is_paraphrase,
|
||||
error_message=error_msg,
|
||||
)
|
||||
|
||||
return {
|
||||
"status_code": resp.status_code,
|
||||
"actual_skill": actual_skill,
|
||||
"actual_exec_mode": actual_exec_mode,
|
||||
"task_succeeded": task_succeeded,
|
||||
}
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Parameterized Team Benchmark Tests
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestTeamBenchmarks:
|
||||
"""Run all team benchmarks with metrics collection."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
TEAM_BENCHMARKS,
|
||||
ids=[b.id for b in TEAM_BENCHMARKS],
|
||||
)
|
||||
def test_team_benchmark(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run original team benchmark and record metrics."""
|
||||
if benchmark.expected_skill:
|
||||
register_skill_via_api(
|
||||
api_client,
|
||||
benchmark.expected_skill,
|
||||
keywords=[benchmark.expected_skill],
|
||||
)
|
||||
|
||||
result = _run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"team_benchmark_{benchmark.id}",
|
||||
)
|
||||
assert result["status_code"] == 200, f"Team benchmark {benchmark.id} failed"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"benchmark",
|
||||
[b for b in TEAM_BENCHMARKS if b.paraphrases],
|
||||
ids=[b.id for b in TEAM_BENCHMARKS if b.paraphrases],
|
||||
)
|
||||
def test_team_paraphrase(
|
||||
self,
|
||||
benchmark: BenchmarkCase,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""Run paraphrases for overfitting detection."""
|
||||
for i, paraphrase in enumerate(benchmark.paraphrases):
|
||||
_run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"team_paraphrase_{benchmark.id}_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=paraphrase,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Team Formation Intelligence
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestTeamFormation:
|
||||
"""Test that teams are formed intelligently based on task requirements."""
|
||||
|
||||
def test_explicit_team_prefix(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""@team prefix should trigger team collaboration mode."""
|
||||
register_skill_via_api(api_client, "team_analyst", keywords=["team_analyst", "analyze"])
|
||||
register_skill_via_api(api_client, "team_writer", keywords=["team_writer", "write"])
|
||||
|
||||
benchmark = BenchmarkCase(
|
||||
id="team-explicit-001",
|
||||
input="Analyze the data and write a report",
|
||||
expected_skill="team_analyst",
|
||||
expected_execution_mode="react",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="explicit_team",
|
||||
paraphrases=["I need analysis and a written report", "分析数据并写报告"],
|
||||
)
|
||||
|
||||
result = _run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="team_explicit",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
_run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"team_explicit_para_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Fallback Intelligence
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestTeamFallback:
|
||||
"""Test that team collaboration falls back gracefully on failure."""
|
||||
|
||||
def test_fallback_to_single_agent_on_team_failure(
|
||||
self,
|
||||
api_client: httpx.Client,
|
||||
metrics_collector: MetricsCollector,
|
||||
):
|
||||
"""If team collaboration fails, should fall back to single agent."""
|
||||
register_skill_via_api(api_client, "fallback_skill", keywords=["fallback_test"])
|
||||
|
||||
benchmark = BenchmarkCase(
|
||||
id="team-fallback-001",
|
||||
input="Complex task that might need fallback",
|
||||
expected_skill="fallback_skill",
|
||||
expected_complexity="high",
|
||||
category="team",
|
||||
subcategory="fallback",
|
||||
paraphrases=["Difficult task requiring fallback mechanism", "需要回退机制的复杂任务"],
|
||||
)
|
||||
|
||||
result = _run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name="team_fallback",
|
||||
)
|
||||
assert result["status_code"] == 200
|
||||
|
||||
for i, para in enumerate(benchmark.paraphrases):
|
||||
_run_team_benchmark(
|
||||
benchmark,
|
||||
api_client,
|
||||
metrics_collector,
|
||||
test_name=f"team_fallback_para_{i}",
|
||||
is_paraphrase=True,
|
||||
input_override=para,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Expert Name Validation
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestExpertNameValidation:
|
||||
"""Test that expert names are validated according to project rules."""
|
||||
|
||||
def test_valid_expert_names(self, api_client: httpx.Client):
|
||||
"""Valid expert names (alphanumeric, dash, underscore) should work."""
|
||||
for name in ["analyst", "data-scientist", "code_reviewer", "expert-123"]:
|
||||
resp = register_skill_via_api(api_client, name, keywords=[name])
|
||||
assert resp.status_code in (200, 201, 409), f"Failed for name: {name}"
|
||||
|
||||
def test_invalid_expert_name_rejected(self, api_client: httpx.Client):
|
||||
"""Invalid expert names should be rejected."""
|
||||
for name in ["expert with spaces", "expert@special", "", "a" * 65]:
|
||||
resp = register_skill_via_api(api_client, name, keywords=[name])
|
||||
assert resp.status_code in (200, 201, 400, 409, 422), (
|
||||
f"Unexpected status for name: '{name}'"
|
||||
)
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agentkit.chat.skill_routing import (
|
||||
CostAwareRouter,
|
||||
ExecutionMode,
|
||||
HeuristicClassifier,
|
||||
SkillRoutingResult,
|
||||
)
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||
from agentkit.experts.router import ExpertTeamRouter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter:
|
||||
"""Create a CostAwareRouter with mocked dependencies."""
|
||||
return CostAwareRouter(
|
||||
llm_gateway=None,
|
||||
model="test",
|
||||
classifier="heuristic",
|
||||
expert_team_router=expert_team_router,
|
||||
)
|
||||
|
||||
|
||||
def _make_team_router_with_templates() -> ExpertTeamRouter:
|
||||
"""Create an ExpertTeamRouter with sample templates."""
|
||||
registry = ExpertTemplateRegistry()
|
||||
for name in ("analyst", "strategist", "reviewer"):
|
||||
config = ExpertConfig(
|
||||
name=name,
|
||||
agent_type="expert",
|
||||
persona=f"Expert in {name}",
|
||||
thinking_style="analytical",
|
||||
bound_skills=[],
|
||||
is_lead=(name == "analyst"),
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": f"Expert in {name}"},
|
||||
)
|
||||
template = ExpertTemplate(
|
||||
name=name,
|
||||
config=config,
|
||||
description=f"Handles {name} tasks",
|
||||
)
|
||||
registry.register(template)
|
||||
return ExpertTeamRouter(template_registry=registry)
|
||||
|
||||
|
||||
def _make_team_router_empty() -> ExpertTeamRouter:
|
||||
"""Create an ExpertTeamRouter with no templates."""
|
||||
return ExpertTeamRouter(template_registry=ExpertTemplateRegistry())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: ExpertTeamRouter.can_handle()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExpertTeamRouterCanHandle:
|
||||
def test_can_handle_with_templates(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
assert router.can_handle("analyze this data") is True
|
||||
|
||||
def test_can_handle_no_templates(self) -> None:
|
||||
router = _make_team_router_empty()
|
||||
assert router.can_handle("analyze this data") is False
|
||||
|
||||
def test_can_handle_name_match(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
assert router.can_handle("I need a strategist for this") is True
|
||||
|
||||
def test_can_handle_description_match(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
assert router.can_handle("handles review tasks") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _try_team_upgrade()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryTeamUpgrade:
|
||||
def test_upgrade_react_to_team_collab(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex multi-step analysis task",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB
|
||||
assert any(t.get("method") == "team_upgrade" for t in trace)
|
||||
|
||||
def test_no_upgrade_low_complexity(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="simple question",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.3,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
assert not any(t.get("method") == "team_upgrade" for t in trace)
|
||||
|
||||
def test_no_upgrade_no_team_router(self) -> None:
|
||||
router = _make_router(expert_team_router=None)
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex analysis",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.9,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
def test_no_upgrade_empty_templates(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_empty())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex analysis",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
def test_no_upgrade_direct_chat_mode(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="hello",
|
||||
matched=False,
|
||||
match_method="greeting",
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "hello", 0.0, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
def test_team_upgrade_exception_handled(self) -> None:
|
||||
"""When ExpertTeamRouter raises, the upgrade is silently skipped."""
|
||||
broken_router = MagicMock()
|
||||
broken_router.can_handle.side_effect = RuntimeError("boom")
|
||||
router = _make_router(expert_team_router=broken_router)
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex task",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: ExpertTeamRouter.resolve() with complexity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExpertTeamRouterResolve:
|
||||
def test_explicit_team_prefix(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("@team:analyst,strategist analyze the market", 0.5)
|
||||
assert result.team_mode is True
|
||||
assert result.match_method == "explicit_team"
|
||||
assert "analyst" in result.specified_experts
|
||||
assert "strategist" in result.specified_experts
|
||||
|
||||
def test_complexity_suggestion(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("complex multi-step analysis", 0.8)
|
||||
assert result.team_mode is True
|
||||
assert result.match_method == "complexity_suggestion"
|
||||
assert result.auto_compose is True
|
||||
|
||||
def test_no_team_low_complexity(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("simple question", 0.2)
|
||||
assert result.team_mode is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: HeuristicClassifier complexity calibration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeuristicClassifierLowComplexity:
|
||||
"""Low-complexity signals should produce scores < 0.3."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_chinese_greeting(self) -> None:
|
||||
assert self.clf.classify("你好") < 0.3
|
||||
|
||||
def test_chinese_greeting_hi(self) -> None:
|
||||
assert self.clf.classify("嗨") < 0.3
|
||||
|
||||
def test_english_greeting_hello(self) -> None:
|
||||
assert self.clf.classify("Hello") < 0.3
|
||||
|
||||
def test_english_greeting_hi(self) -> None:
|
||||
assert self.clf.classify("hi") < 0.3
|
||||
|
||||
def test_multiple_low_complexity_words(self) -> None:
|
||||
assert self.clf.classify("嗨,早上好") < 0.3
|
||||
|
||||
def test_greeting_with_high_complexity_word_not_suppressed(self) -> None:
|
||||
"""Low-complexity signal should NOT override high-complexity signal."""
|
||||
# "你好" is low, but "分析" is high → should score high
|
||||
assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierIdentity:
|
||||
"""Identity queries should produce scores < 0.3."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_who_are_you_cn(self) -> None:
|
||||
assert self.clf.classify("你是谁") < 0.3
|
||||
|
||||
def test_what_is_your_name_cn(self) -> None:
|
||||
assert self.clf.classify("你叫什么") < 0.3
|
||||
|
||||
|
||||
class TestHeuristicClassifierNegation:
|
||||
"""Negated high-complexity words should not contribute to score."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_negate_search_cn(self) -> None:
|
||||
assert self.clf.classify("不要搜索") < 0.3
|
||||
|
||||
def test_negate_analyze_cn(self) -> None:
|
||||
assert self.clf.classify("无需分析,直接告诉我答案") < 0.3
|
||||
|
||||
def test_partial_negation_still_high(self) -> None:
|
||||
"""'搜索' negated but '分析' not — should still be high."""
|
||||
assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierThresholds:
|
||||
"""Verify adjusted base scores."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_no_keyword_short_message(self) -> None:
|
||||
assert self.clf.classify("好的") <= 0.10
|
||||
|
||||
def test_medium_complexity_base(self) -> None:
|
||||
"""Medium complexity keyword should start at 0.35 (not 0.45)."""
|
||||
score = self.clf.classify("如何使用Python?")
|
||||
# '如何' is medium → base 0.35, '?' short question → -0.10 = 0.25
|
||||
# but 'Python' is not in high/medium lists, so just medium base
|
||||
assert 0.25 <= score <= 0.45
|
||||
|
||||
|
||||
class TestHeuristicClassifierShortQuestion:
|
||||
"""Short questions ending with ?/? should get deduction."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_short_question_deduction(self) -> None:
|
||||
assert self.clf.classify("怎么用?") < 0.3
|
||||
|
||||
def test_long_question_no_deduction(self) -> None:
|
||||
assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierHighComplexity:
|
||||
"""Complex tasks should produce scores > 0.7."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_two_high_complexity_words(self) -> None:
|
||||
# "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80
|
||||
assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7
|
||||
|
||||
def test_single_high_complexity_word(self) -> None:
|
||||
# "分析" alone → base 0.65
|
||||
assert self.clf.classify("分析市场趋势并生成报告") > 0.6
|
||||
|
||||
def test_execute_and_restart(self) -> None:
|
||||
assert self.clf.classify("执行部署脚本并重启服务") > 0.7
|
||||
|
||||
|
||||
class TestHeuristicClassifierEdgeCases:
|
||||
"""Boundary conditions."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert self.clf.classify("") == 0.0
|
||||
|
||||
def test_whitespace_only(self) -> None:
|
||||
assert self.clf.classify(" ") == 0.0
|
||||
|
||||
def test_long_low_complexity_message(self) -> None:
|
||||
"""Even a long greeting should stay low."""
|
||||
long_greeting = "你好" * 100 # >200 chars
|
||||
assert self.clf.classify(long_greeting) <= 0.15
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
"""Unit tests for QualityGate skill match validation (5th dimension)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.quality.gate import QualityGate
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str = "test_skill",
|
||||
required_fields: list[str] | None = None,
|
||||
min_word_count: int = 0,
|
||||
) -> Skill:
|
||||
"""Create a Skill with the given quality gate config."""
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="skill",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": f"You are {name}"},
|
||||
quality_gate={
|
||||
"required_fields": required_fields or [],
|
||||
"min_word_count": min_word_count,
|
||||
},
|
||||
)
|
||||
return Skill(config=config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _check_skill_match static method
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckSkillMatch:
|
||||
def setup_method(self) -> None:
|
||||
self.gate = QualityGate()
|
||||
|
||||
def test_none_skill_context(self) -> None:
|
||||
assert self.gate._check_skill_match({"content": "hello"}, None) is None
|
||||
|
||||
def test_empty_skill_context(self) -> None:
|
||||
assert self.gate._check_skill_match({"content": "hello"}, {}) is None
|
||||
|
||||
def test_missing_intent_keywords(self) -> None:
|
||||
assert self.gate._check_skill_match({"content": "hello"}, {"skill_name": "x"}) is None
|
||||
|
||||
def test_empty_intent_keywords(self) -> None:
|
||||
assert self.gate._check_skill_match({"content": "hello"}, {"intent_keywords": []}) is None
|
||||
|
||||
def test_output_contains_keyword(self) -> None:
|
||||
result = self.gate._check_skill_match(
|
||||
{"content": "市场分析报告"},
|
||||
{"intent_keywords": ["分析", "报告"]},
|
||||
)
|
||||
assert result is not None
|
||||
assert result.passed is True
|
||||
assert result.message is None
|
||||
|
||||
def test_output_missing_all_keywords(self) -> None:
|
||||
result = self.gate._check_skill_match(
|
||||
{"content": "今天天气不错"},
|
||||
{"intent_keywords": ["分析", "报告"]},
|
||||
)
|
||||
assert result is not None
|
||||
assert result.passed is True # Warning level, not blocking
|
||||
assert "Warning" in (result.message or "")
|
||||
|
||||
def test_keyword_case_insensitive(self) -> None:
|
||||
result = self.gate._check_skill_match(
|
||||
{"content": "search results"},
|
||||
{"intent_keywords": ["Search"]},
|
||||
)
|
||||
assert result is not None
|
||||
assert result.passed is True
|
||||
assert result.message is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Full validate() with skill_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateWithSkillContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_skill_context_backward_compatible(self) -> None:
|
||||
"""Without skill_context, only 4 dimensions checked."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill()
|
||||
result = await gate.validate({"content": "hello"}, skill)
|
||||
assert result.passed is True
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_context_with_matching_output(self) -> None:
|
||||
"""Output contains keyword → skill_match passes silently."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill()
|
||||
result = await gate.validate(
|
||||
{"content": "市场分析报告"},
|
||||
skill,
|
||||
skill_context={"intent_keywords": ["分析"]},
|
||||
)
|
||||
assert result.passed is True
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 1
|
||||
assert skill_match_checks[0].passed is True
|
||||
assert skill_match_checks[0].message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_context_warning_only(self) -> None:
|
||||
"""Output missing keywords but other checks pass → warning, overall still passed."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill()
|
||||
result = await gate.validate(
|
||||
{"content": "今天天气不错"},
|
||||
skill,
|
||||
skill_context={"intent_keywords": ["分析"]},
|
||||
)
|
||||
assert result.passed is True # Warning doesn't block alone
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 1
|
||||
assert "Warning" in (skill_match_checks[0].message or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_match_escalates_with_other_failure(self) -> None:
|
||||
"""Output missing keywords + required field missing → skill_match escalated to failed."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill(required_fields=["summary"])
|
||||
result = await gate.validate(
|
||||
{"content": "今天天气不错"}, # missing "summary" field
|
||||
skill,
|
||||
skill_context={"intent_keywords": ["分析"]},
|
||||
)
|
||||
assert result.passed is False
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 1
|
||||
assert skill_match_checks[0].passed is False # Escalated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_match_no_escalation_when_matching(self) -> None:
|
||||
"""Output contains keywords + required field missing → skill_match stays passed."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill(required_fields=["summary"])
|
||||
result = await gate.validate(
|
||||
{"content": "分析结果"}, # missing "summary" field
|
||||
skill,
|
||||
skill_context={"intent_keywords": ["分析"]},
|
||||
)
|
||||
assert result.passed is False # Due to required field
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 1
|
||||
assert skill_match_checks[0].passed is True # Not escalated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_intent_keywords_skips_check(self) -> None:
|
||||
"""Empty intent_keywords list → skill_match check skipped entirely."""
|
||||
gate = QualityGate()
|
||||
skill = _make_skill()
|
||||
result = await gate.validate(
|
||||
{"content": "hello"},
|
||||
skill,
|
||||
skill_context={"intent_keywords": []},
|
||||
)
|
||||
skill_match_checks = [c for c in result.checks if c.name == "skill_match"]
|
||||
assert len(skill_match_checks) == 0
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
"""Unit tests for IntentRouter multi-candidate keyword scoring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agentkit.router.intent import IntentRouter
|
||||
from agentkit.skills.base import Skill, SkillConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_skill(name: str, keywords: list[str], description: str = "") -> Skill:
|
||||
"""Create a Skill with the given name and intent keywords."""
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="skill",
|
||||
description=description or f"Skill {name}",
|
||||
task_mode="llm_generate",
|
||||
prompt={"identity": f"You are {name}"},
|
||||
intent={"keywords": keywords, "description": description or f"Skill {name}"},
|
||||
)
|
||||
return Skill(config=config)
|
||||
|
||||
|
||||
def _make_skills(*specs: tuple[str, list[str]]) -> list[Skill]:
|
||||
"""Create multiple skills from (name, keywords) tuples."""
|
||||
return [_make_skill(name, kws) for name, kws in specs]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Single-candidate match (backward compatible)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleCandidateMatch:
|
||||
"""When only one skill matches, behavior is identical to old first-match."""
|
||||
|
||||
def test_single_skill_matches(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["规划", "执行"]), ("skill_b", ["搜索", "查询"]))
|
||||
result = router._match_keywords({"content": "帮我规划一个项目"}, skills)
|
||||
assert result is not None
|
||||
assert result.matched_skill == "skill_a"
|
||||
assert result.method == "keyword"
|
||||
|
||||
def test_single_keyword_match_confidence(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["规划"]))
|
||||
result = router._match_keywords({"content": "规划任务"}, skills)
|
||||
assert result is not None
|
||||
# 1 keyword matched → confidence = min(1.0, 0.5 + 0.1 * 1) = 0.6
|
||||
assert result.confidence == 0.6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Multi-candidate scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMultiCandidateScoring:
|
||||
"""When multiple skills match, the best-scored one wins."""
|
||||
|
||||
def test_longer_keyword_wins(self) -> None:
|
||||
"""'调研报告' (4 chars) beats '报告' (2 chars) on same input."""
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(
|
||||
("plan_exec", ["规划", "报告"]),
|
||||
("goal_driven", ["调研报告", "生成"]),
|
||||
)
|
||||
result = router._match_keywords({"content": "规划一个调研报告"}, skills)
|
||||
assert result is not None
|
||||
# plan_exec: "规划"(2) + "报告"(2) = 4; goal_driven: "调研报告"(4) = 4
|
||||
# Same score → alphabetical: goal_driven < plan_exec
|
||||
assert result.matched_skill == "goal_driven"
|
||||
|
||||
def test_more_keywords_wins(self) -> None:
|
||||
"""Skill matching 3 keywords beats skill matching 1 keyword."""
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(
|
||||
("skill_a", ["分析"]),
|
||||
("skill_b", ["分析", "市场", "趋势"]),
|
||||
)
|
||||
result = router._match_keywords({"content": "分析市场趋势"}, skills)
|
||||
assert result is not None
|
||||
# skill_a: "分析"(2) = 2; skill_b: "分析"(2)+"市场"(2)+"趋势"(2) = 6
|
||||
assert result.matched_skill == "skill_b"
|
||||
|
||||
def test_same_score_alphabetical(self) -> None:
|
||||
"""When scores are equal, alphabetical name order breaks the tie."""
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(
|
||||
("zebra_skill", ["分析"]),
|
||||
("alpha_skill", ["分析"]),
|
||||
)
|
||||
result = router._match_keywords({"content": "分析数据"}, skills)
|
||||
assert result is not None
|
||||
assert result.matched_skill == "alpha_skill"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: No match
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoMatch:
|
||||
def test_no_keyword_match(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["搜索"]), ("skill_b", ["查询"]))
|
||||
result = router._match_keywords({"content": "你好"}, skills)
|
||||
assert result is None
|
||||
|
||||
def test_empty_keywords_list(self) -> None:
|
||||
"""Skill with empty keywords list does not participate in matching."""
|
||||
router = IntentRouter()
|
||||
skills = [_make_skill("empty_kw", [])]
|
||||
result = router._match_keywords({"content": "anything"}, skills)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Case insensitivity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCaseInsensitivity:
|
||||
def test_english_keyword_case_insensitive(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["Search"]))
|
||||
result = router._match_keywords({"content": "please search for this"}, skills)
|
||||
assert result is not None
|
||||
assert result.matched_skill == "skill_a"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Substring matching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSubstringMatch:
|
||||
def test_chinese_substring_match(self) -> None:
|
||||
"""Chinese keyword '报告' should match input containing '报告'."""
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["报告"]))
|
||||
result = router._match_keywords({"content": "生成一份报告"}, skills)
|
||||
assert result is not None
|
||||
assert result.matched_skill == "skill_a"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Confidence calculation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfidenceCalculation:
|
||||
def test_one_keyword_confidence(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["分析"]))
|
||||
result = router._match_keywords({"content": "分析数据"}, skills)
|
||||
assert result is not None
|
||||
assert result.confidence == 0.6 # 0.5 + 0.1 * 1
|
||||
|
||||
def test_three_keywords_confidence(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["分析", "市场", "趋势"]))
|
||||
result = router._match_keywords({"content": "分析市场趋势"}, skills)
|
||||
assert result is not None
|
||||
assert result.confidence == 0.8 # 0.5 + 0.1 * 3
|
||||
|
||||
def test_confidence_capped_at_one(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["a", "b", "c", "d", "e", "f"]))
|
||||
result = router._match_keywords({"content": "a b c d e f"}, skills)
|
||||
assert result is not None
|
||||
assert result.confidence == 1.0 # min(1.0, 0.5 + 0.1 * 6 = 1.1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_input_text(self) -> None:
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["分析"]))
|
||||
result = router._match_keywords({"content": ""}, skills)
|
||||
assert result is None
|
||||
|
||||
def test_nested_input_data(self) -> None:
|
||||
"""_extract_string_values should handle nested dicts/lists."""
|
||||
router = IntentRouter()
|
||||
skills = _make_skills(("skill_a", ["分析"]))
|
||||
result = router._match_keywords(
|
||||
{"message": {"text": "分析数据", "meta": {"role": "user"}}},
|
||||
skills,
|
||||
)
|
||||
assert result is not None
|
||||
assert result.matched_skill == "skill_a"
|
||||
Loading…
Reference in New Issue