diff --git a/docs/plans/2026-06-15-002-feat-e2e-capability-improvement-plan.md b/docs/plans/2026-06-15-002-feat-e2e-capability-improvement-plan.md index e8f9c75..a74bf91 100644 --- a/docs/plans/2026-06-15-002-feat-e2e-capability-improvement-plan.md +++ b/docs/plans/2026-06-15-002-feat-e2e-capability-improvement-plan.md @@ -1,7 +1,10 @@ --- title: "feat: E2E能力分析框架改进与路由智能化提升" type: feat -status: active +status: superseded +superseded_by: "2026-06-16-005-refactor-routing-architecture-plan" +superseded_reason: "路由智能化提升部分已被 SimpleRouter 架构简化替代。E2E 能力分析框架部分可合并到 2026-06-16-005 的 U5 回测验证单元。" +closed: 2026-06-16 created: 2026-06-15 plan-depth: standard --- diff --git a/docs/plans/2026-06-15-003-feat-router-intelligence-optimization-plan.md b/docs/plans/2026-06-15-003-feat-router-intelligence-optimization-plan.md index 9b75b62..72e8593 100644 --- a/docs/plans/2026-06-15-003-feat-router-intelligence-optimization-plan.md +++ b/docs/plans/2026-06-15-003-feat-router-intelligence-optimization-plan.md @@ -1,8 +1,9 @@ --- title: "feat: 路由智能化优化 — 复杂度校准、意图消歧、质量门控增强" -status: active -created: 2026-06-15 -updated: 2026-06-15 +status: superseded +superseded_by: "2026-06-16-005-refactor-routing-architecture-plan" +superseded_reason: "SimpleRouter 已替代 CostAwareRouter 的 4 层路由架构。IntentRouter 多候选评分(U2)和 QualityGate 技能匹配验证(U3)属于被删除的旧路由层组件,不再需要实现。U1 HeuristicClassifier 测试仅对向后兼容有价值。" +closed: 2026-06-16 origin: test-results/e2e/capability_report.txt (真实LLM回测分析报告) --- diff --git a/docs/plans/2026-06-15-004-feat-semantic-router-and-benchmark-upgrade-plan.md b/docs/plans/2026-06-15-004-feat-semantic-router-and-benchmark-upgrade-plan.md index fe2f388..4be188d 100644 --- a/docs/plans/2026-06-15-004-feat-semantic-router-and-benchmark-upgrade-plan.md +++ b/docs/plans/2026-06-15-004-feat-semantic-router-and-benchmark-upgrade-plan.md @@ -2,9 +2,10 @@ ```yaml title: feat: SemanticRouter 启用与回测体系升级 -status: active -created: 2026-06-15 -plan_id: "2026-06-15-004" +status: superseded +superseded_by: "2026-06-16-005-refactor-routing-architecture-plan" +superseded_reason: "SimpleRouter 已替代 CostAwareRouter,不再需要 SemanticRouter 作为路由层组件。LLM 在 REACT agent loop 中看到完整工具描述后自主决策,无需 embedding 做意图路由。如未来工具数量 >50,可参考 Codex 的 tool_search(BM25)做工具发现。" +closed: 2026-06-16 ``` ## Summary @@ -187,6 +188,12 @@ SemanticRouter 已完整实现(`src/agentkit/chat/semantic_router.py`),但 - 对抗性输入测试 - 意图分类微调流水线 - 关键词自动扩充工具 +- **[待办] 提供 Embedding API Key** — 当前百炼 Coding Plan key (sk-sp-*) 不支持 embedding API,需要: + - 方案 A:在 DashScope 控制台单独开通 embedding 服务,获取标准 sk- 前缀 key + - 方案 B:配置本地 embedding 模型(BGE-large-zh-v1.5 via Xinference) + - 方案 C:使用其他支持 embedding 的 provider(智谱/火山等) + - 配置位置:agentkit.yaml → llm.providers 增加 embedding 专用 provider + - 预期效果:SemanticRouter 真正工作,口语化查询无需手动扩充关键词 ## Risks diff --git a/docs/plans/2026-06-16-006-refactor-architecture-optimization-evolution-plan.md b/docs/plans/2026-06-16-006-refactor-architecture-optimization-evolution-plan.md new file mode 100644 index 0000000..19f473b --- /dev/null +++ b/docs/plans/2026-06-16-006-refactor-architecture-optimization-evolution-plan.md @@ -0,0 +1,429 @@ +```yaml +title: "refactor: AgentKit 架构优化演进 — 对齐业界最佳实践" +status: active +plan_id: "2026-06-16-006" +created: 2026-06-16 +depth: deep +origin: "基于对 Codex/Claude Code/Trae/Qoder 的深入分析,优化 AgentKit 架构" +``` + +# AgentKit 架构优化演进计划 + +## Summary + +基于对 Codex CLI、Claude Code、Trae Agent 2.0、Qoder 的深入分析,对 AgentKit 进行架构优化演进。核心变更:(1) 消除"路由"概念改为请求预处理,(2) 专家团从去中心化协作简化为 hub-and-spoke,(3) PlanExec 简化为 Spec-Driven 模式,(4) 聊天记录 SQLite 持久化,(5) 删除旧路由层代码,(6) 新增可验证执行和工具描述精简。 + +## Problem Frame + +AgentKit 当前架构存在三类问题: + +1. **概念误导**:SimpleRouter 仍暗示"路由层"存在,但实际只做 @skill 前缀解析 + greeting fast-path,核心决策已交给 REACT agent loop 中的 LLM +2. **过度设计**:ExpertTeam 的去中心化协作(CollaborationPlan + HandoffTransport + SharedWorkspace + 3 种 MergeStrategy)远超业界实践,增加 ~1500 行代码但无对应价值 +3. **能力缺口**:PlanExecEngine 默认使用 _LLMStepExecutor(纯 LLM 调用无工具),聊天记录不持久化,无自动验证循环 + +## Requirements + +- R1: 消除"路由"概念,SimpleRouter 重命名为 RequestPreprocessor,语义从"路由决策"变为"请求预处理" +- R2: 专家团简化为 hub-and-spoke 模式(Lead Expert + 并行 Task,深度=1),删除 CollaborationPlan/HandoffTransport/SharedWorkspace 的 ExpertTeam 专用逻辑 +- R3: PlanExecEngine 默认使用 ReActStepExecutor,删除 _LLMStepExecutor +- R4: 聊天记录 SQLite 持久化(参考 Codex 的 Thread 持久化) +- R5: 删除 CostAwareRouter 及相关代码(HeuristicClassifier、IntentRouter、QualityGate、SemanticRouter) +- R6: 新增可验证执行(Test-and-Verify 循环,参考 Codex Cloud) +- R7: 工具描述分层注入(核心工具全量 + 扩展工具一行描述 + tool_search 按需获取) +- R8: 默认启用上下文压缩 +- R9: 新增 Spec 文档作为一等公民(参考 Qoder Quest Mode) +- R10: 统一事件模型(SQ/EQ 双队列,参考 Codex) + +## Key Technical Decisions + +### KTD1: SimpleRouter 重命名为 RequestPreprocessor + +**决策**:将 SimpleRouter 重命名为 RequestPreprocessor,route() 方法重命名为 preprocess() + +**依据**: +- Codex 无路由层,Claude Code 无路由层,Trae 移除了 Proposal 阶段——业界共识是没有路由层 +- SimpleRouter 的 3 个功能(@skill 前缀、greeting regex、default REACT)都是预处理而非路由决策 +- "路由"概念误导开发者认为存在意图预测层,实际上 LLM 在 agent loop 中自主决策 + +**替代方案**:完全删除 SimpleRouter,所有请求直接进入 REACT loop。被否决——greeting fast-path 每次请求节省 ~100 tokens + 500ms,@skill 前缀是用户显式指令需要代码级解析 + +### KTD2: 专家团 hub-and-spoke 模式 + +**决策**:ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task(深度=1) + +**依据**: +- Claude Code:Task 工具深度=1,子 Agent 不能再生子 Agent +- Codex:spawn_agent 层级式,结果返回父 Agent +- Qoder:多专家并行独立执行,主 Agent 汇总 +- 去中心化协作的通信复杂度 O(N²),hub-and-spoke 为 O(N) +- 同一 LLM 扮演不同"专家"不产生真正的观点多样性,等价于多次采样+合并 + +**保留**:ExpertConfig/ExpertTemplate/Registry(定义专家 persona)、BEST 合并策略(Lead Agent 选择最佳结果) + +**删除**:CollaborationPlan 的 phase 依赖图、HandoffTransport 的 Agent 间通信、SharedWorkspace 的跨阶段状态共享、VOTE/FUSION 合并策略 + +### KTD3: PlanExec 默认 ReActStepExecutor + Spec-Driven + +**决策**:默认使用 ReActStepExecutor(已实现),删除 _LLMStepExecutor;新增 Spec 文档持久化 + +**依据**: +- _LLMStepExecutor 不支持工具调用 = 没有执行能力 +- ReActStepExecutor 已实现并使用 ReActEngine,支持工具调用和多步推理 +- Qoder Quest Mode:Spec First 是人和 AI 的契约,用户确认后再执行 +- 当前 PlanExec 的计划对用户不可见,用户无法在执行前纠正方向 + +### KTD4: 聊天记录 SQLite 持久化 + +**决策**:使用 SQLite 做聊天记录持久化(参考 Codex 的 Thread 持久化) + +**依据**: +- Codex 使用 SQLite(轻量、零配置、跨平台),支持 resume/fork/archive +- Claude Code 使用 append-only JSONL(更简单但搜索能力弱) +- 当前 ConversationStore 纯内存,服务重启后丢失 +- SQLite 支持会话搜索、分页加载,且无需额外服务 + +--- + +## Implementation Units + +### U1. SimpleRouter 重命名为 RequestPreprocessor + +**Goal**: 消除"路由"概念,将 SimpleRouter 重命名为 RequestPreprocessor + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/chat/simple_router.py` → 重命名为 `src/agentkit/chat/request_preprocessor.py` +- `src/agentkit/chat/skill_routing.py` — 更新引用 +- `src/agentkit/server/routes/portal.py` — 更新 import 和调用 +- `src/agentkit/server/routes/chat.py` — 更新 import 和调用 +- `src/agentkit/server/app.py` — 更新 import 和调用 +- `tests/unit/chat/test_simple_router.py` → 重命名并更新 + +**Approach**: +1. 创建 `request_preprocessor.py`,类名 `RequestPreprocessor`,方法 `route()` → `preprocess()` +2. `_is_direct_chat()` 重命名为 `_is_trivial_input()` +3. `SkillRoutingResult` 保留(它是数据结构,不涉及路由概念) +4. 更新所有调用点 +5. 删除旧文件 + +**Patterns to follow**: 现有 SimpleRouter 的代码结构 + +**Test scenarios**: +- @skill:xxx 前缀正确解析为 SKILL_REACT 模式 +- Greeting regex 匹配返回 DIRECT_CHAT +- 默认输入返回 REACT 模式 +- 未知 skill 回退到 REACT +- preprocess() 方法签名与 route() 兼容 + +**Verification**: `ruff check src/ && pytest tests/unit/chat/ -v` + +--- + +### U2. 删除 CostAwareRouter 及相关代码 + +**Goal**: 删除已被 SimpleRouter 替代的旧路由层代码 + +**Dependencies**: U1 + +**Files**: +- `src/agentkit/router/` 目录下大部分文件(保留 __init__.py 和必要的导出) +- `src/agentkit/server/app.py` — 清理注释掉的引用 +- `src/agentkit/chat/cost_aware_router.py` — 删除 + +**Approach**: +1. 确认 CostAwareRouter 在代码中无活跃引用(app.py 已注释) +2. 删除 `cost_aware_router.py`、`heuristic_classifier.py`、`intent.py`(IntentRouter)、`quality_gate.py`、`semantic_router.py` +3. 保留 `router/__init__.py` 导出必要的类型(如 ExecutionMode,如果前端依赖) +4. 清理 app.py 中的注释引用 +5. 更新 `router/` 目录的 `__init__.py` + +**Test scenarios**: +- 删除后 `ruff check` 无错误 +- `pytest -m "not integration"` 全部通过 +- 无 import 错误 + +**Verification**: `ruff check src/ && pytest -m "not integration" -x` + +--- + +### U3. 专家团简化为 hub-and-spoke + +**Goal**: 将 ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task 模式 + +**Dependencies**: 无(可与 U1/U2 并行) + +**Files**: +- `src/agentkit/experts/orchestrator.py` — 重写为 hub-and-spoke +- `src/agentkit/experts/team.py` — 简化,移除 CollaborationPlan 依赖 +- `src/agentkit/experts/plan.py` — 简化,保留 MergeStrategy.BEST +- `src/agentkit/core/handoff_transport.py` — 移除 ExpertTeam 专用逻辑 +- `src/agentkit/core/shared_workspace.py` — 移除 ExpertTeam 专用逻辑 +- `src/agentkit/experts/router.py` — 简化为 @team 前缀 + RequestPreprocessor 集成 +- `tests/unit/experts/test_orchestrator.py` — 更新 + +**Approach**: +1. 重写 TeamOrchestrator:Lead Expert 自主规划 + 并行 spawn Task +2. 删除 CollaborationPlan 的 phase 依赖图,Lead Expert 自主决定执行顺序 +3. 删除 HandoffTransport 的 Agent 间通信,Task 结果直接返回 Lead Expert +4. 删除 SharedWorkspace 的跨阶段状态共享,Lead Expert 持有所有状态 +5. 保留 MergeStrategy.BEST(Lead Agent 选择最佳结果),删除 VOTE/FUSION +6. 简化 ExpertTeamRouter 为 @team 前缀触发 +7. 保留 ExpertConfig/ExpertTemplate/Registry 不变 + +**Technical design**: + +``` +新 TeamOrchestrator 流程: +1. 用户输入 → @team:xxx 前缀 → ExpertTeamMode +2. Lead Expert 接收任务,自主分解为子任务 +3. 并行 spawn Task(每个 Task 是独立 ReActEngine 实例) +4. 等待所有 Task 完成 +5. Lead Expert 汇总结果(BEST 策略) +6. 返回最终结果 + +约束: +- Task 深度=1(Task 不能再 spawn Task) +- Task 之间无通信 +- Lead Expert 持有所有状态 +``` + +**Test scenarios**: +- Lead Expert 正确分解任务为子任务 +- 并行 Task 独立执行并返回结果 +- Lead Expert 汇总结果 +- 单个 Task 失败不影响其他 Task +- 所有 Task 失败时回退到 Lead Expert 单独执行 +- @team 前缀正确触发 ExpertTeamMode + +**Verification**: `ruff check src/ && pytest tests/unit/experts/ -v` + +--- + +### U4. PlanExec 默认 ReActStepExecutor + 删除 _LLMStepExecutor + +**Goal**: PlanExecEngine 默认使用 ReActStepExecutor,删除 _LLMStepExecutor + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/core/plan_exec_engine.py` — 删除 _LLMStepExecutor 和 _LLMStepAgent,默认 step_executor_type="react" +- `tests/unit/core/test_plan_exec_engine.py` — 更新 + +**Approach**: +1. 删除 `_LLMStepExecutor` 和 `_LLMStepAgent` 类 +2. `_create_executor()` 方法移除 step_executor_type 参数,始终使用 ReActStepExecutor +3. 清理相关 import + +**Test scenarios**: +- PlanExecEngine 默认创建 ReActStepExecutor +- ReActStepExecutor 正确执行带工具调用的步骤 +- 步骤失败时触发重规划 + +**Verification**: `ruff check src/ && pytest tests/unit/core/test_plan_exec_engine.py -v` + +--- + +### U5. 聊天记录 SQLite 持久化 + +**Goal**: 使用 SQLite 持久化聊天记录,服务重启后不丢失 + +**Dependencies**: U1(RequestPreprocessor 重命名完成后更新调用点) + +**Files**: +- `src/agentkit/chat/sqlite_conversation_store.py` — 新建 +- `src/agentkit/server/routes/portal.py` — 替换 ConversationStore +- `src/agentkit/chat/conversation_store.py` — 保留作为接口/内存实现 + +**Approach**: +1. 新建 `SqliteConversationStore`,实现与 `ConversationStore` 相同接口 +2. SQLite 表结构:conversations(id, session_id, role, content, timestamp, metadata) +3. 支持按 session_id 查询、分页加载、搜索 +4. 数据库文件路径:`~/.agentkit/conversations.db` +5. 在 portal.py 中替换 ConversationStore 为 SqliteConversationStore +6. 保留 ConversationStore 作为内存实现(测试用) + +**Test scenarios**: +- 消息正确持久化到 SQLite +- 按 session_id 查询返回完整对话 +- 分页加载正确 +- 服务重启后数据不丢失 +- SQLite 文件不存在时自动创建 + +**Verification**: `ruff check src/ && pytest tests/unit/chat/test_sqlite_conversation_store.py -v` + +--- + +### U6. 可验证执行(Test-and-Verify 循环) + +**Goal**: ReActEngine 执行后可选自动运行项目测试验证结果 + +**Dependencies**: U4 + +**Files**: +- `src/agentkit/core/verification_loop.py` — 新建 +- `src/agentkit/core/react.py` — 集成验证循环 +- `src/agentkit/tools/builtin.py` — 新增 run_tests 工具 + +**Approach**: +1. 新建 `VerificationLoop`:执行后运行 pytest/ruff/typecheck,失败则自动重试 +2. 最大重试次数可配置(默认 2) +3. 验证结果附加到 ReActResult +4. 新增 `run_tests` 内置工具,LLM 可主动调用 +5. 验证循环默认关闭,通过参数 `verification_enabled=True` 启用 + +**Test scenarios**: +- 验证循环关闭时行为不变 +- 验证循环开启时,执行后自动运行测试 +- 测试通过时返回成功 +- 测试失败时自动重试 +- 达到最大重试次数后返回失败结果 + +**Verification**: `ruff check src/ && pytest tests/unit/core/test_verification_loop.py -v` + +--- + +### U7. 工具描述分层注入 + tool_search + +**Goal**: 核心工具全量注入,扩展工具只注入名称+一行描述,LLM 可通过 tool_search 获取完整描述 + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/core/react.py` — 修改 `_build_tool_use_prompt` +- `src/agentkit/tools/builtin.py` — 新增 tool_search 工具 +- `src/agentkit/tools/search.py` — 新建,BM25 工具搜索 + +**Approach**: +1. 工具分为 core(read/write/bash/search)和 extended(其余) +2. core 工具全量注入 prompt +3. extended 工具只注入 name + one-line description +4. 新增 `tool_search` 工具:BM25 搜索工具描述,返回完整描述 +5. LLM 在 agent loop 中按需调用 tool_search + +**Test scenarios**: +- core 工具全量出现在 prompt 中 +- extended 工具只出现名称和一行描述 +- tool_search 正确返回工具完整描述 +- BM25 搜索相关性排序 + +**Verification**: `ruff check src/ && pytest tests/unit/tools/test_tool_search.py -v` + +--- + +### U8. 默认启用上下文压缩 + +**Goal**: ReActEngine 默认启用滑动窗口压缩 + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/core/react.py` — 修改默认 compressor 参数 +- `src/agentkit/core/compressor.py` — 确认滑动窗口实现 + +**Approach**: +1. ReActEngine 的 `__init__` 中 compressor 默认值从 None 改为 SlidingWindowCompressor +2. 保留最近 N 轮 + 系统提示 + 工具描述 +3. N 可配置(默认 10) + +**Test scenarios**: +- 长对话自动压缩 +- 压缩后系统提示和工具描述保留 +- 压缩不影响最近 N 轮对话 + +**Verification**: `ruff check src/ && pytest tests/unit/core/test_compressor.py -v` + +--- + +### U9. Spec 文档作为一等公民 + +**Goal**: PlanExec 生成的计划持久化为 Spec 文档,用户可查看、编辑、确认后再执行 + +**Dependencies**: U4 + +**Files**: +- `src/agentkit/core/spec_manager.py` — 新建 +- `src/agentkit/core/plan_exec_engine.py` — 集成 SpecManager +- `src/agentkit/server/routes/tasks.py` — 新增 Spec 相关 API + +**Approach**: +1. 新建 `SpecManager`:管理 Spec 文档的 CRUD +2. Spec 文件路径:`.agentkit/specs/.yaml` +3. PlanExecEngine 生成计划后,先持久化为 Spec +4. 新增 API:`GET /api/v1/specs`、`GET /api/v1/specs/{id}`、`PUT /api/v1/specs/{id}`、`POST /api/v1/specs/{id}/confirm` +5. 用户确认后才开始执行 + +**Test scenarios**: +- 计划正确持久化为 Spec 文件 +- Spec 文件可读取和编辑 +- 未确认的 Spec 不会执行 +- 确认后触发执行 + +**Verification**: `ruff check src/ && pytest tests/unit/core/test_spec_manager.py -v` + +--- + +### U10. 统一事件模型(SQ/EQ 双队列) + +**Goal**: 统一 CLI 和 WebSocket 的事件模型为 SQ/EQ 双队列 + +**Dependencies**: U3 + +**Files**: +- `src/agentkit/core/protocol.py` — 新增 SQ/EQ 事件类型 +- `src/agentkit/server/routes/portal.py` — 对接 EQ +- `src/agentkit/cli/chat.py` — 对接 EQ + +**Approach**: +1. 定义 SubmissionQueue(用户输入)和 EventQueue(Agent 输出) +2. 事件类型:Session/Task/Turn 三级模型 +3. Portal WebSocket 和 CLI 共享同一事件流 +4. 前端可以统一渲染 + +**Test scenarios**: +- SQ 正确接收用户输入 +- EQ 正确推送 Agent 事件 +- WebSocket 和 CLI 共享事件流 +- 事件类型正确分类 + +**Verification**: `ruff check src/ && pytest tests/unit/core/test_protocol.py -v` + +--- + +## Scope Boundaries + +### In Scope +- 上述 10 个 Implementation Unit +- 单元测试覆盖 + +### Out of Scope +- DockerComputerUseSession 实现(P3,等用户需求验证) +- 前端组件更新(后续迭代) +- Agent 配置热重载(P4) +- 渐进式上下文加载(P4) +- Soul 演变多维度触发(关闭) +- OTel 埋点(延后) + +### Deferred to Follow-Up Work +- 前端 ExpertTeamView 接入真实数据 +- SWE-bench 端到端验证 +- 性能监控和成本追踪 + +--- + +## Risks & Dependencies + +| Risk | Impact | Mitigation | +|------|--------|------------| +| U3 专家团重写可能影响现有 @team 功能 | 中 | 保留 ExpertConfig/ExpertTemplate/Registry,只重写 Orchestrator | +| U5 SQLite 在高并发下可能有锁竞争 | 低 | 聊天场景写频率低,SQLite WAL 模式足够 | +| U7 tool_search 可能增加 LLM 调用轮次 | 中 | 核心工具全量注入,只有扩展工具需要搜索 | +| U9 Spec 文档可能增加用户操作步骤 | 低 | 默认自动确认(可配置),不阻塞自动化流程 | + +--- + +## Phased Delivery + +**Phase 1(清理收尾)**: U1 → U2 → U4 → U8 +**Phase 2(核心能力)**: U5 → U6 → U9 +**Phase 3(多 Agent)**: U3 → U7 → U10 diff --git a/src/agentkit/chat/simple_router.py b/src/agentkit/chat/request_preprocessor.py similarity index 88% rename from src/agentkit/chat/simple_router.py rename to src/agentkit/chat/request_preprocessor.py index 7d1ee83..afa267d 100644 --- a/src/agentkit/chat/simple_router.py +++ b/src/agentkit/chat/request_preprocessor.py @@ -1,4 +1,4 @@ -"""Simple router — minimal routing layer for unified REACT agent loop. +"""Request preprocessor — minimal preprocessing layer for unified REACT agent loop. Replaces the 4-layer CostAwareRouter with a simple approach: 1. @skill:xxx prefix → explicit skill selection @@ -53,15 +53,15 @@ _IDENTITY_RE = re.compile( ) -class SimpleRouter: - """Minimal routing layer: regex fast-path + default REACT. +class RequestPreprocessor: + """Minimal preprocessing layer: regex fast-path + default REACT. Design rationale: - No HeuristicClassifier: keyword enumeration can never cover all colloquial expressions - No IntentRouter: LLM blind-classification without tool context is unreliable - No SemanticRouter: embedding similarity is not intent recognition - LLM in the REACT agent loop sees full tool descriptions and decides autonomously - - This matches Codex/Trae/Hermes architecture: unified agent loop, no routing layer + - This matches Codex/Trae/Hermes architecture: unified agent loop, no preprocessing layer """ def __init__( @@ -78,7 +78,7 @@ class SimpleRouter: self._default_model = default_model self._default_agent_name = default_agent_name - async def route( + async def preprocess( self, content: str, *, @@ -90,7 +90,7 @@ class SimpleRouter: session_id: str = "", transparency: str = "SILENT", ) -> SkillRoutingResult: - """Route user input to the appropriate execution path. + """Preprocess user input to determine the appropriate execution path. Decision tree: 1. @skill:xxx prefix → explicit skill (SKILL_REACT or skill's execution_mode) @@ -99,21 +99,25 @@ class SimpleRouter: """ registry = skill_registry or self._skill_registry tools = default_tools if default_tools is not None else self._default_tools - sys_prompt = default_system_prompt if default_system_prompt is not None else self._default_system_prompt + sys_prompt = ( + default_system_prompt + if default_system_prompt is not None + else self._default_system_prompt + ) model = default_model or self._default_model agent_name = default_agent_name or self._default_agent_name # --- Layer 0: @skill:xxx prefix --- explicit_skill, clean_content = parse_skill_prefix(content) if explicit_skill and registry is not None: - result = self._route_explicit_skill( + result = self._resolve_explicit_skill( explicit_skill, clean_content, registry, model, agent_name ) return result # --- Layer 1: Greeting/chitchat/identity regex (<1ms, zero tokens) --- stripped = content.strip() - if self._is_direct_chat(stripped): + if self._is_trivial_input(stripped): result = SkillRoutingResult( clean_content=stripped, matched=False, @@ -141,7 +145,7 @@ class SimpleRouter: ) return result - def _route_explicit_skill( + def _resolve_explicit_skill( self, skill_name: str, clean_content: str, @@ -149,7 +153,7 @@ class SimpleRouter: model: str, agent_name: str, ) -> SkillRoutingResult: - """Route to an explicitly specified skill via @skill:xxx prefix.""" + """Resolve an explicitly specified skill via @skill:xxx prefix.""" try: skill = registry.get(skill_name) except Exception: @@ -185,13 +189,11 @@ class SimpleRouter: ) @staticmethod - def _is_direct_chat(text: str) -> bool: + def _is_trivial_input(text: str) -> bool: """Check if the input is a greeting, chitchat, or identity question. These are zero-cost direct chat: no tool usage, no ReAct loop needed. """ return bool( - _GREETING_RE.match(text) - or _CHAT_MODE_RE.match(text) - or _IDENTITY_RE.match(text) + _GREETING_RE.match(text) or _CHAT_MODE_RE.match(text) or _IDENTITY_RE.match(text) ) diff --git a/src/agentkit/chat/semantic_router.py b/src/agentkit/chat/semantic_router.py deleted file mode 100644 index 1e4ea8b..0000000 --- a/src/agentkit/chat/semantic_router.py +++ /dev/null @@ -1,224 +0,0 @@ -"""Semantic Router — Embedding-based intent routing as Layer 1.5. - -Uses pre-computed skill embeddings for zero-cost semantic matching, -inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification) -in CostAwareRouter. - -Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md -""" - -import logging -from dataclasses import dataclass -from typing import Any - -from agentkit.memory.embedder import Embedder, EmbeddingCache -from agentkit.utils.vector_math import compute_cosine_similarity - -logger = logging.getLogger(__name__) - - -@dataclass -class SemanticRouteResult: - """Result of semantic routing.""" - - confidence: str # "high" | "medium" | "low" - skill_name: str | None - similarity: float - - -class SkillEmbeddingIndex: - """Pre-computed embedding index for registered skills. - - Embeddings are computed at skill registration time and cached. - Query-time search is O(n) cosine similarity scan, which is fast - for <100 skills with 1024-1536 dim vectors. - """ - - def __init__(self, embedder: Embedder): - self._embedder = embedder - # skill_name → (embedding, source_text) - self._index: dict[str, tuple[list[float], str]] = {} - - async def build(self, skill_registry: Any) -> None: - """Build index from all registered skills.""" - if skill_registry is None: - return - skills = skill_registry.list_skills() - for skill in skills: - await self.update_skill(skill.config.name, skill) - - async def update_skill(self, skill_name: str, skill: Any) -> None: - """Re-embed a single skill (on registration/update).""" - source_text = self._build_source_text(skill) - try: - embedding = await self._embedder.embed(source_text) - self._index[skill_name] = (embedding, source_text) - except Exception as e: - logger.warning(f"Failed to embed skill '{skill_name}': {e}") - - def remove_skill(self, skill_name: str) -> None: - """Remove a skill from the index.""" - self._index.pop(skill_name, None) - - async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]: - """Search for skills matching the query embedding. - - Returns: - List of (skill_name, similarity) sorted by similarity descending. - """ - if not self._index: - return [] - - results: list[tuple[str, float]] = [] - for skill_name, (emb, _) in self._index.items(): - sim = compute_cosine_similarity(query_embedding, emb) - results.append((skill_name, sim)) - - results.sort(key=lambda x: x[1], reverse=True) - return results[:top_k] - - @staticmethod - def _build_source_text(skill: Any) -> str: - """Build embedding source text from skill metadata. - - Combines description, intent keywords, and capability tags - for rich semantic representation. - """ - config = skill.config if hasattr(skill, "config") else skill - parts = [] - - # Description - description = getattr(config, "description", "") or "" - if description: - parts.append(description) - - # Intent keywords - intent = getattr(config, "intent", None) - if intent and hasattr(intent, "keywords") and intent.keywords: - parts.append(" ".join(intent.keywords)) - - # Intent examples (rich semantic signal for short queries) - if intent and hasattr(intent, "examples") and intent.examples: - parts.append(" ".join(intent.examples)) - - # Capability tags - capabilities = getattr(config, "capabilities", None) - if capabilities: - tags = [] - for cap in capabilities: - if isinstance(cap, str): - tags.append(cap) - elif isinstance(cap, dict): - tags.append(cap.get("tag", "")) - elif hasattr(cap, "tag"): - tags.append(cap.tag) - if tags: - parts.append(" ".join(t for t in tags if t)) - - # Fallback: use skill name if no other text available - if not parts: - parts.append(getattr(config, "name", "unknown")) - - return " | ".join(parts) - - @property - def size(self) -> int: - """Number of skills in the index.""" - return len(self._index) - - -class SemanticRouter: - """Embedding-based semantic routing as Layer 1.5. - - Three confidence zones: - - similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2 - - similarity_low (0.4) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2 - - similarity < similarity_low (0.4): LOW → no semantic signal, normal routing - - Short text (<20 chars) uses a lower effective threshold because - brief queries naturally have lower embedding similarity. - """ - - _SHORT_TEXT_THRESHOLD = 20 # chars - - def __init__( - self, - embedder: Embedder, - similarity_high: float = 0.85, - similarity_low: float = 0.4, - ): - self._embedder = embedder - self._similarity_high = similarity_high - self._similarity_low = similarity_low - self._index = SkillEmbeddingIndex(embedder) - self._query_cache = EmbeddingCache(max_size=500, ttl=1800) - - async def build_index(self, skill_registry: Any) -> None: - """Build skill embedding index from registry.""" - await self._index.build(skill_registry) - logger.info(f"Semantic router index built: {self._index.size} skills") - - async def update_skill(self, skill_name: str, skill: Any) -> None: - """Update a single skill's embedding.""" - await self._index.update_skill(skill_name, skill) - - def remove_skill(self, skill_name: str) -> None: - """Remove a skill from the index.""" - self._index.remove_skill(skill_name) - - async def route(self, query: str) -> SemanticRouteResult: - """Route a query using semantic similarity. - - Args: - query: User's input text. - - Returns: - SemanticRouteResult with confidence, skill_name, and similarity. - """ - if self._index.size == 0: - return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) - - if not query or not query.strip(): - return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) - - try: - # Get query embedding (with cache) - query_embedding = self._query_cache.get(query) - if query_embedding is None: - query_embedding = await self._embedder.embed(query) - self._query_cache.put(query, query_embedding) - - # Search skill index - results = await self._index.search(query_embedding, top_k=1) - if not results: - return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) - - best_skill, best_sim = results[0] - - # Short text uses lower effective threshold - effective_low = self._similarity_low - if len(query) < self._SHORT_TEXT_THRESHOLD: - effective_low = max(0.25, self._similarity_low - 0.15) - - if best_sim >= self._similarity_high: - return SemanticRouteResult( - confidence="high", - skill_name=best_skill, - similarity=best_sim, - ) - elif best_sim >= effective_low: - return SemanticRouteResult( - confidence="medium", - skill_name=best_skill, - similarity=best_sim, - ) - else: - return SemanticRouteResult( - confidence="low", - skill_name=None, - similarity=best_sim, - ) - - except Exception as e: - logger.warning(f"Semantic routing failed, returning low confidence: {e}") - return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0) diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index b524c97..6ee0f22 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -7,15 +7,11 @@ and prompt assembly into a single module used by both chat routes. from __future__ import annotations import enum -import json import logging import re from dataclasses import dataclass, field from typing import Any -from agentkit.marketplace.auction import AuctionHouse, Bid -from agentkit.telemetry.tracer import get_tracer - logger = logging.getLogger(__name__) # Strict validation: only lowercase alphanumeric, hyphens, underscores @@ -117,18 +113,17 @@ def build_skill_system_prompt(skill_config) -> str | None: async def resolve_skill_routing( content: str, skill_registry: Any, - intent_router: Any, default_tools: list, default_system_prompt: str | None, default_model: str = "default", default_agent_name: str = "default", agent_tool_registry: Any = None, session_id: str = "", - force_skill: str | None = None, ) -> SkillRoutingResult: """Resolve skill routing for a user message. - This is the shared entry point used by both GUI WebSocket chat and CLI chat. + This is the shared entry point used by CLI chat. + Uses @skill: prefix matching for explicit skill selection. Returns a SkillRoutingResult with all execution parameters set. """ result = SkillRoutingResult() @@ -159,133 +154,6 @@ async def resolve_skill_routing( result.skill_name = None result.skill_config = None - # Try force_skill match (from semantic router high confidence) - if not result.matched and force_skill and skill_registry: - try: - matched_skill = skill_registry.get(force_skill) - result.skill_name = force_skill - result.skill_config = matched_skill.config - result.skill_tools = matched_skill.tools or [] - result.matched = True - result.match_method = "semantic_force" - result.match_confidence = 1.0 - logger.info(f"Session {session_id}: using force-matched skill '{force_skill}'") - except Exception as e: - logger.warning(f"Session {session_id}: force skill '{force_skill}' not found: {e}") - - # Try IntentRouter if no explicit match - if not result.matched and skill_registry and intent_router: - skills = skill_registry.list_skills() - routable_skills = [s for s in skills if s.config.intent.keywords] - if routable_skills: - try: - routing_result = await intent_router.route( - input_data={"content": clean_content}, - skills=routable_skills, - ) - if routing_result and routing_result.confidence >= 0.5: - skill_name = routing_result.matched_skill - try: - matched_skill = skill_registry.get(skill_name) - skill_config = matched_skill.config - - # Check if matched skill can handle tool-calling tasks. - # Direct-mode agents with no tools cannot execute tasks - # that require tool use (shell, search, etc.). - # If the task content suggests tool needs, fall through - # to default agent which has full tool access. - execution_mode = getattr(skill_config, "execution_mode", "react") - skill_tools = matched_skill.tools or [] - if execution_mode == "direct" and not skill_tools: - # Direct agent matched but has no tools — check if - # the task might need tools. If so, skip this match - # and let it fall through to default agent. - tool_hints = [ - "执行", - "运行", - "命令", - "终端", - "shell", - "bash", - "搜索", - "查找", - "联网", - "search", - "安装", - "部署", - "启动", - "停止", - "重启", - "文件", - "目录", - "创建", - "删除", - "修改", - "查看", - "检查", - "监控", - "测试", - "浏览", - "下载", - "上传", - "读取", - "写入", - "导出", - "导入", - "run", - "execute", - "install", - "deploy", - "start", - "stop", - "restart", - "file", - "check", - "monitor", - "test", - "browse", - "download", - "upload", - "read", - "write", - "export", - "import", - ] - content_lower = clean_content.lower() - needs_tools = any(h in content_lower for h in tool_hints) - if needs_tools: - logger.info( - f"Session {session_id}: skill '{skill_name}' is direct-mode " - f"but task may need tools, falling through to default agent" - ) - # Don't set result.matched, let it fall through - else: - result.skill_name = skill_name - result.skill_config = skill_config - result.skill_tools = skill_tools - result.matched = True - result.match_method = routing_result.method - result.match_confidence = routing_result.confidence - else: - result.skill_name = skill_name - result.skill_config = skill_config - result.skill_tools = skill_tools - result.matched = True - result.match_method = routing_result.method - result.match_confidence = routing_result.confidence - - if result.matched: - logger.info( - f"Session {session_id}: routed to skill '{skill_name}' " - f"via {routing_result.method} (confidence={routing_result.confidence})" - ) - except Exception as e: - logger.warning( - f"Session {session_id}: skill '{skill_name}' found by router but not in registry: {e}" - ) - except Exception as e: - logger.warning(f"Skill routing failed for session {session_id}: {e}") - # Determine execution parameters if result.matched and result.skill_config: skill_prompt = build_skill_system_prompt(result.skill_config) @@ -349,1340 +217,3 @@ def _build_tools_description(tools: list) -> str: if params: lines[-1] += f" (parameters: {', '.join(params)})" return "\n".join(lines) - - -# --------------------------------------------------------------------------- -# CostAwareRouter - 三层成本感知路由 -# --------------------------------------------------------------------------- - -_GREETING_RE = re.compile( - r"^(你好|hi|hello|hey|嗨|哈喽|早上好|下午好|晚上好|good morning|good afternoon|good evening)\s*[!!.。??]*$", - re.IGNORECASE, -) - -_CHAT_MODE_RE = re.compile( - r"^(谢谢|感谢|thanks|thank you|ok|好的|嗯|对|是|不是|没关系|再见|bye|goodbye)\s*[!!.。??]*$", - re.IGNORECASE, -) - -# Simple identity/meta questions — zero-cost direct chat, no skill routing needed -_IDENTITY_RE = re.compile( - r"^(你是谁|你叫什么|你是什么|你是哪个|who are you|what are you|what's your name" - r"|介绍一下你自己|自我介绍|你叫啥|你叫什么名字|你的名字)" - r"\s*[??!!.。]*$", - re.IGNORECASE, -) - -_SENTENCE_SPLIT_RE = re.compile(r"[,。!?;\n,.!?;]") - - -def _tokenize_content(content: str) -> list[str]: - """Tokenize content for capability matching. Supports Chinese and English.""" - # 1. Split by punctuation and whitespace - segments = re.split(r"[\s,,。!?、;:\n]+", content) - - # 2. For long Chinese segments, add 2-gram supplements - tokens = [] - for seg in segments: - if len(seg) <= 4: - tokens.append(seg) - else: - tokens.append(seg) - # Add 2-grams for Chinese compound words - for i in range(len(seg) - 1): - bigram = seg[i : i + 2] - if all("\u4e00" <= c <= "\u9fff" for c in bigram): - tokens.append(bigram) - - # 3. Filter stopwords - stopwords = { - "的", - "了", - "是", - "在", - "和", - "与", - "也", - "都", - "就", - "要", - "会", - "我", - "你", - "他", - "这", - "那", - "有", - "没", - "不", - } - tokens = [t for t in tokens if t not in stopwords and len(t) > 1][:10] - - return tokens - - -class HeuristicClassifier: - """零成本本地启发式分类器,替代 LLM quick_classify。 - - 基于消息长度、关键词密度、工具暗示等特征评估复杂度 (0.0-1.0), - 无需任何 LLM 调用,延迟 <1ms。 - """ - - # 高复杂度暗示词(需要工具或多步推理) - # 中文关键词使用子串匹配(中文无自然词边界) - _HIGH_COMPLEXITY_HINTS_CN = { - "执行", - "运行", - "命令", - "终端", - "安装", - "部署", - "启动", - "停止", - "重启", - "配置", - "搜索", - "查找", - "联网", - "文件", - "目录", - "创建", - "删除", - "修改", - "编辑", - "分析", - "比较", - "对比", - "评估", - "调研", - "研究", - "设计", - "规划", - "方案", - "架构", - "实现", - "开发", - "代码", - "编程", - "函数", - "接口", - "调试", - "重构", - "查看", - "检查", - "监控", - "测试", - "浏览", - "下载", - "上传", - "读取", - "写入", - "导出", - "导入", - } - - # 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file") - _HIGH_COMPLEXITY_HINTS_EN = { - "shell", - "bash", - "script", - "search", - "query", - "directory", - "execute", - "install", - "deploy", - "restart", - "modify", - "analyze", - "compare", - "evaluate", - "research", - "design", - "implement", - "develop", - "refactor", - "debug", - "python", - "javascript", - "typescript", - "sql", - "check", - "monitor", - "test", - "browse", - "download", - "upload", - "export", - "import", - } - - # 英文短词需要精确匹配(避免子串误匹配) - _HIGH_COMPLEXITY_EXACT_EN = { - "run", - "find", - "start", - "stop", - "file", - "create", - "delete", - "plan", - "build", - "code", - "program", - "function", - "class", - "interface", - "api", - } - - # 中等复杂度暗示词(简单问题但需思考) - # 注意:不包含"怎么",因为"怎么样"是闲聊而非工具需求 - _MEDIUM_COMPLEXITY_HINTS_CN = { - "如何", - "怎样", - "为什么", - "什么原因", - "区别", - "推荐", - "建议", - "选择", - "哪个", - } - - _MEDIUM_COMPLEXITY_HINTS_EN = { - "difference", - "explain", - "recommend", - "suggest", - "choose", - } - - # 英文短词精确匹配 - _MEDIUM_COMPLEXITY_EXACT_EN = { - "how", - "why", - "what", - "which", - } - - # 低复杂度暗示词(问候/闲聊/简单定义,不需要工具) - # 注意:不包含"怎么样"、"今天"等通用疑问/时间词,因为它们可搭配高复杂度问题 - _LOW_COMPLEXITY_HINTS_CN = { - "你好", - "嗨", - "早上好", - "下午好", - "晚上好", - "再见", - "谢谢", - "辛苦", - "你是谁", - "你叫什么", - "你是什么", - "自我介绍", - "闲聊", - "聊天", - } - - _LOW_COMPLEXITY_HINTS_EN = { - "hello", - "hi", - "hey", - "good morning", - "good afternoon", - "good evening", - "goodbye", - "thanks", - "who are you", - "what are you", - "your name", - "introduce yourself", - "how are you", - "chat", - } - - # 否定上下文模式("不要X"中的X不计入高复杂度匹配) - # 匹配1-4个中文字符或1个英文单词(避免匹配过长串如"分析,直接告诉我答案") - _NEGATION_PATTERNS = re.compile( - r"(?:不要|无需|不用|不需要|别|don'?t|no need|without|not)\s*" - r"([\u4e00-\u9fff]{1,4}|[a-zA-Z]+)", - re.IGNORECASE, - ) - - # 短疑问句模式(以?或?结尾且长度<30) - _SHORT_QUESTION_RE = re.compile(r"[??]\s*$") - - # 预编译英文词边界正则 - _HIGH_EN_RE = re.compile( - r"\b(" - + "|".join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_HINTS_EN, key=len, reverse=True)) - + r")\b", - re.IGNORECASE, - ) - _HIGH_EXACT_RE = re.compile( - r"\b(" - + "|".join(re.escape(w) for w in sorted(_HIGH_COMPLEXITY_EXACT_EN, key=len, reverse=True)) - + r")\b", - re.IGNORECASE, - ) - _MEDIUM_EN_RE = re.compile( - r"\b(" - + "|".join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_HINTS_EN, key=len, reverse=True)) - + r")\b", - re.IGNORECASE, - ) - _MEDIUM_EXACT_RE = re.compile( - r"\b(" - + "|".join(re.escape(w) for w in sorted(_MEDIUM_COMPLEXITY_EXACT_EN, key=len, reverse=True)) - + r")\b", - re.IGNORECASE, - ) - - def classify(self, content: str) -> float: - """评估消息复杂度 (0.0-1.0)。 - - 评分规则: - - 低复杂度信号(问候/闲聊/身份查询)→ 0.05 - - 短消息 (<20字符) 且无复杂度暗示 → 0.1 - - 含中等复杂度关键词 → 0.35 - - 含高复杂度关键词 → 0.65-0.8 - - 否定上下文中的高复杂度词不计入匹配 - - 短疑问句额外扣减 - - 多句/长消息 → 额外加成 - - 代码模式 (反引号/括号) → 额外加成 - """ - if not content or not content.strip(): - return 0.0 - - content_lower = content.lower() - score = 0.0 - - # 0. 低复杂度信号检测(仅在无高复杂度信号时生效) - low_hits_cn = sum(1 for h in self._LOW_COMPLEXITY_HINTS_CN if h in content_lower) - low_hits_en = sum(1 for h in self._LOW_COMPLEXITY_HINTS_EN if h in content_lower) - has_low_signal = low_hits_cn + low_hits_en > 0 - - # 1. 否定上下文检测 — 提取被否定的词 - negated_words: set[str] = set() - for match in self._NEGATION_PATTERNS.finditer(content_lower): - negated_words.add(match.group(1).lower()) - - # 2. 关键词匹配(排除否定上下文中的词) - # 中文:子串匹配 - high_hits = sum( - 1 - for h in self._HIGH_COMPLEXITY_HINTS_CN - if h in content_lower and h not in negated_words - ) - medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS_CN if m in content_lower) - - # 英文:词边界匹配 - high_en_matches = self._HIGH_EN_RE.findall(content) + self._HIGH_EXACT_RE.findall(content) - high_hits += sum(1 for w in high_en_matches if w.lower() not in negated_words) - medium_hits += len(self._MEDIUM_EN_RE.findall(content)) + len( - self._MEDIUM_EXACT_RE.findall(content) - ) - - has_non_low_signal = high_hits > 0 or medium_hits > 0 - - # 低复杂度信号仅在无高/中复杂度信号时生效 - if has_low_signal and not has_non_low_signal: - score = 0.05 # 问候/闲聊直接给极低分 - length = len(content) - if length > 200: - score += 0.05 - elif length > 100: - score += 0.03 - return max(0.0, min(1.0, score)) - - if high_hits >= 2: - score = 0.80 - elif high_hits == 1: - score = 0.65 - elif medium_hits >= 1: - score = 0.35 - else: - score = 0.10 - - # 3. 消息长度加成 - length = len(content) - if length > 200: - score += 0.15 - elif length > 100: - score += 0.10 - elif length > 50: - score += 0.05 - - # 4. 多句加成(逗号/句号/换行分隔) - sentence_count = len(_SENTENCE_SPLIT_RE.split(content)) - if sentence_count >= 4: - score += 0.10 - elif sentence_count >= 2: - score += 0.05 - - # 5. 代码模式加成 - if "`" in content or "```" in content: - score += 0.15 - if re.search(r"[\{\}\[\]\(\)]", content): - score += 0.05 - - # 6. 短疑问句扣减(以?或?结尾且长度<30) - if self._SHORT_QUESTION_RE.search(content) and len(content) < 30: - score -= 0.10 - - return max(0.0, min(1.0, score)) - - -class CostAwareRouter: - """三层成本感知路由器。 - - Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话 - Layer 1: 复杂度分类 — heuristic(零成本)或 LLM(~100 tokens) - Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent - """ - - def __init__( - self, - llm_gateway: Any = None, - model: str = "default", - org_context: Any = None, - auction_enabled: bool = False, - classifier: str = "heuristic", - merged_llm_classify: bool = True, - semantic_router: Any = None, # SemanticRouter | None - expert_team_router: Any = None, # ExpertTeamRouter | None - ): - self._llm_gateway = llm_gateway - self._model = model - self._org_context = org_context - self._auction_enabled = auction_enabled - self._classifier = classifier - self._merged_llm_classify = merged_llm_classify - self._semantic_router = semantic_router - self._expert_team_router = expert_team_router - self._auction_house = AuctionHouse() if auction_enabled else None - if classifier not in ("heuristic", "llm"): - raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'") - self._heuristic = HeuristicClassifier() - - # -- Layer 0: Rule-based (zero cost) ------------------------------------ - - def _match_layer0(self, content: str) -> tuple[str | None, str]: - """Layer 0 规则匹配。 - - Returns: - (match_type, clean_content) — match_type 为 None 表示未命中。 - """ - # @skill: 显式前缀 - explicit_skill, clean = parse_skill_prefix(content) - if explicit_skill: - return "explicit_skill", clean - - # 问候模式 - stripped = content.strip() - if _GREETING_RE.match(stripped): - return "greeting", stripped - - # 简单对话模式 - if _CHAT_MODE_RE.match(stripped): - return "chat_mode", stripped - - # 身份/元问题模式("你是谁"等)— 零成本直接对话 - if _IDENTITY_RE.match(stripped): - return "identity", stripped - - return None, stripped - - # -- Layer 1: LLM quick classify (~100 tokens) ------------------------- - - async def quick_classify(self, content: str) -> float: - """使用 LLM 快速评估用户请求的复杂度 (0.0-1.0)。 - - 当 LLM Gateway 不可用或解析失败时,返回默认中等复杂度 0.5。 - """ - if self._llm_gateway is None: - return 0.5 - - prompt = ( - "You are a complexity classifier. Rate the complexity of the user request on a scale of 0.0 to 1.0.\n" - "0.0 = trivial greeting, 0.3 = simple question, 0.5 = moderate task, " - "0.7 = complex multi-step task, 1.0 = very complex research task.\n\n" - "---BEGIN USER REQUEST---\n" - f"{content}\n" - "---END USER REQUEST---\n\n" - 'Respond ONLY with a JSON object: {"complexity": }' - ) - try: - response = await self._llm_gateway.chat( - messages=[{"role": "user", "content": prompt}], - model=self._model, - ) - data = json.loads(response.content.strip()) - complexity = float(data.get("complexity", 0.5)) - return max(0.0, min(1.0, complexity)) - except Exception as e: - logger.warning(f"CostAwareRouter quick_classify failed: {e}") - return 0.5 - - # -- Layer 1.5: Merged LLM classify (complexity + intent in one call) --- - - async def _classify_merged( - self, - content: str, - skill_registry: Any, - intent_router: Any, - default_tools: list, - default_system_prompt: str | None, - default_model: str, - default_agent_name: str, - agent_tool_registry: Any = None, - session_id: str = "", - complexity: float = 0.5, - ) -> SkillRoutingResult: - """合并 LLM 调用:单次 LLM 同时输出 complexity + intent + skill_hint。 - - 当 HeuristicClassifier 返回不确定区间 (0.3-0.7) 时使用, - 替代分别调用 quick_classify() 和 IntentRouter._classify_with_llm(), - 节省 1 次 LLM 调用 (~1-3s)。 - """ - if self._llm_gateway is None or not self._merged_llm_classify: - # Fallback: 使用独立的 IntentRouter 路由 - return await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - - # Build skill list for the prompt - skill_hints = [] - if skill_registry: - try: - for s in skill_registry.list_skills(): - if s.config.intent and s.config.intent.keywords: - skill_hints.append(s.name) - except Exception: - pass - - skill_list_str = ", ".join(skill_hints) if skill_hints else "none" - - prompt = ( - "You are a routing classifier. Analyze the user request and output:\n" - "1. complexity (0.0-1.0): how complex is this request\n" - "2. intent: the primary intent category\n" - "3. skill_hint: the best matching skill name, or null if none match\n\n" - f"Available skills: [{skill_list_str}]\n\n" - "---BEGIN USER REQUEST---\n" - f"{content}\n" - "---END USER REQUEST---\n\n" - 'Respond ONLY with a JSON object: {"complexity": , "intent": , "skill_hint": }' - ) - - try: - response = await self._llm_gateway.chat( - messages=[{"role": "user", "content": prompt}], - model=self._model, - ) - data = json.loads(response.content.strip()) - merged_complexity = float(data.get("complexity", 0.5)) - merged_complexity = max(0.0, min(1.0, merged_complexity)) - skill_hint = data.get("skill_hint") - - # Validate skill_hint against name pattern before lookup - if skill_hint and skill_registry: - if not _SKILL_NAME_RE.match(str(skill_hint).strip().lower()): - logger.warning(f"Invalid skill_hint from LLM: {skill_hint!r}") - skill_hint = None - try: - matched_skill = skill_registry.get(skill_hint) - result = SkillRoutingResult( - clean_content=content, - skill_name=skill_hint, - skill_config=matched_skill.config, - skill_tools=matched_skill.tools or [], - matched=True, - match_method="merged_llm", - match_confidence=0.7, - complexity=merged_complexity, - execution_mode=_resolve_execution_mode(matched_skill.config), - ) - # Merge tools - agent_tools = ( - agent_tool_registry.list_tools() if agent_tool_registry else default_tools - ) - seen_names = set() - merged_tools = [] - for tool in result.skill_tools + agent_tools: - if tool.name not in seen_names: - seen_names.add(tool.name) - merged_tools.append(tool) - result.tools = merged_tools - result.model = ( - result.skill_config.llm.get("model", default_model) - if result.skill_config.llm - else default_model - ) - result.agent_name = skill_hint - result.system_prompt = ( - build_skill_system_prompt(result.skill_config) or default_system_prompt - ) - # Append available tools to system prompt so LLM knows what it can call - if result.tools: - tools_desc = _build_tools_description(result.tools) - tool_instruction = ( - "\n\n## Tool Usage\n" - "You have access to the following tools. When you need to use a tool, " - "respond with a tool call in the format specified by the system.\n" - "Never make up information or guess answers when you can use a tool to find the answer.\n" - "Always prefer using tools over guessing.\n" - ) - if result.system_prompt: - result.system_prompt += ( - f"{tool_instruction}\n## Available Tools\n{tools_desc}" - ) - logger.info( - f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' " - f"(complexity={merged_complexity:.2f})" - ) - return result - except Exception as e: - logger.warning( - f"Session {session_id}: merged LLM skill_hint '{skill_hint}' not found: {e}" - ) - - # No valid skill_hint — use complexity to decide routing - if merged_complexity < 0.3: - return SkillRoutingResult( - clean_content=content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="merged_llm_low", - match_confidence=1.0 - merged_complexity, - complexity=merged_complexity, - execution_mode=ExecutionMode.DIRECT_CHAT, - ) - elif merged_complexity > 0.7: - # High complexity — delegate to Layer 2 - return SkillRoutingResult( - clean_content=content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="merged_llm_high", - match_confidence=merged_complexity, - complexity=merged_complexity, - execution_mode=ExecutionMode.REACT, - ) - else: - # Medium complexity, no skill match — default agent - return SkillRoutingResult( - clean_content=content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="merged_llm_medium", - match_confidence=0.5, - complexity=merged_complexity, - execution_mode=ExecutionMode.REACT, - ) - except (json.JSONDecodeError, TypeError, ValueError) as e: - logger.warning( - f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default" - ) - return SkillRoutingResult( - clean_content=content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="merged_llm_fallback", - match_confidence=0.5, - complexity=0.5, - execution_mode=ExecutionMode.REACT, - ) - except Exception as e: - logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default") - return SkillRoutingResult( - clean_content=content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="merged_llm_fallback", - match_confidence=0.5, - complexity=0.5, - execution_mode=ExecutionMode.REACT, - ) - - # -- Layer 2: Capability matching / Auction (optional) ----------------- - - def _try_team_upgrade( - self, - result: SkillRoutingResult, - content: str, - complexity: float, - trace: list[dict] | None, - ) -> SkillRoutingResult: - """Attempt to upgrade REACT → TEAM_COLLAB when complexity is high and experts are available.""" - if ( - result.execution_mode == ExecutionMode.REACT - and complexity >= 0.7 - and self._expert_team_router is not None - ): - try: - if self._expert_team_router.can_handle(content): - team_result = self._expert_team_router.resolve(content, complexity) - if team_result.team_mode: - result.execution_mode = ExecutionMode.TEAM_COLLAB - if trace is not None: - trace.append( - { - "layer": 2, - "method": "team_upgrade", - "from_mode": "REACT", - "to_mode": "TEAM_COLLAB", - "team_match_method": team_result.match_method, - "complexity": complexity, - } - ) - except Exception as e: - logger.warning(f"CostAwareRouter team upgrade check failed: {e}") - return result - - async def _route_layer2( - self, - content: str, - skill_registry: Any, - intent_router: Any, - default_tools: list, - default_system_prompt: str | None, - default_model: str, - default_agent_name: str, - agent_tool_registry: Any = None, - session_id: str = "", - complexity: float = 0.0, - trace: list[dict] | None = None, - ) -> SkillRoutingResult: - """Layer 2: 高复杂度任务通过拍卖或 org_context.find_best_agent 路由。""" - # Extract capability-like keywords from content for matching - content_words = _tokenize_content(content) - - # --- Vickrey auction path (when enabled) --- - if ( - self._auction_enabled - and self._auction_house is not None - and self._org_context is not None - ): - try: - # Gather candidate agents from org_context - all_agents = ( - self._org_context.list_agents() - if hasattr(self._org_context, "list_agents") - else [] - ) - # Filter agents that have at least one relevant capability - candidate_agents = [] - for agent_profile in all_agents: - if not agent_profile.availability: - continue - # Check if agent has any of the content_words as capabilities - agent_caps_lower = {c.lower() for c in agent_profile.capabilities} - if any(w.lower() in agent_caps_lower for w in content_words): - candidate_agents.append(agent_profile) - - # Also include agents that match via find_best_agent (they have ALL required caps) - best = self._org_context.find_best_agent(required_capabilities=content_words) - if best is not None: - best_name = best if isinstance(best, str) else getattr(best, "name", str(best)) - existing_names = {a.name for a in candidate_agents} - if best_name not in existing_names: - profile = ( - self._org_context.get_agent_profile(best_name) - if hasattr(self._org_context, "get_agent_profile") - else best - ) - if hasattr(profile, "name"): - candidate_agents.append(profile) - - if len(candidate_agents) >= 1: - # Build Bid objects for each candidate - bids = [] - for agent_profile in candidate_agents: - name = ( - agent_profile.name - if hasattr(agent_profile, "name") - else str(agent_profile) - ) - caps = ( - agent_profile.capabilities - if hasattr(agent_profile, "capabilities") - else [] - ) - arch = ( - agent_profile.agent_type - if hasattr(agent_profile, "agent_type") - else "react" - ) - # Use current_load as a proxy for estimated_cost (higher load → higher cost) - estimated_cost = ( - float(agent_profile.current_load + 1) - if hasattr(agent_profile, "current_load") - else 1.0 - ) - bids.append( - Bid( - agent_name=name, - architecture=arch, - estimated_steps=1, - estimated_cost=estimated_cost, - confidence=0.8, - payment_offer=estimated_cost, - capabilities=caps, - ) - ) - - auction_result = await self._auction_house.run_vickrey_auction( - task_description=content, - bidders=bids, - required_capabilities=content_words, - ) - - if auction_result.winner is not None: - winner_name = auction_result.winner.agent_name - result = SkillRoutingResult( - clean_content=content, - matched=True, - match_method="vickrey_auction", - match_confidence=0.8, - agent_name=winner_name, - model=default_model, - system_prompt=default_system_prompt, - tools=default_tools, - complexity=complexity, - execution_mode=ExecutionMode.REACT, - ) - if trace is not None: - trace.append( - { - "layer": 2, - "method": "vickrey_auction", - "agent_name": winner_name, - "complexity": complexity, - "selection_reason": auction_result.selection_reason, - } - ) - return self._try_team_upgrade(result, content, complexity, trace) - # No winner from auction → fall through to capability matching - except Exception as e: - logger.warning(f"CostAwareRouter Layer 2 Vickrey auction failed: {e}") - - # --- Capability matching path (default) --- - if self._org_context is not None and hasattr(self._org_context, "find_best_agent"): - try: - best_agent = self._org_context.find_best_agent(required_capabilities=content_words) - if best_agent is not None: - agent_name = ( - best_agent - if isinstance(best_agent, str) - else getattr(best_agent, "name", str(best_agent)) - ) - result = SkillRoutingResult( - clean_content=content, - matched=True, - match_method="capability", - match_confidence=0.8, - agent_name=agent_name, - model=default_model, - system_prompt=default_system_prompt, - tools=default_tools, - complexity=complexity, - execution_mode=ExecutionMode.REACT, - ) - if trace is not None: - trace.append( - { - "layer": 2, - "method": "capability", - "agent_name": agent_name, - "complexity": complexity, - } - ) - return self._try_team_upgrade(result, content, complexity, trace) - except Exception as e: - logger.warning(f"CostAwareRouter Layer 2 org_context.find_best_agent failed: {e}") - - # Fallback: high complexity with tools → REACT directly (skip IntentRouter - # which tends to misclassify tool-needing queries as direct_agent) - if complexity >= 0.5 and default_tools: - result = SkillRoutingResult( - clean_content=content, - matched=False, - match_method="complexity_heuristic", - match_confidence=0.7, - agent_name=default_agent_name, - model=default_model, - system_prompt=default_system_prompt, - tools=default_tools, - complexity=complexity, - execution_mode=ExecutionMode.REACT, - ) - if trace is not None: - trace.append( - { - "layer": 2, - "method": "complexity_heuristic_react", - "complexity": complexity, - "reason": "high_complexity_with_tools_skip_intent_router", - } - ) - return self._try_team_upgrade(result, content, complexity, trace) - - # Fallback: 使用 IntentRouter - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - result.complexity = complexity - if trace is not None: - trace.append( - { - "layer": 2, - "method": "intent_router_fallback", - "complexity": complexity, - } - ) - return self._try_team_upgrade(result, content, complexity, trace) - - # -- Main entry point --------------------------------------------------- - - async def route( - self, - content: str, - skill_registry: Any, - intent_router: Any, - default_tools: list, - default_system_prompt: str | None, - default_model: str = "default", - default_agent_name: str = "default", - agent_tool_registry: Any = None, - session_id: str = "", - transparency: str = "SILENT", - ) -> SkillRoutingResult: - """三层成本感知路由主入口。 - - Args: - content: 用户输入内容 - skill_registry: Skill 注册表 - intent_router: IntentRouter 实例 - default_tools: 默认工具列表 - default_system_prompt: 默认系统提示词 - default_model: 默认模型 - default_agent_name: 默认 Agent 名称 - agent_tool_registry: Agent 工具注册表 - session_id: 会话 ID - transparency: 透明度级别 (SILENT / VERBOSE / TRACE) - - Returns: - SkillRoutingResult 包含路由结果和追踪信息 - """ - trace: list[dict] = [] - - tracer = get_tracer() - with tracer.start_span("router.route") as span: - span.set_attribute("input.length", len(content)) - - # ---- Layer 0: Rule-based (zero cost) ---- - match_type, clean_content = self._match_layer0(content) - - if match_type == "explicit_skill": - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - result.match_method = result.match_method or "explicit_skill" - result.complexity = 0.0 - trace.append( - { - "layer": 0, - "method": "explicit_skill", - "matched": result.matched, - "cost": "zero", - } - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", result.match_method or "explicit_skill") - span.set_attribute("route.target", result.skill_name or "default") - return result - - if match_type in ("greeting", "chat_mode", "identity"): - result = SkillRoutingResult( - clean_content=clean_content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method=match_type, - match_confidence=1.0, - complexity=0.0, - execution_mode=ExecutionMode.DIRECT_CHAT, - ) - trace.append( - { - "layer": 0, - "method": match_type, - "matched": False, - "cost": "zero", - } - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", match_type) - span.set_attribute("route.target", "default") - return result - - # ---- Layer 1: Complexity classification ---- - if self._classifier == "heuristic": - complexity = self._heuristic.classify(clean_content) - trace.append( - { - "layer": 1, - "method": "heuristic_classify", - "complexity": complexity, - } - ) - else: - complexity = await self.quick_classify(clean_content) - trace.append( - { - "layer": 1, - "method": "quick_classify", - "complexity": complexity, - } - ) - - # Low complexity → try semantic match, then IntentRouter, then direct chat - if complexity < 0.3: - # Even low-complexity queries may match a skill semantically - if self._semantic_router is not None: - try: - semantic_result = await self._semantic_router.route(clean_content) - if ( - semantic_result.confidence in ("high", "medium") - and semantic_result.skill_name - ): - trace.append( - { - "layer": 1.5, - "method": "semantic_low_complexity_match", - "skill": semantic_result.skill_name, - "similarity": round(semantic_result.similarity, 3), - } - ) - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - force_skill=semantic_result.skill_name, - ) - result.match_method = "semantic_low_complexity" - result.match_confidence = semantic_result.similarity - result.complexity = complexity - if result.matched: - result.execution_mode = _resolve_execution_mode(result.skill_config) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", "semantic_low_complexity") - span.set_attribute("route.target", result.skill_name or "default") - return result - except Exception as e: - logger.warning(f"Semantic routing for low-complexity query failed: {e}") - - # Try IntentRouter keyword match before falling back to direct chat - # Low-complexity queries like "翻译这段话" should still match skills - if skill_registry and intent_router: - try: - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - if result.matched: - result.complexity = complexity - result.match_method = result.match_method or "intent_low_complexity" - trace.append( - { - "layer": 1, - "method": "intent_low_complexity", - "skill": result.skill_name, - "complexity": complexity, - } - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", "intent_low_complexity") - span.set_attribute("route.target", result.skill_name or "default") - return result - except Exception as e: - logger.warning(f"Intent routing for low-complexity query failed: {e}") - - # No semantic or intent match → use REACT if tools available, otherwise direct chat - # Low complexity does NOT mean "no tools needed" — e.g. "查看当前ip" needs shell - result = SkillRoutingResult( - clean_content=clean_content, - system_prompt=default_system_prompt, - tools=default_tools, - model=default_model, - agent_name=default_agent_name, - matched=False, - match_method="low_complexity", - match_confidence=1.0 - complexity, - complexity=complexity, - execution_mode=ExecutionMode.REACT - if default_tools - else ExecutionMode.DIRECT_CHAT, - ) - trace.append( - { - "layer": 1, - "method": "low_complexity", - "complexity": complexity, - "routed_to": "default", - } - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", "low_complexity") - span.set_attribute("route.target", "default") - return result - - # ---- Layer 1.5: Semantic Router (zero LLM cost) ---- - skill_hint = None - if self._semantic_router is not None and complexity >= 0.3: - try: - semantic_result = await self._semantic_router.route(clean_content) - if semantic_result.confidence == "high" and semantic_result.skill_name: - # Direct skill match — skip Layer 2 - trace.append( - { - "layer": 1.5, - "method": "semantic_high", - "skill": semantic_result.skill_name, - "similarity": round(semantic_result.similarity, 3), - "cost": "zero", - } - ) - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - force_skill=semantic_result.skill_name, - ) - result.match_method = "semantic_high" - result.match_confidence = semantic_result.similarity - result.complexity = complexity - if result.matched: - result.execution_mode = _resolve_execution_mode(result.skill_config) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", "semantic_high") - span.set_attribute("route.target", result.skill_name or "default") - return result - elif semantic_result.confidence == "medium" and semantic_result.skill_name: - # Pass skill hint to Layer 1.5 merged classify or Layer 2 - skill_hint = semantic_result.skill_name - trace.append( - { - "layer": 1.5, - "method": "semantic_medium", - "skill_hint": skill_hint, - "similarity": round(semantic_result.similarity, 3), - } - ) - except Exception as e: - logger.warning(f"Semantic routing failed, falling through: {e}") - trace.append( - { - "layer": 1.5, - "method": "semantic_error", - "error": str(e), - } - ) - - # Short text fallback: if semantic router returned low confidence - # and text is short (<20 chars), force LLM classify for better routing. - # BUT: skip LLM fallback when HeuristicClassifier already detected - # high-complexity signals (e.g. "查看ip" has "查看" → complexity >= 0.65). - # In that case the routing outcome is already clear (REACT mode), - # and an extra LLM call would only waste 1-3 seconds. - short_text_llm_hint = None - if ( - skill_hint is None - and len(clean_content) < 20 - and self._merged_llm_classify - and self._llm_gateway is not None - and complexity - < 0.5 # Only trigger LLM fallback for truly ambiguous low-complexity queries - ): - short_text_llm_hint = True - trace.append( - { - "layer": 1.5, - "method": "short_text_llm_fallback", - "reason": "semantic_low + short_text", - } - ) - - # Medium complexity → merged LLM classify or IntentRouter - # Short text with no semantic match forces LLM classify - # BUT: if HeuristicClassifier already detected high-complexity signals - # (complexity >= 0.5), LLM classify tends to override correct routing - # with "direct_agent" — skip it and go straight to IntentRouter - if (complexity <= 0.7 and complexity < 0.5) or short_text_llm_hint: - if self._merged_llm_classify and self._llm_gateway is not None: - # Use merged LLM call: complexity + intent in one call - result = await self._classify_merged( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - complexity=complexity, - ) - # If merged classify returned high complexity, delegate to Layer 2 - if ( - result.complexity > 0.7 - and result.match_method - and result.match_method.startswith("merged_llm_high") - ): - trace.append( - { - "layer": 1, - "method": "merged_llm_high", - "complexity": result.complexity, - "delegated_to_layer2": True, - } - ) - layer2_result = await self._route_layer2( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - complexity=result.complexity, - trace=trace, - ) - layer2_result.execution_trace = trace if transparency != "SILENT" else [] - layer2_result.transparency_level = transparency - return layer2_result - else: - # Fallback: use separate IntentRouter - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - result.complexity = result.complexity if result.complexity > 0 else complexity - trace.append( - { - "layer": 1, - "method": result.match_method or "merged_llm", - "complexity": result.complexity, - "matched": result.matched, - } - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", result.match_method or "merged_llm") - span.set_attribute("route.target", result.skill_name or "default") - return result - - # ---- Layer 2: Capability matching / Auction (high complexity) ---- - trace.append( - { - "layer": 2, - "method": "capability_or_auction", - "complexity": complexity, - "auction_enabled": self._auction_enabled, - } - ) - result = await self._route_layer2( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - complexity=complexity, - trace=trace, - ) - result.execution_trace = trace if transparency != "SILENT" else [] - result.transparency_level = transparency - span.set_attribute("route.layer", result.match_method or "capability") - span.set_attribute("route.target", result.skill_name or result.agent_name or "default") - return result diff --git a/src/agentkit/cli/chat.py b/src/agentkit/cli/chat.py index 40b316b..54c2bea 100644 --- a/src/agentkit/cli/chat.py +++ b/src/agentkit/cli/chat.py @@ -96,11 +96,10 @@ async def _chat_async( WebCrawlTool(), ] - # ── Load skills and build IntentRouter ─────────────────────── + # ── Load skills ──────────────────────────────────────────── from agentkit.tools.registry import ToolRegistry from agentkit.skills.registry import SkillRegistry from agentkit.skills.loader import SkillLoader - from agentkit.router.intent import IntentRouter tool_registry = ToolRegistry() for tool in tools: @@ -123,8 +122,6 @@ async def _chat_async( except Exception: pass - intent_router = IntentRouter(llm_gateway=gateway) if skill_registry.list_skills() else None - # Build system prompt — inject memory into system prompt base_prompt = system_prompt or ( "你是一个有帮助的AI助手。请记住我们对话的上下文,并在后续对话中引用之前的内容。回答要清晰简洁,请使用中文回复。" @@ -218,7 +215,6 @@ async def _chat_async( routing = await resolve_skill_routing( content=user_input, skill_registry=skill_registry, - intent_router=intent_router, default_tools=tools, default_system_prompt=effective_system_prompt, default_model=current_model, diff --git a/src/agentkit/core/plan_exec_engine.py b/src/agentkit/core/plan_exec_engine.py index c00684c..add12f6 100644 --- a/src/agentkit/core/plan_exec_engine.py +++ b/src/agentkit/core/plan_exec_engine.py @@ -25,6 +25,7 @@ from agentkit.core.plan_schema import ExecutionPlan, PlanStep, PlanStepStatus from agentkit.core.protocol import CancellationToken, TaskMessage, TaskResult, TaskStatus from agentkit.core.react import ReActEvent, ReActResult, ReActStep from agentkit.core.shared_workspace import SharedWorkspace +from agentkit.core.spec_manager import Spec, SpecManager, SpecStep from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageResult, StageStatus @@ -73,6 +74,7 @@ class PlanExecEngine: default_timeout: float = 300.0, workspace: SharedWorkspace | None = None, step_event_callback: "Callable[[str, dict[str, Any]], Awaitable[None]] | None" = None, + spec_manager: SpecManager | None = None, ): """ Args: @@ -81,12 +83,14 @@ class PlanExecEngine: default_timeout: 默认超时秒数 workspace: SharedWorkspace 实例,用于步骤间状态传递 step_event_callback: 步骤事件回调,用于非流式执行时推送进度 + spec_manager: SpecManager 实例,用于持久化执行计划为 Spec 文档 """ self._llm_gateway = llm_gateway self._max_replans = max_replans self._default_timeout = default_timeout self._workspace = workspace self._step_event_callback = step_event_callback + self._spec_manager = spec_manager self._confirmation_handler: Any | None = None # 组合子组件 @@ -261,6 +265,17 @@ class PlanExecEngine: tokens=0, )) + # Persist plan as Spec if spec_manager is provided + if self._spec_manager is not None: + spec = self._plan_to_spec(plan) + self._spec_manager.create(spec) + state.step_counter += 1 + yield ReActEvent( + event_type="spec_created", + step=state.step_counter, + data={"spec_id": spec.spec_id, "goal": spec.goal, "num_steps": len(spec.steps)}, + ) + # ── Phase 2 & 3: Execute with optional replanning ── current_plan = plan replan_count = 0 @@ -509,6 +524,20 @@ class PlanExecEngine: tokens=0, )) + # Persist plan as Spec if spec_manager is provided + if self._spec_manager is not None: + spec = self._plan_to_spec(plan) + self._spec_manager.create(spec) + if self._step_event_callback: + try: + await self._step_event_callback("spec_created", { + "spec_id": spec.spec_id, + "goal": spec.goal, + "num_steps": len(spec.steps), + }) + except Exception as e: + logger.warning(f"Step event callback failed: {e}") + if trace_recorder is not None: trace_recorder.record_step( step=1, @@ -613,7 +642,7 @@ class PlanExecEngine: task_id=task_id, ) - # 创建 PlanExecutor(使用 LLM 直接调用模式) + # 创建 PlanExecutor executor = self._create_executor( messages=messages, model=model, @@ -734,6 +763,25 @@ class PlanExecEngine: # 辅助方法 # ------------------------------------------------------------------ + @staticmethod + def _plan_to_spec(plan: ExecutionPlan) -> Spec: + """Convert an ExecutionPlan to a Spec for persistence.""" + steps = [ + SpecStep( + step_id=s.step_id, + name=s.name, + description=s.description, + dependencies=s.dependencies, + ) + for s in plan.steps + ] + return Spec( + spec_id=plan.plan_id, + goal=plan.goal, + steps=steps, + metadata=plan.metadata, + ) + @staticmethod def _extract_goal(messages: list[dict[str, str]]) -> str: """从消息列表中提取用户目标""" @@ -779,31 +827,16 @@ class PlanExecEngine: model: str, system_prompt: str | None, tools: list["Tool"] | None, - step_executor_type: str = "react", ) -> PlanExecutor: - """创建 PlanExecutor 实例 - - Args: - step_executor_type: "react" 使用 ReActStepExecutor(默认,支持工具调用), - "llm" 使用 _LLMStepExecutor(纯 LLM 调用,无工具) - """ - if step_executor_type == "llm": - step_executor: _LLMStepExecutor | ReActStepExecutor = _LLMStepExecutor( - llm_gateway=self._llm_gateway, - messages=messages, - model=model, - system_prompt=system_prompt, - tools=tools, - ) - else: - step_executor = ReActStepExecutor( - llm_gateway=self._llm_gateway, - messages=messages, - model=model, - system_prompt=system_prompt, - tools=tools, - confirmation_handler=self._confirmation_handler, - ) + """创建 PlanExecutor 实例,使用 ReActStepExecutor 执行步骤""" + step_executor = ReActStepExecutor( + llm_gateway=self._llm_gateway, + messages=messages, + model=model, + system_prompt=system_prompt, + tools=tools, + confirmation_handler=self._confirmation_handler, + ) return PlanExecutor( agent_pool=step_executor, max_retries=1, @@ -937,58 +970,6 @@ class PlanExecEngine: return "\n\n".join(parts) -class _LLMStepExecutor: - """LLM 直接调用步骤执行器 - - 作为 PlanExecutor 的 agent_pool 替代品, - 使每个 PlanStep 通过 LLM 直接调用执行,而非通过 AgentPool。 - """ - - def __init__( - self, - llm_gateway: "LLMGateway | None" = None, - messages: list[dict[str, str]] | None = None, - model: str = "default", - system_prompt: str | None = None, - tools: list["Tool"] | None = None, - ): - self._llm_gateway = llm_gateway - self._messages = messages or [] - self._model = model - self._system_prompt = system_prompt - self._tools = tools - self._agents: dict[str, _LLMStepAgent] = {} - - async def create_agent_from_skill(self, skill_name: str): - """创建 LLM 步骤 Agent""" - agent = _LLMStepAgent( - name=skill_name, - llm_gateway=self._llm_gateway, - messages=self._messages, - model=self._model, - system_prompt=self._system_prompt, - tools=self._tools, - ) - self._agents[skill_name] = agent - return agent - - def get_agent(self, key: str): - """获取已创建的 Agent""" - if key in self._agents: - return self._agents[key] - # 回退:创建一个默认 Agent - agent = _LLMStepAgent( - name=key, - llm_gateway=self._llm_gateway, - messages=self._messages, - model=self._model, - system_prompt=self._system_prompt, - tools=self._tools, - ) - self._agents[key] = agent - return agent - - class ReActStepExecutor: """ReAct 循环步骤执行器 @@ -1132,69 +1113,3 @@ class _ReActStepAgent: started_at=now, completed_at=now, ) - - -class _LLMStepAgent: - """LLM 直接调用步骤 Agent - - 将 PlanStep 的描述作为 prompt 发送给 LLM, - 返回 LLM 的响应作为步骤结果。 - """ - - def __init__( - self, - name: str, - llm_gateway: "LLMGateway | None" = None, - messages: list[dict[str, str]] | None = None, - model: str = "default", - system_prompt: str | None = None, - tools: list["Tool"] | None = None, - ): - self.name = name - self._llm_gateway = llm_gateway - self._messages = messages or [] - self._model = model - self._system_prompt = system_prompt - self._tools = tools - - async def execute(self, task_msg: TaskMessage) -> "TaskResult": - """执行步骤:通过 LLM 直接调用""" - if self._llm_gateway is None: - raise RuntimeError(f"No LLM gateway available for step '{task_msg.task_id}'") - - # 构建步骤 prompt - input_data = task_msg.input_data - step_name = input_data.get("step_name", task_msg.task_id) - step_description = input_data.get("step_description", "") - dep_results = input_data.get("dependency_results", {}) - - prompt_parts = [f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"] - - if dep_results: - prompt_parts.append(f"\nResults from previous steps:\n{json.dumps(dep_results, ensure_ascii=False, indent=2)}") - - prompt_parts.append("\nProvide a clear, structured result for this step.") - - conversation: list[dict[str, Any]] = [] - if self._system_prompt: - conversation.append({"role": "system", "content": self._system_prompt}) - # 添加原始对话上下文 - for msg in self._messages: - conversation.append(msg) - conversation.append({"role": "user", "content": "\n".join(prompt_parts)}) - - response = await self._llm_gateway.chat( - messages=conversation, - model=self._model, - ) - - now = datetime.now(timezone.utc) - return TaskResult( - task_id=task_msg.task_id, - agent_name=self.name, - status=TaskStatus.COMPLETED.value, - output_data={"content": response.content or ""}, - error_message=None, - started_at=now, - completed_at=now, - ) diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 01e7dc7..aa7925f 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -74,15 +74,53 @@ class ReActEngine: 使 Agent 能够自主推理并选择工具完成任务。 """ - def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False): + # Default core tools that always get full descriptions injected into the + # prompt. ``tool_search`` is included so its full description is always + # available to the LLM when tiered injection is active. + _DEFAULT_CORE_TOOLS: tuple[str, ...] = ( + "read_file", + "write_file", + "bash", + "search", + "tool_search", + ) + + def __init__( + self, + llm_gateway: LLMGateway, + max_steps: int = 10, + default_timeout: float = 300.0, + parallel_tools: bool | str = False, + compressor: "CompressionStrategy | None" = None, + verification_enabled: bool = False, + verification_commands: list[str] | None = None, + core_tool_names: list[str] | None = None, + enable_tool_search: bool = True, + ): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") if isinstance(parallel_tools, str) and parallel_tools not in ("auto",): - raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}") + raise ValueError( + f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}" + ) self._llm_gateway = llm_gateway self._max_steps = max_steps self._default_timeout = default_timeout self._parallel_tools = parallel_tools + self._verification_enabled = verification_enabled + self._verification_commands = verification_commands + # Tiered tool description injection config + self._core_tool_names: tuple[str, ...] | None = ( + tuple(core_tool_names) if core_tool_names is not None else None + ) + self._enable_tool_search = enable_tool_search + # Default context compression: keep last 10 turns + if compressor is not None: + self._compressor = compressor + else: + from agentkit.core.compressor import ContextCompressor + + self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10) def reset(self) -> None: """Reset internal state for reuse across conversations. @@ -120,10 +158,14 @@ class ReActEngine: 4. 返回 ReActResult 包含输出和轨迹 Args: + compressor: 压缩策略,None 时使用实例默认压缩器 cancellation_token: 协作式取消令牌,每次循环迭代检查是否已取消 timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout """ - effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout + effective_compressor = compressor if compressor is not None else self._compressor + effective_timeout = ( + timeout_seconds if timeout_seconds is not None else self._default_timeout + ) try: if effective_timeout > 0: @@ -138,7 +180,7 @@ class ReActEngine: trace_recorder=trace_recorder, memory_retriever=memory_retriever, task_id=task_id, - compressor=compressor, + compressor=effective_compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, confirmation_handler=confirmation_handler, @@ -156,7 +198,7 @@ class ReActEngine: trace_recorder=trace_recorder, memory_retriever=memory_retriever, task_id=task_id, - compressor=compressor, + compressor=effective_compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, confirmation_handler=confirmation_handler, @@ -188,6 +230,8 @@ class ReActEngine: confirmation_handler: Any | None = None, ) -> ReActResult: tools = tools or [] + if tools: + tools = self._maybe_add_tool_search(tools) tool_schemas = self._build_tool_schemas(tools) if tools else None if tool_schemas: tool_names = [s["function"]["name"] for s in tool_schemas] @@ -205,7 +249,9 @@ class ReActEngine: system_prompt = self._build_tool_use_prompt(tools) # Telemetry: record agent request - agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"}) + agent_request_counter().add( + 1, {"agent.name": agent_name, "agent.type": task_type or "react"} + ) # Start telemetry span for the entire agent execution _span_cm = None @@ -250,7 +296,9 @@ class ReActEngine: else: system_prompt = f"## 参考信息\n{memory_context}" except Exception as e: - logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True) + logger.warning( + f"Memory retrieval failed, continuing without context: {e}", exc_info=True + ) # 构建初始消息 conversation: list[dict[str, Any]] = [] @@ -263,7 +311,9 @@ class ReActEngine: try: conversation = await compressor.compress(conversation) except Exception as e: - logger.warning(f"Context compression failed, continuing with original messages: {e}") + logger.warning( + f"Context compression failed, continuing with original messages: {e}" + ) trace_outcome = "success" step = 0 @@ -323,9 +373,19 @@ class ReActEngine: # 执行工具调用 if self._parallel_tools == "auto" and len(response.tool_calls) > 1: # Auto mode: mixed parallel/serial based on _parallelizable flag - parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) - serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] - parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] + parallelizable_set = set( + self._get_parallelizable_indices(response.tool_calls) + ) + serial_calls = [ + (i, tc) + for i, tc in enumerate(response.tool_calls) + if i not in parallelizable_set + ] + parallel_calls = [ + (i, tc) + for i, tc in enumerate(response.tool_calls) + if i in parallelizable_set + ] # Result slots indexed by original position all_results: list[Any] = [None] * len(response.tool_calls) @@ -340,7 +400,10 @@ class ReActEngine: # Execute parallelizable tools in parallel if len(parallel_calls) > 1: para_results = await asyncio.gather( - *[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], + *[ + self._execute_tool(tc.name, tc.arguments, tools) + for _, tc in parallel_calls + ], return_exceptions=True, ) for j, (i, tc) in enumerate(parallel_calls): @@ -381,12 +444,17 @@ class ReActEngine: error=tool_error, ) - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, compressor, tc.name + ) conversation.append(tool_msg) elif self._should_execute_parallel(response.tool_calls): # 并行执行多个工具调用 (parallel_tools=True) tool_results = await asyncio.gather( - *[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], + *[ + self._execute_tool(tc.name, tc.arguments, tools) + for tc in response.tool_calls + ], return_exceptions=True, ) for idx, tc in enumerate(response.tool_calls): @@ -419,7 +487,9 @@ class ReActEngine: error=tool_error, ) - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, compressor, tc.name + ) conversation.append(tool_msg) else: # 串行执行(单工具或 parallel_tools=False) @@ -428,7 +498,9 @@ class ReActEngine: tool_result = await self._execute_tool(tc.name, tc.arguments, tools) # Handle confirmation flow - if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + if isinstance(tool_result, dict) and tool_result.get( + "needs_confirmation" + ): confirmation_id = tool_result["confirmation_id"] command = tool_result.get("command", "") reason = tool_result.get("reason", "") @@ -436,28 +508,46 @@ class ReActEngine: approved = False if confirmation_handler is not None: try: - approved = await confirmation_handler(confirmation_id, command, reason) + approved = await confirmation_handler( + confirmation_id, command, reason + ) except Exception as e: logger.warning(f"Confirmation handler error: {e}") if approved: tool = self._find_tool(tc.name, tools) - if tool and hasattr(tool, '_is_dangerous'): - clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + if tool and hasattr(tool, "_is_dangerous"): + clean_args = { + k: v + for k, v in tc.arguments.items() + if not k.startswith("_") + } clean_args["_skip_dangerous_check"] = True try: tool_result = await tool.safe_execute(**clean_args) except Exception as e: - tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} + tool_result = { + "error": f"Tool '{tc.name}' execution failed: {e}" + } else: # Non-dangerous tool: confirmation was for the overall action, # re-execute with skip flag to avoid re-triggering confirmation - clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args = { + k: v + for k, v in tc.arguments.items() + if not k.startswith("_") + } clean_args["_skip_dangerous_check"] = True try: - tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + tool_result = ( + await tool.safe_execute(**clean_args) + if tool + else {"error": f"Tool '{tc.name}' not found"} + ) except Exception as e: - tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} + tool_result = { + "error": f"Tool '{tc.name}' execution failed: {e}" + } else: tool_result = { "output": "", @@ -496,7 +586,9 @@ class ReActEngine: ) # Observe: 将工具结果添加到对话历史 - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, compressor, tc.name + ) conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long @@ -524,7 +616,9 @@ class ReActEngine: for pc in parsed_calls: tool_start = time.monotonic() - tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools) + tool_result = await self._execute_tool( + pc["name"], pc["arguments"], tools + ) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) react_step = ReActStep( @@ -554,7 +648,9 @@ class ReActEngine: ) # 将工具结果添加到对话历史 - tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]) + tool_msg = await self._build_tool_result_message( + pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"] + ) conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long @@ -585,6 +681,35 @@ class ReActEngine: ) break + # Verification: 如果启用验证,在 final answer 后运行测试 + if self._verification_enabled and output: + try: + from agentkit.core.verification_loop import VerificationLoop + + vloop = VerificationLoop(commands=self._verification_commands) + vresult = await vloop.verify() + if not vresult.passed: + # 将验证失败信息作为 ReActStep 添加到轨迹 + verification_step = ReActStep( + step=step + 1, + action="tool_call", + tool_name="verification", + arguments={"commands": self._verification_commands}, + result={ + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + content=(f"Verification failed:\n{vresult.test_output[:2000]}"), + ) + trajectory.append(verification_step) + logger.info( + "Verification failed after final answer, " + "appended feedback to trajectory" + ) + except Exception as e: + logger.warning(f"Verification loop failed: {e}") + # 达到 max_steps 时,返回当前最佳输出 if step >= self._max_steps and not output: trace_outcome = "partial" @@ -599,6 +724,7 @@ class ReActEngine: # 兜底:确保 output 永远不为空字符串 if not output or not output.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED + if step >= self._max_steps: output = MAX_STEPS_REACHED else: @@ -660,8 +786,14 @@ class ReActEngine: Same logic as execute() but yields events at each step instead of accumulating a result. + + Args: + compressor: 压缩策略,None 时使用实例默认压缩器 """ + effective_compressor = compressor if compressor is not None else self._compressor tools = tools or [] + if tools: + tools = self._maybe_add_tool_search(tools) tool_schemas = self._build_tool_schemas(tools) if tools else None if tool_schemas: tool_names = [s["function"]["name"] for s in tool_schemas] @@ -679,7 +811,9 @@ class ReActEngine: system_prompt = self._build_tool_use_prompt(tools) # Telemetry: record agent request - agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"}) + agent_request_counter().add( + 1, {"agent.name": agent_name, "agent.type": task_type or "react"} + ) # Start telemetry span for the entire agent execution _span_cm = None @@ -726,11 +860,13 @@ class ReActEngine: conversation.extend(messages) # Context compression: 压缩超长对话历史 - if compressor: + if effective_compressor: try: - conversation = await compressor.compress(conversation) + conversation = await effective_compressor.compress(conversation) except Exception as e: - logger.warning(f"Context compression failed, continuing with original messages: {e}") + logger.warning( + f"Context compression failed, continuing with original messages: {e}" + ) trajectory: list[ReActStep] = [] total_tokens = 0 @@ -738,7 +874,9 @@ class ReActEngine: output = "" trace_outcome = "success" _stream_start = time.monotonic() - effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout + effective_timeout = ( + timeout_seconds if timeout_seconds is not None else self._default_timeout + ) try: while step < self._max_steps: @@ -836,19 +974,44 @@ class ReActEngine: conversation.append(assistant_msg) # Execute tool calls with parallel support - if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls): + if ( + self._parallel_tools + and len(response.tool_calls) > 1 + and self._should_execute_parallel(response.tool_calls) + ): # Parallel execution path - parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls))) - serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] - parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] + parallelizable_set = ( + set(self._get_parallelizable_indices(response.tool_calls)) + if self._parallel_tools == "auto" + else set(range(len(response.tool_calls))) + ) + serial_calls = [ + (i, tc) + for i, tc in enumerate(response.tool_calls) + if i not in parallelizable_set + ] + parallel_calls = [ + (i, tc) + for i, tc in enumerate(response.tool_calls) + if i in parallelizable_set + ] all_results: list[Any] = [None] * len(response.tool_calls) # Execute serial tools first (handles confirmation flow) for i, tc in serial_calls: - yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) + yield ReActEvent( + event_type="tool_call", + step=step, + data={"tool_name": tc.name, "arguments": tc.arguments}, + ) tool_start = time.monotonic() - tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler) + ( + tool_result, + confirm_events, + ) = await self._execute_tool_with_confirmation( + tc, tools, step, confirmation_handler + ) for ev in confirm_events: yield ev tool_duration_ms = int((time.monotonic() - tool_start) * 1000) @@ -857,7 +1020,10 @@ class ReActEngine: # Execute parallelizable tools concurrently if len(parallel_calls) > 1: para_results = await asyncio.gather( - *[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], + *[ + self._execute_tool(tc.name, tc.arguments, tools) + for _, tc in parallel_calls + ], return_exceptions=True, ) for j, (i, tc) in enumerate(parallel_calls): @@ -873,19 +1039,45 @@ class ReActEngine: # Process all results in original order for i, tc in enumerate(response.tool_calls): tc_obj, tool_result, tool_duration_ms = all_results[i] - yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) + yield ReActEvent( + event_type="tool_call", + step=step, + data={"tool_name": tc.name, "arguments": tc.arguments}, + ) - react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens) + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=tc.name, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) trajectory.append(react_step) if trace_recorder is not None: tool_error = None if isinstance(tool_result, dict) and "error" in tool_result: tool_error = tool_result["error"] - trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error) + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=tool_duration_ms, + tokens_used=0, + error=tool_error, + ) - yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result}) - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + yield ReActEvent( + event_type="tool_result", + step=step, + data={"tool_name": tc.name, "result": tool_result}, + ) + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, effective_compressor, tc.name + ) conversation.append(tool_msg) else: # Serial execution path (with confirmation flow) @@ -902,7 +1094,9 @@ class ReActEngine: tool_duration_ms = int((time.monotonic() - tool_start) * 1000) # 检测工具返回的确认请求 - if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + if isinstance(tool_result, dict) and tool_result.get( + "needs_confirmation" + ): confirmation_id = tool_result["confirmation_id"] command = tool_result.get("command", "") reason = tool_result.get("reason", "") @@ -923,16 +1117,22 @@ class ReActEngine: approved = False if confirmation_handler is not None: try: - approved = await confirmation_handler(confirmation_id, command, reason) + approved = await confirmation_handler( + confirmation_id, command, reason + ) except Exception as e: logger.warning(f"Confirmation handler error: {e}") if approved: # 用户确认执行:使用 per-call override 绕过安全检查 tool = self._find_tool(tc.name, tools) - if tool and hasattr(tool, '_is_dangerous'): + if tool and hasattr(tool, "_is_dangerous"): # Strip internal metadata and pass skip_dangerous_check flag - clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args = { + k: v + for k, v in tc.arguments.items() + if not k.startswith("_") + } clean_args["_skip_dangerous_check"] = True try: tool_result = await tool.safe_execute(**clean_args) @@ -940,12 +1140,22 @@ class ReActEngine: pass # No shared state mutation needed else: # Non-dangerous tool: re-execute with skip flag - clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args = { + k: v + for k, v in tc.arguments.items() + if not k.startswith("_") + } clean_args["_skip_dangerous_check"] = True try: - tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + tool_result = ( + await tool.safe_execute(**clean_args) + if tool + else {"error": f"Tool '{tc.name}' not found"} + ) except Exception as e: - tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} + tool_result = { + "error": f"Tool '{tc.name}' execution failed: {e}" + } yield ReActEvent( event_type="confirmation_result", @@ -964,7 +1174,10 @@ class ReActEngine: yield ReActEvent( event_type="confirmation_result", step=step, - data={"confirmation_id": confirmation_id, "approved": False}, + data={ + "confirmation_id": confirmation_id, + "approved": False, + }, ) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) @@ -1001,13 +1214,15 @@ class ReActEngine: data={"tool_name": tc.name, "result": tool_result}, ) - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + tool_msg = await self._build_tool_result_message( + tc.id, tool_result, effective_compressor, tc.name + ) conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, compressor): + if self._should_compress(conversation, effective_compressor): try: - conversation = await compressor.compress(conversation) + conversation = await effective_compressor.compress(conversation) except Exception as e: logger.warning(f"Incremental compression failed: {e}") @@ -1033,16 +1248,20 @@ class ReActEngine: data={"tool_name": pc["name"], "arguments": pc["arguments"]}, ) tool_start = time.monotonic() - tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools) + tool_result = await self._execute_tool( + pc["name"], pc["arguments"], tools + ) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - trajectory.append(ReActStep( - step=step, - action="tool_call", - tool_name=pc["name"], - arguments=pc["arguments"], - result=tool_result, - tokens=step_tokens, - )) + trajectory.append( + ReActStep( + step=step, + action="tool_call", + tool_name=pc["name"], + arguments=pc["arguments"], + result=tool_result, + tokens=step_tokens, + ) + ) # 记录工具调用步骤 if trace_recorder is not None: tool_error = None @@ -1064,14 +1283,17 @@ class ReActEngine: data={"tool_name": pc["name"], "result": tool_result}, ) tool_msg = await self._build_tool_result_message( - pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"] + pc.get("id", f"text_tc_{step}"), + tool_result, + effective_compressor, + pc["name"], ) conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long - if self._should_compress(conversation, compressor): + if self._should_compress(conversation, effective_compressor): try: - conversation = await compressor.compress(conversation) + conversation = await effective_compressor.compress(conversation) except Exception as e: logger.warning(f"Incremental compression failed: {e}") else: @@ -1106,6 +1328,46 @@ class ReActEngine: ) break + # Verification: 如果启用验证,在 final answer 后运行测试 + if self._verification_enabled and output: + try: + from agentkit.core.verification_loop import VerificationLoop + + vloop = VerificationLoop(commands=self._verification_commands) + vresult = await vloop.verify() + if not vresult.passed: + verification_step = ReActStep( + step=step + 1, + action="tool_call", + tool_name="verification", + arguments={"commands": self._verification_commands}, + result={ + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + content=(f"Verification failed:\n{vresult.test_output[:2000]}"), + ) + trajectory.append(verification_step) + yield ReActEvent( + event_type="tool_result", + step=step + 1, + data={ + "tool_name": "verification", + "result": { + "passed": vresult.passed, + "errors": vresult.errors, + "test_output": vresult.test_output, + }, + }, + ) + logger.info( + "Verification failed after final answer, " + "appended feedback to trajectory" + ) + except Exception as e: + logger.warning(f"Verification loop failed: {e}") + if step >= self._max_steps and not output: trace_outcome = "partial" if trajectory and trajectory[-1].content: @@ -1129,6 +1391,7 @@ class ReActEngine: # 兜底:确保 output 永远不为空字符串 if not output or not output.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED + if step >= self._max_steps: output = MAX_STEPS_REACHED else: @@ -1187,33 +1450,40 @@ class ReActEngine: schemas.append(schema) return schemas - @staticmethod - def _build_tool_use_prompt(tools: list[Tool]) -> str: - """Build prompt-based tool calling instructions for LLMs that don't - support native function calling (e.g., Bailian Coding, Qwen). + def _build_tool_use_prompt(self, tools: list[Tool]) -> str: + """Build prompt-based tool calling instructions with tiered injection. - Instructs the LLM to use XML format for tool invocation. - This follows the Hermes pattern: model-agnostic prompt-based tool calling. + Core tools (defined by ``self._core_tool_names`` or + :attr:`_DEFAULT_CORE_TOOLS`) get full descriptions (name + + description + parameters). Extended tools get only name + a + one-line description. When ``tool_search`` is present alongside + extended tools, a hint is added telling the LLM to call + ``tool_search`` for full parameter details. + + Instructs the LLM to use ```` XML format for tool + invocation (Hermes pattern: model-agnostic prompt-based tool calling). """ - tool_descriptions = [] - for tool in tools: - params_desc = "" - if tool.input_schema: - props = tool.input_schema.get("properties", {}) - required = tool.input_schema.get("required", []) - param_parts = [] - for pname, pinfo in props.items(): - ptype = pinfo.get("type", "string") - pdesc = pinfo.get("description", "") - req_flag = " (required)" if pname in required else "" - param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}") - if param_parts: - params_desc = "\n".join(param_parts) - tool_descriptions.append( - f"- {tool.name}: {tool.description}\n{params_desc}" + core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS) + core_tools = [t for t in tools if t.name in core_names] + extended_tools = [t for t in tools if t.name not in core_names] + + sections: list[str] = [] + if core_tools: + sections.append(self._render_core_tools(core_tools)) + if extended_tools: + sections.append(self._render_extended_tools(extended_tools)) + + tools_text = "\n\n".join(sections) + + has_tool_search = any(t.name == "tool_search" for t in tools) + search_hint = "" + if has_tool_search and extended_tools: + search_hint = ( + "\n\n注意:上方「扩展工具」仅显示名称和简短描述。" + '如需使用某个扩展工具,请先调用 tool_search(query="关键词") ' + "获取其完整参数说明。" ) - tools_text = "\n\n".join(tool_descriptions) return ( "## 可用工具\n\n" "你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n" @@ -1225,9 +1495,65 @@ class ReActEngine: "2. 等待工具返回结果后再决定下一步\n" "3. 如果不需要工具就能回答,直接回答即可\n" "4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n" - f"工具列表:\n\n{tools_text}" + f"工具列表:\n\n{tools_text}{search_hint}" ) + @staticmethod + def _render_core_tools(tools: list[Tool]) -> str: + """Render core tools with full descriptions (name + description + parameters).""" + descriptions: list[str] = [] + for tool in tools: + params_desc = "" + if tool.input_schema: + props = tool.input_schema.get("properties", {}) + required = tool.input_schema.get("required", []) + param_parts: list[str] = [] + for pname, pinfo in props.items(): + ptype = pinfo.get("type", "string") + pdesc = pinfo.get("description", "") + req_flag = " (required)" if pname in required else "" + param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}") + if param_parts: + params_desc = "\n".join(param_parts) + descriptions.append(f"- {tool.name}: {tool.description}\n{params_desc}") + return "### 核心工具(完整描述)\n\n" + "\n\n".join(descriptions) + + @staticmethod + def _render_extended_tools(tools: list[Tool]) -> str: + """Render extended tools with name + one-line description only.""" + lines: list[str] = [] + for tool in tools: + desc = tool.description.strip().split("\n")[0] + if len(desc) > 100: + desc = desc[:97] + "..." + lines.append(f"- {tool.name}: {desc}") + return "### 扩展工具(仅名称和简短描述,使用 tool_search 获取详情)\n\n" + "\n".join(lines) + + def _maybe_add_tool_search(self, tools: list[Tool]) -> list[Tool]: + """Add ``tool_search`` tool if enabled and there are extended tools. + + Builds a :class:`ToolSearchIndex` from the extended tools so the + LLM can discover full tool descriptions on demand via BM25 search. + If all tools are core tools, or ``tool_search`` is already present, + or ``enable_tool_search`` is False, the list is returned unchanged. + """ + if not self._enable_tool_search: + return tools + if any(t.name == "tool_search" for t in tools): + return tools + + core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS) + extended_tools = [t for t in tools if t.name not in core_names] + if not extended_tools: + return tools + + from agentkit.tools.builtin import ToolSearchTool + from agentkit.tools.search import ToolSearchIndex + + index = ToolSearchIndex(extended_tools) + search_tool = ToolSearchTool(search_index=index) + return tools + [search_tool] + @staticmethod def _build_response_from_stream( content: str, @@ -1237,6 +1563,7 @@ class ReActEngine: ) -> LLMResponse: """Build an LLMResponse from accumulated stream chunks.""" from agentkit.llm.protocol import LLMResponse, TokenUsage + if usage is None: usage = TokenUsage() return LLMResponse( @@ -1256,7 +1583,9 @@ class ReActEngine: # Default token threshold for incremental compression _DEFAULT_COMPRESS_THRESHOLD = 8000 - def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool: + def _should_compress( + self, conversation: list[dict], compressor: "CompressionStrategy | None" + ) -> bool: """检查是否需要增量压缩""" if not compressor: return False @@ -1331,16 +1660,18 @@ class ReActEngine: command = tool_result.get("command", "") reason = tool_result.get("reason", "") - events.append(ReActEvent( - event_type="confirmation_request", - step=step, - data={ - "confirmation_id": confirmation_id, - "tool_name": tc.name, - "command": command, - "reason": reason, - }, - )) + events.append( + ReActEvent( + event_type="confirmation_request", + step=step, + data={ + "confirmation_id": confirmation_id, + "tool_name": tc.name, + "command": command, + "reason": reason, + }, + ) + ) # Wait for user confirmation approved = False @@ -1353,7 +1684,7 @@ class ReActEngine: if approved: # User approved: re-execute with _skip_dangerous_check tool = self._find_tool(tc.name, tools) - if tool and hasattr(tool, '_is_dangerous'): + if tool and hasattr(tool, "_is_dangerous"): clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args["_skip_dangerous_check"] = True try: @@ -1365,15 +1696,21 @@ class ReActEngine: clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} clean_args["_skip_dangerous_check"] = True try: - tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"} + tool_result = ( + await tool.safe_execute(**clean_args) + if tool + else {"error": f"Tool '{tc.name}' not found"} + ) except Exception as e: tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} - events.append(ReActEvent( - event_type="confirmation_result", - step=step, - data={"confirmation_id": confirmation_id, "approved": True}, - )) + events.append( + ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": True}, + ) + ) else: # User rejected tool_result = { @@ -1383,11 +1720,13 @@ class ReActEngine: "error_type": "permission_denied", "message": f"用户拒绝执行命令: {command[:100]}", } - events.append(ReActEvent( - event_type="confirmation_result", - step=step, - data={"confirmation_id": confirmation_id, "approved": False}, - )) + events.append( + ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": False}, + ) + ) return tool_result, events @@ -1418,7 +1757,7 @@ class ReActEngine: """ indices = [] for i, tc in enumerate(tool_calls): - args = tc.arguments if hasattr(tc, 'arguments') else {} + args = tc.arguments if hasattr(tc, "arguments") else {} if isinstance(args, dict) and args.get("_parallelizable") is True: indices.append(i) return indices @@ -1434,9 +1773,7 @@ class ReActEngine: calls: list[dict[str, Any]] = [] # 格式 1: Action: tool_name(args) - action_pattern = re.compile( - r"Action:\s*(\w+)\((.+?)\)", re.DOTALL - ) + action_pattern = re.compile(r"Action:\s*(\w+)\((.+?)\)", re.DOTALL) for match in action_pattern.finditer(content): name = match.group(1) args_str = match.group(2) @@ -1450,9 +1787,7 @@ class ReActEngine: return calls # 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n``` - code_block_pattern = re.compile( - r"```tool\s*\n(.*?)\n\s*```", re.DOTALL - ) + code_block_pattern = re.compile(r"```tool\s*\n(.*?)\n\s*```", re.DOTALL) for match in code_block_pattern.finditer(content): json_str = match.group(1).strip() try: @@ -1469,9 +1804,7 @@ class ReActEngine: # 格式 3: \n{"name": "...", "arguments": {...}}\n # 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式 - tool_use_pattern = re.compile( - r"\s*(.*?)\s*", re.DOTALL - ) + tool_use_pattern = re.compile(r"\s*(.*?)\s*", re.DOTALL) for match in tool_use_pattern.finditer(content): json_str = match.group(1).strip() try: diff --git a/src/agentkit/experts/router.py b/src/agentkit/experts/router.py index 8dc3bb4..98ff7ac 100644 --- a/src/agentkit/experts/router.py +++ b/src/agentkit/experts/router.py @@ -1,4 +1,11 @@ -"""Expert Team routing — resolves user input to ExpertTeam configuration.""" +"""Expert Team routing — resolves @team prefix to ExpertTeam configuration. + +简化说明(U3): +- 仅通过 @team 前缀触发团队模式 +- 移除基于复杂度的自动建议 +- 保留 @team:expert1,expert2 指定专家成员 +- 保留 resolve_expert_configs 解析专家配置 +""" import logging import re @@ -21,37 +28,37 @@ MAX_EXPERTS = 10 # Maximum number of experts in a team @dataclass class ExpertTeamRoutingResult: - """Result of expert team routing resolution.""" + """Result of expert team routing resolution. + + In hub-and-spoke mode, routing is triggered exclusively by the @team prefix. + """ matched: bool = False team_mode: bool = False specified_experts: list[str] = field(default_factory=list) task_content: str = "" auto_compose: bool = False - complexity: float = 0.0 - match_method: str = "" # "explicit_team" | "complexity_suggestion" + match_method: str = "" # "explicit_team" | "" class ExpertTeamRouter: - """Routes user input to Expert Team mode. + """Routes user input to Expert Team mode via @team prefix. Supports: - - @team prefix → trigger team mode - - @team:analyst,strategist → specify team members - - High complexity → suggest team mode upgrade + - @team prefix → trigger team mode (auto-compose members) + - @team:analyst,strategist → specify team members by name """ - COMPLEXITY_THRESHOLD = 0.7 # Above this, suggest team mode - def __init__(self, template_registry: ExpertTemplateRegistry | None = None): self._registry = template_registry or ExpertTemplateRegistry() - def resolve(self, content: str, complexity: float = 0.0) -> ExpertTeamRoutingResult: + def resolve(self, content: str) -> ExpertTeamRoutingResult: """Resolve user input to an ExpertTeamRoutingResult. + Only @team prefix triggers team mode. No complexity-based suggestion. + Args: content: User's input message - complexity: Pre-computed complexity score (0.0-1.0) Returns: ExpertTeamRoutingResult with routing decision @@ -94,27 +101,15 @@ class ExpertTeamRouter: return result - # Check complexity-based suggestion - if complexity >= self.COMPLEXITY_THRESHOLD: - result.matched = True - result.team_mode = True - result.auto_compose = True - result.complexity = complexity - result.task_content = content - result.match_method = "complexity_suggestion" - return result - # Not a team mode request result.matched = False result.team_mode = False result.task_content = content - result.complexity = complexity return result def can_handle(self, content: str) -> bool: """Check whether any registered expert template can handle the given content. - Used by CostAwareRouter to decide whether to upgrade REACT → TEAM_COLLAB. Returns True if at least one template's name or description overlaps with content tokens, or if any templates exist (auto-compose can always form a team). """ diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 5ac33d5..9347cd8 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -155,6 +155,7 @@ async def lifespan(app: FastAPI): # Restore conversation history from persistent store (async, in lifespan) from agentkit.server.routes.portal import _conversation_store + await _conversation_store.restore_from_store() # In GUI mode, ensure a default chat agent exists with memory + tools @@ -579,13 +580,13 @@ def create_app( app.state.quality_gate = QualityGate() app.state.output_standardizer = OutputStandardizer() - # Initialize SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT) - from agentkit.chat.simple_router import SimpleRouter + # Initialize RequestPreprocessor (minimal preprocessing: @skill prefix + greeting regex + REACT) + from agentkit.chat.request_preprocessor import RequestPreprocessor - simple_router = SimpleRouter( + request_preprocessor = RequestPreprocessor( skill_registry=app.state.skill_registry, ) - app.state.simple_router = simple_router + app.state.request_preprocessor = request_preprocessor # Initialize OrganizationContext from AgentPool + SkillRegistry from agentkit.org.context import OrganizationContext @@ -606,39 +607,6 @@ def create_app( alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway) app.state.alignment_guard = alignment_guard - # CostAwareRouter is no longer used by portal/chat routes (replaced by SimpleRouter). - # It is kept on app.state for backward compatibility with any external consumers. - # To re-enable, set router.legacy_cost_aware_router: true in agentkit.yaml. - router_conf = server_config.router if server_config and server_config.router else {} - if router_conf.get("legacy_cost_aware_router"): - from agentkit.chat.skill_routing import CostAwareRouter - - auction_enabled = False - if server_config and hasattr(server_config, "marketplace") and server_config.marketplace: - auction_enabled = server_config.marketplace.get("auction_enabled", False) - - semantic_router = None - if router_conf.get("semantic", {}).get("enabled"): - try: - from agentkit.chat.semantic_router import SemanticRouter - - semantic_router = SemanticRouter( - embedder=app.state.llm_gateway._embedder, - similarity_high=router_conf["semantic"].get("similarity_high", 0.85), - similarity_low=router_conf["semantic"].get("similarity_low", 0.6), - ) - except Exception as e: - logger.warning(f"Failed to initialize semantic router: {e}") - - cost_aware_router = CostAwareRouter( - llm_gateway=app.state.llm_gateway, - org_context=org_context, - auction_enabled=auction_enabled, - classifier=router_conf.get("classifier", "heuristic"), - merged_llm_classify=router_conf.get("merged_llm_classify", True), - semantic_router=semantic_router, - ) - app.state.cost_aware_router = cost_aware_router # Initialize task store from config ts_config = server_config.task_store if server_config else {} # Merge CLI overrides from AGENTKIT_TASK_STORE env var @@ -680,9 +648,11 @@ def create_app( ) app.state.session_manager = SessionManager(store=session_store) - # Inject SessionManager into Portal's ConversationStore for persistence + # Inject SessionManager into Portal's ConversationStore for persistence (legacy only) from agentkit.server.routes.portal import _conversation_store - _conversation_store.set_session_manager(app.state.session_manager) + + if hasattr(_conversation_store, "set_session_manager"): + _conversation_store.set_session_manager(app.state.session_manager) # Initialize evolution store if configured if server_config and hasattr(server_config, "evolution") and server_config.evolution: diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index bc4dd3b..9546944 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -97,12 +97,21 @@ chat_manager = ChatConnectionManager() # ── Helper ──────────────────────────────────────────────────────────── -_VALID_TEAM_EVENT_TYPES = frozenset({ - "team_formed", "expert_step", "expert_result", - "plan_update", "team_synthesis", "team_dissolved", - "plan_step", "phase_started", "phase_completed", "phase_failed", - "replanning", -}) +_VALID_TEAM_EVENT_TYPES = frozenset( + { + "team_formed", + "expert_step", + "expert_result", + "plan_update", + "team_synthesis", + "team_dissolved", + "plan_step", + "phase_started", + "phase_completed", + "phase_failed", + "replanning", + } +) async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None: @@ -125,10 +134,12 @@ async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> if event_type not in _VALID_TEAM_EVENT_TYPES: logger.warning(f"emit_team_event: invalid event_type '{event_type}'") return - await websocket.send_json({ - "type": event_type, - "data": data, - }) + await websocket.send_json( + { + "type": event_type, + "data": data, + } + ) def _get_session_manager(request: Request) -> SessionManager: @@ -236,11 +247,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques else: react_engine.reset() tools = agent._tool_registry.list_tools() if agent._tool_registry else [] - system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) + system_prompt = getattr(agent, "_system_prompt", None) or ( + agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None + ) result = await react_engine.execute( messages=chat_messages, tools=tools, - model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"), + model=agent.get_model() + if hasattr(agent, "get_model") + else getattr(agent, "_llm_model", "default"), agent_name=agent.name, system_prompt=system_prompt, ) @@ -319,7 +334,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: sm: SessionManager = websocket.app.state.session_manager session = await sm.get_session(session_id) if session is None: - await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}}) + await websocket.send_json( + {"type": "error", "data": {"message": f"Session '{session_id}' not found"}} + ) await websocket.close(code=1000, reason="Session not found") return @@ -367,10 +384,14 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: # Clean up completed tasks first active_tasks.difference_update(t for t in active_tasks if t.done()) if len(active_tasks) >= _MAX_CONCURRENT_TASKS: - await websocket.send_json({ - "type": "error", - "data": {"message": "Too many concurrent requests. Please wait for the current task to complete."}, - }) + await websocket.send_json( + { + "type": "error", + "data": { + "message": "Too many concurrent requests. Please wait for the current task to complete." + }, + } + ) continue # Run in background task so the WebSocket receive loop stays free @@ -378,7 +399,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: # is waiting for user confirmation (otherwise deadlock). task = asyncio.create_task( _handle_chat_message( - websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations, + websocket, + session_id, + content, + sm, + message_token, + pending_replies, + pending_confirmations, model_override=model, ) ) @@ -396,12 +423,16 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: # Reply to confirmation request confirmation_id = msg.get("confirmation_id") approved = msg.get("approved", False) - logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}") + logger.info( + f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}" + ) if confirmation_id and confirmation_id in pending_confirmations: pending_confirmations[confirmation_id].set_result(approved) logger.info(f"Confirmation {confirmation_id} set_result({approved})") else: - logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations") + logger.warning( + f"Confirmation {confirmation_id!r} not found in pending_confirmations" + ) elif msg_type == "cancel": cancellation_token.cancel() @@ -441,9 +472,9 @@ async def _handle_chat_message( ) -> None: """Handle a user message: append to session, execute Agent, stream events. - Uses SimpleRouter for minimal routing: @skill prefix + greeting regex + REACT. + Uses RequestPreprocessor for minimal preprocessing: @skill prefix + greeting regex + REACT. """ - from agentkit.chat.simple_router import SimpleRouter + from agentkit.chat.request_preprocessor import RequestPreprocessor # Resolve Agent first (needed for default tools/prompt) pool = websocket.app.state.agent_pool @@ -454,19 +485,27 @@ async def _handle_chat_message( agent = pool.get_agent(session.agent_name) if agent is None: - await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}) + await websocket.send_json( + {"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}} + ) return # Default execution parameters from agent default_tools = agent._tool_registry.list_tools() if agent._tool_registry else [] - default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) - default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default") + default_system_prompt = getattr(agent, "_system_prompt", None) or ( + agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None + ) + default_model = ( + agent.get_model() + if hasattr(agent, "get_model") + else getattr(agent, "_llm_model", "default") + ) - # Resolve skill routing using SimpleRouter + # Resolve skill routing using RequestPreprocessor skill_registry = getattr(websocket.app.state, "skill_registry", None) - simple_router: SimpleRouter = websocket.app.state.simple_router + request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor - routing = await simple_router.route( + routing = await request_preprocessor.preprocess( content=content, skill_registry=skill_registry, default_tools=default_tools, @@ -477,7 +516,9 @@ async def _handle_chat_message( # Debug: log tools that will be passed to ReActEngine tool_names = [t.name for t in routing.tools] - logger.info(f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}") + logger.info( + f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}" + ) # Apply model override from frontend selector if model_override: @@ -485,17 +526,21 @@ async def _handle_chat_message( # Notify frontend about skill match if routing.matched: - await websocket.send_json({ - "type": "skill_match", - "data": { - "skill": routing.skill_name, - "method": routing.match_method, - "confidence": routing.match_confidence, - }, - }) + await websocket.send_json( + { + "type": "skill_match", + "data": { + "skill": routing.skill_name, + "method": routing.match_method, + "confidence": routing.match_confidence, + }, + } + ) # Append user message (use clean_content if @skill: prefix was stripped) - await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content) + await sm.append_message( + session_id=session_id, role=MessageRole.USER, content=routing.clean_content + ) # Get full conversation history chat_messages = await sm.get_chat_messages(session_id) @@ -516,12 +561,15 @@ async def _handle_chat_message( final_content = response.content or "" if not final_content or not final_content.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE + final_content = EMPTY_LLM_RESPONSE - await websocket.send_json({ - "type": "final_answer", - "content": final_content, - "is_final": True, - }) + await websocket.send_json( + { + "type": "final_answer", + "content": final_content, + "is_final": True, + } + ) await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, @@ -557,14 +605,16 @@ async def _handle_chat_message( async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool: """Send confirmation request to frontend via WebSocket and wait for user reply.""" # Send confirmation request to frontend - await websocket.send_json({ - "type": "confirmation_request", - "data": { - "confirmation_id": confirmation_id, - "command": command, - "reason": reason, - }, - }) + await websocket.send_json( + { + "type": "confirmation_request", + "data": { + "confirmation_id": confirmation_id, + "command": command, + "reason": reason, + }, + } + ) # Create a Future and wait for the user's reply loop = asyncio.get_running_loop() @@ -578,10 +628,12 @@ async def _handle_chat_message( logger.info(f"Confirmation request {confirmation_id} resolved: {result}") # Immediately notify frontend of the result so the card updates # without waiting for the tool to re-execute - await websocket.send_json({ - "type": "confirmation_result", - "data": {"confirmation_id": confirmation_id, "approved": result}, - }) + await websocket.send_json( + { + "type": "confirmation_result", + "data": {"confirmation_id": confirmation_id, "approved": result}, + } + ) return result except asyncio.TimeoutError: logger.warning(f"Confirmation request {confirmation_id} timed out") @@ -592,7 +644,9 @@ async def _handle_chat_message( finally: _pending_confirmations.pop(confirmation_id, None) - logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}") + logger.info( + f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}" + ) try: final_content = "" @@ -615,12 +669,15 @@ async def _handle_chat_message( final_content = event.data.get("output", "") if not final_content or not final_content.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE + final_content = EMPTY_LLM_RESPONSE - await websocket.send_json({ - "type": "final_answer", - "content": final_content, - "is_final": True, - }) + await websocket.send_json( + { + "type": "final_answer", + "content": final_content, + "is_final": True, + } + ) elif event.event_type == "token": # Buffer tokens instead of sending immediately token_buffer.append(event.data.get("content", "")) @@ -640,30 +697,36 @@ async def _handle_chat_message( buffered_text = "".join(token_buffer) token_buffer.clear() await websocket.send_json({"type": "thinking", "content": buffered_text}) - await websocket.send_json({ - "type": "step", - "data": { - "event_type": event.event_type, - "step": event.step, - "data": event.data, - }, - }) + await websocket.send_json( + { + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.step, + "data": event.data, + }, + } + ) elif event.event_type == "confirmation_request": pass elif event.event_type == "confirmation_result": - await websocket.send_json({ - "type": "confirmation_result", - "data": event.data, - }) - else: - await websocket.send_json({ - "type": "step", - "data": { - "event_type": event.event_type, - "step": event.step, + await websocket.send_json( + { + "type": "confirmation_result", "data": event.data, - }, - }) + } + ) + else: + await websocket.send_json( + { + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.step, + "data": event.data, + }, + } + ) # Append assistant reply to session if final_content: diff --git a/tests/e2e/test_simple_router_backtest.py b/tests/e2e/test_request_preprocessor_backtest.py similarity index 88% rename from tests/e2e/test_simple_router_backtest.py rename to tests/e2e/test_request_preprocessor_backtest.py index 3463566..4b74126 100644 --- a/tests/e2e/test_simple_router_backtest.py +++ b/tests/e2e/test_request_preprocessor_backtest.py @@ -1,13 +1,13 @@ -"""E2E Agent Capability Tests — SimpleRouter Backtest (Real LLM). +"""E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM). -Tests SimpleRouter.route() using real LLM configuration loaded from +Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from agentkit.yaml. Records full SkillRoutingResult for precise analysis. Key differences from old CostAwareRouter backtest: - No HeuristicClassifier complexity scoring - No IntentRouter LLM classification - No SemanticRouter embedding matching -- SimpleRouter: @skill prefix + greeting regex + default REACT +- RequestPreprocessor: @skill prefix + greeting regex + default REACT """ import asyncio @@ -16,7 +16,7 @@ from pathlib import Path import pytest -from agentkit.chat.simple_router import SimpleRouter +from agentkit.chat.request_preprocessor import RequestPreprocessor from agentkit.chat.skill_routing import ExecutionMode from agentkit.server.app import _build_llm_gateway, _build_skill_registry from agentkit.server.config import ServerConfig @@ -95,7 +95,7 @@ def _find_config_path() -> str | None: return None -def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]: +def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]: config_path = _find_config_path() if not config_path: pytest.skip("No agentkit.yaml found") @@ -132,15 +132,15 @@ def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]: pytest.skip("No LLM provider with valid API key") skill_registry = _build_skill_registry(server_config) - router = SimpleRouter(skill_registry=skill_registry) + preprocessor = RequestPreprocessor(skill_registry=skill_registry) - return router, skill_registry + return preprocessor, skill_registry -_cached_components: tuple[SimpleRouter, SkillRegistry] | None = None +_cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None -def _get_components() -> tuple[SimpleRouter, SkillRegistry]: +def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]: global _cached_components if _cached_components is None: _cached_components = _build_real_components() @@ -153,8 +153,8 @@ def _get_components() -> tuple[SimpleRouter, SkillRegistry]: @pytest.mark.e2e_capability -class TestSimpleRouterBasic: - """Test SimpleRouter basic routing: greeting → DIRECT_CHAT, others → REACT.""" +class TestRequestPreprocessorBasic: + """Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT.""" @pytest.mark.parametrize( "case", @@ -162,9 +162,9 @@ class TestSimpleRouterBasic: ids=[c["id"] for c in ROUTING_TEST_CASES], ) def test_routing(self, case: dict): - router, skill_registry = _get_components() + preprocessor, skill_registry = _get_components() result = asyncio.run( - router.route( + preprocessor.preprocess( content=case["input"], skill_registry=skill_registry, default_tools=["shell", "search", "file_read"], @@ -179,8 +179,8 @@ class TestSimpleRouterBasic: @pytest.mark.e2e_capability -class TestSimpleRouterParaphraseConsistency: - """Test that paraphrased inputs route to the same execution mode.""" +class TestRequestPreprocessorParaphraseConsistency: + """Test that paraphrased inputs preprocess to the same execution mode.""" @pytest.mark.parametrize( "case", @@ -188,12 +188,12 @@ class TestSimpleRouterParaphraseConsistency: ids=[c["id"] for c in PARAPHRASE_CASES], ) def test_paraphrase_consistency(self, case: dict): - router, skill_registry = _get_components() + preprocessor, skill_registry = _get_components() expected_mode = case["expected_mode"] # Test original result = asyncio.run( - router.route( + preprocessor.preprocess( content=case["original"], skill_registry=skill_registry, default_tools=["shell", "search", "file_read"], @@ -206,7 +206,7 @@ class TestSimpleRouterParaphraseConsistency: # Test all paraphrases for para in case["paraphrases"]: result = asyncio.run( - router.route( + preprocessor.preprocess( content=para, skill_registry=skill_registry, default_tools=["shell", "search", "file_read"], @@ -218,19 +218,19 @@ class TestSimpleRouterParaphraseConsistency: @pytest.mark.e2e_capability -class TestSimpleRouterMetrics: - """Compute and report routing accuracy metrics.""" +class TestRequestPreprocessorMetrics: + """Compute and report preprocessing accuracy metrics.""" def test_accuracy_report(self): """Run all test cases and compute accuracy metrics.""" - router, skill_registry = _get_components() + preprocessor, skill_registry = _get_components() total = len(ROUTING_TEST_CASES) correct = 0 results = [] for case in ROUTING_TEST_CASES: result = asyncio.run( - router.route( + preprocessor.preprocess( content=case["input"], skill_registry=skill_registry, default_tools=["shell", "search", "file_read"], @@ -251,7 +251,7 @@ class TestSimpleRouterMetrics: accuracy = correct / total * 100 print(f"\n{'='*60}") - print(f"SimpleRouter Accuracy Report") + print(f"RequestPreprocessor Accuracy Report") print(f"{'='*60}") print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%") print(f"{'-'*60}") diff --git a/tests/unit/chat/test_simple_router.py b/tests/unit/chat/test_request_preprocessor.py similarity index 64% rename from tests/unit/chat/test_simple_router.py rename to tests/unit/chat/test_request_preprocessor.py index 4fa43d5..c9cbec5 100644 --- a/tests/unit/chat/test_simple_router.py +++ b/tests/unit/chat/test_request_preprocessor.py @@ -1,10 +1,10 @@ -"""Unit tests for SimpleRouter — minimal routing layer.""" +"""Unit tests for RequestPreprocessor — minimal preprocessing layer.""" from __future__ import annotations import pytest -from agentkit.chat.simple_router import SimpleRouter +from agentkit.chat.request_preprocessor import RequestPreprocessor from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult @@ -51,8 +51,8 @@ def registry() -> MockSkillRegistry: @pytest.fixture -def router(registry: MockSkillRegistry) -> SimpleRouter: - return SimpleRouter( +def preprocessor(registry: MockSkillRegistry) -> RequestPreprocessor: + return RequestPreprocessor( skill_registry=registry, default_tools=["shell", "search", "file_read"], default_system_prompt="You are a helpful assistant.", @@ -67,8 +67,8 @@ def router(registry: MockSkillRegistry) -> SimpleRouter: class TestSkillPrefix: @pytest.mark.asyncio - async def test_skill_prefix_routes_to_skill(self, router: SimpleRouter): - result = await router.route("@skill:shell_agent 查看当前ip") + async def test_skill_prefix_routes_to_skill(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("@skill:shell_agent 查看当前ip") assert result.matched is True assert result.skill_name == "shell_agent" assert result.match_method == "skill_prefix" @@ -76,22 +76,22 @@ class TestSkillPrefix: assert result.execution_mode == ExecutionMode.SKILL_REACT @pytest.mark.asyncio - async def test_skill_prefix_direct_mode(self, router: SimpleRouter): - result = await router.route("@skill:direct_agent 翻译hello") + async def test_skill_prefix_direct_mode(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("@skill:direct_agent 翻译hello") assert result.matched is True assert result.skill_name == "direct_agent" assert result.execution_mode == ExecutionMode.DIRECT_CHAT @pytest.mark.asyncio - async def test_skill_prefix_rewoo_mode(self, router: SimpleRouter): - result = await router.route("@skill:rewoo_agent 重构代码") + async def test_skill_prefix_rewoo_mode(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("@skill:rewoo_agent 重构代码") assert result.matched is True assert result.skill_name == "rewoo_agent" assert result.execution_mode == ExecutionMode.REWOO @pytest.mark.asyncio - async def test_unknown_skill_falls_back_to_react(self, router: SimpleRouter): - result = await router.route("@skill:nonexistent 查询") + async def test_unknown_skill_falls_back_to_react(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("@skill:nonexistent 查询") assert result.matched is False assert result.match_method == "skill_not_found_fallback" assert result.execution_mode == ExecutionMode.REACT @@ -103,30 +103,30 @@ class TestSkillPrefix: class TestDirectChat: @pytest.mark.asyncio - async def test_greeting_cn(self, router: SimpleRouter): - result = await router.route("你好") + async def test_greeting_cn(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("你好") assert result.execution_mode == ExecutionMode.DIRECT_CHAT assert result.match_method == "regex_direct" assert result.tools == [] @pytest.mark.asyncio - async def test_greeting_en(self, router: SimpleRouter): - result = await router.route("hello") + async def test_greeting_en(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("hello") assert result.execution_mode == ExecutionMode.DIRECT_CHAT @pytest.mark.asyncio - async def test_chitchat(self, router: SimpleRouter): - result = await router.route("谢谢") + async def test_chitchat(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("谢谢") assert result.execution_mode == ExecutionMode.DIRECT_CHAT @pytest.mark.asyncio - async def test_identity_question(self, router: SimpleRouter): - result = await router.route("你是谁") + async def test_identity_question(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("你是谁") assert result.execution_mode == ExecutionMode.DIRECT_CHAT @pytest.mark.asyncio - async def test_identity_question_en(self, router: SimpleRouter): - result = await router.route("who are you") + async def test_identity_question_en(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("who are you") assert result.execution_mode == ExecutionMode.DIRECT_CHAT @@ -136,15 +136,15 @@ class TestDirectChat: class TestDefaultReact: @pytest.mark.asyncio - async def test_colloquial_tool_query(self, router: SimpleRouter): + async def test_colloquial_tool_query(self, preprocessor: RequestPreprocessor): """口语化工具查询 — 这是之前路由层误判的核心场景""" - result = await router.route("查下ip") + result = await preprocessor.preprocess("查下ip") assert result.execution_mode == ExecutionMode.REACT assert result.match_method == "default_react" assert len(result.tools) > 0 @pytest.mark.asyncio - async def test_various_colloquial_expressions(self, router: SimpleRouter): + async def test_various_colloquial_expressions(self, preprocessor: RequestPreprocessor): """各种口语化说法都应走 REACT,让 LLM 决定""" queries = [ "查看当前ip", @@ -157,30 +157,30 @@ class TestDefaultReact: "检查服务状态", ] for query in queries: - result = await router.route(query) + result = await preprocessor.preprocess(query) assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}" @pytest.mark.asyncio - async def test_complex_query(self, router: SimpleRouter): - result = await router.route("帮我分析一下这个数据并生成报告") + async def test_complex_query(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("帮我分析一下这个数据并生成报告") assert result.execution_mode == ExecutionMode.REACT @pytest.mark.asyncio - async def test_translation_goes_react(self, router: SimpleRouter): + async def test_translation_goes_react(self, preprocessor: RequestPreprocessor): """翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具""" - result = await router.route("翻译hello为中文") + result = await preprocessor.preprocess("翻译hello为中文") assert result.execution_mode == ExecutionMode.REACT # LLM will see tools but decide not to use them @pytest.mark.asyncio - async def test_default_tools_included(self, router: SimpleRouter): - result = await router.route("查下ip") + async def test_default_tools_included(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("查下ip") assert "shell" in result.tools assert "search" in result.tools @pytest.mark.asyncio - async def test_default_system_prompt(self, router: SimpleRouter): - result = await router.route("查下ip") + async def test_default_system_prompt(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("查下ip") assert result.system_prompt == "You are a helpful assistant." @@ -190,31 +190,31 @@ class TestDefaultReact: class TestEdgeCases: @pytest.mark.asyncio - async def test_empty_input(self, router: SimpleRouter): - result = await router.route("") + async def test_empty_input(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess("") assert result.execution_mode == ExecutionMode.REACT @pytest.mark.asyncio - async def test_whitespace_only(self, router: SimpleRouter): - result = await router.route(" ") + async def test_whitespace_only(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess(" ") assert result.execution_mode == ExecutionMode.REACT @pytest.mark.asyncio - async def test_greeting_with_extra_spaces(self, router: SimpleRouter): - result = await router.route(" 你好 ") + async def test_greeting_with_extra_spaces(self, preprocessor: RequestPreprocessor): + result = await preprocessor.preprocess(" 你好 ") assert result.execution_mode == ExecutionMode.DIRECT_CHAT @pytest.mark.asyncio async def test_no_registry(self): - """Router without skill registry should still work for non-skill queries""" - router = SimpleRouter(default_tools=["shell"]) - result = await router.route("查下ip") + """Preprocessor without skill registry should still work for non-skill queries""" + preprocessor = RequestPreprocessor(default_tools=["shell"]) + result = await preprocessor.preprocess("查下ip") assert result.execution_mode == ExecutionMode.REACT @pytest.mark.asyncio - async def test_override_defaults(self, router: SimpleRouter): - """Route-time overrides should work""" - result = await router.route( + async def test_override_defaults(self, preprocessor: RequestPreprocessor): + """Preprocess-time overrides should work""" + result = await preprocessor.preprocess( "查下ip", default_tools=["shell_only"], default_model="gpt-4o", diff --git a/tests/unit/chat/test_skill_routing.py b/tests/unit/chat/test_skill_routing.py index 8303229..3c26753 100644 --- a/tests/unit/chat/test_skill_routing.py +++ b/tests/unit/chat/test_skill_routing.py @@ -1,13 +1,11 @@ -"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier.""" +"""Unit tests for ExpertTeamRouter and skill routing utilities.""" from __future__ import annotations from unittest.mock import MagicMock from agentkit.chat.skill_routing import ( - CostAwareRouter, ExecutionMode, - HeuristicClassifier, SkillRoutingResult, ) from agentkit.experts.config import ExpertConfig, ExpertTemplate @@ -20,16 +18,6 @@ from agentkit.experts.router import ExpertTeamRouter # --------------------------------------------------------------------------- -def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter: - """Create a CostAwareRouter with mocked dependencies.""" - return CostAwareRouter( - llm_gateway=None, - model="test", - classifier="heuristic", - expert_team_router=expert_team_router, - ) - - def _make_team_router_with_templates() -> ExpertTeamRouter: """Create an ExpertTeamRouter with sample templates.""" registry = ExpertTemplateRegistry() @@ -82,251 +70,51 @@ class TestExpertTeamRouterCanHandle: # --------------------------------------------------------------------------- -# Tests: _try_team_upgrade() -# --------------------------------------------------------------------------- - - -class TestTryTeamUpgrade: - def test_upgrade_react_to_team_collab(self) -> None: - router = _make_router(expert_team_router=_make_team_router_with_templates()) - result = SkillRoutingResult( - clean_content="complex multi-step analysis task", - matched=True, - match_method="capability", - match_confidence=0.8, - complexity=0.8, - execution_mode=ExecutionMode.REACT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace) - assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB - assert any(t.get("method") == "team_upgrade" for t in trace) - - def test_no_upgrade_low_complexity(self) -> None: - router = _make_router(expert_team_router=_make_team_router_with_templates()) - result = SkillRoutingResult( - clean_content="simple question", - matched=True, - match_method="capability", - match_confidence=0.8, - complexity=0.3, - execution_mode=ExecutionMode.REACT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace) - assert upgraded.execution_mode == ExecutionMode.REACT - assert not any(t.get("method") == "team_upgrade" for t in trace) - - def test_no_upgrade_no_team_router(self) -> None: - router = _make_router(expert_team_router=None) - result = SkillRoutingResult( - clean_content="complex analysis", - matched=True, - match_method="capability", - match_confidence=0.8, - complexity=0.9, - execution_mode=ExecutionMode.REACT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace) - assert upgraded.execution_mode == ExecutionMode.REACT - - def test_no_upgrade_empty_templates(self) -> None: - router = _make_router(expert_team_router=_make_team_router_empty()) - result = SkillRoutingResult( - clean_content="complex analysis", - matched=True, - match_method="capability", - match_confidence=0.8, - complexity=0.8, - execution_mode=ExecutionMode.REACT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace) - assert upgraded.execution_mode == ExecutionMode.REACT - - def test_no_upgrade_direct_chat_mode(self) -> None: - router = _make_router(expert_team_router=_make_team_router_with_templates()) - result = SkillRoutingResult( - clean_content="hello", - matched=False, - match_method="greeting", - match_confidence=1.0, - complexity=0.0, - execution_mode=ExecutionMode.DIRECT_CHAT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "hello", 0.0, trace) - assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT - - def test_team_upgrade_exception_handled(self) -> None: - """When ExpertTeamRouter raises, the upgrade is silently skipped.""" - broken_router = MagicMock() - broken_router.can_handle.side_effect = RuntimeError("boom") - router = _make_router(expert_team_router=broken_router) - result = SkillRoutingResult( - clean_content="complex task", - matched=True, - match_method="capability", - match_confidence=0.8, - complexity=0.8, - execution_mode=ExecutionMode.REACT, - ) - trace: list[dict] = [] - upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace) - assert upgraded.execution_mode == ExecutionMode.REACT - - -# --------------------------------------------------------------------------- -# Tests: ExpertTeamRouter.resolve() with complexity +# Tests: ExpertTeamRouter.resolve() # --------------------------------------------------------------------------- class TestExpertTeamRouterResolve: def test_explicit_team_prefix(self) -> None: router = _make_team_router_with_templates() - result = router.resolve("@team:analyst,strategist analyze the market", 0.5) + result = router.resolve("@team:analyst,strategist analyze the market") assert result.team_mode is True assert result.match_method == "explicit_team" assert "analyst" in result.specified_experts assert "strategist" in result.specified_experts - def test_complexity_suggestion(self) -> None: + def test_no_team_without_prefix(self) -> None: router = _make_team_router_with_templates() - result = router.resolve("complex multi-step analysis", 0.8) - assert result.team_mode is True - assert result.match_method == "complexity_suggestion" - assert result.auto_compose is True - - def test_no_team_low_complexity(self) -> None: - router = _make_team_router_with_templates() - result = router.resolve("simple question", 0.2) + result = router.resolve("simple question") assert result.team_mode is False # --------------------------------------------------------------------------- -# Tests: HeuristicClassifier complexity calibration +# Tests: SkillRoutingResult data structure # --------------------------------------------------------------------------- -class TestHeuristicClassifierLowComplexity: - """Low-complexity signals should produce scores < 0.3.""" +class TestSkillRoutingResult: + def test_default_execution_mode(self) -> None: + result = SkillRoutingResult( + clean_content="test", + matched=False, + match_method="default_react", + match_confidence=0.8, + agent_name="default", + model="default", + execution_mode=ExecutionMode.REACT, + ) + assert result.execution_mode == ExecutionMode.REACT - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_chinese_greeting(self) -> None: - assert self.clf.classify("你好") < 0.3 - - def test_chinese_greeting_hi(self) -> None: - assert self.clf.classify("嗨") < 0.3 - - def test_english_greeting_hello(self) -> None: - assert self.clf.classify("Hello") < 0.3 - - def test_english_greeting_hi(self) -> None: - assert self.clf.classify("hi") < 0.3 - - def test_multiple_low_complexity_words(self) -> None: - assert self.clf.classify("嗨,早上好") < 0.3 - - def test_greeting_with_high_complexity_word_not_suppressed(self) -> None: - """Low-complexity signal should NOT override high-complexity signal.""" - # "你好" is low, but "分析" is high → should score high - assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5 - - -class TestHeuristicClassifierIdentity: - """Identity queries should produce scores < 0.3.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_who_are_you_cn(self) -> None: - assert self.clf.classify("你是谁") < 0.3 - - def test_what_is_your_name_cn(self) -> None: - assert self.clf.classify("你叫什么") < 0.3 - - -class TestHeuristicClassifierNegation: - """Negated high-complexity words should not contribute to score.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_negate_search_cn(self) -> None: - assert self.clf.classify("不要搜索") < 0.3 - - def test_negate_analyze_cn(self) -> None: - assert self.clf.classify("无需分析,直接告诉我答案") < 0.3 - - def test_partial_negation_still_high(self) -> None: - """'搜索' negated but '分析' not — should still be high.""" - assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5 - - -class TestHeuristicClassifierThresholds: - """Verify adjusted base scores.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_no_keyword_short_message(self) -> None: - assert self.clf.classify("好的") <= 0.10 - - def test_medium_complexity_base(self) -> None: - """Medium complexity keyword should start at 0.35 (not 0.45).""" - score = self.clf.classify("如何使用Python?") - # '如何' is medium → base 0.35, '?' short question → -0.10 = 0.25 - # but 'Python' is not in high/medium lists, so just medium base - assert 0.25 <= score <= 0.45 - - -class TestHeuristicClassifierShortQuestion: - """Short questions ending with ?/? should get deduction.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_short_question_deduction(self) -> None: - assert self.clf.classify("怎么用?") < 0.3 - - def test_long_question_no_deduction(self) -> None: - assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5 - - -class TestHeuristicClassifierHighComplexity: - """Complex tasks should produce scores > 0.7.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_two_high_complexity_words(self) -> None: - # "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80 - assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7 - - def test_single_high_complexity_word(self) -> None: - # "分析" alone → base 0.65 - assert self.clf.classify("分析市场趋势并生成报告") > 0.6 - - def test_execute_and_restart(self) -> None: - assert self.clf.classify("执行部署脚本并重启服务") > 0.7 - - -class TestHeuristicClassifierEdgeCases: - """Boundary conditions.""" - - def setup_method(self) -> None: - self.clf = HeuristicClassifier() - - def test_empty_string(self) -> None: - assert self.clf.classify("") == 0.0 - - def test_whitespace_only(self) -> None: - assert self.clf.classify(" ") == 0.0 - - def test_long_low_complexity_message(self) -> None: - """Even a long greeting should stay low.""" - long_greeting = "你好" * 100 # >200 chars - assert self.clf.classify(long_greeting) <= 0.15 + def test_direct_chat_mode(self) -> None: + result = SkillRoutingResult( + clean_content="hello", + matched=False, + match_method="regex_direct", + match_confidence=1.0, + agent_name="default", + model="default", + execution_mode=ExecutionMode.DIRECT_CHAT, + ) + assert result.execution_mode == ExecutionMode.DIRECT_CHAT diff --git a/tests/unit/test_cost_aware_router.py b/tests/unit/test_cost_aware_router.py deleted file mode 100644 index e2220cc..0000000 --- a/tests/unit/test_cost_aware_router.py +++ /dev/null @@ -1,808 +0,0 @@ -"""CostAwareRouter 单元测试 - 三层成本感知路由""" - -import json -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content -from agentkit.llm.protocol import LLMResponse, TokenUsage -from agentkit.router.intent import IntentRouter, RoutingResult -from agentkit.skills.base import IntentConfig, Skill, SkillConfig - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _make_skill( - name: str, - keywords: list[str] | None = None, - description: str = "", - examples: list[str] | None = None, -) -> Skill: - """快速构造一个带 intent 配置的 Skill""" - config = SkillConfig( - name=name, - agent_type="test", - task_mode="llm_generate", - prompt={"system": f"You are a {name} skill."}, - intent={ - "keywords": keywords or [], - "description": description, - "examples": examples or [], - }, - ) - return Skill(config=config) - - -def _make_llm_gateway(response_content: str) -> MagicMock: - """构造一个 mock LLMGateway,chat 返回指定 content""" - gateway = MagicMock() - gateway.chat = AsyncMock( - return_value=LLMResponse( - content=response_content, - model="test-model", - usage=TokenUsage(prompt_tokens=10, completion_tokens=20), - ) - ) - return gateway - - -def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock: - """构造一个 mock SkillRegistry""" - registry = MagicMock() - _skills = skills or [] - registry.list_skills.return_value = _skills - - def _get(name: str): - for s in _skills: - if s.name == name: - return s - raise KeyError(f"Skill '{name}' not found") - - registry.get = MagicMock(side_effect=_get) - return registry - - -def _make_intent_router() -> IntentRouter: - """构造一个无 LLM 的 IntentRouter(仅关键词匹配)""" - return IntentRouter(llm_gateway=None, model="default") - - -# --------------------------------------------------------------------------- -# Layer 0: Rule-based (zero cost) -# --------------------------------------------------------------------------- - - -class TestLayer0Greeting: - """Layer 0: 问候模式匹配""" - - @pytest.mark.asyncio - async def test_chinese_greeting_hits_layer0(self): - """'你好' 命中 Layer 0 问候规则,零 token 成本""" - router = CostAwareRouter() - result = await router.route( - content="你好", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "greeting" - assert result.complexity == 0.0 - assert result.agent_name == "default" - assert result.matched is False - - @pytest.mark.asyncio - async def test_english_greeting_hits_layer0(self): - """'hello' 命中 Layer 0 问候规则""" - router = CostAwareRouter() - result = await router.route( - content="hello", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "greeting" - assert result.complexity == 0.0 - - @pytest.mark.asyncio - async def test_greeting_with_punctuation(self): - """'你好!' 带标点也命中 Layer 0""" - router = CostAwareRouter() - result = await router.route( - content="你好!", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "greeting" - - -class TestLayer0ChatMode: - """Layer 0: 简单对话模式""" - - @pytest.mark.asyncio - async def test_thanks_hits_chat_mode(self): - """'谢谢' 命中 Layer 0 简单对话模式""" - router = CostAwareRouter() - result = await router.route( - content="谢谢", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "chat_mode" - assert result.complexity == 0.0 - - @pytest.mark.asyncio - async def test_ok_hits_chat_mode(self): - """'好的' 命中 Layer 0 简单对话模式""" - router = CostAwareRouter() - result = await router.route( - content="好的", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "chat_mode" - - -class TestLayer0ExplicitSkill: - """Layer 0: @skill: 显式前缀""" - - @pytest.mark.asyncio - async def test_skill_prefix_hits_layer0(self): - """'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本""" - search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息") - registry = _make_skill_registry([search_skill]) - # 需要 IntentRouter 支持 LLM fallback - gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9})) - intent_router = IntentRouter(llm_gateway=gateway, model="default") - - router = CostAwareRouter() - result = await router.route( - content="@skill:search 搜索XX", - skill_registry=registry, - intent_router=intent_router, - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.matched is True - assert result.skill_name == "search" - assert result.complexity == 0.0 - - -# --------------------------------------------------------------------------- -# Layer 1: LLM quick classify (~100 tokens) -# --------------------------------------------------------------------------- - - -class TestLayer1Classification: - """Layer 1: LLM 快速分类""" - - @pytest.mark.asyncio - async def test_medium_complexity_routes_via_intent_router(self): - """'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter""" - # LLM 返回中等复杂度 - gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) - search_skill = _make_skill("search", keywords=["分析"], description="数据分析") - registry = _make_skill_registry([search_skill]) - - # IntentRouter 也需要 LLM - intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9})) - intent_router = IntentRouter(llm_gateway=intent_gateway, model="default") - - router = CostAwareRouter(llm_gateway=gateway, model="default") - result = await router.route( - content="分析下这个数据", - skill_registry=registry, - intent_router=intent_router, - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert 0.3 <= result.complexity <= 0.7 - - @pytest.mark.asyncio - async def test_low_complexity_routes_to_default(self): - """低复杂度 (<0.3) 路由到默认 Agent""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.1})) - router = CostAwareRouter(llm_gateway=gateway, model="default") - result = await router.route( - content="随便聊聊", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.complexity < 0.3 - assert result.match_method == "low_complexity" - assert result.agent_name == "default" - - @pytest.mark.asyncio - async def test_no_llm_gateway_defaults_to_medium(self): - """无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)""" - router = CostAwareRouter(llm_gateway=None) - complexity = await router.quick_classify("分析下这个数据") - assert complexity == 0.5 - - @pytest.mark.asyncio - async def test_llm_malformed_response_defaults_to_medium(self): - """LLM 返回非 JSON 时 quick_classify 返回 0.5""" - gateway = _make_llm_gateway("这不是JSON") - router = CostAwareRouter(llm_gateway=gateway, model="default") - complexity = await router.quick_classify("分析下这个数据") - assert complexity == 0.5 - - @pytest.mark.asyncio - async def test_complexity_clamped_to_0_1(self): - """复杂度值被限制在 [0.0, 1.0] 范围""" - gateway = _make_llm_gateway(json.dumps({"complexity": 1.5})) - router = CostAwareRouter(llm_gateway=gateway, model="default") - complexity = await router.quick_classify("超级复杂任务") - assert complexity == 1.0 - - gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5})) - router2 = CostAwareRouter(llm_gateway=gateway2, model="default") - complexity2 = await router2.quick_classify("简单任务") - assert complexity2 == 0.0 - - -# --------------------------------------------------------------------------- -# Layer 2: Capability matching / Auction -# --------------------------------------------------------------------------- - - -class TestLayer2CapabilityMatching: - """Layer 2: 能力匹配 / 拍卖""" - - @pytest.mark.asyncio - async def test_high_complexity_triggers_capability_matching(self): - """'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.85})) - org_context = MagicMock() - org_context.find_best_agent = MagicMock(return_value="market-researcher") - - router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.complexity > 0.7 - assert result.match_method == "capability" - assert result.agent_name == "market-researcher" - assert result.matched is True - - @pytest.mark.asyncio - async def test_layer2_with_org_context_object(self): - """org_context.find_best_agent 返回对象时提取 name 属性""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.9})) - agent_obj = MagicMock() - agent_obj.name = "analyst-agent" - org_context = MagicMock() - org_context.find_best_agent = MagicMock(return_value=agent_obj) - - router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.agent_name == "analyst-agent" - - @pytest.mark.asyncio - async def test_layer2_without_org_context_falls_back_to_intent_router(self): - """无 org_context 时 Layer 2 回退到 IntentRouter""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.8})) - search_skill = _make_skill("search", keywords=["调研"], description="市场调研") - registry = _make_skill_registry([search_skill]) - - intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9})) - intent_router = IntentRouter(llm_gateway=intent_gateway, model="default") - - router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=registry, - intent_router=intent_router, - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.complexity > 0.7 - # 回退到 IntentRouter,可能匹配到 skill 或走 default - assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None) - - @pytest.mark.asyncio - async def test_layer2_org_context_find_best_agent_returns_none(self): - """org_context.find_best_agent 返回 None 时回退到 IntentRouter""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.8})) - org_context = MagicMock() - org_context.find_best_agent = MagicMock(return_value=None) - - search_skill = _make_skill("search", keywords=["调研"], description="市场调研") - registry = _make_skill_registry([search_skill]) - intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9})) - intent_router = IntentRouter(llm_gateway=intent_gateway, model="default") - - router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=registry, - intent_router=intent_router, - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.complexity > 0.7 - - @pytest.mark.asyncio - async def test_auction_disabled_by_default(self): - """拍卖模式默认禁用""" - router = CostAwareRouter() - assert router._auction_enabled is False - - @pytest.mark.asyncio - async def test_auction_can_be_enabled(self): - """拍卖模式可手动启用""" - router = CostAwareRouter(auction_enabled=True) - assert router._auction_enabled is True - - -# --------------------------------------------------------------------------- -# Transparency -# --------------------------------------------------------------------------- - - -class TestTransparency: - """透明度级别切换""" - - @pytest.mark.asyncio - async def test_silent_mode_no_trace(self): - """SILENT 模式不暴露路由追踪""" - router = CostAwareRouter() - result = await router.route( - content="你好", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - transparency="SILENT", - ) - assert result.execution_trace == [] - assert result.transparency_level == "SILENT" - - @pytest.mark.asyncio - async def test_verbose_mode_shows_trace(self): - """VERBOSE 模式显示路由追踪""" - router = CostAwareRouter() - result = await router.route( - content="你好", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - transparency="VERBOSE", - ) - assert len(result.execution_trace) > 0 - assert result.execution_trace[0]["layer"] == 0 - assert result.execution_trace[0]["method"] == "greeting" - assert result.transparency_level == "VERBOSE" - - @pytest.mark.asyncio - async def test_trace_mode_shows_full_trace(self): - """TRACE 模式显示完整路由追踪""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.85})) - org_context = MagicMock() - org_context.find_best_agent = MagicMock(return_value="analyst") - - router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - transparency="TRACE", - ) - assert len(result.execution_trace) > 0 - # 应包含 Layer 1 quick_classify 和 Layer 2 的记录 - layers = [t["layer"] for t in result.execution_trace] - assert 1 in layers # Layer 1 quick_classify - assert 2 in layers # Layer 2 capability matching - assert result.transparency_level == "TRACE" - - @pytest.mark.asyncio - async def test_default_transparency_is_silent(self): - """默认透明度为 SILENT""" - router = CostAwareRouter() - result = await router.route( - content="你好", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.transparency_level == "SILENT" - assert result.execution_trace == [] - - -# --------------------------------------------------------------------------- -# SkillRoutingResult 新字段 -# --------------------------------------------------------------------------- - - -class TestSkillRoutingResultNewFields: - """SkillRoutingResult 新字段验证""" - - def test_default_transparency_level(self): - result = SkillRoutingResult() - assert result.transparency_level == "SILENT" - - def test_default_execution_trace(self): - result = SkillRoutingResult() - assert result.execution_trace == [] - - def test_default_complexity(self): - result = SkillRoutingResult() - assert result.complexity == 0.0 - - def test_new_fields_backward_compatible(self): - """新字段不影响旧代码创建 SkillRoutingResult""" - result = SkillRoutingResult( - skill_name="test", - matched=True, - match_method="keyword", - ) - assert result.transparency_level == "SILENT" - assert result.execution_trace == [] - assert result.complexity == 0.0 - - -# --------------------------------------------------------------------------- -# _tokenize_content: 中文分词增强 -# --------------------------------------------------------------------------- - - -class TestTokenizeContent: - """_tokenize_content 中文分词增强测试""" - - def test_chinese_content(self): - """中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram""" - tokens = _tokenize_content("帮我做数据分析") - # 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析 - assert "数据" in tokens or "数据分析" in tokens - - def test_english_content(self): - """英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'""" - tokens = _tokenize_content("help with code generation") - assert "code" in tokens or "generation" in tokens or "code generation" in tokens - - def test_mixed_content(self): - """中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'""" - tokens = _tokenize_content("用python做data analysis") - # 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram - # "analysis" 作为独立 segment - assert "analysis" in tokens - # "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段 - has_python_related = any("python" in t for t in tokens) - assert has_python_related or "data analysis" in tokens - - def test_stopwords_filtered(self): - """停用词过滤:纯停用词短句过滤后应为空或极少 token""" - tokens = _tokenize_content("我的一个") - # "我的一个" 长度 4,作为整体保留(不在停用词集合中) - # 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤 - assert len(tokens) <= 1 - - def test_bigram_generation(self): - """2-gram 生成:'机器学习模型训练' 应包含各 2-gram""" - tokens = _tokenize_content("机器学习模型训练") - expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"] - for bigram in expected_bigrams: - assert bigram in tokens, f"缺少 2-gram: {bigram}" - - -# --------------------------------------------------------------------------- -# HeuristicClassifier -# --------------------------------------------------------------------------- - - -class TestHeuristicClassifier: - """HeuristicClassifier 本地启发式分类器测试""" - - def setup_method(self): - self.classifier = HeuristicClassifier() - - def test_short_greeting_low_complexity(self): - """短问候语 → 低复杂度""" - score = self.classifier.classify("你好呀") - assert score < 0.3 - - def test_simple_question_medium_complexity(self): - """含'如何'的简单问题 → 中等复杂度""" - score = self.classifier.classify("如何使用这个功能?") - assert 0.3 <= score <= 0.7 - - def test_tool_request_high_complexity(self): - """含工具关键词的请求 → 高复杂度""" - score = self.classifier.classify("帮我搜索一下最新的新闻") - assert score > 0.5 - - def test_code_request_high_complexity(self): - """代码相关请求 → 高复杂度""" - score = self.classifier.classify("写一个Python函数实现快速排序") - assert score > 0.6 - - def test_multi_step_request_high_complexity(self): - """多步分析请求 → 高复杂度""" - score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐") - assert score > 0.7 - - def test_empty_string_zero_complexity(self): - """空字符串 → 零复杂度""" - assert self.classifier.classify("") == 0.0 - assert self.classifier.classify(" ") == 0.0 - - def test_long_message_higher_complexity(self): - """长消息 → 更高复杂度""" - short = "帮我查一下" - long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10 - assert self.classifier.classify(long) > self.classifier.classify(short) - - def test_code_patterns_boost_complexity(self): - """代码模式(反引号/括号)提升复杂度""" - with_code = "运行这段代码 `print('hello')`" - without_code = "运行这段代码 print hello" - assert self.classifier.classify(with_code) > self.classifier.classify(without_code) - - def test_score_bounded_0_to_1(self): - """复杂度值始终在 [0.0, 1.0] 范围""" - test_inputs = [ - "", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置", - ] - for inp in test_inputs: - score = self.classifier.classify(inp) - assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'" - - -class TestHeuristicClassifierIntegration: - """HeuristicClassifier 在 CostAwareRouter 中的集成测试""" - - @pytest.mark.asyncio - async def test_heuristic_mode_no_llm_call(self): - """heuristic 模式 + merged_llm_classify=False 时不调用 LLM""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) - router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False) - result = await router.route( - content="帮我分析一下数据", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # LLM gateway.chat 不应被调用(heuristic + merged disabled) - gateway.chat.assert_not_called() - # 复杂度应来自启发式分类器 - assert result.complexity > 0.0 - - @pytest.mark.asyncio - async def test_llm_mode_uses_llm(self): - """llm 模式下调用 LLM quick_classify""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) - router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm") - result = await router.route( - content="帮我分析一下数据", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # LLM gateway.chat 应被调用 - gateway.chat.assert_called() - - @pytest.mark.asyncio - async def test_heuristic_greeting_still_layer0(self): - """heuristic 模式下问候仍走 Layer 0""" - router = CostAwareRouter(classifier="heuristic") - result = await router.route( - content="你好", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "greeting" - assert result.complexity == 0.0 - - @pytest.mark.asyncio - async def test_heuristic_default_classifier_mode(self): - """默认分类器模式为 heuristic""" - router = CostAwareRouter() - assert router._classifier == "heuristic" - - -# --------------------------------------------------------------------------- -# U1: Merged LLM Classify -# --------------------------------------------------------------------------- - - -class TestMergedLLMClassify: - """合并路由 LLM 调用测试""" - - @pytest.mark.asyncio - async def test_merged_classify_returns_valid_skill(self): - """合并调用返回有效 JSON + skill_hint,正确路由到指定 skill""" - merged_response = json.dumps({ - "complexity": 0.6, - "intent": "code_generation", - "skill_hint": "search", - }) - gateway = _make_llm_gateway(merged_response) - search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息") - registry = _make_skill_registry([search_skill]) - - router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True) - result = await router.route( - content="帮我搜索一下最新的新闻", - skill_registry=registry, - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.matched is True - assert result.skill_name == "search" - assert result.match_method == "merged_llm" - assert result.complexity > 0.3 - - @pytest.mark.asyncio - async def test_merged_classify_malformed_response_fallback(self): - """合并调用返回格式异常,fallback 到默认 Agent""" - gateway = _make_llm_gateway("这不是JSON") - router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True) - result = await router.route( - content="帮我分析一下数据", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.match_method == "merged_llm_fallback" - assert result.complexity == 0.5 - - @pytest.mark.asyncio - async def test_merged_classify_low_complexity(self): - """合并调用返回 complexity < 0.3,走低复杂度路由""" - merged_response = json.dumps({ - "complexity": 0.2, - "intent": "greeting", - "skill_hint": None, - }) - gateway = _make_llm_gateway(merged_response) - router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True) - result = await router.route( - content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # Merged LLM returned complexity < 0.3, should route to low complexity - assert result.complexity < 0.3 - assert "low" in result.match_method or "merged_llm_low" in result.match_method - - @pytest.mark.asyncio - async def test_merged_classify_high_complexity(self): - """合并调用返回 complexity > 0.7,走 Layer 2""" - merged_response = json.dumps({ - "complexity": 0.85, - "intent": "research", - "skill_hint": None, - }) - gateway = _make_llm_gateway(merged_response) - org_context = MagicMock() - org_context.find_best_agent = MagicMock(return_value="researcher") - - router = CostAwareRouter( - llm_gateway=gateway, model="default", - org_context=org_context, merged_llm_classify=True, - ) - result = await router.route( - content="做市场调研+竞品分析", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - assert result.complexity > 0.7 - assert result.match_method == "capability" - - @pytest.mark.asyncio - async def test_merged_classify_disabled_falls_back_to_intent_router(self): - """配置 merged_llm_classify=False 时回退到独立 IntentRouter""" - gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) - search_skill = _make_skill("search", keywords=["分析"], description="数据分析") - registry = _make_skill_registry([search_skill]) - intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9})) - intent_router = IntentRouter(llm_gateway=intent_gateway, model="default") - - router = CostAwareRouter( - llm_gateway=gateway, model="default", - merged_llm_classify=False, - ) - result = await router.route( - content="分析下这个数据", - skill_registry=registry, - intent_router=intent_router, - default_tools=[], - default_system_prompt="You are helpful.", - ) - # Should not use merged_llm match_method - assert result.match_method != "merged_llm" - - @pytest.mark.asyncio - async def test_merged_classify_no_llm_gateway_falls_back(self): - """无 LLM Gateway 时 _classify_merged 回退到 IntentRouter""" - search_skill = _make_skill("search", keywords=["分析"], description="数据分析") - registry = _make_skill_registry([search_skill]) - - router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True) - result = await router.route( - content="分析下这个数据", - skill_registry=registry, - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # Should not crash, should use IntentRouter fallback - assert result is not None - - @pytest.mark.asyncio - async def test_merged_classify_skill_hint_not_found_fallback(self): - """合并调用返回的 skill_hint 在 registry 中不存在,fallback""" - merged_response = json.dumps({ - "complexity": 0.5, - "intent": "unknown", - "skill_hint": "nonexistent_skill", - }) - gateway = _make_llm_gateway(merged_response) - router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True) - result = await router.route( - content="帮我分析一下数据", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # Should fallback to default agent (medium complexity, no skill match) - assert result.matched is False - assert result.match_method == "merged_llm_medium" - - @pytest.mark.asyncio - async def test_merged_classify_only_one_llm_call(self): - """合并调用模式下,中等复杂度只产生 1 次 LLM 调用""" - merged_response = json.dumps({ - "complexity": 0.5, - "intent": "question", - "skill_hint": None, - }) - gateway = _make_llm_gateway(merged_response) - router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True) - await router.route( - content="如何使用这个功能?", - skill_registry=_make_skill_registry(), - intent_router=_make_intent_router(), - default_tools=[], - default_system_prompt="You are helpful.", - ) - # Only 1 LLM call should have been made (the merged classify) - assert gateway.chat.call_count == 1 diff --git a/tests/unit/test_semantic_router.py b/tests/unit/test_semantic_router.py deleted file mode 100644 index e1b4589..0000000 --- a/tests/unit/test_semantic_router.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Unit tests for Semantic Router (U3).""" - -import pytest - -from agentkit.chat.semantic_router import ( - SemanticRouteResult, - SkillEmbeddingIndex, - SemanticRouter, -) -from agentkit.memory.embedder import MockEmbedder - - -def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]: - """Create a unit vector for similarity testing.""" - vec = [base_val] * dim - magnitude = sum(x**2 for x in vec) ** 0.5 - return [x / magnitude for x in vec] if magnitude > 0 else vec - - -class MockSkill: - """Mock skill for testing.""" - - def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None): - self.name = name - self.config = MockSkillConfig( - name=name, - description=description, - keywords=keywords or [], - capabilities=capabilities or [], - ) - - -class MockSkillConfig: - """Mock skill config for testing.""" - - def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None): - self.name = name - self.description = description - self.intent = MockIntentConfig(keywords=keywords or []) - self.capabilities = [MockCapabilityTag(tag=t) for t in (capabilities or [])] - - -class MockIntentConfig: - def __init__(self, keywords: list[str] | None = None): - self.keywords = keywords or [] - - -class MockCapabilityTag: - def __init__(self, tag: str): - self.tag = tag - - -class MockSkillRegistry: - """Mock skill registry for testing.""" - - def __init__(self, skills: list[MockSkill] | None = None): - self._skills = {s.name: s for s in (skills or [])} - - def list_skills(self): - return list(self._skills.values()) - - def get(self, name: str): - if name not in self._skills: - raise KeyError(f"Skill '{name}' not found") - return self._skills[name] - - -class TestSkillEmbeddingIndex: - @pytest.mark.asyncio - async def test_build_from_registry(self): - embedder = MockEmbedder(dimension=64) - index = SkillEmbeddingIndex(embedder) - - skills = [ - MockSkill("content_gen", description="生成文章内容", keywords=["写作", "文章"], capabilities=["content"]), - MockSkill("data_analysis", description="数据分析与可视化", keywords=["分析", "数据"], capabilities=["analytics"]), - ] - registry = MockSkillRegistry(skills) - await index.build(registry) - - assert index.size == 2 - - @pytest.mark.asyncio - async def test_search_returns_results(self): - embedder = MockEmbedder(dimension=64) - index = SkillEmbeddingIndex(embedder) - - skill = MockSkill("content_gen", description="生成文章内容") - await index.update_skill("content_gen", skill) - - # MockEmbedder produces deterministic embeddings based on text hash - # Different text → different embedding - query_emb = await embedder.embed("生成文章") - results = await index.search(query_emb) - - assert len(results) >= 1 - assert results[0][0] == "content_gen" # skill_name - assert results[0][1] > 0.0 # similarity - - @pytest.mark.asyncio - async def test_search_empty_index(self): - embedder = MockEmbedder(dimension=64) - index = SkillEmbeddingIndex(embedder) - - query_emb = await embedder.embed("test") - results = await index.search(query_emb) - - assert results == [] - - @pytest.mark.asyncio - async def test_remove_skill(self): - embedder = MockEmbedder(dimension=64) - index = SkillEmbeddingIndex(embedder) - - skill = MockSkill("test_skill", description="Test") - await index.update_skill("test_skill", skill) - assert index.size == 1 - - index.remove_skill("test_skill") - assert index.size == 0 - - def test_build_source_text_with_description(self): - skill = MockSkill("test", description="A test skill", keywords=["test"], capabilities=["testing"]) - text = SkillEmbeddingIndex._build_source_text(skill) - assert "A test skill" in text - assert "test" in text - assert "testing" in text - - def test_build_source_text_fallback_to_name(self): - skill = MockSkill("my_skill", description="", keywords=[], capabilities=[]) - text = SkillEmbeddingIndex._build_source_text(skill) - assert "my_skill" in text - - -class TestSemanticRouter: - @pytest.mark.asyncio - async def test_high_confidence_match(self): - """When similarity > 0.85, return high confidence.""" - embedder = MockEmbedder(dimension=64) - router = SemanticRouter(embedder, similarity_high=0.5, similarity_low=0.3) - - # Add a skill with known embedding - skill = MockSkill("content_gen", description="生成文章内容") - await router.update_skill("content_gen", skill) - - # Query with same text should produce very similar embedding (MockEmbedder is hash-based) - # With low thresholds, even moderate similarity will be "high" - result = await router.route("生成文章内容") - # MockEmbedder may or may not produce high similarity for different text - # Just verify the result structure - assert result.confidence in ("high", "medium", "low") - assert isinstance(result.similarity, float) - - @pytest.mark.asyncio - async def test_low_confidence_empty_index(self): - """Empty index returns low confidence.""" - embedder = MockEmbedder(dimension=64) - router = SemanticRouter(embedder) - - result = await router.route("任何查询") - assert result.confidence == "low" - assert result.skill_name is None - assert result.similarity == 0.0 - - @pytest.mark.asyncio - async def test_medium_confidence_zone(self): - """Test medium confidence zone (0.6-0.85).""" - embedder = MockEmbedder(dimension=64) - router = SemanticRouter(embedder, similarity_high=0.99, similarity_low=0.01) - - skill = MockSkill("content_gen", description="生成文章内容") - await router.update_skill("content_gen", skill) - - # With very high similarity_high and very low similarity_low, - # most matches will be "medium" - result = await router.route("生成文章") - # The result should be medium (since threshold is 0.99) - assert result.confidence in ("medium", "low", "high") - - @pytest.mark.asyncio - async def test_embedder_failure_graceful(self): - """Embedder failure returns low confidence.""" - class FailingEmbedder(MockEmbedder): - async def embed(self, text): - raise RuntimeError("Embedding API failed") - - router = SemanticRouter(FailingEmbedder(dimension=64)) - result = await router.route("test query") - assert result.confidence == "low" - assert result.skill_name is None - - @pytest.mark.asyncio - async def test_build_index_from_registry(self): - """Build index from skill registry.""" - embedder = MockEmbedder(dimension=64) - router = SemanticRouter(embedder) - - skills = [ - MockSkill("skill_a", description="Skill A"), - MockSkill("skill_b", description="Skill B"), - ] - registry = MockSkillRegistry(skills) - await router.build_index(registry) - - assert router._index.size == 2 - - @pytest.mark.asyncio - async def test_chinese_query(self): - """Chinese query works with semantic router.""" - embedder = MockEmbedder(dimension=64) - router = SemanticRouter(embedder, similarity_high=0.01, similarity_low=0.001) - - skill = MockSkill("geo_optimizer", description="地理内容优化", keywords=["优化", "SEO", "地理"], capabilities=["optimization"]) - await router.update_skill("geo_optimizer", skill) - - result = await router.route("帮我优化内容") - # With very low thresholds, should match - assert result.confidence in ("high", "medium") - assert result.skill_name == "geo_optimizer"