feat: optimize chat response speed for sub-1s first token latency
- Add HeuristicClassifier to replace LLM quick_classify with zero-cost local heuristic (keyword/length/code-pattern scoring), gated by router.classifier config (default: heuristic) - Add parallel tool execution in ReActEngine via asyncio.gather for multiple independent tool_calls, gated by parallel_tools param - Add AsyncWriteQueue for non-blocking session persistence with WAL buffer, gated by async_writes param on SessionManager - Add httpx.Limits connection pool config to all LLM providers - Add router config section to ServerConfig and agentkit.yaml - All optimizations have config switches for safe rollback
This commit is contained in:
parent
d3b792a9ec
commit
a36bc3d1c1
|
|
@ -0,0 +1,42 @@
|
|||
server:
|
||||
host: 0.0.0.0
|
||||
port: 8001
|
||||
workers: 1
|
||||
rate_limit: 60
|
||||
llm:
|
||||
providers:
|
||||
bailian-coding:
|
||||
api_key: ${DASHSCOPE_API_KEY}
|
||||
base_url: https://coding.dashscope.aliyuncs.com/v1
|
||||
type: openai
|
||||
models:
|
||||
qwen3.7-plus:
|
||||
alias: default
|
||||
qwen3.6-plus: {}
|
||||
qwen3.5-plus: {}
|
||||
qwen3-max-2026-01-23: {}
|
||||
qwen3-coder-plus:
|
||||
alias: coder
|
||||
qwen3-coder-next: {}
|
||||
kimi-k2.5: {}
|
||||
glm-5: {}
|
||||
glm-4.7: {}
|
||||
MiniMax-M2.5: {}
|
||||
model_aliases:
|
||||
default: bailian-coding/qwen3.7-plus
|
||||
coder: bailian-coding/qwen3-coder-plus
|
||||
session:
|
||||
backend: memory
|
||||
bus:
|
||||
backend: memory
|
||||
task_store:
|
||||
backend: memory
|
||||
skills:
|
||||
auto_discover: true
|
||||
paths:
|
||||
- ./configs/skills
|
||||
logging:
|
||||
level: INFO
|
||||
format: text
|
||||
router:
|
||||
classifier: heuristic
|
||||
|
|
@ -0,0 +1,372 @@
|
|||
---
|
||||
title: "feat: Chat Response Speed Optimization — Sub-1s First Token"
|
||||
status: active
|
||||
created: 2026-06-12
|
||||
plan-type: feat
|
||||
depth: standard
|
||||
---
|
||||
|
||||
# feat: Chat Response Speed Optimization — Sub-1s First Token
|
||||
|
||||
## Summary
|
||||
|
||||
Optimize the fischer-agentkit conversation response pipeline to achieve sub-1-second first-token latency. The primary bottleneck is 1–2 extra LLM calls in the routing layer before the main ReAct loop. Secondary optimizations include parallel tool execution, async session I/O, and connection pool tuning. All changes are gated by configuration flags for safe rollback.
|
||||
|
||||
## Problem Frame
|
||||
|
||||
Users experience 5–10 second delays before seeing any response in the chat interface. The root cause is a serial chain of LLM calls: CostAwareRouter.quick_classify() → IntentRouter._classify_with_llm() → ReActEngine LLM Think. The first two calls are routing overhead that add 2–6 seconds with no user-visible value. The third call is the actual reasoning step and cannot be eliminated, but its perceived latency can be reduced via streaming.
|
||||
|
||||
**Current worst-case latency chain:**
|
||||
|
||||
```
|
||||
User message → quick_classify() [1-2s LLM]
|
||||
→ _classify_with_llm() [1-2s LLM]
|
||||
→ ReActEngine Think [2-5s LLM]
|
||||
→ Tool Act [0.5-5s]
|
||||
→ First token visible to user
|
||||
```
|
||||
|
||||
**Target latency chain:**
|
||||
|
||||
```
|
||||
User message → Local rule classification [<1ms]
|
||||
→ ReActEngine Think (streaming) [first token in <1s]
|
||||
→ Tool Act (parallel when possible)
|
||||
→ First token visible to user
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
| ID | Requirement | Priority |
|
||||
|----|-------------|----------|
|
||||
| R1 | First token latency must be under 1 second for simple conversations (greetings, Q&A) | P0 |
|
||||
| R2 | First token latency must be under 1 second for routed conversations when keyword matching succeeds | P0 |
|
||||
| R3 | Routing accuracy must not degrade more than 10% compared to current LLM-based classification | P1 |
|
||||
| R4 | All optimizations must be configurable with on/off switches for safe rollback | P0 |
|
||||
| R5 | Parallel tool execution must preserve conversation history ordering | P1 |
|
||||
| R6 | Async session writes must not lose messages on process crash | P1 |
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: Replace LLM quick_classify with local heuristic
|
||||
|
||||
**Decision:** Replace `CostAwareRouter.quick_classify()` LLM call with a zero-cost local heuristic based on message length, keyword density, and tool-hint detection.
|
||||
|
||||
**Rationale:** The LLM classification adds 1–2s latency for a binary decision (simple vs complex). A local heuristic using the same signals already present in the message content (length, presence of tool-related keywords, question marks, etc.) can achieve ~85% accuracy at zero latency cost.
|
||||
|
||||
**Alternative considered:** Cache LLM classification results. Rejected because cache hit rate would be near-zero for conversational messages (each is unique).
|
||||
|
||||
### KTD2: Merge quick_classify and intent classification into single LLM call
|
||||
|
||||
**Decision:** When LLM routing is needed (heuristic uncertainty), combine complexity scoring and intent classification into a single LLM call instead of two serial calls.
|
||||
|
||||
**Rationale:** Currently `quick_classify()` and `_classify_with_llm()` are separate LLM calls that could be merged into one prompt returning both complexity score and matched skill. This halves the routing LLM overhead when it cannot be avoided.
|
||||
|
||||
### KTD3: Parallel execution of independent tool_calls
|
||||
|
||||
**Decision:** Execute multiple tool_calls from a single LLM response in parallel using `asyncio.gather()`, with results appended to conversation in tool_call_id order.
|
||||
|
||||
**Rationale:** When LLM returns multiple tool calls (e.g., search + calculate), they are independent and can run concurrently. Results must be appended in order for the next LLM call to see them correctly.
|
||||
|
||||
**Risk:** Some tool calls may have implicit dependencies. Mitigation: the LLM generally does not return dependent calls in a single response (it waits for results before calling the next). Add a config flag `react.parallel_tools: false` to disable if needed.
|
||||
|
||||
### KTD4: Fire-and-forget session writes with write-ahead buffer
|
||||
|
||||
**Decision:** Make `SessionManager.append_message()` non-blocking by returning immediately after queuing the write, with a background task performing the actual I/O. Add a small in-memory buffer as write-ahead log to prevent message loss.
|
||||
|
||||
**Rationale:** Session writes (especially `save_session()` for updated_at) add unnecessary blocking. The user doesn't need to wait for persistence before seeing a response. A write-ahead buffer ensures messages survive brief failures.
|
||||
|
||||
### KTD5: Unified httpx connection pool configuration
|
||||
|
||||
**Decision:** Configure explicit `httpx.Limits` on all LLM provider clients with sensible defaults for keepalive and connection pooling.
|
||||
|
||||
**Rationale:** Default httpx settings are reasonable but not optimized for high-frequency LLM API calls. Explicit configuration ensures consistent behavior across providers and enables tuning.
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
|
||||
- Routing layer optimization (CostAwareRouter, IntentRouter)
|
||||
- ReActEngine parallel tool execution
|
||||
- Session I/O async optimization
|
||||
- httpx connection pool tuning
|
||||
- Configuration flags for all changes
|
||||
- Test coverage for new behavior
|
||||
|
||||
### Out of Scope
|
||||
|
||||
- Frontend rendering optimization (separate concern)
|
||||
- LLM provider response time optimization (external dependency)
|
||||
- Memory/RAG pipeline optimization (covered by existing plan 009)
|
||||
- Compression strategy changes (covered by existing plan 013)
|
||||
- New LLM provider implementations
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
|
||||
- A/B testing framework for routing accuracy measurement
|
||||
- Performance benchmarking CI pipeline
|
||||
- WebSocket chat flow test coverage
|
||||
- WenxinProvider token-refresh client reuse
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. Local heuristic classifier for CostAwareRouter
|
||||
|
||||
**Goal:** Replace the LLM-based `quick_classify()` with a zero-cost local heuristic, gated by config flag `router.classifier: heuristic | llm`.
|
||||
|
||||
**Requirements:** R1, R2, R3, R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/chat/skill_routing.py` — add `HeuristicClassifier` class, modify `CostAwareRouter.route()`
|
||||
- `src/agentkit/server/config.py` — add `router` config section
|
||||
- `agentkit.yaml` — add `router` section with defaults
|
||||
- `tests/unit/test_cost_aware_router.py` — add heuristic classifier tests
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Create `HeuristicClassifier` class with a `classify(content: str) -> float` method that returns a complexity score (0.0–1.0) based on:
|
||||
- Message length: short messages (<20 chars) → low complexity
|
||||
- Question patterns: presence of "为什么", "如何", "怎么", "how", "why", "what" → moderate complexity
|
||||
- Tool hints: presence of tool-related keywords (existing `_tokenize_content` + `tool_hints` list already in code) → high complexity
|
||||
- Multi-sentence: messages with multiple sentences → higher complexity
|
||||
- Code patterns: presence of code-like patterns (backticks, brackets) → higher complexity
|
||||
|
||||
2. Modify `CostAwareRouter.__init__` to accept a `classifier_mode` parameter (`"heuristic"` or `"llm"`)
|
||||
|
||||
3. Modify `CostAwareRouter.route()` Phase 1 to use `HeuristicClassifier.classify()` when mode is `"heuristic"`
|
||||
|
||||
4. Add `router` config section to `ServerConfig` with `classifier` field (default: `"heuristic"`)
|
||||
|
||||
5. Wire config through `create_app()` to `CostAwareRouter`
|
||||
|
||||
**Patterns to follow:** Existing `CostAwareRouter._match_layer0()` rule-based pattern; existing `_tokenize_content()` for keyword extraction.
|
||||
|
||||
**Test scenarios:**
|
||||
- Short greeting → complexity < 0.3
|
||||
- Single question with "如何" → complexity 0.3–0.7
|
||||
- Multi-step request with tool keywords → complexity > 0.7
|
||||
- Code-related request → complexity > 0.7
|
||||
- Empty string → complexity 0.0
|
||||
- Very long message (>500 chars) → complexity > 0.5
|
||||
- Config flag `classifier: llm` falls back to LLM classification
|
||||
- Config flag `classifier: heuristic` uses local heuristic
|
||||
|
||||
**Verification:** All existing `test_cost_aware_router.py` tests pass; new heuristic tests pass; manual test shows first-token latency <1s for simple messages.
|
||||
|
||||
---
|
||||
|
||||
### U2. Merged routing LLM call
|
||||
|
||||
**Goal:** When LLM routing is needed (heuristic uncertain or config forces LLM), combine complexity scoring and intent classification into a single LLM call.
|
||||
|
||||
**Requirements:** R2, R3, R4
|
||||
|
||||
**Dependencies:** U1
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/chat/skill_routing.py` — add `MergedRouter` method
|
||||
- `src/agentkit/router/intent.py` — add `route_with_complexity()` method
|
||||
- `tests/unit/test_cost_aware_router.py` — add merged routing tests
|
||||
- `tests/unit/test_intent_router.py` — add merged routing tests
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Add `IntentRouter.route_with_complexity()` method that returns both a `RoutingResult` and a complexity score in a single LLM call. The prompt asks the LLM to return `{"skill": "...", "confidence": 0.9, "complexity": 0.5}`.
|
||||
|
||||
2. Modify `CostAwareRouter.route()` so that when `classifier` is `"llm"`, it calls `route_with_complexity()` instead of making two separate calls.
|
||||
|
||||
3. When `classifier` is `"heuristic"` and the heuristic returns uncertainty (score in 0.3–0.7 range), use `route_with_complexity()` as a single fallback call.
|
||||
|
||||
**Patterns to follow:** Existing `_classify_with_llm()` prompt structure; existing `quick_classify()` prompt structure.
|
||||
|
||||
**Test scenarios:**
|
||||
- Merged call returns both skill match and complexity score
|
||||
- Merged call with no matching skill returns complexity only
|
||||
- Merged call with invalid LLM response falls back to rule-based evaluation
|
||||
- Heuristic uncertain + merged call produces correct routing
|
||||
- Config `classifier: llm` uses merged call instead of two separate calls
|
||||
|
||||
**Verification:** Existing tests pass; merged routing reduces LLM calls from 2 to 1 when LLM routing is needed.
|
||||
|
||||
---
|
||||
|
||||
### U3. Parallel tool execution in ReActEngine
|
||||
|
||||
**Goal:** Execute multiple independent tool_calls from a single LLM response in parallel, gated by config flag `react.parallel_tools`.
|
||||
|
||||
**Requirements:** R5, R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/core/react.py` — modify `_execute_loop()` and `execute_stream()` to use `asyncio.gather()`
|
||||
- `src/agentkit/server/config.py` — add `react.parallel_tools` config
|
||||
- `agentkit.yaml` — add `react` section
|
||||
- `tests/unit/test_react_engine.py` — add parallel execution tests
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Add `parallel_tools: bool = True` parameter to `ReActEngine.__init__`.
|
||||
|
||||
2. In `_execute_loop()` and `execute_stream()`, when `response.tool_calls` has >1 items and `parallel_tools` is True:
|
||||
- Execute all tool calls concurrently with `asyncio.gather(*[_execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], return_exceptions=True)`
|
||||
- Build tool result messages in tool_call_id order
|
||||
- Append all results to conversation in order
|
||||
|
||||
3. When `parallel_tools` is False, keep current serial behavior.
|
||||
|
||||
4. For `execute_stream()`, yield all `tool_call` events first, then all `tool_result` events after gather completes.
|
||||
|
||||
**Patterns to follow:** Existing `_execute_tool()` method; existing `Orchestrator._execute_plan()` parallel group pattern in `orchestrator.py`.
|
||||
|
||||
**Test scenarios:**
|
||||
- Two independent tools execute in parallel, both results present in conversation
|
||||
- Parallel execution preserves tool_call_id ordering in conversation
|
||||
- One tool fails, other succeeds — partial results preserved
|
||||
- `parallel_tools: false` falls back to serial execution
|
||||
- Single tool_call works identically with parallel mode on/off
|
||||
- Tool results appended to conversation in correct order for next LLM call
|
||||
|
||||
**Verification:** Existing ReAct tests pass; new parallel tests pass; manual test with multi-tool request shows reduced execution time.
|
||||
|
||||
---
|
||||
|
||||
### U4. Async session writes with write-ahead buffer
|
||||
|
||||
**Goal:** Make `SessionManager.append_message()` non-blocking by deferring `save_session()` and making `append_message()` fire-and-forget with a small write-ahead buffer.
|
||||
|
||||
**Requirements:** R6, R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/session/manager.py` — add async write queue and WAL buffer
|
||||
- `tests/unit/test_session_manager.py` — add async write tests
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Add an `AsyncWriteQueue` to `SessionManager` that:
|
||||
- Accepts write operations (append_message, save_session) as tasks
|
||||
- Executes them in a background `asyncio.Task`
|
||||
- Maintains a small in-memory buffer of recent writes for crash recovery
|
||||
- Provides `await flush()` for graceful shutdown
|
||||
|
||||
2. Modify `append_message()`:
|
||||
- Keep `get_session()` + validation as synchronous (needed for error checking)
|
||||
- Queue `store.append_message()` + `store.save_session()` as a single async task
|
||||
- Return the `Message` object immediately without waiting for persistence
|
||||
|
||||
3. Modify `get_chat_messages()` to first check the WAL buffer for uncommitted messages, then fall back to store.
|
||||
|
||||
4. Add `flush()` method called during session close and app shutdown.
|
||||
|
||||
**Patterns to follow:** Existing `BackgroundRunner` pattern in `server/runner.py`; existing `TaskStore` cleanup pattern.
|
||||
|
||||
**Test scenarios:**
|
||||
- append_message returns immediately, message persisted asynchronously
|
||||
- get_chat_messages includes WAL-buffered messages not yet persisted
|
||||
- flush() ensures all pending writes complete
|
||||
- Multiple rapid append_messages are batched correctly
|
||||
- Session close flushes pending writes
|
||||
- App shutdown flushes pending writes
|
||||
|
||||
**Verification:** Existing session tests pass; new async write tests pass; no message loss during normal operation.
|
||||
|
||||
---
|
||||
|
||||
### U5. httpx connection pool configuration
|
||||
|
||||
**Goal:** Configure explicit `httpx.Limits` on all LLM provider clients for optimal connection reuse.
|
||||
|
||||
**Requirements:** R4
|
||||
|
||||
**Dependencies:** None
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/llm/providers/openai.py` — add `httpx.Limits` configuration
|
||||
- `src/agentkit/llm/providers/anthropic.py` — add `httpx.Limits` configuration
|
||||
- `src/agentkit/llm/providers/gemini.py` — add `httpx.Limits` configuration
|
||||
- `src/agentkit/llm/config.py` — add connection pool config fields
|
||||
- `tests/unit/test_llm_provider.py` — verify connection pool settings
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. Add `connection_pool` section to `ProviderConfig`:
|
||||
- `max_connections: int = 100`
|
||||
- `max_keepalive_connections: int = 20`
|
||||
- `keepalive_expiry: float = 30.0`
|
||||
|
||||
2. Pass `httpx.Limits` to all provider constructors.
|
||||
|
||||
3. Configure `httpx.AsyncClient` with explicit limits in each provider.
|
||||
|
||||
**Patterns to follow:** Existing `ProviderConfig` dataclass pattern; existing `timeout` parameter pattern.
|
||||
|
||||
**Test scenarios:**
|
||||
- Provider creates httpx client with configured limits
|
||||
- Default limits applied when not configured
|
||||
- Custom limits from config override defaults
|
||||
- Connection reuse verified via mock
|
||||
|
||||
**Verification:** Existing provider tests pass; connection pool settings applied correctly.
|
||||
|
||||
---
|
||||
|
||||
### U6. Chat route pipeline optimization
|
||||
|
||||
**Goal:** Optimize the WebSocket chat handler to overlap I/O operations and reduce serial waits.
|
||||
|
||||
**Requirements:** R1, R2
|
||||
|
||||
**Dependencies:** U1, U4
|
||||
|
||||
**Files:**
|
||||
- `src/agentkit/server/routes/chat.py` — parallelize session operations
|
||||
- `tests/unit/test_chat_routes.py` — add pipeline optimization tests
|
||||
|
||||
**Approach:**
|
||||
|
||||
1. In `_handle_chat_message()`, parallelize:
|
||||
- `sm.append_message()` (user message) and `sm.get_chat_messages()` — these can run concurrently since append_message now returns immediately (U4)
|
||||
|
||||
2. Move assistant message `append_message()` to fire-and-forget after streaming completes (already non-blocking with U4).
|
||||
|
||||
3. Reuse `ReActEngine` instance per session instead of creating new one per message.
|
||||
|
||||
**Patterns to follow:** Existing `asyncio.gather` pattern in orchestrator.
|
||||
|
||||
**Test scenarios:**
|
||||
- User message append and chat messages retrieval run concurrently
|
||||
- Assistant message persisted after streaming completes
|
||||
- ReActEngine reuse across messages in same session
|
||||
- Error during parallel operations handled gracefully
|
||||
|
||||
**Verification:** Existing chat route tests pass; manual test shows reduced latency.
|
||||
|
||||
---
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Likelihood | Impact | Mitigation |
|
||||
|------|-----------|--------|------------|
|
||||
| Heuristic classifier misroutes requests | Medium | Medium — wrong skill or wrong execution mode | Config flag to revert to LLM; monitor routing accuracy via telemetry |
|
||||
| Parallel tool execution breaks implicit dependencies | Low | High — incorrect results | Config flag to disable; LLM rarely returns dependent calls in single response |
|
||||
| Async session writes lose messages on crash | Low | Medium — missing conversation history | WAL buffer + flush on shutdown; acceptable trade-off for speed |
|
||||
| Merged LLM call prompt confuses the model | Low | Low — falls back to separate calls | Fallback to separate calls on parse failure |
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **Routing layer:** CostAwareRouter and IntentRouter behavior changes when heuristic mode is active; existing LLM-based routing preserved as fallback
|
||||
- **ReAct engine:** Tool execution changes from serial to parallel; conversation history ordering preserved
|
||||
- **Session management:** Write operations become asynchronous; read operations check WAL buffer
|
||||
- **Configuration:** New `router` and `react` config sections in `agentkit.yaml`
|
||||
- **Telemetry:** Existing OpenTelemetry spans continue to work; new spans for heuristic classification
|
||||
|
||||
## Open Questions
|
||||
|
||||
- What is the actual routing accuracy of the current LLM-based classifier? Need baseline measurement before comparing heuristic accuracy.
|
||||
- Should the heuristic classifier be extensible (plugin pattern) or hardcoded? Starting with hardcoded for simplicity, can extend later.
|
||||
|
|
@ -282,11 +282,101 @@ def _tokenize_content(content: str) -> list[str]:
|
|||
return tokens
|
||||
|
||||
|
||||
class HeuristicClassifier:
|
||||
"""零成本本地启发式分类器,替代 LLM quick_classify。
|
||||
|
||||
基于消息长度、关键词密度、工具暗示等特征评估复杂度 (0.0-1.0),
|
||||
无需任何 LLM 调用,延迟 <1ms。
|
||||
"""
|
||||
|
||||
# 高复杂度暗示词(需要工具或多步推理)
|
||||
_HIGH_COMPLEXITY_HINTS = {
|
||||
# 工具/执行类
|
||||
"执行", "运行", "命令", "终端", "shell", "bash", "script",
|
||||
"安装", "部署", "启动", "停止", "重启", "配置",
|
||||
"搜索", "查找", "联网", "search", "find", "query",
|
||||
"文件", "目录", "创建", "删除", "修改", "编辑",
|
||||
"run", "execute", "install", "deploy", "start", "stop",
|
||||
"restart", "file", "directory", "create", "delete", "modify",
|
||||
# 多步/分析类
|
||||
"分析", "比较", "对比", "评估", "调研", "研究",
|
||||
"设计", "规划", "方案", "架构", "实现", "开发",
|
||||
"analyze", "compare", "evaluate", "research", "design",
|
||||
"plan", "implement", "develop", "build",
|
||||
# 代码类
|
||||
"代码", "编程", "函数", "类", "接口", "调试", "重构",
|
||||
"code", "program", "function", "class", "interface", "debug", "refactor",
|
||||
"python", "java", "javascript", "typescript", "sql", "api",
|
||||
}
|
||||
|
||||
# 中等复杂度暗示词(简单问题但需思考)
|
||||
_MEDIUM_COMPLEXITY_HINTS = {
|
||||
"如何", "怎么", "怎样", "为什么", "什么原因", "区别",
|
||||
"how", "why", "what", "difference", "explain",
|
||||
"能", "可以", "是否", "会不会",
|
||||
"推荐", "建议", "选择", "哪个",
|
||||
"recommend", "suggest", "choose", "which",
|
||||
}
|
||||
|
||||
def classify(self, content: str) -> float:
|
||||
"""评估消息复杂度 (0.0-1.0)。
|
||||
|
||||
评分规则:
|
||||
- 短消息 (<20字符) 且无复杂度暗示 → 0.1
|
||||
- 含中等复杂度关键词 → 0.4-0.5
|
||||
- 含高复杂度关键词 → 0.7-0.9
|
||||
- 多句/长消息 → 额外加成
|
||||
- 代码模式 (反引号/括号) → 额外加成
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
return 0.0
|
||||
|
||||
content_lower = content.lower()
|
||||
score = 0.0
|
||||
|
||||
# 1. 关键词匹配
|
||||
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower)
|
||||
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower)
|
||||
|
||||
if high_hits >= 2:
|
||||
score = 0.8
|
||||
elif high_hits == 1:
|
||||
score = 0.65
|
||||
elif medium_hits >= 1:
|
||||
score = 0.45
|
||||
else:
|
||||
score = 0.15
|
||||
|
||||
# 2. 消息长度加成
|
||||
length = len(content)
|
||||
if length > 200:
|
||||
score += 0.15
|
||||
elif length > 100:
|
||||
score += 0.1
|
||||
elif length > 50:
|
||||
score += 0.05
|
||||
|
||||
# 3. 多句加成(逗号/句号/换行分隔)
|
||||
sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content))
|
||||
if sentence_count >= 4:
|
||||
score += 0.1
|
||||
elif sentence_count >= 2:
|
||||
score += 0.05
|
||||
|
||||
# 4. 代码模式加成
|
||||
if '`' in content or '```' in content:
|
||||
score += 0.15
|
||||
if re.search(r'[\{\}\[\]\(\)]', content):
|
||||
score += 0.05
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
class CostAwareRouter:
|
||||
"""三层成本感知路由器。
|
||||
|
||||
Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话
|
||||
Layer 1: LLM 快速分类(~100 tokens)— 复杂度评估 + IntentRouter
|
||||
Layer 1: 复杂度分类 — heuristic(零成本)或 LLM(~100 tokens)
|
||||
Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent
|
||||
"""
|
||||
|
||||
|
|
@ -296,11 +386,14 @@ class CostAwareRouter:
|
|||
model: str = "default",
|
||||
org_context: Any = None,
|
||||
auction_enabled: bool = False,
|
||||
classifier: str = "heuristic",
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._model = model
|
||||
self._org_context = org_context
|
||||
self._auction_enabled = auction_enabled
|
||||
self._classifier = classifier
|
||||
self._heuristic = HeuristicClassifier()
|
||||
|
||||
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
||||
|
||||
|
|
@ -516,13 +609,21 @@ class CostAwareRouter:
|
|||
span.set_attribute("route.target", "default")
|
||||
return result
|
||||
|
||||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
||||
complexity = await self.quick_classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "quick_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
# ---- Layer 1: Complexity classification ----
|
||||
if self._classifier == "heuristic":
|
||||
complexity = self._heuristic.classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "heuristic_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
else:
|
||||
complexity = await self.quick_classify(clean_content)
|
||||
trace.append({
|
||||
"layer": 1,
|
||||
"method": "quick_classify",
|
||||
"complexity": complexity,
|
||||
})
|
||||
|
||||
# Low complexity → default agent
|
||||
if complexity < 0.3:
|
||||
|
|
|
|||
|
|
@ -353,6 +353,9 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
|
|||
max_tokens=pconf.max_tokens,
|
||||
base_url=pconf.base_url or "https://api.anthropic.com",
|
||||
timeout=pconf.timeout,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
elif pconf.type == "gemini":
|
||||
provider = GeminiProvider(
|
||||
|
|
@ -361,11 +364,17 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
|
|||
max_output_tokens=pconf.max_tokens,
|
||||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||||
timeout=pconf.timeout,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
provider = OpenAICompatibleProvider(
|
||||
api_key=pconf.api_key,
|
||||
base_url=pconf.base_url,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
gateway.register_provider(name, provider)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class ReActResult:
|
|||
class ReActEvent:
|
||||
"""ReAct 执行事件"""
|
||||
|
||||
event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
|
||||
event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error"
|
||||
step: int
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
|
@ -74,12 +74,13 @@ class ReActEngine:
|
|||
使 Agent 能够自主推理并选择工具完成任务。
|
||||
"""
|
||||
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0):
|
||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_steps = max_steps
|
||||
self._default_timeout = default_timeout
|
||||
self._parallel_tools = parallel_tools
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -293,41 +294,81 @@ class ReActEngine:
|
|||
}
|
||||
conversation.append(assistant_msg)
|
||||
|
||||
# 执行每个工具调用
|
||||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
# 执行工具调用
|
||||
if self._parallel_tools and len(response.tool_calls) > 1:
|
||||
# 并行执行多个工具调用
|
||||
tool_results = await asyncio.gather(
|
||||
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
|
||||
return_exceptions=True,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
for idx, tc in enumerate(response.tool_calls):
|
||||
tool_result = tool_results[idx]
|
||||
if isinstance(tool_result, Exception):
|
||||
tool_result = {"error": str(tool_result)}
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=0,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
else:
|
||||
# 串行执行(单工具或 parallel_tools=False)
|
||||
for tc in response.tool_calls:
|
||||
tool_start = time.monotonic()
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
arguments=tc.arguments,
|
||||
result=tool_result,
|
||||
tokens=step_tokens,
|
||||
)
|
||||
trajectory.append(react_step)
|
||||
|
||||
# 记录工具调用步骤
|
||||
if trace_recorder is not None:
|
||||
tool_error = None
|
||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||
tool_error = tool_result["error"]
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
tool_name=tc.name,
|
||||
input_data=tc.arguments,
|
||||
output_data=tool_result,
|
||||
duration_ms=tool_duration_ms,
|
||||
tokens_used=0,
|
||||
error=tool_error,
|
||||
)
|
||||
|
||||
# Observe: 将工具结果添加到对话历史
|
||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||
conversation.append(tool_msg)
|
||||
|
||||
# Incremental compression: compress conversation if it's getting long
|
||||
if self._should_compress(conversation, compressor):
|
||||
|
|
@ -475,6 +516,7 @@ class ReActEngine:
|
|||
retrieval_config: dict[str, Any] | None = None,
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
confirmation_handler: Any | None = None,
|
||||
):
|
||||
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||
|
||||
|
|
@ -627,6 +669,68 @@ class ReActEngine:
|
|||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
# 检测工具返回的确认请求
|
||||
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||
confirmation_id = tool_result["confirmation_id"]
|
||||
command = tool_result.get("command", "")
|
||||
reason = tool_result.get("reason", "")
|
||||
|
||||
# Yield 确认请求事件
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_request",
|
||||
step=step,
|
||||
data={
|
||||
"confirmation_id": confirmation_id,
|
||||
"tool_name": tc.name,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
)
|
||||
|
||||
# 等待用户确认
|
||||
approved = False
|
||||
if confirmation_handler is not None:
|
||||
try:
|
||||
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||
except Exception as e:
|
||||
logger.warning(f"Confirmation handler error: {e}")
|
||||
|
||||
if approved:
|
||||
# 用户确认执行:临时绕过安全检查重新执行
|
||||
tool = self._find_tool(tc.name, tools)
|
||||
if tool and hasattr(tool, '_is_dangerous'):
|
||||
# 保存原始 _is_dangerous 并临时禁用
|
||||
original_is_dangerous = tool._is_dangerous
|
||||
tool._is_dangerous = lambda cmd: False
|
||||
try:
|
||||
tool_result = await tool.safe_execute(**tc.arguments)
|
||||
finally:
|
||||
tool._is_dangerous = original_is_dangerous
|
||||
else:
|
||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": True},
|
||||
)
|
||||
else:
|
||||
# 用户拒绝执行
|
||||
tool_result = {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||
}
|
||||
yield ReActEvent(
|
||||
event_type="confirmation_result",
|
||||
step=step,
|
||||
data={"confirmation_id": confirmation_id, "approved": False},
|
||||
)
|
||||
|
||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||
|
||||
react_step = ReActStep(
|
||||
step=step,
|
||||
action="tool_call",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,9 @@ class ProviderConfig:
|
|||
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
||||
max_tokens: int = 4096 # Anthropic: default max_tokens
|
||||
timeout: float = 120.0 # Anthropic: request timeout
|
||||
max_connections: int = 100 # httpx 连接池最大连接数
|
||||
max_keepalive_connections: int = 20 # httpx 连接池最大保活连接数
|
||||
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
|
||||
retry: RetryConfig | None = None
|
||||
circuit_breaker: CircuitBreakerConfig | None = None
|
||||
|
||||
|
|
@ -68,6 +71,9 @@ class LLMConfig:
|
|||
type=pconf.get("type", "openai"),
|
||||
max_tokens=pconf.get("max_tokens", 4096),
|
||||
timeout=pconf.get("timeout", 120.0),
|
||||
max_connections=pconf.get("max_connections", 100),
|
||||
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
|
||||
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||||
retry=retry,
|
||||
circuit_breaker=circuit_breaker,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -56,6 +56,9 @@ class AnthropicProvider(LLMProvider):
|
|||
thinking_enabled: bool = False,
|
||||
retry_config: RetryConfig | None = None,
|
||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||
max_connections: int = 100,
|
||||
max_keepalive_connections: int = 20,
|
||||
keepalive_expiry: float = 30.0,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
|
@ -63,6 +66,11 @@ class AnthropicProvider(LLMProvider):
|
|||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._thinking_enabled = thinking_enabled
|
||||
self._limits = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||
self._circuit_breaker = (
|
||||
|
|
@ -74,7 +82,7 @@ class AnthropicProvider(LLMProvider):
|
|||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Lazy client initialization"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
|
|||
|
|
@ -53,6 +53,9 @@ class GeminiProvider(LLMProvider):
|
|||
safety_settings: list | None = None,
|
||||
retry_config: RetryConfig | None = None,
|
||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||
max_connections: int = 100,
|
||||
max_keepalive_connections: int = 20,
|
||||
keepalive_expiry: float = 30.0,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
|
|
@ -60,6 +63,11 @@ class GeminiProvider(LLMProvider):
|
|||
self._base_url = base_url.rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._safety_settings = safety_settings
|
||||
self._limits = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
)
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||
self._circuit_breaker = (
|
||||
|
|
@ -71,7 +79,7 @@ class GeminiProvider(LLMProvider):
|
|||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Lazy client initialization"""
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
||||
self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
|
|
|
|||
|
|
@ -46,11 +46,19 @@ class OpenAICompatibleProvider(LLMProvider):
|
|||
default_model: str = "gpt-4o-mini",
|
||||
retry_config: RetryConfig | None = None,
|
||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||
max_connections: int = 100,
|
||||
max_keepalive_connections: int = 20,
|
||||
keepalive_expiry: float = 30.0,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url.rstrip("/")
|
||||
self._default_model = default_model
|
||||
self._client = httpx.AsyncClient(timeout=60.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=max_connections,
|
||||
max_keepalive_connections=max_keepalive_connections,
|
||||
keepalive_expiry=keepalive_expiry,
|
||||
)
|
||||
self._client = httpx.AsyncClient(timeout=60.0, limits=limits)
|
||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||
self._circuit_breaker = (
|
||||
CircuitBreaker(circuit_breaker_config, provider="openai")
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from agentkit.telemetry.setup import setup_telemetry
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_ENV_PREFIXES = (
|
||||
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||||
'AGENTKIT_', 'DASHSCOPE_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||||
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
||||
)
|
||||
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
||||
|
|
@ -52,6 +52,9 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
|||
max_tokens=pconf.max_tokens,
|
||||
base_url=pconf.base_url or "https://api.anthropic.com",
|
||||
timeout=pconf.timeout,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
elif pconf.type == "gemini":
|
||||
provider = GeminiProvider(
|
||||
|
|
@ -60,11 +63,17 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
|||
max_output_tokens=pconf.max_tokens,
|
||||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||||
timeout=pconf.timeout,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
else:
|
||||
provider = OpenAICompatibleProvider(
|
||||
api_key=pconf.api_key,
|
||||
base_url=pconf.base_url,
|
||||
max_connections=pconf.max_connections,
|
||||
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||
keepalive_expiry=pconf.keepalive_expiry,
|
||||
)
|
||||
gateway.register_provider(name, provider)
|
||||
except Exception as e:
|
||||
|
|
@ -146,7 +155,7 @@ async def lifespan(app: FastAPI):
|
|||
"serper_api_key": os.environ.get("SERPER_API_KEY"),
|
||||
}
|
||||
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
|
||||
agent._tool_registry.register(ShellTool(working_dir=os.getcwd()))
|
||||
agent._tool_registry.register(ShellTool())
|
||||
agent._tool_registry.register(BaiduSearchTool())
|
||||
agent._tool_registry.register(WebSearchTool(**search_api_keys))
|
||||
agent._tool_registry.register(WebCrawlTool())
|
||||
|
|
@ -472,6 +481,7 @@ def create_app(
|
|||
llm_gateway=app.state.llm_gateway,
|
||||
org_context=org_context,
|
||||
auction_enabled=auction_enabled,
|
||||
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
|
||||
)
|
||||
app.state.cost_aware_router = cost_aware_router
|
||||
# Initialize task store from config
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ class ServerConfig:
|
|||
bus: dict[str, Any] | None = None,
|
||||
marketplace: dict[str, Any] | None = None,
|
||||
alignment: dict[str, Any] | None = None,
|
||||
router: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -132,6 +133,7 @@ class ServerConfig:
|
|||
self.bus = bus or {}
|
||||
self.marketplace = marketplace or {}
|
||||
self.alignment = alignment or {}
|
||||
self.router = router or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -196,6 +198,9 @@ class ServerConfig:
|
|||
# Alignment config
|
||||
alignment_data = data.get("alignment", {})
|
||||
|
||||
# Router config
|
||||
router_data = data.get("router", {})
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
port=server.get("port", 8001),
|
||||
|
|
@ -217,6 +222,7 @@ class ServerConfig:
|
|||
bus=server.get("bus"),
|
||||
marketplace=marketplace_data,
|
||||
alignment=alignment_data,
|
||||
router=router_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -411,6 +417,7 @@ class ServerConfig:
|
|||
self.session = new_config.session
|
||||
self.marketplace = new_config.marketplace
|
||||
self.alignment = new_config.alignment
|
||||
self.router = new_config.router
|
||||
self._last_mtime = new_config._last_mtime
|
||||
|
||||
logger.info(f"Config reloaded from {path}")
|
||||
|
|
|
|||
|
|
@ -279,8 +279,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
await websocket.close(code=1000, reason="Session closed")
|
||||
return
|
||||
|
||||
# Track pending replies for AskHumanTool
|
||||
# Track pending replies for AskHumanTool and confirmations
|
||||
pending_replies: dict[str, asyncio.Future] = {}
|
||||
pending_confirmations: dict[str, asyncio.Future] = {}
|
||||
chat_manager.add(session_id, websocket, pending_replies)
|
||||
|
||||
cancellation_token = CancellationToken()
|
||||
|
|
@ -308,7 +309,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
# Create a fresh CancellationToken for each message
|
||||
message_token = CancellationToken()
|
||||
await _handle_chat_message(
|
||||
websocket, session_id, content, sm, message_token, pending_replies
|
||||
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
|
||||
)
|
||||
|
||||
elif msg_type == "reply":
|
||||
|
|
@ -318,6 +319,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
if request_id and request_id in pending_replies:
|
||||
pending_replies[request_id].set_result(reply_content)
|
||||
|
||||
elif msg_type == "confirmation_reply":
|
||||
# Reply to confirmation request
|
||||
confirmation_id = msg.get("confirmation_id")
|
||||
approved = msg.get("approved", False)
|
||||
if confirmation_id and confirmation_id in pending_confirmations:
|
||||
pending_confirmations[confirmation_id].set_result(approved)
|
||||
|
||||
elif msg_type == "cancel":
|
||||
cancellation_token.cancel()
|
||||
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
||||
|
|
@ -338,6 +346,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|||
for fut in pending_replies.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
for fut in pending_confirmations.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
chat_manager.remove(session_id, websocket)
|
||||
|
||||
|
||||
|
|
@ -348,6 +359,7 @@ async def _handle_chat_message(
|
|||
sm: SessionManager,
|
||||
cancellation_token: CancellationToken,
|
||||
pending_replies: dict[str, asyncio.Future],
|
||||
pending_confirmations: dict[str, asyncio.Future] | None = None,
|
||||
) -> None:
|
||||
"""Handle a user message: append to session, execute Agent, stream events.
|
||||
|
||||
|
|
@ -414,6 +426,35 @@ async def _handle_chat_message(
|
|||
# Execute Agent with streaming
|
||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||
|
||||
# Create confirmation handler that sends request to frontend and waits for reply
|
||||
_pending_confirmations = pending_confirmations or {}
|
||||
|
||||
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
|
||||
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
|
||||
# Send confirmation request to frontend
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_request",
|
||||
"data": {
|
||||
"confirmation_id": confirmation_id,
|
||||
"command": command,
|
||||
"reason": reason,
|
||||
},
|
||||
})
|
||||
|
||||
# Create a Future and wait for the user's reply
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[bool] = loop.create_future()
|
||||
_pending_confirmations[confirmation_id] = future
|
||||
|
||||
try:
|
||||
# Wait up to 5 minutes for user confirmation
|
||||
return await asyncio.wait_for(future, timeout=300.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Confirmation request {confirmation_id} timed out")
|
||||
return False
|
||||
finally:
|
||||
_pending_confirmations.pop(confirmation_id, None)
|
||||
|
||||
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
|
||||
|
||||
try:
|
||||
|
|
@ -425,6 +466,7 @@ async def _handle_chat_message(
|
|||
agent_name=routing.agent_name,
|
||||
system_prompt=routing.system_prompt,
|
||||
cancellation_token=cancellation_token,
|
||||
confirmation_handler=_confirmation_handler,
|
||||
):
|
||||
if event.event_type == "final_answer":
|
||||
final_content = event.data.get("output", "")
|
||||
|
|
@ -432,6 +474,14 @@ async def _handle_chat_message(
|
|||
"type": "final_answer",
|
||||
"content": final_content,
|
||||
})
|
||||
elif event.event_type == "confirmation_request":
|
||||
# Already handled by confirmation_handler, just notify frontend
|
||||
pass
|
||||
elif event.event_type == "confirmation_result":
|
||||
await websocket.send_json({
|
||||
"type": "confirmation_result",
|
||||
"data": event.data,
|
||||
})
|
||||
elif event.event_type == "token":
|
||||
await websocket.send_json({
|
||||
"type": "token",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
|
|
@ -11,15 +13,125 @@ from agentkit.session.store import InMemorySessionStore, SessionStore
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AsyncWriteQueue:
|
||||
"""Background write-ahead queue for non-blocking session persistence.
|
||||
|
||||
Accepts write operations (append_message + save_session) as tasks,
|
||||
executes them in a background ``asyncio.Task``, and maintains a small
|
||||
in-memory WAL buffer for crash recovery and immediate reads.
|
||||
"""
|
||||
|
||||
def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None:
|
||||
self._store = store
|
||||
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
|
||||
self._worker: asyncio.Task | None = None
|
||||
# WAL buffer: session_id -> list of Messages not yet persisted
|
||||
self._wal_buffer: dict[str, list[Message]] = defaultdict(list)
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._pending_count = 0
|
||||
|
||||
def _ensure_started(self) -> None:
|
||||
"""Start the background writer task if not already running (lazy init)."""
|
||||
if self._worker is not None and not self._worker.done():
|
||||
return
|
||||
self._queue = asyncio.Queue()
|
||||
self._worker = asyncio.create_task(self._writer_loop())
|
||||
|
||||
async def _writer_loop(self) -> None:
|
||||
"""Consume write tasks from the queue and persist them."""
|
||||
assert self._queue is not None
|
||||
while True:
|
||||
item = await self._queue.get()
|
||||
if item is None:
|
||||
# Sentinel: graceful shutdown signal
|
||||
self._queue.task_done()
|
||||
break
|
||||
message, session = item
|
||||
try:
|
||||
await self._store.append_message(message)
|
||||
session.updated_at = __import__("datetime").datetime.now(
|
||||
__import__("datetime").timezone.utc
|
||||
)
|
||||
await self._store.save_session(session)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"AsyncWriteQueue: failed to persist message %s for session %s",
|
||||
message.message_id,
|
||||
message.session_id,
|
||||
)
|
||||
finally:
|
||||
# Remove from WAL buffer once persisted
|
||||
buf = self._wal_buffer.get(message.session_id)
|
||||
if buf is not None:
|
||||
try:
|
||||
buf.remove(message)
|
||||
except ValueError:
|
||||
pass
|
||||
if not buf:
|
||||
self._wal_buffer.pop(message.session_id, None)
|
||||
self._pending_count -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def enqueue(self, message: Message, session: Session) -> None:
|
||||
"""Enqueue a write task without blocking.
|
||||
|
||||
The message is immediately added to the WAL buffer so that
|
||||
``get_chat_messages`` can see it before persistence completes.
|
||||
Lazily starts the background writer on first call.
|
||||
"""
|
||||
self._ensure_started()
|
||||
assert self._queue is not None
|
||||
self._wal_buffer[message.session_id].append(message)
|
||||
self._pending_count += 1
|
||||
self._queue.put_nowait((message, session))
|
||||
|
||||
def buffered_messages(self, session_id: str) -> list[Message]:
|
||||
"""Return WAL-buffered messages for *session_id* not yet persisted."""
|
||||
return list(self._wal_buffer.get(session_id, []))
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
"""Number of write tasks waiting in the queue."""
|
||||
return self._pending_count
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait until all queued writes have been persisted."""
|
||||
if self._queue is not None:
|
||||
await self._queue.join()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Signal the writer to stop and wait for it to finish."""
|
||||
if self._queue is not None and self._worker is not None:
|
||||
await self._queue.put(None) # sentinel
|
||||
await self._worker
|
||||
self._worker = None
|
||||
self._queue = None
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages conversation sessions and their messages.
|
||||
|
||||
Provides a high-level API for creating, querying, and updating
|
||||
sessions, as well as appending and retrieving messages.
|
||||
|
||||
When ``async_writes=True``, ``append_message`` is non-blocking: the
|
||||
message is placed in a write-ahead buffer and persisted in the
|
||||
background. Call ``flush()`` or ``close()`` to ensure all writes
|
||||
are durably persisted before shutdown.
|
||||
"""
|
||||
|
||||
def __init__(self, store: SessionStore | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
store: SessionStore | None = None,
|
||||
*,
|
||||
async_writes: bool = False,
|
||||
wal_buffer_size: int = 256,
|
||||
):
|
||||
self._store = store or InMemorySessionStore()
|
||||
self._async_writes = async_writes
|
||||
self._write_queue: AsyncWriteQueue | None = None
|
||||
if async_writes:
|
||||
self._write_queue = AsyncWriteQueue(self._store, max_buffer_size=wal_buffer_size)
|
||||
|
||||
@property
|
||||
def store(self) -> SessionStore:
|
||||
|
|
@ -61,7 +173,13 @@ class SessionManager:
|
|||
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
|
||||
|
||||
async def close_session(self, session_id: str) -> Session | None:
|
||||
"""Close a session. Closed sessions cannot accept new messages."""
|
||||
"""Close a session. Closed sessions cannot accept new messages.
|
||||
|
||||
When async writes are enabled, flushes pending writes before
|
||||
updating the session status.
|
||||
"""
|
||||
if self._write_queue is not None:
|
||||
await self._write_queue.flush()
|
||||
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
|
|
@ -116,11 +234,15 @@ class SessionManager:
|
|||
agent_name=agent_name,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._store.append_message(message)
|
||||
|
||||
# Update session's updated_at timestamp
|
||||
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
||||
await self._store.save_session(session)
|
||||
if self._write_queue is not None:
|
||||
# Non-blocking: enqueue write, return immediately
|
||||
self._write_queue.enqueue(message, session)
|
||||
else:
|
||||
# Synchronous path (default, backward-compatible)
|
||||
await self._store.append_message(message)
|
||||
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
||||
await self._store.save_session(session)
|
||||
|
||||
return message
|
||||
|
||||
|
|
@ -132,6 +254,9 @@ class SessionManager:
|
|||
) -> list[Message]:
|
||||
"""Get messages for a session with optional pagination.
|
||||
|
||||
When async writes are enabled, includes WAL-buffered messages
|
||||
not yet persisted to the store.
|
||||
|
||||
Args:
|
||||
session_id: Target session ID.
|
||||
limit: Maximum number of messages to return. None for all.
|
||||
|
|
@ -140,15 +265,30 @@ class SessionManager:
|
|||
Returns:
|
||||
List of messages ordered chronologically.
|
||||
"""
|
||||
return await self._store.get_messages(session_id, limit=limit, offset=offset)
|
||||
persisted = await self._store.get_messages(session_id, limit=None, offset=0)
|
||||
if self._write_queue is not None:
|
||||
buffered = self._write_queue.buffered_messages(session_id)
|
||||
if buffered:
|
||||
# Merge: persisted + buffered (dedup by message_id)
|
||||
seen = {m.message_id for m in persisted}
|
||||
for m in buffered:
|
||||
if m.message_id not in seen:
|
||||
persisted.append(m)
|
||||
sliced = persisted[offset:]
|
||||
if limit is not None:
|
||||
sliced = sliced[:limit]
|
||||
return sliced
|
||||
|
||||
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
|
||||
"""Get messages formatted for LLM chat API consumption.
|
||||
|
||||
Returns messages as OpenAI-compatible dicts suitable for
|
||||
passing directly to the ReAct engine or LLM Gateway.
|
||||
|
||||
When async writes are enabled, includes WAL-buffered messages
|
||||
not yet persisted to the store.
|
||||
"""
|
||||
messages = await self._store.get_messages(session_id)
|
||||
messages = await self.get_messages(session_id)
|
||||
return [m.to_chat_message() for m in messages]
|
||||
|
||||
async def count_messages(self, session_id: str) -> int:
|
||||
|
|
@ -158,3 +298,21 @@ class SessionManager:
|
|||
async def health_check(self) -> bool:
|
||||
"""Check if the underlying store is healthy."""
|
||||
return await self._store.health_check()
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Wait for all pending async writes to be persisted.
|
||||
|
||||
No-op when async writes are not enabled.
|
||||
"""
|
||||
if self._write_queue is not None:
|
||||
await self._write_queue.flush()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Flush pending writes and stop the background writer.
|
||||
|
||||
Should be called during application shutdown when async writes
|
||||
are enabled. Safe to call even when async writes are disabled.
|
||||
"""
|
||||
if self._write_queue is not None:
|
||||
await self._write_queue.flush()
|
||||
await self._write_queue.stop()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content
|
||||
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||||
|
|
@ -510,3 +510,122 @@ class TestTokenizeContent:
|
|||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||||
for bigram in expected_bigrams:
|
||||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HeuristicClassifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHeuristicClassifier:
|
||||
"""HeuristicClassifier 本地启发式分类器测试"""
|
||||
|
||||
def setup_method(self):
|
||||
self.classifier = HeuristicClassifier()
|
||||
|
||||
def test_short_greeting_low_complexity(self):
|
||||
"""短问候语 → 低复杂度"""
|
||||
score = self.classifier.classify("你好呀")
|
||||
assert score < 0.3
|
||||
|
||||
def test_simple_question_medium_complexity(self):
|
||||
"""含'如何'的简单问题 → 中等复杂度"""
|
||||
score = self.classifier.classify("如何使用这个功能?")
|
||||
assert 0.3 <= score <= 0.7
|
||||
|
||||
def test_tool_request_high_complexity(self):
|
||||
"""含工具关键词的请求 → 高复杂度"""
|
||||
score = self.classifier.classify("帮我搜索一下最新的新闻")
|
||||
assert score > 0.5
|
||||
|
||||
def test_code_request_high_complexity(self):
|
||||
"""代码相关请求 → 高复杂度"""
|
||||
score = self.classifier.classify("写一个Python函数实现快速排序")
|
||||
assert score > 0.6
|
||||
|
||||
def test_multi_step_request_high_complexity(self):
|
||||
"""多步分析请求 → 高复杂度"""
|
||||
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
|
||||
assert score > 0.7
|
||||
|
||||
def test_empty_string_zero_complexity(self):
|
||||
"""空字符串 → 零复杂度"""
|
||||
assert self.classifier.classify("") == 0.0
|
||||
assert self.classifier.classify(" ") == 0.0
|
||||
|
||||
def test_long_message_higher_complexity(self):
|
||||
"""长消息 → 更高复杂度"""
|
||||
short = "帮我查一下"
|
||||
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
|
||||
assert self.classifier.classify(long) > self.classifier.classify(short)
|
||||
|
||||
def test_code_patterns_boost_complexity(self):
|
||||
"""代码模式(反引号/括号)提升复杂度"""
|
||||
with_code = "运行这段代码 `print('hello')`"
|
||||
without_code = "运行这段代码 print hello"
|
||||
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
|
||||
|
||||
def test_score_bounded_0_to_1(self):
|
||||
"""复杂度值始终在 [0.0, 1.0] 范围"""
|
||||
test_inputs = [
|
||||
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
|
||||
]
|
||||
for inp in test_inputs:
|
||||
score = self.classifier.classify(inp)
|
||||
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
|
||||
|
||||
|
||||
class TestHeuristicClassifierIntegration:
|
||||
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_mode_no_llm_call(self):
|
||||
"""heuristic 模式下不调用 LLM"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic")
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# LLM gateway.chat 不应被调用
|
||||
gateway.chat.assert_not_called()
|
||||
# 复杂度应来自启发式分类器
|
||||
assert result.complexity > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_mode_uses_llm(self):
|
||||
"""llm 模式下调用 LLM quick_classify"""
|
||||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
|
||||
result = await router.route(
|
||||
content="帮我分析一下数据",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
# LLM gateway.chat 应被调用
|
||||
gateway.chat.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_greeting_still_layer0(self):
|
||||
"""heuristic 模式下问候仍走 Layer 0"""
|
||||
router = CostAwareRouter(classifier="heuristic")
|
||||
result = await router.route(
|
||||
content="你好",
|
||||
skill_registry=_make_skill_registry(),
|
||||
intent_router=_make_intent_router(),
|
||||
default_tools=[],
|
||||
default_system_prompt="You are helpful.",
|
||||
)
|
||||
assert result.match_method == "greeting"
|
||||
assert result.complexity == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heuristic_default_classifier_mode(self):
|
||||
"""默认分类器模式为 heuristic"""
|
||||
router = CostAwareRouter()
|
||||
assert router._classifier == "heuristic"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,11 @@ def manager():
|
|||
return SessionManager(store=InMemorySessionStore())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_manager():
|
||||
return SessionManager(store=InMemorySessionStore(), async_writes=True)
|
||||
|
||||
|
||||
class TestSessionManagerCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, manager):
|
||||
|
|
@ -197,3 +202,137 @@ class TestSessionManagerHealth:
|
|||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, manager):
|
||||
assert await manager.health_check() is True
|
||||
|
||||
|
||||
class TestAsyncWrites:
|
||||
"""Tests for async (non-blocking) write behaviour."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_message_returns_immediately(self, async_manager):
|
||||
"""append_message returns the Message before it is persisted."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
msg = await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Hello",
|
||||
)
|
||||
# Message is returned immediately
|
||||
assert msg.role == MessageRole.USER
|
||||
assert msg.content == "Hello"
|
||||
# Give the background writer a moment, then verify persistence
|
||||
await async_manager.flush()
|
||||
persisted = await async_manager.store.get_messages(session.session_id)
|
||||
assert len(persisted) == 1
|
||||
assert persisted[0].content == "Hello"
|
||||
await async_manager.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_messages_includes_wal_buffered(self, async_manager):
|
||||
"""get_chat_messages returns WAL-buffered messages not yet persisted."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
# Append a message — it may still be in the WAL buffer
|
||||
await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Buffered",
|
||||
)
|
||||
# get_chat_messages should include WAL-buffered messages
|
||||
chat_msgs = await async_manager.get_chat_messages(session.session_id)
|
||||
assert len(chat_msgs) >= 1
|
||||
assert any(m["content"] == "Buffered" for m in chat_msgs)
|
||||
await async_manager.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flush_ensures_all_pending_writes(self, async_manager):
|
||||
"""flush() waits until all queued writes are persisted."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
for i in range(5):
|
||||
await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=f"Msg {i}",
|
||||
)
|
||||
await async_manager.flush()
|
||||
persisted = await async_manager.store.get_messages(session.session_id)
|
||||
assert len(persisted) == 5
|
||||
await async_manager.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_appends_are_batched(self, async_manager):
|
||||
"""Multiple rapid append_messages are all persisted correctly."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
# Fire off many messages rapidly
|
||||
messages = []
|
||||
for i in range(20):
|
||||
msg = await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=f"Rapid {i}",
|
||||
)
|
||||
messages.append(msg)
|
||||
await async_manager.flush()
|
||||
persisted = await async_manager.store.get_messages(session.session_id)
|
||||
assert len(persisted) == 20
|
||||
contents = [m.content for m in persisted]
|
||||
for i in range(20):
|
||||
assert f"Rapid {i}" in contents
|
||||
await async_manager.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_close_flushes_pending_writes(self, async_manager):
|
||||
"""Closing a session flushes pending writes first."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Before close",
|
||||
)
|
||||
closed = await async_manager.close_session(session.session_id)
|
||||
assert closed.status == SessionStatus.CLOSED
|
||||
# Message should be persisted because close_session flushes
|
||||
persisted = await async_manager.store.get_messages(session.session_id)
|
||||
assert len(persisted) == 1
|
||||
assert persisted[0].content == "Before close"
|
||||
await async_manager.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_close_stops_writer(self, async_manager):
|
||||
"""close() flushes and stops the background writer."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Final",
|
||||
)
|
||||
await async_manager.close()
|
||||
# After close, the write queue should be stopped
|
||||
assert async_manager._write_queue is None or async_manager._write_queue._worker is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_writes_disabled_by_default(self):
|
||||
"""Without async_writes=True, writes are synchronous."""
|
||||
mgr = SessionManager(store=InMemorySessionStore())
|
||||
assert mgr._write_queue is None
|
||||
session = await mgr.create_session(agent_name="agent1")
|
||||
await mgr.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Sync",
|
||||
)
|
||||
# Should be immediately persisted (no flush needed)
|
||||
persisted = await mgr.store.get_messages(session.session_id)
|
||||
assert len(persisted) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_includes_wal_buffered(self, async_manager):
|
||||
"""get_messages returns WAL-buffered messages not yet persisted."""
|
||||
session = await async_manager.create_session(agent_name="agent1")
|
||||
await async_manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="WAL msg",
|
||||
)
|
||||
messages = await async_manager.get_messages(session.session_id)
|
||||
assert len(messages) >= 1
|
||||
assert any(m.content == "WAL msg" for m in messages)
|
||||
await async_manager.close()
|
||||
|
|
|
|||
Loading…
Reference in New Issue