fischer-agentkit/tests/unit/test_contextual_retrieval.py

191 lines
6.7 KiB
Python

"""Tests for ContextualChunker"""
import pytest
from agentkit.memory.contextual_retrieval import (
ContextualChunker,
ContextualChunk,
CONTEXT_PROMPT_TEMPLATE,
)
class MockLLMGateway:
"""Mock LLM Gateway for testing"""
def __init__(self, responses: list[str] | None = None):
self._responses = responses or ["This chunk discusses revenue growth."]
self._call_count = 0
self._last_messages = None
async def chat(self, messages, model="default", **kwargs):
self._call_count += 1
self._last_messages = messages
class MockResponse:
content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
return MockResponse()
class TestContextualChunk:
"""ContextualChunk dataclass tests"""
def test_content_property(self):
chunk = ContextualChunk(
original_content="Revenue grew 3%",
context_prefix="From Acme Q2 2023 report",
enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
chunk_index=0,
metadata={},
)
assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
def test_empty_context(self):
chunk = ContextualChunk(
original_content="Some text",
context_prefix="",
enhanced_content="Some text",
chunk_index=0,
metadata={},
)
assert chunk.content == "Some text"
class TestContextualChunker:
"""ContextualChunker unit tests"""
@pytest.mark.asyncio
async def test_enhance_chunks_with_llm(self):
"""Chunks should be enhanced with LLM-generated context"""
llm = MockLLMGateway(responses=["From the financial report section"])
chunker = ContextualChunker(llm_gateway=llm)
document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
chunks = ["Revenue grew 3%.", "Profit increased 5%."]
result = await chunker.enhance_chunks(document, chunks)
assert len(result) == 2
assert result[0].original_content == "Revenue grew 3%."
assert result[0].context_prefix == "From the financial report section"
assert "From the financial report section" in result[0].enhanced_content
assert "Revenue grew 3%." in result[0].enhanced_content
assert result[0].chunk_index == 0
assert result[0].metadata["has_context"] is True
@pytest.mark.asyncio
async def test_enhance_chunks_without_llm(self):
"""Without LLM, chunks should be returned without context"""
chunker = ContextualChunker(llm_gateway=None)
document = "Test document"
chunks = ["Chunk 1", "Chunk 2"]
result = await chunker.enhance_chunks(document, chunks)
assert len(result) == 2
assert result[0].context_prefix == ""
assert result[0].enhanced_content == "Chunk 1"
assert result[0].metadata.get("has_context") is not True
@pytest.mark.asyncio
async def test_enhance_empty_chunks(self):
"""Empty chunks list should return empty result"""
chunker = ContextualChunker(llm_gateway=MockLLMGateway())
result = await chunker.enhance_chunks("document", [])
assert result == []
@pytest.mark.asyncio
async def test_context_caching(self):
"""Same document+chunk should use cached context"""
llm = MockLLMGateway(responses=["Context A", "Context B"])
chunker = ContextualChunker(llm_gateway=llm)
document = "Test document"
chunks = ["Chunk 1"]
# First call
result1 = await chunker.enhance_chunks(document, chunks)
assert result1[0].context_prefix == "Context A"
assert llm._call_count == 1
# Second call with same input — should use cache
result2 = await chunker.enhance_chunks(document, chunks)
assert result2[0].context_prefix == "Context A"
assert llm._call_count == 1 # No additional LLM call
@pytest.mark.asyncio
async def test_context_truncation(self):
"""Long context should be truncated"""
long_context = "A" * 500
llm = MockLLMGateway(responses=[long_context])
chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
result = await chunker.enhance_chunks("doc", ["chunk"])
assert len(result[0].context_prefix) <= 100
@pytest.mark.asyncio
async def test_llm_failure_returns_empty_context(self):
"""LLM failure should result in empty context, not error"""
class FailingLLM:
async def chat(self, messages, model="default", **kwargs):
raise RuntimeError("LLM unavailable")
chunker = ContextualChunker(llm_gateway=FailingLLM())
result = await chunker.enhance_chunks("doc", ["chunk"])
assert len(result) == 1
assert result[0].context_prefix == ""
assert result[0].enhanced_content == "chunk"
@pytest.mark.asyncio
async def test_batch_processing(self):
"""Large number of chunks should be processed in batches"""
llm = MockLLMGateway(responses=["Context"])
chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
chunks = [f"Chunk {i}" for i in range(7)]
result = await chunker.enhance_chunks("doc", chunks)
assert len(result) == 7
for i, chunk in enumerate(result):
assert chunk.chunk_index == i
@pytest.mark.asyncio
async def test_metadata_preserved(self):
"""Metadata should be preserved and enhanced"""
llm = MockLLMGateway(responses=["Context"])
chunker = ContextualChunker(llm_gateway=llm)
result = await chunker.enhance_chunks(
"doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
)
assert result[0].metadata["source"] == "test"
assert result[0].metadata["doc_id"] == "123"
assert result[0].metadata["chunk_index"] == 0
assert "context_prefix" in result[0].metadata
@pytest.mark.asyncio
async def test_clear_cache(self):
"""clear_cache should reset the context cache"""
llm = MockLLMGateway(responses=["Context A", "Context B"])
chunker = ContextualChunker(llm_gateway=llm)
await chunker.enhance_chunks("doc", ["chunk"])
assert llm._call_count == 1
chunker.clear_cache()
await chunker.enhance_chunks("doc", ["chunk"])
assert llm._call_count == 2 # Cache was cleared, new LLM call
def test_prompt_template_format(self):
"""Prompt template should be formattable with document and chunk"""
formatted = CONTEXT_PROMPT_TEMPLATE.format(
document="Test document", chunk="Test chunk"
)
assert "Test document" in formatted
assert "Test chunk" in formatted
assert "Context:" in formatted