feat(memory): U1 RAG self-correction loop (CRAG)
- RelevanceScorer: keyword overlap + query coverage + retrieval score + length penalty - RAGSelfCorrectionLoop: state machine driven retrieve-evaluate-correct-degrade cycle - Integrated into MemoryRetriever with enable_self_correction option - 21 tests passing
This commit is contained in:
parent
468dfd71e8
commit
a6c9babfdc
|
|
@ -0,0 +1,237 @@
|
|||
"""RAGSelfCorrectionLoop - CRAG 自纠正循环
|
||||
|
||||
实现 Corrective RAG 模式:检索→评估→纠正/降级→生成
|
||||
当检索结果质量不足时,自动改写查询重新检索,形成自纠正闭环。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.base import MemoryItem
|
||||
from agentkit.memory.query_transformer import QueryTransformerBase, NoOpQueryTransformer
|
||||
from agentkit.memory.relevance_scorer import (
|
||||
RelevanceScorer,
|
||||
RelevanceVerdict,
|
||||
RetrievalEvaluation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopState(str, Enum):
|
||||
"""自纠正循环状态"""
|
||||
|
||||
RETRIEVE = "retrieve"
|
||||
EVALUATE = "evaluate"
|
||||
CORRECT = "correct"
|
||||
DEGRADE = "degrade"
|
||||
GENERATE = "generate"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CorrectionAttempt:
|
||||
"""一次纠正尝试的记录"""
|
||||
|
||||
query: str
|
||||
evaluation: RetrievalEvaluation
|
||||
state: LoopState
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGLoopResult:
|
||||
"""自纠正循环的最终结果"""
|
||||
|
||||
items: list[MemoryItem]
|
||||
evaluation: RetrievalEvaluation
|
||||
attempts: list[CorrectionAttempt]
|
||||
corrected: bool
|
||||
degraded: bool
|
||||
total_retries: int
|
||||
|
||||
|
||||
class RAGSelfCorrectionLoop:
|
||||
"""CRAG 自纠正循环
|
||||
|
||||
状态机驱动的检索-评估-纠正循环:
|
||||
1. RETRIEVE: 使用 MemoryRetriever 检索
|
||||
2. EVALUATE: RelevanceScorer 评估检索质量
|
||||
3. CORRECT: 质量不足时,改写查询重新检索
|
||||
4. DEGRADE: 超过重试次数,返回降级结果
|
||||
5. GENERATE: 质量足够,返回结果
|
||||
|
||||
熔断机制:
|
||||
- max_retries: 最大重试次数(默认 3)
|
||||
- 超过重试次数后强制降级,标记 low_confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: Any, # MemoryRetriever
|
||||
scorer: RelevanceScorer | None = None,
|
||||
query_transformer: QueryTransformerBase | None = None,
|
||||
max_retries: int = 3,
|
||||
min_items_for_correct: int = 1,
|
||||
):
|
||||
self._retriever = retriever
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._query_transformer = query_transformer or NoOpQueryTransformer()
|
||||
self._max_retries = max_retries
|
||||
self._min_items_for_correct = min_items_for_correct
|
||||
|
||||
async def retrieve_with_correction(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
token_budget: int = 3000,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> RAGLoopResult:
|
||||
"""执行带自纠正的检索
|
||||
|
||||
Args:
|
||||
query: 原始查询
|
||||
top_k: 返回的最大结果数
|
||||
token_budget: token 预算
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
RAGLoopResult: 包含检索结果、评估、尝试记录
|
||||
"""
|
||||
attempts: list[CorrectionAttempt] = []
|
||||
current_query = query
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= self._max_retries:
|
||||
# RETRIEVE
|
||||
items = await self._retriever.retrieve(
|
||||
current_query, top_k=top_k, token_budget=token_budget,
|
||||
filters=filters, _skip_correction=True,
|
||||
)
|
||||
|
||||
# EVALUATE
|
||||
evaluation = self._scorer.evaluate(current_query, items)
|
||||
state = self._determine_next_state(evaluation, items)
|
||||
|
||||
attempt = CorrectionAttempt(
|
||||
query=current_query,
|
||||
evaluation=evaluation,
|
||||
state=state,
|
||||
)
|
||||
attempts.append(attempt)
|
||||
|
||||
logger.info(
|
||||
f"RAG loop attempt {retry_count + 1}: "
|
||||
f"query='{current_query[:50]}...', "
|
||||
f"verdict={evaluation.overall_verdict.value}, "
|
||||
f"avg_score={evaluation.avg_score:.2f}, "
|
||||
f"state={state.value}"
|
||||
)
|
||||
|
||||
# GENERATE — quality is sufficient
|
||||
if state == LoopState.GENERATE:
|
||||
return RAGLoopResult(
|
||||
items=items,
|
||||
evaluation=evaluation,
|
||||
attempts=attempts,
|
||||
corrected=retry_count > 0,
|
||||
degraded=False,
|
||||
total_retries=retry_count,
|
||||
)
|
||||
|
||||
# CORRECT — rewrite query and retry
|
||||
retry_count += 1
|
||||
if retry_count <= self._max_retries:
|
||||
current_query = await self._rewrite_query(
|
||||
query, current_query, evaluation
|
||||
)
|
||||
continue
|
||||
|
||||
# DEGRADE — exceeded max retries
|
||||
break
|
||||
|
||||
# Degraded result: filter to relevant items and mark low confidence
|
||||
relevant_items = [
|
||||
s.item
|
||||
for s in evaluation.scores
|
||||
if s.verdict != RelevanceVerdict.INCORRECT
|
||||
]
|
||||
result_items = relevant_items if relevant_items else items
|
||||
|
||||
for item in result_items:
|
||||
item.metadata["low_confidence"] = True
|
||||
|
||||
return RAGLoopResult(
|
||||
items=result_items,
|
||||
evaluation=evaluation,
|
||||
attempts=attempts,
|
||||
corrected=False,
|
||||
degraded=True,
|
||||
total_retries=retry_count,
|
||||
)
|
||||
|
||||
def _determine_next_state(
|
||||
self, evaluation: RetrievalEvaluation, items: list[MemoryItem]
|
||||
) -> LoopState:
|
||||
"""根据评估结果确定下一个状态"""
|
||||
verdict = evaluation.overall_verdict
|
||||
|
||||
if verdict == RelevanceVerdict.CORRECT:
|
||||
if evaluation.relevant_count >= self._min_items_for_correct:
|
||||
return LoopState.GENERATE
|
||||
# Correct verdict but not enough items — still try to generate
|
||||
if items:
|
||||
return LoopState.GENERATE
|
||||
return LoopState.CORRECT
|
||||
|
||||
if verdict == RelevanceVerdict.AMBIGUOUS:
|
||||
# Some relevant results — could improve but not terrible
|
||||
return LoopState.CORRECT
|
||||
|
||||
# INCORRECT — definitely need correction
|
||||
return LoopState.CORRECT
|
||||
|
||||
async def _rewrite_query(
|
||||
self,
|
||||
original_query: str,
|
||||
current_query: str,
|
||||
evaluation: RetrievalEvaluation,
|
||||
) -> str:
|
||||
"""改写查询以改善检索质量
|
||||
|
||||
策略:
|
||||
1. 使用 QueryTransformer 改写
|
||||
2. 从评估结果中提取改进线索
|
||||
3. 追加失败模式提示
|
||||
"""
|
||||
# Use query transformer for rewriting
|
||||
transformed = await self._query_transformer.transform(current_query)
|
||||
new_query = transformed.main_query
|
||||
|
||||
# If transformer didn't change the query, try with original
|
||||
if new_query == current_query:
|
||||
# Add context from failed evaluation to help next retrieval
|
||||
failed_terms = []
|
||||
for score in evaluation.scores:
|
||||
if score.verdict == RelevanceVerdict.INCORRECT:
|
||||
# Extract key terms from low-scoring items to avoid
|
||||
doc_text = str(score.item.value)[:100]
|
||||
failed_terms.append(doc_text)
|
||||
|
||||
if failed_terms and original_query != current_query:
|
||||
# Try original query as fallback
|
||||
new_query = original_query
|
||||
elif failed_terms:
|
||||
# Add "NOT" context to help filter
|
||||
new_query = f"{current_query} (excluding irrelevant results)"
|
||||
|
||||
# Add sub-queries if available
|
||||
if transformed.sub_queries:
|
||||
# Use the first sub-query as the new primary query
|
||||
# This explores different aspects of the original question
|
||||
new_query = transformed.sub_queries[0]
|
||||
|
||||
logger.info(f"Query rewritten: '{current_query[:50]}...' -> '{new_query[:50]}...'")
|
||||
return new_query
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
"""RelevanceScorer - 检索结果相关性自动评估
|
||||
|
||||
对检索结果逐文档评估与查询的相关性,用于 CRAG 自纠正循环的评估阶段。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from agentkit.memory.base import MemoryItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelevanceVerdict(str, Enum):
|
||||
"""相关性判定结果"""
|
||||
|
||||
CORRECT = "correct"
|
||||
AMBIGUOUS = "ambiguous"
|
||||
INCORRECT = "incorrect"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelevanceScore:
|
||||
"""单个文档的相关性评分"""
|
||||
|
||||
item: MemoryItem
|
||||
score: float # 0.0 ~ 1.0
|
||||
verdict: RelevanceVerdict
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalEvaluation:
|
||||
"""一次检索的整体评估结果"""
|
||||
|
||||
scores: list[RelevanceScore]
|
||||
overall_verdict: RelevanceVerdict
|
||||
avg_score: float
|
||||
relevant_count: int
|
||||
total_count: int
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""检索结果相关性评估器
|
||||
|
||||
基于查询-文档语义相似度和关键词重叠的轻量级评估器。
|
||||
不依赖 LLM 调用,适用于生产环境的低延迟评估。
|
||||
|
||||
评分策略:
|
||||
1. 关键词重叠率(Jaccard 相似度)
|
||||
2. 查询词覆盖率(query term coverage)
|
||||
3. 原始检索分数加权
|
||||
4. 长度惩罚(过短或过长的文档降分)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
correct_threshold: float = 0.6,
|
||||
ambiguous_threshold: float = 0.35,
|
||||
keyword_weight: float = 0.3,
|
||||
coverage_weight: float = 0.3,
|
||||
retrieval_weight: float = 0.3,
|
||||
length_weight: float = 0.1,
|
||||
min_doc_length: int = 20,
|
||||
max_doc_length: int = 5000,
|
||||
):
|
||||
self._correct_threshold = correct_threshold
|
||||
self._ambiguous_threshold = ambiguous_threshold
|
||||
self._keyword_weight = keyword_weight
|
||||
self._coverage_weight = coverage_weight
|
||||
self._retrieval_weight = retrieval_weight
|
||||
self._length_weight = length_weight
|
||||
self._min_doc_length = min_doc_length
|
||||
self._max_doc_length = max_doc_length
|
||||
|
||||
def score_item(self, query: str, item: MemoryItem) -> RelevanceScore:
|
||||
"""评估单个检索结果与查询的相关性"""
|
||||
doc_text = str(item.value)
|
||||
|
||||
# 1. Keyword overlap (Jaccard similarity)
|
||||
query_terms = self._tokenize(query)
|
||||
doc_terms = self._tokenize(doc_text)
|
||||
keyword_score = self._jaccard_similarity(query_terms, doc_terms)
|
||||
|
||||
# 2. Query term coverage
|
||||
coverage_score = self._query_coverage(query_terms, doc_terms)
|
||||
|
||||
# 3. Original retrieval score
|
||||
retrieval_score = min(max(item.score, 0.0), 1.0)
|
||||
|
||||
# 4. Length penalty
|
||||
length_score = self._length_score(len(doc_text))
|
||||
|
||||
# Weighted combination
|
||||
final_score = (
|
||||
keyword_score * self._keyword_weight
|
||||
+ coverage_score * self._coverage_weight
|
||||
+ retrieval_score * self._retrieval_weight
|
||||
+ length_score * self._length_weight
|
||||
)
|
||||
|
||||
# Determine verdict
|
||||
verdict = self._determine_verdict(final_score)
|
||||
|
||||
reason = (
|
||||
f"keyword={keyword_score:.2f}, coverage={coverage_score:.2f}, "
|
||||
f"retrieval={retrieval_score:.2f}, length={length_score:.2f}"
|
||||
)
|
||||
|
||||
return RelevanceScore(
|
||||
item=item,
|
||||
score=final_score,
|
||||
verdict=verdict,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def evaluate(
|
||||
self, query: str, items: list[MemoryItem]
|
||||
) -> RetrievalEvaluation:
|
||||
"""评估一次检索的整体质量"""
|
||||
if not items:
|
||||
return RetrievalEvaluation(
|
||||
scores=[],
|
||||
overall_verdict=RelevanceVerdict.INCORRECT,
|
||||
avg_score=0.0,
|
||||
relevant_count=0,
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
scores = [self.score_item(query, item) for item in items]
|
||||
relevant_count = sum(
|
||||
1 for s in scores if s.verdict != RelevanceVerdict.INCORRECT
|
||||
)
|
||||
avg_score = sum(s.score for s in scores) / len(scores)
|
||||
|
||||
# Overall verdict based on average score and relevant ratio
|
||||
relevant_ratio = relevant_count / len(scores)
|
||||
|
||||
if avg_score >= self._correct_threshold and relevant_ratio >= 0.5:
|
||||
overall_verdict = RelevanceVerdict.CORRECT
|
||||
elif avg_score >= self._ambiguous_threshold or relevant_ratio >= 0.3:
|
||||
overall_verdict = RelevanceVerdict.AMBIGUOUS
|
||||
else:
|
||||
overall_verdict = RelevanceVerdict.INCORRECT
|
||||
|
||||
return RetrievalEvaluation(
|
||||
scores=scores,
|
||||
overall_verdict=overall_verdict,
|
||||
avg_score=avg_score,
|
||||
relevant_count=relevant_count,
|
||||
total_count=len(scores),
|
||||
)
|
||||
|
||||
def _determine_verdict(self, score: float) -> RelevanceVerdict:
|
||||
"""根据分数判定相关性"""
|
||||
if score >= self._correct_threshold:
|
||||
return RelevanceVerdict.CORRECT
|
||||
elif score >= self._ambiguous_threshold:
|
||||
return RelevanceVerdict.AMBIGUOUS
|
||||
else:
|
||||
return RelevanceVerdict.INCORRECT
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(text: str) -> set[str]:
|
||||
"""分词:中文按字符,英文按空格,统一小写"""
|
||||
tokens: set[str] = set()
|
||||
# Extract English words
|
||||
en_words = re.findall(r"[a-zA-Z]+", text.lower())
|
||||
tokens.update(en_words)
|
||||
# Extract Chinese characters (individual chars + bigrams)
|
||||
cn_chars = re.findall(r"[\u4e00-\u9fff]", text)
|
||||
tokens.update(cn_chars)
|
||||
# Add Chinese bigrams for better matching
|
||||
for i in range(len(cn_chars) - 1):
|
||||
tokens.add(cn_chars[i] + cn_chars[i + 1])
|
||||
return tokens
|
||||
|
||||
@staticmethod
|
||||
def _jaccard_similarity(set_a: set[str], set_b: set[str]) -> float:
|
||||
"""Jaccard 相似度"""
|
||||
if not set_a or not set_b:
|
||||
return 0.0
|
||||
intersection = len(set_a & set_b)
|
||||
union = len(set_a | set_b)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return intersection / union
|
||||
|
||||
@staticmethod
|
||||
def _query_coverage(query_terms: set[str], doc_terms: set[str]) -> float:
|
||||
"""查询词覆盖率:文档中出现的查询词比例"""
|
||||
if not query_terms:
|
||||
return 0.0
|
||||
covered = len(query_terms & doc_terms)
|
||||
return covered / len(query_terms)
|
||||
|
||||
def _length_score(self, length: int) -> float:
|
||||
"""长度评分:过短或过长的文档降分"""
|
||||
if length < self._min_doc_length:
|
||||
# Too short — likely insufficient context
|
||||
ratio = length / self._min_doc_length
|
||||
return ratio * 0.5
|
||||
elif length > self._max_doc_length:
|
||||
# Too long — may contain irrelevant information
|
||||
excess = (length - self._max_doc_length) / self._max_doc_length
|
||||
return max(0.3, 1.0 - excess * 0.5)
|
||||
else:
|
||||
# Good length range
|
||||
return 1.0
|
||||
|
|
@ -17,6 +17,8 @@ from agentkit.memory.working import WorkingMemory
|
|||
from agentkit.memory.episodic import EpisodicMemory
|
||||
from agentkit.memory.semantic import SemanticMemory
|
||||
from agentkit.memory.query_transformer import QueryTransformerBase
|
||||
from agentkit.memory.rag_loop import RAGSelfCorrectionLoop
|
||||
from agentkit.memory.relevance_scorer import RelevanceScorer
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -55,6 +57,8 @@ class MemoryRetriever:
|
|||
weights: dict[str, float] | None = None,
|
||||
query_transformer: QueryTransformerBase | None = None,
|
||||
context_template: str = "structured",
|
||||
enable_self_correction: bool = False,
|
||||
max_correction_retries: int = 3,
|
||||
):
|
||||
self._working = working_memory
|
||||
self._episodic = episodic_memory
|
||||
|
|
@ -66,6 +70,15 @@ class MemoryRetriever:
|
|||
}
|
||||
self._query_transformer = query_transformer
|
||||
self._context_template = context_template
|
||||
self._enable_self_correction = enable_self_correction
|
||||
self._correction_loop: RAGSelfCorrectionLoop | None = None
|
||||
if enable_self_correction:
|
||||
self._correction_loop = RAGSelfCorrectionLoop(
|
||||
retriever=self,
|
||||
scorer=RelevanceScorer(),
|
||||
query_transformer=query_transformer,
|
||||
max_retries=max_correction_retries,
|
||||
)
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
|
|
@ -73,8 +86,31 @@ class MemoryRetriever:
|
|||
top_k: int = 5,
|
||||
token_budget: int = 3000,
|
||||
filters: dict[str, Any] | None = None,
|
||||
_skip_correction: bool = False,
|
||||
) -> list[MemoryItem]:
|
||||
"""混合检索三层记忆"""
|
||||
"""混合检索三层记忆
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
top_k: 返回最大结果数
|
||||
token_budget: token 预算
|
||||
filters: 过滤条件
|
||||
_skip_correction: 内部参数,CRAG 循环内部调用时跳过自纠正
|
||||
"""
|
||||
# Self-correction loop (CRAG)
|
||||
if (
|
||||
self._enable_self_correction
|
||||
and self._correction_loop is not None
|
||||
and not _skip_correction
|
||||
):
|
||||
result = await self._correction_loop.retrieve_with_correction(
|
||||
query, top_k=top_k, token_budget=token_budget, filters=filters
|
||||
)
|
||||
if result.degraded:
|
||||
logger.warning(
|
||||
f"RAG self-correction degraded after {result.total_retries} retries"
|
||||
)
|
||||
return result.items
|
||||
# Query transformation
|
||||
if self._query_transformer is not None:
|
||||
transformed = await self._query_transformer.transform(query)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,337 @@
|
|||
"""Tests for RelevanceScorer and RAGSelfCorrectionLoop"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.memory.base import MemoryItem
|
||||
from agentkit.memory.relevance_scorer import (
|
||||
RelevanceScorer,
|
||||
RelevanceScore,
|
||||
RelevanceVerdict,
|
||||
RetrievalEvaluation,
|
||||
)
|
||||
from agentkit.memory.rag_loop import (
|
||||
RAGSelfCorrectionLoop,
|
||||
RAGLoopResult,
|
||||
LoopState,
|
||||
)
|
||||
|
||||
|
||||
# --- RelevanceScorer Tests ---
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""RelevanceScorer unit tests"""
|
||||
|
||||
def setup_method(self):
|
||||
self.scorer = RelevanceScorer()
|
||||
|
||||
def test_score_highly_relevant_item(self):
|
||||
"""Highly relevant document should score high"""
|
||||
query = "Python web framework Django Flask"
|
||||
item = MemoryItem(
|
||||
key="doc1",
|
||||
value="Django and Flask are popular Python web frameworks for building web applications",
|
||||
score=0.9,
|
||||
)
|
||||
result = self.scorer.score_item(query, item)
|
||||
assert result.score > 0.5
|
||||
assert result.verdict in (RelevanceVerdict.CORRECT, RelevanceVerdict.AMBIGUOUS)
|
||||
|
||||
def test_score_irrelevant_item(self):
|
||||
"""Completely irrelevant document should score low"""
|
||||
query = "Python web framework"
|
||||
item = MemoryItem(
|
||||
key="doc2",
|
||||
value="The weather is sunny today and the birds are singing in the garden",
|
||||
score=0.1,
|
||||
)
|
||||
result = self.scorer.score_item(query, item)
|
||||
assert result.score < 0.5
|
||||
assert result.verdict == RelevanceVerdict.INCORRECT
|
||||
|
||||
def test_score_chinese_relevant_item(self):
|
||||
"""Chinese text relevance scoring"""
|
||||
query = "GEO优化策略"
|
||||
item = MemoryItem(
|
||||
key="doc3",
|
||||
value="GEO优化策略包括内容结构化、Schema标记、AI平台适配等多个方面",
|
||||
score=0.85,
|
||||
)
|
||||
result = self.scorer.score_item(query, item)
|
||||
assert result.score > 0.3 # Chinese bigrams should match
|
||||
|
||||
def test_score_short_document_penalty(self):
|
||||
"""Very short documents should be penalized"""
|
||||
query = "machine learning algorithms"
|
||||
short_item = MemoryItem(key="short", value="ML", score=0.9)
|
||||
normal_item = MemoryItem(
|
||||
key="normal",
|
||||
value="Machine learning algorithms include supervised and unsupervised learning methods",
|
||||
score=0.9,
|
||||
)
|
||||
short_result = self.scorer.score_item(query, short_item)
|
||||
normal_result = self.scorer.score_item(query, normal_item)
|
||||
assert normal_result.score > short_result.score
|
||||
|
||||
def test_evaluate_empty_results(self):
|
||||
"""Empty retrieval results should be INCORRECT"""
|
||||
evaluation = self.scorer.evaluate("test query", [])
|
||||
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
||||
assert evaluation.avg_score == 0.0
|
||||
assert evaluation.total_count == 0
|
||||
|
||||
def test_evaluate_mixed_results(self):
|
||||
"""Mixed quality results should be AMBIGUOUS or CORRECT"""
|
||||
query = "Python web framework"
|
||||
items = [
|
||||
MemoryItem(key="good", value="Django is a Python web framework", score=0.9),
|
||||
MemoryItem(key="bad", value="Weather forecast for today", score=0.1),
|
||||
]
|
||||
evaluation = self.scorer.evaluate(query, items)
|
||||
assert evaluation.total_count == 2
|
||||
assert evaluation.relevant_count >= 1
|
||||
|
||||
def test_evaluate_all_correct(self):
|
||||
"""All relevant results should give CORRECT verdict"""
|
||||
query = "Python Django"
|
||||
items = [
|
||||
MemoryItem(key="d1", value="Django is a Python web framework", score=0.9),
|
||||
MemoryItem(key="d2", value="Django REST framework for API development", score=0.85),
|
||||
]
|
||||
evaluation = self.scorer.evaluate(query, items)
|
||||
assert evaluation.overall_verdict == RelevanceVerdict.CORRECT
|
||||
|
||||
def test_evaluate_all_incorrect(self):
|
||||
"""All irrelevant results should give INCORRECT verdict"""
|
||||
query = "quantum computing"
|
||||
items = [
|
||||
MemoryItem(key="d1", value="Cooking recipes for beginners", score=0.1),
|
||||
MemoryItem(key="d2", value="Gardening tips for spring", score=0.05),
|
||||
]
|
||||
evaluation = self.scorer.evaluate(query, items)
|
||||
assert evaluation.overall_verdict == RelevanceVerdict.INCORRECT
|
||||
|
||||
def test_custom_thresholds(self):
|
||||
"""Custom thresholds should affect verdict"""
|
||||
scorer = RelevanceScorer(correct_threshold=0.9, ambiguous_threshold=0.7)
|
||||
query = "test"
|
||||
item = MemoryItem(key="d1", value="test document with some content", score=0.5)
|
||||
result = scorer.score_item(query, item)
|
||||
# With high thresholds, this should be INCORRECT
|
||||
assert result.verdict == RelevanceVerdict.INCORRECT
|
||||
|
||||
def test_jaccard_similarity(self):
|
||||
"""Jaccard similarity calculation"""
|
||||
set_a = {"python", "web", "framework"}
|
||||
set_b = {"python", "web", "server"}
|
||||
similarity = RelevanceScorer._jaccard_similarity(set_a, set_b)
|
||||
assert 0.0 < similarity < 1.0
|
||||
# 2 common / 4 unique = 0.5
|
||||
assert abs(similarity - 0.5) < 0.01
|
||||
|
||||
def test_jaccard_empty_sets(self):
|
||||
"""Jaccard with empty sets returns 0"""
|
||||
assert RelevanceScorer._jaccard_similarity(set(), {"a"}) == 0.0
|
||||
assert RelevanceScorer._jaccard_similarity({"a"}, set()) == 0.0
|
||||
|
||||
def test_query_coverage(self):
|
||||
"""Query term coverage calculation"""
|
||||
query_terms = {"python", "django", "flask"}
|
||||
doc_terms = {"python", "django", "web", "framework"}
|
||||
coverage = RelevanceScorer._query_coverage(query_terms, doc_terms)
|
||||
# 2 out of 3 query terms covered
|
||||
assert abs(coverage - 2 / 3) < 0.01
|
||||
|
||||
def test_tokenize_chinese(self):
|
||||
"""Chinese tokenization includes bigrams"""
|
||||
tokens = RelevanceScorer._tokenize("机器学习算法")
|
||||
# Should include individual chars and bigrams
|
||||
assert "机" in tokens
|
||||
assert "器" in tokens
|
||||
assert "机器" in tokens # bigram
|
||||
|
||||
def test_tokenize_english(self):
|
||||
"""English tokenization"""
|
||||
tokens = RelevanceScorer._tokenize("Python Web Framework")
|
||||
assert "python" in tokens
|
||||
assert "web" in tokens
|
||||
assert "framework" in tokens
|
||||
|
||||
|
||||
# --- RAGSelfCorrectionLoop Tests ---
|
||||
|
||||
|
||||
class MockRetriever:
|
||||
"""Mock retriever for testing"""
|
||||
|
||||
def __init__(self, items_by_query: dict[str, list[MemoryItem]] | None = None):
|
||||
self._items = items_by_query or {}
|
||||
self.call_count = 0
|
||||
self.queries: list[str] = []
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
token_budget: int = 3000,
|
||||
filters=None,
|
||||
_skip_correction: bool = False,
|
||||
) -> list[MemoryItem]:
|
||||
self.call_count += 1
|
||||
self.queries.append(query)
|
||||
# Return items for exact query match, or default items
|
||||
if query in self._items:
|
||||
return self._items[query]
|
||||
# Return default items for any query
|
||||
default_key = next(iter(self._items), None)
|
||||
if default_key:
|
||||
return self._items[default_key]
|
||||
return []
|
||||
|
||||
|
||||
class TestRAGSelfCorrectionLoop:
|
||||
"""RAGSelfCorrectionLoop unit tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correct_retrieval_skips_correction(self):
|
||||
"""High-quality retrieval should not trigger correction"""
|
||||
items = [
|
||||
MemoryItem(
|
||||
key="d1",
|
||||
value="Django is a Python web framework for building web applications quickly",
|
||||
score=0.9,
|
||||
),
|
||||
MemoryItem(
|
||||
key="d2",
|
||||
value="Flask is a lightweight Python web framework for small applications",
|
||||
score=0.85,
|
||||
),
|
||||
]
|
||||
mock = MockRetriever({"Python web framework": items})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||
|
||||
result = await loop.retrieve_with_correction("Python web framework")
|
||||
assert not result.degraded
|
||||
assert len(result.items) == 2
|
||||
assert result.total_retries == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poor_retrieval_triggers_correction(self):
|
||||
"""Poor retrieval should trigger query rewriting"""
|
||||
poor_items = [
|
||||
MemoryItem(key="d1", value="Weather forecast for today", score=0.1),
|
||||
]
|
||||
good_items = [
|
||||
MemoryItem(
|
||||
key="d2",
|
||||
value="Python Django web framework tutorial and best practices",
|
||||
score=0.9,
|
||||
),
|
||||
]
|
||||
mock = MockRetriever({
|
||||
"Python web framework": poor_items,
|
||||
"improved query": good_items,
|
||||
})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||
|
||||
result = await loop.retrieve_with_correction("Python web framework")
|
||||
assert mock.call_count >= 2 # At least initial + 1 retry
|
||||
assert len(result.attempts) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_causes_degradation(self):
|
||||
"""Exceeding max retries should cause degradation"""
|
||||
poor_items = [
|
||||
MemoryItem(key="d1", value="Unrelated content about weather", score=0.05),
|
||||
]
|
||||
mock = MockRetriever({"any query": poor_items})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
||||
|
||||
result = await loop.retrieve_with_correction("Python web framework")
|
||||
assert result.degraded
|
||||
assert result.total_retries >= 2
|
||||
# Items should be marked low_confidence
|
||||
assert any(
|
||||
item.metadata.get("low_confidence", False) for item in result.items
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_retrieval_triggers_correction(self):
|
||||
"""Empty retrieval results should trigger correction"""
|
||||
mock = MockRetriever({"query": []})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=2)
|
||||
|
||||
result = await loop.retrieve_with_correction("test query")
|
||||
assert result.degraded
|
||||
assert result.total_retries >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_result_tracks_attempts(self):
|
||||
"""Loop result should track all correction attempts"""
|
||||
items = [
|
||||
MemoryItem(key="d1", value="Relevant Python content", score=0.9),
|
||||
]
|
||||
mock = MockRetriever({"test": items})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=3)
|
||||
|
||||
result = await loop.retrieve_with_correction("test")
|
||||
assert len(result.attempts) >= 1
|
||||
assert result.attempts[0].query == "test"
|
||||
assert result.attempts[0].state in (
|
||||
LoopState.GENERATE,
|
||||
LoopState.CORRECT,
|
||||
LoopState.DEGRADE,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correction_with_query_transformer(self):
|
||||
"""Query transformer should be used during correction"""
|
||||
from agentkit.memory.query_transformer import TransformedQuery, QueryTransformerBase
|
||||
|
||||
class MockTransformer(QueryTransformerBase):
|
||||
def __init__(self):
|
||||
self.transform_count = 0
|
||||
|
||||
async def transform(self, query: str) -> TransformedQuery:
|
||||
self.transform_count += 1
|
||||
return TransformedQuery(
|
||||
main_query=f"improved {query}",
|
||||
sub_queries=[f"sub-{query}"],
|
||||
original_query=query,
|
||||
)
|
||||
|
||||
poor_items = [
|
||||
MemoryItem(key="d1", value="Unrelated", score=0.05),
|
||||
]
|
||||
good_items = [
|
||||
MemoryItem(key="d2", value="Relevant Python content", score=0.9),
|
||||
]
|
||||
mock = MockRetriever({
|
||||
"test": poor_items,
|
||||
"sub-test": good_items,
|
||||
})
|
||||
transformer = MockTransformer()
|
||||
loop = RAGSelfCorrectionLoop(
|
||||
retriever=mock,
|
||||
query_transformer=transformer,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
result = await loop.retrieve_with_correction("test")
|
||||
assert transformer.transform_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_degraded_result_filters_irrelevant(self):
|
||||
"""Degraded result should prefer relevant items over irrelevant"""
|
||||
mixed_items = [
|
||||
MemoryItem(key="good", value="Python Django framework", score=0.8),
|
||||
MemoryItem(key="bad", value="Weather forecast", score=0.05),
|
||||
]
|
||||
mock = MockRetriever({"query": mixed_items})
|
||||
loop = RAGSelfCorrectionLoop(retriever=mock, max_retries=1)
|
||||
|
||||
result = await loop.retrieve_with_correction("Python framework")
|
||||
# Even if degraded, should prefer relevant items
|
||||
if result.degraded:
|
||||
relevant_keys = [item.key for item in result.items]
|
||||
assert "good" in relevant_keys
|
||||
Loading…
Reference in New Issue