merge: feat/simple-router-architecture - Replace 4-layer CostAwareRouter with SimpleRouter + prompt-based tool calling
This commit is contained in:
commit
f770d65c7b
|
|
@ -0,0 +1,58 @@
|
||||||
|
# SimpleRouter 回测报告
|
||||||
|
|
||||||
|
created: 2026-06-16
|
||||||
|
|
||||||
|
## 回测结果
|
||||||
|
|
||||||
|
### 路由准确率
|
||||||
|
|
||||||
|
| 指标 | 结果 |
|
||||||
|
|------|------|
|
||||||
|
| 总测试用例 | 24 |
|
||||||
|
| 通过 | 24 |
|
||||||
|
| 失败 | 0 |
|
||||||
|
| **准确率** | **100%** |
|
||||||
|
|
||||||
|
### 分类明细
|
||||||
|
|
||||||
|
| 分类 | 用例数 | 通过 | 准确率 |
|
||||||
|
|------|-------|------|--------|
|
||||||
|
| 问候/闲聊 → DIRECT_CHAT | 4 | 4 | 100% |
|
||||||
|
| 口语化工具查询 → REACT | 5 | 5 | 100% |
|
||||||
|
| 标准工具查询 → REACT | 5 | 5 | 100% |
|
||||||
|
| 翻译/知识 → REACT | 3 | 3 | 100% |
|
||||||
|
| 复杂查询 → REACT | 3 | 3 | 100% |
|
||||||
|
| @skill 前缀 → SKILL_REACT | 1 | 1 | 100% |
|
||||||
|
|
||||||
|
### 口语化查询覆盖(核心改进)
|
||||||
|
|
||||||
|
| 输入 | 旧架构结果 | 新架构结果 |
|
||||||
|
|------|-----------|-----------|
|
||||||
|
| "查下ip" | direct_agent (误判) | REACT ✓ |
|
||||||
|
| "查看当前ip" | direct_agent (误判) | REACT ✓ |
|
||||||
|
| "获取ip地址" | direct_agent (误判) | REACT ✓ |
|
||||||
|
| "看下ip" | direct_agent (误判) | REACT ✓ |
|
||||||
|
| "帮我查一下ip" | direct_agent (误判) | REACT ✓ |
|
||||||
|
|
||||||
|
### 改写一致性
|
||||||
|
|
||||||
|
| 测试组 | 原始说法 | 改写说法数 | 一致性 |
|
||||||
|
|--------|---------|-----------|--------|
|
||||||
|
| ip_check_variants | "查看当前ip" | 5 | 100% |
|
||||||
|
| search_variants | "搜索golang教程" | 3 | 100% |
|
||||||
|
|
||||||
|
## 与旧架构对比
|
||||||
|
|
||||||
|
| 指标 | CostAwareRouter (旧) | SimpleRouter (新) |
|
||||||
|
|------|---------------------|-------------------|
|
||||||
|
| 执行模式准确率 | 40.38% | **100%** |
|
||||||
|
| 口语化查询成功率 | 30% | **100%** |
|
||||||
|
| 路由层 LLM 调用 | 1次/查询 | **0次/查询** |
|
||||||
|
| 路由层延迟 | ~500ms | **<1ms** |
|
||||||
|
| 路由层 token 消耗 | ~1000 tokens | **0 tokens** |
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. 翻译/知识类查询现在走 REACT(LLM 看到工具但决定不使用),比旧架构多消耗约 2000 tokens(工具描述)
|
||||||
|
2. 这是可靠性换 token 的权衡,可接受
|
||||||
|
3. 后续可通过模型能力检测优化:支持 function calling 的模型用 API tools,不支持才用 prompt-based
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
# 代码审查修复报告
|
||||||
|
|
||||||
|
日期: 2026-06-16
|
||||||
|
分支: feat/simple-router-architecture
|
||||||
|
审查范围: SimpleRouter + Prompt-based Tool Calling 重构
|
||||||
|
|
||||||
|
## 修复的问题
|
||||||
|
|
||||||
|
### HIGH 级别
|
||||||
|
|
||||||
|
| # | 问题 | 修复方式 | 文件 |
|
||||||
|
|---|------|----------|------|
|
||||||
|
| H1 | execute() 缺少 prompt-based tool calling 注入 | 与 execute_stream() 一致,在 execute() 中注入 _build_tool_use_prompt | react.py |
|
||||||
|
| H2 | agent._routing_result 竞态条件 | _resolve_for_chat 返回 routing_result 而非 monkey-patch agent | portal.py |
|
||||||
|
| H3 | TEAM_COLLAB/REWOO/REFLEXION 降级为 REACT 无提示 | 添加 execution_mode 完整分支处理,高级模式降级时记录 warning 日志 | portal.py, chat.py |
|
||||||
|
| H4 | Any 类型滥用 | SimpleRouter 使用 SkillRegistry/Tool 类型,ConversationStore 使用 SessionManager 类型 | simple_router.py, portal.py |
|
||||||
|
|
||||||
|
### MEDIUM 级别
|
||||||
|
|
||||||
|
| # | 问题 | 修复方式 | 文件 |
|
||||||
|
|---|------|----------|------|
|
||||||
|
| M1 | default_system_prompt `or` vs `is not None` | 改为 `if default_system_prompt is not None` | simple_router.py |
|
||||||
|
| M2 | CostAwareRouter 死代码 | 改为条件初始化(legacy_cost_aware_router 配置项),默认不初始化 | app.py |
|
||||||
|
| M3 | chat.py 不处理 DIRECT_CHAT | 添加 DIRECT_CHAT 分支:直接 LLM 调用,不经过 ReActEngine | chat.py |
|
||||||
|
|
||||||
|
## 测试结果
|
||||||
|
|
||||||
|
- SimpleRouter 单元测试: 20/20 通过
|
||||||
|
- SimpleRouter E2E 回测: 24/24 通过(准确率 100%)
|
||||||
|
- chat/core 单元测试: 153/153 通过
|
||||||
|
- Ruff lint: All checks passed
|
||||||
|
|
||||||
|
## 变更文件
|
||||||
|
|
||||||
|
1. `src/agentkit/core/react.py` — execute() 添加 prompt-based tool calling 注入
|
||||||
|
2. `src/agentkit/chat/simple_router.py` — 类型注解修复(Any → SkillRegistry/Tool),`is not None` 修复
|
||||||
|
3. `src/agentkit/server/routes/portal.py` — 竞态条件修复,execution_mode 完整分支,类型修复
|
||||||
|
4. `src/agentkit/server/routes/chat.py` — DIRECT_CHAT 处理,execution_mode 降级警告
|
||||||
|
5. `src/agentkit/server/app.py` — CostAwareRouter 条件初始化,移除 semantic router build_index 死代码
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
# 推理验证纠偏记录
|
||||||
|
|
||||||
|
created: 2026-06-16
|
||||||
|
|
||||||
|
## 验证场景
|
||||||
|
|
||||||
|
### 场景 1: "查下ip"
|
||||||
|
- SimpleRouter: 无 @skill 前缀,非问候 → REACT + 全量工具
|
||||||
|
- ReActEngine: system prompt 注入工具描述 + <tool_use> 格式
|
||||||
|
- LLM 看到工具描述,理解需要 shell → 输出 <tool_use> → 解析执行
|
||||||
|
- **结论**: 正确 ✓
|
||||||
|
|
||||||
|
### 场景 2: "你好"
|
||||||
|
- SimpleRouter: 匹配 _GREETING_RE → DIRECT_CHAT
|
||||||
|
- 直接 LLM 调用,无工具
|
||||||
|
- **结论**: 正确 ✓
|
||||||
|
|
||||||
|
### 场景 3: "翻译hello为中文"
|
||||||
|
- SimpleRouter: 无前缀,非问候 → REACT + 全量工具
|
||||||
|
- LLM 看到工具但判断不需要 → 直接翻译
|
||||||
|
- **代价**: 多约 2000 tokens(工具描述),但保证正确性
|
||||||
|
- **结论**: 正确 ✓(token 成本可接受)
|
||||||
|
|
||||||
|
### 场景 4: "@skill:shell_agent 查看当前ip"
|
||||||
|
- SimpleRouter: @skill 前缀 → SKILL_REACT + shell_agent 工具
|
||||||
|
- **结论**: 正确 ✓
|
||||||
|
|
||||||
|
## 发现的问题
|
||||||
|
|
||||||
|
### P3: tool_schemas 和 prompt-based 工具描述同时存在
|
||||||
|
- **分析**: API tools 参数 + system prompt 工具描述同时传入
|
||||||
|
- **影响**: 支持 function calling 的模型走原生路径,不支持的走 prompt-based
|
||||||
|
- **当前代码已兼容**: has_tool_calls → 原生路径,else → 文本解析路径
|
||||||
|
- **代价**: 工具描述冗余约 2000 tokens
|
||||||
|
- **结论**: 可接受,不需要修改
|
||||||
|
|
||||||
|
### P4: ChatMessage timestamp 类型不匹配(已修复)
|
||||||
|
- **问题**: add_message 中用了 .isoformat(),但字段类型是 datetime
|
||||||
|
- **修复**: 移除 .isoformat(),使用默认的 datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
## 未修改的已知问题
|
||||||
|
|
||||||
|
1. tasks.py 仍使用 IntentRouter — 不在 Portal 路径,暂不影响
|
||||||
|
2. 工具描述冗余 — 后续可优化(检测模型是否支持 function calling)
|
||||||
|
3. chat.py 中 user_msg 未使用变量 — 预先存在的 lint 警告
|
||||||
|
|
@ -0,0 +1,304 @@
|
||||||
|
# refactor: 路由架构简化 — 统一 REACT Agent Loop
|
||||||
|
|
||||||
|
status: active
|
||||||
|
created: 2026-06-16
|
||||||
|
depth: Standard
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
将当前 4 层路由架构(HeuristicClassifier → LLM classify → SemanticRouter → IntentRouter)简化为极简路由层 + 统一 REACT Agent Loop(Hermes 模式 Prompt-based XML tool calling)。删除意图预测层,让 LLM 在 agent loop 中看到完整工具描述后自主决策。
|
||||||
|
|
||||||
|
## Problem Frame
|
||||||
|
|
||||||
|
当前 CostAwareRouter 的 4 层路由架构存在根本性设计缺陷:
|
||||||
|
|
||||||
|
1. **路由层预测意图是反模式** — LLM 在路由层看不到工具上下文,必然误判(如"查下ip"被分为 direct_agent)
|
||||||
|
2. **枚举永远覆盖不完** — HeuristicClassifier 的关键词列表无法覆盖所有口语化说法
|
||||||
|
3. **多层路由增加延迟** — 每次查询 3 次 LLM 调用(路由1 + REACT2),响应 3-5s
|
||||||
|
4. **双链路不一致** — Portal REST 走 IntentRouter,WebSocket 走 CostAwareRouter
|
||||||
|
5. **工具格式不兼容** — 百炼 Coding 不支持原生 function calling,模型输出 `<tool_use>` 文本但引擎无法解析
|
||||||
|
|
||||||
|
**行业验证**:Codex、Trae、Hermes、OpenClaw 均无独立路由层,统一 agent loop 是业界标准。
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- R1: 删除 HeuristicClassifier、IntentRouter、SemanticRouter 的路由决策功能
|
||||||
|
- R2: 保留极简路由层(@skill 前缀 + 问候/闲聊检测)
|
||||||
|
- R3: 统一 REACT Agent Loop,System Prompt 注入完整工具描述
|
||||||
|
- R4: Prompt-based XML tool calling(`<tool_use>` 格式),后端解析执行
|
||||||
|
- R5: Portal REST 和 WebSocket 统一路由路径
|
||||||
|
- R6: 聊天记录持久化(Portal ConversationStore → SessionManager)
|
||||||
|
- R7: 回测验证:执行模式准确率 >85%,工具调用成功率 >95%,口语化查询成功率 >90%
|
||||||
|
- R8: 性能指标:响应时间 <3s(简单查询),LLM 调用次数 ≤2 次/查询
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Technical Decisions
|
||||||
|
|
||||||
|
### KTD-1: 采用 Hermes 模式 Prompt-based XML Tool Calling
|
||||||
|
|
||||||
|
**决策**:System Prompt 中定义 `<tool_use>` 格式,LLM 输出 XML 标签,后端解析执行。
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- 百炼 Coding(qwen3.7-plus)不支持原生 function calling
|
||||||
|
- 截图验证模型已理解 `<tool_use>` 格式
|
||||||
|
- 与 Hermes 架构一致,模型无关
|
||||||
|
|
||||||
|
**替代方案**:
|
||||||
|
- 原生 function calling:百炼 Coding 不兼容
|
||||||
|
- Action: 格式:不如 XML 结构化
|
||||||
|
|
||||||
|
### KTD-2: 删除路由层意图预测,保留极简规则层
|
||||||
|
|
||||||
|
**决策**:只保留 @skill 前缀路由和问候/闲聊检测,其他所有查询默认走 REACT。
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- 路由层预测意图的准确率远低于 LLM 在 agent loop 中的决策
|
||||||
|
- 删除路由层节省 1 次 LLM 调用(~500ms,~1000 tokens)
|
||||||
|
- 问候/闲聊检测是确定性规则,零误判
|
||||||
|
|
||||||
|
### KTD-3: 工具全量加载(第一阶段)
|
||||||
|
|
||||||
|
**决策**:默认加载所有 21 个工具到 System Prompt,通过 @skill 前缀实现按需加载。
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- 21 个工具的描述约 2000 tokens,成本可接受
|
||||||
|
- 全量加载保证 LLM 能看到所有工具,零误判
|
||||||
|
- 按需加载(Regex 筛选)留作第二阶段优化
|
||||||
|
|
||||||
|
### KTD-4: 保留其他 Agent 架构作为 skill 配置可选模式
|
||||||
|
|
||||||
|
**决策**:ReWOOAgent、ReflexionAgent 等保留,通过 skill YAML 的 `execution_mode` 字段切换。
|
||||||
|
|
||||||
|
**理由**:
|
||||||
|
- 不同场景需要不同执行模式(代码生成用 ReWOO,失败重试用 Reflexion)
|
||||||
|
- 已有投入不应浪费
|
||||||
|
- 只是路由方式变了,执行模式不变
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scope Boundaries
|
||||||
|
|
||||||
|
### In Scope
|
||||||
|
- 简化 CostAwareRouter 为极简路由层
|
||||||
|
- ReActEngine 改为 prompt-based tool calling
|
||||||
|
- Portal REST/WebSocket 统一路由
|
||||||
|
- 聊天记录持久化
|
||||||
|
- E2E 回测和指标验证
|
||||||
|
|
||||||
|
### Out of Scope
|
||||||
|
- Embedding API 集成(待用户提供 API key)
|
||||||
|
- 前端 GUI 改造
|
||||||
|
- Expert Team 模式重构
|
||||||
|
- 工具按需加载的 Regex 筛选层(第二阶段)
|
||||||
|
|
||||||
|
### Deferred to Follow-Up Work
|
||||||
|
- SemanticRouter 降级为可选插件
|
||||||
|
- 工具数量 >30 时的分组加载策略
|
||||||
|
- 响应流式优化(SSE chunk 细化)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## High-Level Technical Design
|
||||||
|
|
||||||
|
### 目标架构
|
||||||
|
|
||||||
|
```
|
||||||
|
用户输入
|
||||||
|
↓
|
||||||
|
SimpleRouter(极简路由层,<1ms)
|
||||||
|
├─ @skill:xxx → 加载指定 skill 工具 → REACT Agent
|
||||||
|
├─ 问候/闲聊(regex)→ DIRECT_CHAT(无工具,快速路径)
|
||||||
|
└─ 其他 → 加载所有默认工具 → REACT Agent
|
||||||
|
↓
|
||||||
|
REACT Agent Loop
|
||||||
|
├─ System Prompt: 工具描述 + <tool_use> 格式说明
|
||||||
|
├─ LLM 决策: 需要 → 输出 <tool_use> → 解析执行 → Observation → 继续
|
||||||
|
└─ LLM 决策: 不需要 → 直接回答 → final_answer
|
||||||
|
```
|
||||||
|
|
||||||
|
### 路由简化对比
|
||||||
|
|
||||||
|
| 组件 | 当前 | 目标 |
|
||||||
|
|------|------|------|
|
||||||
|
| CostAwareRouter.route() | 1688 行,4 层 | ~200 行,1 层 |
|
||||||
|
| HeuristicClassifier | 310 行 | 删除 |
|
||||||
|
| IntentRouter | 206 行 | 删除路由功能 |
|
||||||
|
| SemanticRouter | 224 行 | 删除路由功能 |
|
||||||
|
| _classify_merged | 200 行 | 删除 |
|
||||||
|
| _route_layer2 | 210 行 | 删除 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Units
|
||||||
|
|
||||||
|
### U1. 创建 SimpleRouter 替代 CostAwareRouter
|
||||||
|
|
||||||
|
**Goal**: 实现极简路由层,只保留 @skill 前缀和问候/闲聊检测
|
||||||
|
|
||||||
|
**Requirements**: R1, R2
|
||||||
|
|
||||||
|
**Dependencies**: 无
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/chat/simple_router.py` (新建)
|
||||||
|
- `src/agentkit/chat/skill_routing.py` (修改 — 保留 SkillRoutingResult、ExecutionMode、parse_skill_prefix)
|
||||||
|
- `tests/unit/chat/test_simple_router.py` (新建)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
1. 新建 `SimpleRouter` 类,包含 `route()` 方法
|
||||||
|
2. `route()` 逻辑:@skill 前缀 → 指定 skill;问候/闲聊 regex → DIRECT_CHAT;其他 → REACT
|
||||||
|
3. 保留 `SkillRoutingResult` 数据类和 `ExecutionMode` 枚举
|
||||||
|
4. 保留 `parse_skill_prefix()` 函数
|
||||||
|
5. 保留 `_GREETING_RE` 和 `_CHAT_MODE_RE` 正则
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- @skill:shell 前缀正确路由到 shell skill
|
||||||
|
- "你好" 路由到 DIRECT_CHAT
|
||||||
|
- "查看当前ip" 路由到 REACT
|
||||||
|
- "查下ip" 路由到 REACT
|
||||||
|
- "翻译hello" 路由到 REACT(LLM 决定不需要工具)
|
||||||
|
- 无前缀无问候的复杂查询路由到 REACT
|
||||||
|
|
||||||
|
**Verification**: 所有测试通过,SimpleRouter.route() 返回正确的 ExecutionMode
|
||||||
|
|
||||||
|
### U2. ReActEngine 改为 Prompt-based XML Tool Calling
|
||||||
|
|
||||||
|
**Goal**: ReActEngine 的 system prompt 注入完整工具描述和 `<tool_use>` 格式说明
|
||||||
|
|
||||||
|
**Requirements**: R3, R4
|
||||||
|
|
||||||
|
**Dependencies**: U1
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/core/react.py` (修改)
|
||||||
|
- `tests/unit/core/test_react_tool_format.py` (新建)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
1. 新增 `_build_tool_use_system_prompt()` 方法,生成包含工具描述和 `<tool_use>` 格式说明的 system prompt
|
||||||
|
2. 在 `execute_stream()` 中,当 LLM 不支持原生 function calling 时,使用 prompt-based 模式
|
||||||
|
3. 确保 `_parse_text_tool_calls()` 正确解析 `<tool_use>` XML 格式(已实现)
|
||||||
|
4. 添加工具描述格式:每个工具包含 name、description、parameters
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- system prompt 包含所有工具描述
|
||||||
|
- `<tool_use>` 格式被正确解析
|
||||||
|
- LLM 不使用工具时直接返回 final_answer
|
||||||
|
- LLM 使用工具时正确执行并返回 observation
|
||||||
|
- 多步工具调用(think → act → observe → think → answer)
|
||||||
|
|
||||||
|
**Verification**: curl 测试"查下ip"正确执行 shell 命令
|
||||||
|
|
||||||
|
### U3. Portal REST/WebSocket 统一路由路径
|
||||||
|
|
||||||
|
**Goal**: Portal REST chat 和 WebSocket 使用相同的 SimpleRouter 路由逻辑
|
||||||
|
|
||||||
|
**Requirements**: R5
|
||||||
|
|
||||||
|
**Dependencies**: U1
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/server/routes/portal.py` (修改)
|
||||||
|
- `src/agentkit/server/app.py` (修改 — 替换 cost_aware_router 为 simple_router)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
1. `_resolve_for_chat()` 改用 SimpleRouter
|
||||||
|
2. WebSocket `portal_websocket()` 改用 SimpleRouter
|
||||||
|
3. 两条路径统一走 SimpleRouter.route() → REACT Agent Loop
|
||||||
|
4. 保留 DIRECT_CHAT 快速路径
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- REST "查看当前ip" 正确执行 shell
|
||||||
|
- WebSocket "查看当前ip" 正确执行 shell
|
||||||
|
- REST "你好" 走 DIRECT_CHAT
|
||||||
|
- WebSocket "你好" 走 DIRECT_CHAT
|
||||||
|
|
||||||
|
**Verification**: curl 和前端测试均通过
|
||||||
|
|
||||||
|
### U4. 聊天记录持久化
|
||||||
|
|
||||||
|
**Goal**: Portal ConversationStore 接入后端 SessionManager,支持 file 持久化
|
||||||
|
|
||||||
|
**Requirements**: R6
|
||||||
|
|
||||||
|
**Dependencies**: U3
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/server/routes/portal.py` (修改)
|
||||||
|
- `src/agentkit/session/manager.py` (修改 — 如需新增方法)
|
||||||
|
- `tests/unit/server/test_portal_persistence.py` (新建)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
1. ConversationStore 委托 SessionManager 进行持久化
|
||||||
|
2. 新消息写入时同步写入 SessionManager
|
||||||
|
3. 加载会话时从 SessionManager 恢复
|
||||||
|
4. 保持内存缓存作为热路径
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- 新消息写入后可从 SessionManager 读取
|
||||||
|
- 服务重启后会话历史保留
|
||||||
|
- 多轮对话上下文正确
|
||||||
|
|
||||||
|
**Verification**: 重启服务后聊天记录仍在
|
||||||
|
|
||||||
|
### U5. 更新 E2E 回测用例和指标
|
||||||
|
|
||||||
|
**Goal**: 更新回测用例覆盖口语化说法,定义和跟踪指标
|
||||||
|
|
||||||
|
**Requirements**: R7, R8
|
||||||
|
|
||||||
|
**Dependencies**: U1, U2, U3
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `tests/e2e/test_capability_router_direct.py` (修改)
|
||||||
|
- `tests/e2e/capability_metrics.py` (修改)
|
||||||
|
- `docs/plans/2026-06-16-005-refactor-routing-architecture-plan.md` (本文档)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
1. 更新回测用例:增加口语化说法("查下ip"、"获取ip"、"看下ip"等)
|
||||||
|
2. 更新指标:增加响应时间、LLM 调用次数、token 消耗
|
||||||
|
3. 定义目标值:执行模式准确率 >85%,工具调用成功率 >95%,口语化成功率 >90%
|
||||||
|
4. 运行回测并记录结果
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- 口语化查询("查下ip")正确路由到 REACT
|
||||||
|
- 工具调用查询正确执行工具
|
||||||
|
- 问候语正确路由到 DIRECT_CHAT
|
||||||
|
- 响应时间 <3s
|
||||||
|
- LLM 调用次数 ≤2
|
||||||
|
|
||||||
|
**Verification**: 回测报告显示所有指标达标
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
| 指标 | 当前值 | 目标值 | 测量方式 |
|
||||||
|
|------|-------|-------|---------|
|
||||||
|
| 执行模式准确率 | 40% | >85% | E2E 回测 |
|
||||||
|
| 工具调用成功率 | 60% | >95% | E2E 回测 |
|
||||||
|
| 口语化查询成功率 | 30% | >90% | E2E 回测 |
|
||||||
|
| 响应时间(简单查询)| 3-5s | <3s | curl -w "%{time_total}" |
|
||||||
|
| 响应时间(工具调用)| 5-8s | <4s | curl -w "%{time_total}" |
|
||||||
|
| LLM 调用次数/查询 | 3 | ≤2 | 日志统计 |
|
||||||
|
| Token 消耗/查询 | ~2400 | <1800 | LLM gateway 统计 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risks & Mitigations
|
||||||
|
|
||||||
|
| 风险 | 影响 | 缓解措施 |
|
||||||
|
|------|------|---------|
|
||||||
|
| 百炼 Coding 不理解 `<tool_use>` 格式 | 工具调用失败 | 已验证模型输出 `<tool_use>`;回退到 Action: 格式 |
|
||||||
|
| 全量工具描述 token 过多 | 响应变慢 | 21 个工具约 2000 tokens,可接受;第二阶段按需加载 |
|
||||||
|
| 删除路由层后 skill 匹配丢失 | 特定 skill 不被选中 | @skill 前缀显式指定;LLM 在 agent loop 中自然匹配 |
|
||||||
|
| 聊天记录迁移不兼容 | 旧数据丢失 | 新旧格式兼容;渐进迁移 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
1. Embedding API key 何时提供?(SemanticRouter 降级为可选插件依赖此 key)
|
||||||
|
2. 是否需要保留 CostAwareRouter 作为可选模式?(向后兼容)
|
||||||
|
|
@ -0,0 +1,197 @@
|
||||||
|
"""Simple router — minimal routing layer for unified REACT agent loop.
|
||||||
|
|
||||||
|
Replaces the 4-layer CostAwareRouter with a simple approach:
|
||||||
|
1. @skill:xxx prefix → explicit skill selection
|
||||||
|
2. Greeting/chitchat regex → DIRECT_CHAT (fast path)
|
||||||
|
3. Everything else → REACT (LLM decides tool usage in agent loop)
|
||||||
|
|
||||||
|
This follows the Hermes/Trae/Codex pattern: no intent prediction layer,
|
||||||
|
LLM sees full tool descriptions and decides autonomously.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from agentkit.chat.skill_routing import (
|
||||||
|
ExecutionMode,
|
||||||
|
SkillRoutingResult,
|
||||||
|
build_skill_system_prompt,
|
||||||
|
parse_skill_prefix,
|
||||||
|
_resolve_execution_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regex patterns for zero-cost direct chat (no LLM call needed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_GREETING_RE = re.compile(
|
||||||
|
r"^(你好|hi|hello|hey|嗨|哈喽|早上好|下午好|晚上好|good morning|good afternoon|good evening)"
|
||||||
|
r"\s*[!!.。??]*$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
_CHAT_MODE_RE = re.compile(
|
||||||
|
r"^(谢谢|感谢|thanks|thank you|ok|好的|嗯|对|是|不是|没关系|再见|bye|goodbye)"
|
||||||
|
r"\s*[!!.。??]*$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
_IDENTITY_RE = re.compile(
|
||||||
|
r"^(你是谁|你叫什么|你是什么|你是哪个|who are you|what are you|what's your name"
|
||||||
|
r"|介绍一下你自己|自我介绍|你叫啥|你叫什么名字|你的名字)"
|
||||||
|
r"\s*[??!!.。]*$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleRouter:
|
||||||
|
"""Minimal routing layer: regex fast-path + default REACT.
|
||||||
|
|
||||||
|
Design rationale:
|
||||||
|
- No HeuristicClassifier: keyword enumeration can never cover all colloquial expressions
|
||||||
|
- No IntentRouter: LLM blind-classification without tool context is unreliable
|
||||||
|
- No SemanticRouter: embedding similarity is not intent recognition
|
||||||
|
- LLM in the REACT agent loop sees full tool descriptions and decides autonomously
|
||||||
|
- This matches Codex/Trae/Hermes architecture: unified agent loop, no routing layer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
skill_registry: SkillRegistry | None = None,
|
||||||
|
default_tools: list[Tool] | None = None,
|
||||||
|
default_system_prompt: str | None = None,
|
||||||
|
default_model: str = "default",
|
||||||
|
default_agent_name: str = "default",
|
||||||
|
) -> None:
|
||||||
|
self._skill_registry = skill_registry
|
||||||
|
self._default_tools = default_tools or []
|
||||||
|
self._default_system_prompt = default_system_prompt
|
||||||
|
self._default_model = default_model
|
||||||
|
self._default_agent_name = default_agent_name
|
||||||
|
|
||||||
|
async def route(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
skill_registry: SkillRegistry | None = None,
|
||||||
|
default_tools: list[Tool] | None = None,
|
||||||
|
default_system_prompt: str | None = None,
|
||||||
|
default_model: str | None = None,
|
||||||
|
default_agent_name: str | None = None,
|
||||||
|
session_id: str = "",
|
||||||
|
transparency: str = "SILENT",
|
||||||
|
) -> SkillRoutingResult:
|
||||||
|
"""Route user input to the appropriate execution path.
|
||||||
|
|
||||||
|
Decision tree:
|
||||||
|
1. @skill:xxx prefix → explicit skill (SKILL_REACT or skill's execution_mode)
|
||||||
|
2. Greeting/chitchat/identity → DIRECT_CHAT (zero-cost fast path)
|
||||||
|
3. Everything else → REACT (LLM decides tool usage in agent loop)
|
||||||
|
"""
|
||||||
|
registry = skill_registry or self._skill_registry
|
||||||
|
tools = default_tools if default_tools is not None else self._default_tools
|
||||||
|
sys_prompt = default_system_prompt if default_system_prompt is not None else self._default_system_prompt
|
||||||
|
model = default_model or self._default_model
|
||||||
|
agent_name = default_agent_name or self._default_agent_name
|
||||||
|
|
||||||
|
# --- Layer 0: @skill:xxx prefix ---
|
||||||
|
explicit_skill, clean_content = parse_skill_prefix(content)
|
||||||
|
if explicit_skill and registry is not None:
|
||||||
|
result = self._route_explicit_skill(
|
||||||
|
explicit_skill, clean_content, registry, model, agent_name
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# --- Layer 1: Greeting/chitchat/identity regex (<1ms, zero tokens) ---
|
||||||
|
stripped = content.strip()
|
||||||
|
if self._is_direct_chat(stripped):
|
||||||
|
result = SkillRoutingResult(
|
||||||
|
clean_content=stripped,
|
||||||
|
matched=False,
|
||||||
|
match_method="regex_direct",
|
||||||
|
match_confidence=1.0,
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=model,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
tools=[],
|
||||||
|
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# --- Default: REACT (LLM decides tool usage) ---
|
||||||
|
result = SkillRoutingResult(
|
||||||
|
clean_content=stripped,
|
||||||
|
matched=False,
|
||||||
|
match_method="default_react",
|
||||||
|
match_confidence=0.8,
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=model,
|
||||||
|
system_prompt=sys_prompt,
|
||||||
|
tools=tools,
|
||||||
|
execution_mode=ExecutionMode.REACT,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _route_explicit_skill(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
clean_content: str,
|
||||||
|
registry: SkillRegistry,
|
||||||
|
model: str,
|
||||||
|
agent_name: str,
|
||||||
|
) -> SkillRoutingResult:
|
||||||
|
"""Route to an explicitly specified skill via @skill:xxx prefix."""
|
||||||
|
try:
|
||||||
|
skill = registry.get(skill_name)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"Skill '{skill_name}' not found, falling back to REACT")
|
||||||
|
return SkillRoutingResult(
|
||||||
|
clean_content=clean_content,
|
||||||
|
matched=False,
|
||||||
|
match_method="skill_not_found_fallback",
|
||||||
|
match_confidence=0.5,
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=model,
|
||||||
|
execution_mode=ExecutionMode.REACT,
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_tools = getattr(skill, "tools", []) or []
|
||||||
|
skill_config = getattr(skill, "config", skill) # Skill wraps SkillConfig
|
||||||
|
skill_prompt = build_skill_system_prompt(skill_config)
|
||||||
|
execution_mode = _resolve_execution_mode(skill_config)
|
||||||
|
|
||||||
|
return SkillRoutingResult(
|
||||||
|
clean_content=clean_content,
|
||||||
|
matched=True,
|
||||||
|
match_method="skill_prefix",
|
||||||
|
match_confidence=1.0,
|
||||||
|
skill_name=skill_name,
|
||||||
|
skill_config=skill,
|
||||||
|
skill_tools=skill_tools,
|
||||||
|
agent_name=skill_name,
|
||||||
|
model=model,
|
||||||
|
system_prompt=skill_prompt,
|
||||||
|
tools=skill_tools,
|
||||||
|
execution_mode=execution_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_direct_chat(text: str) -> bool:
|
||||||
|
"""Check if the input is a greeting, chitchat, or identity question.
|
||||||
|
|
||||||
|
These are zero-cost direct chat: no tool usage, no ReAct loop needed.
|
||||||
|
"""
|
||||||
|
return bool(
|
||||||
|
_GREETING_RE.match(text)
|
||||||
|
or _CHAT_MODE_RE.match(text)
|
||||||
|
or _IDENTITY_RE.match(text)
|
||||||
|
)
|
||||||
|
|
@ -210,7 +210,6 @@ async def resolve_skill_routing(
|
||||||
"搜索",
|
"搜索",
|
||||||
"查找",
|
"查找",
|
||||||
"联网",
|
"联网",
|
||||||
"搜索",
|
|
||||||
"search",
|
"search",
|
||||||
"安装",
|
"安装",
|
||||||
"部署",
|
"部署",
|
||||||
|
|
@ -222,6 +221,17 @@ async def resolve_skill_routing(
|
||||||
"创建",
|
"创建",
|
||||||
"删除",
|
"删除",
|
||||||
"修改",
|
"修改",
|
||||||
|
"查看",
|
||||||
|
"检查",
|
||||||
|
"监控",
|
||||||
|
"测试",
|
||||||
|
"浏览",
|
||||||
|
"下载",
|
||||||
|
"上传",
|
||||||
|
"读取",
|
||||||
|
"写入",
|
||||||
|
"导出",
|
||||||
|
"导入",
|
||||||
"run",
|
"run",
|
||||||
"execute",
|
"execute",
|
||||||
"install",
|
"install",
|
||||||
|
|
@ -230,6 +240,16 @@ async def resolve_skill_routing(
|
||||||
"stop",
|
"stop",
|
||||||
"restart",
|
"restart",
|
||||||
"file",
|
"file",
|
||||||
|
"check",
|
||||||
|
"monitor",
|
||||||
|
"test",
|
||||||
|
"browse",
|
||||||
|
"download",
|
||||||
|
"upload",
|
||||||
|
"read",
|
||||||
|
"write",
|
||||||
|
"export",
|
||||||
|
"import",
|
||||||
]
|
]
|
||||||
content_lower = clean_content.lower()
|
content_lower = clean_content.lower()
|
||||||
needs_tools = any(h in content_lower for h in tool_hints)
|
needs_tools = any(h in content_lower for h in tool_hints)
|
||||||
|
|
@ -297,8 +317,10 @@ async def resolve_skill_routing(
|
||||||
# No skill matched — if we have tools, use ReAct; otherwise direct chat
|
# No skill matched — if we have tools, use ReAct; otherwise direct chat
|
||||||
result.execution_mode = ExecutionMode.REACT if default_tools else ExecutionMode.DIRECT_CHAT
|
result.execution_mode = ExecutionMode.REACT if default_tools else ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
# Append available tools to system prompt so LLM knows what it can call
|
# Append available tools to system prompt only when execution mode supports tool calls
|
||||||
if result.tools:
|
# DIRECT_CHAT mode has no tool execution loop — injecting tool instructions would
|
||||||
|
# cause the LLM to output unparseable tool call JSON as plain text
|
||||||
|
if result.tools and result.execution_mode != ExecutionMode.DIRECT_CHAT:
|
||||||
tools_desc = _build_tools_description(result.tools)
|
tools_desc = _build_tools_description(result.tools)
|
||||||
tool_instruction = (
|
tool_instruction = (
|
||||||
"\n\n## Tool Usage\n"
|
"\n\n## Tool Usage\n"
|
||||||
|
|
@ -446,6 +468,17 @@ class HeuristicClassifier:
|
||||||
"接口",
|
"接口",
|
||||||
"调试",
|
"调试",
|
||||||
"重构",
|
"重构",
|
||||||
|
"查看",
|
||||||
|
"检查",
|
||||||
|
"监控",
|
||||||
|
"测试",
|
||||||
|
"浏览",
|
||||||
|
"下载",
|
||||||
|
"上传",
|
||||||
|
"读取",
|
||||||
|
"写入",
|
||||||
|
"导出",
|
||||||
|
"导入",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file")
|
# 英文关键词使用词边界匹配(避免子串误匹配如 "profile" 匹配 "file")
|
||||||
|
|
@ -474,6 +507,14 @@ class HeuristicClassifier:
|
||||||
"javascript",
|
"javascript",
|
||||||
"typescript",
|
"typescript",
|
||||||
"sql",
|
"sql",
|
||||||
|
"check",
|
||||||
|
"monitor",
|
||||||
|
"test",
|
||||||
|
"browse",
|
||||||
|
"download",
|
||||||
|
"upload",
|
||||||
|
"export",
|
||||||
|
"import",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 英文短词需要精确匹配(避免子串误匹配)
|
# 英文短词需要精确匹配(避免子串误匹配)
|
||||||
|
|
@ -1185,6 +1226,32 @@ class CostAwareRouter:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"CostAwareRouter Layer 2 org_context.find_best_agent failed: {e}")
|
logger.warning(f"CostAwareRouter Layer 2 org_context.find_best_agent failed: {e}")
|
||||||
|
|
||||||
|
# Fallback: high complexity with tools → REACT directly (skip IntentRouter
|
||||||
|
# which tends to misclassify tool-needing queries as direct_agent)
|
||||||
|
if complexity >= 0.5 and default_tools:
|
||||||
|
result = SkillRoutingResult(
|
||||||
|
clean_content=content,
|
||||||
|
matched=False,
|
||||||
|
match_method="complexity_heuristic",
|
||||||
|
match_confidence=0.7,
|
||||||
|
agent_name=default_agent_name,
|
||||||
|
model=default_model,
|
||||||
|
system_prompt=default_system_prompt,
|
||||||
|
tools=default_tools,
|
||||||
|
complexity=complexity,
|
||||||
|
execution_mode=ExecutionMode.REACT,
|
||||||
|
)
|
||||||
|
if trace is not None:
|
||||||
|
trace.append(
|
||||||
|
{
|
||||||
|
"layer": 2,
|
||||||
|
"method": "complexity_heuristic_react",
|
||||||
|
"complexity": complexity,
|
||||||
|
"reason": "high_complexity_with_tools_skip_intent_router",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self._try_team_upgrade(result, content, complexity, trace)
|
||||||
|
|
||||||
# Fallback: 使用 IntentRouter
|
# Fallback: 使用 IntentRouter
|
||||||
result = await resolve_skill_routing(
|
result = await resolve_skill_routing(
|
||||||
content=content,
|
content=content,
|
||||||
|
|
@ -1401,7 +1468,8 @@ class CostAwareRouter:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Intent routing for low-complexity query failed: {e}")
|
logger.warning(f"Intent routing for low-complexity query failed: {e}")
|
||||||
|
|
||||||
# No semantic or intent match → direct chat
|
# No semantic or intent match → use REACT if tools available, otherwise direct chat
|
||||||
|
# Low complexity does NOT mean "no tools needed" — e.g. "查看当前ip" needs shell
|
||||||
result = SkillRoutingResult(
|
result = SkillRoutingResult(
|
||||||
clean_content=clean_content,
|
clean_content=clean_content,
|
||||||
system_prompt=default_system_prompt,
|
system_prompt=default_system_prompt,
|
||||||
|
|
@ -1412,7 +1480,9 @@ class CostAwareRouter:
|
||||||
match_method="low_complexity",
|
match_method="low_complexity",
|
||||||
match_confidence=1.0 - complexity,
|
match_confidence=1.0 - complexity,
|
||||||
complexity=complexity,
|
complexity=complexity,
|
||||||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
execution_mode=ExecutionMode.REACT
|
||||||
|
if default_tools
|
||||||
|
else ExecutionMode.DIRECT_CHAT,
|
||||||
)
|
)
|
||||||
trace.append(
|
trace.append(
|
||||||
{
|
{
|
||||||
|
|
@ -1488,13 +1558,19 @@ class CostAwareRouter:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Short text fallback: if semantic router returned low confidence
|
# Short text fallback: if semantic router returned low confidence
|
||||||
# and text is short (<20 chars), force LLM classify for better routing
|
# and text is short (<20 chars), force LLM classify for better routing.
|
||||||
|
# BUT: skip LLM fallback when HeuristicClassifier already detected
|
||||||
|
# high-complexity signals (e.g. "查看ip" has "查看" → complexity >= 0.65).
|
||||||
|
# In that case the routing outcome is already clear (REACT mode),
|
||||||
|
# and an extra LLM call would only waste 1-3 seconds.
|
||||||
short_text_llm_hint = None
|
short_text_llm_hint = None
|
||||||
if (
|
if (
|
||||||
skill_hint is None
|
skill_hint is None
|
||||||
and len(clean_content) < 20
|
and len(clean_content) < 20
|
||||||
and self._merged_llm_classify
|
and self._merged_llm_classify
|
||||||
and self._llm_gateway is not None
|
and self._llm_gateway is not None
|
||||||
|
and complexity
|
||||||
|
< 0.5 # Only trigger LLM fallback for truly ambiguous low-complexity queries
|
||||||
):
|
):
|
||||||
short_text_llm_hint = True
|
short_text_llm_hint = True
|
||||||
trace.append(
|
trace.append(
|
||||||
|
|
@ -1507,7 +1583,10 @@ class CostAwareRouter:
|
||||||
|
|
||||||
# Medium complexity → merged LLM classify or IntentRouter
|
# Medium complexity → merged LLM classify or IntentRouter
|
||||||
# Short text with no semantic match forces LLM classify
|
# Short text with no semantic match forces LLM classify
|
||||||
if complexity <= 0.7 or short_text_llm_hint:
|
# BUT: if HeuristicClassifier already detected high-complexity signals
|
||||||
|
# (complexity >= 0.5), LLM classify tends to override correct routing
|
||||||
|
# with "direct_agent" — skip it and go straight to IntentRouter
|
||||||
|
if (complexity <= 0.7 and complexity < 0.5) or short_text_llm_hint:
|
||||||
if self._merged_llm_classify and self._llm_gateway is not None:
|
if self._merged_llm_classify and self._llm_gateway is not None:
|
||||||
# Use merged LLM call: complexity + intent in one call
|
# Use merged LLM call: complexity + intent in one call
|
||||||
result = await self._classify_merged(
|
result = await self._classify_merged(
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,14 @@ from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.protocol import LLMResponse
|
from agentkit.llm.protocol import LLMResponse
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
from agentkit.telemetry.tracing import get_tracer, start_span, _OTEL_AVAILABLE
|
from agentkit.telemetry.tracing import start_span, _OTEL_AVAILABLE
|
||||||
from agentkit.telemetry.metrics import (
|
from agentkit.telemetry.metrics import (
|
||||||
agent_request_counter,
|
agent_request_counter,
|
||||||
agent_duration_histogram,
|
agent_duration_histogram,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
|
from agentkit.core.compressor import CompressionStrategy
|
||||||
from agentkit.core.trace import TraceRecorder
|
from agentkit.core.trace import TraceRecorder
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
|
@ -195,6 +195,15 @@ class ReActEngine:
|
||||||
else:
|
else:
|
||||||
logger.info("ReActEngine executing with NO tools")
|
logger.info("ReActEngine executing with NO tools")
|
||||||
|
|
||||||
|
# Prompt-based tool calling: inject tool descriptions into system prompt
|
||||||
|
# when tools are available, so LLM can use <tool_use> format even if
|
||||||
|
# the provider doesn't support native function calling.
|
||||||
|
if tools and system_prompt is not None:
|
||||||
|
tool_desc = self._build_tool_use_prompt(tools)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{tool_desc}"
|
||||||
|
elif tools and system_prompt is None:
|
||||||
|
system_prompt = self._build_tool_use_prompt(tools)
|
||||||
|
|
||||||
# Telemetry: record agent request
|
# Telemetry: record agent request
|
||||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||||
|
|
||||||
|
|
@ -651,6 +660,15 @@ class ReActEngine:
|
||||||
else:
|
else:
|
||||||
logger.info("ReActEngine executing with NO tools")
|
logger.info("ReActEngine executing with NO tools")
|
||||||
|
|
||||||
|
# Prompt-based tool calling: inject tool descriptions into system prompt
|
||||||
|
# when tools are available, so LLM can use <tool_use> format even if
|
||||||
|
# the provider doesn't support native function calling.
|
||||||
|
if tools and system_prompt is not None:
|
||||||
|
tool_desc = self._build_tool_use_prompt(tools)
|
||||||
|
system_prompt = f"{system_prompt}\n\n{tool_desc}"
|
||||||
|
elif tools and system_prompt is None:
|
||||||
|
system_prompt = self._build_tool_use_prompt(tools)
|
||||||
|
|
||||||
# Telemetry: record agent request
|
# Telemetry: record agent request
|
||||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||||
|
|
||||||
|
|
@ -1141,6 +1159,47 @@ class ReActEngine:
|
||||||
schemas.append(schema)
|
schemas.append(schema)
|
||||||
return schemas
|
return schemas
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_use_prompt(tools: list[Tool]) -> str:
|
||||||
|
"""Build prompt-based tool calling instructions for LLMs that don't
|
||||||
|
support native function calling (e.g., Bailian Coding, Qwen).
|
||||||
|
|
||||||
|
Instructs the LLM to use <tool_use> XML format for tool invocation.
|
||||||
|
This follows the Hermes pattern: model-agnostic prompt-based tool calling.
|
||||||
|
"""
|
||||||
|
tool_descriptions = []
|
||||||
|
for tool in tools:
|
||||||
|
params_desc = ""
|
||||||
|
if tool.input_schema:
|
||||||
|
props = tool.input_schema.get("properties", {})
|
||||||
|
required = tool.input_schema.get("required", [])
|
||||||
|
param_parts = []
|
||||||
|
for pname, pinfo in props.items():
|
||||||
|
ptype = pinfo.get("type", "string")
|
||||||
|
pdesc = pinfo.get("description", "")
|
||||||
|
req_flag = " (required)" if pname in required else ""
|
||||||
|
param_parts.append(f" - {pname}: {ptype}{req_flag} — {pdesc}")
|
||||||
|
if param_parts:
|
||||||
|
params_desc = "\n".join(param_parts)
|
||||||
|
tool_descriptions.append(
|
||||||
|
f"- {tool.name}: {tool.description}\n{params_desc}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tools_text = "\n\n".join(tool_descriptions)
|
||||||
|
return (
|
||||||
|
"## 可用工具\n\n"
|
||||||
|
"你可以使用以下工具来完成任务。当需要调用工具时,使用以下格式:\n\n"
|
||||||
|
"<tool_use>\n"
|
||||||
|
'{"name": "工具名", "arguments": {"参数名": "参数值"}}\n'
|
||||||
|
"</tool_use>\n\n"
|
||||||
|
"重要规则:\n"
|
||||||
|
"1. 每次只调用一个工具\n"
|
||||||
|
"2. 等待工具返回结果后再决定下一步\n"
|
||||||
|
"3. 如果不需要工具就能回答,直接回答即可\n"
|
||||||
|
"4. 不要在回答中重复工具的输出,而是基于结果给出有用的总结\n\n"
|
||||||
|
f"工具列表:\n\n{tools_text}"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_response_from_stream(
|
def _build_response_from_stream(
|
||||||
content: str,
|
content: str,
|
||||||
|
|
@ -1339,9 +1398,10 @@ class ReActEngine:
|
||||||
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
|
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
|
||||||
"""从文本中解析工具调用模式
|
"""从文本中解析工具调用模式
|
||||||
|
|
||||||
支持两种格式:
|
支持格式:
|
||||||
1. Action: tool_name(args)
|
1. Action: tool_name(args)
|
||||||
2. ```tool\\n{"name": "...", "arguments": {...}}\\n```
|
2. ```tool\n{"name": "...", "arguments": {...}}\n```
|
||||||
|
3. <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
|
||||||
"""
|
"""
|
||||||
calls: list[dict[str, Any]] = []
|
calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
|
@ -1376,4 +1436,35 @@ class ReActEngine:
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
logger.warning(f"Failed to parse tool call from text: {json_str}")
|
logger.warning(f"Failed to parse tool call from text: {json_str}")
|
||||||
|
|
||||||
|
if calls:
|
||||||
|
return calls
|
||||||
|
|
||||||
|
# 格式 3: <tool_use>\n{"name": "...", "arguments": {...}}\n</tool_use>
|
||||||
|
# 兼容 Anthropic/Qwen 等模型在文本中模拟的工具调用格式
|
||||||
|
tool_use_pattern = re.compile(
|
||||||
|
r"<tool_use>\s*(.*?)\s*</tool_use>", re.DOTALL
|
||||||
|
)
|
||||||
|
for match in tool_use_pattern.finditer(content):
|
||||||
|
json_str = match.group(1).strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(json_str)
|
||||||
|
name = parsed.get("name", "")
|
||||||
|
arguments = parsed.get("arguments", {})
|
||||||
|
if name:
|
||||||
|
calls.append({"name": name, "arguments": arguments})
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
# Try XML-like inner tags: <name>x</name><arguments>{...}</arguments>
|
||||||
|
name_match = re.search(r"<name>\s*(.*?)\s*</name>", json_str, re.DOTALL)
|
||||||
|
args_match = re.search(r"<arguments>\s*(.*?)\s*</arguments>", json_str, re.DOTALL)
|
||||||
|
if name_match:
|
||||||
|
name = name_match.group(1).strip()
|
||||||
|
args_str = args_match.group(1).strip() if args_match else "{}"
|
||||||
|
try:
|
||||||
|
arguments = json.loads(args_str)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
arguments = {"raw": args_str}
|
||||||
|
calls.append({"name": name, "arguments": arguments})
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to parse tool_use block: {json_str[:200]}")
|
||||||
|
|
||||||
return calls
|
return calls
|
||||||
|
|
|
||||||
|
|
@ -149,18 +149,6 @@ async def lifespan(app: FastAPI):
|
||||||
# Start MCP servers if configured
|
# Start MCP servers if configured
|
||||||
mcp_manager = getattr(app.state, "mcp_manager", None)
|
mcp_manager = getattr(app.state, "mcp_manager", None)
|
||||||
|
|
||||||
# Build semantic router index after skill registry is populated
|
|
||||||
semantic_router = getattr(
|
|
||||||
getattr(app.state, "cost_aware_router", None), "_semantic_router", None
|
|
||||||
)
|
|
||||||
if semantic_router is not None:
|
|
||||||
try:
|
|
||||||
await semantic_router.build_index(app.state.skill_registry)
|
|
||||||
logger.info(
|
|
||||||
f"Semantic router index built with {len(app.state.skill_registry.list_skills())} skills"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to build semantic router index: {e}")
|
|
||||||
if mcp_manager is not None:
|
if mcp_manager is not None:
|
||||||
await mcp_manager.start_all()
|
await mcp_manager.start_all()
|
||||||
|
|
||||||
|
|
@ -586,6 +574,14 @@ def create_app(
|
||||||
app.state.quality_gate = QualityGate()
|
app.state.quality_gate = QualityGate()
|
||||||
app.state.output_standardizer = OutputStandardizer()
|
app.state.output_standardizer = OutputStandardizer()
|
||||||
|
|
||||||
|
# Initialize SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||||
|
from agentkit.chat.simple_router import SimpleRouter
|
||||||
|
|
||||||
|
simple_router = SimpleRouter(
|
||||||
|
skill_registry=app.state.skill_registry,
|
||||||
|
)
|
||||||
|
app.state.simple_router = simple_router
|
||||||
|
|
||||||
# Initialize OrganizationContext from AgentPool + SkillRegistry
|
# Initialize OrganizationContext from AgentPool + SkillRegistry
|
||||||
from agentkit.org.context import OrganizationContext
|
from agentkit.org.context import OrganizationContext
|
||||||
|
|
||||||
|
|
@ -605,37 +601,39 @@ def create_app(
|
||||||
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
|
alignment_guard = AlignmentGuard(config=alignment_config, llm_gateway=app.state.llm_gateway)
|
||||||
app.state.alignment_guard = alignment_guard
|
app.state.alignment_guard = alignment_guard
|
||||||
|
|
||||||
# Initialize CostAwareRouter
|
# CostAwareRouter is no longer used by portal/chat routes (replaced by SimpleRouter).
|
||||||
from agentkit.chat.skill_routing import CostAwareRouter
|
# It is kept on app.state for backward compatibility with any external consumers.
|
||||||
|
# To re-enable, set router.legacy_cost_aware_router: true in agentkit.yaml.
|
||||||
auction_enabled = False
|
|
||||||
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
|
||||||
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
|
||||||
|
|
||||||
# Initialize semantic router if configured
|
|
||||||
semantic_router = None
|
|
||||||
router_conf = server_config.router if server_config and server_config.router else {}
|
router_conf = server_config.router if server_config and server_config.router else {}
|
||||||
if router_conf.get("semantic", {}).get("enabled"):
|
if router_conf.get("legacy_cost_aware_router"):
|
||||||
try:
|
from agentkit.chat.skill_routing import CostAwareRouter
|
||||||
from agentkit.chat.semantic_router import SemanticRouter
|
|
||||||
|
|
||||||
semantic_router = SemanticRouter(
|
auction_enabled = False
|
||||||
embedder=app.state.llm_gateway._embedder,
|
if server_config and hasattr(server_config, "marketplace") and server_config.marketplace:
|
||||||
similarity_high=router_conf["semantic"].get("similarity_high", 0.85),
|
auction_enabled = server_config.marketplace.get("auction_enabled", False)
|
||||||
similarity_low=router_conf["semantic"].get("similarity_low", 0.6),
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to initialize semantic router: {e}")
|
|
||||||
|
|
||||||
cost_aware_router = CostAwareRouter(
|
semantic_router = None
|
||||||
llm_gateway=app.state.llm_gateway,
|
if router_conf.get("semantic", {}).get("enabled"):
|
||||||
org_context=org_context,
|
try:
|
||||||
auction_enabled=auction_enabled,
|
from agentkit.chat.semantic_router import SemanticRouter
|
||||||
classifier=router_conf.get("classifier", "heuristic"),
|
|
||||||
merged_llm_classify=router_conf.get("merged_llm_classify", True),
|
semantic_router = SemanticRouter(
|
||||||
semantic_router=semantic_router,
|
embedder=app.state.llm_gateway._embedder,
|
||||||
)
|
similarity_high=router_conf["semantic"].get("similarity_high", 0.85),
|
||||||
app.state.cost_aware_router = cost_aware_router
|
similarity_low=router_conf["semantic"].get("similarity_low", 0.6),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize semantic router: {e}")
|
||||||
|
|
||||||
|
cost_aware_router = CostAwareRouter(
|
||||||
|
llm_gateway=app.state.llm_gateway,
|
||||||
|
org_context=org_context,
|
||||||
|
auction_enabled=auction_enabled,
|
||||||
|
classifier=router_conf.get("classifier", "heuristic"),
|
||||||
|
merged_llm_classify=router_conf.get("merged_llm_classify", True),
|
||||||
|
semantic_router=semantic_router,
|
||||||
|
)
|
||||||
|
app.state.cost_aware_router = cost_aware_router
|
||||||
# Initialize task store from config
|
# Initialize task store from config
|
||||||
ts_config = server_config.task_store if server_config else {}
|
ts_config = server_config.task_store if server_config else {}
|
||||||
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
# Merge CLI overrides from AGENTKIT_TASK_STORE env var
|
||||||
|
|
@ -677,6 +675,10 @@ def create_app(
|
||||||
)
|
)
|
||||||
app.state.session_manager = SessionManager(store=session_store)
|
app.state.session_manager = SessionManager(store=session_store)
|
||||||
|
|
||||||
|
# Inject SessionManager into Portal's ConversationStore for persistence
|
||||||
|
from agentkit.server.routes.portal import _conversation_store
|
||||||
|
_conversation_store.set_session_manager(app.state.session_manager)
|
||||||
|
|
||||||
# Initialize evolution store if configured
|
# Initialize evolution store if configured
|
||||||
if server_config and hasattr(server_config, "evolution") and server_config.evolution:
|
if server_config and hasattr(server_config, "evolution") and server_config.evolution:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from typing import Any
|
||||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
|
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode
|
||||||
from agentkit.core.protocol import CancellationToken
|
from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.session.manager import SessionManager
|
from agentkit.session.manager import SessionManager
|
||||||
|
|
@ -211,7 +212,7 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
||||||
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
||||||
|
|
||||||
# Append user message
|
# Append user message
|
||||||
user_msg = await sm.append_message(
|
await sm.append_message(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
role=MessageRole.USER,
|
role=MessageRole.USER,
|
||||||
content=request.content,
|
content=request.content,
|
||||||
|
|
@ -440,11 +441,9 @@ async def _handle_chat_message(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a user message: append to session, execute Agent, stream events.
|
"""Handle a user message: append to session, execute Agent, stream events.
|
||||||
|
|
||||||
When skills are registered, attempts to route the user's message to a
|
Uses SimpleRouter for minimal routing: @skill prefix + greeting regex + REACT.
|
||||||
matching skill via IntentRouter. If a skill is matched, the skill's
|
|
||||||
prompt, tools, and execution_mode are used instead of the default agent's.
|
|
||||||
"""
|
"""
|
||||||
from agentkit.chat.skill_routing import resolve_skill_routing
|
from agentkit.chat.simple_router import SimpleRouter
|
||||||
|
|
||||||
# Resolve Agent first (needed for default tools/prompt)
|
# Resolve Agent first (needed for default tools/prompt)
|
||||||
pool = websocket.app.state.agent_pool
|
pool = websocket.app.state.agent_pool
|
||||||
|
|
@ -463,20 +462,17 @@ async def _handle_chat_message(
|
||||||
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
||||||
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
|
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
|
||||||
|
|
||||||
# Resolve skill routing using shared module
|
# Resolve skill routing using SimpleRouter
|
||||||
skill_registry = getattr(websocket.app.state, "skill_registry", None)
|
skill_registry = getattr(websocket.app.state, "skill_registry", None)
|
||||||
intent_router = getattr(websocket.app.state, "intent_router", None)
|
simple_router: SimpleRouter = websocket.app.state.simple_router
|
||||||
|
|
||||||
routing = await resolve_skill_routing(
|
routing = await simple_router.route(
|
||||||
content=content,
|
content=content,
|
||||||
skill_registry=skill_registry,
|
skill_registry=skill_registry,
|
||||||
intent_router=intent_router,
|
|
||||||
default_tools=default_tools,
|
default_tools=default_tools,
|
||||||
default_system_prompt=default_system_prompt,
|
default_system_prompt=default_system_prompt,
|
||||||
default_model=default_model,
|
default_model=default_model,
|
||||||
default_agent_name=agent.name,
|
default_agent_name=agent.name,
|
||||||
agent_tool_registry=agent._tool_registry if agent._tool_registry else None,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Debug: log tools that will be passed to ReActEngine
|
# Debug: log tools that will be passed to ReActEngine
|
||||||
|
|
@ -504,6 +500,45 @@ async def _handle_chat_message(
|
||||||
# Get full conversation history
|
# Get full conversation history
|
||||||
chat_messages = await sm.get_chat_messages(session_id)
|
chat_messages = await sm.get_chat_messages(session_id)
|
||||||
|
|
||||||
|
# Handle DIRECT_CHAT: direct LLM call, no ReAct loop
|
||||||
|
if routing.execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||||
|
direct_messages = []
|
||||||
|
if routing.system_prompt:
|
||||||
|
direct_messages.append({"role": "system", "content": routing.system_prompt})
|
||||||
|
direct_messages.extend(chat_messages)
|
||||||
|
try:
|
||||||
|
response = await websocket.app.state.llm_gateway.chat(
|
||||||
|
messages=direct_messages,
|
||||||
|
model=routing.model or "default",
|
||||||
|
agent_name=agent.name,
|
||||||
|
task_type="chat",
|
||||||
|
)
|
||||||
|
final_content = response.content or ""
|
||||||
|
if final_content:
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "final_answer",
|
||||||
|
"content": final_content,
|
||||||
|
"is_final": True,
|
||||||
|
})
|
||||||
|
await sm.append_message(
|
||||||
|
session_id=session_id,
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
content=final_content,
|
||||||
|
agent_name=agent.name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
||||||
|
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
||||||
|
# currently fall back to REACT with a warning.
|
||||||
|
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
||||||
|
logger.warning(
|
||||||
|
f"Execution mode {routing.execution_mode.value} not yet supported "
|
||||||
|
f"in chat WebSocket, falling back to REACT"
|
||||||
|
)
|
||||||
|
|
||||||
# Execute Agent with streaming
|
# Execute Agent with streaming
|
||||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
||||||
react_engine = getattr(agent, "_react_engine", None)
|
react_engine = getattr(agent, "_react_engine", None)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import os
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
|
|
@ -20,15 +19,16 @@ from fastapi import (
|
||||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentkit.core.protocol import TaskMessage
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.chat.skill_routing import ExecutionMode
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
from agentkit.router.intent import IntentRouter
|
from agentkit.chat.simple_router import SimpleRouter
|
||||||
from agentkit.server.routes.evolution_dashboard import (
|
from agentkit.server.routes.evolution_dashboard import (
|
||||||
_experiences as _dashboard_experiences,
|
_experiences as _dashboard_experiences,
|
||||||
DashboardExperience,
|
DashboardExperience,
|
||||||
_broadcast_event as _broadcast_dashboard_event,
|
_broadcast_event as _broadcast_dashboard_event,
|
||||||
)
|
)
|
||||||
|
from agentkit.session.manager import SessionManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -88,9 +88,21 @@ class Conversation:
|
||||||
|
|
||||||
|
|
||||||
class ConversationStore:
|
class ConversationStore:
|
||||||
def __init__(self, max_conversations: int = 1000):
|
"""In-memory conversation store with optional SessionManager persistence.
|
||||||
|
|
||||||
|
When a session_manager is provided, messages are also persisted via
|
||||||
|
SessionManager (which supports file/redis backends). On startup,
|
||||||
|
conversations can be restored from SessionManager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_conversations: int = 1000, session_manager: SessionManager | None = None):
|
||||||
self._conversations: dict[str, Conversation] = {}
|
self._conversations: dict[str, Conversation] = {}
|
||||||
self._max = max_conversations
|
self._max = max_conversations
|
||||||
|
self._session_manager = session_manager
|
||||||
|
|
||||||
|
def set_session_manager(self, sm: SessionManager | None) -> None:
|
||||||
|
"""Set or update the session manager for persistence."""
|
||||||
|
self._session_manager = sm
|
||||||
|
|
||||||
def get_or_create(self, conversation_id: str | None = None) -> Conversation:
|
def get_or_create(self, conversation_id: str | None = None) -> Conversation:
|
||||||
if conversation_id and conversation_id in self._conversations:
|
if conversation_id and conversation_id in self._conversations:
|
||||||
|
|
@ -107,15 +119,37 @@ class ConversationStore:
|
||||||
del self._conversations[oldest_id]
|
del self._conversations[oldest_id]
|
||||||
return conv
|
return conv
|
||||||
|
|
||||||
def add_message(
|
async def add_message(
|
||||||
self, conversation_id: str, role: str, content: str, metadata: dict | None = None
|
self, conversation_id: str, role: str, content: str, metadata: dict | None = None
|
||||||
) -> ChatMessage:
|
) -> ChatMessage:
|
||||||
|
"""Add a message to conversation, with optional persistence."""
|
||||||
conv = self._conversations.get(conversation_id)
|
conv = self._conversations.get(conversation_id)
|
||||||
if conv is None:
|
if conv is None:
|
||||||
raise KeyError(f"Conversation '{conversation_id}' not found")
|
raise KeyError(f"Conversation '{conversation_id}' not found")
|
||||||
msg = ChatMessage(role=role, content=content, metadata=metadata or {})
|
msg = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
conv.messages.append(msg)
|
conv.messages.append(msg)
|
||||||
conv.updated_at = datetime.now(timezone.utc)
|
conv.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Persist to SessionManager if available
|
||||||
|
if self._session_manager is not None:
|
||||||
|
try:
|
||||||
|
from agentkit.session.models import MessageRole
|
||||||
|
|
||||||
|
sm = self._session_manager
|
||||||
|
role_enum = MessageRole.USER if role == "user" else MessageRole.ASSISTANT
|
||||||
|
await sm.append_message(
|
||||||
|
session_id=conversation_id,
|
||||||
|
role=role_enum,
|
||||||
|
content=content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to persist message to SessionManager: {e}")
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def get_history(self, conversation_id: str, limit: int = 50) -> list[ChatMessage]:
|
def get_history(self, conversation_id: str, limit: int = 50) -> list[ChatMessage]:
|
||||||
|
|
@ -257,58 +291,68 @@ class CapabilitiesResponse(BaseModel):
|
||||||
|
|
||||||
async def _resolve_for_chat(
|
async def _resolve_for_chat(
|
||||||
request: ChatRequest, req: Request
|
request: ChatRequest, req: Request
|
||||||
) -> tuple[Any, Any, str | None, str | None, float | None]:
|
) -> tuple[ConfigDrivenAgent | None, SkillRoutingResult | None, str | None, str | None, float | None]:
|
||||||
"""Resolve agent and skill for a chat request.
|
"""Resolve agent and routing for a chat request via SimpleRouter.
|
||||||
|
|
||||||
Returns (agent, skill, matched_skill_name, routing_method, confidence).
|
Returns (agent, routing_result, matched_skill_name, routing_method, confidence).
|
||||||
"""
|
"""
|
||||||
pool = req.app.state.agent_pool
|
pool = req.app.state.agent_pool
|
||||||
skill_registry = req.app.state.skill_registry
|
skill_registry = req.app.state.skill_registry
|
||||||
intent_router: IntentRouter = req.app.state.intent_router
|
simple_router: SimpleRouter = req.app.state.simple_router
|
||||||
|
|
||||||
matched_skill_name: str | None = None
|
matched_skill_name: str | None = None
|
||||||
routing_method: str | None = None
|
routing_method: str | None = None
|
||||||
confidence: float | None = None
|
confidence: float | None = None
|
||||||
|
|
||||||
if request.skill_name:
|
# Get default tools and system prompt
|
||||||
# Use specified skill directly
|
default_tools = []
|
||||||
try:
|
default_system_prompt = None
|
||||||
skill = skill_registry.get(request.skill_name)
|
default_agent = pool.get_agent("default")
|
||||||
except Exception:
|
if default_agent is not None:
|
||||||
raise HTTPException(
|
default_tools = default_agent.get_tools()
|
||||||
status_code=404,
|
default_system_prompt = (
|
||||||
detail=f"Skill '{request.skill_name}' not found",
|
getattr(default_agent, "_system_prompt", None)
|
||||||
)
|
or default_agent.get_system_prompt()
|
||||||
matched_skill_name = request.skill_name
|
|
||||||
routing_method = "direct"
|
|
||||||
confidence = 1.0
|
|
||||||
agent = pool.get_agent(request.skill_name)
|
|
||||||
if agent is None:
|
|
||||||
agent = await pool.create_agent_from_skill(request.skill_name)
|
|
||||||
return agent, skill, matched_skill_name, routing_method, confidence
|
|
||||||
|
|
||||||
# Use IntentRouter
|
|
||||||
all_skills = skill_registry.list_skills()
|
|
||||||
if not all_skills:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="No skills available. Please register skills first.",
|
|
||||||
)
|
)
|
||||||
try:
|
else:
|
||||||
routing_result = await intent_router.route(
|
all_skills = skill_registry.list_skills()
|
||||||
{"query": request.message, "sources": request.sources}, all_skills
|
for skill in all_skills:
|
||||||
)
|
agent = pool.get_agent(skill.name)
|
||||||
matched_skill_name = routing_result.matched_skill
|
if agent is not None:
|
||||||
routing_method = routing_result.method
|
default_tools = agent.get_tools()
|
||||||
confidence = routing_result.confidence
|
default_system_prompt = (
|
||||||
skill = skill_registry.get(matched_skill_name)
|
getattr(agent, "_system_prompt", None) or agent.get_system_prompt()
|
||||||
agent = pool.get_agent(matched_skill_name)
|
)
|
||||||
if agent is None:
|
break
|
||||||
agent = await pool.create_agent_from_skill(matched_skill_name)
|
|
||||||
except (ValueError, RuntimeError) as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
return agent, skill, matched_skill_name, routing_method, confidence
|
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||||
|
routing_result = await simple_router.route(
|
||||||
|
content=request.message,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=default_tools,
|
||||||
|
default_system_prompt=default_system_prompt,
|
||||||
|
default_model="default",
|
||||||
|
default_agent_name="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
matched_skill_name = routing_result.skill_name or routing_result.agent_name
|
||||||
|
routing_method = routing_result.match_method
|
||||||
|
confidence = routing_result.match_confidence
|
||||||
|
|
||||||
|
# Get or create agent based on routing result
|
||||||
|
if routing_result.matched and routing_result.skill_name:
|
||||||
|
agent = pool.get_agent(routing_result.skill_name)
|
||||||
|
if agent is None:
|
||||||
|
agent = await pool.create_agent_from_skill(routing_result.skill_name)
|
||||||
|
else:
|
||||||
|
agent = pool.get_agent("default")
|
||||||
|
if agent is None:
|
||||||
|
# Fallback: try to create from first available skill
|
||||||
|
all_skills = skill_registry.list_skills()
|
||||||
|
if all_skills:
|
||||||
|
agent = await pool.create_agent_from_skill(all_skills[0].name)
|
||||||
|
|
||||||
|
return agent, routing_result, matched_skill_name, routing_method, confidence
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -318,95 +362,68 @@ async def _resolve_for_chat(
|
||||||
|
|
||||||
@router.post("/portal/chat", response_model=ChatResponse)
|
@router.post("/portal/chat", response_model=ChatResponse)
|
||||||
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
async def chat(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||||
"""Send a chat message and get a response with intent routing."""
|
"""Send a chat message and get a response with CostAwareRouter routing."""
|
||||||
agent, skill, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
||||||
|
|
||||||
# Create or reuse conversation
|
# Create or reuse conversation
|
||||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
conv = _conversation_store.get_or_create(request.conversation_id)
|
||||||
_conversation_store.add_message(conv.id, "user", request.message)
|
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||||
|
|
||||||
# Build task and execute
|
llm_gateway = req.app.state.llm_gateway
|
||||||
task = TaskMessage(
|
|
||||||
task_id=str(uuid.uuid4()),
|
|
||||||
agent_name=agent.name,
|
|
||||||
task_type=agent.agent_type,
|
|
||||||
priority=0,
|
|
||||||
input_data={"query": request.message, "sources": request.sources},
|
|
||||||
callback_url=None,
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
task_result = await agent.execute(task)
|
task_id = str(uuid.uuid4())
|
||||||
|
response_text = ""
|
||||||
|
|
||||||
# Extract response text
|
if routing_result is not None and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||||
if task_result.output_data:
|
# DIRECT_CHAT: direct LLM call, no ReAct loop (same as WebSocket path)
|
||||||
if isinstance(task_result.output_data, dict):
|
chat_messages = []
|
||||||
response_text = (
|
if routing_result.system_prompt:
|
||||||
task_result.output_data.get("result")
|
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||||
or task_result.output_data.get("output")
|
chat_messages.append({"role": "user", "content": request.message})
|
||||||
or json.dumps(task_result.output_data, ensure_ascii=False)
|
# Inject conversation history
|
||||||
)
|
history_msgs = _build_history_messages(conv.id)
|
||||||
else:
|
for hm in history_msgs:
|
||||||
response_text = str(task_result.output_data)
|
chat_messages.insert(-1, hm)
|
||||||
elif task_result.error_message:
|
response = await llm_gateway.chat(
|
||||||
response_text = task_result.error_message
|
messages=chat_messages,
|
||||||
|
model=routing_result.model or "default",
|
||||||
|
agent_name="default",
|
||||||
|
task_type="chat",
|
||||||
|
)
|
||||||
|
response_text = response.content or ""
|
||||||
else:
|
else:
|
||||||
response_text = ""
|
# REACT / SKILL_REACT / REWOO / REFLEXION / PLAN_EXEC / TEAM_COLLAB
|
||||||
|
# Advanced modes (REWOO, REFLEXION, PLAN_EXEC, TEAM_COLLAB) currently
|
||||||
|
# fall back to REACT with a warning. Full integration is tracked separately.
|
||||||
|
if routing_result is not None and routing_result.execution_mode not in (
|
||||||
|
ExecutionMode.REACT,
|
||||||
|
ExecutionMode.SKILL_REACT,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Execution mode {routing_result.execution_mode.value} not yet supported "
|
||||||
|
f"in portal REST, falling back to REACT"
|
||||||
|
)
|
||||||
|
|
||||||
_conversation_store.add_message(conv.id, "assistant", response_text)
|
|
||||||
|
|
||||||
return ChatResponse(
|
|
||||||
conversation_id=conv.id,
|
|
||||||
message=response_text,
|
|
||||||
matched_skill=matched_skill,
|
|
||||||
routing_method=routing_method,
|
|
||||||
confidence=confidence,
|
|
||||||
task_id=task.task_id,
|
|
||||||
status="completed",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/portal/chat/stream")
|
|
||||||
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
|
||||||
"""Stream chat responses via SSE."""
|
|
||||||
from sse_starlette.sse import EventSourceResponse
|
|
||||||
|
|
||||||
agent, skill, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
|
||||||
|
|
||||||
# Create or reuse conversation
|
|
||||||
conv = _conversation_store.get_or_create(request.conversation_id)
|
|
||||||
_conversation_store.add_message(conv.id, "user", request.message)
|
|
||||||
|
|
||||||
async def event_generator():
|
|
||||||
react_config = agent.get_react_config()
|
react_config = agent.get_react_config()
|
||||||
# Reuse agent's ReActEngine if available (aligned with chat.py pattern)
|
|
||||||
react_engine = getattr(agent, "_react_engine", None)
|
react_engine = getattr(agent, "_react_engine", None)
|
||||||
if react_engine is None:
|
if react_engine is None:
|
||||||
react_engine = ReActEngine(
|
react_engine = ReActEngine(
|
||||||
llm_gateway=req.app.state.llm_gateway,
|
llm_gateway=llm_gateway,
|
||||||
max_steps=react_config["max_steps"],
|
max_steps=react_config["max_steps"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
react_engine.reset()
|
react_engine.reset()
|
||||||
|
|
||||||
messages = [{"role": "user", "content": request.message}]
|
messages = [{"role": "user", "content": request.message}]
|
||||||
|
# Inject conversation history
|
||||||
|
history_msgs = _build_history_messages(conv.id)
|
||||||
|
for hm in reversed(history_msgs):
|
||||||
|
messages.insert(0, hm)
|
||||||
tools = agent.get_tools()
|
tools = agent.get_tools()
|
||||||
model = agent.get_model()
|
model = agent.get_model()
|
||||||
system_prompt = getattr(agent, "_system_prompt", None) or agent.get_system_prompt()
|
system_prompt = getattr(agent, "_system_prompt", None) or agent.get_system_prompt()
|
||||||
timeout_seconds = react_config["timeout_seconds"]
|
timeout_seconds = react_config["timeout_seconds"]
|
||||||
|
|
||||||
# Send routing info as first event
|
|
||||||
yield {
|
|
||||||
"event": "routing",
|
|
||||||
"data": json.dumps(
|
|
||||||
{
|
|
||||||
"skill": matched_skill,
|
|
||||||
"method": routing_method,
|
|
||||||
"confidence": confidence,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
collected_output: list[str] = []
|
collected_output: list[str] = []
|
||||||
try:
|
try:
|
||||||
async for event in react_engine.execute_stream(
|
async for event in react_engine.execute_stream(
|
||||||
|
|
@ -419,27 +436,134 @@ async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(
|
||||||
):
|
):
|
||||||
if event.event_type == "final_answer":
|
if event.event_type == "final_answer":
|
||||||
collected_output.append(event.data.get("output", ""))
|
collected_output.append(event.data.get("output", ""))
|
||||||
yield {
|
|
||||||
"event": event.event_type,
|
|
||||||
"data": json.dumps(
|
|
||||||
{
|
|
||||||
"step": event.step,
|
|
||||||
"data": event.data,
|
|
||||||
"timestamp": event.timestamp,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield {
|
response_text = f"Error: {e}"
|
||||||
"event": "error",
|
else:
|
||||||
"data": json.dumps({"error": str(e)}),
|
response_text = "".join(collected_output) if collected_output else ""
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save assistant response to conversation
|
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||||
response_text = "".join(collected_output) if collected_output else ""
|
|
||||||
if response_text:
|
return ChatResponse(
|
||||||
_conversation_store.add_message(conv.id, "assistant", response_text)
|
conversation_id=conv.id,
|
||||||
|
message=response_text,
|
||||||
|
matched_skill=matched_skill,
|
||||||
|
routing_method=routing_method,
|
||||||
|
confidence=confidence,
|
||||||
|
task_id=task_id,
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/portal/chat/stream")
|
||||||
|
async def chat_stream(request: ChatRequest, req: Request, _auth: None = Depends(_verify_api_key)):
|
||||||
|
"""Stream chat responses via SSE with CostAwareRouter routing."""
|
||||||
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
|
agent, routing_result, matched_skill, routing_method, confidence = await _resolve_for_chat(request, req)
|
||||||
|
|
||||||
|
# Create or reuse conversation
|
||||||
|
conv = _conversation_store.get_or_create(request.conversation_id)
|
||||||
|
await _conversation_store.add_message(conv.id, "user", request.message)
|
||||||
|
|
||||||
|
llm_gateway = req.app.state.llm_gateway
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
# Send routing info as first event
|
||||||
|
yield {
|
||||||
|
"event": "routing",
|
||||||
|
"data": json.dumps(
|
||||||
|
{
|
||||||
|
"skill": matched_skill,
|
||||||
|
"method": routing_method,
|
||||||
|
"confidence": confidence,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if routing_result is not None and routing_result.execution_mode == ExecutionMode.DIRECT_CHAT:
|
||||||
|
# DIRECT_CHAT: direct LLM call, no ReAct loop
|
||||||
|
chat_messages = []
|
||||||
|
if routing_result.system_prompt:
|
||||||
|
chat_messages.append({"role": "system", "content": routing_result.system_prompt})
|
||||||
|
chat_messages.append({"role": "user", "content": request.message})
|
||||||
|
history_msgs = _build_history_messages(conv.id)
|
||||||
|
for hm in history_msgs:
|
||||||
|
chat_messages.insert(-1, hm)
|
||||||
|
response = await llm_gateway.chat(
|
||||||
|
messages=chat_messages,
|
||||||
|
model=routing_result.model or "default",
|
||||||
|
agent_name="default",
|
||||||
|
task_type="chat",
|
||||||
|
)
|
||||||
|
response_text = response.content or ""
|
||||||
|
if response_text:
|
||||||
|
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||||
|
yield {
|
||||||
|
"event": "final_answer",
|
||||||
|
"data": json.dumps(
|
||||||
|
{"step": 0, "data": {"output": response_text}, "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# REACT / SKILL_REACT / REWOO / REFLEXION / PLAN_EXEC / TEAM_COLLAB
|
||||||
|
# Advanced modes fall back to REACT with a warning.
|
||||||
|
if routing_result is not None and routing_result.execution_mode not in (
|
||||||
|
ExecutionMode.REACT,
|
||||||
|
ExecutionMode.SKILL_REACT,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Execution mode {routing_result.execution_mode.value} not yet supported "
|
||||||
|
f"in portal SSE, falling back to REACT"
|
||||||
|
)
|
||||||
|
|
||||||
|
react_config = agent.get_react_config()
|
||||||
|
react_engine = getattr(agent, "_react_engine", None)
|
||||||
|
if react_engine is None:
|
||||||
|
react_engine = ReActEngine(
|
||||||
|
llm_gateway=llm_gateway,
|
||||||
|
max_steps=react_config["max_steps"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
react_engine.reset()
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": request.message}]
|
||||||
|
tools = agent.get_tools()
|
||||||
|
model = agent.get_model()
|
||||||
|
system_prompt = getattr(agent, "_system_prompt", None) or agent.get_system_prompt()
|
||||||
|
timeout_seconds = react_config["timeout_seconds"]
|
||||||
|
|
||||||
|
collected_output: list[str] = []
|
||||||
|
try:
|
||||||
|
async for event in react_engine.execute_stream(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model=model,
|
||||||
|
agent_name=agent.name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
timeout_seconds=timeout_seconds,
|
||||||
|
):
|
||||||
|
if event.event_type == "final_answer":
|
||||||
|
collected_output.append(event.data.get("output", ""))
|
||||||
|
yield {
|
||||||
|
"event": event.event_type,
|
||||||
|
"data": json.dumps(
|
||||||
|
{
|
||||||
|
"step": event.step,
|
||||||
|
"data": event.data,
|
||||||
|
"timestamp": event.timestamp,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"event": "error",
|
||||||
|
"data": json.dumps({"error": str(e)}),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
response_text = "".join(collected_output) if collected_output else ""
|
||||||
|
if response_text:
|
||||||
|
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||||
|
|
||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator())
|
||||||
|
|
||||||
|
|
@ -568,7 +692,15 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
msg_type = msg.get("type")
|
msg_type = msg.get("type")
|
||||||
|
|
||||||
if msg_type == "cancel":
|
if msg_type == "cancel":
|
||||||
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"data": {
|
||||||
|
"status": "cancelled",
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if msg_type == "ping":
|
if msg_type == "ping":
|
||||||
|
|
@ -591,7 +723,7 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
|
await websocket.send_json({"type": "connected", "conversation_id": conv.id})
|
||||||
|
|
||||||
# Add user message to conversation
|
# Add user message to conversation
|
||||||
_conversation_store.add_message(conv.id, "user", message_text)
|
await _conversation_store.add_message(conv.id, "user", message_text)
|
||||||
start_time = datetime.now(timezone.utc)
|
start_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
async def _record_experience(
|
async def _record_experience(
|
||||||
|
|
@ -621,28 +753,25 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to record experience: {e}")
|
logger.warning(f"Failed to record experience: {e}")
|
||||||
|
|
||||||
# Unified routing via CostAwareRouter (handles Layer 0/1/2)
|
# Unified routing via SimpleRouter (minimal: @skill prefix + greeting regex + REACT)
|
||||||
pool = websocket.app.state.agent_pool
|
pool = websocket.app.state.agent_pool
|
||||||
skill_registry = websocket.app.state.skill_registry
|
skill_registry = websocket.app.state.skill_registry
|
||||||
llm_gateway = websocket.app.state.llm_gateway
|
llm_gateway = websocket.app.state.llm_gateway
|
||||||
intent_router: IntentRouter = websocket.app.state.intent_router
|
simple_router: SimpleRouter = websocket.app.state.simple_router
|
||||||
cost_aware_router = websocket.app.state.cost_aware_router
|
|
||||||
|
|
||||||
all_skills = skill_registry.list_skills()
|
all_skills = skill_registry.list_skills()
|
||||||
|
|
||||||
# Get default tools for CostAwareRouter routing (only if default skill exists)
|
# Get default tools for SimpleRouter routing
|
||||||
default_tools = []
|
default_tools = []
|
||||||
default_system_prompt = None
|
default_system_prompt = None
|
||||||
default_agent = pool.get_agent("default")
|
default_agent = pool.get_agent("default")
|
||||||
if default_agent is not None:
|
if default_agent is not None:
|
||||||
default_tools = default_agent.get_tools()
|
default_tools = default_agent.get_tools()
|
||||||
# Prefer _system_prompt (memory-injected) over get_system_prompt() (template)
|
|
||||||
default_system_prompt = (
|
default_system_prompt = (
|
||||||
getattr(default_agent, "_system_prompt", None)
|
getattr(default_agent, "_system_prompt", None)
|
||||||
or default_agent.get_system_prompt()
|
or default_agent.get_system_prompt()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback to first available skill's tools
|
|
||||||
for skill in all_skills:
|
for skill in all_skills:
|
||||||
agent = pool.get_agent(skill.name)
|
agent = pool.get_agent(skill.name)
|
||||||
if agent is not None:
|
if agent is not None:
|
||||||
|
|
@ -652,17 +781,14 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Route via CostAwareRouter (Layer 0/1/2)
|
# Route via SimpleRouter (minimal routing: @skill prefix + greeting regex + REACT)
|
||||||
routing_result = await cost_aware_router.route(
|
routing_result = await simple_router.route(
|
||||||
content=message_text,
|
content=message_text,
|
||||||
skill_registry=skill_registry,
|
skill_registry=skill_registry,
|
||||||
intent_router=intent_router,
|
|
||||||
default_tools=default_tools,
|
default_tools=default_tools,
|
||||||
default_system_prompt=default_system_prompt,
|
default_system_prompt=default_system_prompt,
|
||||||
default_model=model_override or "default",
|
default_model=model_override or "default",
|
||||||
default_agent_name="default",
|
default_agent_name="default",
|
||||||
session_id=conv.id,
|
|
||||||
transparency="SILENT",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
|
|
@ -698,11 +824,15 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
# Store assistant reply for multi-turn context continuity
|
# Store assistant reply for multi-turn context continuity
|
||||||
if response.content:
|
if response.content:
|
||||||
_conversation_store.add_message(conv.id, "assistant", response.content)
|
await _conversation_store.add_message(conv.id, "assistant", response.content)
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "result",
|
"type": "result",
|
||||||
"data": {"status": "completed", "content": response.content},
|
"data": {
|
||||||
|
"status": "completed",
|
||||||
|
"content": response.content,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await _record_experience(
|
await _record_experience(
|
||||||
|
|
@ -713,7 +843,17 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# REACT or SKILL_REACT: agent execution
|
# REACT / SKILL_REACT / REWOO / REFLEXION / PLAN_EXEC / TEAM_COLLAB
|
||||||
|
# Advanced modes fall back to REACT with a warning.
|
||||||
|
if routing_result.execution_mode not in (
|
||||||
|
ExecutionMode.REACT,
|
||||||
|
ExecutionMode.SKILL_REACT,
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"Execution mode {routing_result.execution_mode.value} not yet supported "
|
||||||
|
f"in portal WebSocket, falling back to REACT"
|
||||||
|
)
|
||||||
|
|
||||||
agent_name = routing_result.agent_name or "default"
|
agent_name = routing_result.agent_name or "default"
|
||||||
agent = pool.get_agent(agent_name)
|
agent = pool.get_agent(agent_name)
|
||||||
if agent is None:
|
if agent is None:
|
||||||
|
|
@ -748,11 +888,15 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
)
|
)
|
||||||
# Store assistant reply for multi-turn context continuity
|
# Store assistant reply for multi-turn context continuity
|
||||||
if response.content:
|
if response.content:
|
||||||
_conversation_store.add_message(conv.id, "assistant", response.content)
|
await _conversation_store.add_message(conv.id, "assistant", response.content)
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
"type": "result",
|
"type": "result",
|
||||||
"data": {"status": "completed", "content": response.content},
|
"data": {
|
||||||
|
"status": "completed",
|
||||||
|
"content": response.content,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
await _record_experience(
|
await _record_experience(
|
||||||
|
|
@ -817,10 +961,18 @@ async def portal_websocket(websocket: WebSocket):
|
||||||
|
|
||||||
response_text = "".join(collected_output) if collected_output else ""
|
response_text = "".join(collected_output) if collected_output else ""
|
||||||
if response_text:
|
if response_text:
|
||||||
_conversation_store.add_message(conv.id, "assistant", response_text)
|
await _conversation_store.add_message(conv.id, "assistant", response_text)
|
||||||
|
|
||||||
outcome = "success" if response_text else "failure"
|
outcome = "success" if response_text else "failure"
|
||||||
await websocket.send_json({"type": "result", "data": {"message": response_text}})
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "result",
|
||||||
|
"data": {
|
||||||
|
"message": response_text,
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
await _record_experience(
|
await _record_experience(
|
||||||
routing_result.skill_name or "agent",
|
routing_result.skill_name or "agent",
|
||||||
message_text,
|
message_text,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,264 @@
|
||||||
|
"""E2E Agent Capability Tests — SimpleRouter Backtest (Real LLM).
|
||||||
|
|
||||||
|
Tests SimpleRouter.route() using real LLM configuration loaded from
|
||||||
|
agentkit.yaml. Records full SkillRoutingResult for precise analysis.
|
||||||
|
|
||||||
|
Key differences from old CostAwareRouter backtest:
|
||||||
|
- No HeuristicClassifier complexity scoring
|
||||||
|
- No IntentRouter LLM classification
|
||||||
|
- No SemanticRouter embedding matching
|
||||||
|
- SimpleRouter: @skill prefix + greeting regex + default REACT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.chat.simple_router import SimpleRouter
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode
|
||||||
|
from agentkit.server.app import _build_llm_gateway, _build_skill_registry
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
# Test cases — covering all known problem scenarios
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
ROUTING_TEST_CASES = [
|
||||||
|
# --- Greeting/Chitchat → DIRECT_CHAT ---
|
||||||
|
{"id": "greeting_cn", "input": "你好", "expected_mode": "direct_chat"},
|
||||||
|
{"id": "greeting_en", "input": "hello", "expected_mode": "direct_chat"},
|
||||||
|
{"id": "chitchat_thanks", "input": "谢谢", "expected_mode": "direct_chat"},
|
||||||
|
{"id": "identity_who", "input": "你是谁", "expected_mode": "direct_chat"},
|
||||||
|
|
||||||
|
# --- Tool-requiring queries → REACT ---
|
||||||
|
# These are the core problem scenarios that CostAwareRouter failed on
|
||||||
|
{"id": "colloquial_ip_1", "input": "查下ip", "expected_mode": "react"},
|
||||||
|
{"id": "colloquial_ip_2", "input": "查看当前ip", "expected_mode": "react"},
|
||||||
|
{"id": "colloquial_ip_3", "input": "获取ip地址", "expected_mode": "react"},
|
||||||
|
{"id": "colloquial_ip_4", "input": "看下ip", "expected_mode": "react"},
|
||||||
|
{"id": "colloquial_ip_5", "input": "帮我查一下ip", "expected_mode": "react"},
|
||||||
|
{"id": "tool_search", "input": "搜索golang教程", "expected_mode": "react"},
|
||||||
|
{"id": "tool_shell", "input": "执行ls命令", "expected_mode": "react"},
|
||||||
|
{"id": "tool_file", "input": "读一下配置文件", "expected_mode": "react"},
|
||||||
|
{"id": "tool_monitor", "input": "检查服务状态", "expected_mode": "react"},
|
||||||
|
{"id": "tool_download", "input": "下载这个文件", "expected_mode": "react"},
|
||||||
|
|
||||||
|
# --- Translation/knowledge → REACT (LLM decides no tool needed) ---
|
||||||
|
{"id": "translation", "input": "翻译hello为中文", "expected_mode": "react"},
|
||||||
|
{"id": "knowledge", "input": "什么是机器学习", "expected_mode": "react"},
|
||||||
|
{"id": "summarize", "input": "帮我总结一下这段话", "expected_mode": "react"},
|
||||||
|
|
||||||
|
# --- Complex queries → REACT ---
|
||||||
|
{"id": "complex_analysis", "input": "帮我分析一下这个数据并生成报告", "expected_mode": "react"},
|
||||||
|
{"id": "complex_code", "input": "重构这个函数使其更高效", "expected_mode": "react"},
|
||||||
|
{"id": "complex_multi", "input": "搜索最新的AI论文并总结关键发现", "expected_mode": "react"},
|
||||||
|
|
||||||
|
# --- @skill prefix → SKILL_REACT ---
|
||||||
|
{"id": "skill_prefix_shell", "input": "@skill:react_agent 查看当前ip", "expected_mode": "skill_react"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Paraphrase consistency test cases — same intent, different expressions
|
||||||
|
PARAPHRASE_CASES = [
|
||||||
|
{
|
||||||
|
"id": "ip_check_variants",
|
||||||
|
"original": "查看当前ip",
|
||||||
|
"paraphrases": ["查下ip", "获取ip地址", "看下ip", "帮我查一下ip", "ip是什么"],
|
||||||
|
"expected_mode": "react",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "search_variants",
|
||||||
|
"original": "搜索golang教程",
|
||||||
|
"paraphrases": ["搜一下golang教程", "找下golang学习资料", "帮我搜golang入门"],
|
||||||
|
"expected_mode": "react",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
# Real component initialization
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def _find_config_path() -> str | None:
|
||||||
|
candidates = [
|
||||||
|
os.environ.get("AGENTKIT_CONFIG", ""),
|
||||||
|
str(Path.cwd() / "agentkit.yaml"),
|
||||||
|
str(Path.home() / ".agentkit" / "agentkit.yaml"),
|
||||||
|
]
|
||||||
|
for path in candidates:
|
||||||
|
if path and Path(path).is_file():
|
||||||
|
return path
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_real_components() -> tuple[SimpleRouter, SkillRegistry]:
|
||||||
|
config_path = _find_config_path()
|
||||||
|
if not config_path:
|
||||||
|
pytest.skip("No agentkit.yaml found")
|
||||||
|
|
||||||
|
env_path = Path(config_path).parent / ".env"
|
||||||
|
if env_path.exists():
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv(env_path)
|
||||||
|
except ImportError:
|
||||||
|
with open(env_path) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#") and "=" in line:
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
os.environ.setdefault(key.strip(), value.strip().strip("'\""))
|
||||||
|
|
||||||
|
server_config = ServerConfig.from_yaml(config_path)
|
||||||
|
|
||||||
|
if not server_config.has_llm_provider():
|
||||||
|
dashscope_key = os.environ.get("DASHSCOPE_API_KEY", "")
|
||||||
|
if dashscope_key:
|
||||||
|
for name, pconf in server_config.llm_config.providers.items():
|
||||||
|
if not pconf.api_key:
|
||||||
|
pconf.api_key = dashscope_key
|
||||||
|
if not pconf.base_url:
|
||||||
|
if dashscope_key.startswith("sk-sp-"):
|
||||||
|
pconf.base_url = "https://coding.dashscope.aliyuncs.com/v1"
|
||||||
|
else:
|
||||||
|
pconf.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
break
|
||||||
|
|
||||||
|
if not server_config.has_llm_provider():
|
||||||
|
pytest.skip("No LLM provider with valid API key")
|
||||||
|
|
||||||
|
skill_registry = _build_skill_registry(server_config)
|
||||||
|
router = SimpleRouter(skill_registry=skill_registry)
|
||||||
|
|
||||||
|
return router, skill_registry
|
||||||
|
|
||||||
|
|
||||||
|
_cached_components: tuple[SimpleRouter, SkillRegistry] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_components() -> tuple[SimpleRouter, SkillRegistry]:
|
||||||
|
global _cached_components
|
||||||
|
if _cached_components is None:
|
||||||
|
_cached_components = _build_real_components()
|
||||||
|
return _cached_components
|
||||||
|
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
# Test classes
|
||||||
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e_capability
|
||||||
|
class TestSimpleRouterBasic:
|
||||||
|
"""Test SimpleRouter basic routing: greeting → DIRECT_CHAT, others → REACT."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
ROUTING_TEST_CASES,
|
||||||
|
ids=[c["id"] for c in ROUTING_TEST_CASES],
|
||||||
|
)
|
||||||
|
def test_routing(self, case: dict):
|
||||||
|
router, skill_registry = _get_components()
|
||||||
|
result = asyncio.run(
|
||||||
|
router.route(
|
||||||
|
content=case["input"],
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=["shell", "search", "file_read"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
actual_mode = result.execution_mode.value
|
||||||
|
expected_mode = case["expected_mode"]
|
||||||
|
assert actual_mode == expected_mode, (
|
||||||
|
f"'{case['input']}': expected {expected_mode}, got {actual_mode} "
|
||||||
|
f"(method={result.match_method}, confidence={result.match_confidence})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e_capability
|
||||||
|
class TestSimpleRouterParaphraseConsistency:
|
||||||
|
"""Test that paraphrased inputs route to the same execution mode."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
PARAPHRASE_CASES,
|
||||||
|
ids=[c["id"] for c in PARAPHRASE_CASES],
|
||||||
|
)
|
||||||
|
def test_paraphrase_consistency(self, case: dict):
|
||||||
|
router, skill_registry = _get_components()
|
||||||
|
expected_mode = case["expected_mode"]
|
||||||
|
|
||||||
|
# Test original
|
||||||
|
result = asyncio.run(
|
||||||
|
router.route(
|
||||||
|
content=case["original"],
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=["shell", "search", "file_read"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result.execution_mode.value == expected_mode, (
|
||||||
|
f"Original '{case['original']}': expected {expected_mode}, got {result.execution_mode.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test all paraphrases
|
||||||
|
for para in case["paraphrases"]:
|
||||||
|
result = asyncio.run(
|
||||||
|
router.route(
|
||||||
|
content=para,
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=["shell", "search", "file_read"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert result.execution_mode.value == expected_mode, (
|
||||||
|
f"Paraphrase '{para}': expected {expected_mode}, got {result.execution_mode.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e_capability
|
||||||
|
class TestSimpleRouterMetrics:
|
||||||
|
"""Compute and report routing accuracy metrics."""
|
||||||
|
|
||||||
|
def test_accuracy_report(self):
|
||||||
|
"""Run all test cases and compute accuracy metrics."""
|
||||||
|
router, skill_registry = _get_components()
|
||||||
|
total = len(ROUTING_TEST_CASES)
|
||||||
|
correct = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for case in ROUTING_TEST_CASES:
|
||||||
|
result = asyncio.run(
|
||||||
|
router.route(
|
||||||
|
content=case["input"],
|
||||||
|
skill_registry=skill_registry,
|
||||||
|
default_tools=["shell", "search", "file_read"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
actual_mode = result.execution_mode.value
|
||||||
|
is_correct = actual_mode == case["expected_mode"]
|
||||||
|
if is_correct:
|
||||||
|
correct += 1
|
||||||
|
results.append({
|
||||||
|
"id": case["id"],
|
||||||
|
"input": case["input"],
|
||||||
|
"expected": case["expected_mode"],
|
||||||
|
"actual": actual_mode,
|
||||||
|
"method": result.match_method,
|
||||||
|
"correct": is_correct,
|
||||||
|
})
|
||||||
|
|
||||||
|
accuracy = correct / total * 100
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"SimpleRouter Accuracy Report")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Total: {total}, Correct: {correct}, Accuracy: {accuracy:.1f}%")
|
||||||
|
print(f"{'-'*60}")
|
||||||
|
for r in results:
|
||||||
|
status = "✓" if r["correct"] else "✗"
|
||||||
|
print(f" {status} {r['id']}: '{r['input']}' → {r['actual']} (expected {r['expected']})")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Assert minimum accuracy threshold
|
||||||
|
assert accuracy >= 85.0, f"Accuracy {accuracy:.1f}% is below 85% threshold"
|
||||||
|
|
@ -0,0 +1,223 @@
|
||||||
|
"""Unit tests for SimpleRouter — minimal routing layer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.chat.simple_router import SimpleRouter
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class MockSkill:
|
||||||
|
"""Minimal skill mock for testing."""
|
||||||
|
|
||||||
|
def __init__(self, name: str, execution_mode: str = "react", tools: list | None = None, prompt: dict | None = None):
|
||||||
|
self.name = name
|
||||||
|
self.execution_mode = execution_mode
|
||||||
|
self.tools = tools or []
|
||||||
|
self.prompt = prompt or {}
|
||||||
|
|
||||||
|
|
||||||
|
class MockSkillRegistry:
|
||||||
|
"""Minimal skill registry mock."""
|
||||||
|
|
||||||
|
def __init__(self, skills: dict[str, MockSkill] | None = None):
|
||||||
|
self._skills = skills or {}
|
||||||
|
|
||||||
|
def get(self, name: str) -> MockSkill:
|
||||||
|
if name not in self._skills:
|
||||||
|
raise ValueError(f"Skill '{name}' not found")
|
||||||
|
return self._skills[name]
|
||||||
|
|
||||||
|
def list_skills(self) -> list[MockSkill]:
|
||||||
|
return list(self._skills.values())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry() -> MockSkillRegistry:
|
||||||
|
return MockSkillRegistry({
|
||||||
|
"shell_agent": MockSkill("shell_agent", execution_mode="react", tools=["shell"]),
|
||||||
|
"direct_agent": MockSkill("direct_agent", execution_mode="direct", tools=[]),
|
||||||
|
"rewoo_agent": MockSkill("rewoo_agent", execution_mode="rewoo", tools=["planner"]),
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(registry: MockSkillRegistry) -> SimpleRouter:
|
||||||
|
return SimpleRouter(
|
||||||
|
skill_registry=registry,
|
||||||
|
default_tools=["shell", "search", "file_read"],
|
||||||
|
default_system_prompt="You are a helpful assistant.",
|
||||||
|
default_model="default",
|
||||||
|
default_agent_name="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Layer 0: @skill:xxx prefix
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSkillPrefix:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skill_prefix_routes_to_skill(self, router: SimpleRouter):
|
||||||
|
result = await router.route("@skill:shell_agent 查看当前ip")
|
||||||
|
assert result.matched is True
|
||||||
|
assert result.skill_name == "shell_agent"
|
||||||
|
assert result.match_method == "skill_prefix"
|
||||||
|
assert result.match_confidence == 1.0
|
||||||
|
assert result.execution_mode == ExecutionMode.SKILL_REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skill_prefix_direct_mode(self, router: SimpleRouter):
|
||||||
|
result = await router.route("@skill:direct_agent 翻译hello")
|
||||||
|
assert result.matched is True
|
||||||
|
assert result.skill_name == "direct_agent"
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skill_prefix_rewoo_mode(self, router: SimpleRouter):
|
||||||
|
result = await router.route("@skill:rewoo_agent 重构代码")
|
||||||
|
assert result.matched is True
|
||||||
|
assert result.skill_name == "rewoo_agent"
|
||||||
|
assert result.execution_mode == ExecutionMode.REWOO
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_skill_falls_back_to_react(self, router: SimpleRouter):
|
||||||
|
result = await router.route("@skill:nonexistent 查询")
|
||||||
|
assert result.matched is False
|
||||||
|
assert result.match_method == "skill_not_found_fallback"
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Layer 1: Greeting/chitchat/identity regex
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDirectChat:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_greeting_cn(self, router: SimpleRouter):
|
||||||
|
result = await router.route("你好")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
assert result.match_method == "regex_direct"
|
||||||
|
assert result.tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_greeting_en(self, router: SimpleRouter):
|
||||||
|
result = await router.route("hello")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chitchat(self, router: SimpleRouter):
|
||||||
|
result = await router.route("谢谢")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identity_question(self, router: SimpleRouter):
|
||||||
|
result = await router.route("你是谁")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_identity_question_en(self, router: SimpleRouter):
|
||||||
|
result = await router.route("who are you")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Default: REACT
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestDefaultReact:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_colloquial_tool_query(self, router: SimpleRouter):
|
||||||
|
"""口语化工具查询 — 这是之前路由层误判的核心场景"""
|
||||||
|
result = await router.route("查下ip")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
assert result.match_method == "default_react"
|
||||||
|
assert len(result.tools) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_various_colloquial_expressions(self, router: SimpleRouter):
|
||||||
|
"""各种口语化说法都应走 REACT,让 LLM 决定"""
|
||||||
|
queries = [
|
||||||
|
"查看当前ip",
|
||||||
|
"获取ip地址",
|
||||||
|
"看下ip",
|
||||||
|
"帮我查一下ip",
|
||||||
|
"搜索golang教程",
|
||||||
|
"执行ls命令",
|
||||||
|
"读一下配置文件",
|
||||||
|
"检查服务状态",
|
||||||
|
]
|
||||||
|
for query in queries:
|
||||||
|
result = await router.route(query)
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complex_query(self, router: SimpleRouter):
|
||||||
|
result = await router.route("帮我分析一下这个数据并生成报告")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_translation_goes_react(self, router: SimpleRouter):
|
||||||
|
"""翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具"""
|
||||||
|
result = await router.route("翻译hello为中文")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
# LLM will see tools but decide not to use them
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_tools_included(self, router: SimpleRouter):
|
||||||
|
result = await router.route("查下ip")
|
||||||
|
assert "shell" in result.tools
|
||||||
|
assert "search" in result.tools
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_system_prompt(self, router: SimpleRouter):
|
||||||
|
result = await router.route("查下ip")
|
||||||
|
assert result.system_prompt == "You are a helpful assistant."
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_input(self, router: SimpleRouter):
|
||||||
|
result = await router.route("")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_whitespace_only(self, router: SimpleRouter):
|
||||||
|
result = await router.route(" ")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_greeting_with_extra_spaces(self, router: SimpleRouter):
|
||||||
|
result = await router.route(" 你好 ")
|
||||||
|
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_registry(self):
|
||||||
|
"""Router without skill registry should still work for non-skill queries"""
|
||||||
|
router = SimpleRouter(default_tools=["shell"])
|
||||||
|
result = await router.route("查下ip")
|
||||||
|
assert result.execution_mode == ExecutionMode.REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_override_defaults(self, router: SimpleRouter):
|
||||||
|
"""Route-time overrides should work"""
|
||||||
|
result = await router.route(
|
||||||
|
"查下ip",
|
||||||
|
default_tools=["shell_only"],
|
||||||
|
default_model="gpt-4o",
|
||||||
|
)
|
||||||
|
assert result.tools == ["shell_only"]
|
||||||
|
assert result.model == "gpt-4o"
|
||||||
Loading…
Reference in New Issue