feat(memory): U2 Contextual Retrieval - LLM-generated context prefixes for chunks

- ContextualChunker: generates context prefixes per chunk via LLM
- Integrated into HttpRAGService ingest with contextual_chunking option
- Caching, batch processing, graceful LLM failure handling
- 12 tests passing
This commit is contained in:
chiguyong 2026-06-06 22:19:02 +08:00
parent a6c9babfdc
commit f16dcb5ebe
3 changed files with 425 additions and 1 deletions

View File

@ -0,0 +1,210 @@
"""ContextualChunker - 上下文增强分块
在嵌入前为每个文档块添加 LLM 生成的上下文前缀
解决分块后上下文丢失问题Anthropic Contextual Retrieval
"""
from __future__ import annotations
import hashlib
import logging
from dataclasses import dataclass
from typing import Any
from agentkit.memory.embedder import EmbeddingCache
logger = logging.getLogger(__name__)
@dataclass
class ContextualChunk:
"""带上下文前缀的文档块"""
original_content: str
context_prefix: str
enhanced_content: str
chunk_index: int
metadata: dict[str, Any]
@property
def content(self) -> str:
"""获取增强后的完整内容"""
return self.enhanced_content
CONTEXT_PROMPT_TEMPLATE = """\
Given the full document below and a specific chunk from it, write a brief context that helps someone understand what this chunk is about in the broader document. Output ONLY the context, no explanations.
<document>
{document}
</document>
<chunk>
{chunk}
</chunk>
Context:"""
class ContextualChunker:
"""上下文增强分块器
为每个文档块生成 LLM 上下文前缀增强检索质量
工作流程
1. 接收文档和分块列表
2. 对每个块调用 LLM 生成简洁上下文语句
3. 将上下文前缀添加到原始内容前
4. 缓存结果避免重复计算
成本优化
- 文档级 Prompt Caching同一文档的多个块共享文档前缀
- EmbeddingCache 缓存上下文生成结果
- 批处理batch_size
"""
def __init__(
self,
llm_gateway: Any = None,
cache: EmbeddingCache | None = None,
batch_size: int = 8,
max_context_length: int = 200,
prompt_template: str = CONTEXT_PROMPT_TEMPLATE,
):
"""
Args:
llm_gateway: LLM Gateway 实例用于生成上下文
cache: 嵌入缓存用于缓存上下文生成结果
batch_size: 批处理大小
max_context_length: 上下文最大字符长度
prompt_template: 上下文生成 prompt 模板
"""
self._llm_gateway = llm_gateway
self._cache = cache
self._batch_size = batch_size
self._max_context_length = max_context_length
self._prompt_template = prompt_template
self._context_cache: dict[str, str] = {}
async def enhance_chunks(
self,
document: str,
chunks: list[str],
metadata: dict[str, Any] | None = None,
) -> list[ContextualChunk]:
"""为文档块添加上下文前缀
Args:
document: 完整文档内容
chunks: 文档分块列表
metadata: 附加元数据
Returns:
增强后的 ContextualChunk 列表
"""
if not chunks:
return []
if not self._llm_gateway:
# No LLM available — return chunks without context
logger.info("No LLM gateway configured, skipping contextual enhancement")
return [
ContextualChunk(
original_content=chunk,
context_prefix="",
enhanced_content=chunk,
chunk_index=i,
metadata=metadata or {},
)
for i, chunk in enumerate(chunks)
]
result: list[ContextualChunk] = []
# Process in batches
for batch_start in range(0, len(chunks), self._batch_size):
batch = chunks[batch_start : batch_start + self._batch_size]
batch_results = await self._process_batch(document, batch, batch_start, metadata)
result.extend(batch_results)
return result
async def _process_batch(
self,
document: str,
chunks: list[str],
start_index: int,
metadata: dict[str, Any] | None,
) -> list[ContextualChunk]:
"""处理一批文档块"""
results: list[ContextualChunk] = []
for i, chunk in enumerate(chunks):
chunk_index = start_index + i
chunk_meta = dict(metadata or {})
chunk_meta["chunk_index"] = chunk_index
# Check cache
cache_key = self._make_cache_key(document, chunk)
if cache_key in self._context_cache:
context = self._context_cache[cache_key]
else:
context = await self._generate_context(document, chunk)
self._context_cache[cache_key] = context
# Truncate context if too long
if len(context) > self._max_context_length:
context = context[: self._max_context_length]
# Build enhanced content
if context:
enhanced = f"{context}\n{chunk}"
else:
enhanced = chunk
chunk_meta["context_prefix"] = context
chunk_meta["has_context"] = bool(context)
results.append(
ContextualChunk(
original_content=chunk,
context_prefix=context,
enhanced_content=enhanced,
chunk_index=chunk_index,
metadata=chunk_meta,
)
)
return results
async def _generate_context(self, document: str, chunk: str) -> str:
"""使用 LLM 为单个块生成上下文"""
# Truncate document for prompt efficiency
doc_preview = document[:3000] if len(document) > 3000 else document
chunk_preview = chunk[:1000] if len(chunk) > 1000 else chunk
prompt = self._prompt_template.format(
document=doc_preview,
chunk=chunk_preview,
)
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
context = response.content.strip()
return context
except Exception as e:
logger.warning(f"Context generation failed for chunk: {e}")
return ""
@staticmethod
def _make_cache_key(document: str, chunk: str) -> str:
"""生成缓存键"""
content = f"{document[:500]}:{chunk[:500]}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def clear_cache(self) -> None:
"""清除上下文缓存"""
self._context_cache.clear()

