"""RAGSelfCorrectionLoop - CRAG 自纠正循环 实现 Corrective RAG 模式:检索→评估→纠正/降级→生成 当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。 """ from __future__ import annotations import logging from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING from agentkit.memory.base import MemoryItem, MetadataDict from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer from agentkit.memory.relevance_scorer import ( RelevanceScorer, RelevanceVerdict, RetrievalEvaluation, ) if TYPE_CHECKING: # 避免与 retriever.py 形成运行时循环导入。 from agentkit.memory.retriever import MemoryRetriever logger = logging.getLogger(__name__) class LoopState(str, Enum): """自纠正循环状态""" RETRIEVE = "retrieve" EVALUATE = "evaluate" CORRECT = "correct" DEGRADE = "degrade" GENERATE = "generate" @dataclass class CorrectionAttempt: """一次纠正尝试的记录""" query: str evaluation: RetrievalEvaluation state: LoopState @dataclass class RAGLoopResult: """自纠正循环的最终结果""" items: list[MemoryItem] evaluation: RetrievalEvaluation attempts: list[CorrectionAttempt] corrected: bool degraded: bool total_retries: int class RAGSelfCorrectionLoop: """CRAG 自纠正循环 状态机驱动的检索-评估-纠正循环: 1. RETRIEVE: 使用 MemoryRetriever 检索 2. EVALUATE: RelevanceScorer 评估检索质量 3. CORRECT: 质量不足时,改写查询重新检索 4. DEGRADE: 超过重试次数,返回降级结果 5. GENERATE: 质量足够,返回结果 熔断机制: - max_retries: 最大重试次数(默认 3) - 超过重试次数后强制降级,标记 low_confidence """ def __init__( self, retriever: MemoryRetriever, scorer: RelevanceScorer | None = None, query_transformer: QueryTransformerBase | None = None, max_retries: int = 3, min_items_for_correct: int = 1, ): self._retriever = retriever self._scorer = scorer or RelevanceScorer() self._query_transformer = query_transformer or NoOpQueryTransformer() self._max_retries = max_retries self._min_items_for_correct = min_items_for_correct async def retrieve_with_correction( self, query: str, top_k: int = 5, token_budget: int = 3000, filters: MetadataDict | None = None, ) -> RAGLoopResult: """执行带自纠正的检索 Args: query: 原始查询 top_k: 返回的最大结果数 token_budget: token 预算 filters: 过滤条件 Returns: RAGLoopResult: 包含检索结果、评估、尝试记录 """ attempts: list[CorrectionAttempt] = [] current_query = query retry_count = 0 while retry_count <= self._max_retries: # RETRIEVE items = await self._retriever.retrieve( current_query, top_k=top_k, token_budget=token_budget, filters=filters, _skip_correction=True, ) # EVALUATE evaluation = self._scorer.evaluate(current_query, items) state = self._determine_next_state(evaluation, items) attempt = CorrectionAttempt( query=current_query, evaluation=evaluation, state=state, ) attempts.append(attempt) logger.info( f"RAG loop attempt {retry_count + 1}: " f"query='{current_query[:50]}...', " f"verdict={evaluation.overall_verdict.value}, " f"avg_score={evaluation.avg_score:.2f}, " f"state={state.value}" ) # GENERATE — quality is sufficient if state == LoopState.GENERATE: return RAGLoopResult( items=items, evaluation=evaluation, attempts=attempts, corrected=retry_count > 0, degraded=False, total_retries=retry_count, ) # CORRECT — rewrite query and retry retry_count += 1 if retry_count <= self._max_retries: current_query = await self._rewrite_query(query, current_query, evaluation) continue # DEGRADE — exceeded max retries break # Degraded result: filter to relevant items and mark low confidence relevant_items = [ s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT ] result_items = relevant_items if relevant_items else items for item in result_items: item.metadata["low_confidence"] = True return RAGLoopResult( items=result_items, evaluation=evaluation, attempts=attempts, corrected=False, degraded=True, total_retries=retry_count, ) def _determine_next_state( self, evaluation: RetrievalEvaluation, items: list[MemoryItem] ) -> LoopState: """根据评估结果确定下一个状态""" verdict = evaluation.overall_verdict if verdict == RelevanceVerdict.CORRECT: if evaluation.relevant_count >= self._min_items_for_correct: return LoopState.GENERATE # Correct verdict but not enough items — still try to generate if items: return LoopState.GENERATE return LoopState.CORRECT if verdict == RelevanceVerdict.AMBIGUOUS: # Some relevant results — could improve but not terrible return LoopState.CORRECT # INCORRECT — definitely need correction return LoopState.CORRECT async def _rewrite_query( self, original_query: str, current_query: str, evaluation: RetrievalEvaluation, ) -> str: """改写查询以改善检索质量 策略: 1. 使用 QueryTransformer 改写 2. 从评估结果中提取改进线索 3. 追加失败模式提示 """ # Use query transformer for rewriting transformed = await self._query_transformer.transform(current_query) new_query = transformed.main_query # If transformer didn't change the query, try with original if new_query == current_query: # Add context from failed evaluation to help next retrieval failed_terms = [] for score in evaluation.scores: if score.verdict == RelevanceVerdict.INCORRECT: # Extract key terms from low-scoring items to avoid doc_text = str(score.item.value)[:100] failed_terms.append(doc_text) if failed_terms and original_query != current_query: # Try original query as fallback new_query = original_query elif failed_terms: # Add "NOT" context to help filter new_query = f"{current_query} (excluding irrelevant results)" # Add sub-queries if available if transformed.sub_queries: # Use the first sub-query as the new primary query # This explores different aspects of the original question new_query = transformed.sub_queries[0] logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'") return new_query