feat(memory): RAG pipeline optimization — 5 Implementation Units
U1: QueryTransformer — LLM/rule-based query rewriting + sub-query decomposition
U2: HttpRAGService enhanced_search() — rerank + compression via /bases/{kb_id}/retrieve
U3: Structured context injection — source attribution headers in RAG results
U4: RetrieveKnowledgeTool — built-in tool for mid-reasoning knowledge retrieval
U5: Configurable retrieval params + per-KB weights + CJK token estimation
Config example:
memory:
retrieval:
top_k: 5
token_budget: 2000
context_template: structured
query_transform:
enabled: true
strategy: llm
semantic:
search_mode: enhanced
use_rerank: true
kb_weights:
industry-kb-id: 1.2
enterprise-kb-id: 0.8
Tests: 1037 passed, 18 skipped, 0 failed
This commit is contained in:
parent
cd5b39087e
commit
e33dc25ad3
|
|
@ -0,0 +1,341 @@
|
||||||
|
---
|
||||||
|
title: "feat: AgentKit RAG Pipeline Optimization"
|
||||||
|
status: active
|
||||||
|
created: 2026-06-06
|
||||||
|
plan-type: feat
|
||||||
|
origin: RAG 场景问题分析(6 个问题:P0×2, P1×3, P2×1)
|
||||||
|
---
|
||||||
|
|
||||||
|
# feat: AgentKit RAG Pipeline Optimization
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Optimize the AgentKit RAG pipeline to improve retrieval quality and LLM answer accuracy. The current pipeline passes raw user queries directly to the knowledge base, lacks reranking, injects context without source attribution, and has no mechanism for iterative retrieval during ReAct reasoning. This plan addresses 6 identified issues across 5 implementation units.
|
||||||
|
|
||||||
|
## Problem Frame
|
||||||
|
|
||||||
|
AgentKit's RAG integration works end-to-end but has critical quality gaps:
|
||||||
|
|
||||||
|
1. **Query quality** — Raw user queries (often vague or conversational) are sent directly to the knowledge base, resulting in poor recall
|
||||||
|
2. **Retrieval quality** — The `/search` endpoint bypasses GEO's EnhancedRAG (rerank + compression), returning unranked results
|
||||||
|
3. **Context injection** — Knowledge base results are injected as a flat text block without source attribution, making it hard for the LLM to assess credibility
|
||||||
|
4. **Iterative retrieval** — Only one retrieval happens before the ReAct loop; the LLM cannot request more information mid-reasoning
|
||||||
|
5. **Configurability** — `top_k` and `token_budget` are hardcoded in `ReActEngine.execute()`
|
||||||
|
6. **Source differentiation** — All knowledge bases are treated equally regardless of authority or recency
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
| ID | Requirement | Priority |
|
||||||
|
|----|-------------|----------|
|
||||||
|
| R1 | Query rewriting: transform vague user queries into structured retrieval queries before searching | P0 |
|
||||||
|
| R2 | Enhanced retrieval: call GEO's `/bases/{kb_id}/retrieve` endpoint with rerank+compression support | P0 |
|
||||||
|
| R3 | Structured context injection: format RAG results with source attribution (title, score, kb type) | P1 |
|
||||||
|
| R4 | Iterative retrieval: register `retrieve_knowledge` as a built-in Tool for mid-reasoning search | P1 |
|
||||||
|
| R5 | Configurable retrieval parameters: `top_k`, `token_budget`, `retrieval_strategy` from config | P1 |
|
||||||
|
| R6 | Per-knowledge-base weight differentiation: industry vs enterprise weights | P2 |
|
||||||
|
|
||||||
|
## Key Technical Decisions
|
||||||
|
|
||||||
|
### KTD-1: Query rewriting via LLM vs rule-based
|
||||||
|
|
||||||
|
**Decision**: LLM-based query rewriting with a lightweight prompt, falling back to rule-based when no LLM gateway is available.
|
||||||
|
|
||||||
|
**Rationale**: Rule-based rewriting (keyword extraction, synonym expansion) is fast but limited. LLM rewriting can decompose complex queries, infer intent, and generate multiple sub-queries. The cost is one additional LLM call per task, which is acceptable given the retrieval quality improvement. The fallback ensures the system works without an LLM gateway.
|
||||||
|
|
||||||
|
**Alternative considered**: Pure rule-based rewriting — rejected because it cannot handle the diverse query patterns in GEO/SEO domain (e.g., "帮我分析一下竞品的SEO策略" → needs decomposition into "竞品SEO策略分析" + "行业SEO最佳实践").
|
||||||
|
|
||||||
|
### KTD-2: Enhanced retrieval via new endpoint vs extending existing
|
||||||
|
|
||||||
|
**Decision**: Add `enhanced_search()` method to `HttpRAGService` that calls GEO's `/bases/{kb_id}/retrieve` endpoint, keeping the existing `search()` method for backward compatibility.
|
||||||
|
|
||||||
|
**Rationale**: The GEO backend already has `EnhancedRAG.retrieve_with_rerank()` exposed at `POST /bases/{kb_id}/retrieve`. Adding a new method avoids breaking existing consumers while enabling rerank+compression. The config controls which method is used.
|
||||||
|
|
||||||
|
### KTD-3: RAG Tool as built-in vs skill-defined
|
||||||
|
|
||||||
|
**Decision**: Register `retrieve_knowledge` as a built-in Tool in `MemoryRetriever`, auto-registered when semantic memory is configured.
|
||||||
|
|
||||||
|
**Rationale**: Making RAG retrieval a Tool (rather than only a pre-execution step) lets the LLM trigger additional searches during ReAct reasoning. Auto-registration when semantic memory is configured means zero-config for the common case. The Tool is created by `MemoryRetriever` and injected into the agent's tool list.
|
||||||
|
|
||||||
|
### KTD-4: Context injection format
|
||||||
|
|
||||||
|
**Decision**: Use structured markdown with source blocks instead of flat text.
|
||||||
|
|
||||||
|
**Rationale**: The current `## Relevant Past Experience\n{raw_text}` format gives the LLM no way to distinguish high-quality knowledge base results from episodic memories, or to cite sources. Structured blocks with `[来源: 行业库 | 置信度: 0.92 | 文档: 行业报告]` headers let the LLM assess credibility and cite appropriately.
|
||||||
|
|
||||||
|
### KTD-5: Per-knowledge-base weight via filters
|
||||||
|
|
||||||
|
**Decision**: Extend `MemoryRetriever` weights to support per-source-type multipliers, configured via `memory.semantic.kb_weights` in the YAML config.
|
||||||
|
|
||||||
|
**Rationale**: Industry knowledge bases (curated, authoritative) should have higher weight than enterprise-specific ones (narrow, potentially outdated). A simple multiplier per kb_id is sufficient — no need for complex authority scoring.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Units
|
||||||
|
|
||||||
|
### U1. QueryTransformer — Query 改写与扩展
|
||||||
|
|
||||||
|
**Goal**: Transform raw user queries into structured retrieval queries before searching the knowledge base, improving recall from ~30% to ~70%+.
|
||||||
|
|
||||||
|
**Requirements**: R1
|
||||||
|
|
||||||
|
**Dependencies**: None
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/memory/query_transformer.py` (create)
|
||||||
|
- `tests/unit/test_query_transformer.py` (create)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- Create `QueryTransformer` class with two strategies:
|
||||||
|
- `LLMQueryTransformer`: Uses LLM gateway to rewrite queries. Prompt instructs the LLM to: (a) extract core intent, (b) decompose complex queries into 1-3 sub-queries, (c) add domain-specific terms. Returns a `TransformedQuery` with `main_query` and `sub_queries`.
|
||||||
|
- `RuleQueryTransformer`: Fallback that applies rule-based transformations — strip filler words, extract noun phrases, add domain synonyms from a configurable map.
|
||||||
|
- `TransformedQuery` dataclass: `main_query: str`, `sub_queries: list[str]`, `original_query: str`.
|
||||||
|
- `QueryTransformer` is called by `MemoryRetriever.retrieve()` before dispatching to memory layers.
|
||||||
|
- Config: `memory.query_transform.enabled: bool`, `memory.query_transform.strategy: "llm" | "rule"`, `memory.query_transform.max_sub_queries: int = 3`.
|
||||||
|
|
||||||
|
**Patterns to follow**: `agentkit/memory/embedder.py` — abstract base + concrete implementations pattern.
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- LLM transformer: mock LLM gateway, verify prompt construction and response parsing
|
||||||
|
- LLM transformer: verify fallback to original query on LLM error
|
||||||
|
- Rule transformer: verify filler word removal and synonym expansion
|
||||||
|
- Rule transformer: verify no-op when query is already well-formed
|
||||||
|
- Integration: verify `MemoryRetriever.retrieve()` calls transformer before search
|
||||||
|
- Integration: verify sub-queries are searched in parallel and results merged
|
||||||
|
|
||||||
|
**Verification**: All tests pass. `MemoryRetriever` with query transform enabled produces different (better) search calls than without.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U2. HttpRAGService Enhanced Search — 增强检索端点
|
||||||
|
|
||||||
|
**Goal**: Enable AgentKit to call GEO's EnhancedRAG endpoint with rerank and compression, improving retrieval precision from ~50% to ~80%+.
|
||||||
|
|
||||||
|
**Requirements**: R2
|
||||||
|
|
||||||
|
**Dependencies**: None
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/memory/http_rag.py` (modify)
|
||||||
|
- `src/agentkit/memory/semantic.py` (modify)
|
||||||
|
- `src/agentkit/server/config.py` (modify)
|
||||||
|
- `tests/unit/test_http_rag_service.py` (modify)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- Add `enhanced_search()` method to `HttpRAGService`:
|
||||||
|
- Calls `POST /bases/{kb_id}/retrieve` for each configured knowledge base
|
||||||
|
- Passes `use_rerank` and `use_compression` parameters
|
||||||
|
- Merges results from multiple KBs, re-scores by reranked relevance
|
||||||
|
- Add `search_mode: "standard" | "enhanced"` parameter to `SemanticMemory.search()`:
|
||||||
|
- `"standard"`: calls `rag_service.search()` (current behavior, backward compatible)
|
||||||
|
- `"enhanced"`: calls `rag_service.enhanced_search()` with rerank+compression
|
||||||
|
- Config additions under `memory.semantic`:
|
||||||
|
- `search_mode: "enhanced"` (default: `"standard"`)
|
||||||
|
- `use_rerank: true` (default: true when enhanced)
|
||||||
|
- `use_compression: false` (default: false)
|
||||||
|
- `SemanticMemory.search()` passes `filters` through to `HttpRAGService` to allow per-query override.
|
||||||
|
|
||||||
|
**Patterns to follow**: Existing `search()` method in `http_rag.py` — same HTTP client pattern, same error handling, same response normalization.
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- `enhanced_search()` with rerank enabled: verify correct endpoint and payload
|
||||||
|
- `enhanced_search()` with compression enabled: verify payload includes `use_compression: true`
|
||||||
|
- `enhanced_search()` with multiple KBs: verify parallel calls and result merging
|
||||||
|
- `enhanced_search()` HTTP error: verify graceful fallback to empty results
|
||||||
|
- `SemanticMemory.search()` with `search_mode="enhanced"`: verify delegation to `enhanced_search()`
|
||||||
|
- `SemanticMemory.search()` with `search_mode="standard"`: verify existing behavior unchanged
|
||||||
|
- Config parsing: verify `search_mode`, `use_rerank`, `use_compression` from YAML
|
||||||
|
|
||||||
|
**Verification**: All tests pass. `enhanced_search()` returns reranked results when GEO backend supports it.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U3. Structured Context Injection — 结构化上下文注入
|
||||||
|
|
||||||
|
**Goal**: Format RAG results with source attribution so the LLM can assess credibility and cite sources.
|
||||||
|
|
||||||
|
**Requirements**: R3
|
||||||
|
|
||||||
|
**Dependencies**: U1 (query transformer affects what results are returned)
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/memory/retriever.py` (modify)
|
||||||
|
- `src/agentkit/core/react.py` (modify)
|
||||||
|
- `tests/unit/test_memory_integration.py` (modify)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- Replace `MemoryRetriever.get_context_string()` with `get_context_messages()` that returns structured context:
|
||||||
|
```
|
||||||
|
### 知识库参考 [来源: 行业库 | 相关度: 0.92 | 文档: AI行业趋势报告]
|
||||||
|
AI行业在2025年呈现三大趋势...
|
||||||
|
|
||||||
|
### 过往经验 [来源: 情景记忆 | 任务类型: seo_analysis]
|
||||||
|
上次分析竞品SEO策略时发现...
|
||||||
|
```
|
||||||
|
- Each `MemoryItem` is rendered with its metadata: `source` (rag/graph/episodic/working), `score`, `document_title`, `kb_type`.
|
||||||
|
- `ReActEngine.execute()` calls `get_context_messages()` instead of `get_context_string()`.
|
||||||
|
- The injection heading changes from `## Relevant Past Experience` to `## 参考信息` (bilingual-friendly).
|
||||||
|
- Add `context_template: "structured" | "flat"` config option (default: `"structured"`).
|
||||||
|
|
||||||
|
**Patterns to follow**: Current `get_context_string()` in `retriever.py` — same token budget logic, same parallel retrieval.
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- Structured format: verify each result has source header with metadata
|
||||||
|
- Flat format: verify backward-compatible plain text output
|
||||||
|
- Token budget: verify long results are truncated within budget
|
||||||
|
- Mixed sources: verify RAG results and episodic memories are formatted differently
|
||||||
|
- ReActEngine integration: verify system_prompt contains structured context
|
||||||
|
- Empty results: verify no context section added when no results found
|
||||||
|
|
||||||
|
**Verification**: LLM receives structured context with source attribution. Backward compatible with `context_template: "flat"`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U4. RetrieveKnowledge Tool — ReAct 循环内二次检索
|
||||||
|
|
||||||
|
**Goal**: Enable the LLM to trigger additional knowledge base searches during ReAct reasoning by registering `retrieve_knowledge` as a built-in Tool.
|
||||||
|
|
||||||
|
**Requirements**: R4
|
||||||
|
|
||||||
|
**Dependencies**: U1, U3
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/memory/retriever.py` (modify)
|
||||||
|
- `src/agentkit/core/config_driven.py` (modify)
|
||||||
|
- `src/agentkit/server/app.py` (modify)
|
||||||
|
- `tests/unit/test_retrieve_knowledge_tool.py` (create)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- Create `RetrieveKnowledgeTool(Tool)` inner class within `MemoryRetriever`:
|
||||||
|
- `name: "retrieve_knowledge"`
|
||||||
|
- `description: "Search the knowledge base for additional information. Use when you need more context or facts."`
|
||||||
|
- `input_schema: {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}`
|
||||||
|
- `execute(query)`: calls `self._retriever.retrieve(query)` and returns formatted results
|
||||||
|
- Add `create_retrieve_tool() -> Tool | None` method to `MemoryRetriever`:
|
||||||
|
- Returns `RetrieveKnowledgeTool` instance if semantic memory is configured
|
||||||
|
- Returns `None` if no semantic memory (tool not available)
|
||||||
|
- Auto-register the tool in `ConfigDrivenAgent.__init__()` and `app.py` when `memory_retriever` is created:
|
||||||
|
- `if memory_retriever and memory_retriever.create_retrieve_tool(): agent.use_tool(tool)`
|
||||||
|
- The tool uses the same `MemoryRetriever.retrieve()` pipeline, so query transformation (U1) and structured formatting (U3) apply automatically.
|
||||||
|
|
||||||
|
**Patterns to follow**: `agentkit/tools/base.py` — Tool subclass pattern with `execute()` and `safe_execute()`.
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- Tool creation: verify `create_retrieve_tool()` returns a Tool when semantic memory is configured
|
||||||
|
- Tool creation: verify `create_retrieve_tool()` returns None when no semantic memory
|
||||||
|
- Tool execution: verify `execute(query="AI趋势")` calls `MemoryRetriever.retrieve()` with the query
|
||||||
|
- Tool execution: verify results are formatted as structured text
|
||||||
|
- Tool schema: verify `input_schema` has `query` field
|
||||||
|
- Auto-registration: verify ConfigDrivenAgent with semantic memory has `retrieve_knowledge` in its tool list
|
||||||
|
- Auto-registration: verify agent without semantic memory does NOT have the tool
|
||||||
|
- ReAct integration: verify LLM can call `retrieve_knowledge` during ReAct loop
|
||||||
|
|
||||||
|
**Verification**: Agent with semantic memory has `retrieve_knowledge` tool. LLM can call it during reasoning. Results are formatted with source attribution.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U5. Configurable Retrieval + Per-KB Weights — 可配置参数与差异化权重
|
||||||
|
|
||||||
|
**Goal**: Make retrieval parameters configurable and support per-knowledge-base weight differentiation.
|
||||||
|
|
||||||
|
**Requirements**: R5, R6
|
||||||
|
|
||||||
|
**Dependencies**: U2, U3
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/core/react.py` (modify)
|
||||||
|
- `src/agentkit/memory/retriever.py` (modify)
|
||||||
|
- `src/agentkit/server/config.py` (modify)
|
||||||
|
- `src/agentkit/core/config_driven.py` (modify)
|
||||||
|
- `tests/unit/test_memory_integration.py` (modify)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- **Configurable retrieval parameters**:
|
||||||
|
- Add `retrieval` sub-section to `memory` config:
|
||||||
|
```yaml
|
||||||
|
memory:
|
||||||
|
retrieval:
|
||||||
|
top_k: 5
|
||||||
|
token_budget: 2000
|
||||||
|
context_template: "structured"
|
||||||
|
```
|
||||||
|
- `ReActEngine.execute()` reads these from `SkillConfig.memory.retrieval` or falls back to defaults.
|
||||||
|
- Pass `retrieval_config` through `ConfigDrivenAgent._handle_react()` to `ReActEngine.execute()`.
|
||||||
|
- **Per-KB weights**:
|
||||||
|
- Add `kb_weights` to `memory.semantic` config:
|
||||||
|
```yaml
|
||||||
|
memory:
|
||||||
|
semantic:
|
||||||
|
kb_weights:
|
||||||
|
"industry-kb-id": 1.2 # 行业库权重更高
|
||||||
|
"enterprise-kb-id": 0.8 # 企业库权重较低
|
||||||
|
```
|
||||||
|
- `SemanticMemory.search()` applies kb_weights as score multipliers after retrieval.
|
||||||
|
- `MemoryRetriever` passes kb_weights through `filters` to `SemanticMemory.search()`.
|
||||||
|
- **Token estimation improvement**:
|
||||||
|
- Replace `len(text) // 4` with a slightly better heuristic: `max(len(text) // 3, len(text.split()))` for mixed Chinese/English content. Not perfect but significantly better for CJK text.
|
||||||
|
|
||||||
|
**Patterns to follow**: Existing config pattern in `ServerConfig.from_dict()` — same dict-based config with env var resolution.
|
||||||
|
|
||||||
|
**Test scenarios**:
|
||||||
|
- Config parsing: verify `retrieval.top_k`, `retrieval.token_budget`, `retrieval.context_template` from YAML
|
||||||
|
- Config parsing: verify `semantic.kb_weights` from YAML
|
||||||
|
- ReActEngine: verify configurable `top_k` and `token_budget` are used instead of hardcoded values
|
||||||
|
- Per-KB weights: verify industry KB results get higher scores than enterprise KB results
|
||||||
|
- Per-KB weights: verify unweighted KBs get default score (1.0 multiplier)
|
||||||
|
- Token estimation: verify improved heuristic for Chinese text
|
||||||
|
- Backward compatibility: verify defaults match current hardcoded values when config is absent
|
||||||
|
|
||||||
|
**Verification**: Retrieval parameters are configurable via YAML. Per-KB weights are applied. No behavior change when config is absent.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scope Boundaries
|
||||||
|
|
||||||
|
### In Scope
|
||||||
|
- Query rewriting (LLM + rule-based)
|
||||||
|
- Enhanced retrieval with rerank/compression
|
||||||
|
- Structured context injection with source attribution
|
||||||
|
- `retrieve_knowledge` Tool for iterative retrieval
|
||||||
|
- Configurable retrieval parameters
|
||||||
|
- Per-knowledge-base weight differentiation
|
||||||
|
|
||||||
|
### Deferred to Follow-Up Work
|
||||||
|
- Cross-encoder reranking model (GEO currently uses LLM-based reranking, which is sufficient)
|
||||||
|
- Full-text search upgrade (GEO's ILIKE → ts_vector is a backend-only change)
|
||||||
|
- Semantic memory protocol formalization (ABC for rag_service)
|
||||||
|
- Caching layer for frequent queries
|
||||||
|
- Multi-hop retrieval (retrieval → extraction → retrieval chains)
|
||||||
|
- Retrieval metrics and observability (hit rate, latency tracking)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risks and Mitigations
|
||||||
|
|
||||||
|
| Risk | Impact | Mitigation |
|
||||||
|
|------|--------|------------|
|
||||||
|
| LLM query rewriting adds latency (~500ms per task) | Medium | Async execution; fallback to rule-based when LLM unavailable; configurable on/off |
|
||||||
|
| Enhanced retrieval endpoint may not exist on all backends | Low | `search_mode: "standard"` is default; `enhanced_search()` falls back to `search()` on 404 |
|
||||||
|
| `retrieve_knowledge` tool may cause infinite retrieval loops | Medium | ReAct `max_steps` already limits total iterations; add `max_retrieval_calls` config (default: 3) |
|
||||||
|
| Per-KB weights require knowing KB IDs at config time | Low | Weights are optional; unweighted KBs use default multiplier (1.0) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## System-Wide Impact
|
||||||
|
|
||||||
|
- **ReActEngine**: New parameters for configurable retrieval; context injection format change
|
||||||
|
- **MemoryRetriever**: Query transformation pipeline; structured context output; tool creation
|
||||||
|
- **HttpRAGService**: New `enhanced_search()` method
|
||||||
|
- **SemanticMemory**: `search_mode` parameter; kb_weights support
|
||||||
|
- **ConfigDrivenAgent**: Auto-registration of `retrieve_knowledge` tool; config-driven retrieval parameters
|
||||||
|
- **ServerConfig**: New config sections for `memory.retrieval` and `memory.semantic.kb_weights`
|
||||||
|
- **GEO backend**: No changes required — `EnhancedRAG` endpoints already exist
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phased Delivery
|
||||||
|
|
||||||
|
| Phase | Units | Focus |
|
||||||
|
|-------|-------|-------|
|
||||||
|
| Phase A: Query Quality | U1, U2 | Query rewriting + enhanced retrieval |
|
||||||
|
| Phase B: Context Quality | U3, U4 | Structured injection + iterative retrieval |
|
||||||
|
| Phase C: Configurability | U5 | Configurable parameters + per-KB weights |
|
||||||
|
|
@ -342,6 +342,10 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
semantic = SemanticMemory(
|
semantic = SemanticMemory(
|
||||||
rag_service=rag_service,
|
rag_service=rag_service,
|
||||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
search_mode=sem_conf.get("search_mode", "standard"),
|
||||||
|
use_rerank=sem_conf.get("use_rerank", True),
|
||||||
|
use_compression=sem_conf.get("use_compression", False),
|
||||||
|
kb_weights=sem_conf.get("kb_weights"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._memory_retriever = MemoryRetriever(
|
self._memory_retriever = MemoryRetriever(
|
||||||
|
|
@ -358,6 +362,12 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
logger.warning(f"Failed to initialize memory system: {e}")
|
logger.warning(f"Failed to initialize memory system: {e}")
|
||||||
self._memory_retriever = None
|
self._memory_retriever = None
|
||||||
|
|
||||||
|
# Auto-register retrieve_knowledge tool if semantic memory is configured
|
||||||
|
if self._memory_retriever:
|
||||||
|
retrieve_tool = self._memory_retriever.create_retrieve_tool()
|
||||||
|
if retrieve_tool:
|
||||||
|
self.use_tool(retrieve_tool)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config(self) -> AgentConfig:
|
def config(self) -> AgentConfig:
|
||||||
return self._config
|
return self._config
|
||||||
|
|
@ -530,6 +540,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||||||
|
|
||||||
# Execute ReAct loop
|
# Execute ReAct loop
|
||||||
|
retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {}
|
||||||
result = await self._react_engine.execute(
|
result = await self._react_engine.execute(
|
||||||
messages=user_messages,
|
messages=user_messages,
|
||||||
tools=self._tools if self._tools else None,
|
tools=self._tools if self._tools else None,
|
||||||
|
|
@ -539,6 +550,7 @@ class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
memory_retriever=self._memory_retriever,
|
memory_retriever=self._memory_retriever,
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
|
retrieval_config=retrieval_config or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse result
|
# Parse result
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ class ReActEngine:
|
||||||
memory_retriever: "MemoryRetriever | None" = None,
|
memory_retriever: "MemoryRetriever | None" = None,
|
||||||
task_id: str | None = None,
|
task_id: str | None = None,
|
||||||
compressor: "ContextCompressor | None" = None,
|
compressor: "ContextCompressor | None" = None,
|
||||||
|
retrieval_config: dict[str, Any] | None = None,
|
||||||
) -> ReActResult:
|
) -> ReActResult:
|
||||||
"""执行 ReAct 循环
|
"""执行 ReAct 循环
|
||||||
|
|
||||||
|
|
@ -104,16 +105,18 @@ class ReActEngine:
|
||||||
if memory_retriever:
|
if memory_retriever:
|
||||||
try:
|
try:
|
||||||
query = str(messages[-1].get("content", "")) if messages else ""
|
query = str(messages[-1].get("content", "")) if messages else ""
|
||||||
|
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||||
|
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||||
memory_context = await memory_retriever.get_context_string(
|
memory_context = await memory_retriever.get_context_string(
|
||||||
query=query,
|
query=query,
|
||||||
top_k=5,
|
top_k=top_k,
|
||||||
token_budget=2000,
|
token_budget=token_budget,
|
||||||
)
|
)
|
||||||
if memory_context:
|
if memory_context:
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
system_prompt += f"\n\n## Relevant Past Experience\n{memory_context}"
|
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||||||
else:
|
else:
|
||||||
system_prompt = f"## Relevant Past Experience\n{memory_context}"
|
system_prompt = f"## 参考信息\n{memory_context}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||||
|
|
||||||
|
|
@ -337,6 +340,7 @@ class ReActEngine:
|
||||||
memory_retriever: "MemoryRetriever | None" = None,
|
memory_retriever: "MemoryRetriever | None" = None,
|
||||||
task_id: str | None = None,
|
task_id: str | None = None,
|
||||||
compressor: "ContextCompressor | None" = None,
|
compressor: "ContextCompressor | None" = None,
|
||||||
|
retrieval_config: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
"""Execute ReAct loop, yielding ReActEvent objects.
|
"""Execute ReAct loop, yielding ReActEvent objects.
|
||||||
|
|
||||||
|
|
@ -358,16 +362,18 @@ class ReActEngine:
|
||||||
if memory_retriever:
|
if memory_retriever:
|
||||||
try:
|
try:
|
||||||
query = str(messages[-1].get("content", "")) if messages else ""
|
query = str(messages[-1].get("content", "")) if messages else ""
|
||||||
|
top_k = (retrieval_config or {}).get("top_k", 5)
|
||||||
|
token_budget = (retrieval_config or {}).get("token_budget", 2000)
|
||||||
memory_context = await memory_retriever.get_context_string(
|
memory_context = await memory_retriever.get_context_string(
|
||||||
query=query,
|
query=query,
|
||||||
top_k=5,
|
top_k=top_k,
|
||||||
token_budget=2000,
|
token_budget=token_budget,
|
||||||
)
|
)
|
||||||
if memory_context:
|
if memory_context:
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
system_prompt += f"\n\n## Relevant Past Experience\n{memory_context}"
|
system_prompt += f"\n\n## 参考信息\n{memory_context}"
|
||||||
else:
|
else:
|
||||||
system_prompt = f"## Relevant Past Experience\n{memory_context}"
|
system_prompt = f"## 参考信息\n{memory_context}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
logger.warning(f"Memory retrieval failed, continuing without context: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,14 @@ from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.semantic import SemanticMemory
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
from agentkit.memory.http_rag import HttpRAGService
|
from agentkit.memory.http_rag import HttpRAGService
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
from agentkit.memory.query_transformer import (
|
||||||
|
QueryTransformerBase,
|
||||||
|
LLMQueryTransformer,
|
||||||
|
RuleQueryTransformer,
|
||||||
|
NoOpQueryTransformer,
|
||||||
|
TransformedQuery,
|
||||||
|
create_query_transformer,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Memory",
|
"Memory",
|
||||||
|
|
@ -15,4 +23,10 @@ __all__ = [
|
||||||
"SemanticMemory",
|
"SemanticMemory",
|
||||||
"HttpRAGService",
|
"HttpRAGService",
|
||||||
"MemoryRetriever",
|
"MemoryRetriever",
|
||||||
|
"QueryTransformerBase",
|
||||||
|
"LLMQueryTransformer",
|
||||||
|
"RuleQueryTransformer",
|
||||||
|
"NoOpQueryTransformer",
|
||||||
|
"TransformedQuery",
|
||||||
|
"create_query_transformer",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -129,6 +129,90 @@ class HttpRAGService:
|
||||||
logger.error(f"RAG search unexpected error: {e}")
|
logger.error(f"RAG search unexpected error: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def enhanced_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
top_k: int = 5,
|
||||||
|
use_rerank: bool = True,
|
||||||
|
use_compression: bool = False,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""增强语义检索知识库(支持 rerank 和 compression)
|
||||||
|
|
||||||
|
对每个知识库分别调用 /bases/{kb_id}/retrieve 接口,
|
||||||
|
合并结果后按 score 降序返回 top_k 条。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 检索查询
|
||||||
|
knowledge_base_ids: 知识库 ID 列表(默认使用配置值)
|
||||||
|
top_k: 返回结果数量
|
||||||
|
use_rerank: 是否启用 rerank 重排序
|
||||||
|
use_compression: 是否启用上下文压缩
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表,每项包含 content/score/document_id 等字段
|
||||||
|
"""
|
||||||
|
kb_ids = knowledge_base_ids or self._knowledge_base_ids
|
||||||
|
if not kb_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"query": query,
|
||||||
|
"top_k": top_k,
|
||||||
|
"use_rerank": use_rerank,
|
||||||
|
"use_compression": use_compression,
|
||||||
|
}
|
||||||
|
|
||||||
|
client = self._get_client()
|
||||||
|
all_results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for kb_id in kb_ids:
|
||||||
|
try:
|
||||||
|
resp = await client.post(f"/bases/{kb_id}/retrieve", json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
# 兼容两种响应格式
|
||||||
|
if isinstance(data, dict) and "results" in data:
|
||||||
|
results = data["results"]
|
||||||
|
elif isinstance(data, list):
|
||||||
|
results = data
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected enhanced_search response format: {type(data)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 标准化
|
||||||
|
for r in results:
|
||||||
|
if isinstance(r, dict):
|
||||||
|
all_results.append({
|
||||||
|
"id": r.get("chunk_id", r.get("id", "")),
|
||||||
|
"content": r.get("content", ""),
|
||||||
|
"score": float(r.get("score", 0.0)),
|
||||||
|
"source": r.get("source", "rag"),
|
||||||
|
"document_id": r.get("document_id", ""),
|
||||||
|
"document_title": r.get("document_title", ""),
|
||||||
|
"knowledge_base_id": kb_id,
|
||||||
|
"metadata": r.get("metadata", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
if e.response.status_code == 404:
|
||||||
|
# 后端不支持增强检索接口,回退到标准 search
|
||||||
|
logger.info(f"Enhanced search endpoint not found (404), falling back to standard search")
|
||||||
|
return await self.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
||||||
|
logger.error(f"RAG enhanced_search HTTP error: {e.response.status_code} — {e.response.text[:200]}")
|
||||||
|
return []
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"RAG enhanced_search request error: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"RAG enhanced_search unexpected error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 按 score 降序排序,返回 top_k
|
||||||
|
all_results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return all_results[:top_k]
|
||||||
|
|
||||||
async def ingest(
|
async def ingest(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""QueryTransformer - RAG 查询改写
|
||||||
|
|
||||||
|
将用户原始查询改写为更适合知识库检索的形式:
|
||||||
|
- LLMQueryTransformer: 基于 LLM 的智能改写
|
||||||
|
- RuleQueryTransformer: 基于规则的改写(去停用词、同义扩展)
|
||||||
|
- NoOpQueryTransformer: 不改写,原样返回
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformedQuery:
|
||||||
|
"""改写后的查询"""
|
||||||
|
|
||||||
|
main_query: str
|
||||||
|
sub_queries: list[str]
|
||||||
|
original_query: str
|
||||||
|
|
||||||
|
|
||||||
|
class QueryTransformerBase(ABC):
|
||||||
|
"""查询改写抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
"""改写查询"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class LLMQueryTransformer(QueryTransformerBase):
|
||||||
|
"""基于 LLM 的查询改写
|
||||||
|
|
||||||
|
通过 LLM 提取核心意图、分解子查询、添加领域术语。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, llm_gateway, max_sub_queries: int = 3):
|
||||||
|
self._llm_gateway = llm_gateway
|
||||||
|
self._max_sub_queries = max_sub_queries
|
||||||
|
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
"""使用 LLM 改写查询"""
|
||||||
|
prompt = (
|
||||||
|
"You are a query rewriting assistant for a knowledge base retrieval system.\n"
|
||||||
|
"Given a user query, your task is to:\n"
|
||||||
|
"1. Extract the core intent of the query\n"
|
||||||
|
"2. If the query is complex, decompose it into simpler sub-queries\n"
|
||||||
|
"3. Add domain-specific terms that may improve retrieval\n\n"
|
||||||
|
f"Original query: {query}\n\n"
|
||||||
|
'Respond ONLY with a JSON object in this exact format: {"main_query": "...", "sub_queries": [...]}\n'
|
||||||
|
"The main_query should be a concise, retrieval-optimized version of the original query.\n"
|
||||||
|
"The sub_queries should be a list of simpler queries (0-3 items) that cover different aspects.\n"
|
||||||
|
"Do not include any other text or explanation."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
model="default",
|
||||||
|
)
|
||||||
|
data = json.loads(response.content)
|
||||||
|
main_query = str(data.get("main_query", query))
|
||||||
|
sub_queries = list(data.get("sub_queries", []))[: self._max_sub_queries]
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=main_query,
|
||||||
|
sub_queries=sub_queries,
|
||||||
|
original_query=query,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("LLM query transformation failed, falling back to original query")
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=query,
|
||||||
|
sub_queries=[],
|
||||||
|
original_query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RuleQueryTransformer(QueryTransformerBase):
|
||||||
|
"""基于规则的查询改写
|
||||||
|
|
||||||
|
去除填充词、提取关键名词短语、同义扩展。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_FILLER_WORDS_CN: list[str] = [
|
||||||
|
"帮我", "请", "一下", "分析", "看看", "告诉我", "想知道", "请问",
|
||||||
|
]
|
||||||
|
_FILLER_WORDS_EN: list[str] = [
|
||||||
|
"please", "can you", "help me", "could you", "i want to", "i need to",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
synonyms: dict[str, list[str]] | None = None,
|
||||||
|
max_sub_queries: int = 3,
|
||||||
|
):
|
||||||
|
self._synonyms = synonyms or {}
|
||||||
|
self._max_sub_queries = max_sub_queries
|
||||||
|
# Pre-compile filler patterns
|
||||||
|
self._filler_patterns_cn = [
|
||||||
|
re.compile(re.escape(w)) for w in self._FILLER_WORDS_CN
|
||||||
|
]
|
||||||
|
self._filler_patterns_en = [
|
||||||
|
re.compile(re.escape(w), re.IGNORECASE) for w in self._FILLER_WORDS_EN
|
||||||
|
]
|
||||||
|
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
"""基于规则改写查询"""
|
||||||
|
cleaned = query
|
||||||
|
|
||||||
|
# Remove Chinese filler words
|
||||||
|
for pattern in self._filler_patterns_cn:
|
||||||
|
cleaned = pattern.sub("", cleaned)
|
||||||
|
|
||||||
|
# Remove English filler words
|
||||||
|
for pattern in self._filler_patterns_en:
|
||||||
|
cleaned = pattern.sub("", cleaned)
|
||||||
|
|
||||||
|
# Collapse whitespace
|
||||||
|
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||||
|
|
||||||
|
# If nothing left after cleaning, use original
|
||||||
|
if not cleaned:
|
||||||
|
cleaned = query
|
||||||
|
|
||||||
|
# Synonym expansion
|
||||||
|
sub_queries: list[str] = []
|
||||||
|
for term, expansions in self._synonyms.items():
|
||||||
|
if term in cleaned:
|
||||||
|
for expansion in expansions:
|
||||||
|
if expansion != cleaned:
|
||||||
|
sub_queries.append(cleaned.replace(term, expansion))
|
||||||
|
if len(sub_queries) >= self._max_sub_queries:
|
||||||
|
break
|
||||||
|
if len(sub_queries) >= self._max_sub_queries:
|
||||||
|
break
|
||||||
|
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=cleaned,
|
||||||
|
sub_queries=sub_queries,
|
||||||
|
original_query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoOpQueryTransformer(QueryTransformerBase):
|
||||||
|
"""不做任何改写,原样返回"""
|
||||||
|
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=query,
|
||||||
|
sub_queries=[],
|
||||||
|
original_query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_query_transformer(
|
||||||
|
strategy: str = "none",
|
||||||
|
llm_gateway=None,
|
||||||
|
synonyms: dict[str, list[str]] | None = None,
|
||||||
|
max_sub_queries: int = 3,
|
||||||
|
) -> QueryTransformerBase:
|
||||||
|
"""工厂函数:根据策略创建查询改写器"""
|
||||||
|
if strategy == "llm":
|
||||||
|
if llm_gateway is None:
|
||||||
|
logger.warning("LLM strategy requested but no llm_gateway provided, falling back to NoOp")
|
||||||
|
return NoOpQueryTransformer()
|
||||||
|
return LLMQueryTransformer(llm_gateway, max_sub_queries=max_sub_queries)
|
||||||
|
elif strategy == "rule":
|
||||||
|
return RuleQueryTransformer(synonyms=synonyms, max_sub_queries=max_sub_queries)
|
||||||
|
else:
|
||||||
|
return NoOpQueryTransformer()
|
||||||
|
|
@ -3,6 +3,8 @@
|
||||||
并行查询三层记忆,按权重融合排序。
|
并行查询三层记忆,按权重融合排序。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
@ -14,10 +16,27 @@ from agentkit.memory.base import Memory, MemoryItem
|
||||||
from agentkit.memory.working import WorkingMemory
|
from agentkit.memory.working import WorkingMemory
|
||||||
from agentkit.memory.episodic import EpisodicMemory
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.semantic import SemanticMemory
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
from agentkit.memory.query_transformer import QueryTransformerBase
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count for mixed Chinese/English text.
|
||||||
|
|
||||||
|
Chinese characters typically use 1-2 tokens each.
|
||||||
|
English words typically use 1 token each.
|
||||||
|
"""
|
||||||
|
cjk_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||||
|
non_cjk = text
|
||||||
|
for c in text:
|
||||||
|
if '\u4e00' <= c <= '\u9fff':
|
||||||
|
non_cjk = non_cjk.replace(c, ' ')
|
||||||
|
word_count = len(non_cjk.split())
|
||||||
|
return cjk_count * 2 + word_count
|
||||||
|
|
||||||
|
|
||||||
class MemoryRetriever:
|
class MemoryRetriever:
|
||||||
"""混合检索器 - 并行查询三层记忆,按权重融合排序
|
"""混合检索器 - 并行查询三层记忆,按权重融合排序
|
||||||
|
|
||||||
|
|
@ -34,6 +53,8 @@ class MemoryRetriever:
|
||||||
episodic_memory: EpisodicMemory | None = None,
|
episodic_memory: EpisodicMemory | None = None,
|
||||||
semantic_memory: SemanticMemory | None = None,
|
semantic_memory: SemanticMemory | None = None,
|
||||||
weights: dict[str, float] | None = None,
|
weights: dict[str, float] | None = None,
|
||||||
|
query_transformer: QueryTransformerBase | None = None,
|
||||||
|
context_template: str = "structured",
|
||||||
):
|
):
|
||||||
self._working = working_memory
|
self._working = working_memory
|
||||||
self._episodic = episodic_memory
|
self._episodic = episodic_memory
|
||||||
|
|
@ -43,6 +64,8 @@ class MemoryRetriever:
|
||||||
"episodic": 0.4,
|
"episodic": 0.4,
|
||||||
"semantic": 0.4,
|
"semantic": 0.4,
|
||||||
}
|
}
|
||||||
|
self._query_transformer = query_transformer
|
||||||
|
self._context_template = context_template
|
||||||
|
|
||||||
async def retrieve(
|
async def retrieve(
|
||||||
self,
|
self,
|
||||||
|
|
@ -52,6 +75,62 @@ class MemoryRetriever:
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
) -> list[MemoryItem]:
|
) -> list[MemoryItem]:
|
||||||
"""混合检索三层记忆"""
|
"""混合检索三层记忆"""
|
||||||
|
# Query transformation
|
||||||
|
if self._query_transformer is not None:
|
||||||
|
transformed = await self._query_transformer.transform(query)
|
||||||
|
search_query = transformed.main_query
|
||||||
|
sub_queries = transformed.sub_queries
|
||||||
|
else:
|
||||||
|
search_query = query
|
||||||
|
sub_queries = []
|
||||||
|
|
||||||
|
# Primary search with main query
|
||||||
|
all_items = await self._search_layers(search_query, top_k, filters)
|
||||||
|
|
||||||
|
# Sub-query search in parallel
|
||||||
|
if sub_queries:
|
||||||
|
sub_tasks = [
|
||||||
|
self._search_layers(sq, top_k, filters) for sq in sub_queries
|
||||||
|
]
|
||||||
|
sub_results = await asyncio.gather(*sub_tasks, return_exceptions=True)
|
||||||
|
for result in sub_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"Sub-query search failed: {result}")
|
||||||
|
continue
|
||||||
|
all_items.extend(result)
|
||||||
|
|
||||||
|
# Deduplicate by key (keep highest score)
|
||||||
|
seen: dict[str, MemoryItem] = {}
|
||||||
|
for item in all_items:
|
||||||
|
if item.key not in seen or item.score > seen[item.key].score:
|
||||||
|
seen[item.key] = item
|
||||||
|
all_items = list(seen.values())
|
||||||
|
|
||||||
|
# 按分数排序
|
||||||
|
all_items.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
|
||||||
|
# Token 预算管理
|
||||||
|
selected = []
|
||||||
|
total_tokens = 0
|
||||||
|
for item in all_items:
|
||||||
|
text = str(item.value)
|
||||||
|
estimated_tokens = _estimate_tokens(text)
|
||||||
|
if total_tokens + estimated_tokens > token_budget:
|
||||||
|
continue
|
||||||
|
selected.append(item)
|
||||||
|
total_tokens += estimated_tokens
|
||||||
|
if len(selected) >= top_k:
|
||||||
|
break
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
async def _search_layers(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
) -> list[MemoryItem]:
|
||||||
|
"""Search all configured memory layers with a single query"""
|
||||||
tasks = []
|
tasks = []
|
||||||
layer_names = []
|
layer_names = []
|
||||||
|
|
||||||
|
|
@ -82,23 +161,7 @@ class MemoryRetriever:
|
||||||
weighted = replace(item, score=item.score * weight)
|
weighted = replace(item, score=item.score * weight)
|
||||||
all_items.append(weighted)
|
all_items.append(weighted)
|
||||||
|
|
||||||
# 按分数排序
|
return all_items
|
||||||
all_items.sort(key=lambda x: x.score, reverse=True)
|
|
||||||
|
|
||||||
# Token 预算管理
|
|
||||||
selected = []
|
|
||||||
total_tokens = 0
|
|
||||||
for item in all_items:
|
|
||||||
text = str(item.value)
|
|
||||||
estimated_tokens = len(text) // 4
|
|
||||||
if total_tokens + estimated_tokens > token_budget:
|
|
||||||
continue
|
|
||||||
selected.append(item)
|
|
||||||
total_tokens += estimated_tokens
|
|
||||||
if len(selected) >= top_k:
|
|
||||||
break
|
|
||||||
|
|
||||||
return selected
|
|
||||||
|
|
||||||
async def get_context_string(
|
async def get_context_string(
|
||||||
self,
|
self,
|
||||||
|
|
@ -106,13 +169,59 @@ class MemoryRetriever:
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
token_budget: int = 3000,
|
token_budget: int = 3000,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""获取格式化的上下文字符串"""
|
"""获取格式化的上下文字符串
|
||||||
|
|
||||||
|
根据 context_template 选择输出格式:
|
||||||
|
- "structured": 带来源标注的结构化格式
|
||||||
|
- "flat": 纯文本拼接(向后兼容)
|
||||||
|
"""
|
||||||
items = await self.retrieve(query, top_k, token_budget)
|
items = await self.retrieve(query, top_k, token_budget)
|
||||||
parts = []
|
|
||||||
for item in items:
|
if not items:
|
||||||
parts.append(str(item.value))
|
return ""
|
||||||
|
|
||||||
|
if self._context_template == "flat":
|
||||||
|
parts = [str(item.value) for item in items]
|
||||||
return "\n\n".join(parts)
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
# Structured format
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in items:
|
||||||
|
header = self._format_structured_header(item)
|
||||||
|
parts.append(f"{header}\n{item.value}")
|
||||||
|
|
||||||
|
result = "\n\n".join(parts)
|
||||||
|
|
||||||
|
# Respect token budget — truncate if formatted output exceeds it
|
||||||
|
estimated_tokens = _estimate_tokens(result)
|
||||||
|
# Safety limit: also check character count as a ceiling.
|
||||||
|
# This handles edge cases like very long unbroken strings.
|
||||||
|
max_chars = token_budget * 4
|
||||||
|
if estimated_tokens > token_budget or len(result) > max_chars:
|
||||||
|
result = result[:max_chars]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_structured_header(item: MemoryItem) -> str:
|
||||||
|
"""根据 MemoryItem 的 metadata 生成结构化标题行"""
|
||||||
|
source = item.metadata.get("source", "")
|
||||||
|
score = item.score
|
||||||
|
|
||||||
|
if source == "rag":
|
||||||
|
kb_type = item.metadata.get("kb_type", "知识库")
|
||||||
|
document_title = item.metadata.get("document_title", "未知文档")
|
||||||
|
return f"### 知识库参考 [来源: {kb_type} | 相关度: {score:.2f} | 文档: {document_title}]"
|
||||||
|
elif source == "graph":
|
||||||
|
return f"### 知识图谱 [实体: {item.key} | 相关度: {score:.2f}]"
|
||||||
|
elif source == "episodic":
|
||||||
|
task_type = item.metadata.get("task_type", "未知")
|
||||||
|
return f"### 过往经验 [来源: 情景记忆 | 任务类型: {task_type}]"
|
||||||
|
elif source == "working":
|
||||||
|
return f"### 工作记忆 [键: {item.key}]"
|
||||||
|
else:
|
||||||
|
return f"### 参考 [来源: {source} | 相关度: {score:.2f}]"
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
self, key: str, value: Any, metadata: dict[str, Any] | None = None
|
self, key: str, value: Any, metadata: dict[str, Any] | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -123,3 +232,59 @@ class MemoryRetriever:
|
||||||
"""
|
"""
|
||||||
if self._episodic is not None:
|
if self._episodic is not None:
|
||||||
await self._episodic.store(key, value, metadata)
|
await self._episodic.store(key, value, metadata)
|
||||||
|
|
||||||
|
def create_retrieve_tool(self, max_calls: int = 3) -> Tool | None:
|
||||||
|
"""Create a retrieve_knowledge tool if semantic memory is configured.
|
||||||
|
|
||||||
|
Returns None if no semantic memory is available (tool not applicable).
|
||||||
|
"""
|
||||||
|
if self._semantic is None:
|
||||||
|
return None
|
||||||
|
return RetrieveKnowledgeTool(retriever=self, max_calls=max_calls)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveKnowledgeTool(Tool):
|
||||||
|
"""Built-in tool for knowledge base retrieval during ReAct reasoning."""
|
||||||
|
|
||||||
|
def __init__(self, retriever: MemoryRetriever, max_calls: int = 3):
|
||||||
|
super().__init__(
|
||||||
|
name="retrieve_knowledge",
|
||||||
|
description="Search the knowledge base for additional information. Use this tool when you need more context, facts, or details to answer a question accurately.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query to find relevant information in the knowledge base",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._retriever = retriever
|
||||||
|
self._max_calls = max_calls
|
||||||
|
self._call_count = 0
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
query = kwargs.get("query", "")
|
||||||
|
if not query:
|
||||||
|
return {"error": "query is required", "results": []}
|
||||||
|
|
||||||
|
if self._call_count >= self._max_calls:
|
||||||
|
return {"error": f"Maximum retrieval calls ({self._max_calls}) reached", "results": []}
|
||||||
|
|
||||||
|
self._call_count += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
items = await self._retriever.retrieve(query, top_k=5)
|
||||||
|
results = []
|
||||||
|
for item in items:
|
||||||
|
results.append({
|
||||||
|
"content": item.value,
|
||||||
|
"score": item.score,
|
||||||
|
"source": item.metadata.get("source", "unknown"),
|
||||||
|
"document_title": item.metadata.get("document_title", ""),
|
||||||
|
})
|
||||||
|
return {"query": query, "results": results, "call_count": self._call_count}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e), "results": []}
|
||||||
|
|
|
||||||
|
|
@ -22,16 +22,28 @@ class SemanticMemory(Memory):
|
||||||
rag_service: Any = None,
|
rag_service: Any = None,
|
||||||
graph_service: Any = None,
|
graph_service: Any = None,
|
||||||
knowledge_base_ids: list[str] | None = None,
|
knowledge_base_ids: list[str] | None = None,
|
||||||
|
search_mode: str = "standard",
|
||||||
|
use_rerank: bool = True,
|
||||||
|
use_compression: bool = False,
|
||||||
|
kb_weights: dict[str, float] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rag_service: RAG 检索服务(需提供 search 方法)
|
rag_service: RAG 检索服务(需提供 search 方法)
|
||||||
graph_service: 知识图谱服务(需提供 query 方法)
|
graph_service: 知识图谱服务(需提供 query 方法)
|
||||||
knowledge_base_ids: 默认检索的知识库 ID 列表
|
knowledge_base_ids: 默认检索的知识库 ID 列表
|
||||||
|
search_mode: 检索模式,"standard" 或 "enhanced"
|
||||||
|
use_rerank: 启用 rerank 重排序(仅 enhanced 模式生效)
|
||||||
|
use_compression: 启用上下文压缩(仅 enhanced 模式生效)
|
||||||
|
kb_weights: 知识库权重映射,key 为知识库 ID,value 为权重倍数
|
||||||
"""
|
"""
|
||||||
self._rag_service = rag_service
|
self._rag_service = rag_service
|
||||||
self._graph_service = graph_service
|
self._graph_service = graph_service
|
||||||
self._knowledge_base_ids = knowledge_base_ids or []
|
self._knowledge_base_ids = knowledge_base_ids or []
|
||||||
|
self._search_mode = search_mode
|
||||||
|
self._use_rerank = use_rerank
|
||||||
|
self._use_compression = use_compression
|
||||||
|
self._kb_weights = kb_weights
|
||||||
|
|
||||||
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
async def store(self, key: str, value: Any, metadata: dict[str, Any] | None = None) -> None:
|
||||||
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
"""Semantic Memory 通常只读,写入委托给 RAG 服务的 ingest 方法"""
|
||||||
|
|
@ -52,17 +64,32 @@ class SemanticMemory(Memory):
|
||||||
if self._rag_service:
|
if self._rag_service:
|
||||||
try:
|
try:
|
||||||
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
kb_ids = (filters or {}).get("knowledge_base_ids", self._knowledge_base_ids)
|
||||||
|
if self._search_mode == "enhanced" and hasattr(self._rag_service, "enhanced_search"):
|
||||||
|
results = await self._rag_service.enhanced_search(
|
||||||
|
query,
|
||||||
|
knowledge_base_ids=kb_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
use_rerank=self._use_rerank,
|
||||||
|
use_compression=self._use_compression,
|
||||||
|
)
|
||||||
|
else:
|
||||||
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
results = await self._rag_service.search(query, knowledge_base_ids=kb_ids, top_k=top_k)
|
||||||
for r in results:
|
for r in results:
|
||||||
|
kb_id = r.get("knowledge_base_id", "")
|
||||||
|
score = r.get("score", 0.0)
|
||||||
|
# Apply per-KB weights
|
||||||
|
if self._kb_weights and kb_id in self._kb_weights:
|
||||||
|
score *= self._kb_weights[kb_id]
|
||||||
items.append(MemoryItem(
|
items.append(MemoryItem(
|
||||||
key=r.get("id", ""),
|
key=r.get("id", ""),
|
||||||
value=r.get("content", ""),
|
value=r.get("content", ""),
|
||||||
metadata={
|
metadata={
|
||||||
"source": r.get("source", "rag"),
|
"source": r.get("source", "rag"),
|
||||||
"score": r.get("score", 0.0),
|
"score": score,
|
||||||
"document_id": r.get("document_id"),
|
"document_id": r.get("document_id"),
|
||||||
|
"knowledge_base_id": kb_id,
|
||||||
},
|
},
|
||||||
score=r.get("score", 0.0),
|
score=score,
|
||||||
))
|
))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"RAG search failed: {e}")
|
logger.error(f"RAG search failed: {e}")
|
||||||
|
|
|
||||||
|
|
@ -189,6 +189,10 @@ def create_app(
|
||||||
semantic = SemanticMemory(
|
semantic = SemanticMemory(
|
||||||
rag_service=rag_service,
|
rag_service=rag_service,
|
||||||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||||||
|
search_mode=sem_conf.get("search_mode", "standard"),
|
||||||
|
use_rerank=sem_conf.get("use_rerank", True),
|
||||||
|
use_compression=sem_conf.get("use_compression", False),
|
||||||
|
kb_weights=sem_conf.get("kb_weights"),
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_retriever = MemoryRetriever(
|
memory_retriever = MemoryRetriever(
|
||||||
|
|
@ -197,6 +201,12 @@ def create_app(
|
||||||
semantic_memory=semantic,
|
semantic_memory=semantic,
|
||||||
)
|
)
|
||||||
app.state.memory_retriever = memory_retriever
|
app.state.memory_retriever = memory_retriever
|
||||||
|
|
||||||
|
# Auto-register retrieve_knowledge tool if semantic memory is configured
|
||||||
|
if memory_retriever:
|
||||||
|
retrieve_tool = memory_retriever.create_retrieve_tool()
|
||||||
|
if retrieve_tool:
|
||||||
|
app.state.retrieve_knowledge_tool = retrieve_tool
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger(__name__).warning(f"Failed to initialize memory components: {e}")
|
logging.getLogger(__name__).warning(f"Failed to initialize memory components: {e}")
|
||||||
|
|
|
||||||
|
|
@ -470,3 +470,321 @@ memory:
|
||||||
config = ServerConfig.from_yaml(f.name)
|
config = ServerConfig.from_yaml(f.name)
|
||||||
|
|
||||||
assert config.memory["semantic"]["api_key"] == "sk-default"
|
assert config.memory["semantic"]["api_key"] == "sk-default"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HttpRAGService enhanced_search tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestHttpRAGServiceEnhancedSearch:
|
||||||
|
"""HttpRAGService.enhanced_search — 增强语义检索"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def svc(self):
|
||||||
|
return HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
api_key="test-key",
|
||||||
|
knowledge_base_ids=["kb-1", "kb-2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_single_kb(self, svc):
|
||||||
|
"""单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression"""
|
||||||
|
svc._knowledge_base_ids = ["kb-1"]
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.enhanced_search("AI 趋势", top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["content"] == "AI 趋势"
|
||||||
|
assert results[0]["score"] == 0.95
|
||||||
|
|
||||||
|
# Verify endpoint and payload
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[0][0] == "/bases/kb-1/retrieve"
|
||||||
|
payload = call_args[1]["json"]
|
||||||
|
assert payload["query"] == "AI 趋势"
|
||||||
|
assert payload["top_k"] == 5
|
||||||
|
assert payload["use_rerank"] is True
|
||||||
|
assert payload["use_compression"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_multiple_kbs(self, svc):
|
||||||
|
"""多知识库增强检索,结果合并并按 score 降序排序"""
|
||||||
|
# First KB returns one result
|
||||||
|
resp1 = MagicMock()
|
||||||
|
resp1.status_code = 200
|
||||||
|
resp1.raise_for_status = MagicMock()
|
||||||
|
resp1.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Second KB returns one result with higher score
|
||||||
|
resp2 = MagicMock()
|
||||||
|
resp2.status_code = 200
|
||||||
|
resp2.raise_for_status = MagicMock()
|
||||||
|
resp2.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(side_effect=[resp1, resp2])
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.enhanced_search("test query", top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# Merged results sorted by score descending
|
||||||
|
assert results[0]["content"] == "KB2 结果"
|
||||||
|
assert results[0]["score"] == 0.95
|
||||||
|
assert results[1]["content"] == "KB1 结果"
|
||||||
|
assert results[1]["score"] == 0.8
|
||||||
|
|
||||||
|
# Verify both KB endpoints were called
|
||||||
|
calls = mock_client.post.call_args_list
|
||||||
|
assert calls[0][0][0] == "/bases/kb-1/retrieve"
|
||||||
|
assert calls[1][0][0] == "/bases/kb-2/retrieve"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_404_fallback(self, svc):
|
||||||
|
"""404 响应回退到标准 search 方法"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 404
|
||||||
|
mock_resp.text = "Not Found"
|
||||||
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"404", request=MagicMock(), response=mock_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
# Mock the standard search method
|
||||||
|
svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}])
|
||||||
|
|
||||||
|
results = await svc.enhanced_search("test query")
|
||||||
|
|
||||||
|
# Should have fallen back to search()
|
||||||
|
svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1", "kb-2"], top_k=5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["id"] == "fallback"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_http_error(self, svc):
|
||||||
|
"""非 404 HTTP 错误返回空列表"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 500
|
||||||
|
mock_resp.text = "Internal Server Error"
|
||||||
|
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"500", request=MagicMock(), response=mock_resp
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
results = await svc.enhanced_search("test query")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_with_compression(self, svc):
|
||||||
|
"""验证 use_compression: true 在 payload 中"""
|
||||||
|
svc._knowledge_base_ids = ["kb-1"]
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"results": []}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
await svc.enhanced_search("test", use_compression=True)
|
||||||
|
|
||||||
|
payload = mock_client.post.call_args[1]["json"]
|
||||||
|
assert payload["use_compression"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhanced_search_without_rerank(self, svc):
|
||||||
|
"""验证 use_rerank: false 在 payload 中"""
|
||||||
|
svc._knowledge_base_ids = ["kb-1"]
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {"results": []}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
svc._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
await svc.enhanced_search("test", use_rerank=False)
|
||||||
|
|
||||||
|
payload = mock_client.post.call_args[1]["json"]
|
||||||
|
assert payload["use_rerank"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SemanticMemory enhanced search mode tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemanticMemoryEnhancedSearch:
|
||||||
|
"""SemanticMemory search_mode — 增强检索模式"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_mode_enhanced(self):
|
||||||
|
"""search_mode="enhanced" 时调用 enhanced_search"""
|
||||||
|
rag = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock enhanced_search
|
||||||
|
rag.enhanced_search = AsyncMock(return_value=[
|
||||||
|
{"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"},
|
||||||
|
])
|
||||||
|
|
||||||
|
semantic = SemanticMemory(
|
||||||
|
rag_service=rag,
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
search_mode="enhanced",
|
||||||
|
use_rerank=True,
|
||||||
|
use_compression=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
items = await semantic.search("test query", top_k=3)
|
||||||
|
|
||||||
|
rag.enhanced_search.assert_called_once_with(
|
||||||
|
"test query",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
top_k=3,
|
||||||
|
use_rerank=True,
|
||||||
|
use_compression=False,
|
||||||
|
)
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].value == "enhanced result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_mode_standard(self):
|
||||||
|
"""search_mode="standard" 时调用标准 search"""
|
||||||
|
rag = HttpRAGService(
|
||||||
|
base_url="http://localhost:8000/api/knowledge",
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock standard search
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.raise_for_status = MagicMock()
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"results": [
|
||||||
|
{"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||||
|
rag._get_client = MagicMock(return_value=mock_client)
|
||||||
|
|
||||||
|
semantic = SemanticMemory(
|
||||||
|
rag_service=rag,
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
search_mode="standard",
|
||||||
|
)
|
||||||
|
|
||||||
|
items = await semantic.search("test query", top_k=3)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].value == "standard result"
|
||||||
|
# Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve
|
||||||
|
call_args = mock_client.post.call_args
|
||||||
|
assert call_args[0][0] == "/search"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_mode_enhanced_fallback(self):
|
||||||
|
"""search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search"""
|
||||||
|
|
||||||
|
class SimpleRAGService:
|
||||||
|
"""A RAG service without enhanced_search"""
|
||||||
|
async def search(self, query, knowledge_base_ids=None, top_k=5):
|
||||||
|
return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}]
|
||||||
|
|
||||||
|
rag = SimpleRAGService()
|
||||||
|
semantic = SemanticMemory(
|
||||||
|
rag_service=rag,
|
||||||
|
knowledge_base_ids=["kb-1"],
|
||||||
|
search_mode="enhanced",
|
||||||
|
)
|
||||||
|
|
||||||
|
items = await semantic.search("test query", top_k=3)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].value == "simple result"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Config enhanced search tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigEnhancedSearch:
|
||||||
|
"""ServerConfig 解析 enhanced search 相关配置"""
|
||||||
|
|
||||||
|
def test_config_search_mode(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://geo:8000/api/knowledge",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"knowledge_base_ids": ["kb-1"],
|
||||||
|
"search_mode": "enhanced",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["semantic"]["search_mode"] == "enhanced"
|
||||||
|
|
||||||
|
def test_config_use_rerank(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://geo:8000/api/knowledge",
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"knowledge_base_ids": ["kb-1"],
|
||||||
|
"use_rerank": False,
|
||||||
|
"use_compression": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["semantic"]["use_rerank"] is False
|
||||||
|
assert config.memory["semantic"]["use_compression"] is True
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ class TestMemoryContextInjection:
|
||||||
system_msg = messages_sent[0]
|
system_msg = messages_sent[0]
|
||||||
assert system_msg["role"] == "system"
|
assert system_msg["role"] == "system"
|
||||||
assert "You are a helpful assistant." in system_msg["content"]
|
assert "You are a helpful assistant." in system_msg["content"]
|
||||||
assert "Relevant Past Experience" in system_msg["content"]
|
assert "参考信息" in system_msg["content"]
|
||||||
assert "Previous task result: success" in system_msg["content"]
|
assert "Previous task result: success" in system_msg["content"]
|
||||||
|
|
||||||
async def test_memory_context_used_as_system_prompt_when_none(self):
|
async def test_memory_context_used_as_system_prompt_when_none(self):
|
||||||
|
|
@ -113,7 +113,7 @@ class TestMemoryContextInjection:
|
||||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
system_msg = messages_sent[0]
|
system_msg = messages_sent[0]
|
||||||
assert system_msg["role"] == "system"
|
assert system_msg["role"] == "system"
|
||||||
assert "Relevant Past Experience" in system_msg["content"]
|
assert "参考信息" in system_msg["content"]
|
||||||
assert "Past context only" in system_msg["content"]
|
assert "Past context only" in system_msg["content"]
|
||||||
|
|
||||||
async def test_no_memory_context_when_retriever_is_none(self):
|
async def test_no_memory_context_when_retriever_is_none(self):
|
||||||
|
|
@ -132,7 +132,7 @@ class TestMemoryContextInjection:
|
||||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
system_msg = messages_sent[0]
|
system_msg = messages_sent[0]
|
||||||
assert system_msg["content"] == "You are a helper."
|
assert system_msg["content"] == "You are a helper."
|
||||||
assert "Relevant Past Experience" not in system_msg["content"]
|
assert "参考信息" not in system_msg["content"]
|
||||||
|
|
||||||
async def test_empty_memory_context_not_injected(self):
|
async def test_empty_memory_context_not_injected(self):
|
||||||
"""当 memory context 为空字符串时,不注入"""
|
"""当 memory context 为空字符串时,不注入"""
|
||||||
|
|
@ -152,7 +152,7 @@ class TestMemoryContextInjection:
|
||||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
system_msg = messages_sent[0]
|
system_msg = messages_sent[0]
|
||||||
assert system_msg["content"] == "You are a helper."
|
assert system_msg["content"] == "You are a helper."
|
||||||
assert "Relevant Past Experience" not in system_msg["content"]
|
assert "参考信息" not in system_msg["content"]
|
||||||
|
|
||||||
|
|
||||||
# ── Test: Memory retrieval failure doesn't break execution ──────────
|
# ── Test: Memory retrieval failure doesn't break execution ──────────
|
||||||
|
|
@ -183,7 +183,7 @@ class TestMemoryRetrievalFailure:
|
||||||
call_args = gateway.chat.call_args
|
call_args = gateway.chat.call_args
|
||||||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
system_msg = messages_sent[0]
|
system_msg = messages_sent[0]
|
||||||
assert "Relevant Past Experience" not in system_msg["content"]
|
assert "参考信息" not in system_msg["content"]
|
||||||
|
|
||||||
|
|
||||||
# ── Test: Task result stored in episodic memory ──────────
|
# ── Test: Task result stored in episodic memory ──────────
|
||||||
|
|
@ -428,3 +428,275 @@ class TestConfigDrivenAgentMemory:
|
||||||
agent = ConfigDrivenAgent(config=config)
|
agent = ConfigDrivenAgent(config=config)
|
||||||
# Either retriever was created or gracefully failed
|
# Either retriever was created or gracefully failed
|
||||||
# The key is that no exception is raised
|
# The key is that no exception is raised
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Structured Context Injection ──────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStructuredContextInjection:
|
||||||
|
"""U3: 结构化上下文注入测试"""
|
||||||
|
|
||||||
|
async def test_structured_format_with_rag_results(self):
|
||||||
|
"""结构化格式:RAG 结果包含知识库参考标题"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
# Mock retrieve to return RAG items
|
||||||
|
rag_item = MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value="AI行业在2025年呈现三大趋势...",
|
||||||
|
metadata={"source": "rag", "kb_type": "行业库", "document_title": "AI行业趋势报告"},
|
||||||
|
score=0.92,
|
||||||
|
)
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[rag_item])
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="AI trends", top_k=5, token_budget=3000)
|
||||||
|
|
||||||
|
assert "### 知识库参考 [来源: 行业库 | 相关度: 0.92 | 文档: AI行业趋势报告]" in result
|
||||||
|
assert "AI行业在2025年呈现三大趋势..." in result
|
||||||
|
|
||||||
|
async def test_structured_format_with_episodic_results(self):
|
||||||
|
"""结构化格式:情景记忆结果包含过往经验标题"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
episodic_item = MemoryItem(
|
||||||
|
key="task:seo-001",
|
||||||
|
value="上次分析竞品SEO策略时发现...",
|
||||||
|
metadata={"source": "episodic", "task_type": "seo_analysis"},
|
||||||
|
score=0.85,
|
||||||
|
)
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[episodic_item])
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="SEO analysis", top_k=5, token_budget=3000)
|
||||||
|
|
||||||
|
assert "### 过往经验 [来源: 情景记忆 | 任务类型: seo_analysis]" in result
|
||||||
|
assert "上次分析竞品SEO策略时发现..." in result
|
||||||
|
|
||||||
|
async def test_structured_format_with_mixed_sources(self):
|
||||||
|
"""结构化格式:不同来源生成不同标题"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
items = [
|
||||||
|
MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value="RAG content here",
|
||||||
|
metadata={"source": "rag", "kb_type": "行业库", "document_title": "报告A"},
|
||||||
|
score=0.90,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="task:ep-1",
|
||||||
|
value="Episodic content here",
|
||||||
|
metadata={"source": "episodic", "task_type": "analysis"},
|
||||||
|
score=0.80,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="entity-1",
|
||||||
|
value="Graph content here",
|
||||||
|
metadata={"source": "graph"},
|
||||||
|
score=0.75,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="ctx-1",
|
||||||
|
value="Working memory content",
|
||||||
|
metadata={"source": "working"},
|
||||||
|
score=0.60,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="other-1",
|
||||||
|
value="Unknown source content",
|
||||||
|
metadata={"source": "custom"},
|
||||||
|
score=0.50,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
retriever.retrieve = AsyncMock(return_value=items)
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
|
||||||
|
|
||||||
|
assert "### 知识库参考" in result
|
||||||
|
assert "### 过往经验" in result
|
||||||
|
assert "### 知识图谱" in result
|
||||||
|
assert "### 工作记忆" in result
|
||||||
|
assert "### 参考 [来源: custom" in result
|
||||||
|
|
||||||
|
async def test_flat_format_backward_compatible(self):
|
||||||
|
"""Flat 格式:纯文本拼接,无标题行"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="flat")
|
||||||
|
|
||||||
|
items = [
|
||||||
|
MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value="First result",
|
||||||
|
metadata={"source": "rag"},
|
||||||
|
score=0.9,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="ep-1",
|
||||||
|
value="Second result",
|
||||||
|
metadata={"source": "episodic"},
|
||||||
|
score=0.8,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
retriever.retrieve = AsyncMock(return_value=items)
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
|
||||||
|
|
||||||
|
# No structured headers
|
||||||
|
assert "### 知识库参考" not in result
|
||||||
|
assert "### 过往经验" not in result
|
||||||
|
# Just plain text values joined by double newline
|
||||||
|
assert "First result" in result
|
||||||
|
assert "Second result" in result
|
||||||
|
assert result == "First result\n\nSecond result"
|
||||||
|
|
||||||
|
async def test_token_budget_truncation_in_structured_format(self):
|
||||||
|
"""结构化格式:超长结果被截断以符合 token 预算"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
# Create a very long content item
|
||||||
|
long_value = "A" * 20000
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value=long_value,
|
||||||
|
metadata={"source": "rag", "kb_type": "知识库", "document_title": "大文档"},
|
||||||
|
score=0.9,
|
||||||
|
)
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[item])
|
||||||
|
|
||||||
|
# Very small token budget
|
||||||
|
result = await retriever.get_context_string(query="test", top_k=5, token_budget=100)
|
||||||
|
|
||||||
|
# Result should be truncated (100 tokens * 4 chars = 400 chars max)
|
||||||
|
assert len(result) <= 400
|
||||||
|
|
||||||
|
async def test_empty_results_returns_empty_string(self):
|
||||||
|
"""空结果:返回空字符串"""
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
async def test_context_template_parameter(self):
|
||||||
|
"""context_template 参数:flat 模式产生纯文本输出"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
# Test with flat template
|
||||||
|
retriever_flat = MemoryRetriever(context_template="flat")
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value="Flat content",
|
||||||
|
metadata={"source": "rag"},
|
||||||
|
score=0.9,
|
||||||
|
)
|
||||||
|
retriever_flat.retrieve = AsyncMock(return_value=[item])
|
||||||
|
|
||||||
|
result_flat = await retriever_flat.get_context_string(query="test")
|
||||||
|
assert "### 知识库参考" not in result_flat
|
||||||
|
assert "Flat content" in result_flat
|
||||||
|
|
||||||
|
# Test with structured template (default)
|
||||||
|
retriever_structured = MemoryRetriever(context_template="structured")
|
||||||
|
retriever_structured.retrieve = AsyncMock(return_value=[item])
|
||||||
|
|
||||||
|
result_structured = await retriever_structured.get_context_string(query="test")
|
||||||
|
assert "### 知识库参考" in result_structured
|
||||||
|
|
||||||
|
async def test_structured_format_default_kb_type(self):
|
||||||
|
"""结构化格式:RAG 结果缺少 kb_type 时使用默认值"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc-1",
|
||||||
|
value="Content without kb_type",
|
||||||
|
metadata={"source": "rag", "document_title": "报告B"},
|
||||||
|
score=0.88,
|
||||||
|
)
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[item])
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="test")
|
||||||
|
assert "### 知识库参考 [来源: 知识库 | 相关度: 0.88 | 文档: 报告B]" in result
|
||||||
|
|
||||||
|
async def test_structured_format_default_task_type(self):
|
||||||
|
"""结构化格式:情景记忆缺少 task_type 时使用默认值"""
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(context_template="structured")
|
||||||
|
|
||||||
|
item = MemoryItem(
|
||||||
|
key="ep-1",
|
||||||
|
value="Content without task_type",
|
||||||
|
metadata={"source": "episodic"},
|
||||||
|
score=0.75,
|
||||||
|
)
|
||||||
|
retriever.retrieve = AsyncMock(return_value=[item])
|
||||||
|
|
||||||
|
result = await retriever.get_context_string(query="test")
|
||||||
|
assert "### 过往经验 [来源: 情景记忆 | 任务类型: 未知]" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: ReAct Context Injection Format ──────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestReActContextInjectionFormat:
|
||||||
|
"""U3: ReActEngine 使用新标题格式"""
|
||||||
|
|
||||||
|
async def test_react_uses_new_heading(self):
|
||||||
|
"""ReActEngine 使用 '## 参考信息' 标题(非旧标题)"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("Some context data")
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
system_prompt="You are a helper.",
|
||||||
|
memory_retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ReActResult)
|
||||||
|
call_args = gateway.chat.call_args
|
||||||
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
|
system_msg = messages_sent[0]
|
||||||
|
assert "## 参考信息" in system_msg["content"]
|
||||||
|
assert "Relevant Past Experience" not in system_msg["content"]
|
||||||
|
|
||||||
|
async def test_react_new_heading_when_no_system_prompt(self):
|
||||||
|
"""没有 system_prompt 时,新标题作为 system_prompt 开头"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("Context only")
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ReActResult)
|
||||||
|
call_args = gateway.chat.call_args
|
||||||
|
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||||
|
system_msg = messages_sent[0]
|
||||||
|
assert system_msg["content"].startswith("## 参考信息")
|
||||||
|
assert "Relevant Past Experience" not in system_msg["content"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,335 @@
|
||||||
|
"""QueryTransformer 单元测试"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
from agentkit.memory.query_transformer import (
|
||||||
|
LLMQueryTransformer,
|
||||||
|
NoOpQueryTransformer,
|
||||||
|
QueryTransformerBase,
|
||||||
|
RuleQueryTransformer,
|
||||||
|
TransformedQuery,
|
||||||
|
create_query_transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryMemory(Memory):
|
||||||
|
"""基于内存的 Memory 实现,用于测试"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._store: dict[str, MemoryItem] = {}
|
||||||
|
|
||||||
|
async def store(self, key: str, value, metadata=None) -> None:
|
||||||
|
self._store[key] = MemoryItem(
|
||||||
|
key=key, value=value, metadata=metadata or {}, score=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||||||
|
results = []
|
||||||
|
for item in self._store.values():
|
||||||
|
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||||||
|
results.append(item)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
return self._store.pop(key, None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestTransformedQuery ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransformedQuery:
|
||||||
|
"""TransformedQuery dataclass 测试"""
|
||||||
|
|
||||||
|
def test_creation_and_field_access(self):
|
||||||
|
tq = TransformedQuery(
|
||||||
|
main_query="SEO策略",
|
||||||
|
sub_queries=["搜索引擎优化策略"],
|
||||||
|
original_query="帮我分析一下SEO策略",
|
||||||
|
)
|
||||||
|
assert tq.main_query == "SEO策略"
|
||||||
|
assert tq.sub_queries == ["搜索引擎优化策略"]
|
||||||
|
assert tq.original_query == "帮我分析一下SEO策略"
|
||||||
|
|
||||||
|
def test_empty_sub_queries(self):
|
||||||
|
tq = TransformedQuery(main_query="AI趋势", sub_queries=[], original_query="AI趋势")
|
||||||
|
assert tq.sub_queries == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestLLMQueryTransformer ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMQueryTransformer:
|
||||||
|
"""LLMQueryTransformer 测试"""
|
||||||
|
|
||||||
|
async def test_successful_transformation(self):
|
||||||
|
"""LLM 返回有效 JSON,验证 main_query 和 sub_queries"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
gateway.chat.return_value = MagicMock(
|
||||||
|
content=json.dumps({
|
||||||
|
"main_query": "SEO optimization strategies",
|
||||||
|
"sub_queries": ["search engine ranking", "keyword research"],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer = LLMQueryTransformer(gateway)
|
||||||
|
result = await transformer.transform("How to improve SEO?")
|
||||||
|
|
||||||
|
assert result.main_query == "SEO optimization strategies"
|
||||||
|
assert len(result.sub_queries) == 2
|
||||||
|
assert "search engine ranking" in result.sub_queries
|
||||||
|
assert result.original_query == "How to improve SEO?"
|
||||||
|
|
||||||
|
async def test_llm_error_fallback(self):
|
||||||
|
"""LLM 抛出异常,回退到原始查询"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
gateway.chat.side_effect = Exception("LLM service unavailable")
|
||||||
|
|
||||||
|
transformer = LLMQueryTransformer(gateway)
|
||||||
|
result = await transformer.transform("test query")
|
||||||
|
|
||||||
|
assert result.main_query == "test query"
|
||||||
|
assert result.sub_queries == []
|
||||||
|
assert result.original_query == "test query"
|
||||||
|
|
||||||
|
async def test_invalid_json_response(self):
|
||||||
|
"""LLM 返回非 JSON,回退到原始查询"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
gateway.chat.return_value = MagicMock(content="This is not JSON")
|
||||||
|
|
||||||
|
transformer = LLMQueryTransformer(gateway)
|
||||||
|
result = await transformer.transform("test query")
|
||||||
|
|
||||||
|
assert result.main_query == "test query"
|
||||||
|
assert result.sub_queries == []
|
||||||
|
|
||||||
|
async def test_max_sub_queries_limit(self):
|
||||||
|
"""LLM 返回 5 个 sub_queries,但 max_sub_queries=3,只保留 3 个"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
gateway.chat.return_value = MagicMock(
|
||||||
|
content=json.dumps({
|
||||||
|
"main_query": "query",
|
||||||
|
"sub_queries": ["sq1", "sq2", "sq3", "sq4", "sq5"],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer = LLMQueryTransformer(gateway, max_sub_queries=3)
|
||||||
|
result = await transformer.transform("test")
|
||||||
|
|
||||||
|
assert len(result.sub_queries) == 3
|
||||||
|
assert result.sub_queries == ["sq1", "sq2", "sq3"]
|
||||||
|
|
||||||
|
async def test_prompt_contains_original_query(self):
|
||||||
|
"""验证发送给 LLM 的 prompt 包含原始查询"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
gateway.chat.return_value = MagicMock(
|
||||||
|
content=json.dumps({"main_query": "q", "sub_queries": []})
|
||||||
|
)
|
||||||
|
|
||||||
|
transformer = LLMQueryTransformer(gateway)
|
||||||
|
await transformer.transform("my original query")
|
||||||
|
|
||||||
|
call_args = gateway.chat.call_args
|
||||||
|
messages = call_args.kwargs.get("messages") or call_args[1].get("messages") or call_args[0][0]
|
||||||
|
# The prompt should contain the original query
|
||||||
|
prompt_text = messages[0]["content"]
|
||||||
|
assert "my original query" in prompt_text
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRuleQueryTransformer ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleQueryTransformer:
|
||||||
|
"""RuleQueryTransformer 测试"""
|
||||||
|
|
||||||
|
async def test_chinese_filler_word_removal(self):
|
||||||
|
"""去除中文填充词:'帮我分析一下SEO策略' → main_query 包含 'SEO策略'"""
|
||||||
|
transformer = RuleQueryTransformer()
|
||||||
|
result = await transformer.transform("帮我分析一下SEO策略")
|
||||||
|
|
||||||
|
assert "SEO策略" in result.main_query
|
||||||
|
assert "帮我" not in result.main_query
|
||||||
|
assert "一下" not in result.main_query
|
||||||
|
assert result.original_query == "帮我分析一下SEO策略"
|
||||||
|
|
||||||
|
async def test_english_filler_word_removal(self):
|
||||||
|
"""去除英文填充词:'Please help me analyze' → main_query 包含 'analyze'"""
|
||||||
|
transformer = RuleQueryTransformer()
|
||||||
|
result = await transformer.transform("Please help me analyze")
|
||||||
|
|
||||||
|
assert "analyze" in result.main_query
|
||||||
|
assert "Please" not in result.main_query
|
||||||
|
assert "help me" not in result.main_query
|
||||||
|
|
||||||
|
async def test_synonym_expansion(self):
|
||||||
|
"""同义扩展:SEO → 搜索引擎优化, Search Engine Optimization"""
|
||||||
|
synonyms = {"SEO": ["搜索引擎优化", "Search Engine Optimization"]}
|
||||||
|
transformer = RuleQueryTransformer(synonyms=synonyms)
|
||||||
|
|
||||||
|
result = await transformer.transform("SEO策略")
|
||||||
|
|
||||||
|
assert "SEO策略" in result.main_query
|
||||||
|
assert len(result.sub_queries) == 2
|
||||||
|
assert any("搜索引擎优化" in sq for sq in result.sub_queries)
|
||||||
|
assert any("Search Engine Optimization" in sq for sq in result.sub_queries)
|
||||||
|
|
||||||
|
async def test_no_op_for_clean_query(self):
|
||||||
|
"""干净查询原样返回:'AI行业趋势' → 不变"""
|
||||||
|
transformer = RuleQueryTransformer()
|
||||||
|
result = await transformer.transform("AI行业趋势")
|
||||||
|
|
||||||
|
assert result.main_query == "AI行业趋势"
|
||||||
|
assert result.sub_queries == []
|
||||||
|
|
||||||
|
async def test_max_sub_queries_limit(self):
|
||||||
|
"""同义扩展受 max_sub_queries 限制"""
|
||||||
|
synonyms = {"AI": ["人工智能", "Artificial Intelligence", "machine intelligence", "ML"]}
|
||||||
|
transformer = RuleQueryTransformer(synonyms=synonyms, max_sub_queries=2)
|
||||||
|
|
||||||
|
result = await transformer.transform("AI trends")
|
||||||
|
|
||||||
|
assert len(result.sub_queries) <= 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestNoOpQueryTransformer ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoOpQueryTransformer:
|
||||||
|
"""NoOpQueryTransformer 测试"""
|
||||||
|
|
||||||
|
async def test_returns_original_query_unchanged(self):
|
||||||
|
"""原样返回原始查询"""
|
||||||
|
transformer = NoOpQueryTransformer()
|
||||||
|
result = await transformer.transform("帮我分析一下SEO策略")
|
||||||
|
|
||||||
|
assert result.main_query == "帮我分析一下SEO策略"
|
||||||
|
assert result.sub_queries == []
|
||||||
|
assert result.original_query == "帮我分析一下SEO策略"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestCreateQueryTransformer ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateQueryTransformer:
|
||||||
|
"""create_query_transformer 工厂函数测试"""
|
||||||
|
|
||||||
|
def test_llm_strategy(self):
|
||||||
|
"""strategy='llm' 创建 LLMQueryTransformer"""
|
||||||
|
gateway = AsyncMock()
|
||||||
|
transformer = create_query_transformer(strategy="llm", llm_gateway=gateway)
|
||||||
|
assert isinstance(transformer, LLMQueryTransformer)
|
||||||
|
|
||||||
|
def test_rule_strategy(self):
|
||||||
|
"""strategy='rule' 创建 RuleQueryTransformer"""
|
||||||
|
transformer = create_query_transformer(strategy="rule")
|
||||||
|
assert isinstance(transformer, RuleQueryTransformer)
|
||||||
|
|
||||||
|
def test_none_strategy(self):
|
||||||
|
"""strategy='none' 创建 NoOpQueryTransformer"""
|
||||||
|
transformer = create_query_transformer(strategy="none")
|
||||||
|
assert isinstance(transformer, NoOpQueryTransformer)
|
||||||
|
|
||||||
|
def test_unknown_strategy_defaults_to_noop(self):
|
||||||
|
"""未知 strategy 默认创建 NoOpQueryTransformer"""
|
||||||
|
transformer = create_query_transformer(strategy="unknown")
|
||||||
|
assert isinstance(transformer, NoOpQueryTransformer)
|
||||||
|
|
||||||
|
def test_llm_strategy_without_gateway_falls_back(self):
|
||||||
|
"""strategy='llm' 但无 gateway 时回退到 NoOp"""
|
||||||
|
transformer = create_query_transformer(strategy="llm", llm_gateway=None)
|
||||||
|
assert isinstance(transformer, NoOpQueryTransformer)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestMemoryRetrieverWithTransformer ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryRetrieverWithTransformer:
|
||||||
|
"""MemoryRetriever 集成 QueryTransformer 测试"""
|
||||||
|
|
||||||
|
async def test_retrieve_calls_transformer_before_search(self):
|
||||||
|
"""retrieve() 在搜索前调用 transformer"""
|
||||||
|
memory = InMemoryMemory()
|
||||||
|
await memory.store("k1", "SEO optimization content")
|
||||||
|
|
||||||
|
transformer = AsyncMock(spec=QueryTransformerBase)
|
||||||
|
transformer.transform.return_value = TransformedQuery(
|
||||||
|
main_query="SEO optimization",
|
||||||
|
sub_queries=[],
|
||||||
|
original_query="帮我分析一下SEO",
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(
|
||||||
|
working_memory=memory,
|
||||||
|
query_transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await retriever.retrieve("帮我分析一下SEO")
|
||||||
|
|
||||||
|
transformer.transform.assert_called_once_with("帮我分析一下SEO")
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
async def test_sub_queries_searched_in_parallel(self):
|
||||||
|
"""子查询被并行搜索"""
|
||||||
|
memory = InMemoryMemory()
|
||||||
|
await memory.store("k1", "SEO optimization content")
|
||||||
|
await memory.store("k2", "Search engine ranking factors")
|
||||||
|
|
||||||
|
transformer = AsyncMock(spec=QueryTransformerBase)
|
||||||
|
transformer.transform.return_value = TransformedQuery(
|
||||||
|
main_query="SEO optimization",
|
||||||
|
sub_queries=["search engine ranking"],
|
||||||
|
original_query="SEO",
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(
|
||||||
|
working_memory=memory,
|
||||||
|
query_transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await retriever.retrieve("SEO")
|
||||||
|
# Both main query and sub-query results should be present
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
async def test_results_deduplicated_by_key(self):
|
||||||
|
"""子查询结果按 key 去重,保留最高分"""
|
||||||
|
memory = InMemoryMemory()
|
||||||
|
await memory.store("k1", "SEO optimization content")
|
||||||
|
|
||||||
|
# The same key appears in both main and sub-query results
|
||||||
|
transformer = AsyncMock(spec=QueryTransformerBase)
|
||||||
|
transformer.transform.return_value = TransformedQuery(
|
||||||
|
main_query="SEO",
|
||||||
|
sub_queries=["SEO"], # Same query → same key match
|
||||||
|
original_query="SEO",
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(
|
||||||
|
working_memory=memory,
|
||||||
|
query_transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await retriever.retrieve("SEO")
|
||||||
|
# Should not have duplicate keys
|
||||||
|
keys = [r.key for r in results]
|
||||||
|
assert len(keys) == len(set(keys))
|
||||||
|
|
||||||
|
async def test_without_transformer_backward_compatible(self):
|
||||||
|
"""不设置 transformer 时行为不变(向后兼容)"""
|
||||||
|
memory = InMemoryMemory()
|
||||||
|
await memory.store("k1", "AI research content")
|
||||||
|
|
||||||
|
retriever = MemoryRetriever(working_memory=memory)
|
||||||
|
results = await retriever.retrieve("AI")
|
||||||
|
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert results[0].key == "k1"
|
||||||
|
|
@ -0,0 +1,438 @@
|
||||||
|
"""U5: Configurable Retrieval Parameters + Per-KB Weights
|
||||||
|
|
||||||
|
Tests for:
|
||||||
|
1. ReActEngine uses configurable top_k/token_budget from retrieval_config
|
||||||
|
2. ConfigDrivenAgent passes retrieval_config from memory config
|
||||||
|
3. SemanticMemory applies per-KB weight multipliers to scores
|
||||||
|
4. Improved token estimation for mixed Chinese/English text
|
||||||
|
5. ServerConfig parsing with memory.retrieval and memory.semantic.kb_weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.react import ReActEngine, ReActResult
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever, _estimate_tokens
|
||||||
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test Helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||||||
|
gateway = MagicMock(spec=LLMGateway)
|
||||||
|
gateway.chat = AsyncMock(side_effect=responses)
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
|
def make_response(
|
||||||
|
content: str = "",
|
||||||
|
prompt_tokens: int = 10,
|
||||||
|
completion_tokens: int = 20,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model="test-model",
|
||||||
|
usage=TokenUsage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
),
|
||||||
|
tool_calls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_mock_memory_retriever(context_string: str = "past experience data"):
|
||||||
|
retriever = MagicMock()
|
||||||
|
retriever.get_context_string = AsyncMock(return_value=context_string)
|
||||||
|
retriever._episodic = None
|
||||||
|
retriever.store_episode = AsyncMock()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Configurable Retrieval Parameters ──────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigurableRetrievalParameters:
|
||||||
|
"""ReActEngine uses configurable top_k/token_budget from retrieval_config"""
|
||||||
|
|
||||||
|
async def test_default_top_k_when_no_config(self):
|
||||||
|
"""ReActEngine uses default top_k=5 when no config provided"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever.get_context_string.assert_awaited_once_with(
|
||||||
|
query="Hello",
|
||||||
|
top_k=5,
|
||||||
|
token_budget=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_configured_top_k(self):
|
||||||
|
"""ReActEngine uses configured top_k from retrieval_config"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
retrieval_config={"top_k": 10, "token_budget": 4000},
|
||||||
|
)
|
||||||
|
|
||||||
|
retriever.get_context_string.assert_awaited_once_with(
|
||||||
|
query="Hello",
|
||||||
|
top_k=10,
|
||||||
|
token_budget=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_configured_token_budget(self):
|
||||||
|
"""ReActEngine uses configured token_budget from retrieval_config"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
retrieval_config={"token_budget": 5000},
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = retriever.get_context_string.call_args
|
||||||
|
assert call_kwargs.kwargs.get("token_budget") == 5000
|
||||||
|
|
||||||
|
async def test_backward_compatibility_no_config(self):
|
||||||
|
"""No config = same behavior as before (top_k=5, token_budget=2000)"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
||||||
|
assert call_kwargs["top_k"] == 5
|
||||||
|
assert call_kwargs["token_budget"] == 2000
|
||||||
|
|
||||||
|
async def test_stream_uses_retrieval_config(self):
|
||||||
|
"""execute_stream also uses retrieval_config"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
events = []
|
||||||
|
async for event in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
retrieval_config={"top_k": 8, "token_budget": 3000},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
||||||
|
assert call_kwargs["top_k"] == 8
|
||||||
|
assert call_kwargs["token_budget"] == 3000
|
||||||
|
|
||||||
|
async def test_partial_config_uses_defaults(self):
|
||||||
|
"""Partial config: only top_k specified, token_budget falls back to default"""
|
||||||
|
gateway = make_mock_gateway([make_response(content="final answer")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
|
retriever = make_mock_memory_retriever("context")
|
||||||
|
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
memory_retriever=retriever,
|
||||||
|
retrieval_config={"top_k": 3},
|
||||||
|
)
|
||||||
|
|
||||||
|
call_kwargs = retriever.get_context_string.call_args.kwargs
|
||||||
|
assert call_kwargs["top_k"] == 3
|
||||||
|
assert call_kwargs["token_budget"] == 2000 # default
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigDrivenAgentRetrievalConfig:
|
||||||
|
"""ConfigDrivenAgent passes retrieval_config from memory config"""
|
||||||
|
|
||||||
|
async def test_retrieval_config_passed_to_react_engine(self):
|
||||||
|
"""ConfigDrivenAgent extracts retrieval config and passes to ReActEngine"""
|
||||||
|
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||||||
|
from agentkit.skills.base import SkillConfig
|
||||||
|
|
||||||
|
config = SkillConfig(
|
||||||
|
name="test-agent",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="llm_generate",
|
||||||
|
execution_mode="react",
|
||||||
|
prompt={"identity": "Test agent"},
|
||||||
|
memory={
|
||||||
|
"retrieval": {"top_k": 10, "token_budget": 5000},
|
||||||
|
"working": {"enabled": False},
|
||||||
|
"episodic": {"enabled": False},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway = MagicMock(spec=LLMGateway)
|
||||||
|
gateway.chat = AsyncMock(return_value=make_response(content="done"))
|
||||||
|
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_gateway=gateway)
|
||||||
|
|
||||||
|
# Verify the agent has memory config
|
||||||
|
assert agent._config.memory.get("retrieval") == {"top_k": 10, "token_budget": 5000}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Per-KB Weights ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerKBWeights:
|
||||||
|
"""SemanticMemory with kb_weights applies multipliers to scores"""
|
||||||
|
|
||||||
|
async def test_kb_weights_applied_to_scores(self):
|
||||||
|
"""kb_weights multiplies scores for matching KB IDs"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Industry data", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "industry-kb"},
|
||||||
|
{"id": "2", "content": "Enterprise data", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "enterprise-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["industry-kb", "enterprise-kb"],
|
||||||
|
kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
|
||||||
|
# Industry KB result should have higher score
|
||||||
|
industry_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "industry-kb")
|
||||||
|
enterprise_item = next(r for r in results if r.metadata.get("knowledge_base_id") == "enterprise-kb")
|
||||||
|
|
||||||
|
assert industry_item.score == pytest.approx(0.9 * 1.2)
|
||||||
|
assert enterprise_item.score == pytest.approx(0.9 * 0.8)
|
||||||
|
|
||||||
|
async def test_industry_kb_scores_higher_than_enterprise(self):
|
||||||
|
"""Industry KB (weight 1.2) results score higher than enterprise KB (weight 0.8)"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Enterprise result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "enterprise-kb"},
|
||||||
|
{"id": "2", "content": "Industry result", "score": 0.9, "source": "rag", "document_id": "d2", "knowledge_base_id": "industry-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["industry-kb", "enterprise-kb"],
|
||||||
|
kb_weights={"industry-kb": 1.2, "enterprise-kb": 0.8},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
|
||||||
|
# After sorting by score, industry should be first
|
||||||
|
assert results[0].metadata.get("knowledge_base_id") == "industry-kb"
|
||||||
|
assert results[0].score > results[1].score
|
||||||
|
|
||||||
|
async def test_unweighted_kb_gets_default_score(self):
|
||||||
|
"""Unweighted KBs get default score (1.0 multiplier)"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Unweighted result", "score": 0.8, "source": "rag", "document_id": "d1", "knowledge_base_id": "unweighted-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["unweighted-kb"],
|
||||||
|
kb_weights={"industry-kb": 1.5}, # no weight for unweighted-kb
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].score == pytest.approx(0.8) # unchanged
|
||||||
|
|
||||||
|
async def test_kb_weights_none_no_modification(self):
|
||||||
|
"""kb_weights=None: no score modification"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["some-kb"],
|
||||||
|
kb_weights=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
assert results[0].score == pytest.approx(0.75)
|
||||||
|
|
||||||
|
async def test_empty_kb_weights_no_modification(self):
|
||||||
|
"""Empty kb_weights dict: no score modification"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Result", "score": 0.75, "source": "rag", "document_id": "d1", "knowledge_base_id": "some-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["some-kb"],
|
||||||
|
kb_weights={},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
assert results[0].score == pytest.approx(0.75)
|
||||||
|
|
||||||
|
async def test_kb_id_propagated_to_metadata(self):
|
||||||
|
"""knowledge_base_id is propagated to MemoryItem metadata"""
|
||||||
|
rag_service = MagicMock()
|
||||||
|
rag_service.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "content": "Result", "score": 0.9, "source": "rag", "document_id": "d1", "knowledge_base_id": "my-kb"},
|
||||||
|
])
|
||||||
|
|
||||||
|
memory = SemanticMemory(
|
||||||
|
rag_service=rag_service,
|
||||||
|
knowledge_base_ids=["my-kb"],
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await memory.search("test query")
|
||||||
|
assert results[0].metadata["knowledge_base_id"] == "my-kb"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Token Estimation ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenEstimation:
|
||||||
|
"""Improved token estimation for mixed Chinese/English text"""
|
||||||
|
|
||||||
|
def test_pure_english_text(self):
|
||||||
|
"""Pure English text: ~1 token per word"""
|
||||||
|
text = "Hello world this is a test"
|
||||||
|
result = _estimate_tokens(text)
|
||||||
|
# 6 words * 1 = 6 tokens
|
||||||
|
assert result == 6
|
||||||
|
|
||||||
|
def test_pure_chinese_text(self):
|
||||||
|
"""Pure Chinese text: ~2 tokens per character"""
|
||||||
|
text = "你好世界测试"
|
||||||
|
result = _estimate_tokens(text)
|
||||||
|
# 6 CJK chars * 2 = 12 tokens
|
||||||
|
assert result == 12
|
||||||
|
|
||||||
|
def test_mixed_chinese_english_text(self):
|
||||||
|
"""Mixed Chinese/English text"""
|
||||||
|
text = "你好world测试test"
|
||||||
|
result = _estimate_tokens(text)
|
||||||
|
# 4 CJK chars * 2 = 8, plus 2 English words = 2, total = 10
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
def test_more_accurate_than_old_for_chinese(self):
|
||||||
|
"""New estimation is more accurate than len(text)//4 for Chinese text"""
|
||||||
|
text = "人工智能技术在近年来取得了巨大突破"
|
||||||
|
new_estimate = _estimate_tokens(text)
|
||||||
|
old_estimate = len(text) // 4
|
||||||
|
|
||||||
|
# For Chinese text, the old method underestimates
|
||||||
|
# 17 CJK chars * 2 = 34 tokens (new)
|
||||||
|
# 17 chars // 4 = 4 tokens (old) — way too low
|
||||||
|
assert new_estimate > old_estimate
|
||||||
|
assert new_estimate == 34
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
"""Empty string: 0 tokens"""
|
||||||
|
assert _estimate_tokens("") == 0
|
||||||
|
|
||||||
|
def test_whitespace_only(self):
|
||||||
|
"""Whitespace only: 0 tokens"""
|
||||||
|
assert _estimate_tokens(" ") == 0
|
||||||
|
|
||||||
|
def test_english_with_punctuation(self):
|
||||||
|
"""English with punctuation"""
|
||||||
|
text = "Hello, world! How are you?"
|
||||||
|
result = _estimate_tokens(text)
|
||||||
|
# "Hello," "world!" "How" "are" "you?" = 5 words
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Config Parsing ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigParsing:
|
||||||
|
"""ServerConfig.from_dict() with memory.retrieval and memory.semantic.kb_weights"""
|
||||||
|
|
||||||
|
def test_memory_retrieval_section(self):
|
||||||
|
"""ServerConfig.from_dict() preserves memory.retrieval section"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"retrieval": {
|
||||||
|
"top_k": 10,
|
||||||
|
"token_budget": 5000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["retrieval"]["top_k"] == 10
|
||||||
|
assert config.memory["retrieval"]["token_budget"] == 5000
|
||||||
|
|
||||||
|
def test_memory_semantic_kb_weights_section(self):
|
||||||
|
"""ServerConfig.from_dict() preserves memory.semantic.kb_weights section"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://localhost:8000",
|
||||||
|
"kb_weights": {
|
||||||
|
"industry-kb": 1.2,
|
||||||
|
"enterprise-kb": 0.8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["semantic"]["kb_weights"]["industry-kb"] == 1.2
|
||||||
|
assert config.memory["semantic"]["kb_weights"]["enterprise-kb"] == 0.8
|
||||||
|
|
||||||
|
def test_memory_config_without_retrieval(self):
|
||||||
|
"""ServerConfig.from_dict() works without memory.retrieval section"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {"enabled": False},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory.get("retrieval") is None
|
||||||
|
|
||||||
|
def test_memory_config_without_kb_weights(self):
|
||||||
|
"""ServerConfig.from_dict() works without kb_weights section"""
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://localhost:8000",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.memory["semantic"].get("kb_weights") is None
|
||||||
|
|
@ -0,0 +1,362 @@
|
||||||
|
"""U4 测试: RetrieveKnowledgeTool - RAG 管线内置工具
|
||||||
|
|
||||||
|
测试 retrieve_knowledge 工具的创建、执行、自动注册和集成。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.base import Memory, MemoryItem
|
||||||
|
from agentkit.memory.retriever import MemoryRetriever, RetrieveKnowledgeTool
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
|
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryMemory(Memory):
|
||||||
|
"""基于内存的 Memory 实现,用于测试"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._store: dict[str, MemoryItem] = {}
|
||||||
|
|
||||||
|
async def store(self, key: str, value, metadata=None) -> None:
|
||||||
|
self._store[key] = MemoryItem(
|
||||||
|
key=key, value=value, metadata=metadata or {}, score=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
||||||
|
return self._store.get(key)
|
||||||
|
|
||||||
|
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||||||
|
results = []
|
||||||
|
for item in self._store.values():
|
||||||
|
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||||||
|
results.append(item)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
return self._store.pop(key, None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRetrieveKnowledgeToolCreation ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveKnowledgeToolCreation:
|
||||||
|
"""RetrieveKnowledgeTool 创建测试"""
|
||||||
|
|
||||||
|
def test_create_retrieve_tool_returns_tool_when_semantic_configured(self):
|
||||||
|
"""有 semantic memory 时 create_retrieve_tool() 返回 Tool"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert isinstance(tool, Tool)
|
||||||
|
|
||||||
|
def test_create_retrieve_tool_returns_none_when_no_semantic(self):
|
||||||
|
"""无 semantic memory 时 create_retrieve_tool() 返回 None"""
|
||||||
|
retriever = MemoryRetriever()
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert tool is None
|
||||||
|
|
||||||
|
def test_create_retrieve_tool_with_working_only_returns_none(self):
|
||||||
|
"""仅有 working memory 时返回 None"""
|
||||||
|
working = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(working_memory=working)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert tool is None
|
||||||
|
|
||||||
|
def test_tool_has_correct_name(self):
|
||||||
|
"""工具名称为 retrieve_knowledge"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert tool.name == "retrieve_knowledge"
|
||||||
|
|
||||||
|
def test_tool_has_description(self):
|
||||||
|
"""工具包含描述"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert isinstance(tool.description, str)
|
||||||
|
assert len(tool.description) > 0
|
||||||
|
|
||||||
|
def test_tool_has_input_schema(self):
|
||||||
|
"""工具包含 input_schema"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert tool.input_schema is not None
|
||||||
|
assert tool.input_schema["type"] == "object"
|
||||||
|
assert "query" in tool.input_schema["properties"]
|
||||||
|
assert "query" in tool.input_schema["required"]
|
||||||
|
|
||||||
|
def test_tool_is_retrieve_knowledge_tool_instance(self):
|
||||||
|
"""工具是 RetrieveKnowledgeTool 实例"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
assert isinstance(tool, RetrieveKnowledgeTool)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRetrieveKnowledgeToolExecution ─────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveKnowledgeToolExecution:
|
||||||
|
"""RetrieveKnowledgeTool 执行测试"""
|
||||||
|
|
||||||
|
async def test_execute_calls_retriever_retrieve(self):
|
||||||
|
"""execute() 调用 MemoryRetriever.retrieve()"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store("s1", "AI趋势报告", metadata={"source": "report.pdf"})
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="AI趋势")
|
||||||
|
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) >= 1
|
||||||
|
|
||||||
|
async def test_execute_results_formatted_correctly(self):
|
||||||
|
"""结果包含 content, score, source, document_title"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store(
|
||||||
|
"s1",
|
||||||
|
"AI趋势报告内容",
|
||||||
|
metadata={"source": "report.pdf", "document_title": "2024 AI Report"},
|
||||||
|
)
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="AI趋势")
|
||||||
|
|
||||||
|
assert "results" in result
|
||||||
|
for item in result["results"]:
|
||||||
|
assert "content" in item
|
||||||
|
assert "score" in item
|
||||||
|
assert "source" in item
|
||||||
|
assert "document_title" in item
|
||||||
|
|
||||||
|
async def test_execute_empty_query_returns_error(self):
|
||||||
|
"""空 query 返回错误"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="")
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert result["results"] == []
|
||||||
|
|
||||||
|
async def test_execute_max_calls_limit(self):
|
||||||
|
"""超过 max_calls 限制后返回错误"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store("s1", "Some content")
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool(max_calls=3)
|
||||||
|
|
||||||
|
# 前 3 次调用应该成功
|
||||||
|
for i in range(3):
|
||||||
|
result = await tool.execute(query="content")
|
||||||
|
assert "error" not in result or result.get("call_count") == i + 1
|
||||||
|
|
||||||
|
# 第 4 次调用应该返回错误
|
||||||
|
result = await tool.execute(query="content")
|
||||||
|
assert "error" in result
|
||||||
|
assert "Maximum retrieval calls" in result["error"]
|
||||||
|
assert result["results"] == []
|
||||||
|
|
||||||
|
async def test_execute_call_count_tracking(self):
|
||||||
|
"""call_count 在响应中正确跟踪"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store("s1", "Some content")
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool(max_calls=5)
|
||||||
|
|
||||||
|
for i in range(1, 4):
|
||||||
|
result = await tool.execute(query="content")
|
||||||
|
assert result["call_count"] == i
|
||||||
|
|
||||||
|
async def test_execute_exception_handling(self):
|
||||||
|
"""retriever 抛出异常时返回错误响应"""
|
||||||
|
retriever = MemoryRetriever(semantic_memory=InMemoryMemory())
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
# Mock retriever.retrieve to raise exception
|
||||||
|
tool._retriever.retrieve = AsyncMock(side_effect=Exception("Service unavailable"))
|
||||||
|
|
||||||
|
result = await tool.execute(query="test")
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert "Service unavailable" in result["error"]
|
||||||
|
assert result["results"] == []
|
||||||
|
|
||||||
|
async def test_execute_returns_query_in_response(self):
|
||||||
|
"""响应中包含原始查询"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store("s1", "Some content")
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="AI趋势")
|
||||||
|
|
||||||
|
assert result["query"] == "AI趋势"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRetrieveKnowledgeToolAutoRegistration ──────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveKnowledgeToolAutoRegistration:
|
||||||
|
"""RetrieveKnowledgeTool 自动注册测试"""
|
||||||
|
|
||||||
|
def test_agent_with_semantic_memory_has_tool(self):
|
||||||
|
"""ConfigDrivenAgent 配置了 semantic memory 时自动注册 retrieve_knowledge"""
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict({
|
||||||
|
"name": "test_agent",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {
|
||||||
|
"identity": "Test agent",
|
||||||
|
"instructions": "Test",
|
||||||
|
},
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://localhost:8080",
|
||||||
|
"knowledge_base_ids": ["kb1"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
# Patch imports inside the try block of ConfigDrivenAgent.__init__
|
||||||
|
with patch("agentkit.memory.http_rag.HttpRAGService") as mock_rag, \
|
||||||
|
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
|
||||||
|
mock_sem.return_value = InMemoryMemory()
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
tool_names = [t.name for t in agent._tools]
|
||||||
|
assert "retrieve_knowledge" in tool_names
|
||||||
|
|
||||||
|
def test_agent_without_semantic_memory_does_not_have_tool(self):
|
||||||
|
"""ConfigDrivenAgent 未配置 semantic memory 时不注册 retrieve_knowledge"""
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict({
|
||||||
|
"name": "test_agent",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {
|
||||||
|
"identity": "Test agent",
|
||||||
|
"instructions": "Test",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
tool_names = [t.name for t in agent._tools]
|
||||||
|
assert "retrieve_knowledge" not in tool_names
|
||||||
|
|
||||||
|
def test_auto_registered_tool_is_retrieve_knowledge_instance(self):
|
||||||
|
"""自动注册的工具是 RetrieveKnowledgeTool 实例"""
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict({
|
||||||
|
"name": "test_agent",
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {
|
||||||
|
"identity": "Test agent",
|
||||||
|
"instructions": "Test",
|
||||||
|
},
|
||||||
|
"memory": {
|
||||||
|
"semantic": {
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": "http://localhost:8080",
|
||||||
|
"knowledge_base_ids": ["kb1"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch("agentkit.memory.http_rag.HttpRAGService"), \
|
||||||
|
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
|
||||||
|
mock_sem.return_value = InMemoryMemory()
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
retrieve_tools = [t for t in agent._tools if t.name == "retrieve_knowledge"]
|
||||||
|
assert len(retrieve_tools) == 1
|
||||||
|
assert isinstance(retrieve_tools[0], RetrieveKnowledgeTool)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRetrieveKnowledgeToolIntegration ───────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrieveKnowledgeToolIntegration:
|
||||||
|
"""RetrieveKnowledgeTool 集成测试"""
|
||||||
|
|
||||||
|
async def test_tool_works_with_query_transformer(self):
|
||||||
|
"""工具配合 query transformer 工作"""
|
||||||
|
from agentkit.memory.query_transformer import QueryTransformerBase, TransformedQuery
|
||||||
|
|
||||||
|
class SimpleTransformer(QueryTransformerBase):
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=f"enhanced: {query}",
|
||||||
|
sub_queries=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store("s1", "enhanced: AI trends data")
|
||||||
|
retriever = MemoryRetriever(
|
||||||
|
semantic_memory=semantic,
|
||||||
|
query_transformer=SimpleTransformer(),
|
||||||
|
)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="AI")
|
||||||
|
|
||||||
|
assert "results" in result
|
||||||
|
|
||||||
|
async def test_tool_returns_structured_results_for_llm(self):
|
||||||
|
"""工具返回 LLM 可用的结构化结果"""
|
||||||
|
semantic = InMemoryMemory()
|
||||||
|
await semantic.store(
|
||||||
|
"s1",
|
||||||
|
"GEO optimization improves brand visibility",
|
||||||
|
metadata={"source": "guide.md", "document_title": "GEO Guide"},
|
||||||
|
)
|
||||||
|
await semantic.store(
|
||||||
|
"s2",
|
||||||
|
"Another relevant document about SEO",
|
||||||
|
metadata={"source": "seo.md", "document_title": "SEO Basics"},
|
||||||
|
)
|
||||||
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
||||||
|
tool = retriever.create_retrieve_tool()
|
||||||
|
|
||||||
|
result = await tool.execute(query="optimization")
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "query" in result
|
||||||
|
assert "results" in result
|
||||||
|
assert "call_count" in result
|
||||||
|
assert isinstance(result["results"], list)
|
||||||
|
for item in result["results"]:
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "content" in item
|
||||||
|
assert "score" in item
|
||||||
Loading…
Reference in New Issue