fix(review): comprehensive P0-P2 code review fixes
This commit is contained in:
parent
2e55aae775
commit
5ef08a3b30
264
README.md
264
README.md
|
|
@ -6,13 +6,15 @@
|
|||
|
||||
AgentKit 解决的核心问题:**从写 150 行 Agent 代码降为 10-20 行 YAML 配置**。
|
||||
|
||||
传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 6 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 Skill(Prompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。
|
||||
传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 8 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 Skill(Prompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。
|
||||
|
||||
核心定位:
|
||||
|
||||
- **配置驱动** -- YAML 定义 Skill,无需写 Agent 子类
|
||||
- **生产就绪** -- 内置质量门禁、模型降级、用量统计
|
||||
- **两种部署** -- Python 库直接引用,或 FastAPI 独立部署
|
||||
- **三种使用** -- Python 库引用、CLI 聊天、Web GUI 界面
|
||||
- **工具丰富** -- 内置 Shell、搜索、爬虫、记忆等工具,支持 MCP 扩展
|
||||
- **Pipeline 编排** -- 多 Agent 协同、Saga 补偿、动态流水线
|
||||
|
||||
## 核心特性
|
||||
|
||||
|
|
@ -22,7 +24,7 @@ Think -> Act -> Observe 循环。LLM 自主决定是否调用工具、调用哪
|
|||
|
||||
### 2. LLM Gateway
|
||||
|
||||
统一 LLM 调用入口。Provider 注册、模型别名解析(如 `deepseek` -> `deepseek/deepseek-chat`)、Fallback 降级策略、Token 用量和成本追踪。
|
||||
统一 LLM 调用入口。Provider 注册、模型别名解析(如 `default` -> `dashscope/qwen3-coder-plus`)、Fallback 降级策略、Token 用量和成本追踪。支持百炼 DashScope、OpenAI、DeepSeek 等 OpenAI 兼容 API。
|
||||
|
||||
### 3. Skill 系统
|
||||
|
||||
|
|
@ -40,73 +42,102 @@ Skill = SkillConfig + 绑定 Tools。一个 Skill 代表一个可执行技能,
|
|||
|
||||
Schema 验证 + 字段类型归一化(str -> int/float/bool)+ 元数据附加(version、produced_at、quality_score)。所有 Skill 产出统一为 StandardOutput 格式。
|
||||
|
||||
### 7. 内置工具集
|
||||
|
||||
开箱即用的工具插件,覆盖常见 Agent 需求:
|
||||
|
||||
| 工具 | 说明 |
|
||||
|------|------|
|
||||
| `ShellTool` | 执行 Shell 命令,白名单安全机制 + 用户确认 |
|
||||
| `WebSearchTool` | DuckDuckGo / Bing 网页搜索 |
|
||||
| `BaiduSearchTool` | 百度搜索 |
|
||||
| `WebCrawlTool` | 网页抓取与内容提取 |
|
||||
| `MemoryTool` | 短期/长期记忆管理 |
|
||||
| `AskHumanTool` | 向用户提问获取信息 |
|
||||
| `SchemaExtractTool` | 从文本提取结构化数据 |
|
||||
| `SchemaGenerateTool` | 生成 JSON Schema |
|
||||
| `MCPTool` | MCP 协议工具扩展 |
|
||||
|
||||
工具组合:`SequentialChain`(顺序链)、`ParallelFanOut`(并行扇出)、`DynamicSelector`(动态选择)。
|
||||
|
||||
### 8. Pipeline 编排
|
||||
|
||||
多 Agent 协同编排,支持复杂工作流:
|
||||
|
||||
- **PipelineEngine** -- 阶段式流水线执行,支持自适应配置
|
||||
- **SagaOrchestrator** -- 分布式事务补偿,失败自动回滚
|
||||
- **DynamicPipeline** -- 运行时动态调整流水线结构
|
||||
- **PipelineReflector** -- 执行反思与重规划
|
||||
- **HandoffManager** -- Agent 间任务移交
|
||||
|
||||
## 架构图
|
||||
|
||||
```
|
||||
+------------------+
|
||||
| User Request |
|
||||
+--------+---------+
|
||||
|
|
||||
v
|
||||
+-------------+--------------+
|
||||
| IntentRouter |
|
||||
| (keyword -> LLM classify) |
|
||||
+-------------+--------------+
|
||||
|
|
||||
matched_skill
|
||||
|
|
||||
v
|
||||
+-------------+--------------+
|
||||
| ConfigDrivenAgent |
|
||||
| (SkillConfig-driven) |
|
||||
+-------------+--------------+
|
||||
|
|
||||
+------------+------------+
|
||||
| |
|
||||
v v
|
||||
+---------+--------+ +----------+---------+
|
||||
| ReActEngine | | Traditional Mode |
|
||||
| Think->Act->Observe| | llm_generate/ |
|
||||
+---------+--------+ | tool_call/custom |
|
||||
| +--------------------+
|
||||
v
|
||||
+----------+----------+
|
||||
| LLM Gateway |
|
||||
| resolve -> chat |
|
||||
| fallback -> track |
|
||||
+----------+----------+
|
||||
|
|
||||
+------+------+
|
||||
| |
|
||||
v v
|
||||
+-----+----+ +-----+-----+
|
||||
| Provider A| | Provider B| ...
|
||||
+-----+----+ +-----+-----+
|
||||
| |
|
||||
v v
|
||||
+-----+----+ +-----+-----+
|
||||
| Tool 1 | | Tool 2 | ...
|
||||
+-----------+ +-----------+
|
||||
|
||||
|
|
||||
v
|
||||
+----------+----------+
|
||||
| Quality Gate |
|
||||
| required_fields |
|
||||
| min_word_count |
|
||||
| schema validation |
|
||||
| custom validator |
|
||||
+----------+----------+
|
||||
|
|
||||
v
|
||||
+----------+----------+
|
||||
| OutputStandardizer |
|
||||
| schema + normalize |
|
||||
| + metadata |
|
||||
+----------+----------+
|
||||
|
|
||||
v
|
||||
StandardOutput
|
||||
+-------------------+ +-------------------+
|
||||
| Web GUI Chat | | CLI Chat |
|
||||
| (WebSocket) | | (agentkit chat) |
|
||||
+--------+----------+ +--------+----------+
|
||||
| |
|
||||
+----------+----------+
|
||||
|
|
||||
+----------v----------+
|
||||
| Skill Routing |
|
||||
| (keyword -> LLM) |
|
||||
+----------+----------+
|
||||
|
|
||||
matched_skill
|
||||
|
|
||||
+-------------------v-------------------+
|
||||
| ConfigDrivenAgent |
|
||||
| (SkillConfig-driven) |
|
||||
+-------------------+------------------+
|
||||
|
|
||||
+--------------+--------------+
|
||||
| |
|
||||
v v
|
||||
+---------+--------+ +----------+---------+
|
||||
| ReActEngine | | Traditional Mode |
|
||||
| Think->Act->Observe| | llm_generate/ |
|
||||
+---------+--------+ | tool_call/custom |
|
||||
| +---------------------+
|
||||
v
|
||||
+----------+----------+
|
||||
| LLM Gateway |
|
||||
| resolve -> chat |
|
||||
| fallback -> track |
|
||||
+----------+----------+
|
||||
|
|
||||
+------+------+
|
||||
| |
|
||||
v v
|
||||
+-----+----+ +-----+-----+
|
||||
| DashScope | | OpenAI | ...
|
||||
+-----+----+ +-----+-----+
|
||||
|
|
||||
+----------+----------+
|
||||
| Tool Registry |
|
||||
| shell / search / |
|
||||
| crawl / memory / ... |
|
||||
+----------+----------+
|
||||
|
|
||||
v
|
||||
+----------+----------+
|
||||
| Quality Gate |
|
||||
| required_fields |
|
||||
| min_word_count |
|
||||
| schema validation |
|
||||
| custom validator |
|
||||
+----------+----------+
|
||||
|
|
||||
v
|
||||
+----------+----------+
|
||||
| OutputStandardizer |
|
||||
| schema + normalize |
|
||||
| + metadata |
|
||||
+----------+----------+
|
||||
|
|
||||
v
|
||||
StandardOutput
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
|
@ -147,7 +178,13 @@ agentkit version
|
|||
# 初始化项目(生成配置文件)
|
||||
agentkit init
|
||||
|
||||
# 启动 Server
|
||||
# 启动 Web GUI 聊天界面(推荐)
|
||||
agentkit gui --port 8002
|
||||
|
||||
# 启动 CLI 聊天
|
||||
agentkit chat
|
||||
|
||||
# 启动 Server(API 模式)
|
||||
agentkit serve --host 0.0.0.0 --port 8001
|
||||
|
||||
# 健康检查
|
||||
|
|
@ -247,9 +284,9 @@ from agentkit.llm.providers.openai import OpenAIProvider
|
|||
async def main():
|
||||
# 1. 初始化 LLM Gateway
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("openai", OpenAIProvider(
|
||||
gateway.register_provider("dashscope", OpenAIProvider(
|
||||
api_key="sk-xxx",
|
||||
base_url="https://api.openai.com/v1",
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
))
|
||||
|
||||
# 2. 定义 Skill
|
||||
|
|
@ -263,7 +300,7 @@ async def main():
|
|||
"instructions": "根据用户需求生成高质量内容",
|
||||
"output_format": "以 JSON 格式输出",
|
||||
},
|
||||
llm={"model": "openai/gpt-4o", "temperature": 0.7},
|
||||
llm={"model": "default", "temperature": 0.7},
|
||||
execution_mode="react",
|
||||
max_steps=5,
|
||||
)
|
||||
|
|
@ -318,9 +355,9 @@ from agentkit import LLMGateway
|
|||
from agentkit.llm.providers.openai import OpenAIProvider
|
||||
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("openai", OpenAIProvider(
|
||||
gateway.register_provider("dashscope", OpenAIProvider(
|
||||
api_key="sk-xxx",
|
||||
base_url="https://api.openai.com/v1",
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
))
|
||||
|
||||
app = create_app(llm_gateway=gateway)
|
||||
|
|
@ -352,8 +389,8 @@ from datetime import datetime, timezone
|
|||
async def main():
|
||||
# 初始化 Gateway
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("openai", OpenAIProvider(
|
||||
api_key="sk-xxx", base_url="https://api.openai.com/v1",
|
||||
gateway.register_provider("dashscope", OpenAIProvider(
|
||||
api_key="sk-xxx", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
))
|
||||
|
||||
# 定义多个 Skill
|
||||
|
|
@ -366,7 +403,7 @@ async def main():
|
|||
"instructions": "生成 SEO 优化内容",
|
||||
"output_format": "JSON: {content, word_count}",
|
||||
},
|
||||
llm={"model": "openai/gpt-4o"},
|
||||
llm={"model": "default"},
|
||||
intent={
|
||||
"keywords": ["生成", "内容", "写作"],
|
||||
"description": "内容生成与写作",
|
||||
|
|
@ -390,7 +427,7 @@ async def main():
|
|||
"instructions": "优化内容以提升 AI 搜索可见性",
|
||||
"output_format": "JSON: {optimized_content, seo_score, changes}",
|
||||
},
|
||||
llm={"model": "openai/gpt-4o"},
|
||||
llm={"model": "default"},
|
||||
intent={
|
||||
"keywords": ["优化", "GEO", "SEO"],
|
||||
"description": "内容 GEO/SEO 优化",
|
||||
|
|
@ -472,7 +509,7 @@ curl -X POST http://localhost:8000/api/v1/skills \
|
|||
"instructions": "生成高质量内容",
|
||||
"output_format": "JSON: {content, word_count}"
|
||||
},
|
||||
"llm": {"model": "openai/gpt-4o"},
|
||||
"llm": {"model": "default"},
|
||||
"intent": {
|
||||
"keywords": ["生成", "内容"],
|
||||
"description": "内容生成"
|
||||
|
|
@ -546,7 +583,7 @@ async def main():
|
|||
"instructions": "生成高质量内容",
|
||||
"output_format": "JSON: {content, word_count}",
|
||||
},
|
||||
"llm": {"model": "openai/gpt-4o"},
|
||||
"llm": {"model": "default"},
|
||||
"intent": {"keywords": ["生成", "内容"], "description": "内容生成"},
|
||||
"quality_gate": {"required_fields": ["content"], "max_retries": 2},
|
||||
"execution_mode": "react",
|
||||
|
|
@ -621,7 +658,7 @@ prompt:
|
|||
output_format: "JSON: generate_topics 返回 {topics: [{title, reason, keywords}]},generate_article 返回 {content, word_count}"
|
||||
|
||||
llm:
|
||||
model: "deepseek"
|
||||
model: "default"
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
||||
|
|
@ -664,34 +701,30 @@ skill = Skill(config=config)
|
|||
|
||||
```yaml
|
||||
providers:
|
||||
dashscope:
|
||||
api_key: "${DASHSCOPE_API_KEY}"
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
models:
|
||||
qwen3-coder-plus:
|
||||
max_tokens: 64000
|
||||
cost_per_1k_input: 0.00014
|
||||
cost_per_1k_output: 0.00028
|
||||
openai:
|
||||
api_key: "sk-xxx"
|
||||
api_key: "${OPENAI_API_KEY}"
|
||||
base_url: "https://api.openai.com/v1"
|
||||
models:
|
||||
gpt-4o:
|
||||
cost_per_1k_input: 0.005
|
||||
cost_per_1k_output: 0.015
|
||||
gpt-4o-mini:
|
||||
cost_per_1k_input: 0.00015
|
||||
cost_per_1k_output: 0.0006
|
||||
deepseek:
|
||||
api_key: "sk-xxx"
|
||||
base_url: "https://api.deepseek.com/v1"
|
||||
models:
|
||||
deepseek-chat:
|
||||
cost_per_1k_input: 0.001
|
||||
cost_per_1k_output: 0.002
|
||||
|
||||
model_aliases:
|
||||
default: "deepseek/deepseek-chat"
|
||||
fast: "openai/gpt-4o-mini"
|
||||
default: "dashscope/qwen3-coder-plus"
|
||||
fast: "dashscope/qwen3-coder-plus"
|
||||
powerful: "openai/gpt-4o"
|
||||
|
||||
fallbacks:
|
||||
openai/gpt-4o:
|
||||
- "deepseek/deepseek-chat"
|
||||
deepseek/deepseek-chat:
|
||||
- "openai/gpt-4o-mini"
|
||||
dashscope/qwen3-coder-plus:
|
||||
- "openai/gpt-4o"
|
||||
```
|
||||
|
||||
加载 LLM 配置:
|
||||
|
|
@ -802,7 +835,7 @@ ReActEngine 实现 Think -> Act -> Observe 循环:
|
|||
统一 LLM 调用入口,核心能力:
|
||||
|
||||
- **Provider 注册**: `gateway.register_provider("openai", provider)`
|
||||
- **模型别名**: `"default"` -> `"deepseek/deepseek-chat"`
|
||||
- **模型别名**: `"default"` -> `"dashscope/qwen3-coder-plus"`
|
||||
- **Fallback 降级**: 主模型失败时自动切换到备选模型
|
||||
- **用量追踪**: 按 agent_name、model 统计 Token 用量和成本
|
||||
- **模型解析**: `"provider/model"` 格式自动路由到对应 Provider
|
||||
|
|
@ -878,10 +911,12 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu
|
|||
|
||||
### server -- FastAPI Server
|
||||
|
||||
独立部署模式,提供 RESTful API:
|
||||
独立部署模式,提供 RESTful API 和 Web GUI:
|
||||
|
||||
| 路径 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `/` | GET | Web GUI 聊天界面 |
|
||||
| `/ws/chat` | WebSocket | GUI 实时聊天通道 |
|
||||
| `/api/v1/agents` | POST | 创建 Agent(指定 skill_name 或 config) |
|
||||
| `/api/v1/agents` | GET | 列出所有 Agent |
|
||||
| `/api/v1/agents/{name}` | GET | 获取 Agent 详情 |
|
||||
|
|
@ -892,6 +927,27 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu
|
|||
| `/api/v1/llm/usage` | GET | 查询 LLM 用量 |
|
||||
| `/api/v1/health` | GET | 健康检查 |
|
||||
|
||||
### Web GUI 聊天界面
|
||||
|
||||
通过 `agentkit gui` 启动,特性:
|
||||
|
||||
- **实时对话** -- WebSocket 流式传输,逐 token 显示
|
||||
- **Markdown 渲染** -- 自动检测并渲染标题、列表、代码块、表格等
|
||||
- **工具确认卡片** -- 危险命令(如 `rm`)执行前弹出确认卡片,用户批准后才执行
|
||||
- **Loading 动画** -- 等待 AI 响应时显示思考动画
|
||||
- **Skill 路由** -- 输入 `@skill_name:` 前缀可指定使用特定 Skill
|
||||
- **会话管理** -- 多会话并行,历史记录持久化
|
||||
|
||||
### orchestrator -- Pipeline 编排
|
||||
|
||||
多 Agent 协同编排模块:
|
||||
|
||||
- **PipelineEngine** -- 按 Stage 定义顺序执行,支持自适应配置和反思重规划
|
||||
- **SagaOrchestrator** -- 分布式事务补偿,失败步骤自动执行补偿操作
|
||||
- **DynamicPipeline** -- 运行时根据条件动态调整流水线结构
|
||||
- **HandoffManager** -- Agent 间任务移交,支持上下文传递
|
||||
- **PipelineStateMemory/Redis/PG** -- 流水线状态持久化,支持内存、Redis、PostgreSQL 后端
|
||||
|
||||
## 配置参考
|
||||
|
||||
### SkillConfig
|
||||
|
|
@ -941,8 +997,8 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu
|
|||
| 字段 | 类型 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| `providers` | dict[str, ProviderConfig] | `{}` | Provider 配置,key 为 provider 名称 |
|
||||
| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "deepseek/deepseek-chat"` |
|
||||
| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `openai/gpt-4o: ["deepseek/deepseek-chat"]` |
|
||||
| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "dashscope/qwen3-coder-plus"` |
|
||||
| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `dashscope/qwen3-coder-plus: ["openai/gpt-4o"]` |
|
||||
|
||||
#### ProviderConfig
|
||||
|
||||
|
|
@ -977,9 +1033,9 @@ from agentkit import LLMGateway
|
|||
from agentkit.llm.providers.openai import OpenAIProvider
|
||||
|
||||
gateway = LLMGateway()
|
||||
gateway.register_provider("deepseek", OpenAIProvider(
|
||||
gateway.register_provider("dashscope", OpenAIProvider(
|
||||
api_key="sk-xxx",
|
||||
base_url="https://api.deepseek.com/v1",
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
))
|
||||
|
||||
app = create_app(llm_gateway=gateway)
|
||||
|
|
@ -1135,6 +1191,8 @@ tools:
|
|||
- search
|
||||
```
|
||||
|
||||
**ShellTool 安全机制**:ShellTool 内置白名单(`ls`、`cat`、`curl` 等安全命令直接执行),非白名单命令会触发用户确认。在 GUI 中以确认卡片形式展示,用户点击"确认执行"后才运行。
|
||||
|
||||
### 代码风格
|
||||
|
||||
项目使用 Ruff 进行代码检查和格式化:
|
||||
|
|
|
|||
|
|
@ -40,3 +40,4 @@ logging:
|
|||
format: text
|
||||
router:
|
||||
classifier: heuristic
|
||||
auction_enabled: false
|
||||
|
|
|
|||
|
|
@ -2,18 +2,9 @@
|
|||
# 环境变量替换:${VAR_NAME} 在启动时由 LLMConfig.from_yaml() 处理
|
||||
|
||||
providers:
|
||||
deepseek:
|
||||
api_key: "${DEEPSEEK_API_KEY}"
|
||||
base_url: "https://api.deepseek.com/v1"
|
||||
models:
|
||||
deepseek-chat:
|
||||
max_tokens: 64000
|
||||
cost_per_1k_input: 0.00014
|
||||
cost_per_1k_output: 0.00028
|
||||
|
||||
openai:
|
||||
api_key: "${OPENAI_API_KEY}"
|
||||
base_url: "${OPENAI_BASE_URL:-https://coding.dashscope.aliyuncs.com/v1}"
|
||||
dashscope:
|
||||
api_key: "${DASHSCOPE_API_KEY}"
|
||||
base_url: "${DASHSCOPE_BASE_URL:-https://dashscope.aliyuncs.com/compatible-mode/v1}"
|
||||
models:
|
||||
qwen3-coder-plus:
|
||||
max_tokens: 64000
|
||||
|
|
@ -21,13 +12,9 @@ providers:
|
|||
cost_per_1k_output: 0.00028
|
||||
|
||||
model_aliases:
|
||||
default: "openai/qwen3-coder-plus"
|
||||
fast: "openai/qwen3-coder-plus"
|
||||
powerful: "openai/qwen3-coder-plus"
|
||||
|
||||
fallbacks:
|
||||
openai/qwen3-coder-plus:
|
||||
- "deepseek/deepseek-chat"
|
||||
default: "dashscope/qwen3-coder-plus"
|
||||
fast: "dashscope/qwen3-coder-plus"
|
||||
powerful: "dashscope/qwen3-coder-plus"
|
||||
|
||||
# 上下文压缩配置 — 长会话自动压缩历史消息,保持 Token 在预算内
|
||||
# GEO Pipeline 启用后,工具输出(搜索结果、网页抓取等)会自动压缩
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ task_mode: llm_generate
|
|||
execution_mode: rewoo
|
||||
max_steps: 8
|
||||
max_concurrency: 3
|
||||
fallback_strategies:
|
||||
- simplified_rewoo
|
||||
- react
|
||||
- direct
|
||||
|
||||
intent:
|
||||
keywords: ["采集", "批量", "并行", "fetch", "collect", "数据获取", "多源"]
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import re
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.marketplace.auction import AuctionHouse, Bid
|
||||
from agentkit.telemetry.tracer import get_tracer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -256,6 +257,8 @@ _CHAT_MODE_RE = re.compile(
|
|||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_SENTENCE_SPLIT_RE = re.compile(r'[,。!?;\n,.!?;]')
|
||||
|
||||
|
||||
def _tokenize_content(content: str) -> list[str]:
|
||||
"""Tokenize content for capability matching. Supports Chinese and English."""
|
||||
|
|
@ -394,7 +397,7 @@ class HeuristicClassifier:
|
|||
score += 0.05
|
||||
|
||||
# 3. 多句加成(逗号/句号/换行分隔)
|
||||
sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content))
|
||||
sentence_count = len(_SENTENCE_SPLIT_RE.split(content))
|
||||
if sentence_count >= 4:
|
||||
score += 0.1
|
||||
elif sentence_count >= 2:
|
||||
|
|
@ -424,12 +427,15 @@ class CostAwareRouter:
|
|||
org_context: Any = None,
|
||||
auction_enabled: bool = False,
|
||||
classifier: str = "heuristic",
|
||||
merged_llm_classify: bool = True,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._model = model
|
||||
self._org_context = org_context
|
||||
self._auction_enabled = auction_enabled
|
||||
self._classifier = classifier
|
||||
self._merged_llm_classify = merged_llm_classify
|
||||
self._auction_house = AuctionHouse() if auction_enabled else None
|
||||
if classifier not in ("heuristic", "llm"):
|
||||
raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'")
|
||||
self._heuristic = HeuristicClassifier()
|
||||
|
|
@ -489,6 +495,175 @@ class CostAwareRouter:
|
|||
logger.warning(f"CostAwareRouter quick_classify failed: {e}")
|
||||
return 0.5
|
||||
|
||||
# -- Layer 1.5: Merged LLM classify (complexity + intent in one call) ---
|
||||
|
||||
async def _classify_merged(
|
||||
self,
|
||||
content: str,
|
||||
skill_registry: Any,
|
||||
intent_router: Any,
|
||||
default_tools: list,
|
||||
default_system_prompt: str | None,
|
||||
default_model: str,
|
||||
default_agent_name: str,
|
||||
agent_tool_registry: Any = None,
|
||||
session_id: str = "",
|
||||
complexity: float = 0.5,
|
||||
) -> SkillRoutingResult:
|
||||
"""合并 LLM 调用:单次 LLM 同时输出 complexity + intent + skill_hint。
|
||||
|
||||
当 HeuristicClassifier 返回不确定区间 (0.3-0.7) 时使用,
|
||||
替代分别调用 quick_classify() 和 IntentRouter._classify_with_llm(),
|
||||
节省 1 次 LLM 调用 (~1-3s)。
|
||||
"""
|
||||
if self._llm_gateway is None or not self._merged_llm_classify:
|
||||
# Fallback: 使用独立的 IntentRouter 路由
|
||||
return await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Build skill list for the prompt
|
||||
skill_hints = []
|
||||
if skill_registry:
|
||||
try:
|
||||
for s in skill_registry.list_skills():
|
||||
if s.config.intent and s.config.intent.keywords:
|
||||
skill_hints.append(s.name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
skill_list_str = ", ".join(skill_hints) if skill_hints else "none"
|
||||
|
||||
prompt = (
|
||||
'You are a routing classifier. Analyze the user request and output:\n'
|
||||
'1. complexity (0.0-1.0): how complex is this request\n'
|
||||
'2. intent: the primary intent category\n'
|
||||
'3. skill_hint: the best matching skill name, or null if none match\n\n'
|
||||
f'Available skills: [{skill_list_str}]\n\n'
|
||||
'---BEGIN USER REQUEST---\n'
|
||||
f'{content}\n'
|
||||
'---END USER REQUEST---\n\n'
|
||||
'Respond ONLY with a JSON object: {"complexity": <float>, "intent": <string>, "skill_hint": <string|null>}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._model,
|
||||
)
|
||||
data = json.loads(response.content.strip())
|
||||
merged_complexity = float(data.get("complexity", 0.5))
|
||||
merged_complexity = max(0.0, min(1.0, merged_complexity))
|
||||
skill_hint = data.get("skill_hint")
|
||||
|
||||
# If skill_hint provided and valid, route directly to that skill
|
||||
if skill_hint and skill_registry:
|
||||
try:
|
||||
matched_skill = skill_registry.get(skill_hint)
|
||||
result = SkillRoutingResult(
|
||||
clean_content=content,
|
||||
skill_name=skill_hint,
|
||||
skill_config=matched_skill.config,
|
||||
skill_tools=matched_skill.tools or [],
|
||||
matched=True,
|
||||
match_method="merged_llm",
|
||||
match_confidence=0.7,
|
||||
complexity=merged_complexity,
|
||||
)
|
||||
# Merge tools
|
||||
agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools
|
||||
seen_names = set()
|
||||
merged_tools = []
|
||||
for tool in result.skill_tools + agent_tools:
|
||||
if tool.name not in seen_names:
|
||||
seen_names.add(tool.name)
|
||||
merged_tools.append(tool)
|
||||
result.tools = merged_tools
|
||||
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
|
||||
result.agent_name = skill_hint
|
||||
result.system_prompt = build_skill_system_prompt(result.skill_config) or default_system_prompt
|
||||
logger.info(
|
||||
f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' "
|
||||
f"(complexity={merged_complexity:.2f})"
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Session {session_id}: merged LLM skill_hint '{skill_hint}' not found: {e}")
|
||||
|
||||
# No valid skill_hint — use complexity to decide routing
|
||||
if merged_complexity < 0.3:
|
||||
return SkillRoutingResult(
|
||||
clean_content=content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="merged_llm_low",
|
||||
match_confidence=1.0 - merged_complexity,
|
||||
complexity=merged_complexity,
|
||||
)
|
||||
elif merged_complexity > 0.7:
|
||||
# High complexity — delegate to Layer 2
|
||||
return SkillRoutingResult(
|
||||
clean_content=content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="merged_llm_high",
|
||||
match_confidence=merged_complexity,
|
||||
complexity=merged_complexity,
|
||||
)
|
||||
else:
|
||||
# Medium complexity, no skill match — default agent
|
||||
return SkillRoutingResult(
|
||||
clean_content=content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="merged_llm_medium",
|
||||
match_confidence=0.5,
|
||||
complexity=merged_complexity,
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.warning(f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default")
|
||||
return SkillRoutingResult(
|
||||
clean_content=content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="merged_llm_fallback",
|
||||
match_confidence=0.5,
|
||||
complexity=0.5,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default")
|
||||
return SkillRoutingResult(
|
||||
clean_content=content,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
model=default_model,
|
||||
agent_name=default_agent_name,
|
||||
matched=False,
|
||||
match_method="merged_llm_fallback",
|
||||
match_confidence=0.5,
|
||||
complexity=0.5,
|
||||
)
|
||||
|
||||
# -- Layer 2: Capability matching / Auction (optional) -----------------
|
||||
|
||||
async def _route_layer2(
|
||||
|
|
@ -505,12 +680,89 @@ class CostAwareRouter:
|
|||
complexity: float = 0.0,
|
||||
trace: list[dict] | None = None,
|
||||
) -> SkillRoutingResult:
|
||||
"""Layer 2: 高复杂度任务通过 org_context.find_best_agent 路由。"""
|
||||
"""Layer 2: 高复杂度任务通过拍卖或 org_context.find_best_agent 路由。"""
|
||||
# Extract capability-like keywords from content for matching
|
||||
content_words = _tokenize_content(content)
|
||||
|
||||
# --- Vickrey auction path (when enabled) ---
|
||||
if self._auction_enabled and self._auction_house is not None and self._org_context is not None:
|
||||
try:
|
||||
# Gather candidate agents from org_context
|
||||
all_agents = self._org_context.list_agents() if hasattr(self._org_context, "list_agents") else []
|
||||
# Filter agents that have at least one relevant capability
|
||||
candidate_agents = []
|
||||
for agent_profile in all_agents:
|
||||
if not agent_profile.availability:
|
||||
continue
|
||||
# Check if agent has any of the content_words as capabilities
|
||||
agent_caps_lower = {c.lower() for c in agent_profile.capabilities}
|
||||
if any(w.lower() in agent_caps_lower for w in content_words):
|
||||
candidate_agents.append(agent_profile)
|
||||
|
||||
# Also include agents that match via find_best_agent (they have ALL required caps)
|
||||
best = self._org_context.find_best_agent(required_capabilities=content_words)
|
||||
if best is not None:
|
||||
best_name = best if isinstance(best, str) else getattr(best, "name", str(best))
|
||||
existing_names = {a.name for a in candidate_agents}
|
||||
if best_name not in existing_names:
|
||||
profile = self._org_context.get_agent_profile(best_name) if hasattr(self._org_context, "get_agent_profile") else best
|
||||
if hasattr(profile, "name"):
|
||||
candidate_agents.append(profile)
|
||||
|
||||
if len(candidate_agents) >= 1:
|
||||
# Build Bid objects for each candidate
|
||||
bids = []
|
||||
for agent_profile in candidate_agents:
|
||||
name = agent_profile.name if hasattr(agent_profile, "name") else str(agent_profile)
|
||||
caps = agent_profile.capabilities if hasattr(agent_profile, "capabilities") else []
|
||||
arch = agent_profile.agent_type if hasattr(agent_profile, "agent_type") else "react"
|
||||
# Use current_load as a proxy for estimated_cost (higher load → higher cost)
|
||||
estimated_cost = float(agent_profile.current_load + 1) if hasattr(agent_profile, "current_load") else 1.0
|
||||
bids.append(Bid(
|
||||
agent_name=name,
|
||||
architecture=arch,
|
||||
estimated_steps=1,
|
||||
estimated_cost=estimated_cost,
|
||||
confidence=0.8,
|
||||
payment_offer=estimated_cost,
|
||||
capabilities=caps,
|
||||
))
|
||||
|
||||
auction_result = await self._auction_house.run_vickrey_auction(
|
||||
task_description=content,
|
||||
bidders=bids,
|
||||
required_capabilities=content_words,
|
||||
)
|
||||
|
||||
if auction_result.winner is not None:
|
||||
winner_name = auction_result.winner.agent_name
|
||||
result = SkillRoutingResult(
|
||||
clean_content=content,
|
||||
matched=True,
|
||||
match_method="vickrey_auction",
|
||||
match_confidence=0.8,
|
||||
agent_name=winner_name,
|
||||
model=default_model,
|
||||
system_prompt=default_system_prompt,
|
||||
tools=default_tools,
|
||||
complexity=complexity,
|
||||
)
|
||||
if trace is not None:
|
||||
trace.append({
|
||||
"layer": 2,
|
||||
"method": "vickrey_auction",
|
||||
"agent_name": winner_name,
|
||||
"complexity": complexity,
|
||||
"selection_reason": auction_result.selection_reason,
|
||||
})
|
||||
return result
|
||||
# No winner from auction → fall through to capability matching
|
||||
except Exception as e:
|
||||
logger.warning(f"CostAwareRouter Layer 2 Vickrey auction failed: {e}")
|
||||
|
||||
# --- Capability matching path (default) ---
|
||||
if self._org_context is not None and hasattr(self._org_context, "find_best_agent"):
|
||||
try:
|
||||
# Extract capability-like keywords from content for matching
|
||||
# find_best_agent expects list[str] of required capabilities
|
||||
content_words = _tokenize_content(content)
|
||||
best_agent = self._org_context.find_best_agent(required_capabilities=content_words)
|
||||
if best_agent is not None:
|
||||
agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent))
|
||||
|
|
@ -689,29 +941,69 @@ class CostAwareRouter:
|
|||
span.set_attribute("route.target", "default")
|
||||
return result
|
||||
|
||||
# Medium complexity → IntentRouter via resolve_skill_routing
|
||||
# Medium complexity → merged LLM classify or IntentRouter
|
||||
if complexity <= 0.7:
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = complexity
|
||||
if self._merged_llm_classify and self._llm_gateway is not None:
|
||||
# Use merged LLM call: complexity + intent in one call
|
||||
result = await self._classify_merged(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
complexity=complexity,
|
||||
)
|
||||
# If merged classify returned high complexity, delegate to Layer 2
|
||||
if result.complexity > 0.7 and result.match_method and result.match_method.startswith("merged_llm_high"):
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "merged_llm_high",
|
||||
"complexity": result.complexity,
|
||||
"delegated_to_layer2": True,
|
||||
})
|
||||
layer2_result = await self._route_layer2(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
complexity=result.complexity,
|
||||
trace=trace,
|
||||
)
|
||||
layer2_result.execution_trace = trace if transparency != "SILENT" else []
|
||||
layer2_result.transparency_level = transparency
|
||||
return layer2_result
|
||||
else:
|
||||
# Fallback: use separate IntentRouter
|
||||
result = await resolve_skill_routing(
|
||||
content=content,
|
||||
skill_registry=skill_registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=default_tools,
|
||||
default_system_prompt=default_system_prompt,
|
||||
default_model=default_model,
|
||||
default_agent_name=default_agent_name,
|
||||
agent_tool_registry=agent_tool_registry,
|
||||
session_id=session_id,
|
||||
)
|
||||
result.complexity = result.complexity or complexity
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "intent_router",
|
||||
"complexity": complexity,
|
||||
"method": result.match_method or "merged_llm",
|
||||
"complexity": result.complexity,
|
||||
"matched": result.matched,
|
||||
})
|
||||
result.execution_trace = trace if transparency != "SILENT" else []
|
||||
result.transparency_level = transparency
|
||||
span.set_attribute("route.layer", result.match_method or "intent_router")
|
||||
span.set_attribute("route.layer", result.match_method or "merged_llm")
|
||||
span.set_attribute("route.target", result.skill_name or "default")
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
|
|||
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.protocol import AgentStatus
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
|
@ -53,6 +54,16 @@ class AgentPool:
|
|||
compressor=self._compressor,
|
||||
)
|
||||
await agent.start()
|
||||
|
||||
# Bind a reusable ReActEngine instance to the agent
|
||||
max_steps = getattr(config, "max_steps", 10) or 10
|
||||
parallel_tools = getattr(config, "parallel_tools", False) or False
|
||||
agent._react_engine = ReActEngine(
|
||||
llm_gateway=self._llm_gateway,
|
||||
max_steps=max_steps,
|
||||
parallel_tools=parallel_tools,
|
||||
)
|
||||
|
||||
self._agents[config.name] = agent
|
||||
logger.info(f"Agent '{config.name}' created and started in pool")
|
||||
|
||||
|
|
|
|||
|
|
@ -74,14 +74,27 @@ class ReActEngine:
|
|||
使 Agent 能够自主推理并选择工具完成任务。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False):
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
if isinstance(parallel_tools, str) and parallel_tools not in ("auto",):
|
||||
raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}")
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_steps = max_steps
|
||||
self._default_timeout = default_timeout
|
||||
self._parallel_tools = parallel_tools
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
||||
Call this before each execute/execute_stream to ensure clean state.
|
||||
The engine itself (LLM gateway, config) is preserved.
|
||||
"""
|
||||
# ReActEngine is stateless between calls — conversation history,
|
||||
# step counts, and trajectory are local to each execute call.
|
||||
# This method exists for API clarity and future stateful extensions.
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
|
|
@ -97,6 +110,7 @@ class ReActEngine:
|
|||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult:
|
||||
"""执行 ReAct 循环
|
||||
|
||||
|
|
@ -127,6 +141,7 @@ class ReActEngine:
|
|||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
),
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
|
@ -144,6 +159,7 @@ class ReActEngine:
|
|||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TaskTimeoutError(
|
||||
|
|
@ -169,6 +185,7 @@ class ReActEngine:
|
|||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult:
|
||||
tools = tools or []
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
|
|
@ -224,7 +241,7 @@ class ReActEngine:
|
|||
else:
|
||||
system_prompt = f"## 参考信息\n{memory_context}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True)
|
||||
|
||||
# 构建初始消息
|
||||
conversation: list[dict[str, Any]] = []
|
||||
|
|
@ -295,8 +312,70 @@ class ReActEngine:
|
|||
conversation.append(assistant_msg)
|
||||
|
||||
# 执行工具调用
|
||||
if self._parallel_tools and len(response.tool_calls) > 1:
|
||||
# 并行执行多个工具调用
|
||||
if self._parallel_tools == "auto" and len(response.tool_calls) > 1:
|
||||
# Auto mode: mixed parallel/serial based on _parallelizable flag
|
||||
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls))
|
||||
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
|
||||
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
|
||||
|
||||
# Result slots indexed by original position
|
||||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||||
|
||||
# Execute serial tools first (in order)
|
||||
for i, tc in serial_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
all_results[i] = (tc, tool_result, tool_duration_ms)
|
||||
|
||||
# Execute parallelizable tools in parallel
|
||||
if len(parallel_calls) > 1:
|
||||
para_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for j, (i, tc) in enumerate(parallel_calls):
|
||||
tool_result = para_results[j]
|
||||
if isinstance(tool_result, Exception):
|
||||
tool_result = {"error": str(tool_result)}
|
||||
all_results[i] = (tc, tool_result, 0)
|
||||
elif len(parallel_calls) == 1:
|
||||
i, tc = parallel_calls[0]
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
all_results[i] = (tc, tool_result, 0)
|
||||
|
||||
# Process all results in original order
|
||||
for i, tc in enumerate(response.tool_calls):
|
||||
tc_obj, tool_result, tool_duration_ms = all_results[i]
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
elif self._should_execute_parallel(response.tool_calls):
|
||||
# 并行执行多个工具调用 (parallel_tools=True)
|
||||
tool_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
|
||||
return_exceptions=True,
|
||||
|
|
@ -338,6 +417,40 @@ class ReActEngine:
|
|||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
# Handle confirmation flow
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
||||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
else:
|
||||
tool_result = {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
|
|
@ -531,6 +644,21 @@ class ReActEngine:
|
|||
else:
|
||||
logger.info("ReActEngine executing with NO tools")
|
||||
|
||||
# Telemetry: record agent request
|
||||
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
|
||||
|
||||
# Start telemetry span for the entire agent execution
|
||||
_span_cm = None
|
||||
_span = None
|
||||
_exec_start = time.monotonic()
|
||||
|
||||
if _OTEL_AVAILABLE:
|
||||
_span_cm = start_span(
|
||||
"agent.execute_stream",
|
||||
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
|
||||
)
|
||||
_span = _span_cm.__enter__()
|
||||
|
||||
# 启动轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.start_trace(
|
||||
|
|
@ -575,11 +703,26 @@ class ReActEngine:
|
|||
step = 0
|
||||
output = ""
|
||||
trace_outcome = "success"
|
||||
_stream_start = time.monotonic()
|
||||
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
|
||||
|
||||
try:
|
||||
while step < self._max_steps:
|
||||
step += 1
|
||||
|
||||
# 协作式取消检查
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
|
||||
# 超时检查
|
||||
if effective_timeout > 0:
|
||||
elapsed = time.monotonic() - _stream_start
|
||||
if elapsed > effective_timeout:
|
||||
trace_outcome = "timeout"
|
||||
raise asyncio.TimeoutError(
|
||||
f"execute_stream exceeded {effective_timeout}s timeout after {elapsed:.1f}s"
|
||||
)
|
||||
|
||||
# Yield thinking event
|
||||
yield ReActEvent(
|
||||
event_type="thinking",
|
||||
|
|
@ -591,7 +734,7 @@ class ReActEngine:
|
|||
llm_start = time.monotonic()
|
||||
|
||||
# Use streaming for token-by-token output
|
||||
stream_content = ""
|
||||
stream_content_chunks: list[str] = []
|
||||
stream_usage = None
|
||||
stream_tool_calls: list[Any] = []
|
||||
stream_model = model
|
||||
|
|
@ -604,7 +747,7 @@ class ReActEngine:
|
|||
tools=tool_schemas,
|
||||
):
|
||||
if chunk.content:
|
||||
stream_content += chunk.content
|
||||
stream_content_chunks.append(chunk.content)
|
||||
yield ReActEvent(
|
||||
event_type="token",
|
||||
step=step,
|
||||
|
|
@ -618,6 +761,7 @@ class ReActEngine:
|
|||
stream_model = chunk.model
|
||||
|
||||
# Build response-like object from stream
|
||||
stream_content = "".join(stream_content_chunks)
|
||||
response = self._build_response_from_stream(
|
||||
content=stream_content,
|
||||
tool_calls=stream_tool_calls,
|
||||
|
|
@ -657,115 +801,168 @@ class ReActEngine:
|
|||
}
|
||||
conversation.append(assistant_msg)
|
||||
|
||||
for tc in response.tool_calls:
|
||||
# Yield tool_call event
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
# Execute tool calls with parallel support
|
||||
if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls):
|
||||
# Parallel execution path
|
||||
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls)))
|
||||
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
|
||||
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
|
||||
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
all_results: list[Any] = [None] * len(response.tool_calls)
|
||||
|
||||
# 检测工具返回的确认请求
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
# Execute serial tools first (handles confirmation flow)
|
||||
for i, tc in serial_calls:
|
||||
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
|
||||
tool_start = time.monotonic()
|
||||
tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler)
|
||||
for ev in confirm_events:
|
||||
yield ev
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
all_results[i] = (tc, tool_result, tool_duration_ms)
|
||||
|
||||
# Yield 确认请求事件
|
||||
# Execute parallelizable tools concurrently
|
||||
if len(parallel_calls) > 1:
|
||||
para_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for j, (i, tc) in enumerate(parallel_calls):
|
||||
tool_result = para_results[j]
|
||||
if isinstance(tool_result, Exception):
|
||||
tool_result = {"error": str(tool_result)}
|
||||
all_results[i] = (tc, tool_result, 0)
|
||||
elif len(parallel_calls) == 1:
|
||||
i, tc = parallel_calls[0]
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
all_results[i] = (tc, tool_result, 0)
|
||||
|
||||
# Process all results in original order
|
||||
for i, tc in enumerate(response.tool_calls):
|
||||
tc_obj, tool_result, tool_duration_ms = all_results[i]
|
||||
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
|
||||
|
||||
react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens)
|
||||
trajectory.append(react_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error)
|
||||
|
||||
yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result})
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
else:
|
||||
# Serial execution path (with confirmation flow)
|
||||
for tc in response.tool_calls:
|
||||
# Yield tool_call event
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
event_type="tool_call",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
data={"tool_name": tc.name, "arguments": tc.arguments},
|
||||
)
|
||||
|
||||
# 等待用户确认
|
||||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
# 用户确认执行:临时绕过安全检查重新执行
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
# 保存原始 _is_dangerous 并临时禁用
|
||||
original_is_dangerous = tool._is_dangerous
|
||||
tool._is_dangerous = lambda cmd: False
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**tc.arguments)
|
||||
finally:
|
||||
tool._is_dangerous = original_is_dangerous
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
)
|
||||
else:
|
||||
# 用户拒绝执行
|
||||
tool_result = {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
)
|
||||
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
# 检测工具返回的确认请求
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
# Yield 确认请求事件
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
# 等待用户确认
|
||||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
# 用户确认执行:使用 per-call override 绕过安全检查
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
# Strip internal metadata and pass skip_dangerous_check flag
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args)
|
||||
finally:
|
||||
pass # No shared state mutation needed
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
)
|
||||
else:
|
||||
# 用户拒绝执行
|
||||
tool_result = {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
)
|
||||
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# Yield tool_result event
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
|
||||
# Yield tool_result event
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
|
|
@ -893,6 +1090,17 @@ class ReActEngine:
|
|||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
# Telemetry: end span and record duration — always runs
|
||||
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
|
||||
if _span is not None:
|
||||
_span.set_attribute("agent.total_steps", len(trajectory))
|
||||
_span.set_attribute("agent.total_tokens", total_tokens)
|
||||
_span.set_attribute("agent.outcome", trace_outcome)
|
||||
_span.set_attribute("agent.duration_ms", _duration_ms)
|
||||
if _span_cm is not None:
|
||||
_span_cm.__exit__(None, None, None)
|
||||
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
|
||||
|
||||
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
|
||||
if memory_retriever and hasattr(memory_retriever, "store_episode"):
|
||||
try:
|
||||
|
|
@ -988,14 +1196,127 @@ class ReActEngine:
|
|||
logger.warning(error_msg)
|
||||
return {"error": error_msg}
|
||||
|
||||
# Strip internal metadata keys before passing to tool
|
||||
clean_args = {k: v for k, v in arguments.items() if not k.startswith("_")}
|
||||
|
||||
try:
|
||||
result = await tool.safe_execute(**arguments)
|
||||
result = await tool.safe_execute(**clean_args)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Tool '{tool_name}' execution failed: {e}"
|
||||
logger.warning(error_msg)
|
||||
return {"error": error_msg}
|
||||
|
||||
async def _execute_tool_with_confirmation(
|
||||
self,
|
||||
tc: Any,
|
||||
tools: list[Tool],
|
||||
step: int,
|
||||
confirmation_handler: Any,
|
||||
) -> tuple[Any, list[ReActEvent]]:
|
||||
"""Execute a tool call with confirmation flow support.
|
||||
|
||||
Used in the parallel execution path for serial (non-parallelizable) tools
|
||||
that may require user confirmation before execution.
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_result, list of ReActEvents to yield)
|
||||
"""
|
||||
events: list[ReActEvent] = []
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
# Check if tool returned a confirmation request
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
))
|
||||
|
||||
# Wait for user confirmation
|
||||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
# User approved: re-execute with _skip_dangerous_check
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
|
||||
clean_args["_skip_dangerous_check"] = True
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**clean_args)
|
||||
except Exception as e:
|
||||
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
))
|
||||
else:
|
||||
# User rejected
|
||||
tool_result = {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
events.append(ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
))
|
||||
|
||||
return tool_result, events
|
||||
|
||||
def _should_execute_parallel(self, tool_calls: list[Any]) -> bool:
|
||||
"""Determine if tool calls should be executed in parallel.
|
||||
|
||||
- parallel_tools=True: always parallel (if >1 tool)
|
||||
- parallel_tools=False: never parallel
|
||||
- parallel_tools="auto": parallel if any tool_call has _parallelizable=true in arguments
|
||||
"""
|
||||
if len(tool_calls) <= 1:
|
||||
return False
|
||||
if self._parallel_tools is True:
|
||||
return True
|
||||
if self._parallel_tools is False:
|
||||
return False
|
||||
# "auto" mode: check _parallelizable metadata in tool call arguments
|
||||
if self._parallel_tools == "auto":
|
||||
parallelizable_indices = self._get_parallelizable_indices(tool_calls)
|
||||
return len(parallelizable_indices) > 1
|
||||
return False
|
||||
|
||||
def _get_parallelizable_indices(self, tool_calls: list[Any]) -> list[int]:
|
||||
"""Get indices of tool_calls that have _parallelizable=true in arguments.
|
||||
|
||||
LLM marks parallelizable tools by including _parallelizable: true
|
||||
in the tool_call arguments.
|
||||
"""
|
||||
indices = []
|
||||
for i, tc in enumerate(tool_calls):
|
||||
args = tc.arguments if hasattr(tc, 'arguments') else {}
|
||||
if isinstance(args, dict) and args.get("_parallelizable") is True:
|
||||
indices.append(i)
|
||||
return indices
|
||||
|
||||
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
|
||||
"""从文本中解析工具调用模式
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
|
@ -33,6 +34,17 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Internal Exceptions ──────────────────────────────────
|
||||
|
||||
|
||||
class _FallbackFailedError(Exception):
|
||||
"""Internal signal: a fallback strategy failed, try the next one."""
|
||||
|
||||
def __init__(self, strategy: str):
|
||||
self.strategy = strategy
|
||||
super().__init__(f"Fallback strategy '{strategy}' failed")
|
||||
|
||||
|
||||
# ── Data Structures ───────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -116,13 +128,25 @@ class ReWOOEngine:
|
|||
"""
|
||||
|
||||
FALLBACK_STRATEGIES = ["simplified_rewoo", "react", "direct"]
|
||||
VALID_STRATEGIES = {"simplified_rewoo", "react", "direct", "plan_exec"}
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0):
|
||||
def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0, fallback_strategies: list[str] | None = None):
|
||||
if max_plan_steps < 1:
|
||||
raise ValueError(f"max_plan_steps must be >= 1, got {max_plan_steps}")
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_plan_steps = max_plan_steps
|
||||
self._default_timeout = default_timeout
|
||||
# Validate and store fallback strategies
|
||||
raw_strategies = fallback_strategies if fallback_strategies is not None else self.FALLBACK_STRATEGIES
|
||||
self._fallback_strategies: list[str] = []
|
||||
for s in raw_strategies:
|
||||
if s in self.VALID_STRATEGIES:
|
||||
self._fallback_strategies.append(s)
|
||||
else:
|
||||
logger.warning(f"Invalid fallback strategy '{s}', skipping. Valid: {self.VALID_STRATEGIES}")
|
||||
if not self._fallback_strategies:
|
||||
logger.warning("No valid fallback strategies, using defaults")
|
||||
self._fallback_strategies = list(self.FALLBACK_STRATEGIES)
|
||||
# ReActEngine 作为 fallback
|
||||
self._react_engine = ReActEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
|
|
@ -145,6 +169,7 @@ class ReWOOEngine:
|
|||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult:
|
||||
"""执行 ReWOO 三阶段流程
|
||||
|
||||
|
|
@ -176,6 +201,7 @@ class ReWOOEngine:
|
|||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
),
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
|
|
@ -193,6 +219,7 @@ class ReWOOEngine:
|
|||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TaskTimeoutError(
|
||||
|
|
@ -218,6 +245,7 @@ class ReWOOEngine:
|
|||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult:
|
||||
tools = tools or []
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
|
|
@ -296,109 +324,31 @@ class ReWOOEngine:
|
|||
tokens_used=planning_tokens,
|
||||
)
|
||||
|
||||
# 如果规划失败,尝试渐进式回退
|
||||
# 如果规划失败,按配置的 fallback 策略顺序尝试回退
|
||||
if plan is None:
|
||||
# 尝试简化规划(max_steps=3)
|
||||
logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=3,
|
||||
)
|
||||
total_tokens += simplified_tokens
|
||||
if plan is not None and plan.steps:
|
||||
fallback_strategy = "simplified_rewoo"
|
||||
logger.info("Simplified ReWOO planning succeeded")
|
||||
except Exception as e2:
|
||||
logger.warning(f"Simplified ReWOO planning also failed: {e2}")
|
||||
|
||||
if plan is None:
|
||||
# 回退到 ReAct
|
||||
fallback_strategy = "react"
|
||||
logger.warning("ReWOO planning failed, falling back to ReActEngine")
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome="fallback")
|
||||
try:
|
||||
react_result = await self._react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0, # timeout already handled by outer wrapper
|
||||
)
|
||||
react_result.fallback_strategy = fallback_strategy
|
||||
return react_result
|
||||
except Exception as react_err:
|
||||
# ReAct 也失败,回退到 Direct(简单 LLM 调用)
|
||||
fallback_strategy = "direct"
|
||||
logger.warning(f"ReAct fallback also failed: {react_err}, falling back to direct LLM call")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
|
||||
direct_response = await self._llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += direct_response.usage.total_tokens
|
||||
|
||||
direct_step = ReWOOStep(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
content=direct_response.content,
|
||||
tokens=direct_response.usage.total_tokens,
|
||||
plan_step_id=None,
|
||||
)
|
||||
trajectory.append(direct_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
output_data={"content": direct_response.content},
|
||||
tokens_used=direct_response.usage.total_tokens,
|
||||
)
|
||||
|
||||
trace_outcome = "success"
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
return ReActResult(
|
||||
output=direct_response.content or "",
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy=fallback_strategy,
|
||||
)
|
||||
except Exception as direct_err:
|
||||
logger.error(f"Direct LLM fallback also failed: {direct_err}")
|
||||
raise
|
||||
fallback_strategy = await self._try_fallback_strategies(
|
||||
strategies=self._fallback_strategies,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
trajectory=trajectory,
|
||||
total_tokens=total_tokens,
|
||||
confirmation_handler=confirmation_handler,
|
||||
)
|
||||
if fallback_strategy is not None:
|
||||
return fallback_strategy
|
||||
# All fallback strategies exhausted
|
||||
raise RuntimeError("All ReWOO fallback strategies exhausted")
|
||||
|
||||
# 如果计划为空(无需工具),直接让 LLM 回答
|
||||
if not plan.steps:
|
||||
|
|
@ -579,6 +529,7 @@ class ReWOOEngine:
|
|||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
):
|
||||
"""Execute ReWOO flow, yielding ReActEvent objects.
|
||||
|
||||
|
|
@ -627,7 +578,6 @@ class ReWOOEngine:
|
|||
trace_outcome = "success"
|
||||
|
||||
try:
|
||||
# ── Phase 1: Planning ──
|
||||
yield ReActEvent(
|
||||
event_type="planning",
|
||||
step=0,
|
||||
|
|
@ -648,83 +598,27 @@ class ReWOOEngine:
|
|||
total_tokens += planning_tokens
|
||||
|
||||
if plan is None:
|
||||
# Try simplified planning
|
||||
logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=3,
|
||||
)
|
||||
total_tokens += simplified_tokens
|
||||
except Exception as e2:
|
||||
logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e2}")
|
||||
|
||||
if plan is None:
|
||||
# Planning failed, fall back to ReAct streaming
|
||||
logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine")
|
||||
try:
|
||||
async for event in self._react_engine.execute_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except Exception as react_err:
|
||||
# ReAct also failed, fall back to direct LLM call
|
||||
logger.warning(f"ReAct fallback also failed in stream mode: {react_err}, falling back to direct LLM call")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
|
||||
direct_response = await self._llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += direct_response.usage.total_tokens
|
||||
output = direct_response.content or ""
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=1,
|
||||
data={
|
||||
"output": output,
|
||||
"total_steps": 1,
|
||||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
return
|
||||
except Exception as direct_err:
|
||||
logger.error(f"Direct LLM fallback also failed in stream mode: {direct_err}")
|
||||
raise
|
||||
# Planning failed, try fallback strategies in configured order
|
||||
async for event in self._try_fallback_strategies_stream(
|
||||
strategies=self._fallback_strategies,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
total_tokens=total_tokens,
|
||||
confirmation_handler=confirmation_handler,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="plan_generated",
|
||||
|
|
@ -875,8 +769,11 @@ class ReWOOEngine:
|
|||
"total_tokens": total_tokens,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
trace_outcome = "error"
|
||||
logger.error(f"ReWOO execute_stream failed: {e}")
|
||||
raise
|
||||
finally:
|
||||
# 结束轨迹记录
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.end_trace(outcome=trace_outcome)
|
||||
|
||||
|
|
@ -892,6 +789,582 @@ class ReWOOEngine:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to store task result in episodic memory: {e}")
|
||||
|
||||
# ── Fallback Strategy Helpers ──────────────────────────
|
||||
|
||||
async def _try_fallback_strategies_stream(
|
||||
self,
|
||||
strategies: list[str],
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
effective_system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
total_tokens: int = 0,
|
||||
confirmation_handler: Any | None = None,
|
||||
):
|
||||
"""Stream version: try fallback strategies in configured order, yielding events from the first successful one.
|
||||
|
||||
If all strategies fail, raises RuntimeError.
|
||||
"""
|
||||
for strategy in strategies:
|
||||
if strategy == "simplified_rewoo":
|
||||
try:
|
||||
async for event in self._fallback_simplified_rewoo_stream(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, cancellation_token=cancellation_token,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except _FallbackFailedError:
|
||||
continue
|
||||
|
||||
elif strategy == "react":
|
||||
try:
|
||||
async for event in self._fallback_react_stream(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id, compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except _FallbackFailedError:
|
||||
continue
|
||||
|
||||
elif strategy == "direct":
|
||||
try:
|
||||
async for event in self._fallback_direct_stream(
|
||||
messages=messages, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, total_tokens=total_tokens,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except _FallbackFailedError:
|
||||
continue
|
||||
|
||||
elif strategy == "plan_exec":
|
||||
try:
|
||||
async for event in self._fallback_plan_exec_stream(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, cancellation_token=cancellation_token,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except _FallbackFailedError:
|
||||
continue
|
||||
|
||||
raise RuntimeError("All ReWOO fallback strategies exhausted in stream mode")
|
||||
|
||||
async def _fallback_simplified_rewoo_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
):
|
||||
"""Stream: Simplified ReWOO fallback with max_steps=3"""
|
||||
logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages, tools=tools or [], tool_schemas=tool_schemas,
|
||||
model=model, agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=effective_system_prompt, compressor=compressor,
|
||||
cancellation_token=cancellation_token, max_steps=3,
|
||||
)
|
||||
if plan is not None and plan.steps:
|
||||
logger.info("Simplified ReWOO planning succeeded in stream mode")
|
||||
yield ReActEvent(event_type="plan_generated", step=0, data={
|
||||
"reasoning": plan.reasoning,
|
||||
"steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps],
|
||||
})
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for plan_step in plan.steps:
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments})
|
||||
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
|
||||
tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning})
|
||||
yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result})
|
||||
|
||||
yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."})
|
||||
output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token)
|
||||
yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": simplified_tokens + synthesis_tokens})
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e}")
|
||||
# Failed, continue to next strategy by not returning
|
||||
# This signals the caller to try the next strategy
|
||||
# We need a different approach - raise a specific exception
|
||||
raise _FallbackFailedError("simplified_rewoo")
|
||||
|
||||
async def _fallback_react_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
):
|
||||
"""Stream: ReAct fallback"""
|
||||
logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine")
|
||||
try:
|
||||
async for event in self._react_engine.execute_stream(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=system_prompt, trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever, task_id=task_id,
|
||||
compressor=compressor, retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token, timeout_seconds=0,
|
||||
confirmation_handler=confirmation_handler,
|
||||
):
|
||||
yield event
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"ReAct fallback also failed in stream mode: {e}")
|
||||
raise _FallbackFailedError("react")
|
||||
|
||||
async def _fallback_direct_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
total_tokens: int = 0,
|
||||
):
|
||||
"""Stream: Direct LLM fallback"""
|
||||
logger.warning("Falling back to direct LLM call in stream mode")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
direct_response = await self._llm_gateway.chat(messages=direct_messages, model=model, agent_name=agent_name, task_type=task_type)
|
||||
output = direct_response.content or ""
|
||||
yield ReActEvent(event_type="final_answer", step=1, data={"output": output, "total_steps": 1, "total_tokens": total_tokens + direct_response.usage.total_tokens})
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Direct LLM fallback also failed in stream mode: {e}")
|
||||
raise _FallbackFailedError("direct")
|
||||
|
||||
async def _fallback_plan_exec_stream(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
):
|
||||
"""Stream: Plan-Exec fallback with max_steps=5"""
|
||||
logger.warning("Falling back to plan-exec mode in stream mode (max_steps=5)")
|
||||
try:
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
plan, plan_tokens = await self._plan_phase(
|
||||
messages=messages, tools=tools or [], tool_schemas=tool_schemas,
|
||||
model=model, agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=effective_system_prompt, compressor=compressor,
|
||||
cancellation_token=cancellation_token, max_steps=5,
|
||||
)
|
||||
if plan is not None and plan.steps:
|
||||
yield ReActEvent(event_type="plan_generated", step=0, data={
|
||||
"reasoning": plan.reasoning,
|
||||
"steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps],
|
||||
})
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for plan_step in plan.steps:
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments})
|
||||
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
|
||||
tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning})
|
||||
yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result})
|
||||
|
||||
yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."})
|
||||
output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token)
|
||||
yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": plan_tokens + synthesis_tokens})
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Plan-exec fallback also failed in stream mode: {e}")
|
||||
raise _FallbackFailedError("plan_exec")
|
||||
|
||||
async def _try_fallback_strategies(
|
||||
self,
|
||||
strategies: list[str],
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
effective_system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
trajectory: list[ReActStep] | None = None,
|
||||
total_tokens: int = 0,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult | None:
|
||||
"""按配置的 fallback 策略顺序尝试回退,返回第一个成功的结果
|
||||
|
||||
Returns:
|
||||
ReActResult if a fallback succeeded, None if all strategies exhausted
|
||||
"""
|
||||
for strategy in strategies:
|
||||
if strategy == "simplified_rewoo":
|
||||
result = await self._fallback_simplified_rewoo(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, cancellation_token=cancellation_token,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
elif strategy == "react":
|
||||
result = await self._fallback_react(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id, compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=confirmation_handler,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
elif strategy == "direct":
|
||||
result = await self._fallback_direct(
|
||||
messages=messages, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, cancellation_token=cancellation_token,
|
||||
trajectory=trajectory, total_tokens=total_tokens,
|
||||
trace_recorder=trace_recorder,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
elif strategy == "plan_exec":
|
||||
result = await self._fallback_plan_exec(
|
||||
messages=messages, tools=tools, model=model,
|
||||
agent_name=agent_name, task_type=task_type,
|
||||
effective_system_prompt=effective_system_prompt,
|
||||
compressor=compressor, cancellation_token=cancellation_token,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
async def _fallback_simplified_rewoo(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> ReActResult | None:
|
||||
"""Simplified ReWOO fallback: retry planning with max_steps=3"""
|
||||
logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3")
|
||||
try:
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
plan, simplified_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools or [],
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=3,
|
||||
)
|
||||
if plan is not None and plan.steps:
|
||||
logger.info("Simplified ReWOO planning succeeded")
|
||||
# Execute the simplified plan
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = simplified_tokens
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for plan_step in plan.steps:
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
|
||||
rewoo_step = ReWOOStep(
|
||||
step=plan_step.step_id,
|
||||
action="tool_call",
|
||||
tool_name=plan_step.tool_name,
|
||||
arguments=plan_step.arguments,
|
||||
result=tool_result,
|
||||
tokens=0,
|
||||
plan_step_id=plan_step.step_id,
|
||||
)
|
||||
trajectory.append(rewoo_step)
|
||||
tool_results.append({
|
||||
"step_id": plan_step.step_id,
|
||||
"tool_name": plan_step.tool_name,
|
||||
"arguments": plan_step.arguments,
|
||||
"result": tool_result,
|
||||
"reasoning": plan_step.reasoning,
|
||||
})
|
||||
|
||||
output, synthesis_tokens = await self._synthesis_phase(
|
||||
messages=messages, tool_results=tool_results,
|
||||
model=model, agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=effective_system_prompt, compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
total_tokens += synthesis_tokens
|
||||
trajectory.append(ReWOOStep(
|
||||
step=len(plan.steps) + 1,
|
||||
action="final_answer",
|
||||
content=output,
|
||||
tokens=synthesis_tokens,
|
||||
))
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy="simplified_rewoo",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Simplified ReWOO planning also failed: {e}")
|
||||
return None
|
||||
|
||||
async def _fallback_react(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
system_prompt: str | None = None,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
memory_retriever: "MemoryRetriever | None" = None,
|
||||
task_id: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
) -> ReActResult | None:
|
||||
"""ReAct fallback: delegate to ReActEngine"""
|
||||
logger.warning("ReWOO planning failed, falling back to ReActEngine")
|
||||
try:
|
||||
react_result = await self._react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=system_prompt,
|
||||
trace_recorder=trace_recorder,
|
||||
memory_retriever=memory_retriever,
|
||||
task_id=task_id,
|
||||
compressor=compressor,
|
||||
retrieval_config=retrieval_config,
|
||||
cancellation_token=cancellation_token,
|
||||
timeout_seconds=0, # timeout already handled by outer wrapper
|
||||
confirmation_handler=confirmation_handler,
|
||||
)
|
||||
react_result.fallback_strategy = "react"
|
||||
return react_result
|
||||
except Exception as e:
|
||||
logger.warning(f"ReAct fallback also failed: {e}")
|
||||
return None
|
||||
|
||||
async def _fallback_direct(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
trajectory: list[ReActStep] | None = None,
|
||||
total_tokens: int = 0,
|
||||
trace_recorder: "TraceRecorder | None" = None,
|
||||
) -> ReActResult | None:
|
||||
"""Direct fallback: simple LLM call without tools"""
|
||||
logger.warning("Falling back to direct LLM call")
|
||||
try:
|
||||
direct_messages: list[dict[str, Any]] = []
|
||||
if effective_system_prompt:
|
||||
direct_messages.append({"role": "system", "content": effective_system_prompt})
|
||||
direct_messages.extend(messages)
|
||||
|
||||
if compressor:
|
||||
try:
|
||||
direct_messages = await compressor.compress(direct_messages)
|
||||
except Exception as e:
|
||||
logger.warning(f"Context compression failed in direct fallback: {e}")
|
||||
|
||||
direct_response = await self._llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
)
|
||||
total_tokens += direct_response.usage.total_tokens
|
||||
|
||||
direct_step = ReWOOStep(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
content=direct_response.content,
|
||||
tokens=direct_response.usage.total_tokens,
|
||||
plan_step_id=None,
|
||||
)
|
||||
if trajectory is not None:
|
||||
trajectory.append(direct_step)
|
||||
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=1,
|
||||
action="final_answer",
|
||||
output_data={"content": direct_response.content},
|
||||
tokens_used=direct_response.usage.total_tokens,
|
||||
)
|
||||
trace_recorder.end_trace(outcome="success")
|
||||
|
||||
return ReActResult(
|
||||
output=direct_response.content or "",
|
||||
trajectory=trajectory or [direct_step],
|
||||
total_steps=len(trajectory or [direct_step]),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy="direct",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Direct LLM fallback also failed: {e}")
|
||||
return None
|
||||
|
||||
async def _fallback_plan_exec(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Tool] | None = None,
|
||||
model: str = "default",
|
||||
agent_name: str = "",
|
||||
task_type: str = "",
|
||||
effective_system_prompt: str | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
) -> ReActResult | None:
|
||||
"""Plan-Exec fallback: plan then execute sequentially (like simplified ReWOO but with max_steps=5)"""
|
||||
logger.warning("Falling back to plan-exec mode (max_steps=5)")
|
||||
try:
|
||||
tool_schemas = self._build_tool_schemas(tools) if tools else None
|
||||
plan, plan_tokens = await self._plan_phase(
|
||||
messages=messages,
|
||||
tools=tools or [],
|
||||
tool_schemas=tool_schemas,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
task_type=task_type,
|
||||
system_prompt=effective_system_prompt,
|
||||
compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
max_steps=5,
|
||||
)
|
||||
if plan is not None and plan.steps:
|
||||
trajectory: list[ReActStep] = []
|
||||
total_tokens = plan_tokens
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for plan_step in plan.steps:
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
|
||||
rewoo_step = ReWOOStep(
|
||||
step=plan_step.step_id,
|
||||
action="tool_call",
|
||||
tool_name=plan_step.tool_name,
|
||||
arguments=plan_step.arguments,
|
||||
result=tool_result,
|
||||
tokens=0,
|
||||
plan_step_id=plan_step.step_id,
|
||||
)
|
||||
trajectory.append(rewoo_step)
|
||||
tool_results.append({
|
||||
"step_id": plan_step.step_id,
|
||||
"tool_name": plan_step.tool_name,
|
||||
"arguments": plan_step.arguments,
|
||||
"result": tool_result,
|
||||
"reasoning": plan_step.reasoning,
|
||||
})
|
||||
|
||||
output, synthesis_tokens = await self._synthesis_phase(
|
||||
messages=messages, tool_results=tool_results,
|
||||
model=model, agent_name=agent_name, task_type=task_type,
|
||||
system_prompt=effective_system_prompt, compressor=compressor,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
total_tokens += synthesis_tokens
|
||||
trajectory.append(ReWOOStep(
|
||||
step=len(plan.steps) + 1,
|
||||
action="final_answer",
|
||||
content=output,
|
||||
tokens=synthesis_tokens,
|
||||
))
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=trajectory,
|
||||
total_steps=len(trajectory),
|
||||
total_tokens=total_tokens,
|
||||
fallback_strategy="plan_exec",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Plan-exec fallback also failed: {e}")
|
||||
return None
|
||||
|
||||
# ── Phase Implementations ─────────────────────────────
|
||||
|
||||
async def _plan_phase(
|
||||
|
|
@ -1079,7 +1552,6 @@ class ReWOOEngine:
|
|||
|
||||
# 尝试从 markdown 代码块中提取
|
||||
if "```" in json_str:
|
||||
import re
|
||||
code_block_match = re.search(r"```(?:json)?\s*\n(.*?)\n\s*```", json_str, re.DOTALL)
|
||||
if code_block_match:
|
||||
json_str = code_block_match.group(1).strip()
|
||||
|
|
|
|||
|
|
@ -21,6 +21,14 @@ class Bid:
|
|||
capabilities: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.estimated_cost < 0:
|
||||
raise ValueError(f"estimated_cost must be non-negative, got {self.estimated_cost}")
|
||||
if not (0.0 <= self.confidence <= 1.0):
|
||||
raise ValueError(f"confidence must be between 0.0 and 1.0, got {self.confidence}")
|
||||
if self.payment_offer < 0:
|
||||
raise ValueError(f"payment_offer must be non-negative, got {self.payment_offer}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuctionResult:
|
||||
|
|
@ -98,3 +106,125 @@ class AuctionHouse:
|
|||
wealth_factor = self._wealth_tracker.get_wealth_factor(bid.agent_name)
|
||||
score = (bid.confidence / max(bid.estimated_cost, 0.001)) * wealth_factor
|
||||
return score
|
||||
|
||||
def filter_by_capabilities(
|
||||
self, bidders: list[Bid], required_capabilities: list[str]
|
||||
) -> list[Bid]:
|
||||
"""Filter bidders whose capabilities include ALL required capabilities.
|
||||
|
||||
Args:
|
||||
bidders: List of Bid objects to filter.
|
||||
required_capabilities: Capabilities that a bidder must ALL possess.
|
||||
|
||||
Returns:
|
||||
List of Bids whose capabilities list includes every required capability.
|
||||
"""
|
||||
if not required_capabilities:
|
||||
return bidders
|
||||
required_set = {cap.lower() for cap in required_capabilities}
|
||||
return [
|
||||
b for b in bidders
|
||||
if required_set.issubset({cap.lower() for cap in b.capabilities})
|
||||
]
|
||||
|
||||
async def run_vickrey_auction(
|
||||
self,
|
||||
task_description: str,
|
||||
bidders: list[Bid],
|
||||
required_capabilities: list[str] | None = None,
|
||||
) -> AuctionResult:
|
||||
"""Run a Vickrey (second-price sealed-bid) auction.
|
||||
|
||||
In a Vickrey auction each bidder submits a sealed bid (their
|
||||
estimated_cost). The lowest estimated_cost wins, but the winner
|
||||
*pays* the second-lowest estimated_cost rather than their own.
|
||||
This is incentive-compatible: agents' dominant strategy is to bid
|
||||
truthfully.
|
||||
|
||||
Steps:
|
||||
1. Filter by required_capabilities (if provided).
|
||||
2. Filter out bankrupt agents.
|
||||
3. If only 1 eligible bidder → wins, pays 0.
|
||||
4. If 2+ eligible bidders → lowest cost wins, pays second-lowest.
|
||||
5. Update WealthTracker: winner earns (payment - cost_estimate).
|
||||
6. Return AuctionResult with selection_reason.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task being auctioned.
|
||||
bidders: List of Bid objects.
|
||||
required_capabilities: Optional list of capabilities that bidders
|
||||
must possess to be eligible.
|
||||
|
||||
Returns:
|
||||
AuctionResult with the winner and Vickrey outcome details.
|
||||
"""
|
||||
# 1. Capability filtering
|
||||
eligible = (
|
||||
self.filter_by_capabilities(bidders, required_capabilities)
|
||||
if required_capabilities
|
||||
else list(bidders)
|
||||
)
|
||||
|
||||
# 2. Filter out bankrupt agents
|
||||
eligible = [
|
||||
b for b in eligible
|
||||
if not self._wealth_tracker.is_bankrupt(b.agent_name)
|
||||
]
|
||||
|
||||
# No bidders at all
|
||||
if not bidders:
|
||||
return AuctionResult(
|
||||
winner=None,
|
||||
all_bids=bidders,
|
||||
selection_reason="No bidders participated",
|
||||
total_bidders=0,
|
||||
)
|
||||
|
||||
# All eligible bidders filtered out (bankrupt or no capabilities)
|
||||
if not eligible:
|
||||
return AuctionResult(
|
||||
winner=None,
|
||||
all_bids=bidders,
|
||||
selection_reason="No eligible bidders (bankrupt or missing capabilities)",
|
||||
total_bidders=len(bidders),
|
||||
)
|
||||
|
||||
# 3. Sort by estimated_cost ascending (lowest cost = best bid)
|
||||
# Apply minimum cost floor to prevent zero-cost bid manipulation
|
||||
MIN_COST = 0.001
|
||||
for b in eligible:
|
||||
b.estimated_cost = max(b.estimated_cost, MIN_COST)
|
||||
eligible = sorted(eligible, key=lambda b: b.estimated_cost)
|
||||
winner = eligible[0]
|
||||
|
||||
# 4. Determine payment (second-price rule)
|
||||
if len(eligible) == 1:
|
||||
# Single bidder: payment equals their own cost (no profit, no loss)
|
||||
payment = winner.estimated_cost
|
||||
else:
|
||||
payment = eligible[1].estimated_cost
|
||||
|
||||
# 5. Update WealthTracker
|
||||
profit = payment - winner.estimated_cost
|
||||
self._wealth_tracker.reward(winner.agent_name, profit)
|
||||
|
||||
# 6. Build selection_reason
|
||||
if len(eligible) == 1:
|
||||
reason = (
|
||||
f"Vickrey auction: Agent '{winner.agent_name}' won as sole eligible bidder "
|
||||
f"(cost={winner.estimated_cost}, payment={payment:.4f}, profit=0)"
|
||||
)
|
||||
else:
|
||||
second = eligible[1]
|
||||
reason = (
|
||||
f"Vickrey auction: Agent '{winner.agent_name}' won with lowest cost "
|
||||
f"({winner.estimated_cost}), pays second-lowest cost ({payment}) from "
|
||||
f"'{second.agent_name}'; profit={profit:.4f}"
|
||||
)
|
||||
|
||||
return AuctionResult(
|
||||
winner=winner,
|
||||
all_bids=bidders,
|
||||
selection_reason=reason,
|
||||
total_bidders=len(bidders),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
|
||||
|
||||
class WealthTracker:
|
||||
"""Track agent wealth for auction mechanism.
|
||||
|
|
@ -9,25 +11,31 @@ class WealthTracker:
|
|||
Agents earn wealth by completing tasks successfully.
|
||||
Agents lose wealth by failing tasks.
|
||||
Bankrupt agents (wealth <= -100) are excluded from auctions.
|
||||
|
||||
Thread-safe: all mutations are protected by a threading.Lock.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_wealth: float = 100.0) -> None:
|
||||
self._balances: dict[str, float] = {}
|
||||
self._initial_wealth = initial_wealth
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_wealth(self, agent_name: str) -> float:
|
||||
"""Get agent's current wealth, defaulting to initial_wealth"""
|
||||
return self._balances.get(agent_name, self._initial_wealth)
|
||||
with self._lock:
|
||||
return self._balances.get(agent_name, self._initial_wealth)
|
||||
|
||||
def reward(self, agent_name: str, amount: float) -> None:
|
||||
"""Reward agent for successful task completion"""
|
||||
current = self.get_wealth(agent_name)
|
||||
self._balances[agent_name] = current + amount
|
||||
with self._lock:
|
||||
current = self._balances.get(agent_name, self._initial_wealth)
|
||||
self._balances[agent_name] = current + amount
|
||||
|
||||
def penalize(self, agent_name: str, amount: float) -> None:
|
||||
"""Penalize agent for task failure"""
|
||||
current = self.get_wealth(agent_name)
|
||||
self._balances[agent_name] = current - amount
|
||||
with self._lock:
|
||||
current = self._balances.get(agent_name, self._initial_wealth)
|
||||
self._balances[agent_name] = current - amount
|
||||
|
||||
def is_bankrupt(self, agent_name: str) -> bool:
|
||||
"""Check if agent is bankrupt (wealth <= -100)"""
|
||||
|
|
@ -46,5 +54,5 @@ class WealthTracker:
|
|||
return all_agents
|
||||
|
||||
def get_wealth_factor(self, agent_name: str) -> float:
|
||||
"""Get wealth factor for scoring: 1.0 + (wealth / 1000.0)"""
|
||||
return 1.0 + (self.get_wealth(agent_name) / 1000.0)
|
||||
"""Get wealth factor for scoring: max(0.01, 1.0 + (wealth / 1000.0))"""
|
||||
return max(0.01, 1.0 + (self.get_wealth(agent_name) / 1000.0))
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from agentkit.router.intent import IntentRouter
|
|||
from agentkit.skills.base import Skill, SkillConfig
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
from agentkit.tools.skill_install import SkillInstallTool
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows, chat
|
||||
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||
|
|
@ -104,6 +105,8 @@ async def lifespan(app: FastAPI):
|
|||
if server_config is not None and server_config._config_path:
|
||||
server_config.on_change = lambda cfg: _on_config_change(app, cfg)
|
||||
server_config.watch_config()
|
||||
# Store event loop reference for thread-safe config reload
|
||||
app.state._event_loop = asyncio.get_running_loop()
|
||||
logger.info("Config hot-reload enabled")
|
||||
|
||||
# Start MCP servers if configured
|
||||
|
|
@ -132,7 +135,10 @@ async def lifespan(app: FastAPI):
|
|||
"你必须先使用搜索工具查找准确和最新的信息,然后再回答。"
|
||||
"中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。"
|
||||
"在能够搜索到真相的情况下,绝不猜测或编造答案。"
|
||||
"始终优先搜索而不是给出可能不正确的信息。"
|
||||
"始终优先搜索而不是给出可能不正确的信息。\n\n"
|
||||
"技能安装:当需要安装技能时,使用 skill_install 工具,不要用 shell 执行 npm install。"
|
||||
"skill_install 的 source 参数格式为 owner/repo@skill,例如 vercel-labs/skills@find-skills。"
|
||||
"如果不知道完整 source,先用 shell 执行 `npx skills search <name>` 搜索。"
|
||||
)
|
||||
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
|
||||
|
||||
|
|
@ -156,6 +162,10 @@ async def lifespan(app: FastAPI):
|
|||
}
|
||||
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
|
||||
agent._tool_registry.register(ShellTool())
|
||||
agent._tool_registry.register(SkillInstallTool(
|
||||
skill_registry=app.state.skill_registry,
|
||||
tool_registry=app.state.tool_registry,
|
||||
))
|
||||
agent._tool_registry.register(BaiduSearchTool())
|
||||
agent._tool_registry.register(WebSearchTool(**search_api_keys))
|
||||
agent._tool_registry.register(WebCrawlTool())
|
||||
|
|
@ -238,13 +248,15 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
|||
- Config version is incremented for audit tracking
|
||||
|
||||
Uses a lock to prevent concurrent config reloads from racing.
|
||||
Thread-safe: uses threading.Event for cross-thread signaling.
|
||||
Thread-safe: schedules reload onto the asyncio event loop via
|
||||
asyncio.run_coroutine_threadsafe() since watchfiles calls this
|
||||
from a non-asyncio thread.
|
||||
"""
|
||||
import threading
|
||||
|
||||
lock: asyncio.Lock = app.state._config_reload_lock
|
||||
lock = app.state._config_reload_lock
|
||||
|
||||
# Thread-safe: set pending flag via threading.Event or call_soon_threadsafe
|
||||
# Thread-safe: set pending flag via threading.Event
|
||||
if not hasattr(app.state, "_config_reload_event"):
|
||||
app.state._config_reload_event = threading.Event()
|
||||
|
||||
|
|
@ -310,12 +322,18 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
|
|||
|
||||
logger.info(f"Config reload complete (v{current_version})")
|
||||
|
||||
# Schedule the reload as a task (non-blocking for the watcher thread)
|
||||
# Schedule the reload onto the event loop via run_coroutine_threadsafe
|
||||
# (watchfiles calls _on_config_change from a non-asyncio thread)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(_reload())
|
||||
asyncio.run_coroutine_threadsafe(_reload(), loop)
|
||||
except RuntimeError:
|
||||
logger.warning("No running event loop, config reload deferred")
|
||||
# No running loop — try getting the stored event loop reference
|
||||
_loop = getattr(app.state, "_event_loop", None)
|
||||
if _loop is not None and not _loop.is_closed():
|
||||
asyncio.run_coroutine_threadsafe(_reload(), _loop)
|
||||
else:
|
||||
logger.warning("No running event loop, config reload deferred")
|
||||
|
||||
|
||||
def create_app(
|
||||
|
|
@ -482,6 +500,7 @@ def create_app(
|
|||
org_context=org_context,
|
||||
auction_enabled=auction_enabled,
|
||||
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
|
||||
merged_llm_classify=server_config.router.get("merged_llm_classify", True) if server_config and server_config.router else True,
|
||||
)
|
||||
app.state.cost_aware_router = cost_aware_router
|
||||
# Initialize task store from config
|
||||
|
|
|
|||
|
|
@ -193,7 +193,12 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
|
||||
# Execute the Agent
|
||||
try:
|
||||
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
||||
react_engine = getattr(agent, "_react_engine", None)
|
||||
if react_engine is None:
|
||||
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
||||
else:
|
||||
react_engine.reset()
|
||||
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
||||
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
||||
result = await react_engine.execute(
|
||||
|
|
@ -286,6 +291,10 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
|
||||
cancellation_token = CancellationToken()
|
||||
|
||||
# Per-session concurrency guard to prevent unlimited task creation (DoS mitigation)
|
||||
_MAX_CONCURRENT_TASKS = 4
|
||||
active_tasks: set[asyncio.Task] = set()
|
||||
|
||||
try:
|
||||
await websocket.send_json({"type": "connected", "session_id": session_id})
|
||||
|
||||
|
|
@ -308,9 +317,27 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
content = msg.get("content", "")
|
||||
# Create a fresh CancellationToken for each message
|
||||
message_token = CancellationToken()
|
||||
await _handle_chat_message(
|
||||
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
|
||||
|
||||
# Guard against unlimited concurrent tasks
|
||||
# Clean up completed tasks first
|
||||
active_tasks.difference_update(t for t in active_tasks if t.done())
|
||||
if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"data": {"message": "Too many concurrent requests. Please wait for the current task to complete."},
|
||||
})
|
||||
continue
|
||||
|
||||
# Run in background task so the WebSocket receive loop stays free
|
||||
# to process confirmation_reply / reply messages while the agent
|
||||
# is waiting for user confirmation (otherwise deadlock).
|
||||
task = asyncio.create_task(
|
||||
_handle_chat_message(
|
||||
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
|
||||
)
|
||||
)
|
||||
active_tasks.add(task)
|
||||
task.add_done_callback(active_tasks.discard)
|
||||
|
||||
elif msg_type == "reply":
|
||||
# Reply to AskHumanTool
|
||||
|
|
@ -323,8 +350,12 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
# Reply to confirmation request
|
||||
confirmation_id = msg.get("confirmation_id")
|
||||
approved = msg.get("approved", False)
|
||||
logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}")
|
||||
if confirmation_id and confirmation_id in pending_confirmations:
|
||||
pending_confirmations[confirmation_id].set_result(approved)
|
||||
logger.info(f"Confirmation {confirmation_id} set_result({approved})")
|
||||
else:
|
||||
logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations")
|
||||
|
||||
elif msg_type == "cancel":
|
||||
cancellation_token.cancel()
|
||||
|
|
@ -424,10 +455,17 @@ async def _handle_chat_message(
|
|||
chat_messages = await sm.get_chat_messages(session_id)
|
||||
|
||||
# Execute Agent with streaming
|
||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
||||
react_engine = getattr(agent, "_react_engine", None)
|
||||
if react_engine is None:
|
||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||
else:
|
||||
react_engine.reset()
|
||||
|
||||
# Create confirmation handler that sends request to frontend and waits for reply
|
||||
_pending_confirmations = pending_confirmations or {}
|
||||
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
|
||||
# and would create a new dict, breaking the shared state with the WS loop.
|
||||
_pending_confirmations = pending_confirmations if pending_confirmations is not None else {}
|
||||
|
||||
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
|
||||
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
|
||||
|
|
@ -442,16 +480,28 @@ async def _handle_chat_message(
|
|||
})
|
||||
|
||||
# Create a Future and wait for the user's reply
|
||||
loop = asyncio.get_event_loop()
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[bool] = loop.create_future()
|
||||
_pending_confirmations[confirmation_id] = future
|
||||
logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply")
|
||||
|
||||
try:
|
||||
# Wait up to 5 minutes for user confirmation
|
||||
return await asyncio.wait_for(future, timeout=300.0)
|
||||
result = await asyncio.wait_for(future, timeout=300.0)
|
||||
logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
|
||||
# Immediately notify frontend of the result so the card updates
|
||||
# without waiting for the tool to re-execute
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_result",
|
||||
"data": {"confirmation_id": confirmation_id, "approved": result},
|
||||
})
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Confirmation request {confirmation_id} timed out")
|
||||
return False
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(f"Confirmation request {confirmation_id} cancelled")
|
||||
return False
|
||||
finally:
|
||||
_pending_confirmations.pop(confirmation_id, None)
|
||||
|
||||
|
|
@ -459,6 +509,7 @@ async def _handle_chat_message(
|
|||
|
||||
try:
|
||||
final_content = ""
|
||||
token_buffer: list[str] = []
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=chat_messages,
|
||||
tools=routing.tools,
|
||||
|
|
@ -469,24 +520,51 @@ async def _handle_chat_message(
|
|||
confirmation_handler=_confirmation_handler,
|
||||
):
|
||||
if event.event_type == "final_answer":
|
||||
# Flush any buffered tokens as a single write
|
||||
if token_buffer:
|
||||
await websocket.send_json({"type": "token", "content": "".join(token_buffer)})
|
||||
token_buffer.clear()
|
||||
# Then send final answer
|
||||
final_content = event.data.get("output", "")
|
||||
await websocket.send_json({
|
||||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
"is_final": True,
|
||||
})
|
||||
elif event.event_type == "token":
|
||||
# Buffer tokens instead of sending immediately
|
||||
token_buffer.append(event.data.get("content", ""))
|
||||
elif event.event_type == "thinking":
|
||||
# If we have buffered tokens, convert them to a thinking event
|
||||
if token_buffer:
|
||||
buffered_text = "".join(token_buffer)
|
||||
token_buffer.clear()
|
||||
await websocket.send_json({"type": "thinking", "content": buffered_text})
|
||||
# Also send the thinking event content
|
||||
thinking_msg = event.data.get("message", "")
|
||||
if thinking_msg:
|
||||
await websocket.send_json({"type": "thinking", "content": thinking_msg})
|
||||
elif event.event_type == "tool_call":
|
||||
# Convert buffered tokens to thinking (they were "thinking" text before tool call)
|
||||
if token_buffer:
|
||||
buffered_text = "".join(token_buffer)
|
||||
token_buffer.clear()
|
||||
await websocket.send_json({"type": "thinking", "content": buffered_text})
|
||||
await websocket.send_json({
|
||||
"type": "step",
|
||||
"data": {
|
||||
"event_type": event.event_type,
|
||||
"step": event.step,
|
||||
"data": event.data,
|
||||
},
|
||||
})
|
||||
elif event.event_type == "confirmation_request":
|
||||
# Already handled by confirmation_handler, just notify frontend
|
||||
pass
|
||||
elif event.event_type == "confirmation_result":
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_result",
|
||||
"data": event.data,
|
||||
})
|
||||
elif event.event_type == "token":
|
||||
await websocket.send_json({
|
||||
"type": "token",
|
||||
"content": event.data.get("content", ""),
|
||||
})
|
||||
else:
|
||||
await websocket.send_json({
|
||||
"type": "step",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:ital,wght@0,300;0,400;0,500;0,600;0,700;1,400&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked@12/marked.min.js"></script>
|
||||
<style>
|
||||
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
|
||||
:root{
|
||||
|
|
@ -88,6 +89,33 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
|
|||
.msg .bubble{padding:12px 18px;font-size:14px;line-height:1.7;white-space:pre-wrap;word-break:break-word;position:relative}
|
||||
.msg.user .bubble{background:var(--user-bg);color:var(--user-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-sm) var(--radius-lg);box-shadow:0 2px 8px rgba(59,91,219,.2)}
|
||||
.msg.agent .bubble{background:var(--surface);color:var(--agent-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);border:1px solid var(--border-light);box-shadow:var(--shadow-xs)}
|
||||
|
||||
/* ── Markdown in Agent Bubbles ──────────────────────────────── */
|
||||
.msg.agent .bubble.md{white-space:normal}
|
||||
.msg.agent .bubble.md p{margin:0 0 .6em}
|
||||
.msg.agent .bubble.md p:last-child{margin-bottom:0}
|
||||
.msg.agent .bubble.md h1,.msg.agent .bubble.md h2,.msg.agent .bubble.md h3,.msg.agent .bubble.md h4{margin:.8em 0 .4em;font-weight:700;letter-spacing:-.3px;line-height:1.3}
|
||||
.msg.agent .bubble.md h1{font-size:1.4em}
|
||||
.msg.agent .bubble.md h2{font-size:1.2em}
|
||||
.msg.agent .bubble.md h3{font-size:1.05em}
|
||||
.msg.agent .bubble.md h4{font-size:1em}
|
||||
.msg.agent .bubble.md ul,.msg.agent .bubble.md ol{margin:.4em 0 .6em 1.5em;padding:0}
|
||||
.msg.agent .bubble.md li{margin:.2em 0}
|
||||
.msg.agent .bubble.md li>p{margin:.2em 0}
|
||||
.msg.agent .bubble.md blockquote{margin:.5em 0;padding:.4em 1em;border-left:3px solid var(--primary);background:var(--primary-light);border-radius:0 var(--radius-sm) var(--radius-sm) 0;color:var(--text2)}
|
||||
.msg.agent .bubble.md blockquote p{margin:.2em 0}
|
||||
.msg.agent .bubble.md code{font-family:'SF Mono',SFMono-Regular,Consolas,'Liberation Mono',Menlo,monospace;font-size:.88em;background:var(--surface2);padding:.15em .4em;border-radius:4px;color:var(--text)}
|
||||
.msg.agent .bubble.md pre{margin:.5em 0;padding:12px 14px;background:var(--surface2);border:1px solid var(--border-light);border-radius:var(--radius-sm);overflow-x:auto;line-height:1.5}
|
||||
.msg.agent .bubble.md pre code{background:none;padding:0;font-size:.85em;color:var(--text)}
|
||||
.msg.agent .bubble.md table{border-collapse:collapse;margin:.5em 0;width:100%;font-size:.9em}
|
||||
.msg.agent .bubble.md th,.msg.agent .bubble.md td{border:1px solid var(--border);padding:6px 10px;text-align:left}
|
||||
.msg.agent .bubble.md th{background:var(--surface2);font-weight:600}
|
||||
.msg.agent .bubble.md hr{border:none;border-top:1px solid var(--border);margin:.8em 0}
|
||||
.msg.agent .bubble.md a{color:var(--primary);text-decoration:none}
|
||||
.msg.agent .bubble.md a:hover{text-decoration:underline}
|
||||
.msg.agent .bubble.md img{max-width:100%;border-radius:var(--radius-sm)}
|
||||
.msg.agent .bubble.md strong{font-weight:700}
|
||||
.msg.agent .bubble.md em{font-style:italic}
|
||||
.msg .meta{font-size:11px;color:var(--text3);margin-top:5px;padding:0 4px;font-weight:500}
|
||||
.msg.user .meta{text-align:right}
|
||||
.typing-indicator{display:inline-flex;gap:5px;padding:6px 0}
|
||||
|
|
@ -164,6 +192,10 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
|
|||
@keyframes fadeIn{from{opacity:0}to{opacity:1}}
|
||||
@keyframes slideInRight{from{opacity:0;transform:translateX(12px)}to{opacity:1;transform:translateX(0)}}
|
||||
|
||||
/* ── Thinking ──────────────────────────────────────────────────── */
|
||||
.thinking-msg{display:flex;flex-direction:column;max-width:72%;animation:msgIn .35s cubic-bezier(.16,1,.3,1);align-self:flex-start}
|
||||
.thinking-msg .bubble{padding:10px 16px;font-size:13px;line-height:1.6;color:var(--text3);font-style:italic;background:var(--surface2);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);border:1px dashed var(--border)}
|
||||
|
||||
/* ── Loading indicator ───────────────────────────────────────── */
|
||||
.loading-msg{display:flex;flex-direction:column;max-width:72%;align-self:flex-start;animation:msgIn .35s cubic-bezier(.16,1,.3,1)}
|
||||
.loading-msg .loading-bubble{padding:8px 20px;background:var(--surface);border:1px solid var(--border-light);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);display:inline-flex;align-items:center;gap:6px;box-shadow:var(--shadow-xs)}
|
||||
|
|
@ -172,6 +204,26 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
|
|||
.loading-msg .loading-dot:nth-child(3){animation-delay:.4s}
|
||||
@keyframes loadingDot{0%,80%,100%{transform:scale(.4);opacity:.3}40%{transform:scale(1);opacity:1}}
|
||||
|
||||
/* ── Confirmation Card ──────────────────────────────────────── */
|
||||
.confirm-card{display:flex;flex-direction:column;max-width:72%;align-self:flex-start;animation:msgIn .35s cubic-bezier(.16,1,.3,1)}
|
||||
.confirm-card .card-inner{background:var(--surface);border:1px solid var(--warning);border-radius:var(--radius-lg);box-shadow:var(--shadow-md);overflow:hidden}
|
||||
.confirm-card .card-header{display:flex;align-items:center;gap:10px;padding:14px 18px 10px;border-bottom:1px solid var(--border-light)}
|
||||
.confirm-card .card-icon{width:32px;height:32px;border-radius:var(--radius-sm);background:#fef3c7;display:flex;align-items:center;justify-content:center;font-size:16px;flex-shrink:0}
|
||||
.confirm-card .card-title{font-size:14px;font-weight:600;color:var(--text);letter-spacing:-0.2px}
|
||||
.confirm-card .card-body{padding:12px 18px 14px}
|
||||
.confirm-card .card-command{background:var(--surface2);border:1px solid var(--border-light);border-radius:var(--radius-sm);padding:10px 14px;font-family:'SF Mono',SFMono-Regular,Consolas,'Liberation Mono',Menlo,monospace;font-size:13px;color:var(--text);word-break:break-all;line-height:1.6;margin-bottom:10px}
|
||||
.confirm-card .card-reason{font-size:13px;color:var(--text2);line-height:1.5;margin-bottom:4px}
|
||||
.confirm-card .card-actions{display:flex;gap:8px;padding:0 18px 14px}
|
||||
.confirm-card .btn-confirm{flex:1;padding:9px 16px;border-radius:var(--radius-sm);font-size:13px;font-weight:600;cursor:pointer;transition:all .2s;border:1px solid;font-family:var(--font)}
|
||||
.confirm-card .btn-confirm.approve{background:var(--primary-light);color:var(--primary);border-color:var(--primary-subtle)}
|
||||
.confirm-card .btn-confirm.approve:hover{background:var(--primary);color:#fff;border-color:var(--primary)}
|
||||
.confirm-card .btn-confirm.deny{background:var(--danger-light);color:var(--danger);border-color:#fecaca}
|
||||
.confirm-card .btn-confirm.deny:hover{background:var(--danger);color:#fff;border-color:var(--danger)}
|
||||
.confirm-card .btn-confirm:disabled{opacity:.5;cursor:not-allowed;transform:none}
|
||||
.confirm-card .card-status{padding:10px 18px 14px;font-size:13px;font-weight:500;display:flex;align-items:center;gap:6px}
|
||||
.confirm-card .card-status.approved{color:var(--success);background:var(--success-light)}
|
||||
.confirm-card .card-status.denied{color:var(--danger);background:var(--danger-light)}
|
||||
|
||||
/* ── Mobile ──────────────────────────────────────────────────── */
|
||||
@media(max-width:768px){
|
||||
.sidebar{position:fixed;left:-100%;z-index:10;transition:left .3s cubic-bezier(.16,1,.3,1);width:85vw;max-width:320px;box-shadow:var(--shadow-lg)}
|
||||
|
|
@ -270,11 +322,41 @@ let activeSessionId = null;
|
|||
let ws = null;
|
||||
let isStreaming = false;
|
||||
let currentAgentBubble = null;
|
||||
let currentAgentRawText = ''; // accumulate raw text during streaming
|
||||
let currentThinkingBubble = null;
|
||||
let loadingEl = null;
|
||||
let skills = [];
|
||||
const API = '/api/v1/chat';
|
||||
const SKILLS_API = '/api/v1/skills';
|
||||
|
||||
// ── Markdown helpers ──────────────────────────────────────────
|
||||
const _MD_RE = /(?:^#{1,4}\s|^\s*[-*+]\s|^\s*\d+\.\s|```|\*\*|__|\[.*?\]\(|^\s*>)/m;
|
||||
|
||||
function isMarkdown(text) {
|
||||
if (!text) return false;
|
||||
return _MD_RE.test(text);
|
||||
}
|
||||
|
||||
function renderMarkdown(text) {
|
||||
if (!text || !text.trim()) return '';
|
||||
try {
|
||||
return marked.parse(text, { breaks: true, gfm: true });
|
||||
} catch {
|
||||
return esc(text);
|
||||
}
|
||||
}
|
||||
|
||||
function renderBubbleContent(bubble, text, forceMd) {
|
||||
const md = forceMd || isMarkdown(text);
|
||||
if (md) {
|
||||
bubble.classList.add('md');
|
||||
bubble.innerHTML = renderMarkdown(text);
|
||||
} else {
|
||||
bubble.classList.remove('md');
|
||||
bubble.textContent = text;
|
||||
}
|
||||
}
|
||||
|
||||
// ── API helpers ────────────────────────────────────────────────
|
||||
async function api(base, path, opts = {}) {
|
||||
const res = await fetch(base + path, {
|
||||
|
|
@ -393,35 +475,66 @@ function handleWsMessage(msg) {
|
|||
removeLoading();
|
||||
appendStep(`路由: ${msg.skill || 'default'} (${msg.method}, ${Math.round((msg.confidence || 0) * 100)}%)`);
|
||||
break;
|
||||
case 'thinking':
|
||||
removeLoading();
|
||||
if (!currentThinkingBubble) {
|
||||
currentThinkingBubble = appendThinking();
|
||||
}
|
||||
const thinkingContent = msg.content || '';
|
||||
if (currentThinkingBubble) {
|
||||
currentThinkingBubble.textContent += thinkingContent;
|
||||
}
|
||||
scrollToBottom();
|
||||
break;
|
||||
case 'token':
|
||||
removeLoading();
|
||||
// Clear any thinking bubble when real content starts
|
||||
if (currentThinkingBubble) {
|
||||
currentThinkingBubble.remove();
|
||||
currentThinkingBubble = null;
|
||||
}
|
||||
if (!currentAgentBubble) {
|
||||
currentAgentBubble = appendMessage('agent', '');
|
||||
currentAgentRawText = '';
|
||||
isStreaming = true;
|
||||
updateSendBtn();
|
||||
}
|
||||
currentAgentBubble.textContent += msg.content || '';
|
||||
currentAgentRawText += msg.content || '';
|
||||
// During streaming, show raw text for performance
|
||||
currentAgentBubble.textContent = currentAgentRawText;
|
||||
scrollToBottom();
|
||||
break;
|
||||
case 'final_answer':
|
||||
removeLoading();
|
||||
// Clear thinking bubble
|
||||
if (currentThinkingBubble) {
|
||||
currentThinkingBubble.remove();
|
||||
currentThinkingBubble = null;
|
||||
}
|
||||
if (currentAgentBubble) {
|
||||
const current = currentAgentBubble.textContent || '';
|
||||
const final = msg.content || '';
|
||||
if (!current.trim() || final.length > current.length) {
|
||||
currentAgentBubble.textContent = final;
|
||||
}
|
||||
// Use the longer/more complete text
|
||||
const bestText = (!currentAgentRawText.trim() || final.length > currentAgentRawText.length) ? final : currentAgentRawText;
|
||||
renderBubbleContent(currentAgentBubble, bestText);
|
||||
currentAgentBubble = null;
|
||||
currentAgentRawText = '';
|
||||
} else {
|
||||
appendMessage('agent', msg.content || '');
|
||||
}
|
||||
isStreaming = false;
|
||||
updateSendBtn();
|
||||
scrollToBottom();
|
||||
// Auto-refresh skill list after agent finishes (may have installed new skills)
|
||||
loadSkills();
|
||||
break;
|
||||
case 'step':
|
||||
removeLoading();
|
||||
if (msg.data?.event_type === 'tool_call') {
|
||||
// Replace thinking with tool step
|
||||
if (currentThinkingBubble) {
|
||||
currentThinkingBubble.remove();
|
||||
currentThinkingBubble = null;
|
||||
}
|
||||
appendStep(`使用工具: ${msg.data?.data?.tool_name || 'tool'}`);
|
||||
}
|
||||
break;
|
||||
|
|
@ -431,8 +544,20 @@ function handleWsMessage(msg) {
|
|||
appendStep(`技能: ${msg.data.skill} (${msg.data.method}, ${Math.round((msg.data.confidence || 0) * 100)}%)`);
|
||||
}
|
||||
break;
|
||||
case 'confirmation_request':
|
||||
removeLoading();
|
||||
showConfirmationCard(msg.data);
|
||||
break;
|
||||
case 'confirmation_result':
|
||||
updateConfirmationCard(msg.data);
|
||||
break;
|
||||
case 'error':
|
||||
removeLoading();
|
||||
// Clear thinking bubble on error
|
||||
if (currentThinkingBubble) {
|
||||
currentThinkingBubble.remove();
|
||||
currentThinkingBubble = null;
|
||||
}
|
||||
appendMessage('agent', `[错误] ${msg.data?.message || '未知错误'}`);
|
||||
currentAgentBubble = null;
|
||||
isStreaming = false;
|
||||
|
|
@ -462,6 +587,71 @@ function removeLoading() {
|
|||
}
|
||||
}
|
||||
|
||||
// ── Confirmation Card ──────────────────────────────────────────
|
||||
function showConfirmationCard(data) {
|
||||
hideWelcome();
|
||||
const container = document.getElementById('messages');
|
||||
const confirmationId = data.confirmation_id;
|
||||
const command = data.command || '';
|
||||
const reason = data.reason || '';
|
||||
|
||||
const div = document.createElement('div');
|
||||
div.className = 'confirm-card';
|
||||
div.id = `confirm-${confirmationId}`;
|
||||
div.innerHTML = `
|
||||
<div class="card-inner">
|
||||
<div class="card-header">
|
||||
<div class="card-icon">⚠</div>
|
||||
<div class="card-title">操作确认</div>
|
||||
</div>
|
||||
<div class="card-body">
|
||||
<div class="card-command">${esc(command)}</div>
|
||||
${reason ? `<div class="card-reason">${esc(reason)}</div>` : ''}
|
||||
</div>
|
||||
<div class="card-actions" id="confirm-actions-${confirmationId}">
|
||||
<button class="btn-confirm approve" onclick="replyConfirmation('${confirmationId}', true)">确认执行</button>
|
||||
<button class="btn-confirm deny" onclick="replyConfirmation('${confirmationId}', false)">拒绝</button>
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
container.appendChild(div);
|
||||
scrollToBottom();
|
||||
}
|
||||
|
||||
function replyConfirmation(confirmationId, approved) {
|
||||
// Disable buttons immediately
|
||||
const actionsEl = document.getElementById(`confirm-actions-${confirmationId}`);
|
||||
if (actionsEl) {
|
||||
const buttons = actionsEl.querySelectorAll('.btn-confirm');
|
||||
buttons.forEach(btn => { btn.disabled = true; });
|
||||
}
|
||||
|
||||
// Send reply to server
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({
|
||||
type: 'confirmation_reply',
|
||||
confirmation_id: confirmationId,
|
||||
approved: approved,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
function updateConfirmationCard(data) {
|
||||
const confirmationId = data.confirmation_id;
|
||||
const approved = data.approved;
|
||||
const cardEl = document.getElementById(`confirm-${confirmationId}`);
|
||||
if (!cardEl) return;
|
||||
|
||||
// Replace actions with status
|
||||
const actionsEl = document.getElementById(`confirm-actions-${confirmationId}`);
|
||||
if (actionsEl) {
|
||||
const statusClass = approved ? 'approved' : 'denied';
|
||||
const statusText = approved ? '✓ 已确认执行' : '✗ 已拒绝';
|
||||
actionsEl.outerHTML = `<div class="card-status ${statusClass}">${statusText}</div>`;
|
||||
}
|
||||
scrollToBottom();
|
||||
}
|
||||
|
||||
// ── Send message ───────────────────────────────────────────────
|
||||
async function sendMessage() {
|
||||
const input = document.getElementById('input');
|
||||
|
|
@ -523,7 +713,12 @@ function appendMessage(role, content) {
|
|||
div.className = `msg ${cssRole}`;
|
||||
const bubble = document.createElement('div');
|
||||
bubble.className = 'bubble';
|
||||
bubble.textContent = content;
|
||||
// Render markdown for agent messages
|
||||
if (cssRole === 'agent' && content) {
|
||||
renderBubbleContent(bubble, content);
|
||||
} else {
|
||||
bubble.textContent = content;
|
||||
}
|
||||
div.appendChild(bubble);
|
||||
const meta = document.createElement('div');
|
||||
meta.className = 'meta';
|
||||
|
|
@ -548,6 +743,18 @@ function appendStep(text) {
|
|||
scrollToBottom();
|
||||
}
|
||||
|
||||
function appendThinking() {
|
||||
hideWelcome();
|
||||
const container = document.getElementById('messages');
|
||||
const div = document.createElement('div');
|
||||
div.className = 'thinking-msg';
|
||||
const bubble = document.createElement('div');
|
||||
bubble.className = 'bubble';
|
||||
div.appendChild(bubble);
|
||||
container.appendChild(div);
|
||||
return bubble;
|
||||
}
|
||||
|
||||
function renderHistory(msgs) {
|
||||
const container = document.getElementById('messages');
|
||||
container.innerHTML = '';
|
||||
|
|
@ -672,7 +879,7 @@ async function installSkill() {
|
|||
status.textContent = `自动安装失败,正在请求智能体协助...`;
|
||||
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
const installMsg = `请帮我安装一个名为"${name}"的技能。请按以下步骤操作:1. 使用搜索工具在网上搜索 "${name}" 的 YAML 配置文件(可在技能市场、GitHub 等平台搜索);2. 如果找到了,使用 shell 工具将其下载到 configs/skills/${name}.yaml;3. 下载完成后,使用 shell 工具执行 curl 命令调用 API 注册:curl -X POST http://localhost:${location.port}/api/v1/skills/install -H 'Content-Type: application/json' -d '{"name":"${name}","source":"file://configs/skills/${name}.yaml"}';4. 如果找不到这个技能,请告诉我。`;
|
||||
const installMsg = `请帮我安装一个名为"${name}"的技能。请使用 skill_install 工具安装,source 参数格式为 owner/repo@skill。如果不知道完整 source,先用 shell 执行 npx skills search ${name} 搜索。`;
|
||||
appendMessage('user', installMsg);
|
||||
ws.send(JSON.stringify({ type: 'message', content: installMsg }));
|
||||
currentAgentBubble = null;
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from typing import Any
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
|
|
@ -31,8 +31,8 @@ class AsyncWriteQueue:
|
|||
self._store = store
|
||||
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
|
||||
self._worker: asyncio.Task | None = None
|
||||
# Pending buffer: session_id -> list of Messages not yet persisted
|
||||
self._pending_buffer: dict[str, list[Message]] = defaultdict(list)
|
||||
# Pending buffer: session_id -> deque of Messages not yet persisted
|
||||
self._pending_buffer: dict[str, deque[Message]] = defaultdict(deque)
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._pending_count = 0
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class AsyncWriteQueue:
|
|||
pass
|
||||
if not buf:
|
||||
self._pending_buffer.pop(message.session_id, None)
|
||||
self._pending_count -= 1
|
||||
self._pending_count = max(0, self._pending_count - 1)
|
||||
self._queue.task_done()
|
||||
|
||||
def enqueue(self, message: Message, session: Session) -> None:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import os
|
|||
import re
|
||||
import shlex
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
|
|
@ -24,46 +25,135 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
# 安全白名单:这些命令前缀不需要确认
|
||||
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||
# 文件浏览
|
||||
"cd",
|
||||
"export",
|
||||
"ls",
|
||||
"grep",
|
||||
"find",
|
||||
"pwd",
|
||||
"find",
|
||||
"file",
|
||||
"stat",
|
||||
"tree",
|
||||
"du",
|
||||
"wc",
|
||||
# 文本查看/处理
|
||||
"cat",
|
||||
"head",
|
||||
"tail",
|
||||
"less",
|
||||
"more",
|
||||
"grep",
|
||||
"egrep",
|
||||
"fgrep",
|
||||
"sort",
|
||||
"uniq",
|
||||
"diff",
|
||||
"comm",
|
||||
"cut",
|
||||
"tr",
|
||||
"awk",
|
||||
"sed", # sed without -i is read-only; -i is caught by operator check
|
||||
"tee",
|
||||
"xargs",
|
||||
# 输出
|
||||
"echo",
|
||||
"which",
|
||||
"printf",
|
||||
# 系统信息
|
||||
"whoami",
|
||||
"id",
|
||||
"date",
|
||||
"uname",
|
||||
"hostname",
|
||||
"uptime",
|
||||
"df",
|
||||
"du",
|
||||
"free",
|
||||
"ps",
|
||||
"top",
|
||||
"htop",
|
||||
"ps",
|
||||
"env",
|
||||
"printenv",
|
||||
"type",
|
||||
"file",
|
||||
"stat",
|
||||
"wc",
|
||||
"sort",
|
||||
"uniq",
|
||||
"diff",
|
||||
"sleep",
|
||||
"which",
|
||||
"whereis",
|
||||
"arch",
|
||||
"lscpu",
|
||||
"lsblk",
|
||||
"mount",
|
||||
# 网络(只读查询)
|
||||
"curl", # curl GET is safe; POST/PUT caught by operator check
|
||||
"wget", # wget download is generally safe
|
||||
"ping",
|
||||
"traceroute",
|
||||
"dig",
|
||||
"nslookup",
|
||||
"host",
|
||||
"ifconfig",
|
||||
"ip",
|
||||
"netstat",
|
||||
"ss",
|
||||
"lsof",
|
||||
# 版本/帮助
|
||||
"python --version",
|
||||
"python3 --version",
|
||||
"node --version",
|
||||
"npm --version",
|
||||
"npm list",
|
||||
"npm view",
|
||||
"npm info",
|
||||
"npm search",
|
||||
"npx --version",
|
||||
"pip list",
|
||||
"pip show",
|
||||
"pip search",
|
||||
"java --version",
|
||||
"go version",
|
||||
"rustc --version",
|
||||
"cargo --version",
|
||||
"git --version",
|
||||
"docker --version",
|
||||
"docker ps",
|
||||
"docker images",
|
||||
"docker logs",
|
||||
"docker inspect",
|
||||
# Git 只读
|
||||
"git status",
|
||||
"git log",
|
||||
"git diff",
|
||||
"git branch",
|
||||
"git remote",
|
||||
"pip list",
|
||||
"pip show",
|
||||
"python --version",
|
||||
"python3 --version",
|
||||
"node --version",
|
||||
"npm list",
|
||||
"docker ps",
|
||||
"docker images",
|
||||
"git show",
|
||||
"git tag",
|
||||
"git stash list",
|
||||
"git config --list",
|
||||
# 包管理器(只读查询)
|
||||
"brew list",
|
||||
"brew info",
|
||||
"brew search",
|
||||
"apt list",
|
||||
"apt search",
|
||||
"apt show",
|
||||
"yum list",
|
||||
"yum search",
|
||||
"yum info",
|
||||
# 其他安全命令
|
||||
"export",
|
||||
"sleep",
|
||||
"true",
|
||||
"false",
|
||||
"test",
|
||||
"seq",
|
||||
"basename",
|
||||
"dirname",
|
||||
"realpath",
|
||||
"readlink",
|
||||
"md5sum",
|
||||
"sha256sum",
|
||||
"shasum",
|
||||
"openssl", # openssl dgst / rand are safe
|
||||
# tar 解压查看
|
||||
"tar -tf",
|
||||
"tar --list",
|
||||
"zipinfo",
|
||||
"unzip -l",
|
||||
)
|
||||
|
||||
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
|
||||
|
|
@ -95,7 +185,8 @@ _DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
|
|||
]
|
||||
|
||||
|
||||
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
|
||||
_SHELL_PIPE_OPERATORS = re.compile(r'\|')
|
||||
_SHELL_CHAIN_OPERATORS = re.compile(r'[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n')
|
||||
|
||||
|
||||
class ShellTool(Tool):
|
||||
|
|
@ -225,21 +316,26 @@ class ShellTool(Tool):
|
|||
session_id = kwargs.get("session_id")
|
||||
interactive = kwargs.get("interactive", False)
|
||||
|
||||
# 安全检查:危险命令需要确认
|
||||
if self._is_dangerous(command):
|
||||
# 安全检查:危险命令需要确认(除非已通过 _skip_dangerous_check 跳过)
|
||||
skip_dangerous = kwargs.get("_skip_dangerous_check", False)
|
||||
if not skip_dangerous and self._is_dangerous(command):
|
||||
confirmed = await self._request_confirmation(command)
|
||||
if not confirmed:
|
||||
self._log_audit(command, None, blocked=True)
|
||||
# 返回确认请求,由上层(ReActEngine/chat)处理
|
||||
confirmation_id = str(uuid.uuid4())
|
||||
return {
|
||||
"output": "",
|
||||
"needs_confirmation": True,
|
||||
"confirmation_id": confirmation_id,
|
||||
"command": command[:500],
|
||||
"reason": "此命令被识别为潜在危险操作,需要用户确认",
|
||||
"suggestions": [
|
||||
"确认执行此命令",
|
||||
"拒绝执行此命令",
|
||||
],
|
||||
"output": f"命令被拒绝: {command[:100]}",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"危险命令已被拒绝执行: {command[:100]}",
|
||||
"suggestions": [
|
||||
"如需执行此命令,请手动确认",
|
||||
"考虑使用更安全的替代命令",
|
||||
],
|
||||
}
|
||||
|
||||
# 根据模式执行
|
||||
|
|
@ -362,12 +458,33 @@ class ShellTool(Tool):
|
|||
def _is_dangerous(self, command: str) -> bool:
|
||||
"""检查命令是否为危险操作
|
||||
|
||||
白名单命令直接放行,其他命令检查是否匹配危险模式。
|
||||
白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。
|
||||
其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。
|
||||
"""
|
||||
command_stripped = command.strip()
|
||||
|
||||
# Check for shell operators that chain commands (always dangerous)
|
||||
if _SHELL_OPERATORS.search(command_stripped):
|
||||
# Check for dangerous chain operators (;, &&, ||, $(), backticks, redirections, newlines)
|
||||
if _SHELL_CHAIN_OPERATORS.search(command_stripped):
|
||||
return True
|
||||
|
||||
# Handle pipe commands: split and check each sub-command
|
||||
if _SHELL_PIPE_OPERATORS.search(command_stripped):
|
||||
parts = command_stripped.split('|')
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if self._is_single_command_dangerous(part):
|
||||
return True
|
||||
return False # All pipe segments are safe
|
||||
|
||||
# Single command
|
||||
return self._is_single_command_dangerous(command_stripped)
|
||||
|
||||
def _is_single_command_dangerous(self, command: str) -> bool:
|
||||
"""Check if a single command (no pipes/chains) is dangerous."""
|
||||
command_stripped = command.strip()
|
||||
if not command_stripped:
|
||||
return True
|
||||
|
||||
# Parse the actual binary being invoked
|
||||
|
|
@ -377,7 +494,6 @@ class ShellTool(Tool):
|
|||
return True
|
||||
binary = os.path.basename(tokens[0])
|
||||
except ValueError:
|
||||
# Unparsable command - treat as dangerous
|
||||
return True
|
||||
|
||||
# Whitelist check: first try full command prefix match, then binary-only match
|
||||
|
|
@ -405,6 +521,9 @@ class ShellTool(Tool):
|
|||
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
|
||||
if flag_pattern in cmd_str:
|
||||
return True
|
||||
# Binary has dangerous flags but none matched — treat as safe
|
||||
# (e.g., "git add" is safe even though "git push --force" is not)
|
||||
return False
|
||||
|
||||
# 3. Cross-token dangerous patterns (regex)
|
||||
command_lower = command_stripped.lower()
|
||||
|
|
@ -412,7 +531,10 @@ class ShellTool(Tool):
|
|||
if pattern.search(command_lower):
|
||||
return True
|
||||
|
||||
return True # Unknown commands are dangerous by default
|
||||
# 4. Unknown binary — check if it looks like a path or known safe pattern
|
||||
# Commands like /usr/bin/python3, ./script.sh, etc. are not in whitelist
|
||||
# but may be safe. Default to requiring confirmation for truly unknown binaries.
|
||||
return True
|
||||
|
||||
async def _request_confirmation(self, command: str) -> bool:
|
||||
"""请求人工确认危险命令
|
||||
|
|
|
|||
|
|
@ -0,0 +1,169 @@
|
|||
"""SkillInstallTool - Agent 可调用的技能安装工具"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillInstallTool(Tool):
|
||||
"""技能安装工具
|
||||
|
||||
让 Agent 可以通过正确的命令安装技能,而不是用 shell 执行 npm install。
|
||||
使用 `npx skills install <owner/repo@skill>` 安装技能,然后注册到 skill_registry。
|
||||
|
||||
Usage:
|
||||
tool = SkillInstallTool()
|
||||
result = await tool.execute(name="find-skills", source="vercel-labs/skills@find-skills")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "skill_install",
|
||||
description: str = "安装 Agent 技能包。使用 npx skills install 安装指定技能,不要用 npm install。",
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
|
||||
skill_registry=None,
|
||||
tool_registry=None,
|
||||
):
|
||||
schema = input_schema or {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "技能名称,如 find-skills",
|
||||
},
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "技能来源,格式为 owner/repo@skill,如 vercel-labs/skills@find-skills。如果不提供则用 name 搜索",
|
||||
},
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=schema,
|
||||
output_schema=output_schema,
|
||||
version=version,
|
||||
tags=tags,
|
||||
)
|
||||
self._confirm_callback = confirm_callback
|
||||
self._skill_registry = skill_registry
|
||||
self._tool_registry = tool_registry
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
name = kwargs.get("name", "").strip()
|
||||
source = kwargs.get("source", "").strip()
|
||||
|
||||
if not name:
|
||||
return {
|
||||
"output": "错误: 必须提供 name 参数",
|
||||
"exit_code": 1,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
# Build the install command
|
||||
if source:
|
||||
install_target = source
|
||||
else:
|
||||
install_target = name
|
||||
|
||||
# Request confirmation before installing
|
||||
if self._confirm_callback:
|
||||
confirmed = await self._confirm_callback(f"npx skills install {install_target}")
|
||||
if not confirmed:
|
||||
return {
|
||||
"output": f"技能安装被用户拒绝: {install_target}",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"npx", "skills@latest", "install", install_target, "-y",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||
output = stdout.decode("utf-8", errors="replace")
|
||||
error_output = stderr.decode("utf-8", errors="replace")
|
||||
|
||||
if proc.returncode == 0:
|
||||
# Try to register the skill into skill_registry
|
||||
registration_msg = ""
|
||||
if self._skill_registry:
|
||||
registration_msg = self._try_register_skill(name)
|
||||
|
||||
return {
|
||||
"output": f"技能 {name} 安装成功\n{output}\n{registration_msg}",
|
||||
"exit_code": 0,
|
||||
"is_error": False,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"output": f"技能 {name} 安装失败 (exit={proc.returncode})\n{error_output}\n{output}",
|
||||
"exit_code": proc.returncode,
|
||||
"is_error": True,
|
||||
"suggestions": [
|
||||
"检查技能名称是否正确",
|
||||
"使用 source 参数指定完整路径,如 vercel-labs/skills@find-skills",
|
||||
"运行 npx skills search <name> 搜索可用技能",
|
||||
],
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"output": f"技能 {name} 安装超时(120s)",
|
||||
"exit_code": -1,
|
||||
"is_error": True,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"output": f"技能 {name} 安装异常: {e}",
|
||||
"exit_code": -1,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
def _try_register_skill(self, name: str) -> str:
|
||||
"""Try to find and register the installed skill YAML into skill_registry."""
|
||||
try:
|
||||
from agentkit.skills.loader import SkillLoader
|
||||
|
||||
for search_dir in [os.path.join(os.getcwd(), ".agents", "skills"),
|
||||
os.path.join(os.path.expanduser("~"), ".agents", "skills"),
|
||||
os.path.join(os.getcwd(), "configs", "skills")]:
|
||||
yaml_path = os.path.join(search_dir, f"{name}.yaml")
|
||||
if os.path.exists(yaml_path):
|
||||
loader = SkillLoader(
|
||||
skill_registry=self._skill_registry,
|
||||
tool_registry=self._tool_registry,
|
||||
)
|
||||
loader.load_from_file(yaml_path)
|
||||
return f"技能已注册到系统(来源: {yaml_path})"
|
||||
|
||||
# Also check for directory-based skills
|
||||
for search_dir in [os.path.join(os.getcwd(), ".agents", "skills"),
|
||||
os.path.join(os.path.expanduser("~"), ".agents", "skills")]:
|
||||
skill_dir = os.path.join(search_dir, name)
|
||||
if os.path.isdir(skill_dir):
|
||||
for fname in os.listdir(skill_dir):
|
||||
if fname.endswith((".yaml", ".yml")):
|
||||
yaml_path = os.path.join(skill_dir, fname)
|
||||
loader = SkillLoader(
|
||||
skill_registry=self._skill_registry,
|
||||
tool_registry=self._tool_registry,
|
||||
)
|
||||
loader.load_from_file(yaml_path)
|
||||
return f"技能已注册到系统(来源: {yaml_path})"
|
||||
|
||||
return "技能文件已下载,但未找到 YAML 配置文件进行注册。可能需要重启服务。"
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register skill {name}: {e}")
|
||||
return f"技能文件已下载,但注册失败: {e}"
|
||||
|
|
@ -66,6 +66,7 @@ class TerminalSession:
|
|||
self.session_id = session_id
|
||||
self._cwd = cwd or os.getcwd()
|
||||
self._env: dict[str, str] = dict(env or os.environ)
|
||||
self._env_delta: set[str] = set() # Track only env vars explicitly set in this session
|
||||
self._history: deque[CommandRecord] = deque(maxlen=max_history)
|
||||
self._output_parser = OutputParser()
|
||||
self._created_at = time.time()
|
||||
|
|
@ -147,7 +148,12 @@ class TerminalSession:
|
|||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass # Process already exited
|
||||
except OSError:
|
||||
pass
|
||||
await proc.wait()
|
||||
output = f"命令执行超时({timeout}s)"
|
||||
exit_code = -1
|
||||
|
|
@ -196,11 +202,12 @@ class TerminalSession:
|
|||
if self._cwd:
|
||||
parts.append(f"cd {shlex.quote(self._cwd)}")
|
||||
|
||||
# 注入环境变量
|
||||
for key, value in self._env.items():
|
||||
if not _ENV_KEY_PATTERN.match(key):
|
||||
continue # Skip invalid env key names
|
||||
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
|
||||
# 注入环境变量(仅注入 session 中显式设置的增量,避免泄露完整 os.environ)
|
||||
if self._env_delta:
|
||||
for key in self._env_delta:
|
||||
value = self._env.get(key, "")
|
||||
if _ENV_KEY_PATTERN.match(key):
|
||||
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
|
||||
|
||||
parts.append(command)
|
||||
return " && ".join(parts)
|
||||
|
|
@ -269,6 +276,7 @@ class TerminalSession:
|
|||
for key, value in matches:
|
||||
value = value.strip().strip("'\"")
|
||||
self._env[key] = value
|
||||
self._env_delta.add(key)
|
||||
|
||||
def _add_history(self, record: CommandRecord) -> None:
|
||||
"""添加命令记录到历史,deque maxlen 自动淘汰最旧记录"""
|
||||
|
|
|
|||
|
|
@ -1,17 +1,23 @@
|
|||
"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路"""
|
||||
"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路
|
||||
|
||||
包含合并 LLM 分类路由和并行工具执行的集成测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||||
from agentkit.core.react import ReActEngine, ReActResult, ReActStep
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.core.react import ReActEngine, ReActResult, ReActStep, ReActEvent
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||||
from agentkit.quality.alignment import AlignmentConfig, AlignmentGuard
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -313,3 +319,285 @@ class TestAlignmentGuardConstraintInjection:
|
|||
alert = guard.record_interaction("session-chain-1")
|
||||
assert alert is None
|
||||
assert guard.get_interaction_count("session-chain-1") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Merged Router → Engine chain (HeuristicClassifier → merged LLM → ReAct)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _SlowTool(Tool):
|
||||
"""带延迟的 Fake Tool,用于验证并行执行"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "slow_tool",
|
||||
description: str = "A slow tool for testing",
|
||||
delay: float = 0.1,
|
||||
result: dict | None = None,
|
||||
):
|
||||
super().__init__(name=name, description=description)
|
||||
self._delay = delay
|
||||
self._result = result or {"status": "ok"}
|
||||
self.call_count = 0
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
self.call_count += 1
|
||||
await asyncio.sleep(self._delay)
|
||||
return self._result
|
||||
|
||||
|
||||
class TestMergedRouterToEngineChain:
|
||||
"""完整链路:用户消息 → HeuristicClassifier → merged LLM classify → ReActEngine → 结果
|
||||
|
||||
当 HeuristicClassifier 返回中等复杂度 (0.3-0.7) 时,使用合并 LLM 调用。
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_complexity_uses_merged_llm_then_react(self):
|
||||
"""中等复杂度触发 merged LLM classify,然后路由到 ReActEngine 执行"""
|
||||
# "如何优化代码" 包含 "如何"(中等)和 "代码"(高),heuristic 给出中等复杂度
|
||||
# merged LLM classify 返回中等复杂度 + 无 skill_hint
|
||||
# → 路由到默认 agent (ReAct)
|
||||
merged_response = _make_llm_response(json.dumps({
|
||||
"complexity": 0.5,
|
||||
"intent": "code_optimization",
|
||||
"skill_hint": None,
|
||||
}))
|
||||
# ReActEngine 最终答案
|
||||
react_final = _make_llm_response("建议使用缓存和异步IO来优化代码性能。")
|
||||
|
||||
gateway = _make_mock_gateway([merged_response, react_final])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway,
|
||||
org_context=org_context,
|
||||
merged_llm_classify=True,
|
||||
)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
# Step 1: Route
|
||||
route_result = await router.route(
|
||||
content="如何优化代码性能",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
# 验证路由结果:中等复杂度,使用 merged_llm 方法
|
||||
assert 0.3 <= route_result.complexity <= 0.7
|
||||
assert route_result.match_method is not None
|
||||
assert "merged_llm" in route_result.match_method
|
||||
|
||||
# Step 2: Execute with ReActEngine
|
||||
react_engine = ReActEngine(llm_gateway=gateway)
|
||||
engine_result = await react_engine.execute(
|
||||
messages=[{"role": "user", "content": route_result.clean_content}],
|
||||
system_prompt=route_result.system_prompt,
|
||||
)
|
||||
|
||||
# Step 3: Verify result
|
||||
assert isinstance(engine_result, ReActResult)
|
||||
assert "优化" in engine_result.output or "缓存" in engine_result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_medium_complexity_merged_llm_routes_to_skill_then_react(self):
|
||||
"""中等复杂度 + merged LLM 返回 skill_hint → 路由到 skill → ReAct 执行"""
|
||||
merged_response = _make_llm_response(json.dumps({
|
||||
"complexity": 0.45,
|
||||
"intent": "code_review",
|
||||
"skill_hint": "code_reviewer",
|
||||
}))
|
||||
react_final = _make_llm_response("代码审查完成,发现3个潜在问题。")
|
||||
|
||||
gateway = _make_mock_gateway([merged_response, react_final])
|
||||
|
||||
# 创建包含 code_reviewer skill 的 mock registry
|
||||
mock_skill = MagicMock()
|
||||
mock_skill.name = "code_reviewer"
|
||||
mock_skill.config.intent.keywords = ["code", "review"]
|
||||
mock_skill.config.llm = None
|
||||
mock_skill.config.prompt = None
|
||||
mock_skill.tools = []
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = [mock_skill]
|
||||
mock_skill_registry.get.return_value = mock_skill
|
||||
|
||||
org_context = OrganizationContext()
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway,
|
||||
org_context=org_context,
|
||||
merged_llm_classify=True,
|
||||
)
|
||||
|
||||
route_result = await router.route(
|
||||
content="如何优化代码性能",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=AsyncMock(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
# 验证路由到 skill
|
||||
assert route_result.matched is True
|
||||
assert route_result.skill_name == "code_reviewer"
|
||||
assert route_result.match_method == "merged_llm"
|
||||
|
||||
# Execute with ReActEngine
|
||||
react_engine = ReActEngine(llm_gateway=gateway)
|
||||
engine_result = await react_engine.execute(
|
||||
messages=[{"role": "user", "content": route_result.clean_content}],
|
||||
system_prompt=route_result.system_prompt,
|
||||
)
|
||||
|
||||
assert isinstance(engine_result, ReActResult)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_llm_high_complexity_delegates_to_layer2(self):
|
||||
"""HeuristicClassifier 中等复杂度 → merged LLM 返回高复杂度 → 委派到 Layer 2"""
|
||||
merged_response = _make_llm_response(json.dumps({
|
||||
"complexity": 0.85,
|
||||
"intent": "deep_analysis",
|
||||
"skill_hint": None,
|
||||
}))
|
||||
|
||||
gateway = _make_mock_gateway([merged_response])
|
||||
|
||||
org_context = OrganizationContext()
|
||||
org_context.register_agent(AgentProfile(
|
||||
name="analyst",
|
||||
agent_type="react",
|
||||
capabilities=["分析", "优化", "代码"],
|
||||
skills=["code_analysis"],
|
||||
current_load=0,
|
||||
))
|
||||
org_context.find_best_agent = MagicMock(
|
||||
return_value=org_context.get_agent_profile("analyst")
|
||||
)
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway,
|
||||
org_context=org_context,
|
||||
merged_llm_classify=True,
|
||||
)
|
||||
|
||||
mock_skill_registry = MagicMock()
|
||||
mock_skill_registry.list_skills.return_value = []
|
||||
mock_intent_router = AsyncMock()
|
||||
|
||||
route_result = await router.route(
|
||||
content="如何优化代码性能",
|
||||
skill_registry=mock_skill_registry,
|
||||
intent_router=mock_intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful",
|
||||
default_model="default",
|
||||
default_agent_name="default",
|
||||
)
|
||||
|
||||
# 高复杂度应委派到 Layer 2
|
||||
assert route_result.complexity >= 0.7
|
||||
assert route_result.matched is True
|
||||
assert route_result.agent_name == "analyst"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: Parallel Tools Integration (ReActEngine with parallel_tools="auto")
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelToolsIntegration:
|
||||
"""ReActEngine + parallel_tools="auto" 在真实场景下的集成测试
|
||||
|
||||
LLM 返回 2 个 tool_calls 且 _parallelizable=true,两者并行执行。
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_parallel_two_tools_realistic(self):
|
||||
"""真实场景:LLM 返回 2 个并行工具调用,并行执行"""
|
||||
tool_a = _SlowTool(name="search_web", delay=0.1, result={"results": ["Python best practices"]})
|
||||
tool_b = _SlowTool(name="search_docs", delay=0.1, result={"docs": ["Official Python docs"]})
|
||||
|
||||
# LLM 返回 2 个并行工具调用
|
||||
tool_call_response = LLMResponse(
|
||||
content="",
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=50, completion_tokens=20),
|
||||
tool_calls=[
|
||||
ToolCall(id="tc_1", name="search_web", arguments={"query": "python", "_parallelizable": True}),
|
||||
ToolCall(id="tc_2", name="search_docs", arguments={"topic": "python", "_parallelizable": True}),
|
||||
],
|
||||
)
|
||||
final_response = _make_llm_response("Based on search results, Python best practices include...")
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=[tool_call_response, final_response])
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
|
||||
|
||||
start = time.monotonic()
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search for Python best practices"}],
|
||||
tools=[tool_a, tool_b],
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
assert tool_a.call_count == 1
|
||||
assert tool_b.call_count == 1
|
||||
# 并行执行应比串行快
|
||||
assert elapsed < 0.25, f"Parallel execution too slow: {elapsed:.2f}s"
|
||||
|
||||
# 验证轨迹包含两个工具调用
|
||||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||||
assert len(tool_steps) == 2
|
||||
names = {s.tool_name for s in tool_steps}
|
||||
assert names == {"search_web", "search_docs"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_parallel_with_serial_tool_mixed(self):
|
||||
"""混合场景:1 个串行工具 + 2 个并行工具"""
|
||||
tool_serial = _SlowTool(name="init_context", delay=0.05, result={"context": "ready"})
|
||||
tool_para_a = _SlowTool(name="search_web", delay=0.1, result={"results": ["web result"]})
|
||||
tool_para_b = _SlowTool(name="search_docs", delay=0.1, result={"docs": ["doc result"]})
|
||||
|
||||
tool_call_response = LLMResponse(
|
||||
content="",
|
||||
model="test-model",
|
||||
usage=TokenUsage(prompt_tokens=50, completion_tokens=20),
|
||||
tool_calls=[
|
||||
ToolCall(id="tc_1", name="init_context", arguments={"project": "test"}),
|
||||
ToolCall(id="tc_2", name="search_web", arguments={"query": "test", "_parallelizable": True}),
|
||||
ToolCall(id="tc_3", name="search_docs", arguments={"topic": "test", "_parallelizable": True}),
|
||||
],
|
||||
)
|
||||
final_response = _make_llm_response("Combined result from all tools")
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=[tool_call_response, final_response])
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Initialize and search"}],
|
||||
tools=[tool_serial, tool_para_a, tool_para_b],
|
||||
)
|
||||
|
||||
assert isinstance(result, ReActResult)
|
||||
assert tool_serial.call_count == 1
|
||||
assert tool_para_a.call_count == 1
|
||||
assert tool_para_b.call_count == 1
|
||||
|
||||
# 所有工具结果都在轨迹中
|
||||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||||
assert len(tool_steps) == 3
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""AuctionHouse 与 WealthTracker 单元测试"""
|
||||
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.marketplace.auction import AuctionHouse, AuctionResult, Bid
|
||||
|
|
@ -288,3 +290,406 @@ class TestAuctionDefaultDisabled:
|
|||
auction_enabled = getattr(marketplace_cfg, "auction_enabled", False)
|
||||
assert auction_enabled is False
|
||||
# If marketplace doesn't exist at all, auction is implicitly disabled
|
||||
|
||||
|
||||
# ---- Vickrey Auction 测试 ----
|
||||
|
||||
|
||||
class TestVickreySingleBidder:
|
||||
"""Vickrey 拍卖:单一竞价者获胜,支付自身成本(利润为0)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_bidder_wins_pays_own_cost(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid = make_bid(agent_name="solo_agent", estimated_cost=5.0)
|
||||
result = await house.run_vickrey_auction("task", [bid])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "solo_agent"
|
||||
assert result.total_bidders == 1
|
||||
# Winner pays own cost, so profit = 0 → wealth unchanged
|
||||
assert tracker.get_wealth("solo_agent") == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_bidder_selection_reason(self):
|
||||
house = AuctionHouse()
|
||||
bid = make_bid(agent_name="solo_agent", estimated_cost=10.0)
|
||||
result = await house.run_vickrey_auction("task", [bid])
|
||||
assert "sole eligible bidder" in result.selection_reason
|
||||
|
||||
|
||||
class TestVickreyTwoBidders:
|
||||
"""Vickrey 拍卖:两个竞价者,最低价赢,支付第二低价"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lowest_cost_wins_pays_second(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_cheap = make_bid(agent_name="cheap_agent", estimated_cost=5.0)
|
||||
bid_expensive = make_bid(agent_name="expensive_agent", estimated_cost=10.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_cheap, bid_expensive])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "cheap_agent"
|
||||
# Winner pays second-lowest = 10.0, profit = 10.0 - 5.0 = 5.0
|
||||
assert tracker.get_wealth("cheap_agent") == 100.0 + 5.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_price_not_own_price(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
|
||||
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_a, bid_b])
|
||||
# agent_a wins, pays 7.0 (not 3.0), profit = 7.0 - 3.0 = 4.0
|
||||
assert tracker.get_wealth("agent_a") == 100.0 + 4.0
|
||||
# Loser pays nothing
|
||||
assert tracker.get_wealth("agent_b") == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selection_reason_contains_second_price(self):
|
||||
house = AuctionHouse()
|
||||
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
|
||||
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_a, bid_b])
|
||||
assert "7.0" in result.selection_reason
|
||||
assert "agent_b" in result.selection_reason
|
||||
|
||||
|
||||
class TestVickreyThreeBidders:
|
||||
"""Vickrey 拍卖:三个竞价者,最低价赢,支付第二低价"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lowest_wins_pays_second_lowest(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
|
||||
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
|
||||
bid_c = make_bid(agent_name="agent_c", estimated_cost=12.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_a, bid_b, bid_c])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "agent_a"
|
||||
# Winner pays second-lowest = 7.0, profit = 7.0 - 3.0 = 4.0
|
||||
assert tracker.get_wealth("agent_a") == 100.0 + 4.0
|
||||
# Third bidder's cost doesn't affect payment
|
||||
assert result.total_bidders == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middle_bidder_wins_when_cheapest_bankrupt(self):
|
||||
tracker = WealthTracker()
|
||||
# Make agent_a bankrupt
|
||||
tracker.penalize("agent_a", 300.0)
|
||||
assert tracker.is_bankrupt("agent_a") is True
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
|
||||
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
|
||||
bid_c = make_bid(agent_name="agent_c", estimated_cost=12.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_a, bid_b, bid_c])
|
||||
# agent_a is bankrupt, so agent_b wins, pays 12.0
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "agent_b"
|
||||
# profit = 12.0 - 7.0 = 5.0
|
||||
assert tracker.get_wealth("agent_b") == 100.0 + 5.0
|
||||
|
||||
|
||||
class TestVickreyCapabilityFiltering:
|
||||
"""Vickrey 拍卖:能力过滤"""
|
||||
|
||||
def test_filter_by_capabilities_basic(self):
|
||||
house = AuctionHouse()
|
||||
bids = [
|
||||
make_bid(agent_name="a", capabilities=["search", "analysis"]),
|
||||
make_bid(agent_name="b", capabilities=["search"]),
|
||||
make_bid(agent_name="c", capabilities=["analysis", "coding"]),
|
||||
]
|
||||
filtered = house.filter_by_capabilities(bids, ["search"])
|
||||
assert len(filtered) == 2
|
||||
assert {b.agent_name for b in filtered} == {"a", "b"}
|
||||
|
||||
def test_filter_by_capabilities_requires_all(self):
|
||||
house = AuctionHouse()
|
||||
bids = [
|
||||
make_bid(agent_name="a", capabilities=["search", "analysis"]),
|
||||
make_bid(agent_name="b", capabilities=["search"]),
|
||||
make_bid(agent_name="c", capabilities=["analysis", "coding"]),
|
||||
]
|
||||
filtered = house.filter_by_capabilities(bids, ["search", "analysis"])
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0].agent_name == "a"
|
||||
|
||||
def test_filter_by_capabilities_case_insensitive(self):
|
||||
house = AuctionHouse()
|
||||
bids = [
|
||||
make_bid(agent_name="a", capabilities=["Search", "Analysis"]),
|
||||
]
|
||||
filtered = house.filter_by_capabilities(bids, ["search", "analysis"])
|
||||
assert len(filtered) == 1
|
||||
|
||||
def test_filter_by_capabilities_no_match(self):
|
||||
house = AuctionHouse()
|
||||
bids = [
|
||||
make_bid(agent_name="a", capabilities=["search"]),
|
||||
]
|
||||
filtered = house.filter_by_capabilities(bids, ["coding"])
|
||||
assert len(filtered) == 0
|
||||
|
||||
def test_filter_by_capabilities_empty_requirements(self):
|
||||
house = AuctionHouse()
|
||||
bids = [
|
||||
make_bid(agent_name="a", capabilities=["search"]),
|
||||
]
|
||||
filtered = house.filter_by_capabilities(bids, [])
|
||||
assert len(filtered) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vickrey_with_capability_filtering(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bids = [
|
||||
make_bid(agent_name="a", estimated_cost=5.0, capabilities=["search", "analysis"]),
|
||||
make_bid(agent_name="b", estimated_cost=3.0, capabilities=["search"]),
|
||||
make_bid(agent_name="c", estimated_cost=8.0, capabilities=["search", "analysis"]),
|
||||
]
|
||||
# Require both "search" and "analysis" → only a and c eligible
|
||||
result = await house.run_vickrey_auction("task", bids, required_capabilities=["search", "analysis"])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "a"
|
||||
# a wins (cost=5.0), pays second-lowest among eligible = c's cost = 8.0
|
||||
assert tracker.get_wealth("a") == 100.0 + (8.0 - 5.0)
|
||||
|
||||
|
||||
class TestVickreyBankruptAgent:
|
||||
"""Vickrey 拍卖:破产 Agent 被排除"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bankrupt_agent_excluded(self):
|
||||
tracker = WealthTracker()
|
||||
tracker.penalize("bankrupt_agent", 300.0) # wealth = -200, bankrupt
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_bankrupt = make_bid(agent_name="bankrupt_agent", estimated_cost=1.0)
|
||||
bid_ok = make_bid(agent_name="ok_agent", estimated_cost=10.0)
|
||||
result = await house.run_vickrey_auction("task", [bid_bankrupt, bid_ok])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "ok_agent"
|
||||
# Only 1 eligible bidder → pays own cost, profit = 0
|
||||
assert tracker.get_wealth("ok_agent") == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_bankrupt_returns_none(self):
|
||||
tracker = WealthTracker()
|
||||
tracker.penalize("a", 300.0)
|
||||
tracker.penalize("b", 300.0)
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bids = [
|
||||
make_bid(agent_name="a", estimated_cost=1.0),
|
||||
make_bid(agent_name="b", estimated_cost=2.0),
|
||||
]
|
||||
result = await house.run_vickrey_auction("task", bids)
|
||||
assert result.winner is None
|
||||
assert "bankrupt" in result.selection_reason.lower() or "eligible" in result.selection_reason.lower()
|
||||
|
||||
|
||||
class TestVickreyNoBidders:
|
||||
"""Vickrey 拍卖:无竞价者"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_bidders_returns_none(self):
|
||||
house = AuctionHouse()
|
||||
result = await house.run_vickrey_auction("task", [])
|
||||
assert result.winner is None
|
||||
assert result.total_bidders == 0
|
||||
assert result.all_bids == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_eligible_after_capability_filter(self):
|
||||
house = AuctionHouse()
|
||||
bid = make_bid(agent_name="a", estimated_cost=5.0, capabilities=["search"])
|
||||
result = await house.run_vickrey_auction("task", [bid], required_capabilities=["coding"])
|
||||
assert result.winner is None
|
||||
assert result.total_bidders == 1
|
||||
|
||||
|
||||
class TestVickreyWealthTrackerUpdate:
|
||||
"""Vickrey 拍卖:WealthTracker 正确更新"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_winner_earns_payment_minus_cost(self):
|
||||
tracker = WealthTracker(initial_wealth=50.0)
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid_a = make_bid(agent_name="a", estimated_cost=4.0)
|
||||
bid_b = make_bid(agent_name="b", estimated_cost=9.0)
|
||||
await house.run_vickrey_auction("task", [bid_a, bid_b])
|
||||
# a wins, pays 9.0, profit = 9.0 - 4.0 = 5.0
|
||||
assert tracker.get_wealth("a") == 50.0 + 5.0
|
||||
# b pays nothing
|
||||
assert tracker.get_wealth("b") == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_bidder_zero_profit(self):
|
||||
tracker = WealthTracker(initial_wealth=100.0)
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
# Single bidder: pays own cost, profit = 0
|
||||
bid = make_bid(agent_name="a", estimated_cost=10.0)
|
||||
await house.run_vickrey_auction("task", [bid])
|
||||
assert tracker.get_wealth("a") == 100.0
|
||||
|
||||
|
||||
class TestVickreyBackwardCompat:
|
||||
"""Vickrey 拍卖:原有 score_bid 方法仍然可用"""
|
||||
|
||||
def test_score_bid_still_works(self):
|
||||
tracker = WealthTracker()
|
||||
house = AuctionHouse(wealth_tracker=tracker)
|
||||
bid = make_bid(agent_name="test_agent", confidence=0.9, estimated_cost=5.0)
|
||||
score = house.score_bid(bid)
|
||||
wealth_factor = 1.0 + (100.0 / 1000.0)
|
||||
expected = (0.9 / 5.0) * wealth_factor
|
||||
assert abs(score - expected) < 0.0001
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_auction_still_works(self):
|
||||
house = AuctionHouse()
|
||||
bid_low = make_bid(agent_name="low_agent", confidence=0.5, estimated_cost=10.0)
|
||||
bid_high = make_bid(agent_name="high_agent", confidence=0.9, estimated_cost=10.0)
|
||||
result = await house.run_auction("do something", [bid_low, bid_high])
|
||||
assert result.winner is not None
|
||||
assert result.winner.agent_name == "high_agent"
|
||||
|
||||
|
||||
# ---- Bid Validation 测试 ----
|
||||
|
||||
|
||||
class TestBidValidation:
|
||||
"""Bid __post_init__ 验证"""
|
||||
|
||||
def test_negative_estimated_cost_raises(self):
|
||||
with pytest.raises(ValueError, match="estimated_cost must be non-negative"):
|
||||
Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=-1.0,
|
||||
confidence=0.8,
|
||||
payment_offer=1.0,
|
||||
capabilities=[],
|
||||
)
|
||||
|
||||
def test_confidence_above_one_raises(self):
|
||||
with pytest.raises(ValueError, match="confidence must be between"):
|
||||
Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=10.0,
|
||||
confidence=1.5,
|
||||
payment_offer=1.0,
|
||||
capabilities=[],
|
||||
)
|
||||
|
||||
def test_confidence_below_zero_raises(self):
|
||||
with pytest.raises(ValueError, match="confidence must be between"):
|
||||
Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=10.0,
|
||||
confidence=-0.1,
|
||||
payment_offer=1.0,
|
||||
capabilities=[],
|
||||
)
|
||||
|
||||
def test_negative_payment_offer_raises(self):
|
||||
with pytest.raises(ValueError, match="payment_offer must be non-negative"):
|
||||
Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=10.0,
|
||||
confidence=0.8,
|
||||
payment_offer=-5.0,
|
||||
capabilities=[],
|
||||
)
|
||||
|
||||
def test_zero_values_allowed(self):
|
||||
bid = Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=0.0,
|
||||
confidence=0.0,
|
||||
payment_offer=0.0,
|
||||
capabilities=[],
|
||||
)
|
||||
assert bid.estimated_cost == 0.0
|
||||
assert bid.confidence == 0.0
|
||||
assert bid.payment_offer == 0.0
|
||||
|
||||
def test_boundary_confidence_one_allowed(self):
|
||||
bid = Bid(
|
||||
agent_name="a",
|
||||
architecture="react",
|
||||
estimated_steps=5,
|
||||
estimated_cost=10.0,
|
||||
confidence=1.0,
|
||||
payment_offer=1.0,
|
||||
capabilities=[],
|
||||
)
|
||||
assert bid.confidence == 1.0
|
||||
|
||||
|
||||
# ---- WealthTracker Thread Safety 测试 ----
|
||||
|
||||
|
||||
class TestWealthTrackerThreadSafety:
|
||||
"""WealthTracker 线程安全"""
|
||||
|
||||
def test_concurrent_reward_penalize(self):
|
||||
tracker = WealthTracker(initial_wealth=1000.0)
|
||||
errors: list[Exception] = []
|
||||
|
||||
def worker(action: str, name: str, amount: float, count: int):
|
||||
try:
|
||||
for _ in range(count):
|
||||
if action == "reward":
|
||||
tracker.reward(name, amount)
|
||||
else:
|
||||
tracker.penalize(name, amount)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=("reward", "agent_a", 1.0, 100)),
|
||||
threading.Thread(target=worker, args=("penalize", "agent_a", 1.0, 100)),
|
||||
threading.Thread(target=worker, args=("reward", "agent_b", 2.0, 50)),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors
|
||||
# agent_a: 1000 + 100*1.0 - 100*1.0 = 1000
|
||||
assert tracker.get_wealth("agent_a") == 1000.0
|
||||
# agent_b: 1000 + 50*2.0 = 1100
|
||||
assert tracker.get_wealth("agent_b") == 1100.0
|
||||
|
||||
|
||||
# ---- Wealth Factor Lower Bound 测试 ----
|
||||
|
||||
|
||||
class TestWealthFactorLowerBound:
|
||||
"""get_wealth_factor 下限保护"""
|
||||
|
||||
def test_extremely_negative_wealth_clamped(self):
|
||||
tracker = WealthTracker(initial_wealth=100.0)
|
||||
tracker.penalize("agent_a", 5000.0)
|
||||
# wealth = 100 - 5000 = -4900, factor would be 1.0 + (-4900/1000) = -3.9
|
||||
# But with lower bound: max(0.01, -3.9) = 0.01
|
||||
factor = tracker.get_wealth_factor("agent_a")
|
||||
assert factor == 0.01
|
||||
|
||||
def test_slightly_negative_wealth_not_clamped(self):
|
||||
tracker = WealthTracker(initial_wealth=100.0)
|
||||
tracker.penalize("agent_a", 150.0)
|
||||
# wealth = -50, factor = 1.0 + (-50/1000) = 0.95
|
||||
factor = tracker.get_wealth_factor("agent_a")
|
||||
assert abs(factor - 0.95) < 0.0001
|
||||
|
|
|
|||
|
|
@ -580,9 +580,9 @@ class TestHeuristicClassifierIntegration:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_mode_no_llm_call(self):
|
||||
"""heuristic 模式下不调用 LLM"""
|
||||
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic")
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
|
|
@ -590,7 +590,7 @@ class TestHeuristicClassifierIntegration:
|
|||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# LLM gateway.chat 不应被调用
|
||||
# LLM gateway.chat 不应被调用(heuristic + merged disabled)
|
||||
gateway.chat.assert_not_called()
|
||||
# 复杂度应来自启发式分类器
|
||||
assert result.complexity > 0.0
|
||||
|
|
@ -629,3 +629,180 @@ class TestHeuristicClassifierIntegration:
|
|||
"""默认分类器模式为 heuristic"""
|
||||
router = CostAwareRouter()
|
||||
assert router._classifier == "heuristic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# U1: Merged LLM Classify
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMergedLLMClassify:
|
||||
"""合并路由 LLM 调用测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_returns_valid_skill(self):
|
||||
"""合并调用返回有效 JSON + skill_hint,正确路由到指定 skill"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.6,
|
||||
"intent": "code_generation",
|
||||
"skill_hint": "search",
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我搜索一下最新的新闻",
|
||||
skill_registry=registry,
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.skill_name == "search"
|
||||
assert result.match_method == "merged_llm"
|
||||
assert result.complexity > 0.3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_malformed_response_fallback(self):
|
||||
"""合并调用返回格式异常,fallback 到默认 Agent"""
|
||||
gateway = _make_llm_gateway("这不是JSON")
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "merged_llm_fallback"
|
||||
assert result.complexity == 0.5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_low_complexity(self):
|
||||
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.2,
|
||||
"intent": "greeting",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Merged LLM returned complexity < 0.3, should route to low complexity
|
||||
assert result.complexity < 0.3
|
||||
assert "low" in result.match_method or "merged_llm_low" in result.match_method
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_high_complexity(self):
|
||||
"""合并调用返回 complexity > 0.7,走 Layer 2"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.85,
|
||||
"intent": "research",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
org_context = MagicMock()
|
||||
org_context.find_best_agent = MagicMock(return_value="researcher")
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway, model="default",
|
||||
org_context=org_context, merged_llm_classify=True,
|
||||
)
|
||||
result = await router.route(
|
||||
content="做市场调研+竞品分析",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.complexity > 0.7
|
||||
assert result.match_method == "capability"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
|
||||
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||||
|
||||
router = CostAwareRouter(
|
||||
llm_gateway=gateway, model="default",
|
||||
merged_llm_classify=False,
|
||||
)
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=intent_router,
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should not use merged_llm match_method
|
||||
assert result.match_method != "merged_llm"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_no_llm_gateway_falls_back(self):
|
||||
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
|
||||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||||
registry = _make_skill_registry([search_skill])
|
||||
|
||||
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="分析下这个数据",
|
||||
skill_registry=registry,
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should not crash, should use IntentRouter fallback
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_skill_hint_not_found_fallback(self):
|
||||
"""合并调用返回的 skill_hint 在 registry 中不存在,fallback"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.5,
|
||||
"intent": "unknown",
|
||||
"skill_hint": "nonexistent_skill",
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Should fallback to default agent (medium complexity, no skill match)
|
||||
assert result.matched is False
|
||||
assert result.match_method == "merged_llm_medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merged_classify_only_one_llm_call(self):
|
||||
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
|
||||
merged_response = json.dumps({
|
||||
"complexity": 0.5,
|
||||
"intent": "question",
|
||||
"skill_hint": None,
|
||||
})
|
||||
gateway = _make_llm_gateway(merged_response)
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||||
await router.route(
|
||||
content="如何使用这个功能?",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# Only 1 LLM call should have been made (the merged classify)
|
||||
assert gateway.chat.call_count == 1
|
||||
|
|
|
|||
|
|
@ -995,3 +995,153 @@ class TestReWOOProgressiveFallback:
|
|||
|
||||
assert result.output == "ReAct answer with tool"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
|
||||
# ── Test: Configurable Fallback Strategies ──────────────────
|
||||
|
||||
|
||||
class TestReWOOConfigurableFallback:
|
||||
"""ReWOO 回退链配置化测试"""
|
||||
|
||||
async def test_default_fallback_strategies(self):
|
||||
"""默认回退链:simplified_rewoo → react → direct"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
engine = ReWOOEngine(llm_gateway=MagicMock(spec=LLMGateway))
|
||||
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||
|
||||
async def test_custom_fallback_strategies(self):
|
||||
"""自定义回退链:plan_exec → react → direct"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
fallback_strategies=["plan_exec", "react", "direct"],
|
||||
)
|
||||
assert engine._fallback_strategies == ["plan_exec", "react", "direct"]
|
||||
|
||||
async def test_invalid_strategy_name_skipped(self):
|
||||
"""无效策略名跳过并警告"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
fallback_strategies=["invalid_strategy", "react", "direct"],
|
||||
)
|
||||
assert engine._fallback_strategies == ["react", "direct"]
|
||||
|
||||
async def test_empty_fallback_strategies_uses_defaults(self):
|
||||
"""空回退链回退到默认"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
fallback_strategies=[],
|
||||
)
|
||||
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||
|
||||
async def test_all_invalid_strategies_uses_defaults(self):
|
||||
"""全部无效策略名回退到默认"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
fallback_strategies=["foo", "bar"],
|
||||
)
|
||||
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||||
|
||||
async def test_custom_fallback_plan_exec_first(self):
|
||||
"""自定义回退链:plan_exec 优先,规划失败时先走 plan_exec 再走 react"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning fails, plan_exec also fails, react succeeds
|
||||
invalid_plan = make_response(content="Not a plan")
|
||||
plan_exec_fail = make_response(content="Still not a plan")
|
||||
react_response = make_response(content="React answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan, plan_exec_fail, react_response])
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=gateway,
|
||||
fallback_strategies=["plan_exec", "react"],
|
||||
)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
assert result.output == "React answer"
|
||||
assert result.fallback_strategy == "react"
|
||||
|
||||
async def test_custom_fallback_direct_only(self):
|
||||
"""自定义回退链:仅 direct,跳过 simplified_rewoo 和 react"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning fails, direct succeeds
|
||||
invalid_plan = make_response(content="Not a plan")
|
||||
direct_response = make_response(content="Direct answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan, direct_response])
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=gateway,
|
||||
fallback_strategies=["direct"],
|
||||
)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
||||
assert result.output == "Direct answer"
|
||||
assert result.fallback_strategy == "direct"
|
||||
|
||||
async def test_valid_strategies_constant(self):
|
||||
"""验证 VALID_STRATEGIES 集合"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
assert ReWOOEngine.VALID_STRATEGIES == {"simplified_rewoo", "react", "direct", "plan_exec"}
|
||||
|
||||
async def test_stream_custom_fallback_react_only(self):
|
||||
"""流式模式:自定义回退链仅 react"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning fails, react succeeds
|
||||
invalid_plan = make_response(content="Not a plan")
|
||||
react_response = make_response(content="React stream answer")
|
||||
|
||||
gateway = make_mock_gateway([invalid_plan, react_response])
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=gateway,
|
||||
fallback_strategies=["react"],
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
event_types = [e.event_type for e in events]
|
||||
assert "final_answer" in event_types
|
||||
|
||||
async def test_all_fallback_exhausted_raises(self):
|
||||
"""所有回退策略耗尽时抛出 RuntimeError"""
|
||||
from agentkit.core.rewoo import ReWOOEngine
|
||||
|
||||
# Planning fails, all fallbacks fail
|
||||
call_count = 0
|
||||
|
||||
async def always_fail(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(side_effect=always_fail)
|
||||
engine = ReWOOEngine(
|
||||
llm_gateway=gateway,
|
||||
fallback_strategies=["react", "direct"],
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "Task"}],
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue