feat(rag_platform): U3+U7 — document processing pipeline + upload security
U3: Document processing pipeline (document_processor.py) - DocumentProcessor class wrapping parse → segment → vectorize - parse() uses memory/document_loader.py for multi-format extraction - segment() uses LlamaIndex SentenceSplitter - preview() returns chunks for read-only preview (no vectorization) - vectorize() embeds chunks and stores in pgvector (all-or-nothing) - process() orchestrates full pipeline with status transitions: pending → parsing → segmenting → vectorizing → indexed | failed U7: Upload security & content sanitization (sanitize.py) - ALLOWED_FILE_TYPES whitelist (pdf/docx/xlsx/pptx/txt/md/csv/html) - MAX_FILE_SIZE 50MB limit - validate_file_type() / validate_file_size() guards - check_zip_bomb() for ZIP-based formats (ratio > 100:1 or > 500MB) - check_image_bomb() for pixel count > 100MP (PNG/JPEG/GIF header parsing) - is_safe_ip() SSRF protection (loopback/RFC1918/link-local/ULA denied) - sanitize_markdown() removes dangerous HTML tags (script/iframe/object/embed) - sanitize_content() main entry point for text format sanitization - parse_xml_safe() XXE protection (forbid_dtd/forbid_entities/forbid_external) Preview API (preview.py) - PreviewChunk / PreviewResult Pydantic models - generate_preview() returns read-only segmentation preview Tests: 112 tests passing (45 new + 67 existing) - test_sanitize.py: file type/size, markdown sanitization, SSRF, zip/image bomb - test_document_processor.py: parse/segment, preview, vectorize, failure status
This commit is contained in:
parent
c1a21f57a1
commit
b55c896794
|
|
@ -1,5 +1,10 @@
|
|||
"""RAG 平台模块 — 企业知识库场景的工业级 RAG 管道。"""
|
||||
|
||||
from agentkit.rag_platform.document_processor import (
|
||||
DEFAULT_CHUNK_OVERLAP,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DocumentProcessor,
|
||||
)
|
||||
from agentkit.rag_platform.models import (
|
||||
Chunk,
|
||||
Document,
|
||||
|
|
@ -9,13 +14,44 @@ from agentkit.rag_platform.models import (
|
|||
QueryMode,
|
||||
QueryResult,
|
||||
)
|
||||
from agentkit.rag_platform.preview import (
|
||||
PreviewChunk,
|
||||
PreviewResult,
|
||||
generate_preview,
|
||||
)
|
||||
from agentkit.rag_platform.sanitize import (
|
||||
ALLOWED_FILE_TYPES,
|
||||
MAX_FILE_SIZE,
|
||||
check_image_bomb,
|
||||
check_zip_bomb,
|
||||
is_safe_ip,
|
||||
sanitize_content,
|
||||
sanitize_markdown,
|
||||
validate_file_size,
|
||||
validate_file_type,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ALLOWED_FILE_TYPES",
|
||||
"DEFAULT_CHUNK_OVERLAP",
|
||||
"DEFAULT_CHUNK_SIZE",
|
||||
"Chunk",
|
||||
"Document",
|
||||
"DocumentProcessor",
|
||||
"DocumentStatus",
|
||||
"KBStatus",
|
||||
"KnowledgeBase",
|
||||
"MAX_FILE_SIZE",
|
||||
"PreviewChunk",
|
||||
"PreviewResult",
|
||||
"QueryMode",
|
||||
"QueryResult",
|
||||
"check_image_bomb",
|
||||
"check_zip_bomb",
|
||||
"generate_preview",
|
||||
"is_safe_ip",
|
||||
"sanitize_content",
|
||||
"sanitize_markdown",
|
||||
"validate_file_size",
|
||||
"validate_file_type",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,250 @@
|
|||
"""文档处理管道 — 解析 → 分段 → 向量化。
|
||||
|
||||
封装 DocumentLoader(多格式解析)+ LlamaIndex SentenceSplitter(分段)+
|
||||
PGVectorStore(向量化存储)。
|
||||
|
||||
状态机:pending → parsing → segmenting → vectorizing → indexed | failed
|
||||
失败语义:vectorize 是 all-or-nothing — 失败时抛出异常,调用方负责将状态置为 failed。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.memory.document_loader import DocumentLoader
|
||||
from agentkit.rag_platform.models import DocumentStatus
|
||||
from agentkit.rag_platform.sanitize import sanitize_content
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
||||
from agentkit.rag_platform.store import KBStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 512
|
||||
DEFAULT_CHUNK_OVERLAP = 50
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""文档处理管道:解析 → 分段 → 向量化。
|
||||
|
||||
Args:
|
||||
chunk_size: 分段大小(token 数),默认 512
|
||||
chunk_overlap: 分段重叠(token 数),默认 50
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> None:
|
||||
self._loader = DocumentLoader()
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
|
||||
def parse(self, file_path: str, file_type: str) -> str:
|
||||
"""从文件中提取文本。
|
||||
|
||||
使用 memory/document_loader.py 的 DocumentLoader 进行多格式解析
|
||||
(PDF/DOCX/XLSX/Markdown/HTML/纯文本)。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
file_type: 文件类型(扩展名,不含点)— 用于净化
|
||||
|
||||
Returns:
|
||||
提取的文本内容(已净化)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 内容超过大小限制
|
||||
"""
|
||||
doc = self._loader.load(file_path)
|
||||
# 对文本格式应用内容净化
|
||||
return sanitize_content(doc.content, file_type)
|
||||
|
||||
def segment(
|
||||
self,
|
||||
text: str,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> list[str]:
|
||||
"""将文本分段为 chunk。
|
||||
|
||||
使用 LlamaIndex SentenceSplitter 进行分段。
|
||||
|
||||
Args:
|
||||
text: 待分段文本
|
||||
chunk_size: 分段大小(token 数)
|
||||
chunk_overlap: 分段重叠(token 数)
|
||||
|
||||
Returns:
|
||||
chunk 文本列表
|
||||
"""
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
splitter = SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
chunks = splitter.split_text(text)
|
||||
logger.debug(
|
||||
"Segmented text into %d chunks (chunk_size=%d, overlap=%d)",
|
||||
len(chunks),
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
)
|
||||
return chunks
|
||||
|
||||
def preview(
|
||||
self,
|
||||
file_path: str,
|
||||
file_type: str,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""解析 + 分段,返回 chunk 用于只读预览(不向量化)。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
file_type: 文件类型(扩展名,不含点)
|
||||
chunk_size: 分段大小(默认使用构造器值)
|
||||
chunk_overlap: 分段重叠(默认使用构造器值)
|
||||
|
||||
Returns:
|
||||
chunk 字典列表:[{index, content, metadata}, ...]
|
||||
"""
|
||||
cs = chunk_size or self._chunk_size
|
||||
co = chunk_overlap or self._chunk_overlap
|
||||
|
||||
text = self.parse(file_path, file_type)
|
||||
chunks = self.segment(text, cs, co)
|
||||
|
||||
return [{"index": i, "content": c, "metadata": {}} for i, c in enumerate(chunks)]
|
||||
|
||||
async def vectorize(
|
||||
self,
|
||||
chunks: list[str] | list[dict[str, Any]],
|
||||
kb_id: str,
|
||||
document_id: str,
|
||||
vector_store: "PGVectorStore",
|
||||
embed_model: "BaseEmbedding",
|
||||
) -> list["TextNode"]:
|
||||
"""将 chunk 向量化并存储到 pgvector。
|
||||
|
||||
All-or-nothing 语义:任一 chunk 向量化失败则整体失败(抛出异常)。
|
||||
|
||||
Args:
|
||||
chunks: chunk 文本列表或 chunk 字典列表
|
||||
kb_id: 知识库 ID
|
||||
document_id: 文档 ID
|
||||
vector_store: LlamaIndex PGVectorStore 实例
|
||||
embed_model: LlamaIndex embedding 模型
|
||||
|
||||
Returns:
|
||||
已写入 vector store 的 TextNode 列表
|
||||
|
||||
Raises:
|
||||
Exception: 向量化或存储失败
|
||||
"""
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
# 统一 chunk 格式为字符串
|
||||
chunk_texts: list[str] = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
if isinstance(chunk, str):
|
||||
chunk_texts.append(chunk)
|
||||
elif isinstance(chunk, dict):
|
||||
chunk_texts.append(chunk.get("content", ""))
|
||||
else:
|
||||
chunk_texts.append(str(chunk))
|
||||
|
||||
# 构建 TextNode
|
||||
nodes: list[TextNode] = [
|
||||
TextNode(
|
||||
text=text,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
"document_id": document_id,
|
||||
"chunk_index": i,
|
||||
},
|
||||
)
|
||||
for i, text in enumerate(chunk_texts)
|
||||
]
|
||||
|
||||
# 向量化 — all-or-nothing:任一失败则抛出异常
|
||||
for node in nodes:
|
||||
node.embedding = await embed_model.aget_text_embedding(node.text)
|
||||
|
||||
# 存储到 vector store
|
||||
await vector_store.async_add(nodes)
|
||||
|
||||
logger.info(
|
||||
"Vectorized %d chunks for document=%s kb=%s",
|
||||
len(nodes),
|
||||
document_id,
|
||||
kb_id,
|
||||
)
|
||||
return nodes
|
||||
|
||||
async def process(
|
||||
self,
|
||||
file_path: str,
|
||||
file_type: str,
|
||||
kb_id: str,
|
||||
document_id: str,
|
||||
vector_store: "PGVectorStore",
|
||||
embed_model: "BaseEmbedding",
|
||||
store: "KBStore",
|
||||
) -> None:
|
||||
"""完整管道:parse → segment → vectorize,含状态转换。
|
||||
|
||||
状态转换:pending → parsing → segmenting → vectorizing → indexed
|
||||
失败时:状态置为 failed,error_message 记录异常信息,异常重新抛出。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
file_type: 文件类型(扩展名,不含点)
|
||||
kb_id: 知识库 ID
|
||||
document_id: 文档 ID
|
||||
vector_store: LlamaIndex PGVectorStore 实例
|
||||
embed_model: LlamaIndex embedding 模型
|
||||
store: KBStore 实例(用于状态更新)
|
||||
|
||||
Raises:
|
||||
Exception: 管道任一阶段失败(状态已置为 failed)
|
||||
"""
|
||||
try:
|
||||
# parsing
|
||||
await store.update_document_status(document_id, DocumentStatus.parsing)
|
||||
text = self.parse(file_path, file_type)
|
||||
|
||||
# segmenting
|
||||
await store.update_document_status(document_id, DocumentStatus.segmenting)
|
||||
chunks = self.segment(text, self._chunk_size, self._chunk_overlap)
|
||||
|
||||
# vectorizing
|
||||
await store.update_document_status(document_id, DocumentStatus.vectorizing)
|
||||
await self.vectorize(chunks, kb_id, document_id, vector_store, embed_model)
|
||||
|
||||
# indexed
|
||||
await store.update_document_status(document_id, DocumentStatus.indexed)
|
||||
logger.info("Document %s processed successfully", document_id)
|
||||
except Exception as e:
|
||||
# failed — 记录错误信息并重新抛出
|
||||
await store.update_document_status(
|
||||
document_id, DocumentStatus.failed, error_message=str(e)
|
||||
)
|
||||
logger.error("Document %s processing failed: %s", document_id, e)
|
||||
raise
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CHUNK_OVERLAP",
|
||||
"DEFAULT_CHUNK_SIZE",
|
||||
"DocumentProcessor",
|
||||
]
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
"""分段预览 API — 只读预览文档分段结果(不向量化)。
|
||||
|
||||
用户上传文档前可预览分段效果,调整 chunk_size/chunk_overlap 参数后再正式提交。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from agentkit.rag_platform.document_processor import (
|
||||
DEFAULT_CHUNK_OVERLAP,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DocumentProcessor,
|
||||
)
|
||||
|
||||
|
||||
class PreviewChunk(BaseModel):
|
||||
"""预览 chunk 条目。"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
index: int
|
||||
content: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PreviewResult(BaseModel):
|
||||
"""分段预览结果。"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
document_id: str = "" # 预览阶段文档尚未创建,默认空
|
||||
chunks: list[PreviewChunk]
|
||||
total_chunks: int
|
||||
|
||||
|
||||
def generate_preview(
|
||||
file_path: str,
|
||||
file_type: str,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> PreviewResult:
|
||||
"""生成分段预览(只读,不向量化)。
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
file_type: 文件类型(扩展名,不含点),如 "pdf"、"md"
|
||||
chunk_size: 分段大小(token 数),默认 512
|
||||
chunk_overlap: 分段重叠(token 数),默认 50
|
||||
|
||||
Returns:
|
||||
PreviewResult,包含 chunk 列表和总数
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 解析失败
|
||||
"""
|
||||
processor = DocumentProcessor(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
chunks = processor.preview(
|
||||
file_path,
|
||||
file_type,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
return PreviewResult(
|
||||
document_id="",
|
||||
chunks=[
|
||||
PreviewChunk(
|
||||
index=c["index"],
|
||||
content=c["content"],
|
||||
metadata=c.get("metadata", {}),
|
||||
)
|
||||
for c in chunks
|
||||
],
|
||||
total_chunks=len(chunks),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PreviewChunk",
|
||||
"PreviewResult",
|
||||
"generate_preview",
|
||||
]
|
||||
|
|
@ -0,0 +1,350 @@
|
|||
"""内容净化与上传安全 — 文件类型白名单、大小限制、zip/image bomb 检测、SSRF 防护、Markdown 净化。
|
||||
|
||||
安全边界:
|
||||
- 文件类型白名单(pdf/docx/xlsx/pptx/txt/md/csv/html)
|
||||
- 文件大小上限(50MB 默认)
|
||||
- ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP)
|
||||
- Image bomb 检测(像素数 > 100MP)
|
||||
- SSRF 防护(拒绝 loopback/RFC1918/link-local/ULA)
|
||||
- Markdown 危险 HTML 标签净化(script/iframe/object/embed)
|
||||
- XXE 防护(forbid_dtd XML 解析)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import re
|
||||
import struct
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# 文件类型白名单:扩展名 → MIME 类型
|
||||
ALLOWED_FILE_TYPES: dict[str, str] = {
|
||||
".pdf": "application/pdf",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".txt": "text/plain",
|
||||
".md": "text/markdown",
|
||||
".csv": "text/csv",
|
||||
".html": "text/html",
|
||||
}
|
||||
|
||||
# 文件大小上限:50MB
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
|
||||
# ZIP bomb 检测阈值
|
||||
MAX_ZIP_RATIO = 100 # 压缩比上限 100:1
|
||||
MAX_ZIP_DECOMPRESSED_SIZE = 500 * 1024 * 1024 # 解压后上限 500MB
|
||||
|
||||
# Image bomb 检测阈值
|
||||
MAX_IMAGE_PIXELS = 100_000_000 # 100MP
|
||||
|
||||
# 需要净化的文本格式
|
||||
_TEXT_FORMATS = frozenset({"md", "html", "txt", "csv"})
|
||||
|
||||
# 危险 HTML 标签(含内容整体移除)
|
||||
_DANGEROUS_TAG_RE = re.compile(
|
||||
r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)"
|
||||
r"[^>]*>.*?<\s*/\s*\1\s*>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
# 危险自闭合/孤立标签
|
||||
_DANGEROUS_SOLO_RE = re.compile(
|
||||
r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)[^>]*/?>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# 危险闭合标签
|
||||
_DANGEROUS_CLOSE_RE = re.compile(
|
||||
r"<\s*/\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)\s*>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# javascript: 协议
|
||||
_JS_URL_RE = re.compile(r"javascript\s*:", re.IGNORECASE)
|
||||
# 事件处理器属性 (onclick=, onload=, onerror=, ...)
|
||||
_EVENT_HANDLER_RE = re.compile(
|
||||
r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# SSRF 防护:拒绝的网络段
|
||||
_DENIED_NETWORKS = [
|
||||
ipaddress.ip_network("127.0.0.0/8"), # IPv4 loopback
|
||||
ipaddress.ip_network("10.0.0.0/8"), # RFC1918
|
||||
ipaddress.ip_network("172.16.0.0/12"), # RFC1918
|
||||
ipaddress.ip_network("192.168.0.0/16"), # RFC1918
|
||||
ipaddress.ip_network("169.254.0.0/16"), # link-local
|
||||
ipaddress.ip_network("0.0.0.0/8"), # "this network"
|
||||
ipaddress.ip_network("::1/128"), # IPv6 loopback
|
||||
ipaddress.ip_network("fc00::/7"), # IPv6 ULA
|
||||
ipaddress.ip_network("fe80::/10"), # IPv6 link-local
|
||||
ipaddress.ip_network("::/128"), # IPv6 unspecified
|
||||
]
|
||||
|
||||
|
||||
def validate_file_type(filename: str, file_type: str | None = None) -> str:
|
||||
"""校验文件类型是否在白名单内。
|
||||
|
||||
Args:
|
||||
filename: 文件名(用于提取扩展名)
|
||||
file_type: 可选 MIME 类型(当前仅用于日志,不做强制校验)
|
||||
|
||||
Returns:
|
||||
归一化文件类型(扩展名,不含点),如 "pdf"、"md"
|
||||
|
||||
Raises:
|
||||
ValueError: 文件类型不在白名单
|
||||
"""
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext not in ALLOWED_FILE_TYPES:
|
||||
allowed = ", ".join(sorted(ALLOWED_FILE_TYPES.keys()))
|
||||
raise ValueError(f"File type '{ext}' is not allowed. Allowed types: {allowed}")
|
||||
return ext.lstrip(".")
|
||||
|
||||
|
||||
def validate_file_size(file_size: int) -> None:
|
||||
"""校验文件大小是否在上限内。
|
||||
|
||||
Args:
|
||||
file_size: 文件大小(字节)
|
||||
|
||||
Raises:
|
||||
ValueError: 文件超过 MAX_FILE_SIZE
|
||||
"""
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"File size {file_size} bytes exceeds limit {MAX_FILE_SIZE} bytes "
|
||||
f"({MAX_FILE_SIZE // (1024 * 1024)}MB)"
|
||||
)
|
||||
if file_size <= 0:
|
||||
raise ValueError(f"File size must be positive, got {file_size}")
|
||||
|
||||
|
||||
def check_zip_bomb(file_path: str) -> None:
|
||||
"""检测 ZIP bomb — 适用于 .docx/.xlsx/.pptx(本质是 ZIP 格式)。
|
||||
|
||||
拒绝条件:
|
||||
- 单文件压缩比 > 100:1
|
||||
- 总解压大小 > 500MB
|
||||
|
||||
Args:
|
||||
file_path: ZIP 格式文件路径
|
||||
|
||||
Raises:
|
||||
ValueError: 检测到 ZIP bomb
|
||||
zipfile.BadZipFile: 文件不是有效 ZIP
|
||||
"""
|
||||
with zipfile.ZipFile(file_path) as zf:
|
||||
total_compressed = 0
|
||||
total_uncompressed = 0
|
||||
|
||||
for info in zf.infolist():
|
||||
total_compressed += info.compress_size
|
||||
total_uncompressed += info.file_size
|
||||
|
||||
# 单文件压缩比检查
|
||||
if info.compress_size > 0:
|
||||
ratio = info.file_size / info.compress_size
|
||||
if ratio > MAX_ZIP_RATIO:
|
||||
raise ValueError(
|
||||
f"Zip bomb detected: file '{info.filename}' has compression "
|
||||
f"ratio {ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)"
|
||||
)
|
||||
|
||||
# 总解压大小检查
|
||||
if total_uncompressed > MAX_ZIP_DECOMPRESSED_SIZE:
|
||||
raise ValueError(
|
||||
f"Zip bomb detected: total uncompressed size "
|
||||
f"{total_uncompressed} bytes exceeds limit "
|
||||
f"{MAX_ZIP_DECOMPRESSED_SIZE} bytes"
|
||||
)
|
||||
|
||||
# 总体压缩比检查
|
||||
if total_compressed > 0:
|
||||
overall_ratio = total_uncompressed / total_compressed
|
||||
if overall_ratio > MAX_ZIP_RATIO:
|
||||
raise ValueError(
|
||||
f"Zip bomb detected: overall compression ratio "
|
||||
f"{overall_ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)"
|
||||
)
|
||||
|
||||
|
||||
def check_image_bomb(image_path: str) -> None:
|
||||
"""检测 image bomb — 像素数超过 100MP 则拒绝。
|
||||
|
||||
支持 PNG、JPEG、GIF 格式(通过文件头解析,不依赖 PIL)。
|
||||
无法识别的格式视为安全(返回不抛异常)。
|
||||
|
||||
Args:
|
||||
image_path: 图片文件路径
|
||||
|
||||
Raises:
|
||||
ValueError: 像素数超过 MAX_IMAGE_PIXELS
|
||||
"""
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
header = f.read(32)
|
||||
except OSError:
|
||||
return # 无法读取,视为安全
|
||||
|
||||
width, height = _read_image_dimensions(header, image_path)
|
||||
if width <= 0 or height <= 0:
|
||||
return # 无法识别尺寸,视为安全
|
||||
|
||||
pixels = width * height
|
||||
if pixels > MAX_IMAGE_PIXELS:
|
||||
raise ValueError(
|
||||
f"Image bomb detected: {width}x{height} = {pixels} pixels "
|
||||
f"exceeds limit {MAX_IMAGE_PIXELS} pixels (100MP)"
|
||||
)
|
||||
|
||||
|
||||
def _read_image_dimensions(header: bytes, file_path: str) -> tuple[int, int]:
|
||||
"""从文件头读取图片尺寸(PNG/GIF/JPEG)。"""
|
||||
# PNG: \x89PNG\r\n\x1a\n + IHDR (width@16, height@20, big-endian)
|
||||
if len(header) >= 24 and header[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
width = struct.unpack(">I", header[16:20])[0]
|
||||
height = struct.unpack(">I", header[20:24])[0]
|
||||
return width, height
|
||||
|
||||
# GIF: GIF87a/GIF89a (width@6, height@8, little-endian)
|
||||
if len(header) >= 10 and header[:6] in (b"GIF87a", b"GIF89a"):
|
||||
width = struct.unpack("<H", header[6:8])[0]
|
||||
height = struct.unpack("<H", header[8:10])[0]
|
||||
return width, height
|
||||
|
||||
# JPEG: \xff\xd8 — 需要扫描 SOF0/SOF2 标记
|
||||
if len(header) >= 2 and header[:2] == b"\xff\xd8":
|
||||
return _read_jpeg_dimensions(file_path)
|
||||
|
||||
return 0, 0
|
||||
|
||||
|
||||
def _read_jpeg_dimensions(file_path: str) -> tuple[int, int]:
|
||||
"""扫描 JPEG 文件查找 SOF0/SOF2 标记以获取尺寸。"""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
f.read(2) # 跳过 SOI 标记
|
||||
while True:
|
||||
marker = f.read(2)
|
||||
if len(marker) < 2:
|
||||
return 0, 0
|
||||
if marker[0] != 0xFF:
|
||||
continue
|
||||
# SOF0 (0xC0) 或 SOF2 (0xC2) 包含尺寸信息
|
||||
if marker[1] in (0xC0, 0xC2):
|
||||
f.read(3) # 跳过 length(2) + precision(1)
|
||||
height = struct.unpack(">H", f.read(2))[0]
|
||||
width = struct.unpack(">H", f.read(2))[0]
|
||||
return width, height
|
||||
else:
|
||||
# 跳过此 marker 段
|
||||
length_bytes = f.read(2)
|
||||
if len(length_bytes) < 2:
|
||||
return 0, 0
|
||||
length = struct.unpack(">H", length_bytes)[0]
|
||||
f.read(length - 2)
|
||||
except (OSError, struct.error):
|
||||
return 0, 0
|
||||
|
||||
|
||||
def is_safe_ip(ip_str: str) -> bool:
|
||||
"""SSRF 防护:检查 IP 是否安全(非 loopback/RFC1918/link-local/ULA)。
|
||||
|
||||
Args:
|
||||
ip_str: IP 地址字符串
|
||||
|
||||
Returns:
|
||||
True 如果 IP 是安全的公网地址,False 如果在拒绝的网络段内
|
||||
"""
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False # 无效 IP,拒绝
|
||||
|
||||
for network in _DENIED_NETWORKS:
|
||||
if ip in network:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_markdown(content: str) -> str:
|
||||
"""净化 Markdown 内容 — 移除危险 HTML 标签和属性。
|
||||
|
||||
移除内容:
|
||||
- 危险标签(script/iframe/object/embed/style/form/input/textarea/button/link/meta)及其内容
|
||||
- javascript: 协议
|
||||
- 事件处理器属性(onclick/onload/onerror 等)
|
||||
|
||||
Args:
|
||||
content: 原始 Markdown/HTML 文本
|
||||
|
||||
Returns:
|
||||
净化后的文本
|
||||
"""
|
||||
# 移除危险标签及其内容(如 <script>...</script>)
|
||||
result = _DANGEROUS_TAG_RE.sub("", content)
|
||||
# 移除孤立的危险标签(如 <script>、<iframe ...>)
|
||||
result = _DANGEROUS_SOLO_RE.sub("", result)
|
||||
# 移除孤立的闭合标签(如 </script>)
|
||||
result = _DANGEROUS_CLOSE_RE.sub("", result)
|
||||
# 移除 javascript: 协议
|
||||
result = _JS_URL_RE.sub("", result)
|
||||
# 移除事件处理器属性
|
||||
result = _EVENT_HANDLER_RE.sub("", result)
|
||||
return result
|
||||
|
||||
|
||||
def sanitize_content(content: str, file_type: str) -> str:
|
||||
"""内容净化主入口 — 对文本格式应用 Markdown 净化。
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
file_type: 文件类型(扩展名,不含点),如 "md"、"html"、"pdf"
|
||||
|
||||
Returns:
|
||||
净化后的内容(二进制格式原样返回)
|
||||
"""
|
||||
if file_type in _TEXT_FORMATS:
|
||||
return sanitize_markdown(content)
|
||||
return content
|
||||
|
||||
|
||||
def parse_xml_safe(content: bytes) -> Any:
|
||||
"""安全解析 XML — 禁止 DTD/实体以防止 XXE 攻击。
|
||||
|
||||
Args:
|
||||
content: XML 字节内容
|
||||
|
||||
Returns:
|
||||
ElementTree Element
|
||||
|
||||
Raises:
|
||||
xml.etree.ElementTree.ParseError: 解析失败或包含 DTD
|
||||
"""
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
parser = ET.XMLParser(
|
||||
forbid_dtd=True,
|
||||
forbid_entities=True,
|
||||
forbid_external=True,
|
||||
)
|
||||
return ET.fromstring(content, parser=parser)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ALLOWED_FILE_TYPES",
|
||||
"MAX_FILE_SIZE",
|
||||
"MAX_IMAGE_PIXELS",
|
||||
"MAX_ZIP_DECOMPRESSED_SIZE",
|
||||
"MAX_ZIP_RATIO",
|
||||
"check_image_bomb",
|
||||
"check_zip_bomb",
|
||||
"is_safe_ip",
|
||||
"parse_xml_safe",
|
||||
"sanitize_content",
|
||||
"sanitize_markdown",
|
||||
"validate_file_size",
|
||||
"validate_file_type",
|
||||
]
|
||||
|
|
@ -0,0 +1,382 @@
|
|||
"""U3+U7 测试 — 文档处理管道。
|
||||
|
||||
测试场景:
|
||||
1. parse + segment 管道(mock 文件 I/O)
|
||||
2. preview 返回 chunk 列表
|
||||
3. vectorize 调用 embed model + vector store
|
||||
4. 失败时设置 error 状态
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.rag_platform.document_processor import DocumentProcessor
|
||||
from agentkit.rag_platform.models import DocumentStatus
|
||||
from agentkit.rag_platform.preview import PreviewResult, generate_preview
|
||||
|
||||
|
||||
class TestParseAndSegment:
|
||||
"""parse + segment 管道测试。"""
|
||||
|
||||
def test_parse_extracts_text(self, tmp_path):
|
||||
"""parse 从文件提取文本。"""
|
||||
# 创建测试文件
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Hello, world!\nThis is a test document.", encoding="utf-8")
|
||||
|
||||
processor = DocumentProcessor()
|
||||
text = processor.parse(str(file_path), "txt")
|
||||
|
||||
assert "Hello, world!" in text
|
||||
assert "test document" in text
|
||||
|
||||
def test_parse_applies_sanitization(self, tmp_path):
|
||||
"""parse 对文本格式应用内容净化。"""
|
||||
file_path = tmp_path / "test.md"
|
||||
file_path.write_text(
|
||||
"Hello <script>alert(1)</script> world",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
processor = DocumentProcessor()
|
||||
text = processor.parse(str(file_path), "md")
|
||||
|
||||
assert "<script>" not in text
|
||||
assert "Hello" in text
|
||||
assert "world" in text
|
||||
|
||||
def test_segment_splits_text(self):
|
||||
"""segment 将文本分段。"""
|
||||
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
|
||||
text = "This is a sentence. " * 50 # 足够长以触发分段
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
chunks = processor.segment(text, chunk_size=100, chunk_overlap=10)
|
||||
|
||||
assert chunks == ["chunk1", "chunk2", "chunk3"]
|
||||
mock_splitter.split_text.assert_called_once_with(text)
|
||||
|
||||
def test_segment_uses_custom_params(self):
|
||||
"""segment 使用自定义 chunk_size 和 chunk_overlap。"""
|
||||
processor = DocumentProcessor()
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
processor.segment("text", chunk_size=256, chunk_overlap=20)
|
||||
|
||||
# 验证 SentenceSplitter 使用了自定义参数
|
||||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||||
assert call_kwargs["chunk_size"] == 256
|
||||
assert call_kwargs["chunk_overlap"] == 20
|
||||
|
||||
def test_parse_file_not_found(self):
|
||||
"""parse 文件不存在时抛出 FileNotFoundError。"""
|
||||
processor = DocumentProcessor()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
processor.parse("/nonexistent/file.txt", "txt")
|
||||
|
||||
|
||||
class TestPreview:
|
||||
"""preview 方法测试。"""
|
||||
|
||||
def test_preview_returns_chunks(self, tmp_path):
|
||||
"""preview 返回 chunk 字典列表。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content for preview.", encoding="utf-8")
|
||||
|
||||
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
chunks = processor.preview(str(file_path), "txt")
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["index"] == 0
|
||||
assert chunks[0]["content"] == "chunk1"
|
||||
assert chunks[1]["index"] == 1
|
||||
assert chunks[1]["content"] == "chunk2"
|
||||
assert chunks[0]["metadata"] == {}
|
||||
|
||||
def test_preview_uses_default_params(self, tmp_path):
|
||||
"""preview 使用构造器默认参数。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content.", encoding="utf-8")
|
||||
|
||||
processor = DocumentProcessor(chunk_size=512, chunk_overlap=50)
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
processor.preview(str(file_path), "txt")
|
||||
|
||||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||||
assert call_kwargs["chunk_size"] == 512
|
||||
assert call_kwargs["chunk_overlap"] == 50
|
||||
|
||||
|
||||
class TestGeneratePreview:
|
||||
"""generate_preview 函数测试。"""
|
||||
|
||||
def test_generate_preview_returns_preview_result(self, tmp_path):
|
||||
"""generate_preview 返回 PreviewResult。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content for preview.", encoding="utf-8")
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
result = generate_preview(str(file_path), "txt")
|
||||
|
||||
assert isinstance(result, PreviewResult)
|
||||
assert result.total_chunks == 3
|
||||
assert len(result.chunks) == 3
|
||||
assert result.chunks[0].index == 0
|
||||
assert result.chunks[0].content == "chunk1"
|
||||
assert result.chunks[2].index == 2
|
||||
assert result.chunks[2].content == "chunk3"
|
||||
assert result.document_id == "" # 预览阶段无 document_id
|
||||
|
||||
def test_generate_preview_with_custom_params(self, tmp_path):
|
||||
"""generate_preview 使用自定义分段参数。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content.", encoding="utf-8")
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
generate_preview(str(file_path), "txt", chunk_size=256, chunk_overlap=20)
|
||||
|
||||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||||
assert call_kwargs["chunk_size"] == 256
|
||||
assert call_kwargs["chunk_overlap"] == 20
|
||||
|
||||
|
||||
class TestVectorize:
|
||||
"""vectorize 方法测试。"""
|
||||
|
||||
def _make_mock_embed_model(self):
|
||||
"""创建 mock embedding 模型。"""
|
||||
mock = MagicMock()
|
||||
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||
return mock
|
||||
|
||||
def _make_mock_vector_store(self):
|
||||
"""创建 mock vector store。"""
|
||||
mock = MagicMock()
|
||||
mock.async_add = AsyncMock()
|
||||
return mock
|
||||
|
||||
async def test_vectorize_calls_embed_model(self):
|
||||
"""vectorize 调用 embedding 模型。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
|
||||
chunks = ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||||
mock_nodes = [MagicMock() for _ in chunks]
|
||||
mock_node_cls.side_effect = mock_nodes
|
||||
|
||||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||||
|
||||
# 每个 chunk 都调用了 embed
|
||||
assert mock_embed.aget_text_embedding.await_count == 3
|
||||
|
||||
async def test_vectorize_calls_vector_store(self):
|
||||
"""vectorize 调用 vector store 的 async_add。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
|
||||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||||
mock_nodes = [MagicMock() for _ in chunks]
|
||||
mock_node_cls.side_effect = mock_nodes
|
||||
|
||||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||||
|
||||
mock_vs.async_add.assert_awaited_once()
|
||||
# 验证传入的 nodes 数量正确
|
||||
added_nodes = mock_vs.async_add.call_args.args[0]
|
||||
assert len(added_nodes) == 2
|
||||
|
||||
async def test_vectorize_accepts_dict_chunks(self):
|
||||
"""vectorize 接受字典格式的 chunk。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
|
||||
chunks = [
|
||||
{"index": 0, "content": "chunk1", "metadata": {}},
|
||||
{"index": 1, "content": "chunk2", "metadata": {}},
|
||||
]
|
||||
|
||||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||||
mock_nodes = [MagicMock() for _ in chunks]
|
||||
mock_node_cls.side_effect = mock_nodes
|
||||
|
||||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||||
|
||||
assert mock_embed.aget_text_embedding.await_count == 2
|
||||
|
||||
async def test_vectorize_failure_propagates(self):
|
||||
"""vectorize 失败时抛出异常。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = MagicMock()
|
||||
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed failed"))
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
|
||||
with patch("llama_index.core.schema.TextNode"):
|
||||
with pytest.raises(RuntimeError, match="embed failed"):
|
||||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||||
|
||||
async def test_vectorize_sets_metadata(self):
|
||||
"""vectorize 为每个 node 设置正确的 metadata。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
|
||||
chunks = ["chunk1", "chunk2"]
|
||||
|
||||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||||
mock_nodes = [MagicMock() for _ in chunks]
|
||||
mock_node_cls.side_effect = mock_nodes
|
||||
|
||||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||||
|
||||
# 验证 TextNode 被创建时使用了正确的 metadata
|
||||
assert mock_node_cls.call_count == 2
|
||||
first_call_kwargs = mock_node_cls.call_args_list[0].kwargs
|
||||
assert first_call_kwargs["metadata"]["kb_id"] == "kb-1"
|
||||
assert first_call_kwargs["metadata"]["document_id"] == "doc-1"
|
||||
assert first_call_kwargs["metadata"]["chunk_index"] == 0
|
||||
|
||||
|
||||
class TestProcess:
|
||||
"""process 方法(完整管道 + 状态转换)测试。"""
|
||||
|
||||
def _make_mock_store(self):
|
||||
"""创建 mock KBStore。"""
|
||||
store = MagicMock()
|
||||
store.update_document_status = AsyncMock()
|
||||
return store
|
||||
|
||||
async def test_process_success_transitions(self, tmp_path):
|
||||
"""process 成功时状态转换:parsing → segmenting → vectorizing → indexed。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content.", encoding="utf-8")
|
||||
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = MagicMock()
|
||||
mock_embed.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||
mock_vs = MagicMock()
|
||||
mock_vs.async_add = AsyncMock()
|
||||
mock_store = self._make_mock_store()
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk1"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
with patch("llama_index.core.schema.TextNode"):
|
||||
await processor.process(
|
||||
str(file_path),
|
||||
"txt",
|
||||
"kb-1",
|
||||
"doc-1",
|
||||
mock_vs,
|
||||
mock_embed,
|
||||
mock_store,
|
||||
)
|
||||
|
||||
# 验证状态转换序列
|
||||
status_calls = mock_store.update_document_status.call_args_list
|
||||
assert len(status_calls) == 4
|
||||
assert status_calls[0].args[1] == DocumentStatus.parsing
|
||||
assert status_calls[1].args[1] == DocumentStatus.segmenting
|
||||
assert status_calls[2].args[1] == DocumentStatus.vectorizing
|
||||
assert status_calls[3].args[1] == DocumentStatus.indexed
|
||||
|
||||
async def test_process_failure_sets_failed_status(self, tmp_path):
|
||||
"""process 失败时状态置为 failed。"""
|
||||
file_path = tmp_path / "test.txt"
|
||||
file_path.write_text("Test content.", encoding="utf-8")
|
||||
|
||||
processor = DocumentProcessor()
|
||||
# embed 模型失败
|
||||
mock_embed = MagicMock()
|
||||
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed error"))
|
||||
mock_vs = MagicMock()
|
||||
mock_vs.async_add = AsyncMock()
|
||||
mock_store = self._make_mock_store()
|
||||
|
||||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||||
mock_splitter = MagicMock()
|
||||
mock_splitter.split_text.return_value = ["chunk1"]
|
||||
mock_splitter_cls.return_value = mock_splitter
|
||||
|
||||
with patch("llama_index.core.schema.TextNode"):
|
||||
with pytest.raises(RuntimeError, match="embed error"):
|
||||
await processor.process(
|
||||
str(file_path),
|
||||
"txt",
|
||||
"kb-1",
|
||||
"doc-1",
|
||||
mock_vs,
|
||||
mock_embed,
|
||||
mock_store,
|
||||
)
|
||||
|
||||
# 验证最后状态为 failed,且包含 error_message
|
||||
status_calls = mock_store.update_document_status.call_args_list
|
||||
last_call = status_calls[-1]
|
||||
assert last_call.args[1] == DocumentStatus.failed
|
||||
assert "embed error" in last_call.kwargs.get("error_message", "")
|
||||
|
||||
async def test_process_parse_failure_sets_failed_status(self, tmp_path):
|
||||
"""parse 阶段失败时状态置为 failed。"""
|
||||
processor = DocumentProcessor()
|
||||
mock_embed = MagicMock()
|
||||
mock_vs = MagicMock()
|
||||
mock_store = self._make_mock_store()
|
||||
|
||||
# 文件不存在导致 parse 失败
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await processor.process(
|
||||
"/nonexistent/file.txt",
|
||||
"txt",
|
||||
"kb-1",
|
||||
"doc-1",
|
||||
mock_vs,
|
||||
mock_embed,
|
||||
mock_store,
|
||||
)
|
||||
|
||||
# 验证状态为 failed
|
||||
status_calls = mock_store.update_document_status.call_args_list
|
||||
assert len(status_calls) >= 2 # 至少 parsing + failed
|
||||
last_call = status_calls[-1]
|
||||
assert last_call.args[1] == DocumentStatus.failed
|
||||
|
|
@ -0,0 +1,346 @@
|
|||
"""U3+U7 测试 — 内容净化与上传安全。
|
||||
|
||||
测试场景:
|
||||
1. 文件类型白名单(允许类型通过,.exe/.sh 拒绝)
|
||||
2. 文件大小限制(超限拒绝)
|
||||
3. Markdown 净化(script 标签移除)
|
||||
4. SSRF IP 过滤(私有 IP 拒绝,公网 IP 允许)
|
||||
5. ZIP bomb 检测(高压缩比拒绝)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
import zipfile
|
||||
import zlib
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.rag_platform.sanitize import (
|
||||
ALLOWED_FILE_TYPES,
|
||||
MAX_FILE_SIZE,
|
||||
check_image_bomb,
|
||||
check_zip_bomb,
|
||||
is_safe_ip,
|
||||
sanitize_content,
|
||||
sanitize_markdown,
|
||||
validate_file_size,
|
||||
validate_file_type,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateFileType:
|
||||
"""文件类型白名单测试。"""
|
||||
|
||||
@pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"])
|
||||
def test_allowed_office_types_pass(self, filename: str):
|
||||
"""允许的 Office 格式通过。"""
|
||||
assert validate_file_type(filename) == Path(filename).suffix[1:].lower()
|
||||
|
||||
@pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"])
|
||||
def test_allowed_text_types_pass(self, filename: str):
|
||||
"""允许的文本格式通过。"""
|
||||
result = validate_file_type(filename)
|
||||
assert result == Path(filename).suffix[1:].lower()
|
||||
|
||||
def test_pdf_returns_pdf(self):
|
||||
"""PDF 文件返回 'pdf'。"""
|
||||
assert validate_file_type("report.pdf") == "pdf"
|
||||
|
||||
@pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"])
|
||||
def test_dangerous_types_rejected(self, filename: str):
|
||||
"""危险文件类型被拒绝。"""
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_file_type(filename)
|
||||
|
||||
def test_exe_rejected(self):
|
||||
""".exe 文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_file_type("program.exe")
|
||||
|
||||
def test_sh_rejected(self):
|
||||
""".sh 文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_file_type("script.sh")
|
||||
|
||||
def test_case_insensitive_extension(self):
|
||||
"""扩展名大小写不敏感。"""
|
||||
assert validate_file_type("DOC.PDF") == "pdf"
|
||||
assert validate_file_type("doc.MD") == "md"
|
||||
|
||||
def test_no_extension_rejected(self):
|
||||
"""无扩展名文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
validate_file_type("noextension")
|
||||
|
||||
def test_all_allowed_types_in_whitelist(self):
|
||||
"""白名单包含 8 种类型。"""
|
||||
expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"}
|
||||
assert set(ALLOWED_FILE_TYPES.keys()) == expected
|
||||
|
||||
|
||||
class TestValidateFileSize:
|
||||
"""文件大小限制测试。"""
|
||||
|
||||
def test_small_file_passes(self):
|
||||
"""小文件通过。"""
|
||||
validate_file_size(1024)
|
||||
validate_file_size(1)
|
||||
|
||||
def test_exact_limit_passes(self):
|
||||
"""刚好等于上限的文件通过。"""
|
||||
validate_file_size(MAX_FILE_SIZE)
|
||||
|
||||
def test_oversized_rejected(self):
|
||||
"""超限文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="exceeds limit"):
|
||||
validate_file_size(MAX_FILE_SIZE + 1)
|
||||
|
||||
def test_zero_size_rejected(self):
|
||||
"""零字节文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="must be positive"):
|
||||
validate_file_size(0)
|
||||
|
||||
def test_negative_size_rejected(self):
|
||||
"""负大小文件被拒绝。"""
|
||||
with pytest.raises(ValueError, match="must be positive"):
|
||||
validate_file_size(-1)
|
||||
|
||||
|
||||
class TestSanitizeMarkdown:
|
||||
"""Markdown 净化测试。"""
|
||||
|
||||
def test_script_tag_removed(self):
|
||||
"""script 标签及其内容被移除。"""
|
||||
content = "Hello <script>alert('xss')</script> world"
|
||||
result = sanitize_markdown(content)
|
||||
assert "<script>" not in result
|
||||
assert "alert" not in result
|
||||
assert "Hello" in result
|
||||
assert "world" in result
|
||||
|
||||
def test_iframe_tag_removed(self):
|
||||
"""iframe 标签被移除。"""
|
||||
content = "Text <iframe src='evil.com'></iframe> more"
|
||||
result = sanitize_markdown(content)
|
||||
assert "<iframe" not in result
|
||||
assert "evil.com" not in result
|
||||
|
||||
def test_object_tag_removed(self):
|
||||
"""object 标签被移除。"""
|
||||
content = "<object data='evil.swf'></object>"
|
||||
result = sanitize_markdown(content)
|
||||
assert "<object" not in result
|
||||
|
||||
def test_embed_tag_removed(self):
|
||||
"""embed 标签被移除。"""
|
||||
content = "<embed src='evil.swf'>"
|
||||
result = sanitize_markdown(content)
|
||||
assert "<embed" not in result
|
||||
|
||||
def test_safe_content_preserved(self):
|
||||
"""安全内容保留。"""
|
||||
content = "# Title\n\nThis is **bold** and *italic*."
|
||||
result = sanitize_markdown(content)
|
||||
assert result == content
|
||||
|
||||
def test_javascript_protocol_removed(self):
|
||||
"""javascript: 协议被移除。"""
|
||||
content = "<a href='javascript:alert(1)'>click</a>"
|
||||
result = sanitize_markdown(content)
|
||||
assert "javascript:" not in result.lower()
|
||||
|
||||
def test_event_handler_removed(self):
|
||||
"""事件处理器属性被移除。"""
|
||||
content = "<div onclick='alert(1)'>text</div>"
|
||||
result = sanitize_markdown(content)
|
||||
assert "onclick" not in result
|
||||
|
||||
def test_multiple_dangerous_tags_removed(self):
|
||||
"""多个危险标签同时移除。"""
|
||||
content = (
|
||||
"<script>bad()</script><iframe src='x'></iframe>safe text<object data='y'></object>"
|
||||
)
|
||||
result = sanitize_markdown(content)
|
||||
assert "<script" not in result
|
||||
assert "<iframe" not in result
|
||||
assert "<object" not in result
|
||||
assert "safe text" in result
|
||||
|
||||
def test_multiline_script_removed(self):
|
||||
"""多行 script 标签被移除。"""
|
||||
content = "before\n<script>\nline1\nline2\n</script>\nafter"
|
||||
result = sanitize_markdown(content)
|
||||
assert "<script>" not in result
|
||||
assert "line1" not in result
|
||||
assert "line2" not in result
|
||||
assert "before" in result
|
||||
assert "after" in result
|
||||
|
||||
|
||||
class TestSanitizeContent:
|
||||
"""sanitize_content 主入口测试。"""
|
||||
|
||||
def test_markdown_sanitized(self):
|
||||
"""Markdown 格式应用净化。"""
|
||||
content = "Hello <script>alert(1)</script>"
|
||||
result = sanitize_content(content, "md")
|
||||
assert "<script>" not in result
|
||||
|
||||
def test_html_sanitized(self):
|
||||
"""HTML 格式应用净化。"""
|
||||
content = "<iframe src='evil'></iframe>"
|
||||
result = sanitize_content(content, "html")
|
||||
assert "<iframe" not in result
|
||||
|
||||
def test_text_sanitized(self):
|
||||
"""纯文本格式应用净化。"""
|
||||
content = "text <script>x</script>"
|
||||
result = sanitize_content(content, "txt")
|
||||
assert "<script>" not in result
|
||||
|
||||
def test_binary_format_not_sanitized(self):
|
||||
"""二进制格式(pdf)不应用净化。"""
|
||||
content = "raw <script>content</script>"
|
||||
result = sanitize_content(content, "pdf")
|
||||
assert result == content # 原样返回
|
||||
|
||||
|
||||
class TestIsSafeIp:
|
||||
"""SSRF IP 过滤测试。"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ip",
|
||||
[
|
||||
"127.0.0.1",
|
||||
"127.0.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
"192.168.1.1",
|
||||
"169.254.1.1",
|
||||
"0.0.0.0",
|
||||
],
|
||||
)
|
||||
def test_private_ips_blocked(self, ip: str):
|
||||
"""私有/loopback IP 被拒绝。"""
|
||||
assert is_safe_ip(ip) is False
|
||||
|
||||
def test_ipv6_loopback_blocked(self):
|
||||
"""IPv6 loopback 被拒绝。"""
|
||||
assert is_safe_ip("::1") is False
|
||||
|
||||
def test_ipv6_ula_blocked(self):
|
||||
"""IPv6 ULA 被拒绝。"""
|
||||
assert is_safe_ip("fd00::1") is False
|
||||
|
||||
def test_ipv6_link_local_blocked(self):
|
||||
"""IPv6 link-local 被拒绝。"""
|
||||
assert is_safe_ip("fe80::1") is False
|
||||
|
||||
@pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"])
|
||||
def test_public_ips_allowed(self, ip: str):
|
||||
"""公网 IP 允许。"""
|
||||
assert is_safe_ip(ip) is True
|
||||
|
||||
def test_invalid_ip_blocked(self):
|
||||
"""无效 IP 被拒绝。"""
|
||||
assert is_safe_ip("not-an-ip") is False
|
||||
|
||||
def test_empty_ip_blocked(self):
|
||||
"""空字符串被拒绝。"""
|
||||
assert is_safe_ip("") is False
|
||||
|
||||
|
||||
class TestCheckZipBomb:
|
||||
"""ZIP bomb 检测测试。"""
|
||||
|
||||
def test_normal_zip_passes(self, tmp_path):
|
||||
"""正常 ZIP 文件通过。"""
|
||||
zip_path = tmp_path / "normal.zip"
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("file.txt", b"Hello, world! " * 100)
|
||||
|
||||
# 不应抛出异常
|
||||
check_zip_bomb(str(zip_path))
|
||||
|
||||
def test_high_ratio_rejected(self, tmp_path):
|
||||
"""高压缩比 ZIP 被拒绝(ZIP bomb)。"""
|
||||
zip_path = tmp_path / "bomb.zip"
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 10MB 全零数据 — 压缩比极高
|
||||
zf.writestr("bomb.txt", b"\x00" * 10_000_000)
|
||||
|
||||
with pytest.raises(ValueError, match="Zip bomb"):
|
||||
check_zip_bomb(str(zip_path))
|
||||
|
||||
def test_large_uncompressed_rejected_via_mock(self, tmp_path):
|
||||
"""解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。"""
|
||||
zip_path = str(tmp_path / "large.zip")
|
||||
|
||||
# 创建 mock ZipFile 返回超大文件信息
|
||||
mock_info = MagicMock()
|
||||
mock_info.filename = "large.bin"
|
||||
mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed
|
||||
mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit)
|
||||
|
||||
mock_zipfile = MagicMock()
|
||||
mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile)
|
||||
mock_zipfile.__exit__ = MagicMock(return_value=False)
|
||||
mock_zipfile.infolist.return_value = [mock_info]
|
||||
|
||||
with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile):
|
||||
with pytest.raises(ValueError, match="Zip bomb"):
|
||||
check_zip_bomb(zip_path)
|
||||
|
||||
def test_docx_zip_format_works(self, tmp_path):
|
||||
""".docx 格式(本质是 ZIP)能被检测。"""
|
||||
zip_path = tmp_path / "fake.docx"
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("document.xml", "<doc/>")
|
||||
|
||||
# 正常 docx 不应触发 zip bomb
|
||||
check_zip_bomb(str(zip_path))
|
||||
|
||||
|
||||
class TestCheckImageBomb:
|
||||
"""Image bomb 检测测试。"""
|
||||
|
||||
def _make_png(self, path: Path, width: int, height: int) -> None:
|
||||
"""创建一个最小 PNG 文件(仅 IHDR)。"""
|
||||
png_header = b"\x89PNG\r\n\x1a\n"
|
||||
ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00"
|
||||
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
|
||||
ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
|
||||
path.write_bytes(png_header + ihdr_chunk)
|
||||
|
||||
def test_small_png_passes(self, tmp_path):
|
||||
"""小 PNG 图片通过。"""
|
||||
img_path = tmp_path / "small.png"
|
||||
self._make_png(img_path, 1, 1)
|
||||
|
||||
# 不应抛出异常
|
||||
check_image_bomb(str(img_path))
|
||||
|
||||
def test_large_png_rejected(self, tmp_path):
|
||||
"""超大 PNG(像素数 > 100MP)被拒绝。"""
|
||||
img_path = tmp_path / "huge.png"
|
||||
# 20000x20000 = 400MP > 100MP
|
||||
self._make_png(img_path, 20000, 20000)
|
||||
|
||||
with pytest.raises(ValueError, match="Image bomb"):
|
||||
check_image_bomb(str(img_path))
|
||||
|
||||
def test_unknown_format_passes(self, tmp_path):
|
||||
"""无法识别的格式视为安全。"""
|
||||
img_path = tmp_path / "unknown.bin"
|
||||
img_path.write_bytes(b"\x00" * 32)
|
||||
|
||||
# 不应抛出异常
|
||||
check_image_bomb(str(img_path))
|
||||
|
||||
def test_nonexistent_file_passes(self, tmp_path):
|
||||
"""不存在的文件视为安全(不抛异常)。"""
|
||||
check_image_bomb(str(tmp_path / "nonexistent.png"))
|
||||
Loading…
Reference in New Issue