diff --git a/src/agentkit/server/routes/kb_management.py b/src/agentkit/server/routes/kb_management.py index 680a601..0a66be2 100644 --- a/src/agentkit/server/routes/kb_management.py +++ b/src/agentkit/server/routes/kb_management.py @@ -4,7 +4,10 @@ from __future__ import annotations import hmac import logging +import os +import tempfile import uuid +import zipfile from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path @@ -14,6 +17,14 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, Upload from fastapi.security import APIKeyHeader, APIKeyQuery from pydantic import BaseModel +from agentkit.rag_platform.document_processor import DocumentProcessor +from agentkit.rag_platform.preview import generate_preview +from agentkit.rag_platform.sanitize import ( + MAX_FILE_SIZE, + check_zip_bomb, + validate_file_size, + validate_file_type, +) from agentkit.server.admin.context import DepartmentContext, get_department_context from agentkit.server.admin.filtering import filter_kb_sources_by_department from agentkit.server.auth.dependencies import require_permission @@ -26,6 +37,9 @@ router = APIRouter(tags=["kb-management"]) MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB +# ZIP-based 文档格式(本质是 ZIP,需要 zip bomb 检测) +_ZIP_BASED_TYPES = frozenset({"docx", "xlsx", "pptx"}) + # --------------------------------------------------------------------------- # API Key Authentication @@ -430,64 +444,157 @@ async def upload_document( _auth: None = Depends(_verify_api_key), _user=Depends(require_permission(Permission.KB_WRITE)), ): - """Upload a document to the knowledge base.""" + """Upload a document to the knowledge base. + + U3+U7: 文件类型白名单 + 大小限制 + ZIP bomb 检测 + 内容净化 + 分段。 + 返回 chunk 预览供前端确认,文档状态为 segmenting(待向量化)。 + """ if not file.filename: raise HTTPException(status_code=422, detail="Filename is required") - # Try to use DocumentLoader if available - chunks = 1 + # U7: 文件类型白名单校验 try: - from agentkit.memory.document_loader import DocumentLoader + file_type = validate_file_type(file.filename) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) from e - content = await file.read(MAX_UPLOAD_SIZE + 1) - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException( - status_code=413, - detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB", + # 读取文件内容并校验大小(U7) + content = await file.read(MAX_FILE_SIZE + 1) + try: + validate_file_size(len(content)) + except ValueError as e: + raise HTTPException(status_code=413, detail=str(e)) from e + + # 保存到临时文件 + tmp_path: str | None = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp: + tmp.write(content) + tmp_path = tmp.name + + # U7: ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP) + if file_type in _ZIP_BASED_TYPES: + try: + check_zip_bomb(tmp_path) + except (ValueError, zipfile.BadZipFile) as e: + raise HTTPException(status_code=422, detail=str(e)) from e + + # U3: 解析(含内容净化)+ 分段 + processor = DocumentProcessor() + try: + text = processor.parse(tmp_path, file_type) + chunks = processor.segment(text) + except Exception as e: + logger.warning("Document parsing failed: %s", e) + raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e + + # 确定source_id — 默认 "local" + effective_source_id = source_id or "local" + if effective_source_id == "local": + local_sources = [s for s in _source_store.list_sources() if s.type == "local"] + if not local_sources: + _source_store.add_source("本地文档", "local", {}) + + uploaded = UploadedDocument( + document_id=str(uuid.uuid4()), + filename=file.filename, + source_id=effective_source_id, + chunks=len(chunks), + status="segmenting", + ) + _source_store.add_document(uploaded) + + return { + "document_id": uploaded.document_id, + "filename": uploaded.filename, + "source_id": uploaded.source_id, + "chunks": uploaded.chunks, + "status": uploaded.status, + "created_at": uploaded.created_at, + "total_chunks": len(chunks), + "chunks_preview": [{"index": i, "content": c} for i, c in enumerate(chunks)], + } + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass + + +@router.post("/kb-management/documents/preview") +async def preview_document( + file: UploadFile = File(...), + chunk_size: int = 512, + chunk_overlap: int = 50, + _auth: None = Depends(_verify_api_key), + _user=Depends(require_permission(Permission.KB_WRITE)), +): + """预览文档分段结果(不创建文档记录,不向量化)。 + + U3: 上传前预览分段效果,可调整 chunk_size/chunk_overlap 参数。 + """ + if not file.filename: + raise HTTPException(status_code=422, detail="Filename is required") + + try: + file_type = validate_file_type(file.filename) + except ValueError as e: + raise HTTPException(status_code=422, detail=str(e)) from e + + content = await file.read(MAX_FILE_SIZE + 1) + try: + validate_file_size(len(content)) + except ValueError as e: + raise HTTPException(status_code=413, detail=str(e)) from e + + tmp_path: str | None = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp: + tmp.write(content) + tmp_path = tmp.name + + if file_type in _ZIP_BASED_TYPES: + try: + check_zip_bomb(tmp_path) + except (ValueError, zipfile.BadZipFile) as e: + raise HTTPException(status_code=422, detail=str(e)) from e + + try: + result = generate_preview( + tmp_path, + file_type, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, ) - loader = DocumentLoader() - doc = loader.load_bytes(content, file.filename) - # Estimate chunks based on content length (rough approximation) - chunks = max(1, len(doc.content) // 500) - except ImportError: - # DocumentLoader not available, use basic estimation - content = await file.read(MAX_UPLOAD_SIZE + 1) - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException( - status_code=413, - detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB", - ) - chunks = max(1, len(content) // 500) - except Exception as e: - logger.warning(f"Document parsing failed: {e}") - chunks = 1 + except Exception as e: + logger.warning("Document preview failed: %s", e) + raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e - # Determine source_id - use "local" default if not provided - effective_source_id = source_id or "local" + return result.model_dump() + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except OSError: + pass - # Ensure a local source exists - if effective_source_id == "local": - local_sources = [s for s in _source_store.list_sources() if s.type == "local"] - if not local_sources: - _source_store.add_source("本地文档", "local", {}) - uploaded = UploadedDocument( - document_id=str(uuid.uuid4()), - filename=file.filename, - source_id=effective_source_id, - chunks=chunks, - status="indexed", +@router.post("/kb-management/documents/{document_id}/vectorize") +async def vectorize_document( + document_id: str, + _auth: None = Depends(_verify_api_key), + _user=Depends(require_permission(Permission.KB_WRITE)), +): + """触发文档向量化。 + + U3: 当前为占位实现 — 完整异步向量化需要 PG-backed KBStore (U8)。 + 旧 KnowledgeSourceStore 不跟踪文件路径,无法重新解析已上传文档。 + """ + raise HTTPException( + status_code=503, + detail="Vectorization requires PG-backed KBStore (U8)", ) - _source_store.add_document(uploaded) - - return { - "document_id": uploaded.document_id, - "filename": uploaded.filename, - "source_id": uploaded.source_id, - "chunks": uploaded.chunks, - "status": uploaded.status, - "created_at": uploaded.created_at if hasattr(uploaded, "created_at") else "", - } @router.post("/kb-management/search") diff --git a/tests/unit/server/test_kb_management.py b/tests/unit/server/test_kb_management.py index 63e9ac1..93187a8 100644 --- a/tests/unit/server/test_kb_management.py +++ b/tests/unit/server/test_kb_management.py @@ -251,7 +251,7 @@ class TestUploadDocument: data = response.json() assert data["filename"] == "test.txt" assert data["document_id"] is not None - assert data["status"] == "indexed" + assert data["status"] == "segmenting" assert data["chunks"] >= 1 def test_upload_document_with_source_id(self, client):