diff --git a/src/agentkit/memory/contextual_retrieval.py b/src/agentkit/memory/contextual_retrieval.py
new file mode 100644
index 0000000..93eb47f
--- /dev/null
+++ b/src/agentkit/memory/contextual_retrieval.py
@@ -0,0 +1,210 @@
+"""ContextualChunker - 上下文增强分块
+
+在嵌入前为每个文档块添加 LLM 生成的上下文前缀,
+解决分块后上下文丢失问题(Anthropic Contextual Retrieval)。
+"""
+
+from __future__ import annotations
+
+import hashlib
+import logging
+from dataclasses import dataclass
+from typing import Any
+
+from agentkit.memory.embedder import EmbeddingCache
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ContextualChunk:
+ """带上下文前缀的文档块"""
+
+ original_content: str
+ context_prefix: str
+ enhanced_content: str
+ chunk_index: int
+ metadata: dict[str, Any]
+
+ @property
+ def content(self) -> str:
+ """获取增强后的完整内容"""
+ return self.enhanced_content
+
+
+CONTEXT_PROMPT_TEMPLATE = """\
+Given the full document below and a specific chunk from it, write a brief context that helps someone understand what this chunk is about in the broader document. Output ONLY the context, no explanations.
+
+
+{document}
+
+
+
+{chunk}
+
+
+Context:"""
+
+
+class ContextualChunker:
+ """上下文增强分块器
+
+ 为每个文档块生成 LLM 上下文前缀,增强检索质量。
+
+ 工作流程:
+ 1. 接收文档和分块列表
+ 2. 对每个块,调用 LLM 生成简洁上下文语句
+ 3. 将上下文前缀添加到原始内容前
+ 4. 缓存结果避免重复计算
+
+ 成本优化:
+ - 文档级 Prompt Caching(同一文档的多个块共享文档前缀)
+ - EmbeddingCache 缓存上下文生成结果
+ - 批处理(batch_size)
+ """
+
+ def __init__(
+ self,
+ llm_gateway: Any = None,
+ cache: EmbeddingCache | None = None,
+ batch_size: int = 8,
+ max_context_length: int = 200,
+ prompt_template: str = CONTEXT_PROMPT_TEMPLATE,
+ ):
+ """
+ Args:
+ llm_gateway: LLM Gateway 实例,用于生成上下文
+ cache: 嵌入缓存,用于缓存上下文生成结果
+ batch_size: 批处理大小
+ max_context_length: 上下文最大字符长度
+ prompt_template: 上下文生成 prompt 模板
+ """
+ self._llm_gateway = llm_gateway
+ self._cache = cache
+ self._batch_size = batch_size
+ self._max_context_length = max_context_length
+ self._prompt_template = prompt_template
+ self._context_cache: dict[str, str] = {}
+
+ async def enhance_chunks(
+ self,
+ document: str,
+ chunks: list[str],
+ metadata: dict[str, Any] | None = None,
+ ) -> list[ContextualChunk]:
+ """为文档块添加上下文前缀
+
+ Args:
+ document: 完整文档内容
+ chunks: 文档分块列表
+ metadata: 附加元数据
+
+ Returns:
+ 增强后的 ContextualChunk 列表
+ """
+ if not chunks:
+ return []
+
+ if not self._llm_gateway:
+ # No LLM available — return chunks without context
+ logger.info("No LLM gateway configured, skipping contextual enhancement")
+ return [
+ ContextualChunk(
+ original_content=chunk,
+ context_prefix="",
+ enhanced_content=chunk,
+ chunk_index=i,
+ metadata=metadata or {},
+ )
+ for i, chunk in enumerate(chunks)
+ ]
+
+ result: list[ContextualChunk] = []
+
+ # Process in batches
+ for batch_start in range(0, len(chunks), self._batch_size):
+ batch = chunks[batch_start : batch_start + self._batch_size]
+ batch_results = await self._process_batch(document, batch, batch_start, metadata)
+ result.extend(batch_results)
+
+ return result
+
+ async def _process_batch(
+ self,
+ document: str,
+ chunks: list[str],
+ start_index: int,
+ metadata: dict[str, Any] | None,
+ ) -> list[ContextualChunk]:
+ """处理一批文档块"""
+ results: list[ContextualChunk] = []
+
+ for i, chunk in enumerate(chunks):
+ chunk_index = start_index + i
+ chunk_meta = dict(metadata or {})
+ chunk_meta["chunk_index"] = chunk_index
+
+ # Check cache
+ cache_key = self._make_cache_key(document, chunk)
+ if cache_key in self._context_cache:
+ context = self._context_cache[cache_key]
+ else:
+ context = await self._generate_context(document, chunk)
+ self._context_cache[cache_key] = context
+
+ # Truncate context if too long
+ if len(context) > self._max_context_length:
+ context = context[: self._max_context_length]
+
+ # Build enhanced content
+ if context:
+ enhanced = f"{context}\n{chunk}"
+ else:
+ enhanced = chunk
+
+ chunk_meta["context_prefix"] = context
+ chunk_meta["has_context"] = bool(context)
+
+ results.append(
+ ContextualChunk(
+ original_content=chunk,
+ context_prefix=context,
+ enhanced_content=enhanced,
+ chunk_index=chunk_index,
+ metadata=chunk_meta,
+ )
+ )
+
+ return results
+
+ async def _generate_context(self, document: str, chunk: str) -> str:
+ """使用 LLM 为单个块生成上下文"""
+ # Truncate document for prompt efficiency
+ doc_preview = document[:3000] if len(document) > 3000 else document
+ chunk_preview = chunk[:1000] if len(chunk) > 1000 else chunk
+
+ prompt = self._prompt_template.format(
+ document=doc_preview,
+ chunk=chunk_preview,
+ )
+
+ try:
+ response = await self._llm_gateway.chat(
+ messages=[{"role": "user", "content": prompt}],
+ model="default",
+ )
+ context = response.content.strip()
+ return context
+ except Exception as e:
+ logger.warning(f"Context generation failed for chunk: {e}")
+ return ""
+
+ @staticmethod
+ def _make_cache_key(document: str, chunk: str) -> str:
+ """生成缓存键"""
+ content = f"{document[:500]}:{chunk[:500]}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+ def clear_cache(self) -> None:
+ """清除上下文缓存"""
+ self._context_cache.clear()
diff --git a/src/agentkit/memory/http_rag.py b/src/agentkit/memory/http_rag.py
index b0ed246..2e4d94f 100644
--- a/src/agentkit/memory/http_rag.py
+++ b/src/agentkit/memory/http_rag.py
@@ -29,6 +29,7 @@ class HttpRAGService:
- "industry-kb-id"
- "enterprise-kb-id"
timeout: 30
+ contextual_chunking: false
"""
def __init__(
@@ -37,6 +38,8 @@ class HttpRAGService:
api_key: str | None = None,
knowledge_base_ids: list[str] | None = None,
timeout: int = 30,
+ contextual_chunking: bool = False,
+ llm_gateway: Any = None,
):
"""
Args:
@@ -50,6 +53,8 @@ class HttpRAGService:
self._knowledge_base_ids = knowledge_base_ids or []
self._timeout = timeout
self._client: httpx.AsyncClient | None = None
+ self._contextual_chunking = contextual_chunking
+ self._llm_gateway = llm_gateway
def _get_client(self) -> httpx.AsyncClient:
"""懒初始化 httpx 客户端"""
@@ -232,6 +237,9 @@ class HttpRAGService:
) -> dict[str, Any] | None:
"""写入文档到知识库(可选操作)
+ When contextual_chunking is enabled and llm_gateway is configured,
+ the document content is enhanced with contextual prefixes before ingestion.
+
Args:
key: 文档标题或标识
value: 文档内容
@@ -245,9 +253,25 @@ class HttpRAGService:
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
return None
+ content = str(value)
+
+ # Apply contextual chunking if enabled
+ if self._contextual_chunking and self._llm_gateway:
+ from agentkit.memory.contextual_retrieval import ContextualChunker
+
+ chunker = ContextualChunker(llm_gateway=self._llm_gateway)
+ # Simple chunking: split by paragraphs
+ raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
+ if raw_chunks:
+ enhanced = await chunker.enhance_chunks(
+ document=content, chunks=raw_chunks, metadata=metadata
+ )
+ # Rejoin enhanced chunks
+ content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
+
payload = {
"title": key,
- "content": str(value),
+ "content": content,
"source_type": "text",
"metadata": metadata or {},
}
diff --git a/tests/unit/test_contextual_retrieval.py b/tests/unit/test_contextual_retrieval.py
new file mode 100644
index 0000000..e139222
--- /dev/null
+++ b/tests/unit/test_contextual_retrieval.py
@@ -0,0 +1,190 @@
+"""Tests for ContextualChunker"""
+
+import pytest
+
+from agentkit.memory.contextual_retrieval import (
+ ContextualChunker,
+ ContextualChunk,
+ CONTEXT_PROMPT_TEMPLATE,
+)
+
+
+class MockLLMGateway:
+ """Mock LLM Gateway for testing"""
+
+ def __init__(self, responses: list[str] | None = None):
+ self._responses = responses or ["This chunk discusses revenue growth."]
+ self._call_count = 0
+ self._last_messages = None
+
+ async def chat(self, messages, model="default", **kwargs):
+ self._call_count += 1
+ self._last_messages = messages
+
+ class MockResponse:
+ content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
+
+ return MockResponse()
+
+
+class TestContextualChunk:
+ """ContextualChunk dataclass tests"""
+
+ def test_content_property(self):
+ chunk = ContextualChunk(
+ original_content="Revenue grew 3%",
+ context_prefix="From Acme Q2 2023 report",
+ enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
+ chunk_index=0,
+ metadata={},
+ )
+ assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
+
+ def test_empty_context(self):
+ chunk = ContextualChunk(
+ original_content="Some text",
+ context_prefix="",
+ enhanced_content="Some text",
+ chunk_index=0,
+ metadata={},
+ )
+ assert chunk.content == "Some text"
+
+
+class TestContextualChunker:
+ """ContextualChunker unit tests"""
+
+ @pytest.mark.asyncio
+ async def test_enhance_chunks_with_llm(self):
+ """Chunks should be enhanced with LLM-generated context"""
+ llm = MockLLMGateway(responses=["From the financial report section"])
+ chunker = ContextualChunker(llm_gateway=llm)
+
+ document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
+ chunks = ["Revenue grew 3%.", "Profit increased 5%."]
+
+ result = await chunker.enhance_chunks(document, chunks)
+
+ assert len(result) == 2
+ assert result[0].original_content == "Revenue grew 3%."
+ assert result[0].context_prefix == "From the financial report section"
+ assert "From the financial report section" in result[0].enhanced_content
+ assert "Revenue grew 3%." in result[0].enhanced_content
+ assert result[0].chunk_index == 0
+ assert result[0].metadata["has_context"] is True
+
+ @pytest.mark.asyncio
+ async def test_enhance_chunks_without_llm(self):
+ """Without LLM, chunks should be returned without context"""
+ chunker = ContextualChunker(llm_gateway=None)
+
+ document = "Test document"
+ chunks = ["Chunk 1", "Chunk 2"]
+
+ result = await chunker.enhance_chunks(document, chunks)
+
+ assert len(result) == 2
+ assert result[0].context_prefix == ""
+ assert result[0].enhanced_content == "Chunk 1"
+ assert result[0].metadata.get("has_context") is not True
+
+ @pytest.mark.asyncio
+ async def test_enhance_empty_chunks(self):
+ """Empty chunks list should return empty result"""
+ chunker = ContextualChunker(llm_gateway=MockLLMGateway())
+ result = await chunker.enhance_chunks("document", [])
+ assert result == []
+
+ @pytest.mark.asyncio
+ async def test_context_caching(self):
+ """Same document+chunk should use cached context"""
+ llm = MockLLMGateway(responses=["Context A", "Context B"])
+ chunker = ContextualChunker(llm_gateway=llm)
+
+ document = "Test document"
+ chunks = ["Chunk 1"]
+
+ # First call
+ result1 = await chunker.enhance_chunks(document, chunks)
+ assert result1[0].context_prefix == "Context A"
+ assert llm._call_count == 1
+
+ # Second call with same input — should use cache
+ result2 = await chunker.enhance_chunks(document, chunks)
+ assert result2[0].context_prefix == "Context A"
+ assert llm._call_count == 1 # No additional LLM call
+
+ @pytest.mark.asyncio
+ async def test_context_truncation(self):
+ """Long context should be truncated"""
+ long_context = "A" * 500
+ llm = MockLLMGateway(responses=[long_context])
+ chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
+
+ result = await chunker.enhance_chunks("doc", ["chunk"])
+ assert len(result[0].context_prefix) <= 100
+
+ @pytest.mark.asyncio
+ async def test_llm_failure_returns_empty_context(self):
+ """LLM failure should result in empty context, not error"""
+ class FailingLLM:
+ async def chat(self, messages, model="default", **kwargs):
+ raise RuntimeError("LLM unavailable")
+
+ chunker = ContextualChunker(llm_gateway=FailingLLM())
+ result = await chunker.enhance_chunks("doc", ["chunk"])
+
+ assert len(result) == 1
+ assert result[0].context_prefix == ""
+ assert result[0].enhanced_content == "chunk"
+
+ @pytest.mark.asyncio
+ async def test_batch_processing(self):
+ """Large number of chunks should be processed in batches"""
+ llm = MockLLMGateway(responses=["Context"])
+ chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
+
+ chunks = [f"Chunk {i}" for i in range(7)]
+ result = await chunker.enhance_chunks("doc", chunks)
+
+ assert len(result) == 7
+ for i, chunk in enumerate(result):
+ assert chunk.chunk_index == i
+
+ @pytest.mark.asyncio
+ async def test_metadata_preserved(self):
+ """Metadata should be preserved and enhanced"""
+ llm = MockLLMGateway(responses=["Context"])
+ chunker = ContextualChunker(llm_gateway=llm)
+
+ result = await chunker.enhance_chunks(
+ "doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
+ )
+
+ assert result[0].metadata["source"] == "test"
+ assert result[0].metadata["doc_id"] == "123"
+ assert result[0].metadata["chunk_index"] == 0
+ assert "context_prefix" in result[0].metadata
+
+ @pytest.mark.asyncio
+ async def test_clear_cache(self):
+ """clear_cache should reset the context cache"""
+ llm = MockLLMGateway(responses=["Context A", "Context B"])
+ chunker = ContextualChunker(llm_gateway=llm)
+
+ await chunker.enhance_chunks("doc", ["chunk"])
+ assert llm._call_count == 1
+
+ chunker.clear_cache()
+
+ await chunker.enhance_chunks("doc", ["chunk"])
+ assert llm._call_count == 2 # Cache was cleared, new LLM call
+
+ def test_prompt_template_format(self):
+ """Prompt template should be formattable with document and chunk"""
+ formatted = CONTEXT_PROMPT_TEMPLATE.format(
+ document="Test document", chunk="Test chunk"
+ )
+ assert "Test document" in formatted
+ assert "Test chunk" in formatted
+ assert "Context:" in formatted