feat(memory): U2 Contextual Retrieval - LLM-generated context prefixes for chunks
- ContextualChunker: generates context prefixes per chunk via LLM - Integrated into HttpRAGService ingest with contextual_chunking option - Caching, batch processing, graceful LLM failure handling - 12 tests passing
This commit is contained in:
parent
a6c9babfdc
commit
f16dcb5ebe
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""ContextualChunker - 上下文增强分块
|
||||||
|
|
||||||
|
在嵌入前为每个文档块添加 LLM 生成的上下文前缀,
|
||||||
|
解决分块后上下文丢失问题(Anthropic Contextual Retrieval)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.embedder import EmbeddingCache
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextualChunk:
|
||||||
|
"""带上下文前缀的文档块"""
|
||||||
|
|
||||||
|
original_content: str
|
||||||
|
context_prefix: str
|
||||||
|
enhanced_content: str
|
||||||
|
chunk_index: int
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
"""获取增强后的完整内容"""
|
||||||
|
return self.enhanced_content
|
||||||
|
|
||||||
|
|
||||||
|
CONTEXT_PROMPT_TEMPLATE = """\
|
||||||
|
Given the full document below and a specific chunk from it, write a brief context that helps someone understand what this chunk is about in the broader document. Output ONLY the context, no explanations.
|
||||||
|
|
||||||
|
<document>
|
||||||
|
{document}
|
||||||
|
</document>
|
||||||
|
|
||||||
|
<chunk>
|
||||||
|
{chunk}
|
||||||
|
</chunk>
|
||||||
|
|
||||||
|
Context:"""
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualChunker:
|
||||||
|
"""上下文增强分块器
|
||||||
|
|
||||||
|
为每个文档块生成 LLM 上下文前缀,增强检索质量。
|
||||||
|
|
||||||
|
工作流程:
|
||||||
|
1. 接收文档和分块列表
|
||||||
|
2. 对每个块,调用 LLM 生成简洁上下文语句
|
||||||
|
3. 将上下文前缀添加到原始内容前
|
||||||
|
4. 缓存结果避免重复计算
|
||||||
|
|
||||||
|
成本优化:
|
||||||
|
- 文档级 Prompt Caching(同一文档的多个块共享文档前缀)
|
||||||
|
- EmbeddingCache 缓存上下文生成结果
|
||||||
|
- 批处理(batch_size)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_gateway: Any = None,
|
||||||
|
cache: EmbeddingCache | None = None,
|
||||||
|
batch_size: int = 8,
|
||||||
|
max_context_length: int = 200,
|
||||||
|
prompt_template: str = CONTEXT_PROMPT_TEMPLATE,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
llm_gateway: LLM Gateway 实例,用于生成上下文
|
||||||
|
cache: 嵌入缓存,用于缓存上下文生成结果
|
||||||
|
batch_size: 批处理大小
|
||||||
|
max_context_length: 上下文最大字符长度
|
||||||
|
prompt_template: 上下文生成 prompt 模板
|
||||||
|
"""
|
||||||
|
self._llm_gateway = llm_gateway
|
||||||
|
self._cache = cache
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._max_context_length = max_context_length
|
||||||
|
self._prompt_template = prompt_template
|
||||||
|
self._context_cache: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def enhance_chunks(
|
||||||
|
self,
|
||||||
|
document: str,
|
||||||
|
chunks: list[str],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> list[ContextualChunk]:
|
||||||
|
"""为文档块添加上下文前缀
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: 完整文档内容
|
||||||
|
chunks: 文档分块列表
|
||||||
|
metadata: 附加元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
增强后的 ContextualChunk 列表
|
||||||
|
"""
|
||||||
|
if not chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not self._llm_gateway:
|
||||||
|
# No LLM available — return chunks without context
|
||||||
|
logger.info("No LLM gateway configured, skipping contextual enhancement")
|
||||||
|
return [
|
||||||
|
ContextualChunk(
|
||||||
|
original_content=chunk,
|
||||||
|
context_prefix="",
|
||||||
|
enhanced_content=chunk,
|
||||||
|
chunk_index=i,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
for i, chunk in enumerate(chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
result: list[ContextualChunk] = []
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for batch_start in range(0, len(chunks), self._batch_size):
|
||||||
|
batch = chunks[batch_start : batch_start + self._batch_size]
|
||||||
|
batch_results = await self._process_batch(document, batch, batch_start, metadata)
|
||||||
|
result.extend(batch_results)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _process_batch(
|
||||||
|
self,
|
||||||
|
document: str,
|
||||||
|
chunks: list[str],
|
||||||
|
start_index: int,
|
||||||
|
metadata: dict[str, Any] | None,
|
||||||
|
) -> list[ContextualChunk]:
|
||||||
|
"""处理一批文档块"""
|
||||||
|
results: list[ContextualChunk] = []
|
||||||
|
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
chunk_index = start_index + i
|
||||||
|
chunk_meta = dict(metadata or {})
|
||||||
|
chunk_meta["chunk_index"] = chunk_index
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = self._make_cache_key(document, chunk)
|
||||||
|
if cache_key in self._context_cache:
|
||||||
|
context = self._context_cache[cache_key]
|
||||||
|
else:
|
||||||
|
context = await self._generate_context(document, chunk)
|
||||||
|
self._context_cache[cache_key] = context
|
||||||
|
|
||||||
|
# Truncate context if too long
|
||||||
|
if len(context) > self._max_context_length:
|
||||||
|
context = context[: self._max_context_length]
|
||||||
|
|
||||||
|
# Build enhanced content
|
||||||
|
if context:
|
||||||
|
enhanced = f"{context}\n{chunk}"
|
||||||
|
else:
|
||||||
|
enhanced = chunk
|
||||||
|
|
||||||
|
chunk_meta["context_prefix"] = context
|
||||||
|
chunk_meta["has_context"] = bool(context)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
ContextualChunk(
|
||||||
|
original_content=chunk,
|
||||||
|
context_prefix=context,
|
||||||
|
enhanced_content=enhanced,
|
||||||
|
chunk_index=chunk_index,
|
||||||
|
metadata=chunk_meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _generate_context(self, document: str, chunk: str) -> str:
|
||||||
|
"""使用 LLM 为单个块生成上下文"""
|
||||||
|
# Truncate document for prompt efficiency
|
||||||
|
doc_preview = document[:3000] if len(document) > 3000 else document
|
||||||
|
chunk_preview = chunk[:1000] if len(chunk) > 1000 else chunk
|
||||||
|
|
||||||
|
prompt = self._prompt_template.format(
|
||||||
|
document=doc_preview,
|
||||||
|
chunk=chunk_preview,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
model="default",
|
||||||
|
)
|
||||||
|
context = response.content.strip()
|
||||||
|
return context
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Context generation failed for chunk: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_cache_key(document: str, chunk: str) -> str:
|
||||||
|
"""生成缓存键"""
|
||||||
|
content = f"{document[:500]}:{chunk[:500]}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清除上下文缓存"""
|
||||||
|
self._context_cache.clear()
|
||||||
|
|
@ -29,6 +29,7 @@ class HttpRAGService:
|
||||||
- "industry-kb-id"
|
- "industry-kb-id"
|
||||||
- "enterprise-kb-id"
|
- "enterprise-kb-id"
|
||||||
timeout: 30
|
timeout: 30
|
||||||
|
contextual_chunking: false
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -37,6 +38,8 @@ class HttpRAGService:
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
knowledge_base_ids: list[str] | None = None,
|
knowledge_base_ids: list[str] | None = None,
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
|
contextual_chunking: bool = False,
|
||||||
|
llm_gateway: Any = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -50,6 +53,8 @@ class HttpRAGService:
|
||||||
self._knowledge_base_ids = knowledge_base_ids or []
|
self._knowledge_base_ids = knowledge_base_ids or []
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._client: httpx.AsyncClient | None = None
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._contextual_chunking = contextual_chunking
|
||||||
|
self._llm_gateway = llm_gateway
|
||||||
|
|
||||||
def _get_client(self) -> httpx.AsyncClient:
|
def _get_client(self) -> httpx.AsyncClient:
|
||||||
"""懒初始化 httpx 客户端"""
|
"""懒初始化 httpx 客户端"""
|
||||||
|
|
@ -232,6 +237,9 @@ class HttpRAGService:
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""写入文档到知识库(可选操作)
|
"""写入文档到知识库(可选操作)
|
||||||
|
|
||||||
|
When contextual_chunking is enabled and llm_gateway is configured,
|
||||||
|
the document content is enhanced with contextual prefixes before ingestion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: 文档标题或标识
|
key: 文档标题或标识
|
||||||
value: 文档内容
|
value: 文档内容
|
||||||
|
|
@ -245,9 +253,25 @@ class HttpRAGService:
|
||||||
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
content = str(value)
|
||||||
|
|
||||||
|
# Apply contextual chunking if enabled
|
||||||
|
if self._contextual_chunking and self._llm_gateway:
|
||||||
|
from agentkit.memory.contextual_retrieval import ContextualChunker
|
||||||
|
|
||||||
|
chunker = ContextualChunker(llm_gateway=self._llm_gateway)
|
||||||
|
# Simple chunking: split by paragraphs
|
||||||
|
raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()]
|
||||||
|
if raw_chunks:
|
||||||
|
enhanced = await chunker.enhance_chunks(
|
||||||
|
document=content, chunks=raw_chunks, metadata=metadata
|
||||||
|
)
|
||||||
|
# Rejoin enhanced chunks
|
||||||
|
content = "\n\n".join(chunk.enhanced_content for chunk in enhanced)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"title": key,
|
"title": key,
|
||||||
"content": str(value),
|
"content": content,
|
||||||
"source_type": "text",
|
"source_type": "text",
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,190 @@
|
||||||
|
"""Tests for ContextualChunker"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.contextual_retrieval import (
|
||||||
|
ContextualChunker,
|
||||||
|
ContextualChunk,
|
||||||
|
CONTEXT_PROMPT_TEMPLATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMGateway:
|
||||||
|
"""Mock LLM Gateway for testing"""
|
||||||
|
|
||||||
|
def __init__(self, responses: list[str] | None = None):
|
||||||
|
self._responses = responses or ["This chunk discusses revenue growth."]
|
||||||
|
self._call_count = 0
|
||||||
|
self._last_messages = None
|
||||||
|
|
||||||
|
async def chat(self, messages, model="default", **kwargs):
|
||||||
|
self._call_count += 1
|
||||||
|
self._last_messages = messages
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
|
||||||
|
|
||||||
|
return MockResponse()
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextualChunk:
|
||||||
|
"""ContextualChunk dataclass tests"""
|
||||||
|
|
||||||
|
def test_content_property(self):
|
||||||
|
chunk = ContextualChunk(
|
||||||
|
original_content="Revenue grew 3%",
|
||||||
|
context_prefix="From Acme Q2 2023 report",
|
||||||
|
enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
|
||||||
|
chunk_index=0,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
|
||||||
|
|
||||||
|
def test_empty_context(self):
|
||||||
|
chunk = ContextualChunk(
|
||||||
|
original_content="Some text",
|
||||||
|
context_prefix="",
|
||||||
|
enhanced_content="Some text",
|
||||||
|
chunk_index=0,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
assert chunk.content == "Some text"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextualChunker:
|
||||||
|
"""ContextualChunker unit tests"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhance_chunks_with_llm(self):
|
||||||
|
"""Chunks should be enhanced with LLM-generated context"""
|
||||||
|
llm = MockLLMGateway(responses=["From the financial report section"])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm)
|
||||||
|
|
||||||
|
document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
|
||||||
|
chunks = ["Revenue grew 3%.", "Profit increased 5%."]
|
||||||
|
|
||||||
|
result = await chunker.enhance_chunks(document, chunks)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].original_content == "Revenue grew 3%."
|
||||||
|
assert result[0].context_prefix == "From the financial report section"
|
||||||
|
assert "From the financial report section" in result[0].enhanced_content
|
||||||
|
assert "Revenue grew 3%." in result[0].enhanced_content
|
||||||
|
assert result[0].chunk_index == 0
|
||||||
|
assert result[0].metadata["has_context"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhance_chunks_without_llm(self):
|
||||||
|
"""Without LLM, chunks should be returned without context"""
|
||||||
|
chunker = ContextualChunker(llm_gateway=None)
|
||||||
|
|
||||||
|
document = "Test document"
|
||||||
|
chunks = ["Chunk 1", "Chunk 2"]
|
||||||
|
|
||||||
|
result = await chunker.enhance_chunks(document, chunks)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0].context_prefix == ""
|
||||||
|
assert result[0].enhanced_content == "Chunk 1"
|
||||||
|
assert result[0].metadata.get("has_context") is not True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enhance_empty_chunks(self):
|
||||||
|
"""Empty chunks list should return empty result"""
|
||||||
|
chunker = ContextualChunker(llm_gateway=MockLLMGateway())
|
||||||
|
result = await chunker.enhance_chunks("document", [])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_caching(self):
|
||||||
|
"""Same document+chunk should use cached context"""
|
||||||
|
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm)
|
||||||
|
|
||||||
|
document = "Test document"
|
||||||
|
chunks = ["Chunk 1"]
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await chunker.enhance_chunks(document, chunks)
|
||||||
|
assert result1[0].context_prefix == "Context A"
|
||||||
|
assert llm._call_count == 1
|
||||||
|
|
||||||
|
# Second call with same input — should use cache
|
||||||
|
result2 = await chunker.enhance_chunks(document, chunks)
|
||||||
|
assert result2[0].context_prefix == "Context A"
|
||||||
|
assert llm._call_count == 1 # No additional LLM call
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_truncation(self):
|
||||||
|
"""Long context should be truncated"""
|
||||||
|
long_context = "A" * 500
|
||||||
|
llm = MockLLMGateway(responses=[long_context])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
|
||||||
|
|
||||||
|
result = await chunker.enhance_chunks("doc", ["chunk"])
|
||||||
|
assert len(result[0].context_prefix) <= 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_failure_returns_empty_context(self):
|
||||||
|
"""LLM failure should result in empty context, not error"""
|
||||||
|
class FailingLLM:
|
||||||
|
async def chat(self, messages, model="default", **kwargs):
|
||||||
|
raise RuntimeError("LLM unavailable")
|
||||||
|
|
||||||
|
chunker = ContextualChunker(llm_gateway=FailingLLM())
|
||||||
|
result = await chunker.enhance_chunks("doc", ["chunk"])
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].context_prefix == ""
|
||||||
|
assert result[0].enhanced_content == "chunk"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_processing(self):
|
||||||
|
"""Large number of chunks should be processed in batches"""
|
||||||
|
llm = MockLLMGateway(responses=["Context"])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
|
||||||
|
|
||||||
|
chunks = [f"Chunk {i}" for i in range(7)]
|
||||||
|
result = await chunker.enhance_chunks("doc", chunks)
|
||||||
|
|
||||||
|
assert len(result) == 7
|
||||||
|
for i, chunk in enumerate(result):
|
||||||
|
assert chunk.chunk_index == i
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metadata_preserved(self):
|
||||||
|
"""Metadata should be preserved and enhanced"""
|
||||||
|
llm = MockLLMGateway(responses=["Context"])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm)
|
||||||
|
|
||||||
|
result = await chunker.enhance_chunks(
|
||||||
|
"doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result[0].metadata["source"] == "test"
|
||||||
|
assert result[0].metadata["doc_id"] == "123"
|
||||||
|
assert result[0].metadata["chunk_index"] == 0
|
||||||
|
assert "context_prefix" in result[0].metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_cache(self):
|
||||||
|
"""clear_cache should reset the context cache"""
|
||||||
|
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
||||||
|
chunker = ContextualChunker(llm_gateway=llm)
|
||||||
|
|
||||||
|
await chunker.enhance_chunks("doc", ["chunk"])
|
||||||
|
assert llm._call_count == 1
|
||||||
|
|
||||||
|
chunker.clear_cache()
|
||||||
|
|
||||||
|
await chunker.enhance_chunks("doc", ["chunk"])
|
||||||
|
assert llm._call_count == 2 # Cache was cleared, new LLM call
|
||||||
|
|
||||||
|
def test_prompt_template_format(self):
|
||||||
|
"""Prompt template should be formattable with document and chunk"""
|
||||||
|
formatted = CONTEXT_PROMPT_TEMPLATE.format(
|
||||||
|
document="Test document", chunk="Test chunk"
|
||||||
|
)
|
||||||
|
assert "Test document" in formatted
|
||||||
|
assert "Test chunk" in formatted
|
||||||
|
assert "Context:" in formatted
|
||||||
Loading…
Reference in New Issue