"""Tests for ContextualChunker""" import pytest from agentkit.memory.contextual_retrieval import ( ContextualChunker, ContextualChunk, CONTEXT_PROMPT_TEMPLATE, ) class MockLLMGateway: """Mock LLM Gateway for testing""" def __init__(self, responses: list[str] | None = None): self._responses = responses or ["This chunk discusses revenue growth."] self._call_count = 0 self._last_messages = None async def chat(self, messages, model="default", **kwargs): self._call_count += 1 self._last_messages = messages class MockResponse: content = self._responses[min(self._call_count - 1, len(self._responses) - 1)] return MockResponse() class TestContextualChunk: """ContextualChunk dataclass tests""" def test_content_property(self): chunk = ContextualChunk( original_content="Revenue grew 3%", context_prefix="From Acme Q2 2023 report", enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%", chunk_index=0, metadata={}, ) assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%" def test_empty_context(self): chunk = ContextualChunk( original_content="Some text", context_prefix="", enhanced_content="Some text", chunk_index=0, metadata={}, ) assert chunk.content == "Some text" class TestContextualChunker: """ContextualChunker unit tests""" @pytest.mark.asyncio async def test_enhance_chunks_with_llm(self): """Chunks should be enhanced with LLM-generated context""" llm = MockLLMGateway(responses=["From the financial report section"]) chunker = ContextualChunker(llm_gateway=llm) document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%." chunks = ["Revenue grew 3%.", "Profit increased 5%."] result = await chunker.enhance_chunks(document, chunks) assert len(result) == 2 assert result[0].original_content == "Revenue grew 3%." assert result[0].context_prefix == "From the financial report section" assert "From the financial report section" in result[0].enhanced_content assert "Revenue grew 3%." in result[0].enhanced_content assert result[0].chunk_index == 0 assert result[0].metadata["has_context"] is True @pytest.mark.asyncio async def test_enhance_chunks_without_llm(self): """Without LLM, chunks should be returned without context""" chunker = ContextualChunker(llm_gateway=None) document = "Test document" chunks = ["Chunk 1", "Chunk 2"] result = await chunker.enhance_chunks(document, chunks) assert len(result) == 2 assert result[0].context_prefix == "" assert result[0].enhanced_content == "Chunk 1" assert result[0].metadata.get("has_context") is not True @pytest.mark.asyncio async def test_enhance_empty_chunks(self): """Empty chunks list should return empty result""" chunker = ContextualChunker(llm_gateway=MockLLMGateway()) result = await chunker.enhance_chunks("document", []) assert result == [] @pytest.mark.asyncio async def test_context_caching(self): """Same document+chunk should use cached context""" llm = MockLLMGateway(responses=["Context A", "Context B"]) chunker = ContextualChunker(llm_gateway=llm) document = "Test document" chunks = ["Chunk 1"] # First call result1 = await chunker.enhance_chunks(document, chunks) assert result1[0].context_prefix == "Context A" assert llm._call_count == 1 # Second call with same input — should use cache result2 = await chunker.enhance_chunks(document, chunks) assert result2[0].context_prefix == "Context A" assert llm._call_count == 1 # No additional LLM call @pytest.mark.asyncio async def test_context_truncation(self): """Long context should be truncated""" long_context = "A" * 500 llm = MockLLMGateway(responses=[long_context]) chunker = ContextualChunker(llm_gateway=llm, max_context_length=100) result = await chunker.enhance_chunks("doc", ["chunk"]) assert len(result[0].context_prefix) <= 100 @pytest.mark.asyncio async def test_llm_failure_returns_empty_context(self): """LLM failure should result in empty context, not error""" class FailingLLM: async def chat(self, messages, model="default", **kwargs): raise RuntimeError("LLM unavailable") chunker = ContextualChunker(llm_gateway=FailingLLM()) result = await chunker.enhance_chunks("doc", ["chunk"]) assert len(result) == 1 assert result[0].context_prefix == "" assert result[0].enhanced_content == "chunk" @pytest.mark.asyncio async def test_batch_processing(self): """Large number of chunks should be processed in batches""" llm = MockLLMGateway(responses=["Context"]) chunker = ContextualChunker(llm_gateway=llm, batch_size=3) chunks = [f"Chunk {i}" for i in range(7)] result = await chunker.enhance_chunks("doc", chunks) assert len(result) == 7 for i, chunk in enumerate(result): assert chunk.chunk_index == i @pytest.mark.asyncio async def test_metadata_preserved(self): """Metadata should be preserved and enhanced""" llm = MockLLMGateway(responses=["Context"]) chunker = ContextualChunker(llm_gateway=llm) result = await chunker.enhance_chunks( "doc", ["chunk"], metadata={"source": "test", "doc_id": "123"} ) assert result[0].metadata["source"] == "test" assert result[0].metadata["doc_id"] == "123" assert result[0].metadata["chunk_index"] == 0 assert "context_prefix" in result[0].metadata @pytest.mark.asyncio async def test_clear_cache(self): """clear_cache should reset the context cache""" llm = MockLLMGateway(responses=["Context A", "Context B"]) chunker = ContextualChunker(llm_gateway=llm) await chunker.enhance_chunks("doc", ["chunk"]) assert llm._call_count == 1 chunker.clear_cache() await chunker.enhance_chunks("doc", ["chunk"]) assert llm._call_count == 2 # Cache was cleared, new LLM call def test_prompt_template_format(self): """Prompt template should be formattable with document and chunk""" formatted = CONTEXT_PROMPT_TEMPLATE.format( document="Test document", chunk="Test chunk" ) assert "Test document" in formatted assert "Test chunk" in formatted assert "Context:" in formatted