fischer-agentkit/src/agentkit/memory/rag_loop.py

241 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RAGSelfCorrectionLoop - CRAG 自纠正循环
实现 Corrective RAG 模式:检索→评估→纠正/降级→生成
当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING
from agentkit.memory.base import MemoryItem, MetadataDict
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
from agentkit.memory.relevance_scorer import (
RelevanceScorer,
RelevanceVerdict,
RetrievalEvaluation,
)
if TYPE_CHECKING:
# 避免与 retriever.py 形成运行时循环导入。
from agentkit.memory.retriever import MemoryRetriever
logger = logging.getLogger(__name__)
class LoopState(str, Enum):
"""自纠正循环状态"""
RETRIEVE = "retrieve"
EVALUATE = "evaluate"
CORRECT = "correct"
DEGRADE = "degrade"
GENERATE = "generate"
@dataclass
class CorrectionAttempt:
"""一次纠正尝试的记录"""
query: str
evaluation: RetrievalEvaluation
state: LoopState
@dataclass
class RAGLoopResult:
"""自纠正循环的最终结果"""
items: list[MemoryItem]
evaluation: RetrievalEvaluation
attempts: list[CorrectionAttempt]
corrected: bool
degraded: bool
total_retries: int
class RAGSelfCorrectionLoop:
"""CRAG 自纠正循环
状态机驱动的检索-评估-纠正循环:
1. RETRIEVE: 使用 MemoryRetriever 检索
2. EVALUATE: RelevanceScorer 评估检索质量
3. CORRECT: 质量不足时,改写查询重新检索
4. DEGRADE: 超过重试次数,返回降级结果
5. GENERATE: 质量足够,返回结果
熔断机制:
- max_retries: 最大重试次数(默认 3
- 超过重试次数后强制降级,标记 low_confidence
"""
def __init__(
self,
retriever: MemoryRetriever,
scorer: RelevanceScorer | None = None,
query_transformer: QueryTransformerBase | None = None,
max_retries: int = 3,
min_items_for_correct: int = 1,
):
self._retriever = retriever
self._scorer = scorer or RelevanceScorer()
self._query_transformer = query_transformer or NoOpQueryTransformer()
self._max_retries = max_retries
self._min_items_for_correct = min_items_for_correct
async def retrieve_with_correction(
self,
query: str,
top_k: int = 5,
token_budget: int = 3000,
filters: MetadataDict | None = None,
) -> RAGLoopResult:
"""执行带自纠正的检索
Args:
query: 原始查询
top_k: 返回的最大结果数
token_budget: token 预算
filters: 过滤条件
Returns:
RAGLoopResult: 包含检索结果、评估、尝试记录
"""
attempts: list[CorrectionAttempt] = []
current_query = query
retry_count = 0
while retry_count <= self._max_retries:
# RETRIEVE
items = await self._retriever.retrieve(
current_query,
top_k=top_k,
token_budget=token_budget,
filters=filters,
_skip_correction=True,
)
# EVALUATE
evaluation = self._scorer.evaluate(current_query, items)
state = self._determine_next_state(evaluation, items)
attempt = CorrectionAttempt(
query=current_query,
evaluation=evaluation,
state=state,
)
attempts.append(attempt)
logger.info(
f"RAG loop attempt {retry_count + 1}: "
f"query='{current_query[:50]}...', "
f"verdict={evaluation.overall_verdict.value}, "
f"avg_score={evaluation.avg_score:.2f}, "
f"state={state.value}"
)
# GENERATE — quality is sufficient
if state == LoopState.GENERATE:
return RAGLoopResult(
items=items,
evaluation=evaluation,
attempts=attempts,
corrected=retry_count > 0,
degraded=False,
total_retries=retry_count,
)
# CORRECT — rewrite query and retry
retry_count += 1
if retry_count <= self._max_retries:
current_query = await self._rewrite_query(query, current_query, evaluation)
continue
# DEGRADE — exceeded max retries
break
# Degraded result: filter to relevant items and mark low confidence
relevant_items = [
s.item for s in evaluation.scores if s.verdict != RelevanceVerdict.INCORRECT
]
result_items = relevant_items if relevant_items else items
for item in result_items:
item.metadata["low_confidence"] = True
return RAGLoopResult(
items=result_items,
evaluation=evaluation,
attempts=attempts,
corrected=False,
degraded=True,
total_retries=retry_count,
)
def _determine_next_state(
self, evaluation: RetrievalEvaluation, items: list[MemoryItem]
) -> LoopState:
"""根据评估结果确定下一个状态"""
verdict = evaluation.overall_verdict
if verdict == RelevanceVerdict.CORRECT:
if evaluation.relevant_count >= self._min_items_for_correct:
return LoopState.GENERATE
# Correct verdict but not enough items — still try to generate
if items:
return LoopState.GENERATE
return LoopState.CORRECT
if verdict == RelevanceVerdict.AMBIGUOUS:
# Some relevant results — could improve but not terrible
return LoopState.CORRECT
# INCORRECT — definitely need correction
return LoopState.CORRECT
async def _rewrite_query(
self,
original_query: str,
current_query: str,
evaluation: RetrievalEvaluation,
) -> str:
"""改写查询以改善检索质量
策略:
1. 使用 QueryTransformer 改写
2. 从评估结果中提取改进线索
3. 追加失败模式提示
"""
# Use query transformer for rewriting
transformed = await self._query_transformer.transform(current_query)
new_query = transformed.main_query
# If transformer didn't change the query, try with original
if new_query == current_query:
# Add context from failed evaluation to help next retrieval
failed_terms = []
for score in evaluation.scores:
if score.verdict == RelevanceVerdict.INCORRECT:
# Extract key terms from low-scoring items to avoid
doc_text = str(score.item.value)[:100]
failed_terms.append(doc_text)
if failed_terms and original_query != current_query:
# Try original query as fallback
new_query = original_query
elif failed_terms:
# Add "NOT" context to help filter
new_query = f"{current_query} (excluding irrelevant results)"
# Add sub-queries if available
if transformed.sub_queries:
# Use the first sub-query as the new primary query
# This explores different aspects of the original question
new_query = transformed.sub_queries[0]
logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'")
return new_query