191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
"""Tests for ContextualChunker"""
|
|
|
|
import pytest
|
|
|
|
from agentkit.memory.contextual_retrieval import (
|
|
ContextualChunker,
|
|
ContextualChunk,
|
|
CONTEXT_PROMPT_TEMPLATE,
|
|
)
|
|
|
|
|
|
class MockLLMGateway:
|
|
"""Mock LLM Gateway for testing"""
|
|
|
|
def __init__(self, responses: list[str] | None = None):
|
|
self._responses = responses or ["This chunk discusses revenue growth."]
|
|
self._call_count = 0
|
|
self._last_messages = None
|
|
|
|
async def chat(self, messages, model="default", **kwargs):
|
|
self._call_count += 1
|
|
self._last_messages = messages
|
|
|
|
class MockResponse:
|
|
content = self._responses[min(self._call_count - 1, len(self._responses) - 1)]
|
|
|
|
return MockResponse()
|
|
|
|
|
|
class TestContextualChunk:
|
|
"""ContextualChunk dataclass tests"""
|
|
|
|
def test_content_property(self):
|
|
chunk = ContextualChunk(
|
|
original_content="Revenue grew 3%",
|
|
context_prefix="From Acme Q2 2023 report",
|
|
enhanced_content="From Acme Q2 2023 report\nRevenue grew 3%",
|
|
chunk_index=0,
|
|
metadata={},
|
|
)
|
|
assert chunk.content == "From Acme Q2 2023 report\nRevenue grew 3%"
|
|
|
|
def test_empty_context(self):
|
|
chunk = ContextualChunk(
|
|
original_content="Some text",
|
|
context_prefix="",
|
|
enhanced_content="Some text",
|
|
chunk_index=0,
|
|
metadata={},
|
|
)
|
|
assert chunk.content == "Some text"
|
|
|
|
|
|
class TestContextualChunker:
|
|
"""ContextualChunker unit tests"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_chunks_with_llm(self):
|
|
"""Chunks should be enhanced with LLM-generated context"""
|
|
llm = MockLLMGateway(responses=["From the financial report section"])
|
|
chunker = ContextualChunker(llm_gateway=llm)
|
|
|
|
document = "Acme Corp Q2 2023 Report\n\nRevenue grew 3%.\n\nProfit increased 5%."
|
|
chunks = ["Revenue grew 3%.", "Profit increased 5%."]
|
|
|
|
result = await chunker.enhance_chunks(document, chunks)
|
|
|
|
assert len(result) == 2
|
|
assert result[0].original_content == "Revenue grew 3%."
|
|
assert result[0].context_prefix == "From the financial report section"
|
|
assert "From the financial report section" in result[0].enhanced_content
|
|
assert "Revenue grew 3%." in result[0].enhanced_content
|
|
assert result[0].chunk_index == 0
|
|
assert result[0].metadata["has_context"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_chunks_without_llm(self):
|
|
"""Without LLM, chunks should be returned without context"""
|
|
chunker = ContextualChunker(llm_gateway=None)
|
|
|
|
document = "Test document"
|
|
chunks = ["Chunk 1", "Chunk 2"]
|
|
|
|
result = await chunker.enhance_chunks(document, chunks)
|
|
|
|
assert len(result) == 2
|
|
assert result[0].context_prefix == ""
|
|
assert result[0].enhanced_content == "Chunk 1"
|
|
assert result[0].metadata.get("has_context") is not True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enhance_empty_chunks(self):
|
|
"""Empty chunks list should return empty result"""
|
|
chunker = ContextualChunker(llm_gateway=MockLLMGateway())
|
|
result = await chunker.enhance_chunks("document", [])
|
|
assert result == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_caching(self):
|
|
"""Same document+chunk should use cached context"""
|
|
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
|
chunker = ContextualChunker(llm_gateway=llm)
|
|
|
|
document = "Test document"
|
|
chunks = ["Chunk 1"]
|
|
|
|
# First call
|
|
result1 = await chunker.enhance_chunks(document, chunks)
|
|
assert result1[0].context_prefix == "Context A"
|
|
assert llm._call_count == 1
|
|
|
|
# Second call with same input — should use cache
|
|
result2 = await chunker.enhance_chunks(document, chunks)
|
|
assert result2[0].context_prefix == "Context A"
|
|
assert llm._call_count == 1 # No additional LLM call
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_truncation(self):
|
|
"""Long context should be truncated"""
|
|
long_context = "A" * 500
|
|
llm = MockLLMGateway(responses=[long_context])
|
|
chunker = ContextualChunker(llm_gateway=llm, max_context_length=100)
|
|
|
|
result = await chunker.enhance_chunks("doc", ["chunk"])
|
|
assert len(result[0].context_prefix) <= 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_llm_failure_returns_empty_context(self):
|
|
"""LLM failure should result in empty context, not error"""
|
|
class FailingLLM:
|
|
async def chat(self, messages, model="default", **kwargs):
|
|
raise RuntimeError("LLM unavailable")
|
|
|
|
chunker = ContextualChunker(llm_gateway=FailingLLM())
|
|
result = await chunker.enhance_chunks("doc", ["chunk"])
|
|
|
|
assert len(result) == 1
|
|
assert result[0].context_prefix == ""
|
|
assert result[0].enhanced_content == "chunk"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_processing(self):
|
|
"""Large number of chunks should be processed in batches"""
|
|
llm = MockLLMGateway(responses=["Context"])
|
|
chunker = ContextualChunker(llm_gateway=llm, batch_size=3)
|
|
|
|
chunks = [f"Chunk {i}" for i in range(7)]
|
|
result = await chunker.enhance_chunks("doc", chunks)
|
|
|
|
assert len(result) == 7
|
|
for i, chunk in enumerate(result):
|
|
assert chunk.chunk_index == i
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_metadata_preserved(self):
|
|
"""Metadata should be preserved and enhanced"""
|
|
llm = MockLLMGateway(responses=["Context"])
|
|
chunker = ContextualChunker(llm_gateway=llm)
|
|
|
|
result = await chunker.enhance_chunks(
|
|
"doc", ["chunk"], metadata={"source": "test", "doc_id": "123"}
|
|
)
|
|
|
|
assert result[0].metadata["source"] == "test"
|
|
assert result[0].metadata["doc_id"] == "123"
|
|
assert result[0].metadata["chunk_index"] == 0
|
|
assert "context_prefix" in result[0].metadata
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_cache(self):
|
|
"""clear_cache should reset the context cache"""
|
|
llm = MockLLMGateway(responses=["Context A", "Context B"])
|
|
chunker = ContextualChunker(llm_gateway=llm)
|
|
|
|
await chunker.enhance_chunks("doc", ["chunk"])
|
|
assert llm._call_count == 1
|
|
|
|
chunker.clear_cache()
|
|
|
|
await chunker.enhance_chunks("doc", ["chunk"])
|
|
assert llm._call_count == 2 # Cache was cleared, new LLM call
|
|
|
|
def test_prompt_template_format(self):
|
|
"""Prompt template should be formattable with document and chunk"""
|
|
formatted = CONTEXT_PROMPT_TEMPLATE.format(
|
|
document="Test document", chunk="Test chunk"
|
|
)
|
|
assert "Test document" in formatted
|
|
assert "Test chunk" in formatted
|
|
assert "Context:" in formatted
|