feat(memory): U2 Contextual Retrieval - LLM-generated context prefixes for chunks
- ContextualChunker: generates context prefixes per chunk via LLM - Integrated into HttpRAGService ingest with contextual_chunking option - Caching, batch processing, graceful LLM failure handling - 12 tests passing
This commit is contained in:
parent
a6c9babfdc
commit
f16dcb5ebe
|
|
@ -0,0 +1,210 @@
|
|||
"""ContextualChunker - 上下文增强分块
|
||||
|
||||
在嵌入前为每个文档块添加 LLM 生成的上下文前缀,
|
||||
解决分块后上下文丢失问题(Anthropic Contextual Retrieval)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.embedder import EmbeddingCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextualChunk:
|
||||
"""带上下文前缀的文档块"""
|
||||
|
||||
original_content: str
|
||||
context_prefix: str
|
||||
enhanced_content: str
|
||||
chunk_index: int
|
||||
metadata: dict[str, Any]
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""获取增强后的完整内容"""
|
||||
return self.enhanced_content
|
||||
|
||||
|
||||
CONTEXT_PROMPT_TEMPLATE = """\
|
||||
Given the full document below and a specific chunk from it, write a brief context that helps someone understand what this chunk is about in the broader document. Output ONLY the context, no explanations.
|
||||
|
||||
<document>
|
||||
{document}
|
||||
</document>
|
||||
|
||||
<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
|
||||
Context:"""
|
||||
|
||||
|
||||
class ContextualChunker:
|
||||
"""上下文增强分块器
|
||||
|
||||
为每个文档块生成 LLM 上下文前缀,增强检索质量。
|
||||
|
||||
工作流程:
|
||||
1. 接收文档和分块列表
|
||||
2. 对每个块,调用 LLM 生成简洁上下文语句
|
||||
3. 将上下文前缀添加到原始内容前
|
||||
4. 缓存结果避免重复计算
|
||||
|
||||
成本优化:
|
||||
- 文档级 Prompt Caching(同一文档的多个块共享文档前缀)
|
||||
- EmbeddingCache 缓存上下文生成结果
|
||||
- 批处理(batch_size)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_gateway: Any = None,
|
||||
cache: EmbeddingCache | None = None,
|
||||
batch_size: int = 8,
|
||||
max_context_length: int = 200,
|
||||
prompt_template: str = CONTEXT_PROMPT_TEMPLATE,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
llm_gateway: LLM Gateway 实例,用于生成上下文
|
||||
cache: 嵌入缓存,用于缓存上下文生成结果
|
||||
batch_size: 批处理大小
|
||||
max_context_length: 上下文最大字符长度
|
||||
prompt_template: 上下文生成 prompt 模板
|
||||
"""
|
||||
self._llm_gateway = llm_gateway
|
||||
self._cache = cache
|
||||
self._batch_size = batch_size
|
||||
self._max_context_length = max_context_length
|
||||
self._prompt_template = prompt_template
|
||||
self._context_cache: dict[str, str] = {}
|
||||
|
||||
async def enhance_chunks(
|
||||
self,
|
||||
document: str,
|
||||
chunks: list[str],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> list[ContextualChunk]:
|
||||
"""为文档块添加上下文前缀
|
||||
|
||||
Args:
|
||||
document: 完整文档内容
|
||||
chunks: 文档分块列表
|
||||
metadata: 附加元数据
|
||||
|
||||
Returns:
|
||||
增强后的 ContextualChunk 列表
|
||||
"""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
if not self._llm_gateway:
|
||||
# No LLM available — return chunks without context
|
||||
logger.info("No LLM gateway configured, skipping contextual enhancement")
|
||||
return [
|
||||
ContextualChunk(
|
||||
original_content=chunk,
|
||||
context_prefix="",
|
||||
enhanced_content=chunk,
|
||||
chunk_index=i,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
result: list[ContextualChunk] = []
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, len(chunks), self._batch_size):
|
||||
batch = chunks[batch_start : batch_start + self._batch_size]
|
||||
batch_results = await self._process_batch(document, batch, batch_start, metadata)
|
||||
result.extend(batch_results)
|
||||
|
||||
return result
|
||||
|
||||
async def _process_batch(
|
||||
self,
|
||||
document: str,
|
||||
chunks: list[str],
|
||||
start_index: int,
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> list[ContextualChunk]:
|
||||
"""处理一批文档块"""
|
||||
results: list[ContextualChunk] = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_index = start_index + i
|
||||
chunk_meta = dict(metadata or {})
|
||||
chunk_meta["chunk_index"] = chunk_index
|
||||
|
||||
# Check cache
|
||||
cache_key = self._make_cache_key(document, chunk)
|
||||
if cache_key in self._context_cache:
|
||||
context = self._context_cache[cache_key]
|
||||
else:
|
||||
context = await self._generate_context(document, chunk)
|
||||
self._context_cache[cache_key] = context
|
||||
|
||||
# Truncate context if too long
|
||||
if len(context) > self._max_context_length:
|
||||
context = context[: self._max_context_length]
|
||||
|
||||
# Build enhanced content
|
||||
if context:
|
||||
enhanced = f"{context}\n{chunk}"
|
||||
else:
|
||||
enhanced = chunk
|
||||
|
||||
chunk_meta["context_prefix"] = context
|
||||
chunk_meta["has_context"] = bool(context)
|
||||
|
||||
results.append(
|
||||
ContextualChunk(
|
||||
original_content=chunk,
|
||||
context_prefix=context,
|
||||
enhanced_content=enhanced,
|
||||
chunk_index=chunk_index,
|
||||
metadata=chunk_meta,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def _generate_context(self, document: str, chunk: str) -> str:
|
||||
"""使用 LLM 为单个块生成上下文"""
|
||||
# Truncate document for prompt efficiency
|
||||
doc_preview = document[:3000] if len(document) > 3000 else document
|
||||
chunk_preview = chunk[:1000] if len(chunk) > 1000 else chunk
|
||||
|
||||
prompt = self._prompt_template.format(
|
||||
document=doc_preview,
|
||||
chunk=chunk_preview,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model="default",
|
||||
)
|
||||
context = response.content.strip()
|
||||
return context
|
||||
except Exception as e:
|
||||
logger.warning(f"Context generation failed for chunk: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _make_cache_key(document: str, chunk: str) -> str:
|
||||
"""生成缓存键"""
|
||||
content = f"{document[:500]}:{chunk[:500]}"
|
||||
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除上下文缓存"""
|
||||
self._context_cache.clear()
|
||||
|
|
@ -29,6 +29,7 @@ class HttpRAGService:
|
|||
- "industry-kb-id"
|
||||
- "enterprise-kb-id"
|
||||
timeout: 30
|
||||
contextual_chunking: false
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -37,6 +38,8 @@ class HttpRAGService:
|
|||
api_key: str | None = None,
|
||||
knowledge_base_ids: list[str] | None = None,
|
||||
timeout: int = 30,
|
||||
contextual_chunking: bool = False,
|
||||
llm_gateway: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -50,6 +53,8 @@ class HttpRAGService:
|
|||
self._knowledge_base_ids = knowledge_base_ids or []
|
||||
self._timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._contextual_chunking = contextual_chunking
|
||||
self._llm_gateway = llm_gateway
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
"""懒初始化 httpx 客户端"""
|
||||
|
|
@ -232,6 +237,9 @@ class HttpRAGService:
|
|||
) -> dict[str, Any] | None:
|
||||
"""写入文档到知识库(可选操作)
|
||||
|
||||
When contextual_chunking is enabled and llm_gateway is configured,
|
||||
the document content is enhanced with contextual prefixes before ingestion.
|
||||
|
||||
Args:
|
||||
key: 文档标题或标识
|
||||
value: 文档内容
|
||||
|
|
@ -245,9 +253,25 @@ class HttpRAGService:
|
|||
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
||||
return None
|
||||
|
||||
content = str(value)
|
||||
|
||||
# Apply contextual chunking if enabled
|
||||
if self._contextual_chunking and self._llm_gateway:
|
||||
from agentkit.memory.contextual_retrieval import ContextualChunker
|
||||
|
||||
chunker = ContextualChunker(llm_gateway=self._llm_gateway)
|
||||
# Simple chunking: split by paragraphs
|
||||
raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
|
||||
if raw_chunks:
|
||||
enhanced = await chunker.enhance_chunks(
|
||||
document=content, chunks=raw_chunks, metadata=metadata
|
||||
)
|
||||
# Rejoin enhanced chunks
|
||||
content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
|
||||
|
||||
payload = {
|
||||
"title": key,
|
||||
"content": str(value),
|
||||
"content": content,
|
||||
"source_type": "text",
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,190 @@
|
|||
"""Tests for ContextualChunker"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.memory.contextual_retrieval import (
|
||||
ContextualChunker,
|
||||
ContextualChunk,
|
||||
CONTEXT_PROMPT_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
class MockLLMGateway:
|
||||
"""Mock LLM Gateway for testing"""
|
||||
|
||||
def __init__(self, responses: list[str] | None = None):
|
||||
self._responses = responses or ["This chunk discusses revenue growth."]
|
||||
self._call_count = 0
|
||||
self._last_messages = None
|
||||
|
||||
async def chat(self, messages, model="default", **kwargs):
|
||||
self._call_count += 1
|
||||
self._last_messages = messages
|
||||
|
||||
class MockResponse:
|
||||
content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
class TestContextualChunk:
|
||||
"""ContextualChunk dataclass tests"""
|
||||
|
||||
def test_content_property(self):
|
||||
chunk = ContextualChunk(
|
||||
original_content="Revenue grew 3%",
|
||||
context_prefix="From Acme Q2 2023 report",
|
||||
enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
|
||||
chunk_index=0,
|
||||
metadata={},
|
||||
)
|
||||
assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
|
||||
|
||||
def test_empty_context(self):
|
||||
chunk = ContextualChunk(
|
||||
original_content="Some text",
|
||||
context_prefix="",
|
||||
enhanced_content="Some text",
|
||||
chunk_index=0,
|
||||
metadata={},
|
||||
)
|
||||
assert chunk.content == "Some text"
|
||||
|
||||
|
||||
class TestContextualChunker:
|
||||
"""ContextualChunker unit tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enhance_chunks_with_llm(self):
|
||||
"""Chunks should be enhanced with LLM-generated context"""
|
||||
llm = MockLLMGateway(responses=["From the financial report section"])
|
||||
chunker = ContextualChunker(llm_gateway=llm)
|
||||
|
||||
document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
|
||||
chunks = ["Revenue grew 3%.", "Profit increased 5%."]
|
||||
|
||||
result = await chunker.enhance_chunks(document, chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].original_content == "Revenue grew 3%."
|
||||
assert result[0].context_prefix == "From the financial report section"
|
||||
assert "From the financial report section" in result[0].enhanced_content
|
||||
assert "Revenue grew 3%." in result[0].enhanced_content
|
||||
assert result[0].chunk_index == 0
|
||||
assert result[0].metadata["has_context"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enhance_chunks_without_llm(self):
|
||||
"""Without LLM, chunks should be returned without context"""
|
||||
chunker = ContextualChunker(llm_gateway=None)
|
||||
|
||||
document = "Test document"
|
||||
chunks = ["Chunk 1", "Chunk 2"]
|
||||
|
||||
result = await chunker.enhance_chunks(document, chunks)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].context_prefix == ""
|
||||
assert result[0].enhanced_content == "Chunk 1"
|
||||
assert result[0].metadata.get("has_context") is not True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enhance_empty_chunks(self):
|
||||
"""Empty chunks list should return empty result"""
|
||||
chunker = ContextualChunker(llm_gateway=MockLLMGateway())
|
||||
result = await chunker.enhance_chunks("document", [])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_caching(self):
|
||||
"""Same document+chunk should use cached context"""
|
||||
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
||||
chunker = ContextualChunker(llm_gateway=llm)
|
||||
|
||||
document = "Test document"
|
||||
chunks = ["Chunk 1"]
|
||||
|
||||
# First call
|
||||
result1 = await chunker.enhance_chunks(document, chunks)
|
||||
assert result1[0].context_prefix == "Context A"
|
||||
assert llm._call_count == 1
|
||||
|
||||
# Second call with same input — should use cache
|
||||
result2 = await chunker.enhance_chunks(document, chunks)
|
||||
assert result2[0].context_prefix == "Context A"
|
||||
assert llm._call_count == 1 # No additional LLM call
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_truncation(self):
|
||||
"""Long context should be truncated"""
|
||||
long_context = "A" * 500
|
||||
llm = MockLLMGateway(responses=[long_context])
|
||||
chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
|
||||
|
||||
result = await chunker.enhance_chunks("doc", ["chunk"])
|
||||
assert len(result[0].context_prefix) <= 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_returns_empty_context(self):
|
||||
"""LLM failure should result in empty context, not error"""
|
||||
class FailingLLM:
|
||||
async def chat(self, messages, model="default", **kwargs):
|
||||
raise RuntimeError("LLM unavailable")
|
||||
|
||||
chunker = ContextualChunker(llm_gateway=FailingLLM())
|
||||
result = await chunker.enhance_chunks("doc", ["chunk"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].context_prefix == ""
|
||||
assert result[0].enhanced_content == "chunk"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_processing(self):
|
||||
"""Large number of chunks should be processed in batches"""
|
||||
llm = MockLLMGateway(responses=["Context"])
|
||||
chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
|
||||
|
||||
chunks = [f"Chunk {i}" for i in range(7)]
|
||||
result = await chunker.enhance_chunks("doc", chunks)
|
||||
|
||||
assert len(result) == 7
|
||||
for i, chunk in enumerate(result):
|
||||
assert chunk.chunk_index == i
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_preserved(self):
|
||||
"""Metadata should be preserved and enhanced"""
|
||||
llm = MockLLMGateway(responses=["Context"])
|
||||
chunker = ContextualChunker(llm_gateway=llm)
|
||||
|
||||
result = await chunker.enhance_chunks(
|
||||
"doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
|
||||
)
|
||||
|
||||
assert result[0].metadata["source"] == "test"
|
||||
assert result[0].metadata["doc_id"] == "123"
|
||||
assert result[0].metadata["chunk_index"] == 0
|
||||
assert "context_prefix" in result[0].metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache(self):
|
||||
"""clear_cache should reset the context cache"""
|
||||
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
||||
chunker = ContextualChunker(llm_gateway=llm)
|
||||
|
||||
await chunker.enhance_chunks("doc", ["chunk"])
|
||||
assert llm._call_count == 1
|
||||
|
||||
chunker.clear_cache()
|
||||
|
||||
await chunker.enhance_chunks("doc", ["chunk"])
|
||||
assert llm._call_count == 2 # Cache was cleared, new LLM call
|
||||
|
||||
def test_prompt_template_format(self):
|
||||
"""Prompt template should be formattable with document and chunk"""
|
||||
formatted = CONTEXT_PROMPT_TEMPLATE.format(
|
||||
document="Test document", chunk="Test chunk"
|
||||
)
|
||||
assert "Test document" in formatted
|
||||
assert "Test chunk" in formatted
|
||||
assert "Context:" in formatted
|
||||
Loading…
Reference in New Issue