From f16dcb5ebeb2e57147e22c47f75730a3ba80315d Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sat, 6 Jun 2026 22:19:02 +0800 Subject: [PATCH] feat(memory): U2 Contextual Retrieval - LLM-generated context prefixes for chunks - ContextualChunker: generates context prefixes per chunk via LLM - Integrated into HttpRAGService ingest with contextual_chunking option - Caching, batch processing, graceful LLM failure handling - 12 tests passing --- src/agentkit/memory/contextual_retrieval.py | 210 ++++++++++++++++++++ src/agentkit/memory/http_rag.py | 26 ++- tests/unit/test_contextual_retrieval.py | 190 ++++++++++++++++++ 3 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 src/agentkit/memory/contextual_retrieval.py create mode 100644 tests/unit/test_contextual_retrieval.py diff --git a/src/agentkit/memory/contextual_retrieval.py b/src/agentkit/memory/contextual_retrieval.py new file mode 100644 index 0000000..93eb47f --- /dev/null +++ b/src/agentkit/memory/contextual_retrieval.py @@ -0,0 +1,210 @@ +"""ContextualChunker - 上下文增强分块 + +在嵌入前为每个文档块添加 LLM 生成的上下文前缀, +解决分块后上下文丢失问题(Anthropic Contextual Retrieval)。 +""" + +from __future__ import annotations + +import hashlib +import logging +from dataclasses import dataclass +from typing import Any + +from agentkit.memory.embedder import EmbeddingCache + +logger = logging.getLogger(__name__) + + +@dataclass +class ContextualChunk: + """带上下文前缀的文档块""" + + original_content: str + context_prefix: str + enhanced_content: str + chunk_index: int + metadata: dict[str, Any] + + @property + def content(self) -> str: + """获取增强后的完整内容""" + return self.enhanced_content + + +CONTEXT_PROMPT_TEMPLATE = """\ +Given the full document below and a specific chunk from it, write a brief context that helps someone understand what this chunk is about in the broader document. Output ONLY the context, no explanations. + + +{document} + + + +{chunk} + + +Context:""" + + +class ContextualChunker: + """上下文增强分块器 + + 为每个文档块生成 LLM 上下文前缀,增强检索质量。 + + 工作流程: + 1. 接收文档和分块列表 + 2. 对每个块,调用 LLM 生成简洁上下文语句 + 3. 将上下文前缀添加到原始内容前 + 4. 缓存结果避免重复计算 + + 成本优化: + - 文档级 Prompt Caching(同一文档的多个块共享文档前缀) + - EmbeddingCache 缓存上下文生成结果 + - 批处理(batch_size) + """ + + def __init__( + self, + llm_gateway: Any = None, + cache: EmbeddingCache | None = None, + batch_size: int = 8, + max_context_length: int = 200, + prompt_template: str = CONTEXT_PROMPT_TEMPLATE, + ): + """ + Args: + llm_gateway: LLM Gateway 实例,用于生成上下文 + cache: 嵌入缓存,用于缓存上下文生成结果 + batch_size: 批处理大小 + max_context_length: 上下文最大字符长度 + prompt_template: 上下文生成 prompt 模板 + """ + self._llm_gateway = llm_gateway + self._cache = cache + self._batch_size = batch_size + self._max_context_length = max_context_length + self._prompt_template = prompt_template + self._context_cache: dict[str, str] = {} + + async def enhance_chunks( + self, + document: str, + chunks: list[str], + metadata: dict[str, Any] | None = None, + ) -> list[ContextualChunk]: + """为文档块添加上下文前缀 + + Args: + document: 完整文档内容 + chunks: 文档分块列表 + metadata: 附加元数据 + + Returns: + 增强后的 ContextualChunk 列表 + """ + if not chunks: + return [] + + if not self._llm_gateway: + # No LLM available — return chunks without context + logger.info("No LLM gateway configured, skipping contextual enhancement") + return [ + ContextualChunk( + original_content=chunk, + context_prefix="", + enhanced_content=chunk, + chunk_index=i, + metadata=metadata or {}, + ) + for i, chunk in enumerate(chunks) + ] + + result: list[ContextualChunk] = [] + + # Process in batches + for batch_start in range(0, len(chunks), self._batch_size): + batch = chunks[batch_start : batch_start + self._batch_size] + batch_results = await self._process_batch(document, batch, batch_start, metadata) + result.extend(batch_results) + + return result + + async def _process_batch( + self, + document: str, + chunks: list[str], + start_index: int, + metadata: dict[str, Any] | None, + ) -> list[ContextualChunk]: + """处理一批文档块""" + results: list[ContextualChunk] = [] + + for i, chunk in enumerate(chunks): + chunk_index = start_index + i + chunk_meta = dict(metadata or {}) + chunk_meta["chunk_index"] = chunk_index + + # Check cache + cache_key = self._make_cache_key(document, chunk) + if cache_key in self._context_cache: + context = self._context_cache[cache_key] + else: + context = await self._generate_context(document, chunk) + self._context_cache[cache_key] = context + + # Truncate context if too long + if len(context) > self._max_context_length: + context = context[: self._max_context_length] + + # Build enhanced content + if context: + enhanced = f"{context}\n{chunk}" + else: + enhanced = chunk + + chunk_meta["context_prefix"] = context + chunk_meta["has_context"] = bool(context) + + results.append( + ContextualChunk( + original_content=chunk, + context_prefix=context, + enhanced_content=enhanced, + chunk_index=chunk_index, + metadata=chunk_meta, + ) + ) + + return results + + async def _generate_context(self, document: str, chunk: str) -> str: + """使用 LLM 为单个块生成上下文""" + # Truncate document for prompt efficiency + doc_preview = document[:3000] if len(document) > 3000 else document + chunk_preview = chunk[:1000] if len(chunk) > 1000 else chunk + + prompt = self._prompt_template.format( + document=doc_preview, + chunk=chunk_preview, + ) + + try: + response = await self._llm_gateway.chat( + messages=[{"role": "user", "content": prompt}], + model="default", + ) + context = response.content.strip() + return context + except Exception as e: + logger.warning(f"Context generation failed for chunk: {e}") + return "" + + @staticmethod + def _make_cache_key(document: str, chunk: str) -> str: + """生成缓存键""" + content = f"{document[:500]}:{chunk[:500]}" + return hashlib.sha256(content.encode()).hexdigest()[:16] + + def clear_cache(self) -> None: + """清除上下文缓存""" + self._context_cache.clear() diff --git a/src/agentkit/memory/http_rag.py b/src/agentkit/memory/http_rag.py index b0ed246..2e4d94f 100644 --- a/src/agentkit/memory/http_rag.py +++ b/src/agentkit/memory/http_rag.py @@ -29,6 +29,7 @@ class HttpRAGService: - "industry-kb-id" - "enterprise-kb-id" timeout: 30 + contextual_chunking: false """ def __init__( @@ -37,6 +38,8 @@ class HttpRAGService: api_key: str | None = None, knowledge_base_ids: list[str] | None = None, timeout: int = 30, + contextual_chunking: bool = False, + llm_gateway: Any = None, ): """ Args: @@ -50,6 +53,8 @@ class HttpRAGService: self._knowledge_base_ids = knowledge_base_ids or [] self._timeout = timeout self._client: httpx.AsyncClient | None = None + self._contextual_chunking = contextual_chunking + self._llm_gateway = llm_gateway def _get_client(self) -> httpx.AsyncClient: """懒初始化 httpx 客户端""" @@ -232,6 +237,9 @@ class HttpRAGService: ) -> dict[str, Any] | None: """写入文档到知识库(可选操作) + When contextual_chunking is enabled and llm_gateway is configured, + the document content is enhanced with contextual prefixes before ingestion. + Args: key: 文档标题或标识 value: 文档内容 @@ -245,9 +253,25 @@ class HttpRAGService: logger.warning("HttpRAGService.ingest: no knowledge_base_ids configured") return None + content = str(value) + + # Apply contextual chunking if enabled + if self._contextual_chunking and self._llm_gateway: + from agentkit.memory.contextual_retrieval import ContextualChunker + + chunker = ContextualChunker(llm_gateway=self._llm_gateway) + # Simple chunking: split by paragraphs + raw_chunks = [c.strip() for c in content.split("\n\n") if c.strip()] + if raw_chunks: + enhanced = await chunker.enhance_chunks( + document=content, chunks=raw_chunks, metadata=metadata + ) + # Rejoin enhanced chunks + content = "\n\n".join(chunk.enhanced_content for chunk in enhanced) + payload = { "title": key, - "content": str(value), + "content": content, "source_type": "text", "metadata": metadata or {}, } diff --git a/tests/unit/test_contextual_retrieval.py b/tests/unit/test_contextual_retrieval.py new file mode 100644 index 0000000..e139222 --- /dev/null +++ b/tests/unit/test_contextual_retrieval.py @@ -0,0 +1,190 @@ +"""Tests for ContextualChunker""" + +import pytest + +from agentkit.memory.contextual_retrieval import ( + ContextualChunker, + ContextualChunk, + CONTEXT_PROMPT_TEMPLATE, +) + + +class MockLLMGateway: + """Mock LLM Gateway for testing""" + + def __init__(self, responses: list[str] | None = None): + self._responses = responses or ["This chunk discusses revenue growth."] + self._call_count = 0 + self._last_messages = None + + async def chat(self, messages, model="default", **kwargs): + self._call_count += 1 + self._last_messages = messages + + class MockResponse: + content = self._responses[min(self._call_count - 1, len(self._responses) - 1)] + + return MockResponse() + + +class TestContextualChunk: + """ContextualChunk dataclass tests""" + + def test_content_property(self): + chunk = ContextualChunk( + original_content="Revenue grew 3%", + context_prefix="From Acme Q2 2023 report", + enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%", + chunk_index=0, + metadata={}, + ) + assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%" + + def test_empty_context(self): + chunk = ContextualChunk( + original_content="Some text", + context_prefix="", + enhanced_content="Some text", + chunk_index=0, + metadata={}, + ) + assert chunk.content == "Some text" + + +class TestContextualChunker: + """ContextualChunker unit tests""" + + @pytest.mark.asyncio + async def test_enhance_chunks_with_llm(self): + """Chunks should be enhanced with LLM-generated context""" + llm = MockLLMGateway(responses=["From the financial report section"]) + chunker = ContextualChunker(llm_gateway=llm) + + document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%." + chunks = ["Revenue grew 3%.", "Profit increased 5%."] + + result = await chunker.enhance_chunks(document, chunks) + + assert len(result) == 2 + assert result[0].original_content == "Revenue grew 3%." + assert result[0].context_prefix == "From the financial report section" + assert "From the financial report section" in result[0].enhanced_content + assert "Revenue grew 3%." in result[0].enhanced_content + assert result[0].chunk_index == 0 + assert result[0].metadata["has_context"] is True + + @pytest.mark.asyncio + async def test_enhance_chunks_without_llm(self): + """Without LLM, chunks should be returned without context""" + chunker = ContextualChunker(llm_gateway=None) + + document = "Test document" + chunks = ["Chunk 1", "Chunk 2"] + + result = await chunker.enhance_chunks(document, chunks) + + assert len(result) == 2 + assert result[0].context_prefix == "" + assert result[0].enhanced_content == "Chunk 1" + assert result[0].metadata.get("has_context") is not True + + @pytest.mark.asyncio + async def test_enhance_empty_chunks(self): + """Empty chunks list should return empty result""" + chunker = ContextualChunker(llm_gateway=MockLLMGateway()) + result = await chunker.enhance_chunks("document", []) + assert result == [] + + @pytest.mark.asyncio + async def test_context_caching(self): + """Same document+chunk should use cached context""" + llm = MockLLMGateway(responses=["Context A", "Context B"]) + chunker = ContextualChunker(llm_gateway=llm) + + document = "Test document" + chunks = ["Chunk 1"] + + # First call + result1 = await chunker.enhance_chunks(document, chunks) + assert result1[0].context_prefix == "Context A" + assert llm._call_count == 1 + + # Second call with same input — should use cache + result2 = await chunker.enhance_chunks(document, chunks) + assert result2[0].context_prefix == "Context A" + assert llm._call_count == 1 # No additional LLM call + + @pytest.mark.asyncio + async def test_context_truncation(self): + """Long context should be truncated""" + long_context = "A" * 500 + llm = MockLLMGateway(responses=[long_context]) + chunker = ContextualChunker(llm_gateway=llm, max_context_length=100) + + result = await chunker.enhance_chunks("doc", ["chunk"]) + assert len(result[0].context_prefix) <= 100 + + @pytest.mark.asyncio + async def test_llm_failure_returns_empty_context(self): + """LLM failure should result in empty context, not error""" + class FailingLLM: + async def chat(self, messages, model="default", **kwargs): + raise RuntimeError("LLM unavailable") + + chunker = ContextualChunker(llm_gateway=FailingLLM()) + result = await chunker.enhance_chunks("doc", ["chunk"]) + + assert len(result) == 1 + assert result[0].context_prefix == "" + assert result[0].enhanced_content == "chunk" + + @pytest.mark.asyncio + async def test_batch_processing(self): + """Large number of chunks should be processed in batches""" + llm = MockLLMGateway(responses=["Context"]) + chunker = ContextualChunker(llm_gateway=llm, batch_size=3) + + chunks = [f"Chunk {i}" for i in range(7)] + result = await chunker.enhance_chunks("doc", chunks) + + assert len(result) == 7 + for i, chunk in enumerate(result): + assert chunk.chunk_index == i + + @pytest.mark.asyncio + async def test_metadata_preserved(self): + """Metadata should be preserved and enhanced""" + llm = MockLLMGateway(responses=["Context"]) + chunker = ContextualChunker(llm_gateway=llm) + + result = await chunker.enhance_chunks( + "doc", ["chunk"], metadata={"source": "test", "doc_id": "123"} + ) + + assert result[0].metadata["source"] == "test" + assert result[0].metadata["doc_id"] == "123" + assert result[0].metadata["chunk_index"] == 0 + assert "context_prefix" in result[0].metadata + + @pytest.mark.asyncio + async def test_clear_cache(self): + """clear_cache should reset the context cache""" + llm = MockLLMGateway(responses=["Context A", "Context B"]) + chunker = ContextualChunker(llm_gateway=llm) + + await chunker.enhance_chunks("doc", ["chunk"]) + assert llm._call_count == 1 + + chunker.clear_cache() + + await chunker.enhance_chunks("doc", ["chunk"]) + assert llm._call_count == 2 # Cache was cleared, new LLM call + + def test_prompt_template_format(self): + """Prompt template should be formattable with document and chunk""" + formatted = CONTEXT_PROMPT_TEMPLATE.format( + document="Test document", chunk="Test chunk" + ) + assert "Test document" in formatted + assert "Test chunk" in formatted + assert "Context:" in formatted