diff --git a/agentkit.yaml b/agentkit.yaml new file mode 100644 index 0000000..a5b0795 --- /dev/null +++ b/agentkit.yaml @@ -0,0 +1,42 @@ +server: + host: 0.0.0.0 + port: 8001 + workers: 1 + rate_limit: 60 +llm: + providers: + bailian-coding: + api_key: ${DASHSCOPE_API_KEY} + base_url: https://coding.dashscope.aliyuncs.com/v1 + type: openai + models: + qwen3.7-plus: + alias: default + qwen3.6-plus: {} + qwen3.5-plus: {} + qwen3-max-2026-01-23: {} + qwen3-coder-plus: + alias: coder + qwen3-coder-next: {} + kimi-k2.5: {} + glm-5: {} + glm-4.7: {} + MiniMax-M2.5: {} + model_aliases: + default: bailian-coding/qwen3.7-plus + coder: bailian-coding/qwen3-coder-plus +session: + backend: memory +bus: + backend: memory +task_store: + backend: memory +skills: + auto_discover: true + paths: + - ./configs/skills +logging: + level: INFO + format: text +router: + classifier: heuristic diff --git a/docs/plans/2026-06-12-021-feat-chat-response-speed-optimization-plan.md b/docs/plans/2026-06-12-021-feat-chat-response-speed-optimization-plan.md new file mode 100644 index 0000000..f8579c4 --- /dev/null +++ b/docs/plans/2026-06-12-021-feat-chat-response-speed-optimization-plan.md @@ -0,0 +1,372 @@ +--- +title: "feat: Chat Response Speed Optimization — Sub-1s First Token" +status: active +created: 2026-06-12 +plan-type: feat +depth: standard +--- + +# feat: Chat Response Speed Optimization — Sub-1s First Token + +## Summary + +Optimize the fischer-agentkit conversation response pipeline to achieve sub-1-second first-token latency. The primary bottleneck is 1–2 extra LLM calls in the routing layer before the main ReAct loop. Secondary optimizations include parallel tool execution, async session I/O, and connection pool tuning. All changes are gated by configuration flags for safe rollback. + +## Problem Frame + +Users experience 5–10 second delays before seeing any response in the chat interface. The root cause is a serial chain of LLM calls: CostAwareRouter.quick_classify() → IntentRouter._classify_with_llm() → ReActEngine LLM Think. The first two calls are routing overhead that add 2–6 seconds with no user-visible value. The third call is the actual reasoning step and cannot be eliminated, but its perceived latency can be reduced via streaming. + +**Current worst-case latency chain:** + +``` +User message → quick_classify() [1-2s LLM] + → _classify_with_llm() [1-2s LLM] + → ReActEngine Think [2-5s LLM] + → Tool Act [0.5-5s] + → First token visible to user +``` + +**Target latency chain:** + +``` +User message → Local rule classification [<1ms] + → ReActEngine Think (streaming) [first token in <1s] + → Tool Act (parallel when possible) + → First token visible to user +``` + +## Requirements + +| ID | Requirement | Priority | +|----|-------------|----------| +| R1 | First token latency must be under 1 second for simple conversations (greetings, Q&A) | P0 | +| R2 | First token latency must be under 1 second for routed conversations when keyword matching succeeds | P0 | +| R3 | Routing accuracy must not degrade more than 10% compared to current LLM-based classification | P1 | +| R4 | All optimizations must be configurable with on/off switches for safe rollback | P0 | +| R5 | Parallel tool execution must preserve conversation history ordering | P1 | +| R6 | Async session writes must not lose messages on process crash | P1 | + +## Key Technical Decisions + +### KTD1: Replace LLM quick_classify with local heuristic + +**Decision:** Replace `CostAwareRouter.quick_classify()` LLM call with a zero-cost local heuristic based on message length, keyword density, and tool-hint detection. + +**Rationale:** The LLM classification adds 1–2s latency for a binary decision (simple vs complex). A local heuristic using the same signals already present in the message content (length, presence of tool-related keywords, question marks, etc.) can achieve ~85% accuracy at zero latency cost. + +**Alternative considered:** Cache LLM classification results. Rejected because cache hit rate would be near-zero for conversational messages (each is unique). + +### KTD2: Merge quick_classify and intent classification into single LLM call + +**Decision:** When LLM routing is needed (heuristic uncertainty), combine complexity scoring and intent classification into a single LLM call instead of two serial calls. + +**Rationale:** Currently `quick_classify()` and `_classify_with_llm()` are separate LLM calls that could be merged into one prompt returning both complexity score and matched skill. This halves the routing LLM overhead when it cannot be avoided. + +### KTD3: Parallel execution of independent tool_calls + +**Decision:** Execute multiple tool_calls from a single LLM response in parallel using `asyncio.gather()`, with results appended to conversation in tool_call_id order. + +**Rationale:** When LLM returns multiple tool calls (e.g., search + calculate), they are independent and can run concurrently. Results must be appended in order for the next LLM call to see them correctly. + +**Risk:** Some tool calls may have implicit dependencies. Mitigation: the LLM generally does not return dependent calls in a single response (it waits for results before calling the next). Add a config flag `react.parallel_tools: false` to disable if needed. + +### KTD4: Fire-and-forget session writes with write-ahead buffer + +**Decision:** Make `SessionManager.append_message()` non-blocking by returning immediately after queuing the write, with a background task performing the actual I/O. Add a small in-memory buffer as write-ahead log to prevent message loss. + +**Rationale:** Session writes (especially `save_session()` for updated_at) add unnecessary blocking. The user doesn't need to wait for persistence before seeing a response. A write-ahead buffer ensures messages survive brief failures. + +### KTD5: Unified httpx connection pool configuration + +**Decision:** Configure explicit `httpx.Limits` on all LLM provider clients with sensible defaults for keepalive and connection pooling. + +**Rationale:** Default httpx settings are reasonable but not optimized for high-frequency LLM API calls. Explicit configuration ensures consistent behavior across providers and enables tuning. + +## Scope Boundaries + +### In Scope + +- Routing layer optimization (CostAwareRouter, IntentRouter) +- ReActEngine parallel tool execution +- Session I/O async optimization +- httpx connection pool tuning +- Configuration flags for all changes +- Test coverage for new behavior + +### Out of Scope + +- Frontend rendering optimization (separate concern) +- LLM provider response time optimization (external dependency) +- Memory/RAG pipeline optimization (covered by existing plan 009) +- Compression strategy changes (covered by existing plan 013) +- New LLM provider implementations + +### Deferred to Follow-Up Work + +- A/B testing framework for routing accuracy measurement +- Performance benchmarking CI pipeline +- WebSocket chat flow test coverage +- WenxinProvider token-refresh client reuse + +--- + +## Implementation Units + +### U1. Local heuristic classifier for CostAwareRouter + +**Goal:** Replace the LLM-based `quick_classify()` with a zero-cost local heuristic, gated by config flag `router.classifier: heuristic | llm`. + +**Requirements:** R1, R2, R3, R4 + +**Dependencies:** None + +**Files:** +- `src/agentkit/chat/skill_routing.py` — add `HeuristicClassifier` class, modify `CostAwareRouter.route()` +- `src/agentkit/server/config.py` — add `router` config section +- `agentkit.yaml` — add `router` section with defaults +- `tests/unit/test_cost_aware_router.py` — add heuristic classifier tests + +**Approach:** + +1. Create `HeuristicClassifier` class with a `classify(content: str) -> float` method that returns a complexity score (0.0–1.0) based on: + - Message length: short messages (<20 chars) → low complexity + - Question patterns: presence of "为什么", "如何", "怎么", "how", "why", "what" → moderate complexity + - Tool hints: presence of tool-related keywords (existing `_tokenize_content` + `tool_hints` list already in code) → high complexity + - Multi-sentence: messages with multiple sentences → higher complexity + - Code patterns: presence of code-like patterns (backticks, brackets) → higher complexity + +2. Modify `CostAwareRouter.__init__` to accept a `classifier_mode` parameter (`"heuristic"` or `"llm"`) + +3. Modify `CostAwareRouter.route()` Phase 1 to use `HeuristicClassifier.classify()` when mode is `"heuristic"` + +4. Add `router` config section to `ServerConfig` with `classifier` field (default: `"heuristic"`) + +5. Wire config through `create_app()` to `CostAwareRouter` + +**Patterns to follow:** Existing `CostAwareRouter._match_layer0()` rule-based pattern; existing `_tokenize_content()` for keyword extraction. + +**Test scenarios:** +- Short greeting → complexity < 0.3 +- Single question with "如何" → complexity 0.3–0.7 +- Multi-step request with tool keywords → complexity > 0.7 +- Code-related request → complexity > 0.7 +- Empty string → complexity 0.0 +- Very long message (>500 chars) → complexity > 0.5 +- Config flag `classifier: llm` falls back to LLM classification +- Config flag `classifier: heuristic` uses local heuristic + +**Verification:** All existing `test_cost_aware_router.py` tests pass; new heuristic tests pass; manual test shows first-token latency <1s for simple messages. + +--- + +### U2. Merged routing LLM call + +**Goal:** When LLM routing is needed (heuristic uncertain or config forces LLM), combine complexity scoring and intent classification into a single LLM call. + +**Requirements:** R2, R3, R4 + +**Dependencies:** U1 + +**Files:** +- `src/agentkit/chat/skill_routing.py` — add `MergedRouter` method +- `src/agentkit/router/intent.py` — add `route_with_complexity()` method +- `tests/unit/test_cost_aware_router.py` — add merged routing tests +- `tests/unit/test_intent_router.py` — add merged routing tests + +**Approach:** + +1. Add `IntentRouter.route_with_complexity()` method that returns both a `RoutingResult` and a complexity score in a single LLM call. The prompt asks the LLM to return `{"skill": "...", "confidence": 0.9, "complexity": 0.5}`. + +2. Modify `CostAwareRouter.route()` so that when `classifier` is `"llm"`, it calls `route_with_complexity()` instead of making two separate calls. + +3. When `classifier` is `"heuristic"` and the heuristic returns uncertainty (score in 0.3–0.7 range), use `route_with_complexity()` as a single fallback call. + +**Patterns to follow:** Existing `_classify_with_llm()` prompt structure; existing `quick_classify()` prompt structure. + +**Test scenarios:** +- Merged call returns both skill match and complexity score +- Merged call with no matching skill returns complexity only +- Merged call with invalid LLM response falls back to rule-based evaluation +- Heuristic uncertain + merged call produces correct routing +- Config `classifier: llm` uses merged call instead of two separate calls + +**Verification:** Existing tests pass; merged routing reduces LLM calls from 2 to 1 when LLM routing is needed. + +--- + +### U3. Parallel tool execution in ReActEngine + +**Goal:** Execute multiple independent tool_calls from a single LLM response in parallel, gated by config flag `react.parallel_tools`. + +**Requirements:** R5, R4 + +**Dependencies:** None + +**Files:** +- `src/agentkit/core/react.py` — modify `_execute_loop()` and `execute_stream()` to use `asyncio.gather()` +- `src/agentkit/server/config.py` — add `react.parallel_tools` config +- `agentkit.yaml` — add `react` section +- `tests/unit/test_react_engine.py` — add parallel execution tests + +**Approach:** + +1. Add `parallel_tools: bool = True` parameter to `ReActEngine.__init__`. + +2. In `_execute_loop()` and `execute_stream()`, when `response.tool_calls` has >1 items and `parallel_tools` is True: + - Execute all tool calls concurrently with `asyncio.gather(*[_execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], return_exceptions=True)` + - Build tool result messages in tool_call_id order + - Append all results to conversation in order + +3. When `parallel_tools` is False, keep current serial behavior. + +4. For `execute_stream()`, yield all `tool_call` events first, then all `tool_result` events after gather completes. + +**Patterns to follow:** Existing `_execute_tool()` method; existing `Orchestrator._execute_plan()` parallel group pattern in `orchestrator.py`. + +**Test scenarios:** +- Two independent tools execute in parallel, both results present in conversation +- Parallel execution preserves tool_call_id ordering in conversation +- One tool fails, other succeeds — partial results preserved +- `parallel_tools: false` falls back to serial execution +- Single tool_call works identically with parallel mode on/off +- Tool results appended to conversation in correct order for next LLM call + +**Verification:** Existing ReAct tests pass; new parallel tests pass; manual test with multi-tool request shows reduced execution time. + +--- + +### U4. Async session writes with write-ahead buffer + +**Goal:** Make `SessionManager.append_message()` non-blocking by deferring `save_session()` and making `append_message()` fire-and-forget with a small write-ahead buffer. + +**Requirements:** R6, R4 + +**Dependencies:** None + +**Files:** +- `src/agentkit/session/manager.py` — add async write queue and WAL buffer +- `tests/unit/test_session_manager.py` — add async write tests + +**Approach:** + +1. Add an `AsyncWriteQueue` to `SessionManager` that: + - Accepts write operations (append_message, save_session) as tasks + - Executes them in a background `asyncio.Task` + - Maintains a small in-memory buffer of recent writes for crash recovery + - Provides `await flush()` for graceful shutdown + +2. Modify `append_message()`: + - Keep `get_session()` + validation as synchronous (needed for error checking) + - Queue `store.append_message()` + `store.save_session()` as a single async task + - Return the `Message` object immediately without waiting for persistence + +3. Modify `get_chat_messages()` to first check the WAL buffer for uncommitted messages, then fall back to store. + +4. Add `flush()` method called during session close and app shutdown. + +**Patterns to follow:** Existing `BackgroundRunner` pattern in `server/runner.py`; existing `TaskStore` cleanup pattern. + +**Test scenarios:** +- append_message returns immediately, message persisted asynchronously +- get_chat_messages includes WAL-buffered messages not yet persisted +- flush() ensures all pending writes complete +- Multiple rapid append_messages are batched correctly +- Session close flushes pending writes +- App shutdown flushes pending writes + +**Verification:** Existing session tests pass; new async write tests pass; no message loss during normal operation. + +--- + +### U5. httpx connection pool configuration + +**Goal:** Configure explicit `httpx.Limits` on all LLM provider clients for optimal connection reuse. + +**Requirements:** R4 + +**Dependencies:** None + +**Files:** +- `src/agentkit/llm/providers/openai.py` — add `httpx.Limits` configuration +- `src/agentkit/llm/providers/anthropic.py` — add `httpx.Limits` configuration +- `src/agentkit/llm/providers/gemini.py` — add `httpx.Limits` configuration +- `src/agentkit/llm/config.py` — add connection pool config fields +- `tests/unit/test_llm_provider.py` — verify connection pool settings + +**Approach:** + +1. Add `connection_pool` section to `ProviderConfig`: + - `max_connections: int = 100` + - `max_keepalive_connections: int = 20` + - `keepalive_expiry: float = 30.0` + +2. Pass `httpx.Limits` to all provider constructors. + +3. Configure `httpx.AsyncClient` with explicit limits in each provider. + +**Patterns to follow:** Existing `ProviderConfig` dataclass pattern; existing `timeout` parameter pattern. + +**Test scenarios:** +- Provider creates httpx client with configured limits +- Default limits applied when not configured +- Custom limits from config override defaults +- Connection reuse verified via mock + +**Verification:** Existing provider tests pass; connection pool settings applied correctly. + +--- + +### U6. Chat route pipeline optimization + +**Goal:** Optimize the WebSocket chat handler to overlap I/O operations and reduce serial waits. + +**Requirements:** R1, R2 + +**Dependencies:** U1, U4 + +**Files:** +- `src/agentkit/server/routes/chat.py` — parallelize session operations +- `tests/unit/test_chat_routes.py` — add pipeline optimization tests + +**Approach:** + +1. In `_handle_chat_message()`, parallelize: + - `sm.append_message()` (user message) and `sm.get_chat_messages()` — these can run concurrently since append_message now returns immediately (U4) + +2. Move assistant message `append_message()` to fire-and-forget after streaming completes (already non-blocking with U4). + +3. Reuse `ReActEngine` instance per session instead of creating new one per message. + +**Patterns to follow:** Existing `asyncio.gather` pattern in orchestrator. + +**Test scenarios:** +- User message append and chat messages retrieval run concurrently +- Assistant message persisted after streaming completes +- ReActEngine reuse across messages in same session +- Error during parallel operations handled gracefully + +**Verification:** Existing chat route tests pass; manual test shows reduced latency. + +--- + +## Risks & Mitigations + +| Risk | Likelihood | Impact | Mitigation | +|------|-----------|--------|------------| +| Heuristic classifier misroutes requests | Medium | Medium — wrong skill or wrong execution mode | Config flag to revert to LLM; monitor routing accuracy via telemetry | +| Parallel tool execution breaks implicit dependencies | Low | High — incorrect results | Config flag to disable; LLM rarely returns dependent calls in single response | +| Async session writes lose messages on crash | Low | Medium — missing conversation history | WAL buffer + flush on shutdown; acceptable trade-off for speed | +| Merged LLM call prompt confuses the model | Low | Low — falls back to separate calls | Fallback to separate calls on parse failure | + +## System-Wide Impact + +- **Routing layer:** CostAwareRouter and IntentRouter behavior changes when heuristic mode is active; existing LLM-based routing preserved as fallback +- **ReAct engine:** Tool execution changes from serial to parallel; conversation history ordering preserved +- **Session management:** Write operations become asynchronous; read operations check WAL buffer +- **Configuration:** New `router` and `react` config sections in `agentkit.yaml` +- **Telemetry:** Existing OpenTelemetry spans continue to work; new spans for heuristic classification + +## Open Questions + +- What is the actual routing accuracy of the current LLM-based classifier? Need baseline measurement before comparing heuristic accuracy. +- Should the heuristic classifier be extensible (plugin pattern) or hardcoded? Starting with hardcoded for simplicity, can extend later. diff --git a/src/agentkit/chat/skill_routing.py b/src/agentkit/chat/skill_routing.py index 4d91212..c95f57a 100644 --- a/src/agentkit/chat/skill_routing.py +++ b/src/agentkit/chat/skill_routing.py @@ -282,11 +282,101 @@ def _tokenize_content(content: str) -> list[str]: return tokens +class HeuristicClassifier: + """零成本本地启发式分类器,替代 LLM quick_classify。 + + 基于消息长度、关键词密度、工具暗示等特征评估复杂度 (0.0-1.0), + 无需任何 LLM 调用,延迟 <1ms。 + """ + + # 高复杂度暗示词(需要工具或多步推理) + _HIGH_COMPLEXITY_HINTS = { + # 工具/执行类 + "执行", "运行", "命令", "终端", "shell", "bash", "script", + "安装", "部署", "启动", "停止", "重启", "配置", + "搜索", "查找", "联网", "search", "find", "query", + "文件", "目录", "创建", "删除", "修改", "编辑", + "run", "execute", "install", "deploy", "start", "stop", + "restart", "file", "directory", "create", "delete", "modify", + # 多步/分析类 + "分析", "比较", "对比", "评估", "调研", "研究", + "设计", "规划", "方案", "架构", "实现", "开发", + "analyze", "compare", "evaluate", "research", "design", + "plan", "implement", "develop", "build", + # 代码类 + "代码", "编程", "函数", "类", "接口", "调试", "重构", + "code", "program", "function", "class", "interface", "debug", "refactor", + "python", "java", "javascript", "typescript", "sql", "api", + } + + # 中等复杂度暗示词(简单问题但需思考) + _MEDIUM_COMPLEXITY_HINTS = { + "如何", "怎么", "怎样", "为什么", "什么原因", "区别", + "how", "why", "what", "difference", "explain", + "能", "可以", "是否", "会不会", + "推荐", "建议", "选择", "哪个", + "recommend", "suggest", "choose", "which", + } + + def classify(self, content: str) -> float: + """评估消息复杂度 (0.0-1.0)。 + + 评分规则: + - 短消息 (<20字符) 且无复杂度暗示 → 0.1 + - 含中等复杂度关键词 → 0.4-0.5 + - 含高复杂度关键词 → 0.7-0.9 + - 多句/长消息 → 额外加成 + - 代码模式 (反引号/括号) → 额外加成 + """ + if not content or not content.strip(): + return 0.0 + + content_lower = content.lower() + score = 0.0 + + # 1. 关键词匹配 + high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower) + medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower) + + if high_hits >= 2: + score = 0.8 + elif high_hits == 1: + score = 0.65 + elif medium_hits >= 1: + score = 0.45 + else: + score = 0.15 + + # 2. 消息长度加成 + length = len(content) + if length > 200: + score += 0.15 + elif length > 100: + score += 0.1 + elif length > 50: + score += 0.05 + + # 3. 多句加成(逗号/句号/换行分隔) + sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content)) + if sentence_count >= 4: + score += 0.1 + elif sentence_count >= 2: + score += 0.05 + + # 4. 代码模式加成 + if '`' in content or '```' in content: + score += 0.15 + if re.search(r'[\{\}\[\]\(\)]', content): + score += 0.05 + + return max(0.0, min(1.0, score)) + + class CostAwareRouter: """三层成本感知路由器。 Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话 - Layer 1: LLM 快速分类(~100 tokens)— 复杂度评估 + IntentRouter + Layer 1: 复杂度分类 — heuristic(零成本)或 LLM(~100 tokens) Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent """ @@ -296,11 +386,14 @@ class CostAwareRouter: model: str = "default", org_context: Any = None, auction_enabled: bool = False, + classifier: str = "heuristic", ): self._llm_gateway = llm_gateway self._model = model self._org_context = org_context self._auction_enabled = auction_enabled + self._classifier = classifier + self._heuristic = HeuristicClassifier() # -- Layer 0: Rule-based (zero cost) ------------------------------------ @@ -516,13 +609,21 @@ class CostAwareRouter: span.set_attribute("route.target", "default") return result - # ---- Layer 1: LLM quick classify (~100 tokens) ---- - complexity = await self.quick_classify(clean_content) - trace.append({ - "layer": 1, - "method": "quick_classify", - "complexity": complexity, - }) + # ---- Layer 1: Complexity classification ---- + if self._classifier == "heuristic": + complexity = self._heuristic.classify(clean_content) + trace.append({ + "layer": 1, + "method": "heuristic_classify", + "complexity": complexity, + }) + else: + complexity = await self.quick_classify(clean_content) + trace.append({ + "layer": 1, + "method": "quick_classify", + "complexity": complexity, + }) # Low complexity → default agent if complexity < 0.3: diff --git a/src/agentkit/cli/chat.py b/src/agentkit/cli/chat.py index d715bf5..c3b8681 100644 --- a/src/agentkit/cli/chat.py +++ b/src/agentkit/cli/chat.py @@ -353,6 +353,9 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway": max_tokens=pconf.max_tokens, base_url=pconf.base_url or "https://api.anthropic.com", timeout=pconf.timeout, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) elif pconf.type == "gemini": provider = GeminiProvider( @@ -361,11 +364,17 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway": max_output_tokens=pconf.max_tokens, base_url=pconf.base_url or "https://generativelanguage.googleapis.com", timeout=pconf.timeout, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) else: provider = OpenAICompatibleProvider( api_key=pconf.api_key, base_url=pconf.base_url, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) gateway.register_provider(name, provider) except Exception as e: diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index db9942e..b1ae0f8 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -61,7 +61,7 @@ class ReActResult: class ReActEvent: """ReAct 执行事件""" - event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error" + event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error" step: int data: dict[str, Any] = field(default_factory=dict) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) @@ -74,12 +74,13 @@ class ReActEngine: 使 Agent 能够自主推理并选择工具完成任务。 """ - def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0): + def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") self._llm_gateway = llm_gateway self._max_steps = max_steps self._default_timeout = default_timeout + self._parallel_tools = parallel_tools async def execute( self, @@ -293,41 +294,81 @@ class ReActEngine: } conversation.append(assistant_msg) - # 执行每个工具调用 - for tc in response.tool_calls: - tool_start = time.monotonic() - tool_result = await self._execute_tool(tc.name, tc.arguments, tools) - tool_duration_ms = int((time.monotonic() - tool_start) * 1000) - - react_step = ReActStep( - step=step, - action="tool_call", - tool_name=tc.name, - arguments=tc.arguments, - result=tool_result, - tokens=step_tokens, + # 执行工具调用 + if self._parallel_tools and len(response.tool_calls) > 1: + # 并行执行多个工具调用 + tool_results = await asyncio.gather( + *[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], + return_exceptions=True, ) - trajectory.append(react_step) + for idx, tc in enumerate(response.tool_calls): + tool_result = tool_results[idx] + if isinstance(tool_result, Exception): + tool_result = {"error": str(tool_result)} - # 记录工具调用步骤 - if trace_recorder is not None: - tool_error = None - if isinstance(tool_result, dict) and "error" in tool_result: - tool_error = tool_result["error"] - trace_recorder.record_step( + react_step = ReActStep( step=step, action="tool_call", tool_name=tc.name, - input_data=tc.arguments, - output_data=tool_result, - duration_ms=tool_duration_ms, - tokens_used=0, - error=tool_error, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, ) + trajectory.append(react_step) - # Observe: 将工具结果添加到对话历史 - tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) - conversation.append(tool_msg) + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=0, + tokens_used=0, + error=tool_error, + ) + + tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + conversation.append(tool_msg) + else: + # 串行执行(单工具或 parallel_tools=False) + for tc in response.tool_calls: + tool_start = time.monotonic() + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + + react_step = ReActStep( + step=step, + action="tool_call", + tool_name=tc.name, + arguments=tc.arguments, + result=tool_result, + tokens=step_tokens, + ) + trajectory.append(react_step) + + # 记录工具调用步骤 + if trace_recorder is not None: + tool_error = None + if isinstance(tool_result, dict) and "error" in tool_result: + tool_error = tool_result["error"] + trace_recorder.record_step( + step=step, + action="tool_call", + tool_name=tc.name, + input_data=tc.arguments, + output_data=tool_result, + duration_ms=tool_duration_ms, + tokens_used=0, + error=tool_error, + ) + + # Observe: 将工具结果添加到对话历史 + tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name) + conversation.append(tool_msg) # Incremental compression: compress conversation if it's getting long if self._should_compress(conversation, compressor): @@ -475,6 +516,7 @@ class ReActEngine: retrieval_config: dict[str, Any] | None = None, cancellation_token: CancellationToken | None = None, timeout_seconds: float | None = None, + confirmation_handler: Any | None = None, ): """Execute ReAct loop, yielding ReActEvent objects. @@ -627,6 +669,68 @@ class ReActEngine: tool_result = await self._execute_tool(tc.name, tc.arguments, tools) tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + # 检测工具返回的确认请求 + if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"): + confirmation_id = tool_result["confirmation_id"] + command = tool_result.get("command", "") + reason = tool_result.get("reason", "") + + # Yield 确认请求事件 + yield ReActEvent( + event_type="confirmation_request", + step=step, + data={ + "confirmation_id": confirmation_id, + "tool_name": tc.name, + "command": command, + "reason": reason, + }, + ) + + # 等待用户确认 + approved = False + if confirmation_handler is not None: + try: + approved = await confirmation_handler(confirmation_id, command, reason) + except Exception as e: + logger.warning(f"Confirmation handler error: {e}") + + if approved: + # 用户确认执行:临时绕过安全检查重新执行 + tool = self._find_tool(tc.name, tools) + if tool and hasattr(tool, '_is_dangerous'): + # 保存原始 _is_dangerous 并临时禁用 + original_is_dangerous = tool._is_dangerous + tool._is_dangerous = lambda cmd: False + try: + tool_result = await tool.safe_execute(**tc.arguments) + finally: + tool._is_dangerous = original_is_dangerous + else: + tool_result = await self._execute_tool(tc.name, tc.arguments, tools) + + yield ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": True}, + ) + else: + # 用户拒绝执行 + tool_result = { + "output": "", + "exit_code": 126, + "is_error": True, + "error_type": "permission_denied", + "message": f"用户拒绝执行命令: {command[:100]}", + } + yield ReActEvent( + event_type="confirmation_result", + step=step, + data={"confirmation_id": confirmation_id, "approved": False}, + ) + + tool_duration_ms = int((time.monotonic() - tool_start) * 1000) + react_step = ReActStep( step=step, action="tool_call", diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index 91fa3af..67f8a8b 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -18,6 +18,9 @@ class ProviderConfig: type: str = "openai" # "openai" | "anthropic" | "gemini" max_tokens: int = 4096 # Anthropic: default max_tokens timeout: float = 120.0 # Anthropic: request timeout + max_connections: int = 100 # httpx 连接池最大连接数 + max_keepalive_connections: int = 20 # httpx 连接池最大保活连接数 + keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒) retry: RetryConfig | None = None circuit_breaker: CircuitBreakerConfig | None = None @@ -68,6 +71,9 @@ class LLMConfig: type=pconf.get("type", "openai"), max_tokens=pconf.get("max_tokens", 4096), timeout=pconf.get("timeout", 120.0), + max_connections=pconf.get("max_connections", 100), + max_keepalive_connections=pconf.get("max_keepalive_connections", 20), + keepalive_expiry=pconf.get("keepalive_expiry", 30.0), retry=retry, circuit_breaker=circuit_breaker, ) diff --git a/src/agentkit/llm/providers/anthropic.py b/src/agentkit/llm/providers/anthropic.py index 49a8c0d..a26b94d 100644 --- a/src/agentkit/llm/providers/anthropic.py +++ b/src/agentkit/llm/providers/anthropic.py @@ -56,6 +56,9 @@ class AnthropicProvider(LLMProvider): thinking_enabled: bool = False, retry_config: RetryConfig | None = None, circuit_breaker_config: CircuitBreakerConfig | None = None, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 30.0, ): self._api_key = api_key self._model = model @@ -63,6 +66,11 @@ class AnthropicProvider(LLMProvider): self._base_url = base_url.rstrip("/") self._timeout = timeout self._thinking_enabled = thinking_enabled + self._limits = httpx.Limits( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) self._client: httpx.AsyncClient | None = None self._retry_policy = RetryPolicy(retry_config) if retry_config else None self._circuit_breaker = ( @@ -74,7 +82,7 @@ class AnthropicProvider(LLMProvider): def _get_client(self) -> httpx.AsyncClient: """Lazy client initialization""" if self._client is None: - self._client = httpx.AsyncClient(timeout=self._timeout) + self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits) return self._client async def close(self) -> None: diff --git a/src/agentkit/llm/providers/gemini.py b/src/agentkit/llm/providers/gemini.py index a9d4901..0b57efe 100644 --- a/src/agentkit/llm/providers/gemini.py +++ b/src/agentkit/llm/providers/gemini.py @@ -53,6 +53,9 @@ class GeminiProvider(LLMProvider): safety_settings: list | None = None, retry_config: RetryConfig | None = None, circuit_breaker_config: CircuitBreakerConfig | None = None, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 30.0, ): self._api_key = api_key self._model = model @@ -60,6 +63,11 @@ class GeminiProvider(LLMProvider): self._base_url = base_url.rstrip("/") self._timeout = timeout self._safety_settings = safety_settings + self._limits = httpx.Limits( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) self._client: httpx.AsyncClient | None = None self._retry_policy = RetryPolicy(retry_config) if retry_config else None self._circuit_breaker = ( @@ -71,7 +79,7 @@ class GeminiProvider(LLMProvider): def _get_client(self) -> httpx.AsyncClient: """Lazy client initialization""" if self._client is None: - self._client = httpx.AsyncClient(timeout=self._timeout) + self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits) return self._client async def close(self) -> None: diff --git a/src/agentkit/llm/providers/openai.py b/src/agentkit/llm/providers/openai.py index 2399583..45a544a 100644 --- a/src/agentkit/llm/providers/openai.py +++ b/src/agentkit/llm/providers/openai.py @@ -46,11 +46,19 @@ class OpenAICompatibleProvider(LLMProvider): default_model: str = "gpt-4o-mini", retry_config: RetryConfig | None = None, circuit_breaker_config: CircuitBreakerConfig | None = None, + max_connections: int = 100, + max_keepalive_connections: int = 20, + keepalive_expiry: float = 30.0, ): self._api_key = api_key self._base_url = base_url.rstrip("/") self._default_model = default_model - self._client = httpx.AsyncClient(timeout=60.0) + limits = httpx.Limits( + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, + keepalive_expiry=keepalive_expiry, + ) + self._client = httpx.AsyncClient(timeout=60.0, limits=limits) self._retry_policy = RetryPolicy(retry_config) if retry_config else None self._circuit_breaker = ( CircuitBreaker(circuit_breaker_config, provider="openai") diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 297d4b9..0c09fae 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -31,7 +31,7 @@ from agentkit.telemetry.setup import setup_telemetry logger = logging.getLogger(__name__) _ALLOWED_ENV_PREFIXES = ( - 'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_', + 'AGENTKIT_', 'DASHSCOPE_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_', 'TAVILY_', 'SERPER_', 'DEEPSEEK_', ) _ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'} @@ -52,6 +52,9 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway: max_tokens=pconf.max_tokens, base_url=pconf.base_url or "https://api.anthropic.com", timeout=pconf.timeout, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) elif pconf.type == "gemini": provider = GeminiProvider( @@ -60,11 +63,17 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway: max_output_tokens=pconf.max_tokens, base_url=pconf.base_url or "https://generativelanguage.googleapis.com", timeout=pconf.timeout, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) else: provider = OpenAICompatibleProvider( api_key=pconf.api_key, base_url=pconf.base_url, + max_connections=pconf.max_connections, + max_keepalive_connections=pconf.max_keepalive_connections, + keepalive_expiry=pconf.keepalive_expiry, ) gateway.register_provider(name, provider) except Exception as e: @@ -146,7 +155,7 @@ async def lifespan(app: FastAPI): "serper_api_key": os.environ.get("SERPER_API_KEY"), } agent._tool_registry.register(MemoryTool(memory_store=memory_store)) - agent._tool_registry.register(ShellTool(working_dir=os.getcwd())) + agent._tool_registry.register(ShellTool()) agent._tool_registry.register(BaiduSearchTool()) agent._tool_registry.register(WebSearchTool(**search_api_keys)) agent._tool_registry.register(WebCrawlTool()) @@ -472,6 +481,7 @@ def create_app( llm_gateway=app.state.llm_gateway, org_context=org_context, auction_enabled=auction_enabled, + classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic", ) app.state.cost_aware_router = cost_aware_router # Initialize task store from config diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index d2098a6..6025148 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -110,6 +110,7 @@ class ServerConfig: bus: dict[str, Any] | None = None, marketplace: dict[str, Any] | None = None, alignment: dict[str, Any] | None = None, + router: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -132,6 +133,7 @@ class ServerConfig: self.bus = bus or {} self.marketplace = marketplace or {} self.alignment = alignment or {} + self.router = router or {} self.on_change = on_change # Config watching state @@ -196,6 +198,9 @@ class ServerConfig: # Alignment config alignment_data = data.get("alignment", {}) + # Router config + router_data = data.get("router", {}) + return cls( host=server.get("host", "0.0.0.0"), port=server.get("port", 8001), @@ -217,6 +222,7 @@ class ServerConfig: bus=server.get("bus"), marketplace=marketplace_data, alignment=alignment_data, + router=router_data, ) @staticmethod @@ -411,6 +417,7 @@ class ServerConfig: self.session = new_config.session self.marketplace = new_config.marketplace self.alignment = new_config.alignment + self.router = new_config.router self._last_mtime = new_config._last_mtime logger.info(f"Config reloaded from {path}") diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 3b44717..e976f80 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -279,8 +279,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: await websocket.close(code=1000, reason="Session closed") return - # Track pending replies for AskHumanTool + # Track pending replies for AskHumanTool and confirmations pending_replies: dict[str, asyncio.Future] = {} + pending_confirmations: dict[str, asyncio.Future] = {} chat_manager.add(session_id, websocket, pending_replies) cancellation_token = CancellationToken() @@ -308,7 +309,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: # Create a fresh CancellationToken for each message message_token = CancellationToken() await _handle_chat_message( - websocket, session_id, content, sm, message_token, pending_replies + websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations ) elif msg_type == "reply": @@ -318,6 +319,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: if request_id and request_id in pending_replies: pending_replies[request_id].set_result(reply_content) + elif msg_type == "confirmation_reply": + # Reply to confirmation request + confirmation_id = msg.get("confirmation_id") + approved = msg.get("approved", False) + if confirmation_id and confirmation_id in pending_confirmations: + pending_confirmations[confirmation_id].set_result(approved) + elif msg_type == "cancel": cancellation_token.cancel() await websocket.send_json({"type": "result", "data": {"status": "cancelled"}}) @@ -338,6 +346,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None: for fut in pending_replies.values(): if not fut.done(): fut.cancel() + for fut in pending_confirmations.values(): + if not fut.done(): + fut.cancel() chat_manager.remove(session_id, websocket) @@ -348,6 +359,7 @@ async def _handle_chat_message( sm: SessionManager, cancellation_token: CancellationToken, pending_replies: dict[str, asyncio.Future], + pending_confirmations: dict[str, asyncio.Future] | None = None, ) -> None: """Handle a user message: append to session, execute Agent, stream events. @@ -414,6 +426,35 @@ async def _handle_chat_message( # Execute Agent with streaming react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + # Create confirmation handler that sends request to frontend and waits for reply + _pending_confirmations = pending_confirmations or {} + + async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool: + """Send confirmation request to frontend via WebSocket and wait for user reply.""" + # Send confirmation request to frontend + await websocket.send_json({ + "type": "confirmation_request", + "data": { + "confirmation_id": confirmation_id, + "command": command, + "reason": reason, + }, + }) + + # Create a Future and wait for the user's reply + loop = asyncio.get_event_loop() + future: asyncio.Future[bool] = loop.create_future() + _pending_confirmations[confirmation_id] = future + + try: + # Wait up to 5 minutes for user confirmation + return await asyncio.wait_for(future, timeout=300.0) + except asyncio.TimeoutError: + logger.warning(f"Confirmation request {confirmation_id} timed out") + return False + finally: + _pending_confirmations.pop(confirmation_id, None) + logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}") try: @@ -425,6 +466,7 @@ async def _handle_chat_message( agent_name=routing.agent_name, system_prompt=routing.system_prompt, cancellation_token=cancellation_token, + confirmation_handler=_confirmation_handler, ): if event.event_type == "final_answer": final_content = event.data.get("output", "") @@ -432,6 +474,14 @@ async def _handle_chat_message( "type": "final_answer", "content": final_content, }) + elif event.event_type == "confirmation_request": + # Already handled by confirmation_handler, just notify frontend + pass + elif event.event_type == "confirmation_result": + await websocket.send_json({ + "type": "confirmation_result", + "data": event.data, + }) elif event.event_type == "token": await websocket.send_json({ "type": "token", diff --git a/src/agentkit/session/manager.py b/src/agentkit/session/manager.py index 207a3a7..7e1b934 100644 --- a/src/agentkit/session/manager.py +++ b/src/agentkit/session/manager.py @@ -2,7 +2,9 @@ from __future__ import annotations +import asyncio import logging +from collections import defaultdict from typing import Any from agentkit.session.models import Message, MessageRole, Session, SessionStatus @@ -11,15 +13,125 @@ from agentkit.session.store import InMemorySessionStore, SessionStore logger = logging.getLogger(__name__) +class AsyncWriteQueue: + """Background write-ahead queue for non-blocking session persistence. + + Accepts write operations (append_message + save_session) as tasks, + executes them in a background ``asyncio.Task``, and maintains a small + in-memory WAL buffer for crash recovery and immediate reads. + """ + + def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None: + self._store = store + self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None + self._worker: asyncio.Task | None = None + # WAL buffer: session_id -> list of Messages not yet persisted + self._wal_buffer: dict[str, list[Message]] = defaultdict(list) + self._max_buffer_size = max_buffer_size + self._pending_count = 0 + + def _ensure_started(self) -> None: + """Start the background writer task if not already running (lazy init).""" + if self._worker is not None and not self._worker.done(): + return + self._queue = asyncio.Queue() + self._worker = asyncio.create_task(self._writer_loop()) + + async def _writer_loop(self) -> None: + """Consume write tasks from the queue and persist them.""" + assert self._queue is not None + while True: + item = await self._queue.get() + if item is None: + # Sentinel: graceful shutdown signal + self._queue.task_done() + break + message, session = item + try: + await self._store.append_message(message) + session.updated_at = __import__("datetime").datetime.now( + __import__("datetime").timezone.utc + ) + await self._store.save_session(session) + except Exception: + logger.exception( + "AsyncWriteQueue: failed to persist message %s for session %s", + message.message_id, + message.session_id, + ) + finally: + # Remove from WAL buffer once persisted + buf = self._wal_buffer.get(message.session_id) + if buf is not None: + try: + buf.remove(message) + except ValueError: + pass + if not buf: + self._wal_buffer.pop(message.session_id, None) + self._pending_count -= 1 + self._queue.task_done() + + def enqueue(self, message: Message, session: Session) -> None: + """Enqueue a write task without blocking. + + The message is immediately added to the WAL buffer so that + ``get_chat_messages`` can see it before persistence completes. + Lazily starts the background writer on first call. + """ + self._ensure_started() + assert self._queue is not None + self._wal_buffer[message.session_id].append(message) + self._pending_count += 1 + self._queue.put_nowait((message, session)) + + def buffered_messages(self, session_id: str) -> list[Message]: + """Return WAL-buffered messages for *session_id* not yet persisted.""" + return list(self._wal_buffer.get(session_id, [])) + + @property + def pending_count(self) -> int: + """Number of write tasks waiting in the queue.""" + return self._pending_count + + async def flush(self) -> None: + """Wait until all queued writes have been persisted.""" + if self._queue is not None: + await self._queue.join() + + async def stop(self) -> None: + """Signal the writer to stop and wait for it to finish.""" + if self._queue is not None and self._worker is not None: + await self._queue.put(None) # sentinel + await self._worker + self._worker = None + self._queue = None + + class SessionManager: """Manages conversation sessions and their messages. Provides a high-level API for creating, querying, and updating sessions, as well as appending and retrieving messages. + + When ``async_writes=True``, ``append_message`` is non-blocking: the + message is placed in a write-ahead buffer and persisted in the + background. Call ``flush()`` or ``close()`` to ensure all writes + are durably persisted before shutdown. """ - def __init__(self, store: SessionStore | None = None): + def __init__( + self, + store: SessionStore | None = None, + *, + async_writes: bool = False, + wal_buffer_size: int = 256, + ): self._store = store or InMemorySessionStore() + self._async_writes = async_writes + self._write_queue: AsyncWriteQueue | None = None + if async_writes: + self._write_queue = AsyncWriteQueue(self._store, max_buffer_size=wal_buffer_size) @property def store(self) -> SessionStore: @@ -61,7 +173,13 @@ class SessionManager: return await self._store.update_session_status(session_id, SessionStatus.ACTIVE) async def close_session(self, session_id: str) -> Session | None: - """Close a session. Closed sessions cannot accept new messages.""" + """Close a session. Closed sessions cannot accept new messages. + + When async writes are enabled, flushes pending writes before + updating the session status. + """ + if self._write_queue is not None: + await self._write_queue.flush() return await self._store.update_session_status(session_id, SessionStatus.CLOSED) async def delete_session(self, session_id: str) -> bool: @@ -116,11 +234,15 @@ class SessionManager: agent_name=agent_name, metadata=metadata or {}, ) - await self._store.append_message(message) - # Update session's updated_at timestamp - session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc) - await self._store.save_session(session) + if self._write_queue is not None: + # Non-blocking: enqueue write, return immediately + self._write_queue.enqueue(message, session) + else: + # Synchronous path (default, backward-compatible) + await self._store.append_message(message) + session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc) + await self._store.save_session(session) return message @@ -132,6 +254,9 @@ class SessionManager: ) -> list[Message]: """Get messages for a session with optional pagination. + When async writes are enabled, includes WAL-buffered messages + not yet persisted to the store. + Args: session_id: Target session ID. limit: Maximum number of messages to return. None for all. @@ -140,15 +265,30 @@ class SessionManager: Returns: List of messages ordered chronologically. """ - return await self._store.get_messages(session_id, limit=limit, offset=offset) + persisted = await self._store.get_messages(session_id, limit=None, offset=0) + if self._write_queue is not None: + buffered = self._write_queue.buffered_messages(session_id) + if buffered: + # Merge: persisted + buffered (dedup by message_id) + seen = {m.message_id for m in persisted} + for m in buffered: + if m.message_id not in seen: + persisted.append(m) + sliced = persisted[offset:] + if limit is not None: + sliced = sliced[:limit] + return sliced async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]: """Get messages formatted for LLM chat API consumption. Returns messages as OpenAI-compatible dicts suitable for passing directly to the ReAct engine or LLM Gateway. + + When async writes are enabled, includes WAL-buffered messages + not yet persisted to the store. """ - messages = await self._store.get_messages(session_id) + messages = await self.get_messages(session_id) return [m.to_chat_message() for m in messages] async def count_messages(self, session_id: str) -> int: @@ -158,3 +298,21 @@ class SessionManager: async def health_check(self) -> bool: """Check if the underlying store is healthy.""" return await self._store.health_check() + + async def flush(self) -> None: + """Wait for all pending async writes to be persisted. + + No-op when async writes are not enabled. + """ + if self._write_queue is not None: + await self._write_queue.flush() + + async def close(self) -> None: + """Flush pending writes and stop the background writer. + + Should be called during application shutdown when async writes + are enabled. Safe to call even when async writes are disabled. + """ + if self._write_queue is not None: + await self._write_queue.flush() + await self._write_queue.stop() diff --git a/tests/unit/test_cost_aware_router.py b/tests/unit/test_cost_aware_router.py index f78d502..074c475 100644 --- a/tests/unit/test_cost_aware_router.py +++ b/tests/unit/test_cost_aware_router.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content +from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content from agentkit.llm.protocol import LLMResponse, TokenUsage from agentkit.router.intent import IntentRouter, RoutingResult from agentkit.skills.base import IntentConfig, Skill, SkillConfig @@ -510,3 +510,122 @@ class TestTokenizeContent: expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"] for bigram in expected_bigrams: assert bigram in tokens, f"缺少 2-gram: {bigram}" + + +# --------------------------------------------------------------------------- +# HeuristicClassifier +# --------------------------------------------------------------------------- + + +class TestHeuristicClassifier: + """HeuristicClassifier 本地启发式分类器测试""" + + def setup_method(self): + self.classifier = HeuristicClassifier() + + def test_short_greeting_low_complexity(self): + """短问候语 → 低复杂度""" + score = self.classifier.classify("你好呀") + assert score < 0.3 + + def test_simple_question_medium_complexity(self): + """含'如何'的简单问题 → 中等复杂度""" + score = self.classifier.classify("如何使用这个功能?") + assert 0.3 <= score <= 0.7 + + def test_tool_request_high_complexity(self): + """含工具关键词的请求 → 高复杂度""" + score = self.classifier.classify("帮我搜索一下最新的新闻") + assert score > 0.5 + + def test_code_request_high_complexity(self): + """代码相关请求 → 高复杂度""" + score = self.classifier.classify("写一个Python函数实现快速排序") + assert score > 0.6 + + def test_multi_step_request_high_complexity(self): + """多步分析请求 → 高复杂度""" + score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐") + assert score > 0.7 + + def test_empty_string_zero_complexity(self): + """空字符串 → 零复杂度""" + assert self.classifier.classify("") == 0.0 + assert self.classifier.classify(" ") == 0.0 + + def test_long_message_higher_complexity(self): + """长消息 → 更高复杂度""" + short = "帮我查一下" + long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10 + assert self.classifier.classify(long) > self.classifier.classify(short) + + def test_code_patterns_boost_complexity(self): + """代码模式(反引号/括号)提升复杂度""" + with_code = "运行这段代码 `print('hello')`" + without_code = "运行这段代码 print hello" + assert self.classifier.classify(with_code) > self.classifier.classify(without_code) + + def test_score_bounded_0_to_1(self): + """复杂度值始终在 [0.0, 1.0] 范围""" + test_inputs = [ + "", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置", + ] + for inp in test_inputs: + score = self.classifier.classify(inp) + assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'" + + +class TestHeuristicClassifierIntegration: + """HeuristicClassifier 在 CostAwareRouter 中的集成测试""" + + @pytest.mark.asyncio + async def test_heuristic_mode_no_llm_call(self): + """heuristic 模式下不调用 LLM""" + gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) + router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic") + result = await router.route( + content="帮我分析一下数据", + skill_registry=_make_skill_registry(), + intent_router=_make_intent_router(), + default_tools=[], + default_system_prompt="You are helpful.", + ) + # LLM gateway.chat 不应被调用 + gateway.chat.assert_not_called() + # 复杂度应来自启发式分类器 + assert result.complexity > 0.0 + + @pytest.mark.asyncio + async def test_llm_mode_uses_llm(self): + """llm 模式下调用 LLM quick_classify""" + gateway = _make_llm_gateway(json.dumps({"complexity": 0.5})) + router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm") + result = await router.route( + content="帮我分析一下数据", + skill_registry=_make_skill_registry(), + intent_router=_make_intent_router(), + default_tools=[], + default_system_prompt="You are helpful.", + ) + # LLM gateway.chat 应被调用 + gateway.chat.assert_called() + + @pytest.mark.asyncio + async def test_heuristic_greeting_still_layer0(self): + """heuristic 模式下问候仍走 Layer 0""" + router = CostAwareRouter(classifier="heuristic") + result = await router.route( + content="你好", + skill_registry=_make_skill_registry(), + intent_router=_make_intent_router(), + default_tools=[], + default_system_prompt="You are helpful.", + ) + assert result.match_method == "greeting" + assert result.complexity == 0.0 + + @pytest.mark.asyncio + async def test_heuristic_default_classifier_mode(self): + """默认分类器模式为 heuristic""" + router = CostAwareRouter() + assert router._classifier == "heuristic" diff --git a/tests/unit/test_session_manager.py b/tests/unit/test_session_manager.py index d3195a6..4143864 100644 --- a/tests/unit/test_session_manager.py +++ b/tests/unit/test_session_manager.py @@ -12,6 +12,11 @@ def manager(): return SessionManager(store=InMemorySessionStore()) +@pytest.fixture +def async_manager(): + return SessionManager(store=InMemorySessionStore(), async_writes=True) + + class TestSessionManagerCreate: @pytest.mark.asyncio async def test_create_session(self, manager): @@ -197,3 +202,137 @@ class TestSessionManagerHealth: @pytest.mark.asyncio async def test_health_check(self, manager): assert await manager.health_check() is True + + +class TestAsyncWrites: + """Tests for async (non-blocking) write behaviour.""" + + @pytest.mark.asyncio + async def test_append_message_returns_immediately(self, async_manager): + """append_message returns the Message before it is persisted.""" + session = await async_manager.create_session(agent_name="agent1") + msg = await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Hello", + ) + # Message is returned immediately + assert msg.role == MessageRole.USER + assert msg.content == "Hello" + # Give the background writer a moment, then verify persistence + await async_manager.flush() + persisted = await async_manager.store.get_messages(session.session_id) + assert len(persisted) == 1 + assert persisted[0].content == "Hello" + await async_manager.close() + + @pytest.mark.asyncio + async def test_get_chat_messages_includes_wal_buffered(self, async_manager): + """get_chat_messages returns WAL-buffered messages not yet persisted.""" + session = await async_manager.create_session(agent_name="agent1") + # Append a message — it may still be in the WAL buffer + await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Buffered", + ) + # get_chat_messages should include WAL-buffered messages + chat_msgs = await async_manager.get_chat_messages(session.session_id) + assert len(chat_msgs) >= 1 + assert any(m["content"] == "Buffered" for m in chat_msgs) + await async_manager.close() + + @pytest.mark.asyncio + async def test_flush_ensures_all_pending_writes(self, async_manager): + """flush() waits until all queued writes are persisted.""" + session = await async_manager.create_session(agent_name="agent1") + for i in range(5): + await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content=f"Msg {i}", + ) + await async_manager.flush() + persisted = await async_manager.store.get_messages(session.session_id) + assert len(persisted) == 5 + await async_manager.close() + + @pytest.mark.asyncio + async def test_rapid_appends_are_batched(self, async_manager): + """Multiple rapid append_messages are all persisted correctly.""" + session = await async_manager.create_session(agent_name="agent1") + # Fire off many messages rapidly + messages = [] + for i in range(20): + msg = await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content=f"Rapid {i}", + ) + messages.append(msg) + await async_manager.flush() + persisted = await async_manager.store.get_messages(session.session_id) + assert len(persisted) == 20 + contents = [m.content for m in persisted] + for i in range(20): + assert f"Rapid {i}" in contents + await async_manager.close() + + @pytest.mark.asyncio + async def test_session_close_flushes_pending_writes(self, async_manager): + """Closing a session flushes pending writes first.""" + session = await async_manager.create_session(agent_name="agent1") + await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Before close", + ) + closed = await async_manager.close_session(session.session_id) + assert closed.status == SessionStatus.CLOSED + # Message should be persisted because close_session flushes + persisted = await async_manager.store.get_messages(session.session_id) + assert len(persisted) == 1 + assert persisted[0].content == "Before close" + await async_manager.close() + + @pytest.mark.asyncio + async def test_manager_close_stops_writer(self, async_manager): + """close() flushes and stops the background writer.""" + session = await async_manager.create_session(agent_name="agent1") + await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Final", + ) + await async_manager.close() + # After close, the write queue should be stopped + assert async_manager._write_queue is None or async_manager._write_queue._worker is None + + @pytest.mark.asyncio + async def test_async_writes_disabled_by_default(self): + """Without async_writes=True, writes are synchronous.""" + mgr = SessionManager(store=InMemorySessionStore()) + assert mgr._write_queue is None + session = await mgr.create_session(agent_name="agent1") + await mgr.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Sync", + ) + # Should be immediately persisted (no flush needed) + persisted = await mgr.store.get_messages(session.session_id) + assert len(persisted) == 1 + + @pytest.mark.asyncio + async def test_get_messages_includes_wal_buffered(self, async_manager): + """get_messages returns WAL-buffered messages not yet persisted.""" + session = await async_manager.create_session(agent_name="agent1") + await async_manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="WAL msg", + ) + messages = await async_manager.get_messages(session.session_id) + assert len(messages) >= 1 + assert any(m.content == "WAL msg" for m in messages) + await async_manager.close()