feat(rag_platform): U3+U7 — rewrite upload endpoint with sanitization + pipeline

Rewrite upload_document() to use rag_platform sanitize + DocumentProcessor:
- File type whitelist validation (8 allowed types, reject .exe/.sh)
- File size limit (50MB) + zip bomb detection for ZIP-based formats
- DocumentProcessor.parse() (with content sanitization) + segment()
- Return chunks preview, status="segmenting" (pending vectorization)

Add POST /kb-management/documents/preview endpoint:
- Pre-upload preview with adjustable chunk_size/chunk_overlap
- Same security validation as upload, no document record created

Add POST /kb-management/documents/{id}/vectorize placeholder:
- Returns 503 — full async vectorization deferred to U8 (TaskIQ)

Test: update test_upload_document assertion (status "indexed" → "segmenting")
This commit is contained in:
chiguyong 2026-06-25 12:06:16 +08:00
parent b55c896794
commit 3f9588e673
2 changed files with 156 additions and 49 deletions

View File

@ -4,7 +4,10 @@ from __future__ import annotations
import hmac import hmac
import logging import logging
import os
import tempfile
import uuid import uuid
import zipfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
@ -14,6 +17,14 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, Upload
from fastapi.security import APIKeyHeader, APIKeyQuery from fastapi.security import APIKeyHeader, APIKeyQuery
from pydantic import BaseModel from pydantic import BaseModel
from agentkit.rag_platform.document_processor import DocumentProcessor
from agentkit.rag_platform.preview import generate_preview
from agentkit.rag_platform.sanitize import (
MAX_FILE_SIZE,
check_zip_bomb,
validate_file_size,
validate_file_type,
)
from agentkit.server.admin.context import DepartmentContext, get_department_context from agentkit.server.admin.context import DepartmentContext, get_department_context
from agentkit.server.admin.filtering import filter_kb_sources_by_department from agentkit.server.admin.filtering import filter_kb_sources_by_department
from agentkit.server.auth.dependencies import require_permission from agentkit.server.auth.dependencies import require_permission
@ -26,6 +37,9 @@ router = APIRouter(tags=["kb-management"])
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
# ZIP-based 文档格式(本质是 ZIP需要 zip bomb 检测)
_ZIP_BASED_TYPES = frozenset({"docx", "xlsx", "pptx"})
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# API Key Authentication # API Key Authentication
@ -430,42 +444,52 @@ async def upload_document(
_auth: None = Depends(_verify_api_key), _auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)), _user=Depends(require_permission(Permission.KB_WRITE)),
): ):
"""Upload a document to the knowledge base.""" """Upload a document to the knowledge base.
U3+U7: 文件类型白名单 + 大小限制 + ZIP bomb 检测 + 内容净化 + 分段
返回 chunk 预览供前端确认文档状态为 segmenting待向量化
"""
if not file.filename: if not file.filename:
raise HTTPException(status_code=422, detail="Filename is required") raise HTTPException(status_code=422, detail="Filename is required")
# Try to use DocumentLoader if available # U7: 文件类型白名单校验
chunks = 1
try: try:
from agentkit.memory.document_loader import DocumentLoader file_type = validate_file_type(file.filename)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e)) from e
content = await file.read(MAX_UPLOAD_SIZE + 1) # 读取文件内容并校验大小U7
if len(content) > MAX_UPLOAD_SIZE: content = await file.read(MAX_FILE_SIZE + 1)
raise HTTPException( try:
status_code=413, validate_file_size(len(content))
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB", except ValueError as e:
) raise HTTPException(status_code=413, detail=str(e)) from e
loader = DocumentLoader()
doc = loader.load_bytes(content, file.filename) # 保存到临时文件
# Estimate chunks based on content length (rough approximation) tmp_path: str | None = None
chunks = max(1, len(doc.content) // 500) try:
except ImportError: with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
# DocumentLoader not available, use basic estimation tmp.write(content)
content = await file.read(MAX_UPLOAD_SIZE + 1) tmp_path = tmp.name
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException( # U7: ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP
status_code=413, if file_type in _ZIP_BASED_TYPES:
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB", try:
) check_zip_bomb(tmp_path)
chunks = max(1, len(content) // 500) except (ValueError, zipfile.BadZipFile) as e:
raise HTTPException(status_code=422, detail=str(e)) from e
# U3: 解析(含内容净化)+ 分段
processor = DocumentProcessor()
try:
text = processor.parse(tmp_path, file_type)
chunks = processor.segment(text)
except Exception as e: except Exception as e:
logger.warning(f"Document parsing failed: {e}") logger.warning("Document parsing failed: %s", e)
chunks = 1 raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
# Determine source_id - use "local" default if not provided # 确定source_id — 默认 "local"
effective_source_id = source_id or "local" effective_source_id = source_id or "local"
# Ensure a local source exists
if effective_source_id == "local": if effective_source_id == "local":
local_sources = [s for s in _source_store.list_sources() if s.type == "local"] local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
if not local_sources: if not local_sources:
@ -475,8 +499,8 @@ async def upload_document(
document_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),
filename=file.filename, filename=file.filename,
source_id=effective_source_id, source_id=effective_source_id,
chunks=chunks, chunks=len(chunks),
status="indexed", status="segmenting",
) )
_source_store.add_document(uploaded) _source_store.add_document(uploaded)
@ -486,8 +510,91 @@ async def upload_document(
"source_id": uploaded.source_id, "source_id": uploaded.source_id,
"chunks": uploaded.chunks, "chunks": uploaded.chunks,
"status": uploaded.status, "status": uploaded.status,
"created_at": uploaded.created_at if hasattr(uploaded, "created_at") else "", "created_at": uploaded.created_at,
"total_chunks": len(chunks),
"chunks_preview": [{"index": i, "content": c} for i, c in enumerate(chunks)],
} }
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
@router.post("/kb-management/documents/preview")
async def preview_document(
file: UploadFile = File(...),
chunk_size: int = 512,
chunk_overlap: int = 50,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""预览文档分段结果(不创建文档记录,不向量化)。
U3: 上传前预览分段效果可调整 chunk_size/chunk_overlap 参数
"""
if not file.filename:
raise HTTPException(status_code=422, detail="Filename is required")
try:
file_type = validate_file_type(file.filename)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e)) from e
content = await file.read(MAX_FILE_SIZE + 1)
try:
validate_file_size(len(content))
except ValueError as e:
raise HTTPException(status_code=413, detail=str(e)) from e
tmp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
tmp.write(content)
tmp_path = tmp.name
if file_type in _ZIP_BASED_TYPES:
try:
check_zip_bomb(tmp_path)
except (ValueError, zipfile.BadZipFile) as e:
raise HTTPException(status_code=422, detail=str(e)) from e
try:
result = generate_preview(
tmp_path,
file_type,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
except Exception as e:
logger.warning("Document preview failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
return result.model_dump()
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
@router.post("/kb-management/documents/{document_id}/vectorize")
async def vectorize_document(
document_id: str,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""触发文档向量化。
U3: 当前为占位实现 完整异步向量化需要 PG-backed KBStore (U8)
KnowledgeSourceStore 不跟踪文件路径无法重新解析已上传文档
"""
raise HTTPException(
status_code=503,
detail="Vectorization requires PG-backed KBStore (U8)",
)
@router.post("/kb-management/search") @router.post("/kb-management/search")

View File

@ -251,7 +251,7 @@ class TestUploadDocument:
data = response.json() data = response.json()
assert data["filename"] == "test.txt" assert data["filename"] == "test.txt"
assert data["document_id"] is not None assert data["document_id"] is not None
assert data["status"] == "indexed" assert data["status"] == "segmenting"
assert data["chunks"] >= 1 assert data["chunks"] >= 1
def test_upload_document_with_source_id(self, client): def test_upload_document_with_source_id(self, client):