"""U4 测试 — 双索引检索引擎(embedding/keywords/blend)。 测试场景: 1. embedding 模式:语义检索返回相关结果,按 kb_id 过滤 2. keywords 模式:中文全文检索返回包含关键词的结果 3. blend 模式:合并语义+全文结果,去重排序 4. 查询无结果时返回空列表(非报错) 5. RetrievalRequest 模型字段 """ from __future__ import annotations import sys from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock from agentkit.rag_platform.models import QueryMode, QueryResult, RetrievalRequest from agentkit.rag_platform.retrieval import RetrievalEngine, _normalize_scores # --------------------------------------------------------------------------- # llama_index 模块 mock — 测试环境可能未安装 llama_index # --------------------------------------------------------------------------- def _setup_llama_index_mocks(): """注入 mock llama_index 模块到 sys.modules,使 import 成功。 仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。 """ try: import llama_index # noqa: F401 return # 真实模块可用,无需 mock except ImportError: pass # 创建 mock 模块层级 mock_li = MagicMock() mock_li_core = MagicMock() mock_li_vector_stores = MagicMock() mock_li_core_vector_stores_types = MagicMock() mock_li_core_embeddings = MagicMock() # VectorStoreQuery — 简单的可调用 mock mock_li_core_vector_stores_types.VectorStoreQuery = MagicMock() sys.modules["llama_index"] = mock_li sys.modules["llama_index.core"] = mock_li_core sys.modules["llama_index.vector_stores"] = mock_li_vector_stores sys.modules["llama_index.core.vector_stores"] = mock_li_core sys.modules["llama_index.core.vector_stores.types"] = mock_li_core_vector_stores_types sys.modules["llama_index.core.embeddings"] = mock_li_core_embeddings _setup_llama_index_mocks() # --------------------------------------------------------------------------- # 测试辅助函数 # --------------------------------------------------------------------------- def _make_mock_embed_model(): """创建 mock embedding 模型。""" mock = MagicMock() mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536) return mock def _make_mock_text_node(node_id: str, content: str, metadata: dict | None = None): """创建 mock LlamaIndex TextNode。""" node = MagicMock() node.node_id = node_id node.get_content.return_value = content node.metadata = metadata or {} return node def _make_mock_vector_store(nodes=None, similarities=None): """创建 mock PGVectorStore。""" mock = MagicMock() mock.aquery = AsyncMock() if nodes is not None: mock_result = MagicMock() mock_result.nodes = nodes mock_result.similarities = similarities or [0.9] * len(nodes) mock.aquery.return_value = mock_result return mock def _make_mock_session_factory(execute_result_rows=None): """创建 mock session factory,用于 keywords 模式测试。 Args: execute_result_rows: list of tuples — SQL 查询返回的行 每行格式: (id, text, metadata_, document_id, score) """ mock_session = AsyncMock() mock_result = MagicMock() mock_result.all.return_value = execute_result_rows or [] mock_session.execute = AsyncMock(return_value=mock_result) @asynccontextmanager async def factory(): yield mock_session return factory, mock_session # --------------------------------------------------------------------------- # RetrievalRequest 模型测试 # --------------------------------------------------------------------------- class TestRetrievalRequest: """RetrievalRequest 模型测试。""" def test_defaults(self): """默认值正确。""" req = RetrievalRequest(query="test", kb_ids=["kb1"]) assert req.query == "test" assert req.kb_ids == ["kb1"] assert req.retrieval_mode is None assert req.hit_processing_mode is None assert req.top_k == 5 assert req.user_id is None def test_explicit_values(self): """显式赋值正确。""" req = RetrievalRequest( query="hello", kb_ids=["kb1", "kb2"], retrieval_mode=QueryMode.embedding, hit_processing_mode="direct", top_k=10, user_id="user1", ) assert req.retrieval_mode == QueryMode.embedding assert req.hit_processing_mode == "direct" assert req.top_k == 10 assert req.user_id == "user1" # --------------------------------------------------------------------------- # _normalize_scores 测试 # --------------------------------------------------------------------------- class TestNormalizeScores: """分数归一化测试。""" def test_empty_list(self): """空列表返回空列表。""" assert _normalize_scores([]) == [] def test_constant_list_returns_ones(self): """常数列表(含单元素)返回全 1.0 — 等价于最高相关度。""" assert _normalize_scores([0.5, 0.5, 0.5]) == [1.0, 1.0, 1.0] assert _normalize_scores([0.9]) == [1.0] def test_min_max_normalization(self): """min-max 归一化到 [0, 1]。""" result = _normalize_scores([1.0, 2.0, 3.0]) assert result[0] == 0.0 # 最小值 assert result[2] == 1.0 # 最大值 assert 0.0 < result[1] < 1.0 # --------------------------------------------------------------------------- # embedding 模式测试 # --------------------------------------------------------------------------- class TestEmbeddingRetrieval: """embedding 模式检索测试。""" async def test_returns_relevant_results(self): """embedding 模式返回相关结果。""" nodes = [ _make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}), _make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}), ] mock_vs = _make_mock_vector_store(nodes, [0.95, 0.80]) mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5) assert len(results) == 2 assert all(isinstance(r, QueryResult) for r in results) assert results[0].chunk_id == "n1" assert results[0].score == 0.95 assert results[0].kb_id == "kb1" async def test_filters_by_kb_id(self): """embedding 模式按 kb_id 过滤结果。""" # n1 属于 kb1,n2 属于 kb2(不在查询范围) nodes = [ _make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}), _make_mock_text_node("n2", "chunk 2", {"kb_id": "kb2"}), ] mock_vs = _make_mock_vector_store(nodes, [0.9, 0.9]) mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5) assert len(results) == 1 assert results[0].kb_id == "kb1" async def test_empty_results_no_error(self): """无结果时返回空列表(非报错)。""" mock_vs = _make_mock_vector_store([], []) mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5) assert results == [] async def test_empty_kb_ids_returns_empty(self): """kb_ids 为空时直接返回空列表。""" mock_vs = _make_mock_vector_store() mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", [], QueryMode.embedding, top_k=5) assert results == [] mock_vs.aquery.assert_not_awaited() async def test_no_embed_model_raises(self): """embedding 模式未配置 embed_model 抛出异常。""" mock_vs = _make_mock_vector_store() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, embed_model=None) try: await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5) raise AssertionError("Expected ValueError") except ValueError as e: assert "embed_model" in str(e) # --------------------------------------------------------------------------- # keywords 模式测试 # --------------------------------------------------------------------------- class TestKeywordsRetrieval: """keywords 模式检索测试。""" async def test_returns_matching_results(self): """keywords 模式返回包含关键词的结果。""" rows = [ ("c1", "自然语言处理内容", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8), ("c2", "另一段文本", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.5), ] sf, _ = _make_mock_session_factory(rows) mock_vs = _make_mock_vector_store() engine = RetrievalEngine(mock_vs, sf, embed_model=None) results = await engine.retrieve("自然语言", ["kb1"], QueryMode.keywords, top_k=5) assert len(results) == 2 assert results[0].chunk_id == "c1" assert results[0].score == 0.8 assert results[0].kb_id == "kb1" assert "自然语言" in results[0].content async def test_empty_results_no_error(self): """无结果时返回空列表(非报错)。""" sf, _ = _make_mock_session_factory([]) mock_vs = _make_mock_vector_store() engine = RetrievalEngine(mock_vs, sf, embed_model=None) results = await engine.retrieve("query", ["kb1"], QueryMode.keywords, top_k=5) assert results == [] async def test_empty_query_returns_empty(self): """空查询(分词后无 token)返回空列表,不执行 SQL。""" sf, mock_session = _make_mock_session_factory([]) mock_vs = _make_mock_vector_store() engine = RetrievalEngine(mock_vs, sf, embed_model=None) results = await engine.retrieve(" ", ["kb1"], QueryMode.keywords, top_k=5) assert results == [] mock_session.execute.assert_not_awaited() async def test_passes_kb_ids_to_sql(self): """kb_ids 作为参数传递给 SQL 查询。""" rows = [ ("c1", "content", {"kb_id": "kb1"}, "d1", 0.5), ] sf, mock_session = _make_mock_session_factory(rows) mock_vs = _make_mock_vector_store() engine = RetrievalEngine(mock_vs, sf, embed_model=None) await engine.retrieve("test", ["kb1", "kb2"], QueryMode.keywords, top_k=3) call_args = mock_session.execute.await_args params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs assert params["kb_ids"] == ["kb1", "kb2"] assert params["top_k"] == 3 async def test_handles_none_metadata(self): """metadata 为 None 时正常处理。""" rows = [ ("c1", "content", None, "d1", 0.5), ] sf, _ = _make_mock_session_factory(rows) mock_vs = _make_mock_vector_store() engine = RetrievalEngine(mock_vs, sf, embed_model=None) results = await engine.retrieve("test", ["kb1"], QueryMode.keywords, top_k=5) assert len(results) == 1 assert results[0].metadata == {} assert results[0].kb_id == "" # --------------------------------------------------------------------------- # blend 模式测试 # --------------------------------------------------------------------------- class TestBlendRetrieval: """blend 模式检索测试。""" async def test_merges_and_deduplicates(self): """blend 模式合并语义+全文结果,按 chunk_id 去重。""" # embedding 结果 nodes = [ _make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}), _make_mock_text_node("c2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}), ] mock_vs = _make_mock_vector_store(nodes, [0.9, 0.7]) mock_embed = _make_mock_embed_model() # keywords 结果 — c1 重复,c3 新增 rows = [ ("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8), ("c3", "chunk 3", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.6), ] sf, _ = _make_mock_session_factory(rows) engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5) # 去重后应有 3 个结果(c1, c2, c3) chunk_ids = {r.chunk_id for r in results} assert chunk_ids == {"c1", "c2", "c3"} # 结果按分数降序 scores = [r.score for r in results] assert scores == sorted(scores, reverse=True) async def test_dedup_takes_highest_score(self): """去重时同 chunk_id 取最高分。""" # embedding: c1 score=0.9 (归一化后 1.0 * 0.6 = 0.6) nodes = [ _make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1"}), ] mock_vs = _make_mock_vector_store(nodes, [0.9]) mock_embed = _make_mock_embed_model() # keywords: c1 score=0.8 (归一化后 1.0 * 0.4 = 0.4) rows = [ ("c1", "chunk 1", {"kb_id": "kb1"}, "d1", 0.8), ] sf, _ = _make_mock_session_factory(rows) engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5) # c1 只出现一次,取 embedding 的分数(0.6 > 0.4) assert len(results) == 1 assert results[0].chunk_id == "c1" assert results[0].score == 0.6 # 1.0 * 0.6 async def test_empty_results_no_error(self): """两种检索都无结果时返回空列表。""" mock_vs = _make_mock_vector_store([], []) mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory([]) engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5) assert results == [] async def test_respects_top_k(self): """blend 模式遵守 top_k 限制。""" # embedding 返回 3 个结果 nodes = [_make_mock_text_node(f"n{i}", f"chunk {i}", {"kb_id": "kb1"}) for i in range(3)] mock_vs = _make_mock_vector_store(nodes, [0.9, 0.8, 0.7]) mock_embed = _make_mock_embed_model() sf, _ = _make_mock_session_factory([]) # keywords 无结果 engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=2) assert len(results) == 2 async def test_fallback_when_embedding_fails(self): """embedding 检索失败时降级为纯关键词检索。""" # embedding 模式会抛异常(mock 设置) mock_vs = MagicMock() mock_vs.aquery = AsyncMock(side_effect=RuntimeError("connection failed")) mock_embed = _make_mock_embed_model() # keywords 返回结果 rows = [ ("c1", "content", {"kb_id": "kb1"}, "d1", 0.5), ] sf, _ = _make_mock_session_factory(rows) engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5) # 应降级为纯关键词结果 assert len(results) == 1 assert results[0].chunk_id == "c1" async def test_fallback_when_no_embed_model(self): """未配置 embed_model 时 blend 降级为纯关键词检索。""" mock_vs = _make_mock_vector_store() mock_embed = None rows = [ ("c1", "content", {"kb_id": "kb1"}, "d1", 0.5), ] sf, _ = _make_mock_session_factory(rows) engine = RetrievalEngine(mock_vs, sf, mock_embed) results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5) assert len(results) == 1 assert results[0].chunk_id == "c1" # --------------------------------------------------------------------------- # 不支持的 mode 测试 # --------------------------------------------------------------------------- class TestUnsupportedMode: """不支持的检索模式测试。""" async def test_unsupported_mode_raises(self): """不支持的 mode 抛出 ValueError。""" mock_vs = _make_mock_vector_store() sf, _ = _make_mock_session_factory() engine = RetrievalEngine(mock_vs, sf, embed_model=None) try: await engine.retrieve("query", ["kb1"], "invalid_mode", top_k=5) # type: ignore[arg-type] raise AssertionError("Expected ValueError") except ValueError as e: assert "Unsupported" in str(e)