feat(rag_platform): U3+U7 — rewrite upload endpoint with sanitization + pipeline
Rewrite upload_document() to use rag_platform sanitize + DocumentProcessor:
- File type whitelist validation (8 allowed types, reject .exe/.sh)
- File size limit (50MB) + zip bomb detection for ZIP-based formats
- DocumentProcessor.parse() (with content sanitization) + segment()
- Return chunks preview, status="segmenting" (pending vectorization)
Add POST /kb-management/documents/preview endpoint:
- Pre-upload preview with adjustable chunk_size/chunk_overlap
- Same security validation as upload, no document record created
Add POST /kb-management/documents/{id}/vectorize placeholder:
- Returns 503 — full async vectorization deferred to U8 (TaskIQ)
Test: update test_upload_document assertion (status "indexed" → "segmenting")
This commit is contained in:
parent
b55c896794
commit
3f9588e673
|
|
@ -4,7 +4,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
|
import zipfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -14,6 +17,14 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Security, Upload
|
||||||
from fastapi.security import APIKeyHeader, APIKeyQuery
|
from fastapi.security import APIKeyHeader, APIKeyQuery
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agentkit.rag_platform.document_processor import DocumentProcessor
|
||||||
|
from agentkit.rag_platform.preview import generate_preview
|
||||||
|
from agentkit.rag_platform.sanitize import (
|
||||||
|
MAX_FILE_SIZE,
|
||||||
|
check_zip_bomb,
|
||||||
|
validate_file_size,
|
||||||
|
validate_file_type,
|
||||||
|
)
|
||||||
from agentkit.server.admin.context import DepartmentContext, get_department_context
|
from agentkit.server.admin.context import DepartmentContext, get_department_context
|
||||||
from agentkit.server.admin.filtering import filter_kb_sources_by_department
|
from agentkit.server.admin.filtering import filter_kb_sources_by_department
|
||||||
from agentkit.server.auth.dependencies import require_permission
|
from agentkit.server.auth.dependencies import require_permission
|
||||||
|
|
@ -26,6 +37,9 @@ router = APIRouter(tags=["kb-management"])
|
||||||
|
|
||||||
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50MB
|
||||||
|
|
||||||
|
# ZIP-based 文档格式(本质是 ZIP,需要 zip bomb 检测)
|
||||||
|
_ZIP_BASED_TYPES = frozenset({"docx", "xlsx", "pptx"})
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# API Key Authentication
|
# API Key Authentication
|
||||||
|
|
@ -430,42 +444,52 @@ async def upload_document(
|
||||||
_auth: None = Depends(_verify_api_key),
|
_auth: None = Depends(_verify_api_key),
|
||||||
_user=Depends(require_permission(Permission.KB_WRITE)),
|
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||||
):
|
):
|
||||||
"""Upload a document to the knowledge base."""
|
"""Upload a document to the knowledge base.
|
||||||
|
|
||||||
|
U3+U7: 文件类型白名单 + 大小限制 + ZIP bomb 检测 + 内容净化 + 分段。
|
||||||
|
返回 chunk 预览供前端确认,文档状态为 segmenting(待向量化)。
|
||||||
|
"""
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
raise HTTPException(status_code=422, detail="Filename is required")
|
raise HTTPException(status_code=422, detail="Filename is required")
|
||||||
|
|
||||||
# Try to use DocumentLoader if available
|
# U7: 文件类型白名单校验
|
||||||
chunks = 1
|
|
||||||
try:
|
try:
|
||||||
from agentkit.memory.document_loader import DocumentLoader
|
file_type = validate_file_type(file.filename)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||||
|
|
||||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
# 读取文件内容并校验大小(U7)
|
||||||
if len(content) > MAX_UPLOAD_SIZE:
|
content = await file.read(MAX_FILE_SIZE + 1)
|
||||||
raise HTTPException(
|
try:
|
||||||
status_code=413,
|
validate_file_size(len(content))
|
||||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
|
except ValueError as e:
|
||||||
)
|
raise HTTPException(status_code=413, detail=str(e)) from e
|
||||||
loader = DocumentLoader()
|
|
||||||
doc = loader.load_bytes(content, file.filename)
|
# 保存到临时文件
|
||||||
# Estimate chunks based on content length (rough approximation)
|
tmp_path: str | None = None
|
||||||
chunks = max(1, len(doc.content) // 500)
|
try:
|
||||||
except ImportError:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
|
||||||
# DocumentLoader not available, use basic estimation
|
tmp.write(content)
|
||||||
content = await file.read(MAX_UPLOAD_SIZE + 1)
|
tmp_path = tmp.name
|
||||||
if len(content) > MAX_UPLOAD_SIZE:
|
|
||||||
raise HTTPException(
|
# U7: ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP)
|
||||||
status_code=413,
|
if file_type in _ZIP_BASED_TYPES:
|
||||||
detail=f"File too large. Maximum size is {MAX_UPLOAD_SIZE // (1024 * 1024)}MB",
|
try:
|
||||||
)
|
check_zip_bomb(tmp_path)
|
||||||
chunks = max(1, len(content) // 500)
|
except (ValueError, zipfile.BadZipFile) as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||||
|
|
||||||
|
# U3: 解析(含内容净化)+ 分段
|
||||||
|
processor = DocumentProcessor()
|
||||||
|
try:
|
||||||
|
text = processor.parse(tmp_path, file_type)
|
||||||
|
chunks = processor.segment(text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Document parsing failed: {e}")
|
logger.warning("Document parsing failed: %s", e)
|
||||||
chunks = 1
|
raise HTTPException(status_code=422, detail=f"Document parsing failed: {e}") from e
|
||||||
|
|
||||||
# Determine source_id - use "local" default if not provided
|
# 确定source_id — 默认 "local"
|
||||||
effective_source_id = source_id or "local"
|
effective_source_id = source_id or "local"
|
||||||
|
|
||||||
# Ensure a local source exists
|
|
||||||
if effective_source_id == "local":
|
if effective_source_id == "local":
|
||||||
local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
|
local_sources = [s for s in _source_store.list_sources() if s.type == "local"]
|
||||||
if not local_sources:
|
if not local_sources:
|
||||||
|
|
@ -475,8 +499,8 @@ async def upload_document(
|
||||||
document_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
filename=file.filename,
|
filename=file.filename,
|
||||||
source_id=effective_source_id,
|
source_id=effective_source_id,
|
||||||
chunks=chunks,
|
chunks=len(chunks),
|
||||||
status="indexed",
|
status="segmenting",
|
||||||
)
|
)
|
||||||
_source_store.add_document(uploaded)
|
_source_store.add_document(uploaded)
|
||||||
|
|
||||||
|
|
@ -486,8 +510,91 @@ async def upload_document(
|
||||||
"source_id": uploaded.source_id,
|
"source_id": uploaded.source_id,
|
||||||
"chunks": uploaded.chunks,
|
"chunks": uploaded.chunks,
|
||||||
"status": uploaded.status,
|
"status": uploaded.status,
|
||||||
"created_at": uploaded.created_at if hasattr(uploaded, "created_at") else "",
|
"created_at": uploaded.created_at,
|
||||||
|
"total_chunks": len(chunks),
|
||||||
|
"chunks_preview": [{"index": i, "content": c} for i, c in enumerate(chunks)],
|
||||||
}
|
}
|
||||||
|
finally:
|
||||||
|
if tmp_path:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/kb-management/documents/preview")
|
||||||
|
async def preview_document(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 50,
|
||||||
|
_auth: None = Depends(_verify_api_key),
|
||||||
|
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||||
|
):
|
||||||
|
"""预览文档分段结果(不创建文档记录,不向量化)。
|
||||||
|
|
||||||
|
U3: 上传前预览分段效果,可调整 chunk_size/chunk_overlap 参数。
|
||||||
|
"""
|
||||||
|
if not file.filename:
|
||||||
|
raise HTTPException(status_code=422, detail="Filename is required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_type = validate_file_type(file.filename)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||||
|
|
||||||
|
content = await file.read(MAX_FILE_SIZE + 1)
|
||||||
|
try:
|
||||||
|
validate_file_size(len(content))
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=413, detail=str(e)) from e
|
||||||
|
|
||||||
|
tmp_path: str | None = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_type}") as tmp:
|
||||||
|
tmp.write(content)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
|
||||||
|
if file_type in _ZIP_BASED_TYPES:
|
||||||
|
try:
|
||||||
|
check_zip_bomb(tmp_path)
|
||||||
|
except (ValueError, zipfile.BadZipFile) as e:
|
||||||
|
raise HTTPException(status_code=422, detail=str(e)) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = generate_preview(
|
||||||
|
tmp_path,
|
||||||
|
file_type,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Document preview failed: %s", e)
|
||||||
|
raise HTTPException(status_code=422, detail=f"Document preview failed: {e}") from e
|
||||||
|
|
||||||
|
return result.model_dump()
|
||||||
|
finally:
|
||||||
|
if tmp_path:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/kb-management/documents/{document_id}/vectorize")
|
||||||
|
async def vectorize_document(
|
||||||
|
document_id: str,
|
||||||
|
_auth: None = Depends(_verify_api_key),
|
||||||
|
_user=Depends(require_permission(Permission.KB_WRITE)),
|
||||||
|
):
|
||||||
|
"""触发文档向量化。
|
||||||
|
|
||||||
|
U3: 当前为占位实现 — 完整异步向量化需要 PG-backed KBStore (U8)。
|
||||||
|
旧 KnowledgeSourceStore 不跟踪文件路径,无法重新解析已上传文档。
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Vectorization requires PG-backed KBStore (U8)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/kb-management/search")
|
@router.post("/kb-management/search")
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ class TestUploadDocument:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["filename"] == "test.txt"
|
assert data["filename"] == "test.txt"
|
||||||
assert data["document_id"] is not None
|
assert data["document_id"] is not None
|
||||||
assert data["status"] == "indexed"
|
assert data["status"] == "segmenting"
|
||||||
assert data["chunks"] >= 1
|
assert data["chunks"] >= 1
|
||||||
|
|
||||||
def test_upload_document_with_source_id(self, client):
|
def test_upload_document_with_source_id(self, client):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue