feat: optimize chat response speed for sub-1s first token latency
- Add HeuristicClassifier to replace LLM quick_classify with zero-cost local heuristic (keyword/length/code-pattern scoring), gated by router.classifier config (default: heuristic) - Add parallel tool execution in ReActEngine via asyncio.gather for multiple independent tool_calls, gated by parallel_tools param - Add AsyncWriteQueue for non-blocking session persistence with WAL buffer, gated by async_writes param on SessionManager - Add httpx.Limits connection pool config to all LLM providers - Add router config section to ServerConfig and agentkit.yaml - All optimizations have config switches for safe rollback
This commit is contained in:
parent
d3b792a9ec
commit
a36bc3d1c1
|
|
@ -0,0 +1,42 @@
|
||||||
|
server:
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8001
|
||||||
|
workers: 1
|
||||||
|
rate_limit: 60
|
||||||
|
llm:
|
||||||
|
providers:
|
||||||
|
bailian-coding:
|
||||||
|
api_key: ${DASHSCOPE_API_KEY}
|
||||||
|
base_url: https://coding.dashscope.aliyuncs.com/v1
|
||||||
|
type: openai
|
||||||
|
models:
|
||||||
|
qwen3.7-plus:
|
||||||
|
alias: default
|
||||||
|
qwen3.6-plus: {}
|
||||||
|
qwen3.5-plus: {}
|
||||||
|
qwen3-max-2026-01-23: {}
|
||||||
|
qwen3-coder-plus:
|
||||||
|
alias: coder
|
||||||
|
qwen3-coder-next: {}
|
||||||
|
kimi-k2.5: {}
|
||||||
|
glm-5: {}
|
||||||
|
glm-4.7: {}
|
||||||
|
MiniMax-M2.5: {}
|
||||||
|
model_aliases:
|
||||||
|
default: bailian-coding/qwen3.7-plus
|
||||||
|
coder: bailian-coding/qwen3-coder-plus
|
||||||
|
session:
|
||||||
|
backend: memory
|
||||||
|
bus:
|
||||||
|
backend: memory
|
||||||
|
task_store:
|
||||||
|
backend: memory
|
||||||
|
skills:
|
||||||
|
auto_discover: true
|
||||||
|
paths:
|
||||||
|
- ./configs/skills
|
||||||
|
logging:
|
||||||
|
level: INFO
|
||||||
|
format: text
|
||||||
|
router:
|
||||||
|
classifier: heuristic
|
||||||
|
|
@ -0,0 +1,372 @@
|
||||||
|
---
|
||||||
|
title: "feat: Chat Response Speed Optimization — Sub-1s First Token"
|
||||||
|
status: active
|
||||||
|
created: 2026-06-12
|
||||||
|
plan-type: feat
|
||||||
|
depth: standard
|
||||||
|
---
|
||||||
|
|
||||||
|
# feat: Chat Response Speed Optimization — Sub-1s First Token
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Optimize the fischer-agentkit conversation response pipeline to achieve sub-1-second first-token latency. The primary bottleneck is 1–2 extra LLM calls in the routing layer before the main ReAct loop. Secondary optimizations include parallel tool execution, async session I/O, and connection pool tuning. All changes are gated by configuration flags for safe rollback.
|
||||||
|
|
||||||
|
## Problem Frame
|
||||||
|
|
||||||
|
Users experience 5–10 second delays before seeing any response in the chat interface. The root cause is a serial chain of LLM calls: CostAwareRouter.quick_classify() → IntentRouter._classify_with_llm() → ReActEngine LLM Think. The first two calls are routing overhead that add 2–6 seconds with no user-visible value. The third call is the actual reasoning step and cannot be eliminated, but its perceived latency can be reduced via streaming.
|
||||||
|
|
||||||
|
**Current worst-case latency chain:**
|
||||||
|
|
||||||
|
```
|
||||||
|
User message → quick_classify() [1-2s LLM]
|
||||||
|
→ _classify_with_llm() [1-2s LLM]
|
||||||
|
→ ReActEngine Think [2-5s LLM]
|
||||||
|
→ Tool Act [0.5-5s]
|
||||||
|
→ First token visible to user
|
||||||
|
```
|
||||||
|
|
||||||
|
**Target latency chain:**
|
||||||
|
|
||||||
|
```
|
||||||
|
User message → Local rule classification [<1ms]
|
||||||
|
→ ReActEngine Think (streaming) [first token in <1s]
|
||||||
|
→ Tool Act (parallel when possible)
|
||||||
|
→ First token visible to user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
| ID | Requirement | Priority |
|
||||||
|
|----|-------------|----------|
|
||||||
|
| R1 | First token latency must be under 1 second for simple conversations (greetings, Q&A) | P0 |
|
||||||
|
| R2 | First token latency must be under 1 second for routed conversations when keyword matching succeeds | P0 |
|
||||||
|
| R3 | Routing accuracy must not degrade more than 10% compared to current LLM-based classification | P1 |
|
||||||
|
| R4 | All optimizations must be configurable with on/off switches for safe rollback | P0 |
|
||||||
|
| R5 | Parallel tool execution must preserve conversation history ordering | P1 |
|
||||||
|
| R6 | Async session writes must not lose messages on process crash | P1 |
|
||||||
|
|
||||||
|
## Key Technical Decisions
|
||||||
|
|
||||||
|
### KTD1: Replace LLM quick_classify with local heuristic
|
||||||
|
|
||||||
|
**Decision:** Replace `CostAwareRouter.quick_classify()` LLM call with a zero-cost local heuristic based on message length, keyword density, and tool-hint detection.
|
||||||
|
|
||||||
|
**Rationale:** The LLM classification adds 1–2s latency for a binary decision (simple vs complex). A local heuristic using the same signals already present in the message content (length, presence of tool-related keywords, question marks, etc.) can achieve ~85% accuracy at zero latency cost.
|
||||||
|
|
||||||
|
**Alternative considered:** Cache LLM classification results. Rejected because cache hit rate would be near-zero for conversational messages (each is unique).
|
||||||
|
|
||||||
|
### KTD2: Merge quick_classify and intent classification into single LLM call
|
||||||
|
|
||||||
|
**Decision:** When LLM routing is needed (heuristic uncertainty), combine complexity scoring and intent classification into a single LLM call instead of two serial calls.
|
||||||
|
|
||||||
|
**Rationale:** Currently `quick_classify()` and `_classify_with_llm()` are separate LLM calls that could be merged into one prompt returning both complexity score and matched skill. This halves the routing LLM overhead when it cannot be avoided.
|
||||||
|
|
||||||
|
### KTD3: Parallel execution of independent tool_calls
|
||||||
|
|
||||||
|
**Decision:** Execute multiple tool_calls from a single LLM response in parallel using `asyncio.gather()`, with results appended to conversation in tool_call_id order.
|
||||||
|
|
||||||
|
**Rationale:** When LLM returns multiple tool calls (e.g., search + calculate), they are independent and can run concurrently. Results must be appended in order for the next LLM call to see them correctly.
|
||||||
|
|
||||||
|
**Risk:** Some tool calls may have implicit dependencies. Mitigation: the LLM generally does not return dependent calls in a single response (it waits for results before calling the next). Add a config flag `react.parallel_tools: false` to disable if needed.
|
||||||
|
|
||||||
|
### KTD4: Fire-and-forget session writes with write-ahead buffer
|
||||||
|
|
||||||
|
**Decision:** Make `SessionManager.append_message()` non-blocking by returning immediately after queuing the write, with a background task performing the actual I/O. Add a small in-memory buffer as write-ahead log to prevent message loss.
|
||||||
|
|
||||||
|
**Rationale:** Session writes (especially `save_session()` for updated_at) add unnecessary blocking. The user doesn't need to wait for persistence before seeing a response. A write-ahead buffer ensures messages survive brief failures.
|
||||||
|
|
||||||
|
### KTD5: Unified httpx connection pool configuration
|
||||||
|
|
||||||
|
**Decision:** Configure explicit `httpx.Limits` on all LLM provider clients with sensible defaults for keepalive and connection pooling.
|
||||||
|
|
||||||
|
**Rationale:** Default httpx settings are reasonable but not optimized for high-frequency LLM API calls. Explicit configuration ensures consistent behavior across providers and enables tuning.
|
||||||
|
|
||||||
|
## Scope Boundaries
|
||||||
|
|
||||||
|
### In Scope
|
||||||
|
|
||||||
|
- Routing layer optimization (CostAwareRouter, IntentRouter)
|
||||||
|
- ReActEngine parallel tool execution
|
||||||
|
- Session I/O async optimization
|
||||||
|
- httpx connection pool tuning
|
||||||
|
- Configuration flags for all changes
|
||||||
|
- Test coverage for new behavior
|
||||||
|
|
||||||
|
### Out of Scope
|
||||||
|
|
||||||
|
- Frontend rendering optimization (separate concern)
|
||||||
|
- LLM provider response time optimization (external dependency)
|
||||||
|
- Memory/RAG pipeline optimization (covered by existing plan 009)
|
||||||
|
- Compression strategy changes (covered by existing plan 013)
|
||||||
|
- New LLM provider implementations
|
||||||
|
|
||||||
|
### Deferred to Follow-Up Work
|
||||||
|
|
||||||
|
- A/B testing framework for routing accuracy measurement
|
||||||
|
- Performance benchmarking CI pipeline
|
||||||
|
- WebSocket chat flow test coverage
|
||||||
|
- WenxinProvider token-refresh client reuse
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Units
|
||||||
|
|
||||||
|
### U1. Local heuristic classifier for CostAwareRouter
|
||||||
|
|
||||||
|
**Goal:** Replace the LLM-based `quick_classify()` with a zero-cost local heuristic, gated by config flag `router.classifier: heuristic | llm`.
|
||||||
|
|
||||||
|
**Requirements:** R1, R2, R3, R4
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/chat/skill_routing.py` — add `HeuristicClassifier` class, modify `CostAwareRouter.route()`
|
||||||
|
- `src/agentkit/server/config.py` — add `router` config section
|
||||||
|
- `agentkit.yaml` — add `router` section with defaults
|
||||||
|
- `tests/unit/test_cost_aware_router.py` — add heuristic classifier tests
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. Create `HeuristicClassifier` class with a `classify(content: str) -> float` method that returns a complexity score (0.0–1.0) based on:
|
||||||
|
- Message length: short messages (<20 chars) → low complexity
|
||||||
|
- Question patterns: presence of "为什么", "如何", "怎么", "how", "why", "what" → moderate complexity
|
||||||
|
- Tool hints: presence of tool-related keywords (existing `_tokenize_content` + `tool_hints` list already in code) → high complexity
|
||||||
|
- Multi-sentence: messages with multiple sentences → higher complexity
|
||||||
|
- Code patterns: presence of code-like patterns (backticks, brackets) → higher complexity
|
||||||
|
|
||||||
|
2. Modify `CostAwareRouter.__init__` to accept a `classifier_mode` parameter (`"heuristic"` or `"llm"`)
|
||||||
|
|
||||||
|
3. Modify `CostAwareRouter.route()` Phase 1 to use `HeuristicClassifier.classify()` when mode is `"heuristic"`
|
||||||
|
|
||||||
|
4. Add `router` config section to `ServerConfig` with `classifier` field (default: `"heuristic"`)
|
||||||
|
|
||||||
|
5. Wire config through `create_app()` to `CostAwareRouter`
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `CostAwareRouter._match_layer0()` rule-based pattern; existing `_tokenize_content()` for keyword extraction.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Short greeting → complexity < 0.3
|
||||||
|
- Single question with "如何" → complexity 0.3–0.7
|
||||||
|
- Multi-step request with tool keywords → complexity > 0.7
|
||||||
|
- Code-related request → complexity > 0.7
|
||||||
|
- Empty string → complexity 0.0
|
||||||
|
- Very long message (>500 chars) → complexity > 0.5
|
||||||
|
- Config flag `classifier: llm` falls back to LLM classification
|
||||||
|
- Config flag `classifier: heuristic` uses local heuristic
|
||||||
|
|
||||||
|
**Verification:** All existing `test_cost_aware_router.py` tests pass; new heuristic tests pass; manual test shows first-token latency <1s for simple messages.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U2. Merged routing LLM call
|
||||||
|
|
||||||
|
**Goal:** When LLM routing is needed (heuristic uncertain or config forces LLM), combine complexity scoring and intent classification into a single LLM call.
|
||||||
|
|
||||||
|
**Requirements:** R2, R3, R4
|
||||||
|
|
||||||
|
**Dependencies:** U1
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/chat/skill_routing.py` — add `MergedRouter` method
|
||||||
|
- `src/agentkit/router/intent.py` — add `route_with_complexity()` method
|
||||||
|
- `tests/unit/test_cost_aware_router.py` — add merged routing tests
|
||||||
|
- `tests/unit/test_intent_router.py` — add merged routing tests
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. Add `IntentRouter.route_with_complexity()` method that returns both a `RoutingResult` and a complexity score in a single LLM call. The prompt asks the LLM to return `{"skill": "...", "confidence": 0.9, "complexity": 0.5}`.
|
||||||
|
|
||||||
|
2. Modify `CostAwareRouter.route()` so that when `classifier` is `"llm"`, it calls `route_with_complexity()` instead of making two separate calls.
|
||||||
|
|
||||||
|
3. When `classifier` is `"heuristic"` and the heuristic returns uncertainty (score in 0.3–0.7 range), use `route_with_complexity()` as a single fallback call.
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `_classify_with_llm()` prompt structure; existing `quick_classify()` prompt structure.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Merged call returns both skill match and complexity score
|
||||||
|
- Merged call with no matching skill returns complexity only
|
||||||
|
- Merged call with invalid LLM response falls back to rule-based evaluation
|
||||||
|
- Heuristic uncertain + merged call produces correct routing
|
||||||
|
- Config `classifier: llm` uses merged call instead of two separate calls
|
||||||
|
|
||||||
|
**Verification:** Existing tests pass; merged routing reduces LLM calls from 2 to 1 when LLM routing is needed.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U3. Parallel tool execution in ReActEngine
|
||||||
|
|
||||||
|
**Goal:** Execute multiple independent tool_calls from a single LLM response in parallel, gated by config flag `react.parallel_tools`.
|
||||||
|
|
||||||
|
**Requirements:** R5, R4
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/core/react.py` — modify `_execute_loop()` and `execute_stream()` to use `asyncio.gather()`
|
||||||
|
- `src/agentkit/server/config.py` — add `react.parallel_tools` config
|
||||||
|
- `agentkit.yaml` — add `react` section
|
||||||
|
- `tests/unit/test_react_engine.py` — add parallel execution tests
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. Add `parallel_tools: bool = True` parameter to `ReActEngine.__init__`.
|
||||||
|
|
||||||
|
2. In `_execute_loop()` and `execute_stream()`, when `response.tool_calls` has >1 items and `parallel_tools` is True:
|
||||||
|
- Execute all tool calls concurrently with `asyncio.gather(*[_execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls], return_exceptions=True)`
|
||||||
|
- Build tool result messages in tool_call_id order
|
||||||
|
- Append all results to conversation in order
|
||||||
|
|
||||||
|
3. When `parallel_tools` is False, keep current serial behavior.
|
||||||
|
|
||||||
|
4. For `execute_stream()`, yield all `tool_call` events first, then all `tool_result` events after gather completes.
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `_execute_tool()` method; existing `Orchestrator._execute_plan()` parallel group pattern in `orchestrator.py`.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Two independent tools execute in parallel, both results present in conversation
|
||||||
|
- Parallel execution preserves tool_call_id ordering in conversation
|
||||||
|
- One tool fails, other succeeds — partial results preserved
|
||||||
|
- `parallel_tools: false` falls back to serial execution
|
||||||
|
- Single tool_call works identically with parallel mode on/off
|
||||||
|
- Tool results appended to conversation in correct order for next LLM call
|
||||||
|
|
||||||
|
**Verification:** Existing ReAct tests pass; new parallel tests pass; manual test with multi-tool request shows reduced execution time.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U4. Async session writes with write-ahead buffer
|
||||||
|
|
||||||
|
**Goal:** Make `SessionManager.append_message()` non-blocking by deferring `save_session()` and making `append_message()` fire-and-forget with a small write-ahead buffer.
|
||||||
|
|
||||||
|
**Requirements:** R6, R4
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/session/manager.py` — add async write queue and WAL buffer
|
||||||
|
- `tests/unit/test_session_manager.py` — add async write tests
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. Add an `AsyncWriteQueue` to `SessionManager` that:
|
||||||
|
- Accepts write operations (append_message, save_session) as tasks
|
||||||
|
- Executes them in a background `asyncio.Task`
|
||||||
|
- Maintains a small in-memory buffer of recent writes for crash recovery
|
||||||
|
- Provides `await flush()` for graceful shutdown
|
||||||
|
|
||||||
|
2. Modify `append_message()`:
|
||||||
|
- Keep `get_session()` + validation as synchronous (needed for error checking)
|
||||||
|
- Queue `store.append_message()` + `store.save_session()` as a single async task
|
||||||
|
- Return the `Message` object immediately without waiting for persistence
|
||||||
|
|
||||||
|
3. Modify `get_chat_messages()` to first check the WAL buffer for uncommitted messages, then fall back to store.
|
||||||
|
|
||||||
|
4. Add `flush()` method called during session close and app shutdown.
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `BackgroundRunner` pattern in `server/runner.py`; existing `TaskStore` cleanup pattern.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- append_message returns immediately, message persisted asynchronously
|
||||||
|
- get_chat_messages includes WAL-buffered messages not yet persisted
|
||||||
|
- flush() ensures all pending writes complete
|
||||||
|
- Multiple rapid append_messages are batched correctly
|
||||||
|
- Session close flushes pending writes
|
||||||
|
- App shutdown flushes pending writes
|
||||||
|
|
||||||
|
**Verification:** Existing session tests pass; new async write tests pass; no message loss during normal operation.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U5. httpx connection pool configuration
|
||||||
|
|
||||||
|
**Goal:** Configure explicit `httpx.Limits` on all LLM provider clients for optimal connection reuse.
|
||||||
|
|
||||||
|
**Requirements:** R4
|
||||||
|
|
||||||
|
**Dependencies:** None
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/llm/providers/openai.py` — add `httpx.Limits` configuration
|
||||||
|
- `src/agentkit/llm/providers/anthropic.py` — add `httpx.Limits` configuration
|
||||||
|
- `src/agentkit/llm/providers/gemini.py` — add `httpx.Limits` configuration
|
||||||
|
- `src/agentkit/llm/config.py` — add connection pool config fields
|
||||||
|
- `tests/unit/test_llm_provider.py` — verify connection pool settings
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. Add `connection_pool` section to `ProviderConfig`:
|
||||||
|
- `max_connections: int = 100`
|
||||||
|
- `max_keepalive_connections: int = 20`
|
||||||
|
- `keepalive_expiry: float = 30.0`
|
||||||
|
|
||||||
|
2. Pass `httpx.Limits` to all provider constructors.
|
||||||
|
|
||||||
|
3. Configure `httpx.AsyncClient` with explicit limits in each provider.
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `ProviderConfig` dataclass pattern; existing `timeout` parameter pattern.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- Provider creates httpx client with configured limits
|
||||||
|
- Default limits applied when not configured
|
||||||
|
- Custom limits from config override defaults
|
||||||
|
- Connection reuse verified via mock
|
||||||
|
|
||||||
|
**Verification:** Existing provider tests pass; connection pool settings applied correctly.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U6. Chat route pipeline optimization
|
||||||
|
|
||||||
|
**Goal:** Optimize the WebSocket chat handler to overlap I/O operations and reduce serial waits.
|
||||||
|
|
||||||
|
**Requirements:** R1, R2
|
||||||
|
|
||||||
|
**Dependencies:** U1, U4
|
||||||
|
|
||||||
|
**Files:**
|
||||||
|
- `src/agentkit/server/routes/chat.py` — parallelize session operations
|
||||||
|
- `tests/unit/test_chat_routes.py` — add pipeline optimization tests
|
||||||
|
|
||||||
|
**Approach:**
|
||||||
|
|
||||||
|
1. In `_handle_chat_message()`, parallelize:
|
||||||
|
- `sm.append_message()` (user message) and `sm.get_chat_messages()` — these can run concurrently since append_message now returns immediately (U4)
|
||||||
|
|
||||||
|
2. Move assistant message `append_message()` to fire-and-forget after streaming completes (already non-blocking with U4).
|
||||||
|
|
||||||
|
3. Reuse `ReActEngine` instance per session instead of creating new one per message.
|
||||||
|
|
||||||
|
**Patterns to follow:** Existing `asyncio.gather` pattern in orchestrator.
|
||||||
|
|
||||||
|
**Test scenarios:**
|
||||||
|
- User message append and chat messages retrieval run concurrently
|
||||||
|
- Assistant message persisted after streaming completes
|
||||||
|
- ReActEngine reuse across messages in same session
|
||||||
|
- Error during parallel operations handled gracefully
|
||||||
|
|
||||||
|
**Verification:** Existing chat route tests pass; manual test shows reduced latency.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risks & Mitigations
|
||||||
|
|
||||||
|
| Risk | Likelihood | Impact | Mitigation |
|
||||||
|
|------|-----------|--------|------------|
|
||||||
|
| Heuristic classifier misroutes requests | Medium | Medium — wrong skill or wrong execution mode | Config flag to revert to LLM; monitor routing accuracy via telemetry |
|
||||||
|
| Parallel tool execution breaks implicit dependencies | Low | High — incorrect results | Config flag to disable; LLM rarely returns dependent calls in single response |
|
||||||
|
| Async session writes lose messages on crash | Low | Medium — missing conversation history | WAL buffer + flush on shutdown; acceptable trade-off for speed |
|
||||||
|
| Merged LLM call prompt confuses the model | Low | Low — falls back to separate calls | Fallback to separate calls on parse failure |
|
||||||
|
|
||||||
|
## System-Wide Impact
|
||||||
|
|
||||||
|
- **Routing layer:** CostAwareRouter and IntentRouter behavior changes when heuristic mode is active; existing LLM-based routing preserved as fallback
|
||||||
|
- **ReAct engine:** Tool execution changes from serial to parallel; conversation history ordering preserved
|
||||||
|
- **Session management:** Write operations become asynchronous; read operations check WAL buffer
|
||||||
|
- **Configuration:** New `router` and `react` config sections in `agentkit.yaml`
|
||||||
|
- **Telemetry:** Existing OpenTelemetry spans continue to work; new spans for heuristic classification
|
||||||
|
|
||||||
|
## Open Questions
|
||||||
|
|
||||||
|
- What is the actual routing accuracy of the current LLM-based classifier? Need baseline measurement before comparing heuristic accuracy.
|
||||||
|
- Should the heuristic classifier be extensible (plugin pattern) or hardcoded? Starting with hardcoded for simplicity, can extend later.
|
||||||
|
|
@ -282,11 +282,101 @@ def _tokenize_content(content: str) -> list[str]:
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class HeuristicClassifier:
|
||||||
|
"""零成本本地启发式分类器,替代 LLM quick_classify。
|
||||||
|
|
||||||
|
基于消息长度、关键词密度、工具暗示等特征评估复杂度 (0.0-1.0),
|
||||||
|
无需任何 LLM 调用,延迟 <1ms。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 高复杂度暗示词(需要工具或多步推理)
|
||||||
|
_HIGH_COMPLEXITY_HINTS = {
|
||||||
|
# 工具/执行类
|
||||||
|
"执行", "运行", "命令", "终端", "shell", "bash", "script",
|
||||||
|
"安装", "部署", "启动", "停止", "重启", "配置",
|
||||||
|
"搜索", "查找", "联网", "search", "find", "query",
|
||||||
|
"文件", "目录", "创建", "删除", "修改", "编辑",
|
||||||
|
"run", "execute", "install", "deploy", "start", "stop",
|
||||||
|
"restart", "file", "directory", "create", "delete", "modify",
|
||||||
|
# 多步/分析类
|
||||||
|
"分析", "比较", "对比", "评估", "调研", "研究",
|
||||||
|
"设计", "规划", "方案", "架构", "实现", "开发",
|
||||||
|
"analyze", "compare", "evaluate", "research", "design",
|
||||||
|
"plan", "implement", "develop", "build",
|
||||||
|
# 代码类
|
||||||
|
"代码", "编程", "函数", "类", "接口", "调试", "重构",
|
||||||
|
"code", "program", "function", "class", "interface", "debug", "refactor",
|
||||||
|
"python", "java", "javascript", "typescript", "sql", "api",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 中等复杂度暗示词(简单问题但需思考)
|
||||||
|
_MEDIUM_COMPLEXITY_HINTS = {
|
||||||
|
"如何", "怎么", "怎样", "为什么", "什么原因", "区别",
|
||||||
|
"how", "why", "what", "difference", "explain",
|
||||||
|
"能", "可以", "是否", "会不会",
|
||||||
|
"推荐", "建议", "选择", "哪个",
|
||||||
|
"recommend", "suggest", "choose", "which",
|
||||||
|
}
|
||||||
|
|
||||||
|
def classify(self, content: str) -> float:
|
||||||
|
"""评估消息复杂度 (0.0-1.0)。
|
||||||
|
|
||||||
|
评分规则:
|
||||||
|
- 短消息 (<20字符) 且无复杂度暗示 → 0.1
|
||||||
|
- 含中等复杂度关键词 → 0.4-0.5
|
||||||
|
- 含高复杂度关键词 → 0.7-0.9
|
||||||
|
- 多句/长消息 → 额外加成
|
||||||
|
- 代码模式 (反引号/括号) → 额外加成
|
||||||
|
"""
|
||||||
|
if not content or not content.strip():
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
content_lower = content.lower()
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
# 1. 关键词匹配
|
||||||
|
high_hits = sum(1 for h in self._HIGH_COMPLEXITY_HINTS if h in content_lower)
|
||||||
|
medium_hits = sum(1 for m in self._MEDIUM_COMPLEXITY_HINTS if m in content_lower)
|
||||||
|
|
||||||
|
if high_hits >= 2:
|
||||||
|
score = 0.8
|
||||||
|
elif high_hits == 1:
|
||||||
|
score = 0.65
|
||||||
|
elif medium_hits >= 1:
|
||||||
|
score = 0.45
|
||||||
|
else:
|
||||||
|
score = 0.15
|
||||||
|
|
||||||
|
# 2. 消息长度加成
|
||||||
|
length = len(content)
|
||||||
|
if length > 200:
|
||||||
|
score += 0.15
|
||||||
|
elif length > 100:
|
||||||
|
score += 0.1
|
||||||
|
elif length > 50:
|
||||||
|
score += 0.05
|
||||||
|
|
||||||
|
# 3. 多句加成(逗号/句号/换行分隔)
|
||||||
|
sentence_count = len(re.split(r'[,。!?;\n,.!?;]', content))
|
||||||
|
if sentence_count >= 4:
|
||||||
|
score += 0.1
|
||||||
|
elif sentence_count >= 2:
|
||||||
|
score += 0.05
|
||||||
|
|
||||||
|
# 4. 代码模式加成
|
||||||
|
if '`' in content or '```' in content:
|
||||||
|
score += 0.15
|
||||||
|
if re.search(r'[\{\}\[\]\(\)]', content):
|
||||||
|
score += 0.05
|
||||||
|
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
|
|
||||||
|
|
||||||
class CostAwareRouter:
|
class CostAwareRouter:
|
||||||
"""三层成本感知路由器。
|
"""三层成本感知路由器。
|
||||||
|
|
||||||
Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话
|
Layer 0: 规则匹配(零成本)— @skill: 前缀 / 问候 / 简单对话
|
||||||
Layer 1: LLM 快速分类(~100 tokens)— 复杂度评估 + IntentRouter
|
Layer 1: 复杂度分类 — heuristic(零成本)或 LLM(~100 tokens)
|
||||||
Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent
|
Layer 2: 能力匹配 / 拍卖(可选)— 高复杂度任务委派给最佳 Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -296,11 +386,14 @@ class CostAwareRouter:
|
||||||
model: str = "default",
|
model: str = "default",
|
||||||
org_context: Any = None,
|
org_context: Any = None,
|
||||||
auction_enabled: bool = False,
|
auction_enabled: bool = False,
|
||||||
|
classifier: str = "heuristic",
|
||||||
):
|
):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
self._model = model
|
self._model = model
|
||||||
self._org_context = org_context
|
self._org_context = org_context
|
||||||
self._auction_enabled = auction_enabled
|
self._auction_enabled = auction_enabled
|
||||||
|
self._classifier = classifier
|
||||||
|
self._heuristic = HeuristicClassifier()
|
||||||
|
|
||||||
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
# -- Layer 0: Rule-based (zero cost) ------------------------------------
|
||||||
|
|
||||||
|
|
@ -516,13 +609,21 @@ class CostAwareRouter:
|
||||||
span.set_attribute("route.target", "default")
|
span.set_attribute("route.target", "default")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# ---- Layer 1: LLM quick classify (~100 tokens) ----
|
# ---- Layer 1: Complexity classification ----
|
||||||
complexity = await self.quick_classify(clean_content)
|
if self._classifier == "heuristic":
|
||||||
trace.append({
|
complexity = self._heuristic.classify(clean_content)
|
||||||
"layer": 1,
|
trace.append({
|
||||||
"method": "quick_classify",
|
"layer": 1,
|
||||||
"complexity": complexity,
|
"method": "heuristic_classify",
|
||||||
})
|
"complexity": complexity,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
complexity = await self.quick_classify(clean_content)
|
||||||
|
trace.append({
|
||||||
|
"layer": 1,
|
||||||
|
"method": "quick_classify",
|
||||||
|
"complexity": complexity,
|
||||||
|
})
|
||||||
|
|
||||||
# Low complexity → default agent
|
# Low complexity → default agent
|
||||||
if complexity < 0.3:
|
if complexity < 0.3:
|
||||||
|
|
|
||||||
|
|
@ -353,6 +353,9 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
|
||||||
max_tokens=pconf.max_tokens,
|
max_tokens=pconf.max_tokens,
|
||||||
base_url=pconf.base_url or "https://api.anthropic.com",
|
base_url=pconf.base_url or "https://api.anthropic.com",
|
||||||
timeout=pconf.timeout,
|
timeout=pconf.timeout,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
elif pconf.type == "gemini":
|
elif pconf.type == "gemini":
|
||||||
provider = GeminiProvider(
|
provider = GeminiProvider(
|
||||||
|
|
@ -361,11 +364,17 @@ def _build_gateway(server_config: "ServerConfig") -> "LLMGateway":
|
||||||
max_output_tokens=pconf.max_tokens,
|
max_output_tokens=pconf.max_tokens,
|
||||||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||||||
timeout=pconf.timeout,
|
timeout=pconf.timeout,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
provider = OpenAICompatibleProvider(
|
provider = OpenAICompatibleProvider(
|
||||||
api_key=pconf.api_key,
|
api_key=pconf.api_key,
|
||||||
base_url=pconf.base_url,
|
base_url=pconf.base_url,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
gateway.register_provider(name, provider)
|
gateway.register_provider(name, provider)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ class ReActResult:
|
||||||
class ReActEvent:
|
class ReActEvent:
|
||||||
"""ReAct 执行事件"""
|
"""ReAct 执行事件"""
|
||||||
|
|
||||||
event_type: str # "thinking", "token", "tool_call", "tool_result", "final_answer", "error"
|
event_type: str # "thinking", "token", "tool_call", "tool_result", "confirmation_request", "final_answer", "error"
|
||||||
step: int
|
step: int
|
||||||
data: dict[str, Any] = field(default_factory=dict)
|
data: dict[str, Any] = field(default_factory=dict)
|
||||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
@ -74,12 +74,13 @@ class ReActEngine:
|
||||||
使 Agent 能够自主推理并选择工具完成任务。
|
使 Agent 能够自主推理并选择工具完成任务。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0):
|
def __init__(self, llm_gateway: LLMGateway, max_steps: int = 10, default_timeout: float = 300.0, parallel_tools: bool = True):
|
||||||
if max_steps < 1:
|
if max_steps < 1:
|
||||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
self._max_steps = max_steps
|
self._max_steps = max_steps
|
||||||
self._default_timeout = default_timeout
|
self._default_timeout = default_timeout
|
||||||
|
self._parallel_tools = parallel_tools
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
|
|
@ -293,41 +294,81 @@ class ReActEngine:
|
||||||
}
|
}
|
||||||
conversation.append(assistant_msg)
|
conversation.append(assistant_msg)
|
||||||
|
|
||||||
# 执行每个工具调用
|
# 执行工具调用
|
||||||
for tc in response.tool_calls:
|
if self._parallel_tools and len(response.tool_calls) > 1:
|
||||||
tool_start = time.monotonic()
|
# 并行执行多个工具调用
|
||||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
tool_results = await asyncio.gather(
|
||||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
*[self._execute_tool(tc.name, tc.arguments, tools) for tc in response.tool_calls],
|
||||||
|
return_exceptions=True,
|
||||||
react_step = ReActStep(
|
|
||||||
step=step,
|
|
||||||
action="tool_call",
|
|
||||||
tool_name=tc.name,
|
|
||||||
arguments=tc.arguments,
|
|
||||||
result=tool_result,
|
|
||||||
tokens=step_tokens,
|
|
||||||
)
|
)
|
||||||
trajectory.append(react_step)
|
for idx, tc in enumerate(response.tool_calls):
|
||||||
|
tool_result = tool_results[idx]
|
||||||
|
if isinstance(tool_result, Exception):
|
||||||
|
tool_result = {"error": str(tool_result)}
|
||||||
|
|
||||||
# 记录工具调用步骤
|
react_step = ReActStep(
|
||||||
if trace_recorder is not None:
|
|
||||||
tool_error = None
|
|
||||||
if isinstance(tool_result, dict) and "error" in tool_result:
|
|
||||||
tool_error = tool_result["error"]
|
|
||||||
trace_recorder.record_step(
|
|
||||||
step=step,
|
step=step,
|
||||||
action="tool_call",
|
action="tool_call",
|
||||||
tool_name=tc.name,
|
tool_name=tc.name,
|
||||||
input_data=tc.arguments,
|
arguments=tc.arguments,
|
||||||
output_data=tool_result,
|
result=tool_result,
|
||||||
duration_ms=tool_duration_ms,
|
tokens=step_tokens,
|
||||||
tokens_used=0,
|
|
||||||
error=tool_error,
|
|
||||||
)
|
)
|
||||||
|
trajectory.append(react_step)
|
||||||
|
|
||||||
# Observe: 将工具结果添加到对话历史
|
if trace_recorder is not None:
|
||||||
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
tool_error = None
|
||||||
conversation.append(tool_msg)
|
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||||
|
tool_error = tool_result["error"]
|
||||||
|
trace_recorder.record_step(
|
||||||
|
step=step,
|
||||||
|
action="tool_call",
|
||||||
|
tool_name=tc.name,
|
||||||
|
input_data=tc.arguments,
|
||||||
|
output_data=tool_result,
|
||||||
|
duration_ms=0,
|
||||||
|
tokens_used=0,
|
||||||
|
error=tool_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||||
|
conversation.append(tool_msg)
|
||||||
|
else:
|
||||||
|
# 串行执行(单工具或 parallel_tools=False)
|
||||||
|
for tc in response.tool_calls:
|
||||||
|
tool_start = time.monotonic()
|
||||||
|
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||||
|
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||||
|
|
||||||
|
react_step = ReActStep(
|
||||||
|
step=step,
|
||||||
|
action="tool_call",
|
||||||
|
tool_name=tc.name,
|
||||||
|
arguments=tc.arguments,
|
||||||
|
result=tool_result,
|
||||||
|
tokens=step_tokens,
|
||||||
|
)
|
||||||
|
trajectory.append(react_step)
|
||||||
|
|
||||||
|
# 记录工具调用步骤
|
||||||
|
if trace_recorder is not None:
|
||||||
|
tool_error = None
|
||||||
|
if isinstance(tool_result, dict) and "error" in tool_result:
|
||||||
|
tool_error = tool_result["error"]
|
||||||
|
trace_recorder.record_step(
|
||||||
|
step=step,
|
||||||
|
action="tool_call",
|
||||||
|
tool_name=tc.name,
|
||||||
|
input_data=tc.arguments,
|
||||||
|
output_data=tool_result,
|
||||||
|
duration_ms=tool_duration_ms,
|
||||||
|
tokens_used=0,
|
||||||
|
error=tool_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Observe: 将工具结果添加到对话历史
|
||||||
|
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
|
||||||
|
conversation.append(tool_msg)
|
||||||
|
|
||||||
# Incremental compression: compress conversation if it's getting long
|
# Incremental compression: compress conversation if it's getting long
|
||||||
if self._should_compress(conversation, compressor):
|
if self._should_compress(conversation, compressor):
|
||||||
|
|
@ -475,6 +516,7 @@ class ReActEngine:
|
||||||
retrieval_config: dict[str, Any] | None = None,
|
retrieval_config: dict[str, Any] | None = None,
|
||||||
cancellation_token: CancellationToken | None = None,
|
cancellation_token: CancellationToken | None = None,
|
||||||
timeout_seconds: float | None = None,
|
timeout_seconds: float | None = None,
|
||||||
|
confirmation_handler: Any | None = None,
|
||||||
):
|
):
|
||||||
"""Execute ReAct loop, yielding ReActEvent objects.
|
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||||
|
|
||||||
|
|
@ -627,6 +669,68 @@ class ReActEngine:
|
||||||
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||||
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||||
|
|
||||||
|
# 检测工具返回的确认请求
|
||||||
|
if isinstance(tool_result, dict) and tool_result.get("needs_confirmation"):
|
||||||
|
confirmation_id = tool_result["confirmation_id"]
|
||||||
|
command = tool_result.get("command", "")
|
||||||
|
reason = tool_result.get("reason", "")
|
||||||
|
|
||||||
|
# Yield 确认请求事件
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="confirmation_request",
|
||||||
|
step=step,
|
||||||
|
data={
|
||||||
|
"confirmation_id": confirmation_id,
|
||||||
|
"tool_name": tc.name,
|
||||||
|
"command": command,
|
||||||
|
"reason": reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 等待用户确认
|
||||||
|
approved = False
|
||||||
|
if confirmation_handler is not None:
|
||||||
|
try:
|
||||||
|
approved = await confirmation_handler(confirmation_id, command, reason)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Confirmation handler error: {e}")
|
||||||
|
|
||||||
|
if approved:
|
||||||
|
# 用户确认执行:临时绕过安全检查重新执行
|
||||||
|
tool = self._find_tool(tc.name, tools)
|
||||||
|
if tool and hasattr(tool, '_is_dangerous'):
|
||||||
|
# 保存原始 _is_dangerous 并临时禁用
|
||||||
|
original_is_dangerous = tool._is_dangerous
|
||||||
|
tool._is_dangerous = lambda cmd: False
|
||||||
|
try:
|
||||||
|
tool_result = await tool.safe_execute(**tc.arguments)
|
||||||
|
finally:
|
||||||
|
tool._is_dangerous = original_is_dangerous
|
||||||
|
else:
|
||||||
|
tool_result = await self._execute_tool(tc.name, tc.arguments, tools)
|
||||||
|
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="confirmation_result",
|
||||||
|
step=step,
|
||||||
|
data={"confirmation_id": confirmation_id, "approved": True},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 用户拒绝执行
|
||||||
|
tool_result = {
|
||||||
|
"output": "",
|
||||||
|
"exit_code": 126,
|
||||||
|
"is_error": True,
|
||||||
|
"error_type": "permission_denied",
|
||||||
|
"message": f"用户拒绝执行命令: {command[:100]}",
|
||||||
|
}
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="confirmation_result",
|
||||||
|
step=step,
|
||||||
|
data={"confirmation_id": confirmation_id, "approved": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_duration_ms = int((time.monotonic() - tool_start) * 1000)
|
||||||
|
|
||||||
react_step = ReActStep(
|
react_step = ReActStep(
|
||||||
step=step,
|
step=step,
|
||||||
action="tool_call",
|
action="tool_call",
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@ class ProviderConfig:
|
||||||
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
type: str = "openai" # "openai" | "anthropic" | "gemini"
|
||||||
max_tokens: int = 4096 # Anthropic: default max_tokens
|
max_tokens: int = 4096 # Anthropic: default max_tokens
|
||||||
timeout: float = 120.0 # Anthropic: request timeout
|
timeout: float = 120.0 # Anthropic: request timeout
|
||||||
|
max_connections: int = 100 # httpx 连接池最大连接数
|
||||||
|
max_keepalive_connections: int = 20 # httpx 连接池最大保活连接数
|
||||||
|
keepalive_expiry: float = 30.0 # httpx 保活连接过期时间(秒)
|
||||||
retry: RetryConfig | None = None
|
retry: RetryConfig | None = None
|
||||||
circuit_breaker: CircuitBreakerConfig | None = None
|
circuit_breaker: CircuitBreakerConfig | None = None
|
||||||
|
|
||||||
|
|
@ -68,6 +71,9 @@ class LLMConfig:
|
||||||
type=pconf.get("type", "openai"),
|
type=pconf.get("type", "openai"),
|
||||||
max_tokens=pconf.get("max_tokens", 4096),
|
max_tokens=pconf.get("max_tokens", 4096),
|
||||||
timeout=pconf.get("timeout", 120.0),
|
timeout=pconf.get("timeout", 120.0),
|
||||||
|
max_connections=pconf.get("max_connections", 100),
|
||||||
|
max_keepalive_connections=pconf.get("max_keepalive_connections", 20),
|
||||||
|
keepalive_expiry=pconf.get("keepalive_expiry", 30.0),
|
||||||
retry=retry,
|
retry=retry,
|
||||||
circuit_breaker=circuit_breaker,
|
circuit_breaker=circuit_breaker,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,9 @@ class AnthropicProvider(LLMProvider):
|
||||||
thinking_enabled: bool = False,
|
thinking_enabled: bool = False,
|
||||||
retry_config: RetryConfig | None = None,
|
retry_config: RetryConfig | None = None,
|
||||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||||
|
max_connections: int = 100,
|
||||||
|
max_keepalive_connections: int = 20,
|
||||||
|
keepalive_expiry: float = 30.0,
|
||||||
):
|
):
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
@ -63,6 +66,11 @@ class AnthropicProvider(LLMProvider):
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._thinking_enabled = thinking_enabled
|
self._thinking_enabled = thinking_enabled
|
||||||
|
self._limits = httpx.Limits(
|
||||||
|
max_connections=max_connections,
|
||||||
|
max_keepalive_connections=max_keepalive_connections,
|
||||||
|
keepalive_expiry=keepalive_expiry,
|
||||||
|
)
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||||
self._circuit_breaker = (
|
self._circuit_breaker = (
|
||||||
|
|
@ -74,7 +82,7 @@ class AnthropicProvider(LLMProvider):
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
"""Lazy client initialization"""
|
"""Lazy client initialization"""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,9 @@ class GeminiProvider(LLMProvider):
|
||||||
safety_settings: list | None = None,
|
safety_settings: list | None = None,
|
||||||
retry_config: RetryConfig | None = None,
|
retry_config: RetryConfig | None = None,
|
||||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||||
|
max_connections: int = 100,
|
||||||
|
max_keepalive_connections: int = 20,
|
||||||
|
keepalive_expiry: float = 30.0,
|
||||||
):
|
):
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._model = model
|
self._model = model
|
||||||
|
|
@ -60,6 +63,11 @@ class GeminiProvider(LLMProvider):
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._safety_settings = safety_settings
|
self._safety_settings = safety_settings
|
||||||
|
self._limits = httpx.Limits(
|
||||||
|
max_connections=max_connections,
|
||||||
|
max_keepalive_connections=max_keepalive_connections,
|
||||||
|
keepalive_expiry=keepalive_expiry,
|
||||||
|
)
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||||
self._circuit_breaker = (
|
self._circuit_breaker = (
|
||||||
|
|
@ -71,7 +79,7 @@ class GeminiProvider(LLMProvider):
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
"""Lazy client initialization"""
|
"""Lazy client initialization"""
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
self._client = httpx.AsyncClient(timeout=self._timeout)
|
self._client = httpx.AsyncClient(timeout=self._timeout, limits=self._limits)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -46,11 +46,19 @@ class OpenAICompatibleProvider(LLMProvider):
|
||||||
default_model: str = "gpt-4o-mini",
|
default_model: str = "gpt-4o-mini",
|
||||||
retry_config: RetryConfig | None = None,
|
retry_config: RetryConfig | None = None,
|
||||||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||||||
|
max_connections: int = 100,
|
||||||
|
max_keepalive_connections: int = 20,
|
||||||
|
keepalive_expiry: float = 30.0,
|
||||||
):
|
):
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._base_url = base_url.rstrip("/")
|
self._base_url = base_url.rstrip("/")
|
||||||
self._default_model = default_model
|
self._default_model = default_model
|
||||||
self._client = httpx.AsyncClient(timeout=60.0)
|
limits = httpx.Limits(
|
||||||
|
max_connections=max_connections,
|
||||||
|
max_keepalive_connections=max_keepalive_connections,
|
||||||
|
keepalive_expiry=keepalive_expiry,
|
||||||
|
)
|
||||||
|
self._client = httpx.AsyncClient(timeout=60.0, limits=limits)
|
||||||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||||||
self._circuit_breaker = (
|
self._circuit_breaker = (
|
||||||
CircuitBreaker(circuit_breaker_config, provider="openai")
|
CircuitBreaker(circuit_breaker_config, provider="openai")
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ from agentkit.telemetry.setup import setup_telemetry
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ALLOWED_ENV_PREFIXES = (
|
_ALLOWED_ENV_PREFIXES = (
|
||||||
'AGENTKIT_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
'AGENTKIT_', 'DASHSCOPE_', 'OPENAI_', 'ANTHROPIC_', 'GEMINI_',
|
||||||
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
'TAVILY_', 'SERPER_', 'DEEPSEEK_',
|
||||||
)
|
)
|
||||||
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
_ALLOWED_ENV_EXACT = {'DATABASE_URL', 'REDIS_URL'}
|
||||||
|
|
@ -52,6 +52,9 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||||
max_tokens=pconf.max_tokens,
|
max_tokens=pconf.max_tokens,
|
||||||
base_url=pconf.base_url or "https://api.anthropic.com",
|
base_url=pconf.base_url or "https://api.anthropic.com",
|
||||||
timeout=pconf.timeout,
|
timeout=pconf.timeout,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
elif pconf.type == "gemini":
|
elif pconf.type == "gemini":
|
||||||
provider = GeminiProvider(
|
provider = GeminiProvider(
|
||||||
|
|
@ -60,11 +63,17 @@ def _build_llm_gateway(config: ServerConfig) -> LLMGateway:
|
||||||
max_output_tokens=pconf.max_tokens,
|
max_output_tokens=pconf.max_tokens,
|
||||||
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
base_url=pconf.base_url or "https://generativelanguage.googleapis.com",
|
||||||
timeout=pconf.timeout,
|
timeout=pconf.timeout,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
provider = OpenAICompatibleProvider(
|
provider = OpenAICompatibleProvider(
|
||||||
api_key=pconf.api_key,
|
api_key=pconf.api_key,
|
||||||
base_url=pconf.base_url,
|
base_url=pconf.base_url,
|
||||||
|
max_connections=pconf.max_connections,
|
||||||
|
max_keepalive_connections=pconf.max_keepalive_connections,
|
||||||
|
keepalive_expiry=pconf.keepalive_expiry,
|
||||||
)
|
)
|
||||||
gateway.register_provider(name, provider)
|
gateway.register_provider(name, provider)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -146,7 +155,7 @@ async def lifespan(app: FastAPI):
|
||||||
"serper_api_key": os.environ.get("SERPER_API_KEY"),
|
"serper_api_key": os.environ.get("SERPER_API_KEY"),
|
||||||
}
|
}
|
||||||
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
|
agent._tool_registry.register(MemoryTool(memory_store=memory_store))
|
||||||
agent._tool_registry.register(ShellTool(working_dir=os.getcwd()))
|
agent._tool_registry.register(ShellTool())
|
||||||
agent._tool_registry.register(BaiduSearchTool())
|
agent._tool_registry.register(BaiduSearchTool())
|
||||||
agent._tool_registry.register(WebSearchTool(**search_api_keys))
|
agent._tool_registry.register(WebSearchTool(**search_api_keys))
|
||||||
agent._tool_registry.register(WebCrawlTool())
|
agent._tool_registry.register(WebCrawlTool())
|
||||||
|
|
@ -472,6 +481,7 @@ def create_app(
|
||||||
llm_gateway=app.state.llm_gateway,
|
llm_gateway=app.state.llm_gateway,
|
||||||
org_context=org_context,
|
org_context=org_context,
|
||||||
auction_enabled=auction_enabled,
|
auction_enabled=auction_enabled,
|
||||||
|
classifier=server_config.router.get("classifier", "heuristic") if server_config and server_config.router else "heuristic",
|
||||||
)
|
)
|
||||||
app.state.cost_aware_router = cost_aware_router
|
app.state.cost_aware_router = cost_aware_router
|
||||||
# Initialize task store from config
|
# Initialize task store from config
|
||||||
|
|
|
||||||
|
|
@ -110,6 +110,7 @@ class ServerConfig:
|
||||||
bus: dict[str, Any] | None = None,
|
bus: dict[str, Any] | None = None,
|
||||||
marketplace: dict[str, Any] | None = None,
|
marketplace: dict[str, Any] | None = None,
|
||||||
alignment: dict[str, Any] | None = None,
|
alignment: dict[str, Any] | None = None,
|
||||||
|
router: dict[str, Any] | None = None,
|
||||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
@ -132,6 +133,7 @@ class ServerConfig:
|
||||||
self.bus = bus or {}
|
self.bus = bus or {}
|
||||||
self.marketplace = marketplace or {}
|
self.marketplace = marketplace or {}
|
||||||
self.alignment = alignment or {}
|
self.alignment = alignment or {}
|
||||||
|
self.router = router or {}
|
||||||
self.on_change = on_change
|
self.on_change = on_change
|
||||||
|
|
||||||
# Config watching state
|
# Config watching state
|
||||||
|
|
@ -196,6 +198,9 @@ class ServerConfig:
|
||||||
# Alignment config
|
# Alignment config
|
||||||
alignment_data = data.get("alignment", {})
|
alignment_data = data.get("alignment", {})
|
||||||
|
|
||||||
|
# Router config
|
||||||
|
router_data = data.get("router", {})
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
host=server.get("host", "0.0.0.0"),
|
host=server.get("host", "0.0.0.0"),
|
||||||
port=server.get("port", 8001),
|
port=server.get("port", 8001),
|
||||||
|
|
@ -217,6 +222,7 @@ class ServerConfig:
|
||||||
bus=server.get("bus"),
|
bus=server.get("bus"),
|
||||||
marketplace=marketplace_data,
|
marketplace=marketplace_data,
|
||||||
alignment=alignment_data,
|
alignment=alignment_data,
|
||||||
|
router=router_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -411,6 +417,7 @@ class ServerConfig:
|
||||||
self.session = new_config.session
|
self.session = new_config.session
|
||||||
self.marketplace = new_config.marketplace
|
self.marketplace = new_config.marketplace
|
||||||
self.alignment = new_config.alignment
|
self.alignment = new_config.alignment
|
||||||
|
self.router = new_config.router
|
||||||
self._last_mtime = new_config._last_mtime
|
self._last_mtime = new_config._last_mtime
|
||||||
|
|
||||||
logger.info(f"Config reloaded from {path}")
|
logger.info(f"Config reloaded from {path}")
|
||||||
|
|
|
||||||
|
|
@ -279,8 +279,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
await websocket.close(code=1000, reason="Session closed")
|
await websocket.close(code=1000, reason="Session closed")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Track pending replies for AskHumanTool
|
# Track pending replies for AskHumanTool and confirmations
|
||||||
pending_replies: dict[str, asyncio.Future] = {}
|
pending_replies: dict[str, asyncio.Future] = {}
|
||||||
|
pending_confirmations: dict[str, asyncio.Future] = {}
|
||||||
chat_manager.add(session_id, websocket, pending_replies)
|
chat_manager.add(session_id, websocket, pending_replies)
|
||||||
|
|
||||||
cancellation_token = CancellationToken()
|
cancellation_token = CancellationToken()
|
||||||
|
|
@ -308,7 +309,7 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
# Create a fresh CancellationToken for each message
|
# Create a fresh CancellationToken for each message
|
||||||
message_token = CancellationToken()
|
message_token = CancellationToken()
|
||||||
await _handle_chat_message(
|
await _handle_chat_message(
|
||||||
websocket, session_id, content, sm, message_token, pending_replies
|
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations
|
||||||
)
|
)
|
||||||
|
|
||||||
elif msg_type == "reply":
|
elif msg_type == "reply":
|
||||||
|
|
@ -318,6 +319,13 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
if request_id and request_id in pending_replies:
|
if request_id and request_id in pending_replies:
|
||||||
pending_replies[request_id].set_result(reply_content)
|
pending_replies[request_id].set_result(reply_content)
|
||||||
|
|
||||||
|
elif msg_type == "confirmation_reply":
|
||||||
|
# Reply to confirmation request
|
||||||
|
confirmation_id = msg.get("confirmation_id")
|
||||||
|
approved = msg.get("approved", False)
|
||||||
|
if confirmation_id and confirmation_id in pending_confirmations:
|
||||||
|
pending_confirmations[confirmation_id].set_result(approved)
|
||||||
|
|
||||||
elif msg_type == "cancel":
|
elif msg_type == "cancel":
|
||||||
cancellation_token.cancel()
|
cancellation_token.cancel()
|
||||||
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
||||||
|
|
@ -338,6 +346,9 @@ async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
for fut in pending_replies.values():
|
for fut in pending_replies.values():
|
||||||
if not fut.done():
|
if not fut.done():
|
||||||
fut.cancel()
|
fut.cancel()
|
||||||
|
for fut in pending_confirmations.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
chat_manager.remove(session_id, websocket)
|
chat_manager.remove(session_id, websocket)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -348,6 +359,7 @@ async def _handle_chat_message(
|
||||||
sm: SessionManager,
|
sm: SessionManager,
|
||||||
cancellation_token: CancellationToken,
|
cancellation_token: CancellationToken,
|
||||||
pending_replies: dict[str, asyncio.Future],
|
pending_replies: dict[str, asyncio.Future],
|
||||||
|
pending_confirmations: dict[str, asyncio.Future] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a user message: append to session, execute Agent, stream events.
|
"""Handle a user message: append to session, execute Agent, stream events.
|
||||||
|
|
||||||
|
|
@ -414,6 +426,35 @@ async def _handle_chat_message(
|
||||||
# Execute Agent with streaming
|
# Execute Agent with streaming
|
||||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||||
|
|
||||||
|
# Create confirmation handler that sends request to frontend and waits for reply
|
||||||
|
_pending_confirmations = pending_confirmations or {}
|
||||||
|
|
||||||
|
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
|
||||||
|
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
|
||||||
|
# Send confirmation request to frontend
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "confirmation_request",
|
||||||
|
"data": {
|
||||||
|
"confirmation_id": confirmation_id,
|
||||||
|
"command": command,
|
||||||
|
"reason": reason,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
# Create a Future and wait for the user's reply
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
future: asyncio.Future[bool] = loop.create_future()
|
||||||
|
_pending_confirmations[confirmation_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait up to 5 minutes for user confirmation
|
||||||
|
return await asyncio.wait_for(future, timeout=300.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Confirmation request {confirmation_id} timed out")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
_pending_confirmations.pop(confirmation_id, None)
|
||||||
|
|
||||||
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
|
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -425,6 +466,7 @@ async def _handle_chat_message(
|
||||||
agent_name=routing.agent_name,
|
agent_name=routing.agent_name,
|
||||||
system_prompt=routing.system_prompt,
|
system_prompt=routing.system_prompt,
|
||||||
cancellation_token=cancellation_token,
|
cancellation_token=cancellation_token,
|
||||||
|
confirmation_handler=_confirmation_handler,
|
||||||
):
|
):
|
||||||
if event.event_type == "final_answer":
|
if event.event_type == "final_answer":
|
||||||
final_content = event.data.get("output", "")
|
final_content = event.data.get("output", "")
|
||||||
|
|
@ -432,6 +474,14 @@ async def _handle_chat_message(
|
||||||
"type": "final_answer",
|
"type": "final_answer",
|
||||||
"content": final_content,
|
"content": final_content,
|
||||||
})
|
})
|
||||||
|
elif event.event_type == "confirmation_request":
|
||||||
|
# Already handled by confirmation_handler, just notify frontend
|
||||||
|
pass
|
||||||
|
elif event.event_type == "confirmation_result":
|
||||||
|
await websocket.send_json({
|
||||||
|
"type": "confirmation_result",
|
||||||
|
"data": event.data,
|
||||||
|
})
|
||||||
elif event.event_type == "token":
|
elif event.event_type == "token":
|
||||||
await websocket.send_json({
|
await websocket.send_json({
|
||||||
"type": "token",
|
"type": "token",
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||||
|
|
@ -11,15 +13,125 @@ from agentkit.session.store import InMemorySessionStore, SessionStore
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncWriteQueue:
|
||||||
|
"""Background write-ahead queue for non-blocking session persistence.
|
||||||
|
|
||||||
|
Accepts write operations (append_message + save_session) as tasks,
|
||||||
|
executes them in a background ``asyncio.Task``, and maintains a small
|
||||||
|
in-memory WAL buffer for crash recovery and immediate reads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, store: SessionStore, max_buffer_size: int = 256) -> None:
|
||||||
|
self._store = store
|
||||||
|
self._queue: asyncio.Queue[tuple[Message, Session] | None] | None = None
|
||||||
|
self._worker: asyncio.Task | None = None
|
||||||
|
# WAL buffer: session_id -> list of Messages not yet persisted
|
||||||
|
self._wal_buffer: dict[str, list[Message]] = defaultdict(list)
|
||||||
|
self._max_buffer_size = max_buffer_size
|
||||||
|
self._pending_count = 0
|
||||||
|
|
||||||
|
def _ensure_started(self) -> None:
|
||||||
|
"""Start the background writer task if not already running (lazy init)."""
|
||||||
|
if self._worker is not None and not self._worker.done():
|
||||||
|
return
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._worker = asyncio.create_task(self._writer_loop())
|
||||||
|
|
||||||
|
async def _writer_loop(self) -> None:
|
||||||
|
"""Consume write tasks from the queue and persist them."""
|
||||||
|
assert self._queue is not None
|
||||||
|
while True:
|
||||||
|
item = await self._queue.get()
|
||||||
|
if item is None:
|
||||||
|
# Sentinel: graceful shutdown signal
|
||||||
|
self._queue.task_done()
|
||||||
|
break
|
||||||
|
message, session = item
|
||||||
|
try:
|
||||||
|
await self._store.append_message(message)
|
||||||
|
session.updated_at = __import__("datetime").datetime.now(
|
||||||
|
__import__("datetime").timezone.utc
|
||||||
|
)
|
||||||
|
await self._store.save_session(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"AsyncWriteQueue: failed to persist message %s for session %s",
|
||||||
|
message.message_id,
|
||||||
|
message.session_id,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Remove from WAL buffer once persisted
|
||||||
|
buf = self._wal_buffer.get(message.session_id)
|
||||||
|
if buf is not None:
|
||||||
|
try:
|
||||||
|
buf.remove(message)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if not buf:
|
||||||
|
self._wal_buffer.pop(message.session_id, None)
|
||||||
|
self._pending_count -= 1
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def enqueue(self, message: Message, session: Session) -> None:
|
||||||
|
"""Enqueue a write task without blocking.
|
||||||
|
|
||||||
|
The message is immediately added to the WAL buffer so that
|
||||||
|
``get_chat_messages`` can see it before persistence completes.
|
||||||
|
Lazily starts the background writer on first call.
|
||||||
|
"""
|
||||||
|
self._ensure_started()
|
||||||
|
assert self._queue is not None
|
||||||
|
self._wal_buffer[message.session_id].append(message)
|
||||||
|
self._pending_count += 1
|
||||||
|
self._queue.put_nowait((message, session))
|
||||||
|
|
||||||
|
def buffered_messages(self, session_id: str) -> list[Message]:
|
||||||
|
"""Return WAL-buffered messages for *session_id* not yet persisted."""
|
||||||
|
return list(self._wal_buffer.get(session_id, []))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
"""Number of write tasks waiting in the queue."""
|
||||||
|
return self._pending_count
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Wait until all queued writes have been persisted."""
|
||||||
|
if self._queue is not None:
|
||||||
|
await self._queue.join()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Signal the writer to stop and wait for it to finish."""
|
||||||
|
if self._queue is not None and self._worker is not None:
|
||||||
|
await self._queue.put(None) # sentinel
|
||||||
|
await self._worker
|
||||||
|
self._worker = None
|
||||||
|
self._queue = None
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""Manages conversation sessions and their messages.
|
"""Manages conversation sessions and their messages.
|
||||||
|
|
||||||
Provides a high-level API for creating, querying, and updating
|
Provides a high-level API for creating, querying, and updating
|
||||||
sessions, as well as appending and retrieving messages.
|
sessions, as well as appending and retrieving messages.
|
||||||
|
|
||||||
|
When ``async_writes=True``, ``append_message`` is non-blocking: the
|
||||||
|
message is placed in a write-ahead buffer and persisted in the
|
||||||
|
background. Call ``flush()`` or ``close()`` to ensure all writes
|
||||||
|
are durably persisted before shutdown.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, store: SessionStore | None = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
store: SessionStore | None = None,
|
||||||
|
*,
|
||||||
|
async_writes: bool = False,
|
||||||
|
wal_buffer_size: int = 256,
|
||||||
|
):
|
||||||
self._store = store or InMemorySessionStore()
|
self._store = store or InMemorySessionStore()
|
||||||
|
self._async_writes = async_writes
|
||||||
|
self._write_queue: AsyncWriteQueue | None = None
|
||||||
|
if async_writes:
|
||||||
|
self._write_queue = AsyncWriteQueue(self._store, max_buffer_size=wal_buffer_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def store(self) -> SessionStore:
|
def store(self) -> SessionStore:
|
||||||
|
|
@ -61,7 +173,13 @@ class SessionManager:
|
||||||
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
|
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
|
||||||
|
|
||||||
async def close_session(self, session_id: str) -> Session | None:
|
async def close_session(self, session_id: str) -> Session | None:
|
||||||
"""Close a session. Closed sessions cannot accept new messages."""
|
"""Close a session. Closed sessions cannot accept new messages.
|
||||||
|
|
||||||
|
When async writes are enabled, flushes pending writes before
|
||||||
|
updating the session status.
|
||||||
|
"""
|
||||||
|
if self._write_queue is not None:
|
||||||
|
await self._write_queue.flush()
|
||||||
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
|
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
|
||||||
|
|
||||||
async def delete_session(self, session_id: str) -> bool:
|
async def delete_session(self, session_id: str) -> bool:
|
||||||
|
|
@ -116,11 +234,15 @@ class SessionManager:
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
)
|
)
|
||||||
await self._store.append_message(message)
|
|
||||||
|
|
||||||
# Update session's updated_at timestamp
|
if self._write_queue is not None:
|
||||||
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
# Non-blocking: enqueue write, return immediately
|
||||||
await self._store.save_session(session)
|
self._write_queue.enqueue(message, session)
|
||||||
|
else:
|
||||||
|
# Synchronous path (default, backward-compatible)
|
||||||
|
await self._store.append_message(message)
|
||||||
|
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
||||||
|
await self._store.save_session(session)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
@ -132,6 +254,9 @@ class SessionManager:
|
||||||
) -> list[Message]:
|
) -> list[Message]:
|
||||||
"""Get messages for a session with optional pagination.
|
"""Get messages for a session with optional pagination.
|
||||||
|
|
||||||
|
When async writes are enabled, includes WAL-buffered messages
|
||||||
|
not yet persisted to the store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: Target session ID.
|
session_id: Target session ID.
|
||||||
limit: Maximum number of messages to return. None for all.
|
limit: Maximum number of messages to return. None for all.
|
||||||
|
|
@ -140,15 +265,30 @@ class SessionManager:
|
||||||
Returns:
|
Returns:
|
||||||
List of messages ordered chronologically.
|
List of messages ordered chronologically.
|
||||||
"""
|
"""
|
||||||
return await self._store.get_messages(session_id, limit=limit, offset=offset)
|
persisted = await self._store.get_messages(session_id, limit=None, offset=0)
|
||||||
|
if self._write_queue is not None:
|
||||||
|
buffered = self._write_queue.buffered_messages(session_id)
|
||||||
|
if buffered:
|
||||||
|
# Merge: persisted + buffered (dedup by message_id)
|
||||||
|
seen = {m.message_id for m in persisted}
|
||||||
|
for m in buffered:
|
||||||
|
if m.message_id not in seen:
|
||||||
|
persisted.append(m)
|
||||||
|
sliced = persisted[offset:]
|
||||||
|
if limit is not None:
|
||||||
|
sliced = sliced[:limit]
|
||||||
|
return sliced
|
||||||
|
|
||||||
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
|
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
|
||||||
"""Get messages formatted for LLM chat API consumption.
|
"""Get messages formatted for LLM chat API consumption.
|
||||||
|
|
||||||
Returns messages as OpenAI-compatible dicts suitable for
|
Returns messages as OpenAI-compatible dicts suitable for
|
||||||
passing directly to the ReAct engine or LLM Gateway.
|
passing directly to the ReAct engine or LLM Gateway.
|
||||||
|
|
||||||
|
When async writes are enabled, includes WAL-buffered messages
|
||||||
|
not yet persisted to the store.
|
||||||
"""
|
"""
|
||||||
messages = await self._store.get_messages(session_id)
|
messages = await self.get_messages(session_id)
|
||||||
return [m.to_chat_message() for m in messages]
|
return [m.to_chat_message() for m in messages]
|
||||||
|
|
||||||
async def count_messages(self, session_id: str) -> int:
|
async def count_messages(self, session_id: str) -> int:
|
||||||
|
|
@ -158,3 +298,21 @@ class SessionManager:
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
"""Check if the underlying store is healthy."""
|
"""Check if the underlying store is healthy."""
|
||||||
return await self._store.health_check()
|
return await self._store.health_check()
|
||||||
|
|
||||||
|
async def flush(self) -> None:
|
||||||
|
"""Wait for all pending async writes to be persisted.
|
||||||
|
|
||||||
|
No-op when async writes are not enabled.
|
||||||
|
"""
|
||||||
|
if self._write_queue is not None:
|
||||||
|
await self._write_queue.flush()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Flush pending writes and stop the background writer.
|
||||||
|
|
||||||
|
Should be called during application shutdown when async writes
|
||||||
|
are enabled. Safe to call even when async writes are disabled.
|
||||||
|
"""
|
||||||
|
if self._write_queue is not None:
|
||||||
|
await self._write_queue.flush()
|
||||||
|
await self._write_queue.stop()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content
|
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
|
||||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
from agentkit.router.intent import IntentRouter, RoutingResult
|
from agentkit.router.intent import IntentRouter, RoutingResult
|
||||||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||||||
|
|
@ -510,3 +510,122 @@ class TestTokenizeContent:
|
||||||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||||||
for bigram in expected_bigrams:
|
for bigram in expected_bigrams:
|
||||||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HeuristicClassifier
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeuristicClassifier:
|
||||||
|
"""HeuristicClassifier 本地启发式分类器测试"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.classifier = HeuristicClassifier()
|
||||||
|
|
||||||
|
def test_short_greeting_low_complexity(self):
|
||||||
|
"""短问候语 → 低复杂度"""
|
||||||
|
score = self.classifier.classify("你好呀")
|
||||||
|
assert score < 0.3
|
||||||
|
|
||||||
|
def test_simple_question_medium_complexity(self):
|
||||||
|
"""含'如何'的简单问题 → 中等复杂度"""
|
||||||
|
score = self.classifier.classify("如何使用这个功能?")
|
||||||
|
assert 0.3 <= score <= 0.7
|
||||||
|
|
||||||
|
def test_tool_request_high_complexity(self):
|
||||||
|
"""含工具关键词的请求 → 高复杂度"""
|
||||||
|
score = self.classifier.classify("帮我搜索一下最新的新闻")
|
||||||
|
assert score > 0.5
|
||||||
|
|
||||||
|
def test_code_request_high_complexity(self):
|
||||||
|
"""代码相关请求 → 高复杂度"""
|
||||||
|
score = self.classifier.classify("写一个Python函数实现快速排序")
|
||||||
|
assert score > 0.6
|
||||||
|
|
||||||
|
def test_multi_step_request_high_complexity(self):
|
||||||
|
"""多步分析请求 → 高复杂度"""
|
||||||
|
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
|
||||||
|
assert score > 0.7
|
||||||
|
|
||||||
|
def test_empty_string_zero_complexity(self):
|
||||||
|
"""空字符串 → 零复杂度"""
|
||||||
|
assert self.classifier.classify("") == 0.0
|
||||||
|
assert self.classifier.classify(" ") == 0.0
|
||||||
|
|
||||||
|
def test_long_message_higher_complexity(self):
|
||||||
|
"""长消息 → 更高复杂度"""
|
||||||
|
short = "帮我查一下"
|
||||||
|
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
|
||||||
|
assert self.classifier.classify(long) > self.classifier.classify(short)
|
||||||
|
|
||||||
|
def test_code_patterns_boost_complexity(self):
|
||||||
|
"""代码模式(反引号/括号)提升复杂度"""
|
||||||
|
with_code = "运行这段代码 `print('hello')`"
|
||||||
|
without_code = "运行这段代码 print hello"
|
||||||
|
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
|
||||||
|
|
||||||
|
def test_score_bounded_0_to_1(self):
|
||||||
|
"""复杂度值始终在 [0.0, 1.0] 范围"""
|
||||||
|
test_inputs = [
|
||||||
|
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
|
||||||
|
]
|
||||||
|
for inp in test_inputs:
|
||||||
|
score = self.classifier.classify(inp)
|
||||||
|
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeuristicClassifierIntegration:
|
||||||
|
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_heuristic_mode_no_llm_call(self):
|
||||||
|
"""heuristic 模式下不调用 LLM"""
|
||||||
|
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||||
|
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic")
|
||||||
|
result = await router.route(
|
||||||
|
content="帮我分析一下数据",
|
||||||
|
skill_registry=_make_skill_registry(),
|
||||||
|
intent_router=_make_intent_router(),
|
||||||
|
default_tools=[],
|
||||||
|
default_system_prompt="You are helpful.",
|
||||||
|
)
|
||||||
|
# LLM gateway.chat 不应被调用
|
||||||
|
gateway.chat.assert_not_called()
|
||||||
|
# 复杂度应来自启发式分类器
|
||||||
|
assert result.complexity > 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_mode_uses_llm(self):
|
||||||
|
"""llm 模式下调用 LLM quick_classify"""
|
||||||
|
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||||||
|
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
|
||||||
|
result = await router.route(
|
||||||
|
content="帮我分析一下数据",
|
||||||
|
skill_registry=_make_skill_registry(),
|
||||||
|
intent_router=_make_intent_router(),
|
||||||
|
default_tools=[],
|
||||||
|
default_system_prompt="You are helpful.",
|
||||||
|
)
|
||||||
|
# LLM gateway.chat 应被调用
|
||||||
|
gateway.chat.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_heuristic_greeting_still_layer0(self):
|
||||||
|
"""heuristic 模式下问候仍走 Layer 0"""
|
||||||
|
router = CostAwareRouter(classifier="heuristic")
|
||||||
|
result = await router.route(
|
||||||
|
content="你好",
|
||||||
|
skill_registry=_make_skill_registry(),
|
||||||
|
intent_router=_make_intent_router(),
|
||||||
|
default_tools=[],
|
||||||
|
default_system_prompt="You are helpful.",
|
||||||
|
)
|
||||||
|
assert result.match_method == "greeting"
|
||||||
|
assert result.complexity == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_heuristic_default_classifier_mode(self):
|
||||||
|
"""默认分类器模式为 heuristic"""
|
||||||
|
router = CostAwareRouter()
|
||||||
|
assert router._classifier == "heuristic"
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,11 @@ def manager():
|
||||||
return SessionManager(store=InMemorySessionStore())
|
return SessionManager(store=InMemorySessionStore())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_manager():
|
||||||
|
return SessionManager(store=InMemorySessionStore(), async_writes=True)
|
||||||
|
|
||||||
|
|
||||||
class TestSessionManagerCreate:
|
class TestSessionManagerCreate:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_session(self, manager):
|
async def test_create_session(self, manager):
|
||||||
|
|
@ -197,3 +202,137 @@ class TestSessionManagerHealth:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_health_check(self, manager):
|
async def test_health_check(self, manager):
|
||||||
assert await manager.health_check() is True
|
assert await manager.health_check() is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncWrites:
|
||||||
|
"""Tests for async (non-blocking) write behaviour."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_append_message_returns_immediately(self, async_manager):
|
||||||
|
"""append_message returns the Message before it is persisted."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
msg = await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="Hello",
|
||||||
|
)
|
||||||
|
# Message is returned immediately
|
||||||
|
assert msg.role == MessageRole.USER
|
||||||
|
assert msg.content == "Hello"
|
||||||
|
# Give the background writer a moment, then verify persistence
|
||||||
|
await async_manager.flush()
|
||||||
|
persisted = await async_manager.store.get_messages(session.session_id)
|
||||||
|
assert len(persisted) == 1
|
||||||
|
assert persisted[0].content == "Hello"
|
||||||
|
await async_manager.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_chat_messages_includes_wal_buffered(self, async_manager):
|
||||||
|
"""get_chat_messages returns WAL-buffered messages not yet persisted."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
# Append a message — it may still be in the WAL buffer
|
||||||
|
await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="Buffered",
|
||||||
|
)
|
||||||
|
# get_chat_messages should include WAL-buffered messages
|
||||||
|
chat_msgs = await async_manager.get_chat_messages(session.session_id)
|
||||||
|
assert len(chat_msgs) >= 1
|
||||||
|
assert any(m["content"] == "Buffered" for m in chat_msgs)
|
||||||
|
await async_manager.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flush_ensures_all_pending_writes(self, async_manager):
|
||||||
|
"""flush() waits until all queued writes are persisted."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
for i in range(5):
|
||||||
|
await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=f"Msg {i}",
|
||||||
|
)
|
||||||
|
await async_manager.flush()
|
||||||
|
persisted = await async_manager.store.get_messages(session.session_id)
|
||||||
|
assert len(persisted) == 5
|
||||||
|
await async_manager.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rapid_appends_are_batched(self, async_manager):
|
||||||
|
"""Multiple rapid append_messages are all persisted correctly."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
# Fire off many messages rapidly
|
||||||
|
messages = []
|
||||||
|
for i in range(20):
|
||||||
|
msg = await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content=f"Rapid {i}",
|
||||||
|
)
|
||||||
|
messages.append(msg)
|
||||||
|
await async_manager.flush()
|
||||||
|
persisted = await async_manager.store.get_messages(session.session_id)
|
||||||
|
assert len(persisted) == 20
|
||||||
|
contents = [m.content for m in persisted]
|
||||||
|
for i in range(20):
|
||||||
|
assert f"Rapid {i}" in contents
|
||||||
|
await async_manager.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_close_flushes_pending_writes(self, async_manager):
|
||||||
|
"""Closing a session flushes pending writes first."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="Before close",
|
||||||
|
)
|
||||||
|
closed = await async_manager.close_session(session.session_id)
|
||||||
|
assert closed.status == SessionStatus.CLOSED
|
||||||
|
# Message should be persisted because close_session flushes
|
||||||
|
persisted = await async_manager.store.get_messages(session.session_id)
|
||||||
|
assert len(persisted) == 1
|
||||||
|
assert persisted[0].content == "Before close"
|
||||||
|
await async_manager.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_close_stops_writer(self, async_manager):
|
||||||
|
"""close() flushes and stops the background writer."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="Final",
|
||||||
|
)
|
||||||
|
await async_manager.close()
|
||||||
|
# After close, the write queue should be stopped
|
||||||
|
assert async_manager._write_queue is None or async_manager._write_queue._worker is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_writes_disabled_by_default(self):
|
||||||
|
"""Without async_writes=True, writes are synchronous."""
|
||||||
|
mgr = SessionManager(store=InMemorySessionStore())
|
||||||
|
assert mgr._write_queue is None
|
||||||
|
session = await mgr.create_session(agent_name="agent1")
|
||||||
|
await mgr.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="Sync",
|
||||||
|
)
|
||||||
|
# Should be immediately persisted (no flush needed)
|
||||||
|
persisted = await mgr.store.get_messages(session.session_id)
|
||||||
|
assert len(persisted) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_messages_includes_wal_buffered(self, async_manager):
|
||||||
|
"""get_messages returns WAL-buffered messages not yet persisted."""
|
||||||
|
session = await async_manager.create_session(agent_name="agent1")
|
||||||
|
await async_manager.append_message(
|
||||||
|
session_id=session.session_id,
|
||||||
|
role=MessageRole.USER,
|
||||||
|
content="WAL msg",
|
||||||
|
)
|
||||||
|
messages = await async_manager.get_messages(session.session_id)
|
||||||
|
assert len(messages) >= 1
|
||||||
|
assert any(m.content == "WAL msg" for m in messages)
|
||||||
|
await async_manager.close()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue