feat(rag_platform): U4 — dual-index retrieval (pgvector semantic + PG fulltext jieba)
Add fulltext.py: jieba tokenization + tsvector write/query Add retrieval.py: RetrievalEngine with embedding/keywords/blend modes Update models.py: add RetrievalRequest model Tests: 35 new tests, 147 total passing
This commit is contained in:
parent
3f9588e673
commit
fb9f16d6e5
|
|
@ -13,6 +13,7 @@ from agentkit.rag_platform.models import (
|
||||||
KnowledgeBase,
|
KnowledgeBase,
|
||||||
QueryMode,
|
QueryMode,
|
||||||
QueryResult,
|
QueryResult,
|
||||||
|
RetrievalRequest,
|
||||||
)
|
)
|
||||||
from agentkit.rag_platform.preview import (
|
from agentkit.rag_platform.preview import (
|
||||||
PreviewChunk,
|
PreviewChunk,
|
||||||
|
|
@ -46,6 +47,7 @@ __all__ = [
|
||||||
"PreviewResult",
|
"PreviewResult",
|
||||||
"QueryMode",
|
"QueryMode",
|
||||||
"QueryResult",
|
"QueryResult",
|
||||||
|
"RetrievalRequest",
|
||||||
"check_image_bomb",
|
"check_image_bomb",
|
||||||
"check_zip_bomb",
|
"check_zip_bomb",
|
||||||
"generate_preview",
|
"generate_preview",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""PG 全文检索 — jieba 分词 + tsvector 写入/查询。
|
||||||
|
|
||||||
|
避免依赖 PG 扩展(pg_jieba/zhparser):在 Python 层用 jieba 分词后,
|
||||||
|
将 token 用空格连接写入 `search_vector` 列(`to_tsvector('simple', ...)`)。
|
||||||
|
查询时同样用 jieba 分词构造 tsquery(AND 语义)。
|
||||||
|
|
||||||
|
依赖 `rag_platform_kb_chunks` 表的 `search_vector` 列(由 PGVectorStore
|
||||||
|
hybrid_search=True 自动创建)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 表名 — 与 indexing.py 保持一致
|
||||||
|
KB_CHUNKS_TABLE = "rag_platform_kb_chunks"
|
||||||
|
|
||||||
|
# ponytail: tsquery 中需转义的字符(PG to_tsquery 语法)
|
||||||
|
# 升级路径:若需支持短语查询/权重,改用 phraseto_tsquery
|
||||||
|
_TSQUERY_SPECIAL = re.compile(r"[&|!():<>\"'\\]")
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(text: str) -> str:
|
||||||
|
"""jieba 分词后用空格连接 — 用于 tsvector 写入。
|
||||||
|
|
||||||
|
精确模式(cut_all=False)适合索引构建。
|
||||||
|
"""
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
tokens = jieba.cut(text, cut_all=False)
|
||||||
|
return " ".join(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def build_tsquery(text: str) -> str:
|
||||||
|
"""jieba 分词后构造 tsquery — 用于全文检索查询。
|
||||||
|
|
||||||
|
用 `&` 连接(AND 语义),过滤空 token 和纯空白 token。
|
||||||
|
转义 PG to_tsquery 的特殊字符,避免语法错误。
|
||||||
|
"""
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
tokens = jieba.cut(text, cut_all=False)
|
||||||
|
cleaned: list[str] = []
|
||||||
|
for t in tokens:
|
||||||
|
t = t.strip()
|
||||||
|
if not t:
|
||||||
|
continue
|
||||||
|
# 转义 PG to_tsquery 特殊字符
|
||||||
|
t = _TSQUERY_SPECIAL.sub(" ", t).strip()
|
||||||
|
if t:
|
||||||
|
cleaned.append(t)
|
||||||
|
return " & ".join(cleaned)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_search_vector(session: "AsyncSession", chunk_id: str, content: str) -> None:
|
||||||
|
"""将 jieba 分词后的内容写入 chunk 的 search_vector 列。
|
||||||
|
|
||||||
|
使用 `to_tsvector('simple', space_joined_tokens)` 写入 — 'simple' 配置
|
||||||
|
不做语干提取,保留 jieba 切分的原始 token。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: SQLAlchemy async session(调用方负责 commit)
|
||||||
|
chunk_id: chunk 行 ID(PGVectorStore 写入后的 node_id)
|
||||||
|
content: chunk 原始文本
|
||||||
|
"""
|
||||||
|
tokenized = tokenize(content)
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
f"UPDATE {KB_CHUNKS_TABLE} " # noqa: S608 — 表名为常量,无注入风险
|
||||||
|
"SET search_vector = to_tsvector('simple', :tokens) "
|
||||||
|
"WHERE id = :chunk_id"
|
||||||
|
),
|
||||||
|
{"tokens": tokenized, "chunk_id": chunk_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_search_vector_batch(session: "AsyncSession", items: list[tuple[str, str]]) -> None:
|
||||||
|
"""批量写入 search_vector — 用于文档索引构建后批量回填。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: SQLAlchemy async session(调用方负责 commit)
|
||||||
|
items: [(chunk_id, content), ...]
|
||||||
|
"""
|
||||||
|
for chunk_id, content in items:
|
||||||
|
await write_search_vector(session, chunk_id, content)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_tsquery",
|
||||||
|
"tokenize",
|
||||||
|
"write_search_vector",
|
||||||
|
"write_search_vector_batch",
|
||||||
|
]
|
||||||
|
|
@ -106,6 +106,22 @@ class QueryResult(BaseModel):
|
||||||
kb_id: str
|
kb_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalRequest(BaseModel):
|
||||||
|
"""检索请求 — 支持覆盖 KB 默认配置。
|
||||||
|
|
||||||
|
retrieval_mode / hit_processing_mode 为 None 时使用 KB 默认值。
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict()
|
||||||
|
|
||||||
|
query: str
|
||||||
|
kb_ids: list[str]
|
||||||
|
retrieval_mode: QueryMode | None = None # None = 使用 KB 默认
|
||||||
|
hit_processing_mode: str | None = None # None = 使用 KB 默认
|
||||||
|
top_k: int = 5
|
||||||
|
user_id: str | None = None # 用于 ACL 过滤
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped)
|
# ORM Models (SQLAlchemy 2 DeclarativeBase + Mapped)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,255 @@
|
||||||
|
"""双索引检索引擎 — embedding/keywords/blend 三模式。
|
||||||
|
|
||||||
|
- embedding: pgvector 语义检索(LlamaIndex PGVectorStore.aquery)
|
||||||
|
- keywords: PG 全文检索(jieba 分词 + tsquery_rank)
|
||||||
|
- blend: 并行执行两种检索,按分数归一化后合并去重排序
|
||||||
|
|
||||||
|
参考 LlamaIndex hybrid retriever 模式(VectorStoreRetriever + 全文检索)。
|
||||||
|
ACL 过滤由调用方在传入 `kb_ids` 前完成(见 acl.filter_kb_by_user_acl)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from agentkit.rag_platform.fulltext import KB_CHUNKS_TABLE, build_tsquery
|
||||||
|
from agentkit.rag_platform.models import QueryMode, QueryResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core.embeddings import BaseEmbedding
|
||||||
|
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||||
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ponytail: 默认归一化策略 — min-max,将每种检索的分数映射到 [0,1]
|
||||||
|
# 升级路径:若需更精细的融合权重,可引入 RRF (Reciprocal Rank Fusion) 或学习权重
|
||||||
|
_BLEND_WEIGHT_EMBEDDING = 0.6
|
||||||
|
_BLEND_WEIGHT_KEYWORDS = 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_scores(scores: list[float]) -> list[float]:
|
||||||
|
"""min-max 归一化分数到 [0, 1]。
|
||||||
|
|
||||||
|
空列表返回空。常数列表(含单元素)返回全 1.0 — 所有结果等价于最高相关度。
|
||||||
|
"""
|
||||||
|
if not scores:
|
||||||
|
return []
|
||||||
|
lo, hi = min(scores), max(scores)
|
||||||
|
if hi - lo < 1e-9:
|
||||||
|
return [1.0 for _ in scores]
|
||||||
|
return [(s - lo) / (hi - lo) for s in scores]
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalEngine:
|
||||||
|
"""双索引检索引擎 — embedding/keywords/blend 三模式。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store: LlamaIndex PGVectorStore 实例(embedding 模式使用)
|
||||||
|
session_factory: SQLAlchemy async session factory(keywords 模式使用)
|
||||||
|
embed_model: LlamaIndex BaseEmbedding(embedding 模式使用)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: "PGVectorStore",
|
||||||
|
session_factory,
|
||||||
|
embed_model: "BaseEmbedding | None" = None,
|
||||||
|
) -> None:
|
||||||
|
self._vector_store = vector_store
|
||||||
|
self._sf = session_factory
|
||||||
|
self._embed_model = embed_model
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
mode: QueryMode,
|
||||||
|
top_k: int = 5,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""检索入口 — 根据 mode 分发到具体实现。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询文本
|
||||||
|
kb_ids: 限定检索的知识库 ID 列表(已通过 ACL 过滤)
|
||||||
|
mode: 检索模式
|
||||||
|
top_k: 返回结果数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryResult 列表(按分数降序)。无结果时返回空列表。
|
||||||
|
"""
|
||||||
|
if not kb_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if mode == QueryMode.embedding:
|
||||||
|
return await self._retrieve_embedding(query, kb_ids, top_k)
|
||||||
|
elif mode == QueryMode.keywords:
|
||||||
|
return await self._retrieve_keywords(query, kb_ids, top_k)
|
||||||
|
elif mode == QueryMode.blend:
|
||||||
|
return await self._retrieve_blend(query, kb_ids, top_k)
|
||||||
|
else: # pragma: no cover — 枚举已穷尽
|
||||||
|
raise ValueError(f"Unsupported query mode: {mode}")
|
||||||
|
|
||||||
|
async def _retrieve_embedding(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""pgvector 语义检索。
|
||||||
|
|
||||||
|
使用 LlamaIndex vector_store.aquery 执行向量检索,结果按 metadata.kb_id
|
||||||
|
过滤到 kb_ids 子集。
|
||||||
|
"""
|
||||||
|
if self._embed_model is None:
|
||||||
|
raise ValueError("embed_model is required for embedding retrieval mode")
|
||||||
|
|
||||||
|
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||||
|
|
||||||
|
query_embedding = await self._embed_model.aget_text_embedding(query)
|
||||||
|
vs_query: VectorStoreQuery = VectorStoreQuery(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
similarity_top_k=top_k,
|
||||||
|
)
|
||||||
|
result = await self._vector_store.aquery(vs_query)
|
||||||
|
|
||||||
|
kb_set = set(kb_ids)
|
||||||
|
out: list[QueryResult] = []
|
||||||
|
for node, score in zip(result.nodes, result.similarities):
|
||||||
|
meta = node.metadata or {}
|
||||||
|
node_kb = meta.get("kb_id", "")
|
||||||
|
if node_kb not in kb_set:
|
||||||
|
continue
|
||||||
|
out.append(
|
||||||
|
QueryResult(
|
||||||
|
chunk_id=node.node_id,
|
||||||
|
content=node.get_content(),
|
||||||
|
score=float(score) if score is not None else 0.0,
|
||||||
|
metadata=meta,
|
||||||
|
document_id=meta.get("document_id", ""),
|
||||||
|
kb_id=node_kb,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(out) >= top_k:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _retrieve_keywords(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""PG 全文检索(jieba 分词 + tsquery)。
|
||||||
|
|
||||||
|
使用 `ts_rank(search_vector, tsquery)` 排序,按 kb_id 过滤。
|
||||||
|
kb_id 存储在 metadata_ JSON 列中(LlamaIndex PGVectorStore schema),
|
||||||
|
用 `metadata_->>'kb_id'` 提取。
|
||||||
|
"""
|
||||||
|
tsquery = build_tsquery(query)
|
||||||
|
if not tsquery:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# ponytail: 用 ANY(%s) 传 kb_ids 列表,避免字符串拼接
|
||||||
|
# 升级路径:若 kb_ids 数量超过 PG 参数限制(32k),需分批查询
|
||||||
|
# 列名参考 LlamaIndex PGVectorStore 默认 schema:
|
||||||
|
# id / embedding / text / metadata_ (JSON) / document_id / search_vector
|
||||||
|
sql = text(
|
||||||
|
f"""
|
||||||
|
SELECT id, text, metadata_, document_id,
|
||||||
|
ts_rank(search_vector, to_tsquery('simple', :tsquery)) AS score
|
||||||
|
FROM {KB_CHUNKS_TABLE}
|
||||||
|
WHERE search_vector @@ to_tsquery('simple', :tsquery)
|
||||||
|
AND metadata_->>'kb_id' = ANY(:kb_ids)
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT :top_k
|
||||||
|
""" # noqa: S608 — 表名为常量
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._sf() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
sql,
|
||||||
|
{
|
||||||
|
"tsquery": tsquery,
|
||||||
|
"kb_ids": list(kb_ids),
|
||||||
|
"top_k": top_k,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
|
||||||
|
out: list[QueryResult] = []
|
||||||
|
for row in rows:
|
||||||
|
chunk_id, content, metadata, document_id, score = row
|
||||||
|
meta: dict[str, Any] = metadata if isinstance(metadata, dict) else {}
|
||||||
|
kb_id = meta.get("kb_id", "")
|
||||||
|
out.append(
|
||||||
|
QueryResult(
|
||||||
|
chunk_id=str(chunk_id),
|
||||||
|
content=content or "",
|
||||||
|
score=float(score) if score is not None else 0.0,
|
||||||
|
metadata=meta,
|
||||||
|
document_id=str(document_id) if document_id is not None else "",
|
||||||
|
kb_id=str(kb_id) if kb_id is not None else "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _retrieve_blend(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""双索引合并 — 语义 + 全文结果去重排序。
|
||||||
|
|
||||||
|
并行执行两种检索(各取 top_k),按 chunk_id 去重,分数归一化后加权融合。
|
||||||
|
若任一检索失败(如 embed_model 未配置),降级为另一种。
|
||||||
|
"""
|
||||||
|
# 并行执行两种检索
|
||||||
|
embed_task = self._safe_retrieve_embedding(query, kb_ids, top_k)
|
||||||
|
kw_task = self._retrieve_keywords(query, kb_ids, top_k)
|
||||||
|
embed_results, kw_results = await asyncio.gather(
|
||||||
|
embed_task, kw_task, return_exceptions=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 归一化分数
|
||||||
|
embed_scores = _normalize_scores([r.score for r in embed_results])
|
||||||
|
for r, s in zip(embed_results, embed_scores):
|
||||||
|
r.score = s * _BLEND_WEIGHT_EMBEDDING
|
||||||
|
|
||||||
|
kw_scores = _normalize_scores([r.score for r in kw_results])
|
||||||
|
for r, s in zip(kw_results, kw_scores):
|
||||||
|
r.score = s * _BLEND_WEIGHT_KEYWORDS
|
||||||
|
|
||||||
|
# 合并去重 — 同 chunk_id 取最高分
|
||||||
|
merged: dict[str, QueryResult] = {}
|
||||||
|
for r in (*embed_results, *kw_results):
|
||||||
|
existing = merged.get(r.chunk_id)
|
||||||
|
if existing is None or r.score > existing.score:
|
||||||
|
merged[r.chunk_id] = r
|
||||||
|
|
||||||
|
# 按分数降序,取 top_k
|
||||||
|
results = sorted(merged.values(), key=lambda x: x.score, reverse=True)
|
||||||
|
return results[:top_k]
|
||||||
|
|
||||||
|
async def _safe_retrieve_embedding(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""embedding 检索的容错包装 — 失败时返回空列表(降级为纯关键词)。"""
|
||||||
|
if self._embed_model is None:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return await self._retrieve_embedding(query, kb_ids, top_k)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Embedding retrieval failed, falling back to keywords only: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["RetrievalEngine"]
|
||||||
|
|
@ -0,0 +1,155 @@
|
||||||
|
"""U4 测试 — jieba 分词 + tsvector 写入/查询。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. tokenize:中文分词后用空格连接
|
||||||
|
2. build_tsquery:构造 AND 语义的 tsquery,过滤空 token,转义特殊字符
|
||||||
|
3. write_search_vector:调用 session.execute 执行 UPDATE SQL
|
||||||
|
4. write_search_vector_batch:批量写入
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from agentkit.rag_platform.fulltext import (
|
||||||
|
KB_CHUNKS_TABLE,
|
||||||
|
build_tsquery,
|
||||||
|
tokenize,
|
||||||
|
write_search_vector,
|
||||||
|
write_search_vector_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenize:
|
||||||
|
"""jieba 分词测试。"""
|
||||||
|
|
||||||
|
def test_chinese_text_tokenized(self):
|
||||||
|
"""中文文本被分词后用空格连接。"""
|
||||||
|
result = tokenize("我爱自然语言处理")
|
||||||
|
# jieba 精确模式应切分为多个 token
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert " " in result # 多个 token 用空格连接
|
||||||
|
# 关键词应出现在结果中
|
||||||
|
assert "自然语言" in result or "自然语言处理" in result
|
||||||
|
|
||||||
|
def test_english_text_preserved(self):
|
||||||
|
"""英文文本保持原样(按空格分词)。"""
|
||||||
|
result = tokenize("hello world")
|
||||||
|
assert "hello" in result
|
||||||
|
assert "world" in result
|
||||||
|
|
||||||
|
def test_mixed_text(self):
|
||||||
|
"""中英文混合文本正常分词。"""
|
||||||
|
result = tokenize("RAG 检索增强生成")
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
"""空字符串返回空字符串。"""
|
||||||
|
assert tokenize("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildTsquery:
|
||||||
|
"""tsquery 构造测试。"""
|
||||||
|
|
||||||
|
def test_chinese_query_and_semantics(self):
|
||||||
|
"""中文查询用 & 连接(AND 语义)。"""
|
||||||
|
result = build_tsquery("自然语言处理")
|
||||||
|
# 应包含 & 连接符(多 token 时)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
# 不应包含空 token(连续 & 之间无内容)
|
||||||
|
assert " & & " not in result
|
||||||
|
assert not result.startswith("& ")
|
||||||
|
assert not result.endswith(" &")
|
||||||
|
|
||||||
|
def test_filters_empty_tokens(self):
|
||||||
|
"""空 token 被过滤掉。"""
|
||||||
|
# 多空格输入
|
||||||
|
result = build_tsquery("hello world")
|
||||||
|
# 不应有连续的 &(空 token 会导致 " & & ")
|
||||||
|
assert " & & " not in result
|
||||||
|
|
||||||
|
def test_escapes_special_chars(self):
|
||||||
|
"""PG to_tsquery 特殊字符被转义(替换为空格)。"""
|
||||||
|
# 包含 & | ! ( ) : < > " ' \
|
||||||
|
result = build_tsquery("test & injection | attempt")
|
||||||
|
# 不应保留原始的特殊字符(会被替换为空格然后过滤)
|
||||||
|
assert "&" not in result or result.count("&") == result.count(" & ") + (
|
||||||
|
0 if result.startswith("&") else 0
|
||||||
|
)
|
||||||
|
# 应该是合法的 tsquery 格式
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
def test_empty_query_returns_empty(self):
|
||||||
|
"""空查询返回空字符串。"""
|
||||||
|
assert build_tsquery("") == ""
|
||||||
|
|
||||||
|
def test_whitespace_only_returns_empty(self):
|
||||||
|
"""纯空白查询返回空字符串。"""
|
||||||
|
assert build_tsquery(" ") == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteSearchVector:
|
||||||
|
"""search_vector 写入测试。"""
|
||||||
|
|
||||||
|
async def test_write_calls_execute_with_correct_sql(self):
|
||||||
|
"""write_search_vector 调用 session.execute 执行 UPDATE SQL。"""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock()
|
||||||
|
|
||||||
|
await write_search_vector(mock_session, "chunk-001", "测试内容")
|
||||||
|
|
||||||
|
mock_session.execute.assert_awaited_once()
|
||||||
|
# 验证传入的参数
|
||||||
|
call_args = mock_session.execute.await_args
|
||||||
|
# 第一个参数是 SQL text 对象,第二个是参数字典
|
||||||
|
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
|
||||||
|
assert params["chunk_id"] == "chunk-001"
|
||||||
|
# tokenize 后 jieba 将 "测试内容" 切分为 "测试 内容"
|
||||||
|
assert "测试" in params["tokens"]
|
||||||
|
assert "内容" in params["tokens"]
|
||||||
|
|
||||||
|
async def test_write_uses_correct_table_name(self):
|
||||||
|
"""write_search_vector 使用正确的表名。"""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock()
|
||||||
|
|
||||||
|
await write_search_vector(mock_session, "c1", "content")
|
||||||
|
|
||||||
|
call_args = mock_session.execute.await_args
|
||||||
|
sql_obj = call_args.args[0]
|
||||||
|
# SQL 文本应包含表名
|
||||||
|
assert KB_CHUNKS_TABLE in str(sql_obj)
|
||||||
|
# 应使用 to_tsvector('simple', ...)
|
||||||
|
assert "to_tsvector" in str(sql_obj)
|
||||||
|
assert "search_vector" in str(sql_obj)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteSearchVectorBatch:
|
||||||
|
"""批量写入测试。"""
|
||||||
|
|
||||||
|
async def test_batch_writes_all_items(self):
|
||||||
|
"""批量写入调用 write_search_vector N 次。"""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock()
|
||||||
|
|
||||||
|
items = [
|
||||||
|
("chunk-1", "内容一"),
|
||||||
|
("chunk-2", "内容二"),
|
||||||
|
("chunk-3", "内容三"),
|
||||||
|
]
|
||||||
|
|
||||||
|
await write_search_vector_batch(mock_session, items)
|
||||||
|
|
||||||
|
# 应调用 execute 3 次
|
||||||
|
assert mock_session.execute.await_count == 3
|
||||||
|
|
||||||
|
async def test_batch_empty_items_no_calls(self):
|
||||||
|
"""空列表不调用 execute。"""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.execute = AsyncMock()
|
||||||
|
|
||||||
|
await write_search_vector_batch(mock_session, [])
|
||||||
|
|
||||||
|
mock_session.execute.assert_not_awaited()
|
||||||
|
|
@ -0,0 +1,467 @@
|
||||||
|
"""U4 测试 — 双索引检索引擎(embedding/keywords/blend)。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. embedding 模式:语义检索返回相关结果,按 kb_id 过滤
|
||||||
|
2. keywords 模式:中文全文检索返回包含关键词的结果
|
||||||
|
3. blend 模式:合并语义+全文结果,去重排序
|
||||||
|
4. 查询无结果时返回空列表(非报错)
|
||||||
|
5. RetrievalRequest 模型字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.rag_platform.models import QueryMode, QueryResult, RetrievalRequest
|
||||||
|
from agentkit.rag_platform.retrieval import RetrievalEngine, _normalize_scores
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# llama_index 模块 mock — 测试环境可能未安装 llama_index
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_llama_index_mocks():
|
||||||
|
"""注入 mock llama_index 模块到 sys.modules,使 import 成功。
|
||||||
|
|
||||||
|
仅当真实 llama_index 无法导入时才注入 mock — 避免污染已安装真实模块的环境。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import llama_index # noqa: F401
|
||||||
|
|
||||||
|
return # 真实模块可用,无需 mock
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 创建 mock 模块层级
|
||||||
|
mock_li = MagicMock()
|
||||||
|
mock_li_core = MagicMock()
|
||||||
|
mock_li_vector_stores = MagicMock()
|
||||||
|
mock_li_core_vector_stores_types = MagicMock()
|
||||||
|
mock_li_core_embeddings = MagicMock()
|
||||||
|
|
||||||
|
# VectorStoreQuery — 简单的可调用 mock
|
||||||
|
mock_li_core_vector_stores_types.VectorStoreQuery = MagicMock()
|
||||||
|
|
||||||
|
sys.modules["llama_index"] = mock_li
|
||||||
|
sys.modules["llama_index.core"] = mock_li_core
|
||||||
|
sys.modules["llama_index.vector_stores"] = mock_li_vector_stores
|
||||||
|
sys.modules["llama_index.core.vector_stores"] = mock_li_core
|
||||||
|
sys.modules["llama_index.core.vector_stores.types"] = mock_li_core_vector_stores_types
|
||||||
|
sys.modules["llama_index.core.embeddings"] = mock_li_core_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
_setup_llama_index_mocks()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 测试辅助函数
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_embed_model():
|
||||||
|
"""创建 mock embedding 模型。"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_text_node(node_id: str, content: str, metadata: dict | None = None):
|
||||||
|
"""创建 mock LlamaIndex TextNode。"""
|
||||||
|
node = MagicMock()
|
||||||
|
node.node_id = node_id
|
||||||
|
node.get_content.return_value = content
|
||||||
|
node.metadata = metadata or {}
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_vector_store(nodes=None, similarities=None):
|
||||||
|
"""创建 mock PGVectorStore。"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.aquery = AsyncMock()
|
||||||
|
if nodes is not None:
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.nodes = nodes
|
||||||
|
mock_result.similarities = similarities or [0.9] * len(nodes)
|
||||||
|
mock.aquery.return_value = mock_result
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_session_factory(execute_result_rows=None):
|
||||||
|
"""创建 mock session factory,用于 keywords 模式测试。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execute_result_rows: list of tuples — SQL 查询返回的行
|
||||||
|
每行格式: (id, text, metadata_, document_id, score)
|
||||||
|
"""
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.all.return_value = execute_result_rows or []
|
||||||
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def factory():
|
||||||
|
yield mock_session
|
||||||
|
|
||||||
|
return factory, mock_session
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RetrievalRequest 模型测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalRequest:
|
||||||
|
"""RetrievalRequest 模型测试。"""
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
"""默认值正确。"""
|
||||||
|
req = RetrievalRequest(query="test", kb_ids=["kb1"])
|
||||||
|
assert req.query == "test"
|
||||||
|
assert req.kb_ids == ["kb1"]
|
||||||
|
assert req.retrieval_mode is None
|
||||||
|
assert req.hit_processing_mode is None
|
||||||
|
assert req.top_k == 5
|
||||||
|
assert req.user_id is None
|
||||||
|
|
||||||
|
def test_explicit_values(self):
|
||||||
|
"""显式赋值正确。"""
|
||||||
|
req = RetrievalRequest(
|
||||||
|
query="hello",
|
||||||
|
kb_ids=["kb1", "kb2"],
|
||||||
|
retrieval_mode=QueryMode.embedding,
|
||||||
|
hit_processing_mode="direct",
|
||||||
|
top_k=10,
|
||||||
|
user_id="user1",
|
||||||
|
)
|
||||||
|
assert req.retrieval_mode == QueryMode.embedding
|
||||||
|
assert req.hit_processing_mode == "direct"
|
||||||
|
assert req.top_k == 10
|
||||||
|
assert req.user_id == "user1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _normalize_scores 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeScores:
|
||||||
|
"""分数归一化测试。"""
|
||||||
|
|
||||||
|
def test_empty_list(self):
|
||||||
|
"""空列表返回空列表。"""
|
||||||
|
assert _normalize_scores([]) == []
|
||||||
|
|
||||||
|
def test_constant_list_returns_ones(self):
|
||||||
|
"""常数列表(含单元素)返回全 1.0 — 等价于最高相关度。"""
|
||||||
|
assert _normalize_scores([0.5, 0.5, 0.5]) == [1.0, 1.0, 1.0]
|
||||||
|
assert _normalize_scores([0.9]) == [1.0]
|
||||||
|
|
||||||
|
def test_min_max_normalization(self):
|
||||||
|
"""min-max 归一化到 [0, 1]。"""
|
||||||
|
result = _normalize_scores([1.0, 2.0, 3.0])
|
||||||
|
assert result[0] == 0.0 # 最小值
|
||||||
|
assert result[2] == 1.0 # 最大值
|
||||||
|
assert 0.0 < result[1] < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# embedding 模式测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingRetrieval:
|
||||||
|
"""embedding 模式检索测试。"""
|
||||||
|
|
||||||
|
async def test_returns_relevant_results(self):
|
||||||
|
"""embedding 模式返回相关结果。"""
|
||||||
|
nodes = [
|
||||||
|
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
|
||||||
|
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
|
||||||
|
]
|
||||||
|
mock_vs = _make_mock_vector_store(nodes, [0.95, 0.80])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert all(isinstance(r, QueryResult) for r in results)
|
||||||
|
assert results[0].chunk_id == "n1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].kb_id == "kb1"
|
||||||
|
|
||||||
|
async def test_filters_by_kb_id(self):
|
||||||
|
"""embedding 模式按 kb_id 过滤结果。"""
|
||||||
|
# n1 属于 kb1,n2 属于 kb2(不在查询范围)
|
||||||
|
nodes = [
|
||||||
|
_make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
|
||||||
|
_make_mock_text_node("n2", "chunk 2", {"kb_id": "kb2"}),
|
||||||
|
]
|
||||||
|
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.9])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].kb_id == "kb1"
|
||||||
|
|
||||||
|
async def test_empty_results_no_error(self):
|
||||||
|
"""无结果时返回空列表(非报错)。"""
|
||||||
|
mock_vs = _make_mock_vector_store([], [])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
async def test_empty_kb_ids_returns_empty(self):
|
||||||
|
"""kb_ids 为空时直接返回空列表。"""
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", [], QueryMode.embedding, top_k=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
mock_vs.aquery.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_no_embed_model_raises(self):
|
||||||
|
"""embedding 模式未配置 embed_model 抛出异常。"""
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
try:
|
||||||
|
await engine.retrieve("query", ["kb1"], QueryMode.embedding, top_k=5)
|
||||||
|
raise AssertionError("Expected ValueError")
|
||||||
|
except ValueError as e:
|
||||||
|
assert "embed_model" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# keywords 模式测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeywordsRetrieval:
|
||||||
|
"""keywords 模式检索测试。"""
|
||||||
|
|
||||||
|
async def test_returns_matching_results(self):
|
||||||
|
"""keywords 模式返回包含关键词的结果。"""
|
||||||
|
rows = [
|
||||||
|
("c1", "自然语言处理内容", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
|
||||||
|
("c2", "另一段文本", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.5),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
results = await engine.retrieve("自然语言", ["kb1"], QueryMode.keywords, top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0].chunk_id == "c1"
|
||||||
|
assert results[0].score == 0.8
|
||||||
|
assert results[0].kb_id == "kb1"
|
||||||
|
assert "自然语言" in results[0].content
|
||||||
|
|
||||||
|
async def test_empty_results_no_error(self):
|
||||||
|
"""无结果时返回空列表(非报错)。"""
|
||||||
|
sf, _ = _make_mock_session_factory([])
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.keywords, top_k=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
async def test_empty_query_returns_empty(self):
|
||||||
|
"""空查询(分词后无 token)返回空列表,不执行 SQL。"""
|
||||||
|
sf, mock_session = _make_mock_session_factory([])
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
results = await engine.retrieve(" ", ["kb1"], QueryMode.keywords, top_k=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
mock_session.execute.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_passes_kb_ids_to_sql(self):
|
||||||
|
"""kb_ids 作为参数传递给 SQL 查询。"""
|
||||||
|
rows = [
|
||||||
|
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||||||
|
]
|
||||||
|
sf, mock_session = _make_mock_session_factory(rows)
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
await engine.retrieve("test", ["kb1", "kb2"], QueryMode.keywords, top_k=3)
|
||||||
|
|
||||||
|
call_args = mock_session.execute.await_args
|
||||||
|
params = call_args.args[1] if len(call_args.args) > 1 else call_args.kwargs
|
||||||
|
assert params["kb_ids"] == ["kb1", "kb2"]
|
||||||
|
assert params["top_k"] == 3
|
||||||
|
|
||||||
|
async def test_handles_none_metadata(self):
|
||||||
|
"""metadata 为 None 时正常处理。"""
|
||||||
|
rows = [
|
||||||
|
("c1", "content", None, "d1", 0.5),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
results = await engine.retrieve("test", ["kb1"], QueryMode.keywords, top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].metadata == {}
|
||||||
|
assert results[0].kb_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# blend 模式测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlendRetrieval:
|
||||||
|
"""blend 模式检索测试。"""
|
||||||
|
|
||||||
|
async def test_merges_and_deduplicates(self):
|
||||||
|
"""blend 模式合并语义+全文结果,按 chunk_id 去重。"""
|
||||||
|
# embedding 结果
|
||||||
|
nodes = [
|
||||||
|
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}),
|
||||||
|
_make_mock_text_node("c2", "chunk 2", {"kb_id": "kb1", "document_id": "d1"}),
|
||||||
|
]
|
||||||
|
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.7])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
|
||||||
|
# keywords 结果 — c1 重复,c3 新增
|
||||||
|
rows = [
|
||||||
|
("c1", "chunk 1", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.8),
|
||||||
|
("c3", "chunk 3", {"kb_id": "kb1", "document_id": "d1"}, "d1", 0.6),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||||||
|
|
||||||
|
# 去重后应有 3 个结果(c1, c2, c3)
|
||||||
|
chunk_ids = {r.chunk_id for r in results}
|
||||||
|
assert chunk_ids == {"c1", "c2", "c3"}
|
||||||
|
# 结果按分数降序
|
||||||
|
scores = [r.score for r in results]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
|
||||||
|
async def test_dedup_takes_highest_score(self):
|
||||||
|
"""去重时同 chunk_id 取最高分。"""
|
||||||
|
# embedding: c1 score=0.9 (归一化后 1.0 * 0.6 = 0.6)
|
||||||
|
nodes = [
|
||||||
|
_make_mock_text_node("c1", "chunk 1", {"kb_id": "kb1"}),
|
||||||
|
]
|
||||||
|
mock_vs = _make_mock_vector_store(nodes, [0.9])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
|
||||||
|
# keywords: c1 score=0.8 (归一化后 1.0 * 0.4 = 0.4)
|
||||||
|
rows = [
|
||||||
|
("c1", "chunk 1", {"kb_id": "kb1"}, "d1", 0.8),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||||||
|
|
||||||
|
# c1 只出现一次,取 embedding 的分数(0.6 > 0.4)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].chunk_id == "c1"
|
||||||
|
assert results[0].score == 0.6 # 1.0 * 0.6
|
||||||
|
|
||||||
|
async def test_empty_results_no_error(self):
|
||||||
|
"""两种检索都无结果时返回空列表。"""
|
||||||
|
mock_vs = _make_mock_vector_store([], [])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory([])
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
async def test_respects_top_k(self):
|
||||||
|
"""blend 模式遵守 top_k 限制。"""
|
||||||
|
# embedding 返回 3 个结果
|
||||||
|
nodes = [_make_mock_text_node(f"n{i}", f"chunk {i}", {"kb_id": "kb1"}) for i in range(3)]
|
||||||
|
mock_vs = _make_mock_vector_store(nodes, [0.9, 0.8, 0.7])
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
sf, _ = _make_mock_session_factory([]) # keywords 无结果
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=2)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
async def test_fallback_when_embedding_fails(self):
|
||||||
|
"""embedding 检索失败时降级为纯关键词检索。"""
|
||||||
|
# embedding 模式会抛异常(mock 设置)
|
||||||
|
mock_vs = MagicMock()
|
||||||
|
mock_vs.aquery = AsyncMock(side_effect=RuntimeError("connection failed"))
|
||||||
|
mock_embed = _make_mock_embed_model()
|
||||||
|
|
||||||
|
# keywords 返回结果
|
||||||
|
rows = [
|
||||||
|
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||||||
|
|
||||||
|
# 应降级为纯关键词结果
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].chunk_id == "c1"
|
||||||
|
|
||||||
|
async def test_fallback_when_no_embed_model(self):
|
||||||
|
"""未配置 embed_model 时 blend 降级为纯关键词检索。"""
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
mock_embed = None
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
("c1", "content", {"kb_id": "kb1"}, "d1", 0.5),
|
||||||
|
]
|
||||||
|
sf, _ = _make_mock_session_factory(rows)
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, mock_embed)
|
||||||
|
results = await engine.retrieve("query", ["kb1"], QueryMode.blend, top_k=5)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].chunk_id == "c1"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 不支持的 mode 测试
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUnsupportedMode:
|
||||||
|
"""不支持的检索模式测试。"""
|
||||||
|
|
||||||
|
async def test_unsupported_mode_raises(self):
|
||||||
|
"""不支持的 mode 抛出 ValueError。"""
|
||||||
|
mock_vs = _make_mock_vector_store()
|
||||||
|
sf, _ = _make_mock_session_factory()
|
||||||
|
|
||||||
|
engine = RetrievalEngine(mock_vs, sf, embed_model=None)
|
||||||
|
try:
|
||||||
|
await engine.retrieve("query", ["kb1"], "invalid_mode", top_k=5) # type: ignore[arg-type]
|
||||||
|
raise AssertionError("Expected ValueError")
|
||||||
|
except ValueError as e:
|
||||||
|
assert "Unsupported" in str(e)
|
||||||
Loading…
Reference in New Issue