242 lines
6.5 KiB
Python
242 lines
6.5 KiB
Python
"""
|
||
增量索引服务
|
||
"""
|
||
import hashlib
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import delete, func, select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
|
||
|
||
|
||
class IncrementalIndexService:
|
||
"""增量索引服务 - 支持文档的增删改"""
|
||
|
||
def __init__(self, rag_service):
|
||
self.rag = rag_service
|
||
|
||
async def add_document(
|
||
self,
|
||
session: AsyncSession,
|
||
kb_id: str,
|
||
document_id: str,
|
||
) -> dict:
|
||
"""
|
||
增量添加文档
|
||
|
||
不重建全量索引,只处理单个文档
|
||
"""
|
||
# 检查是否已存在
|
||
existing = await self._get_document_status(session, document_id)
|
||
|
||
if existing and existing.get("status") == "ready":
|
||
return {"action": "skip", "reason": "already_indexed"}
|
||
|
||
# 执行增量摄入
|
||
chunk_count = await self.rag.ingest_document(session, document_id)
|
||
|
||
return {
|
||
"action": "indexed",
|
||
"document_id": document_id,
|
||
"chunk_count": chunk_count,
|
||
}
|
||
|
||
async def update_document(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
new_content: str,
|
||
) -> dict:
|
||
"""
|
||
增量更新文档
|
||
|
||
策略:
|
||
1. 计算新内容hash
|
||
2. 若hash未变,跳过
|
||
3. 若hash改变,删除旧chunks,生成新的
|
||
"""
|
||
# 计算新hash
|
||
new_hash = hashlib.sha256(new_content.encode()).hexdigest()
|
||
|
||
# 获取旧hash
|
||
old_hash = await self._get_content_hash(session, document_id)
|
||
|
||
if new_hash == old_hash:
|
||
return {"action": "skip", "reason": "content_unchanged"}
|
||
|
||
# 删除旧chunks
|
||
deleted = await self._delete_document_chunks(session, document_id)
|
||
|
||
# 更新文档内容
|
||
await self._update_document_content(session, document_id, new_content, new_hash)
|
||
|
||
# 增量摄入新内容
|
||
chunk_count = await self.rag.ingest_document(session, document_id)
|
||
|
||
return {
|
||
"action": "updated",
|
||
"document_id": document_id,
|
||
"deleted_chunks": deleted,
|
||
"new_chunks": chunk_count,
|
||
}
|
||
|
||
async def delete_document(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
) -> dict:
|
||
"""
|
||
删除文档
|
||
|
||
删除文档及其所有chunks
|
||
"""
|
||
# 删除chunks
|
||
deleted = await self._delete_document_chunks(session, document_id)
|
||
|
||
# 删除文档记录
|
||
await self._delete_document(session, document_id)
|
||
|
||
return {
|
||
"action": "deleted",
|
||
"document_id": document_id,
|
||
"deleted_chunks": deleted,
|
||
}
|
||
|
||
async def rebuild_knowledge_base(
|
||
self,
|
||
session: AsyncSession,
|
||
kb_id: str,
|
||
force: bool = False,
|
||
) -> dict:
|
||
"""
|
||
重建知识库索引
|
||
|
||
Args:
|
||
force: 是否强制重建(即使状态是ready)
|
||
"""
|
||
# 获取所有文档
|
||
stmt = select(KnowledgeDocument).where(
|
||
KnowledgeDocument.knowledge_base_id == kb_id
|
||
)
|
||
result = await session.execute(stmt)
|
||
documents = result.scalars().all()
|
||
|
||
stats = {
|
||
"total": len(documents),
|
||
"processed": 0,
|
||
"skipped": 0,
|
||
"failed": 0,
|
||
"errors": [],
|
||
}
|
||
|
||
for doc in documents:
|
||
try:
|
||
if doc.status == "ready" and not force:
|
||
stats["skipped"] += 1
|
||
continue
|
||
|
||
# 删除旧chunks
|
||
await self._delete_document_chunks(session, str(doc.id))
|
||
|
||
# 重新摄入
|
||
await self.rag.ingest_document(session, str(doc.id))
|
||
|
||
stats["processed"] += 1
|
||
|
||
except Exception as e:
|
||
stats["failed"] += 1
|
||
stats["errors"].append({
|
||
"document_id": str(doc.id),
|
||
"error": str(e),
|
||
})
|
||
|
||
return stats
|
||
|
||
# ------------------------------------------------------------------
|
||
# 辅助方法
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _get_document_status(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
) -> Optional[dict]:
|
||
"""获取文档状态"""
|
||
stmt = select(KnowledgeDocument).where(
|
||
KnowledgeDocument.id == document_id
|
||
)
|
||
result = await session.execute(stmt)
|
||
doc = result.scalar_one_or_none()
|
||
|
||
if not doc:
|
||
return None
|
||
|
||
return {
|
||
"status": doc.status,
|
||
"content_hash": getattr(doc, "content_hash", None),
|
||
}
|
||
|
||
async def _get_content_hash(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
) -> Optional[str]:
|
||
"""获取文档内容hash"""
|
||
status = await self._get_document_status(session, document_id)
|
||
return status.get("content_hash") if status else None
|
||
|
||
async def _delete_document_chunks(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
) -> int:
|
||
"""删除文档的所有chunks"""
|
||
# 统计要删除的数量
|
||
count_stmt = select(func.count()).where(
|
||
KnowledgeChunk.document_id == document_id
|
||
)
|
||
count_result = await session.execute(count_stmt)
|
||
count = count_result.scalar() or 0
|
||
|
||
# 删除
|
||
delete_stmt = delete(KnowledgeChunk).where(
|
||
KnowledgeChunk.document_id == document_id
|
||
)
|
||
await session.execute(delete_stmt)
|
||
|
||
return count
|
||
|
||
async def _update_document_content(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
content: str,
|
||
content_hash: str,
|
||
):
|
||
"""更新文档内容和hash"""
|
||
stmt = (
|
||
update(KnowledgeDocument)
|
||
.where(KnowledgeDocument.id == document_id)
|
||
.values(
|
||
content=content,
|
||
content_hash=content_hash,
|
||
status="pending", # 标记为待处理
|
||
)
|
||
)
|
||
await session.execute(stmt)
|
||
|
||
async def _delete_document(
|
||
self,
|
||
session: AsyncSession,
|
||
document_id: str,
|
||
):
|
||
"""删除文档记录"""
|
||
stmt = select(KnowledgeDocument).where(
|
||
KnowledgeDocument.id == document_id
|
||
)
|
||
result = await session.execute(stmt)
|
||
doc = result.scalar_one_or_none()
|
||
|
||
if doc:
|
||
await session.delete(doc) |