fischer-agentkit/tests/unit/rag_platform/test_rerank.py

345 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U5 测试 — Rerank 模型集成Cohere/BGE/none
测试场景:
1. provider="none" 时原样返回结果
2. 空结果列表原样返回
3. Cohere rerank 调用 LlamaIndex CohereRerank 并按相关性重排
4. BGE rerank 调用 SentenceTransformerRerank 并按相关性重排
5. rerank 失败时降级返回原始结果
6. rerank 后结果分数被更新为 rerank 分数
7. RerankConfig 默认值与字段校验
"""
from __future__ import annotations
import sys
from unittest.mock import MagicMock
from agentkit.rag_platform.models import QueryResult
from agentkit.rag_platform.rerank import RerankConfig, Reranker
# ---------------------------------------------------------------------------
# llama_index 模块 mock — 测试环境可能未安装 llama_index
# ---------------------------------------------------------------------------
def _setup_llama_index_mocks():
"""注入 mock llama_index 模块到 sys.modules使 import 成功。
仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。
"""
try:
import llama_index # noqa: F401
return # 真实模块可用,无需 mock
except ImportError:
pass
mock_li = MagicMock()
mock_li_core = MagicMock()
mock_li_core_schema = MagicMock()
mock_li_postprocessor = MagicMock()
mock_li_postprocessor_cohere = MagicMock()
mock_li_postprocessor_st = MagicMock()
# TextNode / NodeWithScore — 简单的 mock 类
class MockTextNode:
def __init__(self, id_=None, text="", metadata=None):
self.node_id = id_
self.text = text
self.metadata = metadata or {}
def get_content(self):
return self.text
class MockNodeWithScore:
def __init__(self, node=None, score=None):
self.node = node
self.score = score
mock_li_core_schema.TextNode = MockTextNode
mock_li_core_schema.NodeWithScore = MockNodeWithScore
# CohereRerank / SentenceTransformerRerank — 由测试动态配置
mock_li_postprocessor_cohere.CohereRerank = MagicMock()
mock_li_postprocessor_st.SentenceTransformerRerank = MagicMock()
sys.modules["llama_index"] = mock_li
sys.modules["llama_index.core"] = mock_li_core
sys.modules["llama_index.core.schema"] = mock_li_core_schema
sys.modules["llama_index.postprocessor"] = mock_li_postprocessor
sys.modules["llama_index.postprocessor.cohere_rerank"] = mock_li_postprocessor_cohere
sys.modules["llama_index.postprocessor.sentence_transformers_rerank"] = mock_li_postprocessor_st
_setup_llama_index_mocks()
# ---------------------------------------------------------------------------
# 测试辅助函数
# ---------------------------------------------------------------------------
def _make_query_result(
chunk_id: str,
content: str,
score: float = 0.5,
document_id: str = "doc1",
kb_id: str = "kb1",
) -> QueryResult:
"""创建测试用 QueryResult。"""
return QueryResult(
chunk_id=chunk_id,
content=content,
score=score,
metadata={"document_id": document_id, "kb_id": kb_id},
document_id=document_id,
kb_id=kb_id,
)
def _make_mock_reranker(reranked_order: list[tuple[str, float]]):
"""创建 mock LlamaIndex reranker。
Args:
reranked_order: [(node_id, new_score), ...] — 重排后的顺序和分数
Returns:
mock reranker 实例postprocessnodes 方法返回重排后的 NodeWithScore 列表
"""
from llama_index.core.schema import NodeWithScore, TextNode
mock = MagicMock()
mock.postprocessnodes = MagicMock(
return_value=[
NodeWithScore(
node=TextNode(id_=node_id, text=""),
score=score,
)
for node_id, score in reranked_order
]
)
return mock
# ---------------------------------------------------------------------------
# RerankConfig 测试
# ---------------------------------------------------------------------------
class TestRerankConfig:
"""RerankConfig 模型测试。"""
def test_defaults(self):
"""默认值为 none provider无数据出境风险。"""
cfg = RerankConfig()
assert cfg.provider == "none"
assert cfg.api_key is None
assert cfg.base_url is None
assert cfg.top_n == 5
assert cfg.data_export_warning is False
def test_cohere_config(self):
"""Cohere 配置 — 需标注数据出境风险。"""
cfg = RerankConfig(
provider="cohere",
api_key="test-key",
top_n=3,
data_export_warning=True,
)
assert cfg.provider == "cohere"
assert cfg.api_key == "test-key"
assert cfg.top_n == 3
assert cfg.data_export_warning is True
def test_bge_config(self):
"""BGE 配置 — 本地部署,无数据出境。"""
cfg = RerankConfig(
provider="bge",
base_url="http://localhost:9997",
model_name="bge-reranker-base",
top_n=10,
)
assert cfg.provider == "bge"
assert cfg.base_url == "http://localhost:9997"
assert cfg.model_name == "bge-reranker-base"
assert cfg.top_n == 10
# ---------------------------------------------------------------------------
# Reranker 测试
# ---------------------------------------------------------------------------
class TestRerankerNone:
"""provider="none" 时的测试。"""
async def test_none_provider_returns_as_is(self):
"""provider="none" 时原样返回结果(不调用 reranker"""
cfg = RerankConfig(provider="none")
reranker = Reranker(cfg)
results = [
_make_query_result("c1", "content 1", 0.9),
_make_query_result("c2", "content 2", 0.5),
]
reranked = await reranker.rerank("query", results)
assert len(reranked) == 2
assert reranked[0].chunk_id == "c1"
assert reranked[1].chunk_id == "c2"
# 分数不变
assert reranked[0].score == 0.9
assert reranked[1].score == 0.5
async def test_empty_results_returns_empty(self):
"""空结果列表原样返回空。"""
cfg = RerankConfig(provider="cohere", api_key="key")
reranker = Reranker(cfg)
reranked = await reranker.rerank("query", [])
assert reranked == []
class TestRerankerCohere:
"""Cohere rerank 测试。"""
async def test_rerank_reorders_results(self):
"""rerank 后结果按相关性重排 — 原始顺序被打乱。"""
# 原始顺序c1 (score=0.9), c2 (score=0.5)
# 重排后c2 (score=0.95), c1 (score=0.80) — c2 更相关
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=2)
reranker = Reranker(cfg)
# 注入 mock reranker
mock_reranker = _make_mock_reranker([("c2", 0.95), ("c1", 0.80)])
reranker._reranker = mock_reranker
results = [
_make_query_result("c1", "content 1", 0.9),
_make_query_result("c2", "content 2", 0.5),
]
reranked = await reranker.rerank("query", results)
assert len(reranked) == 2
# 重排后 c2 在前
assert reranked[0].chunk_id == "c2"
assert reranked[1].chunk_id == "c1"
# 分数被更新为 rerank 分数
assert reranked[0].score == 0.95
assert reranked[1].score == 0.80
async def test_rerank_preserves_metadata(self):
"""rerank 后保留原始 QueryResult 的元数据document_id, kb_id"""
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=1)
reranker = Reranker(cfg)
mock_reranker = _make_mock_reranker([("c1", 0.99)])
reranker._reranker = mock_reranker
results = [
_make_query_result("c1", "content 1", 0.5, document_id="doc-99", kb_id="kb-99"),
]
reranked = await reranker.rerank("query", results)
assert len(reranked) == 1
assert reranked[0].document_id == "doc-99"
assert reranked[0].kb_id == "kb-99"
assert reranked[0].metadata["document_id"] == "doc-99"
async def test_rerank_failure_falls_back(self):
"""reranker 抛异常时降级返回原始结果。"""
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=5)
reranker = Reranker(cfg)
# mock reranker 抛异常
mock_reranker = MagicMock()
mock_reranker.postprocessnodes = MagicMock(side_effect=RuntimeError("API error"))
reranker._reranker = mock_reranker
results = [
_make_query_result("c1", "content 1", 0.9),
_make_query_result("c2", "content 2", 0.5),
]
reranked = await reranker.rerank("query", results)
# 降级返回原始结果(顺序不变)
assert len(reranked) == 2
assert reranked[0].chunk_id == "c1"
assert reranked[1].chunk_id == "c2"
async def test_cohere_requires_api_key(self):
"""Cohere provider 缺少 api_key 时抛 ValueError。"""
cfg = RerankConfig(provider="cohere", api_key=None)
reranker = Reranker(cfg)
results = [_make_query_result("c1", "content", 0.5)]
try:
await reranker.rerank("query", results)
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "api_key" in str(e)
class TestRerankerBGE:
"""BGE rerank 测试。"""
async def test_bge_rerank_reorders_results(self):
"""BGE rerank 同样按相关性重排。"""
cfg = RerankConfig(provider="bge", top_n=2)
reranker = Reranker(cfg)
mock_reranker = _make_mock_reranker([("c2", 0.88), ("c1", 0.55)])
reranker._reranker = mock_reranker
results = [
_make_query_result("c1", "content 1", 0.9),
_make_query_result("c2", "content 2", 0.5),
]
reranked = await reranker.rerank("query", results)
assert len(reranked) == 2
assert reranked[0].chunk_id == "c2"
assert reranked[0].score == 0.88
assert reranked[1].chunk_id == "c1"
assert reranked[1].score == 0.55
async def test_bge_no_api_key_required(self):
"""BGE 本地部署不需要 api_key。"""
cfg = RerankConfig(provider="bge", top_n=5)
reranker = Reranker(cfg)
# 注入 mock reranker — 验证不抛异常
mock_reranker = _make_mock_reranker([("c1", 0.9)])
reranker._reranker = mock_reranker
results = [_make_query_result("c1", "content", 0.5)]
reranked = await reranker.rerank("query", results)
assert len(reranked) == 1
assert reranked[0].score == 0.9
class TestRerankerUnsupportedProvider:
"""不支持的 provider 测试。"""
async def test_unsupported_provider_raises(self):
"""不支持的 provider 抛 ValueError。"""
cfg = RerankConfig(provider="invalid_provider")
reranker = Reranker(cfg)
results = [_make_query_result("c1", "content", 0.5)]
try:
await reranker.rerank("query", results)
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "Unsupported" in str(e)