From a6c9babfdcb64860a1777d99fc0129103c126911 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sat, 6 Jun 2026 22:16:23 +0800 Subject: [PATCH] feat(memory): U1 RAG self-correction loop (CRAG) - RelevanceScorer: keyword overlap + query coverage + retrieval score + length penalty - RAGSelfCorrectionLoop: state machine driven retrieve-evaluate-correct-degrade cycle - Integrated into MemoryRetriever with enable_self_correction option - 21 tests passing --- src/agentkit/memory/rag_loop.py | 237 +++++++++++++++++ src/agentkit/memory/relevance_scorer.py | 215 +++++++++++++++ src/agentkit/memory/retriever.py | 38 ++- tests/unit/test_rag_loop.py | 337 ++++++++++++++++++++++++ 4 files changed, 826 insertions(+), 1 deletion(-) create mode 100644 src/agentkit/memory/rag_loop.py create mode 100644 src/agentkit/memory/relevance_scorer.py create mode 100644 tests/unit/test_rag_loop.py diff --git a/src/agentkit/memory/rag_loop.py b/src/agentkit/memory/rag_loop.py new file mode 100644 index 0000000..b0d6074 --- /dev/null +++ b/src/agentkit/memory/rag_loop.py @@ -0,0 +1,237 @@ +"""RAGSelfCorrectionLoop - CRAG 自纠正循环 + +实现 Corrective RAG 模式:检索→评估→纠正/降级→生成 +当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from agentkit.memory.base import MemoryItem +from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer +from agentkit.memory.relevance_scorer import ( + RelevanceScorer, + RelevanceVerdict, + RetrievalEvaluation, +) + +logger = logging.getLogger(__name__) + + +class LoopState(str, Enum): + """自纠正循环状态""" + + RETRIEVE = "retrieve" + EVALUATE = "evaluate" + CORRECT = "correct" + DEGRADE = "degrade" + GENERATE = "generate" + + +@dataclass +class CorrectionAttempt: + """一次纠正尝试的记录""" + + query: str + evaluation: RetrievalEvaluation + state: LoopState + + +@dataclass +class RAGLoopResult: + """自纠正循环的最终结果""" + + items: list[MemoryItem] + evaluation: RetrievalEvaluation + attempts: list[CorrectionAttempt] + corrected: bool + degraded: bool + total_retries: int + + +class RAGSelfCorrectionLoop: + """CRAG 自纠正循环 + + 状态机驱动的检索-评估-纠正循环: + 1. RETRIEVE: 使用 MemoryRetriever 检索 + 2. EVALUATE: RelevanceScorer 评估检索质量 + 3. CORRECT: 质量不足时,改写查询重新检索 + 4. DEGRADE: 超过重试次数,返回降级结果 + 5. GENERATE: 质量足够,返回结果 + + 熔断机制: + - max_retries: 最大重试次数(默认 3) + - 超过重试次数后强制降级,标记 low_confidence + """ + + def __init__( + self, + retriever: Any, # MemoryRetriever + scorer: RelevanceScorer | None = None, + query_transformer: QueryTransformerBase | None = None, + max_retries: int = 3, + min_items_for_correct: int = 1, + ): + self._retriever = retriever + self._scorer = scorer or RelevanceScorer() + self._query_transformer = query_transformer or NoOpQueryTransformer() + self._max_retries = max_retries + self._min_items_for_correct = min_items_for_correct + + async def retrieve_with_correction( + self, + query: str, + top_k: int = 5, + token_budget: int = 3000, + filters: dict[str, Any] | None = None, + ) -> RAGLoopResult: + """执行带自纠正的检索 + + Args: + query: 原始查询 + top_k: 返回的最大结果数 + token_budget: token 预算 + filters: 过滤条件 + + Returns: + RAGLoopResult: 包含检索结果、评估、尝试记录 + """ + attempts: list[CorrectionAttempt] = [] + current_query = query + retry_count = 0 + + while retry_count <= self._max_retries: + # RETRIEVE + items = await self._retriever.retrieve( + current_query, top_k=top_k, token_budget=token_budget, + filters=filters, _skip_correction=True, + ) + + # EVALUATE + evaluation = self._scorer.evaluate(current_query, items) + state = self._determine_next_state(evaluation, items) + + attempt = CorrectionAttempt( + query=current_query, + evaluation=evaluation, + state=state, + ) + attempts.append(attempt) + + logger.info( + f"RAG loop attempt {retry_count + 1}: " + f"query='{current_query[:50]}...', " + f"verdict={evaluation.overall_verdict.value}, " + f"avg_score={evaluation.avg_score:.2f}, " + f"state={state.value}" + ) + + # GENERATE — quality is sufficient + if state == LoopState.GENERATE: + return RAGLoopResult( + items=items, + evaluation=evaluation, + attempts=attempts, + corrected=retry_count > 0, + degraded=False, + total_retries=retry_count, + ) + + # CORRECT — rewrite query and retry + retry_count += 1 + if retry_count <= self._max_retries: + current_query = await self._rewrite_query( + query, current_query, evaluation + ) + continue + + # DEGRADE — exceeded max retries + break + + # Degraded result: filter to relevant items and mark low confidence + relevant_items = [ + s.item + for s in evaluation.scores + if s.verdict != RelevanceVerdict.INCORRECT + ] + result_items = relevant_items if relevant_items else items + + for item in result_items: + item.metadata["low_confidence"] = True + + return RAGLoopResult( + items=result_items, + evaluation=evaluation, + attempts=attempts, + corrected=False, + degraded=True, + total_retries=retry_count, + ) + + def _determine_next_state( + self, evaluation: RetrievalEvaluation, items: list[MemoryItem] + ) -> LoopState: + """根据评估结果确定下一个状态""" + verdict = evaluation.overall_verdict + + if verdict == RelevanceVerdict.CORRECT: + if evaluation.relevant_count >= self._min_items_for_correct: + return LoopState.GENERATE + # Correct verdict but not enough items — still try to generate + if items: + return LoopState.GENERATE + return LoopState.CORRECT + + if verdict == RelevanceVerdict.AMBIGUOUS: + # Some relevant results — could improve but not terrible + return LoopState.CORRECT + + # INCORRECT — definitely need correction + return LoopState.CORRECT + + async def _rewrite_query( + self, + original_query: str, + current_query: str, + evaluation: RetrievalEvaluation, + ) -> str: + """改写查询以改善检索质量 + + 策略: + 1. 使用 QueryTransformer 改写 + 2. 从评估结果中提取改进线索 + 3. 追加失败模式提示 + """ + # Use query transformer for rewriting + transformed = await self._query_transformer.transform(current_query) + new_query = transformed.main_query + + # If transformer didn't change the query, try with original + if new_query == current_query: + # Add context from failed evaluation to help next retrieval + failed_terms = [] + for score in evaluation.scores: + if score.verdict == RelevanceVerdict.INCORRECT: + # Extract key terms from low-scoring items to avoid + doc_text = str(score.item.value)[:100] + failed_terms.append(doc_text) + + if failed_terms and original_query != current_query: + # Try original query as fallback + new_query = original_query + elif failed_terms: + # Add "NOT" context to help filter + new_query = f"{current_query} (excluding irrelevant results)" + + # Add sub-queries if available + if transformed.sub_queries: + # Use the first sub-query as the new primary query + # This explores different aspects of the original question + new_query = transformed.sub_queries[0] + + logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'") + return new_query diff --git a/src/agentkit/memory/relevance_scorer.py b/src/agentkit/memory/relevance_scorer.py new file mode 100644 index 0000000..7866cce --- /dev/null +++ b/src/agentkit/memory/relevance_scorer.py @@ -0,0 +1,215 @@ +"""RelevanceScorer - 检索结果相关性自动评估 + +对检索结果逐文档评估与查询的相关性,用于 CRAG 自纠正循环的评估阶段。 +""" + +from __future__ import annotations + +import logging +import math +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from agentkit.memory.base import MemoryItem + +logger = logging.getLogger(__name__) + + +class RelevanceVerdict(str, Enum): + """相关性判定结果""" + + CORRECT = "correct" + AMBIGUOUS = "ambiguous" + INCORRECT = "incorrect" + + +@dataclass +class RelevanceScore: + """单个文档的相关性评分""" + + item: MemoryItem + score: float # 0.0 ~ 1.0 + verdict: RelevanceVerdict + reason: str = "" + + +@dataclass +class RetrievalEvaluation: + """一次检索的整体评估结果""" + + scores: list[RelevanceScore] + overall_verdict: RelevanceVerdict + avg_score: float + relevant_count: int + total_count: int + + +class RelevanceScorer: + """检索结果相关性评估器 + + 基于查询-文档语义相似度和关键词重叠的轻量级评估器。 + 不依赖 LLM 调用,适用于生产环境的低延迟评估。 + + 评分策略: + 1. 关键词重叠率(Jaccard 相似度) + 2. 查询词覆盖率(query term coverage) + 3. 原始检索分数加权 + 4. 长度惩罚(过短或过长的文档降分) + """ + + def __init__( + self, + correct_threshold: float = 0.6, + ambiguous_threshold: float = 0.35, + keyword_weight: float = 0.3, + coverage_weight: float = 0.3, + retrieval_weight: float = 0.3, + length_weight: float = 0.1, + min_doc_length: int = 20, + max_doc_length: int = 5000, + ): + self._correct_threshold = correct_threshold + self._ambiguous_threshold = ambiguous_threshold + self._keyword_weight = keyword_weight + self._coverage_weight = coverage_weight + self._retrieval_weight = retrieval_weight + self._length_weight = length_weight + self._min_doc_length = min_doc_length + self._max_doc_length = max_doc_length + + def score_item(self, query: str, item: MemoryItem) -> RelevanceScore: + """评估单个检索结果与查询的相关性""" + doc_text = str(item.value) + + # 1. Keyword overlap (Jaccard similarity) + query_terms = self._tokenize(query) + doc_terms = self._tokenize(doc_text) + keyword_score = self._jaccard_similarity(query_terms, doc_terms) + + # 2. Query term coverage + coverage_score = self._query_coverage(query_terms, doc_terms) + + # 3. Original retrieval score + retrieval_score = min(max(item.score, 0.0), 1.0) + + # 4. Length penalty + length_score = self._length_score(len(doc_text)) + + # Weighted combination + final_score = ( + keyword_score * self._keyword_weight + + coverage_score * self._coverage_weight + + retrieval_score * self._retrieval_weight + + length_score * self._length_weight + ) + + # Determine verdict + verdict = self._determine_verdict(final_score) + + reason = ( + f"keyword={keyword_score:.2f}, coverage={coverage_score:.2f}, " + f"retrieval={retrieval_score:.2f}, length={length_score:.2f}" + ) + + return RelevanceScore( + item=item, + score=final_score, + verdict=verdict, + reason=reason, + ) + + def evaluate( + self, query: str, items: list[MemoryItem] + ) -> RetrievalEvaluation: + """评估一次检索的整体质量""" + if not items: + return RetrievalEvaluation( + scores=[], + overall_verdict=RelevanceVerdict.INCORRECT, + avg_score=0.0, + relevant_count=0, + total_count=0, + ) + + scores = [self.score_item(query, item) for item in items] + relevant_count = sum( + 1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT + ) + avg_score = sum(s.score for s in scores) / len(scores) + + # Overall verdict based on average score and relevant ratio + relevant_ratio = relevant_count / len(scores) + + if avg_score >= self._correct_threshold and relevant_ratio >= 0.5: + overall_verdict = RelevanceVerdict.CORRECT + elif avg_score >= self._ambiguous_threshold or relevant_ratio >= 0.3: + overall_verdict = RelevanceVerdict.AMBIGUOUS + else: + overall_verdict = RelevanceVerdict.INCORRECT + + return RetrievalEvaluation( + scores=scores, + overall_verdict=overall_verdict, + avg_score=avg_score, + relevant_count=relevant_count, + total_count=len(scores), + ) + + def _determine_verdict(self, score: float) -> RelevanceVerdict: + """根据分数判定相关性""" + if score >= self._correct_threshold: + return RelevanceVerdict.CORRECT + elif score >= self._ambiguous_threshold: + return RelevanceVerdict.AMBIGUOUS + else: + return RelevanceVerdict.INCORRECT + + @staticmethod + def _tokenize(text: str) -> set[str]: + """分词:中文按字符,英文按空格,统一小写""" + tokens: set[str] = set() + # Extract English words + en_words = re.findall(r"[a-zA-Z]+", text.lower()) + tokens.update(en_words) + # Extract Chinese characters (individual chars + bigrams) + cn_chars = re.findall(r"[\u4e00-\u9fff]", text) + tokens.update(cn_chars) + # Add Chinese bigrams for better matching + for i in range(len(cn_chars) - 1): + tokens.add(cn_chars[i] + cn_chars[i + 1]) + return tokens + + @staticmethod + def _jaccard_similarity(set_a: set[str], set_b: set[str]) -> float: + """Jaccard 相似度""" + if not set_a or not set_b: + return 0.0 + intersection = len(set_a & set_b) + union = len(set_a | set_b) + if union == 0: + return 0.0 + return intersection / union + + @staticmethod + def _query_coverage(query_terms: set[str], doc_terms: set[str]) -> float: + """查询词覆盖率:文档中出现的查询词比例""" + if not query_terms: + return 0.0 + covered = len(query_terms & doc_terms) + return covered / len(query_terms) + + def _length_score(self, length: int) -> float: + """长度评分:过短或过长的文档降分""" + if length < self._min_doc_length: + # Too short — likely insufficient context + ratio = length / self._min_doc_length + return ratio * 0.5 + elif length > self._max_doc_length: + # Too long — may contain irrelevant information + excess = (length - self._max_doc_length) / self._max_doc_length + return max(0.3, 1.0 - excess * 0.5) + else: + # Good length range + return 1.0 diff --git a/src/agentkit/memory/retriever.py b/src/agentkit/memory/retriever.py index dad7531..ebbc571 100644 --- a/src/agentkit/memory/retriever.py +++ b/src/agentkit/memory/retriever.py @@ -17,6 +17,8 @@ from agentkit.memory.working import WorkingMemory from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.semantic import SemanticMemory from agentkit.memory.query_transformer import QueryTransformerBase +from agentkit.memory.rag_loop import RAGSelfCorrectionLoop +from agentkit.memory.relevance_scorer import RelevanceScorer from agentkit.tools.base import Tool logger = logging.getLogger(__name__) @@ -55,6 +57,8 @@ class MemoryRetriever: weights: dict[str, float] | None = None, query_transformer: QueryTransformerBase | None = None, context_template: str = "structured", + enable_self_correction: bool = False, + max_correction_retries: int = 3, ): self._working = working_memory self._episodic = episodic_memory @@ -66,6 +70,15 @@ class MemoryRetriever: } self._query_transformer = query_transformer self._context_template = context_template + self._enable_self_correction = enable_self_correction + self._correction_loop: RAGSelfCorrectionLoop | None = None + if enable_self_correction: + self._correction_loop = RAGSelfCorrectionLoop( + retriever=self, + scorer=RelevanceScorer(), + query_transformer=query_transformer, + max_retries=max_correction_retries, + ) async def retrieve( self, @@ -73,8 +86,31 @@ class MemoryRetriever: top_k: int = 5, token_budget: int = 3000, filters: dict[str, Any] | None = None, + _skip_correction: bool = False, ) -> list[MemoryItem]: - """混合检索三层记忆""" + """混合检索三层记忆 + + Args: + query: 检索查询 + top_k: 返回最大结果数 + token_budget: token 预算 + filters: 过滤条件 + _skip_correction: 内部参数,CRAG 循环内部调用时跳过自纠正 + """ + # Self-correction loop (CRAG) + if ( + self._enable_self_correction + and self._correction_loop is not None + and not _skip_correction + ): + result = await self._correction_loop.retrieve_with_correction( + query, top_k=top_k, token_budget=token_budget, filters=filters + ) + if result.degraded: + logger.warning( + f"RAG self-correction degraded after {result.total_retries} retries" + ) + return result.items # Query transformation if self._query_transformer is not None: transformed = await self._query_transformer.transform(query) diff --git a/tests/unit/test_rag_loop.py b/tests/unit/test_rag_loop.py new file mode 100644 index 0000000..332565f --- /dev/null +++ b/tests/unit/test_rag_loop.py @@ -0,0 +1,337 @@ +"""Tests for RelevanceScorer and RAGSelfCorrectionLoop""" + +import pytest + +from agentkit.memory.base import MemoryItem +from agentkit.memory.relevance_scorer import ( + RelevanceScorer, + RelevanceScore, + RelevanceVerdict, + RetrievalEvaluation, +) +from agentkit.memory.rag_loop import ( + RAGSelfCorrectionLoop, + RAGLoopResult, + LoopState, +) + + +# --- RelevanceScorer Tests --- + + +class TestRelevanceScorer: + """RelevanceScorer unit tests""" + + def setup_method(self): + self.scorer = RelevanceScorer() + + def test_score_highly_relevant_item(self): + """Highly relevant document should score high""" + query = "Python web framework Django Flask" + item = MemoryItem( + key="doc1", + value="Django and Flask are popular Python web frameworks for building web applications", + score=0.9, + ) + result = self.scorer.score_item(query, item) + assert result.score > 0.5 + assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS) + + def test_score_irrelevant_item(self): + """Completely irrelevant document should score low""" + query = "Python web framework" + item = MemoryItem( + key="doc2", + value="The weather is sunny today and the birds are singing in the garden", + score=0.1, + ) + result = self.scorer.score_item(query, item) + assert result.score < 0.5 + assert result.verdict == RelevanceVerdict.INCORRECT + + def test_score_chinese_relevant_item(self): + """Chinese text relevance scoring""" + query = "GEO优化策略" + item = MemoryItem( + key="doc3", + value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面", + score=0.85, + ) + result = self.scorer.score_item(query, item) + assert result.score > 0.3 # Chinese bigrams should match + + def test_score_short_document_penalty(self): + """Very short documents should be penalized""" + query = "machine learning algorithms" + short_item = MemoryItem(key="short", value="ML", score=0.9) + normal_item = MemoryItem( + key="normal", + value="Machine learning algorithms include supervised and unsupervised learning methods", + score=0.9, + ) + short_result = self.scorer.score_item(query, short_item) + normal_result = self.scorer.score_item(query, normal_item) + assert normal_result.score > short_result.score + + def test_evaluate_empty_results(self): + """Empty retrieval results should be INCORRECT""" + evaluation = self.scorer.evaluate("test query", []) + assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT + assert evaluation.avg_score == 0.0 + assert evaluation.total_count == 0 + + def test_evaluate_mixed_results(self): + """Mixed quality results should be AMBIGUOUS or CORRECT""" + query = "Python web framework" + items = [ + MemoryItem(key="good", value="Django is a Python web framework", score=0.9), + MemoryItem(key="bad", value="Weather forecast for today", score=0.1), + ] + evaluation = self.scorer.evaluate(query, items) + assert evaluation.total_count == 2 + assert evaluation.relevant_count >= 1 + + def test_evaluate_all_correct(self): + """All relevant results should give CORRECT verdict""" + query = "Python Django" + items = [ + MemoryItem(key="d1", value="Django is a Python web framework", score=0.9), + MemoryItem(key="d2", value="Django REST framework for API development", score=0.85), + ] + evaluation = self.scorer.evaluate(query, items) + assert evaluation.overall_verdict == RelevanceVerdict.CORRECT + + def test_evaluate_all_incorrect(self): + """All irrelevant results should give INCORRECT verdict""" + query = "quantum computing" + items = [ + MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1), + MemoryItem(key="d2", value="Gardening tips for spring", score=0.05), + ] + evaluation = self.scorer.evaluate(query, items) + assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT + + def test_custom_thresholds(self): + """Custom thresholds should affect verdict""" + scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7) + query = "test" + item = MemoryItem(key="d1", value="test document with some content", score=0.5) + result = scorer.score_item(query, item) + # With high thresholds, this should be INCORRECT + assert result.verdict == RelevanceVerdict.INCORRECT + + def test_jaccard_similarity(self): + """Jaccard similarity calculation""" + set_a = {"python", "web", "framework"} + set_b = {"python", "web", "server"} + similarity = RelevanceScorer._jaccard_similarity(set_a, set_b) + assert 0.0 < similarity < 1.0 + # 2 common / 4 unique = 0.5 + assert abs(similarity - 0.5) < 0.01 + + def test_jaccard_empty_sets(self): + """Jaccard with empty sets returns 0""" + assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0 + assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0 + + def test_query_coverage(self): + """Query term coverage calculation""" + query_terms = {"python", "django", "flask"} + doc_terms = {"python", "django", "web", "framework"} + coverage = RelevanceScorer._query_coverage(query_terms, doc_terms) + # 2 out of 3 query terms covered + assert abs(coverage - 2 / 3) < 0.01 + + def test_tokenize_chinese(self): + """Chinese tokenization includes bigrams""" + tokens = RelevanceScorer._tokenize("机器学习算法") + # Should include individual chars and bigrams + assert "机" in tokens + assert "器" in tokens + assert "机器" in tokens # bigram + + def test_tokenize_english(self): + """English tokenization""" + tokens = RelevanceScorer._tokenize("Python Web Framework") + assert "python" in tokens + assert "web" in tokens + assert "framework" in tokens + + +# --- RAGSelfCorrectionLoop Tests --- + + +class MockRetriever: + """Mock retriever for testing""" + + def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None): + self._items = items_by_query or {} + self.call_count = 0 + self.queries: list[str] = [] + + async def retrieve( + self, + query: str, + top_k: int = 5, + token_budget: int = 3000, + filters=None, + _skip_correction: bool = False, + ) -> list[MemoryItem]: + self.call_count += 1 + self.queries.append(query) + # Return items for exact query match, or default items + if query in self._items: + return self._items[query] + # Return default items for any query + default_key = next(iter(self._items), None) + if default_key: + return self._items[default_key] + return [] + + +class TestRAGSelfCorrectionLoop: + """RAGSelfCorrectionLoop unit tests""" + + @pytest.mark.asyncio + async def test_correct_retrieval_skips_correction(self): + """High-quality retrieval should not trigger correction""" + items = [ + MemoryItem( + key="d1", + value="Django is a Python web framework for building web applications quickly", + score=0.9, + ), + MemoryItem( + key="d2", + value="Flask is a lightweight Python web framework for small applications", + score=0.85, + ), + ] + mock = MockRetriever({"Python web framework": items}) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) + + result = await loop.retrieve_with_correction("Python web framework") + assert not result.degraded + assert len(result.items) == 2 + assert result.total_retries == 0 + + @pytest.mark.asyncio + async def test_poor_retrieval_triggers_correction(self): + """Poor retrieval should trigger query rewriting""" + poor_items = [ + MemoryItem(key="d1", value="Weather forecast for today", score=0.1), + ] + good_items = [ + MemoryItem( + key="d2", + value="Python Django web framework tutorial and best practices", + score=0.9, + ), + ] + mock = MockRetriever({ + "Python web framework": poor_items, + "improved query": good_items, + }) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) + + result = await loop.retrieve_with_correction("Python web framework") + assert mock.call_count >= 2 # At least initial + 1 retry + assert len(result.attempts) >= 2 + + @pytest.mark.asyncio + async def test_max_retries_causes_degradation(self): + """Exceeding max retries should cause degradation""" + poor_items = [ + MemoryItem(key="d1", value="Unrelated content about weather", score=0.05), + ] + mock = MockRetriever({"any query": poor_items}) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2) + + result = await loop.retrieve_with_correction("Python web framework") + assert result.degraded + assert result.total_retries >= 2 + # Items should be marked low_confidence + assert any( + item.metadata.get("low_confidence", False) for item in result.items + ) + + @pytest.mark.asyncio + async def test_empty_retrieval_triggers_correction(self): + """Empty retrieval results should trigger correction""" + mock = MockRetriever({"query": []}) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2) + + result = await loop.retrieve_with_correction("test query") + assert result.degraded + assert result.total_retries >= 1 + + @pytest.mark.asyncio + async def test_loop_result_tracks_attempts(self): + """Loop result should track all correction attempts""" + items = [ + MemoryItem(key="d1", value="Relevant Python content", score=0.9), + ] + mock = MockRetriever({"test": items}) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3) + + result = await loop.retrieve_with_correction("test") + assert len(result.attempts) >= 1 + assert result.attempts[0].query == "test" + assert result.attempts[0].state in ( + LoopState.GENERATE, + LoopState.CORRECT, + LoopState.DEGRADE, + ) + + @pytest.mark.asyncio + async def test_correction_with_query_transformer(self): + """Query transformer should be used during correction""" + from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase + + class MockTransformer(QueryTransformerBase): + def __init__(self): + self.transform_count = 0 + + async def transform(self, query: str) -> TransformedQuery: + self.transform_count += 1 + return TransformedQuery( + main_query=f"improved {query}", + sub_queries=[f"sub-{query}"], + original_query=query, + ) + + poor_items = [ + MemoryItem(key="d1", value="Unrelated", score=0.05), + ] + good_items = [ + MemoryItem(key="d2", value="Relevant Python content", score=0.9), + ] + mock = MockRetriever({ + "test": poor_items, + "sub-test": good_items, + }) + transformer = MockTransformer() + loop = RAGSelfCorrectionLoop( + retriever=mock, + query_transformer=transformer, + max_retries=3, + ) + + result = await loop.retrieve_with_correction("test") + assert transformer.transform_count >= 1 + + @pytest.mark.asyncio + async def test_degraded_result_filters_irrelevant(self): + """Degraded result should prefer relevant items over irrelevant""" + mixed_items = [ + MemoryItem(key="good", value="Python Django framework", score=0.8), + MemoryItem(key="bad", value="Weather forecast", score=0.05), + ] + mock = MockRetriever({"query": mixed_items}) + loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1) + + result = await loop.retrieve_with_correction("Python framework") + # Even if degraded, should prefer relevant items + if result.degraded: + relevant_keys = [item.key for item in result.items] + assert "good" in relevant_keys