338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""Tests for RelevanceScorer and RAGSelfCorrectionLoop"""
|
|
|
|
import pytest
|
|
|
|
from agentkit.memory.base import MemoryItem
|
|
from agentkit.memory.relevance_scorer import (
|
|
RelevanceScorer,
|
|
RelevanceScore,
|
|
RelevanceVerdict,
|
|
RetrievalEvaluation,
|
|
)
|
|
from agentkit.memory.rag_loop import (
|
|
RAGSelfCorrectionLoop,
|
|
RAGLoopResult,
|
|
LoopState,
|
|
)
|
|
|
|
|
|
# --- RelevanceScorer Tests ---
|
|
|
|
|
|
class TestRelevanceScorer:
|
|
"""RelevanceScorer unit tests"""
|
|
|
|
def setup_method(self):
|
|
self.scorer = RelevanceScorer()
|
|
|
|
def test_score_highly_relevant_item(self):
|
|
"""Highly relevant document should score high"""
|
|
query = "Python web framework Django Flask"
|
|
item = MemoryItem(
|
|
key="doc1",
|
|
value="Django and Flask are popular Python web frameworks for building web applications",
|
|
score=0.9,
|
|
)
|
|
result = self.scorer.score_item(query, item)
|
|
assert result.score > 0.5
|
|
assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS)
|
|
|
|
def test_score_irrelevant_item(self):
|
|
"""Completely irrelevant document should score low"""
|
|
query = "Python web framework"
|
|
item = MemoryItem(
|
|
key="doc2",
|
|
value="The weather is sunny today and the birds are singing in the garden",
|
|
score=0.1,
|
|
)
|
|
result = self.scorer.score_item(query, item)
|
|
assert result.score < 0.5
|
|
assert result.verdict == RelevanceVerdict.INCORRECT
|
|
|
|
def test_score_chinese_relevant_item(self):
|
|
"""Chinese text relevance scoring"""
|
|
query = "GEO优化策略"
|
|
item = MemoryItem(
|
|
key="doc3",
|
|
value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面",
|
|
score=0.85,
|
|
)
|
|
result = self.scorer.score_item(query, item)
|
|
assert result.score > 0.3 # Chinese bigrams should match
|
|
|
|
def test_score_short_document_penalty(self):
|
|
"""Very short documents should be penalized"""
|
|
query = "machine learning algorithms"
|
|
short_item = MemoryItem(key="short", value="ML", score=0.9)
|
|
normal_item = MemoryItem(
|
|
key="normal",
|
|
value="Machine learning algorithms include supervised and unsupervised learning methods",
|
|
score=0.9,
|
|
)
|
|
short_result = self.scorer.score_item(query, short_item)
|
|
normal_result = self.scorer.score_item(query, normal_item)
|
|
assert normal_result.score > short_result.score
|
|
|
|
def test_evaluate_empty_results(self):
|
|
"""Empty retrieval results should be INCORRECT"""
|
|
evaluation = self.scorer.evaluate("test query", [])
|
|
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
|
assert evaluation.avg_score == 0.0
|
|
assert evaluation.total_count == 0
|
|
|
|
def test_evaluate_mixed_results(self):
|
|
"""Mixed quality results should be AMBIGUOUS or CORRECT"""
|
|
query = "Python web framework"
|
|
items = [
|
|
MemoryItem(key="good", value="Django is a Python web framework", score=0.9),
|
|
MemoryItem(key="bad", value="Weather forecast for today", score=0.1),
|
|
]
|
|
evaluation = self.scorer.evaluate(query, items)
|
|
assert evaluation.total_count == 2
|
|
assert evaluation.relevant_count >= 1
|
|
|
|
def test_evaluate_all_correct(self):
|
|
"""All relevant results should give CORRECT verdict"""
|
|
query = "Python Django"
|
|
items = [
|
|
MemoryItem(key="d1", value="Django is a Python web framework", score=0.9),
|
|
MemoryItem(key="d2", value="Django REST framework for API development", score=0.85),
|
|
]
|
|
evaluation = self.scorer.evaluate(query, items)
|
|
assert evaluation.overall_verdict == RelevanceVerdict.CORRECT
|
|
|
|
def test_evaluate_all_incorrect(self):
|
|
"""All irrelevant results should give INCORRECT verdict"""
|
|
query = "quantum computing"
|
|
items = [
|
|
MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1),
|
|
MemoryItem(key="d2", value="Gardening tips for spring", score=0.05),
|
|
]
|
|
evaluation = self.scorer.evaluate(query, items)
|
|
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
|
|
|
def test_custom_thresholds(self):
|
|
"""Custom thresholds should affect verdict"""
|
|
scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7)
|
|
query = "test"
|
|
item = MemoryItem(key="d1", value="test document with some content", score=0.5)
|
|
result = scorer.score_item(query, item)
|
|
# With high thresholds, this should be INCORRECT
|
|
assert result.verdict == RelevanceVerdict.INCORRECT
|
|
|
|
def test_jaccard_similarity(self):
|
|
"""Jaccard similarity calculation"""
|
|
set_a = {"python", "web", "framework"}
|
|
set_b = {"python", "web", "server"}
|
|
similarity = RelevanceScorer._jaccard_similarity(set_a, set_b)
|
|
assert 0.0 < similarity < 1.0
|
|
# 2 common / 4 unique = 0.5
|
|
assert abs(similarity - 0.5) < 0.01
|
|
|
|
def test_jaccard_empty_sets(self):
|
|
"""Jaccard with empty sets returns 0"""
|
|
assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0
|
|
assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0
|
|
|
|
def test_query_coverage(self):
|
|
"""Query term coverage calculation"""
|
|
query_terms = {"python", "django", "flask"}
|
|
doc_terms = {"python", "django", "web", "framework"}
|
|
coverage = RelevanceScorer._query_coverage(query_terms, doc_terms)
|
|
# 2 out of 3 query terms covered
|
|
assert abs(coverage - 2 / 3) < 0.01
|
|
|
|
def test_tokenize_chinese(self):
|
|
"""Chinese tokenization includes bigrams"""
|
|
tokens = RelevanceScorer._tokenize("机器学习算法")
|
|
# Should include individual chars and bigrams
|
|
assert "机" in tokens
|
|
assert "器" in tokens
|
|
assert "机器" in tokens # bigram
|
|
|
|
def test_tokenize_english(self):
|
|
"""English tokenization"""
|
|
tokens = RelevanceScorer._tokenize("Python Web Framework")
|
|
assert "python" in tokens
|
|
assert "web" in tokens
|
|
assert "framework" in tokens
|
|
|
|
|
|
# --- RAGSelfCorrectionLoop Tests ---
|
|
|
|
|
|
class MockRetriever:
|
|
"""Mock retriever for testing"""
|
|
|
|
def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None):
|
|
self._items = items_by_query or {}
|
|
self.call_count = 0
|
|
self.queries: list[str] = []
|
|
|
|
async def retrieve(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5,
|
|
token_budget: int = 3000,
|
|
filters=None,
|
|
_skip_correction: bool = False,
|
|
) -> list[MemoryItem]:
|
|
self.call_count += 1
|
|
self.queries.append(query)
|
|
# Return items for exact query match, or default items
|
|
if query in self._items:
|
|
return self._items[query]
|
|
# Return default items for any query
|
|
default_key = next(iter(self._items), None)
|
|
if default_key:
|
|
return self._items[default_key]
|
|
return []
|
|
|
|
|
|
class TestRAGSelfCorrectionLoop:
|
|
"""RAGSelfCorrectionLoop unit tests"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_correct_retrieval_skips_correction(self):
|
|
"""High-quality retrieval should not trigger correction"""
|
|
items = [
|
|
MemoryItem(
|
|
key="d1",
|
|
value="Django is a Python web framework for building web applications quickly",
|
|
score=0.9,
|
|
),
|
|
MemoryItem(
|
|
key="d2",
|
|
value="Flask is a lightweight Python web framework for small applications",
|
|
score=0.85,
|
|
),
|
|
]
|
|
mock = MockRetriever({"Python web framework": items})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
|
|
|
result = await loop.retrieve_with_correction("Python web framework")
|
|
assert not result.degraded
|
|
assert len(result.items) == 2
|
|
assert result.total_retries == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_poor_retrieval_triggers_correction(self):
|
|
"""Poor retrieval should trigger query rewriting"""
|
|
poor_items = [
|
|
MemoryItem(key="d1", value="Weather forecast for today", score=0.1),
|
|
]
|
|
good_items = [
|
|
MemoryItem(
|
|
key="d2",
|
|
value="Python Django web framework tutorial and best practices",
|
|
score=0.9,
|
|
),
|
|
]
|
|
mock = MockRetriever({
|
|
"Python web framework": poor_items,
|
|
"improved query": good_items,
|
|
})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
|
|
|
result = await loop.retrieve_with_correction("Python web framework")
|
|
assert mock.call_count >= 2 # At least initial + 1 retry
|
|
assert len(result.attempts) >= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_retries_causes_degradation(self):
|
|
"""Exceeding max retries should cause degradation"""
|
|
poor_items = [
|
|
MemoryItem(key="d1", value="Unrelated content about weather", score=0.05),
|
|
]
|
|
mock = MockRetriever({"any query": poor_items})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
|
|
|
result = await loop.retrieve_with_correction("Python web framework")
|
|
assert result.degraded
|
|
assert result.total_retries >= 2
|
|
# Items should be marked low_confidence
|
|
assert any(
|
|
item.metadata.get("low_confidence", False) for item in result.items
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_retrieval_triggers_correction(self):
|
|
"""Empty retrieval results should trigger correction"""
|
|
mock = MockRetriever({"query": []})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
|
|
|
result = await loop.retrieve_with_correction("test query")
|
|
assert result.degraded
|
|
assert result.total_retries >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_loop_result_tracks_attempts(self):
|
|
"""Loop result should track all correction attempts"""
|
|
items = [
|
|
MemoryItem(key="d1", value="Relevant Python content", score=0.9),
|
|
]
|
|
mock = MockRetriever({"test": items})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
|
|
|
result = await loop.retrieve_with_correction("test")
|
|
assert len(result.attempts) >= 1
|
|
assert result.attempts[0].query == "test"
|
|
assert result.attempts[0].state in (
|
|
LoopState.GENERATE,
|
|
LoopState.CORRECT,
|
|
LoopState.DEGRADE,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_correction_with_query_transformer(self):
|
|
"""Query transformer should be used during correction"""
|
|
from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase
|
|
|
|
class MockTransformer(QueryTransformerBase):
|
|
def __init__(self):
|
|
self.transform_count = 0
|
|
|
|
async def transform(self, query: str) -> TransformedQuery:
|
|
self.transform_count += 1
|
|
return TransformedQuery(
|
|
main_query=f"improved {query}",
|
|
sub_queries=[f"sub-{query}"],
|
|
original_query=query,
|
|
)
|
|
|
|
poor_items = [
|
|
MemoryItem(key="d1", value="Unrelated", score=0.05),
|
|
]
|
|
good_items = [
|
|
MemoryItem(key="d2", value="Relevant Python content", score=0.9),
|
|
]
|
|
mock = MockRetriever({
|
|
"test": poor_items,
|
|
"sub-test": good_items,
|
|
})
|
|
transformer = MockTransformer()
|
|
loop = RAGSelfCorrectionLoop(
|
|
retriever=mock,
|
|
query_transformer=transformer,
|
|
max_retries=3,
|
|
)
|
|
|
|
result = await loop.retrieve_with_correction("test")
|
|
assert transformer.transform_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_degraded_result_filters_irrelevant(self):
|
|
"""Degraded result should prefer relevant items over irrelevant"""
|
|
mixed_items = [
|
|
MemoryItem(key="good", value="Python Django framework", score=0.8),
|
|
MemoryItem(key="bad", value="Weather forecast", score=0.05),
|
|
]
|
|
mock = MockRetriever({"query": mixed_items})
|
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1)
|
|
|
|
result = await loop.retrieve_with_correction("Python framework")
|
|
# Even if degraded, should prefer relevant items
|
|
if result.degraded:
|
|
relevant_keys = [item.key for item in result.items]
|
|
assert "good" in relevant_keys
|