""" RAGService: 统一 RAG 服务入口 负责文档摄入(分块→embedding→存储)与知识检索。 """ import hashlib import logging import uuid from typing import Optional from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.models.knowledge import KnowledgeChunk, KnowledgeDocument from .chunker import RecursiveChunker from .embedder import EmbeddingService, MockEmbedder from .retriever import HybridRetriever logger = logging.getLogger(__name__) class RAGService: """统一 RAG 服务入口""" def __init__(self, embedder: Optional[EmbeddingService] = None): self.chunker = RecursiveChunker() if embedder is not None: self.embedder = embedder else: from app.config import settings if settings.OPENAI_API_KEY: from app.services.knowledge.embedder import OpenAIEmbedder self.embedder = OpenAIEmbedder(api_key=settings.OPENAI_API_KEY) else: logger.warning("未配置 OPENAI_API_KEY,知识库将使用 MockEmbedder(仅适用于开发环境)") self.embedder = MockEmbedder() self.retriever = HybridRetriever(self.embedder) # ------------------------------------------------------------------ # Ingest # ------------------------------------------------------------------ async def ingest_document(self, session: AsyncSession, document_id: str) -> int: """ 处理文档:分块 → embedding → 存储。 步骤: 1. 从 DB 读取 document(状态必须为 processing 或 ready) 2. 删除旧 chunks(重新处理场景) 3. 文本分块 4. 批量 embedding 5. 写入 knowledge_chunks 6. 更新 document.chunk_count 和 document.status = 'ready' Returns: 生成的 chunk 数量 Raises: ValueError: document 不存在 """ doc_uuid = uuid.UUID(str(document_id)) # 1. 读取 document stmt = select(KnowledgeDocument).where(KnowledgeDocument.id == doc_uuid) result = await session.execute(stmt) document: Optional[KnowledgeDocument] = result.scalar_one_or_none() if document is None: raise ValueError(f"Document {document_id} not found") try: # 2. 删除旧 chunks(幂等重处理) await self._delete_chunks_by_document_id(session, doc_uuid) # 3. 分块 chunk_dicts = self.chunker.chunk( document.content, metadata={"document_id": str(document_id), "title": document.title}, ) if not chunk_dicts: # 空内容 → 直接标记 ready await self._update_document_status(session, doc_uuid, "ready", chunk_count=0) await session.commit() return 0 # 4. 批量 embedding texts = [c["content"] for c in chunk_dicts] embeddings = await self.embedder.embed_batch(texts) # 5. 写入 knowledge_chunks chunks_to_insert = [] for chunk_dict, embedding in zip(chunk_dicts, embeddings): chunk = KnowledgeChunk( id=uuid.uuid4(), document_id=doc_uuid, content=chunk_dict["content"], embedding=embedding, chunk_index=chunk_dict["chunk_index"], token_count=chunk_dict["token_count"], extra_metadata=chunk_dict.get("metadata", {}), ) chunks_to_insert.append(chunk) session.add_all(chunks_to_insert) # 6. 更新 document await self._update_document_status( session, doc_uuid, "ready", chunk_count=len(chunks_to_insert) ) await session.commit() logger.info( f"Ingested document {document_id}: {len(chunks_to_insert)} chunks" ) return len(chunks_to_insert) except Exception as exc: await session.rollback() logger.error(f"Failed to ingest document {document_id}: {exc}") # 标记文档为 failed try: await self._update_document_status( session, doc_uuid, "failed", error_message=str(exc), ) await session.commit() except Exception as inner_exc: logger.error(f"Failed to update document status: {inner_exc}") raise # ------------------------------------------------------------------ # Search # ------------------------------------------------------------------ async def search( self, session: AsyncSession, query: str, knowledge_base_ids: list[str], top_k: int = 10, ) -> list[dict]: """ 知识检索(委托给 HybridRetriever)。 Returns: list of dicts: {chunk_id, content, score, document_id, document_title, metadata} """ return await self.retriever.search(session, query, knowledge_base_ids, top_k) # ------------------------------------------------------------------ # Delete # ------------------------------------------------------------------ async def delete_document_chunks( self, session: AsyncSession, document_id: str ) -> int: """ 删除文档的所有 chunks。 Returns: 删除的 chunk 数量 """ doc_uuid = uuid.UUID(str(document_id)) deleted = await self._delete_chunks_by_document_id(session, doc_uuid) await session.commit() return deleted # ------------------------------------------------------------------ # Utility # ------------------------------------------------------------------ @staticmethod def compute_content_hash(content: str) -> str: """计算内容 SHA-256 哈希(用于去重)""" return hashlib.sha256(content.encode()).hexdigest() # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ async def _delete_chunks_by_document_id( self, session: AsyncSession, doc_uuid: uuid.UUID ) -> int: """删除指定 document 下所有 chunks,返回删除数量。""" stmt = delete(KnowledgeChunk).where(KnowledgeChunk.document_id == doc_uuid) result = await session.execute(stmt) return result.rowcount # type: ignore[return-value] async def _update_document_status( self, session: AsyncSession, doc_uuid: uuid.UUID, status: str, chunk_count: Optional[int] = None, error_message: Optional[str] = None, ) -> None: """更新 document 状态字段。""" values: dict = {"status": status} if chunk_count is not None: values["chunk_count"] = chunk_count if error_message is not None: values["error_message"] = error_message stmt = ( update(KnowledgeDocument) .where(KnowledgeDocument.id == doc_uuid) .values(**values) ) await session.execute(stmt)