geo/backend/app/api/knowledge.py

655 lines
20 KiB
Python

"""
Knowledge Base API
CRUD for KnowledgeBase / KnowledgeDocument + RAG search endpoint.
"""
import re
import time
import uuid
import logging
from typing import Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.knowledge import (
KnowledgeBase,
KnowledgeChunk,
KnowledgeDocument,
KnowledgeSearchLog,
)
from app.models.user import User
from app.schemas.knowledge import (
ChunkPreview,
DocumentResponse,
DocumentUpload,
KnowledgeBaseCreate,
KnowledgeBaseResponse,
KnowledgeSearchRequest,
RetrieveRequest,
SearchResponse,
SearchResultItem,
UpdateDocumentRequest,
)
from app.services.knowledge import MockEmbedder, RAGService
from app.services.knowledge.enhanced_rag import EnhancedRAG
from app.services.knowledge.incremental_index import IncrementalIndexService
from app.services.knowledge.chunker import ChunkerFactory
logger = logging.getLogger(__name__)
router = APIRouter()
# Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later)
_rag_service = RAGService(embedder=MockEmbedder())
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _strip_html(html: str) -> str:
"""Very simple HTML tag stripper using regex."""
text = re.sub(r"<[^>]+>", " ", html)
text = re.sub(r"\s+", " ", text)
return text.strip()
async def _fetch_url_content(url: str) -> str:
"""Fetch a URL and return plain text (strips HTML tags)."""
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
resp = await client.get(url, headers={"User-Agent": "GEO-Bot/1.0"})
resp.raise_for_status()
return _strip_html(resp.text)
async def _get_kb(
db: AsyncSession, kb_id: uuid.UUID, org_id: uuid.UUID
) -> KnowledgeBase:
stmt = select(KnowledgeBase).where(
KnowledgeBase.id == kb_id,
KnowledgeBase.organization_id == org_id,
)
result = await db.execute(stmt)
kb = result.scalar_one_or_none()
if kb is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
return kb
# ---------------------------------------------------------------------------
# Knowledge Base CRUD
# ---------------------------------------------------------------------------
@router.post(
"/bases",
response_model=KnowledgeBaseResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_knowledge_base(
body: KnowledgeBaseCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User has no organization. Create or join an organization first.",
)
kb = KnowledgeBase(
organization_id=org_id,
name=body.name,
type=body.type,
description=body.description,
created_by=current_user.id,
status="active",
document_count=0,
)
db.add(kb)
await db.commit()
await db.refresh(kb)
return KnowledgeBaseResponse(
id=str(kb.id),
name=kb.name,
type=kb.type,
description=kb.description,
document_count=kb.document_count,
status=kb.status,
created_at=kb.created_at,
)
@router.get("/bases", response_model=list[KnowledgeBaseResponse])
async def list_knowledge_bases(
type: Optional[str] = Query(default=None, pattern="^(industry|enterprise)$"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
return []
stmt = select(KnowledgeBase).where(KnowledgeBase.organization_id == org_id)
if type:
stmt = stmt.where(KnowledgeBase.type == type)
stmt = stmt.order_by(KnowledgeBase.created_at.desc())
result = await db.execute(stmt)
bases = result.scalars().all()
return [
KnowledgeBaseResponse(
id=str(kb.id),
name=kb.name,
type=kb.type,
description=kb.description,
document_count=kb.document_count,
status=kb.status,
created_at=kb.created_at,
)
for kb in bases
]
@router.get("/bases/{kb_id}", response_model=KnowledgeBaseResponse)
async def get_knowledge_base(
kb_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
kb = await _get_kb(db, kb_id, org_id)
return KnowledgeBaseResponse(
id=str(kb.id),
name=kb.name,
type=kb.type,
description=kb.description,
document_count=kb.document_count,
status=kb.status,
created_at=kb.created_at,
)
@router.delete("/bases/{kb_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_knowledge_base(
kb_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
kb = await _get_kb(db, kb_id, org_id)
# Cascade: delete chunks → documents → knowledge base
# 1. Find all document IDs under this KB
doc_stmt = select(KnowledgeDocument.id).where(KnowledgeDocument.knowledge_base_id == kb_id)
doc_result = await db.execute(doc_stmt)
doc_ids = [row[0] for row in doc_result.all()]
if doc_ids:
# 2. Delete all chunks for those documents
chunk_del = delete(KnowledgeChunk).where(KnowledgeChunk.document_id.in_(doc_ids))
await db.execute(chunk_del)
# 3. Delete all documents
doc_del = delete(KnowledgeDocument).where(KnowledgeDocument.knowledge_base_id == kb_id)
await db.execute(doc_del)
# 4. Delete KB
await db.delete(kb)
await db.commit()
# ---------------------------------------------------------------------------
# Document Management
# ---------------------------------------------------------------------------
@router.post(
"/bases/{kb_id}/documents",
response_model=DocumentResponse,
status_code=status.HTTP_201_CREATED,
)
async def upload_document(
kb_id: uuid.UUID,
body: DocumentUpload,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no organization")
await _get_kb(db, kb_id, org_id) # ownership check
# Resolve content
if body.source_type == "url":
if not body.source_url:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="source_url is required for url type documents",
)
try:
content = await _fetch_url_content(body.source_url)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Failed to fetch URL content: {exc}",
)
else:
if not body.content:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="content is required for text/markdown type documents",
)
content = body.content
content_hash = RAGService.compute_content_hash(content)
doc = KnowledgeDocument(
knowledge_base_id=kb_id,
title=body.title,
source_type=body.source_type,
source_url=body.source_url,
content=content,
content_hash=content_hash,
status="processing",
chunk_count=0,
)
db.add(doc)
await db.flush() # get doc.id before ingest
# Update document_count on KB
await db.execute(
update(KnowledgeBase)
.where(KnowledgeBase.id == kb_id)
.values(document_count=KnowledgeBase.document_count + 1)
)
await db.commit()
await db.refresh(doc)
# Asynchronously ingest (same request; background task optimization later)
try:
await _rag_service.ingest_document(db, str(doc.id))
await db.refresh(doc)
except Exception as exc:
logger.error(f"Ingest failed for document {doc.id}: {exc}")
# Status already set to 'failed' by ingest_document on exception
return DocumentResponse(
id=str(doc.id),
title=doc.title,
source_type=doc.source_type,
source_url=doc.source_url,
chunk_count=doc.chunk_count,
status=doc.status,
error_message=doc.error_message,
created_at=doc.created_at,
)
@router.get("/bases/{kb_id}/documents", response_model=list[DocumentResponse])
async def list_documents(
kb_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
stmt = (
select(KnowledgeDocument)
.where(KnowledgeDocument.knowledge_base_id == kb_id)
.order_by(KnowledgeDocument.created_at.desc())
)
result = await db.execute(stmt)
docs = result.scalars().all()
return [
DocumentResponse(
id=str(d.id),
title=d.title,
source_type=d.source_type,
source_url=d.source_url,
chunk_count=d.chunk_count,
status=d.status,
error_message=d.error_message,
created_at=d.created_at,
)
for d in docs
]
@router.delete(
"/bases/{kb_id}/documents/{doc_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_document(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == doc_id,
KnowledgeDocument.knowledge_base_id == kb_id,
)
result = await db.execute(stmt)
doc = result.scalar_one_or_none()
if doc is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
# Delete chunks first (cascade also handles this, but explicit for clarity)
await _rag_service.delete_document_chunks(db, str(doc.id))
await db.delete(doc)
# Decrement document_count
await db.execute(
update(KnowledgeBase)
.where(KnowledgeBase.id == kb_id)
.values(document_count=KnowledgeBase.document_count - 1)
)
await db.commit()
# ---------------------------------------------------------------------------
# Chunk Preview
# ---------------------------------------------------------------------------
@router.get(
"/bases/{kb_id}/documents/{doc_id}/chunks",
response_model=list[ChunkPreview],
)
async def list_chunks(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
# Verify document belongs to KB
doc_stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == doc_id,
KnowledgeDocument.knowledge_base_id == kb_id,
)
doc_result = await db.execute(doc_stmt)
doc = doc_result.scalar_one_or_none()
if doc is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
chunk_stmt = (
select(KnowledgeChunk)
.where(KnowledgeChunk.document_id == doc_id)
.order_by(KnowledgeChunk.chunk_index)
)
chunk_result = await db.execute(chunk_stmt)
chunks = chunk_result.scalars().all()
return [
ChunkPreview(
id=str(c.id),
content=c.content,
chunk_index=c.chunk_index,
token_count=c.token_count,
)
for c in chunks
]
# ---------------------------------------------------------------------------
# Knowledge Search
# ---------------------------------------------------------------------------
@router.post("/search", response_model=SearchResponse)
async def knowledge_search(
body: KnowledgeSearchRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User has no organization",
)
# Verify all requested KB IDs belong to this org
kb_uuids = []
for kb_id_str in body.knowledge_base_ids:
try:
kb_uuids.append(uuid.UUID(kb_id_str))
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid knowledge_base_id: {kb_id_str}",
)
if kb_uuids:
check_stmt = select(KnowledgeBase.id).where(
KnowledgeBase.id.in_(kb_uuids),
KnowledgeBase.organization_id == org_id,
)
check_result = await db.execute(check_stmt)
found_ids = {row[0] for row in check_result.all()}
missing = set(kb_uuids) - found_ids
if missing:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Knowledge bases not found or not accessible: {[str(m) for m in missing]}",
)
t0 = time.monotonic()
raw_results = await _rag_service.search(
db,
query=body.query,
knowledge_base_ids=body.knowledge_base_ids,
top_k=body.top_k,
)
latency_ms = int((time.monotonic() - t0) * 1000)
# Log search
search_log = KnowledgeSearchLog(
organization_id=org_id,
user_id=current_user.id,
query=body.query,
knowledge_base_ids=body.knowledge_base_ids,
results_count=len(raw_results),
latency_ms=latency_ms,
)
db.add(search_log)
await db.commit()
items = [
SearchResultItem(
chunk_id=r["chunk_id"],
content=r["content"],
score=r["score"],
document_id=r["document_id"],
document_title=r["document_title"],
metadata=r.get("metadata", {}),
)
for r in raw_results
]
return SearchResponse(
results=items,
total=len(items),
latency_ms=latency_ms,
)
@router.post("/bases/{kb_id}/chunks/preview")
async def preview_chunks(
kb_id: uuid.UUID,
text: str,
strategy: str = "recursive",
chunk_size: int = 500,
):
"""预览分块效果"""
chunker = ChunkerFactory.create(strategy)
# 临时修改chunk_size
if strategy == "recursive":
chunker.STRATEGY.chunk_size = chunk_size
elif strategy == "semantic":
chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大
elif strategy == "fixed":
chunker.STRATEGY.chunk_size = chunk_size
preview = chunker.preview(text, max_chunks=10)
return {
"strategy": strategy,
"chunk_count": len(preview),
"preview": preview,
"strategies": [
{
"name": s.name,
"description": s.description,
"recommended_size": s.chunk_size,
}
for s in ChunkerFactory.list_strategies()
],
}
# ---------------------------------------------------------------------------
# 增量索引 API
# ---------------------------------------------------------------------------
@router.post("/bases/{kb_id}/documents/{doc_id}/reindex")
async def reindex_document(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""重新索引单个文档"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.add_document(
db, str(kb_id), str(doc_id)
)
return result
@router.post("/bases/{kb_id}/documents/{doc_id}/update")
async def update_document_content(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
request: UpdateDocumentRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新文档内容(增量)"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.update_document(
db, str(doc_id), request.content
)
return result
@router.delete("/bases/{kb_id}/documents/{doc_id}")
async def delete_document_incremental(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除文档"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.delete_document(db, str(doc_id))
return result
@router.post("/bases/{kb_id}/rebuild")
async def rebuild_knowledge_base(
kb_id: uuid.UUID,
force: bool = False,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""重建知识库索引"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.rebuild_knowledge_base(
db, str(kb_id), force
)
return result
@router.post("/bases/{kb_id}/retrieve")
async def enhanced_retrieve(
kb_id: uuid.UUID,
request: RetrieveRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""增强检索(支持重排序和压缩)"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
results = await enhanced_rag.retrieve_with_rerank(
db,
request.query,
[str(kb_id)],
top_k=request.top_k or 5,
use_rerank=request.use_rerank,
use_compression=request.use_compression,
)
return {"results": results, "query": request.query}