geo/backend/app/services/knowledge/incremental_index.py

242 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
增量索引服务
"""
import hashlib
from typing import Optional
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
class IncrementalIndexService:
"""增量索引服务 - 支持文档的增删改"""
def __init__(self, rag_service):
self.rag = rag_service
async def add_document(
self,
session: AsyncSession,
kb_id: str,
document_id: str,
) -> dict:
"""
增量添加文档
不重建全量索引,只处理单个文档
"""
# 检查是否已存在
existing = await self._get_document_status(session, document_id)
if existing and existing.get("status") == "ready":
return {"action": "skip", "reason": "already_indexed"}
# 执行增量摄入
chunk_count = await self.rag.ingest_document(session, document_id)
return {
"action": "indexed",
"document_id": document_id,
"chunk_count": chunk_count,
}
async def update_document(
self,
session: AsyncSession,
document_id: str,
new_content: str,
) -> dict:
"""
增量更新文档
策略:
1. 计算新内容hash
2. 若hash未变跳过
3. 若hash改变删除旧chunks生成新的
"""
# 计算新hash
new_hash = hashlib.sha256(new_content.encode()).hexdigest()
# 获取旧hash
old_hash = await self._get_content_hash(session, document_id)
if new_hash == old_hash:
return {"action": "skip", "reason": "content_unchanged"}
# 删除旧chunks
deleted = await self._delete_document_chunks(session, document_id)
# 更新文档内容
await self._update_document_content(session, document_id, new_content, new_hash)
# 增量摄入新内容
chunk_count = await self.rag.ingest_document(session, document_id)
return {
"action": "updated",
"document_id": document_id,
"deleted_chunks": deleted,
"new_chunks": chunk_count,
}
async def delete_document(
self,
session: AsyncSession,
document_id: str,
) -> dict:
"""
删除文档
删除文档及其所有chunks
"""
# 删除chunks
deleted = await self._delete_document_chunks(session, document_id)
# 删除文档记录
await self._delete_document(session, document_id)
return {
"action": "deleted",
"document_id": document_id,
"deleted_chunks": deleted,
}
async def rebuild_knowledge_base(
self,
session: AsyncSession,
kb_id: str,
force: bool = False,
) -> dict:
"""
重建知识库索引
Args:
force: 是否强制重建即使状态是ready
"""
# 获取所有文档
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.knowledge_base_id == kb_id
)
result = await session.execute(stmt)
documents = result.scalars().all()
stats = {
"total": len(documents),
"processed": 0,
"skipped": 0,
"failed": 0,
"errors": [],
}
for doc in documents:
try:
if doc.status == "ready" and not force:
stats["skipped"] += 1
continue
# 删除旧chunks
await self._delete_document_chunks(session, str(doc.id))
# 重新摄入
await self.rag.ingest_document(session, str(doc.id))
stats["processed"] += 1
except Exception as e:
stats["failed"] += 1
stats["errors"].append({
"document_id": str(doc.id),
"error": str(e),
})
return stats
# ------------------------------------------------------------------
# 辅助方法
# ------------------------------------------------------------------
async def _get_document_status(
self,
session: AsyncSession,
document_id: str,
) -> Optional[dict]:
"""获取文档状态"""
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == document_id
)
result = await session.execute(stmt)
doc = result.scalar_one_or_none()
if not doc:
return None
return {
"status": doc.status,
"content_hash": getattr(doc, "content_hash", None),
}
async def _get_content_hash(
self,
session: AsyncSession,
document_id: str,
) -> Optional[str]:
"""获取文档内容hash"""
status = await self._get_document_status(session, document_id)
return status.get("content_hash") if status else None
async def _delete_document_chunks(
self,
session: AsyncSession,
document_id: str,
) -> int:
"""删除文档的所有chunks"""
# 统计要删除的数量
count_stmt = select(func.count()).where(
KnowledgeChunk.document_id == document_id
)
count_result = await session.execute(count_stmt)
count = count_result.scalar() or 0
# 删除
delete_stmt = delete(KnowledgeChunk).where(
KnowledgeChunk.document_id == document_id
)
await session.execute(delete_stmt)
return count
async def _update_document_content(
self,
session: AsyncSession,
document_id: str,
content: str,
content_hash: str,
):
"""更新文档内容和hash"""
stmt = (
update(KnowledgeDocument)
.where(KnowledgeDocument.id == document_id)
.values(
content=content,
content_hash=content_hash,
status="pending", # 标记为待处理
)
)
await session.execute(stmt)
async def _delete_document(
self,
session: AsyncSession,
document_id: str,
):
"""删除文档记录"""
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == document_id
)
result = await session.execute(stmt)
doc = result.scalar_one_or_none()
if doc:
await session.delete(doc)