213 lines
7.3 KiB
Python
213 lines
7.3 KiB
Python
"""
|
||
RAGService: 统一 RAG 服务入口
|
||
负责文档摄入(分块→embedding→存储)与知识检索。
|
||
"""
|
||
import hashlib
|
||
import logging
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import delete, select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
|
||
from .chunker import RecursiveChunker
|
||
from .embedder import EmbeddingService, MockEmbedder
|
||
from .retriever import HybridRetriever
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RAGService:
|
||
"""统一 RAG 服务入口"""
|
||
|
||
def __init__(self, embedder: Optional[EmbeddingService] = None):
|
||
self.chunker = RecursiveChunker()
|
||
if embedder is not None:
|
||
self.embedder = embedder
|
||
else:
|
||
from app.config import settings
|
||
if settings.OPENAI_API_KEY:
|
||
from app.services.knowledge.embedder import OpenAIEmbedder
|
||
self.embedder = OpenAIEmbedder(api_key=settings.OPENAI_API_KEY)
|
||
else:
|
||
logger.warning("未配置 OPENAI_API_KEY,知识库将使用 MockEmbedder(仅适用于开发环境)")
|
||
self.embedder = MockEmbedder()
|
||
self.retriever = HybridRetriever(self.embedder)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Ingest
|
||
# ------------------------------------------------------------------
|
||
|
||
async def ingest_document(self, session: AsyncSession, document_id: str) -> int:
|
||
"""
|
||
处理文档:分块 → embedding → 存储。
|
||
|
||
步骤:
|
||
1. 从 DB 读取 document(状态必须为 processing 或 ready)
|
||
2. 删除旧 chunks(重新处理场景)
|
||
3. 文本分块
|
||
4. 批量 embedding
|
||
5. 写入 knowledge_chunks
|
||
6. 更新 document.chunk_count 和 document.status = 'ready'
|
||
|
||
Returns:
|
||
生成的 chunk 数量
|
||
|
||
Raises:
|
||
ValueError: document 不存在
|
||
"""
|
||
doc_uuid = uuid.UUID(str(document_id))
|
||
|
||
# 1. 读取 document
|
||
stmt = select(KnowledgeDocument).where(KnowledgeDocument.id == doc_uuid)
|
||
result = await session.execute(stmt)
|
||
document: Optional[KnowledgeDocument] = result.scalar_one_or_none()
|
||
|
||
if document is None:
|
||
raise ValueError(f"Document {document_id} not found")
|
||
|
||
try:
|
||
# 2. 删除旧 chunks(幂等重处理)
|
||
await self._delete_chunks_by_document_id(session, doc_uuid)
|
||
|
||
# 3. 分块
|
||
chunk_dicts = self.chunker.chunk(
|
||
document.content,
|
||
metadata={"document_id": str(document_id), "title": document.title},
|
||
)
|
||
|
||
if not chunk_dicts:
|
||
# 空内容 → 直接标记 ready
|
||
await self._update_document_status(session, doc_uuid, "ready", chunk_count=0)
|
||
await session.commit()
|
||
return 0
|
||
|
||
# 4. 批量 embedding
|
||
texts = [c["content"] for c in chunk_dicts]
|
||
embeddings = await self.embedder.embed_batch(texts)
|
||
|
||
# 5. 写入 knowledge_chunks
|
||
chunks_to_insert = []
|
||
for chunk_dict, embedding in zip(chunk_dicts, embeddings):
|
||
chunk = KnowledgeChunk(
|
||
id=uuid.uuid4(),
|
||
document_id=doc_uuid,
|
||
content=chunk_dict["content"],
|
||
embedding=embedding,
|
||
chunk_index=chunk_dict["chunk_index"],
|
||
token_count=chunk_dict["token_count"],
|
||
extra_metadata=chunk_dict.get("metadata", {}),
|
||
)
|
||
chunks_to_insert.append(chunk)
|
||
|
||
session.add_all(chunks_to_insert)
|
||
|
||
# 6. 更新 document
|
||
await self._update_document_status(
|
||
session, doc_uuid, "ready", chunk_count=len(chunks_to_insert)
|
||
)
|
||
|
||
await session.commit()
|
||
logger.info(
|
||
f"Ingested document {document_id}: {len(chunks_to_insert)} chunks"
|
||
)
|
||
return len(chunks_to_insert)
|
||
|
||
except Exception as exc:
|
||
await session.rollback()
|
||
logger.error(f"Failed to ingest document {document_id}: {exc}")
|
||
# 标记文档为 failed
|
||
try:
|
||
await self._update_document_status(
|
||
session,
|
||
doc_uuid,
|
||
"failed",
|
||
error_message=str(exc),
|
||
)
|
||
await session.commit()
|
||
except Exception as inner_exc:
|
||
logger.error(f"Failed to update document status: {inner_exc}")
|
||
raise
|
||
|
||
# ------------------------------------------------------------------
|
||
# Search
|
||
# ------------------------------------------------------------------
|
||
|
||
async def search(
|
||
self,
|
||
session: AsyncSession,
|
||
query: str,
|
||
knowledge_base_ids: list[str],
|
||
top_k: int = 10,
|
||
) -> list[dict]:
|
||
"""
|
||
知识检索(委托给 HybridRetriever)。
|
||
|
||
Returns:
|
||
list of dicts:
|
||
{chunk_id, content, score, document_id, document_title, metadata}
|
||
"""
|
||
return await self.retriever.search(session, query, knowledge_base_ids, top_k)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Delete
|
||
# ------------------------------------------------------------------
|
||
|
||
async def delete_document_chunks(
|
||
self, session: AsyncSession, document_id: str
|
||
) -> int:
|
||
"""
|
||
删除文档的所有 chunks。
|
||
|
||
Returns:
|
||
删除的 chunk 数量
|
||
"""
|
||
doc_uuid = uuid.UUID(str(document_id))
|
||
deleted = await self._delete_chunks_by_document_id(session, doc_uuid)
|
||
await session.commit()
|
||
return deleted
|
||
|
||
# ------------------------------------------------------------------
|
||
# Utility
|
||
# ------------------------------------------------------------------
|
||
|
||
@staticmethod
|
||
def compute_content_hash(content: str) -> str:
|
||
"""计算内容 SHA-256 哈希(用于去重)"""
|
||
return hashlib.sha256(content.encode()).hexdigest()
|
||
|
||
# ------------------------------------------------------------------
|
||
# Private helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _delete_chunks_by_document_id(
|
||
self, session: AsyncSession, doc_uuid: uuid.UUID
|
||
) -> int:
|
||
"""删除指定 document 下所有 chunks,返回删除数量。"""
|
||
stmt = delete(KnowledgeChunk).where(KnowledgeChunk.document_id == doc_uuid)
|
||
result = await session.execute(stmt)
|
||
return result.rowcount # type: ignore[return-value]
|
||
|
||
async def _update_document_status(
|
||
self,
|
||
session: AsyncSession,
|
||
doc_uuid: uuid.UUID,
|
||
status: str,
|
||
chunk_count: Optional[int] = None,
|
||
error_message: Optional[str] = None,
|
||
) -> None:
|
||
"""更新 document 状态字段。"""
|
||
values: dict = {"status": status}
|
||
if chunk_count is not None:
|
||
values["chunk_count"] = chunk_count
|
||
if error_message is not None:
|
||
values["error_message"] = error_message
|
||
|
||
stmt = (
|
||
update(KnowledgeDocument)
|
||
.where(KnowledgeDocument.id == doc_uuid)
|
||
.values(**values)
|
||
)
|
||
await session.execute(stmt)
|