149 lines
4.1 KiB
Python
149 lines
4.1 KiB
Python
"""
|
||
增强版RAG服务 - 包含重排序和上下文压缩
|
||
"""
|
||
import re
|
||
from typing import Optional
|
||
|
||
from app.services.llm.factory import LLMFactory
|
||
|
||
|
||
class EnhancedRAG:
|
||
"""增强版RAG检索服务"""
|
||
|
||
def __init__(self, rag_service, embedder):
|
||
self.rag = rag_service
|
||
self.embedder = embedder
|
||
self.llm = LLMFactory.create()
|
||
|
||
async def retrieve_with_rerank(
|
||
self,
|
||
session,
|
||
query: str,
|
||
kb_ids: list[str],
|
||
top_k: int = 5,
|
||
use_rerank: bool = True,
|
||
use_compression: bool = False,
|
||
) -> list[dict]:
|
||
"""
|
||
增强检索流程:
|
||
1. 初始检索(扩大候选集)
|
||
2. 可选:重排序
|
||
3. 可选:上下文压缩
|
||
"""
|
||
# Step 1: 初始检索
|
||
initial_k = top_k * 4 if use_rerank else top_k
|
||
candidates = await self.rag.search(
|
||
session, query, kb_ids, top_k=initial_k
|
||
)
|
||
|
||
if not candidates:
|
||
return []
|
||
|
||
# Step 2: 可选重排序
|
||
if use_rerank and len(candidates) > top_k:
|
||
candidates = await self._rerank(query, candidates, top_k)
|
||
|
||
# Step 3: 可选上下文压缩
|
||
if use_compression:
|
||
candidates = await self._compress(candidates, query)
|
||
|
||
return candidates[:top_k]
|
||
|
||
async def _rerank(
|
||
self,
|
||
query: str,
|
||
candidates: list[dict],
|
||
top_k: int,
|
||
) -> list[dict]:
|
||
"""
|
||
使用LLM进行相关性重排序
|
||
|
||
对每个候选计算与查询的相关性分数,然后排序
|
||
"""
|
||
reranked = []
|
||
|
||
for item in candidates:
|
||
# 提取候选内容片段
|
||
content = item.get("content", "")[:500] # 限制长度
|
||
|
||
# 构建评估Prompt
|
||
prompt = f"""评估以下查询与文档片段的相关性。
|
||
|
||
查询:{query}
|
||
|
||
文档片段:
|
||
{content}
|
||
|
||
请只返回一个0到1之间的小数,表示相关性分数。0表示完全不相关,1表示完全相关。只返回数字:"""
|
||
|
||
try:
|
||
response = await self.llm.generate(prompt)
|
||
# 提取数字
|
||
match = re.search(r'0?\.\d+', response)
|
||
relevance = float(match.group()) if match else 0.5
|
||
except Exception:
|
||
relevance = item.get("score", 0.5)
|
||
|
||
item["relevance_score"] = relevance
|
||
reranked.append(item)
|
||
|
||
# 按相关性分数降序排序
|
||
reranked.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||
|
||
return reranked
|
||
|
||
async def _compress(
|
||
self,
|
||
candidates: list[dict],
|
||
query: str,
|
||
max_context_tokens: int = 2000,
|
||
) -> list[dict]:
|
||
"""
|
||
上下文压缩
|
||
|
||
从每个chunk中提取与query相关的内容,减少token消耗
|
||
"""
|
||
compressed = []
|
||
current_tokens = 0
|
||
|
||
for chunk in candidates:
|
||
content = chunk.get("content", "")
|
||
|
||
# 估算token(中文约1.5字符/token)
|
||
est_tokens = len(content) // 2
|
||
|
||
if current_tokens + est_tokens <= max_context_tokens:
|
||
chunk["compressed"] = False
|
||
compressed.append(chunk)
|
||
current_tokens += est_tokens
|
||
else:
|
||
# 尝试压缩这个chunk
|
||
compressed_chunk = await self._compress_chunk(content, query)
|
||
chunk["compressed"] = True
|
||
chunk["compressed_content"] = compressed_chunk
|
||
compressed.append(chunk)
|
||
break
|
||
|
||
return compressed
|
||
|
||
async def _compress_chunk(
|
||
self,
|
||
content: str,
|
||
query: str,
|
||
) -> str:
|
||
"""压缩单个chunk,保留与query相关的内容"""
|
||
prompt = f"""从以下文本中提取与问题最相关的部分,保持原文的表达方式,不要总结或改写。
|
||
|
||
问题:{query}
|
||
|
||
原文:
|
||
{content}
|
||
|
||
直接返回提取的内容(不要解释):"""
|
||
|
||
try:
|
||
result = await self.llm.generate(prompt)
|
||
return result.strip()
|
||
except Exception:
|
||
# 压缩失败时返回原文
|
||
return content[:500] + "..." |