geo/backend/app/services/knowledge/enhanced_rag.py

149 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
增强版RAG服务 - 包含重排序和上下文压缩
"""
import re
from typing import Optional
from app.services.llm.factory import LLMFactory
class EnhancedRAG:
"""增强版RAG检索服务"""
def __init__(self, rag_service, embedder):
self.rag = rag_service
self.embedder = embedder
self.llm = LLMFactory.create()
async def retrieve_with_rerank(
self,
session,
query: str,
kb_ids: list[str],
top_k: int = 5,
use_rerank: bool = True,
use_compression: bool = False,
) -> list[dict]:
"""
增强检索流程:
1. 初始检索(扩大候选集)
2. 可选:重排序
3. 可选:上下文压缩
"""
# Step 1: 初始检索
initial_k = top_k * 4 if use_rerank else top_k
candidates = await self.rag.search(
session, query, kb_ids, top_k=initial_k
)
if not candidates:
return []
# Step 2: 可选重排序
if use_rerank and len(candidates) > top_k:
candidates = await self._rerank(query, candidates, top_k)
# Step 3: 可选上下文压缩
if use_compression:
candidates = await self._compress(candidates, query)
return candidates[:top_k]
async def _rerank(
self,
query: str,
candidates: list[dict],
top_k: int,
) -> list[dict]:
"""
使用LLM进行相关性重排序
对每个候选计算与查询的相关性分数,然后排序
"""
reranked = []
for item in candidates:
# 提取候选内容片段
content = item.get("content", "")[:500] # 限制长度
# 构建评估Prompt
prompt = f"""评估以下查询与文档片段的相关性。
查询:{query}
文档片段:
{content}
请只返回一个0到1之间的小数表示相关性分数。0表示完全不相关1表示完全相关。只返回数字"""
try:
response = await self.llm.generate(prompt)
# 提取数字
match = re.search(r'0?\.\d+', response)
relevance = float(match.group()) if match else 0.5
except Exception:
relevance = item.get("score", 0.5)
item["relevance_score"] = relevance
reranked.append(item)
# 按相关性分数降序排序
reranked.sort(key=lambda x: x["relevance_score"], reverse=True)
return reranked
async def _compress(
self,
candidates: list[dict],
query: str,
max_context_tokens: int = 2000,
) -> list[dict]:
"""
上下文压缩
从每个chunk中提取与query相关的内容减少token消耗
"""
compressed = []
current_tokens = 0
for chunk in candidates:
content = chunk.get("content", "")
# 估算token中文约1.5字符/token
est_tokens = len(content) // 2
if current_tokens + est_tokens <= max_context_tokens:
chunk["compressed"] = False
compressed.append(chunk)
current_tokens += est_tokens
else:
# 尝试压缩这个chunk
compressed_chunk = await self._compress_chunk(content, query)
chunk["compressed"] = True
chunk["compressed_content"] = compressed_chunk
compressed.append(chunk)
break
return compressed
async def _compress_chunk(
self,
content: str,
query: str,
) -> str:
"""压缩单个chunk保留与query相关的内容"""
prompt = f"""从以下文本中提取与问题最相关的部分,保持原文的表达方式,不要总结或改写。
问题:{query}
原文:
{content}
直接返回提取的内容(不要解释):"""
try:
result = await self.llm.generate(prompt)
return result.strip()
except Exception:
# 压缩失败时返回原文
return content[:500] + "..."