feat(rag_platform): U3+U7 — document processing pipeline + upload security

U3: Document processing pipeline (document_processor.py)
- DocumentProcessor class wrapping parse → segment → vectorize
- parse() uses memory/document_loader.py for multi-format extraction
- segment() uses LlamaIndex SentenceSplitter
- preview() returns chunks for read-only preview (no vectorization)
- vectorize() embeds chunks and stores in pgvector (all-or-nothing)
- process() orchestrates full pipeline with status transitions:
  pending → parsing → segmenting → vectorizing → indexed | failed

U7: Upload security & content sanitization (sanitize.py)
- ALLOWED_FILE_TYPES whitelist (pdf/docx/xlsx/pptx/txt/md/csv/html)
- MAX_FILE_SIZE 50MB limit
- validate_file_type() / validate_file_size() guards
- check_zip_bomb() for ZIP-based formats (ratio > 100:1 or > 500MB)
- check_image_bomb() for pixel count > 100MP (PNG/JPEG/GIF header parsing)
- is_safe_ip() SSRF protection (loopback/RFC1918/link-local/ULA denied)
- sanitize_markdown() removes dangerous HTML tags (script/iframe/object/embed)
- sanitize_content() main entry point for text format sanitization
- parse_xml_safe() XXE protection (forbid_dtd/forbid_entities/forbid_external)

Preview API (preview.py)
- PreviewChunk / PreviewResult Pydantic models
- generate_preview() returns read-only segmentation preview

Tests: 112 tests passing (45 new + 67 existing)
- test_sanitize.py: file type/size, markdown sanitization, SSRF, zip/image bomb
- test_document_processor.py: parse/segment, preview, vectorize, failure status
This commit is contained in:
chiguyong 2026-06-25 11:21:42 +08:00
parent c1a21f57a1
commit b55c896794
6 changed files with 1450 additions and 0 deletions

View File

