fischer-agentkit/tests/unit/rag_platform/test_retrieval.py

468 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U4 测试 — 双索引检索引擎embedding/keywords/blend
测试场景:
1. embedding 模式:语义检索返回相关结果,按 kb_id 过滤
2. keywords 模式:中文全文检索返回包含关键词的结果
3. blend 模式:合并语义+全文结果,去重排序
4. 查询无结果时返回空列表(非报错)
5. RetrievalRequest 模型字段
"""
from __future__ import annotations
import sys
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.models import QueryMode, QueryResult, RetrievalRequest
from agentkit.rag_platform.retrieval import RetrievalEngine, _normalize_scores
# ---------------------------------------------------------------------------
# llama_index 模块 mock — 测试环境可能未安装 llama_index
# ---------------------------------------------------------------------------
def _setup_llama_index_mocks():
"""注入 mock llama_index 模块到 sys.modules使 import 成功。
仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。
"""
try:
import llama_index # noqa: F401
return # 真实模块可用,无需 mock
except ImportError:
pass
# 创建 mock 模块层级
mock_li = MagicMock()
mock_li_core = MagicMock()
mock_li_vector_stores = MagicMock()
mock_li_core_vector_stores_types = MagicMock()
mock_li_core_embeddings = MagicMock()
# VectorStoreQuery — 简单的可调用 mock
mock_li_core_vector_stores_types.VectorStoreQuery = MagicMock()
sys.modules["llama_index"] = mock_li
sys.modules["llama_index.core"] = mock_li_core
sys.modules["llama_index.vector_stores"] = mock_li_vector_stores
sys.modules["llama_index.core.vector_stores"] = mock_li_core
sys.modules["llama_index.core.vector_stores.types"] = mock_li_core_vector_stores_types
sys.modules["llama_index.core.embeddings"] = mock_li_core_embeddings
_setup_llama_index_mocks()
# ---------------------------------------------------------------------------
# 测试辅助函数
# ---------------------------------------------------------------------------
def _make_mock_embed_model():
"""创建 mock embedding 模型。"""
mock = MagicMock()
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return mock
def _make_mock_text_node(node_id: str, content: str, metadata: dict | None = None):
"""创建 mock LlamaIndex TextNode。"""
node = MagicMock()
node.node_id = node_id
node.get_content.return_value = content
node.metadata = metadata or {}
return node
def _make_mock_vector_store(nodes=None, similarities=None):
"""创建 mock PGVectorStore。"""
mock = MagicMock()
mock.aquery = AsyncMock()
if nodes is not None:
mock_result = MagicMock()
mock_result.nodes = nodes
mock_result.similarities = similarities or [0.9] * len(nodes)
mock.aquery.return_value = mock_result
return mock
def _make_mock_session_factory(execute_result_rows=None):
"""创建 mock session factory用于 keywords 模式测试。
Args:
execute_result_rows: list of tuples — SQL 查询返回的行
每行格式: (id, text, metadata_, document_id, score)
"""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.all.return_value = execute_result_rows or []
mock_session.execute = AsyncMock(return_value=mock_result)
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
# ---------------------------------------------------------------------------
# RetrievalRequest 模型测试
# ---------------------------------------------------------------------------
class TestRetrievalRequest:
"""RetrievalRequest 模型测试。"""
def test_defaults(self):
"""默认值正确。"""
req = RetrievalRequest(query="test", kb_ids=["kb1"])
assert req.query == "test"
assert req.kb_ids == ["kb1"]
assert req.retrieval_mode is None
assert req.hit_processing_mode is None
assert req.top_k == 5
assert req.user_id is None
def test_explicit_values(self):
"""显式赋值正确。"""
req = RetrievalRequest(
query="hello",
kb_ids=["kb1", "kb2"],
retrieval_mode=QueryMode.embedding,
hit_processing_mode="direct",
top_k=10,
user_id="user1",
)
assert req.retrieval_mode == QueryMode.embedding
assert req.hit_processing_mode == "direct"
assert req.top_k == 10
assert req.user_id == "user1"
# ---------------------------------------------------------------------------
# _normalize_scores 测试
# ---------------------------------------------------------------------------
class TestNormalizeScores:
"""分数归一化测试。"""
def test_empty_list(self):
"""空列表返回空列表。"""
assert _normalize_scores([]) == []
def test_constant_list_returns_ones(self):
"""常数列表(含单元素)返回全 1.0 — 等价于最高相关度。"""
assert _normalize_scores([0.5, 0.5, 0.5]) == [1.0, 1.0, 1.0]
assert _normalize_scores([0.9]) == [1.0]
def test_min_max_normalization(self):
"""min-max 归一化到 [0, 1]。"""
result = _normalize_scores([1.0, 2.0, 3.0])
assert result[0] == 0.0 # 最小值
assert result[2] == 1.0 # 最大值
assert 0.0 < result[1] < 1.0
# ---------------------------------------------------------------------------
# embedding 模式测试
# ---------------------------------------------------------------------------
class TestEmbeddingRetrieval:
"""embedding 模式检索测试。"""
async def test_returns_relevant_results(self):
"""embedding 模式返回相关结果。"""
nodes = [
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.95, 0.80])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert len(results) == 2
assert all(isinstance(r, QueryResult) for r in results)
assert results[0].chunk_id == "n1"
assert results[0].score == 0.95
assert results[0].kb_id == "kb1"
async def test_filters_by_kb_id(self):
"""embedding 模式按 kb_id 过滤结果。"""
# n1 属于 kb1n2 属于 kb2不在查询范围
nodes = [
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb2"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.9])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert len(results) == 1
assert results[0].kb_id == "kb1"
async def test_empty_results_no_error(self):
"""无结果时返回空列表(非报错)。"""
mock_vs = _make_mock_vector_store([], [])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert results == []
async def test_empty_kb_ids_returns_empty(self):
"""kb_ids 为空时直接返回空列表。"""
mock_vs = _make_mock_vector_store()
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", [], QueryMode.embedding, top_k=5)
assert results == []
mock_vs.aquery.assert_not_awaited()
async def test_no_embed_model_raises(self):
"""embedding 模式未配置 embed_model 抛出异常。"""
mock_vs = _make_mock_vector_store()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
try:
await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "embed_model" in str(e)
# ---------------------------------------------------------------------------
# keywords 模式测试
# ---------------------------------------------------------------------------
class TestKeywordsRetrieval:
"""keywords 模式检索测试。"""
async def test_returns_matching_results(self):
"""keywords 模式返回包含关键词的结果。"""
rows = [
("c1", "自然语言处理内容", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
("c2", "另一段文本", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("自然语言", ["kb1"], QueryMode.keywords, top_k=5)
assert len(results) == 2
assert results[0].chunk_id == "c1"
assert results[0].score == 0.8
assert results[0].kb_id == "kb1"
assert "自然语言" in results[0].content
async def test_empty_results_no_error(self):
"""无结果时返回空列表(非报错)。"""
sf, _ = _make_mock_session_factory([])
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("query", ["kb1"], QueryMode.keywords, top_k=5)
assert results == []
async def test_empty_query_returns_empty(self):
"""空查询(分词后无 token返回空列表不执行 SQL。"""
sf, mock_session = _make_mock_session_factory([])
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve(" ", ["kb1"], QueryMode.keywords, top_k=5)
assert results == []
mock_session.execute.assert_not_awaited()
async def test_passes_kb_ids_to_sql(self):
"""kb_ids 作为参数传递给 SQL 查询。"""
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, mock_session = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
await engine.retrieve("test", ["kb1", "kb2"], QueryMode.keywords, top_k=3)
call_args = mock_session.execute.await_args
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
assert params["kb_ids"] == ["kb1", "kb2"]
assert params["top_k"] == 3
async def test_handles_none_metadata(self):
"""metadata 为 None 时正常处理。"""
rows = [
("c1", "content", None, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("test", ["kb1"], QueryMode.keywords, top_k=5)
assert len(results) == 1
assert results[0].metadata == {}
assert results[0].kb_id == ""
# ---------------------------------------------------------------------------
# blend 模式测试
# ---------------------------------------------------------------------------
class TestBlendRetrieval:
"""blend 模式检索测试。"""
async def test_merges_and_deduplicates(self):
"""blend 模式合并语义+全文结果,按 chunk_id 去重。"""
# embedding 结果
nodes = [
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
_make_mock_text_node("c2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.7])
mock_embed = _make_mock_embed_model()
# keywords 结果 — c1 重复c3 新增
rows = [
("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
("c3", "chunk 3", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.6),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# 去重后应有 3 个结果c1, c2, c3
chunk_ids = {r.chunk_id for r in results}
assert chunk_ids == {"c1", "c2", "c3"}
# 结果按分数降序
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
async def test_dedup_takes_highest_score(self):
"""去重时同 chunk_id 取最高分。"""
# embedding: c1 score=0.9 (归一化后 1.0 * 0.6 = 0.6)
nodes = [
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9])
mock_embed = _make_mock_embed_model()
# keywords: c1 score=0.8 (归一化后 1.0 * 0.4 = 0.4)
rows = [
("c1", "chunk 1", {"kb_id": "kb1"}, "d1", 0.8),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# c1 只出现一次,取 embedding 的分数0.6 > 0.4
assert len(results) == 1
assert results[0].chunk_id == "c1"
assert results[0].score == 0.6 # 1.0 * 0.6
async def test_empty_results_no_error(self):
"""两种检索都无结果时返回空列表。"""
mock_vs = _make_mock_vector_store([], [])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory([])
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
assert results == []
async def test_respects_top_k(self):
"""blend 模式遵守 top_k 限制。"""
# embedding 返回 3 个结果
nodes = [_make_mock_text_node(f"n{i}", f"chunk {i}", {"kb_id": "kb1"}) for i in range(3)]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.8, 0.7])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory([]) # keywords 无结果
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=2)
assert len(results) == 2
async def test_fallback_when_embedding_fails(self):
"""embedding 检索失败时降级为纯关键词检索。"""
# embedding 模式会抛异常mock 设置)
mock_vs = MagicMock()
mock_vs.aquery = AsyncMock(side_effect=RuntimeError("connection failed"))
mock_embed = _make_mock_embed_model()
# keywords 返回结果
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# 应降级为纯关键词结果
assert len(results) == 1
assert results[0].chunk_id == "c1"
async def test_fallback_when_no_embed_model(self):
"""未配置 embed_model 时 blend 降级为纯关键词检索。"""
mock_vs = _make_mock_vector_store()
mock_embed = None
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
assert len(results) == 1
assert results[0].chunk_id == "c1"
# ---------------------------------------------------------------------------
# 不支持的 mode 测试
# ---------------------------------------------------------------------------
class TestUnsupportedMode:
"""不支持的检索模式测试。"""
async def test_unsupported_mode_raises(self):
"""不支持的 mode 抛出 ValueError。"""
mock_vs = _make_mock_vector_store()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
try:
await engine.retrieve("query", ["kb1"], "invalid_mode", top_k=5) # type: ignore[arg-type]
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "Unsupported" in str(e)