geo/backend/app/services/knowledge/retriever.py

242 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
HybridRetriever: 混合检索器
策略pgvector 向量搜索 + PostgreSQL ILIKE 关键词搜索 + RRFReciprocal Rank Fusion融合排序
"""
import logging
import uuid
from typing import TYPE_CHECKING
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
if TYPE_CHECKING:
from .embedder import EmbeddingService
logger = logging.getLogger(__name__)
class HybridRetriever:
"""混合检索器:向量搜索 + 关键词搜索 + RRF 融合"""
def __init__(
self,
embedder: "EmbeddingService",
vector_weight: float = 0.7,
keyword_weight: float = 0.3,
):
self.embedder = embedder
self.vector_weight = vector_weight
self.keyword_weight = keyword_weight
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def search(
self,
session: AsyncSession,
query: str,
knowledge_base_ids: list[str],
top_k: int = 10,
) -> list[dict]:
"""
混合检索:
1. 获取 query embedding
2. 向量搜索pgvector cosine distance
3. 关键词搜索ILIKE
4. RRF 融合排序
Returns:
list of dicts:
{
"chunk_id": str,
"content": str,
"score": float,
"document_id": str,
"document_title": str,
"metadata": dict,
}
"""
if not query or not knowledge_base_ids:
return []
# 并发执行向量检索和关键词检索
# asyncio 环境中顺序 await 即可DB session 不支持真并发)
query_embedding = await self.embedder.embed(query)
vector_results = await self._vector_search(
session, query_embedding, knowledge_base_ids, top_k * 2
)
keyword_results = await self._keyword_search(
session, query, knowledge_base_ids, top_k * 2
)
fused = self._rrf_fusion(vector_results, keyword_results)
return fused[:top_k]
# ------------------------------------------------------------------
# Internal: vector search
# ------------------------------------------------------------------
async def _vector_search(
self,
session: AsyncSession,
query_embedding: list[float],
kb_ids: list[str],
top_k: int,
) -> list[dict]:
"""
pgvector 向量搜索cosine distance <=>)。
使用原生 SQL 以支持 pgvector 操作符。
"""
# 将 embedding list 转为 pgvector 字符串格式
embedding_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
# kb_ids 转为 UUID 数组,兼容 PostgreSQL ANY()
kb_uuid_strs = [str(uid) for uid in kb_ids]
sql = text(
"""
SELECT
kc.id AS chunk_id,
kc.content AS content,
kd.id AS document_id,
kd.title AS document_title,
kc.metadata AS metadata,
1 - (kc.embedding <=> CAST(:query_vec AS vector)) AS score
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.document_id = kd.id
WHERE kd.knowledge_base_id = ANY(CAST(:kb_ids AS uuid[]))
AND kd.status = 'ready'
AND kc.embedding IS NOT NULL
ORDER BY kc.embedding <=> CAST(:query_vec AS vector)
LIMIT :top_k
"""
)
try:
result = await session.execute(
sql,
{
"query_vec": embedding_str,
"kb_ids": "{" + ",".join(kb_uuid_strs) + "}",
"top_k": top_k,
},
)
rows = result.mappings().all()
except Exception as e:
logger.warning(f"Vector search failed: {e}; falling back to empty results")
return []
return [
{
"chunk_id": str(row["chunk_id"]),
"content": row["content"],
"score": float(row["score"]) if row["score"] is not None else 0.0,
"document_id": str(row["document_id"]),
"document_title": row["document_title"],
"metadata": row["metadata"] or {},
}
for row in rows
]
# ------------------------------------------------------------------
# Internal: keyword search
# ------------------------------------------------------------------
async def _keyword_search(
self,
session: AsyncSession,
query: str,
kb_ids: list[str],
top_k: int,
) -> list[dict]:
"""
关键词搜索ILIKE 模糊匹配)。
后续可升级为 ts_vector 全文检索。
"""
kb_uuid_strs = [str(uid) for uid in kb_ids]
like_pattern = f"%{query}%"
sql = text(
"""
SELECT
kc.id AS chunk_id,
kc.content AS content,
kd.id AS document_id,
kd.title AS document_title,
kc.metadata AS metadata,
1.0 AS score
FROM knowledge_chunks kc
JOIN knowledge_documents kd ON kc.document_id = kd.id
WHERE kd.knowledge_base_id = ANY(CAST(:kb_ids AS uuid[]))
AND kd.status = 'ready'
AND kc.content ILIKE :pattern
LIMIT :top_k
"""
)
try:
result = await session.execute(
sql,
{
"kb_ids": "{" + ",".join(kb_uuid_strs) + "}",
"pattern": like_pattern,
"top_k": top_k,
},
)
rows = result.mappings().all()
except Exception as e:
logger.warning(f"Keyword search failed: {e}; falling back to empty results")
return []
return [
{
"chunk_id": str(row["chunk_id"]),
"content": row["content"],
"score": 1.0,
"document_id": str(row["document_id"]),
"document_title": row["document_title"],
"metadata": row["metadata"] or {},
}
for row in rows
]
# ------------------------------------------------------------------
# Internal: RRF fusion
# ------------------------------------------------------------------
def _rrf_fusion(
self,
vector_results: list[dict],
keyword_results: list[dict],
k: int = 60,
) -> list[dict]:
"""
Reciprocal Rank Fusion (RRF)。
score = Σ ( weight_i / (k + rank_i) ) 对所有结果列表求和
"""
scores: dict[str, float] = {}
chunk_data: dict[str, dict] = {}
def _accumulate(results: list[dict], weight: float) -> None:
for rank, item in enumerate(results, start=1):
cid = item["chunk_id"]
rrf_score = weight / (k + rank)
scores[cid] = scores.get(cid, 0.0) + rrf_score
if cid not in chunk_data:
chunk_data[cid] = item
_accumulate(vector_results, self.vector_weight)
_accumulate(keyword_results, self.keyword_weight)
# 按 RRF 分数降序排列
sorted_ids = sorted(scores, key=lambda cid: scores[cid], reverse=True)
fused = []
for cid in sorted_ids:
item = dict(chunk_data[cid])
item["score"] = round(scores[cid], 6)
fused.append(item)
return fused