@ -1,5 +1,10 @@
"""RAG 平台模块 — 企业知识库场景的工业级 RAG 管道。"""
from agentkit.rag_platform.document_processor import (
DEFAULT_CHUNK_OVERLAP,
DEFAULT_CHUNK_SIZE,
DocumentProcessor,
)
from agentkit.rag_platform.models import (
Chunk,
Document,
@ -9,13 +14,44 @@ from agentkit.rag_platform.models import (
QueryMode,
QueryResult,
)
from agentkit.rag_platform.preview import (
PreviewChunk,
PreviewResult,
generate_preview,
)
from agentkit.rag_platform.sanitize import (
ALLOWED_FILE_TYPES,
MAX_FILE_SIZE,
check_image_bomb,
check_zip_bomb,
is_safe_ip,
sanitize_content,
sanitize_markdown,
validate_file_size,
validate_file_type,
)
__all__ = [
"ALLOWED_FILE_TYPES",
"DEFAULT_CHUNK_OVERLAP",
"DEFAULT_CHUNK_SIZE",
"Chunk",
"Document",
"DocumentProcessor",
"DocumentStatus",
"KBStatus",
"KnowledgeBase",
"MAX_FILE_SIZE",
"PreviewChunk",
"PreviewResult",
"QueryMode",
"QueryResult",
"check_image_bomb",
"check_zip_bomb",
"generate_preview",
"is_safe_ip",
"sanitize_content",
"sanitize_markdown",
"validate_file_size",
"validate_file_type",
]

View File

@ -0,0 +1,250 @@
"""文档处理管道 — 解析 → 分段 → 向量化。
封装 DocumentLoader多格式解析+ LlamaIndex SentenceSplitter分段+
PGVectorStore向量化存储
状态机pending parsing segmenting vectorizing indexed | failed
失败语义vectorize all-or-nothing 失败时抛出异常调用方负责将状态置为 failed
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from agentkit.memory.document_loader import DocumentLoader
from agentkit.rag_platform.models import DocumentStatus
from agentkit.rag_platform.sanitize import sanitize_content
if TYPE_CHECKING:
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.postgres import PGVectorStore
from agentkit.rag_platform.store import KBStore
logger = logging.getLogger(__name__)
DEFAULT_CHUNK_SIZE = 512
DEFAULT_CHUNK_OVERLAP = 50
class DocumentProcessor:
"""文档处理管道:解析 → 分段 → 向量化。
Args:
chunk_size: 分段大小token 默认 512
chunk_overlap: 分段重叠token 默认 50
"""
def __init__(
self,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> None:
self._loader = DocumentLoader()
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
def parse(self, file_path: str, file_type: str) -> str:
"""从文件中提取文本。
使用 memory/document_loader.py DocumentLoader 进行多格式解析
PDF/DOCX/XLSX/Markdown/HTML/纯文本
Args:
file_path: 文件路径
file_type: 文件类型扩展名不含点 用于净化
Returns:
提取的文本内容已净化
Raises:
FileNotFoundError: 文件不存在
ValueError: 内容超过大小限制
"""
doc = self._loader.load(file_path)
# 对文本格式应用内容净化
return sanitize_content(doc.content, file_type)
def segment(
self,
text: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> list[str]:
"""将文本分段为 chunk。
使用 LlamaIndex SentenceSplitter 进行分段
Args:
text: 待分段文本
chunk_size: 分段大小token
chunk_overlap: 分段重叠token
Returns:
chunk 文本列表
"""
from llama_index.core.node_parser import SentenceSplitter
splitter = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
chunks = splitter.split_text(text)
logger.debug(
"Segmented text into %d chunks (chunk_size=%d, overlap=%d)",
len(chunks),
chunk_size,
chunk_overlap,
)
return chunks
def preview(
self,
file_path: str,
file_type: str,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
) -> list[dict[str, Any]]:
"""解析 + 分段,返回 chunk 用于只读预览(不向量化)。
Args:
file_path: 文件路径
file_type: 文件类型扩展名不含点
chunk_size: 分段大小默认使用构造器值
chunk_overlap: 分段重叠默认使用构造器值
Returns:
chunk 字典列表[{index, content, metadata}, ...]
"""
cs = chunk_size or self._chunk_size
co = chunk_overlap or self._chunk_overlap
text = self.parse(file_path, file_type)
chunks = self.segment(text, cs, co)
return [{"index": i, "content": c, "metadata": {}} for i, c in enumerate(chunks)]
async def vectorize(
self,
chunks: list[str] | list[dict[str, Any]],
kb_id: str,
document_id: str,
vector_store: "PGVectorStore",
embed_model: "BaseEmbedding",
) -> list["TextNode"]:
"""将 chunk 向量化并存储到 pgvector。
All-or-nothing 语义任一 chunk 向量化失败则整体失败抛出异常
Args:
chunks: chunk 文本列表或 chunk 字典列表
kb_id: 知识库 ID
document_id: 文档 ID
vector_store: LlamaIndex PGVectorStore 实例
embed_model: LlamaIndex embedding 模型
Returns:
已写入 vector store TextNode 列表
Raises:
Exception: 向量化或存储失败
"""
from llama_index.core.schema import TextNode
# 统一 chunk 格式为字符串
chunk_texts: list[str] = []
for i, chunk in enumerate(chunks):
if isinstance(chunk, str):
chunk_texts.append(chunk)
elif isinstance(chunk, dict):
chunk_texts.append(chunk.get("content", ""))
else:
chunk_texts.append(str(chunk))
# 构建 TextNode
nodes: list[TextNode] = [
TextNode(
text=text,
metadata={
"kb_id": kb_id,
"document_id": document_id,
"chunk_index": i,
},
)
for i, text in enumerate(chunk_texts)
]
# 向量化 — all-or-nothing任一失败则抛出异常
for node in nodes:
node.embedding = await embed_model.aget_text_embedding(node.text)
# 存储到 vector store
await vector_store.async_add(nodes)
logger.info(
"Vectorized %d chunks for document=%s kb=%s",
len(nodes),
document_id,
kb_id,
)
return nodes
async def process(
self,
file_path: str,
file_type: str,
kb_id: str,
document_id: str,
vector_store: "PGVectorStore",
embed_model: "BaseEmbedding",
store: "KBStore",
) -> None:
"""完整管道parse → segment → vectorize含状态转换。
状态转换pending parsing segmenting vectorizing indexed
失败时状态置为 failederror_message 记录异常信息异常重新抛出
Args:
file_path: 文件路径
file_type: 文件类型扩展名不含点
kb_id: 知识库 ID
document_id: 文档 ID
vector_store: LlamaIndex PGVectorStore 实例
embed_model: LlamaIndex embedding 模型
store: KBStore 实例用于状态更新
Raises:
Exception: 管道任一阶段失败状态已置为 failed
"""
try:
# parsing
await store.update_document_status(document_id, DocumentStatus.parsing)
text = self.parse(file_path, file_type)
# segmenting
await store.update_document_status(document_id, DocumentStatus.segmenting)
chunks = self.segment(text, self._chunk_size, self._chunk_overlap)
# vectorizing
await store.update_document_status(document_id, DocumentStatus.vectorizing)
await self.vectorize(chunks, kb_id, document_id, vector_store, embed_model)
# indexed
await store.update_document_status(document_id, DocumentStatus.indexed)
logger.info("Document %s processed successfully", document_id)
except Exception as e:
# failed — 记录错误信息并重新抛出
await store.update_document_status(
document_id, DocumentStatus.failed, error_message=str(e)
)
logger.error("Document %s processing failed: %s", document_id, e)
raise
__all__ = [
"DEFAULT_CHUNK_OVERLAP",
"DEFAULT_CHUNK_SIZE",
"DocumentProcessor",
]

View File

@ -0,0 +1,86 @@
"""分段预览 API — 只读预览文档分段结果(不向量化)。
用户上传文档前可预览分段效果调整 chunk_size/chunk_overlap 参数后再正式提交
"""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from agentkit.rag_platform.document_processor import (
DEFAULT_CHUNK_OVERLAP,
DEFAULT_CHUNK_SIZE,
DocumentProcessor,
)
class PreviewChunk(BaseModel):
"""预览 chunk 条目。"""
model_config = ConfigDict()
index: int
content: str
metadata: dict[str, Any] = Field(default_factory=dict)
class PreviewResult(BaseModel):
"""分段预览结果。"""
model_config = ConfigDict()
document_id: str = "" # 预览阶段文档尚未创建,默认空
chunks: list[PreviewChunk]
total_chunks: int
def generate_preview(
file_path: str,
file_type: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> PreviewResult:
"""生成分段预览(只读,不向量化)。
Args:
file_path: 文件路径
file_type: 文件类型扩展名不含点 "pdf""md"
chunk_size: 分段大小token 默认 512
chunk_overlap: 分段重叠token 默认 50
Returns:
PreviewResult包含 chunk 列表和总数
Raises:
FileNotFoundError: 文件不存在
ValueError: 解析失败
"""
processor = DocumentProcessor(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = processor.preview(
file_path,
file_type,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
return PreviewResult(
document_id="",
chunks=[
PreviewChunk(
index=c["index"],
content=c["content"],
metadata=c.get("metadata", {}),
)
for c in chunks
],
total_chunks=len(chunks),
)
__all__ = [
"PreviewChunk",
"PreviewResult",
"generate_preview",
]

View File

@ -0,0 +1,350 @@
"""内容净化与上传安全 — 文件类型白名单、大小限制、zip/image bomb 检测、SSRF 防护、Markdown 净化。
安全边界
- 文件类型白名单pdf/docx/xlsx/pptx/txt/md/csv/html
- 文件大小上限50MB 默认
- ZIP bomb 检测.docx/.xlsx/.pptx 本质是 ZIP
- Image bomb 检测像素数 > 100MP
- SSRF 防护拒绝 loopback/RFC1918/link-local/ULA
- Markdown 危险 HTML 标签净化script/iframe/object/embed
- XXE 防护forbid_dtd XML 解析
"""
from __future__ import annotations
import ipaddress
import re
import struct
import zipfile
from pathlib import Path
from typing import Any
# 文件类型白名单:扩展名 → MIME 类型
ALLOWED_FILE_TYPES: dict[str, str] = {
".pdf": "application/pdf",
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
".txt": "text/plain",
".md": "text/markdown",
".csv": "text/csv",
".html": "text/html",
}
# 文件大小上限50MB
MAX_FILE_SIZE = 50 * 1024 * 1024
# ZIP bomb 检测阈值
MAX_ZIP_RATIO = 100 # 压缩比上限 100:1
MAX_ZIP_DECOMPRESSED_SIZE = 500 * 1024 * 1024 # 解压后上限 500MB
# Image bomb 检测阈值
MAX_IMAGE_PIXELS = 100_000_000 # 100MP
# 需要净化的文本格式
_TEXT_FORMATS = frozenset({"md", "html", "txt", "csv"})
# 危险 HTML 标签(含内容整体移除)
_DANGEROUS_TAG_RE = re.compile(
r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)"
r"[^>]*>.*?<\s*/\s*\1\s*>",
re.IGNORECASE | re.DOTALL,
)
# 危险自闭合/孤立标签
_DANGEROUS_SOLO_RE = re.compile(
r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)[^>]*/?>",
re.IGNORECASE,
)
# 危险闭合标签
_DANGEROUS_CLOSE_RE = re.compile(
r"<\s*/\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)\s*>",
re.IGNORECASE,
)
# javascript: 协议
_JS_URL_RE = re.compile(r"javascript\s*:", re.IGNORECASE)
# 事件处理器属性 (onclick=, onload=, onerror=, ...)
_EVENT_HANDLER_RE = re.compile(
r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)",
re.IGNORECASE,
)
# SSRF 防护:拒绝的网络段
_DENIED_NETWORKS = [
ipaddress.ip_network("127.0.0.0/8"), # IPv4 loopback
ipaddress.ip_network("10.0.0.0/8"), # RFC1918
ipaddress.ip_network("172.16.0.0/12"), # RFC1918
ipaddress.ip_network("192.168.0.0/16"), # RFC1918
ipaddress.ip_network("169.254.0.0/16"), # link-local
ipaddress.ip_network("0.0.0.0/8"), # "this network"
ipaddress.ip_network("::1/128"), # IPv6 loopback
ipaddress.ip_network("fc00::/7"), # IPv6 ULA
ipaddress.ip_network("fe80::/10"), # IPv6 link-local
ipaddress.ip_network("::/128"), # IPv6 unspecified
]
def validate_file_type(filename: str, file_type: str | None = None) -> str:
"""校验文件类型是否在白名单内。
Args:
filename: 文件名用于提取扩展名
file_type: 可选 MIME 类型当前仅用于日志不做强制校验
Returns:
归一化文件类型扩展名不含点 "pdf""md"
Raises:
ValueError: 文件类型不在白名单
"""
ext = Path(filename).suffix.lower()
if ext not in ALLOWED_FILE_TYPES:
allowed = ", ".join(sorted(ALLOWED_FILE_TYPES.keys()))
raise ValueError(f"File type '{ext}' is not allowed. Allowed types: {allowed}")
return ext.lstrip(".")
def validate_file_size(file_size: int) -> None:
"""校验文件大小是否在上限内。
Args:
file_size: 文件大小字节
Raises:
ValueError: 文件超过 MAX_FILE_SIZE
"""
if file_size > MAX_FILE_SIZE:
raise ValueError(
f"File size {file_size} bytes exceeds limit {MAX_FILE_SIZE} bytes "
f"({MAX_FILE_SIZE // (1024 * 1024)}MB)"
)
if file_size <= 0:
raise ValueError(f"File size must be positive, got {file_size}")
def check_zip_bomb(file_path: str) -> None:
"""检测 ZIP bomb — 适用于 .docx/.xlsx/.pptx本质是 ZIP 格式)。
拒绝条件
- 单文件压缩比 > 100:1
- 总解压大小 > 500MB
Args:
file_path: ZIP 格式文件路径
Raises:
ValueError: 检测到 ZIP bomb
zipfile.BadZipFile: 文件不是有效 ZIP
"""
with zipfile.ZipFile(file_path) as zf:
total_compressed = 0
total_uncompressed = 0
for info in zf.infolist():
total_compressed += info.compress_size
total_uncompressed += info.file_size
# 单文件压缩比检查
if info.compress_size > 0:
ratio = info.file_size / info.compress_size
if ratio > MAX_ZIP_RATIO:
raise ValueError(
f"Zip bomb detected: file '{info.filename}' has compression "
f"ratio {ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)"
)
# 总解压大小检查
if total_uncompressed > MAX_ZIP_DECOMPRESSED_SIZE:
raise ValueError(
f"Zip bomb detected: total uncompressed size "
f"{total_uncompressed} bytes exceeds limit "
f"{MAX_ZIP_DECOMPRESSED_SIZE} bytes"
)
# 总体压缩比检查
if total_compressed > 0:
overall_ratio = total_uncompressed / total_compressed
if overall_ratio > MAX_ZIP_RATIO:
raise ValueError(
f"Zip bomb detected: overall compression ratio "
f"{overall_ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)"
)
def check_image_bomb(image_path: str) -> None:
"""检测 image bomb — 像素数超过 100MP 则拒绝。
支持 PNGJPEGGIF 格式通过文件头解析不依赖 PIL
无法识别的格式视为安全返回不抛异常
Args:
image_path: 图片文件路径
Raises:
ValueError: 像素数超过 MAX_IMAGE_PIXELS
"""
try:
with open(image_path, "rb") as f:
header = f.read(32)
except OSError:
return # 无法读取,视为安全
width, height = _read_image_dimensions(header, image_path)
if width <= 0 or height <= 0:
return # 无法识别尺寸,视为安全
pixels = width * height
if pixels > MAX_IMAGE_PIXELS:
raise ValueError(
f"Image bomb detected: {width}x{height} = {pixels} pixels "
f"exceeds limit {MAX_IMAGE_PIXELS} pixels (100MP)"
)
def _read_image_dimensions(header: bytes, file_path: str) -> tuple[int, int]:
"""从文件头读取图片尺寸PNG/GIF/JPEG"""
# PNG: \x89PNG\r\n\x1a\n + IHDR (width@16, height@20, big-endian)
if len(header) >= 24 and header[:8] == b"\x89PNG\r\n\x1a\n":
width = struct.unpack(">I", header[16:20])[0]
height = struct.unpack(">I", header[20:24])[0]
return width, height
# GIF: GIF87a/GIF89a (width@6, height@8, little-endian)
if len(header) >= 10 and header[:6] in (b"GIF87a", b"GIF89a"):
width = struct.unpack("<H", header[6:8])[0]
height = struct.unpack("<H", header[8:10])[0]
return width, height
# JPEG: \xff\xd8 — 需要扫描 SOF0/SOF2 标记
if len(header) >= 2 and header[:2] == b"\xff\xd8":
return _read_jpeg_dimensions(file_path)
return 0, 0
def _read_jpeg_dimensions(file_path: str) -> tuple[int, int]:
"""扫描 JPEG 文件查找 SOF0/SOF2 标记以获取尺寸。"""
try:
with open(file_path, "rb") as f:
f.read(2) # 跳过 SOI 标记
while True:
marker = f.read(2)
if len(marker) < 2:
return 0, 0
if marker[0] != 0xFF:
continue
# SOF0 (0xC0) 或 SOF2 (0xC2) 包含尺寸信息
if marker[1] in (0xC0, 0xC2):
f.read(3) # 跳过 length(2) + precision(1)
height = struct.unpack(">H", f.read(2))[0]
width = struct.unpack(">H", f.read(2))[0]
return width, height
else:
# 跳过此 marker 段
length_bytes = f.read(2)
if len(length_bytes) < 2:
return 0, 0
length = struct.unpack(">H", length_bytes)[0]
f.read(length - 2)
except (OSError, struct.error):
return 0, 0
def is_safe_ip(ip_str: str) -> bool:
"""SSRF 防护:检查 IP 是否安全(非 loopback/RFC1918/link-local/ULA
Args:
ip_str: IP 地址字符串
Returns:
True 如果 IP 是安全的公网地址False 如果在拒绝的网络段内
"""
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
return False # 无效 IP拒绝
for network in _DENIED_NETWORKS:
if ip in network:
return False
return True
def sanitize_markdown(content: str) -> str:
"""净化 Markdown 内容 — 移除危险 HTML 标签和属性。
移除内容
- 危险标签script/iframe/object/embed/style/form/input/textarea/button/link/meta及其内容
- javascript: 协议
- 事件处理器属性onclick/onload/onerror
Args:
content: 原始 Markdown/HTML 文本
Returns:
净化后的文本
"""
# 移除危险标签及其内容(如 <script>...</script>
result = _DANGEROUS_TAG_RE.sub("", content)
# 移除孤立的危险标签(如 <script>、<iframe ...>
result = _DANGEROUS_SOLO_RE.sub("", result)
# 移除孤立的闭合标签(如 </script>
result = _DANGEROUS_CLOSE_RE.sub("", result)
# 移除 javascript: 协议
result = _JS_URL_RE.sub("", result)
# 移除事件处理器属性
result = _EVENT_HANDLER_RE.sub("", result)
return result
def sanitize_content(content: str, file_type: str) -> str:
"""内容净化主入口 — 对文本格式应用 Markdown 净化。
Args:
content: 文本内容
file_type: 文件类型扩展名不含点 "md""html""pdf"
Returns:
净化后的内容二进制格式原样返回
"""
if file_type in _TEXT_FORMATS:
return sanitize_markdown(content)
return content
def parse_xml_safe(content: bytes) -> Any:
"""安全解析 XML — 禁止 DTD/实体以防止 XXE 攻击。
Args:
content: XML 字节内容
Returns:
ElementTree Element
Raises:
xml.etree.ElementTree.ParseError: 解析失败或包含 DTD
"""
import xml.etree.ElementTree as ET
parser = ET.XMLParser(
forbid_dtd=True,
forbid_entities=True,
forbid_external=True,
)
return ET.fromstring(content, parser=parser)
__all__ = [
"ALLOWED_FILE_TYPES",
"MAX_FILE_SIZE",
"MAX_IMAGE_PIXELS",
"MAX_ZIP_DECOMPRESSED_SIZE",
"MAX_ZIP_RATIO",
"check_image_bomb",
"check_zip_bomb",
"is_safe_ip",
"parse_xml_safe",
"sanitize_content",
"sanitize_markdown",
"validate_file_size",
"validate_file_type",
]

View File

@ -0,0 +1,382 @@
"""U3+U7 测试 — 文档处理管道。
测试场景
1. parse + segment 管道mock 文件 I/O
2. preview 返回 chunk 列表
3. vectorize 调用 embed model + vector store
4. 失败时设置 error 状态
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.rag_platform.document_processor import DocumentProcessor
from agentkit.rag_platform.models import DocumentStatus
from agentkit.rag_platform.preview import PreviewResult, generate_preview
class TestParseAndSegment:
"""parse + segment 管道测试。"""
def test_parse_extracts_text(self, tmp_path):
"""parse 从文件提取文本。"""
# 创建测试文件
file_path = tmp_path / "test.txt"
file_path.write_text("Hello, world!\nThis is a test document.", encoding="utf-8")
processor = DocumentProcessor()
text = processor.parse(str(file_path), "txt")
assert "Hello, world!" in text
assert "test document" in text
def test_parse_applies_sanitization(self, tmp_path):
"""parse 对文本格式应用内容净化。"""
file_path = tmp_path / "test.md"
file_path.write_text(
"Hello <script>alert(1)</script> world",
encoding="utf-8",
)
processor = DocumentProcessor()
text = processor.parse(str(file_path), "md")
assert "<script>" not in text
assert "Hello" in text
assert "world" in text
def test_segment_splits_text(self):
"""segment 将文本分段。"""
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
text = "This is a sentence. " * 50 # 足够长以触发分段
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
mock_splitter_cls.return_value = mock_splitter
chunks = processor.segment(text, chunk_size=100, chunk_overlap=10)
assert chunks == ["chunk1", "chunk2", "chunk3"]
mock_splitter.split_text.assert_called_once_with(text)
def test_segment_uses_custom_params(self):
"""segment 使用自定义 chunk_size 和 chunk_overlap。"""
processor = DocumentProcessor()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
processor.segment("text", chunk_size=256, chunk_overlap=20)
# 验证 SentenceSplitter 使用了自定义参数
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 256
assert call_kwargs["chunk_overlap"] == 20
def test_parse_file_not_found(self):
"""parse 文件不存在时抛出 FileNotFoundError。"""
processor = DocumentProcessor()
with pytest.raises(FileNotFoundError):
processor.parse("/nonexistent/file.txt", "txt")
class TestPreview:
"""preview 方法测试。"""
def test_preview_returns_chunks(self, tmp_path):
"""preview 返回 chunk 字典列表。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content for preview.", encoding="utf-8")
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2"]
mock_splitter_cls.return_value = mock_splitter
chunks = processor.preview(str(file_path), "txt")
assert len(chunks) == 2
assert chunks[0]["index"] == 0
assert chunks[0]["content"] == "chunk1"
assert chunks[1]["index"] == 1
assert chunks[1]["content"] == "chunk2"
assert chunks[0]["metadata"] == {}
def test_preview_uses_default_params(self, tmp_path):
"""preview 使用构造器默认参数。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor(chunk_size=512, chunk_overlap=50)
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
processor.preview(str(file_path), "txt")
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 512
assert call_kwargs["chunk_overlap"] == 50
class TestGeneratePreview:
"""generate_preview 函数测试。"""
def test_generate_preview_returns_preview_result(self, tmp_path):
"""generate_preview 返回 PreviewResult。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content for preview.", encoding="utf-8")
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
mock_splitter_cls.return_value = mock_splitter
result = generate_preview(str(file_path), "txt")
assert isinstance(result, PreviewResult)
assert result.total_chunks == 3
assert len(result.chunks) == 3
assert result.chunks[0].index == 0
assert result.chunks[0].content == "chunk1"
assert result.chunks[2].index == 2
assert result.chunks[2].content == "chunk3"
assert result.document_id == "" # 预览阶段无 document_id
def test_generate_preview_with_custom_params(self, tmp_path):
"""generate_preview 使用自定义分段参数。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
generate_preview(str(file_path), "txt", chunk_size=256, chunk_overlap=20)
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 256
assert call_kwargs["chunk_overlap"] == 20
class TestVectorize:
"""vectorize 方法测试。"""
def _make_mock_embed_model(self):
"""创建 mock embedding 模型。"""
mock = MagicMock()
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return mock
def _make_mock_vector_store(self):
"""创建 mock vector store。"""
mock = MagicMock()
mock.async_add = AsyncMock()
return mock
async def test_vectorize_calls_embed_model(self):
"""vectorize 调用 embedding 模型。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2", "chunk3"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
# 每个 chunk 都调用了 embed
assert mock_embed.aget_text_embedding.await_count == 3
async def test_vectorize_calls_vector_store(self):
"""vectorize 调用 vector store 的 async_add。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
mock_vs.async_add.assert_awaited_once()
# 验证传入的 nodes 数量正确
added_nodes = mock_vs.async_add.call_args.args[0]
assert len(added_nodes) == 2
async def test_vectorize_accepts_dict_chunks(self):
"""vectorize 接受字典格式的 chunk。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = [
{"index": 0, "content": "chunk1", "metadata": {}},
{"index": 1, "content": "chunk2", "metadata": {}},
]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
assert mock_embed.aget_text_embedding.await_count == 2
async def test_vectorize_failure_propagates(self):
"""vectorize 失败时抛出异常。"""
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed failed"))
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode"):
with pytest.raises(RuntimeError, match="embed failed"):
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
async def test_vectorize_sets_metadata(self):
"""vectorize 为每个 node 设置正确的 metadata。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
# 验证 TextNode 被创建时使用了正确的 metadata
assert mock_node_cls.call_count == 2
first_call_kwargs = mock_node_cls.call_args_list[0].kwargs
assert first_call_kwargs["metadata"]["kb_id"] == "kb-1"
assert first_call_kwargs["metadata"]["document_id"] == "doc-1"
assert first_call_kwargs["metadata"]["chunk_index"] == 0
class TestProcess:
"""process 方法(完整管道 + 状态转换)测试。"""
def _make_mock_store(self):
"""创建 mock KBStore。"""
store = MagicMock()
store.update_document_status = AsyncMock()
return store
async def test_process_success_transitions(self, tmp_path):
"""process 成功时状态转换parsing → segmenting → vectorizing → indexed。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
mock_vs = MagicMock()
mock_vs.async_add = AsyncMock()
mock_store = self._make_mock_store()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1"]
mock_splitter_cls.return_value = mock_splitter
with patch("llama_index.core.schema.TextNode"):
await processor.process(
str(file_path),
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证状态转换序列
status_calls = mock_store.update_document_status.call_args_list
assert len(status_calls) == 4
assert status_calls[0].args[1] == DocumentStatus.parsing
assert status_calls[1].args[1] == DocumentStatus.segmenting
assert status_calls[2].args[1] == DocumentStatus.vectorizing
assert status_calls[3].args[1] == DocumentStatus.indexed
async def test_process_failure_sets_failed_status(self, tmp_path):
"""process 失败时状态置为 failed。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor()
# embed 模型失败
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed error"))
mock_vs = MagicMock()
mock_vs.async_add = AsyncMock()
mock_store = self._make_mock_store()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1"]
mock_splitter_cls.return_value = mock_splitter
with patch("llama_index.core.schema.TextNode"):
with pytest.raises(RuntimeError, match="embed error"):
await processor.process(
str(file_path),
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证最后状态为 failed且包含 error_message
status_calls = mock_store.update_document_status.call_args_list
last_call = status_calls[-1]
assert last_call.args[1] == DocumentStatus.failed
assert "embed error" in last_call.kwargs.get("error_message", "")
async def test_process_parse_failure_sets_failed_status(self, tmp_path):
"""parse 阶段失败时状态置为 failed。"""
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_vs = MagicMock()
mock_store = self._make_mock_store()
# 文件不存在导致 parse 失败
with pytest.raises(FileNotFoundError):
await processor.process(
"/nonexistent/file.txt",
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证状态为 failed
status_calls = mock_store.update_document_status.call_args_list
assert len(status_calls) >= 2 # 至少 parsing + failed
last_call = status_calls[-1]
assert last_call.args[1] == DocumentStatus.failed

View File

@ -0,0 +1,346 @@
"""U3+U7 测试 — 内容净化与上传安全。
测试场景
1. 文件类型白名单允许类型通过.exe/.sh 拒绝
2. 文件大小限制超限拒绝
3. Markdown 净化script 标签移除
4. SSRF IP 过滤私有 IP 拒绝公网 IP 允许
5. ZIP bomb 检测高压缩比拒绝
"""
from __future__ import annotations
import struct
import zipfile
import zlib
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from agentkit.rag_platform.sanitize import (
ALLOWED_FILE_TYPES,
MAX_FILE_SIZE,
check_image_bomb,
check_zip_bomb,
is_safe_ip,
sanitize_content,
sanitize_markdown,
validate_file_size,
validate_file_type,
)
class TestValidateFileType:
"""文件类型白名单测试。"""
@pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"])
def test_allowed_office_types_pass(self, filename: str):
"""允许的 Office 格式通过。"""
assert validate_file_type(filename) == Path(filename).suffix[1:].lower()
@pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"])
def test_allowed_text_types_pass(self, filename: str):
"""允许的文本格式通过。"""
result = validate_file_type(filename)
assert result == Path(filename).suffix[1:].lower()
def test_pdf_returns_pdf(self):
"""PDF 文件返回 'pdf'"""
assert validate_file_type("report.pdf") == "pdf"
@pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"])
def test_dangerous_types_rejected(self, filename: str):
"""危险文件类型被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type(filename)
def test_exe_rejected(self):
""".exe 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("program.exe")
def test_sh_rejected(self):
""".sh 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("script.sh")
def test_case_insensitive_extension(self):
"""扩展名大小写不敏感。"""
assert validate_file_type("DOC.PDF") == "pdf"
assert validate_file_type("doc.MD") == "md"
def test_no_extension_rejected(self):
"""无扩展名文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("noextension")
def test_all_allowed_types_in_whitelist(self):
"""白名单包含 8 种类型。"""
expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"}
assert set(ALLOWED_FILE_TYPES.keys()) == expected
class TestValidateFileSize:
"""文件大小限制测试。"""
def test_small_file_passes(self):
"""小文件通过。"""
validate_file_size(1024)
validate_file_size(1)
def test_exact_limit_passes(self):
"""刚好等于上限的文件通过。"""
validate_file_size(MAX_FILE_SIZE)
def test_oversized_rejected(self):
"""超限文件被拒绝。"""
with pytest.raises(ValueError, match="exceeds limit"):
validate_file_size(MAX_FILE_SIZE + 1)
def test_zero_size_rejected(self):
"""零字节文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(0)
def test_negative_size_rejected(self):
"""负大小文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(-1)
class TestSanitizeMarkdown:
"""Markdown 净化测试。"""
def test_script_tag_removed(self):
"""script 标签及其内容被移除。"""
content = "Hello <script>alert('xss')</script> world"
result = sanitize_markdown(content)
assert "<script>" not in result
assert "alert" not in result
assert "Hello" in result
assert "world" in result
def test_iframe_tag_removed(self):
"""iframe 标签被移除。"""
content = "Text <iframe src='evil.com'></iframe> more"
result = sanitize_markdown(content)
assert "<iframe" not in result
assert "evil.com" not in result
def test_object_tag_removed(self):
"""object 标签被移除。"""
content = "<object data='evil.swf'></object>"
result = sanitize_markdown(content)
assert "<object" not in result
def test_embed_tag_removed(self):
"""embed 标签被移除。"""
content = "<embed src='evil.swf'>"
result = sanitize_markdown(content)
assert "<embed" not in result
def test_safe_content_preserved(self):
"""安全内容保留。"""
content = "# Title\n\nThis is **bold** and *italic*."
result = sanitize_markdown(content)
assert result == content
def test_javascript_protocol_removed(self):
"""javascript: 协议被移除。"""
content = "<a href='javascript:alert(1)'>click</a>"
result = sanitize_markdown(content)
assert "javascript:" not in result.lower()
def test_event_handler_removed(self):
"""事件处理器属性被移除。"""
content = "<div onclick='alert(1)'>text</div>"
result = sanitize_markdown(content)
assert "onclick" not in result
def test_multiple_dangerous_tags_removed(self):
"""多个危险标签同时移除。"""
content = (
"<script>bad()</script><iframe src='x'></iframe>safe text<object data='y'></object>"
)
result = sanitize_markdown(content)
assert "<script" not in result
assert "<iframe" not in result
assert "<object" not in result
assert "safe text" in result
def test_multiline_script_removed(self):
"""多行 script 标签被移除。"""
content = "before\n<script>\nline1\nline2\n</script>\nafter"
result = sanitize_markdown(content)
assert "<script>" not in result
assert "line1" not in result
assert "line2" not in result
assert "before" in result
assert "after" in result
class TestSanitizeContent:
"""sanitize_content 主入口测试。"""
def test_markdown_sanitized(self):
"""Markdown 格式应用净化。"""
content = "Hello <script>alert(1)</script>"
result = sanitize_content(content, "md")
assert "<script>" not in result
def test_html_sanitized(self):
"""HTML 格式应用净化。"""
content = "<iframe src='evil'></iframe>"
result = sanitize_content(content, "html")
assert "<iframe" not in result
def test_text_sanitized(self):
"""纯文本格式应用净化。"""
content = "text <script>x</script>"
result = sanitize_content(content, "txt")
assert "<script>" not in result
def test_binary_format_not_sanitized(self):
"""二进制格式pdf不应用净化。"""
content = "raw <script>content</script>"
result = sanitize_content(content, "pdf")
assert result == content # 原样返回
class TestIsSafeIp:
"""SSRF IP 过滤测试。"""
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"127.0.1.1",
"10.0.0.1",
"172.16.0.1",
"172.31.255.255",
"192.168.1.1",
"169.254.1.1",
"0.0.0.0",
],
)
def test_private_ips_blocked(self, ip: str):
"""私有/loopback IP 被拒绝。"""
assert is_safe_ip(ip) is False
def test_ipv6_loopback_blocked(self):
"""IPv6 loopback 被拒绝。"""
assert is_safe_ip("::1") is False
def test_ipv6_ula_blocked(self):
"""IPv6 ULA 被拒绝。"""
assert is_safe_ip("fd00::1") is False
def test_ipv6_link_local_blocked(self):
"""IPv6 link-local 被拒绝。"""
assert is_safe_ip("fe80::1") is False
@pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"])
def test_public_ips_allowed(self, ip: str):
"""公网 IP 允许。"""
assert is_safe_ip(ip) is True
def test_invalid_ip_blocked(self):
"""无效 IP 被拒绝。"""
assert is_safe_ip("not-an-ip") is False
def test_empty_ip_blocked(self):
"""空字符串被拒绝。"""
assert is_safe_ip("") is False
class TestCheckZipBomb:
"""ZIP bomb 检测测试。"""
def test_normal_zip_passes(self, tmp_path):
"""正常 ZIP 文件通过。"""
zip_path = tmp_path / "normal.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("file.txt", b"Hello, world! " * 100)
# 不应抛出异常
check_zip_bomb(str(zip_path))
def test_high_ratio_rejected(self, tmp_path):
"""高压缩比 ZIP 被拒绝ZIP bomb"""
zip_path = tmp_path / "bomb.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 10MB 全零数据 — 压缩比极高
zf.writestr("bomb.txt", b"\x00" * 10_000_000)
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(str(zip_path))
def test_large_uncompressed_rejected_via_mock(self, tmp_path):
"""解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。"""
zip_path = str(tmp_path / "large.zip")
# 创建 mock ZipFile 返回超大文件信息
mock_info = MagicMock()
mock_info.filename = "large.bin"
mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed
mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit)
mock_zipfile = MagicMock()
mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile)
mock_zipfile.__exit__ = MagicMock(return_value=False)
mock_zipfile.infolist.return_value = [mock_info]
with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile):
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(zip_path)
def test_docx_zip_format_works(self, tmp_path):
""".docx 格式(本质是 ZIP能被检测。"""
zip_path = tmp_path / "fake.docx"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("document.xml", "<doc/>")
# 正常 docx 不应触发 zip bomb
check_zip_bomb(str(zip_path))
class TestCheckImageBomb:
"""Image bomb 检测测试。"""
def _make_png(self, path: Path, width: int, height: int) -> None:
"""创建一个最小 PNG 文件(仅 IHDR"""
png_header = b"\x89PNG\r\n\x1a\n"
ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00"
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
path.write_bytes(png_header + ihdr_chunk)
def test_small_png_passes(self, tmp_path):
"""小 PNG 图片通过。"""
img_path = tmp_path / "small.png"
self._make_png(img_path, 1, 1)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_large_png_rejected(self, tmp_path):
"""超大 PNG像素数 > 100MP被拒绝。"""
img_path = tmp_path / "huge.png"
# 20000x20000 = 400MP > 100MP
self._make_png(img_path, 20000, 20000)
with pytest.raises(ValueError, match="Image bomb"):
check_image_bomb(str(img_path))
def test_unknown_format_passes(self, tmp_path):
"""无法识别的格式视为安全。"""
img_path = tmp_path / "unknown.bin"
img_path.write_bytes(b"\x00" * 32)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_nonexistent_file_passes(self, tmp_path):
"""不存在的文件视为安全(不抛异常)。"""
check_image_bomb(str(tmp_path / "nonexistent.png"))