345 lines
11 KiB
Python
345 lines
11 KiB
Python
"""U5 测试 — Rerank 模型集成(Cohere/BGE/none)。
|
||
|
||
测试场景:
|
||
1. provider="none" 时原样返回结果
|
||
2. 空结果列表原样返回
|
||
3. Cohere rerank 调用 LlamaIndex CohereRerank 并按相关性重排
|
||
4. BGE rerank 调用 SentenceTransformerRerank 并按相关性重排
|
||
5. rerank 失败时降级返回原始结果
|
||
6. rerank 后结果分数被更新为 rerank 分数
|
||
7. RerankConfig 默认值与字段校验
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import sys
|
||
from unittest.mock import MagicMock
|
||
|
||
from agentkit.rag_platform.models import QueryResult
|
||
from agentkit.rag_platform.rerank import RerankConfig, Reranker
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# llama_index 模块 mock — 测试环境可能未安装 llama_index
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _setup_llama_index_mocks():
|
||
"""注入 mock llama_index 模块到 sys.modules,使 import 成功。
|
||
|
||
仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。
|
||
"""
|
||
try:
|
||
import llama_index # noqa: F401
|
||
|
||
return # 真实模块可用,无需 mock
|
||
except ImportError:
|
||
pass
|
||
|
||
mock_li = MagicMock()
|
||
mock_li_core = MagicMock()
|
||
mock_li_core_schema = MagicMock()
|
||
mock_li_postprocessor = MagicMock()
|
||
mock_li_postprocessor_cohere = MagicMock()
|
||
mock_li_postprocessor_st = MagicMock()
|
||
|
||
# TextNode / NodeWithScore — 简单的 mock 类
|
||
class MockTextNode:
|
||
def __init__(self, id_=None, text="", metadata=None):
|
||
self.node_id = id_
|
||
self.text = text
|
||
self.metadata = metadata or {}
|
||
|
||
def get_content(self):
|
||
return self.text
|
||
|
||
class MockNodeWithScore:
|
||
def __init__(self, node=None, score=None):
|
||
self.node = node
|
||
self.score = score
|
||
|
||
mock_li_core_schema.TextNode = MockTextNode
|
||
mock_li_core_schema.NodeWithScore = MockNodeWithScore
|
||
|
||
# CohereRerank / SentenceTransformerRerank — 由测试动态配置
|
||
mock_li_postprocessor_cohere.CohereRerank = MagicMock()
|
||
mock_li_postprocessor_st.SentenceTransformerRerank = MagicMock()
|
||
|
||
sys.modules["llama_index"] = mock_li
|
||
sys.modules["llama_index.core"] = mock_li_core
|
||
sys.modules["llama_index.core.schema"] = mock_li_core_schema
|
||
sys.modules["llama_index.postprocessor"] = mock_li_postprocessor
|
||
sys.modules["llama_index.postprocessor.cohere_rerank"] = mock_li_postprocessor_cohere
|
||
sys.modules["llama_index.postprocessor.sentence_transformers_rerank"] = mock_li_postprocessor_st
|
||
|
||
|
||
_setup_llama_index_mocks()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_query_result(
|
||
chunk_id: str,
|
||
content: str,
|
||
score: float = 0.5,
|
||
document_id: str = "doc1",
|
||
kb_id: str = "kb1",
|
||
) -> QueryResult:
|
||
"""创建测试用 QueryResult。"""
|
||
return QueryResult(
|
||
chunk_id=chunk_id,
|
||
content=content,
|
||
score=score,
|
||
metadata={"document_id": document_id, "kb_id": kb_id},
|
||
document_id=document_id,
|
||
kb_id=kb_id,
|
||
)
|
||
|
||
|
||
def _make_mock_reranker(reranked_order: list[tuple[str, float]]):
|
||
"""创建 mock LlamaIndex reranker。
|
||
|
||
Args:
|
||
reranked_order: [(node_id, new_score), ...] — 重排后的顺序和分数
|
||
|
||
Returns:
|
||
mock reranker 实例,postprocessnodes 方法返回重排后的 NodeWithScore 列表
|
||
"""
|
||
from llama_index.core.schema import NodeWithScore, TextNode
|
||
|
||
mock = MagicMock()
|
||
mock.postprocessnodes = MagicMock(
|
||
return_value=[
|
||
NodeWithScore(
|
||
node=TextNode(id_=node_id, text=""),
|
||
score=score,
|
||
)
|
||
for node_id, score in reranked_order
|
||
]
|
||
)
|
||
return mock
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RerankConfig 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRerankConfig:
|
||
"""RerankConfig 模型测试。"""
|
||
|
||
def test_defaults(self):
|
||
"""默认值为 none provider,无数据出境风险。"""
|
||
cfg = RerankConfig()
|
||
assert cfg.provider == "none"
|
||
assert cfg.api_key is None
|
||
assert cfg.base_url is None
|
||
assert cfg.top_n == 5
|
||
assert cfg.data_export_warning is False
|
||
|
||
def test_cohere_config(self):
|
||
"""Cohere 配置 — 需标注数据出境风险。"""
|
||
cfg = RerankConfig(
|
||
provider="cohere",
|
||
api_key="test-key",
|
||
top_n=3,
|
||
data_export_warning=True,
|
||
)
|
||
assert cfg.provider == "cohere"
|
||
assert cfg.api_key == "test-key"
|
||
assert cfg.top_n == 3
|
||
assert cfg.data_export_warning is True
|
||
|
||
def test_bge_config(self):
|
||
"""BGE 配置 — 本地部署,无数据出境。"""
|
||
cfg = RerankConfig(
|
||
provider="bge",
|
||
base_url="http://localhost:9997",
|
||
model_name="bge-reranker-base",
|
||
top_n=10,
|
||
)
|
||
assert cfg.provider == "bge"
|
||
assert cfg.base_url == "http://localhost:9997"
|
||
assert cfg.model_name == "bge-reranker-base"
|
||
assert cfg.top_n == 10
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Reranker 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRerankerNone:
|
||
"""provider="none" 时的测试。"""
|
||
|
||
async def test_none_provider_returns_as_is(self):
|
||
"""provider="none" 时原样返回结果(不调用 reranker)。"""
|
||
cfg = RerankConfig(provider="none")
|
||
reranker = Reranker(cfg)
|
||
|
||
results = [
|
||
_make_query_result("c1", "content 1", 0.9),
|
||
_make_query_result("c2", "content 2", 0.5),
|
||
]
|
||
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
assert len(reranked) == 2
|
||
assert reranked[0].chunk_id == "c1"
|
||
assert reranked[1].chunk_id == "c2"
|
||
# 分数不变
|
||
assert reranked[0].score == 0.9
|
||
assert reranked[1].score == 0.5
|
||
|
||
async def test_empty_results_returns_empty(self):
|
||
"""空结果列表原样返回空。"""
|
||
cfg = RerankConfig(provider="cohere", api_key="key")
|
||
reranker = Reranker(cfg)
|
||
|
||
reranked = await reranker.rerank("query", [])
|
||
assert reranked == []
|
||
|
||
|
||
class TestRerankerCohere:
|
||
"""Cohere rerank 测试。"""
|
||
|
||
async def test_rerank_reorders_results(self):
|
||
"""rerank 后结果按相关性重排 — 原始顺序被打乱。"""
|
||
# 原始顺序:c1 (score=0.9), c2 (score=0.5)
|
||
# 重排后:c2 (score=0.95), c1 (score=0.80) — c2 更相关
|
||
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=2)
|
||
reranker = Reranker(cfg)
|
||
|
||
# 注入 mock reranker
|
||
mock_reranker = _make_mock_reranker([("c2", 0.95), ("c1", 0.80)])
|
||
reranker._reranker = mock_reranker
|
||
|
||
results = [
|
||
_make_query_result("c1", "content 1", 0.9),
|
||
_make_query_result("c2", "content 2", 0.5),
|
||
]
|
||
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
assert len(reranked) == 2
|
||
# 重排后 c2 在前
|
||
assert reranked[0].chunk_id == "c2"
|
||
assert reranked[1].chunk_id == "c1"
|
||
# 分数被更新为 rerank 分数
|
||
assert reranked[0].score == 0.95
|
||
assert reranked[1].score == 0.80
|
||
|
||
async def test_rerank_preserves_metadata(self):
|
||
"""rerank 后保留原始 QueryResult 的元数据(document_id, kb_id)。"""
|
||
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=1)
|
||
reranker = Reranker(cfg)
|
||
|
||
mock_reranker = _make_mock_reranker([("c1", 0.99)])
|
||
reranker._reranker = mock_reranker
|
||
|
||
results = [
|
||
_make_query_result("c1", "content 1", 0.5, document_id="doc-99", kb_id="kb-99"),
|
||
]
|
||
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
assert len(reranked) == 1
|
||
assert reranked[0].document_id == "doc-99"
|
||
assert reranked[0].kb_id == "kb-99"
|
||
assert reranked[0].metadata["document_id"] == "doc-99"
|
||
|
||
async def test_rerank_failure_falls_back(self):
|
||
"""reranker 抛异常时降级返回原始结果。"""
|
||
cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=5)
|
||
reranker = Reranker(cfg)
|
||
|
||
# mock reranker 抛异常
|
||
mock_reranker = MagicMock()
|
||
mock_reranker.postprocessnodes = MagicMock(side_effect=RuntimeError("API error"))
|
||
reranker._reranker = mock_reranker
|
||
|
||
results = [
|
||
_make_query_result("c1", "content 1", 0.9),
|
||
_make_query_result("c2", "content 2", 0.5),
|
||
]
|
||
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
# 降级返回原始结果(顺序不变)
|
||
assert len(reranked) == 2
|
||
assert reranked[0].chunk_id == "c1"
|
||
assert reranked[1].chunk_id == "c2"
|
||
|
||
async def test_cohere_requires_api_key(self):
|
||
"""Cohere provider 缺少 api_key 时抛 ValueError。"""
|
||
cfg = RerankConfig(provider="cohere", api_key=None)
|
||
reranker = Reranker(cfg)
|
||
|
||
results = [_make_query_result("c1", "content", 0.5)]
|
||
|
||
try:
|
||
await reranker.rerank("query", results)
|
||
raise AssertionError("Expected ValueError")
|
||
except ValueError as e:
|
||
assert "api_key" in str(e)
|
||
|
||
|
||
class TestRerankerBGE:
|
||
"""BGE rerank 测试。"""
|
||
|
||
async def test_bge_rerank_reorders_results(self):
|
||
"""BGE rerank 同样按相关性重排。"""
|
||
cfg = RerankConfig(provider="bge", top_n=2)
|
||
reranker = Reranker(cfg)
|
||
|
||
mock_reranker = _make_mock_reranker([("c2", 0.88), ("c1", 0.55)])
|
||
reranker._reranker = mock_reranker
|
||
|
||
results = [
|
||
_make_query_result("c1", "content 1", 0.9),
|
||
_make_query_result("c2", "content 2", 0.5),
|
||
]
|
||
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
assert len(reranked) == 2
|
||
assert reranked[0].chunk_id == "c2"
|
||
assert reranked[0].score == 0.88
|
||
assert reranked[1].chunk_id == "c1"
|
||
assert reranked[1].score == 0.55
|
||
|
||
async def test_bge_no_api_key_required(self):
|
||
"""BGE 本地部署不需要 api_key。"""
|
||
cfg = RerankConfig(provider="bge", top_n=5)
|
||
reranker = Reranker(cfg)
|
||
|
||
# 注入 mock reranker — 验证不抛异常
|
||
mock_reranker = _make_mock_reranker([("c1", 0.9)])
|
||
reranker._reranker = mock_reranker
|
||
|
||
results = [_make_query_result("c1", "content", 0.5)]
|
||
reranked = await reranker.rerank("query", results)
|
||
|
||
assert len(reranked) == 1
|
||
assert reranked[0].score == 0.9
|
||
|
||
|
||
class TestRerankerUnsupportedProvider:
|
||
"""不支持的 provider 测试。"""
|
||
|
||
async def test_unsupported_provider_raises(self):
|
||
"""不支持的 provider 抛 ValueError。"""
|
||
cfg = RerankConfig(provider="invalid_provider")
|
||
reranker = Reranker(cfg)
|
||
|
||
results = [_make_query_result("c1", "content", 0.5)]
|
||
|
||
try:
|
||
await reranker.rerank("query", results)
|
||
raise AssertionError("Expected ValueError")
|
||
except ValueError as e:
|
||
assert "Unsupported" in str(e)
|