refactor: eliminate routing layer, align with industry best practices
Phase 1 of architecture optimization (U1/U2/U4/U8): - U1: Rename SimpleRouter to RequestPreprocessor, route() to preprocess() Eliminates misleading routing concept; LLM decides autonomously in REACT agent loop (matches Codex/Claude Code/Trae pattern) - U2: Delete CostAwareRouter, HeuristicClassifier, SemanticRouter (~700 lines removed). skill_routing.py: 1688 to 220 lines - U4: PlanExecEngine defaults to ReActStepExecutor, delete _LLMStepExecutor (pure LLM calls without tools = no execution capability) - U8: ReActEngine defaults to ContextCompressor(keep_recent=10) Supersedes plans 2026-06-15-002/003/004. New plan: 2026-06-16-006-refactor-architecture-optimization-evolution-plan.md
This commit is contained in:
parent
b54213b3c6
commit
5374bc8501
|
|
@ -1,7 +1,10 @@
|
|||
---
|
||||
title: "feat: E2E能力分析框架改进与路由智能化提升"
|
||||
type: feat
|
||||
status: active
|
||||
status: superseded
|
||||
superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
|
||||
superseded_reason: "路由智能化提升部分已被 SimpleRouter 架构简化替代。E2E 能力分析框架部分可合并到 2026-06-16-005 的 U5 回测验证单元。"
|
||||
closed: 2026-06-16
|
||||
created: 2026-06-15
|
||||
plan-depth: standard
|
||||
---
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
---
|
||||
title: "feat: 路由智能化优化 — 复杂度校准、意图消歧、质量门控增强"
|
||||
status: active
|
||||
created: 2026-06-15
|
||||
updated: 2026-06-15
|
||||
status: superseded
|
||||
superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
|
||||
superseded_reason: "SimpleRouter 已替代 CostAwareRouter 的 4 层路由架构。IntentRouter 多候选评分(U2)和 QualityGate 技能匹配验证(U3)属于被删除的旧路由层组件,不再需要实现。U1 HeuristicClassifier 测试仅对向后兼容有价值。"
|
||||
closed: 2026-06-16
|
||||
origin: test-results/e2e/capability_report.txt (真实LLM回测分析报告)
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
```yaml
|
||||
title: feat: SemanticRouter 启用与回测体系升级
|
||||
status: active
|
||||
created: 2026-06-15
|
||||
plan_id: "2026-06-15-004"
|
||||
status: superseded
|
||||
superseded_by: "2026-06-16-005-refactor-routing-architecture-plan"
|
||||
superseded_reason: "SimpleRouter 已替代 CostAwareRouter,不再需要 SemanticRouter 作为路由层组件。LLM 在 REACT agent loop 中看到完整工具描述后自主决策,无需 embedding 做意图路由。如未来工具数量 >50,可参考 Codex 的 tool_search(BM25)做工具发现。"
|
||||
closed: 2026-06-16
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
|
@ -187,6 +188,12 @@ SemanticRouter 已完整实现(`src/agentkit/chat/semantic_router.py`),但
|
|||
- 对抗性输入测试
|
||||
- 意图分类微调流水线
|
||||
- 关键词自动扩充工具
|
||||
- **[待办] 提供 Embedding API Key** — 当前百炼 Coding Plan key (sk-sp-*) 不支持 embedding API,需要:
|
||||
- 方案 A:在 DashScope 控制台单独开通 embedding 服务,获取标准 sk- 前缀 key
|
||||
- 方案 B:配置本地 embedding 模型(BGE-large-zh-v1.5 via Xinference)
|
||||
- 方案 C:使用其他支持 embedding 的 provider(智谱/火山等)
|
||||
- 配置位置:agentkit.yaml → llm.providers 增加 embedding 专用 provider
|
||||
- 预期效果:SemanticRouter 真正工作,口语化查询无需手动扩充关键词
|
||||
|
||||
## Risks
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,429 @@
|
|||
```yaml
|
||||
title: "refactor: AgentKit 架构优化演进 — 对齐业界最佳实践"
|
||||
status: active
|
||||
plan_id: "2026-06-16-006"
|
||||
created: 2026-06-16
|
||||
depth: deep
|
||||
origin: "基于对 Codex/Claude Code/Trae/Qoder 的深入分析,优化 AgentKit 架构"
|
||||
```
|
||||
|
||||
# AgentKit 架构优化演进计划
|
||||
|
||||
## Summary
|
||||
|
||||
基于对 Codex CLI、Claude Code、Trae Agent 2.0、Qoder 的深入分析,对 AgentKit 进行架构优化演进。核心变更:(1) 消除"路由"概念改为请求预处理,(2) 专家团从去中心化协作简化为 hub-and-spoke,(3) PlanExec 简化为 Spec-Driven 模式,(4) 聊天记录 SQLite 持久化,(5) 删除旧路由层代码,(6) 新增可验证执行和工具描述精简。
|
||||
|
||||
## Problem Frame
|
||||
|
||||
AgentKit 当前架构存在三类问题:
|
||||
|
||||
1. **概念误导**:SimpleRouter 仍暗示"路由层"存在,但实际只做 @skill 前缀解析 + greeting fast-path,核心决策已交给 REACT agent loop 中的 LLM
|
||||
2. **过度设计**:ExpertTeam 的去中心化协作(CollaborationPlan + HandoffTransport + SharedWorkspace + 3 种 MergeStrategy)远超业界实践,增加 ~1500 行代码但无对应价值
|
||||
3. **能力缺口**:PlanExecEngine 默认使用 _LLMStepExecutor(纯 LLM 调用无工具),聊天记录不持久化,无自动验证循环
|
||||
|
||||
## Requirements
|
||||
|
||||
- R1: 消除"路由"概念,SimpleRouter 重命名为 RequestPreprocessor,语义从"路由决策"变为"请求预处理"
|
||||
- R2: 专家团简化为 hub-and-spoke 模式(Lead Expert + 并行 Task,深度=1),删除 CollaborationPlan/HandoffTransport/SharedWorkspace 的 ExpertTeam 专用逻辑
|
||||
- R3: PlanExecEngine 默认使用 ReActStepExecutor,删除 _LLMStepExecutor
|
||||
- R4: 聊天记录 SQLite 持久化(参考 Codex 的 Thread 持久化)
|
||||
- R5: 删除 CostAwareRouter 及相关代码(HeuristicClassifier、IntentRouter、QualityGate、SemanticRouter)
|
||||
- R6: 新增可验证执行(Test-and-Verify 循环,参考 Codex Cloud)
|
||||
- R7: 工具描述分层注入(核心工具全量 + 扩展工具一行描述 + tool_search 按需获取)
|
||||
- R8: 默认启用上下文压缩
|
||||
- R9: 新增 Spec 文档作为一等公民(参考 Qoder Quest Mode)
|
||||
- R10: 统一事件模型(SQ/EQ 双队列,参考 Codex)
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: SimpleRouter 重命名为 RequestPreprocessor
|
||||
|
||||
**决策**:将 SimpleRouter 重命名为 RequestPreprocessor,route() 方法重命名为 preprocess()
|
||||
|
||||
**依据**:
|
||||
- Codex 无路由层,Claude Code 无路由层,Trae 移除了 Proposal 阶段——业界共识是没有路由层
|
||||
- SimpleRouter 的 3 个功能(@skill 前缀、greeting regex、default REACT)都是预处理而非路由决策
|
||||
- "路由"概念误导开发者认为存在意图预测层,实际上 LLM 在 agent loop 中自主决策
|
||||
|
||||
**替代方案**:完全删除 SimpleRouter,所有请求直接进入 REACT loop。被否决——greeting fast-path 每次请求节省 ~100 tokens + 500ms,@skill 前缀是用户显式指令需要代码级解析
|
||||
|
||||
### KTD2: 专家团 hub-and-spoke 模式
|
||||
|
||||
**决策**:ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task(深度=1)
|
||||
|
||||
**依据**:
|
||||
- Claude Code:Task 工具深度=1,子 Agent 不能再生子 Agent
|
||||
- Codex:spawn_agent 层级式,结果返回父 Agent
|
||||
- Qoder:多专家并行独立执行,主 Agent 汇总
|
||||
- 去中心化协作的通信复杂度 O(N²),hub-and-spoke 为 O(N)
|
||||
- 同一 LLM 扮演不同"专家"不产生真正的观点多样性,等价于多次采样+合并
|
||||
|
||||
**保留**:ExpertConfig/ExpertTemplate/Registry(定义专家 persona)、BEST 合并策略(Lead Agent 选择最佳结果)
|
||||
|
||||
**删除**:CollaborationPlan 的 phase 依赖图、HandoffTransport 的 Agent 间通信、SharedWorkspace 的跨阶段状态共享、VOTE/FUSION 合并策略
|
||||
|
||||
### KTD3: PlanExec 默认 ReActStepExecutor + Spec-Driven
|
||||
|
||||
**决策**:默认使用 ReActStepExecutor(已实现),删除 _LLMStepExecutor;新增 Spec 文档持久化
|
||||
|
||||
**依据**:
|
||||
- _LLMStepExecutor 不支持工具调用 = 没有执行能力
|
||||
- ReActStepExecutor 已实现并使用 ReActEngine,支持工具调用和多步推理
|
||||
- Qoder Quest Mode:Spec First 是人和 AI 的契约,用户确认后再执行
|
||||
- 当前 PlanExec 的计划对用户不可见,用户无法在执行前纠正方向
|
||||
|
||||
### KTD4: 聊天记录 SQLite 持久化
|
||||
|
||||
**决策**:使用 SQLite 做聊天记录持久化(参考 Codex 的 Thread 持久化)
|
||||
|
||||
**依据**:
|
||||
- Codex 使用 SQLite(轻量、零配置、跨平台),支持 resume/fork/archive
|
||||
- Claude Code 使用 append-only JSONL(更简单但搜索能力弱)
|
||||
- 当前 ConversationStore 纯内存,服务重启后丢失
|
||||
- SQLite 支持会话搜索、分页加载,且无需额外服务
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. SimpleRouter 重命名为 RequestPreprocessor
|
||||
|
||||
**Goal**: 消除"路由"概念,将 SimpleRouter 重命名为 RequestPreprocessor
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/chat/simple_router.py` → 重命名为 `src/agentkit/chat/request_preprocessor.py`
|
||||
- `src/agentkit/chat/skill_routing.py` — 更新引用
|
||||
- `src/agentkit/server/routes/portal.py` — 更新 import 和调用
|
||||
- `src/agentkit/server/routes/chat.py` — 更新 import 和调用
|
||||
- `src/agentkit/server/app.py` — 更新 import 和调用
|
||||
- `tests/unit/chat/test_simple_router.py` → 重命名并更新
|
||||
|
||||
**Approach**:
|
||||
1. 创建 `request_preprocessor.py`,类名 `RequestPreprocessor`,方法 `route()` → `preprocess()`
|
||||
2. `_is_direct_chat()` 重命名为 `_is_trivial_input()`
|
||||
3. `SkillRoutingResult` 保留(它是数据结构,不涉及路由概念)
|
||||
4. 更新所有调用点
|
||||
5. 删除旧文件
|
||||
|
||||
**Patterns to follow**: 现有 SimpleRouter 的代码结构
|
||||
|
||||
**Test scenarios**:
|
||||
- @skill:xxx 前缀正确解析为 SKILL_REACT 模式
|
||||
- Greeting regex 匹配返回 DIRECT_CHAT
|
||||
- 默认输入返回 REACT 模式
|
||||
- 未知 skill 回退到 REACT
|
||||
- preprocess() 方法签名与 route() 兼容
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/chat/ -v`
|
||||
|
||||
---
|
||||
|
||||
### U2. 删除 CostAwareRouter 及相关代码
|
||||
|
||||
**Goal**: 删除已被 SimpleRouter 替代的旧路由层代码
|
||||
|
||||
**Dependencies**: U1
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/router/` 目录下大部分文件(保留 __init__.py 和必要的导出)
|
||||
- `src/agentkit/server/app.py` — 清理注释掉的引用
|
||||
- `src/agentkit/chat/cost_aware_router.py` — 删除
|
||||
|
||||
**Approach**:
|
||||
1. 确认 CostAwareRouter 在代码中无活跃引用(app.py 已注释)
|
||||
2. 删除 `cost_aware_router.py`、`heuristic_classifier.py`、`intent.py`(IntentRouter)、`quality_gate.py`、`semantic_router.py`
|
||||
3. 保留 `router/__init__.py` 导出必要的类型(如 ExecutionMode,如果前端依赖)
|
||||
4. 清理 app.py 中的注释引用
|
||||
5. 更新 `router/` 目录的 `__init__.py`
|
||||
|
||||
**Test scenarios**:
|
||||
- 删除后 `ruff check` 无错误
|
||||
- `pytest -m "not integration"` 全部通过
|
||||
- 无 import 错误
|
||||
|
||||
**Verification**: `ruff check src/ && pytest -m "not integration" -x`
|
||||
|
||||
---
|
||||
|
||||
### U3. 专家团简化为 hub-and-spoke
|
||||
|
||||
**Goal**: 将 ExpertTeam 从去中心化协作简化为 Lead Expert + 并行 Task 模式
|
||||
|
||||
**Dependencies**: 无(可与 U1/U2 并行)
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/experts/orchestrator.py` — 重写为 hub-and-spoke
|
||||
- `src/agentkit/experts/team.py` — 简化,移除 CollaborationPlan 依赖
|
||||
- `src/agentkit/experts/plan.py` — 简化,保留 MergeStrategy.BEST
|
||||
- `src/agentkit/core/handoff_transport.py` — 移除 ExpertTeam 专用逻辑
|
||||
- `src/agentkit/core/shared_workspace.py` — 移除 ExpertTeam 专用逻辑
|
||||
- `src/agentkit/experts/router.py` — 简化为 @team 前缀 + RequestPreprocessor 集成
|
||||
- `tests/unit/experts/test_orchestrator.py` — 更新
|
||||
|
||||
**Approach**:
|
||||
1. 重写 TeamOrchestrator:Lead Expert 自主规划 + 并行 spawn Task
|
||||
2. 删除 CollaborationPlan 的 phase 依赖图,Lead Expert 自主决定执行顺序
|
||||
3. 删除 HandoffTransport 的 Agent 间通信,Task 结果直接返回 Lead Expert
|
||||
4. 删除 SharedWorkspace 的跨阶段状态共享,Lead Expert 持有所有状态
|
||||
5. 保留 MergeStrategy.BEST(Lead Agent 选择最佳结果),删除 VOTE/FUSION
|
||||
6. 简化 ExpertTeamRouter 为 @team 前缀触发
|
||||
7. 保留 ExpertConfig/ExpertTemplate/Registry 不变
|
||||
|
||||
**Technical design**:
|
||||
|
||||
```
|
||||
新 TeamOrchestrator 流程:
|
||||
1. 用户输入 → @team:xxx 前缀 → ExpertTeamMode
|
||||
2. Lead Expert 接收任务,自主分解为子任务
|
||||
3. 并行 spawn Task(每个 Task 是独立 ReActEngine 实例)
|
||||
4. 等待所有 Task 完成
|
||||
5. Lead Expert 汇总结果(BEST 策略)
|
||||
6. 返回最终结果
|
||||
|
||||
约束:
|
||||
- Task 深度=1(Task 不能再 spawn Task)
|
||||
- Task 之间无通信
|
||||
- Lead Expert 持有所有状态
|
||||
```
|
||||
|
||||
**Test scenarios**:
|
||||
- Lead Expert 正确分解任务为子任务
|
||||
- 并行 Task 独立执行并返回结果
|
||||
- Lead Expert 汇总结果
|
||||
- 单个 Task 失败不影响其他 Task
|
||||
- 所有 Task 失败时回退到 Lead Expert 单独执行
|
||||
- @team 前缀正确触发 ExpertTeamMode
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/experts/ -v`
|
||||
|
||||
---
|
||||
|
||||
### U4. PlanExec 默认 ReActStepExecutor + 删除 _LLMStepExecutor
|
||||
|
||||
**Goal**: PlanExecEngine 默认使用 ReActStepExecutor,删除 _LLMStepExecutor
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/plan_exec_engine.py` — 删除 _LLMStepExecutor 和 _LLMStepAgent,默认 step_executor_type="react"
|
||||
- `tests/unit/core/test_plan_exec_engine.py` — 更新
|
||||
|
||||
**Approach**:
|
||||
1. 删除 `_LLMStepExecutor` 和 `_LLMStepAgent` 类
|
||||
2. `_create_executor()` 方法移除 step_executor_type 参数,始终使用 ReActStepExecutor
|
||||
3. 清理相关 import
|
||||
|
||||
**Test scenarios**:
|
||||
- PlanExecEngine 默认创建 ReActStepExecutor
|
||||
- ReActStepExecutor 正确执行带工具调用的步骤
|
||||
- 步骤失败时触发重规划
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/core/test_plan_exec_engine.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U5. 聊天记录 SQLite 持久化
|
||||
|
||||
**Goal**: 使用 SQLite 持久化聊天记录,服务重启后不丢失
|
||||
|
||||
**Dependencies**: U1(RequestPreprocessor 重命名完成后更新调用点)
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/chat/sqlite_conversation_store.py` — 新建
|
||||
- `src/agentkit/server/routes/portal.py` — 替换 ConversationStore
|
||||
- `src/agentkit/chat/conversation_store.py` — 保留作为接口/内存实现
|
||||
|
||||
**Approach**:
|
||||
1. 新建 `SqliteConversationStore`,实现与 `ConversationStore` 相同接口
|
||||
2. SQLite 表结构:conversations(id, session_id, role, content, timestamp, metadata)
|
||||
3. 支持按 session_id 查询、分页加载、搜索
|
||||
4. 数据库文件路径:`~/.agentkit/conversations.db`
|
||||
5. 在 portal.py 中替换 ConversationStore 为 SqliteConversationStore
|
||||
6. 保留 ConversationStore 作为内存实现(测试用)
|
||||
|
||||
**Test scenarios**:
|
||||
- 消息正确持久化到 SQLite
|
||||
- 按 session_id 查询返回完整对话
|
||||
- 分页加载正确
|
||||
- 服务重启后数据不丢失
|
||||
- SQLite 文件不存在时自动创建
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/chat/test_sqlite_conversation_store.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U6. 可验证执行(Test-and-Verify 循环)
|
||||
|
||||
**Goal**: ReActEngine 执行后可选自动运行项目测试验证结果
|
||||
|
||||
**Dependencies**: U4
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/verification_loop.py` — 新建
|
||||
- `src/agentkit/core/react.py` — 集成验证循环
|
||||
- `src/agentkit/tools/builtin.py` — 新增 run_tests 工具
|
||||
|
||||
**Approach**:
|
||||
1. 新建 `VerificationLoop`:执行后运行 pytest/ruff/typecheck,失败则自动重试
|
||||
2. 最大重试次数可配置(默认 2)
|
||||
3. 验证结果附加到 ReActResult
|
||||
4. 新增 `run_tests` 内置工具,LLM 可主动调用
|
||||
5. 验证循环默认关闭,通过参数 `verification_enabled=True` 启用
|
||||
|
||||
**Test scenarios**:
|
||||
- 验证循环关闭时行为不变
|
||||
- 验证循环开启时,执行后自动运行测试
|
||||
- 测试通过时返回成功
|
||||
- 测试失败时自动重试
|
||||
- 达到最大重试次数后返回失败结果
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/core/test_verification_loop.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U7. 工具描述分层注入 + tool_search
|
||||
|
||||
**Goal**: 核心工具全量注入,扩展工具只注入名称+一行描述,LLM 可通过 tool_search 获取完整描述
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/react.py` — 修改 `_build_tool_use_prompt`
|
||||
- `src/agentkit/tools/builtin.py` — 新增 tool_search 工具
|
||||
- `src/agentkit/tools/search.py` — 新建,BM25 工具搜索
|
||||
|
||||
**Approach**:
|
||||
1. 工具分为 core(read/write/bash/search)和 extended(其余)
|
||||
2. core 工具全量注入 prompt
|
||||
3. extended 工具只注入 name + one-line description
|
||||
4. 新增 `tool_search` 工具:BM25 搜索工具描述,返回完整描述
|
||||
5. LLM 在 agent loop 中按需调用 tool_search
|
||||
|
||||
**Test scenarios**:
|
||||
- core 工具全量出现在 prompt 中
|
||||
- extended 工具只出现名称和一行描述
|
||||
- tool_search 正确返回工具完整描述
|
||||
- BM25 搜索相关性排序
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/tools/test_tool_search.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U8. 默认启用上下文压缩
|
||||
|
||||
**Goal**: ReActEngine 默认启用滑动窗口压缩
|
||||
|
||||
**Dependencies**: 无
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/react.py` — 修改默认 compressor 参数
|
||||
- `src/agentkit/core/compressor.py` — 确认滑动窗口实现
|
||||
|
||||
**Approach**:
|
||||
1. ReActEngine 的 `__init__` 中 compressor 默认值从 None 改为 SlidingWindowCompressor
|
||||
2. 保留最近 N 轮 + 系统提示 + 工具描述
|
||||
3. N 可配置(默认 10)
|
||||
|
||||
**Test scenarios**:
|
||||
- 长对话自动压缩
|
||||
- 压缩后系统提示和工具描述保留
|
||||
- 压缩不影响最近 N 轮对话
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/core/test_compressor.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U9. Spec 文档作为一等公民
|
||||
|
||||
**Goal**: PlanExec 生成的计划持久化为 Spec 文档,用户可查看、编辑、确认后再执行
|
||||
|
||||
**Dependencies**: U4
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/spec_manager.py` — 新建
|
||||
- `src/agentkit/core/plan_exec_engine.py` — 集成 SpecManager
|
||||
- `src/agentkit/server/routes/tasks.py` — 新增 Spec 相关 API
|
||||
|
||||
**Approach**:
|
||||
1. 新建 `SpecManager`:管理 Spec 文档的 CRUD
|
||||
2. Spec 文件路径:`.agentkit/specs/<plan_id>.yaml`
|
||||
3. PlanExecEngine 生成计划后,先持久化为 Spec
|
||||
4. 新增 API:`GET /api/v1/specs`、`GET /api/v1/specs/{id}`、`PUT /api/v1/specs/{id}`、`POST /api/v1/specs/{id}/confirm`
|
||||
5. 用户确认后才开始执行
|
||||
|
||||
**Test scenarios**:
|
||||
- 计划正确持久化为 Spec 文件
|
||||
- Spec 文件可读取和编辑
|
||||
- 未确认的 Spec 不会执行
|
||||
- 确认后触发执行
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/core/test_spec_manager.py -v`
|
||||
|
||||
---
|
||||
|
||||
### U10. 统一事件模型(SQ/EQ 双队列)
|
||||
|
||||
**Goal**: 统一 CLI 和 WebSocket 的事件模型为 SQ/EQ 双队列
|
||||
|
||||
**Dependencies**: U3
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/protocol.py` — 新增 SQ/EQ 事件类型
|
||||
- `src/agentkit/server/routes/portal.py` — 对接 EQ
|
||||
- `src/agentkit/cli/chat.py` — 对接 EQ
|
||||
|
||||
**Approach**:
|
||||
1. 定义 SubmissionQueue(用户输入)和 EventQueue(Agent 输出)
|
||||
2. 事件类型:Session/Task/Turn 三级模型
|
||||
3. Portal WebSocket 和 CLI 共享同一事件流
|
||||
4. 前端可以统一渲染
|
||||
|
||||
**Test scenarios**:
|
||||
- SQ 正确接收用户输入
|
||||
- EQ 正确推送 Agent 事件
|
||||
- WebSocket 和 CLI 共享事件流
|
||||
- 事件类型正确分类
|
||||
|
||||
**Verification**: `ruff check src/ && pytest tests/unit/core/test_protocol.py -v`
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
- 上述 10 个 Implementation Unit
|
||||
- 单元测试覆盖
|
||||
|
||||
### Out of Scope
|
||||
- DockerComputerUseSession 实现(P3,等用户需求验证)
|
||||
- 前端组件更新(后续迭代)
|
||||
- Agent 配置热重载(P4)
|
||||
- 渐进式上下文加载(P4)
|
||||
- Soul 演变多维度触发(关闭)
|
||||
- OTel 埋点(延后)
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
- 前端 ExpertTeamView 接入真实数据
|
||||
- SWE-bench 端到端验证
|
||||
- 性能监控和成本追踪
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
|------|--------|------------|
|
||||
| U3 专家团重写可能影响现有 @team 功能 | 中 | 保留 ExpertConfig/ExpertTemplate/Registry,只重写 Orchestrator |
|
||||
| U5 SQLite 在高并发下可能有锁竞争 | 低 | 聊天场景写频率低,SQLite WAL 模式足够 |
|
||||
| U7 tool_search 可能增加 LLM 调用轮次 | 中 | 核心工具全量注入,只有扩展工具需要搜索 |
|
||||
| U9 Spec 文档可能增加用户操作步骤 | 低 | 默认自动确认(可配置),不阻塞自动化流程 |
|
||||
|
||||
---
|
||||
|
||||
## Phased Delivery
|
||||
|
||||
**Phase 1(清理收尾)**: U1 → U2 → U4 → U8
|
||||
**Phase 2(核心能力)**: U5 → U6 → U9
|
||||
**Phase 3(多 Agent)**: U3 → U7 → U10
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Simple router — minimal routing layer for unified REACT agent loop.
|
||||
"""Request preprocessor — minimal preprocessing layer for unified REACT agent loop.
|
||||
|
||||
Replaces the 4-layer CostAwareRouter with a simple approach:
|
||||
1. @skill:xxx prefix → explicit skill selection
|
||||
|
|
@ -53,15 +53,15 @@ _IDENTITY_RE = re.compile(
|
|||
)
|
||||
|
||||
|
||||
class SimpleRouter:
|
||||
"""Minimal routing layer: regex fast-path + default REACT.
|
||||
class RequestPreprocessor:
|
||||
"""Minimal preprocessing layer: regex fast-path + default REACT.
|
||||
|
||||
Design rationale:
|
||||
- No HeuristicClassifier: keyword enumeration can never cover all colloquial expressions
|
||||
- No IntentRouter: LLM blind-classification without tool context is unreliable
|
||||
- No SemanticRouter: embedding similarity is not intent recognition
|
||||
- LLM in the REACT agent loop sees full tool descriptions and decides autonomously
|
||||
- This matches Codex/Trae/Hermes architecture: unified agent loop, no routing layer
|
||||
- This matches Codex/Trae/Hermes architecture: unified agent loop, no preprocessing layer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -78,7 +78,7 @@ class SimpleRouter:
|
|||
self._default_model = default_model
|
||||
self._default_agent_name = default_agent_name
|
||||
|
||||
async def route(
|
||||
async def preprocess(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
|
|
@ -90,7 +90,7 @@ class SimpleRouter:
|
|||
session_id: str = "",
|
||||
transparency: str = "SILENT",
|
||||
) -> SkillRoutingResult:
|
||||
"""Route user input to the appropriate execution path.
|
||||
"""Preprocess user input to determine the appropriate execution path.
|
||||
|
||||
Decision tree:
|
||||
1. @skill:xxx prefix → explicit skill (SKILL_REACT or skill's execution_mode)
|
||||
|
|
@ -99,21 +99,25 @@ class SimpleRouter:
|
|||
"""
|
||||
registry = skill_registry or self._skill_registry
|
||||
tools = default_tools if default_tools is not None else self._default_tools
|
||||
sys_prompt = default_system_prompt if default_system_prompt is not None else self._default_system_prompt
|
||||
sys_prompt = (
|
||||
default_system_prompt
|
||||
if default_system_prompt is not None
|
||||
else self._default_system_prompt
|
||||
)
|
||||
model = default_model or self._default_model
|
||||
agent_name = default_agent_name or self._default_agent_name
|
||||
|
||||
# --- Layer 0: @skill:xxx prefix ---
|
||||
explicit_skill, clean_content = parse_skill_prefix(content)
|
||||
if explicit_skill and registry is not None:
|
||||
result = self._route_explicit_skill(
|
||||
result = self._resolve_explicit_skill(
|
||||
explicit_skill, clean_content, registry, model, agent_name
|
||||
)
|
||||
return result
|
||||
|
||||
# --- Layer 1: Greeting/chitchat/identity regex (<1ms, zero tokens) ---
|
||||
stripped = content.strip()
|
||||
if self._is_direct_chat(stripped):
|
||||
if self._is_trivial_input(stripped):
|
||||
result = SkillRoutingResult(
|
||||
clean_content=stripped,
|
||||
matched=False,
|
||||
|
|
@ -141,7 +145,7 @@ class SimpleRouter:
|
|||
)
|
||||
return result
|
||||
|
||||
def _route_explicit_skill(
|
||||
def _resolve_explicit_skill(
|
||||
self,
|
||||
skill_name: str,
|
||||
clean_content: str,
|
||||
|
|
@ -149,7 +153,7 @@ class SimpleRouter:
|
|||
model: str,
|
||||
agent_name: str,
|
||||
) -> SkillRoutingResult:
|
||||
"""Route to an explicitly specified skill via @skill:xxx prefix."""
|
||||
"""Resolve an explicitly specified skill via @skill:xxx prefix."""
|
||||
try:
|
||||
skill = registry.get(skill_name)
|
||||
except Exception:
|
||||
|
|
@ -185,13 +189,11 @@ class SimpleRouter:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_direct_chat(text: str) -> bool:
|
||||
def _is_trivial_input(text: str) -> bool:
|
||||
"""Check if the input is a greeting, chitchat, or identity question.
|
||||
|
||||
These are zero-cost direct chat: no tool usage, no ReAct loop needed.
|
||||
"""
|
||||
return bool(
|
||||
_GREETING_RE.match(text)
|
||||
or _CHAT_MODE_RE.match(text)
|
||||
or _IDENTITY_RE.match(text)
|
||||
_GREETING_RE.match(text) or _CHAT_MODE_RE.match(text) or _IDENTITY_RE.match(text)
|
||||
)
|
||||
|
|
@ -1,224 +0,0 @@
|
|||
"""Semantic Router — Embedding-based intent routing as Layer 1.5.
|
||||
|
||||
Uses pre-computed skill embeddings for zero-cost semantic matching,
|
||||
inserted between Layer 1 (HeuristicClassifier) and Layer 2 (LLM classification)
|
||||
in CostAwareRouter.
|
||||
|
||||
Design doc: docs/plans/2026-06-14-004-u3-semantic-router.md
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.embedder import Embedder, EmbeddingCache
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SemanticRouteResult:
|
||||
"""Result of semantic routing."""
|
||||
|
||||
confidence: str # "high" | "medium" | "low"
|
||||
skill_name: str | None
|
||||
similarity: float
|
||||
|
||||
|
||||
class SkillEmbeddingIndex:
|
||||
"""Pre-computed embedding index for registered skills.
|
||||
|
||||
Embeddings are computed at skill registration time and cached.
|
||||
Query-time search is O(n) cosine similarity scan, which is fast
|
||||
for <100 skills with 1024-1536 dim vectors.
|
||||
"""
|
||||
|
||||
def __init__(self, embedder: Embedder):
|
||||
self._embedder = embedder
|
||||
# skill_name → (embedding, source_text)
|
||||
self._index: dict[str, tuple[list[float], str]] = {}
|
||||
|
||||
async def build(self, skill_registry: Any) -> None:
|
||||
"""Build index from all registered skills."""
|
||||
if skill_registry is None:
|
||||
return
|
||||
skills = skill_registry.list_skills()
|
||||
for skill in skills:
|
||||
await self.update_skill(skill.config.name, skill)
|
||||
|
||||
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
||||
"""Re-embed a single skill (on registration/update)."""
|
||||
source_text = self._build_source_text(skill)
|
||||
try:
|
||||
embedding = await self._embedder.embed(source_text)
|
||||
self._index[skill_name] = (embedding, source_text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to embed skill '{skill_name}': {e}")
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the index."""
|
||||
self._index.pop(skill_name, None)
|
||||
|
||||
async def search(self, query_embedding: list[float], top_k: int = 5) -> list[tuple[str, float]]:
|
||||
"""Search for skills matching the query embedding.
|
||||
|
||||
Returns:
|
||||
List of (skill_name, similarity) sorted by similarity descending.
|
||||
"""
|
||||
if not self._index:
|
||||
return []
|
||||
|
||||
results: list[tuple[str, float]] = []
|
||||
for skill_name, (emb, _) in self._index.items():
|
||||
sim = compute_cosine_similarity(query_embedding, emb)
|
||||
results.append((skill_name, sim))
|
||||
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _build_source_text(skill: Any) -> str:
|
||||
"""Build embedding source text from skill metadata.
|
||||
|
||||
Combines description, intent keywords, and capability tags
|
||||
for rich semantic representation.
|
||||
"""
|
||||
config = skill.config if hasattr(skill, "config") else skill
|
||||
parts = []
|
||||
|
||||
# Description
|
||||
description = getattr(config, "description", "") or ""
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Intent keywords
|
||||
intent = getattr(config, "intent", None)
|
||||
if intent and hasattr(intent, "keywords") and intent.keywords:
|
||||
parts.append(" ".join(intent.keywords))
|
||||
|
||||
# Intent examples (rich semantic signal for short queries)
|
||||
if intent and hasattr(intent, "examples") and intent.examples:
|
||||
parts.append(" ".join(intent.examples))
|
||||
|
||||
# Capability tags
|
||||
capabilities = getattr(config, "capabilities", None)
|
||||
if capabilities:
|
||||
tags = []
|
||||
for cap in capabilities:
|
||||
if isinstance(cap, str):
|
||||
tags.append(cap)
|
||||
elif isinstance(cap, dict):
|
||||
tags.append(cap.get("tag", ""))
|
||||
elif hasattr(cap, "tag"):
|
||||
tags.append(cap.tag)
|
||||
if tags:
|
||||
parts.append(" ".join(t for t in tags if t))
|
||||
|
||||
# Fallback: use skill name if no other text available
|
||||
if not parts:
|
||||
parts.append(getattr(config, "name", "unknown"))
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Number of skills in the index."""
|
||||
return len(self._index)
|
||||
|
||||
|
||||
class SemanticRouter:
|
||||
"""Embedding-based semantic routing as Layer 1.5.
|
||||
|
||||
Three confidence zones:
|
||||
- similarity > similarity_high (0.85): HIGH → direct skill match, skip Layer 2
|
||||
- similarity_low (0.4) <= similarity <= similarity_high: MEDIUM → skill hint for Layer 2
|
||||
- similarity < similarity_low (0.4): LOW → no semantic signal, normal routing
|
||||
|
||||
Short text (<20 chars) uses a lower effective threshold because
|
||||
brief queries naturally have lower embedding similarity.
|
||||
"""
|
||||
|
||||
_SHORT_TEXT_THRESHOLD = 20 # chars
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: Embedder,
|
||||
similarity_high: float = 0.85,
|
||||
similarity_low: float = 0.4,
|
||||
):
|
||||
self._embedder = embedder
|
||||
self._similarity_high = similarity_high
|
||||
self._similarity_low = similarity_low
|
||||
self._index = SkillEmbeddingIndex(embedder)
|
||||
self._query_cache = EmbeddingCache(max_size=500, ttl=1800)
|
||||
|
||||
async def build_index(self, skill_registry: Any) -> None:
|
||||
"""Build skill embedding index from registry."""
|
||||
await self._index.build(skill_registry)
|
||||
logger.info(f"Semantic router index built: {self._index.size} skills")
|
||||
|
||||
async def update_skill(self, skill_name: str, skill: Any) -> None:
|
||||
"""Update a single skill's embedding."""
|
||||
await self._index.update_skill(skill_name, skill)
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the index."""
|
||||
self._index.remove_skill(skill_name)
|
||||
|
||||
async def route(self, query: str) -> SemanticRouteResult:
|
||||
"""Route a query using semantic similarity.
|
||||
|
||||
Args:
|
||||
query: User's input text.
|
||||
|
||||
Returns:
|
||||
SemanticRouteResult with confidence, skill_name, and similarity.
|
||||
"""
|
||||
if self._index.size == 0:
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
||||
if not query or not query.strip():
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
||||
try:
|
||||
# Get query embedding (with cache)
|
||||
query_embedding = self._query_cache.get(query)
|
||||
if query_embedding is None:
|
||||
query_embedding = await self._embedder.embed(query)
|
||||
self._query_cache.put(query, query_embedding)
|
||||
|
||||
# Search skill index
|
||||
results = await self._index.search(query_embedding, top_k=1)
|
||||
if not results:
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
|
||||
best_skill, best_sim = results[0]
|
||||
|
||||
# Short text uses lower effective threshold
|
||||
effective_low = self._similarity_low
|
||||
if len(query) < self._SHORT_TEXT_THRESHOLD:
|
||||
effective_low = max(0.25, self._similarity_low - 0.15)
|
||||
|
||||
if best_sim >= self._similarity_high:
|
||||
return SemanticRouteResult(
|
||||
confidence="high",
|
||||
skill_name=best_skill,
|
||||
similarity=best_sim,
|
||||
)
|
||||
elif best_sim >= effective_low:
|
||||
return SemanticRouteResult(
|
||||
confidence="medium",
|
||||
skill_name=best_skill,
|
||||
similarity=best_sim,
|
||||
)
|
||||
else:
|
||||
return SemanticRouteResult(
|
||||
confidence="low",
|
||||
skill_name=None,
|
||||
similarity=best_sim,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic routing failed, returning low confidence: {e}")
|
||||
return SemanticRouteResult(confidence="low", skill_name=None, similarity=0.0)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -96,11 +96,10 @@ async def _chat_async(
|
|||
WebCrawlTool(),
|
||||
]
|
||||
|
||||
# ── Load skills and build IntentRouter ───────────────────────
|
||||
# ── Load skills ────────────────────────────────────────────
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
from agentkit.router.intent import IntentRouter
|
||||
|
||||
tool_registry = ToolRegistry()
|
||||
for tool in tools:
|
||||
|
|
@ -123,8 +122,6 @@ async def _chat_async(
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
intent_router = IntentRouter(llm_gateway=gateway) if skill_registry.list_skills() else None
|
||||
|
||||
# Build system prompt — inject memory into system prompt
|
||||
base_prompt = system_prompt or (
|
||||
"你是一个有帮助的AI助手。请记住我们对话的上下文,并在后续对话中引用之前的内容。回答要清晰简洁,请使用中文回复。"
|
||||
|
|
@ -218,7 +215,6 @@ async def _chat_async(
|
|||
routing = await resolve_skill_routing(
|
||||
content=user_input,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=tools,
|
||||
default_system_prompt=effective_system_prompt,
|
||||
default_model=current_model,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from agentkit.core.plan_schema import ExecutionPlan, PlanStep, PlanStepStatus
|
|||
from agentkit.core.protocol import CancellationToken, TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.core.react import ReActEvent, ReActResult, ReActStep
|
||||
from agentkit.core.shared_workspace import SharedWorkspace
|
||||
from agentkit.core.spec_manager import Spec, SpecManager, SpecStep
|
||||
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
|
||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineResult, StageResult, StageStatus
|
||||
|
||||
|
|
@ -73,6 +74,7 @@ class PlanExecEngine:
|
|||
default_timeout: float = 300.0,
|
||||
workspace: SharedWorkspace | None = None,
|
||||
step_event_callback: "Callable[[str, dict[str, Any]], Awaitable[None]] | None" = None,
|
||||
spec_manager: SpecManager | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -81,12 +83,14 @@ class PlanExecEngine:
|
|||
default_timeout: 默认超时秒数
|
||||
workspace: SharedWorkspace 实例,用于步骤间状态传递
|
||||
step_event_callback: 步骤事件回调,用于非流式执行时推送进度
|
||||
spec_manager: SpecManager 实例,用于持久化执行计划为 Spec 文档
|
||||
"""
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_replans = max_replans
|
||||
self._default_timeout = default_timeout
|
||||
self._workspace = workspace
|
||||
self._step_event_callback = step_event_callback
|
||||
self._spec_manager = spec_manager
|
||||
self._confirmation_handler: Any | None = None
|
||||
|
||||
# 组合子组件
|
||||
|
|
@ -261,6 +265,17 @@ class PlanExecEngine:
|
|||
tokens=0,
|
||||
))
|
||||
|
||||
# Persist plan as Spec if spec_manager is provided
|
||||
if self._spec_manager is not None:
|
||||
spec = self._plan_to_spec(plan)
|
||||
self._spec_manager.create(spec)
|
||||
state.step_counter += 1
|
||||
yield ReActEvent(
|
||||
event_type="spec_created",
|
||||
step=state.step_counter,
|
||||
data={"spec_id": spec.spec_id, "goal": spec.goal, "num_steps": len(spec.steps)},
|
||||
)
|
||||
|
||||
# ── Phase 2 & 3: Execute with optional replanning ──
|
||||
current_plan = plan
|
||||
replan_count = 0
|
||||
|
|
@ -509,6 +524,20 @@ class PlanExecEngine:
|
|||
tokens=0,
|
||||
))
|
||||
|
||||
# Persist plan as Spec if spec_manager is provided
|
||||
if self._spec_manager is not None:
|
||||
spec = self._plan_to_spec(plan)
|
||||
self._spec_manager.create(spec)
|
||||
if self._step_event_callback:
|
||||
try:
|
||||
await self._step_event_callback("spec_created", {
|
||||
"spec_id": spec.spec_id,
|
||||
"goal": spec.goal,
|
||||
"num_steps": len(spec.steps),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Step event callback failed: {e}")
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=1,
|
||||
|
|
@ -613,7 +642,7 @@ class PlanExecEngine:
|
|||
task_id=task_id,
|
||||
)
|
||||
|
||||
# 创建 PlanExecutor(使用 LLM 直接调用模式)
|
||||
# 创建 PlanExecutor
|
||||
executor = self._create_executor(
|
||||
messages=messages,
|
||||
model=model,
|
||||
|
|
@ -734,6 +763,25 @@ class PlanExecEngine:
|
|||
# 辅助方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _plan_to_spec(plan: ExecutionPlan) -> Spec:
|
||||
"""Convert an ExecutionPlan to a Spec for persistence."""
|
||||
steps = [
|
||||
SpecStep(
|
||||
step_id=s.step_id,
|
||||
name=s.name,
|
||||
description=s.description,
|
||||
dependencies=s.dependencies,
|
||||
)
|
||||
for s in plan.steps
|
||||
]
|
||||
return Spec(
|
||||
spec_id=plan.plan_id,
|
||||
goal=plan.goal,
|
||||
steps=steps,
|
||||
metadata=plan.metadata,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_goal(messages: list[dict[str, str]]) -> str:
|
||||
"""从消息列表中提取用户目标"""
|
||||
|
|
@ -779,31 +827,16 @@ class PlanExecEngine:
|
|||
model: str,
|
||||
system_prompt: str | None,
|
||||
tools: list["Tool"] | None,
|
||||
step_executor_type: str = "react",
|
||||
) -> PlanExecutor:
|
||||
"""创建 PlanExecutor 实例
|
||||
|
||||
Args:
|
||||
step_executor_type: "react" 使用 ReActStepExecutor(默认,支持工具调用),
|
||||
"llm" 使用 _LLMStepExecutor(纯 LLM 调用,无工具)
|
||||
"""
|
||||
if step_executor_type == "llm":
|
||||
step_executor: _LLMStepExecutor | ReActStepExecutor = _LLMStepExecutor(
|
||||
llm_gateway=self._llm_gateway,
|
||||
messages=messages,
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
)
|
||||
else:
|
||||
step_executor = ReActStepExecutor(
|
||||
llm_gateway=self._llm_gateway,
|
||||
messages=messages,
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
confirmation_handler=self._confirmation_handler,
|
||||
)
|
||||
"""创建 PlanExecutor 实例,使用 ReActStepExecutor 执行步骤"""
|
||||
step_executor = ReActStepExecutor(
|
||||
llm_gateway=self._llm_gateway,
|
||||
messages=messages,
|
||||
model=model,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
confirmation_handler=self._confirmation_handler,
|
||||
)
|
||||
return PlanExecutor(
|
||||
agent_pool=step_executor,
|
||||
max_retries=1,
|
||||
|
|
@ -937,58 +970,6 @@ class PlanExecEngine:
|
|||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
class _LLMStepExecutor:
|
||||
"""LLM 直接调用步骤执行器
|
||||
|
||||
作为 PlanExecutor 的 agent_pool 替代品,
|
||||
使每个 PlanStep 通过 LLM 直接调用执行,而非通过 AgentPool。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: "LLMGateway | None" = None,
|
||||
messages: list[dict[str, str]] | None = None,
|
||||
model: str = "default",
|
||||
system_prompt: str | None = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._messages = messages or []
|
||||
self._model = model
|
||||
self._system_prompt = system_prompt
|
||||
self._tools = tools
|
||||
self._agents: dict[str, _LLMStepAgent] = {}
|
||||
|
||||
async def create_agent_from_skill(self, skill_name: str):
|
||||
"""创建 LLM 步骤 Agent"""
|
||||
agent = _LLMStepAgent(
|
||||
name=skill_name,
|
||||
llm_gateway=self._llm_gateway,
|
||||
messages=self._messages,
|
||||
model=self._model,
|
||||
system_prompt=self._system_prompt,
|
||||
tools=self._tools,
|
||||
)
|
||||
self._agents[skill_name] = agent
|
||||
return agent
|
||||
|
||||
def get_agent(self, key: str):
|
||||
"""获取已创建的 Agent"""
|
||||
if key in self._agents:
|
||||
return self._agents[key]
|
||||
# 回退:创建一个默认 Agent
|
||||
agent = _LLMStepAgent(
|
||||
name=key,
|
||||
llm_gateway=self._llm_gateway,
|
||||
messages=self._messages,
|
||||
model=self._model,
|
||||
system_prompt=self._system_prompt,
|
||||
tools=self._tools,
|
||||
)
|
||||
self._agents[key] = agent
|
||||
return agent
|
||||
|
||||
|
||||
class ReActStepExecutor:
|
||||
"""ReAct 循环步骤执行器
|
||||
|
||||
|
|
@ -1132,69 +1113,3 @@ class _ReActStepAgent:
|
|||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
|
||||
class _LLMStepAgent:
|
||||
"""LLM 直接调用步骤 Agent
|
||||
|
||||
将 PlanStep 的描述作为 prompt 发送给 LLM,
|
||||
返回 LLM 的响应作为步骤结果。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
llm_gateway: "LLMGateway | None" = None,
|
||||
messages: list[dict[str, str]] | None = None,
|
||||
model: str = "default",
|
||||
system_prompt: str | None = None,
|
||||
tools: list["Tool"] | None = None,
|
||||
):
|
||||
self.name = name
|
||||
self._llm_gateway = llm_gateway
|
||||
self._messages = messages or []
|
||||
self._model = model
|
||||
self._system_prompt = system_prompt
|
||||
self._tools = tools
|
||||
|
||||
async def execute(self, task_msg: TaskMessage) -> "TaskResult":
|
||||
"""执行步骤:通过 LLM 直接调用"""
|
||||
if self._llm_gateway is None:
|
||||
raise RuntimeError(f"No LLM gateway available for step '{task_msg.task_id}'")
|
||||
|
||||
# 构建步骤 prompt
|
||||
input_data = task_msg.input_data
|
||||
step_name = input_data.get("step_name", task_msg.task_id)
|
||||
step_description = input_data.get("step_description", "")
|
||||
dep_results = input_data.get("dependency_results", {})
|
||||
|
||||
prompt_parts = [f"Execute the following task step:\n\nStep: {step_name}\nDescription: {step_description}"]
|
||||
|
||||
if dep_results:
|
||||
prompt_parts.append(f"\nResults from previous steps:\n{json.dumps(dep_results, ensure_ascii=False, indent=2)}")
|
||||
|
||||
prompt_parts.append("\nProvide a clear, structured result for this step.")
|
||||
|
||||
conversation: list[dict[str, Any]] = []
|
||||
if self._system_prompt:
|
||||
conversation.append({"role": "system", "content": self._system_prompt})
|
||||
# 添加原始对话上下文
|
||||
for msg in self._messages:
|
||||
conversation.append(msg)
|
||||
conversation.append({"role": "user", "content": "\n".join(prompt_parts)})
|
||||
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=conversation,
|
||||
model=self._model,
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
return TaskResult(
|
||||
task_id=task_msg.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED.value,
|
||||
output_data={"content": response.content or ""},
|
||||
error_message=None,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -74,15 +74,53 @@ class ReActEngine:
|
|||
使 Agent 能够自主推理并选择工具完成任务。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False):
|
||||
# Default core tools that always get full descriptions injected into the
|
||||
# prompt. ``tool_search`` is included so its full description is always
|
||||
# available to the LLM when tiered injection is active.
|
||||
_DEFAULT_CORE_TOOLS: tuple[str, ...] = (
|
||||
"read_file",
|
||||
"write_file",
|
||||
"bash",
|
||||
"search",
|
||||
"tool_search",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: LLMGateway,
|
||||
max_steps: int = 10,
|
||||
default_timeout: float = 300.0,
|
||||
parallel_tools: bool | str = False,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
verification_enabled: bool = False,
|
||||
verification_commands: list[str] | None = None,
|
||||
core_tool_names: list[str] | None = None,
|
||||
enable_tool_search: bool = True,
|
||||
):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
if isinstance(parallel_tools, str) and parallel_tools not in ("auto",):
|
||||
raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}")
|
||||
raise ValueError(
|
||||
f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}"
|
||||
)
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_steps = max_steps
|
||||
self._default_timeout = default_timeout
|
||||
self._parallel_tools = parallel_tools
|
||||
self._verification_enabled = verification_enabled
|
||||
self._verification_commands = verification_commands
|
||||
# Tiered tool description injection config
|
||||
self._core_tool_names: tuple[str, ...] | None = (
|
||||
tuple(core_tool_names) if core_tool_names is not None else None
|
||||
)
|
||||
self._enable_tool_search = enable_tool_search
|
||||
# Default context compression: keep last 10 turns
|
||||
if compressor is not None:
|
||||
self._compressor = compressor
|
||||
else:
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
|
||||
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
|
@ -120,10 +158,14 @@ class ReActEngine:
|
|||
4. 返回 ReActResult 包含输出和轨迹
|
||||
|
||||
Args:
|
||||
compressor: 压缩策略,None 时使用实例默认压缩器
|
||||
cancellation_token: 协作式取消令牌,每次循环迭代检查是否已取消
|
||||
timeout_seconds: 超时秒数,0 表示无超时,None 使用 default_timeout
|
||||
"""
|
||||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
effective_compressor = compressor if compressor is not None else self._compressor
|
||||
effective_timeout = (
|
||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
)
|
||||
|
||||
try:
|
||||
if effective_timeout > 0:
|
||||
|
|
@ -138,7 +180,7 @@ class ReActEngine:
|
|||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
compressor=effective_compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
|
|
@ -156,7 +198,7 @@ class ReActEngine:
|
|||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
compressor=effective_compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
|
|
@ -188,6 +230,8 @@ class ReActEngine:
|
|||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult:
|
||||
tools = tools or []
|
||||
if tools:
|
||||
tools = self._maybe_add_tool_search(tools)
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
if tool_schemas:
|
||||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||||
|
|
@ -205,7 +249,9 @@ class ReActEngine:
|
|||
system_prompt = self._build_tool_use_prompt(tools)
|
||||
|
||||
# Telemetry: record agent request
|
||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||
agent_request_counter().add(
|
||||
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
|
||||
)
|
||||
|
||||
# Start telemetry span for the entire agent execution
|
||||
_span_cm = None
|
||||
|
|
@ -250,7 +296,9 @@ class ReActEngine:
|
|||
else:
|
||||
system_prompt = f"## 参考信息\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"Memory retrieval failed, continuing without context: {e}", exc_info=True
|
||||
)
|
||||
|
||||
# 构建初始消息
|
||||
conversation: list[dict[str, Any]] = []
|
||||
|
|
@ -263,7 +311,9 @@ class ReActEngine:
|
|||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
logger.warning(
|
||||
f"Context compression failed, continuing with original messages: {e}"
|
||||
)
|
||||
|
||||
trace_outcome = "success"
|
||||
step = 0
|
||||
|
|
@ -323,9 +373,19 @@ class ReActEngine:
|
|||
# 执行工具调用
|
||||
if self._parallel_tools == "auto" and len(response.tool_calls) > 1:
|
||||
# Auto mode: mixed parallel/serial based on _parallelizable flag
|
||||
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls))
|
||||
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
|
||||
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
|
||||
parallelizable_set = set(
|
||||
self._get_parallelizable_indices(response.tool_calls)
|
||||
)
|
||||
serial_calls = [
|
||||
(i, tc)
|
||||
for i, tc in enumerate(response.tool_calls)
|
||||
if i not in parallelizable_set
|
||||
]
|
||||
parallel_calls = [
|
||||
(i, tc)
|
||||
for i, tc in enumerate(response.tool_calls)
|
||||
if i in parallelizable_set
|
||||
]
|
||||
|
||||
# Result slots indexed by original position
|
||||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||||
|
|
@ -340,7 +400,10 @@ class ReActEngine:
|
|||
# Execute parallelizable tools in parallel
|
||||
if len(parallel_calls) > 1:
|
||||
para_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
|
||||
*[
|
||||
self._execute_tool(tc.name, tc.arguments, tools)
|
||||
for _, tc in parallel_calls
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for j, (i, tc) in enumerate(parallel_calls):
|
||||
|
|
@ -381,12 +444,17 @@ class ReActEngine:
|
|||
error=tool_error,
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, compressor, tc.name
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
elif self._should_execute_parallel(response.tool_calls):
|
||||
# 并行执行多个工具调用 (parallel_tools=True)
|
||||
tool_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
|
||||
*[
|
||||
self._execute_tool(tc.name, tc.arguments, tools)
|
||||
for tc in response.tool_calls
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for idx, tc in enumerate(response.tool_calls):
|
||||
|
|
@ -419,7 +487,9 @@ class ReActEngine:
|
|||
error=tool_error,
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, compressor, tc.name
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
else:
|
||||
# 串行执行(单工具或 parallel_tools=False)
|
||||
|
|
@ -428,7 +498,9 @@ class ReActEngine:
|
|||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
# Handle confirmation flow
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
if isinstance(tool_result, dict) and tool_result.get(
|
||||
"needs_confirmation"
|
||||
):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
|
@ -436,28 +508,46 @@ class ReActEngine:
|
|||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
approved = await confirmation_handler(
|
||||
confirmation_id, command, reason
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
if tool and hasattr(tool, "_is_dangerous"):
|
||||
clean_args = {
|
||||
k: v
|
||||
for k, v in tc.arguments.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
tool_result = {
|
||||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||||
}
|
||||
else:
|
||||
# Non-dangerous tool: confirmation was for the overall action,
|
||||
# re-execute with skip flag to avoid re-triggering confirmation
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args = {
|
||||
k: v
|
||||
for k, v in tc.arguments.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
tool_result = (
|
||||
await tool.safe_execute(**clean_args)
|
||||
if tool
|
||||
else {"error": f"Tool '{tc.name}' not found"}
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
tool_result = {
|
||||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||||
}
|
||||
else:
|
||||
tool_result = {
|
||||
"output": "",
|
||||
|
|
@ -496,7 +586,9 @@ class ReActEngine:
|
|||
)
|
||||
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, compressor, tc.name
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
|
|
@ -524,7 +616,9 @@ class ReActEngine:
|
|||
|
||||
for pc in parsed_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_result = await self._execute_tool(
|
||||
pc["name"], pc["arguments"], tools
|
||||
)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
|
|
@ -554,7 +648,9 @@ class ReActEngine:
|
|||
)
|
||||
|
||||
# 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"])
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
|
|
@ -585,6 +681,35 @@ class ReActEngine:
|
|||
)
|
||||
break
|
||||
|
||||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||||
if self._verification_enabled and output:
|
||||
try:
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
||||
vloop = VerificationLoop(commands=self._verification_commands)
|
||||
vresult = await vloop.verify()
|
||||
if not vresult.passed:
|
||||
# 将验证失败信息作为 ReActStep 添加到轨迹
|
||||
verification_step = ReActStep(
|
||||
step=step + 1,
|
||||
action="tool_call",
|
||||
tool_name="verification",
|
||||
arguments={"commands": self._verification_commands},
|
||||
result={
|
||||
"passed": vresult.passed,
|
||||
"errors": vresult.errors,
|
||||
"test_output": vresult.test_output,
|
||||
},
|
||||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||||
)
|
||||
trajectory.append(verification_step)
|
||||
logger.info(
|
||||
"Verification failed after final answer, "
|
||||
"appended feedback to trajectory"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification loop failed: {e}")
|
||||
|
||||
# 达到 max_steps 时,返回当前最佳输出
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
|
|
@ -599,6 +724,7 @@ class ReActEngine:
|
|||
# 兜底:确保 output 永远不为空字符串
|
||||
if not output or not output.strip():
|
||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
|
||||
|
||||
if step >= self._max_steps:
|
||||
output = MAX_STEPS_REACHED
|
||||
else:
|
||||
|
|
@ -660,8 +786,14 @@ class ReActEngine:
|
|||
|
||||
Same logic as execute() but yields events at each step instead of
|
||||
accumulating a result.
|
||||
|
||||
Args:
|
||||
compressor: 压缩策略,None 时使用实例默认压缩器
|
||||
"""
|
||||
effective_compressor = compressor if compressor is not None else self._compressor
|
||||
tools = tools or []
|
||||
if tools:
|
||||
tools = self._maybe_add_tool_search(tools)
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
if tool_schemas:
|
||||
tool_names = [s["function"]["name"] for s in tool_schemas]
|
||||
|
|
@ -679,7 +811,9 @@ class ReActEngine:
|
|||
system_prompt = self._build_tool_use_prompt(tools)
|
||||
|
||||
# Telemetry: record agent request
|
||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||
agent_request_counter().add(
|
||||
1, {"agent.name": agent_name, "agent.type": task_type or "react"}
|
||||
)
|
||||
|
||||
# Start telemetry span for the entire agent execution
|
||||
_span_cm = None
|
||||
|
|
@ -726,11 +860,13 @@ class ReActEngine:
|
|||
conversation.extend(messages)
|
||||
|
||||
# Context compression: 压缩超长对话历史
|
||||
if compressor:
|
||||
if effective_compressor:
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
conversation = await effective_compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed, continuing with original messages: {e}")
|
||||
logger.warning(
|
||||
f"Context compression failed, continuing with original messages: {e}"
|
||||
)
|
||||
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = 0
|
||||
|
|
@ -738,7 +874,9 @@ class ReActEngine:
|
|||
output = ""
|
||||
trace_outcome = "success"
|
||||
_stream_start = time.monotonic()
|
||||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
effective_timeout = (
|
||||
timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
)
|
||||
|
||||
try:
|
||||
while step < self._max_steps:
|
||||
|
|
@ -836,19 +974,44 @@ class ReActEngine:
|
|||
conversation.append(assistant_msg)
|
||||
|
||||
# Execute tool calls with parallel support
|
||||
if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls):
|
||||
if (
|
||||
self._parallel_tools
|
||||
and len(response.tool_calls) > 1
|
||||
and self._should_execute_parallel(response.tool_calls)
|
||||
):
|
||||
# Parallel execution path
|
||||
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls)))
|
||||
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
|
||||
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
|
||||
parallelizable_set = (
|
||||
set(self._get_parallelizable_indices(response.tool_calls))
|
||||
if self._parallel_tools == "auto"
|
||||
else set(range(len(response.tool_calls)))
|
||||
)
|
||||
serial_calls = [
|
||||
(i, tc)
|
||||
for i, tc in enumerate(response.tool_calls)
|
||||
if i not in parallelizable_set
|
||||
]
|
||||
parallel_calls = [
|
||||
(i, tc)
|
||||
for i, tc in enumerate(response.tool_calls)
|
||||
if i in parallelizable_set
|
||||
]
|
||||
|
||||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||||
|
||||
# Execute serial tools first (handles confirmation flow)
|
||||
for i, tc in serial_calls:
|
||||
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
tool_start = time.monotonic()
|
||||
tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler)
|
||||
(
|
||||
tool_result,
|
||||
confirm_events,
|
||||
) = await self._execute_tool_with_confirmation(
|
||||
tc, tools, step, confirmation_handler
|
||||
)
|
||||
for ev in confirm_events:
|
||||
yield ev
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
|
@ -857,7 +1020,10 @@ class ReActEngine:
|
|||
# Execute parallelizable tools concurrently
|
||||
if len(parallel_calls) > 1:
|
||||
para_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
|
||||
*[
|
||||
self._execute_tool(tc.name, tc.arguments, tools)
|
||||
for _, tc in parallel_calls
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for j, (i, tc) in enumerate(parallel_calls):
|
||||
|
|
@ -873,19 +1039,45 @@ class ReActEngine:
|
|||
# Process all results in original order
|
||||
for i, tc in enumerate(response.tool_calls):
|
||||
tc_obj, tool_result, tool_duration_ms = all_results[i]
|
||||
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
|
||||
react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens)
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error)
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result})
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, effective_compressor, tc.name
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
else:
|
||||
# Serial execution path (with confirmation flow)
|
||||
|
|
@ -902,7 +1094,9 @@ class ReActEngine:
|
|||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
# 检测工具返回的确认请求
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
if isinstance(tool_result, dict) and tool_result.get(
|
||||
"needs_confirmation"
|
||||
):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
|
@ -923,16 +1117,22 @@ class ReActEngine:
|
|||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
approved = await confirmation_handler(
|
||||
confirmation_id, command, reason
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
# 用户确认执行:使用 per-call override 绕过安全检查
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
if tool and hasattr(tool, "_is_dangerous"):
|
||||
# Strip internal metadata and pass skip_dangerous_check flag
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args = {
|
||||
k: v
|
||||
for k, v in tc.arguments.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args)
|
||||
|
|
@ -940,12 +1140,22 @@ class ReActEngine:
|
|||
pass # No shared state mutation needed
|
||||
else:
|
||||
# Non-dangerous tool: re-execute with skip flag
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args = {
|
||||
k: v
|
||||
for k, v in tc.arguments.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
tool_result = (
|
||||
await tool.safe_execute(**clean_args)
|
||||
if tool
|
||||
else {"error": f"Tool '{tc.name}' not found"}
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
tool_result = {
|
||||
"error": f"Tool '{tc.name}' execution failed: {e}"
|
||||
}
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
|
|
@ -964,7 +1174,10 @@ class ReActEngine:
|
|||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"approved": False,
|
||||
},
|
||||
)
|
||||
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
|
@ -1001,13 +1214,15 @@ class ReActEngine:
|
|||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, effective_compressor, tc.name
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
if self._should_compress(conversation, effective_compressor):
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
conversation = await effective_compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Incremental compression failed: {e}")
|
||||
|
||||
|
|
@ -1033,16 +1248,20 @@ class ReActEngine:
|
|||
data={"tool_name": pc["name"], "arguments": pc["arguments"]},
|
||||
)
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(pc["name"], pc["arguments"], tools)
|
||||
tool_result = await self._execute_tool(
|
||||
pc["name"], pc["arguments"], tools
|
||||
)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
trajectory.append(ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
arguments=pc["arguments"],
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
))
|
||||
trajectory.append(
|
||||
ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=pc["name"],
|
||||
arguments=pc["arguments"],
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
)
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
|
|
@ -1064,14 +1283,17 @@ class ReActEngine:
|
|||
data={"tool_name": pc["name"], "result": tool_result},
|
||||
)
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
|
||||
pc.get("id", f"text_tc_{step}"),
|
||||
tool_result,
|
||||
effective_compressor,
|
||||
pc["name"],
|
||||
)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
if self._should_compress(conversation, effective_compressor):
|
||||
try:
|
||||
conversation = await compressor.compress(conversation)
|
||||
conversation = await effective_compressor.compress(conversation)
|
||||
except Exception as e:
|
||||
logger.warning(f"Incremental compression failed: {e}")
|
||||
else:
|
||||
|
|
@ -1106,6 +1328,46 @@ class ReActEngine:
|
|||
)
|
||||
break
|
||||
|
||||
# Verification: 如果启用验证,在 final answer 后运行测试
|
||||
if self._verification_enabled and output:
|
||||
try:
|
||||
from agentkit.core.verification_loop import VerificationLoop
|
||||
|
||||
vloop = VerificationLoop(commands=self._verification_commands)
|
||||
vresult = await vloop.verify()
|
||||
if not vresult.passed:
|
||||
verification_step = ReActStep(
|
||||
step=step + 1,
|
||||
action="tool_call",
|
||||
tool_name="verification",
|
||||
arguments={"commands": self._verification_commands},
|
||||
result={
|
||||
"passed": vresult.passed,
|
||||
"errors": vresult.errors,
|
||||
"test_output": vresult.test_output,
|
||||
},
|
||||
content=(f"Verification failed:\n{vresult.test_output[:2000]}"),
|
||||
)
|
||||
trajectory.append(verification_step)
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step + 1,
|
||||
data={
|
||||
"tool_name": "verification",
|
||||
"result": {
|
||||
"passed": vresult.passed,
|
||||
"errors": vresult.errors,
|
||||
"test_output": vresult.test_output,
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
"Verification failed after final answer, "
|
||||
"appended feedback to trajectory"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Verification loop failed: {e}")
|
||||
|
||||
if step >= self._max_steps and not output:
|
||||
trace_outcome = "partial"
|
||||
if trajectory and trajectory[-1].content:
|
||||
|
|
@ -1129,6 +1391,7 @@ class ReActEngine:
|
|||
# 兜底:确保 output 永远不为空字符串
|
||||
if not output or not output.strip():
|
||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE, MAX_STEPS_REACHED
|
||||
|
||||
if step >= self._max_steps:
|
||||
output = MAX_STEPS_REACHED
|
||||
else:
|
||||
|
|
@ -1187,33 +1450,40 @@ class ReActEngine:
|
|||
schemas.append(schema)
|
||||
return schemas
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_use_prompt(tools: list[Tool]) -> str:
|
||||
"""Build prompt-based tool calling instructions for LLMs that don't
|
||||
support native function calling (e.g., Bailian Coding, Qwen).
|
||||
def _build_tool_use_prompt(self, tools: list[Tool]) -> str:
|
||||
"""Build prompt-based tool calling instructions with tiered injection.
|
||||
|
||||
Instructs the LLM to use <tool_use> XML format for tool invocation.
|
||||
This follows the Hermes pattern: model-agnostic prompt-based tool calling.
|
||||
Core tools (defined by ``self._core_tool_names`` or
|
||||
:attr:`_DEFAULT_CORE_TOOLS`) get full descriptions (name +
|
||||
description + parameters). Extended tools get only name + a
|
||||
one-line description. When ``tool_search`` is present alongside
|
||||
extended tools, a hint is added telling the LLM to call
|
||||
``tool_search`` for full parameter details.
|
||||
|
||||
Instructs the LLM to use ``<tool_use>`` XML format for tool
|
||||
invocation (Hermes pattern: model-agnostic prompt-based tool calling).
|
||||
"""
|
||||
tool_descriptions = []
|
||||
for tool in tools:
|
||||
params_desc = ""
|
||||
if tool.input_schema:
|
||||
props = tool.input_schema.get("properties", {})
|
||||
required = tool.input_schema.get("required", [])
|
||||
param_parts = []
|
||||
for pname, pinfo in props.items():
|
||||
ptype = pinfo.get("type", "string")
|
||||
pdesc = pinfo.get("description", "")
|
||||
req_flag = " (required)" if pname in required else ""
|
||||
param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}")
|
||||
if param_parts:
|
||||
params_desc = "\n".join(param_parts)
|
||||
tool_descriptions.append(
|
||||
f"- {tool.name}: {tool.description}\n{params_desc}"
|
||||
core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
|
||||
core_tools = [t for t in tools if t.name in core_names]
|
||||
extended_tools = [t for t in tools if t.name not in core_names]
|
||||
|
||||
sections: list[str] = []
|
||||
if core_tools:
|
||||
sections.append(self._render_core_tools(core_tools))
|
||||
if extended_tools:
|
||||
sections.append(self._render_extended_tools(extended_tools))
|
||||
|
||||
tools_text = "\n\n".join(sections)
|
||||
|
||||
has_tool_search = any(t.name == "tool_search" for t in tools)
|
||||
search_hint = ""
|
||||
if has_tool_search and extended_tools:
|
||||
search_hint = (
|
||||
"\n\n注意:上方「扩展工具」仅显示名称和简短描述。"
|
||||
'如需使用某个扩展工具,请先调用 tool_search(query="关键词") '
|
||||
"获取其完整参数说明。"
|
||||
)
|
||||
|
||||
tools_text = "\n\n".join(tool_descriptions)
|
||||
return (
|
||||
"## 可用工具\n\n"
|
||||
"你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n"
|
||||
|
|
@ -1225,9 +1495,65 @@ class ReActEngine:
|
|||
"2. 等待工具返回结果后再决定下一步\n"
|
||||
"3. 如果不需要工具就能回答,直接回答即可\n"
|
||||
"4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n"
|
||||
f"工具列表:\n\n{tools_text}"
|
||||
f"工具列表:\n\n{tools_text}{search_hint}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _render_core_tools(tools: list[Tool]) -> str:
|
||||
"""Render core tools with full descriptions (name + description + parameters)."""
|
||||
descriptions: list[str] = []
|
||||
for tool in tools:
|
||||
params_desc = ""
|
||||
if tool.input_schema:
|
||||
props = tool.input_schema.get("properties", {})
|
||||
required = tool.input_schema.get("required", [])
|
||||
param_parts: list[str] = []
|
||||
for pname, pinfo in props.items():
|
||||
ptype = pinfo.get("type", "string")
|
||||
pdesc = pinfo.get("description", "")
|
||||
req_flag = " (required)" if pname in required else ""
|
||||
param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}")
|
||||
if param_parts:
|
||||
params_desc = "\n".join(param_parts)
|
||||
descriptions.append(f"- {tool.name}: {tool.description}\n{params_desc}")
|
||||
return "### 核心工具(完整描述)\n\n" + "\n\n".join(descriptions)
|
||||
|
||||
@staticmethod
|
||||
def _render_extended_tools(tools: list[Tool]) -> str:
|
||||
"""Render extended tools with name + one-line description only."""
|
||||
lines: list[str] = []
|
||||
for tool in tools:
|
||||
desc = tool.description.strip().split("\n")[0]
|
||||
if len(desc) > 100:
|
||||
desc = desc[:97] + "..."
|
||||
lines.append(f"- {tool.name}: {desc}")
|
||||
return "### 扩展工具(仅名称和简短描述,使用 tool_search 获取详情)\n\n" + "\n".join(lines)
|
||||
|
||||
def _maybe_add_tool_search(self, tools: list[Tool]) -> list[Tool]:
|
||||
"""Add ``tool_search`` tool if enabled and there are extended tools.
|
||||
|
||||
Builds a :class:`ToolSearchIndex` from the extended tools so the
|
||||
LLM can discover full tool descriptions on demand via BM25 search.
|
||||
If all tools are core tools, or ``tool_search`` is already present,
|
||||
or ``enable_tool_search`` is False, the list is returned unchanged.
|
||||
"""
|
||||
if not self._enable_tool_search:
|
||||
return tools
|
||||
if any(t.name == "tool_search" for t in tools):
|
||||
return tools
|
||||
|
||||
core_names = set(self._core_tool_names or self._DEFAULT_CORE_TOOLS)
|
||||
extended_tools = [t for t in tools if t.name not in core_names]
|
||||
if not extended_tools:
|
||||
return tools
|
||||
|
||||
from agentkit.tools.builtin import ToolSearchTool
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
|
||||
index = ToolSearchIndex(extended_tools)
|
||||
search_tool = ToolSearchTool(search_index=index)
|
||||
return tools + [search_tool]
|
||||
|
||||
@staticmethod
|
||||
def _build_response_from_stream(
|
||||
content: str,
|
||||
|
|
@ -1237,6 +1563,7 @@ class ReActEngine:
|
|||
) -> LLMResponse:
|
||||
"""Build an LLMResponse from accumulated stream chunks."""
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
if usage is None:
|
||||
usage = TokenUsage()
|
||||
return LLMResponse(
|
||||
|
|
@ -1256,7 +1583,9 @@ class ReActEngine:
|
|||
# Default token threshold for incremental compression
|
||||
_DEFAULT_COMPRESS_THRESHOLD = 8000
|
||||
|
||||
def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool:
|
||||
def _should_compress(
|
||||
self, conversation: list[dict], compressor: "CompressionStrategy | None"
|
||||
) -> bool:
|
||||
"""检查是否需要增量压缩"""
|
||||
if not compressor:
|
||||
return False
|
||||
|
|
@ -1331,16 +1660,18 @@ class ReActEngine:
|
|||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
))
|
||||
events.append(
|
||||
ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for user confirmation
|
||||
approved = False
|
||||
|
|
@ -1353,7 +1684,7 @@ class ReActEngine:
|
|||
if approved:
|
||||
# User approved: re-execute with _skip_dangerous_check
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
if tool and hasattr(tool, "_is_dangerous"):
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
|
|
@ -1365,15 +1696,21 @@ class ReActEngine:
|
|||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args) if tool else {"error": f"Tool '{tc.name}' not found"}
|
||||
tool_result = (
|
||||
await tool.safe_execute(**clean_args)
|
||||
if tool
|
||||
else {"error": f"Tool '{tc.name}' not found"}
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
))
|
||||
events.append(
|
||||
ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
)
|
||||
)
|
||||
else:
|
||||
# User rejected
|
||||
tool_result = {
|
||||
|
|
@ -1383,11 +1720,13 @@ class ReActEngine:
|
|||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
))
|
||||
events.append(
|
||||
ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
)
|
||||
)
|
||||
|
||||
return tool_result, events
|
||||
|
||||
|
|
@ -1418,7 +1757,7 @@ class ReActEngine:
|
|||
"""
|
||||
indices = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
args = tc.arguments if hasattr(tc, 'arguments') else {}
|
||||
args = tc.arguments if hasattr(tc, "arguments") else {}
|
||||
if isinstance(args, dict) and args.get("_parallelizable") is True:
|
||||
indices.append(i)
|
||||
return indices
|
||||
|
|
@ -1434,9 +1773,7 @@ class ReActEngine:
|
|||
calls: list[dict[str, Any]] = []
|
||||
|
||||
# 格式 1: Action: tool_name(args)
|
||||
action_pattern = re.compile(
|
||||
r"Action:\s*(\w+)\((.+?)\)", re.DOTALL
|
||||
)
|
||||
action_pattern = re.compile(r"Action:\s*(\w+)\((.+?)\)", re.DOTALL)
|
||||
for match in action_pattern.finditer(content):
|
||||
name = match.group(1)
|
||||
args_str = match.group(2)
|
||||
|
|
@ -1450,9 +1787,7 @@ class ReActEngine:
|
|||
return calls
|
||||
|
||||
# 格式 2: ```tool\n{"name": "...", "arguments": {...}}\n```
|
||||
code_block_pattern = re.compile(
|
||||
r"```tool\s*\n(.*?)\n\s*```", re.DOTALL
|
||||
)
|
||||
code_block_pattern = re.compile(r"```tool\s*\n(.*?)\n\s*```", re.DOTALL)
|
||||
for match in code_block_pattern.finditer(content):
|
||||
json_str = match.group(1).strip()
|
||||
try:
|
||||
|
|
@ -1469,9 +1804,7 @@ class ReActEngine:
|
|||
|
||||
# 格式 3: <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
|
||||
# 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式
|
||||
tool_use_pattern = re.compile(
|
||||
r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL
|
||||
)
|
||||
tool_use_pattern = re.compile(r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL)
|
||||
for match in tool_use_pattern.finditer(content):
|
||||
json_str = match.group(1).strip()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
"""Expert Team routing — resolves user input to ExpertTeam configuration."""
|
||||
"""Expert Team routing — resolves @team prefix to ExpertTeam configuration.
|
||||
|
||||
简化说明(U3):
|
||||
- 仅通过 @team 前缀触发团队模式
|
||||
- 移除基于复杂度的自动建议
|
||||
- 保留 @team:expert1,expert2 指定专家成员
|
||||
- 保留 resolve_expert_configs 解析专家配置
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -21,37 +28,37 @@ MAX_EXPERTS = 10 # Maximum number of experts in a team
|
|||
|
||||
@dataclass
|
||||
class ExpertTeamRoutingResult:
|
||||
"""Result of expert team routing resolution."""
|
||||
"""Result of expert team routing resolution.
|
||||
|
||||
In hub-and-spoke mode, routing is triggered exclusively by the @team prefix.
|
||||
"""
|
||||
|
||||
matched: bool = False
|
||||
team_mode: bool = False
|
||||
specified_experts: list[str] = field(default_factory=list)
|
||||
task_content: str = ""
|
||||
auto_compose: bool = False
|
||||
complexity: float = 0.0
|
||||
match_method: str = "" # "explicit_team" | "complexity_suggestion"
|
||||
match_method: str = "" # "explicit_team" | ""
|
||||
|
||||
|
||||
class ExpertTeamRouter:
|
||||
"""Routes user input to Expert Team mode.
|
||||
"""Routes user input to Expert Team mode via @team prefix.
|
||||
|
||||
Supports:
|
||||
- @team prefix → trigger team mode
|
||||
- @team:analyst,strategist → specify team members
|
||||
- High complexity → suggest team mode upgrade
|
||||
- @team prefix → trigger team mode (auto-compose members)
|
||||
- @team:analyst,strategist → specify team members by name
|
||||
"""
|
||||
|
||||
COMPLEXITY_THRESHOLD = 0.7 # Above this, suggest team mode
|
||||
|
||||
def __init__(self, template_registry: ExpertTemplateRegistry | None = None):
|
||||
self._registry = template_registry or ExpertTemplateRegistry()
|
||||
|
||||
def resolve(self, content: str, complexity: float = 0.0) -> ExpertTeamRoutingResult:
|
||||
def resolve(self, content: str) -> ExpertTeamRoutingResult:
|
||||
"""Resolve user input to an ExpertTeamRoutingResult.
|
||||
|
||||
Only @team prefix triggers team mode. No complexity-based suggestion.
|
||||
|
||||
Args:
|
||||
content: User's input message
|
||||
complexity: Pre-computed complexity score (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
ExpertTeamRoutingResult with routing decision
|
||||
|
|
@ -94,27 +101,15 @@ class ExpertTeamRouter:
|
|||
|
||||
return result
|
||||
|
||||
# Check complexity-based suggestion
|
||||
if complexity >= self.COMPLEXITY_THRESHOLD:
|
||||
result.matched = True
|
||||
result.team_mode = True
|
||||
result.auto_compose = True
|
||||
result.complexity = complexity
|
||||
result.task_content = content
|
||||
result.match_method = "complexity_suggestion"
|
||||
return result
|
||||
|
||||
# Not a team mode request
|
||||
result.matched = False
|
||||
result.team_mode = False
|
||||
result.task_content = content
|
||||
result.complexity = complexity
|
||||
return result
|
||||
|
||||
def can_handle(self, content: str) -> bool:
|
||||
"""Check whether any registered expert template can handle the given content.
|
||||
|
||||
Used by CostAwareRouter to decide whether to upgrade REACT → TEAM_COLLAB.
|
||||
Returns True if at least one template's name or description overlaps with
|
||||
content tokens, or if any templates exist (auto-compose can always form a team).
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Restore conversation history from persistent store (async, in lifespan)
|
||||
from agentkit.server.routes.portal import _conversation_store
|
||||
|
||||
await _conversation_store.restore_from_store()
|
||||
|
||||
# In GUI mode, ensure a default chat agent exists with memory + tools
|
||||
|
|
@ -579,13 +580,13 @@ def create_app(
|
|||
app.state.quality_gate = QualityGate()
|
||||
app.state.output_standardizer = OutputStandardizer()
|
||||
|
||||
# Initialize SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||
from agentkit.chat.simple_router import SimpleRouter
|
||||
# Initialize RequestPreprocessor (minimal preprocessing: @skill prefix + greeting regex + REACT)
|
||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||
|
||||
simple_router = SimpleRouter(
|
||||
request_preprocessor = RequestPreprocessor(
|
||||
skill_registry=app.state.skill_registry,
|
||||
)
|
||||
app.state.simple_router = simple_router
|
||||
app.state.request_preprocessor = request_preprocessor
|
||||
|
||||
# Initialize OrganizationContext from AgentPool + SkillRegistry
|
||||
from agentkit.org.context import OrganizationContext
|
||||
|
|
@ -606,39 +607,6 @@ def create_app(
|
|||
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
|
||||
app.state.alignment_guard = alignment_guard
|
||||
|
||||
# CostAwareRouter is no longer used by portal/chat routes (replaced by SimpleRouter).
|
||||
# It is kept on app.state for backward compatibility with any external consumers.
|
||||
# To re-enable, set router.legacy_cost_aware_router: true in agentkit.yaml.
|
||||
router_conf = server_config.router if server_config and server_config.router else {}
|
||||
if router_conf.get("legacy_cost_aware_router"):
|
||||
from agentkit.chat.skill_routing import CostAwareRouter
|
||||
|
||||
auction_enabled = False
|
||||
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
||||
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
||||
|
||||
semantic_router = None
|
||||
if router_conf.get("semantic", {}).get("enabled"):
|
||||
try:
|
||||
from agentkit.chat.semantic_router import SemanticRouter
|
||||
|
||||
semantic_router = SemanticRouter(
|
||||
embedder=app.state.llm_gateway._embedder,
|
||||
similarity_high=router_conf["semantic"].get("similarity_high", 0.85),
|
||||
similarity_low=router_conf["semantic"].get("similarity_low", 0.6),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize semantic router: {e}")
|
||||
|
||||
cost_aware_router = CostAwareRouter(
|
||||
llm_gateway=app.state.llm_gateway,
|
||||
org_context=org_context,
|
||||
auction_enabled=auction_enabled,
|
||||
classifier=router_conf.get("classifier", "heuristic"),
|
||||
merged_llm_classify=router_conf.get("merged_llm_classify", True),
|
||||
semantic_router=semantic_router,
|
||||
)
|
||||
app.state.cost_aware_router = cost_aware_router
|
||||
# Initialize task store from config
|
||||
ts_config = server_config.task_store if server_config else {}
|
||||
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
||||
|
|
@ -680,9 +648,11 @@ def create_app(
|
|||
)
|
||||
app.state.session_manager = SessionManager(store=session_store)
|
||||
|
||||
# Inject SessionManager into Portal's ConversationStore for persistence
|
||||
# Inject SessionManager into Portal's ConversationStore for persistence (legacy only)
|
||||
from agentkit.server.routes.portal import _conversation_store
|
||||
_conversation_store.set_session_manager(app.state.session_manager)
|
||||
|
||||
if hasattr(_conversation_store, "set_session_manager"):
|
||||
_conversation_store.set_session_manager(app.state.session_manager)
|
||||
|
||||
# Initialize evolution store if configured
|
||||
if server_config and hasattr(server_config, "evolution") and server_config.evolution:
|
||||
|
|
|
|||
|
|
@ -97,12 +97,21 @@ chat_manager = ChatConnectionManager()
|
|||
# ── Helper ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
_VALID_TEAM_EVENT_TYPES = frozenset({
|
||||
"team_formed", "expert_step", "expert_result",
|
||||
"plan_update", "team_synthesis", "team_dissolved",
|
||||
"plan_step", "phase_started", "phase_completed", "phase_failed",
|
||||
"replanning",
|
||||
})
|
||||
_VALID_TEAM_EVENT_TYPES = frozenset(
|
||||
{
|
||||
"team_formed",
|
||||
"expert_step",
|
||||
"expert_result",
|
||||
"plan_update",
|
||||
"team_synthesis",
|
||||
"team_dissolved",
|
||||
"plan_step",
|
||||
"phase_started",
|
||||
"phase_completed",
|
||||
"phase_failed",
|
||||
"replanning",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None:
|
||||
|
|
@ -125,10 +134,12 @@ async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) ->
|
|||
if event_type not in _VALID_TEAM_EVENT_TYPES:
|
||||
logger.warning(f"emit_team_event: invalid event_type '{event_type}'")
|
||||
return
|
||||
await websocket.send_json({
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": event_type,
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_session_manager(request: Request) -> SessionManager:
|
||||
|
|
@ -236,11 +247,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
else:
|
||||
react_engine.reset()
|
||||
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
||||
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
||||
system_prompt = getattr(agent, "_system_prompt", None) or (
|
||||
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
||||
)
|
||||
result = await react_engine.execute(
|
||||
messages=chat_messages,
|
||||
tools=tools,
|
||||
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"),
|
||||
model=agent.get_model()
|
||||
if hasattr(agent, "get_model")
|
||||
else getattr(agent, "_llm_model", "default"),
|
||||
agent_name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
|
@ -319,7 +334,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
sm: SessionManager = websocket.app.state.session_manager
|
||||
session = await sm.get_session(session_id)
|
||||
if session is None:
|
||||
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
|
||||
await websocket.send_json(
|
||||
{"type": "error", "data": {"message": f"Session '{session_id}' not found"}}
|
||||
)
|
||||
await websocket.close(code=1000, reason="Session not found")
|
||||
return
|
||||
|
||||
|
|
@ -367,10 +384,14 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
# Clean up completed tasks first
|
||||
active_tasks.difference_update(t for t in active_tasks if t.done())
|
||||
if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"message": "Too many concurrent requests. Please wait for the current task to complete."},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"data": {
|
||||
"message": "Too many concurrent requests. Please wait for the current task to complete."
|
||||
},
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Run in background task so the WebSocket receive loop stays free
|
||||
|
|
@ -378,7 +399,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
# is waiting for user confirmation (otherwise deadlock).
|
||||
task = asyncio.create_task(
|
||||
_handle_chat_message(
|
||||
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations,
|
||||
websocket,
|
||||
session_id,
|
||||
content,
|
||||
sm,
|
||||
message_token,
|
||||
pending_replies,
|
||||
pending_confirmations,
|
||||
model_override=model,
|
||||
)
|
||||
)
|
||||
|
|
@ -396,12 +423,16 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
# Reply to confirmation request
|
||||
confirmation_id = msg.get("confirmation_id")
|
||||
approved = msg.get("approved", False)
|
||||
logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}")
|
||||
logger.info(
|
||||
f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}"
|
||||
)
|
||||
if confirmation_id and confirmation_id in pending_confirmations:
|
||||
pending_confirmations[confirmation_id].set_result(approved)
|
||||
logger.info(f"Confirmation {confirmation_id} set_result({approved})")
|
||||
else:
|
||||
logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations")
|
||||
logger.warning(
|
||||
f"Confirmation {confirmation_id!r} not found in pending_confirmations"
|
||||
)
|
||||
|
||||
elif msg_type == "cancel":
|
||||
cancellation_token.cancel()
|
||||
|
|
@ -441,9 +472,9 @@ async def _handle_chat_message(
|
|||
) -> None:
|
||||
"""Handle a user message: append to session, execute Agent, stream events.
|
||||
|
||||
Uses SimpleRouter for minimal routing: @skill prefix + greeting regex + REACT.
|
||||
Uses RequestPreprocessor for minimal preprocessing: @skill prefix + greeting regex + REACT.
|
||||
"""
|
||||
from agentkit.chat.simple_router import SimpleRouter
|
||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||
|
||||
# Resolve Agent first (needed for default tools/prompt)
|
||||
pool = websocket.app.state.agent_pool
|
||||
|
|
@ -454,19 +485,27 @@ async def _handle_chat_message(
|
|||
|
||||
agent = pool.get_agent(session.agent_name)
|
||||
if agent is None:
|
||||
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
|
||||
await websocket.send_json(
|
||||
{"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}
|
||||
)
|
||||
return
|
||||
|
||||
# Default execution parameters from agent
|
||||
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
||||
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
||||
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
|
||||
default_system_prompt = getattr(agent, "_system_prompt", None) or (
|
||||
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
||||
)
|
||||
default_model = (
|
||||
agent.get_model()
|
||||
if hasattr(agent, "get_model")
|
||||
else getattr(agent, "_llm_model", "default")
|
||||
)
|
||||
|
||||
# Resolve skill routing using SimpleRouter
|
||||
# Resolve skill routing using RequestPreprocessor
|
||||
skill_registry = getattr(websocket.app.state, "skill_registry", None)
|
||||
simple_router: SimpleRouter = websocket.app.state.simple_router
|
||||
request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
|
||||
|
||||
routing = await simple_router.route(
|
||||
routing = await request_preprocessor.preprocess(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
default_tools=default_tools,
|
||||
|
|
@ -477,7 +516,9 @@ async def _handle_chat_message(
|
|||
|
||||
# Debug: log tools that will be passed to ReActEngine
|
||||
tool_names = [t.name for t in routing.tools]
|
||||
logger.info(f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}")
|
||||
logger.info(
|
||||
f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}"
|
||||
)
|
||||
|
||||
# Apply model override from frontend selector
|
||||
if model_override:
|
||||
|
|
@ -485,17 +526,21 @@ async def _handle_chat_message(
|
|||
|
||||
# Notify frontend about skill match
|
||||
if routing.matched:
|
||||
await websocket.send_json({
|
||||
"type": "skill_match",
|
||||
"data": {
|
||||
"skill": routing.skill_name,
|
||||
"method": routing.match_method,
|
||||
"confidence": routing.match_confidence,
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "skill_match",
|
||||
"data": {
|
||||
"skill": routing.skill_name,
|
||||
"method": routing.match_method,
|
||||
"confidence": routing.match_confidence,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Append user message (use clean_content if @skill: prefix was stripped)
|
||||
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content)
|
||||
await sm.append_message(
|
||||
session_id=session_id, role=MessageRole.USER, content=routing.clean_content
|
||||
)
|
||||
|
||||
# Get full conversation history
|
||||
chat_messages = await sm.get_chat_messages(session_id)
|
||||
|
|
@ -516,12 +561,15 @@ async def _handle_chat_message(
|
|||
final_content = response.content or ""
|
||||
if not final_content or not final_content.strip():
|
||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
||||
|
||||
final_content = EMPTY_LLM_RESPONSE
|
||||
await websocket.send_json({
|
||||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
"is_final": True,
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
await sm.append_message(
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
|
|
@ -557,14 +605,16 @@ async def _handle_chat_message(
|
|||
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
|
||||
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
|
||||
# Send confirmation request to frontend
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_request",
|
||||
"data": {
|
||||
"confirmation_id": confirmation_id,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "confirmation_request",
|
||||
"data": {
|
||||
"confirmation_id": confirmation_id,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Create a Future and wait for the user's reply
|
||||
loop = asyncio.get_running_loop()
|
||||
|
|
@ -578,10 +628,12 @@ async def _handle_chat_message(
|
|||
logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
|
||||
# Immediately notify frontend of the result so the card updates
|
||||
# without waiting for the tool to re-execute
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_result",
|
||||
"data": {"confirmation_id": confirmation_id, "approved": result},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "confirmation_result",
|
||||
"data": {"confirmation_id": confirmation_id, "approved": result},
|
||||
}
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Confirmation request {confirmation_id} timed out")
|
||||
|
|
@ -592,7 +644,9 @@ async def _handle_chat_message(
|
|||
finally:
|
||||
_pending_confirmations.pop(confirmation_id, None)
|
||||
|
||||
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
|
||||
logger.info(
|
||||
f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
final_content = ""
|
||||
|
|
@ -615,12 +669,15 @@ async def _handle_chat_message(
|
|||
final_content = event.data.get("output", "")
|
||||
if not final_content or not final_content.strip():
|
||||
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
||||
|
||||
final_content = EMPTY_LLM_RESPONSE
|
||||
await websocket.send_json({
|
||||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
"is_final": True,
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
elif event.event_type == "token":
|
||||
# Buffer tokens instead of sending immediately
|
||||
token_buffer.append(event.data.get("content", ""))
|
||||
|
|
@ -640,30 +697,36 @@ async def _handle_chat_message(
|
|||
buffered_text = "".join(token_buffer)
|
||||
token_buffer.clear()
|
||||
await websocket.send_json({"type": "thinking", "content": buffered_text})
|
||||
await websocket.send_json({
|
||||
"type": "step",
|
||||
"data": {
|
||||
"event_type": event.event_type,
|
||||
"step": event.step,
|
||||
"data": event.data,
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "step",
|
||||
"data": {
|
||||
"event_type": event.event_type,
|
||||
"step": event.step,
|
||||
"data": event.data,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif event.event_type == "confirmation_request":
|
||||
pass
|
||||
elif event.event_type == "confirmation_result":
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_result",
|
||||
"data": event.data,
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "step",
|
||||
"data": {
|
||||
"event_type": event.event_type,
|
||||
"step": event.step,
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "confirmation_result",
|
||||
"data": event.data,
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "step",
|
||||
"data": {
|
||||
"event_type": event.event_type,
|
||||
"step": event.step,
|
||||
"data": event.data,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Append assistant reply to session
|
||||
if final_content:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""E2E Agent Capability Tests — SimpleRouter Backtest (Real LLM).
|
||||
"""E2E Agent Capability Tests — RequestPreprocessor Backtest (Real LLM).
|
||||
|
||||
Tests SimpleRouter.route() using real LLM configuration loaded from
|
||||
Tests RequestPreprocessor.preprocess() using real LLM configuration loaded from
|
||||
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
|
||||
|
||||
Key differences from old CostAwareRouter backtest:
|
||||
- No HeuristicClassifier complexity scoring
|
||||
- No IntentRouter LLM classification
|
||||
- No SemanticRouter embedding matching
|
||||
- SimpleRouter: @skill prefix + greeting regex + default REACT
|
||||
- RequestPreprocessor: @skill prefix + greeting regex + default REACT
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -16,7 +16,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.simple_router import SimpleRouter
|
||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||
from agentkit.chat.skill_routing import ExecutionMode
|
||||
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
|
@ -95,7 +95,7 @@ def _find_config_path() -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]:
|
||||
def _build_real_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
||||
config_path = _find_config_path()
|
||||
if not config_path:
|
||||
pytest.skip("No agentkit.yaml found")
|
||||
|
|
@ -132,15 +132,15 @@ def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]:
|
|||
pytest.skip("No LLM provider with valid API key")
|
||||
|
||||
skill_registry = _build_skill_registry(server_config)
|
||||
router = SimpleRouter(skill_registry=skill_registry)
|
||||
preprocessor = RequestPreprocessor(skill_registry=skill_registry)
|
||||
|
||||
return router, skill_registry
|
||||
return preprocessor, skill_registry
|
||||
|
||||
|
||||
_cached_components: tuple[SimpleRouter, SkillRegistry] | None = None
|
||||
_cached_components: tuple[RequestPreprocessor, SkillRegistry] | None = None
|
||||
|
||||
|
||||
def _get_components() -> tuple[SimpleRouter, SkillRegistry]:
|
||||
def _get_components() -> tuple[RequestPreprocessor, SkillRegistry]:
|
||||
global _cached_components
|
||||
if _cached_components is None:
|
||||
_cached_components = _build_real_components()
|
||||
|
|
@ -153,8 +153,8 @@ def _get_components() -> tuple[SimpleRouter, SkillRegistry]:
|
|||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestSimpleRouterBasic:
|
||||
"""Test SimpleRouter basic routing: greeting → DIRECT_CHAT, others → REACT."""
|
||||
class TestRequestPreprocessorBasic:
|
||||
"""Test RequestPreprocessor basic preprocessing: greeting → DIRECT_CHAT, others → REACT."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
|
|
@ -162,9 +162,9 @@ class TestSimpleRouterBasic:
|
|||
ids=[c["id"] for c in ROUTING_TEST_CASES],
|
||||
)
|
||||
def test_routing(self, case: dict):
|
||||
router, skill_registry = _get_components()
|
||||
preprocessor, skill_registry = _get_components()
|
||||
result = asyncio.run(
|
||||
router.route(
|
||||
preprocessor.preprocess(
|
||||
content=case["input"],
|
||||
skill_registry=skill_registry,
|
||||
default_tools=["shell", "search", "file_read"],
|
||||
|
|
@ -179,8 +179,8 @@ class TestSimpleRouterBasic:
|
|||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestSimpleRouterParaphraseConsistency:
|
||||
"""Test that paraphrased inputs route to the same execution mode."""
|
||||
class TestRequestPreprocessorParaphraseConsistency:
|
||||
"""Test that paraphrased inputs preprocess to the same execution mode."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
|
|
@ -188,12 +188,12 @@ class TestSimpleRouterParaphraseConsistency:
|
|||
ids=[c["id"] for c in PARAPHRASE_CASES],
|
||||
)
|
||||
def test_paraphrase_consistency(self, case: dict):
|
||||
router, skill_registry = _get_components()
|
||||
preprocessor, skill_registry = _get_components()
|
||||
expected_mode = case["expected_mode"]
|
||||
|
||||
# Test original
|
||||
result = asyncio.run(
|
||||
router.route(
|
||||
preprocessor.preprocess(
|
||||
content=case["original"],
|
||||
skill_registry=skill_registry,
|
||||
default_tools=["shell", "search", "file_read"],
|
||||
|
|
@ -206,7 +206,7 @@ class TestSimpleRouterParaphraseConsistency:
|
|||
# Test all paraphrases
|
||||
for para in case["paraphrases"]:
|
||||
result = asyncio.run(
|
||||
router.route(
|
||||
preprocessor.preprocess(
|
||||
content=para,
|
||||
skill_registry=skill_registry,
|
||||
default_tools=["shell", "search", "file_read"],
|
||||
|
|
@ -218,19 +218,19 @@ class TestSimpleRouterParaphraseConsistency:
|
|||
|
||||
|
||||
@pytest.mark.e2e_capability
|
||||
class TestSimpleRouterMetrics:
|
||||
"""Compute and report routing accuracy metrics."""
|
||||
class TestRequestPreprocessorMetrics:
|
||||
"""Compute and report preprocessing accuracy metrics."""
|
||||
|
||||
def test_accuracy_report(self):
|
||||
"""Run all test cases and compute accuracy metrics."""
|
||||
router, skill_registry = _get_components()
|
||||
preprocessor, skill_registry = _get_components()
|
||||
total = len(ROUTING_TEST_CASES)
|
||||
correct = 0
|
||||
results = []
|
||||
|
||||
for case in ROUTING_TEST_CASES:
|
||||
result = asyncio.run(
|
||||
router.route(
|
||||
preprocessor.preprocess(
|
||||
content=case["input"],
|
||||
skill_registry=skill_registry,
|
||||
default_tools=["shell", "search", "file_read"],
|
||||
|
|
@ -251,7 +251,7 @@ class TestSimpleRouterMetrics:
|
|||
|
||||
accuracy = correct / total * 100
|
||||
print(f"\n{'='*60}")
|
||||
print(f"SimpleRouter Accuracy Report")
|
||||
print(f"RequestPreprocessor Accuracy Report")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
|
||||
print(f"{'-'*60}")
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
"""Unit tests for SimpleRouter — minimal routing layer."""
|
||||
"""Unit tests for RequestPreprocessor — minimal preprocessing layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.simple_router import SimpleRouter
|
||||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||
|
||||
|
||||
|
|
@ -51,8 +51,8 @@ def registry() -> MockSkillRegistry:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def router(registry: MockSkillRegistry) -> SimpleRouter:
|
||||
return SimpleRouter(
|
||||
def preprocessor(registry: MockSkillRegistry) -> RequestPreprocessor:
|
||||
return RequestPreprocessor(
|
||||
skill_registry=registry,
|
||||
default_tools=["shell", "search", "file_read"],
|
||||
default_system_prompt="You are a helpful assistant.",
|
||||
|
|
@ -67,8 +67,8 @@ def router(registry: MockSkillRegistry) -> SimpleRouter:
|
|||
|
||||
class TestSkillPrefix:
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_prefix_routes_to_skill(self, router: SimpleRouter):
|
||||
result = await router.route("@skill:shell_agent 查看当前ip")
|
||||
async def test_skill_prefix_routes_to_skill(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("@skill:shell_agent 查看当前ip")
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "shell_agent"
|
||||
assert result.match_method == "skill_prefix"
|
||||
|
|
@ -76,22 +76,22 @@ class TestSkillPrefix:
|
|||
assert result.execution_mode == ExecutionMode.SKILL_REACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_prefix_direct_mode(self, router: SimpleRouter):
|
||||
result = await router.route("@skill:direct_agent 翻译hello")
|
||||
async def test_skill_prefix_direct_mode(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("@skill:direct_agent 翻译hello")
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "direct_agent"
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_prefix_rewoo_mode(self, router: SimpleRouter):
|
||||
result = await router.route("@skill:rewoo_agent 重构代码")
|
||||
async def test_skill_prefix_rewoo_mode(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("@skill:rewoo_agent 重构代码")
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "rewoo_agent"
|
||||
assert result.execution_mode == ExecutionMode.REWOO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_skill_falls_back_to_react(self, router: SimpleRouter):
|
||||
result = await router.route("@skill:nonexistent 查询")
|
||||
async def test_unknown_skill_falls_back_to_react(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("@skill:nonexistent 查询")
|
||||
assert result.matched is False
|
||||
assert result.match_method == "skill_not_found_fallback"
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
|
@ -103,30 +103,30 @@ class TestSkillPrefix:
|
|||
|
||||
class TestDirectChat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_cn(self, router: SimpleRouter):
|
||||
result = await router.route("你好")
|
||||
async def test_greeting_cn(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("你好")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
assert result.match_method == "regex_direct"
|
||||
assert result.tools == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_en(self, router: SimpleRouter):
|
||||
result = await router.route("hello")
|
||||
async def test_greeting_en(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("hello")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chitchat(self, router: SimpleRouter):
|
||||
result = await router.route("谢谢")
|
||||
async def test_chitchat(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("谢谢")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identity_question(self, router: SimpleRouter):
|
||||
result = await router.route("你是谁")
|
||||
async def test_identity_question(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("你是谁")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_identity_question_en(self, router: SimpleRouter):
|
||||
result = await router.route("who are you")
|
||||
async def test_identity_question_en(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("who are you")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
|
||||
|
|
@ -136,15 +136,15 @@ class TestDirectChat:
|
|||
|
||||
class TestDefaultReact:
|
||||
@pytest.mark.asyncio
|
||||
async def test_colloquial_tool_query(self, router: SimpleRouter):
|
||||
async def test_colloquial_tool_query(self, preprocessor: RequestPreprocessor):
|
||||
"""口语化工具查询 — 这是之前路由层误判的核心场景"""
|
||||
result = await router.route("查下ip")
|
||||
result = await preprocessor.preprocess("查下ip")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
assert result.match_method == "default_react"
|
||||
assert len(result.tools) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_various_colloquial_expressions(self, router: SimpleRouter):
|
||||
async def test_various_colloquial_expressions(self, preprocessor: RequestPreprocessor):
|
||||
"""各种口语化说法都应走 REACT,让 LLM 决定"""
|
||||
queries = [
|
||||
"查看当前ip",
|
||||
|
|
@ -157,30 +157,30 @@ class TestDefaultReact:
|
|||
"检查服务状态",
|
||||
]
|
||||
for query in queries:
|
||||
result = await router.route(query)
|
||||
result = await preprocessor.preprocess(query)
|
||||
assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_query(self, router: SimpleRouter):
|
||||
result = await router.route("帮我分析一下这个数据并生成报告")
|
||||
async def test_complex_query(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("帮我分析一下这个数据并生成报告")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_translation_goes_react(self, router: SimpleRouter):
|
||||
async def test_translation_goes_react(self, preprocessor: RequestPreprocessor):
|
||||
"""翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具"""
|
||||
result = await router.route("翻译hello为中文")
|
||||
result = await preprocessor.preprocess("翻译hello为中文")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
# LLM will see tools but decide not to use them
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_tools_included(self, router: SimpleRouter):
|
||||
result = await router.route("查下ip")
|
||||
async def test_default_tools_included(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("查下ip")
|
||||
assert "shell" in result.tools
|
||||
assert "search" in result.tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_system_prompt(self, router: SimpleRouter):
|
||||
result = await router.route("查下ip")
|
||||
async def test_default_system_prompt(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("查下ip")
|
||||
assert result.system_prompt == "You are a helpful assistant."
|
||||
|
||||
|
||||
|
|
@ -190,31 +190,31 @@ class TestDefaultReact:
|
|||
|
||||
class TestEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_input(self, router: SimpleRouter):
|
||||
result = await router.route("")
|
||||
async def test_empty_input(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess("")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only(self, router: SimpleRouter):
|
||||
result = await router.route(" ")
|
||||
async def test_whitespace_only(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess(" ")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_with_extra_spaces(self, router: SimpleRouter):
|
||||
result = await router.route(" 你好 ")
|
||||
async def test_greeting_with_extra_spaces(self, preprocessor: RequestPreprocessor):
|
||||
result = await preprocessor.preprocess(" 你好 ")
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_registry(self):
|
||||
"""Router without skill registry should still work for non-skill queries"""
|
||||
router = SimpleRouter(default_tools=["shell"])
|
||||
result = await router.route("查下ip")
|
||||
"""Preprocessor without skill registry should still work for non-skill queries"""
|
||||
preprocessor = RequestPreprocessor(default_tools=["shell"])
|
||||
result = await preprocessor.preprocess("查下ip")
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_override_defaults(self, router: SimpleRouter):
|
||||
"""Route-time overrides should work"""
|
||||
result = await router.route(
|
||||
async def test_override_defaults(self, preprocessor: RequestPreprocessor):
|
||||
"""Preprocess-time overrides should work"""
|
||||
result = await preprocessor.preprocess(
|
||||
"查下ip",
|
||||
default_tools=["shell_only"],
|
||||
default_model="gpt-4o",
|
||||
|
|
@ -1,13 +1,11 @@
|
|||
"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier."""
|
||||
"""Unit tests for ExpertTeamRouter and skill routing utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agentkit.chat.skill_routing import (
|
||||
CostAwareRouter,
|
||||
ExecutionMode,
|
||||
HeuristicClassifier,
|
||||
SkillRoutingResult,
|
||||
)
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
|
|
@ -20,16 +18,6 @@ from agentkit.experts.router import ExpertTeamRouter
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter:
|
||||
"""Create a CostAwareRouter with mocked dependencies."""
|
||||
return CostAwareRouter(
|
||||
llm_gateway=None,
|
||||
model="test",
|
||||
classifier="heuristic",
|
||||
expert_team_router=expert_team_router,
|
||||
)
|
||||
|
||||
|
||||
def _make_team_router_with_templates() -> ExpertTeamRouter:
|
||||
"""Create an ExpertTeamRouter with sample templates."""
|
||||
registry = ExpertTemplateRegistry()
|
||||
|
|
@ -82,251 +70,51 @@ class TestExpertTeamRouterCanHandle:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _try_team_upgrade()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTryTeamUpgrade:
|
||||
def test_upgrade_react_to_team_collab(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex multi-step analysis task",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB
|
||||
assert any(t.get("method") == "team_upgrade" for t in trace)
|
||||
|
||||
def test_no_upgrade_low_complexity(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="simple question",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.3,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
assert not any(t.get("method") == "team_upgrade" for t in trace)
|
||||
|
||||
def test_no_upgrade_no_team_router(self) -> None:
|
||||
router = _make_router(expert_team_router=None)
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex analysis",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.9,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
def test_no_upgrade_empty_templates(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_empty())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex analysis",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
def test_no_upgrade_direct_chat_mode(self) -> None:
|
||||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||||
result = SkillRoutingResult(
|
||||
clean_content="hello",
|
||||
matched=False,
|
||||
match_method="greeting",
|
||||
match_confidence=1.0,
|
||||
complexity=0.0,
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "hello", 0.0, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
||||
def test_team_upgrade_exception_handled(self) -> None:
|
||||
"""When ExpertTeamRouter raises, the upgrade is silently skipped."""
|
||||
broken_router = MagicMock()
|
||||
broken_router.can_handle.side_effect = RuntimeError("boom")
|
||||
router = _make_router(expert_team_router=broken_router)
|
||||
result = SkillRoutingResult(
|
||||
clean_content="complex task",
|
||||
matched=True,
|
||||
match_method="capability",
|
||||
match_confidence=0.8,
|
||||
complexity=0.8,
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
trace: list[dict] = []
|
||||
upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace)
|
||||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: ExpertTeamRouter.resolve() with complexity
|
||||
# Tests: ExpertTeamRouter.resolve()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExpertTeamRouterResolve:
|
||||
def test_explicit_team_prefix(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("@team:analyst,strategist analyze the market", 0.5)
|
||||
result = router.resolve("@team:analyst,strategist analyze the market")
|
||||
assert result.team_mode is True
|
||||
assert result.match_method == "explicit_team"
|
||||
assert "analyst" in result.specified_experts
|
||||
assert "strategist" in result.specified_experts
|
||||
|
||||
def test_complexity_suggestion(self) -> None:
|
||||
def test_no_team_without_prefix(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("complex multi-step analysis", 0.8)
|
||||
assert result.team_mode is True
|
||||
assert result.match_method == "complexity_suggestion"
|
||||
assert result.auto_compose is True
|
||||
|
||||
def test_no_team_low_complexity(self) -> None:
|
||||
router = _make_team_router_with_templates()
|
||||
result = router.resolve("simple question", 0.2)
|
||||
result = router.resolve("simple question")
|
||||
assert result.team_mode is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: HeuristicClassifier complexity calibration
|
||||
# Tests: SkillRoutingResult data structure
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeuristicClassifierLowComplexity:
|
||||
"""Low-complexity signals should produce scores < 0.3."""
|
||||
class TestSkillRoutingResult:
|
||||
def test_default_execution_mode(self) -> None:
|
||||
result = SkillRoutingResult(
|
||||
clean_content="test",
|
||||
matched=False,
|
||||
match_method="default_react",
|
||||
match_confidence=0.8,
|
||||
agent_name="default",
|
||||
model="default",
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
)
|
||||
assert result.execution_mode == ExecutionMode.REACT
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_chinese_greeting(self) -> None:
|
||||
assert self.clf.classify("你好") < 0.3
|
||||
|
||||
def test_chinese_greeting_hi(self) -> None:
|
||||
assert self.clf.classify("嗨") < 0.3
|
||||
|
||||
def test_english_greeting_hello(self) -> None:
|
||||
assert self.clf.classify("Hello") < 0.3
|
||||
|
||||
def test_english_greeting_hi(self) -> None:
|
||||
assert self.clf.classify("hi") < 0.3
|
||||
|
||||
def test_multiple_low_complexity_words(self) -> None:
|
||||
assert self.clf.classify("嗨,早上好") < 0.3
|
||||
|
||||
def test_greeting_with_high_complexity_word_not_suppressed(self) -> None:
|
||||
"""Low-complexity signal should NOT override high-complexity signal."""
|
||||
# "你好" is low, but "分析" is high → should score high
|
||||
assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierIdentity:
|
||||
"""Identity queries should produce scores < 0.3."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_who_are_you_cn(self) -> None:
|
||||
assert self.clf.classify("你是谁") < 0.3
|
||||
|
||||
def test_what_is_your_name_cn(self) -> None:
|
||||
assert self.clf.classify("你叫什么") < 0.3
|
||||
|
||||
|
||||
class TestHeuristicClassifierNegation:
|
||||
"""Negated high-complexity words should not contribute to score."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_negate_search_cn(self) -> None:
|
||||
assert self.clf.classify("不要搜索") < 0.3
|
||||
|
||||
def test_negate_analyze_cn(self) -> None:
|
||||
assert self.clf.classify("无需分析,直接告诉我答案") < 0.3
|
||||
|
||||
def test_partial_negation_still_high(self) -> None:
|
||||
"""'搜索' negated but '分析' not — should still be high."""
|
||||
assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierThresholds:
|
||||
"""Verify adjusted base scores."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_no_keyword_short_message(self) -> None:
|
||||
assert self.clf.classify("好的") <= 0.10
|
||||
|
||||
def test_medium_complexity_base(self) -> None:
|
||||
"""Medium complexity keyword should start at 0.35 (not 0.45)."""
|
||||
score = self.clf.classify("如何使用Python?")
|
||||
# '如何' is medium → base 0.35, '?' short question → -0.10 = 0.25
|
||||
# but 'Python' is not in high/medium lists, so just medium base
|
||||
assert 0.25 <= score <= 0.45
|
||||
|
||||
|
||||
class TestHeuristicClassifierShortQuestion:
|
||||
"""Short questions ending with ?/? should get deduction."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_short_question_deduction(self) -> None:
|
||||
assert self.clf.classify("怎么用?") < 0.3
|
||||
|
||||
def test_long_question_no_deduction(self) -> None:
|
||||
assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5
|
||||
|
||||
|
||||
class TestHeuristicClassifierHighComplexity:
|
||||
"""Complex tasks should produce scores > 0.7."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_two_high_complexity_words(self) -> None:
|
||||
# "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80
|
||||
assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7
|
||||
|
||||
def test_single_high_complexity_word(self) -> None:
|
||||
# "分析" alone → base 0.65
|
||||
assert self.clf.classify("分析市场趋势并生成报告") > 0.6
|
||||
|
||||
def test_execute_and_restart(self) -> None:
|
||||
assert self.clf.classify("执行部署脚本并重启服务") > 0.7
|
||||
|
||||
|
||||
class TestHeuristicClassifierEdgeCases:
|
||||
"""Boundary conditions."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
self.clf = HeuristicClassifier()
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert self.clf.classify("") == 0.0
|
||||
|
||||
def test_whitespace_only(self) -> None:
|
||||
assert self.clf.classify(" ") == 0.0
|
||||
|
||||
def test_long_low_complexity_message(self) -> None:
|
||||
"""Even a long greeting should stay low."""
|
||||
long_greeting = "你好" * 100 # >200 chars
|
||||
assert self.clf.classify(long_greeting) <= 0.15
|
||||
def test_direct_chat_mode(self) -> None:
|
||||
result = SkillRoutingResult(
|
||||
clean_content="hello",
|
||||
matched=False,
|
||||
match_method="regex_direct",
|
||||
match_confidence=1.0,
|
||||
agent_name="default",
|
||||
model="default",
|
||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||
)
|
||||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||
|
|
|
|||
|
|
@ -1,808 +0,0 @@
|
|||
"""CostAwareRouter 单元测试 - 三层成本感知路由"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_skill(
|
||||
name: str,
|
||||
keywords: list[str] | None = None,
|
||||
description: str = "",
|
||||
examples: list[str] | None = None,
|
||||
) -> Skill:
|
||||
"""快速构造一个带 intent 配置的 Skill"""
|
||||
config = SkillConfig(
|
||||
name=name,
|
||||
agent_type="test",
|
||||
task_mode="llm_generate",
|
||||
prompt={"system": f"You are a {name} skill."},
|
||||
intent={
|
||||
"keywords": keywords or [],
|
||||
"description": description,
|
||||
"examples": examples or [],
|
||||
},
|
||||
)
|
||||
return Skill(config=config)
|
||||
|
||||
|
||||
def _make_llm_gateway(response_content: str) -> MagicMock:
|
||||
"""构造一个 mock LLMGateway,chat 返回指定 content"""
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=response_content,
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
)
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
|
||||
"""构造一个 mock SkillRegistry"""
|
||||
registry = MagicMock()
|
||||
_skills = skills or []
|
||||
registry.list_skills.return_value = _skills
|
||||
|
||||
def _get(name: str):
|
||||
for s in _skills:
|
||||
if s.name == name:
|
||||
return s
|
||||
raise KeyError(f"Skill '{name}' not found")
|
||||
|
||||
registry.get = MagicMock(side_effect=_get)
|
||||
return registry
|
||||
|
||||
|
||||
def _make_intent_router() -> IntentRouter:
|
||||
"""构造一个无 LLM 的 IntentRouter(仅关键词匹配)"""
|
||||
return IntentRouter(llm_gateway=None, model="default")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 0: Rule-based (zero cost)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer0Greeting:
|
||||
"""Layer 0: 问候模式匹配"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chinese_greeting_hits_layer0(self):
|
||||
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
assert result.agent_name == "default"
|
||||
assert result.matched is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_english_greeting_hits_layer0(self):
|
||||
"""'hello' 命中 Layer 0 问候规则"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="hello",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_with_punctuation(self):
|
||||
"""'你好!' 带标点也命中 Layer 0"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好!",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
|
||||
|
||||
class TestLayer0ChatMode:
|
||||
"""Layer 0: 简单对话模式"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thanks_hits_chat_mode(self):
|
||||
"""'谢谢' 命中 Layer 0 简单对话模式"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="谢谢",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "chat_mode"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ok_hits_chat_mode(self):
|
||||
"""'好的' 命中 Layer 0 简单对话模式"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="好的",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "chat_mode"
|
||||
|
||||
|
||||
class TestLayer0ExplicitSkill:
|
||||
"""Layer 0: @skill: 显式前缀"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_prefix_hits_layer0(self):
|
||||
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
|
||||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
# 需要 IntentRouter 支持 LLM fallback
|
||||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=gateway, model="default")
|
||||
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="@skill:search 搜索XX",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "search"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 1: LLM quick classify (~100 tokens)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer1Classification:
|
||||
"""Layer 1: LLM 快速分类"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_complexity_routes_via_intent_router(self):
|
||||
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
|
||||
# LLM 返回中等复杂度
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
# IntentRouter 也需要 LLM
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert 0.3 <= result.complexity <= 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_complexity_routes_to_default(self):
|
||||
"""低复杂度 (<0.3) 路由到默认 Agent"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
result = await router.route(
|
||||
content="随便聊聊",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity < 0.3
|
||||
assert result.match_method == "low_complexity"
|
||||
assert result.agent_name == "default"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_gateway_defaults_to_medium(self):
|
||||
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
|
||||
router = CostAwareRouter(llm_gateway=None)
|
||||
complexity = await router.quick_classify("分析下这个数据")
|
||||
assert complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_malformed_response_defaults_to_medium(self):
|
||||
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
|
||||
gateway = _make_llm_gateway("这不是JSON")
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
complexity = await router.quick_classify("分析下这个数据")
|
||||
assert complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complexity_clamped_to_0_1(self):
|
||||
"""复杂度值被限制在 [0.0, 1.0] 范围"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||||
complexity = await router.quick_classify("超级复杂任务")
|
||||
assert complexity == 1.0
|
||||
|
||||
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
|
||||
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
|
||||
complexity2 = await router2.quick_classify("简单任务")
|
||||
assert complexity2 == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 2: Capability matching / Auction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLayer2CapabilityMatching:
|
||||
"""Layer 2: 能力匹配 / 拍卖"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_complexity_triggers_capability_matching(self):
|
||||
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value="market-researcher")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
assert result.match_method == "capability"
|
||||
assert result.agent_name == "market-researcher"
|
||||
assert result.matched is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_with_org_context_object(self):
|
||||
"""org_context.find_best_agent 返回对象时提取 name 属性"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
|
||||
agent_obj = MagicMock()
|
||||
agent_obj.name = "analyst-agent"
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value=agent_obj)
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.agent_name == "analyst-agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
|
||||
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
# 回退到 IntentRouter,可能匹配到 skill 或走 default
|
||||
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_layer2_org_context_find_best_agent_returns_none(self):
|
||||
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value=None)
|
||||
|
||||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_disabled_by_default(self):
|
||||
"""拍卖模式默认禁用"""
|
||||
router = CostAwareRouter()
|
||||
assert router._auction_enabled is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auction_can_be_enabled(self):
|
||||
"""拍卖模式可手动启用"""
|
||||
router = CostAwareRouter(auction_enabled=True)
|
||||
assert router._auction_enabled is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transparency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTransparency:
|
||||
"""透明度级别切换"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_silent_mode_no_trace(self):
|
||||
"""SILENT 模式不暴露路由追踪"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="SILENT",
|
||||
)
|
||||
assert result.execution_trace == []
|
||||
assert result.transparency_level == "SILENT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verbose_mode_shows_trace(self):
|
||||
"""VERBOSE 模式显示路由追踪"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="VERBOSE",
|
||||
)
|
||||
assert len(result.execution_trace) > 0
|
||||
assert result.execution_trace[0]["layer"] == 0
|
||||
assert result.execution_trace[0]["method"] == "greeting"
|
||||
assert result.transparency_level == "VERBOSE"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_mode_shows_full_trace(self):
|
||||
"""TRACE 模式显示完整路由追踪"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value="analyst")
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
transparency="TRACE",
|
||||
)
|
||||
assert len(result.execution_trace) > 0
|
||||
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
|
||||
layers = [t["layer"] for t in result.execution_trace]
|
||||
assert 1 in layers # Layer 1 quick_classify
|
||||
assert 2 in layers # Layer 2 capability matching
|
||||
assert result.transparency_level == "TRACE"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_transparency_is_silent(self):
|
||||
"""默认透明度为 SILENT"""
|
||||
router = CostAwareRouter()
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillRoutingResult 新字段
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSkillRoutingResultNewFields:
|
||||
"""SkillRoutingResult 新字段验证"""
|
||||
|
||||
def test_default_transparency_level(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.transparency_level == "SILENT"
|
||||
|
||||
def test_default_execution_trace(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.execution_trace == []
|
||||
|
||||
def test_default_complexity(self):
|
||||
result = SkillRoutingResult()
|
||||
assert result.complexity == 0.0
|
||||
|
||||
def test_new_fields_backward_compatible(self):
|
||||
"""新字段不影响旧代码创建 SkillRoutingResult"""
|
||||
result = SkillRoutingResult(
|
||||
skill_name="test",
|
||||
matched=True,
|
||||
match_method="keyword",
|
||||
)
|
||||
assert result.transparency_level == "SILENT"
|
||||
assert result.execution_trace == []
|
||||
assert result.complexity == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _tokenize_content: 中文分词增强
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenizeContent:
|
||||
"""_tokenize_content 中文分词增强测试"""
|
||||
|
||||
def test_chinese_content(self):
|
||||
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
|
||||
tokens = _tokenize_content("帮我做数据分析")
|
||||
# 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析
|
||||
assert "数据" in tokens or "数据分析" in tokens
|
||||
|
||||
def test_english_content(self):
|
||||
"""英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'"""
|
||||
tokens = _tokenize_content("help with code generation")
|
||||
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
|
||||
|
||||
def test_mixed_content(self):
|
||||
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
|
||||
tokens = _tokenize_content("用python做data analysis")
|
||||
# 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram
|
||||
# "analysis" 作为独立 segment
|
||||
assert "analysis" in tokens
|
||||
# "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段
|
||||
has_python_related = any("python" in t for t in tokens)
|
||||
assert has_python_related or "data analysis" in tokens
|
||||
|
||||
def test_stopwords_filtered(self):
|
||||
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
|
||||
tokens = _tokenize_content("我的一个")
|
||||
# "我的一个" 长度 4,作为整体保留(不在停用词集合中)
|
||||
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
|
||||
assert len(tokens) <= 1
|
||||
|
||||
def test_bigram_generation(self):
|
||||
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
|
||||
tokens = _tokenize_content("机器学习模型训练")
|
||||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||||
for bigram in expected_bigrams:
|
||||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HeuristicClassifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeuristicClassifier:
|
||||
"""HeuristicClassifier 本地启发式分类器测试"""
|
||||
|
||||
def setup_method(self):
|
||||
self.classifier = HeuristicClassifier()
|
||||
|
||||
def test_short_greeting_low_complexity(self):
|
||||
"""短问候语 → 低复杂度"""
|
||||
score = self.classifier.classify("你好呀")
|
||||
assert score < 0.3
|
||||
|
||||
def test_simple_question_medium_complexity(self):
|
||||
"""含'如何'的简单问题 → 中等复杂度"""
|
||||
score = self.classifier.classify("如何使用这个功能?")
|
||||
assert 0.3 <= score <= 0.7
|
||||
|
||||
def test_tool_request_high_complexity(self):
|
||||
"""含工具关键词的请求 → 高复杂度"""
|
||||
score = self.classifier.classify("帮我搜索一下最新的新闻")
|
||||
assert score > 0.5
|
||||
|
||||
def test_code_request_high_complexity(self):
|
||||
"""代码相关请求 → 高复杂度"""
|
||||
score = self.classifier.classify("写一个Python函数实现快速排序")
|
||||
assert score > 0.6
|
||||
|
||||
def test_multi_step_request_high_complexity(self):
|
||||
"""多步分析请求 → 高复杂度"""
|
||||
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
|
||||
assert score > 0.7
|
||||
|
||||
def test_empty_string_zero_complexity(self):
|
||||
"""空字符串 → 零复杂度"""
|
||||
assert self.classifier.classify("") == 0.0
|
||||
assert self.classifier.classify(" ") == 0.0
|
||||
|
||||
def test_long_message_higher_complexity(self):
|
||||
"""长消息 → 更高复杂度"""
|
||||
short = "帮我查一下"
|
||||
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
|
||||
assert self.classifier.classify(long) > self.classifier.classify(short)
|
||||
|
||||
def test_code_patterns_boost_complexity(self):
|
||||
"""代码模式(反引号/括号)提升复杂度"""
|
||||
with_code = "运行这段代码 `print('hello')`"
|
||||
without_code = "运行这段代码 print hello"
|
||||
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
|
||||
|
||||
def test_score_bounded_0_to_1(self):
|
||||
"""复杂度值始终在 [0.0, 1.0] 范围"""
|
||||
test_inputs = [
|
||||
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
|
||||
]
|
||||
for inp in test_inputs:
|
||||
score = self.classifier.classify(inp)
|
||||
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
|
||||
|
||||
|
||||
class TestHeuristicClassifierIntegration:
|
||||
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_mode_no_llm_call(self):
|
||||
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# LLM gateway.chat 不应被调用(heuristic + merged disabled)
|
||||
gateway.chat.assert_not_called()
|
||||
# 复杂度应来自启发式分类器
|
||||
assert result.complexity > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_mode_uses_llm(self):
|
||||
"""llm 模式下调用 LLM quick_classify"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# LLM gateway.chat 应被调用
|
||||
gateway.chat.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_greeting_still_layer0(self):
|
||||
"""heuristic 模式下问候仍走 Layer 0"""
|
||||
router = CostAwareRouter(classifier="heuristic")
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_default_classifier_mode(self):
|
||||
"""默认分类器模式为 heuristic"""
|
||||
router = CostAwareRouter()
|
||||
assert router._classifier == "heuristic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# U1: Merged LLM Classify
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedLLMClassify:
|
||||
"""合并路由 LLM 调用测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_returns_valid_skill(self):
|
||||
"""合并调用返回有效 JSON + skill_hint,正确路由到指定 skill"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.6,
|
||||
"intent": "code_generation",
|
||||
"skill_hint": "search",
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我搜索一下最新的新闻",
|
||||
skill_registry=registry,
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "search"
|
||||
assert result.match_method == "merged_llm"
|
||||
assert result.complexity > 0.3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_malformed_response_fallback(self):
|
||||
"""合并调用返回格式异常,fallback 到默认 Agent"""
|
||||
gateway = _make_llm_gateway("这不是JSON")
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "merged_llm_fallback"
|
||||
assert result.complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_low_complexity(self):
|
||||
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.2,
|
||||
"intent": "greeting",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Merged LLM returned complexity < 0.3, should route to low complexity
|
||||
assert result.complexity < 0.3
|
||||
assert "low" in result.match_method or "merged_llm_low" in result.match_method
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_high_complexity(self):
|
||||
"""合并调用返回 complexity > 0.7,走 Layer 2"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.85,
|
||||
"intent": "research",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value="researcher")
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway, model="default",
|
||||
org_context=org_context, merged_llm_classify=True,
|
||||
)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
assert result.match_method == "capability"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
|
||||
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway, model="default",
|
||||
merged_llm_classify=False,
|
||||
)
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should not use merged_llm match_method
|
||||
assert result.match_method != "merged_llm"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_no_llm_gateway_falls_back(self):
|
||||
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should not crash, should use IntentRouter fallback
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_skill_hint_not_found_fallback(self):
|
||||
"""合并调用返回的 skill_hint 在 registry 中不存在,fallback"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.5,
|
||||
"intent": "unknown",
|
||||
"skill_hint": "nonexistent_skill",
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should fallback to default agent (medium complexity, no skill match)
|
||||
assert result.matched is False
|
||||
assert result.match_method == "merged_llm_medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_only_one_llm_call(self):
|
||||
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.5,
|
||||
"intent": "question",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
await router.route(
|
||||
content="如何使用这个功能?",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Only 1 LLM call should have been made (the merged classify)
|
||||
assert gateway.chat.call_count == 1
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
"""Unit tests for Semantic Router (U3)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.semantic_router import (
|
||||
SemanticRouteResult,
|
||||
SkillEmbeddingIndex,
|
||||
SemanticRouter,
|
||||
)
|
||||
from agentkit.memory.embedder import MockEmbedder
|
||||
|
||||
|
||||
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
|
||||
"""Create a unit vector for similarity testing."""
|
||||
vec = [base_val] * dim
|
||||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||||
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
||||
|
||||
|
||||
class MockSkill:
|
||||
"""Mock skill for testing."""
|
||||
|
||||
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
|
||||
self.name = name
|
||||
self.config = MockSkillConfig(
|
||||
name=name,
|
||||
description=description,
|
||||
keywords=keywords or [],
|
||||
capabilities=capabilities or [],
|
||||
)
|
||||
|
||||
|
||||
class MockSkillConfig:
|
||||
"""Mock skill config for testing."""
|
||||
|
||||
def __init__(self, name: str, description: str = "", keywords: list[str] | None = None, capabilities: list[str] | None = None):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.intent = MockIntentConfig(keywords=keywords or [])
|
||||
self.capabilities = [MockCapabilityTag(tag=t) for t in (capabilities or [])]
|
||||
|
||||
|
||||
class MockIntentConfig:
|
||||
def __init__(self, keywords: list[str] | None = None):
|
||||
self.keywords = keywords or []
|
||||
|
||||
|
||||
class MockCapabilityTag:
|
||||
def __init__(self, tag: str):
|
||||
self.tag = tag
|
||||
|
||||
|
||||
class MockSkillRegistry:
|
||||
"""Mock skill registry for testing."""
|
||||
|
||||
def __init__(self, skills: list[MockSkill] | None = None):
|
||||
self._skills = {s.name: s for s in (skills or [])}
|
||||
|
||||
def list_skills(self):
|
||||
return list(self._skills.values())
|
||||
|
||||
def get(self, name: str):
|
||||
if name not in self._skills:
|
||||
raise KeyError(f"Skill '{name}' not found")
|
||||
return self._skills[name]
|
||||
|
||||
|
||||
class TestSkillEmbeddingIndex:
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_from_registry(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skills = [
|
||||
MockSkill("content_gen", description="生成文章内容", keywords=["写作", "文章"], capabilities=["content"]),
|
||||
MockSkill("data_analysis", description="数据分析与可视化", keywords=["分析", "数据"], capabilities=["analytics"]),
|
||||
]
|
||||
registry = MockSkillRegistry(skills)
|
||||
await index.build(registry)
|
||||
|
||||
assert index.size == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_results(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await index.update_skill("content_gen", skill)
|
||||
|
||||
# MockEmbedder produces deterministic embeddings based on text hash
|
||||
# Different text → different embedding
|
||||
query_emb = await embedder.embed("生成文章")
|
||||
results = await index.search(query_emb)
|
||||
|
||||
assert len(results) >= 1
|
||||
assert results[0][0] == "content_gen" # skill_name
|
||||
assert results[0][1] > 0.0 # similarity
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_index(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
query_emb = await embedder.embed("test")
|
||||
results = await index.search(query_emb)
|
||||
|
||||
assert results == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_skill(self):
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
index = SkillEmbeddingIndex(embedder)
|
||||
|
||||
skill = MockSkill("test_skill", description="Test")
|
||||
await index.update_skill("test_skill", skill)
|
||||
assert index.size == 1
|
||||
|
||||
index.remove_skill("test_skill")
|
||||
assert index.size == 0
|
||||
|
||||
def test_build_source_text_with_description(self):
|
||||
skill = MockSkill("test", description="A test skill", keywords=["test"], capabilities=["testing"])
|
||||
text = SkillEmbeddingIndex._build_source_text(skill)
|
||||
assert "A test skill" in text
|
||||
assert "test" in text
|
||||
assert "testing" in text
|
||||
|
||||
def test_build_source_text_fallback_to_name(self):
|
||||
skill = MockSkill("my_skill", description="", keywords=[], capabilities=[])
|
||||
text = SkillEmbeddingIndex._build_source_text(skill)
|
||||
assert "my_skill" in text
|
||||
|
||||
|
||||
class TestSemanticRouter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_confidence_match(self):
|
||||
"""When similarity > 0.85, return high confidence."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.5, similarity_low=0.3)
|
||||
|
||||
# Add a skill with known embedding
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await router.update_skill("content_gen", skill)
|
||||
|
||||
# Query with same text should produce very similar embedding (MockEmbedder is hash-based)
|
||||
# With low thresholds, even moderate similarity will be "high"
|
||||
result = await router.route("生成文章内容")
|
||||
# MockEmbedder may or may not produce high similarity for different text
|
||||
# Just verify the result structure
|
||||
assert result.confidence in ("high", "medium", "low")
|
||||
assert isinstance(result.similarity, float)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_low_confidence_empty_index(self):
|
||||
"""Empty index returns low confidence."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder)
|
||||
|
||||
result = await router.route("任何查询")
|
||||
assert result.confidence == "low"
|
||||
assert result.skill_name is None
|
||||
assert result.similarity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_confidence_zone(self):
|
||||
"""Test medium confidence zone (0.6-0.85)."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.99, similarity_low=0.01)
|
||||
|
||||
skill = MockSkill("content_gen", description="生成文章内容")
|
||||
await router.update_skill("content_gen", skill)
|
||||
|
||||
# With very high similarity_high and very low similarity_low,
|
||||
# most matches will be "medium"
|
||||
result = await router.route("生成文章")
|
||||
# The result should be medium (since threshold is 0.99)
|
||||
assert result.confidence in ("medium", "low", "high")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedder_failure_graceful(self):
|
||||
"""Embedder failure returns low confidence."""
|
||||
class FailingEmbedder(MockEmbedder):
|
||||
async def embed(self, text):
|
||||
raise RuntimeError("Embedding API failed")
|
||||
|
||||
router = SemanticRouter(FailingEmbedder(dimension=64))
|
||||
result = await router.route("test query")
|
||||
assert result.confidence == "low"
|
||||
assert result.skill_name is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_index_from_registry(self):
|
||||
"""Build index from skill registry."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder)
|
||||
|
||||
skills = [
|
||||
MockSkill("skill_a", description="Skill A"),
|
||||
MockSkill("skill_b", description="Skill B"),
|
||||
]
|
||||
registry = MockSkillRegistry(skills)
|
||||
await router.build_index(registry)
|
||||
|
||||
assert router._index.size == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chinese_query(self):
|
||||
"""Chinese query works with semantic router."""
|
||||
embedder = MockEmbedder(dimension=64)
|
||||
router = SemanticRouter(embedder, similarity_high=0.01, similarity_low=0.001)
|
||||
|
||||
skill = MockSkill("geo_optimizer", description="地理内容优化", keywords=["优化", "SEO", "地理"], capabilities=["optimization"])
|
||||
await router.update_skill("geo_optimizer", skill)
|
||||
|
||||
result = await router.route("帮我优化内容")
|
||||
# With very low thresholds, should match
|
||||
assert result.confidence in ("high", "medium")
|
||||
assert result.skill_name == "geo_optimizer"
|
||||
Loading…
Reference in New Issue