feat(memory): U1 RAG self-correction loop (CRAG)

- RelevanceScorer: keyword overlap + query coverage + retrieval score + length penalty
- RAGSelfCorrectionLoop: state machine driven retrieve-evaluate-correct-degrade cycle
- Integrated into MemoryRetriever with enable_self_correction option
- 21 tests passing
This commit is contained in:
chiguyong 2026-06-06 22:16:23 +08:00
parent 468dfd71e8
commit a6c9babfdc
4 changed files with 826 additions and 1 deletions

View File

@ -0,0 +1,237 @@
"""RAGSelfCorrectionLoop - CRAG 自纠正循环
实现 Corrective RAG 模式检索评估纠正/降级生成
当检索结果质量不足时自动改写查询重新检索形成自纠正闭环
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
RelevanceVerdict,
RetrievalEvaluation,
)
logger = logging.getLogger(__name__)
class LoopState(str, Enum):
"""自纠正循环状态"""
RETRIEVE = "retrieve"
EVALUATE = "evaluate"
CORRECT = "correct"
DEGRADE = "degrade"
GENERATE = "generate"
@dataclass
class CorrectionAttempt:
"""一次纠正尝试的记录"""
query: str
evaluation: RetrievalEvaluation
state: LoopState
@dataclass
class RAGLoopResult:
"""自纠正循环的最终结果"""
items: list[MemoryItem]
evaluation: RetrievalEvaluation
attempts: list[CorrectionAttempt]
corrected: bool
degraded: bool
total_retries: int
class RAGSelfCorrectionLoop:
"""CRAG 自纠正循环
状态机驱动的检索-评估-纠正循环
1. RETRIEVE: 使用 MemoryRetriever 检索
2. EVALUATE: RelevanceScorer 评估检索质量
3. CORRECT: 质量不足时改写查询重新检索
4. DEGRADE: 超过重试次数返回降级结果
5. GENERATE: 质量足够返回结果
熔断机制
- max_retries: 最大重试次数默认 3
- 超过重试次数后强制降级标记 low_confidence
"""
def __init__(
self,
retriever: Any, # MemoryRetriever
scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3,
min_items_for_correct: int = 1,
):
self._retriever = retriever
self._scorer = scorer or RelevanceScorer()
self._query_transformer = query_transformer or NoOpQueryTransformer()
self._max_retries = max_retries
self._min_items_for_correct = min_items_for_correct
async def retrieve_with_correction(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
) -> RAGLoopResult:
"""执行带自纠正的检索
Args:
query: 原始查询
top_k: 返回的最大结果数
token_budget: token 预算
filters: 过滤条件
Returns:
RAGLoopResult: 包含检索结果评估尝试记录
"""
attempts: list[CorrectionAttempt] = []
current_query = query
retry_count = 0
while retry_count <= self._max_retries:
# RETRIEVE
items = await self._retriever.retrieve(
current_query, top_k=top_k, token_budget=token_budget,
filters=filters, _skip_correction=True,
)
# EVALUATE
evaluation = self._scorer.evaluate(current_query, items)
state = self._determine_next_state(evaluation, items)
attempt = CorrectionAttempt(
query=current_query,
evaluation=evaluation,
state=state,
)
attempts.append(attempt)
logger.info(
f"RAG loop attempt {retry_count + 1}: "
f"query='{current_query[:50]}...', "
f"verdict={evaluation.overall_verdict.value}, "
f"avg_score={evaluation.avg_score:.2f}, "
f"state={state.value}"
)
# GENERATE — quality is sufficient
if state == LoopState.GENERATE:
return RAGLoopResult(
items=items,
evaluation=evaluation,
attempts=attempts,
corrected=retry_count > 0,
degraded=False,
total_retries=retry_count,
)
# CORRECT — rewrite query and retry
retry_count += 1
if retry_count <= self._max_retries:
current_query = await self._rewrite_query(
query, current_query, evaluation
)
continue
# DEGRADE — exceeded max retries
break
# Degraded result: filter to relevant items and mark low confidence
relevant_items = [
s.item
for s in evaluation.scores
if s.verdict != RelevanceVerdict.INCORRECT
]
result_items = relevant_items if relevant_items else items
for item in result_items:
item.metadata["low_confidence"] = True
return RAGLoopResult(
items=result_items,
evaluation=evaluation,
attempts=attempts,
corrected=False,
degraded=True,
total_retries=retry_count,
)
def _determine_next_state(
self, evaluation: RetrievalEvaluation, items: list[MemoryItem]
) -> LoopState:
"""根据评估结果确定下一个状态"""
verdict = evaluation.overall_verdict
if verdict == RelevanceVerdict.CORRECT:
if evaluation.relevant_count >= self._min_items_for_correct:
return LoopState.GENERATE
# Correct verdict but not enough items — still try to generate
if items:
return LoopState.GENERATE
return LoopState.CORRECT
if verdict == RelevanceVerdict.AMBIGUOUS:
# Some relevant results — could improve but not terrible
return LoopState.CORRECT
# INCORRECT — definitely need correction
return LoopState.CORRECT
async def _rewrite_query(
self,
original_query: str,
current_query: str,
evaluation: RetrievalEvaluation,
) -> str:
"""改写查询以改善检索质量
策略
1. 使用 QueryTransformer 改写
2. 从评估结果中提取改进线索
3. 追加失败模式提示
"""
# Use query transformer for rewriting
transformed = await self._query_transformer.transform(current_query)
new_query = transformed.main_query
# If transformer didn't change the query, try with original
if new_query == current_query:
# Add context from failed evaluation to help next retrieval
failed_terms = []
for score in evaluation.scores:
if score.verdict == RelevanceVerdict.INCORRECT:
# Extract key terms from low-scoring items to avoid
doc_text = str(score.item.value)[:100]
failed_terms.append(doc_text)
if failed_terms and original_query != current_query:
# Try original query as fallback
new_query = original_query
elif failed_terms:
# Add "NOT" context to help filter
new_query = f"{current_query} (excluding irrelevant results)"
# Add sub-queries if available
if transformed.sub_queries:
# Use the first sub-query as the new primary query
# This explores different aspects of the original question
new_query = transformed.sub_queries[0]
logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'")
return new_query

View File

@ -0,0 +1,215 @@
"""RelevanceScorer - 检索结果相关性自动评估
对检索结果逐文档评估与查询的相关性用于 CRAG 自纠正循环的评估阶段
"""
from __future__ import annotations
import logging
import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any
from agentkit.memory.base import MemoryItem
logger = logging.getLogger(__name__)
class RelevanceVerdict(str, Enum):
"""相关性判定结果"""
CORRECT = "correct"
AMBIGUOUS = "ambiguous"
INCORRECT = "incorrect"
@dataclass
class RelevanceScore:
"""单个文档的相关性评分"""
item: MemoryItem
score: float # 0.0 ~ 1.0
verdict: RelevanceVerdict
reason: str = ""
@dataclass
class RetrievalEvaluation:
"""一次检索的整体评估结果"""
scores: list[RelevanceScore]
overall_verdict: RelevanceVerdict
avg_score: float
relevant_count: int
total_count: int
class RelevanceScorer:
"""检索结果相关性评估器
基于查询-文档语义相似度和关键词重叠的轻量级评估器
不依赖 LLM 调用适用于生产环境的低延迟评估
评分策略
1. 关键词重叠率Jaccard 相似度
2. 查询词覆盖率query term coverage
3. 原始检索分数加权
4. 长度惩罚过短或过长的文档降分
"""
def __init__(
self,
correct_threshold: float = 0.6,
ambiguous_threshold: float = 0.35,
keyword_weight: float = 0.3,
coverage_weight: float = 0.3,
retrieval_weight: float = 0.3,
length_weight: float = 0.1,
min_doc_length: int = 20,
max_doc_length: int = 5000,
):
self._correct_threshold = correct_threshold
self._ambiguous_threshold = ambiguous_threshold
self._keyword_weight = keyword_weight
self._coverage_weight = coverage_weight
self._retrieval_weight = retrieval_weight
self._length_weight = length_weight
self._min_doc_length = min_doc_length
self._max_doc_length = max_doc_length
def score_item(self, query: str, item: MemoryItem) -> RelevanceScore:
"""评估单个检索结果与查询的相关性"""
doc_text = str(item.value)
# 1. Keyword overlap (Jaccard similarity)
query_terms = self._tokenize(query)
doc_terms = self._tokenize(doc_text)
keyword_score = self._jaccard_similarity(query_terms, doc_terms)
# 2. Query term coverage
coverage_score = self._query_coverage(query_terms, doc_terms)
# 3. Original retrieval score
retrieval_score = min(max(item.score, 0.0), 1.0)
# 4. Length penalty
length_score = self._length_score(len(doc_text))
# Weighted combination
final_score = (
keyword_score * self._keyword_weight
+ coverage_score * self._coverage_weight
+ retrieval_score * self._retrieval_weight
+ length_score * self._length_weight
)
# Determine verdict
verdict = self._determine_verdict(final_score)
reason = (
f"keyword={keyword_score:.2f}, coverage={coverage_score:.2f}, "
f"retrieval={retrieval_score:.2f}, length={length_score:.2f}"
)
return RelevanceScore(
item=item,
score=final_score,
verdict=verdict,
reason=reason,
)
def evaluate(
self, query: str, items: list[MemoryItem]
) -> RetrievalEvaluation:
"""评估一次检索的整体质量"""
if not items:
return RetrievalEvaluation(
scores=[],
overall_verdict=RelevanceVerdict.INCORRECT,
avg_score=0.0,
relevant_count=0,
total_count=0,
)
scores = [self.score_item(query, item) for item in items]
relevant_count = sum(
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
)
avg_score = sum(s.score for s in scores) / len(scores)
# Overall verdict based on average score and relevant ratio
relevant_ratio = relevant_count / len(scores)
if avg_score >= self._correct_threshold and relevant_ratio >= 0.5:
overall_verdict = RelevanceVerdict.CORRECT
elif avg_score >= self._ambiguous_threshold or relevant_ratio >= 0.3:
overall_verdict = RelevanceVerdict.AMBIGUOUS
else:
overall_verdict = RelevanceVerdict.INCORRECT
return RetrievalEvaluation(
scores=scores,
overall_verdict=overall_verdict,
avg_score=avg_score,
relevant_count=relevant_count,
total_count=len(scores),
)
def _determine_verdict(self, score: float) -> RelevanceVerdict:
"""根据分数判定相关性"""
if score >= self._correct_threshold:
return RelevanceVerdict.CORRECT
elif score >= self._ambiguous_threshold:
return RelevanceVerdict.AMBIGUOUS
else:
return RelevanceVerdict.INCORRECT
@staticmethod
def _tokenize(text: str) -> set[str]:
"""分词:中文按字符,英文按空格,统一小写"""
tokens: set[str] = set()
# Extract English words
en_words = re.findall(r"[a-zA-Z]+", text.lower())
tokens.update(en_words)
# Extract Chinese characters (individual chars + bigrams)
cn_chars = re.findall(r"[\u4e00-\u9fff]", text)
tokens.update(cn_chars)
# Add Chinese bigrams for better matching
for i in range(len(cn_chars) - 1):
tokens.add(cn_chars[i] + cn_chars[i + 1])
return tokens
@staticmethod
def _jaccard_similarity(set_a: set[str], set_b: set[str]) -> float:
"""Jaccard 相似度"""
if not set_a or not set_b:
return 0.0
intersection = len(set_a & set_b)
union = len(set_a | set_b)
if union == 0:
return 0.0
return intersection / union
@staticmethod
def _query_coverage(query_terms: set[str], doc_terms: set[str]) -> float:
"""查询词覆盖率:文档中出现的查询词比例"""
if not query_terms:
return 0.0
covered = len(query_terms & doc_terms)
return covered / len(query_terms)
def _length_score(self, length: int) -> float:
"""长度评分:过短或过长的文档降分"""
if length < self._min_doc_length:
# Too short — likely insufficient context
ratio = length / self._min_doc_length
return ratio * 0.5
elif length > self._max_doc_length:
# Too long — may contain irrelevant information
excess = (length - self._max_doc_length) / self._max_doc_length
return max(0.3, 1.0 - excess * 0.5)
else:
# Good length range
return 1.0

View File

@ -17,6 +17,8 @@ from agentkit.memory.working import WorkingMemory
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.semantic import SemanticMemory
from agentkit.memory.query_transformer import QueryTransformerBase
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
from agentkit.memory.relevance_scorer import RelevanceScorer
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
@ -55,6 +57,8 @@ class MemoryRetriever:
weights: dict[str, float] | None = None,
query_transformer: QueryTransformerBase | None = None,
context_template: str = "structured",
enable_self_correction: bool = False,
max_correction_retries: int = 3,
):
self._working = working_memory
self._episodic = episodic_memory
@ -66,6 +70,15 @@ class MemoryRetriever:
}
self._query_transformer = query_transformer
self._context_template = context_template
self._enable_self_correction = enable_self_correction
self._correction_loop: RAGSelfCorrectionLoop | None = None
if enable_self_correction:
self._correction_loop = RAGSelfCorrectionLoop(
retriever=self,
scorer=RelevanceScorer(),
query_transformer=query_transformer,
max_retries=max_correction_retries,
)
async def retrieve(
self,
@ -73,8 +86,31 @@ class MemoryRetriever:
top_k: int = 5,
token_budget: int = 3000,
filters: dict[str, Any] | None = None,
_skip_correction: bool = False,
) -> list[MemoryItem]:
"""混合检索三层记忆"""
"""混合检索三层记忆
Args:
query: 检索查询
top_k: 返回最大结果数
token_budget: token 预算
filters: 过滤条件
_skip_correction: 内部参数CRAG 循环内部调用时跳过自纠正
"""
# Self-correction loop (CRAG)
if (
self._enable_self_correction
and self._correction_loop is not None
and not _skip_correction
):
result = await self._correction_loop.retrieve_with_correction(
query, top_k=top_k, token_budget=token_budget, filters=filters
)
if result.degraded:
logger.warning(
f"RAG self-correction degraded after {result.total_retries} retries"
)
return result.items
# Query transformation
if self._query_transformer is not None:
transformed = await self._query_transformer.transform(query)

337
tests/unit/test_rag_loop.py Normal file
View File

@ -0,0 +1,337 @@
"""Tests for RelevanceScorer and RAGSelfCorrectionLoop"""
import pytest
from agentkit.memory.base import MemoryItem
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
RelevanceScore,
RelevanceVerdict,
RetrievalEvaluation,
)
from agentkit.memory.rag_loop import (
RAGSelfCorrectionLoop,
RAGLoopResult,
LoopState,
)
# --- RelevanceScorer Tests ---
class TestRelevanceScorer:
"""RelevanceScorer unit tests"""
def setup_method(self):
self.scorer = RelevanceScorer()
def test_score_highly_relevant_item(self):
"""Highly relevant document should score high"""
query = "Python web framework Django Flask"
item = MemoryItem(
key="doc1",
value="Django and Flask are popular Python web frameworks for building web applications",
score=0.9,
)
result = self.scorer.score_item(query, item)
assert result.score > 0.5
assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS)
def test_score_irrelevant_item(self):
"""Completely irrelevant document should score low"""
query = "Python web framework"
item = MemoryItem(
key="doc2",
value="The weather is sunny today and the birds are singing in the garden",
score=0.1,
)
result = self.scorer.score_item(query, item)
assert result.score < 0.5
assert result.verdict == RelevanceVerdict.INCORRECT
def test_score_chinese_relevant_item(self):
"""Chinese text relevance scoring"""
query = "GEO优化策略"
item = MemoryItem(
key="doc3",
value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面",
score=0.85,
)
result = self.scorer.score_item(query, item)
assert result.score > 0.3 # Chinese bigrams should match
def test_score_short_document_penalty(self):
"""Very short documents should be penalized"""
query = "machine learning algorithms"
short_item = MemoryItem(key="short", value="ML", score=0.9)
normal_item = MemoryItem(
key="normal",
value="Machine learning algorithms include supervised and unsupervised learning methods",
score=0.9,
)
short_result = self.scorer.score_item(query, short_item)
normal_result = self.scorer.score_item(query, normal_item)
assert normal_result.score > short_result.score
def test_evaluate_empty_results(self):
"""Empty retrieval results should be INCORRECT"""
evaluation = self.scorer.evaluate("test query", [])
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
assert evaluation.avg_score == 0.0
assert evaluation.total_count == 0
def test_evaluate_mixed_results(self):
"""Mixed quality results should be AMBIGUOUS or CORRECT"""
query = "Python web framework"
items = [
MemoryItem(key="good", value="Django is a Python web framework", score=0.9),
MemoryItem(key="bad", value="Weather forecast for today", score=0.1),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.total_count == 2
assert evaluation.relevant_count >= 1
def test_evaluate_all_correct(self):
"""All relevant results should give CORRECT verdict"""
query = "Python Django"
items = [
MemoryItem(key="d1", value="Django is a Python web framework", score=0.9),
MemoryItem(key="d2", value="Django REST framework for API development", score=0.85),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.overall_verdict == RelevanceVerdict.CORRECT
def test_evaluate_all_incorrect(self):
"""All irrelevant results should give INCORRECT verdict"""
query = "quantum computing"
items = [
MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1),
MemoryItem(key="d2", value="Gardening tips for spring", score=0.05),
]
evaluation = self.scorer.evaluate(query, items)
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
def test_custom_thresholds(self):
"""Custom thresholds should affect verdict"""
scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7)
query = "test"
item = MemoryItem(key="d1", value="test document with some content", score=0.5)
result = scorer.score_item(query, item)
# With high thresholds, this should be INCORRECT
assert result.verdict == RelevanceVerdict.INCORRECT
def test_jaccard_similarity(self):
"""Jaccard similarity calculation"""
set_a = {"python", "web", "framework"}
set_b = {"python", "web", "server"}
similarity = RelevanceScorer._jaccard_similarity(set_a, set_b)
assert 0.0 < similarity < 1.0
# 2 common / 4 unique = 0.5
assert abs(similarity - 0.5) < 0.01
def test_jaccard_empty_sets(self):
"""Jaccard with empty sets returns 0"""
assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0
assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0
def test_query_coverage(self):
"""Query term coverage calculation"""
query_terms = {"python", "django", "flask"}
doc_terms = {"python", "django", "web", "framework"}
coverage = RelevanceScorer._query_coverage(query_terms, doc_terms)
# 2 out of 3 query terms covered
assert abs(coverage - 2 / 3) < 0.01
def test_tokenize_chinese(self):
"""Chinese tokenization includes bigrams"""
tokens = RelevanceScorer._tokenize("机器学习算法")
# Should include individual chars and bigrams
assert "" in tokens
assert "" in tokens
assert "机器" in tokens # bigram
def test_tokenize_english(self):
"""English tokenization"""
tokens = RelevanceScorer._tokenize("Python Web Framework")
assert "python" in tokens
assert "web" in tokens
assert "framework" in tokens
# --- RAGSelfCorrectionLoop Tests ---
class MockRetriever:
"""Mock retriever for testing"""
def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None):
self._items = items_by_query or {}
self.call_count = 0
self.queries: list[str] = []
async def retrieve(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters=None,
_skip_correction: bool = False,
) -> list[MemoryItem]:
self.call_count += 1
self.queries.append(query)
# Return items for exact query match, or default items
if query in self._items:
return self._items[query]
# Return default items for any query
default_key = next(iter(self._items), None)
if default_key:
return self._items[default_key]
return []
class TestRAGSelfCorrectionLoop:
"""RAGSelfCorrectionLoop unit tests"""
@pytest.mark.asyncio
async def test_correct_retrieval_skips_correction(self):
"""High-quality retrieval should not trigger correction"""
items = [
MemoryItem(
key="d1",
value="Django is a Python web framework for building web applications quickly",
score=0.9,
),
MemoryItem(
key="d2",
value="Flask is a lightweight Python web framework for small applications",
score=0.85,
),
]
mock = MockRetriever({"Python web framework": items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("Python web framework")
assert not result.degraded
assert len(result.items) == 2
assert result.total_retries == 0
@pytest.mark.asyncio
async def test_poor_retrieval_triggers_correction(self):
"""Poor retrieval should trigger query rewriting"""
poor_items = [
MemoryItem(key="d1", value="Weather forecast for today", score=0.1),
]
good_items = [
MemoryItem(
key="d2",
value="Python Django web framework tutorial and best practices",
score=0.9,
),
]
mock = MockRetriever({
"Python web framework": poor_items,
"improved query": good_items,
})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("Python web framework")
assert mock.call_count >= 2 # At least initial + 1 retry
assert len(result.attempts) >= 2
@pytest.mark.asyncio
async def test_max_retries_causes_degradation(self):
"""Exceeding max retries should cause degradation"""
poor_items = [
MemoryItem(key="d1", value="Unrelated content about weather", score=0.05),
]
mock = MockRetriever({"any query": poor_items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
result = await loop.retrieve_with_correction("Python web framework")
assert result.degraded
assert result.total_retries >= 2
# Items should be marked low_confidence
assert any(
item.metadata.get("low_confidence", False) for item in result.items
)
@pytest.mark.asyncio
async def test_empty_retrieval_triggers_correction(self):
"""Empty retrieval results should trigger correction"""
mock = MockRetriever({"query": []})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
result = await loop.retrieve_with_correction("test query")
assert result.degraded
assert result.total_retries >= 1
@pytest.mark.asyncio
async def test_loop_result_tracks_attempts(self):
"""Loop result should track all correction attempts"""
items = [
MemoryItem(key="d1", value="Relevant Python content", score=0.9),
]
mock = MockRetriever({"test": items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
result = await loop.retrieve_with_correction("test")
assert len(result.attempts) >= 1
assert result.attempts[0].query == "test"
assert result.attempts[0].state in (
LoopState.GENERATE,
LoopState.CORRECT,
LoopState.DEGRADE,
)
@pytest.mark.asyncio
async def test_correction_with_query_transformer(self):
"""Query transformer should be used during correction"""
from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase
class MockTransformer(QueryTransformerBase):
def __init__(self):
self.transform_count = 0
async def transform(self, query: str) -> TransformedQuery:
self.transform_count += 1
return TransformedQuery(
main_query=f"improved {query}",
sub_queries=[f"sub-{query}"],
original_query=query,
)
poor_items = [
MemoryItem(key="d1", value="Unrelated", score=0.05),
]
good_items = [
MemoryItem(key="d2", value="Relevant Python content", score=0.9),
]
mock = MockRetriever({
"test": poor_items,
"sub-test": good_items,
})
transformer = MockTransformer()
loop = RAGSelfCorrectionLoop(
retriever=mock,
query_transformer=transformer,
max_retries=3,
)
result = await loop.retrieve_with_correction("test")
assert transformer.transform_count >= 1
@pytest.mark.asyncio
async def test_degraded_result_filters_irrelevant(self):
"""Degraded result should prefer relevant items over irrelevant"""
mixed_items = [
MemoryItem(key="good", value="Python Django framework", score=0.8),
MemoryItem(key="bad", value="Weather forecast", score=0.05),
]
mock = MockRetriever({"query": mixed_items})
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1)
result = await loop.retrieve_with_correction("Python framework")
# Even if degraded, should prefer relevant items
if result.degraded:
relevant_keys = [item.key for item in result.items]
assert "good" in relevant_keys