feat(memory): U1 RAG self-correction loop (CRAG)
- RelevanceScorer: keyword overlap + query coverage + retrieval score + length penalty - RAGSelfCorrectionLoop: state machine driven retrieve-evaluate-correct-degrade cycle - Integrated into MemoryRetriever with enable_self_correction option - 21 tests passing
This commit is contained in:
parent
468dfd71e8
commit
a6c9babfdc
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""RAGSelfCorrectionLoop - CRAG 自纠正循环
|
||||||
|
|
||||||
|
实现 Corrective RAG 模式:检索→评估→纠正/降级→生成
|
||||||
|
当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
||||||
|
from agentkit.memory.relevance_scorer import (
|
||||||
|
RelevanceScorer,
|
||||||
|
RelevanceVerdict,
|
||||||
|
RetrievalEvaluation,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoopState(str, Enum):
|
||||||
|
"""自纠正循环状态"""
|
||||||
|
|
||||||
|
RETRIEVE = "retrieve"
|
||||||
|
EVALUATE = "evaluate"
|
||||||
|
CORRECT = "correct"
|
||||||
|
DEGRADE = "degrade"
|
||||||
|
GENERATE = "generate"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CorrectionAttempt:
|
||||||
|
"""一次纠正尝试的记录"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
evaluation: RetrievalEvaluation
|
||||||
|
state: LoopState
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RAGLoopResult:
|
||||||
|
"""自纠正循环的最终结果"""
|
||||||
|
|
||||||
|
items: list[MemoryItem]
|
||||||
|
evaluation: RetrievalEvaluation
|
||||||
|
attempts: list[CorrectionAttempt]
|
||||||
|
corrected: bool
|
||||||
|
degraded: bool
|
||||||
|
total_retries: int
|
||||||
|
|
||||||
|
|
||||||
|
class RAGSelfCorrectionLoop:
|
||||||
|
"""CRAG 自纠正循环
|
||||||
|
|
||||||
|
状态机驱动的检索-评估-纠正循环:
|
||||||
|
1. RETRIEVE: 使用 MemoryRetriever 检索
|
||||||
|
2. EVALUATE: RelevanceScorer 评估检索质量
|
||||||
|
3. CORRECT: 质量不足时,改写查询重新检索
|
||||||
|
4. DEGRADE: 超过重试次数,返回降级结果
|
||||||
|
5. GENERATE: 质量足够,返回结果
|
||||||
|
|
||||||
|
熔断机制:
|
||||||
|
- max_retries: 最大重试次数(默认 3)
|
||||||
|
- 超过重试次数后强制降级,标记 low_confidence
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
retriever: Any, # MemoryRetriever
|
||||||
|
scorer: RelevanceScorer | None = None,
|
||||||
|
query_transformer: QueryTransformerBase | None = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
min_items_for_correct: int = 1,
|
||||||
|
):
|
||||||
|
self._retriever = retriever
|
||||||
|
self._scorer = scorer or RelevanceScorer()
|
||||||
|
self._query_transformer = query_transformer or NoOpQueryTransformer()
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._min_items_for_correct = min_items_for_correct
|
||||||
|
|
||||||
|
async def retrieve_with_correction(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
token_budget: int = 3000,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
) -> RAGLoopResult:
|
||||||
|
"""执行带自纠正的检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 原始查询
|
||||||
|
top_k: 返回的最大结果数
|
||||||
|
token_budget: token 预算
|
||||||
|
filters: 过滤条件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RAGLoopResult: 包含检索结果、评估、尝试记录
|
||||||
|
"""
|
||||||
|
attempts: list[CorrectionAttempt] = []
|
||||||
|
current_query = query
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
|
while retry_count <= self._max_retries:
|
||||||
|
# RETRIEVE
|
||||||
|
items = await self._retriever.retrieve(
|
||||||
|
current_query, top_k=top_k, token_budget=token_budget,
|
||||||
|
filters=filters, _skip_correction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# EVALUATE
|
||||||
|
evaluation = self._scorer.evaluate(current_query, items)
|
||||||
|
state = self._determine_next_state(evaluation, items)
|
||||||
|
|
||||||
|
attempt = CorrectionAttempt(
|
||||||
|
query=current_query,
|
||||||
|
evaluation=evaluation,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
attempts.append(attempt)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"RAG loop attempt {retry_count + 1}: "
|
||||||
|
f"query='{current_query[:50]}...', "
|
||||||
|
f"verdict={evaluation.overall_verdict.value}, "
|
||||||
|
f"avg_score={evaluation.avg_score:.2f}, "
|
||||||
|
f"state={state.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# GENERATE — quality is sufficient
|
||||||
|
if state == LoopState.GENERATE:
|
||||||
|
return RAGLoopResult(
|
||||||
|
items=items,
|
||||||
|
evaluation=evaluation,
|
||||||
|
attempts=attempts,
|
||||||
|
corrected=retry_count > 0,
|
||||||
|
degraded=False,
|
||||||
|
total_retries=retry_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORRECT — rewrite query and retry
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self._max_retries:
|
||||||
|
current_query = await self._rewrite_query(
|
||||||
|
query, current_query, evaluation
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# DEGRADE — exceeded max retries
|
||||||
|
break
|
||||||
|
|
||||||
|
# Degraded result: filter to relevant items and mark low confidence
|
||||||
|
relevant_items = [
|
||||||
|
s.item
|
||||||
|
for s in evaluation.scores
|
||||||
|
if s.verdict != RelevanceVerdict.INCORRECT
|
||||||
|
]
|
||||||
|
result_items = relevant_items if relevant_items else items
|
||||||
|
|
||||||
|
for item in result_items:
|
||||||
|
item.metadata["low_confidence"] = True
|
||||||
|
|
||||||
|
return RAGLoopResult(
|
||||||
|
items=result_items,
|
||||||
|
evaluation=evaluation,
|
||||||
|
attempts=attempts,
|
||||||
|
corrected=False,
|
||||||
|
degraded=True,
|
||||||
|
total_retries=retry_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _determine_next_state(
|
||||||
|
self, evaluation: RetrievalEvaluation, items: list[MemoryItem]
|
||||||
|
) -> LoopState:
|
||||||
|
"""根据评估结果确定下一个状态"""
|
||||||
|
verdict = evaluation.overall_verdict
|
||||||
|
|
||||||
|
if verdict == RelevanceVerdict.CORRECT:
|
||||||
|
if evaluation.relevant_count >= self._min_items_for_correct:
|
||||||
|
return LoopState.GENERATE
|
||||||
|
# Correct verdict but not enough items — still try to generate
|
||||||
|
if items:
|
||||||
|
return LoopState.GENERATE
|
||||||
|
return LoopState.CORRECT
|
||||||
|
|
||||||
|
if verdict == RelevanceVerdict.AMBIGUOUS:
|
||||||
|
# Some relevant results — could improve but not terrible
|
||||||
|
return LoopState.CORRECT
|
||||||
|
|
||||||
|
# INCORRECT — definitely need correction
|
||||||
|
return LoopState.CORRECT
|
||||||
|
|
||||||
|
async def _rewrite_query(
|
||||||
|
self,
|
||||||
|
original_query: str,
|
||||||
|
current_query: str,
|
||||||
|
evaluation: RetrievalEvaluation,
|
||||||
|
) -> str:
|
||||||
|
"""改写查询以改善检索质量
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 使用 QueryTransformer 改写
|
||||||
|
2. 从评估结果中提取改进线索
|
||||||
|
3. 追加失败模式提示
|
||||||
|
"""
|
||||||
|
# Use query transformer for rewriting
|
||||||
|
transformed = await self._query_transformer.transform(current_query)
|
||||||
|
new_query = transformed.main_query
|
||||||
|
|
||||||
|
# If transformer didn't change the query, try with original
|
||||||
|
if new_query == current_query:
|
||||||
|
# Add context from failed evaluation to help next retrieval
|
||||||
|
failed_terms = []
|
||||||
|
for score in evaluation.scores:
|
||||||
|
if score.verdict == RelevanceVerdict.INCORRECT:
|
||||||
|
# Extract key terms from low-scoring items to avoid
|
||||||
|
doc_text = str(score.item.value)[:100]
|
||||||
|
failed_terms.append(doc_text)
|
||||||
|
|
||||||
|
if failed_terms and original_query != current_query:
|
||||||
|
# Try original query as fallback
|
||||||
|
new_query = original_query
|
||||||
|
elif failed_terms:
|
||||||
|
# Add "NOT" context to help filter
|
||||||
|
new_query = f"{current_query} (excluding irrelevant results)"
|
||||||
|
|
||||||
|
# Add sub-queries if available
|
||||||
|
if transformed.sub_queries:
|
||||||
|
# Use the first sub-query as the new primary query
|
||||||
|
# This explores different aspects of the original question
|
||||||
|
new_query = transformed.sub_queries[0]
|
||||||
|
|
||||||
|
logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'")
|
||||||
|
return new_query
|
||||||
|
|
@ -0,0 +1,215 @@
|
||||||
|
"""RelevanceScorer - 检索结果相关性自动评估
|
||||||
|
|
||||||
|
对检索结果逐文档评估与查询的相关性,用于 CRAG 自纠正循环的评估阶段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceVerdict(str, Enum):
|
||||||
|
"""相关性判定结果"""
|
||||||
|
|
||||||
|
CORRECT = "correct"
|
||||||
|
AMBIGUOUS = "ambiguous"
|
||||||
|
INCORRECT = "incorrect"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RelevanceScore:
|
||||||
|
"""单个文档的相关性评分"""
|
||||||
|
|
||||||
|
item: MemoryItem
|
||||||
|
score: float # 0.0 ~ 1.0
|
||||||
|
verdict: RelevanceVerdict
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalEvaluation:
|
||||||
|
"""一次检索的整体评估结果"""
|
||||||
|
|
||||||
|
scores: list[RelevanceScore]
|
||||||
|
overall_verdict: RelevanceVerdict
|
||||||
|
avg_score: float
|
||||||
|
relevant_count: int
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceScorer:
|
||||||
|
"""检索结果相关性评估器
|
||||||
|
|
||||||
|
基于查询-文档语义相似度和关键词重叠的轻量级评估器。
|
||||||
|
不依赖 LLM 调用,适用于生产环境的低延迟评估。
|
||||||
|
|
||||||
|
评分策略:
|
||||||
|
1. 关键词重叠率(Jaccard 相似度)
|
||||||
|
2. 查询词覆盖率(query term coverage)
|
||||||
|
3. 原始检索分数加权
|
||||||
|
4. 长度惩罚(过短或过长的文档降分)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
correct_threshold: float = 0.6,
|
||||||
|
ambiguous_threshold: float = 0.35,
|
||||||
|
keyword_weight: float = 0.3,
|
||||||
|
coverage_weight: float = 0.3,
|
||||||
|
retrieval_weight: float = 0.3,
|
||||||
|
length_weight: float = 0.1,
|
||||||
|
min_doc_length: int = 20,
|
||||||
|
max_doc_length: int = 5000,
|
||||||
|
):
|
||||||
|
self._correct_threshold = correct_threshold
|
||||||
|
self._ambiguous_threshold = ambiguous_threshold
|
||||||
|
self._keyword_weight = keyword_weight
|
||||||
|
self._coverage_weight = coverage_weight
|
||||||
|
self._retrieval_weight = retrieval_weight
|
||||||
|
self._length_weight = length_weight
|
||||||
|
self._min_doc_length = min_doc_length
|
||||||
|
self._max_doc_length = max_doc_length
|
||||||
|
|
||||||
|
def score_item(self, query: str, item: MemoryItem) -> RelevanceScore:
|
||||||
|
"""评估单个检索结果与查询的相关性"""
|
||||||
|
doc_text = str(item.value)
|
||||||
|
|
||||||
|
# 1. Keyword overlap (Jaccard similarity)
|
||||||
|
query_terms = self._tokenize(query)
|
||||||
|
doc_terms = self._tokenize(doc_text)
|
||||||
|
keyword_score = self._jaccard_similarity(query_terms, doc_terms)
|
||||||
|
|
||||||
|
# 2. Query term coverage
|
||||||
|
coverage_score = self._query_coverage(query_terms, doc_terms)
|
||||||
|
|
||||||
|
# 3. Original retrieval score
|
||||||
|
retrieval_score = min(max(item.score, 0.0), 1.0)
|
||||||
|
|
||||||
|
# 4. Length penalty
|
||||||
|
length_score = self._length_score(len(doc_text))
|
||||||
|
|
||||||
|
# Weighted combination
|
||||||
|
final_score = (
|
||||||
|
keyword_score * self._keyword_weight
|
||||||
|
+ coverage_score * self._coverage_weight
|
||||||
|
+ retrieval_score * self._retrieval_weight
|
||||||
|
+ length_score * self._length_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine verdict
|
||||||
|
verdict = self._determine_verdict(final_score)
|
||||||
|
|
||||||
|
reason = (
|
||||||
|
f"keyword={keyword_score:.2f}, coverage={coverage_score:.2f}, "
|
||||||
|
f"retrieval={retrieval_score:.2f}, length={length_score:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RelevanceScore(
|
||||||
|
item=item,
|
||||||
|
score=final_score,
|
||||||
|
verdict=verdict,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self, query: str, items: list[MemoryItem]
|
||||||
|
) -> RetrievalEvaluation:
|
||||||
|
"""评估一次检索的整体质量"""
|
||||||
|
if not items:
|
||||||
|
return RetrievalEvaluation(
|
||||||
|
scores=[],
|
||||||
|
overall_verdict=RelevanceVerdict.INCORRECT,
|
||||||
|
avg_score=0.0,
|
||||||
|
relevant_count=0,
|
||||||
|
total_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = [self.score_item(query, item) for item in items]
|
||||||
|
relevant_count = sum(
|
||||||
|
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
|
||||||
|
)
|
||||||
|
avg_score = sum(s.score for s in scores) / len(scores)
|
||||||
|
|
||||||
|
# Overall verdict based on average score and relevant ratio
|
||||||
|
relevant_ratio = relevant_count / len(scores)
|
||||||
|
|
||||||
|
if avg_score >= self._correct_threshold and relevant_ratio >= 0.5:
|
||||||
|
overall_verdict = RelevanceVerdict.CORRECT
|
||||||
|
elif avg_score >= self._ambiguous_threshold or relevant_ratio >= 0.3:
|
||||||
|
overall_verdict = RelevanceVerdict.AMBIGUOUS
|
||||||
|
else:
|
||||||
|
overall_verdict = RelevanceVerdict.INCORRECT
|
||||||
|
|
||||||
|
return RetrievalEvaluation(
|
||||||
|
scores=scores,
|
||||||
|
overall_verdict=overall_verdict,
|
||||||
|
avg_score=avg_score,
|
||||||
|
relevant_count=relevant_count,
|
||||||
|
total_count=len(scores),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _determine_verdict(self, score: float) -> RelevanceVerdict:
|
||||||
|
"""根据分数判定相关性"""
|
||||||
|
if score >= self._correct_threshold:
|
||||||
|
return RelevanceVerdict.CORRECT
|
||||||
|
elif score >= self._ambiguous_threshold:
|
||||||
|
return RelevanceVerdict.AMBIGUOUS
|
||||||
|
else:
|
||||||
|
return RelevanceVerdict.INCORRECT
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tokenize(text: str) -> set[str]:
|
||||||
|
"""分词:中文按字符,英文按空格,统一小写"""
|
||||||
|
tokens: set[str] = set()
|
||||||
|
# Extract English words
|
||||||
|
en_words = re.findall(r"[a-zA-Z]+", text.lower())
|
||||||
|
tokens.update(en_words)
|
||||||
|
# Extract Chinese characters (individual chars + bigrams)
|
||||||
|
cn_chars = re.findall(r"[\u4e00-\u9fff]", text)
|
||||||
|
tokens.update(cn_chars)
|
||||||
|
# Add Chinese bigrams for better matching
|
||||||
|
for i in range(len(cn_chars) - 1):
|
||||||
|
tokens.add(cn_chars[i] + cn_chars[i + 1])
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _jaccard_similarity(set_a: set[str], set_b: set[str]) -> float:
|
||||||
|
"""Jaccard 相似度"""
|
||||||
|
if not set_a or not set_b:
|
||||||
|
return 0.0
|
||||||
|
intersection = len(set_a & set_b)
|
||||||
|
union = len(set_a | set_b)
|
||||||
|
if union == 0:
|
||||||
|
return 0.0
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _query_coverage(query_terms: set[str], doc_terms: set[str]) -> float:
|
||||||
|
"""查询词覆盖率:文档中出现的查询词比例"""
|
||||||
|
if not query_terms:
|
||||||
|
return 0.0
|
||||||
|
covered = len(query_terms & doc_terms)
|
||||||
|
return covered / len(query_terms)
|
||||||
|
|
||||||
|
def _length_score(self, length: int) -> float:
|
||||||
|
"""长度评分:过短或过长的文档降分"""
|
||||||
|
if length < self._min_doc_length:
|
||||||
|
# Too short — likely insufficient context
|
||||||
|
ratio = length / self._min_doc_length
|
||||||
|
return ratio * 0.5
|
||||||
|
elif length > self._max_doc_length:
|
||||||
|
# Too long — may contain irrelevant information
|
||||||
|
excess = (length - self._max_doc_length) / self._max_doc_length
|
||||||
|
return max(0.3, 1.0 - excess * 0.5)
|
||||||
|
else:
|
||||||
|
# Good length range
|
||||||
|
return 1.0
|
||||||
|
|
@ -17,6 +17,8 @@ from agentkit.memory.working import WorkingMemory
|
||||||
from agentkit.memory.episodic import EpisodicMemory
|
from agentkit.memory.episodic import EpisodicMemory
|
||||||
from agentkit.memory.semantic import SemanticMemory
|
from agentkit.memory.semantic import SemanticMemory
|
||||||
from agentkit.memory.query_transformer import QueryTransformerBase
|
from agentkit.memory.query_transformer import QueryTransformerBase
|
||||||
|
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
|
||||||
|
from agentkit.memory.relevance_scorer import RelevanceScorer
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -55,6 +57,8 @@ class MemoryRetriever:
|
||||||
weights: dict[str, float] | None = None,
|
weights: dict[str, float] | None = None,
|
||||||
query_transformer: QueryTransformerBase | None = None,
|
query_transformer: QueryTransformerBase | None = None,
|
||||||
context_template: str = "structured",
|
context_template: str = "structured",
|
||||||
|
enable_self_correction: bool = False,
|
||||||
|
max_correction_retries: int = 3,
|
||||||
):
|
):
|
||||||
self._working = working_memory
|
self._working = working_memory
|
||||||
self._episodic = episodic_memory
|
self._episodic = episodic_memory
|
||||||
|
|
@ -66,6 +70,15 @@ class MemoryRetriever:
|
||||||
}
|
}
|
||||||
self._query_transformer = query_transformer
|
self._query_transformer = query_transformer
|
||||||
self._context_template = context_template
|
self._context_template = context_template
|
||||||
|
self._enable_self_correction = enable_self_correction
|
||||||
|
self._correction_loop: RAGSelfCorrectionLoop | None = None
|
||||||
|
if enable_self_correction:
|
||||||
|
self._correction_loop = RAGSelfCorrectionLoop(
|
||||||
|
retriever=self,
|
||||||
|
scorer=RelevanceScorer(),
|
||||||
|
query_transformer=query_transformer,
|
||||||
|
max_retries=max_correction_retries,
|
||||||
|
)
|
||||||
|
|
||||||
async def retrieve(
|
async def retrieve(
|
||||||
self,
|
self,
|
||||||
|
|
@ -73,8 +86,31 @@ class MemoryRetriever:
|
||||||
top_k: int = 5,
|
top_k: int = 5,
|
||||||
token_budget: int = 3000,
|
token_budget: int = 3000,
|
||||||
filters: dict[str, Any] | None = None,
|
filters: dict[str, Any] | None = None,
|
||||||
|
_skip_correction: bool = False,
|
||||||
) -> list[MemoryItem]:
|
) -> list[MemoryItem]:
|
||||||
"""混合检索三层记忆"""
|
"""混合检索三层记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 检索查询
|
||||||
|
top_k: 返回最大结果数
|
||||||
|
token_budget: token 预算
|
||||||
|
filters: 过滤条件
|
||||||
|
_skip_correction: 内部参数,CRAG 循环内部调用时跳过自纠正
|
||||||
|
"""
|
||||||
|
# Self-correction loop (CRAG)
|
||||||
|
if (
|
||||||
|
self._enable_self_correction
|
||||||
|
and self._correction_loop is not None
|
||||||
|
and not _skip_correction
|
||||||
|
):
|
||||||
|
result = await self._correction_loop.retrieve_with_correction(
|
||||||
|
query, top_k=top_k, token_budget=token_budget, filters=filters
|
||||||
|
)
|
||||||
|
if result.degraded:
|
||||||
|
logger.warning(
|
||||||
|
f"RAG self-correction degraded after {result.total_retries} retries"
|
||||||
|
)
|
||||||
|
return result.items
|
||||||
# Query transformation
|
# Query transformation
|
||||||
if self._query_transformer is not None:
|
if self._query_transformer is not None:
|
||||||
transformed = await self._query_transformer.transform(query)
|
transformed = await self._query_transformer.transform(query)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,337 @@
|
||||||
|
"""Tests for RelevanceScorer and RAGSelfCorrectionLoop"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.memory.base import MemoryItem
|
||||||
|
from agentkit.memory.relevance_scorer import (
|
||||||
|
RelevanceScorer,
|
||||||
|
RelevanceScore,
|
||||||
|
RelevanceVerdict,
|
||||||
|
RetrievalEvaluation,
|
||||||
|
)
|
||||||
|
from agentkit.memory.rag_loop import (
|
||||||
|
RAGSelfCorrectionLoop,
|
||||||
|
RAGLoopResult,
|
||||||
|
LoopState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- RelevanceScorer Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelevanceScorer:
|
||||||
|
"""RelevanceScorer unit tests"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
def test_score_highly_relevant_item(self):
|
||||||
|
"""Highly relevant document should score high"""
|
||||||
|
query = "Python web framework Django Flask"
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc1",
|
||||||
|
value="Django and Flask are popular Python web frameworks for building web applications",
|
||||||
|
score=0.9,
|
||||||
|
)
|
||||||
|
result = self.scorer.score_item(query, item)
|
||||||
|
assert result.score > 0.5
|
||||||
|
assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS)
|
||||||
|
|
||||||
|
def test_score_irrelevant_item(self):
|
||||||
|
"""Completely irrelevant document should score low"""
|
||||||
|
query = "Python web framework"
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc2",
|
||||||
|
value="The weather is sunny today and the birds are singing in the garden",
|
||||||
|
score=0.1,
|
||||||
|
)
|
||||||
|
result = self.scorer.score_item(query, item)
|
||||||
|
assert result.score < 0.5
|
||||||
|
assert result.verdict == RelevanceVerdict.INCORRECT
|
||||||
|
|
||||||
|
def test_score_chinese_relevant_item(self):
|
||||||
|
"""Chinese text relevance scoring"""
|
||||||
|
query = "GEO优化策略"
|
||||||
|
item = MemoryItem(
|
||||||
|
key="doc3",
|
||||||
|
value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面",
|
||||||
|
score=0.85,
|
||||||
|
)
|
||||||
|
result = self.scorer.score_item(query, item)
|
||||||
|
assert result.score > 0.3 # Chinese bigrams should match
|
||||||
|
|
||||||
|
def test_score_short_document_penalty(self):
|
||||||
|
"""Very short documents should be penalized"""
|
||||||
|
query = "machine learning algorithms"
|
||||||
|
short_item = MemoryItem(key="short", value="ML", score=0.9)
|
||||||
|
normal_item = MemoryItem(
|
||||||
|
key="normal",
|
||||||
|
value="Machine learning algorithms include supervised and unsupervised learning methods",
|
||||||
|
score=0.9,
|
||||||
|
)
|
||||||
|
short_result = self.scorer.score_item(query, short_item)
|
||||||
|
normal_result = self.scorer.score_item(query, normal_item)
|
||||||
|
assert normal_result.score > short_result.score
|
||||||
|
|
||||||
|
def test_evaluate_empty_results(self):
|
||||||
|
"""Empty retrieval results should be INCORRECT"""
|
||||||
|
evaluation = self.scorer.evaluate("test query", [])
|
||||||
|
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
||||||
|
assert evaluation.avg_score == 0.0
|
||||||
|
assert evaluation.total_count == 0
|
||||||
|
|
||||||
|
def test_evaluate_mixed_results(self):
|
||||||
|
"""Mixed quality results should be AMBIGUOUS or CORRECT"""
|
||||||
|
query = "Python web framework"
|
||||||
|
items = [
|
||||||
|
MemoryItem(key="good", value="Django is a Python web framework", score=0.9),
|
||||||
|
MemoryItem(key="bad", value="Weather forecast for today", score=0.1),
|
||||||
|
]
|
||||||
|
evaluation = self.scorer.evaluate(query, items)
|
||||||
|
assert evaluation.total_count == 2
|
||||||
|
assert evaluation.relevant_count >= 1
|
||||||
|
|
||||||
|
def test_evaluate_all_correct(self):
|
||||||
|
"""All relevant results should give CORRECT verdict"""
|
||||||
|
query = "Python Django"
|
||||||
|
items = [
|
||||||
|
MemoryItem(key="d1", value="Django is a Python web framework", score=0.9),
|
||||||
|
MemoryItem(key="d2", value="Django REST framework for API development", score=0.85),
|
||||||
|
]
|
||||||
|
evaluation = self.scorer.evaluate(query, items)
|
||||||
|
assert evaluation.overall_verdict == RelevanceVerdict.CORRECT
|
||||||
|
|
||||||
|
def test_evaluate_all_incorrect(self):
|
||||||
|
"""All irrelevant results should give INCORRECT verdict"""
|
||||||
|
query = "quantum computing"
|
||||||
|
items = [
|
||||||
|
MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1),
|
||||||
|
MemoryItem(key="d2", value="Gardening tips for spring", score=0.05),
|
||||||
|
]
|
||||||
|
evaluation = self.scorer.evaluate(query, items)
|
||||||
|
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
||||||
|
|
||||||
|
def test_custom_thresholds(self):
|
||||||
|
"""Custom thresholds should affect verdict"""
|
||||||
|
scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7)
|
||||||
|
query = "test"
|
||||||
|
item = MemoryItem(key="d1", value="test document with some content", score=0.5)
|
||||||
|
result = scorer.score_item(query, item)
|
||||||
|
# With high thresholds, this should be INCORRECT
|
||||||
|
assert result.verdict == RelevanceVerdict.INCORRECT
|
||||||
|
|
||||||
|
def test_jaccard_similarity(self):
|
||||||
|
"""Jaccard similarity calculation"""
|
||||||
|
set_a = {"python", "web", "framework"}
|
||||||
|
set_b = {"python", "web", "server"}
|
||||||
|
similarity = RelevanceScorer._jaccard_similarity(set_a, set_b)
|
||||||
|
assert 0.0 < similarity < 1.0
|
||||||
|
# 2 common / 4 unique = 0.5
|
||||||
|
assert abs(similarity - 0.5) < 0.01
|
||||||
|
|
||||||
|
def test_jaccard_empty_sets(self):
|
||||||
|
"""Jaccard with empty sets returns 0"""
|
||||||
|
assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0
|
||||||
|
assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0
|
||||||
|
|
||||||
|
def test_query_coverage(self):
|
||||||
|
"""Query term coverage calculation"""
|
||||||
|
query_terms = {"python", "django", "flask"}
|
||||||
|
doc_terms = {"python", "django", "web", "framework"}
|
||||||
|
coverage = RelevanceScorer._query_coverage(query_terms, doc_terms)
|
||||||
|
# 2 out of 3 query terms covered
|
||||||
|
assert abs(coverage - 2 / 3) < 0.01
|
||||||
|
|
||||||
|
def test_tokenize_chinese(self):
|
||||||
|
"""Chinese tokenization includes bigrams"""
|
||||||
|
tokens = RelevanceScorer._tokenize("机器学习算法")
|
||||||
|
# Should include individual chars and bigrams
|
||||||
|
assert "机" in tokens
|
||||||
|
assert "器" in tokens
|
||||||
|
assert "机器" in tokens # bigram
|
||||||
|
|
||||||
|
def test_tokenize_english(self):
|
||||||
|
"""English tokenization"""
|
||||||
|
tokens = RelevanceScorer._tokenize("Python Web Framework")
|
||||||
|
assert "python" in tokens
|
||||||
|
assert "web" in tokens
|
||||||
|
assert "framework" in tokens
|
||||||
|
|
||||||
|
|
||||||
|
# --- RAGSelfCorrectionLoop Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class MockRetriever:
|
||||||
|
"""Mock retriever for testing"""
|
||||||
|
|
||||||
|
def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None):
|
||||||
|
self._items = items_by_query or {}
|
||||||
|
self.call_count = 0
|
||||||
|
self.queries: list[str] = []
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
token_budget: int = 3000,
|
||||||
|
filters=None,
|
||||||
|
_skip_correction: bool = False,
|
||||||
|
) -> list[MemoryItem]:
|
||||||
|
self.call_count += 1
|
||||||
|
self.queries.append(query)
|
||||||
|
# Return items for exact query match, or default items
|
||||||
|
if query in self._items:
|
||||||
|
return self._items[query]
|
||||||
|
# Return default items for any query
|
||||||
|
default_key = next(iter(self._items), None)
|
||||||
|
if default_key:
|
||||||
|
return self._items[default_key]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGSelfCorrectionLoop:
|
||||||
|
"""RAGSelfCorrectionLoop unit tests"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_correct_retrieval_skips_correction(self):
|
||||||
|
"""High-quality retrieval should not trigger correction"""
|
||||||
|
items = [
|
||||||
|
MemoryItem(
|
||||||
|
key="d1",
|
||||||
|
value="Django is a Python web framework for building web applications quickly",
|
||||||
|
score=0.9,
|
||||||
|
),
|
||||||
|
MemoryItem(
|
||||||
|
key="d2",
|
||||||
|
value="Flask is a lightweight Python web framework for small applications",
|
||||||
|
score=0.85,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({"Python web framework": items})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("Python web framework")
|
||||||
|
assert not result.degraded
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.total_retries == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poor_retrieval_triggers_correction(self):
|
||||||
|
"""Poor retrieval should trigger query rewriting"""
|
||||||
|
poor_items = [
|
||||||
|
MemoryItem(key="d1", value="Weather forecast for today", score=0.1),
|
||||||
|
]
|
||||||
|
good_items = [
|
||||||
|
MemoryItem(
|
||||||
|
key="d2",
|
||||||
|
value="Python Django web framework tutorial and best practices",
|
||||||
|
score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({
|
||||||
|
"Python web framework": poor_items,
|
||||||
|
"improved query": good_items,
|
||||||
|
})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("Python web framework")
|
||||||
|
assert mock.call_count >= 2 # At least initial + 1 retry
|
||||||
|
assert len(result.attempts) >= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_retries_causes_degradation(self):
|
||||||
|
"""Exceeding max retries should cause degradation"""
|
||||||
|
poor_items = [
|
||||||
|
MemoryItem(key="d1", value="Unrelated content about weather", score=0.05),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({"any query": poor_items})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("Python web framework")
|
||||||
|
assert result.degraded
|
||||||
|
assert result.total_retries >= 2
|
||||||
|
# Items should be marked low_confidence
|
||||||
|
assert any(
|
||||||
|
item.metadata.get("low_confidence", False) for item in result.items
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_retrieval_triggers_correction(self):
|
||||||
|
"""Empty retrieval results should trigger correction"""
|
||||||
|
mock = MockRetriever({"query": []})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("test query")
|
||||||
|
assert result.degraded
|
||||||
|
assert result.total_retries >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_result_tracks_attempts(self):
|
||||||
|
"""Loop result should track all correction attempts"""
|
||||||
|
items = [
|
||||||
|
MemoryItem(key="d1", value="Relevant Python content", score=0.9),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({"test": items})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("test")
|
||||||
|
assert len(result.attempts) >= 1
|
||||||
|
assert result.attempts[0].query == "test"
|
||||||
|
assert result.attempts[0].state in (
|
||||||
|
LoopState.GENERATE,
|
||||||
|
LoopState.CORRECT,
|
||||||
|
LoopState.DEGRADE,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_correction_with_query_transformer(self):
|
||||||
|
"""Query transformer should be used during correction"""
|
||||||
|
from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase
|
||||||
|
|
||||||
|
class MockTransformer(QueryTransformerBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.transform_count = 0
|
||||||
|
|
||||||
|
async def transform(self, query: str) -> TransformedQuery:
|
||||||
|
self.transform_count += 1
|
||||||
|
return TransformedQuery(
|
||||||
|
main_query=f"improved {query}",
|
||||||
|
sub_queries=[f"sub-{query}"],
|
||||||
|
original_query=query,
|
||||||
|
)
|
||||||
|
|
||||||
|
poor_items = [
|
||||||
|
MemoryItem(key="d1", value="Unrelated", score=0.05),
|
||||||
|
]
|
||||||
|
good_items = [
|
||||||
|
MemoryItem(key="d2", value="Relevant Python content", score=0.9),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({
|
||||||
|
"test": poor_items,
|
||||||
|
"sub-test": good_items,
|
||||||
|
})
|
||||||
|
transformer = MockTransformer()
|
||||||
|
loop = RAGSelfCorrectionLoop(
|
||||||
|
retriever=mock,
|
||||||
|
query_transformer=transformer,
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("test")
|
||||||
|
assert transformer.transform_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_degraded_result_filters_irrelevant(self):
|
||||||
|
"""Degraded result should prefer relevant items over irrelevant"""
|
||||||
|
mixed_items = [
|
||||||
|
MemoryItem(key="good", value="Python Django framework", score=0.8),
|
||||||
|
MemoryItem(key="bad", value="Weather forecast", score=0.05),
|
||||||
|
]
|
||||||
|
mock = MockRetriever({"query": mixed_items})
|
||||||
|
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1)
|
||||||
|
|
||||||
|
result = await loop.retrieve_with_correction("Python framework")
|
||||||
|
# Even if degraded, should prefer relevant items
|
||||||
|
if result.degraded:
|
||||||
|
relevant_keys = [item.key for item in result.items]
|
||||||
|
assert "good" in relevant_keys
|
||||||
Loading…
Reference in New Issue