feat(rag_platform): U3+U7 — rewrite upload endpoint with sanitization + pipeline
Rewrite upload_document() to use rag_platform sanitize + DocumentProcessor:
- File type whitelist validation (8 allowed types, reject .exe/.sh)
- File size limit (50MB) + zip bomb detection for ZIP-based formats
- DocumentProcessor.parse() (with content sanitization) + segment()
- Return chunks preview, status="segmenting" (pending vectorization)
Add POST /kb-management/documents/preview endpoint:
- Pre-upload preview with adjustable chunk_size/chunk_overlap
- Same security validation as upload, no document record created
Add POST /kb-management/documents/{id}/vectorize placeholder:
- Returns 503 — full async vectorization deferred to U8 (TaskIQ)
Test: update test_upload_document assertion (status "indexed" → "segmenting")
This commit is contained in:
parent
b55c896794
commit
3f9588e673
|
|
@ -4,7 +4,10 @@ from __future__ import annotations
|
|||
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
|
@ -14,6 +17,14 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, Upload
|
|||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agentkit.rag_platform.document_processor import DocumentProcessor
|
||||
from agentkit.rag_platform.preview import generate_preview
|
||||
from agentkit.rag_platform.sanitize import (
|
||||
MAX_FILE_SIZE,
|
||||
check_zip_bomb,
|
||||
validate_file_size,
|
||||
validate_file_type,
|
||||
)
|
||||
from agentkit.server.admin.context import DepartmentContext, get_department_context
|
||||
from agentkit.server.admin.filtering import filter_kb_sources_by_department
|
||||
from agentkit.server.auth.dependencies import require_permission
|
||||
|
|
@ -26,6 +37,9 @@ router = APIRouter(tags=["kb-management"])
|
|||
|
||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
# ZIP-based 文档格式(本质是 ZIP,需要 zip bomb 检测)
|
||||
_ZIP_BASED_TYPES = frozenset({"docx", "xlsx", "pptx"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API Key Authentication
|
||||
|
|
@ -430,42 +444,52 @@ async def upload_document(
|
|||
_auth: None = Depends(_verify_api_key),
|
||||
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||
):
|
||||
"""Upload a document to the knowledge base."""
|
||||
"""Upload a document to the knowledge base.
|
||||
|
||||
U3+U7: 文件类型白名单 + 大小限制 + ZIP bomb 检测 + 内容净化 + 分段。
|
||||
返回 chunk 预览供前端确认,文档状态为 segmenting(待向量化)。
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=422, detail="Filename is required")
|
||||
|
||||
# Try to use DocumentLoader if available
|
||||
chunks = 1
|
||||
# U7: 文件类型白名单校验
|
||||
try:
|
||||
from agentkit.memory.document_loader import DocumentLoader
|
||||
file_type = validate_file_type(file.filename)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
|
||||
)
|
||||
loader = DocumentLoader()
|
||||
doc = loader.load_bytes(content, file.filename)
|
||||
# Estimate chunks based on content length (rough approximation)
|
||||
chunks = max(1, len(doc.content) // 500)
|
||||
except ImportError:
|
||||
# DocumentLoader not available, use basic estimation
|
||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
||||
if len(content) > MAX_UPLOAD_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
|
||||
)
|
||||
chunks = max(1, len(content) // 500)
|
||||
# 读取文件内容并校验大小(U7)
|
||||
content = await file.read(MAX_FILE_SIZE + 1)
|
||||
try:
|
||||
validate_file_size(len(content))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=413, detail=str(e)) from e
|
||||
|
||||
# 保存到临时文件
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
# U7: ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP)
|
||||
if file_type in _ZIP_BASED_TYPES:
|
||||
try:
|
||||
check_zip_bomb(tmp_path)
|
||||
except (ValueError, zipfile.BadZipFile) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||
|
||||
# U3: 解析(含内容净化)+ 分段
|
||||
processor = DocumentProcessor()
|
||||
try:
|
||||
text = processor.parse(tmp_path, file_type)
|
||||
chunks = processor.segment(text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Document parsing failed: {e}")
|
||||
chunks = 1
|
||||
logger.warning("Document parsing failed: %s", e)
|
||||
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
|
||||
|
||||
# Determine source_id - use "local" default if not provided
|
||||
# 确定source_id — 默认 "local"
|
||||
effective_source_id = source_id or "local"
|
||||
|
||||
# Ensure a local source exists
|
||||
if effective_source_id == "local":
|
||||
local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
|
||||
if not local_sources:
|
||||
|
|
@ -475,8 +499,8 @@ async def upload_document(
|
|||
document_id=str(uuid.uuid4()),
|
||||
filename=file.filename,
|
||||
source_id=effective_source_id,
|
||||
chunks=chunks,
|
||||
status="indexed",
|
||||
chunks=len(chunks),
|
||||
status="segmenting",
|
||||
)
|
||||
_source_store.add_document(uploaded)
|
||||
|
||||
|
|
@ -486,8 +510,91 @@ async def upload_document(
|
|||
"source_id": uploaded.source_id,
|
||||
"chunks": uploaded.chunks,
|
||||
"status": uploaded.status,
|
||||
"created_at": uploaded.created_at if hasattr(uploaded, "created_at") else "",
|
||||
"created_at": uploaded.created_at,
|
||||
"total_chunks": len(chunks),
|
||||
"chunks_preview": [{"index": i, "content": c} for i, c in enumerate(chunks)],
|
||||
}
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/kb-management/documents/preview")
|
||||
async def preview_document(
|
||||
file: UploadFile = File(...),
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
_auth: None = Depends(_verify_api_key),
|
||||
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||
):
|
||||
"""预览文档分段结果(不创建文档记录,不向量化)。
|
||||
|
||||
U3: 上传前预览分段效果,可调整 chunk_size/chunk_overlap 参数。
|
||||
"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=422, detail="Filename is required")
|
||||
|
||||
try:
|
||||
file_type = validate_file_type(file.filename)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||
|
||||
content = await file.read(MAX_FILE_SIZE + 1)
|
||||
try:
|
||||
validate_file_size(len(content))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=413, detail=str(e)) from e
|
||||
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
|
||||
if file_type in _ZIP_BASED_TYPES:
|
||||
try:
|
||||
check_zip_bomb(tmp_path)
|
||||
except (ValueError, zipfile.BadZipFile) as e:
|
||||
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||
|
||||
try:
|
||||
result = generate_preview(
|
||||
tmp_path,
|
||||
file_type,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Document preview failed: %s", e)
|
||||
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
|
||||
|
||||
return result.model_dump()
|
||||
finally:
|
||||
if tmp_path:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/kb-management/documents/{document_id}/vectorize")
|
||||
async def vectorize_document(
|
||||
document_id: str,
|
||||
_auth: None = Depends(_verify_api_key),
|
||||
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||
):
|
||||
"""触发文档向量化。
|
||||
|
||||
U3: 当前为占位实现 — 完整异步向量化需要 PG-backed KBStore (U8)。
|
||||
旧 KnowledgeSourceStore 不跟踪文件路径,无法重新解析已上传文档。
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Vectorization requires PG-backed KBStore (U8)",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/kb-management/search")
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ class TestUploadDocument:
|
|||
data = response.json()
|
||||
assert data["filename"] == "test.txt"
|
||||
assert data["document_id"] is not None
|
||||
assert data["status"] == "indexed"
|
||||
assert data["status"] == "segmenting"
|
||||
assert data["chunks"] >= 1
|
||||
|
||||
def test_upload_document_with_source_id(self, client):
|
||||
|
|
|
|||
Loading…
Reference in New Issue