241 lines
7.6 KiB
Python
241 lines
7.6 KiB
Python
"""RAGSelfCorrectionLoop - CRAG 自纠正循环
|
||
|
||
实现 Corrective RAG 模式:检索→评估→纠正/降级→生成
|
||
当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from typing import TYPE_CHECKING
|
||
|
||
from agentkit.memory.base import MemoryItem, MetadataDict
|
||
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
||
from agentkit.memory.relevance_scorer import (
|
||
RelevanceScorer,
|
||
RelevanceVerdict,
|
||
RetrievalEvaluation,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
# 避免与 retriever.py 形成运行时循环导入。
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LoopState(str, Enum):
|
||
"""自纠正循环状态"""
|
||
|
||
RETRIEVE = "retrieve"
|
||
EVALUATE = "evaluate"
|
||
CORRECT = "correct"
|
||
DEGRADE = "degrade"
|
||
GENERATE = "generate"
|
||
|
||
|
||
@dataclass
|
||
class CorrectionAttempt:
|
||
"""一次纠正尝试的记录"""
|
||
|
||
query: str
|
||
evaluation: RetrievalEvaluation
|
||
state: LoopState
|
||
|
||
|
||
@dataclass
|
||
class RAGLoopResult:
|
||
"""自纠正循环的最终结果"""
|
||
|
||
items: list[MemoryItem]
|
||
evaluation: RetrievalEvaluation
|
||
attempts: list[CorrectionAttempt]
|
||
corrected: bool
|
||
degraded: bool
|
||
total_retries: int
|
||
|
||
|
||
class RAGSelfCorrectionLoop:
|
||
"""CRAG 自纠正循环
|
||
|
||
状态机驱动的检索-评估-纠正循环:
|
||
1. RETRIEVE: 使用 MemoryRetriever 检索
|
||
2. EVALUATE: RelevanceScorer 评估检索质量
|
||
3. CORRECT: 质量不足时,改写查询重新检索
|
||
4. DEGRADE: 超过重试次数,返回降级结果
|
||
5. GENERATE: 质量足够,返回结果
|
||
|
||
熔断机制:
|
||
- max_retries: 最大重试次数(默认 3)
|
||
- 超过重试次数后强制降级,标记 low_confidence
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
retriever: MemoryRetriever,
|
||
scorer: RelevanceScorer | None = None,
|
||
query_transformer: QueryTransformerBase | None = None,
|
||
max_retries: int = 3,
|
||
min_items_for_correct: int = 1,
|
||
):
|
||
self._retriever = retriever
|
||
self._scorer = scorer or RelevanceScorer()
|
||
self._query_transformer = query_transformer or NoOpQueryTransformer()
|
||
self._max_retries = max_retries
|
||
self._min_items_for_correct = min_items_for_correct
|
||
|
||
async def retrieve_with_correction(
|
||
self,
|
||
query: str,
|
||
top_k: int = 5,
|
||
token_budget: int = 3000,
|
||
filters: MetadataDict | None = None,
|
||
) -> RAGLoopResult:
|
||
"""执行带自纠正的检索
|
||
|
||
Args:
|
||
query: 原始查询
|
||
top_k: 返回的最大结果数
|
||
token_budget: token 预算
|
||
filters: 过滤条件
|
||
|
||
Returns:
|
||
RAGLoopResult: 包含检索结果、评估、尝试记录
|
||
"""
|
||
attempts: list[CorrectionAttempt] = []
|
||
current_query = query
|
||
retry_count = 0
|
||
|
||
while retry_count <= self._max_retries:
|
||
# RETRIEVE
|
||
items = await self._retriever.retrieve(
|
||
current_query,
|
||
top_k=top_k,
|
||
token_budget=token_budget,
|
||
filters=filters,
|
||
_skip_correction=True,
|
||
)
|
||
|
||
# EVALUATE
|
||
evaluation = self._scorer.evaluate(current_query, items)
|
||
state = self._determine_next_state(evaluation, items)
|
||
|
||
attempt = CorrectionAttempt(
|
||
query=current_query,
|
||
evaluation=evaluation,
|
||
state=state,
|
||
)
|
||
attempts.append(attempt)
|
||
|
||
logger.info(
|
||
f"RAG loop attempt {retry_count + 1}: "
|
||
f"query='{current_query[:50]}...', "
|
||
f"verdict={evaluation.overall_verdict.value}, "
|
||
f"avg_score={evaluation.avg_score:.2f}, "
|
||
f"state={state.value}"
|
||
)
|
||
|
||
# GENERATE — quality is sufficient
|
||
if state == LoopState.GENERATE:
|
||
return RAGLoopResult(
|
||
items=items,
|
||
evaluation=evaluation,
|
||
attempts=attempts,
|
||
corrected=retry_count > 0,
|
||
degraded=False,
|
||
total_retries=retry_count,
|
||
)
|
||
|
||
# CORRECT — rewrite query and retry
|
||
retry_count += 1
|
||
if retry_count <= self._max_retries:
|
||
current_query = await self._rewrite_query(query, current_query, evaluation)
|
||
continue
|
||
|
||
# DEGRADE — exceeded max retries
|
||
break
|
||
|
||
# Degraded result: filter to relevant items and mark low confidence
|
||
relevant_items = [
|
||
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
|
||
]
|
||
result_items = relevant_items if relevant_items else items
|
||
|
||
for item in result_items:
|
||
item.metadata["low_confidence"] = True
|
||
|
||
return RAGLoopResult(
|
||
items=result_items,
|
||
evaluation=evaluation,
|
||
attempts=attempts,
|
||
corrected=False,
|
||
degraded=True,
|
||
total_retries=retry_count,
|
||
)
|
||
|
||
def _determine_next_state(
|
||
self, evaluation: RetrievalEvaluation, items: list[MemoryItem]
|
||
) -> LoopState:
|
||
"""根据评估结果确定下一个状态"""
|
||
verdict = evaluation.overall_verdict
|
||
|
||
if verdict == RelevanceVerdict.CORRECT:
|
||
if evaluation.relevant_count >= self._min_items_for_correct:
|
||
return LoopState.GENERATE
|
||
# Correct verdict but not enough items — still try to generate
|
||
if items:
|
||
return LoopState.GENERATE
|
||
return LoopState.CORRECT
|
||
|
||
if verdict == RelevanceVerdict.AMBIGUOUS:
|
||
# Some relevant results — could improve but not terrible
|
||
return LoopState.CORRECT
|
||
|
||
# INCORRECT — definitely need correction
|
||
return LoopState.CORRECT
|
||
|
||
async def _rewrite_query(
|
||
self,
|
||
original_query: str,
|
||
current_query: str,
|
||
evaluation: RetrievalEvaluation,
|
||
) -> str:
|
||
"""改写查询以改善检索质量
|
||
|
||
策略:
|
||
1. 使用 QueryTransformer 改写
|
||
2. 从评估结果中提取改进线索
|
||
3. 追加失败模式提示
|
||
"""
|
||
# Use query transformer for rewriting
|
||
transformed = await self._query_transformer.transform(current_query)
|
||
new_query = transformed.main_query
|
||
|
||
# If transformer didn't change the query, try with original
|
||
if new_query == current_query:
|
||
# Add context from failed evaluation to help next retrieval
|
||
failed_terms = []
|
||
for score in evaluation.scores:
|
||
if score.verdict == RelevanceVerdict.INCORRECT:
|
||
# Extract key terms from low-scoring items to avoid
|
||
doc_text = str(score.item.value)[:100]
|
||
failed_terms.append(doc_text)
|
||
|
||
if failed_terms and original_query != current_query:
|
||
# Try original query as fallback
|
||
new_query = original_query
|
||
elif failed_terms:
|
||
# Add "NOT" context to help filter
|
||
new_query = f"{current_query} (excluding irrelevant results)"
|
||
|
||
# Add sub-queries if available
|
||
if transformed.sub_queries:
|
||
# Use the first sub-query as the new primary query
|
||
# This explores different aspects of the original question
|
||
new_query = transformed.sub_queries[0]
|
||
|
||
logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'")
|
||
return new_query
|