"""U5 测试 — Rerank 模型集成(Cohere/BGE/none)。 测试场景: 1. provider="none" 时原样返回结果 2. 空结果列表原样返回 3. Cohere rerank 调用 LlamaIndex CohereRerank 并按相关性重排 4. BGE rerank 调用 SentenceTransformerRerank 并按相关性重排 5. rerank 失败时降级返回原始结果 6. rerank 后结果分数被更新为 rerank 分数 7. RerankConfig 默认值与字段校验 """ from __future__ import annotations import sys from unittest.mock import MagicMock from agentkit.rag_platform.models import QueryResult from agentkit.rag_platform.rerank import RerankConfig, Reranker # --------------------------------------------------------------------------- # llama_index 模块 mock — 测试环境可能未安装 llama_index # --------------------------------------------------------------------------- def _setup_llama_index_mocks(): """注入 mock llama_index 模块到 sys.modules,使 import 成功。 仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。 """ try: import llama_index # noqa: F401 return # 真实模块可用,无需 mock except ImportError: pass mock_li = MagicMock() mock_li_core = MagicMock() mock_li_core_schema = MagicMock() mock_li_postprocessor = MagicMock() mock_li_postprocessor_cohere = MagicMock() mock_li_postprocessor_st = MagicMock() # TextNode / NodeWithScore — 简单的 mock 类 class MockTextNode: def __init__(self, id_=None, text="", metadata=None): self.node_id = id_ self.text = text self.metadata = metadata or {} def get_content(self): return self.text class MockNodeWithScore: def __init__(self, node=None, score=None): self.node = node self.score = score mock_li_core_schema.TextNode = MockTextNode mock_li_core_schema.NodeWithScore = MockNodeWithScore # CohereRerank / SentenceTransformerRerank — 由测试动态配置 mock_li_postprocessor_cohere.CohereRerank = MagicMock() mock_li_postprocessor_st.SentenceTransformerRerank = MagicMock() sys.modules["llama_index"] = mock_li sys.modules["llama_index.core"] = mock_li_core sys.modules["llama_index.core.schema"] = mock_li_core_schema sys.modules["llama_index.postprocessor"] = mock_li_postprocessor sys.modules["llama_index.postprocessor.cohere_rerank"] = mock_li_postprocessor_cohere sys.modules["llama_index.postprocessor.sentence_transformers_rerank"] = mock_li_postprocessor_st _setup_llama_index_mocks() # --------------------------------------------------------------------------- # 测试辅助函数 # --------------------------------------------------------------------------- def _make_query_result( chunk_id: str, content: str, score: float = 0.5, document_id: str = "doc1", kb_id: str = "kb1", ) -> QueryResult: """创建测试用 QueryResult。""" return QueryResult( chunk_id=chunk_id, content=content, score=score, metadata={"document_id": document_id, "kb_id": kb_id}, document_id=document_id, kb_id=kb_id, ) def _make_mock_reranker(reranked_order: list[tuple[str, float]]): """创建 mock LlamaIndex reranker。 Args: reranked_order: [(node_id, new_score), ...] — 重排后的顺序和分数 Returns: mock reranker 实例,postprocessnodes 方法返回重排后的 NodeWithScore 列表 """ from llama_index.core.schema import NodeWithScore, TextNode mock = MagicMock() mock.postprocessnodes = MagicMock( return_value=[ NodeWithScore( node=TextNode(id_=node_id, text=""), score=score, ) for node_id, score in reranked_order ] ) return mock # --------------------------------------------------------------------------- # RerankConfig 测试 # --------------------------------------------------------------------------- class TestRerankConfig: """RerankConfig 模型测试。""" def test_defaults(self): """默认值为 none provider,无数据出境风险。""" cfg = RerankConfig() assert cfg.provider == "none" assert cfg.api_key is None assert cfg.base_url is None assert cfg.top_n == 5 assert cfg.data_export_warning is False def test_cohere_config(self): """Cohere 配置 — 需标注数据出境风险。""" cfg = RerankConfig( provider="cohere", api_key="test-key", top_n=3, data_export_warning=True, ) assert cfg.provider == "cohere" assert cfg.api_key == "test-key" assert cfg.top_n == 3 assert cfg.data_export_warning is True def test_bge_config(self): """BGE 配置 — 本地部署,无数据出境。""" cfg = RerankConfig( provider="bge", base_url="http://localhost:9997", model_name="bge-reranker-base", top_n=10, ) assert cfg.provider == "bge" assert cfg.base_url == "http://localhost:9997" assert cfg.model_name == "bge-reranker-base" assert cfg.top_n == 10 # --------------------------------------------------------------------------- # Reranker 测试 # --------------------------------------------------------------------------- class TestRerankerNone: """provider="none" 时的测试。""" async def test_none_provider_returns_as_is(self): """provider="none" 时原样返回结果(不调用 reranker)。""" cfg = RerankConfig(provider="none") reranker = Reranker(cfg) results = [ _make_query_result("c1", "content 1", 0.9), _make_query_result("c2", "content 2", 0.5), ] reranked = await reranker.rerank("query", results) assert len(reranked) == 2 assert reranked[0].chunk_id == "c1" assert reranked[1].chunk_id == "c2" # 分数不变 assert reranked[0].score == 0.9 assert reranked[1].score == 0.5 async def test_empty_results_returns_empty(self): """空结果列表原样返回空。""" cfg = RerankConfig(provider="cohere", api_key="key") reranker = Reranker(cfg) reranked = await reranker.rerank("query", []) assert reranked == [] class TestRerankerCohere: """Cohere rerank 测试。""" async def test_rerank_reorders_results(self): """rerank 后结果按相关性重排 — 原始顺序被打乱。""" # 原始顺序:c1 (score=0.9), c2 (score=0.5) # 重排后:c2 (score=0.95), c1 (score=0.80) — c2 更相关 cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=2) reranker = Reranker(cfg) # 注入 mock reranker mock_reranker = _make_mock_reranker([("c2", 0.95), ("c1", 0.80)]) reranker._reranker = mock_reranker results = [ _make_query_result("c1", "content 1", 0.9), _make_query_result("c2", "content 2", 0.5), ] reranked = await reranker.rerank("query", results) assert len(reranked) == 2 # 重排后 c2 在前 assert reranked[0].chunk_id == "c2" assert reranked[1].chunk_id == "c1" # 分数被更新为 rerank 分数 assert reranked[0].score == 0.95 assert reranked[1].score == 0.80 async def test_rerank_preserves_metadata(self): """rerank 后保留原始 QueryResult 的元数据(document_id, kb_id)。""" cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=1) reranker = Reranker(cfg) mock_reranker = _make_mock_reranker([("c1", 0.99)]) reranker._reranker = mock_reranker results = [ _make_query_result("c1", "content 1", 0.5, document_id="doc-99", kb_id="kb-99"), ] reranked = await reranker.rerank("query", results) assert len(reranked) == 1 assert reranked[0].document_id == "doc-99" assert reranked[0].kb_id == "kb-99" assert reranked[0].metadata["document_id"] == "doc-99" async def test_rerank_failure_falls_back(self): """reranker 抛异常时降级返回原始结果。""" cfg = RerankConfig(provider="cohere", api_key="test-key", top_n=5) reranker = Reranker(cfg) # mock reranker 抛异常 mock_reranker = MagicMock() mock_reranker.postprocessnodes = MagicMock(side_effect=RuntimeError("API error")) reranker._reranker = mock_reranker results = [ _make_query_result("c1", "content 1", 0.9), _make_query_result("c2", "content 2", 0.5), ] reranked = await reranker.rerank("query", results) # 降级返回原始结果(顺序不变) assert len(reranked) == 2 assert reranked[0].chunk_id == "c1" assert reranked[1].chunk_id == "c2" async def test_cohere_requires_api_key(self): """Cohere provider 缺少 api_key 时抛 ValueError。""" cfg = RerankConfig(provider="cohere", api_key=None) reranker = Reranker(cfg) results = [_make_query_result("c1", "content", 0.5)] try: await reranker.rerank("query", results) raise AssertionError("Expected ValueError") except ValueError as e: assert "api_key" in str(e) class TestRerankerBGE: """BGE rerank 测试。""" async def test_bge_rerank_reorders_results(self): """BGE rerank 同样按相关性重排。""" cfg = RerankConfig(provider="bge", top_n=2) reranker = Reranker(cfg) mock_reranker = _make_mock_reranker([("c2", 0.88), ("c1", 0.55)]) reranker._reranker = mock_reranker results = [ _make_query_result("c1", "content 1", 0.9), _make_query_result("c2", "content 2", 0.5), ] reranked = await reranker.rerank("query", results) assert len(reranked) == 2 assert reranked[0].chunk_id == "c2" assert reranked[0].score == 0.88 assert reranked[1].chunk_id == "c1" assert reranked[1].score == 0.55 async def test_bge_no_api_key_required(self): """BGE 本地部署不需要 api_key。""" cfg = RerankConfig(provider="bge", top_n=5) reranker = Reranker(cfg) # 注入 mock reranker — 验证不抛异常 mock_reranker = _make_mock_reranker([("c1", 0.9)]) reranker._reranker = mock_reranker results = [_make_query_result("c1", "content", 0.5)] reranked = await reranker.rerank("query", results) assert len(reranked) == 1 assert reranked[0].score == 0.9 class TestRerankerUnsupportedProvider: """不支持的 provider 测试。""" async def test_unsupported_provider_raises(self): """不支持的 provider 抛 ValueError。""" cfg = RerankConfig(provider="invalid_provider") reranker = Reranker(cfg) results = [_make_query_result("c1", "content", 0.5)] try: await reranker.rerank("query", results) raise AssertionError("Expected ValueError") except ValueError as e: assert "Unsupported" in str(e)