geo/backend/app/services/knowledge/rag_service.py

213 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RAGService: 统一 RAG 服务入口
负责文档摄入分块→embedding→存储与知识检索。
"""
import hashlib
import logging
import uuid
from typing import Optional
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
from .chunker import RecursiveChunker
from .embedder import EmbeddingService, MockEmbedder
from .retriever import HybridRetriever
logger = logging.getLogger(__name__)
class RAGService:
"""统一 RAG 服务入口"""
def __init__(self, embedder: Optional[EmbeddingService] = None):
self.chunker = RecursiveChunker()
if embedder is not None:
self.embedder = embedder
else:
from app.config import settings
if settings.OPENAI_API_KEY:
from app.services.knowledge.embedder import OpenAIEmbedder
self.embedder = OpenAIEmbedder(api_key=settings.OPENAI_API_KEY)
else:
logger.warning("未配置 OPENAI_API_KEY知识库将使用 MockEmbedder仅适用于开发环境")
self.embedder = MockEmbedder()
self.retriever = HybridRetriever(self.embedder)
# ------------------------------------------------------------------
# Ingest
# ------------------------------------------------------------------
async def ingest_document(self, session: AsyncSession, document_id: str) -> int:
"""
处理文档:分块 → embedding → 存储。
步骤:
1. 从 DB 读取 document状态必须为 processing 或 ready
2. 删除旧 chunks重新处理场景
3. 文本分块
4. 批量 embedding
5. 写入 knowledge_chunks
6. 更新 document.chunk_count 和 document.status = 'ready'
Returns:
生成的 chunk 数量
Raises:
ValueError: document 不存在
"""
doc_uuid = uuid.UUID(str(document_id))
# 1. 读取 document
stmt = select(KnowledgeDocument).where(KnowledgeDocument.id == doc_uuid)
result = await session.execute(stmt)
document: Optional[KnowledgeDocument] = result.scalar_one_or_none()
if document is None:
raise ValueError(f"Document {document_id} not found")
try:
# 2. 删除旧 chunks幂等重处理
await self._delete_chunks_by_document_id(session, doc_uuid)
# 3. 分块
chunk_dicts = self.chunker.chunk(
document.content,
metadata={"document_id": str(document_id), "title": document.title},
)
if not chunk_dicts:
# 空内容 → 直接标记 ready
await self._update_document_status(session, doc_uuid, "ready", chunk_count=0)
await session.commit()
return 0
# 4. 批量 embedding
texts = [c["content"] for c in chunk_dicts]
embeddings = await self.embedder.embed_batch(texts)
# 5. 写入 knowledge_chunks
chunks_to_insert = []
for chunk_dict, embedding in zip(chunk_dicts, embeddings):
chunk = KnowledgeChunk(
id=uuid.uuid4(),
document_id=doc_uuid,
content=chunk_dict["content"],
embedding=embedding,
chunk_index=chunk_dict["chunk_index"],
token_count=chunk_dict["token_count"],
extra_metadata=chunk_dict.get("metadata", {}),
)
chunks_to_insert.append(chunk)
session.add_all(chunks_to_insert)
# 6. 更新 document
await self._update_document_status(
session, doc_uuid, "ready", chunk_count=len(chunks_to_insert)
)
await session.commit()
logger.info(
f"Ingested document {document_id}: {len(chunks_to_insert)} chunks"
)
return len(chunks_to_insert)
except Exception as exc:
await session.rollback()
logger.error(f"Failed to ingest document {document_id}: {exc}")
# 标记文档为 failed
try:
await self._update_document_status(
session,
doc_uuid,
"failed",
error_message=str(exc),
)
await session.commit()
except Exception as inner_exc:
logger.error(f"Failed to update document status: {inner_exc}")
raise
# ------------------------------------------------------------------
# Search
# ------------------------------------------------------------------
async def search(
self,
session: AsyncSession,
query: str,
knowledge_base_ids: list[str],
top_k: int = 10,
) -> list[dict]:
"""
知识检索(委托给 HybridRetriever
Returns:
list of dicts:
{chunk_id, content, score, document_id, document_title, metadata}
"""
return await self.retriever.search(session, query, knowledge_base_ids, top_k)
# ------------------------------------------------------------------
# Delete
# ------------------------------------------------------------------
async def delete_document_chunks(
self, session: AsyncSession, document_id: str
) -> int:
"""
删除文档的所有 chunks。
Returns:
删除的 chunk 数量
"""
doc_uuid = uuid.UUID(str(document_id))
deleted = await self._delete_chunks_by_document_id(session, doc_uuid)
await session.commit()
return deleted
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
@staticmethod
def compute_content_hash(content: str) -> str:
"""计算内容 SHA-256 哈希(用于去重)"""
return hashlib.sha256(content.encode()).hexdigest()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _delete_chunks_by_document_id(
self, session: AsyncSession, doc_uuid: uuid.UUID
) -> int:
"""删除指定 document 下所有 chunks返回删除数量。"""
stmt = delete(KnowledgeChunk).where(KnowledgeChunk.document_id == doc_uuid)
result = await session.execute(stmt)
return result.rowcount # type: ignore[return-value]
async def _update_document_status(
self,
session: AsyncSession,
doc_uuid: uuid.UUID,
status: str,
chunk_count: Optional[int] = None,
error_message: Optional[str] = None,
) -> None:
"""更新 document 状态字段。"""
values: dict = {"status": status}
if chunk_count is not None:
values["chunk_count"] = chunk_count
if error_message is not None:
values["error_message"] = error_message
stmt = (
update(KnowledgeDocument)
.where(KnowledgeDocument.id == doc_uuid)
.values(**values)
)
await session.execute(stmt)