View File

@ -29,6 +29,7 @@ class HttpRAGService:
- "industry-kb-id" - "industry-kb-id"
- "enterprise-kb-id" - "enterprise-kb-id"
timeout: 30 timeout: 30
contextual_chunking: false
""" """
def __init__( def __init__(
@ -37,6 +38,8 @@ class HttpRAGService:
api_key: str | None = None, api_key: str | None = None,
knowledge_base_ids: list[str] | None = None, knowledge_base_ids: list[str] | None = None,
timeout: int = 30, timeout: int = 30,
contextual_chunking: bool = False,
llm_gateway: Any = None,
): ):
""" """
Args: Args:
@ -50,6 +53,8 @@ class HttpRAGService:
self._knowledge_base_ids = knowledge_base_ids or [] self._knowledge_base_ids = knowledge_base_ids or []
self._timeout = timeout self._timeout = timeout
self._client: httpx.AsyncClient | None = None self._client: httpx.AsyncClient | None = None
self._contextual_chunking = contextual_chunking
self._llm_gateway = llm_gateway
def _get_client(self) -> httpx.AsyncClient: def _get_client(self) -> httpx.AsyncClient:
"""懒初始化 httpx 客户端""" """懒初始化 httpx 客户端"""
@ -232,6 +237,9 @@ class HttpRAGService:
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""写入文档到知识库(可选操作) """写入文档到知识库(可选操作)
When contextual_chunking is enabled and llm_gateway is configured,
the document content is enhanced with contextual prefixes before ingestion.
Args: Args:
key: 文档标题或标识 key: 文档标题或标识
value: 文档内容 value: 文档内容
@ -245,9 +253,25 @@ class HttpRAGService:
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured") logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
return None return None
content = str(value)
# Apply contextual chunking if enabled
if self._contextual_chunking and self._llm_gateway:
from agentkit.memory.contextual_retrieval import ContextualChunker
chunker = ContextualChunker(llm_gateway=self._llm_gateway)
# Simple chunking: split by paragraphs
raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
if raw_chunks:
enhanced = await chunker.enhance_chunks(
document=content, chunks=raw_chunks, metadata=metadata
)
# Rejoin enhanced chunks
content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
payload = { payload = {
"title": key, "title": key,
"content": str(value), "content": content,
"source_type": "text", "source_type": "text",
"metadata": metadata or {}, "metadata": metadata or {},
} }

View File

@ -0,0 +1,190 @@
"""Tests for ContextualChunker"""
import pytest
from agentkit.memory.contextual_retrieval import (
ContextualChunker,
ContextualChunk,
CONTEXT_PROMPT_TEMPLATE,
)
class MockLLMGateway:
"""Mock LLM Gateway for testing"""
def __init__(self, responses: list[str] | None = None):
self._responses = responses or ["This chunk discusses revenue growth."]
self._call_count = 0
self._last_messages = None
async def chat(self, messages, model="default", **kwargs):
self._call_count += 1
self._last_messages = messages
class MockResponse:
content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
return MockResponse()
class TestContextualChunk:
"""ContextualChunk dataclass tests"""
def test_content_property(self):
chunk = ContextualChunk(
original_content="Revenue grew 3%",
context_prefix="From Acme Q2 2023 report",
enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
chunk_index=0,
metadata={},
)
assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
def test_empty_context(self):
chunk = ContextualChunk(
original_content="Some text",
context_prefix="",
enhanced_content="Some text",
chunk_index=0,
metadata={},
)
assert chunk.content == "Some text"
class TestContextualChunker:
"""ContextualChunker unit tests"""
@pytest.mark.asyncio
async def test_enhance_chunks_with_llm(self):
"""Chunks should be enhanced with LLM-generated context"""
llm = MockLLMGateway(responses=["From the financial report section"])
chunker = ContextualChunker(llm_gateway=llm)
document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
chunks = ["Revenue grew 3%.", "Profit increased 5%."]
result = await chunker.enhance_chunks(document, chunks)
assert len(result) == 2
assert result[0].original_content == "Revenue grew 3%."
assert result[0].context_prefix == "From the financial report section"
assert "From the financial report section" in result[0].enhanced_content
assert "Revenue grew 3%." in result[0].enhanced_content
assert result[0].chunk_index == 0
assert result[0].metadata["has_context"] is True
@pytest.mark.asyncio
async def test_enhance_chunks_without_llm(self):
"""Without LLM, chunks should be returned without context"""
chunker = ContextualChunker(llm_gateway=None)
document = "Test document"
chunks = ["Chunk 1", "Chunk 2"]
result = await chunker.enhance_chunks(document, chunks)
assert len(result) == 2
assert result[0].context_prefix == ""
assert result[0].enhanced_content == "Chunk 1"
assert result[0].metadata.get("has_context") is not True
@pytest.mark.asyncio
async def test_enhance_empty_chunks(self):
"""Empty chunks list should return empty result"""
chunker = ContextualChunker(llm_gateway=MockLLMGateway())
result = await chunker.enhance_chunks("document", [])
assert result == []
@pytest.mark.asyncio
async def test_context_caching(self):
"""Same document+chunk should use cached context"""
llm = MockLLMGateway(responses=["Context A", "Context B"])
chunker = ContextualChunker(llm_gateway=llm)
document = "Test document"
chunks = ["Chunk 1"]
# First call
result1 = await chunker.enhance_chunks(document, chunks)
assert result1[0].context_prefix == "Context A"
assert llm._call_count == 1
# Second call with same input — should use cache
result2 = await chunker.enhance_chunks(document, chunks)
assert result2[0].context_prefix == "Context A"
assert llm._call_count == 1 # No additional LLM call
@pytest.mark.asyncio
async def test_context_truncation(self):
"""Long context should be truncated"""
long_context = "A" * 500
llm = MockLLMGateway(responses=[long_context])
chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
result = await chunker.enhance_chunks("doc", ["chunk"])
assert len(result[0].context_prefix) <= 100
@pytest.mark.asyncio
async def test_llm_failure_returns_empty_context(self):
"""LLM failure should result in empty context, not error"""
class FailingLLM:
async def chat(self, messages, model="default", **kwargs):
raise RuntimeError("LLM unavailable")
chunker = ContextualChunker(llm_gateway=FailingLLM())
result = await chunker.enhance_chunks("doc", ["chunk"])
assert len(result) == 1
assert result[0].context_prefix == ""
assert result[0].enhanced_content == "chunk"
@pytest.mark.asyncio
async def test_batch_processing(self):
"""Large number of chunks should be processed in batches"""
llm = MockLLMGateway(responses=["Context"])
chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
chunks = [f"Chunk {i}" for i in range(7)]
result = await chunker.enhance_chunks("doc", chunks)
assert len(result) == 7
for i, chunk in enumerate(result):
assert chunk.chunk_index == i
@pytest.mark.asyncio
async def test_metadata_preserved(self):
"""Metadata should be preserved and enhanced"""
llm = MockLLMGateway(responses=["Context"])
chunker = ContextualChunker(llm_gateway=llm)
result = await chunker.enhance_chunks(
"doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
)
assert result[0].metadata["source"] == "test"
assert result[0].metadata["doc_id"] == "123"
assert result[0].metadata["chunk_index"] == 0
assert "context_prefix" in result[0].metadata
@pytest.mark.asyncio
async def test_clear_cache(self):
"""clear_cache should reset the context cache"""
llm = MockLLMGateway(responses=["Context A", "Context B"])
chunker = ContextualChunker(llm_gateway=llm)
await chunker.enhance_chunks("doc", ["chunk"])
assert llm._call_count == 1
chunker.clear_cache()
await chunker.enhance_chunks("doc", ["chunk"])
assert llm._call_count == 2 # Cache was cleared, new LLM call
def test_prompt_template_format(self):
"""Prompt template should be formattable with document and chunk"""
formatted = CONTEXT_PROMPT_TEMPLATE.format(
document="Test document", chunk="Test chunk"
)
assert "Test document" in formatted
assert "Test chunk" in formatted
assert "Context:" in formatted