fix(review): comprehensive P0-P2 code review fixes

This commit is contained in:
chiguyong 2026-06-12 22:18:25 +08:00
parent 2e55aae775
commit 5ef08a3b30
21 changed files with 3420 additions and 513 deletions

264
README.md
View File

@ -6,13 +6,15 @@
AgentKit 解决的核心问题:**从写 150 行 Agent 代码降为 10-20 行 YAML 配置**。
传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 6 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 SkillPrompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。
传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 8 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 SkillPrompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。
核心定位:
- **配置驱动** -- YAML 定义 Skill无需写 Agent 子类
- **生产就绪** -- 内置质量门禁、模型降级、用量统计
- **两种部署** -- Python 库直接引用,或 FastAPI 独立部署
- **三种使用** -- Python 库引用、CLI 聊天、Web GUI 界面
- **工具丰富** -- 内置 Shell、搜索、爬虫、记忆等工具支持 MCP 扩展
- **Pipeline 编排** -- 多 Agent 协同、Saga 补偿、动态流水线
## 核心特性
@ -22,7 +24,7 @@ Think -> Act -> Observe 循环。LLM 自主决定是否调用工具、调用哪
### 2. LLM Gateway
统一 LLM 调用入口。Provider 注册、模型别名解析(如 `deepseek` -> `deepseek/deepseek-chat`、Fallback 降级策略、Token 用量和成本追踪。
统一 LLM 调用入口。Provider 注册、模型别名解析(如 `default` -> `dashscope/qwen3-coder-plus`、Fallback 降级策略、Token 用量和成本追踪。支持百炼 DashScope、OpenAI、DeepSeek 等 OpenAI 兼容 API。
### 3. Skill 系统
@ -40,73 +42,102 @@ Skill = SkillConfig + 绑定 Tools。一个 Skill 代表一个可执行技能,
Schema 验证 + 字段类型归一化str -> int/float/bool+ 元数据附加version、produced_at、quality_score。所有 Skill 产出统一为 StandardOutput 格式。
### 7. 内置工具集
开箱即用的工具插件,覆盖常见 Agent 需求:
| 工具 | 说明 |
|------|------|
| `ShellTool` | 执行 Shell 命令,白名单安全机制 + 用户确认 |
| `WebSearchTool` | DuckDuckGo / Bing 网页搜索 |
| `BaiduSearchTool` | 百度搜索 |
| `WebCrawlTool` | 网页抓取与内容提取 |
| `MemoryTool` | 短期/长期记忆管理 |
| `AskHumanTool` | 向用户提问获取信息 |
| `SchemaExtractTool` | 从文本提取结构化数据 |
| `SchemaGenerateTool` | 生成 JSON Schema |
| `MCPTool` | MCP 协议工具扩展 |
工具组合:`SequentialChain`(顺序链)、`ParallelFanOut`(并行扇出)、`DynamicSelector`(动态选择)。
### 8. Pipeline 编排
多 Agent 协同编排,支持复杂工作流:
- **PipelineEngine** -- 阶段式流水线执行,支持自适应配置
- **SagaOrchestrator** -- 分布式事务补偿,失败自动回滚
- **DynamicPipeline** -- 运行时动态调整流水线结构
- **PipelineReflector** -- 执行反思与重规划
- **HandoffManager** -- Agent 间任务移交
## 架构图
```
+------------------+
| User Request |
+--------+---------+
|
v
+-------------+--------------+
| IntentRouter |
| (keyword -> LLM classify) |
+-------------+--------------+
|
matched_skill
|
v
+-------------+--------------+
| ConfigDrivenAgent |
| (SkillConfig-driven) |
+-------------+--------------+
|
+------------+------------+
| |
v v
+---------+--------+ +----------+---------+
| ReActEngine | | Traditional Mode |
| Think->Act->Observe| | llm_generate/ |
+---------+--------+ | tool_call/custom |
| +--------------------+
v
+----------+----------+
| LLM Gateway |
| resolve -> chat |
| fallback -> track |
+----------+----------+
|
+------+------+
| |
v v
+-----+----+ +-----+-----+
| Provider A| | Provider B| ...
+-----+----+ +-----+-----+
| |
v v
+-----+----+ +-----+-----+
| Tool 1 | | Tool 2 | ...
+-----------+ +-----------+
|
v
+----------+----------+
| Quality Gate |
| required_fields |
| min_word_count |
| schema validation |
| custom validator |
+----------+----------+
|
v
+----------+----------+
| OutputStandardizer |
| schema + normalize |
| + metadata |
+----------+----------+
|
v
StandardOutput
+-------------------+ +-------------------+
| Web GUI Chat | | CLI Chat |
| (WebSocket) | | (agentkit chat) |
+--------+----------+ +--------+----------+
| |
+----------+----------+
|
+----------v----------+
| Skill Routing |
| (keyword -> LLM) |
+----------+----------+
|
matched_skill
|
+-------------------v-------------------+
| ConfigDrivenAgent |
| (SkillConfig-driven) |
+-------------------+------------------+
|
+--------------+--------------+
| |
v v
+---------+--------+ +----------+---------+
| ReActEngine | | Traditional Mode |
| Think->Act->Observe| | llm_generate/ |
+---------+--------+ | tool_call/custom |
| +---------------------+
v
+----------+----------+
| LLM Gateway |
| resolve -> chat |
| fallback -> track |
+----------+----------+
|
+------+------+
| |
v v
+-----+----+ +-----+-----+
| DashScope | | OpenAI | ...
+-----+----+ +-----+-----+
|
+----------+----------+
| Tool Registry |
| shell / search / |
| crawl / memory / ... |
+----------+----------+
|
v
+----------+----------+
| Quality Gate |
| required_fields |
| min_word_count |
| schema validation |
| custom validator |
+----------+----------+
|
v
+----------+----------+
| OutputStandardizer |
| schema + normalize |
| + metadata |
+----------+----------+
|
v
StandardOutput
```
## 快速开始
@ -147,7 +178,13 @@ agentkit version
# 初始化项目(生成配置文件)
agentkit init
# 启动 Server
# 启动 Web GUI 聊天界面(推荐)
agentkit gui --port 8002
# 启动 CLI 聊天
agentkit chat
# 启动 ServerAPI 模式)
agentkit serve --host 0.0.0.0 --port 8001
# 健康检查
@ -247,9 +284,9 @@ from agentkit.llm.providers.openai import OpenAIProvider
async def main():
# 1. 初始化 LLM Gateway
gateway = LLMGateway()
gateway.register_provider("openai", OpenAIProvider(
gateway.register_provider("dashscope", OpenAIProvider(
api_key="sk-xxx",
base_url="https://api.openai.com/v1",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
))
# 2. 定义 Skill
@ -263,7 +300,7 @@ async def main():
"instructions": "根据用户需求生成高质量内容",
"output_format": "以 JSON 格式输出",
},
llm={"model": "openai/gpt-4o", "temperature": 0.7},
llm={"model": "default", "temperature": 0.7},
execution_mode="react",
max_steps=5,
)
@ -318,9 +355,9 @@ from agentkit import LLMGateway
from agentkit.llm.providers.openai import OpenAIProvider
gateway = LLMGateway()
gateway.register_provider("openai", OpenAIProvider(
gateway.register_provider("dashscope", OpenAIProvider(
api_key="sk-xxx",
base_url="https://api.openai.com/v1",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
))
app = create_app(llm_gateway=gateway)
@ -352,8 +389,8 @@ from datetime import datetime, timezone
async def main():
# 初始化 Gateway
gateway = LLMGateway()
gateway.register_provider("openai", OpenAIProvider(
api_key="sk-xxx", base_url="https://api.openai.com/v1",
gateway.register_provider("dashscope", OpenAIProvider(
api_key="sk-xxx", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
))
# 定义多个 Skill
@ -366,7 +403,7 @@ async def main():
"instructions": "生成 SEO 优化内容",
"output_format": "JSON: {content, word_count}",
},
llm={"model": "openai/gpt-4o"},
llm={"model": "default"},
intent={
"keywords": ["生成", "内容", "写作"],
"description": "内容生成与写作",
@ -390,7 +427,7 @@ async def main():
"instructions": "优化内容以提升 AI 搜索可见性",
"output_format": "JSON: {optimized_content, seo_score, changes}",
},
llm={"model": "openai/gpt-4o"},
llm={"model": "default"},
intent={
"keywords": ["优化", "GEO", "SEO"],
"description": "内容 GEO/SEO 优化",
@ -472,7 +509,7 @@ curl -X POST http://localhost:8000/api/v1/skills \
"instructions": "生成高质量内容",
"output_format": "JSON: {content, word_count}"
},
"llm": {"model": "openai/gpt-4o"},
"llm": {"model": "default"},
"intent": {
"keywords": ["生成", "内容"],
"description": "内容生成"
@ -546,7 +583,7 @@ async def main():
"instructions": "生成高质量内容",
"output_format": "JSON: {content, word_count}",
},
"llm": {"model": "openai/gpt-4o"},
"llm": {"model": "default"},
"intent": {"keywords": ["生成", "内容"], "description": "内容生成"},
"quality_gate": {"required_fields": ["content"], "max_retries": 2},
"execution_mode": "react",
@ -621,7 +658,7 @@ prompt:
output_format: "JSON: generate_topics 返回 {topics: [{title, reason, keywords}]}generate_article 返回 {content, word_count}"
llm:
model: "deepseek"
model: "default"
temperature: 0.7
max_tokens: 4000
@ -664,34 +701,30 @@ skill = Skill(config=config)
```yaml
providers:
dashscope:
api_key: "${DASHSCOPE_API_KEY}"
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
models:
qwen3-coder-plus:
max_tokens: 64000
cost_per_1k_input: 0.00014
cost_per_1k_output: 0.00028
openai:
api_key: "sk-xxx"
api_key: "${OPENAI_API_KEY}"
base_url: "https://api.openai.com/v1"
models:
gpt-4o:
cost_per_1k_input: 0.005
cost_per_1k_output: 0.015
gpt-4o-mini:
cost_per_1k_input: 0.00015
cost_per_1k_output: 0.0006
deepseek:
api_key: "sk-xxx"
base_url: "https://api.deepseek.com/v1"
models:
deepseek-chat:
cost_per_1k_input: 0.001
cost_per_1k_output: 0.002
model_aliases:
default: "deepseek/deepseek-chat"
fast: "openai/gpt-4o-mini"
default: "dashscope/qwen3-coder-plus"
fast: "dashscope/qwen3-coder-plus"
powerful: "openai/gpt-4o"
fallbacks:
openai/gpt-4o:
- "deepseek/deepseek-chat"
deepseek/deepseek-chat:
- "openai/gpt-4o-mini"
dashscope/qwen3-coder-plus:
- "openai/gpt-4o"
```
加载 LLM 配置:
@ -802,7 +835,7 @@ ReActEngine 实现 Think -> Act -> Observe 循环:
统一 LLM 调用入口,核心能力:
- **Provider 注册**: `gateway.register_provider("openai", provider)`
- **模型别名**: `"default"` -> `"deepseek/deepseek-chat"`
- **模型别名**: `"default"` -> `"dashscope/qwen3-coder-plus"`
- **Fallback 降级**: 主模型失败时自动切换到备选模型
- **用量追踪**: 按 agent_name、model 统计 Token 用量和成本
- **模型解析**: `"provider/model"` 格式自动路由到对应 Provider
@ -878,10 +911,12 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式Qu
### server -- FastAPI Server
独立部署模式,提供 RESTful API
独立部署模式,提供 RESTful API 和 Web GUI
| 路径 | 方法 | 说明 |
|------|------|------|
| `/` | GET | Web GUI 聊天界面 |
| `/ws/chat` | WebSocket | GUI 实时聊天通道 |
| `/api/v1/agents` | POST | 创建 Agent指定 skill_name 或 config |
| `/api/v1/agents` | GET | 列出所有 Agent |
| `/api/v1/agents/{name}` | GET | 获取 Agent 详情 |
@ -892,6 +927,27 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式Qu
| `/api/v1/llm/usage` | GET | 查询 LLM 用量 |
| `/api/v1/health` | GET | 健康检查 |
### Web GUI 聊天界面
通过 `agentkit gui` 启动,特性:
- **实时对话** -- WebSocket 流式传输,逐 token 显示
- **Markdown 渲染** -- 自动检测并渲染标题、列表、代码块、表格等
- **工具确认卡片** -- 危险命令(如 `rm`)执行前弹出确认卡片,用户批准后才执行
- **Loading 动画** -- 等待 AI 响应时显示思考动画
- **Skill 路由** -- 输入 `@skill_name:` 前缀可指定使用特定 Skill
- **会话管理** -- 多会话并行,历史记录持久化
### orchestrator -- Pipeline 编排
多 Agent 协同编排模块:
- **PipelineEngine** -- 按 Stage 定义顺序执行,支持自适应配置和反思重规划
- **SagaOrchestrator** -- 分布式事务补偿,失败步骤自动执行补偿操作
- **DynamicPipeline** -- 运行时根据条件动态调整流水线结构
- **HandoffManager** -- Agent 间任务移交,支持上下文传递
- **PipelineStateMemory/Redis/PG** -- 流水线状态持久化支持内存、Redis、PostgreSQL 后端
## 配置参考
### SkillConfig
@ -941,8 +997,8 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式Qu
| 字段 | 类型 | 默认值 | 说明 |
|------|------|--------|------|
| `providers` | dict[str, ProviderConfig] | `{}` | Provider 配置key 为 provider 名称 |
| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "deepseek/deepseek-chat"` |
| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `openai/gpt-4o: ["deepseek/deepseek-chat"]` |
| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "dashscope/qwen3-coder-plus"` |
| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `dashscope/qwen3-coder-plus: ["openai/gpt-4o"]` |
#### ProviderConfig
@ -977,9 +1033,9 @@ from agentkit import LLMGateway
from agentkit.llm.providers.openai import OpenAIProvider
gateway = LLMGateway()
gateway.register_provider("deepseek", OpenAIProvider(
gateway.register_provider("dashscope", OpenAIProvider(
api_key="sk-xxx",
base_url="https://api.deepseek.com/v1",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
))
app = create_app(llm_gateway=gateway)
@ -1135,6 +1191,8 @@ tools:
- search
```
**ShellTool 安全机制**ShellTool 内置白名单(`ls`、`cat`、`curl` 等安全命令直接执行),非白名单命令会触发用户确认。在 GUI 中以确认卡片形式展示,用户点击"确认执行"后才运行。
### 代码风格
项目使用 Ruff 进行代码检查和格式化:

View File

@ -40,3 +40,4 @@ logging:
format: text
router:
classifier: heuristic
auction_enabled: false

View File

@ -2,18 +2,9 @@
# 环境变量替换:${VAR_NAME} 在启动时由 LLMConfig.from_yaml() 处理
providers:
deepseek:
api_key: "${DEEPSEEK_API_KEY}"
base_url: "https://api.deepseek.com/v1"
models:
deepseek-chat:
max_tokens: 64000
cost_per_1k_input: 0.00014
cost_per_1k_output: 0.00028
openai:
api_key: "${OPENAI_API_KEY}"
base_url: "${OPENAI_BASE_URL:-https://coding.dashscope.aliyuncs.com/v1}"
dashscope:
api_key: "${DASHSCOPE_API_KEY}"
base_url: "${DASHSCOPE_BASE_URL:-https://dashscope.aliyuncs.com/compatible-mode/v1}"
models:
qwen3-coder-plus:
max_tokens: 64000
@ -21,13 +12,9 @@ providers:
cost_per_1k_output: 0.00028
model_aliases:
default: "openai/qwen3-coder-plus"
fast: "openai/qwen3-coder-plus"
powerful: "openai/qwen3-coder-plus"
fallbacks:
openai/qwen3-coder-plus:
- "deepseek/deepseek-chat"
default: "dashscope/qwen3-coder-plus"
fast: "dashscope/qwen3-coder-plus"
powerful: "dashscope/qwen3-coder-plus"
# 上下文压缩配置 — 长会话自动压缩历史消息,保持 Token 在预算内
# GEO Pipeline 启用后,工具输出(搜索结果、网页抓取等)会自动压缩

View File

@ -6,6 +6,10 @@ task_mode: llm_generate
execution_mode: rewoo
max_steps: 8
max_concurrency: 3
fallback_strategies:
- simplified_rewoo
- react
- direct
intent:
keywords: ["采集", "批量", "并行", "fetch", "collect", "数据获取", "多源"]

View File

@ -12,6 +12,7 @@ import re
from dataclasses import dataclass, field
from typing import Any
from agentkit.marketplace.auction import AuctionHouse, Bid
from agentkit.telemetry.tracer import get_tracer
logger = logging.getLogger(__name__)
@ -256,6 +257,8 @@ _CHAT_MODE_RE = re.compile(
re.IGNORECASE,
)
_SENTENCE_SPLIT_RE = re.compile(r'[,。!?;\n,.!?;]')
def _tokenize_content(content: str) -> list[str]:
"""Tokenize content for capability matching. Supports Chinese and English."""
@ -394,7 +397,7 @@ class HeuristicClassifier:
score += 0.05
# 3. 多句加成(逗号/句号/换行分隔)
sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content))
sentence_count = len(_SENTENCE_SPLIT_RE.split(content))
if sentence_count >= 4:
score += 0.1
elif sentence_count >= 2:
@ -424,12 +427,15 @@ class CostAwareRouter:
org_context: Any = None,
auction_enabled: bool = False,
classifier: str = "heuristic",
merged_llm_classify: bool = True,
):
self._llm_gateway = llm_gateway
self._model = model
self._org_context = org_context
self._auction_enabled = auction_enabled
self._classifier = classifier
self._merged_llm_classify = merged_llm_classify
self._auction_house = AuctionHouse() if auction_enabled else None
if classifier not in ("heuristic", "llm"):
raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'")
self._heuristic = HeuristicClassifier()
@ -489,6 +495,175 @@ class CostAwareRouter:
logger.warning(f"CostAwareRouter quick_classify failed: {e}")
return 0.5
# -- Layer 1.5: Merged LLM classify (complexity + intent in one call) ---
async def _classify_merged(
self,
content: str,
skill_registry: Any,
intent_router: Any,
default_tools: list,
default_system_prompt: str | None,
default_model: str,
default_agent_name: str,
agent_tool_registry: Any = None,
session_id: str = "",
complexity: float = 0.5,
) -> SkillRoutingResult:
"""合并 LLM 调用:单次 LLM 同时输出 complexity + intent + skill_hint。
HeuristicClassifier 返回不确定区间 (0.3-0.7) 时使用
替代分别调用 quick_classify() IntentRouter._classify_with_llm()
节省 1 LLM 调用 (~1-3s)
"""
if self._llm_gateway is None or not self._merged_llm_classify:
# Fallback: 使用独立的 IntentRouter 路由
return await resolve_skill_routing(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=default_agent_name,
agent_tool_registry=agent_tool_registry,
session_id=session_id,
)
# Build skill list for the prompt
skill_hints = []
if skill_registry:
try:
for s in skill_registry.list_skills():
if s.config.intent and s.config.intent.keywords:
skill_hints.append(s.name)
except Exception:
pass
skill_list_str = ", ".join(skill_hints) if skill_hints else "none"
prompt = (
'You are a routing classifier. Analyze the user request and output:\n'
'1. complexity (0.0-1.0): how complex is this request\n'
'2. intent: the primary intent category\n'
'3. skill_hint: the best matching skill name, or null if none match\n\n'
f'Available skills: [{skill_list_str}]\n\n'
'---BEGIN USER REQUEST---\n'
f'{content}\n'
'---END USER REQUEST---\n\n'
'Respond ONLY with a JSON object: {"complexity": <float>, "intent": <string>, "skill_hint": <string|null>}'
)
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
)
data = json.loads(response.content.strip())
merged_complexity = float(data.get("complexity", 0.5))
merged_complexity = max(0.0, min(1.0, merged_complexity))
skill_hint = data.get("skill_hint")
# If skill_hint provided and valid, route directly to that skill
if skill_hint and skill_registry:
try:
matched_skill = skill_registry.get(skill_hint)
result = SkillRoutingResult(
clean_content=content,
skill_name=skill_hint,
skill_config=matched_skill.config,
skill_tools=matched_skill.tools or [],
matched=True,
match_method="merged_llm",
match_confidence=0.7,
complexity=merged_complexity,
)
# Merge tools
agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools
seen_names = set()
merged_tools = []
for tool in result.skill_tools + agent_tools:
if tool.name not in seen_names:
seen_names.add(tool.name)
merged_tools.append(tool)
result.tools = merged_tools
result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model
result.agent_name = skill_hint
result.system_prompt = build_skill_system_prompt(result.skill_config) or default_system_prompt
logger.info(
f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' "
f"(complexity={merged_complexity:.2f})"
)
return result
except Exception as e:
logger.warning(f"Session {session_id}: merged LLM skill_hint '{skill_hint}' not found: {e}")
# No valid skill_hint — use complexity to decide routing
if merged_complexity < 0.3:
return SkillRoutingResult(
clean_content=content,
system_prompt=default_system_prompt,
tools=default_tools,
model=default_model,
agent_name=default_agent_name,
matched=False,
match_method="merged_llm_low",
match_confidence=1.0 - merged_complexity,
complexity=merged_complexity,
)
elif merged_complexity > 0.7:
# High complexity — delegate to Layer 2
return SkillRoutingResult(
clean_content=content,
system_prompt=default_system_prompt,
tools=default_tools,
model=default_model,
agent_name=default_agent_name,
matched=False,
match_method="merged_llm_high",
match_confidence=merged_complexity,
complexity=merged_complexity,
)
else:
# Medium complexity, no skill match — default agent
return SkillRoutingResult(
clean_content=content,
system_prompt=default_system_prompt,
tools=default_tools,
model=default_model,
agent_name=default_agent_name,
matched=False,
match_method="merged_llm_medium",
match_confidence=0.5,
complexity=merged_complexity,
)
except (json.JSONDecodeError, TypeError, ValueError) as e:
logger.warning(f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default")
return SkillRoutingResult(
clean_content=content,
system_prompt=default_system_prompt,
tools=default_tools,
model=default_model,
agent_name=default_agent_name,
matched=False,
match_method="merged_llm_fallback",
match_confidence=0.5,
complexity=0.5,
)
except Exception as e:
logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default")
return SkillRoutingResult(
clean_content=content,
system_prompt=default_system_prompt,
tools=default_tools,
model=default_model,
agent_name=default_agent_name,
matched=False,
match_method="merged_llm_fallback",
match_confidence=0.5,
complexity=0.5,
)
# -- Layer 2: Capability matching / Auction (optional) -----------------
async def _route_layer2(
@ -505,12 +680,89 @@ class CostAwareRouter:
complexity: float = 0.0,
trace: list[dict] | None = None,
) -> SkillRoutingResult:
"""Layer 2: 高复杂度任务通过 org_context.find_best_agent 路由。"""
"""Layer 2: 高复杂度任务通过拍卖或 org_context.find_best_agent 路由。"""
# Extract capability-like keywords from content for matching
content_words = _tokenize_content(content)
# --- Vickrey auction path (when enabled) ---
if self._auction_enabled and self._auction_house is not None and self._org_context is not None:
try:
# Gather candidate agents from org_context
all_agents = self._org_context.list_agents() if hasattr(self._org_context, "list_agents") else []
# Filter agents that have at least one relevant capability
candidate_agents = []
for agent_profile in all_agents:
if not agent_profile.availability:
continue
# Check if agent has any of the content_words as capabilities
agent_caps_lower = {c.lower() for c in agent_profile.capabilities}
if any(w.lower() in agent_caps_lower for w in content_words):
candidate_agents.append(agent_profile)
# Also include agents that match via find_best_agent (they have ALL required caps)
best = self._org_context.find_best_agent(required_capabilities=content_words)
if best is not None:
best_name = best if isinstance(best, str) else getattr(best, "name", str(best))
existing_names = {a.name for a in candidate_agents}
if best_name not in existing_names:
profile = self._org_context.get_agent_profile(best_name) if hasattr(self._org_context, "get_agent_profile") else best
if hasattr(profile, "name"):
candidate_agents.append(profile)
if len(candidate_agents) >= 1:
# Build Bid objects for each candidate
bids = []
for agent_profile in candidate_agents:
name = agent_profile.name if hasattr(agent_profile, "name") else str(agent_profile)
caps = agent_profile.capabilities if hasattr(agent_profile, "capabilities") else []
arch = agent_profile.agent_type if hasattr(agent_profile, "agent_type") else "react"
# Use current_load as a proxy for estimated_cost (higher load → higher cost)
estimated_cost = float(agent_profile.current_load + 1) if hasattr(agent_profile, "current_load") else 1.0
bids.append(Bid(
agent_name=name,
architecture=arch,
estimated_steps=1,
estimated_cost=estimated_cost,
confidence=0.8,
payment_offer=estimated_cost,
capabilities=caps,
))
auction_result = await self._auction_house.run_vickrey_auction(
task_description=content,
bidders=bids,
required_capabilities=content_words,
)
if auction_result.winner is not None:
winner_name = auction_result.winner.agent_name
result = SkillRoutingResult(
clean_content=content,
matched=True,
match_method="vickrey_auction",
match_confidence=0.8,
agent_name=winner_name,
model=default_model,
system_prompt=default_system_prompt,
tools=default_tools,
complexity=complexity,
)
if trace is not None:
trace.append({
"layer": 2,
"method": "vickrey_auction",
"agent_name": winner_name,
"complexity": complexity,
"selection_reason": auction_result.selection_reason,
})
return result
# No winner from auction → fall through to capability matching
except Exception as e:
logger.warning(f"CostAwareRouter Layer 2 Vickrey auction failed: {e}")
# --- Capability matching path (default) ---
if self._org_context is not None and hasattr(self._org_context, "find_best_agent"):
try:
# Extract capability-like keywords from content for matching
# find_best_agent expects list[str] of required capabilities
content_words = _tokenize_content(content)
best_agent = self._org_context.find_best_agent(required_capabilities=content_words)
if best_agent is not None:
agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent))
@ -689,29 +941,69 @@ class CostAwareRouter:
span.set_attribute("route.target", "default")
return result
# Medium complexity → IntentRouter via resolve_skill_routing
# Medium complexity → merged LLM classify or IntentRouter
if complexity <= 0.7:
result = await resolve_skill_routing(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=default_agent_name,
agent_tool_registry=agent_tool_registry,
session_id=session_id,
)
result.complexity = complexity
if self._merged_llm_classify and self._llm_gateway is not None:
# Use merged LLM call: complexity + intent in one call
result = await self._classify_merged(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=default_agent_name,
agent_tool_registry=agent_tool_registry,
session_id=session_id,
complexity=complexity,
)
# If merged classify returned high complexity, delegate to Layer 2
if result.complexity > 0.7 and result.match_method and result.match_method.startswith("merged_llm_high"):
trace.append({
"layer": 1,
"method": "merged_llm_high",
"complexity": result.complexity,
"delegated_to_layer2": True,
})
layer2_result = await self._route_layer2(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=default_agent_name,
agent_tool_registry=agent_tool_registry,
session_id=session_id,
complexity=result.complexity,
trace=trace,
)
layer2_result.execution_trace = trace if transparency != "SILENT" else []
layer2_result.transparency_level = transparency
return layer2_result
else:
# Fallback: use separate IntentRouter
result = await resolve_skill_routing(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=default_agent_name,
agent_tool_registry=agent_tool_registry,
session_id=session_id,
)
result.complexity = result.complexity or complexity
trace.append({
"layer": 1,
"method": "intent_router",
"complexity": complexity,
"method": result.match_method or "merged_llm",
"complexity": result.complexity,
"matched": result.matched,
})
result.execution_trace = trace if transparency != "SILENT" else []
result.transparency_level = transparency
span.set_attribute("route.layer", result.match_method or "intent_router")
span.set_attribute("route.layer", result.match_method or "merged_llm")
span.set_attribute("route.target", result.skill_name or "default")
return result

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any
from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import AgentStatus
from agentkit.core.react import ReActEngine
from agentkit.llm.gateway import LLMGateway
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
@ -53,6 +54,16 @@ class AgentPool:
compressor=self._compressor,
)
await agent.start()
# Bind a reusable ReActEngine instance to the agent
max_steps = getattr(config, "max_steps", 10) or 10
parallel_tools = getattr(config, "parallel_tools", False) or False
agent._react_engine = ReActEngine(
llm_gateway=self._llm_gateway,
max_steps=max_steps,
parallel_tools=parallel_tools,
)
self._agents[config.name] = agent
logger.info(f"Agent '{config.name}' created and started in pool")

View File

@ -74,14 +74,27 @@ class ReActEngine:
使 Agent 能够自主推理并选择工具完成任务
"""
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False):
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False):
if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
if isinstance(parallel_tools, str) and parallel_tools not in ("auto",):
raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}")
self._llm_gateway = llm_gateway
self._max_steps = max_steps
self._default_timeout = default_timeout
self._parallel_tools = parallel_tools
def reset(self) -> None:
"""Reset internal state for reuse across conversations.
Call this before each execute/execute_stream to ensure clean state.
The engine itself (LLM gateway, config) is preserved.
"""
# ReActEngine is stateless between calls — conversation history,
# step counts, and trajectory are local to each execute call.
# This method exists for API clarity and future stateful extensions.
pass
async def execute(
self,
messages: list[dict[str, str]],
@ -97,6 +110,7 @@ class ReActEngine:
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
timeout_seconds: float | None = None,
confirmation_handler: Any | None = None,
) -> ReActResult:
"""执行 ReAct 循环
@ -127,6 +141,7 @@ class ReActEngine:
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
),
timeout=effective_timeout,
)
@ -144,6 +159,7 @@ class ReActEngine:
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
)
except asyncio.TimeoutError:
raise TaskTimeoutError(
@ -169,6 +185,7 @@ class ReActEngine:
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
confirmation_handler: Any | None = None,
) -> ReActResult:
tools = tools or []
tool_schemas = self._build_tool_schemas(tools) if tools else None
@ -224,7 +241,7 @@ class ReActEngine:
else:
system_prompt = f"## 参考信息\n{memory_context}"
except Exception as e:
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True)
# 构建初始消息
conversation: list[dict[str, Any]] = []
@ -295,8 +312,70 @@ class ReActEngine:
conversation.append(assistant_msg)
# 执行工具调用
if self._parallel_tools and len(response.tool_calls) > 1:
# 并行执行多个工具调用
if self._parallel_tools == "auto" and len(response.tool_calls) > 1:
# Auto mode: mixed parallel/serial based on _parallelizable flag
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls))
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
# Result slots indexed by original position
all_results: list[Any] = [None] * len(response.tool_calls)
# Execute serial tools first (in order)
for i, tc in serial_calls:
tool_start = time.monotonic()
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
all_results[i] = (tc, tool_result, tool_duration_ms)
# Execute parallelizable tools in parallel
if len(parallel_calls) > 1:
para_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
return_exceptions=True,
)
for j, (i, tc) in enumerate(parallel_calls):
tool_result = para_results[j]
if isinstance(tool_result, Exception):
tool_result = {"error": str(tool_result)}
all_results[i] = (tc, tool_result, 0)
elif len(parallel_calls) == 1:
i, tc = parallel_calls[0]
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
all_results[i] = (tc, tool_result, 0)
# Process all results in original order
for i, tc in enumerate(response.tool_calls):
tc_obj, tool_result, tool_duration_ms = all_results[i]
react_step = ReActStep(
step=step,
action="tool_call",
tool_name=tc.name,
arguments=tc.arguments,
result=tool_result,
tokens=step_tokens,
)
trajectory.append(react_step)
if trace_recorder is not None:
tool_error = None
if isinstance(tool_result, dict) and "error" in tool_result:
tool_error = tool_result["error"]
trace_recorder.record_step(
step=step,
action="tool_call",
tool_name=tc.name,
input_data=tc.arguments,
output_data=tool_result,
duration_ms=tool_duration_ms,
tokens_used=0,
error=tool_error,
)
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
elif self._should_execute_parallel(response.tool_calls):
# 并行执行多个工具调用 (parallel_tools=True)
tool_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
return_exceptions=True,
@ -338,6 +417,40 @@ class ReActEngine:
for tc in response.tool_calls:
tool_start = time.monotonic()
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
# Handle confirmation flow
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "")
reason = tool_result.get("reason", "")
approved = False
if confirmation_handler is not None:
try:
approved = await confirmation_handler(confirmation_id, command, reason)
except Exception as e:
logger.warning(f"Confirmation handler error: {e}")
if approved:
tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'):
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
clean_args["_skip_dangerous_check"] = True
try:
tool_result = await tool.safe_execute(**clean_args)
except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
else:
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
else:
tool_result = {
"output": "",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"用户拒绝执行命令: {command[:100]}",
}
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
react_step = ReActStep(
@ -531,6 +644,21 @@ class ReActEngine:
else:
logger.info("ReActEngine executing with NO tools")
# Telemetry: record agent request
agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"})
# Start telemetry span for the entire agent execution
_span_cm = None
_span = None
_exec_start = time.monotonic()
if _OTEL_AVAILABLE:
_span_cm = start_span(
"agent.execute_stream",
attributes={"agent.name": agent_name, "agent.type": task_type or "react"},
)
_span = _span_cm.__enter__()
# 启动轨迹记录
if trace_recorder is not None:
trace_recorder.start_trace(
@ -575,11 +703,26 @@ class ReActEngine:
step = 0
output = ""
trace_outcome = "success"
_stream_start = time.monotonic()
effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout
try:
while step < self._max_steps:
step += 1
# 协作式取消检查
if cancellation_token is not None:
cancellation_token.check()
# 超时检查
if effective_timeout > 0:
elapsed = time.monotonic() - _stream_start
if elapsed > effective_timeout:
trace_outcome = "timeout"
raise asyncio.TimeoutError(
f"execute_stream exceeded {effective_timeout}s timeout after {elapsed:.1f}s"
)
# Yield thinking event
yield ReActEvent(
event_type="thinking",
@ -591,7 +734,7 @@ class ReActEngine:
llm_start = time.monotonic()
# Use streaming for token-by-token output
stream_content = ""
stream_content_chunks: list[str] = []
stream_usage = None
stream_tool_calls: list[Any] = []
stream_model = model
@ -604,7 +747,7 @@ class ReActEngine:
tools=tool_schemas,
):
if chunk.content:
stream_content += chunk.content
stream_content_chunks.append(chunk.content)
yield ReActEvent(
event_type="token",
step=step,
@ -618,6 +761,7 @@ class ReActEngine:
stream_model = chunk.model
# Build response-like object from stream
stream_content = "".join(stream_content_chunks)
response = self._build_response_from_stream(
content=stream_content,
tool_calls=stream_tool_calls,
@ -657,115 +801,168 @@ class ReActEngine:
}
conversation.append(assistant_msg)
for tc in response.tool_calls:
# Yield tool_call event
yield ReActEvent(
event_type="tool_call",
step=step,
data={"tool_name": tc.name, "arguments": tc.arguments},
)
# Execute tool calls with parallel support
if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls):
# Parallel execution path
parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls)))
serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set]
parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set]
tool_start = time.monotonic()
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
all_results: list[Any] = [None] * len(response.tool_calls)
# 检测工具返回的确认请求
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "")
reason = tool_result.get("reason", "")
# Execute serial tools first (handles confirmation flow)
for i, tc in serial_calls:
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
tool_start = time.monotonic()
tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler)
for ev in confirm_events:
yield ev
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
all_results[i] = (tc, tool_result, tool_duration_ms)
# Yield 确认请求事件
# Execute parallelizable tools concurrently
if len(parallel_calls) > 1:
para_results = await asyncio.gather(
*[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls],
return_exceptions=True,
)
for j, (i, tc) in enumerate(parallel_calls):
tool_result = para_results[j]
if isinstance(tool_result, Exception):
tool_result = {"error": str(tool_result)}
all_results[i] = (tc, tool_result, 0)
elif len(parallel_calls) == 1:
i, tc = parallel_calls[0]
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
all_results[i] = (tc, tool_result, 0)
# Process all results in original order
for i, tc in enumerate(response.tool_calls):
tc_obj, tool_result, tool_duration_ms = all_results[i]
yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments})
react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens)
trajectory.append(react_step)
if trace_recorder is not None:
tool_error = None
if isinstance(tool_result, dict) and "error" in tool_result:
tool_error = tool_result["error"]
trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error)
yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result})
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
else:
# Serial execution path (with confirmation flow)
for tc in response.tool_calls:
# Yield tool_call event
yield ReActEvent(
event_type="confirmation_request",
event_type="tool_call",
step=step,
data={
"confirmation_id": confirmation_id,
"tool_name": tc.name,
"command": command,
"reason": reason,
},
data={"tool_name": tc.name, "arguments": tc.arguments},
)
# 等待用户确认
approved = False
if confirmation_handler is not None:
try:
approved = await confirmation_handler(confirmation_id, command, reason)
except Exception as e:
logger.warning(f"Confirmation handler error: {e}")
if approved:
# 用户确认执行:临时绕过安全检查重新执行
tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'):
# 保存原始 _is_dangerous 并临时禁用
original_is_dangerous = tool._is_dangerous
tool._is_dangerous = lambda cmd: False
try:
tool_result = await tool.safe_execute(**tc.arguments)
finally:
tool._is_dangerous = original_is_dangerous
else:
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
yield ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": True},
)
else:
# 用户拒绝执行
tool_result = {
"output": "",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"用户拒绝执行命令: {command[:100]}",
}
yield ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": False},
)
tool_start = time.monotonic()
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
react_step = ReActStep(
step=step,
action="tool_call",
tool_name=tc.name,
arguments=tc.arguments,
result=tool_result,
tokens=step_tokens,
)
trajectory.append(react_step)
# 检测工具返回的确认请求
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "")
reason = tool_result.get("reason", "")
# 记录工具调用步骤
if trace_recorder is not None:
tool_error = None
if isinstance(tool_result, dict) and "error" in tool_result:
tool_error = tool_result["error"]
trace_recorder.record_step(
# Yield 确认请求事件
yield ReActEvent(
event_type="confirmation_request",
step=step,
data={
"confirmation_id": confirmation_id,
"tool_name": tc.name,
"command": command,
"reason": reason,
},
)
# 等待用户确认
approved = False
if confirmation_handler is not None:
try:
approved = await confirmation_handler(confirmation_id, command, reason)
except Exception as e:
logger.warning(f"Confirmation handler error: {e}")
if approved:
# 用户确认执行:使用 per-call override 绕过安全检查
tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'):
# Strip internal metadata and pass skip_dangerous_check flag
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
clean_args["_skip_dangerous_check"] = True
try:
tool_result = await tool.safe_execute(**clean_args)
finally:
pass # No shared state mutation needed
else:
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
yield ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": True},
)
else:
# 用户拒绝执行
tool_result = {
"output": "",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"用户拒绝执行命令: {command[:100]}",
}
yield ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": False},
)
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
react_step = ReActStep(
step=step,
action="tool_call",
tool_name=tc.name,
input_data=tc.arguments,
output_data=tool_result,
duration_ms=tool_duration_ms,
tokens_used=0,
error=tool_error,
arguments=tc.arguments,
result=tool_result,
tokens=step_tokens,
)
trajectory.append(react_step)
if trace_recorder is not None:
tool_error = None
if isinstance(tool_result, dict) and "error" in tool_result:
tool_error = tool_result["error"]
trace_recorder.record_step(
step=step,
action="tool_call",
tool_name=tc.name,
input_data=tc.arguments,
output_data=tool_result,
duration_ms=tool_duration_ms,
tokens_used=0,
error=tool_error,
)
# Yield tool_result event
yield ReActEvent(
event_type="tool_result",
step=step,
data={"tool_name": tc.name, "result": tool_result},
)
# Yield tool_result event
yield ReActEvent(
event_type="tool_result",
step=step,
data={"tool_name": tc.name, "result": tool_result},
)
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor):
@ -893,6 +1090,17 @@ class ReActEngine:
if trace_recorder is not None:
trace_recorder.end_trace(outcome=trace_outcome)
# Telemetry: end span and record duration — always runs
_duration_ms = int((time.monotonic() - _exec_start) * 1000)
if _span is not None:
_span.set_attribute("agent.total_steps", len(trajectory))
_span.set_attribute("agent.total_tokens", total_tokens)
_span.set_attribute("agent.outcome", trace_outcome)
_span.set_attribute("agent.duration_ms", _duration_ms)
if _span_cm is not None:
_span_cm.__exit__(None, None, None)
agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name})
# Memory storage: 执行后写入轨迹摘要到 EpisodicMemory
if memory_retriever and hasattr(memory_retriever, "store_episode"):
try:
@ -988,14 +1196,127 @@ class ReActEngine:
logger.warning(error_msg)
return {"error": error_msg}
# Strip internal metadata keys before passing to tool
clean_args = {k: v for k, v in arguments.items() if not k.startswith("_")}
try:
result = await tool.safe_execute(**arguments)
result = await tool.safe_execute(**clean_args)
return result
except Exception as e:
error_msg = f"Tool '{tool_name}' execution failed: {e}"
logger.warning(error_msg)
return {"error": error_msg}
async def _execute_tool_with_confirmation(
self,
tc: Any,
tools: list[Tool],
step: int,
confirmation_handler: Any,
) -> tuple[Any, list[ReActEvent]]:
"""Execute a tool call with confirmation flow support.
Used in the parallel execution path for serial (non-parallelizable) tools
that may require user confirmation before execution.
Returns:
Tuple of (tool_result, list of ReActEvents to yield)
"""
events: list[ReActEvent] = []
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
# Check if tool returned a confirmation request
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
confirmation_id = tool_result["confirmation_id"]
command = tool_result.get("command", "")
reason = tool_result.get("reason", "")
events.append(ReActEvent(
event_type="confirmation_request",
step=step,
data={
"confirmation_id": confirmation_id,
"tool_name": tc.name,
"command": command,
"reason": reason,
},
))
# Wait for user confirmation
approved = False
if confirmation_handler is not None:
try:
approved = await confirmation_handler(confirmation_id, command, reason)
except Exception as e:
logger.warning(f"Confirmation handler error: {e}")
if approved:
# User approved: re-execute with _skip_dangerous_check
tool = self._find_tool(tc.name, tools)
if tool and hasattr(tool, '_is_dangerous'):
clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")}
clean_args["_skip_dangerous_check"] = True
try:
tool_result = await tool.safe_execute(**clean_args)
except Exception as e:
tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"}
else:
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
events.append(ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": True},
))
else:
# User rejected
tool_result = {
"output": "",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"用户拒绝执行命令: {command[:100]}",
}
events.append(ReActEvent(
event_type="confirmation_result",
step=step,
data={"confirmation_id": confirmation_id, "approved": False},
))
return tool_result, events
def _should_execute_parallel(self, tool_calls: list[Any]) -> bool:
"""Determine if tool calls should be executed in parallel.
- parallel_tools=True: always parallel (if >1 tool)
- parallel_tools=False: never parallel
- parallel_tools="auto": parallel if any tool_call has _parallelizable=true in arguments
"""
if len(tool_calls) <= 1:
return False
if self._parallel_tools is True:
return True
if self._parallel_tools is False:
return False
# "auto" mode: check _parallelizable metadata in tool call arguments
if self._parallel_tools == "auto":
parallelizable_indices = self._get_parallelizable_indices(tool_calls)
return len(parallelizable_indices) > 1
return False
def _get_parallelizable_indices(self, tool_calls: list[Any]) -> list[int]:
"""Get indices of tool_calls that have _parallelizable=true in arguments.
LLM marks parallelizable tools by including _parallelizable: true
in the tool_call arguments.
"""
indices = []
for i, tc in enumerate(tool_calls):
args = tc.arguments if hasattr(tc, 'arguments') else {}
if isinstance(args, dict) and args.get("_parallelizable") is True:
indices.append(i)
return indices
def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]:
"""从文本中解析工具调用模式

View File

@ -8,6 +8,7 @@
import asyncio
import json
import logging
import re
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
@ -33,6 +34,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# ── Internal Exceptions ──────────────────────────────────
class _FallbackFailedError(Exception):
"""Internal signal: a fallback strategy failed, try the next one."""
def __init__(self, strategy: str):
self.strategy = strategy
super().__init__(f"Fallback strategy '{strategy}' failed")
# ── Data Structures ───────────────────────────────────────
@ -116,13 +128,25 @@ class ReWOOEngine:
"""
FALLBACK_STRATEGIES = ["simplified_rewoo", "react", "direct"]
VALID_STRATEGIES = {"simplified_rewoo", "react", "direct", "plan_exec"}
def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0):
def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0, fallback_strategies: list[str] | None = None):
if max_plan_steps < 1:
raise ValueError(f"max_plan_steps must be >= 1, got {max_plan_steps}")
self._llm_gateway = llm_gateway
self._max_plan_steps = max_plan_steps
self._default_timeout = default_timeout
# Validate and store fallback strategies
raw_strategies = fallback_strategies if fallback_strategies is not None else self.FALLBACK_STRATEGIES
self._fallback_strategies: list[str] = []
for s in raw_strategies:
if s in self.VALID_STRATEGIES:
self._fallback_strategies.append(s)
else:
logger.warning(f"Invalid fallback strategy '{s}', skipping. Valid: {self.VALID_STRATEGIES}")
if not self._fallback_strategies:
logger.warning("No valid fallback strategies, using defaults")
self._fallback_strategies = list(self.FALLBACK_STRATEGIES)
# ReActEngine 作为 fallback
self._react_engine = ReActEngine(
llm_gateway=llm_gateway,
@ -145,6 +169,7 @@ class ReWOOEngine:
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
timeout_seconds: float | None = None,
confirmation_handler: Any | None = None,
) -> ReActResult:
"""执行 ReWOO 三阶段流程
@ -176,6 +201,7 @@ class ReWOOEngine:
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
),
timeout=effective_timeout,
)
@ -193,6 +219,7 @@ class ReWOOEngine:
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
)
except asyncio.TimeoutError:
raise TaskTimeoutError(
@ -218,6 +245,7 @@ class ReWOOEngine:
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
confirmation_handler: Any | None = None,
) -> ReActResult:
tools = tools or []
tool_schemas = self._build_tool_schemas(tools) if tools else None
@ -296,109 +324,31 @@ class ReWOOEngine:
tokens_used=planning_tokens,
)
# 如果规划失败,尝试渐进式回退
# 如果规划失败,按配置的 fallback 策略顺序尝试回退
if plan is None:
# 尝试简化规划max_steps=3
logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3")
try:
plan, simplified_tokens = await self._plan_phase(
messages=messages,
tools=tools,
tool_schemas=tool_schemas,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=effective_system_prompt,
compressor=compressor,
cancellation_token=cancellation_token,
max_steps=3,
)
total_tokens += simplified_tokens
if plan is not None and plan.steps:
fallback_strategy = "simplified_rewoo"
logger.info("Simplified ReWOO planning succeeded")
except Exception as e2:
logger.warning(f"Simplified ReWOO planning also failed: {e2}")
if plan is None:
# 回退到 ReAct
fallback_strategy = "react"
logger.warning("ReWOO planning failed, falling back to ReActEngine")
if trace_recorder is not None:
trace_recorder.end_trace(outcome="fallback")
try:
react_result = await self._react_engine.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id,
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
timeout_seconds=0, # timeout already handled by outer wrapper
)
react_result.fallback_strategy = fallback_strategy
return react_result
except Exception as react_err:
# ReAct 也失败,回退到 Direct简单 LLM 调用)
fallback_strategy = "direct"
logger.warning(f"ReAct fallback also failed: {react_err}, falling back to direct LLM call")
try:
direct_messages: list[dict[str, Any]] = []
if effective_system_prompt:
direct_messages.append({"role": "system", "content": effective_system_prompt})
direct_messages.extend(messages)
if compressor:
try:
direct_messages = await compressor.compress(direct_messages)
except Exception as e:
logger.warning(f"Context compression failed in direct fallback: {e}")
direct_response = await self._llm_gateway.chat(
messages=direct_messages,
model=model,
agent_name=agent_name,
task_type=task_type,
)
total_tokens += direct_response.usage.total_tokens
direct_step = ReWOOStep(
step=1,
action="final_answer",
content=direct_response.content,
tokens=direct_response.usage.total_tokens,
plan_step_id=None,
)
trajectory.append(direct_step)
if trace_recorder is not None:
trace_recorder.record_step(
step=1,
action="final_answer",
output_data={"content": direct_response.content},
tokens_used=direct_response.usage.total_tokens,
)
trace_outcome = "success"
if trace_recorder is not None:
trace_recorder.end_trace(outcome=trace_outcome)
return ReActResult(
output=direct_response.content or "",
trajectory=trajectory,
total_steps=len(trajectory),
total_tokens=total_tokens,
fallback_strategy=fallback_strategy,
)
except Exception as direct_err:
logger.error(f"Direct LLM fallback also failed: {direct_err}")
raise
fallback_strategy = await self._try_fallback_strategies(
strategies=self._fallback_strategies,
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=system_prompt,
effective_system_prompt=effective_system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id,
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
trajectory=trajectory,
total_tokens=total_tokens,
confirmation_handler=confirmation_handler,
)
if fallback_strategy is not None:
return fallback_strategy
# All fallback strategies exhausted
raise RuntimeError("All ReWOO fallback strategies exhausted")
# 如果计划为空(无需工具),直接让 LLM 回答
if not plan.steps:
@ -579,6 +529,7 @@ class ReWOOEngine:
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
timeout_seconds: float | None = None,
confirmation_handler: Any | None = None,
):
"""Execute ReWOO flow, yielding ReActEvent objects.
@ -627,7 +578,6 @@ class ReWOOEngine:
trace_outcome = "success"
try:
# ── Phase 1: Planning ──
yield ReActEvent(
event_type="planning",
step=0,
@ -648,83 +598,27 @@ class ReWOOEngine:
total_tokens += planning_tokens
if plan is None:
# Try simplified planning
logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3")
try:
plan, simplified_tokens = await self._plan_phase(
messages=messages,
tools=tools,
tool_schemas=tool_schemas,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=effective_system_prompt,
compressor=compressor,
cancellation_token=cancellation_token,
max_steps=3,
)
total_tokens += simplified_tokens
except Exception as e2:
logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e2}")
if plan is None:
# Planning failed, fall back to ReAct streaming
logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine")
try:
async for event in self._react_engine.execute_stream(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id,
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
timeout_seconds=0,
):
yield event
return
except Exception as react_err:
# ReAct also failed, fall back to direct LLM call
logger.warning(f"ReAct fallback also failed in stream mode: {react_err}, falling back to direct LLM call")
try:
direct_messages: list[dict[str, Any]] = []
if effective_system_prompt:
direct_messages.append({"role": "system", "content": effective_system_prompt})
direct_messages.extend(messages)
if compressor:
try:
direct_messages = await compressor.compress(direct_messages)
except Exception as e:
logger.warning(f"Context compression failed in direct fallback: {e}")
direct_response = await self._llm_gateway.chat(
messages=direct_messages,
model=model,
agent_name=agent_name,
task_type=task_type,
)
total_tokens += direct_response.usage.total_tokens
output = direct_response.content or ""
yield ReActEvent(
event_type="final_answer",
step=1,
data={
"output": output,
"total_steps": 1,
"total_tokens": total_tokens,
},
)
return
except Exception as direct_err:
logger.error(f"Direct LLM fallback also failed in stream mode: {direct_err}")
raise
# Planning failed, try fallback strategies in configured order
async for event in self._try_fallback_strategies_stream(
strategies=self._fallback_strategies,
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=system_prompt,
effective_system_prompt=effective_system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id,
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
total_tokens=total_tokens,
confirmation_handler=confirmation_handler,
):
yield event
return
yield ReActEvent(
event_type="plan_generated",
@ -875,8 +769,11 @@ class ReWOOEngine:
"total_tokens": total_tokens,
},
)
except Exception as e:
trace_outcome = "error"
logger.error(f"ReWOO execute_stream failed: {e}")
raise
finally:
# 结束轨迹记录
if trace_recorder is not None:
trace_recorder.end_trace(outcome=trace_outcome)
@ -892,6 +789,582 @@ class ReWOOEngine:
except Exception as e:
logger.warning(f"Failed to store task result in episodic memory: {e}")
# ── Fallback Strategy Helpers ──────────────────────────
async def _try_fallback_strategies_stream(
self,
strategies: list[str],
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
system_prompt: str | None = None,
effective_system_prompt: str | None = None,
trace_recorder: "TraceRecorder | None" = None,
memory_retriever: "MemoryRetriever | None" = None,
task_id: str | None = None,
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
total_tokens: int = 0,
confirmation_handler: Any | None = None,
):
"""Stream version: try fallback strategies in configured order, yielding events from the first successful one.
If all strategies fail, raises RuntimeError.
"""
for strategy in strategies:
if strategy == "simplified_rewoo":
try:
async for event in self._fallback_simplified_rewoo_stream(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, cancellation_token=cancellation_token,
):
yield event
return
except _FallbackFailedError:
continue
elif strategy == "react":
try:
async for event in self._fallback_react_stream(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
system_prompt=system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id, compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
):
yield event
return
except _FallbackFailedError:
continue
elif strategy == "direct":
try:
async for event in self._fallback_direct_stream(
messages=messages, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, total_tokens=total_tokens,
):
yield event
return
except _FallbackFailedError:
continue
elif strategy == "plan_exec":
try:
async for event in self._fallback_plan_exec_stream(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, cancellation_token=cancellation_token,
):
yield event
return
except _FallbackFailedError:
continue
raise RuntimeError("All ReWOO fallback strategies exhausted in stream mode")
async def _fallback_simplified_rewoo_stream(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
cancellation_token: CancellationToken | None = None,
):
"""Stream: Simplified ReWOO fallback with max_steps=3"""
logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3")
try:
tool_schemas = self._build_tool_schemas(tools) if tools else None
plan, simplified_tokens = await self._plan_phase(
messages=messages, tools=tools or [], tool_schemas=tool_schemas,
model=model, agent_name=agent_name, task_type=task_type,
system_prompt=effective_system_prompt, compressor=compressor,
cancellation_token=cancellation_token, max_steps=3,
)
if plan is not None and plan.steps:
logger.info("Simplified ReWOO planning succeeded in stream mode")
yield ReActEvent(event_type="plan_generated", step=0, data={
"reasoning": plan.reasoning,
"steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps],
})
tool_results: list[dict[str, Any]] = []
for plan_step in plan.steps:
if cancellation_token is not None:
cancellation_token.check()
yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments})
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning})
yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result})
yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."})
output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token)
yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": simplified_tokens + synthesis_tokens})
return
except Exception as e:
logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e}")
# Failed, continue to next strategy by not returning
# This signals the caller to try the next strategy
# We need a different approach - raise a specific exception
raise _FallbackFailedError("simplified_rewoo")
async def _fallback_react_stream(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
system_prompt: str | None = None,
trace_recorder: "TraceRecorder | None" = None,
memory_retriever: "MemoryRetriever | None" = None,
task_id: str | None = None,
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
confirmation_handler: Any | None = None,
):
"""Stream: ReAct fallback"""
logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine")
try:
async for event in self._react_engine.execute_stream(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
system_prompt=system_prompt, trace_recorder=trace_recorder,
memory_retriever=memory_retriever, task_id=task_id,
compressor=compressor, retrieval_config=retrieval_config,
cancellation_token=cancellation_token, timeout_seconds=0,
confirmation_handler=confirmation_handler,
):
yield event
return
except Exception as e:
logger.warning(f"ReAct fallback also failed in stream mode: {e}")
raise _FallbackFailedError("react")
async def _fallback_direct_stream(
self,
messages: list[dict[str, str]],
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
total_tokens: int = 0,
):
"""Stream: Direct LLM fallback"""
logger.warning("Falling back to direct LLM call in stream mode")
try:
direct_messages: list[dict[str, Any]] = []
if effective_system_prompt:
direct_messages.append({"role": "system", "content": effective_system_prompt})
direct_messages.extend(messages)
if compressor:
try:
direct_messages = await compressor.compress(direct_messages)
except Exception as e:
logger.warning(f"Context compression failed in direct fallback: {e}")
direct_response = await self._llm_gateway.chat(messages=direct_messages, model=model, agent_name=agent_name, task_type=task_type)
output = direct_response.content or ""
yield ReActEvent(event_type="final_answer", step=1, data={"output": output, "total_steps": 1, "total_tokens": total_tokens + direct_response.usage.total_tokens})
return
except Exception as e:
logger.error(f"Direct LLM fallback also failed in stream mode: {e}")
raise _FallbackFailedError("direct")
async def _fallback_plan_exec_stream(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
cancellation_token: CancellationToken | None = None,
):
"""Stream: Plan-Exec fallback with max_steps=5"""
logger.warning("Falling back to plan-exec mode in stream mode (max_steps=5)")
try:
tool_schemas = self._build_tool_schemas(tools) if tools else None
plan, plan_tokens = await self._plan_phase(
messages=messages, tools=tools or [], tool_schemas=tool_schemas,
model=model, agent_name=agent_name, task_type=task_type,
system_prompt=effective_system_prompt, compressor=compressor,
cancellation_token=cancellation_token, max_steps=5,
)
if plan is not None and plan.steps:
yield ReActEvent(event_type="plan_generated", step=0, data={
"reasoning": plan.reasoning,
"steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps],
})
tool_results: list[dict[str, Any]] = []
for plan_step in plan.steps:
if cancellation_token is not None:
cancellation_token.check()
yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments})
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning})
yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result})
yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."})
output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token)
yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": plan_tokens + synthesis_tokens})
return
except Exception as e:
logger.warning(f"Plan-exec fallback also failed in stream mode: {e}")
raise _FallbackFailedError("plan_exec")
async def _try_fallback_strategies(
self,
strategies: list[str],
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
system_prompt: str | None = None,
effective_system_prompt: str | None = None,
trace_recorder: "TraceRecorder | None" = None,
memory_retriever: "MemoryRetriever | None" = None,
task_id: str | None = None,
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
trajectory: list[ReActStep] | None = None,
total_tokens: int = 0,
confirmation_handler: Any | None = None,
) -> ReActResult | None:
"""按配置的 fallback 策略顺序尝试回退,返回第一个成功的结果
Returns:
ReActResult if a fallback succeeded, None if all strategies exhausted
"""
for strategy in strategies:
if strategy == "simplified_rewoo":
result = await self._fallback_simplified_rewoo(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, cancellation_token=cancellation_token,
)
if result is not None:
return result
elif strategy == "react":
result = await self._fallback_react(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
system_prompt=system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id, compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
confirmation_handler=confirmation_handler,
)
if result is not None:
return result
elif strategy == "direct":
result = await self._fallback_direct(
messages=messages, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, cancellation_token=cancellation_token,
trajectory=trajectory, total_tokens=total_tokens,
trace_recorder=trace_recorder,
)
if result is not None:
return result
elif strategy == "plan_exec":
result = await self._fallback_plan_exec(
messages=messages, tools=tools, model=model,
agent_name=agent_name, task_type=task_type,
effective_system_prompt=effective_system_prompt,
compressor=compressor, cancellation_token=cancellation_token,
)
if result is not None:
return result
return None
async def _fallback_simplified_rewoo(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
cancellation_token: CancellationToken | None = None,
) -> ReActResult | None:
"""Simplified ReWOO fallback: retry planning with max_steps=3"""
logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3")
try:
tool_schemas = self._build_tool_schemas(tools) if tools else None
plan, simplified_tokens = await self._plan_phase(
messages=messages,
tools=tools or [],
tool_schemas=tool_schemas,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=effective_system_prompt,
compressor=compressor,
cancellation_token=cancellation_token,
max_steps=3,
)
if plan is not None and plan.steps:
logger.info("Simplified ReWOO planning succeeded")
# Execute the simplified plan
trajectory: list[ReActStep] = []
total_tokens = simplified_tokens
tool_results: list[dict[str, Any]] = []
for plan_step in plan.steps:
if cancellation_token is not None:
cancellation_token.check()
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
rewoo_step = ReWOOStep(
step=plan_step.step_id,
action="tool_call",
tool_name=plan_step.tool_name,
arguments=plan_step.arguments,
result=tool_result,
tokens=0,
plan_step_id=plan_step.step_id,
)
trajectory.append(rewoo_step)
tool_results.append({
"step_id": plan_step.step_id,
"tool_name": plan_step.tool_name,
"arguments": plan_step.arguments,
"result": tool_result,
"reasoning": plan_step.reasoning,
})
output, synthesis_tokens = await self._synthesis_phase(
messages=messages, tool_results=tool_results,
model=model, agent_name=agent_name, task_type=task_type,
system_prompt=effective_system_prompt, compressor=compressor,
cancellation_token=cancellation_token,
)
total_tokens += synthesis_tokens
trajectory.append(ReWOOStep(
step=len(plan.steps) + 1,
action="final_answer",
content=output,
tokens=synthesis_tokens,
))
return ReActResult(
output=output,
trajectory=trajectory,
total_steps=len(trajectory),
total_tokens=total_tokens,
fallback_strategy="simplified_rewoo",
)
except Exception as e:
logger.warning(f"Simplified ReWOO planning also failed: {e}")
return None
async def _fallback_react(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
system_prompt: str | None = None,
trace_recorder: "TraceRecorder | None" = None,
memory_retriever: "MemoryRetriever | None" = None,
task_id: str | None = None,
compressor: "CompressionStrategy | None" = None,
retrieval_config: dict[str, Any] | None = None,
cancellation_token: CancellationToken | None = None,
confirmation_handler: Any | None = None,
) -> ReActResult | None:
"""ReAct fallback: delegate to ReActEngine"""
logger.warning("ReWOO planning failed, falling back to ReActEngine")
try:
react_result = await self._react_engine.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=system_prompt,
trace_recorder=trace_recorder,
memory_retriever=memory_retriever,
task_id=task_id,
compressor=compressor,
retrieval_config=retrieval_config,
cancellation_token=cancellation_token,
timeout_seconds=0, # timeout already handled by outer wrapper
confirmation_handler=confirmation_handler,
)
react_result.fallback_strategy = "react"
return react_result
except Exception as e:
logger.warning(f"ReAct fallback also failed: {e}")
return None
async def _fallback_direct(
self,
messages: list[dict[str, str]],
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
cancellation_token: CancellationToken | None = None,
trajectory: list[ReActStep] | None = None,
total_tokens: int = 0,
trace_recorder: "TraceRecorder | None" = None,
) -> ReActResult | None:
"""Direct fallback: simple LLM call without tools"""
logger.warning("Falling back to direct LLM call")
try:
direct_messages: list[dict[str, Any]] = []
if effective_system_prompt:
direct_messages.append({"role": "system", "content": effective_system_prompt})
direct_messages.extend(messages)
if compressor:
try:
direct_messages = await compressor.compress(direct_messages)
except Exception as e:
logger.warning(f"Context compression failed in direct fallback: {e}")
direct_response = await self._llm_gateway.chat(
messages=direct_messages,
model=model,
agent_name=agent_name,
task_type=task_type,
)
total_tokens += direct_response.usage.total_tokens
direct_step = ReWOOStep(
step=1,
action="final_answer",
content=direct_response.content,
tokens=direct_response.usage.total_tokens,
plan_step_id=None,
)
if trajectory is not None:
trajectory.append(direct_step)
if trace_recorder is not None:
trace_recorder.record_step(
step=1,
action="final_answer",
output_data={"content": direct_response.content},
tokens_used=direct_response.usage.total_tokens,
)
trace_recorder.end_trace(outcome="success")
return ReActResult(
output=direct_response.content or "",
trajectory=trajectory or [direct_step],
total_steps=len(trajectory or [direct_step]),
total_tokens=total_tokens,
fallback_strategy="direct",
)
except Exception as e:
logger.error(f"Direct LLM fallback also failed: {e}")
return None
async def _fallback_plan_exec(
self,
messages: list[dict[str, str]],
tools: list[Tool] | None = None,
model: str = "default",
agent_name: str = "",
task_type: str = "",
effective_system_prompt: str | None = None,
compressor: "CompressionStrategy | None" = None,
cancellation_token: CancellationToken | None = None,
) -> ReActResult | None:
"""Plan-Exec fallback: plan then execute sequentially (like simplified ReWOO but with max_steps=5)"""
logger.warning("Falling back to plan-exec mode (max_steps=5)")
try:
tool_schemas = self._build_tool_schemas(tools) if tools else None
plan, plan_tokens = await self._plan_phase(
messages=messages,
tools=tools or [],
tool_schemas=tool_schemas,
model=model,
agent_name=agent_name,
task_type=task_type,
system_prompt=effective_system_prompt,
compressor=compressor,
cancellation_token=cancellation_token,
max_steps=5,
)
if plan is not None and plan.steps:
trajectory: list[ReActStep] = []
total_tokens = plan_tokens
tool_results: list[dict[str, Any]] = []
for plan_step in plan.steps:
if cancellation_token is not None:
cancellation_token.check()
tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or [])
rewoo_step = ReWOOStep(
step=plan_step.step_id,
action="tool_call",
tool_name=plan_step.tool_name,
arguments=plan_step.arguments,
result=tool_result,
tokens=0,
plan_step_id=plan_step.step_id,
)
trajectory.append(rewoo_step)
tool_results.append({
"step_id": plan_step.step_id,
"tool_name": plan_step.tool_name,
"arguments": plan_step.arguments,
"result": tool_result,
"reasoning": plan_step.reasoning,
})
output, synthesis_tokens = await self._synthesis_phase(
messages=messages, tool_results=tool_results,
model=model, agent_name=agent_name, task_type=task_type,
system_prompt=effective_system_prompt, compressor=compressor,
cancellation_token=cancellation_token,
)
total_tokens += synthesis_tokens
trajectory.append(ReWOOStep(
step=len(plan.steps) + 1,
action="final_answer",
content=output,
tokens=synthesis_tokens,
))
return ReActResult(
output=output,
trajectory=trajectory,
total_steps=len(trajectory),
total_tokens=total_tokens,
fallback_strategy="plan_exec",
)
except Exception as e:
logger.warning(f"Plan-exec fallback also failed: {e}")
return None
# ── Phase Implementations ─────────────────────────────
async def _plan_phase(
@ -1079,7 +1552,6 @@ class ReWOOEngine:
# 尝试从 markdown 代码块中提取
if "```" in json_str:
import re
code_block_match = re.search(r"```(?:json)?\s*\n(.*?)\n\s*```", json_str, re.DOTALL)
if code_block_match:
json_str = code_block_match.group(1).strip()

View File

@ -21,6 +21,14 @@ class Bid:
capabilities: list[str] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.estimated_cost < 0:
raise ValueError(f"estimated_cost must be non-negative, got {self.estimated_cost}")
if not (0.0 <= self.confidence <= 1.0):
raise ValueError(f"confidence must be between 0.0 and 1.0, got {self.confidence}")
if self.payment_offer < 0:
raise ValueError(f"payment_offer must be non-negative, got {self.payment_offer}")
@dataclass
class AuctionResult:
@ -98,3 +106,125 @@ class AuctionHouse:
wealth_factor = self._wealth_tracker.get_wealth_factor(bid.agent_name)
score = (bid.confidence / max(bid.estimated_cost, 0.001)) * wealth_factor
return score
def filter_by_capabilities(
self, bidders: list[Bid], required_capabilities: list[str]
) -> list[Bid]:
"""Filter bidders whose capabilities include ALL required capabilities.
Args:
bidders: List of Bid objects to filter.
required_capabilities: Capabilities that a bidder must ALL possess.
Returns:
List of Bids whose capabilities list includes every required capability.
"""
if not required_capabilities:
return bidders
required_set = {cap.lower() for cap in required_capabilities}
return [
b for b in bidders
if required_set.issubset({cap.lower() for cap in b.capabilities})
]
async def run_vickrey_auction(
self,
task_description: str,
bidders: list[Bid],
required_capabilities: list[str] | None = None,
) -> AuctionResult:
"""Run a Vickrey (second-price sealed-bid) auction.
In a Vickrey auction each bidder submits a sealed bid (their
estimated_cost). The lowest estimated_cost wins, but the winner
*pays* the second-lowest estimated_cost rather than their own.
This is incentive-compatible: agents' dominant strategy is to bid
truthfully.
Steps:
1. Filter by required_capabilities (if provided).
2. Filter out bankrupt agents.
3. If only 1 eligible bidder wins, pays 0.
4. If 2+ eligible bidders lowest cost wins, pays second-lowest.
5. Update WealthTracker: winner earns (payment - cost_estimate).
6. Return AuctionResult with selection_reason.
Args:
task_description: Description of the task being auctioned.
bidders: List of Bid objects.
required_capabilities: Optional list of capabilities that bidders
must possess to be eligible.
Returns:
AuctionResult with the winner and Vickrey outcome details.
"""
# 1. Capability filtering
eligible = (
self.filter_by_capabilities(bidders, required_capabilities)
if required_capabilities
else list(bidders)
)
# 2. Filter out bankrupt agents
eligible = [
b for b in eligible
if not self._wealth_tracker.is_bankrupt(b.agent_name)
]
# No bidders at all
if not bidders:
return AuctionResult(
winner=None,
all_bids=bidders,
selection_reason="No bidders participated",
total_bidders=0,
)
# All eligible bidders filtered out (bankrupt or no capabilities)
if not eligible:
return AuctionResult(
winner=None,
all_bids=bidders,
selection_reason="No eligible bidders (bankrupt or missing capabilities)",
total_bidders=len(bidders),
)
# 3. Sort by estimated_cost ascending (lowest cost = best bid)
# Apply minimum cost floor to prevent zero-cost bid manipulation
MIN_COST = 0.001
for b in eligible:
b.estimated_cost = max(b.estimated_cost, MIN_COST)
eligible = sorted(eligible, key=lambda b: b.estimated_cost)
winner = eligible[0]
# 4. Determine payment (second-price rule)
if len(eligible) == 1:
# Single bidder: payment equals their own cost (no profit, no loss)
payment = winner.estimated_cost
else:
payment = eligible[1].estimated_cost
# 5. Update WealthTracker
profit = payment - winner.estimated_cost
self._wealth_tracker.reward(winner.agent_name, profit)
# 6. Build selection_reason
if len(eligible) == 1:
reason = (
f"Vickrey auction: Agent '{winner.agent_name}' won as sole eligible bidder "
f"(cost={winner.estimated_cost}, payment={payment:.4f}, profit=0)"
)
else:
second = eligible[1]
reason = (
f"Vickrey auction: Agent '{winner.agent_name}' won with lowest cost "
f"({winner.estimated_cost}), pays second-lowest cost ({payment}) from "
f"'{second.agent_name}'; profit={profit:.4f}"
)
return AuctionResult(
winner=winner,
all_bids=bidders,
selection_reason=reason,
total_bidders=len(bidders),
)

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import threading
class WealthTracker:
"""Track agent wealth for auction mechanism.
@ -9,25 +11,31 @@ class WealthTracker:
Agents earn wealth by completing tasks successfully.
Agents lose wealth by failing tasks.
Bankrupt agents (wealth <= -100) are excluded from auctions.
Thread-safe: all mutations are protected by a threading.Lock.
"""
def __init__(self, initial_wealth: float = 100.0) -> None:
self._balances: dict[str, float] = {}
self._initial_wealth = initial_wealth
self._lock = threading.Lock()
def get_wealth(self, agent_name: str) -> float:
"""Get agent's current wealth, defaulting to initial_wealth"""
return self._balances.get(agent_name, self._initial_wealth)
with self._lock:
return self._balances.get(agent_name, self._initial_wealth)
def reward(self, agent_name: str, amount: float) -> None:
"""Reward agent for successful task completion"""
current = self.get_wealth(agent_name)
self._balances[agent_name] = current + amount
with self._lock:
current = self._balances.get(agent_name, self._initial_wealth)
self._balances[agent_name] = current + amount
def penalize(self, agent_name: str, amount: float) -> None:
"""Penalize agent for task failure"""
current = self.get_wealth(agent_name)
self._balances[agent_name] = current - amount
with self._lock:
current = self._balances.get(agent_name, self._initial_wealth)
self._balances[agent_name] = current - amount
def is_bankrupt(self, agent_name: str) -> bool:
"""Check if agent is bankrupt (wealth <= -100)"""
@ -46,5 +54,5 @@ class WealthTracker:
return all_agents
def get_wealth_factor(self, agent_name: str) -> float:
"""Get wealth factor for scoring: 1.0 + (wealth / 1000.0)"""
return 1.0 + (self.get_wealth(agent_name) / 1000.0)
"""Get wealth factor for scoring: max(0.01, 1.0 + (wealth / 1000.0))"""
return max(0.01, 1.0 + (self.get_wealth(agent_name) / 1000.0))

View File

@ -20,6 +20,7 @@ from agentkit.router.intent import IntentRouter
from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
from agentkit.tools.skill_install import SkillInstallTool
from agentkit.server.config import ServerConfig
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows, chat
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
@ -104,6 +105,8 @@ async def lifespan(app: FastAPI):
if server_config is not None and server_config._config_path:
server_config.on_change = lambda cfg: _on_config_change(app, cfg)
server_config.watch_config()
# Store event loop reference for thread-safe config reload
app.state._event_loop = asyncio.get_running_loop()
logger.info("Config hot-reload enabled")
# Start MCP servers if configured
@ -132,7 +135,10 @@ async def lifespan(app: FastAPI):
"你必须先使用搜索工具查找准确和最新的信息,然后再回答。"
"中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。"
"在能够搜索到真相的情况下,绝不猜测或编造答案。"
"始终优先搜索而不是给出可能不正确的信息。"
"始终优先搜索而不是给出可能不正确的信息。\n\n"
"技能安装:当需要安装技能时,使用 skill_install 工具,不要用 shell 执行 npm install。"
"skill_install 的 source 参数格式为 owner/repo@skill例如 vercel-labs/skills@find-skills。"
"如果不知道完整 source先用 shell 执行 `npx skills search <name>` 搜索。"
)
effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt)
@ -156,6 +162,10 @@ async def lifespan(app: FastAPI):
}
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
agent._tool_registry.register(ShellTool())
agent._tool_registry.register(SkillInstallTool(
skill_registry=app.state.skill_registry,
tool_registry=app.state.tool_registry,
))
agent._tool_registry.register(BaiduSearchTool())
agent._tool_registry.register(WebSearchTool(**search_api_keys))
agent._tool_registry.register(WebCrawlTool())
@ -238,13 +248,15 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
- Config version is incremented for audit tracking
Uses a lock to prevent concurrent config reloads from racing.
Thread-safe: uses threading.Event for cross-thread signaling.
Thread-safe: schedules reload onto the asyncio event loop via
asyncio.run_coroutine_threadsafe() since watchfiles calls this
from a non-asyncio thread.
"""
import threading
lock: asyncio.Lock = app.state._config_reload_lock
lock = app.state._config_reload_lock
# Thread-safe: set pending flag via threading.Event or call_soon_threadsafe
# Thread-safe: set pending flag via threading.Event
if not hasattr(app.state, "_config_reload_event"):
app.state._config_reload_event = threading.Event()
@ -310,12 +322,18 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None:
logger.info(f"Config reload complete (v{current_version})")
# Schedule the reload as a task (non-blocking for the watcher thread)
# Schedule the reload onto the event loop via run_coroutine_threadsafe
# (watchfiles calls _on_config_change from a non-asyncio thread)
try:
loop = asyncio.get_running_loop()
loop.create_task(_reload())
asyncio.run_coroutine_threadsafe(_reload(), loop)
except RuntimeError:
logger.warning("No running event loop, config reload deferred")
# No running loop — try getting the stored event loop reference
_loop = getattr(app.state, "_event_loop", None)
if _loop is not None and not _loop.is_closed():
asyncio.run_coroutine_threadsafe(_reload(), _loop)
else:
logger.warning("No running event loop, config reload deferred")
def create_app(
@ -482,6 +500,7 @@ def create_app(
org_context=org_context,
auction_enabled=auction_enabled,
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
merged_llm_classify=server_config.router.get("merged_llm_classify", True) if server_config and server_config.router else True,
)
app.state.cost_aware_router = cost_aware_router
# Initialize task store from config

View File

@ -193,7 +193,12 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
# Execute the Agent
try:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
else:
react_engine.reset()
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
result = await react_engine.execute(
@ -286,6 +291,10 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
cancellation_token = CancellationToken()
# Per-session concurrency guard to prevent unlimited task creation (DoS mitigation)
_MAX_CONCURRENT_TASKS = 4
active_tasks: set[asyncio.Task] = set()
try:
await websocket.send_json({"type": "connected", "session_id": session_id})
@ -308,9 +317,27 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
content = msg.get("content", "")
# Create a fresh CancellationToken for each message
message_token = CancellationToken()
await _handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
# Guard against unlimited concurrent tasks
# Clean up completed tasks first
active_tasks.difference_update(t for t in active_tasks if t.done())
if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
await websocket.send_json({
"type": "error",
"data": {"message": "Too many concurrent requests. Please wait for the current task to complete."},
})
continue
# Run in background task so the WebSocket receive loop stays free
# to process confirmation_reply / reply messages while the agent
# is waiting for user confirmation (otherwise deadlock).
task = asyncio.create_task(
_handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
)
)
active_tasks.add(task)
task.add_done_callback(active_tasks.discard)
elif msg_type == "reply":
# Reply to AskHumanTool
@ -323,8 +350,12 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
# Reply to confirmation request
confirmation_id = msg.get("confirmation_id")
approved = msg.get("approved", False)
logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}")
if confirmation_id and confirmation_id in pending_confirmations:
pending_confirmations[confirmation_id].set_result(approved)
logger.info(f"Confirmation {confirmation_id} set_result({approved})")
else:
logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations")
elif msg_type == "cancel":
cancellation_token.cancel()
@ -424,10 +455,17 @@ async def _handle_chat_message(
chat_messages = await sm.get_chat_messages(session_id)
# Execute Agent with streaming
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
else:
react_engine.reset()
# Create confirmation handler that sends request to frontend and waits for reply
_pending_confirmations = pending_confirmations or {}
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
# and would create a new dict, breaking the shared state with the WS loop.
_pending_confirmations = pending_confirmations if pending_confirmations is not None else {}
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
@ -442,16 +480,28 @@ async def _handle_chat_message(
})
# Create a Future and wait for the user's reply
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
future: asyncio.Future[bool] = loop.create_future()
_pending_confirmations[confirmation_id] = future
logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply")
try:
# Wait up to 5 minutes for user confirmation
return await asyncio.wait_for(future, timeout=300.0)
result = await asyncio.wait_for(future, timeout=300.0)
logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
# Immediately notify frontend of the result so the card updates
# without waiting for the tool to re-execute
await websocket.send_json({
"type": "confirmation_result",
"data": {"confirmation_id": confirmation_id, "approved": result},
})
return result
except asyncio.TimeoutError:
logger.warning(f"Confirmation request {confirmation_id} timed out")
return False
except asyncio.CancelledError:
logger.warning(f"Confirmation request {confirmation_id} cancelled")
return False
finally:
_pending_confirmations.pop(confirmation_id, None)
@ -459,6 +509,7 @@ async def _handle_chat_message(
try:
final_content = ""
token_buffer: list[str] = []
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=routing.tools,
@ -469,24 +520,51 @@ async def _handle_chat_message(
confirmation_handler=_confirmation_handler,
):
if event.event_type == "final_answer":
# Flush any buffered tokens as a single write
if token_buffer:
await websocket.send_json({"type": "token", "content": "".join(token_buffer)})
token_buffer.clear()
# Then send final answer
final_content = event.data.get("output", "")
await websocket.send_json({
"type": "final_answer",
"content": final_content,
"is_final": True,
})
elif event.event_type == "token":
# Buffer tokens instead of sending immediately
token_buffer.append(event.data.get("content", ""))
elif event.event_type == "thinking":
# If we have buffered tokens, convert them to a thinking event
if token_buffer:
buffered_text = "".join(token_buffer)
token_buffer.clear()
await websocket.send_json({"type": "thinking", "content": buffered_text})
# Also send the thinking event content
thinking_msg = event.data.get("message", "")
if thinking_msg:
await websocket.send_json({"type": "thinking", "content": thinking_msg})
elif event.event_type == "tool_call":
# Convert buffered tokens to thinking (they were "thinking" text before tool call)
if token_buffer:
buffered_text = "".join(token_buffer)
token_buffer.clear()
await websocket.send_json({"type": "thinking", "content": buffered_text})
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
elif event.event_type == "confirmation_request":
# Already handled by confirmation_handler, just notify frontend
pass
elif event.event_type == "confirmation_result":
await websocket.send_json({
"type": "confirmation_result",
"data": event.data,
})
elif event.event_type == "token":
await websocket.send_json({
"type": "token",
"content": event.data.get("content", ""),
})
else:
await websocket.send_json({
"type": "step",

View File

@ -8,6 +8,7 @@
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:ital,wght@0,300;0,400;0,500;0,600;0,700;1,400&display=swap" rel="stylesheet">
<script src="https://cdn.jsdelivr.net/npm/marked@12/marked.min.js"></script>
<style>
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
:root{
@ -88,6 +89,33 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
.msg .bubble{padding:12px 18px;font-size:14px;line-height:1.7;white-space:pre-wrap;word-break:break-word;position:relative}
.msg.user .bubble{background:var(--user-bg);color:var(--user-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-sm) var(--radius-lg);box-shadow:0 2px 8px rgba(59,91,219,.2)}
.msg.agent .bubble{background:var(--surface);color:var(--agent-text);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);border:1px solid var(--border-light);box-shadow:var(--shadow-xs)}
/* ── Markdown in Agent Bubbles ──────────────────────────────── */
.msg.agent .bubble.md{white-space:normal}
.msg.agent .bubble.md p{margin:0 0 .6em}
.msg.agent .bubble.md p:last-child{margin-bottom:0}
.msg.agent .bubble.md h1,.msg.agent .bubble.md h2,.msg.agent .bubble.md h3,.msg.agent .bubble.md h4{margin:.8em 0 .4em;font-weight:700;letter-spacing:-.3px;line-height:1.3}
.msg.agent .bubble.md h1{font-size:1.4em}
.msg.agent .bubble.md h2{font-size:1.2em}
.msg.agent .bubble.md h3{font-size:1.05em}
.msg.agent .bubble.md h4{font-size:1em}
.msg.agent .bubble.md ul,.msg.agent .bubble.md ol{margin:.4em 0 .6em 1.5em;padding:0}
.msg.agent .bubble.md li{margin:.2em 0}
.msg.agent .bubble.md li>p{margin:.2em 0}
.msg.agent .bubble.md blockquote{margin:.5em 0;padding:.4em 1em;border-left:3px solid var(--primary);background:var(--primary-light);border-radius:0 var(--radius-sm) var(--radius-sm) 0;color:var(--text2)}
.msg.agent .bubble.md blockquote p{margin:.2em 0}
.msg.agent .bubble.md code{font-family:'SF Mono',SFMono-Regular,Consolas,'Liberation Mono',Menlo,monospace;font-size:.88em;background:var(--surface2);padding:.15em .4em;border-radius:4px;color:var(--text)}
.msg.agent .bubble.md pre{margin:.5em 0;padding:12px 14px;background:var(--surface2);border:1px solid var(--border-light);border-radius:var(--radius-sm);overflow-x:auto;line-height:1.5}
.msg.agent .bubble.md pre code{background:none;padding:0;font-size:.85em;color:var(--text)}
.msg.agent .bubble.md table{border-collapse:collapse;margin:.5em 0;width:100%;font-size:.9em}
.msg.agent .bubble.md th,.msg.agent .bubble.md td{border:1px solid var(--border);padding:6px 10px;text-align:left}
.msg.agent .bubble.md th{background:var(--surface2);font-weight:600}
.msg.agent .bubble.md hr{border:none;border-top:1px solid var(--border);margin:.8em 0}
.msg.agent .bubble.md a{color:var(--primary);text-decoration:none}
.msg.agent .bubble.md a:hover{text-decoration:underline}
.msg.agent .bubble.md img{max-width:100%;border-radius:var(--radius-sm)}
.msg.agent .bubble.md strong{font-weight:700}
.msg.agent .bubble.md em{font-style:italic}
.msg .meta{font-size:11px;color:var(--text3);margin-top:5px;padding:0 4px;font-weight:500}
.msg.user .meta{text-align:right}
.typing-indicator{display:inline-flex;gap:5px;padding:6px 0}
@ -164,6 +192,10 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
@keyframes fadeIn{from{opacity:0}to{opacity:1}}
@keyframes slideInRight{from{opacity:0;transform:translateX(12px)}to{opacity:1;transform:translateX(0)}}
/* ── Thinking ──────────────────────────────────────────────────── */
.thinking-msg{display:flex;flex-direction:column;max-width:72%;animation:msgIn .35s cubic-bezier(.16,1,.3,1);align-self:flex-start}
.thinking-msg .bubble{padding:10px 16px;font-size:13px;line-height:1.6;color:var(--text3);font-style:italic;background:var(--surface2);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);border:1px dashed var(--border)}
/* ── Loading indicator ───────────────────────────────────────── */
.loading-msg{display:flex;flex-direction:column;max-width:72%;align-self:flex-start;animation:msgIn .35s cubic-bezier(.16,1,.3,1)}
.loading-msg .loading-bubble{padding:8px 20px;background:var(--surface);border:1px solid var(--border-light);border-radius:var(--radius-lg) var(--radius-lg) var(--radius-lg) var(--radius-sm);display:inline-flex;align-items:center;gap:6px;box-shadow:var(--shadow-xs)}
@ -172,6 +204,26 @@ html,body{height:100%;font-family:var(--font);background:var(--bg);color:var(--t
.loading-msg .loading-dot:nth-child(3){animation-delay:.4s}
@keyframes loadingDot{0%,80%,100%{transform:scale(.4);opacity:.3}40%{transform:scale(1);opacity:1}}
/* ── Confirmation Card ──────────────────────────────────────── */
.confirm-card{display:flex;flex-direction:column;max-width:72%;align-self:flex-start;animation:msgIn .35s cubic-bezier(.16,1,.3,1)}
.confirm-card .card-inner{background:var(--surface);border:1px solid var(--warning);border-radius:var(--radius-lg);box-shadow:var(--shadow-md);overflow:hidden}
.confirm-card .card-header{display:flex;align-items:center;gap:10px;padding:14px 18px 10px;border-bottom:1px solid var(--border-light)}
.confirm-card .card-icon{width:32px;height:32px;border-radius:var(--radius-sm);background:#fef3c7;display:flex;align-items:center;justify-content:center;font-size:16px;flex-shrink:0}
.confirm-card .card-title{font-size:14px;font-weight:600;color:var(--text);letter-spacing:-0.2px}
.confirm-card .card-body{padding:12px 18px 14px}
.confirm-card .card-command{background:var(--surface2);border:1px solid var(--border-light);border-radius:var(--radius-sm);padding:10px 14px;font-family:'SF Mono',SFMono-Regular,Consolas,'Liberation Mono',Menlo,monospace;font-size:13px;color:var(--text);word-break:break-all;line-height:1.6;margin-bottom:10px}
.confirm-card .card-reason{font-size:13px;color:var(--text2);line-height:1.5;margin-bottom:4px}
.confirm-card .card-actions{display:flex;gap:8px;padding:0 18px 14px}
.confirm-card .btn-confirm{flex:1;padding:9px 16px;border-radius:var(--radius-sm);font-size:13px;font-weight:600;cursor:pointer;transition:all .2s;border:1px solid;font-family:var(--font)}
.confirm-card .btn-confirm.approve{background:var(--primary-light);color:var(--primary);border-color:var(--primary-subtle)}
.confirm-card .btn-confirm.approve:hover{background:var(--primary);color:#fff;border-color:var(--primary)}
.confirm-card .btn-confirm.deny{background:var(--danger-light);color:var(--danger);border-color:#fecaca}
.confirm-card .btn-confirm.deny:hover{background:var(--danger);color:#fff;border-color:var(--danger)}
.confirm-card .btn-confirm:disabled{opacity:.5;cursor:not-allowed;transform:none}
.confirm-card .card-status{padding:10px 18px 14px;font-size:13px;font-weight:500;display:flex;align-items:center;gap:6px}
.confirm-card .card-status.approved{color:var(--success);background:var(--success-light)}
.confirm-card .card-status.denied{color:var(--danger);background:var(--danger-light)}
/* ── Mobile ──────────────────────────────────────────────────── */
@media(max-width:768px){
.sidebar{position:fixed;left:-100%;z-index:10;transition:left .3s cubic-bezier(.16,1,.3,1);width:85vw;max-width:320px;box-shadow:var(--shadow-lg)}
@ -270,11 +322,41 @@ let activeSessionId = null;
let ws = null;
let isStreaming = false;
let currentAgentBubble = null;
let currentAgentRawText = ''; // accumulate raw text during streaming
let currentThinkingBubble = null;
let loadingEl = null;
let skills = [];
const API = '/api/v1/chat';
const SKILLS_API = '/api/v1/skills';
// ── Markdown helpers ──────────────────────────────────────────
const _MD_RE = /(?:^#{1,4}\s|^\s*[-*+]\s|^\s*\d+\.\s|```|\*\*|__|\[.*?\]\(|^\s*>)/m;
function isMarkdown(text) {
if (!text) return false;
return _MD_RE.test(text);
}
function renderMarkdown(text) {
if (!text || !text.trim()) return '';
try {
return marked.parse(text, { breaks: true, gfm: true });
} catch {
return esc(text);
}
}
function renderBubbleContent(bubble, text, forceMd) {
const md = forceMd || isMarkdown(text);
if (md) {
bubble.classList.add('md');
bubble.innerHTML = renderMarkdown(text);
} else {
bubble.classList.remove('md');
bubble.textContent = text;
}
}
// ── API helpers ────────────────────────────────────────────────
async function api(base, path, opts = {}) {
const res = await fetch(base + path, {
@ -393,35 +475,66 @@ function handleWsMessage(msg) {
removeLoading();
appendStep(`路由: ${msg.skill || 'default'} (${msg.method}, ${Math.round((msg.confidence || 0) * 100)}%)`);
break;
case 'thinking':
removeLoading();
if (!currentThinkingBubble) {
currentThinkingBubble = appendThinking();
}
const thinkingContent = msg.content || '';
if (currentThinkingBubble) {
currentThinkingBubble.textContent += thinkingContent;
}
scrollToBottom();
break;
case 'token':
removeLoading();
// Clear any thinking bubble when real content starts
if (currentThinkingBubble) {
currentThinkingBubble.remove();
currentThinkingBubble = null;
}
if (!currentAgentBubble) {
currentAgentBubble = appendMessage('agent', '');
currentAgentRawText = '';
isStreaming = true;
updateSendBtn();
}
currentAgentBubble.textContent += msg.content || '';
currentAgentRawText += msg.content || '';
// During streaming, show raw text for performance
currentAgentBubble.textContent = currentAgentRawText;
scrollToBottom();
break;
case 'final_answer':
removeLoading();
// Clear thinking bubble
if (currentThinkingBubble) {
currentThinkingBubble.remove();
currentThinkingBubble = null;
}
if (currentAgentBubble) {
const current = currentAgentBubble.textContent || '';
const final = msg.content || '';
if (!current.trim() || final.length > current.length) {
currentAgentBubble.textContent = final;
}
// Use the longer/more complete text
const bestText = (!currentAgentRawText.trim() || final.length > currentAgentRawText.length) ? final : currentAgentRawText;
renderBubbleContent(currentAgentBubble, bestText);
currentAgentBubble = null;
currentAgentRawText = '';
} else {
appendMessage('agent', msg.content || '');
}
isStreaming = false;
updateSendBtn();
scrollToBottom();
// Auto-refresh skill list after agent finishes (may have installed new skills)
loadSkills();
break;
case 'step':
removeLoading();
if (msg.data?.event_type === 'tool_call') {
// Replace thinking with tool step
if (currentThinkingBubble) {
currentThinkingBubble.remove();
currentThinkingBubble = null;
}
appendStep(`使用工具: ${msg.data?.data?.tool_name || 'tool'}`);
}
break;
@ -431,8 +544,20 @@ function handleWsMessage(msg) {
appendStep(`技能: ${msg.data.skill} (${msg.data.method}, ${Math.round((msg.data.confidence || 0) * 100)}%)`);
}
break;
case 'confirmation_request':
removeLoading();
showConfirmationCard(msg.data);
break;
case 'confirmation_result':
updateConfirmationCard(msg.data);
break;
case 'error':
removeLoading();
// Clear thinking bubble on error
if (currentThinkingBubble) {
currentThinkingBubble.remove();
currentThinkingBubble = null;
}
appendMessage('agent', `[错误] ${msg.data?.message || '未知错误'}`);
currentAgentBubble = null;
isStreaming = false;
@ -462,6 +587,71 @@ function removeLoading() {
}
}
// ── Confirmation Card ──────────────────────────────────────────
function showConfirmationCard(data) {
hideWelcome();
const container = document.getElementById('messages');
const confirmationId = data.confirmation_id;
const command = data.command || '';
const reason = data.reason || '';
const div = document.createElement('div');
div.className = 'confirm-card';
div.id = `confirm-${confirmationId}`;
div.innerHTML = `
<div class="card-inner">
<div class="card-header">
<div class="card-icon">&#9888;</div>
<div class="card-title">操作确认</div>
</div>
<div class="card-body">
<div class="card-command">${esc(command)}</div>
${reason ? `<div class="card-reason">${esc(reason)}</div>` : ''}
</div>
<div class="card-actions" id="confirm-actions-${confirmationId}">
<button class="btn-confirm approve" onclick="replyConfirmation('${confirmationId}', true)">确认执行</button>
<button class="btn-confirm deny" onclick="replyConfirmation('${confirmationId}', false)">拒绝</button>
</div>
</div>
`;
container.appendChild(div);
scrollToBottom();
}
function replyConfirmation(confirmationId, approved) {
// Disable buttons immediately
const actionsEl = document.getElementById(`confirm-actions-${confirmationId}`);
if (actionsEl) {
const buttons = actionsEl.querySelectorAll('.btn-confirm');
buttons.forEach(btn => { btn.disabled = true; });
}
// Send reply to server
if (ws && ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({
type: 'confirmation_reply',
confirmation_id: confirmationId,
approved: approved,
}));
}
}
function updateConfirmationCard(data) {
const confirmationId = data.confirmation_id;
const approved = data.approved;
const cardEl = document.getElementById(`confirm-${confirmationId}`);
if (!cardEl) return;
// Replace actions with status
const actionsEl = document.getElementById(`confirm-actions-${confirmationId}`);
if (actionsEl) {
const statusClass = approved ? 'approved' : 'denied';
const statusText = approved ? '&#10003; 已确认执行' : '&#10007; 已拒绝';
actionsEl.outerHTML = `<div class="card-status ${statusClass}">${statusText}</div>`;
}
scrollToBottom();
}
// ── Send message ───────────────────────────────────────────────
async function sendMessage() {
const input = document.getElementById('input');
@ -523,7 +713,12 @@ function appendMessage(role, content) {
div.className = `msg ${cssRole}`;
const bubble = document.createElement('div');
bubble.className = 'bubble';
bubble.textContent = content;
// Render markdown for agent messages
if (cssRole === 'agent' && content) {
renderBubbleContent(bubble, content);
} else {
bubble.textContent = content;
}
div.appendChild(bubble);
const meta = document.createElement('div');
meta.className = 'meta';
@ -548,6 +743,18 @@ function appendStep(text) {
scrollToBottom();
}
function appendThinking() {
hideWelcome();
const container = document.getElementById('messages');
const div = document.createElement('div');
div.className = 'thinking-msg';
const bubble = document.createElement('div');
bubble.className = 'bubble';
div.appendChild(bubble);
container.appendChild(div);
return bubble;
}
function renderHistory(msgs) {
const container = document.getElementById('messages');
container.innerHTML = '';
@ -672,7 +879,7 @@ async function installSkill() {
status.textContent = `自动安装失败,正在请求智能体协助...`;
if (ws && ws.readyState === WebSocket.OPEN) {
const installMsg = `请帮我安装一个名为"${name}"的技能。请按以下步骤操作1. 使用搜索工具在网上搜索 "${name}" 的 YAML 配置文件可在技能市场、GitHub 等平台搜索2. 如果找到了,使用 shell 工具将其下载到 configs/skills/${name}.yaml3. 下载完成后,使用 shell 工具执行 curl 命令调用 API 注册curl -X POST http://localhost:${location.port}/api/v1/skills/install -H 'Content-Type: application/json' -d '{"name":"${name}","source":"file://configs/skills/${name}.yaml"}'4. 如果找不到这个技能,请告诉我。`;
const installMsg = `请帮我安装一个名为"${name}"的技能。请使用 skill_install 工具安装source 参数格式为 owner/repo@skill。如果不知道完整 source先用 shell 执行 npx skills search ${name} 搜索。`;
appendMessage('user', installMsg);
ws.send(JSON.stringify({ type: 'message', content: installMsg }));
currentAgentBubble = null;

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio
import datetime
import logging
from collections import defaultdict
from collections import defaultdict, deque
from typing import Any
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
@ -31,8 +31,8 @@ class AsyncWriteQueue:
self._store = store
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
self._worker: asyncio.Task | None = None
# Pending buffer: session_id -> list of Messages not yet persisted
self._pending_buffer: dict[str, list[Message]] = defaultdict(list)
# Pending buffer: session_id -> deque of Messages not yet persisted
self._pending_buffer: dict[str, deque[Message]] = defaultdict(deque)
self._max_buffer_size = max_buffer_size
self._pending_count = 0
@ -73,7 +73,7 @@ class AsyncWriteQueue:
pass
if not buf:
self._pending_buffer.pop(message.session_id, None)
self._pending_count -= 1
self._pending_count = max(0, self._pending_count - 1)
self._queue.task_done()
def enqueue(self, message: Message, session: Session) -> None:

View File

@ -12,6 +12,7 @@ import os
import re
import shlex
import time
import uuid
from collections import deque
from typing import Any, Callable, Awaitable
@ -24,46 +25,135 @@ logger = logging.getLogger(__name__)
# 安全白名单:这些命令前缀不需要确认
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
# 文件浏览
"cd",
"export",
"ls",
"grep",
"find",
"pwd",
"find",
"file",
"stat",
"tree",
"du",
"wc",
# 文本查看/处理
"cat",
"head",
"tail",
"less",
"more",
"grep",
"egrep",
"fgrep",
"sort",
"uniq",
"diff",
"comm",
"cut",
"tr",
"awk",
"sed", # sed without -i is read-only; -i is caught by operator check
"tee",
"xargs",
# 输出
"echo",
"which",
"printf",
# 系统信息
"whoami",
"id",
"date",
"uname",
"hostname",
"uptime",
"df",
"du",
"free",
"ps",
"top",
"htop",
"ps",
"env",
"printenv",
"type",
"file",
"stat",
"wc",
"sort",
"uniq",
"diff",
"sleep",
"which",
"whereis",
"arch",
"lscpu",
"lsblk",
"mount",
# 网络(只读查询)
"curl", # curl GET is safe; POST/PUT caught by operator check
"wget", # wget download is generally safe
"ping",
"traceroute",
"dig",
"nslookup",
"host",
"ifconfig",
"ip",
"netstat",
"ss",
"lsof",
# 版本/帮助
"python --version",
"python3 --version",
"node --version",
"npm --version",
"npm list",
"npm view",
"npm info",
"npm search",
"npx --version",
"pip list",
"pip show",
"pip search",
"java --version",
"go version",
"rustc --version",
"cargo --version",
"git --version",
"docker --version",
"docker ps",
"docker images",
"docker logs",
"docker inspect",
# Git 只读
"git status",
"git log",
"git diff",
"git branch",
"git remote",
"pip list",
"pip show",
"python --version",
"python3 --version",
"node --version",
"npm list",
"docker ps",
"docker images",
"git show",
"git tag",
"git stash list",
"git config --list",
# 包管理器(只读查询)
"brew list",
"brew info",
"brew search",
"apt list",
"apt search",
"apt show",
"yum list",
"yum search",
"yum info",
# 其他安全命令
"export",
"sleep",
"true",
"false",
"test",
"seq",
"basename",
"dirname",
"realpath",
"readlink",
"md5sum",
"sha256sum",
"shasum",
"openssl", # openssl dgst / rand are safe
# tar 解压查看
"tar -tf",
"tar --list",
"zipinfo",
"unzip -l",
)
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
@ -95,7 +185,8 @@ _DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
]
_SHELL_OPERATORS = re.compile(r'[|;&]|&&|\|\||\$\(|\$\{|`|\$<|>|<|\n')
_SHELL_PIPE_OPERATORS = re.compile(r'\|')
_SHELL_CHAIN_OPERATORS = re.compile(r'[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n')
class ShellTool(Tool):
@ -225,21 +316,26 @@ class ShellTool(Tool):
session_id = kwargs.get("session_id")
interactive = kwargs.get("interactive", False)
# 安全检查:危险命令需要确认
if self._is_dangerous(command):
# 安全检查:危险命令需要确认(除非已通过 _skip_dangerous_check 跳过)
skip_dangerous = kwargs.get("_skip_dangerous_check", False)
if not skip_dangerous and self._is_dangerous(command):
confirmed = await self._request_confirmation(command)
if not confirmed:
self._log_audit(command, None, blocked=True)
# 返回确认请求由上层ReActEngine/chat处理
confirmation_id = str(uuid.uuid4())
return {
"output": "",
"needs_confirmation": True,
"confirmation_id": confirmation_id,
"command": command[:500],
"reason": "此命令被识别为潜在危险操作,需要用户确认",
"suggestions": [
"确认执行此命令",
"拒绝执行此命令",
],
"output": f"命令被拒绝: {command[:100]}",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"危险命令已被拒绝执行: {command[:100]}",
"suggestions": [
"如需执行此命令,请手动确认",
"考虑使用更安全的替代命令",
],
}
# 根据模式执行
@ -362,12 +458,33 @@ class ShellTool(Tool):
def _is_dangerous(self, command: str) -> bool:
"""检查命令是否为危险操作
白名单命令直接放行其他命令检查是否匹配危险模式
白名单命令直接放行管道命令|在所有子命令都安全时放行
其他链式操作符;&&||$()>< 一律视为危险
"""
command_stripped = command.strip()
# Check for shell operators that chain commands (always dangerous)
if _SHELL_OPERATORS.search(command_stripped):
# Check for dangerous chain operators (;, &&, ||, $(), backticks, redirections, newlines)
if _SHELL_CHAIN_OPERATORS.search(command_stripped):
return True
# Handle pipe commands: split and check each sub-command
if _SHELL_PIPE_OPERATORS.search(command_stripped):
parts = command_stripped.split('|')
for part in parts:
part = part.strip()
if not part:
continue
if self._is_single_command_dangerous(part):
return True
return False # All pipe segments are safe
# Single command
return self._is_single_command_dangerous(command_stripped)
def _is_single_command_dangerous(self, command: str) -> bool:
"""Check if a single command (no pipes/chains) is dangerous."""
command_stripped = command.strip()
if not command_stripped:
return True
# Parse the actual binary being invoked
@ -377,7 +494,6 @@ class ShellTool(Tool):
return True
binary = os.path.basename(tokens[0])
except ValueError:
# Unparsable command - treat as dangerous
return True
# Whitelist check: first try full command prefix match, then binary-only match
@ -405,6 +521,9 @@ class ShellTool(Tool):
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
if flag_pattern in cmd_str:
return True
# Binary has dangerous flags but none matched — treat as safe
# (e.g., "git add" is safe even though "git push --force" is not)
return False
# 3. Cross-token dangerous patterns (regex)
command_lower = command_stripped.lower()
@ -412,7 +531,10 @@ class ShellTool(Tool):
if pattern.search(command_lower):
return True
return True # Unknown commands are dangerous by default
# 4. Unknown binary — check if it looks like a path or known safe pattern
# Commands like /usr/bin/python3, ./script.sh, etc. are not in whitelist
# but may be safe. Default to requiring confirmation for truly unknown binaries.
return True
async def _request_confirmation(self, command: str) -> bool:
"""请求人工确认危险命令

View File

@ -0,0 +1,169 @@
"""SkillInstallTool - Agent 可调用的技能安装工具"""
import asyncio
import logging
import os
from typing import Any, Callable, Awaitable
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class SkillInstallTool(Tool):
"""技能安装工具
Agent 可以通过正确的命令安装技能而不是用 shell 执行 npm install
使用 `npx skills install <owner/repo@skill>` 安装技能然后注册到 skill_registry
Usage:
tool = SkillInstallTool()
result = await tool.execute(name="find-skills", source="vercel-labs/skills@find-skills")
"""
def __init__(
self,
name: str = "skill_install",
description: str = "安装 Agent 技能包。使用 npx skills install 安装指定技能,不要用 npm install。",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
skill_registry=None,
tool_registry=None,
):
schema = input_schema or {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "技能名称,如 find-skills",
},
"source": {
"type": "string",
"description": "技能来源,格式为 owner/repo@skill如 vercel-labs/skills@find-skills。如果不提供则用 name 搜索",
},
},
"required": ["name"],
}
super().__init__(
name=name,
description=description,
input_schema=schema,
output_schema=output_schema,
version=version,
tags=tags,
)
self._confirm_callback = confirm_callback
self._skill_registry = skill_registry
self._tool_registry = tool_registry
async def execute(self, **kwargs) -> dict:
name = kwargs.get("name", "").strip()
source = kwargs.get("source", "").strip()
if not name:
return {
"output": "错误: 必须提供 name 参数",
"exit_code": 1,
"is_error": True,
}
# Build the install command
if source:
install_target = source
else:
install_target = name
# Request confirmation before installing
if self._confirm_callback:
confirmed = await self._confirm_callback(f"npx skills install {install_target}")
if not confirmed:
return {
"output": f"技能安装被用户拒绝: {install_target}",
"exit_code": 126,
"is_error": True,
}
try:
proc = await asyncio.create_subprocess_exec(
"npx", "skills@latest", "install", install_target, "-y",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
output = stdout.decode("utf-8", errors="replace")
error_output = stderr.decode("utf-8", errors="replace")
if proc.returncode == 0:
# Try to register the skill into skill_registry
registration_msg = ""
if self._skill_registry:
registration_msg = self._try_register_skill(name)
return {
"output": f"技能 {name} 安装成功\n{output}\n{registration_msg}",
"exit_code": 0,
"is_error": False,
}
else:
return {
"output": f"技能 {name} 安装失败 (exit={proc.returncode})\n{error_output}\n{output}",
"exit_code": proc.returncode,
"is_error": True,
"suggestions": [
"检查技能名称是否正确",
"使用 source 参数指定完整路径,如 vercel-labs/skills@find-skills",
"运行 npx skills search <name> 搜索可用技能",
],
}
except asyncio.TimeoutError:
return {
"output": f"技能 {name} 安装超时120s",
"exit_code": -1,
"is_error": True,
}
except Exception as e:
return {
"output": f"技能 {name} 安装异常: {e}",
"exit_code": -1,
"is_error": True,
}
def _try_register_skill(self, name: str) -> str:
"""Try to find and register the installed skill YAML into skill_registry."""
try:
from agentkit.skills.loader import SkillLoader
for search_dir in [os.path.join(os.getcwd(), ".agents", "skills"),
os.path.join(os.path.expanduser("~"), ".agents", "skills"),
os.path.join(os.getcwd(), "configs", "skills")]:
yaml_path = os.path.join(search_dir, f"{name}.yaml")
if os.path.exists(yaml_path):
loader = SkillLoader(
skill_registry=self._skill_registry,
tool_registry=self._tool_registry,
)
loader.load_from_file(yaml_path)
return f"技能已注册到系统(来源: {yaml_path}"
# Also check for directory-based skills
for search_dir in [os.path.join(os.getcwd(), ".agents", "skills"),
os.path.join(os.path.expanduser("~"), ".agents", "skills")]:
skill_dir = os.path.join(search_dir, name)
if os.path.isdir(skill_dir):
for fname in os.listdir(skill_dir):
if fname.endswith((".yaml", ".yml")):
yaml_path = os.path.join(skill_dir, fname)
loader = SkillLoader(
skill_registry=self._skill_registry,
tool_registry=self._tool_registry,
)
loader.load_from_file(yaml_path)
return f"技能已注册到系统(来源: {yaml_path}"
return "技能文件已下载,但未找到 YAML 配置文件进行注册。可能需要重启服务。"
except Exception as e:
logger.warning(f"Failed to register skill {name}: {e}")
return f"技能文件已下载,但注册失败: {e}"

View File

@ -66,6 +66,7 @@ class TerminalSession:
self.session_id = session_id
self._cwd = cwd or os.getcwd()
self._env: dict[str, str] = dict(env or os.environ)
self._env_delta: set[str] = set() # Track only env vars explicitly set in this session
self._history: deque[CommandRecord] = deque(maxlen=max_history)
self._output_parser = OutputParser()
self._created_at = time.time()
@ -147,7 +148,12 @@ class TerminalSession:
timeout=timeout,
)
except asyncio.TimeoutError:
proc.kill()
try:
proc.kill()
except ProcessLookupError:
pass # Process already exited
except OSError:
pass
await proc.wait()
output = f"命令执行超时({timeout}s"
exit_code = -1
@ -196,11 +202,12 @@ class TerminalSession:
if self._cwd:
parts.append(f"cd {shlex.quote(self._cwd)}")
# 注入环境变量
for key, value in self._env.items():
if not _ENV_KEY_PATTERN.match(key):
continue # Skip invalid env key names
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
# 注入环境变量(仅注入 session 中显式设置的增量,避免泄露完整 os.environ
if self._env_delta:
for key in self._env_delta:
value = self._env.get(key, "")
if _ENV_KEY_PATTERN.match(key):
parts.append(f"export {shlex.quote(key)}={shlex.quote(value)}")
parts.append(command)
return " && ".join(parts)
@ -269,6 +276,7 @@ class TerminalSession:
for key, value in matches:
value = value.strip().strip("'\"")
self._env[key] = value
self._env_delta.add(key)
def _add_history(self, record: CommandRecord) -> None:
"""添加命令记录到历史deque maxlen 自动淘汰最旧记录"""

View File

@ -1,17 +1,23 @@
"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路"""
"""集成测试 - CostAwareRouter → Engine → AlignmentGuard 全链路
包含合并 LLM 分类路由和并行工具执行的集成测试"""
from __future__ import annotations
import asyncio
import json
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
from agentkit.core.react import ReActEngine, ReActResult, ReActStep
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.core.react import ReActEngine, ReActResult, ReActStep, ReActEvent
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.org.context import AgentProfile, OrganizationContext
from agentkit.quality.alignment import AlignmentConfig, AlignmentGuard
from agentkit.tools.base import Tool
# ---------------------------------------------------------------------------
@ -313,3 +319,285 @@ class TestAlignmentGuardConstraintInjection:
alert = guard.record_interaction("session-chain-1")
assert alert is None
assert guard.get_interaction_count("session-chain-1") == 1
# ---------------------------------------------------------------------------
# Test 5: Merged Router → Engine chain (HeuristicClassifier → merged LLM → ReAct)
# ---------------------------------------------------------------------------
class _SlowTool(Tool):
"""带延迟的 Fake Tool用于验证并行执行"""
def __init__(
self,
name: str = "slow_tool",
description: str = "A slow tool for testing",
delay: float = 0.1,
result: dict | None = None,
):
super().__init__(name=name, description=description)
self._delay = delay
self._result = result or {"status": "ok"}
self.call_count = 0
async def execute(self, **kwargs) -> dict:
self.call_count += 1
await asyncio.sleep(self._delay)
return self._result
class TestMergedRouterToEngineChain:
"""完整链路:用户消息 → HeuristicClassifier → merged LLM classify → ReActEngine → 结果
HeuristicClassifier 返回中等复杂度 (0.3-0.7) 使用合并 LLM 调用
"""
@pytest.mark.asyncio
async def test_medium_complexity_uses_merged_llm_then_react(self):
"""中等复杂度触发 merged LLM classify然后路由到 ReActEngine 执行"""
# "如何优化代码" 包含 "如何"(中等)和 "代码"heuristic 给出中等复杂度
# merged LLM classify 返回中等复杂度 + 无 skill_hint
# → 路由到默认 agent (ReAct)
merged_response = _make_llm_response(json.dumps({
"complexity": 0.5,
"intent": "code_optimization",
"skill_hint": None,
}))
# ReActEngine 最终答案
react_final = _make_llm_response("建议使用缓存和异步IO来优化代码性能。")
gateway = _make_mock_gateway([merged_response, react_final])
org_context = OrganizationContext()
router = CostAwareRouter(
llm_gateway=gateway,
org_context=org_context,
merged_llm_classify=True,
)
mock_skill_registry = MagicMock()
mock_skill_registry.list_skills.return_value = []
mock_intent_router = AsyncMock()
# Step 1: Route
route_result = await router.route(
content="如何优化代码性能",
skill_registry=mock_skill_registry,
intent_router=mock_intent_router,
default_tools=[],
default_system_prompt="You are helpful",
default_model="default",
default_agent_name="default",
)
# 验证路由结果:中等复杂度,使用 merged_llm 方法
assert 0.3 <= route_result.complexity <= 0.7
assert route_result.match_method is not None
assert "merged_llm" in route_result.match_method
# Step 2: Execute with ReActEngine
react_engine = ReActEngine(llm_gateway=gateway)
engine_result = await react_engine.execute(
messages=[{"role": "user", "content": route_result.clean_content}],
system_prompt=route_result.system_prompt,
)
# Step 3: Verify result
assert isinstance(engine_result, ReActResult)
assert "优化" in engine_result.output or "缓存" in engine_result.output
@pytest.mark.asyncio
async def test_medium_complexity_merged_llm_routes_to_skill_then_react(self):
"""中等复杂度 + merged LLM 返回 skill_hint → 路由到 skill → ReAct 执行"""
merged_response = _make_llm_response(json.dumps({
"complexity": 0.45,
"intent": "code_review",
"skill_hint": "code_reviewer",
}))
react_final = _make_llm_response("代码审查完成发现3个潜在问题。")
gateway = _make_mock_gateway([merged_response, react_final])
# 创建包含 code_reviewer skill 的 mock registry
mock_skill = MagicMock()
mock_skill.name = "code_reviewer"
mock_skill.config.intent.keywords = ["code", "review"]
mock_skill.config.llm = None
mock_skill.config.prompt = None
mock_skill.tools = []
mock_skill_registry = MagicMock()
mock_skill_registry.list_skills.return_value = [mock_skill]
mock_skill_registry.get.return_value = mock_skill
org_context = OrganizationContext()
router = CostAwareRouter(
llm_gateway=gateway,
org_context=org_context,
merged_llm_classify=True,
)
route_result = await router.route(
content="如何优化代码性能",
skill_registry=mock_skill_registry,
intent_router=AsyncMock(),
default_tools=[],
default_system_prompt="You are helpful",
default_model="default",
default_agent_name="default",
)
# 验证路由到 skill
assert route_result.matched is True
assert route_result.skill_name == "code_reviewer"
assert route_result.match_method == "merged_llm"
# Execute with ReActEngine
react_engine = ReActEngine(llm_gateway=gateway)
engine_result = await react_engine.execute(
messages=[{"role": "user", "content": route_result.clean_content}],
system_prompt=route_result.system_prompt,
)
assert isinstance(engine_result, ReActResult)
@pytest.mark.asyncio
async def test_merged_llm_high_complexity_delegates_to_layer2(self):
"""HeuristicClassifier 中等复杂度 → merged LLM 返回高复杂度 → 委派到 Layer 2"""
merged_response = _make_llm_response(json.dumps({
"complexity": 0.85,
"intent": "deep_analysis",
"skill_hint": None,
}))
gateway = _make_mock_gateway([merged_response])
org_context = OrganizationContext()
org_context.register_agent(AgentProfile(
name="analyst",
agent_type="react",
capabilities=["分析", "优化", "代码"],
skills=["code_analysis"],
current_load=0,
))
org_context.find_best_agent = MagicMock(
return_value=org_context.get_agent_profile("analyst")
)
router = CostAwareRouter(
llm_gateway=gateway,
org_context=org_context,
merged_llm_classify=True,
)
mock_skill_registry = MagicMock()
mock_skill_registry.list_skills.return_value = []
mock_intent_router = AsyncMock()
route_result = await router.route(
content="如何优化代码性能",
skill_registry=mock_skill_registry,
intent_router=mock_intent_router,
default_tools=[],
default_system_prompt="You are helpful",
default_model="default",
default_agent_name="default",
)
# 高复杂度应委派到 Layer 2
assert route_result.complexity >= 0.7
assert route_result.matched is True
assert route_result.agent_name == "analyst"
# ---------------------------------------------------------------------------
# Test 6: Parallel Tools Integration (ReActEngine with parallel_tools="auto")
# ---------------------------------------------------------------------------
class TestParallelToolsIntegration:
"""ReActEngine + parallel_tools="auto" 在真实场景下的集成测试
LLM 返回 2 tool_calls _parallelizable=true两者并行执行
"""
@pytest.mark.asyncio
async def test_auto_parallel_two_tools_realistic(self):
"""真实场景LLM 返回 2 个并行工具调用,并行执行"""
tool_a = _SlowTool(name="search_web", delay=0.1, result={"results": ["Python best practices"]})
tool_b = _SlowTool(name="search_docs", delay=0.1, result={"docs": ["Official Python docs"]})
# LLM 返回 2 个并行工具调用
tool_call_response = LLMResponse(
content="",
model="test-model",
usage=TokenUsage(prompt_tokens=50, completion_tokens=20),
tool_calls=[
ToolCall(id="tc_1", name="search_web", arguments={"query": "python", "_parallelizable": True}),
ToolCall(id="tc_2", name="search_docs", arguments={"topic": "python", "_parallelizable": True}),
],
)
final_response = _make_llm_response("Based on search results, Python best practices include...")
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=[tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Search for Python best practices"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1
assert tool_b.call_count == 1
# 并行执行应比串行快
assert elapsed < 0.25, f"Parallel execution too slow: {elapsed:.2f}s"
# 验证轨迹包含两个工具调用
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
assert len(tool_steps) == 2
names = {s.tool_name for s in tool_steps}
assert names == {"search_web", "search_docs"}
@pytest.mark.asyncio
async def test_auto_parallel_with_serial_tool_mixed(self):
"""混合场景1 个串行工具 + 2 个并行工具"""
tool_serial = _SlowTool(name="init_context", delay=0.05, result={"context": "ready"})
tool_para_a = _SlowTool(name="search_web", delay=0.1, result={"results": ["web result"]})
tool_para_b = _SlowTool(name="search_docs", delay=0.1, result={"docs": ["doc result"]})
tool_call_response = LLMResponse(
content="",
model="test-model",
usage=TokenUsage(prompt_tokens=50, completion_tokens=20),
tool_calls=[
ToolCall(id="tc_1", name="init_context", arguments={"project": "test"}),
ToolCall(id="tc_2", name="search_web", arguments={"query": "test", "_parallelizable": True}),
ToolCall(id="tc_3", name="search_docs", arguments={"topic": "test", "_parallelizable": True}),
],
)
final_response = _make_llm_response("Combined result from all tools")
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=[tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
result = await engine.execute(
messages=[{"role": "user", "content": "Initialize and search"}],
tools=[tool_serial, tool_para_a, tool_para_b],
)
assert isinstance(result, ReActResult)
assert tool_serial.call_count == 1
assert tool_para_a.call_count == 1
assert tool_para_b.call_count == 1
# 所有工具结果都在轨迹中
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
assert len(tool_steps) == 3

View File

@ -1,5 +1,7 @@
"""AuctionHouse 与 WealthTracker 单元测试"""
import threading
import pytest
from agentkit.marketplace.auction import AuctionHouse, AuctionResult, Bid
@ -288,3 +290,406 @@ class TestAuctionDefaultDisabled:
auction_enabled = getattr(marketplace_cfg, "auction_enabled", False)
assert auction_enabled is False
# If marketplace doesn't exist at all, auction is implicitly disabled
# ---- Vickrey Auction 测试 ----
class TestVickreySingleBidder:
"""Vickrey 拍卖单一竞价者获胜支付自身成本利润为0"""
@pytest.mark.asyncio
async def test_single_bidder_wins_pays_own_cost(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bid = make_bid(agent_name="solo_agent", estimated_cost=5.0)
result = await house.run_vickrey_auction("task", [bid])
assert result.winner is not None
assert result.winner.agent_name == "solo_agent"
assert result.total_bidders == 1
# Winner pays own cost, so profit = 0 → wealth unchanged
assert tracker.get_wealth("solo_agent") == 100.0
@pytest.mark.asyncio
async def test_single_bidder_selection_reason(self):
house = AuctionHouse()
bid = make_bid(agent_name="solo_agent", estimated_cost=10.0)
result = await house.run_vickrey_auction("task", [bid])
assert "sole eligible bidder" in result.selection_reason
class TestVickreyTwoBidders:
"""Vickrey 拍卖:两个竞价者,最低价赢,支付第二低价"""
@pytest.mark.asyncio
async def test_lowest_cost_wins_pays_second(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bid_cheap = make_bid(agent_name="cheap_agent", estimated_cost=5.0)
bid_expensive = make_bid(agent_name="expensive_agent", estimated_cost=10.0)
result = await house.run_vickrey_auction("task", [bid_cheap, bid_expensive])
assert result.winner is not None
assert result.winner.agent_name == "cheap_agent"
# Winner pays second-lowest = 10.0, profit = 10.0 - 5.0 = 5.0
assert tracker.get_wealth("cheap_agent") == 100.0 + 5.0
@pytest.mark.asyncio
async def test_second_price_not_own_price(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
result = await house.run_vickrey_auction("task", [bid_a, bid_b])
# agent_a wins, pays 7.0 (not 3.0), profit = 7.0 - 3.0 = 4.0
assert tracker.get_wealth("agent_a") == 100.0 + 4.0
# Loser pays nothing
assert tracker.get_wealth("agent_b") == 100.0
@pytest.mark.asyncio
async def test_selection_reason_contains_second_price(self):
house = AuctionHouse()
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
result = await house.run_vickrey_auction("task", [bid_a, bid_b])
assert "7.0" in result.selection_reason
assert "agent_b" in result.selection_reason
class TestVickreyThreeBidders:
"""Vickrey 拍卖:三个竞价者,最低价赢,支付第二低价"""
@pytest.mark.asyncio
async def test_lowest_wins_pays_second_lowest(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
bid_c = make_bid(agent_name="agent_c", estimated_cost=12.0)
result = await house.run_vickrey_auction("task", [bid_a, bid_b, bid_c])
assert result.winner is not None
assert result.winner.agent_name == "agent_a"
# Winner pays second-lowest = 7.0, profit = 7.0 - 3.0 = 4.0
assert tracker.get_wealth("agent_a") == 100.0 + 4.0
# Third bidder's cost doesn't affect payment
assert result.total_bidders == 3
@pytest.mark.asyncio
async def test_middle_bidder_wins_when_cheapest_bankrupt(self):
tracker = WealthTracker()
# Make agent_a bankrupt
tracker.penalize("agent_a", 300.0)
assert tracker.is_bankrupt("agent_a") is True
house = AuctionHouse(wealth_tracker=tracker)
bid_a = make_bid(agent_name="agent_a", estimated_cost=3.0)
bid_b = make_bid(agent_name="agent_b", estimated_cost=7.0)
bid_c = make_bid(agent_name="agent_c", estimated_cost=12.0)
result = await house.run_vickrey_auction("task", [bid_a, bid_b, bid_c])
# agent_a is bankrupt, so agent_b wins, pays 12.0
assert result.winner is not None
assert result.winner.agent_name == "agent_b"
# profit = 12.0 - 7.0 = 5.0
assert tracker.get_wealth("agent_b") == 100.0 + 5.0
class TestVickreyCapabilityFiltering:
"""Vickrey 拍卖:能力过滤"""
def test_filter_by_capabilities_basic(self):
house = AuctionHouse()
bids = [
make_bid(agent_name="a", capabilities=["search", "analysis"]),
make_bid(agent_name="b", capabilities=["search"]),
make_bid(agent_name="c", capabilities=["analysis", "coding"]),
]
filtered = house.filter_by_capabilities(bids, ["search"])
assert len(filtered) == 2
assert {b.agent_name for b in filtered} == {"a", "b"}
def test_filter_by_capabilities_requires_all(self):
house = AuctionHouse()
bids = [
make_bid(agent_name="a", capabilities=["search", "analysis"]),
make_bid(agent_name="b", capabilities=["search"]),
make_bid(agent_name="c", capabilities=["analysis", "coding"]),
]
filtered = house.filter_by_capabilities(bids, ["search", "analysis"])
assert len(filtered) == 1
assert filtered[0].agent_name == "a"
def test_filter_by_capabilities_case_insensitive(self):
house = AuctionHouse()
bids = [
make_bid(agent_name="a", capabilities=["Search", "Analysis"]),
]
filtered = house.filter_by_capabilities(bids, ["search", "analysis"])
assert len(filtered) == 1
def test_filter_by_capabilities_no_match(self):
house = AuctionHouse()
bids = [
make_bid(agent_name="a", capabilities=["search"]),
]
filtered = house.filter_by_capabilities(bids, ["coding"])
assert len(filtered) == 0
def test_filter_by_capabilities_empty_requirements(self):
house = AuctionHouse()
bids = [
make_bid(agent_name="a", capabilities=["search"]),
]
filtered = house.filter_by_capabilities(bids, [])
assert len(filtered) == 1
@pytest.mark.asyncio
async def test_vickrey_with_capability_filtering(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bids = [
make_bid(agent_name="a", estimated_cost=5.0, capabilities=["search", "analysis"]),
make_bid(agent_name="b", estimated_cost=3.0, capabilities=["search"]),
make_bid(agent_name="c", estimated_cost=8.0, capabilities=["search", "analysis"]),
]
# Require both "search" and "analysis" → only a and c eligible
result = await house.run_vickrey_auction("task", bids, required_capabilities=["search", "analysis"])
assert result.winner is not None
assert result.winner.agent_name == "a"
# a wins (cost=5.0), pays second-lowest among eligible = c's cost = 8.0
assert tracker.get_wealth("a") == 100.0 + (8.0 - 5.0)
class TestVickreyBankruptAgent:
"""Vickrey 拍卖:破产 Agent 被排除"""
@pytest.mark.asyncio
async def test_bankrupt_agent_excluded(self):
tracker = WealthTracker()
tracker.penalize("bankrupt_agent", 300.0) # wealth = -200, bankrupt
house = AuctionHouse(wealth_tracker=tracker)
bid_bankrupt = make_bid(agent_name="bankrupt_agent", estimated_cost=1.0)
bid_ok = make_bid(agent_name="ok_agent", estimated_cost=10.0)
result = await house.run_vickrey_auction("task", [bid_bankrupt, bid_ok])
assert result.winner is not None
assert result.winner.agent_name == "ok_agent"
# Only 1 eligible bidder → pays own cost, profit = 0
assert tracker.get_wealth("ok_agent") == 100.0
@pytest.mark.asyncio
async def test_all_bankrupt_returns_none(self):
tracker = WealthTracker()
tracker.penalize("a", 300.0)
tracker.penalize("b", 300.0)
house = AuctionHouse(wealth_tracker=tracker)
bids = [
make_bid(agent_name="a", estimated_cost=1.0),
make_bid(agent_name="b", estimated_cost=2.0),
]
result = await house.run_vickrey_auction("task", bids)
assert result.winner is None
assert "bankrupt" in result.selection_reason.lower() or "eligible" in result.selection_reason.lower()
class TestVickreyNoBidders:
"""Vickrey 拍卖:无竞价者"""
@pytest.mark.asyncio
async def test_no_bidders_returns_none(self):
house = AuctionHouse()
result = await house.run_vickrey_auction("task", [])
assert result.winner is None
assert result.total_bidders == 0
assert result.all_bids == []
@pytest.mark.asyncio
async def test_no_eligible_after_capability_filter(self):
house = AuctionHouse()
bid = make_bid(agent_name="a", estimated_cost=5.0, capabilities=["search"])
result = await house.run_vickrey_auction("task", [bid], required_capabilities=["coding"])
assert result.winner is None
assert result.total_bidders == 1
class TestVickreyWealthTrackerUpdate:
"""Vickrey 拍卖WealthTracker 正确更新"""
@pytest.mark.asyncio
async def test_winner_earns_payment_minus_cost(self):
tracker = WealthTracker(initial_wealth=50.0)
house = AuctionHouse(wealth_tracker=tracker)
bid_a = make_bid(agent_name="a", estimated_cost=4.0)
bid_b = make_bid(agent_name="b", estimated_cost=9.0)
await house.run_vickrey_auction("task", [bid_a, bid_b])
# a wins, pays 9.0, profit = 9.0 - 4.0 = 5.0
assert tracker.get_wealth("a") == 50.0 + 5.0
# b pays nothing
assert tracker.get_wealth("b") == 50.0
@pytest.mark.asyncio
async def test_single_bidder_zero_profit(self):
tracker = WealthTracker(initial_wealth=100.0)
house = AuctionHouse(wealth_tracker=tracker)
# Single bidder: pays own cost, profit = 0
bid = make_bid(agent_name="a", estimated_cost=10.0)
await house.run_vickrey_auction("task", [bid])
assert tracker.get_wealth("a") == 100.0
class TestVickreyBackwardCompat:
"""Vickrey 拍卖:原有 score_bid 方法仍然可用"""
def test_score_bid_still_works(self):
tracker = WealthTracker()
house = AuctionHouse(wealth_tracker=tracker)
bid = make_bid(agent_name="test_agent", confidence=0.9, estimated_cost=5.0)
score = house.score_bid(bid)
wealth_factor = 1.0 + (100.0 / 1000.0)
expected = (0.9 / 5.0) * wealth_factor
assert abs(score - expected) < 0.0001
@pytest.mark.asyncio
async def test_run_auction_still_works(self):
house = AuctionHouse()
bid_low = make_bid(agent_name="low_agent", confidence=0.5, estimated_cost=10.0)
bid_high = make_bid(agent_name="high_agent", confidence=0.9, estimated_cost=10.0)
result = await house.run_auction("do something", [bid_low, bid_high])
assert result.winner is not None
assert result.winner.agent_name == "high_agent"
# ---- Bid Validation 测试 ----
class TestBidValidation:
"""Bid __post_init__ 验证"""
def test_negative_estimated_cost_raises(self):
with pytest.raises(ValueError, match="estimated_cost must be non-negative"):
Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=-1.0,
confidence=0.8,
payment_offer=1.0,
capabilities=[],
)
def test_confidence_above_one_raises(self):
with pytest.raises(ValueError, match="confidence must be between"):
Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=10.0,
confidence=1.5,
payment_offer=1.0,
capabilities=[],
)
def test_confidence_below_zero_raises(self):
with pytest.raises(ValueError, match="confidence must be between"):
Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=10.0,
confidence=-0.1,
payment_offer=1.0,
capabilities=[],
)
def test_negative_payment_offer_raises(self):
with pytest.raises(ValueError, match="payment_offer must be non-negative"):
Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=10.0,
confidence=0.8,
payment_offer=-5.0,
capabilities=[],
)
def test_zero_values_allowed(self):
bid = Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=0.0,
confidence=0.0,
payment_offer=0.0,
capabilities=[],
)
assert bid.estimated_cost == 0.0
assert bid.confidence == 0.0
assert bid.payment_offer == 0.0
def test_boundary_confidence_one_allowed(self):
bid = Bid(
agent_name="a",
architecture="react",
estimated_steps=5,
estimated_cost=10.0,
confidence=1.0,
payment_offer=1.0,
capabilities=[],
)
assert bid.confidence == 1.0
# ---- WealthTracker Thread Safety 测试 ----
class TestWealthTrackerThreadSafety:
"""WealthTracker 线程安全"""
def test_concurrent_reward_penalize(self):
tracker = WealthTracker(initial_wealth=1000.0)
errors: list[Exception] = []
def worker(action: str, name: str, amount: float, count: int):
try:
for _ in range(count):
if action == "reward":
tracker.reward(name, amount)
else:
tracker.penalize(name, amount)
except Exception as e:
errors.append(e)
threads = [
threading.Thread(target=worker, args=("reward", "agent_a", 1.0, 100)),
threading.Thread(target=worker, args=("penalize", "agent_a", 1.0, 100)),
threading.Thread(target=worker, args=("reward", "agent_b", 2.0, 50)),
]
for t in threads:
t.start()
for t in threads:
t.join()
assert not errors
# agent_a: 1000 + 100*1.0 - 100*1.0 = 1000
assert tracker.get_wealth("agent_a") == 1000.0
# agent_b: 1000 + 50*2.0 = 1100
assert tracker.get_wealth("agent_b") == 1100.0
# ---- Wealth Factor Lower Bound 测试 ----
class TestWealthFactorLowerBound:
"""get_wealth_factor 下限保护"""
def test_extremely_negative_wealth_clamped(self):
tracker = WealthTracker(initial_wealth=100.0)
tracker.penalize("agent_a", 5000.0)
# wealth = 100 - 5000 = -4900, factor would be 1.0 + (-4900/1000) = -3.9
# But with lower bound: max(0.01, -3.9) = 0.01
factor = tracker.get_wealth_factor("agent_a")
assert factor == 0.01
def test_slightly_negative_wealth_not_clamped(self):
tracker = WealthTracker(initial_wealth=100.0)
tracker.penalize("agent_a", 150.0)
# wealth = -50, factor = 1.0 + (-50/1000) = 0.95
factor = tracker.get_wealth_factor("agent_a")
assert abs(factor - 0.95) < 0.0001

View File

@ -580,9 +580,9 @@ class TestHeuristicClassifierIntegration:
@pytest.mark.asyncio
async def test_heuristic_mode_no_llm_call(self):
"""heuristic 模式不调用 LLM"""
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic")
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
@ -590,7 +590,7 @@ class TestHeuristicClassifierIntegration:
default_tools=[],
default_system_prompt="You are helpful.",
)
# LLM gateway.chat 不应被调用
# LLM gateway.chat 不应被调用heuristic + merged disabled
gateway.chat.assert_not_called()
# 复杂度应来自启发式分类器
assert result.complexity > 0.0
@ -629,3 +629,180 @@ class TestHeuristicClassifierIntegration:
"""默认分类器模式为 heuristic"""
router = CostAwareRouter()
assert router._classifier == "heuristic"
# ---------------------------------------------------------------------------
# U1: Merged LLM Classify
# ---------------------------------------------------------------------------
class TestMergedLLMClassify:
"""合并路由 LLM 调用测试"""
@pytest.mark.asyncio
async def test_merged_classify_returns_valid_skill(self):
"""合并调用返回有效 JSON + skill_hint正确路由到指定 skill"""
merged_response = json.dumps({
"complexity": 0.6,
"intent": "code_generation",
"skill_hint": "search",
})
gateway = _make_llm_gateway(merged_response)
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我搜索一下最新的新闻",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.match_method == "merged_llm"
assert result.complexity > 0.3
@pytest.mark.asyncio
async def test_merged_classify_malformed_response_fallback(self):
"""合并调用返回格式异常fallback 到默认 Agent"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "merged_llm_fallback"
assert result.complexity == 0.5
@pytest.mark.asyncio
async def test_merged_classify_low_complexity(self):
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
merged_response = json.dumps({
"complexity": 0.2,
"intent": "greeting",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Merged LLM returned complexity < 0.3, should route to low complexity
assert result.complexity < 0.3
assert "low" in result.match_method or "merged_llm_low" in result.match_method
@pytest.mark.asyncio
async def test_merged_classify_high_complexity(self):
"""合并调用返回 complexity > 0.7,走 Layer 2"""
merged_response = json.dumps({
"complexity": 0.85,
"intent": "research",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="researcher")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
org_context=org_context, merged_llm_classify=True,
)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
@pytest.mark.asyncio
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
merged_llm_classify=False,
)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not use merged_llm match_method
assert result.match_method != "merged_llm"
@pytest.mark.asyncio
async def test_merged_classify_no_llm_gateway_falls_back(self):
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not crash, should use IntentRouter fallback
assert result is not None
@pytest.mark.asyncio
async def test_merged_classify_skill_hint_not_found_fallback(self):
"""合并调用返回的 skill_hint 在 registry 中不存在fallback"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "unknown",
"skill_hint": "nonexistent_skill",
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should fallback to default agent (medium complexity, no skill match)
assert result.matched is False
assert result.match_method == "merged_llm_medium"
@pytest.mark.asyncio
async def test_merged_classify_only_one_llm_call(self):
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "question",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
await router.route(
content="如何使用这个功能?",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Only 1 LLM call should have been made (the merged classify)
assert gateway.chat.call_count == 1

View File

@ -995,3 +995,153 @@ class TestReWOOProgressiveFallback:
assert result.output == "ReAct answer with tool"
assert result.fallback_strategy == "react"
# ── Test: Configurable Fallback Strategies ──────────────────
class TestReWOOConfigurableFallback:
"""ReWOO 回退链配置化测试"""
async def test_default_fallback_strategies(self):
"""默认回退链simplified_rewoo → react → direct"""
from agentkit.core.rewoo import ReWOOEngine
engine = ReWOOEngine(llm_gateway=MagicMock(spec=LLMGateway))
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
async def test_custom_fallback_strategies(self):
"""自定义回退链plan_exec → react → direct"""
from agentkit.core.rewoo import ReWOOEngine
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
fallback_strategies=["plan_exec", "react", "direct"],
)
assert engine._fallback_strategies == ["plan_exec", "react", "direct"]
async def test_invalid_strategy_name_skipped(self):
"""无效策略名跳过并警告"""
from agentkit.core.rewoo import ReWOOEngine
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
fallback_strategies=["invalid_strategy", "react", "direct"],
)
assert engine._fallback_strategies == ["react", "direct"]
async def test_empty_fallback_strategies_uses_defaults(self):
"""空回退链回退到默认"""
from agentkit.core.rewoo import ReWOOEngine
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
fallback_strategies=[],
)
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
async def test_all_invalid_strategies_uses_defaults(self):
"""全部无效策略名回退到默认"""
from agentkit.core.rewoo import ReWOOEngine
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
fallback_strategies=["foo", "bar"],
)
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
async def test_custom_fallback_plan_exec_first(self):
"""自定义回退链plan_exec 优先,规划失败时先走 plan_exec 再走 react"""
from agentkit.core.rewoo import ReWOOEngine
# Planning fails, plan_exec also fails, react succeeds
invalid_plan = make_response(content="Not a plan")
plan_exec_fail = make_response(content="Still not a plan")
react_response = make_response(content="React answer")
gateway = make_mock_gateway([invalid_plan, plan_exec_fail, react_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["plan_exec", "react"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Task"}],
)
assert result.output == "React answer"
assert result.fallback_strategy == "react"
async def test_custom_fallback_direct_only(self):
"""自定义回退链:仅 direct跳过 simplified_rewoo 和 react"""
from agentkit.core.rewoo import ReWOOEngine
# Planning fails, direct succeeds
invalid_plan = make_response(content="Not a plan")
direct_response = make_response(content="Direct answer")
gateway = make_mock_gateway([invalid_plan, direct_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["direct"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Task"}],
)
assert result.output == "Direct answer"
assert result.fallback_strategy == "direct"
async def test_valid_strategies_constant(self):
"""验证 VALID_STRATEGIES 集合"""
from agentkit.core.rewoo import ReWOOEngine
assert ReWOOEngine.VALID_STRATEGIES == {"simplified_rewoo", "react", "direct", "plan_exec"}
async def test_stream_custom_fallback_react_only(self):
"""流式模式:自定义回退链仅 react"""
from agentkit.core.rewoo import ReWOOEngine
# Planning fails, react succeeds
invalid_plan = make_response(content="Not a plan")
react_response = make_response(content="React stream answer")
gateway = make_mock_gateway([invalid_plan, react_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["react"],
)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Task"}],
):
events.append(event)
event_types = [e.event_type for e in events]
assert "final_answer" in event_types
async def test_all_fallback_exhausted_raises(self):
"""所有回退策略耗尽时抛出 RuntimeError"""
from agentkit.core.rewoo import ReWOOEngine
# Planning fails, all fallbacks fail
call_count = 0
async def always_fail(**kwargs):
nonlocal call_count
call_count += 1
raise RuntimeError("LLM unavailable")
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=always_fail)
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["react", "direct"],
)
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
await engine.execute(
messages=[{"role": "user", "content": "Task"}],
)