refactor: eliminate routing layer, align with industry best practices

Phase 1 of architecture optimization (U1/U2/U4/U8):

- U1: Rename SimpleRouter to RequestPreprocessor, route() to preprocess()
  Eliminates misleading routing concept; LLM decides autonomously
  in REACT agent loop (matches Codex/Claude Code/Trae pattern)
- U2: Delete CostAwareRouter, HeuristicClassifier, SemanticRouter
  (~700 lines removed). skill_routing.py: 1688 to 220 lines
- U4: PlanExecEngine defaults to ReActStepExecutor, delete _LLMStepExecutor
  (pure LLM calls without tools = no execution capability)
- U8: ReActEngine defaults to ContextCompressor(keep_recent=10)

Supersedes plans 2026-06-15-002/003/004.
New plan: 2026-06-16-006-refactor-architecture-optimization-evolution-plan.md
This commit is contained in:
chiguyong 2026-06-17 10:44:40 +08:00
parent b54213b3c6
commit 5374bc8501
18 changed files with 1251 additions and 3469 deletions

View File

@ -1,7 +1,10 @@
--- ---
title: "feat: E2E能力分析框架改进与路由智能化提升" title: "feat: E2E能力分析框架改进与路由智能化提升"
type: feat type: feat
status: active status: superseded
superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
superseded_reason: "路由智能化提升部分已被 SimpleRouter 架构简化替代。E2E 能力分析框架部分可合并到 2026-06-16-005 的 U5 回测验证单元。"
closed: 2026-06-16
created: 2026-06-15 created: 2026-06-15
plan-depth: standard plan-depth: standard
--- ---

View File

@ -1,8 +1,9 @@
--- ---
title: "feat: 路由智能化优化 — 复杂度校准、意图消歧、质量门控增强" title: "feat: 路由智能化优化 — 复杂度校准、意图消歧、质量门控增强"
status: active status: superseded
created: 2026-06-15 superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
updated: 2026-06-15 superseded_reason: "SimpleRouter 已替代 CostAwareRouter 的 4 层路由架构。IntentRouter 多候选评分U2和 QualityGate 技能匹配验证U3属于被删除的旧路由层组件不再需要实现。U1 HeuristicClassifier 测试仅对向后兼容有价值。"
closed: 2026-06-16
origin: test-results/e2e/capability_report.txt (真实LLM回测分析报告) origin: test-results/e2e/capability_report.txt (真实LLM回测分析报告)
--- ---

View File

@ -2,9 +2,10 @@
```yaml ```yaml
title: feat: SemanticRouter 启用与回测体系升级 title: feat: SemanticRouter 启用与回测体系升级
status: active status: superseded
created: 2026-06-15 superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
plan_id: "2026-06-15-004" superseded_reason: "SimpleRouter 已替代 CostAwareRouter不再需要 SemanticRouter 作为路由层组件。LLM 在 REACT agent loop 中看到完整工具描述后自主决策,无需 embedding 做意图路由。如未来工具数量 >50可参考 Codex 的 tool_searchBM25做工具发现。"
closed: 2026-06-16
``` ```
## Summary ## Summary
@ -187,6 +188,12 @@ SemanticRouter 已完整实现(`src/agentkit/chat/semantic_router.py`),但
- 对抗性输入测试 - 对抗性输入测试
- 意图分类微调流水线 - 意图分类微调流水线
- 关键词自动扩充工具 - 关键词自动扩充工具
- **[待办] 提供 Embedding API Key** — 当前百炼 Coding Plan key (sk-sp-*) 不支持 embedding API需要
- 方案 A在 DashScope 控制台单独开通 embedding 服务,获取标准 sk- 前缀 key
- 方案 B配置本地 embedding 模型BGE-large-zh-v1.5 via Xinference
- 方案 C使用其他支持 embedding 的 provider智谱/火山等)
- 配置位置agentkit.yaml → llm.providers 增加 embedding 专用 provider
- 预期效果SemanticRouter 真正工作,口语化查询无需手动扩充关键词
## Risks ## Risks

View File

@ -0,0 +1,429 @@
```yaml
title: "refactor: AgentKit 架构优化演进 — 对齐业界最佳实践"
status: active
plan_id: "2026-06-16-006"
created: 2026-06-16
depth: deep
origin: "基于对 Codex/Claude Code/Trae/Qoder 的深入分析,优化 AgentKit 架构"
```
# AgentKit 架构优化演进计划
## Summary
基于对 Codex CLI、Claude Code、Trae Agent 2.0、Qoder 的深入分析,对 AgentKit 进行架构优化演进。核心变更:(1) 消除"路由"概念改为请求预处理,(2) 专家团从去中心化协作简化为 hub-and-spoke(3) PlanExec 简化为 Spec-Driven 模式,(4) 聊天记录 SQLite 持久化,(5) 删除旧路由层代码,(6) 新增可验证执行和工具描述精简。
## Problem Frame
AgentKit 当前架构存在三类问题:
1. **概念误导**SimpleRouter 仍暗示"路由层"存在,但实际只做 @skill 前缀解析 + greeting fast-path核心决策已交给 REACT agent loop 中的 LLM
2. **过度设计**ExpertTeam 的去中心化协作CollaborationPlan + HandoffTransport + SharedWorkspace + 3 种 MergeStrategy远超业界实践增加 ~1500 行代码但无对应价值
3. **能力缺口**PlanExecEngine 默认使用 _LLMStepExecutor纯 LLM 调用无工具),聊天记录不持久化,无自动验证循环
## Requirements
- R1: 消除"路由"概念SimpleRouter 重命名为 RequestPreprocessor语义从"路由决策"变为"请求预处理"
- R2: 专家团简化为 hub-and-spoke 模式Lead Expert + 并行 Task深度=1删除 CollaborationPlan/HandoffTransport/SharedWorkspace 的 ExpertTeam 专用逻辑
- R3: PlanExecEngine 默认使用 ReActStepExecutor删除 _LLMStepExecutor
- R4: 聊天记录 SQLite 持久化(参考 Codex 的 Thread 持久化)
- R5: 删除 CostAwareRouter 及相关代码HeuristicClassifier、IntentRouter、QualityGate、SemanticRouter
- R6: 新增可验证执行Test-and-Verify 循环,参考 Codex Cloud
- R7: 工具描述分层注入(核心工具全量 + 扩展工具一行描述 + tool_search 按需获取)
- R8: 默认启用上下文压缩
- R9: 新增 Spec 文档作为一等公民(参考 Qoder Quest Mode
- R10: 统一事件模型SQ/EQ 双队列,参考 Codex
## Key Technical Decisions
### KTD1: SimpleRouter 重命名为 RequestPreprocessor
**决策**:将 SimpleRouter 重命名为 RequestPreprocessorroute() 方法重命名为 preprocess()
**依据**
- Codex 无路由层Claude Code 无路由层Trae 移除了 Proposal 阶段——业界共识是没有路由层
- SimpleRouter 的 3 个功能(@skill 前缀、greeting regex、default REACT都是预处理而非路由决策
- "路由"概念误导开发者认为存在意图预测层,实际上 LLM 在 agent loop 中自主决策
**替代方案**:完全删除 SimpleRouter所有请求直接进入 REACT loop。被否决——greeting fast-path 每次请求节省 ~100 tokens + 500ms@skill 前缀是用户显式指令需要代码级解析
### KTD2: 专家团 hub-and-spoke 模式
**决策**ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task深度=1
**依据**
- Claude CodeTask 工具深度=1子 Agent 不能再生子 Agent
- Codexspawn_agent 层级式,结果返回父 Agent
- Qoder多专家并行独立执行主 Agent 汇总
- 去中心化协作的通信复杂度 O(N²)hub-and-spoke 为 O(N)
- 同一 LLM 扮演不同"专家"不产生真正的观点多样性,等价于多次采样+合并
**保留**ExpertConfig/ExpertTemplate/Registry定义专家 persona、BEST 合并策略Lead Agent 选择最佳结果)
**删除**CollaborationPlan 的 phase 依赖图、HandoffTransport 的 Agent 间通信、SharedWorkspace 的跨阶段状态共享、VOTE/FUSION 合并策略
### KTD3: PlanExec 默认 ReActStepExecutor + Spec-Driven
**决策**:默认使用 ReActStepExecutor已实现删除 _LLMStepExecutor新增 Spec 文档持久化
**依据**
- _LLMStepExecutor 不支持工具调用 = 没有执行能力
- ReActStepExecutor 已实现并使用 ReActEngine支持工具调用和多步推理
- Qoder Quest ModeSpec First 是人和 AI 的契约,用户确认后再执行
- 当前 PlanExec 的计划对用户不可见,用户无法在执行前纠正方向
### KTD4: 聊天记录 SQLite 持久化
**决策**:使用 SQLite 做聊天记录持久化(参考 Codex 的 Thread 持久化)
**依据**
- Codex 使用 SQLite轻量、零配置、跨平台支持 resume/fork/archive
- Claude Code 使用 append-only JSONL更简单但搜索能力弱
- 当前 ConversationStore 纯内存,服务重启后丢失
- SQLite 支持会话搜索、分页加载,且无需额外服务
---
## Implementation Units
### U1. SimpleRouter 重命名为 RequestPreprocessor
**Goal**: 消除"路由"概念,将 SimpleRouter 重命名为 RequestPreprocessor
**Dependencies**: 无
**Files**:
- `src/agentkit/chat/simple_router.py` → 重命名为 `src/agentkit/chat/request_preprocessor.py`
- `src/agentkit/chat/skill_routing.py` — 更新引用
- `src/agentkit/server/routes/portal.py` — 更新 import 和调用
- `src/agentkit/server/routes/chat.py` — 更新 import 和调用
- `src/agentkit/server/app.py` — 更新 import 和调用
- `tests/unit/chat/test_simple_router.py` → 重命名并更新
**Approach**:
1. 创建 `request_preprocessor.py`,类名 `RequestPreprocessor`,方法 `route()``preprocess()`
2. `_is_direct_chat()` 重命名为 `_is_trivial_input()`
3. `SkillRoutingResult` 保留(它是数据结构,不涉及路由概念)
4. 更新所有调用点
5. 删除旧文件
**Patterns to follow**: 现有 SimpleRouter 的代码结构
**Test scenarios**:
- @skill:xxx 前缀正确解析为 SKILL_REACT 模式
- Greeting regex 匹配返回 DIRECT_CHAT
- 默认输入返回 REACT 模式
- 未知 skill 回退到 REACT
- preprocess() 方法签名与 route() 兼容
**Verification**: `ruff check src/ && pytest tests/unit/chat/ -v`
---
### U2. 删除 CostAwareRouter 及相关代码
**Goal**: 删除已被 SimpleRouter 替代的旧路由层代码
**Dependencies**: U1
**Files**:
- `src/agentkit/router/` 目录下大部分文件(保留 __init__.py 和必要的导出)
- `src/agentkit/server/app.py` — 清理注释掉的引用
- `src/agentkit/chat/cost_aware_router.py` — 删除
**Approach**:
1. 确认 CostAwareRouter 在代码中无活跃引用app.py 已注释)
2. 删除 `cost_aware_router.py`、`heuristic_classifier.py`、`intent.py`IntentRouter、`quality_gate.py`、`semantic_router.py`
3. 保留 `router/__init__.py` 导出必要的类型(如 ExecutionMode如果前端依赖
4. 清理 app.py 中的注释引用
5. 更新 `router/` 目录的 `__init__.py`
**Test scenarios**:
- 删除后 `ruff check` 无错误
- `pytest -m "not integration"` 全部通过
- 无 import 错误
**Verification**: `ruff check src/ && pytest -m "not integration" -x`
---
### U3. 专家团简化为 hub-and-spoke
**Goal**: 将 ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task 模式
**Dependencies**: 无(可与 U1/U2 并行)
**Files**:
- `src/agentkit/experts/orchestrator.py` — 重写为 hub-and-spoke
- `src/agentkit/experts/team.py` — 简化,移除 CollaborationPlan 依赖
- `src/agentkit/experts/plan.py` — 简化,保留 MergeStrategy.BEST
- `src/agentkit/core/handoff_transport.py` — 移除 ExpertTeam 专用逻辑
- `src/agentkit/core/shared_workspace.py` — 移除 ExpertTeam 专用逻辑
- `src/agentkit/experts/router.py` — 简化为 @team 前缀 + RequestPreprocessor 集成
- `tests/unit/experts/test_orchestrator.py` — 更新
**Approach**:
1. 重写 TeamOrchestratorLead Expert 自主规划 + 并行 spawn Task
2. 删除 CollaborationPlan 的 phase 依赖图Lead Expert 自主决定执行顺序
3. 删除 HandoffTransport 的 Agent 间通信Task 结果直接返回 Lead Expert
4. 删除 SharedWorkspace 的跨阶段状态共享Lead Expert 持有所有状态
5. 保留 MergeStrategy.BESTLead Agent 选择最佳结果),删除 VOTE/FUSION
6. 简化 ExpertTeamRouter 为 @team 前缀触发
7. 保留 ExpertConfig/ExpertTemplate/Registry 不变
**Technical design**:
```
新 TeamOrchestrator 流程:
1. 用户输入 → @team:xxx 前缀 → ExpertTeamMode
2. Lead Expert 接收任务,自主分解为子任务
3. 并行 spawn Task每个 Task 是独立 ReActEngine 实例)
4. 等待所有 Task 完成
5. Lead Expert 汇总结果BEST 策略)
6. 返回最终结果
约束:
- Task 深度=1Task 不能再 spawn Task
- Task 之间无通信
- Lead Expert 持有所有状态
```
**Test scenarios**:
- Lead Expert 正确分解任务为子任务
- 并行 Task 独立执行并返回结果
- Lead Expert 汇总结果
- 单个 Task 失败不影响其他 Task
- 所有 Task 失败时回退到 Lead Expert 单独执行
- @team 前缀正确触发 ExpertTeamMode
**Verification**: `ruff check src/ && pytest tests/unit/experts/ -v`
---
### U4. PlanExec 默认 ReActStepExecutor + 删除 _LLMStepExecutor
**Goal**: PlanExecEngine 默认使用 ReActStepExecutor删除 _LLMStepExecutor
**Dependencies**: 无
**Files**:
- `src/agentkit/core/plan_exec_engine.py` — 删除 _LLMStepExecutor 和 _LLMStepAgent默认 step_executor_type="react"
- `tests/unit/core/test_plan_exec_engine.py` — 更新
**Approach**:
1. 删除 `_LLMStepExecutor``_LLMStepAgent`
2. `_create_executor()` 方法移除 step_executor_type 参数,始终使用 ReActStepExecutor
3. 清理相关 import
**Test scenarios**:
- PlanExecEngine 默认创建 ReActStepExecutor
- ReActStepExecutor 正确执行带工具调用的步骤
- 步骤失败时触发重规划
**Verification**: `ruff check src/ && pytest tests/unit/core/test_plan_exec_engine.py -v`
---
### U5. 聊天记录 SQLite 持久化
**Goal**: 使用 SQLite 持久化聊天记录,服务重启后不丢失
**Dependencies**: U1RequestPreprocessor 重命名完成后更新调用点)
**Files**:
- `src/agentkit/chat/sqlite_conversation_store.py` — 新建
- `src/agentkit/server/routes/portal.py` — 替换 ConversationStore
- `src/agentkit/chat/conversation_store.py` — 保留作为接口/内存实现
**Approach**:
1. 新建 `SqliteConversationStore`,实现与 `ConversationStore` 相同接口
2. SQLite 表结构conversations(id, session_id, role, content, timestamp, metadata)
3. 支持按 session_id 查询、分页加载、搜索
4. 数据库文件路径:`~/.agentkit/conversations.db`
5. 在 portal.py 中替换 ConversationStore 为 SqliteConversationStore
6. 保留 ConversationStore 作为内存实现(测试用)
**Test scenarios**:
- 消息正确持久化到 SQLite
- 按 session_id 查询返回完整对话
- 分页加载正确
- 服务重启后数据不丢失
- SQLite 文件不存在时自动创建
**Verification**: `ruff check src/ && pytest tests/unit/chat/test_sqlite_conversation_store.py -v`
---
### U6. 可验证执行Test-and-Verify 循环)
**Goal**: ReActEngine 执行后可选自动运行项目测试验证结果
**Dependencies**: U4
**Files**:
- `src/agentkit/core/verification_loop.py` — 新建
- `src/agentkit/core/react.py` — 集成验证循环
- `src/agentkit/tools/builtin.py` — 新增 run_tests 工具
**Approach**:
1. 新建 `VerificationLoop`:执行后运行 pytest/ruff/typecheck失败则自动重试
2. 最大重试次数可配置(默认 2
3. 验证结果附加到 ReActResult
4. 新增 `run_tests` 内置工具LLM 可主动调用
5. 验证循环默认关闭,通过参数 `verification_enabled=True` 启用
**Test scenarios**:
- 验证循环关闭时行为不变
- 验证循环开启时,执行后自动运行测试
- 测试通过时返回成功
- 测试失败时自动重试
- 达到最大重试次数后返回失败结果
**Verification**: `ruff check src/ && pytest tests/unit/core/test_verification_loop.py -v`
---
### U7. 工具描述分层注入 + tool_search
**Goal**: 核心工具全量注入,扩展工具只注入名称+一行描述LLM 可通过 tool_search 获取完整描述
**Dependencies**: 无
**Files**:
- `src/agentkit/core/react.py` — 修改 `_build_tool_use_prompt`
- `src/agentkit/tools/builtin.py` — 新增 tool_search 工具
- `src/agentkit/tools/search.py` — 新建BM25 工具搜索
**Approach**:
1. 工具分为 coreread/write/bash/search和 extended其余
2. core 工具全量注入 prompt
3. extended 工具只注入 name + one-line description
4. 新增 `tool_search` 工具BM25 搜索工具描述,返回完整描述
5. LLM 在 agent loop 中按需调用 tool_search
**Test scenarios**:
- core 工具全量出现在 prompt 中
- extended 工具只出现名称和一行描述
- tool_search 正确返回工具完整描述
- BM25 搜索相关性排序
**Verification**: `ruff check src/ && pytest tests/unit/tools/test_tool_search.py -v`
---
### U8. 默认启用上下文压缩
**Goal**: ReActEngine 默认启用滑动窗口压缩
**Dependencies**: 无
**Files**:
- `src/agentkit/core/react.py` — 修改默认 compressor 参数
- `src/agentkit/core/compressor.py` — 确认滑动窗口实现
**Approach**:
1. ReActEngine 的 `__init__` 中 compressor 默认值从 None 改为 SlidingWindowCompressor
2. 保留最近 N 轮 + 系统提示 + 工具描述
3. N 可配置(默认 10
**Test scenarios**:
- 长对话自动压缩
- 压缩后系统提示和工具描述保留
- 压缩不影响最近 N 轮对话
**Verification**: `ruff check src/ && pytest tests/unit/core/test_compressor.py -v`
---
### U9. Spec 文档作为一等公民
**Goal**: PlanExec 生成的计划持久化为 Spec 文档,用户可查看、编辑、确认后再执行
**Dependencies**: U4
**Files**:
- `src/agentkit/core/spec_manager.py` — 新建
- `src/agentkit/core/plan_exec_engine.py` — 集成 SpecManager
- `src/agentkit/server/routes/tasks.py` — 新增 Spec 相关 API
**Approach**:
1. 新建 `SpecManager`:管理 Spec 文档的 CRUD
2. Spec 文件路径:`.agentkit/specs/<plan_id>.yaml`
3. PlanExecEngine 生成计划后,先持久化为 Spec
4. 新增 API`GET /api/v1/specs`、`GET /api/v1/specs/{id}`、`PUT /api/v1/specs/{id}`、`POST /api/v1/specs/{id}/confirm`
5. 用户确认后才开始执行
**Test scenarios**:
- 计划正确持久化为 Spec 文件
- Spec 文件可读取和编辑
- 未确认的 Spec 不会执行
- 确认后触发执行
**Verification**: `ruff check src/ && pytest tests/unit/core/test_spec_manager.py -v`
---
### U10. 统一事件模型SQ/EQ 双队列)
**Goal**: 统一 CLI 和 WebSocket 的事件模型为 SQ/EQ 双队列
**Dependencies**: U3
**Files**:
- `src/agentkit/core/protocol.py` — 新增 SQ/EQ 事件类型
- `src/agentkit/server/routes/portal.py` — 对接 EQ
- `src/agentkit/cli/chat.py` — 对接 EQ
**Approach**:
1. 定义 SubmissionQueue用户输入和 EventQueueAgent 输出)
2. 事件类型Session/Task/Turn 三级模型
3. Portal WebSocket 和 CLI 共享同一事件流
4. 前端可以统一渲染
**Test scenarios**:
- SQ 正确接收用户输入
- EQ 正确推送 Agent 事件
- WebSocket 和 CLI 共享事件流
- 事件类型正确分类
**Verification**: `ruff check src/ && pytest tests/unit/core/test_protocol.py -v`
---
## Scope Boundaries
### In Scope
- 上述 10 个 Implementation Unit
- 单元测试覆盖
### Out of Scope
- DockerComputerUseSession 实现P3等用户需求验证
- 前端组件更新(后续迭代)
- Agent 配置热重载P4
- 渐进式上下文加载P4
- Soul 演变多维度触发(关闭)
- OTel 埋点(延后)
### Deferred to Follow-Up Work
- 前端 ExpertTeamView 接入真实数据
- SWE-bench 端到端验证
- 性能监控和成本追踪
---
## Risks & Dependencies
| Risk | Impact | Mitigation |
|------|--------|------------|
| U3 专家团重写可能影响现有 @team 功能 | 中 | 保留 ExpertConfig/ExpertTemplate/Registry只重写 Orchestrator |
| U5 SQLite 在高并发下可能有锁竞争 | 低 | 聊天场景写频率低SQLite WAL 模式足够 |
| U7 tool_search 可能增加 LLM 调用轮次 | 中 | 核心工具全量注入,只有扩展工具需要搜索 |
| U9 Spec 文档可能增加用户操作步骤 | 低 | 默认自动确认(可配置),不阻塞自动化流程 |
---
## Phased Delivery
**Phase 1清理收尾**: U1 → U2 → U4 → U8
**Phase 2核心能力**: U5 → U6 → U9
**Phase 3多 Agent**: U3 → U7 → U10

View File

@ -1,4 +1,4 @@
"""Simple router — minimal routing layer for unified REACT agent loop. """Request preprocessor — minimal preprocessing layer for unified REACT agent loop.
Replaces the 4-layer CostAwareRouter with a simple approach: Replaces the 4-layer CostAwareRouter with a simple approach:
1. @skill:xxx prefix explicit skill selection 1. @skill:xxx prefix explicit skill selection
@ -53,15 +53,15 @@ _IDENTITY_RE = re.compile(
) )
class SimpleRouter: class RequestPreprocessor:
"""Minimal routing layer: regex fast-path + default REACT. """Minimal preprocessing layer: regex fast-path + default REACT.
Design rationale: Design rationale:
- No HeuristicClassifier: keyword enumeration can never cover all colloquial expressions - No HeuristicClassifier: keyword enumeration can never cover all colloquial expressions
- No IntentRouter: LLM blind-classification without tool context is unreliable - No IntentRouter: LLM blind-classification without tool context is unreliable
- No SemanticRouter: embedding similarity is not intent recognition - No SemanticRouter: embedding similarity is not intent recognition
- LLM in the REACT agent loop sees full tool descriptions and decides autonomously - LLM in the REACT agent loop sees full tool descriptions and decides autonomously
- This matches Codex/Trae/Hermes architecture: unified agent loop, no routing layer - This matches Codex/Trae/Hermes architecture: unified agent loop, no preprocessing layer
""" """
def __init__( def __init__(
@ -78,7 +78,7 @@ class SimpleRouter:
self._default_model = default_model self._default_model = default_model
self._default_agent_name = default_agent_name self._default_agent_name = default_agent_name
async def route( async def preprocess(
self, self,
content: str, content: str,
*, *,
@ -90,7 +90,7 @@ class SimpleRouter:
session_id: str = "", session_id: str = "",
transparency: str = "SILENT", transparency: str = "SILENT",
) -> SkillRoutingResult: ) -> SkillRoutingResult:
"""Route user input to the appropriate execution path. """Preprocess user input to determine the appropriate execution path.
Decision tree: Decision tree:
1. @skill:xxx prefix explicit skill (SKILL_REACT or skill's execution_mode) 1. @skill:xxx prefix explicit skill (SKILL_REACT or skill's execution_mode)
@ -99,21 +99,25 @@ class SimpleRouter:
""" """
registry = skill_registry or self._skill_registry registry = skill_registry or self._skill_registry
tools = default_tools if default_tools is not None else self._default_tools tools = default_tools if default_tools is not None else self._default_tools
sys_prompt = default_system_prompt if default_system_prompt is not None else self._default_system_prompt sys_prompt = (
default_system_prompt
if default_system_prompt is not None
else self._default_system_prompt
)
model = default_model or self._default_model model = default_model or self._default_model
agent_name = default_agent_name or self._default_agent_name agent_name = default_agent_name or self._default_agent_name
# --- Layer 0: @skill:xxx prefix --- # --- Layer 0: @skill:xxx prefix ---
explicit_skill, clean_content = parse_skill_prefix(content) explicit_skill, clean_content = parse_skill_prefix(content)
if explicit_skill and registry is not None: if explicit_skill and registry is not None:
result = self._route_explicit_skill( result = self._resolve_explicit_skill(
explicit_skill, clean_content, registry, model, agent_name explicit_skill, clean_content, registry, model, agent_name
) )
return result return result
# --- Layer 1: Greeting/chitchat/identity regex (<1ms, zero tokens) --- # --- Layer 1: Greeting/chitchat/identity regex (<1ms, zero tokens) ---
stripped = content.strip() stripped = content.strip()
if self._is_direct_chat(stripped): if self._is_trivial_input(stripped):
result = SkillRoutingResult( result = SkillRoutingResult(
clean_content=stripped, clean_content=stripped,
matched=False, matched=False,
@ -141,7 +145,7 @@ class SimpleRouter:
) )
return result return result
def _route_explicit_skill( def _resolve_explicit_skill(
self, self,
skill_name: str, skill_name: str,
clean_content: str, clean_content: str,
@ -149,7 +153,7 @@ class SimpleRouter:
model: str, model: str,
agent_name: str, agent_name: str,
) -> SkillRoutingResult: ) -> SkillRoutingResult:
"""Route to an explicitly specified skill via @skill:xxx prefix.""" """Resolve an explicitly specified skill via @skill:xxx prefix."""
try: try:
skill = registry.get(skill_name) skill = registry.get(skill_name)
except Exception: except Exception:
@ -185,13 +189,11 @@ class SimpleRouter:
) )
@staticmethod @staticmethod
def _is_direct_chat(text: str) -> bool: def _is_trivial_input(text: str) -> bool:
"""Check if the input is a greeting, chitchat, or identity question. """Check if the input is a greeting, chitchat, or identity question.
These are zero-cost direct chat: no tool usage, no ReAct loop needed. These are zero-cost direct chat: no tool usage, no ReAct loop needed.
""" """
return bool( return bool(
_GREETING_RE.match(text) _GREETING_RE.match(text) or _CHAT_MODE_RE.match(text) or _IDENTITY_RE.match(text)
or _CHAT_MODE_RE.match(text)
or _IDENTITY_RE.match(text)
) )

View File

@ -1,224 +0,0 @@
"""Semantic Router — Embedding-based intent routing as Layer 1.5.
Uses pre-computed skill embeddings for zero-cost semantic matching,
inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification)
in CostAwareRouter.
Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md
"""
import logging
from dataclasses import dataclass
from typing import Any
from agentkit.memory.embedder import Embedder, EmbeddingCache
from agentkit.utils.vector_math import compute_cosine_similarity
logger = logging.getLogger(__name__)
@dataclass
class SemanticRouteResult:
"""Result of semantic routing."""
confidence: str # "high" | "medium" | "low"
skill_name: str | None
similarity: float
class SkillEmbeddingIndex:
"""Pre-computed embedding index for registered skills.
Embeddings are computed at skill registration time and cached.
Query-time search is O(n) cosine similarity scan, which is fast
for <100 skills with 1024-1536 dim vectors.
"""
def __init__(self, embedder: Embedder):
self._embedder = embedder
# skill_name → (embedding, source_text)
self._index: dict[str, tuple[list[float], str]] = {}
async def build(self, skill_registry: Any) -> None:
"""Build index from all registered skills."""
if skill_registry is None:
return
skills = skill_registry.list_skills()
for skill in skills:
await self.update_skill(skill.config.name, skill)
async def update_skill(self, skill_name: str, skill: Any) -> None:
"""Re-embed a single skill (on registration/update)."""
source_text = self._build_source_text(skill)
try:
embedding = await self._embedder.embed(source_text)
self._index[skill_name] = (embedding, source_text)
except Exception as e:
logger.warning(f"Failed to embed skill '{skill_name}': {e}")
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the index."""
self._index.pop(skill_name, None)
async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]:
"""Search for skills matching the query embedding.
Returns:
List of (skill_name, similarity) sorted by similarity descending.
"""
if not self._index:
return []
results: list[tuple[str, float]] = []
for skill_name, (emb, _) in self._index.items():
sim = compute_cosine_similarity(query_embedding, emb)
results.append((skill_name, sim))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
@staticmethod
def _build_source_text(skill: Any) -> str:
"""Build embedding source text from skill metadata.
Combines description, intent keywords, and capability tags
for rich semantic representation.
"""
config = skill.config if hasattr(skill, "config") else skill
parts = []
# Description
description = getattr(config, "description", "") or ""
if description:
parts.append(description)
# Intent keywords
intent = getattr(config, "intent", None)
if intent and hasattr(intent, "keywords") and intent.keywords:
parts.append(" ".join(intent.keywords))
# Intent examples (rich semantic signal for short queries)
if intent and hasattr(intent, "examples") and intent.examples:
parts.append(" ".join(intent.examples))
# Capability tags
capabilities = getattr(config, "capabilities", None)
if capabilities:
tags = []
for cap in capabilities:
if isinstance(cap, str):
tags.append(cap)
elif isinstance(cap, dict):
tags.append(cap.get("tag", ""))
elif hasattr(cap, "tag"):
tags.append(cap.tag)
if tags:
parts.append(" ".join(t for t in tags if t))
# Fallback: use skill name if no other text available
if not parts:
parts.append(getattr(config, "name", "unknown"))
return " | ".join(parts)
@property
def size(self) -> int:
"""Number of skills in the index."""
return len(self._index)
class SemanticRouter:
"""Embedding-based semantic routing as Layer 1.5.
Three confidence zones:
- similarity > similarity_high (0.85): HIGH direct skill match, skip Layer 2
- similarity_low (0.4) <= similarity <= similarity_high: MEDIUM skill hint for Layer 2
- similarity < similarity_low (0.4): LOW no semantic signal, normal routing
Short text (<20 chars) uses a lower effective threshold because
brief queries naturally have lower embedding similarity.
"""
_SHORT_TEXT_THRESHOLD = 20 # chars
def __init__(
self,
embedder: Embedder,
similarity_high: float = 0.85,
similarity_low: float = 0.4,
):
self._embedder = embedder
self._similarity_high = similarity_high
self._similarity_low = similarity_low
self._index = SkillEmbeddingIndex(embedder)
self._query_cache = EmbeddingCache(max_size=500, ttl=1800)
async def build_index(self, skill_registry: Any) -> None:
"""Build skill embedding index from registry."""
await self._index.build(skill_registry)
logger.info(f"Semantic router index built: {self._index.size} skills")
async def update_skill(self, skill_name: str, skill: Any) -> None:
"""Update a single skill's embedding."""
await self._index.update_skill(skill_name, skill)
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the index."""
self._index.remove_skill(skill_name)
async def route(self, query: str) -> SemanticRouteResult:
"""Route a query using semantic similarity.
Args:
query: User's input text.
Returns:
SemanticRouteResult with confidence, skill_name, and similarity.
"""
if self._index.size == 0:
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
if not query or not query.strip():
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
try:
# Get query embedding (with cache)
query_embedding = self._query_cache.get(query)
if query_embedding is None:
query_embedding = await self._embedder.embed(query)
self._query_cache.put(query, query_embedding)
# Search skill index
results = await self._index.search(query_embedding, top_k=1)
if not results:
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
best_skill, best_sim = results[0]
# Short text uses lower effective threshold
effective_low = self._similarity_low
if len(query) < self._SHORT_TEXT_THRESHOLD:
effective_low = max(0.25, self._similarity_low - 0.15)
if best_sim >= self._similarity_high:
return SemanticRouteResult(
confidence="high",
skill_name=best_skill,
similarity=best_sim,
)
elif best_sim >= effective_low:
return SemanticRouteResult(
confidence="medium",
skill_name=best_skill,
similarity=best_sim,
)
else:
return SemanticRouteResult(
confidence="low",
skill_name=None,
similarity=best_sim,
)
except Exception as e:
logger.warning(f"Semantic routing failed, returning low confidence: {e}")
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)

File diff suppressed because it is too large Load Diff

View File

@ -96,11 +96,10 @@ async def _chat_async(
WebCrawlTool(), WebCrawlTool(),
] ]
# ── Load skills and build IntentRouter ─────────────────────── # ── Load skills ────────────────────────────────────────────
from agentkit.tools.registry import ToolRegistry from agentkit.tools.registry import ToolRegistry
from agentkit.skills.registry import SkillRegistry from agentkit.skills.registry import SkillRegistry
from agentkit.skills.loader import SkillLoader from agentkit.skills.loader import SkillLoader
from agentkit.router.intent import IntentRouter
tool_registry = ToolRegistry() tool_registry = ToolRegistry()
for tool in tools: for tool in tools:
@ -123,8 +122,6 @@ async def _chat_async(
except Exception: except Exception:
pass pass
intent_router = IntentRouter(llm_gateway=gateway) if skill_registry.list_skills() else None
# Build system prompt — inject memory into system prompt # Build system prompt — inject memory into system prompt
base_prompt = system_prompt or ( base_prompt = system_prompt or (
"你是一个有帮助的AI助手。请记住我们对话的上下文并在后续对话中引用之前的内容。回答要清晰简洁请使用中文回复。" "你是一个有帮助的AI助手。请记住我们对话的上下文并在后续对话中引用之前的内容。回答要清晰简洁请使用中文回复。"
@ -218,7 +215,6 @@ async def _chat_async(
routing = await resolve_skill_routing( routing = await resolve_skill_routing(
content=user_input, content=user_input,
skill_registry=skill_registry, skill_registry=skill_registry,
intent_router=intent_router,
default_tools=tools, default_tools=tools,
default_system_prompt=effective_system_prompt, default_system_prompt=effective_system_prompt,
default_model=current_model, default_model=current_model,

View File

@ -25,6 +25,7 @@ from agentkit.core.plan_schema import ExecutionPlan, PlanStep, PlanStepStatus
from agentkit.core.protocol import CancellationToken, TaskMessage, TaskResult, TaskStatus from agentkit.core.protocol import CancellationToken, TaskMessage, TaskResult, TaskStatus
from agentkit.core.react import ReActEvent, ReActResult, ReActStep from agentkit.core.react import ReActEvent, ReActResult, ReActStep
from agentkit.core.shared_workspace import SharedWorkspace from agentkit.core.shared_workspace import SharedWorkspace
from agentkit.core.spec_manager import Spec, SpecManager, SpecStep
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageResult, StageStatus from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageResult, StageStatus
@ -73,6 +74,7 @@ class PlanExecEngine:
default_timeout: float = 300.0, default_timeout: float = 300.0,
workspace: SharedWorkspace | None = None, workspace: SharedWorkspace | None = None,
step_event_callback: "Callable[[str, dict[str, Any]], Awaitable[None]] | None" = None, step_event_callback: "Callable[[str, dict[str, Any]], Awaitable[None]] | None" = None,
spec_manager: SpecManager | None = None,
): ):
""" """
Args: Args:
@ -81,12 +83,14 @@ class PlanExecEngine:
default_timeout: 默认超时秒数 default_timeout: 默认超时秒数
workspace: SharedWorkspace 实例用于步骤间状态传递 workspace: SharedWorkspace 实例用于步骤间状态传递
step_event_callback: 步骤事件回调用于非流式执行时推送进度 step_event_callback: 步骤事件回调用于非流式执行时推送进度
spec_manager: SpecManager 实例用于持久化执行计划为 Spec 文档
""" """
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
self._max_replans = max_replans self._max_replans = max_replans
self._default_timeout = default_timeout self._default_timeout = default_timeout
self._workspace = workspace self._workspace = workspace
self._step_event_callback = step_event_callback self._step_event_callback = step_event_callback
self._spec_manager = spec_manager
self._confirmation_handler: Any | None = None self._confirmation_handler: Any | None = None
# 组合子组件 # 组合子组件
@ -261,6 +265,17 @@ class PlanExecEngine:
tokens=0, tokens=0,
)) ))
# Persist plan as Spec if spec_manager is provided
if self._spec_manager is not None:
spec = self._plan_to_spec(plan)
self._spec_manager.create(spec)
state.step_counter += 1
yield ReActEvent(
event_type="spec_created",
step=state.step_counter,
data={"spec_id": spec.spec_id, "goal": spec.goal, "num_steps": len(spec.steps)},
)
# ── Phase 2 & 3: Execute with optional replanning ── # ── Phase 2 & 3: Execute with optional replanning ──
current_plan = plan current_plan = plan
replan_count = 0 replan_count = 0
@ -509,6 +524,20 @@ class PlanExecEngine:
tokens=0, tokens=0,
)) ))
# Persist plan as Spec if spec_manager is provided
if self._spec_manager is not None:
spec = self._plan_to_spec(plan)
self._spec_manager.create(spec)
if self._step_event_callback:
try:
await self._step_event_callback("spec_created", {
"spec_id": spec.spec_id,
"goal": spec.goal,
"num_steps": len(spec.steps),
})
except Exception as e:
logger.warning(f"Step event callback failed: {e}")
if trace_recorder is not None: if trace_recorder is not None:
trace_recorder.record_step( trace_recorder.record_step(
step=1, step=1,
@ -613,7 +642,7 @@ class PlanExecEngine:
task_id=task_id, task_id=task_id,
) )
# 创建 PlanExecutor(使用 LLM 直接调用模式) # 创建 PlanExecutor
executor = self._create_executor( executor = self._create_executor(
messages=messages, messages=messages,
model=model, model=model,
@ -734,6 +763,25 @@ class PlanExecEngine:
# 辅助方法 # 辅助方法
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def _plan_to_spec(plan: ExecutionPlan) -> Spec:
"""Convert an ExecutionPlan to a Spec for persistence."""
steps = [
SpecStep(
step_id=s.step_id,
name=s.name,
description=s.description,
dependencies=s.dependencies,
)
for s in plan.steps
]
return Spec(
spec_id=plan.plan_id,
goal=plan.goal,
steps=steps,
metadata=plan.metadata,
)
@staticmethod @staticmethod
def _extract_goal(messages: list[dict[str, str]]) -> str: def _extract_goal(messages: list[dict[str, str]]) -> str:
"""从消息列表中提取用户目标""" """从消息列表中提取用户目标"""
@ -779,23 +827,8 @@ class PlanExecEngine:
model: str, model: str,
system_prompt: str | None, system_prompt: str | None,
tools: list["Tool"] | None, tools: list["Tool"] | None,
step_executor_type: str = "react",
) -> PlanExecutor: ) -> PlanExecutor:
"""创建 PlanExecutor 实例 """创建 PlanExecutor 实例,使用 ReActStepExecutor 执行步骤"""
Args:
step_executor_type: "react" 使用 ReActStepExecutor默认支持工具调用
"llm" 使用 _LLMStepExecutor LLM 调用无工具
"""
if step_executor_type == "llm":
step_executor: _LLMStepExecutor | ReActStepExecutor = _LLMStepExecutor(
llm_gateway=self._llm_gateway,
messages=messages,
model=model,
system_prompt=system_prompt,
tools=tools,
)
else:
step_executor = ReActStepExecutor( step_executor = ReActStepExecutor(
llm_gateway=self._llm_gateway, llm_gateway=self._llm_gateway,
messages=messages, messages=messages,
@ -937,58 +970,6 @@ class PlanExecEngine:
return "\n\n".join(parts) return "\n\n".join(parts)
class _LLMStepExecutor:
"""LLM 直接调用步骤执行器
作为 PlanExecutor agent_pool 替代品
使每个 PlanStep 通过 LLM 直接调用执行而非通过 AgentPool
"""
def __init__(
self,
llm_gateway: "LLMGateway | None" = None,
messages: list[dict[str, str]] | None = None,
model: str = "default",
system_prompt: str | None = None,
tools: list["Tool"] | None = None,
):
self._llm_gateway = llm_gateway
self._messages = messages or []
self._model = model
self._system_prompt = system_prompt
self._tools = tools
self._agents: dict[str, _LLMStepAgent] = {}
async def create_agent_from_skill(self, skill_name: str):
"""创建 LLM 步骤 Agent"""
agent = _LLMStepAgent(
name=skill_name,
llm_gateway=self._llm_gateway,
messages=self._messages,
model=self._model,
system_prompt=self._system_prompt,
tools=self._tools,
)
self._agents[skill_name] = agent
return agent
def get_agent(self, key: str):
"""获取已创建的 Agent"""
if key in self._agents:
return self._agents[key]
# 回退:创建一个默认 Agent
agent = _LLMStepAgent(
name=key,
llm_gateway=self._llm_gateway,
messages=self._messages,
model=self._model,
system_prompt=self._system_prompt,
tools=self._tools,
)
self._agents[key] = agent
return agent
class ReActStepExecutor: class ReActStepExecutor:
"""ReAct 循环步骤执行器 """ReAct 循环步骤执行器
@ -1132,69 +1113,3 @@ class _ReActStepAgent:
started_at=now, started_at=now,
completed_at=now, completed_at=now,
) )
class _LLMStepAgent:
"""LLM 直接调用步骤 Agent
PlanStep 的描述作为 prompt 发送给 LLM
返回 LLM 的响应作为步骤结果
"""
def __init__(
self,
name: str,
llm_gateway: "LLMGateway | None" = None,
messages: list[dict[str, str]] | None = None,
model: str = "default",
system_prompt: str | None = None,
tools: list["Tool"] | None = None,
):
self.name = name
self._llm_gateway = llm_gateway
self._messages = messages or []
self._model = model
self._system_prompt = system_prompt
self._tools = tools
async def execute(self, task_msg: TaskMessage) -> "TaskResult":
"""执行步骤:通过 LLM 直接调用"""
if self._llm_gateway is None:
raise RuntimeError(f"No LLM gateway available for step '{task_msg.task_id}'")
# 构建步骤 prompt
input_data = task_msg.input_data
step_name = input_data.get("step_name", task_msg.task_id)
step_description = input_data.get("step_description", "")
dep_results = input_data.get("dependency_results", {})
prompt_parts = [f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"]
if dep_results:
prompt_parts.append(f"\nResults from previous steps:\n{json.dumps(dep_results, ensure_ascii=False, indent=2)}")
prompt_parts.append("\nProvide a clear, structured result for this step.")
conversation: list[dict[str, Any]] = []
if self._system_prompt:
conversation.append({"role": "system", "content": self._system_prompt})
# 添加原始对话上下文
for msg in self._messages:
conversation.append(msg)
conversation.append({"role": "user", "content": "\n".join(prompt_parts)})
response = await self._llm_gateway.chat(
messages=conversation,
model=self._model,
)
now = datetime.now(timezone.utc)
return TaskResult(
task_id=task_msg.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED.value,
output_data={"content": response.content or ""},
error_message=None,
started_at=now,
completed_at=now,
)

View File

@ -74,15 +74,53 @@ class ReActEngine:
使 Agent 能够自主推理并选择工具完成任务 使 Agent 能够自主推理并选择工具完成任务
""" """
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False): # Default core tools that always get full descriptions injected into the
# prompt. ``tool_search`` is included so its full description is always
# available to the LLM when tiered injection is active.
_DEFAULT_CORE_TOOLS: tuple[str, ...] = (
"read_file",
"write_file",
"bash",
"search",
"tool_search",
)
def __init__(
self,
llm_gateway: LLMGateway,
max_steps: int = 10,
default_timeout: float = 300.0,
parallel_tools: bool | str = False,
compressor: "CompressionStrategy | None" = None,
verification_enabled: bool = False,
verification_commands: list[str] | None = None,
core_tool_names: list[str] | None = None,
enable_tool_search: bool = True,
):
if max_steps < 1: if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}") raise ValueError(f"max_steps must be >= 1, got {max_steps}")
if isinstance(parallel_tools, str) and parallel_tools not in ("auto",): if isinstance(parallel_tools, str) and parallel_tools not in ("auto",):
raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}") raise ValueError(
f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}"
)
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
self._max_steps = max_steps self._max_steps = max_steps
self._default_timeout = default_timeout self._default_timeout = default_timeout
self._parallel_tools = parallel_tools self._parallel_tools = parallel_tools
self._verification_enabled = verification_enabled
self._verification_commands = verification_commands
# Tiered tool description injection config
self._core_tool_names: tuple[str, ...] | None = (
tuple(core_tool_names) if core_tool_names is not None else None
)
self._enable_tool_search = enable_tool_search
# Default context compression: keep last 10 turns
if compressor is not None:
self._compressor = compressor
else:
from agentkit.core.compressor import ContextCompressor
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
def reset(self) -> None: def reset(self) -> None:
"""Reset internal state for reuse across conversations. """Reset internal state for reuse across conversations.
@ -120,10 +158,14 @@ class ReActEngine:
4. 返回 ReActResult 包含输出和轨迹 4. 返回 ReActResult 包含输出和轨迹
Args: Args:
compressor: 压缩策略None 时使用实例默认压缩器
cancellation_token: 协作式取消令牌每次循环迭代检查是否已取消 cancellation_token: 协作式取消令牌每次循环迭代检查是否已取消
timeout_seconds: 超时秒数0 表示无超时None 使用 default_timeout timeout_seconds: 超时秒数0 表示无超时None 使用 default_timeout
""" """
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout effective_compressor = compressor if compressor is not None else self._compressor
effective_timeout = (
timeout_seconds if timeout_seconds is not None else self._default_timeout
)
try: try:
if effective_timeout > 0: if effective_timeout > 0:
@ -138,7 +180,7 @@ class ReActEngine:
trace_recorder=trace_recorder, trace_recorder=trace_recorder,
memory_retriever=memory_retriever, memory_retriever=memory_retriever,
task_id=task_id, task_id=task_id,
compressor=compressor, compressor=effective_compressor,
retrieval_config=retrieval_config, retrieval_config=retrieval_config,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler, confirmation_handler=confirmation_handler,
@ -156,7 +198,7 @@ class ReActEngine:
trace_recorder=trace_recorder, trace_recorder=trace_recorder,
memory_retriever=memory_retriever, memory_retriever=memory_retriever,
task_id=task_id, task_id=task_id,
compressor=compressor, compressor=effective_compressor,
retrieval_config=retrieval_config, retrieval_config=retrieval_config,
cancellation_token=cancellation_token, cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler, confirmation_handler=confirmation_handler,
@ -188,6 +230,8 @@ class ReActEngine:
confirmation_handler: Any | None = None, confirmation_handler: Any | None = None,
) -> ReActResult: ) -> ReActResult:
tools = tools or [] tools = tools or []
if tools:
tools = self._maybe_add_tool_search(tools)
tool_schemas = self._build_tool_schemas(tools) if tools else None tool_schemas = self._build_tool_schemas(tools) if tools else None
if tool_schemas: if tool_schemas:
tool_names = [s["function"]["name"] for s in tool_schemas] tool_names = [s["function"]["name"] for s in tool_schemas]
@ -205,7 +249,9 @@ class ReActEngine:
system_prompt = self._build_tool_use_prompt(tools) system_prompt = self._build_tool_use_prompt(tools)
# Telemetry: record agent request # Telemetry: record agent request
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"}) agent_request_counter().add(
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
)
# Start telemetry span for the entire agent execution # Start telemetry span for the entire agent execution
_span_cm = None _span_cm = None
@ -250,7 +296,9 @@ class ReActEngine:
else: else:
system_prompt = f"## 参考信息\n{memory_context}" system_prompt = f"## 参考信息\n{memory_context}"
except Exception as e: except Exception as e:
logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True) logger.warning(
f"Memory retrieval failed, continuing without context: {e}", exc_info=True
)
# 构建初始消息 # 构建初始消息
conversation: list[dict[str, Any]] = [] conversation: list[dict[str, Any]] = []
@ -263,7 +311,9 @@ class ReActEngine:
try: try:
conversation = await compressor.compress(conversation) conversation = await compressor.compress(conversation)
except Exception as e: except Exception as e:
logger.warning(f"Context compression failed, continuing with original messages: {e}") logger.warning(
f"Context compression failed, continuing with original messages: {e}"
)
trace_outcome = "success" trace_outcome = "success"
step = 0 step = 0
@ -323,9 +373,19 @@ class ReActEngine:
# 执行工具调用 # 执行工具调用
if self._parallel_tools == "auto" and len(response.tool_calls) > 1: if self._parallel_tools == "auto" and len(response.tool_calls) > 1:
# Auto mode: mixed parallel/serial based on _parallelizable flag # Auto mode: mixed parallel/serial based on _parallelizable flag
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) parallelizable_set = set(
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] self._get_parallelizable_indices(response.tool_calls)
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] )
serial_calls = [
(i, tc)
for i, tc in enumerate(response.tool_calls)
if i not in parallelizable_set
]
parallel_calls = [
(i, tc)
for i, tc in enumerate(response.tool_calls)
if i in parallelizable_set
]
# Result slots indexed by original position # Result slots indexed by original position
all_results: list[Any] = [None] * len(response.tool_calls) all_results: list[Any] = [None] * len(response.tool_calls)
@ -340,7 +400,10 @@ class ReActEngine:
# Execute parallelizable tools in parallel # Execute parallelizable tools in parallel
if len(parallel_calls) > 1: if len(parallel_calls) > 1:
para_results = await asyncio.gather( para_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], *[
self._execute_tool(tc.name, tc.arguments, tools)
for _, tc in parallel_calls
],
return_exceptions=True, return_exceptions=True,
) )
for j, (i, tc) in enumerate(parallel_calls): for j, (i, tc) in enumerate(parallel_calls):
@ -381,12 +444,17 @@ class ReActEngine:
error=tool_error, error=tool_error,
) )
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) tool_msg = await self._build_tool_result_message(
tc.id, tool_result, compressor, tc.name
)
conversation.append(tool_msg) conversation.append(tool_msg)
elif self._should_execute_parallel(response.tool_calls): elif self._should_execute_parallel(response.tool_calls):
# 并行执行多个工具调用 (parallel_tools=True) # 并行执行多个工具调用 (parallel_tools=True)
tool_results = await asyncio.gather( tool_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], *[
self._execute_tool(tc.name, tc.arguments, tools)
for tc in response.tool_calls
],
return_exceptions=True, return_exceptions=True,
) )
for idx, tc in enumerate(response.tool_calls): for idx, tc in enumerate(response.tool_calls):
@ -419,7 +487,9 @@ class ReActEngine:
error=tool_error, error=tool_error,
) )
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) tool_msg = await self._build_tool_result_message(
tc.id, tool_result, compressor, tc.name
)
conversation.append(tool_msg) conversation.append(tool_msg)
else: else:
# 串行执行(单工具或 parallel_tools=False # 串行执行(单工具或 parallel_tools=False
@ -428,7 +498,9 @@ class ReActEngine:
tool_result = await self._execute_tool(tc.name, tc.arguments, tools) tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
# Handle confirmation flow # Handle confirmation flow
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): if isinstance(tool_result, dict) and tool_result.get(
"needs_confirmation"
):
confirmation_id = tool_result["confirmation_id"] confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "") command = tool_result.get("command", "")
reason = tool_result.get("reason", "") reason = tool_result.get("reason", "")
@ -436,28 +508,46 @@ class ReActEngine:
approved = False approved = False
if confirmation_handler is not None: if confirmation_handler is not None:
try: try:
approved = await confirmation_handler(confirmation_id, command, reason) approved = await confirmation_handler(
confirmation_id, command, reason
)
except Exception as e: except Exception as e:
logger.warning(f"Confirmation handler error: {e}") logger.warning(f"Confirmation handler error: {e}")
if approved: if approved:
tool = self._find_tool(tc.name, tools) tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'): if tool and hasattr(tool, "_is_dangerous"):
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {
k: v
for k, v in tc.arguments.items()
if not k.startswith("_")
}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
tool_result = await tool.safe_execute(**clean_args) tool_result = await tool.safe_execute(**clean_args)
except Exception as e: except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} tool_result = {
"error": f"Tool '{tc.name}' execution failed: {e}"
}
else: else:
# Non-dangerous tool: confirmation was for the overall action, # Non-dangerous tool: confirmation was for the overall action,
# re-execute with skip flag to avoid re-triggering confirmation # re-execute with skip flag to avoid re-triggering confirmation
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {
k: v
for k, v in tc.arguments.items()
if not k.startswith("_")
}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} tool_result = (
await tool.safe_execute(**clean_args)
if tool
else {"error": f"Tool '{tc.name}' not found"}
)
except Exception as e: except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} tool_result = {
"error": f"Tool '{tc.name}' execution failed: {e}"
}
else: else:
tool_result = { tool_result = {
"output": "", "output": "",
@ -496,7 +586,9 @@ class ReActEngine:
) )
# Observe: 将工具结果添加到对话历史 # Observe: 将工具结果添加到对话历史
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) tool_msg = await self._build_tool_result_message(
tc.id, tool_result, compressor, tc.name
)
conversation.append(tool_msg) conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long # Incremental compression: compress conversation if it's getting long
@ -524,7 +616,9 @@ class ReActEngine:
for pc in parsed_calls: for pc in parsed_calls:
tool_start = time.monotonic() tool_start = time.monotonic()
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools) tool_result = await self._execute_tool(
pc["name"], pc["arguments"], tools
)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000) tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
react_step = ReActStep( react_step = ReActStep(
@ -554,7 +648,9 @@ class ReActEngine:
) )
# 将工具结果添加到对话历史 # 将工具结果添加到对话历史
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]) tool_msg = await self._build_tool_result_message(
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
)
conversation.append(tool_msg) conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long # Incremental compression: compress conversation if it's getting long
@ -585,6 +681,35 @@ class ReActEngine:
) )
break break
# Verification: 如果启用验证,在 final answer 后运行测试
if self._verification_enabled and output:
try:
from agentkit.core.verification_loop import VerificationLoop
vloop = VerificationLoop(commands=self._verification_commands)
vresult = await vloop.verify()
if not vresult.passed:
# 将验证失败信息作为 ReActStep 添加到轨迹
verification_step = ReActStep(
step=step + 1,
action="tool_call",
tool_name="verification",
arguments={"commands": self._verification_commands},
result={
"passed": vresult.passed,
"errors": vresult.errors,
"test_output": vresult.test_output,
},
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
)
trajectory.append(verification_step)
logger.info(
"Verification failed after final answer, "
"appended feedback to trajectory"
)
except Exception as e:
logger.warning(f"Verification loop failed: {e}")
# 达到 max_steps 时,返回当前最佳输出 # 达到 max_steps 时,返回当前最佳输出
if step >= self._max_steps and not output: if step >= self._max_steps and not output:
trace_outcome = "partial" trace_outcome = "partial"
@ -599,6 +724,7 @@ class ReActEngine:
# 兜底:确保 output 永远不为空字符串 # 兜底:确保 output 永远不为空字符串
if not output or not output.strip(): if not output or not output.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
if step >= self._max_steps: if step >= self._max_steps:
output = MAX_STEPS_REACHED output = MAX_STEPS_REACHED
else: else:
@ -660,8 +786,14 @@ class ReActEngine:
Same logic as execute() but yields events at each step instead of Same logic as execute() but yields events at each step instead of
accumulating a result. accumulating a result.
Args:
compressor: 压缩策略None 时使用实例默认压缩器
""" """
effective_compressor = compressor if compressor is not None else self._compressor
tools = tools or [] tools = tools or []
if tools:
tools = self._maybe_add_tool_search(tools)
tool_schemas = self._build_tool_schemas(tools) if tools else None tool_schemas = self._build_tool_schemas(tools) if tools else None
if tool_schemas: if tool_schemas:
tool_names = [s["function"]["name"] for s in tool_schemas] tool_names = [s["function"]["name"] for s in tool_schemas]
@ -679,7 +811,9 @@ class ReActEngine:
system_prompt = self._build_tool_use_prompt(tools) system_prompt = self._build_tool_use_prompt(tools)
# Telemetry: record agent request # Telemetry: record agent request
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"}) agent_request_counter().add(
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
)
# Start telemetry span for the entire agent execution # Start telemetry span for the entire agent execution
_span_cm = None _span_cm = None
@ -726,11 +860,13 @@ class ReActEngine:
conversation.extend(messages) conversation.extend(messages)
# Context compression: 压缩超长对话历史 # Context compression: 压缩超长对话历史
if compressor: if effective_compressor:
try: try:
conversation = await compressor.compress(conversation) conversation = await effective_compressor.compress(conversation)
except Exception as e: except Exception as e:
logger.warning(f"Context compression failed, continuing with original messages: {e}") logger.warning(
f"Context compression failed, continuing with original messages: {e}"
)
trajectory: list[ReActStep] = [] trajectory: list[ReActStep] = []
total_tokens = 0 total_tokens = 0
@ -738,7 +874,9 @@ class ReActEngine:
output = "" output = ""
trace_outcome = "success" trace_outcome = "success"
_stream_start = time.monotonic() _stream_start = time.monotonic()
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout effective_timeout = (
timeout_seconds if timeout_seconds is not None else self._default_timeout
)
try: try:
while step < self._max_steps: while step < self._max_steps:
@ -836,19 +974,44 @@ class ReActEngine:
conversation.append(assistant_msg) conversation.append(assistant_msg)
# Execute tool calls with parallel support # Execute tool calls with parallel support
if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls): if (
self._parallel_tools
and len(response.tool_calls) > 1
and self._should_execute_parallel(response.tool_calls)
):
# Parallel execution path # Parallel execution path
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls))) parallelizable_set = (
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] set(self._get_parallelizable_indices(response.tool_calls))
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] if self._parallel_tools == "auto"
else set(range(len(response.tool_calls)))
)
serial_calls = [
(i, tc)
for i, tc in enumerate(response.tool_calls)
if i not in parallelizable_set
]
parallel_calls = [
(i, tc)
for i, tc in enumerate(response.tool_calls)
if i in parallelizable_set
]
all_results: list[Any] = [None] * len(response.tool_calls) all_results: list[Any] = [None] * len(response.tool_calls)
# Execute serial tools first (handles confirmation flow) # Execute serial tools first (handles confirmation flow)
for i, tc in serial_calls: for i, tc in serial_calls:
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) yield ReActEvent(
event_type="tool_call",
step=step,
data={"tool_name": tc.name, "arguments": tc.arguments},
)
tool_start = time.monotonic() tool_start = time.monotonic()
tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler) (
tool_result,
confirm_events,
) = await self._execute_tool_with_confirmation(
tc, tools, step, confirmation_handler
)
for ev in confirm_events: for ev in confirm_events:
yield ev yield ev
tool_duration_ms = int((time.monotonic() - tool_start) * 1000) tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
@ -857,7 +1020,10 @@ class ReActEngine:
# Execute parallelizable tools concurrently # Execute parallelizable tools concurrently
if len(parallel_calls) > 1: if len(parallel_calls) > 1:
para_results = await asyncio.gather( para_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], *[
self._execute_tool(tc.name, tc.arguments, tools)
for _, tc in parallel_calls
],
return_exceptions=True, return_exceptions=True,
) )
for j, (i, tc) in enumerate(parallel_calls): for j, (i, tc) in enumerate(parallel_calls):
@ -873,19 +1039,45 @@ class ReActEngine:
# Process all results in original order # Process all results in original order
for i, tc in enumerate(response.tool_calls): for i, tc in enumerate(response.tool_calls):
tc_obj, tool_result, tool_duration_ms = all_results[i] tc_obj, tool_result, tool_duration_ms = all_results[i]
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) yield ReActEvent(
event_type="tool_call",
step=step,
data={"tool_name": tc.name, "arguments": tc.arguments},
)
react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens) react_step = ReActStep(
step=step,
action="tool_call",
tool_name=tc.name,
arguments=tc.arguments,
result=tool_result,
tokens=step_tokens,
)
trajectory.append(react_step) trajectory.append(react_step)
if trace_recorder is not None: if trace_recorder is not None:
tool_error = None tool_error = None
if isinstance(tool_result, dict) and "error" in tool_result: if isinstance(tool_result, dict) and "error" in tool_result:
tool_error = tool_result["error"] tool_error = tool_result["error"]
trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error) trace_recorder.record_step(
step=step,
action="tool_call",
tool_name=tc.name,
input_data=tc.arguments,
output_data=tool_result,
duration_ms=tool_duration_ms,
tokens_used=0,
error=tool_error,
)
yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result}) yield ReActEvent(
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) event_type="tool_result",
step=step,
data={"tool_name": tc.name, "result": tool_result},
)
tool_msg = await self._build_tool_result_message(
tc.id, tool_result, effective_compressor, tc.name
)
conversation.append(tool_msg) conversation.append(tool_msg)
else: else:
# Serial execution path (with confirmation flow) # Serial execution path (with confirmation flow)
@ -902,7 +1094,9 @@ class ReActEngine:
tool_duration_ms = int((time.monotonic() - tool_start) * 1000) tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
# 检测工具返回的确认请求 # 检测工具返回的确认请求
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): if isinstance(tool_result, dict) and tool_result.get(
"needs_confirmation"
):
confirmation_id = tool_result["confirmation_id"] confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "") command = tool_result.get("command", "")
reason = tool_result.get("reason", "") reason = tool_result.get("reason", "")
@ -923,16 +1117,22 @@ class ReActEngine:
approved = False approved = False
if confirmation_handler is not None: if confirmation_handler is not None:
try: try:
approved = await confirmation_handler(confirmation_id, command, reason) approved = await confirmation_handler(
confirmation_id, command, reason
)
except Exception as e: except Exception as e:
logger.warning(f"Confirmation handler error: {e}") logger.warning(f"Confirmation handler error: {e}")
if approved: if approved:
# 用户确认执行:使用 per-call override 绕过安全检查 # 用户确认执行:使用 per-call override 绕过安全检查
tool = self._find_tool(tc.name, tools) tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'): if tool and hasattr(tool, "_is_dangerous"):
# Strip internal metadata and pass skip_dangerous_check flag # Strip internal metadata and pass skip_dangerous_check flag
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {
k: v
for k, v in tc.arguments.items()
if not k.startswith("_")
}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
tool_result = await tool.safe_execute(**clean_args) tool_result = await tool.safe_execute(**clean_args)
@ -940,12 +1140,22 @@ class ReActEngine:
pass # No shared state mutation needed pass # No shared state mutation needed
else: else:
# Non-dangerous tool: re-execute with skip flag # Non-dangerous tool: re-execute with skip flag
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {
k: v
for k, v in tc.arguments.items()
if not k.startswith("_")
}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} tool_result = (
await tool.safe_execute(**clean_args)
if tool
else {"error": f"Tool '{tc.name}' not found"}
)
except Exception as e: except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} tool_result = {
"error": f"Tool '{tc.name}' execution failed: {e}"
}
yield ReActEvent( yield ReActEvent(
event_type="confirmation_result", event_type="confirmation_result",
@ -964,7 +1174,10 @@ class ReActEngine:
yield ReActEvent( yield ReActEvent(
event_type="confirmation_result", event_type="confirmation_result",
step=step, step=step,
data={"confirmation_id": confirmation_id, "approved": False}, data={
"confirmation_id": confirmation_id,
"approved": False,
},
) )
tool_duration_ms = int((time.monotonic() - tool_start) * 1000) tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
@ -1001,13 +1214,15 @@ class ReActEngine:
data={"tool_name": tc.name, "result": tool_result}, data={"tool_name": tc.name, "result": tool_result},
) )
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) tool_msg = await self._build_tool_result_message(
tc.id, tool_result, effective_compressor, tc.name
)
conversation.append(tool_msg) conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long # Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor): if self._should_compress(conversation, effective_compressor):
try: try:
conversation = await compressor.compress(conversation) conversation = await effective_compressor.compress(conversation)
except Exception as e: except Exception as e:
logger.warning(f"Incremental compression failed: {e}") logger.warning(f"Incremental compression failed: {e}")
@ -1033,16 +1248,20 @@ class ReActEngine:
data={"tool_name": pc["name"], "arguments": pc["arguments"]}, data={"tool_name": pc["name"], "arguments": pc["arguments"]},
) )
tool_start = time.monotonic() tool_start = time.monotonic()
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools) tool_result = await self._execute_tool(
pc["name"], pc["arguments"], tools
)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000) tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
trajectory.append(ReActStep( trajectory.append(
ReActStep(
step=step, step=step,
action="tool_call", action="tool_call",
tool_name=pc["name"], tool_name=pc["name"],
arguments=pc["arguments"], arguments=pc["arguments"],
result=tool_result, result=tool_result,
tokens=step_tokens, tokens=step_tokens,
)) )
)
# 记录工具调用步骤 # 记录工具调用步骤
if trace_recorder is not None: if trace_recorder is not None:
tool_error = None tool_error = None
@ -1064,14 +1283,17 @@ class ReActEngine:
data={"tool_name": pc["name"], "result": tool_result}, data={"tool_name": pc["name"], "result": tool_result},
) )
tool_msg = await self._build_tool_result_message( tool_msg = await self._build_tool_result_message(
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"] pc.get("id", f"text_tc_{step}"),
tool_result,
effective_compressor,
pc["name"],
) )
conversation.append(tool_msg) conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long # Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor): if self._should_compress(conversation, effective_compressor):
try: try:
conversation = await compressor.compress(conversation) conversation = await effective_compressor.compress(conversation)
except Exception as e: except Exception as e:
logger.warning(f"Incremental compression failed: {e}") logger.warning(f"Incremental compression failed: {e}")
else: else:
@ -1106,6 +1328,46 @@ class ReActEngine:
) )
break break
# Verification: 如果启用验证,在 final answer 后运行测试
if self._verification_enabled and output:
try:
from agentkit.core.verification_loop import VerificationLoop
vloop = VerificationLoop(commands=self._verification_commands)
vresult = await vloop.verify()
if not vresult.passed:
verification_step = ReActStep(
step=step + 1,
action="tool_call",
tool_name="verification",
arguments={"commands": self._verification_commands},
result={
"passed": vresult.passed,
"errors": vresult.errors,
"test_output": vresult.test_output,
},
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
)
trajectory.append(verification_step)
yield ReActEvent(
event_type="tool_result",
step=step + 1,
data={
"tool_name": "verification",
"result": {
"passed": vresult.passed,
"errors": vresult.errors,
"test_output": vresult.test_output,
},
},
)
logger.info(
"Verification failed after final answer, "
"appended feedback to trajectory"
)
except Exception as e:
logger.warning(f"Verification loop failed: {e}")
if step >= self._max_steps and not output: if step >= self._max_steps and not output:
trace_outcome = "partial" trace_outcome = "partial"
if trajectory and trajectory[-1].content: if trajectory and trajectory[-1].content:
@ -1129,6 +1391,7 @@ class ReActEngine:
# 兜底:确保 output 永远不为空字符串 # 兜底:确保 output 永远不为空字符串
if not output or not output.strip(): if not output or not output.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
if step >= self._max_steps: if step >= self._max_steps:
output = MAX_STEPS_REACHED output = MAX_STEPS_REACHED
else: else:
@ -1187,33 +1450,40 @@ class ReActEngine:
schemas.append(schema) schemas.append(schema)
return schemas return schemas
@staticmethod def _build_tool_use_prompt(self, tools: list[Tool]) -> str:
def _build_tool_use_prompt(tools: list[Tool]) -> str: """Build prompt-based tool calling instructions with tiered injection.
"""Build prompt-based tool calling instructions for LLMs that don't
support native function calling (e.g., Bailian Coding, Qwen).
Instructs the LLM to use <tool_use> XML format for tool invocation. Core tools (defined by ``self._core_tool_names`` or
This follows the Hermes pattern: model-agnostic prompt-based tool calling. :attr:`_DEFAULT_CORE_TOOLS`) get full descriptions (name +
description + parameters). Extended tools get only name + a
one-line description. When ``tool_search`` is present alongside
extended tools, a hint is added telling the LLM to call
``tool_search`` for full parameter details.
Instructs the LLM to use ``<tool_use>`` XML format for tool
invocation (Hermes pattern: model-agnostic prompt-based tool calling).
""" """
tool_descriptions = [] core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
for tool in tools: core_tools = [t for t in tools if t.name in core_names]
params_desc = "" extended_tools = [t for t in tools if t.name not in core_names]
if tool.input_schema:
props = tool.input_schema.get("properties", {}) sections: list[str] = []
required = tool.input_schema.get("required", []) if core_tools:
param_parts = [] sections.append(self._render_core_tools(core_tools))
for pname, pinfo in props.items(): if extended_tools:
ptype = pinfo.get("type", "string") sections.append(self._render_extended_tools(extended_tools))
pdesc = pinfo.get("description", "")
req_flag = " (required)" if pname in required else "" tools_text = "\n\n".join(sections)
param_parts.append(f" - {pname}: {ptype}{req_flag}{pdesc}")
if param_parts: has_tool_search = any(t.name == "tool_search" for t in tools)
params_desc = "\n".join(param_parts) search_hint = ""
tool_descriptions.append( if has_tool_search and extended_tools:
f"- {tool.name}: {tool.description}\n{params_desc}" search_hint = (
"\n\n注意:上方「扩展工具」仅显示名称和简短描述。"
'如需使用某个扩展工具,请先调用 tool_search(query="关键词") '
"获取其完整参数说明。"
) )
tools_text = "\n\n".join(tool_descriptions)
return ( return (
"## 可用工具\n\n" "## 可用工具\n\n"
"你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n" "你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n"
@ -1225,9 +1495,65 @@ class ReActEngine:
"2. 等待工具返回结果后再决定下一步\n" "2. 等待工具返回结果后再决定下一步\n"
"3. 如果不需要工具就能回答,直接回答即可\n" "3. 如果不需要工具就能回答,直接回答即可\n"
"4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n" "4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n"
f"工具列表:\n\n{tools_text}" f"工具列表:\n\n{tools_text}{search_hint}"
) )
@staticmethod
def _render_core_tools(tools: list[Tool]) -> str:
"""Render core tools with full descriptions (name + description + parameters)."""
descriptions: list[str] = []
for tool in tools:
params_desc = ""
if tool.input_schema:
props = tool.input_schema.get("properties", {})
required = tool.input_schema.get("required", [])
param_parts: list[str] = []
for pname, pinfo in props.items():
ptype = pinfo.get("type", "string")
pdesc = pinfo.get("description", "")
req_flag = " (required)" if pname in required else ""
param_parts.append(f" - {pname}: {ptype}{req_flag}{pdesc}")
if param_parts:
params_desc = "\n".join(param_parts)
descriptions.append(f"- {tool.name}: {tool.description}\n{params_desc}")
return "### 核心工具(完整描述)\n\n" + "\n\n".join(descriptions)
@staticmethod
def _render_extended_tools(tools: list[Tool]) -> str:
"""Render extended tools with name + one-line description only."""
lines: list[str] = []
for tool in tools:
desc = tool.description.strip().split("\n")[0]
if len(desc) > 100:
desc = desc[:97] + "..."
lines.append(f"- {tool.name}: {desc}")
return "### 扩展工具(仅名称和简短描述,使用 tool_search 获取详情)\n\n" + "\n".join(lines)
def _maybe_add_tool_search(self, tools: list[Tool]) -> list[Tool]:
"""Add ``tool_search`` tool if enabled and there are extended tools.
Builds a :class:`ToolSearchIndex` from the extended tools so the
LLM can discover full tool descriptions on demand via BM25 search.
If all tools are core tools, or ``tool_search`` is already present,
or ``enable_tool_search`` is False, the list is returned unchanged.
"""
if not self._enable_tool_search:
return tools
if any(t.name == "tool_search" for t in tools):
return tools
core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
extended_tools = [t for t in tools if t.name not in core_names]
if not extended_tools:
return tools
from agentkit.tools.builtin import ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
index = ToolSearchIndex(extended_tools)
search_tool = ToolSearchTool(search_index=index)
return tools + [search_tool]
@staticmethod @staticmethod
def _build_response_from_stream( def _build_response_from_stream(
content: str, content: str,
@ -1237,6 +1563,7 @@ class ReActEngine:
) -> LLMResponse: ) -> LLMResponse:
"""Build an LLMResponse from accumulated stream chunks.""" """Build an LLMResponse from accumulated stream chunks."""
from agentkit.llm.protocol import LLMResponse, TokenUsage from agentkit.llm.protocol import LLMResponse, TokenUsage
if usage is None: if usage is None:
usage = TokenUsage() usage = TokenUsage()
return LLMResponse( return LLMResponse(
@ -1256,7 +1583,9 @@ class ReActEngine:
# Default token threshold for incremental compression # Default token threshold for incremental compression
_DEFAULT_COMPRESS_THRESHOLD = 8000 _DEFAULT_COMPRESS_THRESHOLD = 8000
def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool: def _should_compress(
self, conversation: list[dict], compressor: "CompressionStrategy | None"
) -> bool:
"""检查是否需要增量压缩""" """检查是否需要增量压缩"""
if not compressor: if not compressor:
return False return False
@ -1331,7 +1660,8 @@ class ReActEngine:
command = tool_result.get("command", "") command = tool_result.get("command", "")
reason = tool_result.get("reason", "") reason = tool_result.get("reason", "")
events.append(ReActEvent( events.append(
ReActEvent(
event_type="confirmation_request", event_type="confirmation_request",
step=step, step=step,
data={ data={
@ -1340,7 +1670,8 @@ class ReActEngine:
"command": command, "command": command,
"reason": reason, "reason": reason,
}, },
)) )
)
# Wait for user confirmation # Wait for user confirmation
approved = False approved = False
@ -1353,7 +1684,7 @@ class ReActEngine:
if approved: if approved:
# User approved: re-execute with _skip_dangerous_check # User approved: re-execute with _skip_dangerous_check
tool = self._find_tool(tc.name, tools) tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'): if tool and hasattr(tool, "_is_dangerous"):
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
@ -1365,15 +1696,21 @@ class ReActEngine:
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
clean_args["_skip_dangerous_check"] = True clean_args["_skip_dangerous_check"] = True
try: try:
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} tool_result = (
await tool.safe_execute(**clean_args)
if tool
else {"error": f"Tool '{tc.name}' not found"}
)
except Exception as e: except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
events.append(ReActEvent( events.append(
ReActEvent(
event_type="confirmation_result", event_type="confirmation_result",
step=step, step=step,
data={"confirmation_id": confirmation_id, "approved": True}, data={"confirmation_id": confirmation_id, "approved": True},
)) )
)
else: else:
# User rejected # User rejected
tool_result = { tool_result = {
@ -1383,11 +1720,13 @@ class ReActEngine:
"error_type": "permission_denied", "error_type": "permission_denied",
"message": f"用户拒绝执行命令: {command[:100]}", "message": f"用户拒绝执行命令: {command[:100]}",
} }
events.append(ReActEvent( events.append(
ReActEvent(
event_type="confirmation_result", event_type="confirmation_result",
step=step, step=step,
data={"confirmation_id": confirmation_id, "approved": False}, data={"confirmation_id": confirmation_id, "approved": False},
)) )
)
return tool_result, events return tool_result, events
@ -1418,7 +1757,7 @@ class ReActEngine:
""" """
indices = [] indices = []
for i, tc in enumerate(tool_calls): for i, tc in enumerate(tool_calls):
args = tc.arguments if hasattr(tc, 'arguments') else {} args = tc.arguments if hasattr(tc, "arguments") else {}
if isinstance(args, dict) and args.get("_parallelizable") is True: if isinstance(args, dict) and args.get("_parallelizable") is True:
indices.append(i) indices.append(i)
return indices return indices
@ -1434,9 +1773,7 @@ class ReActEngine:
calls: list[dict[str, Any]] = [] calls: list[dict[str, Any]] = []
# 格式 1: Action: tool_name(args) # 格式 1: Action: tool_name(args)
action_pattern = re.compile( action_pattern = re.compile(r"Action:\s*(\w+)\((.+?)\)", re.DOTALL)
r"Action:\s*(\w+)\((.+?)\)", re.DOTALL
)
for match in action_pattern.finditer(content): for match in action_pattern.finditer(content):
name = match.group(1) name = match.group(1)
args_str = match.group(2) args_str = match.group(2)
@ -1450,9 +1787,7 @@ class ReActEngine:
return calls return calls
# 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n``` # 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n```
code_block_pattern = re.compile( code_block_pattern = re.compile(r"```tool\s*\n(.*?)\n\s*```", re.DOTALL)
r"```tool\s*\n(.*?)\n\s*```", re.DOTALL
)
for match in code_block_pattern.finditer(content): for match in code_block_pattern.finditer(content):
json_str = match.group(1).strip() json_str = match.group(1).strip()
try: try:
@ -1469,9 +1804,7 @@ class ReActEngine:
# 格式 3: <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use> # 格式 3: <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
# 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式 # 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式
tool_use_pattern = re.compile( tool_use_pattern = re.compile(r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL)
r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL
)
for match in tool_use_pattern.finditer(content): for match in tool_use_pattern.finditer(content):
json_str = match.group(1).strip() json_str = match.group(1).strip()
try: try:

View File

@ -1,4 +1,11 @@
"""Expert Team routing — resolves user input to ExpertTeam configuration.""" """Expert Team routing — resolves @team prefix to ExpertTeam configuration.
简化说明U3
- 仅通过 @team 前缀触发团队模式
- 移除基于复杂度的自动建议
- 保留 @team:expert1,expert2 指定专家成员
- 保留 resolve_expert_configs 解析专家配置
"""
import logging import logging
import re import re
@ -21,37 +28,37 @@ MAX_EXPERTS = 10 # Maximum number of experts in a team
@dataclass @dataclass
class ExpertTeamRoutingResult: class ExpertTeamRoutingResult:
"""Result of expert team routing resolution.""" """Result of expert team routing resolution.
In hub-and-spoke mode, routing is triggered exclusively by the @team prefix.
"""
matched: bool = False matched: bool = False
team_mode: bool = False team_mode: bool = False
specified_experts: list[str] = field(default_factory=list) specified_experts: list[str] = field(default_factory=list)
task_content: str = "" task_content: str = ""
auto_compose: bool = False auto_compose: bool = False
complexity: float = 0.0 match_method: str = "" # "explicit_team" | ""
match_method: str = "" # "explicit_team" | "complexity_suggestion"
class ExpertTeamRouter: class ExpertTeamRouter:
"""Routes user input to Expert Team mode. """Routes user input to Expert Team mode via @team prefix.
Supports: Supports:
- @team prefix trigger team mode - @team prefix trigger team mode (auto-compose members)
- @team:analyst,strategist specify team members - @team:analyst,strategist specify team members by name
- High complexity suggest team mode upgrade
""" """
COMPLEXITY_THRESHOLD = 0.7 # Above this, suggest team mode
def __init__(self, template_registry: ExpertTemplateRegistry | None = None): def __init__(self, template_registry: ExpertTemplateRegistry | None = None):
self._registry = template_registry or ExpertTemplateRegistry() self._registry = template_registry or ExpertTemplateRegistry()
def resolve(self, content: str, complexity: float = 0.0) -> ExpertTeamRoutingResult: def resolve(self, content: str) -> ExpertTeamRoutingResult:
"""Resolve user input to an ExpertTeamRoutingResult. """Resolve user input to an ExpertTeamRoutingResult.
Only @team prefix triggers team mode. No complexity-based suggestion.
Args: Args:
content: User's input message content: User's input message
complexity: Pre-computed complexity score (0.0-1.0)
Returns: Returns:
ExpertTeamRoutingResult with routing decision ExpertTeamRoutingResult with routing decision
@ -94,27 +101,15 @@ class ExpertTeamRouter:
return result return result
# Check complexity-based suggestion
if complexity >= self.COMPLEXITY_THRESHOLD:
result.matched = True
result.team_mode = True
result.auto_compose = True
result.complexity = complexity
result.task_content = content
result.match_method = "complexity_suggestion"
return result
# Not a team mode request # Not a team mode request
result.matched = False result.matched = False
result.team_mode = False result.team_mode = False
result.task_content = content result.task_content = content
result.complexity = complexity
return result return result
def can_handle(self, content: str) -> bool: def can_handle(self, content: str) -> bool:
"""Check whether any registered expert template can handle the given content. """Check whether any registered expert template can handle the given content.
Used by CostAwareRouter to decide whether to upgrade REACT TEAM_COLLAB.
Returns True if at least one template's name or description overlaps with Returns True if at least one template's name or description overlaps with
content tokens, or if any templates exist (auto-compose can always form a team). content tokens, or if any templates exist (auto-compose can always form a team).
""" """

View File

@ -155,6 +155,7 @@ async def lifespan(app: FastAPI):
# Restore conversation history from persistent store (async, in lifespan) # Restore conversation history from persistent store (async, in lifespan)
from agentkit.server.routes.portal import _conversation_store from agentkit.server.routes.portal import _conversation_store
await _conversation_store.restore_from_store() await _conversation_store.restore_from_store()
# In GUI mode, ensure a default chat agent exists with memory + tools # In GUI mode, ensure a default chat agent exists with memory + tools
@ -579,13 +580,13 @@ def create_app(
app.state.quality_gate = QualityGate() app.state.quality_gate = QualityGate()
app.state.output_standardizer = OutputStandardizer() app.state.output_standardizer = OutputStandardizer()
# Initialize SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT) # Initialize RequestPreprocessor (minimal preprocessing: @skill prefix + greeting regex + REACT)
from agentkit.chat.simple_router import SimpleRouter from agentkit.chat.request_preprocessor import RequestPreprocessor
simple_router = SimpleRouter( request_preprocessor = RequestPreprocessor(
skill_registry=app.state.skill_registry, skill_registry=app.state.skill_registry,
) )
app.state.simple_router = simple_router app.state.request_preprocessor = request_preprocessor
# Initialize OrganizationContext from AgentPool + SkillRegistry # Initialize OrganizationContext from AgentPool + SkillRegistry
from agentkit.org.context import OrganizationContext from agentkit.org.context import OrganizationContext
@ -606,39 +607,6 @@ def create_app(
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway) alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
app.state.alignment_guard = alignment_guard app.state.alignment_guard = alignment_guard
# CostAwareRouter is no longer used by portal/chat routes (replaced by SimpleRouter).
# It is kept on app.state for backward compatibility with any external consumers.
# To re-enable, set router.legacy_cost_aware_router: true in agentkit.yaml.
router_conf = server_config.router if server_config and server_config.router else {}
if router_conf.get("legacy_cost_aware_router"):
from agentkit.chat.skill_routing import CostAwareRouter
auction_enabled = False
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
auction_enabled = server_config.marketplace.get("auction_enabled", False)
semantic_router = None
if router_conf.get("semantic", {}).get("enabled"):
try:
from agentkit.chat.semantic_router import SemanticRouter
semantic_router = SemanticRouter(
embedder=app.state.llm_gateway._embedder,
similarity_high=router_conf["semantic"].get("similarity_high", 0.85),
similarity_low=router_conf["semantic"].get("similarity_low", 0.6),
)
except Exception as e:
logger.warning(f"Failed to initialize semantic router: {e}")
cost_aware_router = CostAwareRouter(
llm_gateway=app.state.llm_gateway,
org_context=org_context,
auction_enabled=auction_enabled,
classifier=router_conf.get("classifier", "heuristic"),
merged_llm_classify=router_conf.get("merged_llm_classify", True),
semantic_router=semantic_router,
)
app.state.cost_aware_router = cost_aware_router
# Initialize task store from config # Initialize task store from config
ts_config = server_config.task_store if server_config else {} ts_config = server_config.task_store if server_config else {}
# Merge CLI overrides from AGENTKIT_TASK_STORE env var # Merge CLI overrides from AGENTKIT_TASK_STORE env var
@ -680,8 +648,10 @@ def create_app(
) )
app.state.session_manager = SessionManager(store=session_store) app.state.session_manager = SessionManager(store=session_store)
# Inject SessionManager into Portal's ConversationStore for persistence # Inject SessionManager into Portal's ConversationStore for persistence (legacy only)
from agentkit.server.routes.portal import _conversation_store from agentkit.server.routes.portal import _conversation_store
if hasattr(_conversation_store, "set_session_manager"):
_conversation_store.set_session_manager(app.state.session_manager) _conversation_store.set_session_manager(app.state.session_manager)
# Initialize evolution store if configured # Initialize evolution store if configured

View File

@ -97,12 +97,21 @@ chat_manager = ChatConnectionManager()
# ── Helper ──────────────────────────────────────────────────────────── # ── Helper ────────────────────────────────────────────────────────────
_VALID_TEAM_EVENT_TYPES = frozenset({ _VALID_TEAM_EVENT_TYPES = frozenset(
"team_formed", "expert_step", "expert_result", {
"plan_update", "team_synthesis", "team_dissolved", "team_formed",
"plan_step", "phase_started", "phase_completed", "phase_failed", "expert_step",
"expert_result",
"plan_update",
"team_synthesis",
"team_dissolved",
"plan_step",
"phase_started",
"phase_completed",
"phase_failed",
"replanning", "replanning",
}) }
)
async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None: async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None:
@ -125,10 +134,12 @@ async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) ->
if event_type not in _VALID_TEAM_EVENT_TYPES: if event_type not in _VALID_TEAM_EVENT_TYPES:
logger.warning(f"emit_team_event: invalid event_type '{event_type}'") logger.warning(f"emit_team_event: invalid event_type '{event_type}'")
return return
await websocket.send_json({ await websocket.send_json(
{
"type": event_type, "type": event_type,
"data": data, "data": data,
}) }
)
def _get_session_manager(request: Request) -> SessionManager: def _get_session_manager(request: Request) -> SessionManager:
@ -236,11 +247,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
else: else:
react_engine.reset() react_engine.reset()
tools = agent._tool_registry.list_tools() if agent._tool_registry else [] tools = agent._tool_registry.list_tools() if agent._tool_registry else []
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) system_prompt = getattr(agent, "_system_prompt", None) or (
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
)
result = await react_engine.execute( result = await react_engine.execute(
messages=chat_messages, messages=chat_messages,
tools=tools, tools=tools,
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"), model=agent.get_model()
if hasattr(agent, "get_model")
else getattr(agent, "_llm_model", "default"),
agent_name=agent.name, agent_name=agent.name,
system_prompt=system_prompt, system_prompt=system_prompt,
) )
@ -319,7 +334,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
sm: SessionManager = websocket.app.state.session_manager sm: SessionManager = websocket.app.state.session_manager
session = await sm.get_session(session_id) session = await sm.get_session(session_id)
if session is None: if session is None:
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}}) await websocket.send_json(
{"type": "error", "data": {"message": f"Session '{session_id}' not found"}}
)
await websocket.close(code=1000, reason="Session not found") await websocket.close(code=1000, reason="Session not found")
return return
@ -367,10 +384,14 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
# Clean up completed tasks first # Clean up completed tasks first
active_tasks.difference_update(t for t in active_tasks if t.done()) active_tasks.difference_update(t for t in active_tasks if t.done())
if len(active_tasks) >= _MAX_CONCURRENT_TASKS: if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
await websocket.send_json({ await websocket.send_json(
{
"type": "error", "type": "error",
"data": {"message": "Too many concurrent requests. Please wait for the current task to complete."}, "data": {
}) "message": "Too many concurrent requests. Please wait for the current task to complete."
},
}
)
continue continue
# Run in background task so the WebSocket receive loop stays free # Run in background task so the WebSocket receive loop stays free
@ -378,7 +399,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
# is waiting for user confirmation (otherwise deadlock). # is waiting for user confirmation (otherwise deadlock).
task = asyncio.create_task( task = asyncio.create_task(
_handle_chat_message( _handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations, websocket,
session_id,
content,
sm,
message_token,
pending_replies,
pending_confirmations,
model_override=model, model_override=model,
) )
) )
@ -396,12 +423,16 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
# Reply to confirmation request # Reply to confirmation request
confirmation_id = msg.get("confirmation_id") confirmation_id = msg.get("confirmation_id")
approved = msg.get("approved", False) approved = msg.get("approved", False)
logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}") logger.info(
f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}"
)
if confirmation_id and confirmation_id in pending_confirmations: if confirmation_id and confirmation_id in pending_confirmations:
pending_confirmations[confirmation_id].set_result(approved) pending_confirmations[confirmation_id].set_result(approved)
logger.info(f"Confirmation {confirmation_id} set_result({approved})") logger.info(f"Confirmation {confirmation_id} set_result({approved})")
else: else:
logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations") logger.warning(
f"Confirmation {confirmation_id!r} not found in pending_confirmations"
)
elif msg_type == "cancel": elif msg_type == "cancel":
cancellation_token.cancel() cancellation_token.cancel()
@ -441,9 +472,9 @@ async def _handle_chat_message(
) -> None: ) -> None:
"""Handle a user message: append to session, execute Agent, stream events. """Handle a user message: append to session, execute Agent, stream events.
Uses SimpleRouter for minimal routing: @skill prefix + greeting regex + REACT. Uses RequestPreprocessor for minimal preprocessing: @skill prefix + greeting regex + REACT.
""" """
from agentkit.chat.simple_router import SimpleRouter from agentkit.chat.request_preprocessor import RequestPreprocessor
# Resolve Agent first (needed for default tools/prompt) # Resolve Agent first (needed for default tools/prompt)
pool = websocket.app.state.agent_pool pool = websocket.app.state.agent_pool
@ -454,19 +485,27 @@ async def _handle_chat_message(
agent = pool.get_agent(session.agent_name) agent = pool.get_agent(session.agent_name)
if agent is None: if agent is None:
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}) await websocket.send_json(
{"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}
)
return return
# Default execution parameters from agent # Default execution parameters from agent
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else [] default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) default_system_prompt = getattr(agent, "_system_prompt", None) or (
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default") agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
)
default_model = (
agent.get_model()
if hasattr(agent, "get_model")
else getattr(agent, "_llm_model", "default")
)
# Resolve skill routing using SimpleRouter # Resolve skill routing using RequestPreprocessor
skill_registry = getattr(websocket.app.state, "skill_registry", None) skill_registry = getattr(websocket.app.state, "skill_registry", None)
simple_router: SimpleRouter = websocket.app.state.simple_router request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
routing = await simple_router.route( routing = await request_preprocessor.preprocess(
content=content, content=content,
skill_registry=skill_registry, skill_registry=skill_registry,
default_tools=default_tools, default_tools=default_tools,
@ -477,7 +516,9 @@ async def _handle_chat_message(
# Debug: log tools that will be passed to ReActEngine # Debug: log tools that will be passed to ReActEngine
tool_names = [t.name for t in routing.tools] tool_names = [t.name for t in routing.tools]
logger.info(f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}") logger.info(
f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}"
)
# Apply model override from frontend selector # Apply model override from frontend selector
if model_override: if model_override:
@ -485,17 +526,21 @@ async def _handle_chat_message(
# Notify frontend about skill match # Notify frontend about skill match
if routing.matched: if routing.matched:
await websocket.send_json({ await websocket.send_json(
{
"type": "skill_match", "type": "skill_match",
"data": { "data": {
"skill": routing.skill_name, "skill": routing.skill_name,
"method": routing.match_method, "method": routing.match_method,
"confidence": routing.match_confidence, "confidence": routing.match_confidence,
}, },
}) }
)
# Append user message (use clean_content if @skill: prefix was stripped) # Append user message (use clean_content if @skill: prefix was stripped)
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content) await sm.append_message(
session_id=session_id, role=MessageRole.USER, content=routing.clean_content
)
# Get full conversation history # Get full conversation history
chat_messages = await sm.get_chat_messages(session_id) chat_messages = await sm.get_chat_messages(session_id)
@ -516,12 +561,15 @@ async def _handle_chat_message(
final_content = response.content or "" final_content = response.content or ""
if not final_content or not final_content.strip(): if not final_content or not final_content.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE from agentkit.core.fallback import EMPTY_LLM_RESPONSE
final_content = EMPTY_LLM_RESPONSE final_content = EMPTY_LLM_RESPONSE
await websocket.send_json({ await websocket.send_json(
{
"type": "final_answer", "type": "final_answer",
"content": final_content, "content": final_content,
"is_final": True, "is_final": True,
}) }
)
await sm.append_message( await sm.append_message(
session_id=session_id, session_id=session_id,
role=MessageRole.ASSISTANT, role=MessageRole.ASSISTANT,
@ -557,14 +605,16 @@ async def _handle_chat_message(
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool: async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
"""Send confirmation request to frontend via WebSocket and wait for user reply.""" """Send confirmation request to frontend via WebSocket and wait for user reply."""
# Send confirmation request to frontend # Send confirmation request to frontend
await websocket.send_json({ await websocket.send_json(
{
"type": "confirmation_request", "type": "confirmation_request",
"data": { "data": {
"confirmation_id": confirmation_id, "confirmation_id": confirmation_id,
"command": command, "command": command,
"reason": reason, "reason": reason,
}, },
}) }
)
# Create a Future and wait for the user's reply # Create a Future and wait for the user's reply
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -578,10 +628,12 @@ async def _handle_chat_message(
logger.info(f"Confirmation request {confirmation_id} resolved: {result}") logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
# Immediately notify frontend of the result so the card updates # Immediately notify frontend of the result so the card updates
# without waiting for the tool to re-execute # without waiting for the tool to re-execute
await websocket.send_json({ await websocket.send_json(
{
"type": "confirmation_result", "type": "confirmation_result",
"data": {"confirmation_id": confirmation_id, "approved": result}, "data": {"confirmation_id": confirmation_id, "approved": result},
}) }
)
return result return result
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"Confirmation request {confirmation_id} timed out") logger.warning(f"Confirmation request {confirmation_id} timed out")
@ -592,7 +644,9 @@ async def _handle_chat_message(
finally: finally:
_pending_confirmations.pop(confirmation_id, None) _pending_confirmations.pop(confirmation_id, None)
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}") logger.info(
f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}"
)
try: try:
final_content = "" final_content = ""
@ -615,12 +669,15 @@ async def _handle_chat_message(
final_content = event.data.get("output", "") final_content = event.data.get("output", "")
if not final_content or not final_content.strip(): if not final_content or not final_content.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE from agentkit.core.fallback import EMPTY_LLM_RESPONSE
final_content = EMPTY_LLM_RESPONSE final_content = EMPTY_LLM_RESPONSE
await websocket.send_json({ await websocket.send_json(
{
"type": "final_answer", "type": "final_answer",
"content": final_content, "content": final_content,
"is_final": True, "is_final": True,
}) }
)
elif event.event_type == "token": elif event.event_type == "token":
# Buffer tokens instead of sending immediately # Buffer tokens instead of sending immediately
token_buffer.append(event.data.get("content", "")) token_buffer.append(event.data.get("content", ""))
@ -640,30 +697,36 @@ async def _handle_chat_message(
buffered_text = "".join(token_buffer) buffered_text = "".join(token_buffer)
token_buffer.clear() token_buffer.clear()
await websocket.send_json({"type": "thinking", "content": buffered_text}) await websocket.send_json({"type": "thinking", "content": buffered_text})
await websocket.send_json({ await websocket.send_json(
{
"type": "step", "type": "step",
"data": { "data": {
"event_type": event.event_type, "event_type": event.event_type,
"step": event.step, "step": event.step,
"data": event.data, "data": event.data,
}, },
}) }
)
elif event.event_type == "confirmation_request": elif event.event_type == "confirmation_request":
pass pass
elif event.event_type == "confirmation_result": elif event.event_type == "confirmation_result":
await websocket.send_json({ await websocket.send_json(
{
"type": "confirmation_result", "type": "confirmation_result",
"data": event.data, "data": event.data,
}) }
)
else: else:
await websocket.send_json({ await websocket.send_json(
{
"type": "step", "type": "step",
"data": { "data": {
"event_type": event.event_type, "event_type": event.event_type,
"step": event.step, "step": event.step,
"data": event.data, "data": event.data,
}, },
}) }
)
# Append assistant reply to session # Append assistant reply to session
if final_content: if final_content:

View File

@ -1,13 +1,13 @@
"""E2E Agent Capability Tests — SimpleRouter Backtest (Real LLM). """E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM).
Tests SimpleRouter.route() using real LLM configuration loaded from Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from
agentkit.yaml. Records full SkillRoutingResult for precise analysis. agentkit.yaml. Records full SkillRoutingResult for precise analysis.
Key differences from old CostAwareRouter backtest: Key differences from old CostAwareRouter backtest:
- No HeuristicClassifier complexity scoring - No HeuristicClassifier complexity scoring
- No IntentRouter LLM classification - No IntentRouter LLM classification
- No SemanticRouter embedding matching - No SemanticRouter embedding matching
- SimpleRouter: @skill prefix + greeting regex + default REACT - RequestPreprocessor: @skill prefix + greeting regex + default REACT
""" """
import asyncio import asyncio
@ -16,7 +16,7 @@ from pathlib import Path
import pytest import pytest
from agentkit.chat.simple_router import SimpleRouter from agentkit.chat.request_preprocessor import RequestPreprocessor
from agentkit.chat.skill_routing import ExecutionMode from agentkit.chat.skill_routing import ExecutionMode
from agentkit.server.app import _build_llm_gateway, _build_skill_registry from agentkit.server.app import _build_llm_gateway, _build_skill_registry
from agentkit.server.config import ServerConfig from agentkit.server.config import ServerConfig
@ -95,7 +95,7 @@ def _find_config_path() -> str | None:
return None return None
def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]: def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]:
config_path = _find_config_path() config_path = _find_config_path()
if not config_path: if not config_path:
pytest.skip("No agentkit.yaml found") pytest.skip("No agentkit.yaml found")
@ -132,15 +132,15 @@ def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]:
pytest.skip("No LLM provider with valid API key") pytest.skip("No LLM provider with valid API key")
skill_registry = _build_skill_registry(server_config) skill_registry = _build_skill_registry(server_config)
router = SimpleRouter(skill_registry=skill_registry) preprocessor = RequestPreprocessor(skill_registry=skill_registry)
return router, skill_registry return preprocessor, skill_registry
_cached_components: tuple[SimpleRouter, SkillRegistry] | None = None _cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None
def _get_components() -> tuple[SimpleRouter, SkillRegistry]: def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]:
global _cached_components global _cached_components
if _cached_components is None: if _cached_components is None:
_cached_components = _build_real_components() _cached_components = _build_real_components()
@ -153,8 +153,8 @@ def _get_components() -> tuple[SimpleRouter, SkillRegistry]:
@pytest.mark.e2e_capability @pytest.mark.e2e_capability
class TestSimpleRouterBasic: class TestRequestPreprocessorBasic:
"""Test SimpleRouter basic routing: greeting → DIRECT_CHAT, others → REACT.""" """Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT."""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case", "case",
@ -162,9 +162,9 @@ class TestSimpleRouterBasic:
ids=[c["id"] for c in ROUTING_TEST_CASES], ids=[c["id"] for c in ROUTING_TEST_CASES],
) )
def test_routing(self, case: dict): def test_routing(self, case: dict):
router, skill_registry = _get_components() preprocessor, skill_registry = _get_components()
result = asyncio.run( result = asyncio.run(
router.route( preprocessor.preprocess(
content=case["input"], content=case["input"],
skill_registry=skill_registry, skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"], default_tools=["shell", "search", "file_read"],
@ -179,8 +179,8 @@ class TestSimpleRouterBasic:
@pytest.mark.e2e_capability @pytest.mark.e2e_capability
class TestSimpleRouterParaphraseConsistency: class TestRequestPreprocessorParaphraseConsistency:
"""Test that paraphrased inputs route to the same execution mode.""" """Test that paraphrased inputs preprocess to the same execution mode."""
@pytest.mark.parametrize( @pytest.mark.parametrize(
"case", "case",
@ -188,12 +188,12 @@ class TestSimpleRouterParaphraseConsistency:
ids=[c["id"] for c in PARAPHRASE_CASES], ids=[c["id"] for c in PARAPHRASE_CASES],
) )
def test_paraphrase_consistency(self, case: dict): def test_paraphrase_consistency(self, case: dict):
router, skill_registry = _get_components() preprocessor, skill_registry = _get_components()
expected_mode = case["expected_mode"] expected_mode = case["expected_mode"]
# Test original # Test original
result = asyncio.run( result = asyncio.run(
router.route( preprocessor.preprocess(
content=case["original"], content=case["original"],
skill_registry=skill_registry, skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"], default_tools=["shell", "search", "file_read"],
@ -206,7 +206,7 @@ class TestSimpleRouterParaphraseConsistency:
# Test all paraphrases # Test all paraphrases
for para in case["paraphrases"]: for para in case["paraphrases"]:
result = asyncio.run( result = asyncio.run(
router.route( preprocessor.preprocess(
content=para, content=para,
skill_registry=skill_registry, skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"], default_tools=["shell", "search", "file_read"],
@ -218,19 +218,19 @@ class TestSimpleRouterParaphraseConsistency:
@pytest.mark.e2e_capability @pytest.mark.e2e_capability
class TestSimpleRouterMetrics: class TestRequestPreprocessorMetrics:
"""Compute and report routing accuracy metrics.""" """Compute and report preprocessing accuracy metrics."""
def test_accuracy_report(self): def test_accuracy_report(self):
"""Run all test cases and compute accuracy metrics.""" """Run all test cases and compute accuracy metrics."""
router, skill_registry = _get_components() preprocessor, skill_registry = _get_components()
total = len(ROUTING_TEST_CASES) total = len(ROUTING_TEST_CASES)
correct = 0 correct = 0
results = [] results = []
for case in ROUTING_TEST_CASES: for case in ROUTING_TEST_CASES:
result = asyncio.run( result = asyncio.run(
router.route( preprocessor.preprocess(
content=case["input"], content=case["input"],
skill_registry=skill_registry, skill_registry=skill_registry,
default_tools=["shell", "search", "file_read"], default_tools=["shell", "search", "file_read"],
@ -251,7 +251,7 @@ class TestSimpleRouterMetrics:
accuracy = correct / total * 100 accuracy = correct / total * 100
print(f"\n{'='*60}") print(f"\n{'='*60}")
print(f"SimpleRouter Accuracy Report") print(f"RequestPreprocessor Accuracy Report")
print(f"{'='*60}") print(f"{'='*60}")
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%") print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
print(f"{'-'*60}") print(f"{'-'*60}")

View File

@ -1,10 +1,10 @@
"""Unit tests for SimpleRouter — minimal routing layer.""" """Unit tests for RequestPreprocessor — minimal preprocessing layer."""
from __future__ import annotations from __future__ import annotations
import pytest import pytest
from agentkit.chat.simple_router import SimpleRouter from agentkit.chat.request_preprocessor import RequestPreprocessor
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
@ -51,8 +51,8 @@ def registry() -> MockSkillRegistry:
@pytest.fixture @pytest.fixture
def router(registry: MockSkillRegistry) -> SimpleRouter: def preprocessor(registry: MockSkillRegistry) -> RequestPreprocessor:
return SimpleRouter( return RequestPreprocessor(
skill_registry=registry, skill_registry=registry,
default_tools=["shell", "search", "file_read"], default_tools=["shell", "search", "file_read"],
default_system_prompt="You are a helpful assistant.", default_system_prompt="You are a helpful assistant.",
@ -67,8 +67,8 @@ def router(registry: MockSkillRegistry) -> SimpleRouter:
class TestSkillPrefix: class TestSkillPrefix:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_skill_prefix_routes_to_skill(self, router: SimpleRouter): async def test_skill_prefix_routes_to_skill(self, preprocessor: RequestPreprocessor):
result = await router.route("@skill:shell_agent 查看当前ip") result = await preprocessor.preprocess("@skill:shell_agent 查看当前ip")
assert result.matched is True assert result.matched is True
assert result.skill_name == "shell_agent" assert result.skill_name == "shell_agent"
assert result.match_method == "skill_prefix" assert result.match_method == "skill_prefix"
@ -76,22 +76,22 @@ class TestSkillPrefix:
assert result.execution_mode == ExecutionMode.SKILL_REACT assert result.execution_mode == ExecutionMode.SKILL_REACT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_skill_prefix_direct_mode(self, router: SimpleRouter): async def test_skill_prefix_direct_mode(self, preprocessor: RequestPreprocessor):
result = await router.route("@skill:direct_agent 翻译hello") result = await preprocessor.preprocess("@skill:direct_agent 翻译hello")
assert result.matched is True assert result.matched is True
assert result.skill_name == "direct_agent" assert result.skill_name == "direct_agent"
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_skill_prefix_rewoo_mode(self, router: SimpleRouter): async def test_skill_prefix_rewoo_mode(self, preprocessor: RequestPreprocessor):
result = await router.route("@skill:rewoo_agent 重构代码") result = await preprocessor.preprocess("@skill:rewoo_agent 重构代码")
assert result.matched is True assert result.matched is True
assert result.skill_name == "rewoo_agent" assert result.skill_name == "rewoo_agent"
assert result.execution_mode == ExecutionMode.REWOO assert result.execution_mode == ExecutionMode.REWOO
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unknown_skill_falls_back_to_react(self, router: SimpleRouter): async def test_unknown_skill_falls_back_to_react(self, preprocessor: RequestPreprocessor):
result = await router.route("@skill:nonexistent 查询") result = await preprocessor.preprocess("@skill:nonexistent 查询")
assert result.matched is False assert result.matched is False
assert result.match_method == "skill_not_found_fallback" assert result.match_method == "skill_not_found_fallback"
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
@ -103,30 +103,30 @@ class TestSkillPrefix:
class TestDirectChat: class TestDirectChat:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_greeting_cn(self, router: SimpleRouter): async def test_greeting_cn(self, preprocessor: RequestPreprocessor):
result = await router.route("你好") result = await preprocessor.preprocess("你好")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
assert result.match_method == "regex_direct" assert result.match_method == "regex_direct"
assert result.tools == [] assert result.tools == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_greeting_en(self, router: SimpleRouter): async def test_greeting_en(self, preprocessor: RequestPreprocessor):
result = await router.route("hello") result = await preprocessor.preprocess("hello")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chitchat(self, router: SimpleRouter): async def test_chitchat(self, preprocessor: RequestPreprocessor):
result = await router.route("谢谢") result = await preprocessor.preprocess("谢谢")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_identity_question(self, router: SimpleRouter): async def test_identity_question(self, preprocessor: RequestPreprocessor):
result = await router.route("你是谁") result = await preprocessor.preprocess("你是谁")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_identity_question_en(self, router: SimpleRouter): async def test_identity_question_en(self, preprocessor: RequestPreprocessor):
result = await router.route("who are you") result = await preprocessor.preprocess("who are you")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@ -136,15 +136,15 @@ class TestDirectChat:
class TestDefaultReact: class TestDefaultReact:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_colloquial_tool_query(self, router: SimpleRouter): async def test_colloquial_tool_query(self, preprocessor: RequestPreprocessor):
"""口语化工具查询 — 这是之前路由层误判的核心场景""" """口语化工具查询 — 这是之前路由层误判的核心场景"""
result = await router.route("查下ip") result = await preprocessor.preprocess("查下ip")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
assert result.match_method == "default_react" assert result.match_method == "default_react"
assert len(result.tools) > 0 assert len(result.tools) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_various_colloquial_expressions(self, router: SimpleRouter): async def test_various_colloquial_expressions(self, preprocessor: RequestPreprocessor):
"""各种口语化说法都应走 REACT让 LLM 决定""" """各种口语化说法都应走 REACT让 LLM 决定"""
queries = [ queries = [
"查看当前ip", "查看当前ip",
@ -157,30 +157,30 @@ class TestDefaultReact:
"检查服务状态", "检查服务状态",
] ]
for query in queries: for query in queries:
result = await router.route(query) result = await preprocessor.preprocess(query)
assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}" assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_complex_query(self, router: SimpleRouter): async def test_complex_query(self, preprocessor: RequestPreprocessor):
result = await router.route("帮我分析一下这个数据并生成报告") result = await preprocessor.preprocess("帮我分析一下这个数据并生成报告")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_translation_goes_react(self, router: SimpleRouter): async def test_translation_goes_react(self, preprocessor: RequestPreprocessor):
"""翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具""" """翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具"""
result = await router.route("翻译hello为中文") result = await preprocessor.preprocess("翻译hello为中文")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
# LLM will see tools but decide not to use them # LLM will see tools but decide not to use them
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_tools_included(self, router: SimpleRouter): async def test_default_tools_included(self, preprocessor: RequestPreprocessor):
result = await router.route("查下ip") result = await preprocessor.preprocess("查下ip")
assert "shell" in result.tools assert "shell" in result.tools
assert "search" in result.tools assert "search" in result.tools
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_default_system_prompt(self, router: SimpleRouter): async def test_default_system_prompt(self, preprocessor: RequestPreprocessor):
result = await router.route("查下ip") result = await preprocessor.preprocess("查下ip")
assert result.system_prompt == "You are a helpful assistant." assert result.system_prompt == "You are a helpful assistant."
@ -190,31 +190,31 @@ class TestDefaultReact:
class TestEdgeCases: class TestEdgeCases:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_input(self, router: SimpleRouter): async def test_empty_input(self, preprocessor: RequestPreprocessor):
result = await router.route("") result = await preprocessor.preprocess("")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_whitespace_only(self, router: SimpleRouter): async def test_whitespace_only(self, preprocessor: RequestPreprocessor):
result = await router.route(" ") result = await preprocessor.preprocess(" ")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_greeting_with_extra_spaces(self, router: SimpleRouter): async def test_greeting_with_extra_spaces(self, preprocessor: RequestPreprocessor):
result = await router.route(" 你好 ") result = await preprocessor.preprocess(" 你好 ")
assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.execution_mode == ExecutionMode.DIRECT_CHAT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_registry(self): async def test_no_registry(self):
"""Router without skill registry should still work for non-skill queries""" """Preprocessor without skill registry should still work for non-skill queries"""
router = SimpleRouter(default_tools=["shell"]) preprocessor = RequestPreprocessor(default_tools=["shell"])
result = await router.route("查下ip") result = await preprocessor.preprocess("查下ip")
assert result.execution_mode == ExecutionMode.REACT assert result.execution_mode == ExecutionMode.REACT
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_override_defaults(self, router: SimpleRouter): async def test_override_defaults(self, preprocessor: RequestPreprocessor):
"""Route-time overrides should work""" """Preprocess-time overrides should work"""
result = await router.route( result = await preprocessor.preprocess(
"查下ip", "查下ip",
default_tools=["shell_only"], default_tools=["shell_only"],
default_model="gpt-4o", default_model="gpt-4o",

View File

@ -1,13 +1,11 @@
"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier.""" """Unit tests for ExpertTeamRouter and skill routing utilities."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import MagicMock
from agentkit.chat.skill_routing import ( from agentkit.chat.skill_routing import (
CostAwareRouter,
ExecutionMode, ExecutionMode,
HeuristicClassifier,
SkillRoutingResult, SkillRoutingResult,
) )
from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.config import ExpertConfig, ExpertTemplate
@ -20,16 +18,6 @@ from agentkit.experts.router import ExpertTeamRouter
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter:
"""Create a CostAwareRouter with mocked dependencies."""
return CostAwareRouter(
llm_gateway=None,
model="test",
classifier="heuristic",
expert_team_router=expert_team_router,
)
def _make_team_router_with_templates() -> ExpertTeamRouter: def _make_team_router_with_templates() -> ExpertTeamRouter:
"""Create an ExpertTeamRouter with sample templates.""" """Create an ExpertTeamRouter with sample templates."""
registry = ExpertTemplateRegistry() registry = ExpertTemplateRegistry()
@ -82,251 +70,51 @@ class TestExpertTeamRouterCanHandle:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Tests: _try_team_upgrade() # Tests: ExpertTeamRouter.resolve()
# ---------------------------------------------------------------------------
class TestTryTeamUpgrade:
def test_upgrade_react_to_team_collab(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="complex multi-step analysis task",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB
assert any(t.get("method") == "team_upgrade" for t in trace)
def test_no_upgrade_low_complexity(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="simple question",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.3,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
assert not any(t.get("method") == "team_upgrade" for t in trace)
def test_no_upgrade_no_team_router(self) -> None:
router = _make_router(expert_team_router=None)
result = SkillRoutingResult(
clean_content="complex analysis",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.9,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
def test_no_upgrade_empty_templates(self) -> None:
router = _make_router(expert_team_router=_make_team_router_empty())
result = SkillRoutingResult(
clean_content="complex analysis",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
def test_no_upgrade_direct_chat_mode(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="hello",
matched=False,
match_method="greeting",
match_confidence=1.0,
complexity=0.0,
execution_mode=ExecutionMode.DIRECT_CHAT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "hello", 0.0, trace)
assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT
def test_team_upgrade_exception_handled(self) -> None:
"""When ExpertTeamRouter raises, the upgrade is silently skipped."""
broken_router = MagicMock()
broken_router.can_handle.side_effect = RuntimeError("boom")
router = _make_router(expert_team_router=broken_router)
result = SkillRoutingResult(
clean_content="complex task",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
# ---------------------------------------------------------------------------
# Tests: ExpertTeamRouter.resolve() with complexity
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestExpertTeamRouterResolve: class TestExpertTeamRouterResolve:
def test_explicit_team_prefix(self) -> None: def test_explicit_team_prefix(self) -> None:
router = _make_team_router_with_templates() router = _make_team_router_with_templates()
result = router.resolve("@team:analyst,strategist analyze the market", 0.5) result = router.resolve("@team:analyst,strategist analyze the market")
assert result.team_mode is True assert result.team_mode is True
assert result.match_method == "explicit_team" assert result.match_method == "explicit_team"
assert "analyst" in result.specified_experts assert "analyst" in result.specified_experts
assert "strategist" in result.specified_experts assert "strategist" in result.specified_experts
def test_complexity_suggestion(self) -> None: def test_no_team_without_prefix(self) -> None:
router = _make_team_router_with_templates() router = _make_team_router_with_templates()
result = router.resolve("complex multi-step analysis", 0.8) result = router.resolve("simple question")
assert result.team_mode is True
assert result.match_method == "complexity_suggestion"
assert result.auto_compose is True
def test_no_team_low_complexity(self) -> None:
router = _make_team_router_with_templates()
result = router.resolve("simple question", 0.2)
assert result.team_mode is False assert result.team_mode is False
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Tests: HeuristicClassifier complexity calibration # Tests: SkillRoutingResult data structure
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestHeuristicClassifierLowComplexity: class TestSkillRoutingResult:
"""Low-complexity signals should produce scores < 0.3.""" def test_default_execution_mode(self) -> None:
result = SkillRoutingResult(
clean_content="test",
matched=False,
match_method="default_react",
match_confidence=0.8,
agent_name="default",
model="default",
execution_mode=ExecutionMode.REACT,
)
assert result.execution_mode == ExecutionMode.REACT
def setup_method(self) -> None: def test_direct_chat_mode(self) -> None:
self.clf = HeuristicClassifier() result = SkillRoutingResult(
clean_content="hello",
def test_chinese_greeting(self) -> None: matched=False,
assert self.clf.classify("你好") < 0.3 match_method="regex_direct",
match_confidence=1.0,
def test_chinese_greeting_hi(self) -> None: agent_name="default",
assert self.clf.classify("") < 0.3 model="default",
execution_mode=ExecutionMode.DIRECT_CHAT,
def test_english_greeting_hello(self) -> None: )
assert self.clf.classify("Hello") < 0.3 assert result.execution_mode == ExecutionMode.DIRECT_CHAT
def test_english_greeting_hi(self) -> None:
assert self.clf.classify("hi") < 0.3
def test_multiple_low_complexity_words(self) -> None:
assert self.clf.classify("嗨,早上好") < 0.3
def test_greeting_with_high_complexity_word_not_suppressed(self) -> None:
"""Low-complexity signal should NOT override high-complexity signal."""
# "你好" is low, but "分析" is high → should score high
assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5
class TestHeuristicClassifierIdentity:
"""Identity queries should produce scores < 0.3."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_who_are_you_cn(self) -> None:
assert self.clf.classify("你是谁") < 0.3
def test_what_is_your_name_cn(self) -> None:
assert self.clf.classify("你叫什么") < 0.3
class TestHeuristicClassifierNegation:
"""Negated high-complexity words should not contribute to score."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_negate_search_cn(self) -> None:
assert self.clf.classify("不要搜索") < 0.3
def test_negate_analyze_cn(self) -> None:
assert self.clf.classify("无需分析,直接告诉我答案") < 0.3
def test_partial_negation_still_high(self) -> None:
"""'搜索' negated but '分析' not — should still be high."""
assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5
class TestHeuristicClassifierThresholds:
"""Verify adjusted base scores."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_no_keyword_short_message(self) -> None:
assert self.clf.classify("好的") <= 0.10
def test_medium_complexity_base(self) -> None:
"""Medium complexity keyword should start at 0.35 (not 0.45)."""
score = self.clf.classify("如何使用Python")
# '如何' is medium → base 0.35, '' short question → -0.10 = 0.25
# but 'Python' is not in high/medium lists, so just medium base
assert 0.25 <= score <= 0.45
class TestHeuristicClassifierShortQuestion:
"""Short questions ending with /? should get deduction."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_short_question_deduction(self) -> None:
assert self.clf.classify("怎么用?") < 0.3
def test_long_question_no_deduction(self) -> None:
assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5
class TestHeuristicClassifierHighComplexity:
"""Complex tasks should produce scores > 0.7."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_two_high_complexity_words(self) -> None:
# "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80
assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7
def test_single_high_complexity_word(self) -> None:
# "分析" alone → base 0.65
assert self.clf.classify("分析市场趋势并生成报告") > 0.6
def test_execute_and_restart(self) -> None:
assert self.clf.classify("执行部署脚本并重启服务") > 0.7
class TestHeuristicClassifierEdgeCases:
"""Boundary conditions."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_empty_string(self) -> None:
assert self.clf.classify("") == 0.0
def test_whitespace_only(self) -> None:
assert self.clf.classify(" ") == 0.0
def test_long_low_complexity_message(self) -> None:
"""Even a long greeting should stay low."""
long_greeting = "你好" * 100 # >200 chars
assert self.clf.classify(long_greeting) <= 0.15

View File

@ -1,808 +0,0 @@
"""CostAwareRouter 单元测试 - 三层成本感知路由"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.router.intent import IntentRouter, RoutingResult
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_skill(
name: str,
keywords: list[str] | None = None,
description: str = "",
examples: list[str] | None = None,
) -> Skill:
"""快速构造一个带 intent 配置的 Skill"""
config = SkillConfig(
name=name,
agent_type="test",
task_mode="llm_generate",
prompt={"system": f"You are a {name} skill."},
intent={
"keywords": keywords or [],
"description": description,
"examples": examples or [],
},
)
return Skill(config=config)
def _make_llm_gateway(response_content: str) -> MagicMock:
"""构造一个 mock LLMGatewaychat 返回指定 content"""
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=LLMResponse(
content=response_content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
)
return gateway
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
"""构造一个 mock SkillRegistry"""
registry = MagicMock()
_skills = skills or []
registry.list_skills.return_value = _skills
def _get(name: str):
for s in _skills:
if s.name == name:
return s
raise KeyError(f"Skill '{name}' not found")
registry.get = MagicMock(side_effect=_get)
return registry
def _make_intent_router() -> IntentRouter:
"""构造一个无 LLM 的 IntentRouter仅关键词匹配"""
return IntentRouter(llm_gateway=None, model="default")
# ---------------------------------------------------------------------------
# Layer 0: Rule-based (zero cost)
# ---------------------------------------------------------------------------
class TestLayer0Greeting:
"""Layer 0: 问候模式匹配"""
@pytest.mark.asyncio
async def test_chinese_greeting_hits_layer0(self):
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
assert result.agent_name == "default"
assert result.matched is False
@pytest.mark.asyncio
async def test_english_greeting_hits_layer0(self):
"""'hello' 命中 Layer 0 问候规则"""
router = CostAwareRouter()
result = await router.route(
content="hello",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_greeting_with_punctuation(self):
"""'你好!' 带标点也命中 Layer 0"""
router = CostAwareRouter()
result = await router.route(
content="你好!",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
class TestLayer0ChatMode:
"""Layer 0: 简单对话模式"""
@pytest.mark.asyncio
async def test_thanks_hits_chat_mode(self):
"""'谢谢' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="谢谢",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_ok_hits_chat_mode(self):
"""'好的' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="好的",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
class TestLayer0ExplicitSkill:
"""Layer 0: @skill: 显式前缀"""
@pytest.mark.asyncio
async def test_skill_prefix_hits_layer0(self):
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
# 需要 IntentRouter 支持 LLM fallback
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=gateway, model="default")
router = CostAwareRouter()
result = await router.route(
content="@skill:search 搜索XX",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.complexity == 0.0
# ---------------------------------------------------------------------------
# Layer 1: LLM quick classify (~100 tokens)
# ---------------------------------------------------------------------------
class TestLayer1Classification:
"""Layer 1: LLM 快速分类"""
@pytest.mark.asyncio
async def test_medium_complexity_routes_via_intent_router(self):
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
# LLM 返回中等复杂度
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
# IntentRouter 也需要 LLM
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert 0.3 <= result.complexity <= 0.7
@pytest.mark.asyncio
async def test_low_complexity_routes_to_default(self):
"""低复杂度 (<0.3) 路由到默认 Agent"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="随便聊聊",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity < 0.3
assert result.match_method == "low_complexity"
assert result.agent_name == "default"
@pytest.mark.asyncio
async def test_no_llm_gateway_defaults_to_medium(self):
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
router = CostAwareRouter(llm_gateway=None)
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_llm_malformed_response_defaults_to_medium(self):
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_complexity_clamped_to_0_1(self):
"""复杂度值被限制在 [0.0, 1.0] 范围"""
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("超级复杂任务")
assert complexity == 1.0
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
complexity2 = await router2.quick_classify("简单任务")
assert complexity2 == 0.0
# ---------------------------------------------------------------------------
# Layer 2: Capability matching / Auction
# ---------------------------------------------------------------------------
class TestLayer2CapabilityMatching:
"""Layer 2: 能力匹配 / 拍卖"""
@pytest.mark.asyncio
async def test_high_complexity_triggers_capability_matching(self):
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="market-researcher")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
assert result.agent_name == "market-researcher"
assert result.matched is True
@pytest.mark.asyncio
async def test_layer2_with_org_context_object(self):
"""org_context.find_best_agent 返回对象时提取 name 属性"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
agent_obj = MagicMock()
agent_obj.name = "analyst-agent"
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value=agent_obj)
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.agent_name == "analyst-agent"
@pytest.mark.asyncio
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
# 回退到 IntentRouter可能匹配到 skill 或走 default
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
@pytest.mark.asyncio
async def test_layer2_org_context_find_best_agent_returns_none(self):
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value=None)
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
@pytest.mark.asyncio
async def test_auction_disabled_by_default(self):
"""拍卖模式默认禁用"""
router = CostAwareRouter()
assert router._auction_enabled is False
@pytest.mark.asyncio
async def test_auction_can_be_enabled(self):
"""拍卖模式可手动启用"""
router = CostAwareRouter(auction_enabled=True)
assert router._auction_enabled is True
# ---------------------------------------------------------------------------
# Transparency
# ---------------------------------------------------------------------------
class TestTransparency:
"""透明度级别切换"""
@pytest.mark.asyncio
async def test_silent_mode_no_trace(self):
"""SILENT 模式不暴露路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="SILENT",
)
assert result.execution_trace == []
assert result.transparency_level == "SILENT"
@pytest.mark.asyncio
async def test_verbose_mode_shows_trace(self):
"""VERBOSE 模式显示路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="VERBOSE",
)
assert len(result.execution_trace) > 0
assert result.execution_trace[0]["layer"] == 0
assert result.execution_trace[0]["method"] == "greeting"
assert result.transparency_level == "VERBOSE"
@pytest.mark.asyncio
async def test_trace_mode_shows_full_trace(self):
"""TRACE 模式显示完整路由追踪"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="analyst")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="TRACE",
)
assert len(result.execution_trace) > 0
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
layers = [t["layer"] for t in result.execution_trace]
assert 1 in layers # Layer 1 quick_classify
assert 2 in layers # Layer 2 capability matching
assert result.transparency_level == "TRACE"
@pytest.mark.asyncio
async def test_default_transparency_is_silent(self):
"""默认透明度为 SILENT"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
# ---------------------------------------------------------------------------
# SkillRoutingResult 新字段
# ---------------------------------------------------------------------------
class TestSkillRoutingResultNewFields:
"""SkillRoutingResult 新字段验证"""
def test_default_transparency_level(self):
result = SkillRoutingResult()
assert result.transparency_level == "SILENT"
def test_default_execution_trace(self):
result = SkillRoutingResult()
assert result.execution_trace == []
def test_default_complexity(self):
result = SkillRoutingResult()
assert result.complexity == 0.0
def test_new_fields_backward_compatible(self):
"""新字段不影响旧代码创建 SkillRoutingResult"""
result = SkillRoutingResult(
skill_name="test",
matched=True,
match_method="keyword",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
assert result.complexity == 0.0
# ---------------------------------------------------------------------------
# _tokenize_content: 中文分词增强
# ---------------------------------------------------------------------------
class TestTokenizeContent:
"""_tokenize_content 中文分词增强测试"""
def test_chinese_content(self):
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
tokens = _tokenize_content("帮我做数据分析")
# 整段无标点分隔,生成 2-gram帮我、我做、做数、数据、据分、分析
assert "数据" in tokens or "数据分析" in tokens
def test_english_content(self):
"""英文内容:'help with code generation' 应包含 'code', 'generation''code generation'"""
tokens = _tokenize_content("help with code generation")
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
def test_mixed_content(self):
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
tokens = _tokenize_content("用python做data analysis")
# 按空格分割后 "用python做data" 作为一个 segment生成 2-gram
# "analysis" 作为独立 segment
assert "analysis" in tokens
# "用python做data" 长度 > 4会生成 2-gram其中包含 python 相关片段
has_python_related = any("python" in t for t in tokens)
assert has_python_related or "data analysis" in tokens
def test_stopwords_filtered(self):
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
tokens = _tokenize_content("我的一个")
# "我的一个" 长度 4作为整体保留不在停用词集合中
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
assert len(tokens) <= 1
def test_bigram_generation(self):
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
tokens = _tokenize_content("机器学习模型训练")
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
for bigram in expected_bigrams:
assert bigram in tokens, f"缺少 2-gram: {bigram}"
# ---------------------------------------------------------------------------
# HeuristicClassifier
# ---------------------------------------------------------------------------
class TestHeuristicClassifier:
"""HeuristicClassifier 本地启发式分类器测试"""
def setup_method(self):
self.classifier = HeuristicClassifier()
def test_short_greeting_low_complexity(self):
"""短问候语 → 低复杂度"""
score = self.classifier.classify("你好呀")
assert score < 0.3
def test_simple_question_medium_complexity(self):
"""'如何'的简单问题 → 中等复杂度"""
score = self.classifier.classify("如何使用这个功能?")
assert 0.3 <= score <= 0.7
def test_tool_request_high_complexity(self):
"""含工具关键词的请求 → 高复杂度"""
score = self.classifier.classify("帮我搜索一下最新的新闻")
assert score > 0.5
def test_code_request_high_complexity(self):
"""代码相关请求 → 高复杂度"""
score = self.classifier.classify("写一个Python函数实现快速排序")
assert score > 0.6
def test_multi_step_request_high_complexity(self):
"""多步分析请求 → 高复杂度"""
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
assert score > 0.7
def test_empty_string_zero_complexity(self):
"""空字符串 → 零复杂度"""
assert self.classifier.classify("") == 0.0
assert self.classifier.classify(" ") == 0.0
def test_long_message_higher_complexity(self):
"""长消息 → 更高复杂度"""
short = "帮我查一下"
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
assert self.classifier.classify(long) > self.classifier.classify(short)
def test_code_patterns_boost_complexity(self):
"""代码模式(反引号/括号)提升复杂度"""
with_code = "运行这段代码 `print('hello')`"
without_code = "运行这段代码 print hello"
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
def test_score_bounded_0_to_1(self):
"""复杂度值始终在 [0.0, 1.0] 范围"""
test_inputs = [
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
]
for inp in test_inputs:
score = self.classifier.classify(inp)
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
class TestHeuristicClassifierIntegration:
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
@pytest.mark.asyncio
async def test_heuristic_mode_no_llm_call(self):
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# LLM gateway.chat 不应被调用heuristic + merged disabled
gateway.chat.assert_not_called()
# 复杂度应来自启发式分类器
assert result.complexity > 0.0
@pytest.mark.asyncio
async def test_llm_mode_uses_llm(self):
"""llm 模式下调用 LLM quick_classify"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# LLM gateway.chat 应被调用
gateway.chat.assert_called()
@pytest.mark.asyncio
async def test_heuristic_greeting_still_layer0(self):
"""heuristic 模式下问候仍走 Layer 0"""
router = CostAwareRouter(classifier="heuristic")
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_heuristic_default_classifier_mode(self):
"""默认分类器模式为 heuristic"""
router = CostAwareRouter()
assert router._classifier == "heuristic"
# ---------------------------------------------------------------------------
# U1: Merged LLM Classify
# ---------------------------------------------------------------------------
class TestMergedLLMClassify:
"""合并路由 LLM 调用测试"""
@pytest.mark.asyncio
async def test_merged_classify_returns_valid_skill(self):
"""合并调用返回有效 JSON + skill_hint正确路由到指定 skill"""
merged_response = json.dumps({
"complexity": 0.6,
"intent": "code_generation",
"skill_hint": "search",
})
gateway = _make_llm_gateway(merged_response)
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我搜索一下最新的新闻",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.match_method == "merged_llm"
assert result.complexity > 0.3
@pytest.mark.asyncio
async def test_merged_classify_malformed_response_fallback(self):
"""合并调用返回格式异常fallback 到默认 Agent"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "merged_llm_fallback"
assert result.complexity == 0.5
@pytest.mark.asyncio
async def test_merged_classify_low_complexity(self):
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
merged_response = json.dumps({
"complexity": 0.2,
"intent": "greeting",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Merged LLM returned complexity < 0.3, should route to low complexity
assert result.complexity < 0.3
assert "low" in result.match_method or "merged_llm_low" in result.match_method
@pytest.mark.asyncio
async def test_merged_classify_high_complexity(self):
"""合并调用返回 complexity > 0.7,走 Layer 2"""
merged_response = json.dumps({
"complexity": 0.85,
"intent": "research",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="researcher")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
org_context=org_context, merged_llm_classify=True,
)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
@pytest.mark.asyncio
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
merged_llm_classify=False,
)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not use merged_llm match_method
assert result.match_method != "merged_llm"
@pytest.mark.asyncio
async def test_merged_classify_no_llm_gateway_falls_back(self):
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not crash, should use IntentRouter fallback
assert result is not None
@pytest.mark.asyncio
async def test_merged_classify_skill_hint_not_found_fallback(self):
"""合并调用返回的 skill_hint 在 registry 中不存在fallback"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "unknown",
"skill_hint": "nonexistent_skill",
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should fallback to default agent (medium complexity, no skill match)
assert result.matched is False
assert result.match_method == "merged_llm_medium"
@pytest.mark.asyncio
async def test_merged_classify_only_one_llm_call(self):
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "question",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
await router.route(
content="如何使用这个功能?",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Only 1 LLM call should have been made (the merged classify)
assert gateway.chat.call_count == 1

View File

@ -1,219 +0,0 @@
"""Unit tests for Semantic Router (U3)."""
import pytest
from agentkit.chat.semantic_router import (
SemanticRouteResult,
SkillEmbeddingIndex,
SemanticRouter,
)
from agentkit.memory.embedder import MockEmbedder
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
"""Create a unit vector for similarity testing."""
vec = [base_val] * dim
magnitude = sum(x**2 for x in vec) ** 0.5
return [x / magnitude for x in vec] if magnitude > 0 else vec
class MockSkill:
"""Mock skill for testing."""
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
self.name = name
self.config = MockSkillConfig(
name=name,
description=description,
keywords=keywords or [],
capabilities=capabilities or [],
)
class MockSkillConfig:
"""Mock skill config for testing."""
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
self.name = name
self.description = description
self.intent = MockIntentConfig(keywords=keywords or [])
self.capabilities = [MockCapabilityTag(tag=t) for t in (capabilities or [])]
class MockIntentConfig:
def __init__(self, keywords: list[str] | None = None):
self.keywords = keywords or []
class MockCapabilityTag:
def __init__(self, tag: str):
self.tag = tag
class MockSkillRegistry:
"""Mock skill registry for testing."""
def __init__(self, skills: list[MockSkill] | None = None):
self._skills = {s.name: s for s in (skills or [])}
def list_skills(self):
return list(self._skills.values())
def get(self, name: str):
if name not in self._skills:
raise KeyError(f"Skill '{name}' not found")
return self._skills[name]
class TestSkillEmbeddingIndex:
@pytest.mark.asyncio
async def test_build_from_registry(self):
embedder = MockEmbedder(dimension=64)
index = SkillEmbeddingIndex(embedder)
skills = [
MockSkill("content_gen", description="生成文章内容", keywords=["写作", "文章"], capabilities=["content"]),
MockSkill("data_analysis", description="数据分析与可视化", keywords=["分析", "数据"], capabilities=["analytics"]),
]
registry = MockSkillRegistry(skills)
await index.build(registry)
assert index.size == 2
@pytest.mark.asyncio
async def test_search_returns_results(self):
embedder = MockEmbedder(dimension=64)
index = SkillEmbeddingIndex(embedder)
skill = MockSkill("content_gen", description="生成文章内容")
await index.update_skill("content_gen", skill)
# MockEmbedder produces deterministic embeddings based on text hash
# Different text → different embedding
query_emb = await embedder.embed("生成文章")
results = await index.search(query_emb)
assert len(results) >= 1
assert results[0][0] == "content_gen" # skill_name
assert results[0][1] > 0.0 # similarity
@pytest.mark.asyncio
async def test_search_empty_index(self):
embedder = MockEmbedder(dimension=64)
index = SkillEmbeddingIndex(embedder)
query_emb = await embedder.embed("test")
results = await index.search(query_emb)
assert results == []
@pytest.mark.asyncio
async def test_remove_skill(self):
embedder = MockEmbedder(dimension=64)
index = SkillEmbeddingIndex(embedder)
skill = MockSkill("test_skill", description="Test")
await index.update_skill("test_skill", skill)
assert index.size == 1
index.remove_skill("test_skill")
assert index.size == 0
def test_build_source_text_with_description(self):
skill = MockSkill("test", description="A test skill", keywords=["test"], capabilities=["testing"])
text = SkillEmbeddingIndex._build_source_text(skill)
assert "A test skill" in text
assert "test" in text
assert "testing" in text
def test_build_source_text_fallback_to_name(self):
skill = MockSkill("my_skill", description="", keywords=[], capabilities=[])
text = SkillEmbeddingIndex._build_source_text(skill)
assert "my_skill" in text
class TestSemanticRouter:
@pytest.mark.asyncio
async def test_high_confidence_match(self):
"""When similarity > 0.85, return high confidence."""
embedder = MockEmbedder(dimension=64)
router = SemanticRouter(embedder, similarity_high=0.5, similarity_low=0.3)
# Add a skill with known embedding
skill = MockSkill("content_gen", description="生成文章内容")
await router.update_skill("content_gen", skill)
# Query with same text should produce very similar embedding (MockEmbedder is hash-based)
# With low thresholds, even moderate similarity will be "high"
result = await router.route("生成文章内容")
# MockEmbedder may or may not produce high similarity for different text
# Just verify the result structure
assert result.confidence in ("high", "medium", "low")
assert isinstance(result.similarity, float)
@pytest.mark.asyncio
async def test_low_confidence_empty_index(self):
"""Empty index returns low confidence."""
embedder = MockEmbedder(dimension=64)
router = SemanticRouter(embedder)
result = await router.route("任何查询")
assert result.confidence == "low"
assert result.skill_name is None
assert result.similarity == 0.0
@pytest.mark.asyncio
async def test_medium_confidence_zone(self):
"""Test medium confidence zone (0.6-0.85)."""
embedder = MockEmbedder(dimension=64)
router = SemanticRouter(embedder, similarity_high=0.99, similarity_low=0.01)
skill = MockSkill("content_gen", description="生成文章内容")
await router.update_skill("content_gen", skill)
# With very high similarity_high and very low similarity_low,
# most matches will be "medium"
result = await router.route("生成文章")
# The result should be medium (since threshold is 0.99)
assert result.confidence in ("medium", "low", "high")
@pytest.mark.asyncio
async def test_embedder_failure_graceful(self):
"""Embedder failure returns low confidence."""
class FailingEmbedder(MockEmbedder):
async def embed(self, text):
raise RuntimeError("Embedding API failed")
router = SemanticRouter(FailingEmbedder(dimension=64))
result = await router.route("test query")
assert result.confidence == "low"
assert result.skill_name is None
@pytest.mark.asyncio
async def test_build_index_from_registry(self):
"""Build index from skill registry."""
embedder = MockEmbedder(dimension=64)
router = SemanticRouter(embedder)
skills = [
MockSkill("skill_a", description="Skill A"),
MockSkill("skill_b", description="Skill B"),
]
registry = MockSkillRegistry(skills)
await router.build_index(registry)
assert router._index.size == 2
@pytest.mark.asyncio
async def test_chinese_query(self):
"""Chinese query works with semantic router."""
embedder = MockEmbedder(dimension=64)
router = SemanticRouter(embedder, similarity_high=0.01, similarity_low=0.001)
skill = MockSkill("geo_optimizer", description="地理内容优化", keywords=["优化", "SEO", "地理"], capabilities=["optimization"])
await router.update_skill("geo_optimizer", skill)
result = await router.route("帮我优化内容")
# With very low thresholds, should match
assert result.confidence in ("high", "medium")
assert result.skill_name == "geo_optimizer"