diff --git a/README.md b/README.md index 22d75c8..4480168 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,15 @@ AgentKit 解决的核心问题:**从写 150 行 Agent 代码降为 10-20 行 YAML 配置**。 -传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 6 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 Skill(Prompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。 +传统方式下,每新增一个 Agent 需要编写子类、处理 LLM 调用、管理工具绑定、校验输出质量。AgentKit 将这些能力标准化为 8 个可组合模块,开发者只需编写 YAML 配置即可定义一个完整的 Skill(Prompt + Tool + 质量门禁),框架自动完成 ReAct 推理循环、模型路由降级、产出质量检查和标准化输出。 核心定位: - **配置驱动** -- YAML 定义 Skill,无需写 Agent 子类 - **生产就绪** -- 内置质量门禁、模型降级、用量统计 -- **两种部署** -- Python 库直接引用,或 FastAPI 独立部署 +- **三种使用** -- Python 库引用、CLI 聊天、Web GUI 界面 +- **工具丰富** -- 内置 Shell、搜索、爬虫、记忆等工具,支持 MCP 扩展 +- **Pipeline 编排** -- 多 Agent 协同、Saga 补偿、动态流水线 ## 核心特性 @@ -22,7 +24,7 @@ Think -> Act -> Observe 循环。LLM 自主决定是否调用工具、调用哪 ### 2. LLM Gateway -统一 LLM 调用入口。Provider 注册、模型别名解析(如 `deepseek` -> `deepseek/deepseek-chat`)、Fallback 降级策略、Token 用量和成本追踪。 +统一 LLM 调用入口。Provider 注册、模型别名解析(如 `default` -> `dashscope/qwen3-coder-plus`)、Fallback 降级策略、Token 用量和成本追踪。支持百炼 DashScope、OpenAI、DeepSeek 等 OpenAI 兼容 API。 ### 3. Skill 系统 @@ -40,73 +42,102 @@ Skill = SkillConfig + 绑定 Tools。一个 Skill 代表一个可执行技能, Schema 验证 + 字段类型归一化(str -> int/float/bool)+ 元数据附加(version、produced_at、quality_score)。所有 Skill 产出统一为 StandardOutput 格式。 +### 7. 内置工具集 + +开箱即用的工具插件,覆盖常见 Agent 需求: + +| 工具 | 说明 | +|------|------| +| `ShellTool` | 执行 Shell 命令,白名单安全机制 + 用户确认 | +| `WebSearchTool` | DuckDuckGo / Bing 网页搜索 | +| `BaiduSearchTool` | 百度搜索 | +| `WebCrawlTool` | 网页抓取与内容提取 | +| `MemoryTool` | 短期/长期记忆管理 | +| `AskHumanTool` | 向用户提问获取信息 | +| `SchemaExtractTool` | 从文本提取结构化数据 | +| `SchemaGenerateTool` | 生成 JSON Schema | +| `MCPTool` | MCP 协议工具扩展 | + +工具组合:`SequentialChain`(顺序链)、`ParallelFanOut`(并行扇出)、`DynamicSelector`(动态选择)。 + +### 8. Pipeline 编排 + +多 Agent 协同编排,支持复杂工作流: + +- **PipelineEngine** -- 阶段式流水线执行,支持自适应配置 +- **SagaOrchestrator** -- 分布式事务补偿,失败自动回滚 +- **DynamicPipeline** -- 运行时动态调整流水线结构 +- **PipelineReflector** -- 执行反思与重规划 +- **HandoffManager** -- Agent 间任务移交 + ## 架构图 ``` - +------------------+ - | User Request | - +--------+---------+ - | - v - +-------------+--------------+ - | IntentRouter | - | (keyword -> LLM classify) | - +-------------+--------------+ - | - matched_skill - | - v - +-------------+--------------+ - | ConfigDrivenAgent | - | (SkillConfig-driven) | - +-------------+--------------+ - | - +------------+------------+ - | | - v v - +---------+--------+ +----------+---------+ - | ReActEngine | | Traditional Mode | - | Think->Act->Observe| | llm_generate/ | - +---------+--------+ | tool_call/custom | - | +--------------------+ - v - +----------+----------+ - | LLM Gateway | - | resolve -> chat | - | fallback -> track | - +----------+----------+ - | - +------+------+ - | | - v v - +-----+----+ +-----+-----+ - | Provider A| | Provider B| ... - +-----+----+ +-----+-----+ - | | - v v - +-----+----+ +-----+-----+ - | Tool 1 | | Tool 2 | ... - +-----------+ +-----------+ - - | - v - +----------+----------+ - | Quality Gate | - | required_fields | - | min_word_count | - | schema validation | - | custom validator | - +----------+----------+ - | - v - +----------+----------+ - | OutputStandardizer | - | schema + normalize | - | + metadata | - +----------+----------+ - | - v - StandardOutput + +-------------------+ +-------------------+ + | Web GUI Chat | | CLI Chat | + | (WebSocket) | | (agentkit chat) | + +--------+----------+ +--------+----------+ + | | + +----------+----------+ + | + +----------v----------+ + | Skill Routing | + | (keyword -> LLM) | + +----------+----------+ + | + matched_skill + | + +-------------------v-------------------+ + | ConfigDrivenAgent | + | (SkillConfig-driven) | + +-------------------+------------------+ + | + +--------------+--------------+ + | | + v v + +---------+--------+ +----------+---------+ + | ReActEngine | | Traditional Mode | + | Think->Act->Observe| | llm_generate/ | + +---------+--------+ | tool_call/custom | + | +---------------------+ + v + +----------+----------+ + | LLM Gateway | + | resolve -> chat | + | fallback -> track | + +----------+----------+ + | + +------+------+ + | | + v v + +-----+----+ +-----+-----+ + | DashScope | | OpenAI | ... + +-----+----+ +-----+-----+ + | + +----------+----------+ + | Tool Registry | + | shell / search / | + | crawl / memory / ... | + +----------+----------+ + | + v + +----------+----------+ + | Quality Gate | + | required_fields | + | min_word_count | + | schema validation | + | custom validator | + +----------+----------+ + | + v + +----------+----------+ + | OutputStandardizer | + | schema + normalize | + | + metadata | + +----------+----------+ + | + v + StandardOutput ``` ## 快速开始 @@ -147,7 +178,13 @@ agentkit version # 初始化项目(生成配置文件) agentkit init -# 启动 Server +# 启动 Web GUI 聊天界面(推荐) +agentkit gui --port 8002 + +# 启动 CLI 聊天 +agentkit chat + +# 启动 Server(API 模式) agentkit serve --host 0.0.0.0 --port 8001 # 健康检查 @@ -247,9 +284,9 @@ from agentkit.llm.providers.openai import OpenAIProvider async def main(): # 1. 初始化 LLM Gateway gateway = LLMGateway() - gateway.register_provider("openai", OpenAIProvider( + gateway.register_provider("dashscope", OpenAIProvider( api_key="sk-xxx", - base_url="https://api.openai.com/v1", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", )) # 2. 定义 Skill @@ -263,7 +300,7 @@ async def main(): "instructions": "根据用户需求生成高质量内容", "output_format": "以 JSON 格式输出", }, - llm={"model": "openai/gpt-4o", "temperature": 0.7}, + llm={"model": "default", "temperature": 0.7}, execution_mode="react", max_steps=5, ) @@ -318,9 +355,9 @@ from agentkit import LLMGateway from agentkit.llm.providers.openai import OpenAIProvider gateway = LLMGateway() -gateway.register_provider("openai", OpenAIProvider( +gateway.register_provider("dashscope", OpenAIProvider( api_key="sk-xxx", - base_url="https://api.openai.com/v1", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", )) app = create_app(llm_gateway=gateway) @@ -352,8 +389,8 @@ from datetime import datetime, timezone async def main(): # 初始化 Gateway gateway = LLMGateway() - gateway.register_provider("openai", OpenAIProvider( - api_key="sk-xxx", base_url="https://api.openai.com/v1", + gateway.register_provider("dashscope", OpenAIProvider( + api_key="sk-xxx", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", )) # 定义多个 Skill @@ -366,7 +403,7 @@ async def main(): "instructions": "生成 SEO 优化内容", "output_format": "JSON: {content, word_count}", }, - llm={"model": "openai/gpt-4o"}, + llm={"model": "default"}, intent={ "keywords": ["生成", "内容", "写作"], "description": "内容生成与写作", @@ -390,7 +427,7 @@ async def main(): "instructions": "优化内容以提升 AI 搜索可见性", "output_format": "JSON: {optimized_content, seo_score, changes}", }, - llm={"model": "openai/gpt-4o"}, + llm={"model": "default"}, intent={ "keywords": ["优化", "GEO", "SEO"], "description": "内容 GEO/SEO 优化", @@ -472,7 +509,7 @@ curl -X POST http://localhost:8000/api/v1/skills \ "instructions": "生成高质量内容", "output_format": "JSON: {content, word_count}" }, - "llm": {"model": "openai/gpt-4o"}, + "llm": {"model": "default"}, "intent": { "keywords": ["生成", "内容"], "description": "内容生成" @@ -546,7 +583,7 @@ async def main(): "instructions": "生成高质量内容", "output_format": "JSON: {content, word_count}", }, - "llm": {"model": "openai/gpt-4o"}, + "llm": {"model": "default"}, "intent": {"keywords": ["生成", "内容"], "description": "内容生成"}, "quality_gate": {"required_fields": ["content"], "max_retries": 2}, "execution_mode": "react", @@ -621,7 +658,7 @@ prompt: output_format: "JSON: generate_topics 返回 {topics: [{title, reason, keywords}]},generate_article 返回 {content, word_count}" llm: - model: "deepseek" + model: "default" temperature: 0.7 max_tokens: 4000 @@ -664,34 +701,30 @@ skill = Skill(config=config) ```yaml providers: + dashscope: + api_key: "${DASHSCOPE_API_KEY}" + base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1" + models: + qwen3-coder-plus: + max_tokens: 64000 + cost_per_1k_input: 0.00014 + cost_per_1k_output: 0.00028 openai: - api_key: "sk-xxx" + api_key: "${OPENAI_API_KEY}" base_url: "https://api.openai.com/v1" models: gpt-4o: cost_per_1k_input: 0.005 cost_per_1k_output: 0.015 - gpt-4o-mini: - cost_per_1k_input: 0.00015 - cost_per_1k_output: 0.0006 - deepseek: - api_key: "sk-xxx" - base_url: "https://api.deepseek.com/v1" - models: - deepseek-chat: - cost_per_1k_input: 0.001 - cost_per_1k_output: 0.002 model_aliases: - default: "deepseek/deepseek-chat" - fast: "openai/gpt-4o-mini" + default: "dashscope/qwen3-coder-plus" + fast: "dashscope/qwen3-coder-plus" powerful: "openai/gpt-4o" fallbacks: - openai/gpt-4o: - - "deepseek/deepseek-chat" - deepseek/deepseek-chat: - - "openai/gpt-4o-mini" + dashscope/qwen3-coder-plus: + - "openai/gpt-4o" ``` 加载 LLM 配置: @@ -802,7 +835,7 @@ ReActEngine 实现 Think -> Act -> Observe 循环: 统一 LLM 调用入口,核心能力: - **Provider 注册**: `gateway.register_provider("openai", provider)` -- **模型别名**: `"default"` -> `"deepseek/deepseek-chat"` +- **模型别名**: `"default"` -> `"dashscope/qwen3-coder-plus"` - **Fallback 降级**: 主模型失败时自动切换到备选模型 - **用量追踪**: 按 agent_name、model 统计 Token 用量和成本 - **模型解析**: `"provider/model"` 格式自动路由到对应 Provider @@ -878,10 +911,12 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu ### server -- FastAPI Server -独立部署模式,提供 RESTful API: +独立部署模式,提供 RESTful API 和 Web GUI: | 路径 | 方法 | 说明 | |------|------|------| +| `/` | GET | Web GUI 聊天界面 | +| `/ws/chat` | WebSocket | GUI 实时聊天通道 | | `/api/v1/agents` | POST | 创建 Agent(指定 skill_name 或 config) | | `/api/v1/agents` | GET | 列出所有 Agent | | `/api/v1/agents/{name}` | GET | 获取 Agent 详情 | @@ -892,6 +927,27 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu | `/api/v1/llm/usage` | GET | 查询 LLM 用量 | | `/api/v1/health` | GET | 健康检查 | +### Web GUI 聊天界面 + +通过 `agentkit gui` 启动,特性: + +- **实时对话** -- WebSocket 流式传输,逐 token 显示 +- **Markdown 渲染** -- 自动检测并渲染标题、列表、代码块、表格等 +- **工具确认卡片** -- 危险命令(如 `rm`)执行前弹出确认卡片,用户批准后才执行 +- **Loading 动画** -- 等待 AI 响应时显示思考动画 +- **Skill 路由** -- 输入 `@skill_name:` 前缀可指定使用特定 Skill +- **会话管理** -- 多会话并行,历史记录持久化 + +### orchestrator -- Pipeline 编排 + +多 Agent 协同编排模块: + +- **PipelineEngine** -- 按 Stage 定义顺序执行,支持自适应配置和反思重规划 +- **SagaOrchestrator** -- 分布式事务补偿,失败步骤自动执行补偿操作 +- **DynamicPipeline** -- 运行时根据条件动态调整流水线结构 +- **HandoffManager** -- Agent 间任务移交,支持上下文传递 +- **PipelineStateMemory/Redis/PG** -- 流水线状态持久化,支持内存、Redis、PostgreSQL 后端 + ## 配置参考 ### SkillConfig @@ -941,8 +997,8 @@ v2 增强:接受 SkillConfig 时自动创建 Skill 并启用 ReAct 模式,Qu | 字段 | 类型 | 默认值 | 说明 | |------|------|--------|------| | `providers` | dict[str, ProviderConfig] | `{}` | Provider 配置,key 为 provider 名称 | -| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "deepseek/deepseek-chat"` | -| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `openai/gpt-4o: ["deepseek/deepseek-chat"]` | +| `model_aliases` | dict[str, str] | `{}` | 模型别名映射,如 `default: "dashscope/qwen3-coder-plus"` | +| `fallbacks` | dict[str, list[str]] | `{}` | 降级策略,如 `dashscope/qwen3-coder-plus: ["openai/gpt-4o"]` | #### ProviderConfig @@ -977,9 +1033,9 @@ from agentkit import LLMGateway from agentkit.llm.providers.openai import OpenAIProvider gateway = LLMGateway() -gateway.register_provider("deepseek", OpenAIProvider( +gateway.register_provider("dashscope", OpenAIProvider( api_key="sk-xxx", - base_url="https://api.deepseek.com/v1", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", )) app = create_app(llm_gateway=gateway) @@ -1135,6 +1191,8 @@ tools: - search ``` +**ShellTool 安全机制**:ShellTool 内置白名单(`ls`、`cat`、`curl` 等安全命令直接执行),非白名单命令会触发用户确认。在 GUI 中以确认卡片形式展示,用户点击"确认执行"后才运行。 + ### 代码风格 项目使用 Ruff 进行代码检查和格式化: diff --git a/agentkit.yaml b/agentkit.yaml index a5b0795..46b1ebb 100644 --- a/agentkit.yaml +++ b/agentkit.yaml @@ -40,3 +40,4 @@ logging: format: text router: classifier: heuristic + auction_enabled: false diff --git a/configs/llm_config.yaml b/configs/llm_config.yaml index c789dbf..ebee7cc 100644 --- a/configs/llm_config.yaml +++ b/configs/llm_config.yaml @@ -2,18 +2,9 @@ # 环境变量替换:${VAR_NAME} 在启动时由 LLMConfig.from_yaml() 处理 providers: - deepseek: - api_key: "${DEEPSEEK_API_KEY}" - base_url: "https://api.deepseek.com/v1" - models: - deepseek-chat: - max_tokens: 64000 - cost_per_1k_input: 0.00014 - cost_per_1k_output: 0.00028 - - openai: - api_key: "${OPENAI_API_KEY}" - base_url: "${OPENAI_BASE_URL:-https://coding.dashscope.aliyuncs.com/v1}" + dashscope: + api_key: "${DASHSCOPE_API_KEY}" + base_url: "${DASHSCOPE_BASE_URL:-https://dashscope.aliyuncs.com/compatible-mode/v1}" models: qwen3-coder-plus: max_tokens: 64000 @@ -21,13 +12,9 @@ providers: cost_per_1k_output: 0.00028 model_aliases: - default: "openai/qwen3-coder-plus" - fast: "openai/qwen3-coder-plus" - powerful: "openai/qwen3-coder-plus" - -fallbacks: - openai/qwen3-coder-plus: - - "deepseek/deepseek-chat" + default: "dashscope/qwen3-coder-plus" + fast: "dashscope/qwen3-coder-plus" + powerful: "dashscope/qwen3-coder-plus" # 上下文压缩配置 — 长会话自动压缩历史消息,保持 Token 在预算内 # GEO Pipeline 启用后,工具输出(搜索结果、网页抓取等)会自动压缩 diff --git a/configs/skills/rewoo_agent.yaml b/configs/skills/rewoo_agent.yaml index 874ffd5..08c5508 100644 --- a/configs/skills/rewoo_agent.yaml +++ b/configs/skills/rewoo_agent.yaml @@ -6,6 +6,10 @@ task_mode: llm_generate execution_mode: rewoo max_steps: 8 max_concurrency: 3 +fallback_strategies: + - simplified_rewoo + - react + - direct intent: keywords: ["采集", "批量", "并行", "fetch", "collect", "数据获取", "多源"] diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index c8e6e09..74188dd 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -12,6 +12,7 @@ import re from dataclasses import dataclass, field from typing import Any +from agentkit.marketplace.auction import AuctionHouse, Bid from agentkit.telemetry.tracer import get_tracer logger = logging.getLogger(__name__) @@ -256,6 +257,8 @@ _CHAT_MODE_RE = re.compile( re.IGNORECASE, ) +_SENTENCE_SPLIT_RE = re.compile(r'[,。!?;\n,.!?;]') + def _tokenize_content(content: str) -> list[str]: """Tokenize content for capability matching. Supports Chinese and English.""" @@ -394,7 +397,7 @@ class HeuristicClassifier: score += 0.05 # 3. 多句加成(逗号/句号/换行分隔) - sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content)) + sentence_count = len(_SENTENCE_SPLIT_RE.split(content)) if sentence_count >= 4: score += 0.1 elif sentence_count >= 2: @@ -424,12 +427,15 @@ class CostAwareRouter: org_context: Any = None, auction_enabled: bool = False, classifier: str = "heuristic", + merged_llm_classify: bool = True, ): self._llm_gateway = llm_gateway self._model = model self._org_context = org_context self._auction_enabled = auction_enabled self._classifier = classifier + self._merged_llm_classify = merged_llm_classify + self._auction_house = AuctionHouse() if auction_enabled else None if classifier not in ("heuristic", "llm"): raise ValueError(f"Invalid classifier: {classifier!r}, must be 'heuristic' or 'llm'") self._heuristic = HeuristicClassifier() @@ -489,6 +495,175 @@ class CostAwareRouter: logger.warning(f"CostAwareRouter quick_classify failed: {e}") return 0.5 + # -- Layer 1.5: Merged LLM classify (complexity + intent in one call) --- + + async def _classify_merged( + self, + content: str, + skill_registry: Any, + intent_router: Any, + default_tools: list, + default_system_prompt: str | None, + default_model: str, + default_agent_name: str, + agent_tool_registry: Any = None, + session_id: str = "", + complexity: float = 0.5, + ) -> SkillRoutingResult: + """合并 LLM 调用:单次 LLM 同时输出 complexity + intent + skill_hint。 + + 当 HeuristicClassifier 返回不确定区间 (0.3-0.7) 时使用, + 替代分别调用 quick_classify() 和 IntentRouter._classify_with_llm(), + 节省 1 次 LLM 调用 (~1-3s)。 + """ + if self._llm_gateway is None or not self._merged_llm_classify: + # Fallback: 使用独立的 IntentRouter 路由 + return await resolve_skill_routing( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + ) + + # Build skill list for the prompt + skill_hints = [] + if skill_registry: + try: + for s in skill_registry.list_skills(): + if s.config.intent and s.config.intent.keywords: + skill_hints.append(s.name) + except Exception: + pass + + skill_list_str = ", ".join(skill_hints) if skill_hints else "none" + + prompt = ( + 'You are a routing classifier. Analyze the user request and output:\n' + '1. complexity (0.0-1.0): how complex is this request\n' + '2. intent: the primary intent category\n' + '3. skill_hint: the best matching skill name, or null if none match\n\n' + f'Available skills: [{skill_list_str}]\n\n' + '---BEGIN USER REQUEST---\n' + f'{content}\n' + '---END USER REQUEST---\n\n' + 'Respond ONLY with a JSON object: {"complexity": , "intent": , "skill_hint": }' + ) + + try: + response = await self._llm_gateway.chat( + messages=[{"role": "user", "content": prompt}], + model=self._model, + ) + data = json.loads(response.content.strip()) + merged_complexity = float(data.get("complexity", 0.5)) + merged_complexity = max(0.0, min(1.0, merged_complexity)) + skill_hint = data.get("skill_hint") + + # If skill_hint provided and valid, route directly to that skill + if skill_hint and skill_registry: + try: + matched_skill = skill_registry.get(skill_hint) + result = SkillRoutingResult( + clean_content=content, + skill_name=skill_hint, + skill_config=matched_skill.config, + skill_tools=matched_skill.tools or [], + matched=True, + match_method="merged_llm", + match_confidence=0.7, + complexity=merged_complexity, + ) + # Merge tools + agent_tools = agent_tool_registry.list_tools() if agent_tool_registry else default_tools + seen_names = set() + merged_tools = [] + for tool in result.skill_tools + agent_tools: + if tool.name not in seen_names: + seen_names.add(tool.name) + merged_tools.append(tool) + result.tools = merged_tools + result.model = result.skill_config.llm.get("model", default_model) if result.skill_config.llm else default_model + result.agent_name = skill_hint + result.system_prompt = build_skill_system_prompt(result.skill_config) or default_system_prompt + logger.info( + f"Session {session_id}: merged LLM classify routed to skill '{skill_hint}' " + f"(complexity={merged_complexity:.2f})" + ) + return result + except Exception as e: + logger.warning(f"Session {session_id}: merged LLM skill_hint '{skill_hint}' not found: {e}") + + # No valid skill_hint — use complexity to decide routing + if merged_complexity < 0.3: + return SkillRoutingResult( + clean_content=content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="merged_llm_low", + match_confidence=1.0 - merged_complexity, + complexity=merged_complexity, + ) + elif merged_complexity > 0.7: + # High complexity — delegate to Layer 2 + return SkillRoutingResult( + clean_content=content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="merged_llm_high", + match_confidence=merged_complexity, + complexity=merged_complexity, + ) + else: + # Medium complexity, no skill match — default agent + return SkillRoutingResult( + clean_content=content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="merged_llm_medium", + match_confidence=0.5, + complexity=merged_complexity, + ) + except (json.JSONDecodeError, TypeError, ValueError) as e: + logger.warning(f"CostAwareRouter _classify_merged parse failed: {e}, falling back to default") + return SkillRoutingResult( + clean_content=content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="merged_llm_fallback", + match_confidence=0.5, + complexity=0.5, + ) + except Exception as e: + logger.warning(f"CostAwareRouter _classify_merged failed: {e}, falling back to default") + return SkillRoutingResult( + clean_content=content, + system_prompt=default_system_prompt, + tools=default_tools, + model=default_model, + agent_name=default_agent_name, + matched=False, + match_method="merged_llm_fallback", + match_confidence=0.5, + complexity=0.5, + ) + # -- Layer 2: Capability matching / Auction (optional) ----------------- async def _route_layer2( @@ -505,12 +680,89 @@ class CostAwareRouter: complexity: float = 0.0, trace: list[dict] | None = None, ) -> SkillRoutingResult: - """Layer 2: 高复杂度任务通过 org_context.find_best_agent 路由。""" + """Layer 2: 高复杂度任务通过拍卖或 org_context.find_best_agent 路由。""" + # Extract capability-like keywords from content for matching + content_words = _tokenize_content(content) + + # --- Vickrey auction path (when enabled) --- + if self._auction_enabled and self._auction_house is not None and self._org_context is not None: + try: + # Gather candidate agents from org_context + all_agents = self._org_context.list_agents() if hasattr(self._org_context, "list_agents") else [] + # Filter agents that have at least one relevant capability + candidate_agents = [] + for agent_profile in all_agents: + if not agent_profile.availability: + continue + # Check if agent has any of the content_words as capabilities + agent_caps_lower = {c.lower() for c in agent_profile.capabilities} + if any(w.lower() in agent_caps_lower for w in content_words): + candidate_agents.append(agent_profile) + + # Also include agents that match via find_best_agent (they have ALL required caps) + best = self._org_context.find_best_agent(required_capabilities=content_words) + if best is not None: + best_name = best if isinstance(best, str) else getattr(best, "name", str(best)) + existing_names = {a.name for a in candidate_agents} + if best_name not in existing_names: + profile = self._org_context.get_agent_profile(best_name) if hasattr(self._org_context, "get_agent_profile") else best + if hasattr(profile, "name"): + candidate_agents.append(profile) + + if len(candidate_agents) >= 1: + # Build Bid objects for each candidate + bids = [] + for agent_profile in candidate_agents: + name = agent_profile.name if hasattr(agent_profile, "name") else str(agent_profile) + caps = agent_profile.capabilities if hasattr(agent_profile, "capabilities") else [] + arch = agent_profile.agent_type if hasattr(agent_profile, "agent_type") else "react" + # Use current_load as a proxy for estimated_cost (higher load → higher cost) + estimated_cost = float(agent_profile.current_load + 1) if hasattr(agent_profile, "current_load") else 1.0 + bids.append(Bid( + agent_name=name, + architecture=arch, + estimated_steps=1, + estimated_cost=estimated_cost, + confidence=0.8, + payment_offer=estimated_cost, + capabilities=caps, + )) + + auction_result = await self._auction_house.run_vickrey_auction( + task_description=content, + bidders=bids, + required_capabilities=content_words, + ) + + if auction_result.winner is not None: + winner_name = auction_result.winner.agent_name + result = SkillRoutingResult( + clean_content=content, + matched=True, + match_method="vickrey_auction", + match_confidence=0.8, + agent_name=winner_name, + model=default_model, + system_prompt=default_system_prompt, + tools=default_tools, + complexity=complexity, + ) + if trace is not None: + trace.append({ + "layer": 2, + "method": "vickrey_auction", + "agent_name": winner_name, + "complexity": complexity, + "selection_reason": auction_result.selection_reason, + }) + return result + # No winner from auction → fall through to capability matching + except Exception as e: + logger.warning(f"CostAwareRouter Layer 2 Vickrey auction failed: {e}") + + # --- Capability matching path (default) --- if self._org_context is not None and hasattr(self._org_context, "find_best_agent"): try: - # Extract capability-like keywords from content for matching - # find_best_agent expects list[str] of required capabilities - content_words = _tokenize_content(content) best_agent = self._org_context.find_best_agent(required_capabilities=content_words) if best_agent is not None: agent_name = best_agent if isinstance(best_agent, str) else getattr(best_agent, "name", str(best_agent)) @@ -689,29 +941,69 @@ class CostAwareRouter: span.set_attribute("route.target", "default") return result - # Medium complexity → IntentRouter via resolve_skill_routing + # Medium complexity → merged LLM classify or IntentRouter if complexity <= 0.7: - result = await resolve_skill_routing( - content=content, - skill_registry=skill_registry, - intent_router=intent_router, - default_tools=default_tools, - default_system_prompt=default_system_prompt, - default_model=default_model, - default_agent_name=default_agent_name, - agent_tool_registry=agent_tool_registry, - session_id=session_id, - ) - result.complexity = complexity + if self._merged_llm_classify and self._llm_gateway is not None: + # Use merged LLM call: complexity + intent in one call + result = await self._classify_merged( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + complexity=complexity, + ) + # If merged classify returned high complexity, delegate to Layer 2 + if result.complexity > 0.7 and result.match_method and result.match_method.startswith("merged_llm_high"): + trace.append({ + "layer": 1, + "method": "merged_llm_high", + "complexity": result.complexity, + "delegated_to_layer2": True, + }) + layer2_result = await self._route_layer2( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + complexity=result.complexity, + trace=trace, + ) + layer2_result.execution_trace = trace if transparency != "SILENT" else [] + layer2_result.transparency_level = transparency + return layer2_result + else: + # Fallback: use separate IntentRouter + result = await resolve_skill_routing( + content=content, + skill_registry=skill_registry, + intent_router=intent_router, + default_tools=default_tools, + default_system_prompt=default_system_prompt, + default_model=default_model, + default_agent_name=default_agent_name, + agent_tool_registry=agent_tool_registry, + session_id=session_id, + ) + result.complexity = result.complexity or complexity trace.append({ "layer": 1, - "method": "intent_router", - "complexity": complexity, + "method": result.match_method or "merged_llm", + "complexity": result.complexity, "matched": result.matched, }) result.execution_trace = trace if transparency != "SILENT" else [] result.transparency_level = transparency - span.set_attribute("route.layer", result.match_method or "intent_router") + span.set_attribute("route.layer", result.match_method or "merged_llm") span.set_attribute("route.target", result.skill_name or "default") return result diff --git a/src/agentkit/core/agent_pool.py b/src/agentkit/core/agent_pool.py index 1525390..2f936a2 100644 --- a/src/agentkit/core/agent_pool.py +++ b/src/agentkit/core/agent_pool.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.protocol import AgentStatus +from agentkit.core.react import ReActEngine from agentkit.llm.gateway import LLMGateway from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry @@ -53,6 +54,16 @@ class AgentPool: compressor=self._compressor, ) await agent.start() + + # Bind a reusable ReActEngine instance to the agent + max_steps = getattr(config, "max_steps", 10) or 10 + parallel_tools = getattr(config, "parallel_tools", False) or False + agent._react_engine = ReActEngine( + llm_gateway=self._llm_gateway, + max_steps=max_steps, + parallel_tools=parallel_tools, + ) + self._agents[config.name] = agent logger.info(f"Agent '{config.name}' created and started in pool") diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 1f5246d..26b844b 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -74,14 +74,27 @@ class ReActEngine: 使 Agent 能够自主推理并选择工具完成任务。 """ - def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = False): + def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool | str = False): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") + if isinstance(parallel_tools, str) and parallel_tools not in ("auto",): + raise ValueError(f"parallel_tools must be True, False, or 'auto', got {parallel_tools!r}") self._llm_gateway = llm_gateway self._max_steps = max_steps self._default_timeout = default_timeout self._parallel_tools = parallel_tools + def reset(self) -> None: + """Reset internal state for reuse across conversations. + + Call this before each execute/execute_stream to ensure clean state. + The engine itself (LLM gateway, config) is preserved. + """ + # ReActEngine is stateless between calls — conversation history, + # step counts, and trajectory are local to each execute call. + # This method exists for API clarity and future stateful extensions. + pass + async def execute( self, messages: list[dict[str, str]], @@ -97,6 +110,7 @@ class ReActEngine: retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, timeout_seconds: float | None = None, + confirmation_handler: Any | None = None, ) -> ReActResult: """执行 ReAct 循环 @@ -127,6 +141,7 @@ class ReActEngine: compressor=compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, ), timeout=effective_timeout, ) @@ -144,6 +159,7 @@ class ReActEngine: compressor=compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, ) except asyncio.TimeoutError: raise TaskTimeoutError( @@ -169,6 +185,7 @@ class ReActEngine: compressor: "CompressionStrategy | None" = None, retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, + confirmation_handler: Any | None = None, ) -> ReActResult: tools = tools or [] tool_schemas = self._build_tool_schemas(tools) if tools else None @@ -224,7 +241,7 @@ class ReActEngine: else: system_prompt = f"## 参考信息\n{memory_context}" except Exception as e: - logger.warning(f"Memory retrieval failed, continuing without context: {e}") + logger.warning(f"Memory retrieval failed, continuing without context: {e}", exc_info=True) # 构建初始消息 conversation: list[dict[str, Any]] = [] @@ -295,8 +312,70 @@ class ReActEngine: conversation.append(assistant_msg) # 执行工具调用 - if self._parallel_tools and len(response.tool_calls) > 1: - # 并行执行多个工具调用 + if self._parallel_tools == "auto" and len(response.tool_calls) > 1: + # Auto mode: mixed parallel/serial based on _parallelizable flag + parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) + serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] + parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] + + # Result slots indexed by original position + all_results: list[Any] = [None] * len(response.tool_calls) + + # Execute serial tools first (in order) + for i, tc in serial_calls: + tool_start = time.monotonic() + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + all_results[i] = (tc, tool_result, tool_duration_ms) + + # Execute parallelizable tools in parallel + if len(parallel_calls) > 1: + para_results = await asyncio.gather( + *[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], + return_exceptions=True, + ) + for j, (i, tc) in enumerate(parallel_calls): + tool_result = para_results[j] + if isinstance(tool_result, Exception): + tool_result = {"error": str(tool_result)} + all_results[i] = (tc, tool_result, 0) + elif len(parallel_calls) == 1: + i, tc = parallel_calls[0] + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + all_results[i] = (tc, tool_result, 0) + + # Process all results in original order + for i, tc in enumerate(response.tool_calls): + tc_obj, tool_result, tool_duration_ms = all_results[i] + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=tc.name, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) + trajectory.append(react_step) + + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=tool_duration_ms, + tokens_used=0, + error=tool_error, + ) + + tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + conversation.append(tool_msg) + elif self._should_execute_parallel(response.tool_calls): + # 并行执行多个工具调用 (parallel_tools=True) tool_results = await asyncio.gather( *[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], return_exceptions=True, @@ -338,6 +417,40 @@ class ReActEngine: for tc in response.tool_calls: tool_start = time.monotonic() tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + + # Handle confirmation flow + if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + confirmation_id = tool_result["confirmation_id"] + command = tool_result.get("command", "") + reason = tool_result.get("reason", "") + + approved = False + if confirmation_handler is not None: + try: + approved = await confirmation_handler(confirmation_id, command, reason) + except Exception as e: + logger.warning(f"Confirmation handler error: {e}") + + if approved: + tool = self._find_tool(tc.name, tools) + if tool and hasattr(tool, '_is_dangerous'): + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) + except Exception as e: + tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} + else: + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + else: + tool_result = { + "output": "", + "exit_code": 126, + "is_error": True, + "error_type": "permission_denied", + "message": f"用户拒绝执行命令: {command[:100]}", + } + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) react_step = ReActStep( @@ -531,6 +644,21 @@ class ReActEngine: else: logger.info("ReActEngine executing with NO tools") + # Telemetry: record agent request + agent_request_counter().add(1, {"agent.name": agent_name, "agent.type": task_type or "react"}) + + # Start telemetry span for the entire agent execution + _span_cm = None + _span = None + _exec_start = time.monotonic() + + if _OTEL_AVAILABLE: + _span_cm = start_span( + "agent.execute_stream", + attributes={"agent.name": agent_name, "agent.type": task_type or "react"}, + ) + _span = _span_cm.__enter__() + # 启动轨迹记录 if trace_recorder is not None: trace_recorder.start_trace( @@ -575,11 +703,26 @@ class ReActEngine: step = 0 output = "" trace_outcome = "success" + _stream_start = time.monotonic() + effective_timeout = timeout_seconds if timeout_seconds is not None else self._default_timeout try: while step < self._max_steps: step += 1 + # 协作式取消检查 + if cancellation_token is not None: + cancellation_token.check() + + # 超时检查 + if effective_timeout > 0: + elapsed = time.monotonic() - _stream_start + if elapsed > effective_timeout: + trace_outcome = "timeout" + raise asyncio.TimeoutError( + f"execute_stream exceeded {effective_timeout}s timeout after {elapsed:.1f}s" + ) + # Yield thinking event yield ReActEvent( event_type="thinking", @@ -591,7 +734,7 @@ class ReActEngine: llm_start = time.monotonic() # Use streaming for token-by-token output - stream_content = "" + stream_content_chunks: list[str] = [] stream_usage = None stream_tool_calls: list[Any] = [] stream_model = model @@ -604,7 +747,7 @@ class ReActEngine: tools=tool_schemas, ): if chunk.content: - stream_content += chunk.content + stream_content_chunks.append(chunk.content) yield ReActEvent( event_type="token", step=step, @@ -618,6 +761,7 @@ class ReActEngine: stream_model = chunk.model # Build response-like object from stream + stream_content = "".join(stream_content_chunks) response = self._build_response_from_stream( content=stream_content, tool_calls=stream_tool_calls, @@ -657,115 +801,168 @@ class ReActEngine: } conversation.append(assistant_msg) - for tc in response.tool_calls: - # Yield tool_call event - yield ReActEvent( - event_type="tool_call", - step=step, - data={"tool_name": tc.name, "arguments": tc.arguments}, - ) + # Execute tool calls with parallel support + if self._parallel_tools and len(response.tool_calls) > 1 and self._should_execute_parallel(response.tool_calls): + # Parallel execution path + parallelizable_set = set(self._get_parallelizable_indices(response.tool_calls)) if self._parallel_tools == "auto" else set(range(len(response.tool_calls))) + serial_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i not in parallelizable_set] + parallel_calls = [(i, tc) for i, tc in enumerate(response.tool_calls) if i in parallelizable_set] - tool_start = time.monotonic() - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + all_results: list[Any] = [None] * len(response.tool_calls) - # 检测工具返回的确认请求 - if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): - confirmation_id = tool_result["confirmation_id"] - command = tool_result.get("command", "") - reason = tool_result.get("reason", "") + # Execute serial tools first (handles confirmation flow) + for i, tc in serial_calls: + yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) + tool_start = time.monotonic() + tool_result, confirm_events = await self._execute_tool_with_confirmation(tc, tools, step, confirmation_handler) + for ev in confirm_events: + yield ev + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + all_results[i] = (tc, tool_result, tool_duration_ms) - # Yield 确认请求事件 + # Execute parallelizable tools concurrently + if len(parallel_calls) > 1: + para_results = await asyncio.gather( + *[self._execute_tool(tc.name, tc.arguments, tools) for _, tc in parallel_calls], + return_exceptions=True, + ) + for j, (i, tc) in enumerate(parallel_calls): + tool_result = para_results[j] + if isinstance(tool_result, Exception): + tool_result = {"error": str(tool_result)} + all_results[i] = (tc, tool_result, 0) + elif len(parallel_calls) == 1: + i, tc = parallel_calls[0] + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + all_results[i] = (tc, tool_result, 0) + + # Process all results in original order + for i, tc in enumerate(response.tool_calls): + tc_obj, tool_result, tool_duration_ms = all_results[i] + yield ReActEvent(event_type="tool_call", step=step, data={"tool_name": tc.name, "arguments": tc.arguments}) + + react_step = ReActStep(step=step, action="tool_call", tool_name=tc.name, arguments=tc.arguments, result=tool_result, tokens=step_tokens) + trajectory.append(react_step) + + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step(step=step, action="tool_call", tool_name=tc.name, input_data=tc.arguments, output_data=tool_result, duration_ms=tool_duration_ms, tokens_used=0, error=tool_error) + + yield ReActEvent(event_type="tool_result", step=step, data={"tool_name": tc.name, "result": tool_result}) + tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + conversation.append(tool_msg) + else: + # Serial execution path (with confirmation flow) + for tc in response.tool_calls: + # Yield tool_call event yield ReActEvent( - event_type="confirmation_request", + event_type="tool_call", step=step, - data={ - "confirmation_id": confirmation_id, - "tool_name": tc.name, - "command": command, - "reason": reason, - }, + data={"tool_name": tc.name, "arguments": tc.arguments}, ) - # 等待用户确认 - approved = False - if confirmation_handler is not None: - try: - approved = await confirmation_handler(confirmation_id, command, reason) - except Exception as e: - logger.warning(f"Confirmation handler error: {e}") - - if approved: - # 用户确认执行:临时绕过安全检查重新执行 - tool = self._find_tool(tc.name, tools) - if tool and hasattr(tool, '_is_dangerous'): - # 保存原始 _is_dangerous 并临时禁用 - original_is_dangerous = tool._is_dangerous - tool._is_dangerous = lambda cmd: False - try: - tool_result = await tool.safe_execute(**tc.arguments) - finally: - tool._is_dangerous = original_is_dangerous - else: - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - - yield ReActEvent( - event_type="confirmation_result", - step=step, - data={"confirmation_id": confirmation_id, "approved": True}, - ) - else: - # 用户拒绝执行 - tool_result = { - "output": "", - "exit_code": 126, - "is_error": True, - "error_type": "permission_denied", - "message": f"用户拒绝执行命令: {command[:100]}", - } - yield ReActEvent( - event_type="confirmation_result", - step=step, - data={"confirmation_id": confirmation_id, "approved": False}, - ) - + tool_start = time.monotonic() + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=tc.name, - arguments=tc.arguments, - result=tool_result, - tokens=step_tokens, - ) - trajectory.append(react_step) + # 检测工具返回的确认请求 + if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + confirmation_id = tool_result["confirmation_id"] + command = tool_result.get("command", "") + reason = tool_result.get("reason", "") - # 记录工具调用步骤 - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( + # Yield 确认请求事件 + yield ReActEvent( + event_type="confirmation_request", + step=step, + data={ + "confirmation_id": confirmation_id, + "tool_name": tc.name, + "command": command, + "reason": reason, + }, + ) + + # 等待用户确认 + approved = False + if confirmation_handler is not None: + try: + approved = await confirmation_handler(confirmation_id, command, reason) + except Exception as e: + logger.warning(f"Confirmation handler error: {e}") + + if approved: + # 用户确认执行:使用 per-call override 绕过安全检查 + tool = self._find_tool(tc.name, tools) + if tool and hasattr(tool, '_is_dangerous'): + # Strip internal metadata and pass skip_dangerous_check flag + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) + finally: + pass # No shared state mutation needed + else: + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + + yield ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": True}, + ) + else: + # 用户拒绝执行 + tool_result = { + "output": "", + "exit_code": 126, + "is_error": True, + "error_type": "permission_denied", + "message": f"用户拒绝执行命令: {command[:100]}", + } + yield ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": False}, + ) + + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + + react_step = ReActStep( step=step, action="tool_call", tool_name=tc.name, - input_data=tc.arguments, - output_data=tool_result, - duration_ms=tool_duration_ms, - tokens_used=0, - error=tool_error, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) + trajectory.append(react_step) + + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=tool_duration_ms, + tokens_used=0, + error=tool_error, + ) + + # Yield tool_result event + yield ReActEvent( + event_type="tool_result", + step=step, + data={"tool_name": tc.name, "result": tool_result}, ) - # Yield tool_result event - yield ReActEvent( - event_type="tool_result", - step=step, - data={"tool_name": tc.name, "result": tool_result}, - ) - - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) - conversation.append(tool_msg) + tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long if self._should_compress(conversation, compressor): @@ -893,6 +1090,17 @@ class ReActEngine: if trace_recorder is not None: trace_recorder.end_trace(outcome=trace_outcome) + # Telemetry: end span and record duration — always runs + _duration_ms = int((time.monotonic() - _exec_start) * 1000) + if _span is not None: + _span.set_attribute("agent.total_steps", len(trajectory)) + _span.set_attribute("agent.total_tokens", total_tokens) + _span.set_attribute("agent.outcome", trace_outcome) + _span.set_attribute("agent.duration_ms", _duration_ms) + if _span_cm is not None: + _span_cm.__exit__(None, None, None) + agent_duration_histogram().record(_duration_ms, {"agent.name": agent_name}) + # Memory storage: 执行后写入轨迹摘要到 EpisodicMemory if memory_retriever and hasattr(memory_retriever, "store_episode"): try: @@ -988,14 +1196,127 @@ class ReActEngine: logger.warning(error_msg) return {"error": error_msg} + # Strip internal metadata keys before passing to tool + clean_args = {k: v for k, v in arguments.items() if not k.startswith("_")} + try: - result = await tool.safe_execute(**arguments) + result = await tool.safe_execute(**clean_args) return result except Exception as e: error_msg = f"Tool '{tool_name}' execution failed: {e}" logger.warning(error_msg) return {"error": error_msg} + async def _execute_tool_with_confirmation( + self, + tc: Any, + tools: list[Tool], + step: int, + confirmation_handler: Any, + ) -> tuple[Any, list[ReActEvent]]: + """Execute a tool call with confirmation flow support. + + Used in the parallel execution path for serial (non-parallelizable) tools + that may require user confirmation before execution. + + Returns: + Tuple of (tool_result, list of ReActEvents to yield) + """ + events: list[ReActEvent] = [] + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + + # Check if tool returned a confirmation request + if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + confirmation_id = tool_result["confirmation_id"] + command = tool_result.get("command", "") + reason = tool_result.get("reason", "") + + events.append(ReActEvent( + event_type="confirmation_request", + step=step, + data={ + "confirmation_id": confirmation_id, + "tool_name": tc.name, + "command": command, + "reason": reason, + }, + )) + + # Wait for user confirmation + approved = False + if confirmation_handler is not None: + try: + approved = await confirmation_handler(confirmation_id, command, reason) + except Exception as e: + logger.warning(f"Confirmation handler error: {e}") + + if approved: + # User approved: re-execute with _skip_dangerous_check + tool = self._find_tool(tc.name, tools) + if tool and hasattr(tool, '_is_dangerous'): + clean_args = {k: v for k, v in tc.arguments.items() if not k.startswith("_")} + clean_args["_skip_dangerous_check"] = True + try: + tool_result = await tool.safe_execute(**clean_args) + except Exception as e: + tool_result = {"error": f"Tool '{tc.name}' execution failed: {e}"} + else: + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + + events.append(ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": True}, + )) + else: + # User rejected + tool_result = { + "output": "", + "exit_code": 126, + "is_error": True, + "error_type": "permission_denied", + "message": f"用户拒绝执行命令: {command[:100]}", + } + events.append(ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": False}, + )) + + return tool_result, events + + def _should_execute_parallel(self, tool_calls: list[Any]) -> bool: + """Determine if tool calls should be executed in parallel. + + - parallel_tools=True: always parallel (if >1 tool) + - parallel_tools=False: never parallel + - parallel_tools="auto": parallel if any tool_call has _parallelizable=true in arguments + """ + if len(tool_calls) <= 1: + return False + if self._parallel_tools is True: + return True + if self._parallel_tools is False: + return False + # "auto" mode: check _parallelizable metadata in tool call arguments + if self._parallel_tools == "auto": + parallelizable_indices = self._get_parallelizable_indices(tool_calls) + return len(parallelizable_indices) > 1 + return False + + def _get_parallelizable_indices(self, tool_calls: list[Any]) -> list[int]: + """Get indices of tool_calls that have _parallelizable=true in arguments. + + LLM marks parallelizable tools by including _parallelizable: true + in the tool_call arguments. + """ + indices = [] + for i, tc in enumerate(tool_calls): + args = tc.arguments if hasattr(tc, 'arguments') else {} + if isinstance(args, dict) and args.get("_parallelizable") is True: + indices.append(i) + return indices + def _parse_text_tool_calls(self, content: str) -> list[dict[str, Any]]: """从文本中解析工具调用模式 diff --git a/src/agentkit/core/rewoo.py b/src/agentkit/core/rewoo.py index 39ec7e8..a3fb88c 100644 --- a/src/agentkit/core/rewoo.py +++ b/src/agentkit/core/rewoo.py @@ -8,6 +8,7 @@ import asyncio import json import logging +import re import time from dataclasses import dataclass, field from datetime import datetime, timezone @@ -33,6 +34,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# ── Internal Exceptions ────────────────────────────────── + + +class _FallbackFailedError(Exception): + """Internal signal: a fallback strategy failed, try the next one.""" + + def __init__(self, strategy: str): + self.strategy = strategy + super().__init__(f"Fallback strategy '{strategy}' failed") + + # ── Data Structures ─────────────────────────────────────── @@ -116,13 +128,25 @@ class ReWOOEngine: """ FALLBACK_STRATEGIES = ["simplified_rewoo", "react", "direct"] + VALID_STRATEGIES = {"simplified_rewoo", "react", "direct", "plan_exec"} - def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0): + def __init__(self, llm_gateway: LLMGateway, max_plan_steps: int = 10, default_timeout: float = 300.0, fallback_strategies: list[str] | None = None): if max_plan_steps < 1: raise ValueError(f"max_plan_steps must be >= 1, got {max_plan_steps}") self._llm_gateway = llm_gateway self._max_plan_steps = max_plan_steps self._default_timeout = default_timeout + # Validate and store fallback strategies + raw_strategies = fallback_strategies if fallback_strategies is not None else self.FALLBACK_STRATEGIES + self._fallback_strategies: list[str] = [] + for s in raw_strategies: + if s in self.VALID_STRATEGIES: + self._fallback_strategies.append(s) + else: + logger.warning(f"Invalid fallback strategy '{s}', skipping. Valid: {self.VALID_STRATEGIES}") + if not self._fallback_strategies: + logger.warning("No valid fallback strategies, using defaults") + self._fallback_strategies = list(self.FALLBACK_STRATEGIES) # ReActEngine 作为 fallback self._react_engine = ReActEngine( llm_gateway=llm_gateway, @@ -145,6 +169,7 @@ class ReWOOEngine: retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, timeout_seconds: float | None = None, + confirmation_handler: Any | None = None, ) -> ReActResult: """执行 ReWOO 三阶段流程 @@ -176,6 +201,7 @@ class ReWOOEngine: compressor=compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, ), timeout=effective_timeout, ) @@ -193,6 +219,7 @@ class ReWOOEngine: compressor=compressor, retrieval_config=retrieval_config, cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, ) except asyncio.TimeoutError: raise TaskTimeoutError( @@ -218,6 +245,7 @@ class ReWOOEngine: compressor: "CompressionStrategy | None" = None, retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, + confirmation_handler: Any | None = None, ) -> ReActResult: tools = tools or [] tool_schemas = self._build_tool_schemas(tools) if tools else None @@ -296,109 +324,31 @@ class ReWOOEngine: tokens_used=planning_tokens, ) - # 如果规划失败,尝试渐进式回退 + # 如果规划失败,按配置的 fallback 策略顺序尝试回退 if plan is None: - # 尝试简化规划(max_steps=3) - logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3") - try: - plan, simplified_tokens = await self._plan_phase( - messages=messages, - tools=tools, - tool_schemas=tool_schemas, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=effective_system_prompt, - compressor=compressor, - cancellation_token=cancellation_token, - max_steps=3, - ) - total_tokens += simplified_tokens - if plan is not None and plan.steps: - fallback_strategy = "simplified_rewoo" - logger.info("Simplified ReWOO planning succeeded") - except Exception as e2: - logger.warning(f"Simplified ReWOO planning also failed: {e2}") - - if plan is None: - # 回退到 ReAct - fallback_strategy = "react" - logger.warning("ReWOO planning failed, falling back to ReActEngine") - if trace_recorder is not None: - trace_recorder.end_trace(outcome="fallback") - try: - react_result = await self._react_engine.execute( - messages=messages, - tools=tools, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=system_prompt, - trace_recorder=trace_recorder, - memory_retriever=memory_retriever, - task_id=task_id, - compressor=compressor, - retrieval_config=retrieval_config, - cancellation_token=cancellation_token, - timeout_seconds=0, # timeout already handled by outer wrapper - ) - react_result.fallback_strategy = fallback_strategy - return react_result - except Exception as react_err: - # ReAct 也失败,回退到 Direct(简单 LLM 调用) - fallback_strategy = "direct" - logger.warning(f"ReAct fallback also failed: {react_err}, falling back to direct LLM call") - try: - direct_messages: list[dict[str, Any]] = [] - if effective_system_prompt: - direct_messages.append({"role": "system", "content": effective_system_prompt}) - direct_messages.extend(messages) - - if compressor: - try: - direct_messages = await compressor.compress(direct_messages) - except Exception as e: - logger.warning(f"Context compression failed in direct fallback: {e}") - - direct_response = await self._llm_gateway.chat( - messages=direct_messages, - model=model, - agent_name=agent_name, - task_type=task_type, - ) - total_tokens += direct_response.usage.total_tokens - - direct_step = ReWOOStep( - step=1, - action="final_answer", - content=direct_response.content, - tokens=direct_response.usage.total_tokens, - plan_step_id=None, - ) - trajectory.append(direct_step) - - if trace_recorder is not None: - trace_recorder.record_step( - step=1, - action="final_answer", - output_data={"content": direct_response.content}, - tokens_used=direct_response.usage.total_tokens, - ) - - trace_outcome = "success" - if trace_recorder is not None: - trace_recorder.end_trace(outcome=trace_outcome) - - return ReActResult( - output=direct_response.content or "", - trajectory=trajectory, - total_steps=len(trajectory), - total_tokens=total_tokens, - fallback_strategy=fallback_strategy, - ) - except Exception as direct_err: - logger.error(f"Direct LLM fallback also failed: {direct_err}") - raise + fallback_strategy = await self._try_fallback_strategies( + strategies=self._fallback_strategies, + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + effective_system_prompt=effective_system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + trajectory=trajectory, + total_tokens=total_tokens, + confirmation_handler=confirmation_handler, + ) + if fallback_strategy is not None: + return fallback_strategy + # All fallback strategies exhausted + raise RuntimeError("All ReWOO fallback strategies exhausted") # 如果计划为空(无需工具),直接让 LLM 回答 if not plan.steps: @@ -579,6 +529,7 @@ class ReWOOEngine: retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, timeout_seconds: float | None = None, + confirmation_handler: Any | None = None, ): """Execute ReWOO flow, yielding ReActEvent objects. @@ -627,7 +578,6 @@ class ReWOOEngine: trace_outcome = "success" try: - # ── Phase 1: Planning ── yield ReActEvent( event_type="planning", step=0, @@ -648,83 +598,27 @@ class ReWOOEngine: total_tokens += planning_tokens if plan is None: - # Try simplified planning - logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3") - try: - plan, simplified_tokens = await self._plan_phase( - messages=messages, - tools=tools, - tool_schemas=tool_schemas, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=effective_system_prompt, - compressor=compressor, - cancellation_token=cancellation_token, - max_steps=3, - ) - total_tokens += simplified_tokens - except Exception as e2: - logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e2}") - - if plan is None: - # Planning failed, fall back to ReAct streaming - logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine") - try: - async for event in self._react_engine.execute_stream( - messages=messages, - tools=tools, - model=model, - agent_name=agent_name, - task_type=task_type, - system_prompt=system_prompt, - trace_recorder=trace_recorder, - memory_retriever=memory_retriever, - task_id=task_id, - compressor=compressor, - retrieval_config=retrieval_config, - cancellation_token=cancellation_token, - timeout_seconds=0, - ): - yield event - return - except Exception as react_err: - # ReAct also failed, fall back to direct LLM call - logger.warning(f"ReAct fallback also failed in stream mode: {react_err}, falling back to direct LLM call") - try: - direct_messages: list[dict[str, Any]] = [] - if effective_system_prompt: - direct_messages.append({"role": "system", "content": effective_system_prompt}) - direct_messages.extend(messages) - - if compressor: - try: - direct_messages = await compressor.compress(direct_messages) - except Exception as e: - logger.warning(f"Context compression failed in direct fallback: {e}") - - direct_response = await self._llm_gateway.chat( - messages=direct_messages, - model=model, - agent_name=agent_name, - task_type=task_type, - ) - total_tokens += direct_response.usage.total_tokens - output = direct_response.content or "" - - yield ReActEvent( - event_type="final_answer", - step=1, - data={ - "output": output, - "total_steps": 1, - "total_tokens": total_tokens, - }, - ) - return - except Exception as direct_err: - logger.error(f"Direct LLM fallback also failed in stream mode: {direct_err}") - raise + # Planning failed, try fallback strategies in configured order + async for event in self._try_fallback_strategies_stream( + strategies=self._fallback_strategies, + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + effective_system_prompt=effective_system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + total_tokens=total_tokens, + confirmation_handler=confirmation_handler, + ): + yield event + return yield ReActEvent( event_type="plan_generated", @@ -875,8 +769,11 @@ class ReWOOEngine: "total_tokens": total_tokens, }, ) + except Exception as e: + trace_outcome = "error" + logger.error(f"ReWOO execute_stream failed: {e}") + raise finally: - # 结束轨迹记录 if trace_recorder is not None: trace_recorder.end_trace(outcome=trace_outcome) @@ -892,6 +789,582 @@ class ReWOOEngine: except Exception as e: logger.warning(f"Failed to store task result in episodic memory: {e}") + # ── Fallback Strategy Helpers ────────────────────────── + + async def _try_fallback_strategies_stream( + self, + strategies: list[str], + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + effective_system_prompt: str | None = None, + trace_recorder: "TraceRecorder | None" = None, + memory_retriever: "MemoryRetriever | None" = None, + task_id: str | None = None, + compressor: "CompressionStrategy | None" = None, + retrieval_config: dict[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + total_tokens: int = 0, + confirmation_handler: Any | None = None, + ): + """Stream version: try fallback strategies in configured order, yielding events from the first successful one. + + If all strategies fail, raises RuntimeError. + """ + for strategy in strategies: + if strategy == "simplified_rewoo": + try: + async for event in self._fallback_simplified_rewoo_stream( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, cancellation_token=cancellation_token, + ): + yield event + return + except _FallbackFailedError: + continue + + elif strategy == "react": + try: + async for event in self._fallback_react_stream( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, + ): + yield event + return + except _FallbackFailedError: + continue + + elif strategy == "direct": + try: + async for event in self._fallback_direct_stream( + messages=messages, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, total_tokens=total_tokens, + ): + yield event + return + except _FallbackFailedError: + continue + + elif strategy == "plan_exec": + try: + async for event in self._fallback_plan_exec_stream( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, cancellation_token=cancellation_token, + ): + yield event + return + except _FallbackFailedError: + continue + + raise RuntimeError("All ReWOO fallback strategies exhausted in stream mode") + + async def _fallback_simplified_rewoo_stream( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + cancellation_token: CancellationToken | None = None, + ): + """Stream: Simplified ReWOO fallback with max_steps=3""" + logger.warning("ReWOO planning failed in stream mode, trying simplified planning with max_steps=3") + try: + tool_schemas = self._build_tool_schemas(tools) if tools else None + plan, simplified_tokens = await self._plan_phase( + messages=messages, tools=tools or [], tool_schemas=tool_schemas, + model=model, agent_name=agent_name, task_type=task_type, + system_prompt=effective_system_prompt, compressor=compressor, + cancellation_token=cancellation_token, max_steps=3, + ) + if plan is not None and plan.steps: + logger.info("Simplified ReWOO planning succeeded in stream mode") + yield ReActEvent(event_type="plan_generated", step=0, data={ + "reasoning": plan.reasoning, + "steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps], + }) + tool_results: list[dict[str, Any]] = [] + for plan_step in plan.steps: + if cancellation_token is not None: + cancellation_token.check() + yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments}) + tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or []) + tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning}) + yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result}) + + yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."}) + output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token) + yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": simplified_tokens + synthesis_tokens}) + return + except Exception as e: + logger.warning(f"Simplified ReWOO planning also failed in stream mode: {e}") + # Failed, continue to next strategy by not returning + # This signals the caller to try the next strategy + # We need a different approach - raise a specific exception + raise _FallbackFailedError("simplified_rewoo") + + async def _fallback_react_stream( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + trace_recorder: "TraceRecorder | None" = None, + memory_retriever: "MemoryRetriever | None" = None, + task_id: str | None = None, + compressor: "CompressionStrategy | None" = None, + retrieval_config: dict[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + confirmation_handler: Any | None = None, + ): + """Stream: ReAct fallback""" + logger.warning("ReWOO planning failed in stream mode, falling back to ReActEngine") + try: + async for event in self._react_engine.execute_stream( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + system_prompt=system_prompt, trace_recorder=trace_recorder, + memory_retriever=memory_retriever, task_id=task_id, + compressor=compressor, retrieval_config=retrieval_config, + cancellation_token=cancellation_token, timeout_seconds=0, + confirmation_handler=confirmation_handler, + ): + yield event + return + except Exception as e: + logger.warning(f"ReAct fallback also failed in stream mode: {e}") + raise _FallbackFailedError("react") + + async def _fallback_direct_stream( + self, + messages: list[dict[str, str]], + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + total_tokens: int = 0, + ): + """Stream: Direct LLM fallback""" + logger.warning("Falling back to direct LLM call in stream mode") + try: + direct_messages: list[dict[str, Any]] = [] + if effective_system_prompt: + direct_messages.append({"role": "system", "content": effective_system_prompt}) + direct_messages.extend(messages) + if compressor: + try: + direct_messages = await compressor.compress(direct_messages) + except Exception as e: + logger.warning(f"Context compression failed in direct fallback: {e}") + direct_response = await self._llm_gateway.chat(messages=direct_messages, model=model, agent_name=agent_name, task_type=task_type) + output = direct_response.content or "" + yield ReActEvent(event_type="final_answer", step=1, data={"output": output, "total_steps": 1, "total_tokens": total_tokens + direct_response.usage.total_tokens}) + return + except Exception as e: + logger.error(f"Direct LLM fallback also failed in stream mode: {e}") + raise _FallbackFailedError("direct") + + async def _fallback_plan_exec_stream( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + cancellation_token: CancellationToken | None = None, + ): + """Stream: Plan-Exec fallback with max_steps=5""" + logger.warning("Falling back to plan-exec mode in stream mode (max_steps=5)") + try: + tool_schemas = self._build_tool_schemas(tools) if tools else None + plan, plan_tokens = await self._plan_phase( + messages=messages, tools=tools or [], tool_schemas=tool_schemas, + model=model, agent_name=agent_name, task_type=task_type, + system_prompt=effective_system_prompt, compressor=compressor, + cancellation_token=cancellation_token, max_steps=5, + ) + if plan is not None and plan.steps: + yield ReActEvent(event_type="plan_generated", step=0, data={ + "reasoning": plan.reasoning, + "steps": [{"step_id": s.step_id, "tool_name": s.tool_name, "arguments": s.arguments, "reasoning": s.reasoning} for s in plan.steps], + }) + tool_results: list[dict[str, Any]] = [] + for plan_step in plan.steps: + if cancellation_token is not None: + cancellation_token.check() + yield ReActEvent(event_type="tool_call", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "arguments": plan_step.arguments}) + tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or []) + tool_results.append({"step_id": plan_step.step_id, "tool_name": plan_step.tool_name, "arguments": plan_step.arguments, "result": tool_result, "reasoning": plan_step.reasoning}) + yield ReActEvent(event_type="tool_result", step=plan_step.step_id, data={"tool_name": plan_step.tool_name, "result": tool_result}) + + yield ReActEvent(event_type="synthesis", step=len(plan.steps) + 1, data={"message": "Synthesizing results..."}) + output, synthesis_tokens = await self._synthesis_phase(messages=messages, tool_results=tool_results, model=model, agent_name=agent_name, task_type=task_type, system_prompt=effective_system_prompt, compressor=compressor, cancellation_token=cancellation_token) + yield ReActEvent(event_type="final_answer", step=len(plan.steps) + 1, data={"output": output, "total_steps": len(plan.steps) + 1, "total_tokens": plan_tokens + synthesis_tokens}) + return + except Exception as e: + logger.warning(f"Plan-exec fallback also failed in stream mode: {e}") + raise _FallbackFailedError("plan_exec") + + async def _try_fallback_strategies( + self, + strategies: list[str], + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + effective_system_prompt: str | None = None, + trace_recorder: "TraceRecorder | None" = None, + memory_retriever: "MemoryRetriever | None" = None, + task_id: str | None = None, + compressor: "CompressionStrategy | None" = None, + retrieval_config: dict[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + trajectory: list[ReActStep] | None = None, + total_tokens: int = 0, + confirmation_handler: Any | None = None, + ) -> ReActResult | None: + """按配置的 fallback 策略顺序尝试回退,返回第一个成功的结果 + + Returns: + ReActResult if a fallback succeeded, None if all strategies exhausted + """ + for strategy in strategies: + if strategy == "simplified_rewoo": + result = await self._fallback_simplified_rewoo( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, cancellation_token=cancellation_token, + ) + if result is not None: + return result + + elif strategy == "react": + result = await self._fallback_react( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + confirmation_handler=confirmation_handler, + ) + if result is not None: + return result + + elif strategy == "direct": + result = await self._fallback_direct( + messages=messages, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, cancellation_token=cancellation_token, + trajectory=trajectory, total_tokens=total_tokens, + trace_recorder=trace_recorder, + ) + if result is not None: + return result + + elif strategy == "plan_exec": + result = await self._fallback_plan_exec( + messages=messages, tools=tools, model=model, + agent_name=agent_name, task_type=task_type, + effective_system_prompt=effective_system_prompt, + compressor=compressor, cancellation_token=cancellation_token, + ) + if result is not None: + return result + + return None + + async def _fallback_simplified_rewoo( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + cancellation_token: CancellationToken | None = None, + ) -> ReActResult | None: + """Simplified ReWOO fallback: retry planning with max_steps=3""" + logger.warning("ReWOO planning failed, trying simplified planning with max_steps=3") + try: + tool_schemas = self._build_tool_schemas(tools) if tools else None + plan, simplified_tokens = await self._plan_phase( + messages=messages, + tools=tools or [], + tool_schemas=tool_schemas, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=effective_system_prompt, + compressor=compressor, + cancellation_token=cancellation_token, + max_steps=3, + ) + if plan is not None and plan.steps: + logger.info("Simplified ReWOO planning succeeded") + # Execute the simplified plan + trajectory: list[ReActStep] = [] + total_tokens = simplified_tokens + tool_results: list[dict[str, Any]] = [] + for plan_step in plan.steps: + if cancellation_token is not None: + cancellation_token.check() + tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or []) + rewoo_step = ReWOOStep( + step=plan_step.step_id, + action="tool_call", + tool_name=plan_step.tool_name, + arguments=plan_step.arguments, + result=tool_result, + tokens=0, + plan_step_id=plan_step.step_id, + ) + trajectory.append(rewoo_step) + tool_results.append({ + "step_id": plan_step.step_id, + "tool_name": plan_step.tool_name, + "arguments": plan_step.arguments, + "result": tool_result, + "reasoning": plan_step.reasoning, + }) + + output, synthesis_tokens = await self._synthesis_phase( + messages=messages, tool_results=tool_results, + model=model, agent_name=agent_name, task_type=task_type, + system_prompt=effective_system_prompt, compressor=compressor, + cancellation_token=cancellation_token, + ) + total_tokens += synthesis_tokens + trajectory.append(ReWOOStep( + step=len(plan.steps) + 1, + action="final_answer", + content=output, + tokens=synthesis_tokens, + )) + return ReActResult( + output=output, + trajectory=trajectory, + total_steps=len(trajectory), + total_tokens=total_tokens, + fallback_strategy="simplified_rewoo", + ) + except Exception as e: + logger.warning(f"Simplified ReWOO planning also failed: {e}") + return None + + async def _fallback_react( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + system_prompt: str | None = None, + trace_recorder: "TraceRecorder | None" = None, + memory_retriever: "MemoryRetriever | None" = None, + task_id: str | None = None, + compressor: "CompressionStrategy | None" = None, + retrieval_config: dict[str, Any] | None = None, + cancellation_token: CancellationToken | None = None, + confirmation_handler: Any | None = None, + ) -> ReActResult | None: + """ReAct fallback: delegate to ReActEngine""" + logger.warning("ReWOO planning failed, falling back to ReActEngine") + try: + react_result = await self._react_engine.execute( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=system_prompt, + trace_recorder=trace_recorder, + memory_retriever=memory_retriever, + task_id=task_id, + compressor=compressor, + retrieval_config=retrieval_config, + cancellation_token=cancellation_token, + timeout_seconds=0, # timeout already handled by outer wrapper + confirmation_handler=confirmation_handler, + ) + react_result.fallback_strategy = "react" + return react_result + except Exception as e: + logger.warning(f"ReAct fallback also failed: {e}") + return None + + async def _fallback_direct( + self, + messages: list[dict[str, str]], + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + cancellation_token: CancellationToken | None = None, + trajectory: list[ReActStep] | None = None, + total_tokens: int = 0, + trace_recorder: "TraceRecorder | None" = None, + ) -> ReActResult | None: + """Direct fallback: simple LLM call without tools""" + logger.warning("Falling back to direct LLM call") + try: + direct_messages: list[dict[str, Any]] = [] + if effective_system_prompt: + direct_messages.append({"role": "system", "content": effective_system_prompt}) + direct_messages.extend(messages) + + if compressor: + try: + direct_messages = await compressor.compress(direct_messages) + except Exception as e: + logger.warning(f"Context compression failed in direct fallback: {e}") + + direct_response = await self._llm_gateway.chat( + messages=direct_messages, + model=model, + agent_name=agent_name, + task_type=task_type, + ) + total_tokens += direct_response.usage.total_tokens + + direct_step = ReWOOStep( + step=1, + action="final_answer", + content=direct_response.content, + tokens=direct_response.usage.total_tokens, + plan_step_id=None, + ) + if trajectory is not None: + trajectory.append(direct_step) + + if trace_recorder is not None: + trace_recorder.record_step( + step=1, + action="final_answer", + output_data={"content": direct_response.content}, + tokens_used=direct_response.usage.total_tokens, + ) + trace_recorder.end_trace(outcome="success") + + return ReActResult( + output=direct_response.content or "", + trajectory=trajectory or [direct_step], + total_steps=len(trajectory or [direct_step]), + total_tokens=total_tokens, + fallback_strategy="direct", + ) + except Exception as e: + logger.error(f"Direct LLM fallback also failed: {e}") + return None + + async def _fallback_plan_exec( + self, + messages: list[dict[str, str]], + tools: list[Tool] | None = None, + model: str = "default", + agent_name: str = "", + task_type: str = "", + effective_system_prompt: str | None = None, + compressor: "CompressionStrategy | None" = None, + cancellation_token: CancellationToken | None = None, + ) -> ReActResult | None: + """Plan-Exec fallback: plan then execute sequentially (like simplified ReWOO but with max_steps=5)""" + logger.warning("Falling back to plan-exec mode (max_steps=5)") + try: + tool_schemas = self._build_tool_schemas(tools) if tools else None + plan, plan_tokens = await self._plan_phase( + messages=messages, + tools=tools or [], + tool_schemas=tool_schemas, + model=model, + agent_name=agent_name, + task_type=task_type, + system_prompt=effective_system_prompt, + compressor=compressor, + cancellation_token=cancellation_token, + max_steps=5, + ) + if plan is not None and plan.steps: + trajectory: list[ReActStep] = [] + total_tokens = plan_tokens + tool_results: list[dict[str, Any]] = [] + for plan_step in plan.steps: + if cancellation_token is not None: + cancellation_token.check() + tool_result = await self._execute_tool(plan_step.tool_name, plan_step.arguments, tools or []) + rewoo_step = ReWOOStep( + step=plan_step.step_id, + action="tool_call", + tool_name=plan_step.tool_name, + arguments=plan_step.arguments, + result=tool_result, + tokens=0, + plan_step_id=plan_step.step_id, + ) + trajectory.append(rewoo_step) + tool_results.append({ + "step_id": plan_step.step_id, + "tool_name": plan_step.tool_name, + "arguments": plan_step.arguments, + "result": tool_result, + "reasoning": plan_step.reasoning, + }) + + output, synthesis_tokens = await self._synthesis_phase( + messages=messages, tool_results=tool_results, + model=model, agent_name=agent_name, task_type=task_type, + system_prompt=effective_system_prompt, compressor=compressor, + cancellation_token=cancellation_token, + ) + total_tokens += synthesis_tokens + trajectory.append(ReWOOStep( + step=len(plan.steps) + 1, + action="final_answer", + content=output, + tokens=synthesis_tokens, + )) + return ReActResult( + output=output, + trajectory=trajectory, + total_steps=len(trajectory), + total_tokens=total_tokens, + fallback_strategy="plan_exec", + ) + except Exception as e: + logger.warning(f"Plan-exec fallback also failed: {e}") + return None + # ── Phase Implementations ───────────────────────────── async def _plan_phase( @@ -1079,7 +1552,6 @@ class ReWOOEngine: # 尝试从 markdown 代码块中提取 if "```" in json_str: - import re code_block_match = re.search(r"```(?:json)?\s*\n(.*?)\n\s*```", json_str, re.DOTALL) if code_block_match: json_str = code_block_match.group(1).strip() diff --git a/src/agentkit/marketplace/auction.py b/src/agentkit/marketplace/auction.py index 4c48ce2..6fbc6c5 100644 --- a/src/agentkit/marketplace/auction.py +++ b/src/agentkit/marketplace/auction.py @@ -21,6 +21,14 @@ class Bid: capabilities: list[str] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) + def __post_init__(self) -> None: + if self.estimated_cost < 0: + raise ValueError(f"estimated_cost must be non-negative, got {self.estimated_cost}") + if not (0.0 <= self.confidence <= 1.0): + raise ValueError(f"confidence must be between 0.0 and 1.0, got {self.confidence}") + if self.payment_offer < 0: + raise ValueError(f"payment_offer must be non-negative, got {self.payment_offer}") + @dataclass class AuctionResult: @@ -98,3 +106,125 @@ class AuctionHouse: wealth_factor = self._wealth_tracker.get_wealth_factor(bid.agent_name) score = (bid.confidence / max(bid.estimated_cost, 0.001)) * wealth_factor return score + + def filter_by_capabilities( + self, bidders: list[Bid], required_capabilities: list[str] + ) -> list[Bid]: + """Filter bidders whose capabilities include ALL required capabilities. + + Args: + bidders: List of Bid objects to filter. + required_capabilities: Capabilities that a bidder must ALL possess. + + Returns: + List of Bids whose capabilities list includes every required capability. + """ + if not required_capabilities: + return bidders + required_set = {cap.lower() for cap in required_capabilities} + return [ + b for b in bidders + if required_set.issubset({cap.lower() for cap in b.capabilities}) + ] + + async def run_vickrey_auction( + self, + task_description: str, + bidders: list[Bid], + required_capabilities: list[str] | None = None, + ) -> AuctionResult: + """Run a Vickrey (second-price sealed-bid) auction. + + In a Vickrey auction each bidder submits a sealed bid (their + estimated_cost). The lowest estimated_cost wins, but the winner + *pays* the second-lowest estimated_cost rather than their own. + This is incentive-compatible: agents' dominant strategy is to bid + truthfully. + + Steps: + 1. Filter by required_capabilities (if provided). + 2. Filter out bankrupt agents. + 3. If only 1 eligible bidder → wins, pays 0. + 4. If 2+ eligible bidders → lowest cost wins, pays second-lowest. + 5. Update WealthTracker: winner earns (payment - cost_estimate). + 6. Return AuctionResult with selection_reason. + + Args: + task_description: Description of the task being auctioned. + bidders: List of Bid objects. + required_capabilities: Optional list of capabilities that bidders + must possess to be eligible. + + Returns: + AuctionResult with the winner and Vickrey outcome details. + """ + # 1. Capability filtering + eligible = ( + self.filter_by_capabilities(bidders, required_capabilities) + if required_capabilities + else list(bidders) + ) + + # 2. Filter out bankrupt agents + eligible = [ + b for b in eligible + if not self._wealth_tracker.is_bankrupt(b.agent_name) + ] + + # No bidders at all + if not bidders: + return AuctionResult( + winner=None, + all_bids=bidders, + selection_reason="No bidders participated", + total_bidders=0, + ) + + # All eligible bidders filtered out (bankrupt or no capabilities) + if not eligible: + return AuctionResult( + winner=None, + all_bids=bidders, + selection_reason="No eligible bidders (bankrupt or missing capabilities)", + total_bidders=len(bidders), + ) + + # 3. Sort by estimated_cost ascending (lowest cost = best bid) + # Apply minimum cost floor to prevent zero-cost bid manipulation + MIN_COST = 0.001 + for b in eligible: + b.estimated_cost = max(b.estimated_cost, MIN_COST) + eligible = sorted(eligible, key=lambda b: b.estimated_cost) + winner = eligible[0] + + # 4. Determine payment (second-price rule) + if len(eligible) == 1: + # Single bidder: payment equals their own cost (no profit, no loss) + payment = winner.estimated_cost + else: + payment = eligible[1].estimated_cost + + # 5. Update WealthTracker + profit = payment - winner.estimated_cost + self._wealth_tracker.reward(winner.agent_name, profit) + + # 6. Build selection_reason + if len(eligible) == 1: + reason = ( + f"Vickrey auction: Agent '{winner.agent_name}' won as sole eligible bidder " + f"(cost={winner.estimated_cost}, payment={payment:.4f}, profit=0)" + ) + else: + second = eligible[1] + reason = ( + f"Vickrey auction: Agent '{winner.agent_name}' won with lowest cost " + f"({winner.estimated_cost}), pays second-lowest cost ({payment}) from " + f"'{second.agent_name}'; profit={profit:.4f}" + ) + + return AuctionResult( + winner=winner, + all_bids=bidders, + selection_reason=reason, + total_bidders=len(bidders), + ) diff --git a/src/agentkit/marketplace/wealth.py b/src/agentkit/marketplace/wealth.py index 5c5d5cf..ecc43a4 100644 --- a/src/agentkit/marketplace/wealth.py +++ b/src/agentkit/marketplace/wealth.py @@ -2,6 +2,8 @@ from __future__ import annotations +import threading + class WealthTracker: """Track agent wealth for auction mechanism. @@ -9,25 +11,31 @@ class WealthTracker: Agents earn wealth by completing tasks successfully. Agents lose wealth by failing tasks. Bankrupt agents (wealth <= -100) are excluded from auctions. + + Thread-safe: all mutations are protected by a threading.Lock. """ def __init__(self, initial_wealth: float = 100.0) -> None: self._balances: dict[str, float] = {} self._initial_wealth = initial_wealth + self._lock = threading.Lock() def get_wealth(self, agent_name: str) -> float: """Get agent's current wealth, defaulting to initial_wealth""" - return self._balances.get(agent_name, self._initial_wealth) + with self._lock: + return self._balances.get(agent_name, self._initial_wealth) def reward(self, agent_name: str, amount: float) -> None: """Reward agent for successful task completion""" - current = self.get_wealth(agent_name) - self._balances[agent_name] = current + amount + with self._lock: + current = self._balances.get(agent_name, self._initial_wealth) + self._balances[agent_name] = current + amount def penalize(self, agent_name: str, amount: float) -> None: """Penalize agent for task failure""" - current = self.get_wealth(agent_name) - self._balances[agent_name] = current - amount + with self._lock: + current = self._balances.get(agent_name, self._initial_wealth) + self._balances[agent_name] = current - amount def is_bankrupt(self, agent_name: str) -> bool: """Check if agent is bankrupt (wealth <= -100)""" @@ -46,5 +54,5 @@ class WealthTracker: return all_agents def get_wealth_factor(self, agent_name: str) -> float: - """Get wealth factor for scoring: 1.0 + (wealth / 1000.0)""" - return 1.0 + (self.get_wealth(agent_name) / 1000.0) + """Get wealth factor for scoring: max(0.01, 1.0 + (wealth / 1000.0))""" + return max(0.01, 1.0 + (self.get_wealth(agent_name) / 1000.0)) diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 0c09fae..1ca622d 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -20,6 +20,7 @@ from agentkit.router.intent import IntentRouter from agentkit.skills.base import Skill, SkillConfig from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry +from agentkit.tools.skill_install import SkillInstallTool from agentkit.server.config import ServerConfig from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, portal, evolution_dashboard, kb_management, skill_management, workflows, chat from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware @@ -104,6 +105,8 @@ async def lifespan(app: FastAPI): if server_config is not None and server_config._config_path: server_config.on_change = lambda cfg: _on_config_change(app, cfg) server_config.watch_config() + # Store event loop reference for thread-safe config reload + app.state._event_loop = asyncio.get_running_loop() logger.info("Config hot-reload enabled") # Start MCP servers if configured @@ -132,7 +135,10 @@ async def lifespan(app: FastAPI): "你必须先使用搜索工具查找准确和最新的信息,然后再回答。" "中文内容优先使用 baidu_search 工具,英文/国际内容使用 web_search。" "在能够搜索到真相的情况下,绝不猜测或编造答案。" - "始终优先搜索而不是给出可能不正确的信息。" + "始终优先搜索而不是给出可能不正确的信息。\n\n" + "技能安装:当需要安装技能时,使用 skill_install 工具,不要用 shell 执行 npm install。" + "skill_install 的 source 参数格式为 owner/repo@skill,例如 vercel-labs/skills@find-skills。" + "如果不知道完整 source,先用 shell 执行 `npx skills search ` 搜索。" ) effective_system_prompt = memory_store.build_system_prompt(memory_snapshot, base_prompt) @@ -156,6 +162,10 @@ async def lifespan(app: FastAPI): } agent._tool_registry.register(MemoryTool(memory_store=memory_store)) agent._tool_registry.register(ShellTool()) + agent._tool_registry.register(SkillInstallTool( + skill_registry=app.state.skill_registry, + tool_registry=app.state.tool_registry, + )) agent._tool_registry.register(BaiduSearchTool()) agent._tool_registry.register(WebSearchTool(**search_api_keys)) agent._tool_registry.register(WebCrawlTool()) @@ -238,13 +248,15 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None: - Config version is incremented for audit tracking Uses a lock to prevent concurrent config reloads from racing. - Thread-safe: uses threading.Event for cross-thread signaling. + Thread-safe: schedules reload onto the asyncio event loop via + asyncio.run_coroutine_threadsafe() since watchfiles calls this + from a non-asyncio thread. """ import threading - lock: asyncio.Lock = app.state._config_reload_lock + lock = app.state._config_reload_lock - # Thread-safe: set pending flag via threading.Event or call_soon_threadsafe + # Thread-safe: set pending flag via threading.Event if not hasattr(app.state, "_config_reload_event"): app.state._config_reload_event = threading.Event() @@ -310,12 +322,18 @@ def _on_config_change(app: FastAPI, config: ServerConfig) -> None: logger.info(f"Config reload complete (v{current_version})") - # Schedule the reload as a task (non-blocking for the watcher thread) + # Schedule the reload onto the event loop via run_coroutine_threadsafe + # (watchfiles calls _on_config_change from a non-asyncio thread) try: loop = asyncio.get_running_loop() - loop.create_task(_reload()) + asyncio.run_coroutine_threadsafe(_reload(), loop) except RuntimeError: - logger.warning("No running event loop, config reload deferred") + # No running loop — try getting the stored event loop reference + _loop = getattr(app.state, "_event_loop", None) + if _loop is not None and not _loop.is_closed(): + asyncio.run_coroutine_threadsafe(_reload(), _loop) + else: + logger.warning("No running event loop, config reload deferred") def create_app( @@ -482,6 +500,7 @@ def create_app( org_context=org_context, auction_enabled=auction_enabled, classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic", + merged_llm_classify=server_config.router.get("merged_llm_classify", True) if server_config and server_config.router else True, ) app.state.cost_aware_router = cost_aware_router # Initialize task store from config diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index e976f80..4106995 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -193,7 +193,12 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques # Execute the Agent try: - react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) + # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) + react_engine = getattr(agent, "_react_engine", None) + if react_engine is None: + react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) + else: + react_engine.reset() tools = agent._tool_registry.list_tools() if agent._tool_registry else [] system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) result = await react_engine.execute( @@ -286,6 +291,10 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: cancellation_token = CancellationToken() + # Per-session concurrency guard to prevent unlimited task creation (DoS mitigation) + _MAX_CONCURRENT_TASKS = 4 + active_tasks: set[asyncio.Task] = set() + try: await websocket.send_json({"type": "connected", "session_id": session_id}) @@ -308,9 +317,27 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: content = msg.get("content", "") # Create a fresh CancellationToken for each message message_token = CancellationToken() - await _handle_chat_message( - websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations + + # Guard against unlimited concurrent tasks + # Clean up completed tasks first + active_tasks.difference_update(t for t in active_tasks if t.done()) + if len(active_tasks) >= _MAX_CONCURRENT_TASKS: + await websocket.send_json({ + "type": "error", + "data": {"message": "Too many concurrent requests. Please wait for the current task to complete."}, + }) + continue + + # Run in background task so the WebSocket receive loop stays free + # to process confirmation_reply / reply messages while the agent + # is waiting for user confirmation (otherwise deadlock). + task = asyncio.create_task( + _handle_chat_message( + websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations + ) ) + active_tasks.add(task) + task.add_done_callback(active_tasks.discard) elif msg_type == "reply": # Reply to AskHumanTool @@ -323,8 +350,12 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: # Reply to confirmation request confirmation_id = msg.get("confirmation_id") approved = msg.get("approved", False) + logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}") if confirmation_id and confirmation_id in pending_confirmations: pending_confirmations[confirmation_id].set_result(approved) + logger.info(f"Confirmation {confirmation_id} set_result({approved})") + else: + logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations") elif msg_type == "cancel": cancellation_token.cancel() @@ -424,10 +455,17 @@ async def _handle_chat_message( chat_messages = await sm.get_chat_messages(session_id) # Execute Agent with streaming - react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) + react_engine = getattr(agent, "_react_engine", None) + if react_engine is None: + react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + else: + react_engine.reset() # Create confirmation handler that sends request to frontend and waits for reply - _pending_confirmations = pending_confirmations or {} + # Use the same dict object — do NOT use `or {}` because an empty dict is falsy + # and would create a new dict, breaking the shared state with the WS loop. + _pending_confirmations = pending_confirmations if pending_confirmations is not None else {} async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool: """Send confirmation request to frontend via WebSocket and wait for user reply.""" @@ -442,16 +480,28 @@ async def _handle_chat_message( }) # Create a Future and wait for the user's reply - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() future: asyncio.Future[bool] = loop.create_future() _pending_confirmations[confirmation_id] = future + logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply") try: # Wait up to 5 minutes for user confirmation - return await asyncio.wait_for(future, timeout=300.0) + result = await asyncio.wait_for(future, timeout=300.0) + logger.info(f"Confirmation request {confirmation_id} resolved: {result}") + # Immediately notify frontend of the result so the card updates + # without waiting for the tool to re-execute + await websocket.send_json({ + "type": "confirmation_result", + "data": {"confirmation_id": confirmation_id, "approved": result}, + }) + return result except asyncio.TimeoutError: logger.warning(f"Confirmation request {confirmation_id} timed out") return False + except asyncio.CancelledError: + logger.warning(f"Confirmation request {confirmation_id} cancelled") + return False finally: _pending_confirmations.pop(confirmation_id, None) @@ -459,6 +509,7 @@ async def _handle_chat_message( try: final_content = "" + token_buffer: list[str] = [] async for event in react_engine.execute_stream( messages=chat_messages, tools=routing.tools, @@ -469,24 +520,51 @@ async def _handle_chat_message( confirmation_handler=_confirmation_handler, ): if event.event_type == "final_answer": + # Flush any buffered tokens as a single write + if token_buffer: + await websocket.send_json({"type": "token", "content": "".join(token_buffer)}) + token_buffer.clear() + # Then send final answer final_content = event.data.get("output", "") await websocket.send_json({ "type": "final_answer", "content": final_content, + "is_final": True, + }) + elif event.event_type == "token": + # Buffer tokens instead of sending immediately + token_buffer.append(event.data.get("content", "")) + elif event.event_type == "thinking": + # If we have buffered tokens, convert them to a thinking event + if token_buffer: + buffered_text = "".join(token_buffer) + token_buffer.clear() + await websocket.send_json({"type": "thinking", "content": buffered_text}) + # Also send the thinking event content + thinking_msg = event.data.get("message", "") + if thinking_msg: + await websocket.send_json({"type": "thinking", "content": thinking_msg}) + elif event.event_type == "tool_call": + # Convert buffered tokens to thinking (they were "thinking" text before tool call) + if token_buffer: + buffered_text = "".join(token_buffer) + token_buffer.clear() + await websocket.send_json({"type": "thinking", "content": buffered_text}) + await websocket.send_json({ + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.step, + "data": event.data, + }, }) elif event.event_type == "confirmation_request": - # Already handled by confirmation_handler, just notify frontend pass elif event.event_type == "confirmation_result": await websocket.send_json({ "type": "confirmation_result", "data": event.data, }) - elif event.event_type == "token": - await websocket.send_json({ - "type": "token", - "content": event.data.get("content", ""), - }) else: await websocket.send_json({ "type": "step", diff --git a/src/agentkit/server/static/index.html b/src/agentkit/server/static/index.html index bd5c30d..cf42de2 100644 --- a/src/agentkit/server/static/index.html +++ b/src/agentkit/server/static/index.html @@ -8,6 +8,7 @@ +