feat(rag_platform): U3+U7 — rewrite upload endpoint with sanitization + pipeline

Rewrite upload_document() to use rag_platform sanitize + DocumentProcessor:
- File type whitelist validation (8 allowed types, reject .exe/.sh)
- File size limit (50MB) + zip bomb detection for ZIP-based formats
- DocumentProcessor.parse() (with content sanitization) + segment()
- Return chunks preview, status="segmenting" (pending vectorization)

Add POST /kb-management/documents/preview endpoint:
- Pre-upload preview with adjustable chunk_size/chunk_overlap
- Same security validation as upload, no document record created

Add POST /kb-management/documents/{id}/vectorize placeholder:
- Returns 503 — full async vectorization deferred to U8 (TaskIQ)

Test: update test_upload_document assertion (status "indexed" → "segmenting")
This commit is contained in:
chiguyong 2026-06-25 12:06:16 +08:00
parent b55c896794
commit 3f9588e673
2 changed files with 156 additions and 49 deletions

View File

@ -4,7 +4,10 @@ from __future__ import annotations
import hmac
import logging
import os
import tempfile
import uuid
import zipfile
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
@ -14,6 +17,14 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, Upload
from fastapi.security import APIKeyHeader, APIKeyQuery
from pydantic import BaseModel
from agentkit.rag_platform.document_processor import DocumentProcessor
from agentkit.rag_platform.preview import generate_preview
from agentkit.rag_platform.sanitize import (
MAX_FILE_SIZE,
check_zip_bomb,
validate_file_size,
validate_file_type,
)
from agentkit.server.admin.context import DepartmentContext, get_department_context
from agentkit.server.admin.filtering import filter_kb_sources_by_department
from agentkit.server.auth.dependencies import require_permission
@ -26,6 +37,9 @@ router = APIRouter(tags=["kb-management"])
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
# ZIP-based 文档格式(本质是 ZIP需要 zip bomb 检测)
_ZIP_BASED_TYPES = frozenset({"docx", "xlsx", "pptx"})
# ---------------------------------------------------------------------------
# API Key Authentication
@ -430,64 +444,157 @@ async def upload_document(
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""Upload a document to the knowledge base."""
"""Upload a document to the knowledge base.
U3+U7: 文件类型白名单 + 大小限制 + ZIP bomb 检测 + 内容净化 + 分段
返回 chunk 预览供前端确认文档状态为 segmenting待向量化
"""
if not file.filename:
raise HTTPException(status_code=422, detail="Filename is required")
# Try to use DocumentLoader if available
chunks = 1
# U7: 文件类型白名单校验
try:
from agentkit.memory.document_loader import DocumentLoader
file_type = validate_file_type(file.filename)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e)) from e
content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=413,
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
# 读取文件内容并校验大小U7
content = await file.read(MAX_FILE_SIZE + 1)
try:
validate_file_size(len(content))
except ValueError as e:
raise HTTPException(status_code=413, detail=str(e)) from e
# 保存到临时文件
tmp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
tmp.write(content)
tmp_path = tmp.name
# U7: ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP
if file_type in _ZIP_BASED_TYPES:
try:
check_zip_bomb(tmp_path)
except (ValueError, zipfile.BadZipFile) as e:
raise HTTPException(status_code=422, detail=str(e)) from e
# U3: 解析(含内容净化)+ 分段
processor = DocumentProcessor()
try:
text = processor.parse(tmp_path, file_type)
chunks = processor.segment(text)
except Exception as e:
logger.warning("Document parsing failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
# 确定source_id — 默认 "local"
effective_source_id = source_id or "local"
if effective_source_id == "local":
local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
if not local_sources:
_source_store.add_source("本地文档", "local", {})
uploaded = UploadedDocument(
document_id=str(uuid.uuid4()),
filename=file.filename,
source_id=effective_source_id,
chunks=len(chunks),
status="segmenting",
)
_source_store.add_document(uploaded)
return {
"document_id": uploaded.document_id,
"filename": uploaded.filename,
"source_id": uploaded.source_id,
"chunks": uploaded.chunks,
"status": uploaded.status,
"created_at": uploaded.created_at,
"total_chunks": len(chunks),
"chunks_preview": [{"index": i, "content": c} for i, c in enumerate(chunks)],
}
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
@router.post("/kb-management/documents/preview")
async def preview_document(
file: UploadFile = File(...),
chunk_size: int = 512,
chunk_overlap: int = 50,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""预览文档分段结果(不创建文档记录,不向量化)。
U3: 上传前预览分段效果可调整 chunk_size/chunk_overlap 参数
"""
if not file.filename:
raise HTTPException(status_code=422, detail="Filename is required")
try:
file_type = validate_file_type(file.filename)
except ValueError as e:
raise HTTPException(status_code=422, detail=str(e)) from e
content = await file.read(MAX_FILE_SIZE + 1)
try:
validate_file_size(len(content))
except ValueError as e:
raise HTTPException(status_code=413, detail=str(e)) from e
tmp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
tmp.write(content)
tmp_path = tmp.name
if file_type in _ZIP_BASED_TYPES:
try:
check_zip_bomb(tmp_path)
except (ValueError, zipfile.BadZipFile) as e:
raise HTTPException(status_code=422, detail=str(e)) from e
try:
result = generate_preview(
tmp_path,
file_type,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
loader = DocumentLoader()
doc = loader.load_bytes(content, file.filename)
# Estimate chunks based on content length (rough approximation)
chunks = max(1, len(doc.content) // 500)
except ImportError:
# DocumentLoader not available, use basic estimation
content = await file.read(MAX_UPLOAD_SIZE + 1)
if len(content) > MAX_UPLOAD_SIZE:
raise HTTPException(
status_code=413,
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
)
chunks = max(1, len(content) // 500)
except Exception as e:
logger.warning(f"Document parsing failed: {e}")
chunks = 1
except Exception as e:
logger.warning("Document preview failed: %s", e)
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
# Determine source_id - use "local" default if not provided
effective_source_id = source_id or "local"
return result.model_dump()
finally:
if tmp_path:
try:
os.unlink(tmp_path)
except OSError:
pass
# Ensure a local source exists
if effective_source_id == "local":
local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
if not local_sources:
_source_store.add_source("本地文档", "local", {})
uploaded = UploadedDocument(
document_id=str(uuid.uuid4()),
filename=file.filename,
source_id=effective_source_id,
chunks=chunks,
status="indexed",
@router.post("/kb-management/documents/{document_id}/vectorize")
async def vectorize_document(
document_id: str,
_auth: None = Depends(_verify_api_key),
_user=Depends(require_permission(Permission.KB_WRITE)),
):
"""触发文档向量化。
U3: 当前为占位实现 完整异步向量化需要 PG-backed KBStore (U8)
KnowledgeSourceStore 不跟踪文件路径无法重新解析已上传文档
"""
raise HTTPException(
status_code=503,
detail="Vectorization requires PG-backed KBStore (U8)",
)
_source_store.add_document(uploaded)
return {
"document_id": uploaded.document_id,
"filename": uploaded.filename,
"source_id": uploaded.source_id,
"chunks": uploaded.chunks,
"status": uploaded.status,
"created_at": uploaded.created_at if hasattr(uploaded, "created_at") else "",
}
@router.post("/kb-management/search")

View File

@ -251,7 +251,7 @@ class TestUploadDocument:
data = response.json()
assert data["filename"] == "test.txt"
assert data["document_id"] is not None
assert data["status"] == "indexed"
assert data["status"] == "segmenting"
assert data["chunks"] >= 1
def test_upload_document_with_source_id(self, client):