""" Knowledge Base API CRUD for KnowledgeBase / KnowledgeDocument + RAG search endpoint. """ import re import time import uuid import logging from typing import Optional import httpx from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db from app.models.knowledge import ( KnowledgeBase, KnowledgeChunk, KnowledgeDocument, KnowledgeSearchLog, ) from app.models.user import User from app.schemas.knowledge import ( ChunkPreview, DocumentResponse, DocumentUpload, KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeSearchRequest, RetrieveRequest, SearchResponse, SearchResultItem, UpdateDocumentRequest, ) from app.services.knowledge import MockEmbedder, RAGService from app.services.knowledge.enhanced_rag import EnhancedRAG from app.services.knowledge.incremental_index import IncrementalIndexService from app.services.knowledge.chunker import ChunkerFactory logger = logging.getLogger(__name__) router = APIRouter() # Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later) _rag_service = RAGService(embedder=MockEmbedder()) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _strip_html(html: str) -> str: """Very simple HTML tag stripper using regex.""" text = re.sub(r"<[^>]+>", " ", html) text = re.sub(r"\s+", " ", text) return text.strip() async def _fetch_url_content(url: str) -> str: """Fetch a URL and return plain text (strips HTML tags).""" async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: resp = await client.get(url, headers={"User-Agent": "GEO-Bot/1.0"}) resp.raise_for_status() return _strip_html(resp.text) async def _get_kb( db: AsyncSession, kb_id: uuid.UUID, org_id: uuid.UUID ) -> KnowledgeBase: stmt = select(KnowledgeBase).where( KnowledgeBase.id == kb_id, KnowledgeBase.organization_id == org_id, ) result = await db.execute(stmt) kb = result.scalar_one_or_none() if kb is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") return kb # --------------------------------------------------------------------------- # Knowledge Base CRUD # --------------------------------------------------------------------------- @router.post( "/bases", response_model=KnowledgeBaseResponse, status_code=status.HTTP_201_CREATED, ) async def create_knowledge_base( body: KnowledgeBaseCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="User has no organization. Create or join an organization first.", ) kb = KnowledgeBase( organization_id=org_id, name=body.name, type=body.type, description=body.description, created_by=current_user.id, status="active", document_count=0, ) db.add(kb) await db.commit() await db.refresh(kb) return KnowledgeBaseResponse( id=str(kb.id), name=kb.name, type=kb.type, description=kb.description, document_count=kb.document_count, status=kb.status, created_at=kb.created_at, ) @router.get("/bases", response_model=list[KnowledgeBaseResponse]) async def list_knowledge_bases( type: Optional[str] = Query(default=None, pattern="^(industry|enterprise)$"), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: return [] stmt = select(KnowledgeBase).where(KnowledgeBase.organization_id == org_id) if type: stmt = stmt.where(KnowledgeBase.type == type) stmt = stmt.order_by(KnowledgeBase.created_at.desc()) result = await db.execute(stmt) bases = result.scalars().all() return [ KnowledgeBaseResponse( id=str(kb.id), name=kb.name, type=kb.type, description=kb.description, document_count=kb.document_count, status=kb.status, created_at=kb.created_at, ) for kb in bases ] @router.get("/bases/{kb_id}", response_model=KnowledgeBaseResponse) async def get_knowledge_base( kb_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") kb = await _get_kb(db, kb_id, org_id) return KnowledgeBaseResponse( id=str(kb.id), name=kb.name, type=kb.type, description=kb.description, document_count=kb.document_count, status=kb.status, created_at=kb.created_at, ) @router.delete("/bases/{kb_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_knowledge_base( kb_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") kb = await _get_kb(db, kb_id, org_id) # Cascade: delete chunks → documents → knowledge base # 1. Find all document IDs under this KB doc_stmt = select(KnowledgeDocument.id).where(KnowledgeDocument.knowledge_base_id == kb_id) doc_result = await db.execute(doc_stmt) doc_ids = [row[0] for row in doc_result.all()] if doc_ids: # 2. Delete all chunks for those documents chunk_del = delete(KnowledgeChunk).where(KnowledgeChunk.document_id.in_(doc_ids)) await db.execute(chunk_del) # 3. Delete all documents doc_del = delete(KnowledgeDocument).where(KnowledgeDocument.knowledge_base_id == kb_id) await db.execute(doc_del) # 4. Delete KB await db.delete(kb) await db.commit() # --------------------------------------------------------------------------- # Document Management # --------------------------------------------------------------------------- @router.post( "/bases/{kb_id}/documents", response_model=DocumentResponse, status_code=status.HTTP_201_CREATED, ) async def upload_document( kb_id: uuid.UUID, body: DocumentUpload, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="User has no organization") await _get_kb(db, kb_id, org_id) # ownership check # Resolve content if body.source_type == "url": if not body.source_url: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="source_url is required for url type documents", ) try: content = await _fetch_url_content(body.source_url) except Exception as exc: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to fetch URL content: {exc}", ) else: if not body.content: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="content is required for text/markdown type documents", ) content = body.content content_hash = RAGService.compute_content_hash(content) doc = KnowledgeDocument( knowledge_base_id=kb_id, title=body.title, source_type=body.source_type, source_url=body.source_url, content=content, content_hash=content_hash, status="processing", chunk_count=0, ) db.add(doc) await db.flush() # get doc.id before ingest # Update document_count on KB await db.execute( update(KnowledgeBase) .where(KnowledgeBase.id == kb_id) .values(document_count=KnowledgeBase.document_count + 1) ) await db.commit() await db.refresh(doc) # Asynchronously ingest (same request; background task optimization later) try: await _rag_service.ingest_document(db, str(doc.id)) await db.refresh(doc) except Exception as exc: logger.error(f"Ingest failed for document {doc.id}: {exc}") # Status already set to 'failed' by ingest_document on exception return DocumentResponse( id=str(doc.id), title=doc.title, source_type=doc.source_type, source_url=doc.source_url, chunk_count=doc.chunk_count, status=doc.status, error_message=doc.error_message, created_at=doc.created_at, ) @router.get("/bases/{kb_id}/documents", response_model=list[DocumentResponse]) async def list_documents( kb_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) stmt = ( select(KnowledgeDocument) .where(KnowledgeDocument.knowledge_base_id == kb_id) .order_by(KnowledgeDocument.created_at.desc()) ) result = await db.execute(stmt) docs = result.scalars().all() return [ DocumentResponse( id=str(d.id), title=d.title, source_type=d.source_type, source_url=d.source_url, chunk_count=d.chunk_count, status=d.status, error_message=d.error_message, created_at=d.created_at, ) for d in docs ] @router.delete( "/bases/{kb_id}/documents/{doc_id}", status_code=status.HTTP_204_NO_CONTENT, ) async def delete_document( kb_id: uuid.UUID, doc_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) stmt = select(KnowledgeDocument).where( KnowledgeDocument.id == doc_id, KnowledgeDocument.knowledge_base_id == kb_id, ) result = await db.execute(stmt) doc = result.scalar_one_or_none() if doc is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found") # Delete chunks first (cascade also handles this, but explicit for clarity) await _rag_service.delete_document_chunks(db, str(doc.id)) await db.delete(doc) # Decrement document_count await db.execute( update(KnowledgeBase) .where(KnowledgeBase.id == kb_id) .values(document_count=KnowledgeBase.document_count - 1) ) await db.commit() # --------------------------------------------------------------------------- # Chunk Preview # --------------------------------------------------------------------------- @router.get( "/bases/{kb_id}/documents/{doc_id}/chunks", response_model=list[ChunkPreview], ) async def list_chunks( kb_id: uuid.UUID, doc_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) # Verify document belongs to KB doc_stmt = select(KnowledgeDocument).where( KnowledgeDocument.id == doc_id, KnowledgeDocument.knowledge_base_id == kb_id, ) doc_result = await db.execute(doc_stmt) doc = doc_result.scalar_one_or_none() if doc is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found") chunk_stmt = ( select(KnowledgeChunk) .where(KnowledgeChunk.document_id == doc_id) .order_by(KnowledgeChunk.chunk_index) ) chunk_result = await db.execute(chunk_stmt) chunks = chunk_result.scalars().all() return [ ChunkPreview( id=str(c.id), content=c.content, chunk_index=c.chunk_index, token_count=c.token_count, ) for c in chunks ] # --------------------------------------------------------------------------- # Knowledge Search # --------------------------------------------------------------------------- @router.post("/search", response_model=SearchResponse) async def knowledge_search( body: KnowledgeSearchRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): org_id = current_user.organization_id if not org_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="User has no organization", ) # Verify all requested KB IDs belong to this org kb_uuids = [] for kb_id_str in body.knowledge_base_ids: try: kb_uuids.append(uuid.UUID(kb_id_str)) except ValueError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid knowledge_base_id: {kb_id_str}", ) if kb_uuids: check_stmt = select(KnowledgeBase.id).where( KnowledgeBase.id.in_(kb_uuids), KnowledgeBase.organization_id == org_id, ) check_result = await db.execute(check_stmt) found_ids = {row[0] for row in check_result.all()} missing = set(kb_uuids) - found_ids if missing: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Knowledge bases not found or not accessible: {[str(m) for m in missing]}", ) t0 = time.monotonic() raw_results = await _rag_service.search( db, query=body.query, knowledge_base_ids=body.knowledge_base_ids, top_k=body.top_k, ) latency_ms = int((time.monotonic() - t0) * 1000) # Log search search_log = KnowledgeSearchLog( organization_id=org_id, user_id=current_user.id, query=body.query, knowledge_base_ids=body.knowledge_base_ids, results_count=len(raw_results), latency_ms=latency_ms, ) db.add(search_log) await db.commit() items = [ SearchResultItem( chunk_id=r["chunk_id"], content=r["content"], score=r["score"], document_id=r["document_id"], document_title=r["document_title"], metadata=r.get("metadata", {}), ) for r in raw_results ] return SearchResponse( results=items, total=len(items), latency_ms=latency_ms, ) @router.post("/bases/{kb_id}/chunks/preview") async def preview_chunks( kb_id: uuid.UUID, text: str, strategy: str = "recursive", chunk_size: int = 500, ): """预览分块效果""" chunker = ChunkerFactory.create(strategy) # 临时修改chunk_size if strategy == "recursive": chunker.STRATEGY.chunk_size = chunk_size elif strategy == "semantic": chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大 elif strategy == "fixed": chunker.STRATEGY.chunk_size = chunk_size preview = chunker.preview(text, max_chunks=10) return { "strategy": strategy, "chunk_count": len(preview), "preview": preview, "strategies": [ { "name": s.name, "description": s.description, "recommended_size": s.chunk_size, } for s in ChunkerFactory.list_strategies() ], } # --------------------------------------------------------------------------- # 增量索引 API # --------------------------------------------------------------------------- @router.post("/bases/{kb_id}/documents/{doc_id}/reindex") async def reindex_document( kb_id: uuid.UUID, doc_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """重新索引单个文档""" org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) index_service = IncrementalIndexService(_rag_service) result = await index_service.add_document( db, str(kb_id), str(doc_id) ) return result @router.post("/bases/{kb_id}/documents/{doc_id}/update") async def update_document_content( kb_id: uuid.UUID, doc_id: uuid.UUID, request: UpdateDocumentRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """更新文档内容(增量)""" org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) index_service = IncrementalIndexService(_rag_service) result = await index_service.update_document( db, str(doc_id), request.content ) return result @router.delete("/bases/{kb_id}/documents/{doc_id}") async def delete_document_incremental( kb_id: uuid.UUID, doc_id: uuid.UUID, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """删除文档""" org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) index_service = IncrementalIndexService(_rag_service) result = await index_service.delete_document(db, str(doc_id)) return result @router.post("/bases/{kb_id}/rebuild") async def rebuild_knowledge_base( kb_id: uuid.UUID, force: bool = False, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """重建知识库索引""" org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) index_service = IncrementalIndexService(_rag_service) result = await index_service.rebuild_knowledge_base( db, str(kb_id), force ) return result @router.post("/bases/{kb_id}/retrieve") async def enhanced_retrieve( kb_id: uuid.UUID, request: RetrieveRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """增强检索(支持重排序和压缩)""" org_id = current_user.organization_id if not org_id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") await _get_kb(db, kb_id, org_id) enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder) results = await enhanced_rag.retrieve_with_rerank( db, request.query, [str(kb_id)], top_k=request.top_k or 5, use_rerank=request.use_rerank, use_compression=request.use_compression, ) return {"results": results, "query": request.query}