fischer-agentkit/tests/unit/test_rag_loop.py

338 lines
12 KiB
Python

"""Tests for RelevanceScorer and RAGSelfCorrectionLoop"""
import pytest
from agentkit.memory.base import MemoryItem
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
RelevanceScore,
RelevanceVerdict,
RetrievalEvaluation,
)
from agentkit.memory.rag_loop import (
RAGSelfCorrectionLoop,
RAGLoopResult,
LoopState,
)
# --- RelevanceScorer Tests ---
class TestRelevanceScorer:
"""RelevanceScorer unit tests"""
def setup_method(self):
self.scorer = RelevanceScorer()
def test_score_highly_relevant_item(self):
"""Highly relevant document should score high"""
query = "Python web framework Django Flask"
item = MemoryItem(
key="doc1",
value="Django and Flask are popular Python web frameworks for building web applications",
score=0.9,
)
result = self.scorer.score_item(query, item)
assert result.score > 0.5
assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS)
def test_score_irrelevant_item(self):
"""Completely irrelevant document should score low"""
query = "Python web framework"
item = MemoryItem(
key="doc2",
value="The weather is sunny today and the birds are singing in the garden",
score=0.1,
)
result = self.scorer.score_item(query, item)
assert result.score < 0.5
assert result.verdict == RelevanceVerdict.INCORRECT
def test_score_chinese_relevant_item(self):
"""Chinese text relevance scoring"""
query = "GEO优化策略"
item = MemoryItem(
key="doc3",
value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面",
score=0.85,
)
result = self.scorer.score_item(query, item)
assert result.score > 0.3 # Chinese bigrams should match
def test_score_short_document_penalty(self):
"""Very short documents should be penalized"""
query = "machine learning algorithms"
short_item = MemoryItem(key="short", value="ML", score=0.9)
normal_item = MemoryItem(
key="normal",
value="Machine learning algorithms include supervised and unsupervised learning methods",
score=0.9,
)
short_result = self.scorer.score_item(query, short_item)
normal_result = self.scorer.score_item(query, normal_item)
assert normal_result.score > short_result.score
def test_evaluate_empty_results(self):
"""Empty retrieval results should be INCORRECT"""
evaluation = self.scorer.evaluate("test query", [])
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
assert evaluation.avg_score == 0.0
assert evaluation.total_count == 0
def test_evaluate_mixed_results(self):
"""Mixed quality results should be AMBIGUOUS or CORRECT"""
query = "Python web framework"
items = [
MemoryItem(key="good", value="Django is a Python web framework", score=0.9),
MemoryItem(key="bad", value="Weather forecast for today", score=0.1),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.total_count == 2
assert evaluation.relevant_count >= 1
def test_evaluate_all_correct(self):
"""All relevant results should give CORRECT verdict"""
query = "Python Django"
items = [
MemoryItem(key="d1", value="Django is a Python web framework", score=0.9),
MemoryItem(key="d2", value="Django REST framework for API development", score=0.85),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.overall_verdict == RelevanceVerdict.CORRECT
def test_evaluate_all_incorrect(self):
"""All irrelevant results should give INCORRECT verdict"""
query = "quantum computing"
items = [
MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1),
MemoryItem(key="d2", value="Gardening tips for spring", score=0.05),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
def test_custom_thresholds(self):
"""Custom thresholds should affect verdict"""
scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7)
query = "test"
item = MemoryItem(key="d1", value="test document with some content", score=0.5)
result = scorer.score_item(query, item)
# With high thresholds, this should be INCORRECT
assert result.verdict == RelevanceVerdict.INCORRECT
def test_jaccard_similarity(self):
"""Jaccard similarity calculation"""
set_a = {"python", "web", "framework"}
set_b = {"python", "web", "server"}
similarity = RelevanceScorer._jaccard_similarity(set_a, set_b)
assert 0.0 < similarity < 1.0
# 2 common / 4 unique = 0.5
assert abs(similarity - 0.5) < 0.01
def test_jaccard_empty_sets(self):
"""Jaccard with empty sets returns 0"""
assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0
assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0
def test_query_coverage(self):
"""Query term coverage calculation"""
query_terms = {"python", "django", "flask"}
doc_terms = {"python", "django", "web", "framework"}
coverage = RelevanceScorer._query_coverage(query_terms, doc_terms)
# 2 out of 3 query terms covered
assert abs(coverage - 2 / 3) < 0.01
def test_tokenize_chinese(self):
"""Chinese tokenization includes bigrams"""
tokens = RelevanceScorer._tokenize("机器学习算法")
# Should include individual chars and bigrams
assert "" in tokens
assert "" in tokens
assert "机器" in tokens # bigram
def test_tokenize_english(self):
"""English tokenization"""
tokens = RelevanceScorer._tokenize("Python Web Framework")
assert "python" in tokens
assert "web" in tokens
assert "framework" in tokens
# --- RAGSelfCorrectionLoop Tests ---
class MockRetriever:
"""Mock retriever for testing"""
def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None):
self._items = items_by_query or {}
self.call_count = 0
self.queries: list[str] = []
async def retrieve(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters=None,
_skip_correction: bool = False,
) -> list[MemoryItem]:
self.call_count += 1
self.queries.append(query)
# Return items for exact query match, or default items
if query in self._items:
return self._items[query]
# Return default items for any query
default_key = next(iter(self._items), None)
if default_key:
return self._items[default_key]
return []
class TestRAGSelfCorrectionLoop:
"""RAGSelfCorrectionLoop unit tests"""
@pytest.mark.asyncio
async def test_correct_retrieval_skips_correction(self):
"""High-quality retrieval should not trigger correction"""
items = [
MemoryItem(
key="d1",
value="Django is a Python web framework for building web applications quickly",
score=0.9,
),
MemoryItem(
key="d2",
value="Flask is a lightweight Python web framework for small applications",
score=0.85,
),
]
mock = MockRetriever({"Python web framework": items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("Python web framework")
assert not result.degraded
assert len(result.items) == 2
assert result.total_retries == 0
@pytest.mark.asyncio
async def test_poor_retrieval_triggers_correction(self):
"""Poor retrieval should trigger query rewriting"""
poor_items = [
MemoryItem(key="d1", value="Weather forecast for today", score=0.1),
]
good_items = [
MemoryItem(
key="d2",
value="Python Django web framework tutorial and best practices",
score=0.9,
),
]
mock = MockRetriever({
"Python web framework": poor_items,
"improved query": good_items,
})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("Python web framework")
assert mock.call_count >= 2 # At least initial + 1 retry
assert len(result.attempts) >= 2
@pytest.mark.asyncio
async def test_max_retries_causes_degradation(self):
"""Exceeding max retries should cause degradation"""
poor_items = [
MemoryItem(key="d1", value="Unrelated content about weather", score=0.05),
]
mock = MockRetriever({"any query": poor_items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
result = await loop.retrieve_with_correction("Python web framework")
assert result.degraded
assert result.total_retries >= 2
# Items should be marked low_confidence
assert any(
item.metadata.get("low_confidence", False) for item in result.items
)
@pytest.mark.asyncio
async def test_empty_retrieval_triggers_correction(self):
"""Empty retrieval results should trigger correction"""
mock = MockRetriever({"query": []})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
result = await loop.retrieve_with_correction("test query")
assert result.degraded
assert result.total_retries >= 1
@pytest.mark.asyncio
async def test_loop_result_tracks_attempts(self):
"""Loop result should track all correction attempts"""
items = [
MemoryItem(key="d1", value="Relevant Python content", score=0.9),
]
mock = MockRetriever({"test": items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("test")
assert len(result.attempts) >= 1
assert result.attempts[0].query == "test"
assert result.attempts[0].state in (
LoopState.GENERATE,
LoopState.CORRECT,
LoopState.DEGRADE,
)
@pytest.mark.asyncio
async def test_correction_with_query_transformer(self):
"""Query transformer should be used during correction"""
from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase
class MockTransformer(QueryTransformerBase):
def __init__(self):
self.transform_count = 0
async def transform(self, query: str) -> TransformedQuery:
self.transform_count += 1
return TransformedQuery(
main_query=f"improved {query}",
sub_queries=[f"sub-{query}"],
original_query=query,
)
poor_items = [
MemoryItem(key="d1", value="Unrelated", score=0.05),
]
good_items = [
MemoryItem(key="d2", value="Relevant Python content", score=0.9),
]
mock = MockRetriever({
"test": poor_items,
"sub-test": good_items,
})
transformer = MockTransformer()
loop = RAGSelfCorrectionLoop(
retriever=mock,
query_transformer=transformer,
max_retries=3,
)
result = await loop.retrieve_with_correction("test")
assert transformer.transform_count >= 1
@pytest.mark.asyncio
async def test_degraded_result_filters_irrelevant(self):
"""Degraded result should prefer relevant items over irrelevant"""
mixed_items = [
MemoryItem(key="good", value="Python Django framework", score=0.8),
MemoryItem(key="bad", value="Weather forecast", score=0.05),
]
mock = MockRetriever({"query": mixed_items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1)
result = await loop.retrieve_with_correction("Python framework")
# Even if degraded, should prefer relevant items
if result.degraded:
relevant_keys = [item.key for item in result.items]
assert "good" in relevant_keys