"""Tests for RelevanceScorer and RAGSelfCorrectionLoop""" import pytest from agentkit.memory.base import MemoryItem from agentkit.memory.relevance_scorer import ( RelevanceScorer, RelevanceScore, RelevanceVerdict, RetrievalEvaluation, ) from agentkit.memory.rag_loop import ( RAGSelfCorrectionLoop, RAGLoopResult, LoopState, ) # --- RelevanceScorer Tests --- class TestRelevanceScorer: """RelevanceScorer unit tests""" def setup_method(self): self.scorer = RelevanceScorer() def test_score_highly_relevant_item(self): """Highly relevant document should score high""" query = "Python web framework Django Flask" item = MemoryItem( key="doc1", value="Django and Flask are popular Python web frameworks for building web applications", score=0.9, ) result = self.scorer.score_item(query, item) assert result.score > 0.5 assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS) def test_score_irrelevant_item(self): """Completely irrelevant document should score low""" query = "Python web framework" item = MemoryItem( key="doc2", value="The weather is sunny today and the birds are singing in the garden", score=0.1, ) result = self.scorer.score_item(query, item) assert result.score < 0.5 assert result.verdict == RelevanceVerdict.INCORRECT def test_score_chinese_relevant_item(self): """Chinese text relevance scoring""" query = "GEO优化策略" item = MemoryItem( key="doc3", value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面", score=0.85, ) result = self.scorer.score_item(query, item) assert result.score > 0.3 # Chinese bigrams should match def test_score_short_document_penalty(self): """Very short documents should be penalized""" query = "machine learning algorithms" short_item = MemoryItem(key="short", value="ML", score=0.9) normal_item = MemoryItem( key="normal", value="Machine learning algorithms include supervised and unsupervised learning methods", score=0.9, ) short_result = self.scorer.score_item(query, short_item) normal_result = self.scorer.score_item(query, normal_item) assert normal_result.score > short_result.score def test_evaluate_empty_results(self): """Empty retrieval results should be INCORRECT""" evaluation = self.scorer.evaluate("test query", []) assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT assert evaluation.avg_score == 0.0 assert evaluation.total_count == 0 def test_evaluate_mixed_results(self): """Mixed quality results should be AMBIGUOUS or CORRECT""" query = "Python web framework" items = [ MemoryItem(key="good", value="Django is a Python web framework", score=0.9), MemoryItem(key="bad", value="Weather forecast for today", score=0.1), ] evaluation = self.scorer.evaluate(query, items) assert evaluation.total_count == 2 assert evaluation.relevant_count >= 1 def test_evaluate_all_correct(self): """All relevant results should give CORRECT verdict""" query = "Python Django" items = [ MemoryItem(key="d1", value="Django is a Python web framework", score=0.9), MemoryItem(key="d2", value="Django REST framework for API development", score=0.85), ] evaluation = self.scorer.evaluate(query, items) assert evaluation.overall_verdict == RelevanceVerdict.CORRECT def test_evaluate_all_incorrect(self): """All irrelevant results should give INCORRECT verdict""" query = "quantum computing" items = [ MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1), MemoryItem(key="d2", value="Gardening tips for spring", score=0.05), ] evaluation = self.scorer.evaluate(query, items) assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT def test_custom_thresholds(self): """Custom thresholds should affect verdict""" scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7) query = "test" item = MemoryItem(key="d1", value="test document with some content", score=0.5) result = scorer.score_item(query, item) # With high thresholds, this should be INCORRECT assert result.verdict == RelevanceVerdict.INCORRECT def test_jaccard_similarity(self): """Jaccard similarity calculation""" set_a = {"python", "web", "framework"} set_b = {"python", "web", "server"} similarity = RelevanceScorer._jaccard_similarity(set_a, set_b) assert 0.0 < similarity < 1.0 # 2 common / 4 unique = 0.5 assert abs(similarity - 0.5) < 0.01 def test_jaccard_empty_sets(self): """Jaccard with empty sets returns 0""" assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0 assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0 def test_query_coverage(self): """Query term coverage calculation""" query_terms = {"python", "django", "flask"} doc_terms = {"python", "django", "web", "framework"} coverage = RelevanceScorer._query_coverage(query_terms, doc_terms) # 2 out of 3 query terms covered assert abs(coverage - 2 / 3) < 0.01 def test_tokenize_chinese(self): """Chinese tokenization includes bigrams""" tokens = RelevanceScorer._tokenize("机器学习算法") # Should include individual chars and bigrams assert "机" in tokens assert "器" in tokens assert "机器" in tokens # bigram def test_tokenize_english(self): """English tokenization""" tokens = RelevanceScorer._tokenize("Python Web Framework") assert "python" in tokens assert "web" in tokens assert "framework" in tokens # --- RAGSelfCorrectionLoop Tests --- class MockRetriever: """Mock retriever for testing""" def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None): self._items = items_by_query or {} self.call_count = 0 self.queries: list[str] = [] async def retrieve( self, query: str, top_k: int = 5, token_budget: int = 3000, filters=None, _skip_correction: bool = False, ) -> list[MemoryItem]: self.call_count += 1 self.queries.append(query) # Return items for exact query match, or default items if query in self._items: return self._items[query] # Return default items for any query default_key = next(iter(self._items), None) if default_key: return self._items[default_key] return [] class TestRAGSelfCorrectionLoop: """RAGSelfCorrectionLoop unit tests""" @pytest.mark.asyncio async def test_correct_retrieval_skips_correction(self): """High-quality retrieval should not trigger correction""" items = [ MemoryItem( key="d1", value="Django is a Python web framework for building web applications quickly", score=0.9, ), MemoryItem( key="d2", value="Flask is a lightweight Python web framework for small applications", score=0.85, ), ] mock = MockRetriever({"Python web framework": items}) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) result = await loop.retrieve_with_correction("Python web framework") assert not result.degraded assert len(result.items) == 2 assert result.total_retries == 0 @pytest.mark.asyncio async def test_poor_retrieval_triggers_correction(self): """Poor retrieval should trigger query rewriting""" poor_items = [ MemoryItem(key="d1", value="Weather forecast for today", score=0.1), ] good_items = [ MemoryItem( key="d2", value="Python Django web framework tutorial and best practices", score=0.9, ), ] mock = MockRetriever({ "Python web framework": poor_items, "improved query": good_items, }) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) result = await loop.retrieve_with_correction("Python web framework") assert mock.call_count >= 2 # At least initial + 1 retry assert len(result.attempts) >= 2 @pytest.mark.asyncio async def test_max_retries_causes_degradation(self): """Exceeding max retries should cause degradation""" poor_items = [ MemoryItem(key="d1", value="Unrelated content about weather", score=0.05), ] mock = MockRetriever({"any query": poor_items}) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2) result = await loop.retrieve_with_correction("Python web framework") assert result.degraded assert result.total_retries >= 2 # Items should be marked low_confidence assert any( item.metadata.get("low_confidence", False) for item in result.items ) @pytest.mark.asyncio async def test_empty_retrieval_triggers_correction(self): """Empty retrieval results should trigger correction""" mock = MockRetriever({"query": []}) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2) result = await loop.retrieve_with_correction("test query") assert result.degraded assert result.total_retries >= 1 @pytest.mark.asyncio async def test_loop_result_tracks_attempts(self): """Loop result should track all correction attempts""" items = [ MemoryItem(key="d1", value="Relevant Python content", score=0.9), ] mock = MockRetriever({"test": items}) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) result = await loop.retrieve_with_correction("test") assert len(result.attempts) >= 1 assert result.attempts[0].query == "test" assert result.attempts[0].state in ( LoopState.GENERATE, LoopState.CORRECT, LoopState.DEGRADE, ) @pytest.mark.asyncio async def test_correction_with_query_transformer(self): """Query transformer should be used during correction""" from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase class MockTransformer(QueryTransformerBase): def __init__(self): self.transform_count = 0 async def transform(self, query: str) -> TransformedQuery: self.transform_count += 1 return TransformedQuery( main_query=f"improved {query}", sub_queries=[f"sub-{query}"], original_query=query, ) poor_items = [ MemoryItem(key="d1", value="Unrelated", score=0.05), ] good_items = [ MemoryItem(key="d2", value="Relevant Python content", score=0.9), ] mock = MockRetriever({ "test": poor_items, "sub-test": good_items, }) transformer = MockTransformer() loop = RAGSelfCorrectionLoop( retriever=mock, query_transformer=transformer, max_retries=3, ) result = await loop.retrieve_with_correction("test") assert transformer.transform_count >= 1 @pytest.mark.asyncio async def test_degraded_result_filters_irrelevant(self): """Degraded result should prefer relevant items over irrelevant""" mixed_items = [ MemoryItem(key="good", value="Python Django framework", score=0.8), MemoryItem(key="bad", value="Weather forecast", score=0.05), ] mock = MockRetriever({"query": mixed_items}) loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1) result = await loop.retrieve_with_correction("Python framework") # Even if degraded, should prefer relevant items if result.degraded: relevant_keys = [item.key for item in result.items] assert "good" in relevant_keys