242 lines
7.8 KiB
Python
242 lines
7.8 KiB
Python
"""
|
||
HybridRetriever: 混合检索器
|
||
策略:pgvector 向量搜索 + PostgreSQL ILIKE 关键词搜索 + RRF(Reciprocal Rank Fusion)融合排序
|
||
"""
|
||
import logging
|
||
import uuid
|
||
from typing import TYPE_CHECKING
|
||
|
||
from sqlalchemy import text
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
if TYPE_CHECKING:
|
||
from .embedder import EmbeddingService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class HybridRetriever:
|
||
"""混合检索器:向量搜索 + 关键词搜索 + RRF 融合"""
|
||
|
||
def __init__(
|
||
self,
|
||
embedder: "EmbeddingService",
|
||
vector_weight: float = 0.7,
|
||
keyword_weight: float = 0.3,
|
||
):
|
||
self.embedder = embedder
|
||
self.vector_weight = vector_weight
|
||
self.keyword_weight = keyword_weight
|
||
|
||
# ------------------------------------------------------------------
|
||
# Public API
|
||
# ------------------------------------------------------------------
|
||
|
||
async def search(
|
||
self,
|
||
session: AsyncSession,
|
||
query: str,
|
||
knowledge_base_ids: list[str],
|
||
top_k: int = 10,
|
||
) -> list[dict]:
|
||
"""
|
||
混合检索:
|
||
1. 获取 query embedding
|
||
2. 向量搜索(pgvector cosine distance)
|
||
3. 关键词搜索(ILIKE)
|
||
4. RRF 融合排序
|
||
|
||
Returns:
|
||
list of dicts:
|
||
{
|
||
"chunk_id": str,
|
||
"content": str,
|
||
"score": float,
|
||
"document_id": str,
|
||
"document_title": str,
|
||
"metadata": dict,
|
||
}
|
||
"""
|
||
if not query or not knowledge_base_ids:
|
||
return []
|
||
|
||
# 并发执行向量检索和关键词检索
|
||
# (asyncio 环境中顺序 await 即可,DB session 不支持真并发)
|
||
query_embedding = await self.embedder.embed(query)
|
||
vector_results = await self._vector_search(
|
||
session, query_embedding, knowledge_base_ids, top_k * 2
|
||
)
|
||
keyword_results = await self._keyword_search(
|
||
session, query, knowledge_base_ids, top_k * 2
|
||
)
|
||
|
||
fused = self._rrf_fusion(vector_results, keyword_results)
|
||
return fused[:top_k]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal: vector search
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _vector_search(
|
||
self,
|
||
session: AsyncSession,
|
||
query_embedding: list[float],
|
||
kb_ids: list[str],
|
||
top_k: int,
|
||
) -> list[dict]:
|
||
"""
|
||
pgvector 向量搜索(cosine distance <=>)。
|
||
使用原生 SQL 以支持 pgvector 操作符。
|
||
"""
|
||
# 将 embedding list 转为 pgvector 字符串格式
|
||
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
|
||
|
||
# kb_ids 转为 UUID 数组,兼容 PostgreSQL ANY()
|
||
kb_uuid_strs = [str(uid) for uid in kb_ids]
|
||
|
||
sql = text(
|
||
"""
|
||
SELECT
|
||
kc.id AS chunk_id,
|
||
kc.content AS content,
|
||
kd.id AS document_id,
|
||
kd.title AS document_title,
|
||
kc.metadata AS metadata,
|
||
1 - (kc.embedding <=> CAST(:query_vec AS vector)) AS score
|
||
FROM knowledge_chunks kc
|
||
JOIN knowledge_documents kd ON kc.document_id = kd.id
|
||
WHERE kd.knowledge_base_id = ANY(CAST(:kb_ids AS uuid[]))
|
||
AND kd.status = 'ready'
|
||
AND kc.embedding IS NOT NULL
|
||
ORDER BY kc.embedding <=> CAST(:query_vec AS vector)
|
||
LIMIT :top_k
|
||
"""
|
||
)
|
||
|
||
try:
|
||
result = await session.execute(
|
||
sql,
|
||
{
|
||
"query_vec": embedding_str,
|
||
"kb_ids": "{" + ",".join(kb_uuid_strs) + "}",
|
||
"top_k": top_k,
|
||
},
|
||
)
|
||
rows = result.mappings().all()
|
||
except Exception as e:
|
||
logger.warning(f"Vector search failed: {e}; falling back to empty results")
|
||
return []
|
||
|
||
return [
|
||
{
|
||
"chunk_id": str(row["chunk_id"]),
|
||
"content": row["content"],
|
||
"score": float(row["score"]) if row["score"] is not None else 0.0,
|
||
"document_id": str(row["document_id"]),
|
||
"document_title": row["document_title"],
|
||
"metadata": row["metadata"] or {},
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal: keyword search
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _keyword_search(
|
||
self,
|
||
session: AsyncSession,
|
||
query: str,
|
||
kb_ids: list[str],
|
||
top_k: int,
|
||
) -> list[dict]:
|
||
"""
|
||
关键词搜索(ILIKE 模糊匹配)。
|
||
后续可升级为 ts_vector 全文检索。
|
||
"""
|
||
kb_uuid_strs = [str(uid) for uid in kb_ids]
|
||
like_pattern = f"%{query}%"
|
||
|
||
sql = text(
|
||
"""
|
||
SELECT
|
||
kc.id AS chunk_id,
|
||
kc.content AS content,
|
||
kd.id AS document_id,
|
||
kd.title AS document_title,
|
||
kc.metadata AS metadata,
|
||
1.0 AS score
|
||
FROM knowledge_chunks kc
|
||
JOIN knowledge_documents kd ON kc.document_id = kd.id
|
||
WHERE kd.knowledge_base_id = ANY(CAST(:kb_ids AS uuid[]))
|
||
AND kd.status = 'ready'
|
||
AND kc.content ILIKE :pattern
|
||
LIMIT :top_k
|
||
"""
|
||
)
|
||
|
||
try:
|
||
result = await session.execute(
|
||
sql,
|
||
{
|
||
"kb_ids": "{" + ",".join(kb_uuid_strs) + "}",
|
||
"pattern": like_pattern,
|
||
"top_k": top_k,
|
||
},
|
||
)
|
||
rows = result.mappings().all()
|
||
except Exception as e:
|
||
logger.warning(f"Keyword search failed: {e}; falling back to empty results")
|
||
return []
|
||
|
||
return [
|
||
{
|
||
"chunk_id": str(row["chunk_id"]),
|
||
"content": row["content"],
|
||
"score": 1.0,
|
||
"document_id": str(row["document_id"]),
|
||
"document_title": row["document_title"],
|
||
"metadata": row["metadata"] or {},
|
||
}
|
||
for row in rows
|
||
]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal: RRF fusion
|
||
# ------------------------------------------------------------------
|
||
|
||
def _rrf_fusion(
|
||
self,
|
||
vector_results: list[dict],
|
||
keyword_results: list[dict],
|
||
k: int = 60,
|
||
) -> list[dict]:
|
||
"""
|
||
Reciprocal Rank Fusion (RRF)。
|
||
score = Σ ( weight_i / (k + rank_i) ) 对所有结果列表求和
|
||
"""
|
||
scores: dict[str, float] = {}
|
||
chunk_data: dict[str, dict] = {}
|
||
|
||
def _accumulate(results: list[dict], weight: float) -> None:
|
||
for rank, item in enumerate(results, start=1):
|
||
cid = item["chunk_id"]
|
||
rrf_score = weight / (k + rank)
|
||
scores[cid] = scores.get(cid, 0.0) + rrf_score
|
||
if cid not in chunk_data:
|
||
chunk_data[cid] = item
|
||
|
||
_accumulate(vector_results, self.vector_weight)
|
||
_accumulate(keyword_results, self.keyword_weight)
|
||
|
||
# 按 RRF 分数降序排列
|
||
sorted_ids = sorted(scores, key=lambda cid: scores[cid], reverse=True)
|
||
|
||
fused = []
|
||
for cid in sorted_ids:
|
||
item = dict(chunk_data[cid])
|
||
item["score"] = round(scores[cid], 6)
|
||
fused.append(item)
|
||
|
||
return fused
|