468 lines
17 KiB
Python
468 lines
17 KiB
Python
"""U4 测试 — 双索引检索引擎(embedding/keywords/blend)。
|
||
|
||
测试场景:
|
||
1. embedding 模式:语义检索返回相关结果,按 kb_id 过滤
|
||
2. keywords 模式:中文全文检索返回包含关键词的结果
|
||
3. blend 模式:合并语义+全文结果,去重排序
|
||
4. 查询无结果时返回空列表(非报错)
|
||
5. RetrievalRequest 模型字段
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import sys
|
||
from contextlib import asynccontextmanager
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.models import QueryMode, QueryResult, RetrievalRequest
|
||
from agentkit.rag_platform.retrieval import RetrievalEngine, _normalize_scores
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# llama_index 模块 mock — 测试环境可能未安装 llama_index
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _setup_llama_index_mocks():
|
||
"""注入 mock llama_index 模块到 sys.modules,使 import 成功。
|
||
|
||
仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。
|
||
"""
|
||
try:
|
||
import llama_index # noqa: F401
|
||
|
||
return # 真实模块可用,无需 mock
|
||
except ImportError:
|
||
pass
|
||
|
||
# 创建 mock 模块层级
|
||
mock_li = MagicMock()
|
||
mock_li_core = MagicMock()
|
||
mock_li_vector_stores = MagicMock()
|
||
mock_li_core_vector_stores_types = MagicMock()
|
||
mock_li_core_embeddings = MagicMock()
|
||
|
||
# VectorStoreQuery — 简单的可调用 mock
|
||
mock_li_core_vector_stores_types.VectorStoreQuery = MagicMock()
|
||
|
||
sys.modules["llama_index"] = mock_li
|
||
sys.modules["llama_index.core"] = mock_li_core
|
||
sys.modules["llama_index.vector_stores"] = mock_li_vector_stores
|
||
sys.modules["llama_index.core.vector_stores"] = mock_li_core
|
||
sys.modules["llama_index.core.vector_stores.types"] = mock_li_core_vector_stores_types
|
||
sys.modules["llama_index.core.embeddings"] = mock_li_core_embeddings
|
||
|
||
|
||
_setup_llama_index_mocks()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_mock_embed_model():
|
||
"""创建 mock embedding 模型。"""
|
||
mock = MagicMock()
|
||
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||
return mock
|
||
|
||
|
||
def _make_mock_text_node(node_id: str, content: str, metadata: dict | None = None):
|
||
"""创建 mock LlamaIndex TextNode。"""
|
||
node = MagicMock()
|
||
node.node_id = node_id
|
||
node.get_content.return_value = content
|
||
node.metadata = metadata or {}
|
||
return node
|
||
|
||
|
||
def _make_mock_vector_store(nodes=None, similarities=None):
|
||
"""创建 mock PGVectorStore。"""
|
||
mock = MagicMock()
|
||
mock.aquery = AsyncMock()
|
||
if nodes is not None:
|
||
mock_result = MagicMock()
|
||
mock_result.nodes = nodes
|
||
mock_result.similarities = similarities or [0.9] * len(nodes)
|
||
mock.aquery.return_value = mock_result
|
||
return mock
|
||
|
||
|
||
def _make_mock_session_factory(execute_result_rows=None):
|
||
"""创建 mock session factory,用于 keywords 模式测试。
|
||
|
||
Args:
|
||
execute_result_rows: list of tuples — SQL 查询返回的行
|
||
每行格式: (id, text, metadata_, document_id, score)
|
||
"""
|
||
mock_session = AsyncMock()
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.all.return_value = execute_result_rows or []
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
return factory, mock_session
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RetrievalRequest 模型测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRetrievalRequest:
|
||
"""RetrievalRequest 模型测试。"""
|
||
|
||
def test_defaults(self):
|
||
"""默认值正确。"""
|
||
req = RetrievalRequest(query="test", kb_ids=["kb1"])
|
||
assert req.query == "test"
|
||
assert req.kb_ids == ["kb1"]
|
||
assert req.retrieval_mode is None
|
||
assert req.hit_processing_mode is None
|
||
assert req.top_k == 5
|
||
assert req.user_id is None
|
||
|
||
def test_explicit_values(self):
|
||
"""显式赋值正确。"""
|
||
req = RetrievalRequest(
|
||
query="hello",
|
||
kb_ids=["kb1", "kb2"],
|
||
retrieval_mode=QueryMode.embedding,
|
||
hit_processing_mode="direct",
|
||
top_k=10,
|
||
user_id="user1",
|
||
)
|
||
assert req.retrieval_mode == QueryMode.embedding
|
||
assert req.hit_processing_mode == "direct"
|
||
assert req.top_k == 10
|
||
assert req.user_id == "user1"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _normalize_scores 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestNormalizeScores:
|
||
"""分数归一化测试。"""
|
||
|
||
def test_empty_list(self):
|
||
"""空列表返回空列表。"""
|
||
assert _normalize_scores([]) == []
|
||
|
||
def test_constant_list_returns_ones(self):
|
||
"""常数列表(含单元素)返回全 1.0 — 等价于最高相关度。"""
|
||
assert _normalize_scores([0.5, 0.5, 0.5]) == [1.0, 1.0, 1.0]
|
||
assert _normalize_scores([0.9]) == [1.0]
|
||
|
||
def test_min_max_normalization(self):
|
||
"""min-max 归一化到 [0, 1]。"""
|
||
result = _normalize_scores([1.0, 2.0, 3.0])
|
||
assert result[0] == 0.0 # 最小值
|
||
assert result[2] == 1.0 # 最大值
|
||
assert 0.0 < result[1] < 1.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# embedding 模式测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestEmbeddingRetrieval:
|
||
"""embedding 模式检索测试。"""
|
||
|
||
async def test_returns_relevant_results(self):
|
||
"""embedding 模式返回相关结果。"""
|
||
nodes = [
|
||
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
|
||
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
|
||
]
|
||
mock_vs = _make_mock_vector_store(nodes, [0.95, 0.80])
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||
|
||
assert len(results) == 2
|
||
assert all(isinstance(r, QueryResult) for r in results)
|
||
assert results[0].chunk_id == "n1"
|
||
assert results[0].score == 0.95
|
||
assert results[0].kb_id == "kb1"
|
||
|
||
async def test_filters_by_kb_id(self):
|
||
"""embedding 模式按 kb_id 过滤结果。"""
|
||
# n1 属于 kb1,n2 属于 kb2(不在查询范围)
|
||
nodes = [
|
||
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
|
||
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb2"}),
|
||
]
|
||
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.9])
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||
|
||
assert len(results) == 1
|
||
assert results[0].kb_id == "kb1"
|
||
|
||
async def test_empty_results_no_error(self):
|
||
"""无结果时返回空列表(非报错)。"""
|
||
mock_vs = _make_mock_vector_store([], [])
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||
|
||
assert results == []
|
||
|
||
async def test_empty_kb_ids_returns_empty(self):
|
||
"""kb_ids 为空时直接返回空列表。"""
|
||
mock_vs = _make_mock_vector_store()
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", [], QueryMode.embedding, top_k=5)
|
||
|
||
assert results == []
|
||
mock_vs.aquery.assert_not_awaited()
|
||
|
||
async def test_no_embed_model_raises(self):
|
||
"""embedding 模式未配置 embed_model 抛出异常。"""
|
||
mock_vs = _make_mock_vector_store()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
try:
|
||
await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||
raise AssertionError("Expected ValueError")
|
||
except ValueError as e:
|
||
assert "embed_model" in str(e)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# keywords 模式测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKeywordsRetrieval:
|
||
"""keywords 模式检索测试。"""
|
||
|
||
async def test_returns_matching_results(self):
|
||
"""keywords 模式返回包含关键词的结果。"""
|
||
rows = [
|
||
("c1", "自然语言处理内容", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
|
||
("c2", "另一段文本", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.5),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
mock_vs = _make_mock_vector_store()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
results = await engine.retrieve("自然语言", ["kb1"], QueryMode.keywords, top_k=5)
|
||
|
||
assert len(results) == 2
|
||
assert results[0].chunk_id == "c1"
|
||
assert results[0].score == 0.8
|
||
assert results[0].kb_id == "kb1"
|
||
assert "自然语言" in results[0].content
|
||
|
||
async def test_empty_results_no_error(self):
|
||
"""无结果时返回空列表(非报错)。"""
|
||
sf, _ = _make_mock_session_factory([])
|
||
mock_vs = _make_mock_vector_store()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.keywords, top_k=5)
|
||
|
||
assert results == []
|
||
|
||
async def test_empty_query_returns_empty(self):
|
||
"""空查询(分词后无 token)返回空列表,不执行 SQL。"""
|
||
sf, mock_session = _make_mock_session_factory([])
|
||
mock_vs = _make_mock_vector_store()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
results = await engine.retrieve(" ", ["kb1"], QueryMode.keywords, top_k=5)
|
||
|
||
assert results == []
|
||
mock_session.execute.assert_not_awaited()
|
||
|
||
async def test_passes_kb_ids_to_sql(self):
|
||
"""kb_ids 作为参数传递给 SQL 查询。"""
|
||
rows = [
|
||
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||
]
|
||
sf, mock_session = _make_mock_session_factory(rows)
|
||
mock_vs = _make_mock_vector_store()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
await engine.retrieve("test", ["kb1", "kb2"], QueryMode.keywords, top_k=3)
|
||
|
||
call_args = mock_session.execute.await_args
|
||
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
|
||
assert params["kb_ids"] == ["kb1", "kb2"]
|
||
assert params["top_k"] == 3
|
||
|
||
async def test_handles_none_metadata(self):
|
||
"""metadata 为 None 时正常处理。"""
|
||
rows = [
|
||
("c1", "content", None, "d1", 0.5),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
mock_vs = _make_mock_vector_store()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
results = await engine.retrieve("test", ["kb1"], QueryMode.keywords, top_k=5)
|
||
|
||
assert len(results) == 1
|
||
assert results[0].metadata == {}
|
||
assert results[0].kb_id == ""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# blend 模式测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestBlendRetrieval:
|
||
"""blend 模式检索测试。"""
|
||
|
||
async def test_merges_and_deduplicates(self):
|
||
"""blend 模式合并语义+全文结果,按 chunk_id 去重。"""
|
||
# embedding 结果
|
||
nodes = [
|
||
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
|
||
_make_mock_text_node("c2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
|
||
]
|
||
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.7])
|
||
mock_embed = _make_mock_embed_model()
|
||
|
||
# keywords 结果 — c1 重复,c3 新增
|
||
rows = [
|
||
("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
|
||
("c3", "chunk 3", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.6),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||
|
||
# 去重后应有 3 个结果(c1, c2, c3)
|
||
chunk_ids = {r.chunk_id for r in results}
|
||
assert chunk_ids == {"c1", "c2", "c3"}
|
||
# 结果按分数降序
|
||
scores = [r.score for r in results]
|
||
assert scores == sorted(scores, reverse=True)
|
||
|
||
async def test_dedup_takes_highest_score(self):
|
||
"""去重时同 chunk_id 取最高分。"""
|
||
# embedding: c1 score=0.9 (归一化后 1.0 * 0.6 = 0.6)
|
||
nodes = [
|
||
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1"}),
|
||
]
|
||
mock_vs = _make_mock_vector_store(nodes, [0.9])
|
||
mock_embed = _make_mock_embed_model()
|
||
|
||
# keywords: c1 score=0.8 (归一化后 1.0 * 0.4 = 0.4)
|
||
rows = [
|
||
("c1", "chunk 1", {"kb_id": "kb1"}, "d1", 0.8),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||
|
||
# c1 只出现一次,取 embedding 的分数(0.6 > 0.4)
|
||
assert len(results) == 1
|
||
assert results[0].chunk_id == "c1"
|
||
assert results[0].score == 0.6 # 1.0 * 0.6
|
||
|
||
async def test_empty_results_no_error(self):
|
||
"""两种检索都无结果时返回空列表。"""
|
||
mock_vs = _make_mock_vector_store([], [])
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory([])
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||
|
||
assert results == []
|
||
|
||
async def test_respects_top_k(self):
|
||
"""blend 模式遵守 top_k 限制。"""
|
||
# embedding 返回 3 个结果
|
||
nodes = [_make_mock_text_node(f"n{i}", f"chunk {i}", {"kb_id": "kb1"}) for i in range(3)]
|
||
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.8, 0.7])
|
||
mock_embed = _make_mock_embed_model()
|
||
sf, _ = _make_mock_session_factory([]) # keywords 无结果
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=2)
|
||
|
||
assert len(results) == 2
|
||
|
||
async def test_fallback_when_embedding_fails(self):
|
||
"""embedding 检索失败时降级为纯关键词检索。"""
|
||
# embedding 模式会抛异常(mock 设置)
|
||
mock_vs = MagicMock()
|
||
mock_vs.aquery = AsyncMock(side_effect=RuntimeError("connection failed"))
|
||
mock_embed = _make_mock_embed_model()
|
||
|
||
# keywords 返回结果
|
||
rows = [
|
||
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||
|
||
# 应降级为纯关键词结果
|
||
assert len(results) == 1
|
||
assert results[0].chunk_id == "c1"
|
||
|
||
async def test_fallback_when_no_embed_model(self):
|
||
"""未配置 embed_model 时 blend 降级为纯关键词检索。"""
|
||
mock_vs = _make_mock_vector_store()
|
||
mock_embed = None
|
||
|
||
rows = [
|
||
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||
]
|
||
sf, _ = _make_mock_session_factory(rows)
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||
|
||
assert len(results) == 1
|
||
assert results[0].chunk_id == "c1"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 不支持的 mode 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestUnsupportedMode:
|
||
"""不支持的检索模式测试。"""
|
||
|
||
async def test_unsupported_mode_raises(self):
|
||
"""不支持的 mode 抛出 ValueError。"""
|
||
mock_vs = _make_mock_vector_store()
|
||
sf, _ = _make_mock_session_factory()
|
||
|
||
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||
try:
|
||
await engine.retrieve("query", ["kb1"], "invalid_mode", top_k=5) # type: ignore[arg-type]
|
||
raise AssertionError("Expected ValueError")
|
||
except ValueError as e:
|
||
assert "Unsupported" in str(e)
|