655 lines
20 KiB
Python
655 lines
20 KiB
Python
"""
|
|
Knowledge Base API
|
|
CRUD for KnowledgeBase / KnowledgeDocument + RAG search endpoint.
|
|
"""
|
|
import re
|
|
import time
|
|
import uuid
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import delete, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.knowledge import (
|
|
KnowledgeBase,
|
|
KnowledgeChunk,
|
|
KnowledgeDocument,
|
|
KnowledgeSearchLog,
|
|
)
|
|
from app.models.user import User
|
|
from app.schemas.knowledge import (
|
|
ChunkPreview,
|
|
DocumentResponse,
|
|
DocumentUpload,
|
|
KnowledgeBaseCreate,
|
|
KnowledgeBaseResponse,
|
|
KnowledgeSearchRequest,
|
|
RetrieveRequest,
|
|
SearchResponse,
|
|
SearchResultItem,
|
|
UpdateDocumentRequest,
|
|
)
|
|
from app.services.knowledge import MockEmbedder, RAGService
|
|
from app.services.knowledge.enhanced_rag import EnhancedRAG
|
|
from app.services.knowledge.incremental_index import IncrementalIndexService
|
|
from app.services.knowledge.chunker import ChunkerFactory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
# Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later)
|
|
_rag_service = RAGService(embedder=MockEmbedder())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _strip_html(html: str) -> str:
|
|
"""Very simple HTML tag stripper using regex."""
|
|
text = re.sub(r"<[^>]+>", " ", html)
|
|
text = re.sub(r"\s+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
async def _fetch_url_content(url: str) -> str:
|
|
"""Fetch a URL and return plain text (strips HTML tags)."""
|
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
|
resp = await client.get(url, headers={"User-Agent": "GEO-Bot/1.0"})
|
|
resp.raise_for_status()
|
|
return _strip_html(resp.text)
|
|
|
|
|
|
async def _get_kb(
|
|
db: AsyncSession, kb_id: uuid.UUID, org_id: uuid.UUID
|
|
) -> KnowledgeBase:
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.id == kb_id,
|
|
KnowledgeBase.organization_id == org_id,
|
|
)
|
|
result = await db.execute(stmt)
|
|
kb = result.scalar_one_or_none()
|
|
if kb is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
return kb
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Knowledge Base CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post(
|
|
"/bases",
|
|
response_model=KnowledgeBaseResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def create_knowledge_base(
|
|
body: KnowledgeBaseCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="User has no organization. Create or join an organization first.",
|
|
)
|
|
|
|
kb = KnowledgeBase(
|
|
organization_id=org_id,
|
|
name=body.name,
|
|
type=body.type,
|
|
description=body.description,
|
|
created_by=current_user.id,
|
|
status="active",
|
|
document_count=0,
|
|
)
|
|
db.add(kb)
|
|
await db.commit()
|
|
await db.refresh(kb)
|
|
|
|
return KnowledgeBaseResponse(
|
|
id=str(kb.id),
|
|
name=kb.name,
|
|
type=kb.type,
|
|
description=kb.description,
|
|
document_count=kb.document_count,
|
|
status=kb.status,
|
|
created_at=kb.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/bases", response_model=list[KnowledgeBaseResponse])
|
|
async def list_knowledge_bases(
|
|
type: Optional[str] = Query(default=None, pattern="^(industry|enterprise)$"),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
return []
|
|
|
|
stmt = select(KnowledgeBase).where(KnowledgeBase.organization_id == org_id)
|
|
if type:
|
|
stmt = stmt.where(KnowledgeBase.type == type)
|
|
stmt = stmt.order_by(KnowledgeBase.created_at.desc())
|
|
|
|
result = await db.execute(stmt)
|
|
bases = result.scalars().all()
|
|
|
|
return [
|
|
KnowledgeBaseResponse(
|
|
id=str(kb.id),
|
|
name=kb.name,
|
|
type=kb.type,
|
|
description=kb.description,
|
|
document_count=kb.document_count,
|
|
status=kb.status,
|
|
created_at=kb.created_at,
|
|
)
|
|
for kb in bases
|
|
]
|
|
|
|
|
|
@router.get("/bases/{kb_id}", response_model=KnowledgeBaseResponse)
|
|
async def get_knowledge_base(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
kb = await _get_kb(db, kb_id, org_id)
|
|
return KnowledgeBaseResponse(
|
|
id=str(kb.id),
|
|
name=kb.name,
|
|
type=kb.type,
|
|
description=kb.description,
|
|
document_count=kb.document_count,
|
|
status=kb.status,
|
|
created_at=kb.created_at,
|
|
)
|
|
|
|
|
|
@router.delete("/bases/{kb_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_knowledge_base(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
kb = await _get_kb(db, kb_id, org_id)
|
|
|
|
# Cascade: delete chunks → documents → knowledge base
|
|
# 1. Find all document IDs under this KB
|
|
doc_stmt = select(KnowledgeDocument.id).where(KnowledgeDocument.knowledge_base_id == kb_id)
|
|
doc_result = await db.execute(doc_stmt)
|
|
doc_ids = [row[0] for row in doc_result.all()]
|
|
|
|
if doc_ids:
|
|
# 2. Delete all chunks for those documents
|
|
chunk_del = delete(KnowledgeChunk).where(KnowledgeChunk.document_id.in_(doc_ids))
|
|
await db.execute(chunk_del)
|
|
|
|
# 3. Delete all documents
|
|
doc_del = delete(KnowledgeDocument).where(KnowledgeDocument.knowledge_base_id == kb_id)
|
|
await db.execute(doc_del)
|
|
|
|
# 4. Delete KB
|
|
await db.delete(kb)
|
|
await db.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Document Management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post(
|
|
"/bases/{kb_id}/documents",
|
|
response_model=DocumentResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def upload_document(
|
|
kb_id: uuid.UUID,
|
|
body: DocumentUpload,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no organization")
|
|
|
|
await _get_kb(db, kb_id, org_id) # ownership check
|
|
|
|
# Resolve content
|
|
if body.source_type == "url":
|
|
if not body.source_url:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="source_url is required for url type documents",
|
|
)
|
|
try:
|
|
content = await _fetch_url_content(body.source_url)
|
|
except Exception as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Failed to fetch URL content: {exc}",
|
|
)
|
|
else:
|
|
if not body.content:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="content is required for text/markdown type documents",
|
|
)
|
|
content = body.content
|
|
|
|
content_hash = RAGService.compute_content_hash(content)
|
|
|
|
doc = KnowledgeDocument(
|
|
knowledge_base_id=kb_id,
|
|
title=body.title,
|
|
source_type=body.source_type,
|
|
source_url=body.source_url,
|
|
content=content,
|
|
content_hash=content_hash,
|
|
status="processing",
|
|
chunk_count=0,
|
|
)
|
|
db.add(doc)
|
|
await db.flush() # get doc.id before ingest
|
|
|
|
# Update document_count on KB
|
|
await db.execute(
|
|
update(KnowledgeBase)
|
|
.where(KnowledgeBase.id == kb_id)
|
|
.values(document_count=KnowledgeBase.document_count + 1)
|
|
)
|
|
|
|
await db.commit()
|
|
await db.refresh(doc)
|
|
|
|
# Asynchronously ingest (same request; background task optimization later)
|
|
try:
|
|
await _rag_service.ingest_document(db, str(doc.id))
|
|
await db.refresh(doc)
|
|
except Exception as exc:
|
|
logger.error(f"Ingest failed for document {doc.id}: {exc}")
|
|
# Status already set to 'failed' by ingest_document on exception
|
|
|
|
return DocumentResponse(
|
|
id=str(doc.id),
|
|
title=doc.title,
|
|
source_type=doc.source_type,
|
|
source_url=doc.source_url,
|
|
chunk_count=doc.chunk_count,
|
|
status=doc.status,
|
|
error_message=doc.error_message,
|
|
created_at=doc.created_at,
|
|
)
|
|
|
|
|
|
@router.get("/bases/{kb_id}/documents", response_model=list[DocumentResponse])
|
|
async def list_documents(
|
|
kb_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
stmt = (
|
|
select(KnowledgeDocument)
|
|
.where(KnowledgeDocument.knowledge_base_id == kb_id)
|
|
.order_by(KnowledgeDocument.created_at.desc())
|
|
)
|
|
result = await db.execute(stmt)
|
|
docs = result.scalars().all()
|
|
|
|
return [
|
|
DocumentResponse(
|
|
id=str(d.id),
|
|
title=d.title,
|
|
source_type=d.source_type,
|
|
source_url=d.source_url,
|
|
chunk_count=d.chunk_count,
|
|
status=d.status,
|
|
error_message=d.error_message,
|
|
created_at=d.created_at,
|
|
)
|
|
for d in docs
|
|
]
|
|
|
|
|
|
@router.delete(
|
|
"/bases/{kb_id}/documents/{doc_id}",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
)
|
|
async def delete_document(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
stmt = select(KnowledgeDocument).where(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.knowledge_base_id == kb_id,
|
|
)
|
|
result = await db.execute(stmt)
|
|
doc = result.scalar_one_or_none()
|
|
if doc is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
|
|
|
# Delete chunks first (cascade also handles this, but explicit for clarity)
|
|
await _rag_service.delete_document_chunks(db, str(doc.id))
|
|
|
|
await db.delete(doc)
|
|
|
|
# Decrement document_count
|
|
await db.execute(
|
|
update(KnowledgeBase)
|
|
.where(KnowledgeBase.id == kb_id)
|
|
.values(document_count=KnowledgeBase.document_count - 1)
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Chunk Preview
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.get(
|
|
"/bases/{kb_id}/documents/{doc_id}/chunks",
|
|
response_model=list[ChunkPreview],
|
|
)
|
|
async def list_chunks(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
# Verify document belongs to KB
|
|
doc_stmt = select(KnowledgeDocument).where(
|
|
KnowledgeDocument.id == doc_id,
|
|
KnowledgeDocument.knowledge_base_id == kb_id,
|
|
)
|
|
doc_result = await db.execute(doc_stmt)
|
|
doc = doc_result.scalar_one_or_none()
|
|
if doc is None:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
|
|
|
chunk_stmt = (
|
|
select(KnowledgeChunk)
|
|
.where(KnowledgeChunk.document_id == doc_id)
|
|
.order_by(KnowledgeChunk.chunk_index)
|
|
)
|
|
chunk_result = await db.execute(chunk_stmt)
|
|
chunks = chunk_result.scalars().all()
|
|
|
|
return [
|
|
ChunkPreview(
|
|
id=str(c.id),
|
|
content=c.content,
|
|
chunk_index=c.chunk_index,
|
|
token_count=c.token_count,
|
|
)
|
|
for c in chunks
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Knowledge Search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/search", response_model=SearchResponse)
|
|
async def knowledge_search(
|
|
body: KnowledgeSearchRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="User has no organization",
|
|
)
|
|
|
|
# Verify all requested KB IDs belong to this org
|
|
kb_uuids = []
|
|
for kb_id_str in body.knowledge_base_ids:
|
|
try:
|
|
kb_uuids.append(uuid.UUID(kb_id_str))
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Invalid knowledge_base_id: {kb_id_str}",
|
|
)
|
|
|
|
if kb_uuids:
|
|
check_stmt = select(KnowledgeBase.id).where(
|
|
KnowledgeBase.id.in_(kb_uuids),
|
|
KnowledgeBase.organization_id == org_id,
|
|
)
|
|
check_result = await db.execute(check_stmt)
|
|
found_ids = {row[0] for row in check_result.all()}
|
|
missing = set(kb_uuids) - found_ids
|
|
if missing:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Knowledge bases not found or not accessible: {[str(m) for m in missing]}",
|
|
)
|
|
|
|
t0 = time.monotonic()
|
|
|
|
raw_results = await _rag_service.search(
|
|
db,
|
|
query=body.query,
|
|
knowledge_base_ids=body.knowledge_base_ids,
|
|
top_k=body.top_k,
|
|
)
|
|
|
|
latency_ms = int((time.monotonic() - t0) * 1000)
|
|
|
|
# Log search
|
|
search_log = KnowledgeSearchLog(
|
|
organization_id=org_id,
|
|
user_id=current_user.id,
|
|
query=body.query,
|
|
knowledge_base_ids=body.knowledge_base_ids,
|
|
results_count=len(raw_results),
|
|
latency_ms=latency_ms,
|
|
)
|
|
db.add(search_log)
|
|
await db.commit()
|
|
|
|
items = [
|
|
SearchResultItem(
|
|
chunk_id=r["chunk_id"],
|
|
content=r["content"],
|
|
score=r["score"],
|
|
document_id=r["document_id"],
|
|
document_title=r["document_title"],
|
|
metadata=r.get("metadata", {}),
|
|
)
|
|
for r in raw_results
|
|
]
|
|
|
|
return SearchResponse(
|
|
results=items,
|
|
total=len(items),
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
|
|
@router.post("/bases/{kb_id}/chunks/preview")
|
|
async def preview_chunks(
|
|
kb_id: uuid.UUID,
|
|
text: str,
|
|
strategy: str = "recursive",
|
|
chunk_size: int = 500,
|
|
):
|
|
"""预览分块效果"""
|
|
chunker = ChunkerFactory.create(strategy)
|
|
|
|
# 临时修改chunk_size
|
|
if strategy == "recursive":
|
|
chunker.STRATEGY.chunk_size = chunk_size
|
|
elif strategy == "semantic":
|
|
chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大
|
|
elif strategy == "fixed":
|
|
chunker.STRATEGY.chunk_size = chunk_size
|
|
|
|
preview = chunker.preview(text, max_chunks=10)
|
|
|
|
return {
|
|
"strategy": strategy,
|
|
"chunk_count": len(preview),
|
|
"preview": preview,
|
|
"strategies": [
|
|
{
|
|
"name": s.name,
|
|
"description": s.description,
|
|
"recommended_size": s.chunk_size,
|
|
}
|
|
for s in ChunkerFactory.list_strategies()
|
|
],
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 增量索引 API
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@router.post("/bases/{kb_id}/documents/{doc_id}/reindex")
|
|
async def reindex_document(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""重新索引单个文档"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
index_service = IncrementalIndexService(_rag_service)
|
|
result = await index_service.add_document(
|
|
db, str(kb_id), str(doc_id)
|
|
)
|
|
return result
|
|
|
|
|
|
@router.post("/bases/{kb_id}/documents/{doc_id}/update")
|
|
async def update_document_content(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
request: UpdateDocumentRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""更新文档内容(增量)"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
index_service = IncrementalIndexService(_rag_service)
|
|
result = await index_service.update_document(
|
|
db, str(doc_id), request.content
|
|
)
|
|
return result
|
|
|
|
|
|
@router.delete("/bases/{kb_id}/documents/{doc_id}")
|
|
async def delete_document_incremental(
|
|
kb_id: uuid.UUID,
|
|
doc_id: uuid.UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""删除文档"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
index_service = IncrementalIndexService(_rag_service)
|
|
result = await index_service.delete_document(db, str(doc_id))
|
|
return result
|
|
|
|
|
|
@router.post("/bases/{kb_id}/rebuild")
|
|
async def rebuild_knowledge_base(
|
|
kb_id: uuid.UUID,
|
|
force: bool = False,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""重建知识库索引"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
index_service = IncrementalIndexService(_rag_service)
|
|
result = await index_service.rebuild_knowledge_base(
|
|
db, str(kb_id), force
|
|
)
|
|
return result
|
|
|
|
|
|
@router.post("/bases/{kb_id}/retrieve")
|
|
async def enhanced_retrieve(
|
|
kb_id: uuid.UUID,
|
|
request: RetrieveRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""增强检索(支持重排序和压缩)"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
|
|
|
await _get_kb(db, kb_id, org_id)
|
|
|
|
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
|
|
results = await enhanced_rag.retrieve_with_rerank(
|
|
db,
|
|
request.query,
|
|
[str(kb_id)],
|
|
top_k=request.top_k or 5,
|
|
use_rerank=request.use_rerank,
|
|
use_compression=request.use_compression,
|
|
)
|
|
return {"results": results, "query": request.query}
|