feat(rag_platform): U4 — dual-index retrieval (pgvector semantic + PG fulltext jieba)

Add fulltext.py: jieba tokenization + tsvector write/query
Add retrieval.py: RetrievalEngine with embedding/keywords/blend modes
Update models.py: add RetrievalRequest model
Tests: 35 new tests, 147 total passing
This commit is contained in:
chiguyong 2026-06-25 12:20:48 +08:00
parent 3f9588e673
commit fb9f16d6e5
6 changed files with 997 additions and 0 deletions

View File

@ -13,6 +13,7 @@ from agentkit.rag_platform.models import (
KnowledgeBase, KnowledgeBase,
QueryMode, QueryMode,
QueryResult, QueryResult,
RetrievalRequest,
) )
from agentkit.rag_platform.preview import ( from agentkit.rag_platform.preview import (
PreviewChunk, PreviewChunk,
@ -46,6 +47,7 @@ __all__ = [
"PreviewResult", "PreviewResult",
"QueryMode", "QueryMode",
"QueryResult", "QueryResult",
"RetrievalRequest",
"check_image_bomb", "check_image_bomb",
"check_zip_bomb", "check_zip_bomb",
"generate_preview", "generate_preview",

View File

@ -0,0 +1,102 @@
"""PG 全文检索 — jieba 分词 + tsvector 写入/查询。
避免依赖 PG 扩展pg_jieba/zhparser Python 层用 jieba 分词后
token 用空格连接写入 `search_vector` `to_tsvector('simple', ...)`
查询时同样用 jieba 分词构造 tsqueryAND 语义
依赖 `rag_platform_kb_chunks` 表的 `search_vector` PGVectorStore
hybrid_search=True 自动创建
"""
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING
from sqlalchemy import text
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
# 表名 — 与 indexing.py 保持一致
KB_CHUNKS_TABLE = "rag_platform_kb_chunks"
# ponytail: tsquery 中需转义的字符PG to_tsquery 语法)
# 升级路径:若需支持短语查询/权重,改用 phraseto_tsquery
_TSQUERY_SPECIAL = re.compile(r"[&|!():<>\"'\\]")
def tokenize(text: str) -> str:
"""jieba 分词后用空格连接 — 用于 tsvector 写入。
精确模式cut_all=False适合索引构建
"""
import jieba
tokens = jieba.cut(text, cut_all=False)
return " ".join(tokens)
def build_tsquery(text: str) -> str:
"""jieba 分词后构造 tsquery — 用于全文检索查询。
`&` 连接AND 语义过滤空 token 和纯空白 token
转义 PG to_tsquery 的特殊字符避免语法错误
"""
import jieba
tokens = jieba.cut(text, cut_all=False)
cleaned: list[str] = []
for t in tokens:
t = t.strip()
if not t:
continue
# 转义 PG to_tsquery 特殊字符
t = _TSQUERY_SPECIAL.sub(" ", t).strip()
if t:
cleaned.append(t)
return " & ".join(cleaned)
async def write_search_vector(session: "AsyncSession", chunk_id: str, content: str) -> None:
"""将 jieba 分词后的内容写入 chunk 的 search_vector 列。
使用 `to_tsvector('simple', space_joined_tokens)` 写入 'simple' 配置
不做语干提取保留 jieba 切分的原始 token
Args:
session: SQLAlchemy async session调用方负责 commit
chunk_id: chunk IDPGVectorStore 写入后的 node_id
content: chunk 原始文本
"""
tokenized = tokenize(content)
await session.execute(
text(
f"UPDATE {KB_CHUNKS_TABLE} " # noqa: S608 — 表名为常量,无注入风险
"SET search_vector = to_tsvector('simple', :tokens) "
"WHERE id = :chunk_id"
),
{"tokens": tokenized, "chunk_id": chunk_id},
)
async def write_search_vector_batch(session: "AsyncSession", items: list[tuple[str, str]]) -> None:
"""批量写入 search_vector — 用于文档索引构建后批量回填。
Args:
session: SQLAlchemy async session调用方负责 commit
items: [(chunk_id, content), ...]
"""
for chunk_id, content in items:
await write_search_vector(session, chunk_id, content)
__all__ = [
"build_tsquery",
"tokenize",
"write_search_vector",
"write_search_vector_batch",
]

View File

@ -106,6 +106,22 @@ class QueryResult(BaseModel):
kb_id: str kb_id: str
class RetrievalRequest(BaseModel):
"""检索请求 — 支持覆盖 KB 默认配置。
retrieval_mode / hit_processing_mode None 时使用 KB 默认值
"""
model_config = ConfigDict()
query: str
kb_ids: list[str]
retrieval_mode: QueryMode | None = None # None = 使用 KB 默认
hit_processing_mode: str | None = None # None = 使用 KB 默认
top_k: int = 5
user_id: str | None = None # 用于 ACL 过滤
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped) # ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -0,0 +1,255 @@
"""双索引检索引擎 — embedding/keywords/blend 三模式。
- embedding: pgvector 语义检索LlamaIndex PGVectorStore.aquery
- keywords: PG 全文检索jieba 分词 + tsquery_rank
- blend: 并行执行两种检索按分数归一化后合并去重排序
参考 LlamaIndex hybrid retriever 模式VectorStoreRetriever + 全文检索
ACL 过滤由调用方在传入 `kb_ids` 前完成 acl.filter_kb_by_user_acl
"""
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING, Any
from sqlalchemy import text
from agentkit.rag_platform.fulltext import KB_CHUNKS_TABLE, build_tsquery
from agentkit.rag_platform.models import QueryMode, QueryResult
if TYPE_CHECKING:
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.vector_stores.types import VectorStoreQuery
from llama_index.vector_stores.postgres import PGVectorStore
logger = logging.getLogger(__name__)
# ponytail: 默认归一化策略 — min-max将每种检索的分数映射到 [0,1]
# 升级路径:若需更精细的融合权重,可引入 RRF (Reciprocal Rank Fusion) 或学习权重
_BLEND_WEIGHT_EMBEDDING = 0.6
_BLEND_WEIGHT_KEYWORDS = 0.4
def _normalize_scores(scores: list[float]) -> list[float]:
"""min-max 归一化分数到 [0, 1]。
空列表返回空常数列表含单元素返回全 1.0 所有结果等价于最高相关度
"""
if not scores:
return []
lo, hi = min(scores), max(scores)
if hi - lo < 1e-9:
return [1.0 for _ in scores]
return [(s - lo) / (hi - lo) for s in scores]
class RetrievalEngine:
"""双索引检索引擎 — embedding/keywords/blend 三模式。
Args:
vector_store: LlamaIndex PGVectorStore 实例embedding 模式使用
session_factory: SQLAlchemy async session factorykeywords 模式使用
embed_model: LlamaIndex BaseEmbeddingembedding 模式使用
"""
def __init__(
self,
vector_store: "PGVectorStore",
session_factory,
embed_model: "BaseEmbedding | None" = None,
) -> None:
self._vector_store = vector_store
self._sf = session_factory
self._embed_model = embed_model
async def retrieve(
self,
query: str,
kb_ids: list[str],
mode: QueryMode,
top_k: int = 5,
) -> list[QueryResult]:
"""检索入口 — 根据 mode 分发到具体实现。
Args:
query: 查询文本
kb_ids: 限定检索的知识库 ID 列表已通过 ACL 过滤
mode: 检索模式
top_k: 返回结果数
Returns:
QueryResult 列表按分数降序无结果时返回空列表
"""
if not kb_ids:
return []
if mode == QueryMode.embedding:
return await self._retrieve_embedding(query, kb_ids, top_k)
elif mode == QueryMode.keywords:
return await self._retrieve_keywords(query, kb_ids, top_k)
elif mode == QueryMode.blend:
return await self._retrieve_blend(query, kb_ids, top_k)
else: # pragma: no cover — 枚举已穷尽
raise ValueError(f"Unsupported query mode: {mode}")
async def _retrieve_embedding(
self,
query: str,
kb_ids: list[str],
top_k: int,
) -> list[QueryResult]:
"""pgvector 语义检索。
使用 LlamaIndex vector_store.aquery 执行向量检索结果按 metadata.kb_id
过滤到 kb_ids 子集
"""
if self._embed_model is None:
raise ValueError("embed_model is required for embedding retrieval mode")
from llama_index.core.vector_stores.types import VectorStoreQuery
query_embedding = await self._embed_model.aget_text_embedding(query)
vs_query: VectorStoreQuery = VectorStoreQuery(
query_embedding=query_embedding,
similarity_top_k=top_k,
)
result = await self._vector_store.aquery(vs_query)
kb_set = set(kb_ids)
out: list[QueryResult] = []
for node, score in zip(result.nodes, result.similarities):
meta = node.metadata or {}
node_kb = meta.get("kb_id", "")
if node_kb not in kb_set:
continue
out.append(
QueryResult(
chunk_id=node.node_id,
content=node.get_content(),
score=float(score) if score is not None else 0.0,
metadata=meta,
document_id=meta.get("document_id", ""),
kb_id=node_kb,
)
)
if len(out) >= top_k:
break
return out
async def _retrieve_keywords(
self,
query: str,
kb_ids: list[str],
top_k: int,
) -> list[QueryResult]:
"""PG 全文检索jieba 分词 + tsquery
使用 `ts_rank(search_vector, tsquery)` 排序 kb_id 过滤
kb_id 存储在 metadata_ JSON 列中LlamaIndex PGVectorStore schema
`metadata_->>'kb_id'` 提取
"""
tsquery = build_tsquery(query)
if not tsquery:
return []
# ponytail: 用 ANY(%s) 传 kb_ids 列表,避免字符串拼接
# 升级路径:若 kb_ids 数量超过 PG 参数限制32k需分批查询
# 列名参考 LlamaIndex PGVectorStore 默认 schema
# id / embedding / text / metadata_ (JSON) / document_id / search_vector
sql = text(
f"""
SELECT id, text, metadata_, document_id,
ts_rank(search_vector, to_tsquery('simple', :tsquery)) AS score
FROM {KB_CHUNKS_TABLE}
WHERE search_vector @@ to_tsquery('simple', :tsquery)
AND metadata_->>'kb_id' = ANY(:kb_ids)
ORDER BY score DESC
LIMIT :top_k
""" # noqa: S608 — 表名为常量
)
async with self._sf() as db:
result = await db.execute(
sql,
{
"tsquery": tsquery,
"kb_ids": list(kb_ids),
"top_k": top_k,
},
)
rows = result.all()
out: list[QueryResult] = []
for row in rows:
chunk_id, content, metadata, document_id, score = row
meta: dict[str, Any] = metadata if isinstance(metadata, dict) else {}
kb_id = meta.get("kb_id", "")
out.append(
QueryResult(
chunk_id=str(chunk_id),
content=content or "",
score=float(score) if score is not None else 0.0,
metadata=meta,
document_id=str(document_id) if document_id is not None else "",
kb_id=str(kb_id) if kb_id is not None else "",
)
)
return out
async def _retrieve_blend(
self,
query: str,
kb_ids: list[str],
top_k: int,
) -> list[QueryResult]:
"""双索引合并 — 语义 + 全文结果去重排序。
并行执行两种检索各取 top_k chunk_id 去重分数归一化后加权融合
若任一检索失败 embed_model 未配置降级为另一种
"""
# 并行执行两种检索
embed_task = self._safe_retrieve_embedding(query, kb_ids, top_k)
kw_task = self._retrieve_keywords(query, kb_ids, top_k)
embed_results, kw_results = await asyncio.gather(
embed_task, kw_task, return_exceptions=False
)
# 归一化分数
embed_scores = _normalize_scores([r.score for r in embed_results])
for r, s in zip(embed_results, embed_scores):
r.score = s * _BLEND_WEIGHT_EMBEDDING
kw_scores = _normalize_scores([r.score for r in kw_results])
for r, s in zip(kw_results, kw_scores):
r.score = s * _BLEND_WEIGHT_KEYWORDS
# 合并去重 — 同 chunk_id 取最高分
merged: dict[str, QueryResult] = {}
for r in (*embed_results, *kw_results):
existing = merged.get(r.chunk_id)
if existing is None or r.score > existing.score:
merged[r.chunk_id] = r
# 按分数降序,取 top_k
results = sorted(merged.values(), key=lambda x: x.score, reverse=True)
return results[:top_k]
async def _safe_retrieve_embedding(
self,
query: str,
kb_ids: list[str],
top_k: int,
) -> list[QueryResult]:
"""embedding 检索的容错包装 — 失败时返回空列表(降级为纯关键词)。"""
if self._embed_model is None:
return []
try:
return await self._retrieve_embedding(query, kb_ids, top_k)
except Exception as e:
logger.warning("Embedding retrieval failed, falling back to keywords only: %s", e)
return []
__all__ = ["RetrievalEngine"]

View File

@ -0,0 +1,155 @@
"""U4 测试 — jieba 分词 + tsvector 写入/查询。
测试场景
1. tokenize中文分词后用空格连接
2. build_tsquery构造 AND 语义的 tsquery过滤空 token转义特殊字符
3. write_search_vector调用 session.execute 执行 UPDATE SQL
4. write_search_vector_batch批量写入
"""
from __future__ import annotations
from unittest.mock import AsyncMock
from agentkit.rag_platform.fulltext import (
KB_CHUNKS_TABLE,
build_tsquery,
tokenize,
write_search_vector,
write_search_vector_batch,
)
class TestTokenize:
"""jieba 分词测试。"""
def test_chinese_text_tokenized(self):
"""中文文本被分词后用空格连接。"""
result = tokenize("我爱自然语言处理")
# jieba 精确模式应切分为多个 token
assert isinstance(result, str)
assert " " in result # 多个 token 用空格连接
# 关键词应出现在结果中
assert "自然语言" in result or "自然语言处理" in result
def test_english_text_preserved(self):
"""英文文本保持原样(按空格分词)。"""
result = tokenize("hello world")
assert "hello" in result
assert "world" in result
def test_mixed_text(self):
"""中英文混合文本正常分词。"""
result = tokenize("RAG 检索增强生成")
assert isinstance(result, str)
assert len(result) > 0
def test_empty_string(self):
"""空字符串返回空字符串。"""
assert tokenize("") == ""
class TestBuildTsquery:
"""tsquery 构造测试。"""
def test_chinese_query_and_semantics(self):
"""中文查询用 & 连接AND 语义)。"""
result = build_tsquery("自然语言处理")
# 应包含 & 连接符(多 token 时)
assert isinstance(result, str)
assert len(result) > 0
# 不应包含空 token连续 & 之间无内容)
assert " & & " not in result
assert not result.startswith("& ")
assert not result.endswith(" &")
def test_filters_empty_tokens(self):
"""空 token 被过滤掉。"""
# 多空格输入
result = build_tsquery("hello world")
# 不应有连续的 &(空 token 会导致 " & & "
assert " & & " not in result
def test_escapes_special_chars(self):
"""PG to_tsquery 特殊字符被转义(替换为空格)。"""
# 包含 & | ! ( ) : < > " ' \
result = build_tsquery("test & injection | attempt")
# 不应保留原始的特殊字符(会被替换为空格然后过滤)
assert "&" not in result or result.count("&") == result.count(" & ") + (
0 if result.startswith("&") else 0
)
# 应该是合法的 tsquery 格式
assert isinstance(result, str)
def test_empty_query_returns_empty(self):
"""空查询返回空字符串。"""
assert build_tsquery("") == ""
def test_whitespace_only_returns_empty(self):
"""纯空白查询返回空字符串。"""
assert build_tsquery(" ") == ""
class TestWriteSearchVector:
"""search_vector 写入测试。"""
async def test_write_calls_execute_with_correct_sql(self):
"""write_search_vector 调用 session.execute 执行 UPDATE SQL。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector(mock_session, "chunk-001", "测试内容")
mock_session.execute.assert_awaited_once()
# 验证传入的参数
call_args = mock_session.execute.await_args
# 第一个参数是 SQL text 对象,第二个是参数字典
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
assert params["chunk_id"] == "chunk-001"
# tokenize 后 jieba 将 "测试内容" 切分为 "测试 内容"
assert "测试" in params["tokens"]
assert "内容" in params["tokens"]
async def test_write_uses_correct_table_name(self):
"""write_search_vector 使用正确的表名。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector(mock_session, "c1", "content")
call_args = mock_session.execute.await_args
sql_obj = call_args.args[0]
# SQL 文本应包含表名
assert KB_CHUNKS_TABLE in str(sql_obj)
# 应使用 to_tsvector('simple', ...)
assert "to_tsvector" in str(sql_obj)
assert "search_vector" in str(sql_obj)
class TestWriteSearchVectorBatch:
"""批量写入测试。"""
async def test_batch_writes_all_items(self):
"""批量写入调用 write_search_vector N 次。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
items = [
("chunk-1", "内容一"),
("chunk-2", "内容二"),
("chunk-3", "内容三"),
]
await write_search_vector_batch(mock_session, items)
# 应调用 execute 3 次
assert mock_session.execute.await_count == 3
async def test_batch_empty_items_no_calls(self):
"""空列表不调用 execute。"""
mock_session = AsyncMock()
mock_session.execute = AsyncMock()
await write_search_vector_batch(mock_session, [])
mock_session.execute.assert_not_awaited()

View File

@ -0,0 +1,467 @@
"""U4 测试 — 双索引检索引擎embedding/keywords/blend
测试场景
1. embedding 模式语义检索返回相关结果 kb_id 过滤
2. keywords 模式中文全文检索返回包含关键词的结果
3. blend 模式合并语义+全文结果去重排序
4. 查询无结果时返回空列表非报错
5. RetrievalRequest 模型字段
"""
from __future__ import annotations
import sys
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.models import QueryMode, QueryResult, RetrievalRequest
from agentkit.rag_platform.retrieval import RetrievalEngine, _normalize_scores
# ---------------------------------------------------------------------------
# llama_index 模块 mock — 测试环境可能未安装 llama_index
# ---------------------------------------------------------------------------
def _setup_llama_index_mocks():
"""注入 mock llama_index 模块到 sys.modules使 import 成功。
仅当真实 llama_index 无法导入时才注入 mock 避免污染已安装真实模块的环境
"""
try:
import llama_index # noqa: F401
return # 真实模块可用,无需 mock
except ImportError:
pass
# 创建 mock 模块层级
mock_li = MagicMock()
mock_li_core = MagicMock()
mock_li_vector_stores = MagicMock()
mock_li_core_vector_stores_types = MagicMock()
mock_li_core_embeddings = MagicMock()
# VectorStoreQuery — 简单的可调用 mock
mock_li_core_vector_stores_types.VectorStoreQuery = MagicMock()
sys.modules["llama_index"] = mock_li
sys.modules["llama_index.core"] = mock_li_core
sys.modules["llama_index.vector_stores"] = mock_li_vector_stores
sys.modules["llama_index.core.vector_stores"] = mock_li_core
sys.modules["llama_index.core.vector_stores.types"] = mock_li_core_vector_stores_types
sys.modules["llama_index.core.embeddings"] = mock_li_core_embeddings
_setup_llama_index_mocks()
# ---------------------------------------------------------------------------
# 测试辅助函数
# ---------------------------------------------------------------------------
def _make_mock_embed_model():
"""创建 mock embedding 模型。"""
mock = MagicMock()
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return mock
def _make_mock_text_node(node_id: str, content: str, metadata: dict | None = None):
"""创建 mock LlamaIndex TextNode。"""
node = MagicMock()
node.node_id = node_id
node.get_content.return_value = content
node.metadata = metadata or {}
return node
def _make_mock_vector_store(nodes=None, similarities=None):
"""创建 mock PGVectorStore。"""
mock = MagicMock()
mock.aquery = AsyncMock()
if nodes is not None:
mock_result = MagicMock()
mock_result.nodes = nodes
mock_result.similarities = similarities or [0.9] * len(nodes)
mock.aquery.return_value = mock_result
return mock
def _make_mock_session_factory(execute_result_rows=None):
"""创建 mock session factory用于 keywords 模式测试。
Args:
execute_result_rows: list of tuples SQL 查询返回的行
每行格式: (id, text, metadata_, document_id, score)
"""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.all.return_value = execute_result_rows or []
mock_session.execute = AsyncMock(return_value=mock_result)
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
# ---------------------------------------------------------------------------
# RetrievalRequest 模型测试
# ---------------------------------------------------------------------------
class TestRetrievalRequest:
"""RetrievalRequest 模型测试。"""
def test_defaults(self):
"""默认值正确。"""
req = RetrievalRequest(query="test", kb_ids=["kb1"])
assert req.query == "test"
assert req.kb_ids == ["kb1"]
assert req.retrieval_mode is None
assert req.hit_processing_mode is None
assert req.top_k == 5
assert req.user_id is None
def test_explicit_values(self):
"""显式赋值正确。"""
req = RetrievalRequest(
query="hello",
kb_ids=["kb1", "kb2"],
retrieval_mode=QueryMode.embedding,
hit_processing_mode="direct",
top_k=10,
user_id="user1",
)
assert req.retrieval_mode == QueryMode.embedding
assert req.hit_processing_mode == "direct"
assert req.top_k == 10
assert req.user_id == "user1"
# ---------------------------------------------------------------------------
# _normalize_scores 测试
# ---------------------------------------------------------------------------
class TestNormalizeScores:
"""分数归一化测试。"""
def test_empty_list(self):
"""空列表返回空列表。"""
assert _normalize_scores([]) == []
def test_constant_list_returns_ones(self):
"""常数列表(含单元素)返回全 1.0 — 等价于最高相关度。"""
assert _normalize_scores([0.5, 0.5, 0.5]) == [1.0, 1.0, 1.0]
assert _normalize_scores([0.9]) == [1.0]
def test_min_max_normalization(self):
"""min-max 归一化到 [0, 1]。"""
result = _normalize_scores([1.0, 2.0, 3.0])
assert result[0] == 0.0 # 最小值
assert result[2] == 1.0 # 最大值
assert 0.0 < result[1] < 1.0
# ---------------------------------------------------------------------------
# embedding 模式测试
# ---------------------------------------------------------------------------
class TestEmbeddingRetrieval:
"""embedding 模式检索测试。"""
async def test_returns_relevant_results(self):
"""embedding 模式返回相关结果。"""
nodes = [
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.95, 0.80])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert len(results) == 2
assert all(isinstance(r, QueryResult) for r in results)
assert results[0].chunk_id == "n1"
assert results[0].score == 0.95
assert results[0].kb_id == "kb1"
async def test_filters_by_kb_id(self):
"""embedding 模式按 kb_id 过滤结果。"""
# n1 属于 kb1n2 属于 kb2不在查询范围
nodes = [
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb2"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.9])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert len(results) == 1
assert results[0].kb_id == "kb1"
async def test_empty_results_no_error(self):
"""无结果时返回空列表(非报错)。"""
mock_vs = _make_mock_vector_store([], [])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
assert results == []
async def test_empty_kb_ids_returns_empty(self):
"""kb_ids 为空时直接返回空列表。"""
mock_vs = _make_mock_vector_store()
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", [], QueryMode.embedding, top_k=5)
assert results == []
mock_vs.aquery.assert_not_awaited()
async def test_no_embed_model_raises(self):
"""embedding 模式未配置 embed_model 抛出异常。"""
mock_vs = _make_mock_vector_store()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
try:
await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "embed_model" in str(e)
# ---------------------------------------------------------------------------
# keywords 模式测试
# ---------------------------------------------------------------------------
class TestKeywordsRetrieval:
"""keywords 模式检索测试。"""
async def test_returns_matching_results(self):
"""keywords 模式返回包含关键词的结果。"""
rows = [
("c1", "自然语言处理内容", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
("c2", "另一段文本", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("自然语言", ["kb1"], QueryMode.keywords, top_k=5)
assert len(results) == 2
assert results[0].chunk_id == "c1"
assert results[0].score == 0.8
assert results[0].kb_id == "kb1"
assert "自然语言" in results[0].content
async def test_empty_results_no_error(self):
"""无结果时返回空列表(非报错)。"""
sf, _ = _make_mock_session_factory([])
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("query", ["kb1"], QueryMode.keywords, top_k=5)
assert results == []
async def test_empty_query_returns_empty(self):
"""空查询(分词后无 token返回空列表不执行 SQL。"""
sf, mock_session = _make_mock_session_factory([])
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve(" ", ["kb1"], QueryMode.keywords, top_k=5)
assert results == []
mock_session.execute.assert_not_awaited()
async def test_passes_kb_ids_to_sql(self):
"""kb_ids 作为参数传递给 SQL 查询。"""
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, mock_session = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
await engine.retrieve("test", ["kb1", "kb2"], QueryMode.keywords, top_k=3)
call_args = mock_session.execute.await_args
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
assert params["kb_ids"] == ["kb1", "kb2"]
assert params["top_k"] == 3
async def test_handles_none_metadata(self):
"""metadata 为 None 时正常处理。"""
rows = [
("c1", "content", None, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
mock_vs = _make_mock_vector_store()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
results = await engine.retrieve("test", ["kb1"], QueryMode.keywords, top_k=5)
assert len(results) == 1
assert results[0].metadata == {}
assert results[0].kb_id == ""
# ---------------------------------------------------------------------------
# blend 模式测试
# ---------------------------------------------------------------------------
class TestBlendRetrieval:
"""blend 模式检索测试。"""
async def test_merges_and_deduplicates(self):
"""blend 模式合并语义+全文结果,按 chunk_id 去重。"""
# embedding 结果
nodes = [
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
_make_mock_text_node("c2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.7])
mock_embed = _make_mock_embed_model()
# keywords 结果 — c1 重复c3 新增
rows = [
("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
("c3", "chunk 3", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.6),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# 去重后应有 3 个结果c1, c2, c3
chunk_ids = {r.chunk_id for r in results}
assert chunk_ids == {"c1", "c2", "c3"}
# 结果按分数降序
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
async def test_dedup_takes_highest_score(self):
"""去重时同 chunk_id 取最高分。"""
# embedding: c1 score=0.9 (归一化后 1.0 * 0.6 = 0.6)
nodes = [
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1"}),
]
mock_vs = _make_mock_vector_store(nodes, [0.9])
mock_embed = _make_mock_embed_model()
# keywords: c1 score=0.8 (归一化后 1.0 * 0.4 = 0.4)
rows = [
("c1", "chunk 1", {"kb_id": "kb1"}, "d1", 0.8),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# c1 只出现一次,取 embedding 的分数0.6 > 0.4
assert len(results) == 1
assert results[0].chunk_id == "c1"
assert results[0].score == 0.6 # 1.0 * 0.6
async def test_empty_results_no_error(self):
"""两种检索都无结果时返回空列表。"""
mock_vs = _make_mock_vector_store([], [])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory([])
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
assert results == []
async def test_respects_top_k(self):
"""blend 模式遵守 top_k 限制。"""
# embedding 返回 3 个结果
nodes = [_make_mock_text_node(f"n{i}", f"chunk {i}", {"kb_id": "kb1"}) for i in range(3)]
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.8, 0.7])
mock_embed = _make_mock_embed_model()
sf, _ = _make_mock_session_factory([]) # keywords 无结果
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=2)
assert len(results) == 2
async def test_fallback_when_embedding_fails(self):
"""embedding 检索失败时降级为纯关键词检索。"""
# embedding 模式会抛异常mock 设置)
mock_vs = MagicMock()
mock_vs.aquery = AsyncMock(side_effect=RuntimeError("connection failed"))
mock_embed = _make_mock_embed_model()
# keywords 返回结果
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
# 应降级为纯关键词结果
assert len(results) == 1
assert results[0].chunk_id == "c1"
async def test_fallback_when_no_embed_model(self):
"""未配置 embed_model 时 blend 降级为纯关键词检索。"""
mock_vs = _make_mock_vector_store()
mock_embed = None
rows = [
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
]
sf, _ = _make_mock_session_factory(rows)
engine = RetrievalEngine(mock_vs, sf, mock_embed)
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
assert len(results) == 1
assert results[0].chunk_id == "c1"
# ---------------------------------------------------------------------------
# 不支持的 mode 测试
# ---------------------------------------------------------------------------
class TestUnsupportedMode:
"""不支持的检索模式测试。"""
async def test_unsupported_mode_raises(self):
"""不支持的 mode 抛出 ValueError。"""
mock_vs = _make_mock_vector_store()
sf, _ = _make_mock_session_factory()
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
try:
await engine.retrieve("query", ["kb1"], "invalid_mode", top_k=5) # type: ignore[arg-type]
raise AssertionError("Expected ValueError")
except ValueError as e:
assert "Unsupported" in str(e)