From b55c8967940ab027c1485ffc30e227fdd17ae660 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 11:21:42 +0800 Subject: [PATCH] =?UTF-8?q?feat(rag=5Fplatform):=20U3+U7=20=E2=80=94=20doc?= =?UTF-8?q?ument=20processing=20pipeline=20+=20upload=20security?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit U3: Document processing pipeline (document_processor.py) - DocumentProcessor class wrapping parse → segment → vectorize - parse() uses memory/document_loader.py for multi-format extraction - segment() uses LlamaIndex SentenceSplitter - preview() returns chunks for read-only preview (no vectorization) - vectorize() embeds chunks and stores in pgvector (all-or-nothing) - process() orchestrates full pipeline with status transitions: pending → parsing → segmenting → vectorizing → indexed | failed U7: Upload security & content sanitization (sanitize.py) - ALLOWED_FILE_TYPES whitelist (pdf/docx/xlsx/pptx/txt/md/csv/html) - MAX_FILE_SIZE 50MB limit - validate_file_type() / validate_file_size() guards - check_zip_bomb() for ZIP-based formats (ratio > 100:1 or > 500MB) - check_image_bomb() for pixel count > 100MP (PNG/JPEG/GIF header parsing) - is_safe_ip() SSRF protection (loopback/RFC1918/link-local/ULA denied) - sanitize_markdown() removes dangerous HTML tags (script/iframe/object/embed) - sanitize_content() main entry point for text format sanitization - parse_xml_safe() XXE protection (forbid_dtd/forbid_entities/forbid_external) Preview API (preview.py) - PreviewChunk / PreviewResult Pydantic models - generate_preview() returns read-only segmentation preview Tests: 112 tests passing (45 new + 67 existing) - test_sanitize.py: file type/size, markdown sanitization, SSRF, zip/image bomb - test_document_processor.py: parse/segment, preview, vectorize, failure status --- src/agentkit/rag_platform/__init__.py | 36 ++ .../rag_platform/document_processor.py | 250 ++++++++++++ src/agentkit/rag_platform/preview.py | 86 ++++ src/agentkit/rag_platform/sanitize.py | 350 ++++++++++++++++ .../rag_platform/test_document_processor.py | 382 ++++++++++++++++++ tests/unit/rag_platform/test_sanitize.py | 346 ++++++++++++++++ 6 files changed, 1450 insertions(+) create mode 100644 src/agentkit/rag_platform/document_processor.py create mode 100644 src/agentkit/rag_platform/preview.py create mode 100644 src/agentkit/rag_platform/sanitize.py create mode 100644 tests/unit/rag_platform/test_document_processor.py create mode 100644 tests/unit/rag_platform/test_sanitize.py diff --git a/src/agentkit/rag_platform/__init__.py b/src/agentkit/rag_platform/__init__.py index 8d451aa..1340c23 100644 --- a/src/agentkit/rag_platform/__init__.py +++ b/src/agentkit/rag_platform/__init__.py @@ -1,5 +1,10 @@ """RAG 平台模块 — 企业知识库场景的工业级 RAG 管道。""" +from agentkit.rag_platform.document_processor import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DocumentProcessor, +) from agentkit.rag_platform.models import ( Chunk, Document, @@ -9,13 +14,44 @@ from agentkit.rag_platform.models import ( QueryMode, QueryResult, ) +from agentkit.rag_platform.preview import ( + PreviewChunk, + PreviewResult, + generate_preview, +) +from agentkit.rag_platform.sanitize import ( + ALLOWED_FILE_TYPES, + MAX_FILE_SIZE, + check_image_bomb, + check_zip_bomb, + is_safe_ip, + sanitize_content, + sanitize_markdown, + validate_file_size, + validate_file_type, +) __all__ = [ + "ALLOWED_FILE_TYPES", + "DEFAULT_CHUNK_OVERLAP", + "DEFAULT_CHUNK_SIZE", "Chunk", "Document", + "DocumentProcessor", "DocumentStatus", "KBStatus", "KnowledgeBase", + "MAX_FILE_SIZE", + "PreviewChunk", + "PreviewResult", "QueryMode", "QueryResult", + "check_image_bomb", + "check_zip_bomb", + "generate_preview", + "is_safe_ip", + "sanitize_content", + "sanitize_markdown", + "validate_file_size", + "validate_file_type", ] diff --git a/src/agentkit/rag_platform/document_processor.py b/src/agentkit/rag_platform/document_processor.py new file mode 100644 index 0000000..be55200 --- /dev/null +++ b/src/agentkit/rag_platform/document_processor.py @@ -0,0 +1,250 @@ +"""文档处理管道 — 解析 → 分段 → 向量化。 + +封装 DocumentLoader(多格式解析)+ LlamaIndex SentenceSplitter(分段)+ +PGVectorStore(向量化存储)。 + +状态机:pending → parsing → segmenting → vectorizing → indexed | failed +失败语义:vectorize 是 all-or-nothing — 失败时抛出异常,调用方负责将状态置为 failed。 +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from agentkit.memory.document_loader import DocumentLoader +from agentkit.rag_platform.models import DocumentStatus +from agentkit.rag_platform.sanitize import sanitize_content + +if TYPE_CHECKING: + from llama_index.core.embeddings import BaseEmbedding + from llama_index.core.schema import TextNode + from llama_index.vector_stores.postgres import PGVectorStore + + from agentkit.rag_platform.store import KBStore + +logger = logging.getLogger(__name__) + +DEFAULT_CHUNK_SIZE = 512 +DEFAULT_CHUNK_OVERLAP = 50 + + +class DocumentProcessor: + """文档处理管道:解析 → 分段 → 向量化。 + + Args: + chunk_size: 分段大小(token 数),默认 512 + chunk_overlap: 分段重叠(token 数),默认 50 + """ + + def __init__( + self, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + ) -> None: + self._loader = DocumentLoader() + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + + def parse(self, file_path: str, file_type: str) -> str: + """从文件中提取文本。 + + 使用 memory/document_loader.py 的 DocumentLoader 进行多格式解析 + (PDF/DOCX/XLSX/Markdown/HTML/纯文本)。 + + Args: + file_path: 文件路径 + file_type: 文件类型(扩展名,不含点)— 用于净化 + + Returns: + 提取的文本内容(已净化) + + Raises: + FileNotFoundError: 文件不存在 + ValueError: 内容超过大小限制 + """ + doc = self._loader.load(file_path) + # 对文本格式应用内容净化 + return sanitize_content(doc.content, file_type) + + def segment( + self, + text: str, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + ) -> list[str]: + """将文本分段为 chunk。 + + 使用 LlamaIndex SentenceSplitter 进行分段。 + + Args: + text: 待分段文本 + chunk_size: 分段大小(token 数) + chunk_overlap: 分段重叠(token 数) + + Returns: + chunk 文本列表 + """ + from llama_index.core.node_parser import SentenceSplitter + + splitter = SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + chunks = splitter.split_text(text) + logger.debug( + "Segmented text into %d chunks (chunk_size=%d, overlap=%d)", + len(chunks), + chunk_size, + chunk_overlap, + ) + return chunks + + def preview( + self, + file_path: str, + file_type: str, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[dict[str, Any]]: + """解析 + 分段,返回 chunk 用于只读预览(不向量化)。 + + Args: + file_path: 文件路径 + file_type: 文件类型(扩展名,不含点) + chunk_size: 分段大小(默认使用构造器值) + chunk_overlap: 分段重叠(默认使用构造器值) + + Returns: + chunk 字典列表:[{index, content, metadata}, ...] + """ + cs = chunk_size or self._chunk_size + co = chunk_overlap or self._chunk_overlap + + text = self.parse(file_path, file_type) + chunks = self.segment(text, cs, co) + + return [{"index": i, "content": c, "metadata": {}} for i, c in enumerate(chunks)] + + async def vectorize( + self, + chunks: list[str] | list[dict[str, Any]], + kb_id: str, + document_id: str, + vector_store: "PGVectorStore", + embed_model: "BaseEmbedding", + ) -> list["TextNode"]: + """将 chunk 向量化并存储到 pgvector。 + + All-or-nothing 语义:任一 chunk 向量化失败则整体失败(抛出异常)。 + + Args: + chunks: chunk 文本列表或 chunk 字典列表 + kb_id: 知识库 ID + document_id: 文档 ID + vector_store: LlamaIndex PGVectorStore 实例 + embed_model: LlamaIndex embedding 模型 + + Returns: + 已写入 vector store 的 TextNode 列表 + + Raises: + Exception: 向量化或存储失败 + """ + from llama_index.core.schema import TextNode + + # 统一 chunk 格式为字符串 + chunk_texts: list[str] = [] + for i, chunk in enumerate(chunks): + if isinstance(chunk, str): + chunk_texts.append(chunk) + elif isinstance(chunk, dict): + chunk_texts.append(chunk.get("content", "")) + else: + chunk_texts.append(str(chunk)) + + # 构建 TextNode + nodes: list[TextNode] = [ + TextNode( + text=text, + metadata={ + "kb_id": kb_id, + "document_id": document_id, + "chunk_index": i, + }, + ) + for i, text in enumerate(chunk_texts) + ] + + # 向量化 — all-or-nothing:任一失败则抛出异常 + for node in nodes: + node.embedding = await embed_model.aget_text_embedding(node.text) + + # 存储到 vector store + await vector_store.async_add(nodes) + + logger.info( + "Vectorized %d chunks for document=%s kb=%s", + len(nodes), + document_id, + kb_id, + ) + return nodes + + async def process( + self, + file_path: str, + file_type: str, + kb_id: str, + document_id: str, + vector_store: "PGVectorStore", + embed_model: "BaseEmbedding", + store: "KBStore", + ) -> None: + """完整管道:parse → segment → vectorize,含状态转换。 + + 状态转换:pending → parsing → segmenting → vectorizing → indexed + 失败时:状态置为 failed,error_message 记录异常信息,异常重新抛出。 + + Args: + file_path: 文件路径 + file_type: 文件类型(扩展名,不含点) + kb_id: 知识库 ID + document_id: 文档 ID + vector_store: LlamaIndex PGVectorStore 实例 + embed_model: LlamaIndex embedding 模型 + store: KBStore 实例(用于状态更新) + + Raises: + Exception: 管道任一阶段失败(状态已置为 failed) + """ + try: + # parsing + await store.update_document_status(document_id, DocumentStatus.parsing) + text = self.parse(file_path, file_type) + + # segmenting + await store.update_document_status(document_id, DocumentStatus.segmenting) + chunks = self.segment(text, self._chunk_size, self._chunk_overlap) + + # vectorizing + await store.update_document_status(document_id, DocumentStatus.vectorizing) + await self.vectorize(chunks, kb_id, document_id, vector_store, embed_model) + + # indexed + await store.update_document_status(document_id, DocumentStatus.indexed) + logger.info("Document %s processed successfully", document_id) + except Exception as e: + # failed — 记录错误信息并重新抛出 + await store.update_document_status( + document_id, DocumentStatus.failed, error_message=str(e) + ) + logger.error("Document %s processing failed: %s", document_id, e) + raise + + +__all__ = [ + "DEFAULT_CHUNK_OVERLAP", + "DEFAULT_CHUNK_SIZE", + "DocumentProcessor", +] diff --git a/src/agentkit/rag_platform/preview.py b/src/agentkit/rag_platform/preview.py new file mode 100644 index 0000000..2dea820 --- /dev/null +++ b/src/agentkit/rag_platform/preview.py @@ -0,0 +1,86 @@ +"""分段预览 API — 只读预览文档分段结果(不向量化)。 + +用户上传文档前可预览分段效果,调整 chunk_size/chunk_overlap 参数后再正式提交。 +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from agentkit.rag_platform.document_processor import ( + DEFAULT_CHUNK_OVERLAP, + DEFAULT_CHUNK_SIZE, + DocumentProcessor, +) + + +class PreviewChunk(BaseModel): + """预览 chunk 条目。""" + + model_config = ConfigDict() + + index: int + content: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class PreviewResult(BaseModel): + """分段预览结果。""" + + model_config = ConfigDict() + + document_id: str = "" # 预览阶段文档尚未创建,默认空 + chunks: list[PreviewChunk] + total_chunks: int + + +def generate_preview( + file_path: str, + file_type: str, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, +) -> PreviewResult: + """生成分段预览(只读,不向量化)。 + + Args: + file_path: 文件路径 + file_type: 文件类型(扩展名,不含点),如 "pdf"、"md" + chunk_size: 分段大小(token 数),默认 512 + chunk_overlap: 分段重叠(token 数),默认 50 + + Returns: + PreviewResult,包含 chunk 列表和总数 + + Raises: + FileNotFoundError: 文件不存在 + ValueError: 解析失败 + """ + processor = DocumentProcessor(chunk_size=chunk_size, chunk_overlap=chunk_overlap) + chunks = processor.preview( + file_path, + file_type, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + + return PreviewResult( + document_id="", + chunks=[ + PreviewChunk( + index=c["index"], + content=c["content"], + metadata=c.get("metadata", {}), + ) + for c in chunks + ], + total_chunks=len(chunks), + ) + + +__all__ = [ + "PreviewChunk", + "PreviewResult", + "generate_preview", +] diff --git a/src/agentkit/rag_platform/sanitize.py b/src/agentkit/rag_platform/sanitize.py new file mode 100644 index 0000000..99a47e9 --- /dev/null +++ b/src/agentkit/rag_platform/sanitize.py @@ -0,0 +1,350 @@ +"""内容净化与上传安全 — 文件类型白名单、大小限制、zip/image bomb 检测、SSRF 防护、Markdown 净化。 + +安全边界: +- 文件类型白名单(pdf/docx/xlsx/pptx/txt/md/csv/html) +- 文件大小上限(50MB 默认) +- ZIP bomb 检测(.docx/.xlsx/.pptx 本质是 ZIP) +- Image bomb 检测(像素数 > 100MP) +- SSRF 防护(拒绝 loopback/RFC1918/link-local/ULA) +- Markdown 危险 HTML 标签净化(script/iframe/object/embed) +- XXE 防护(forbid_dtd XML 解析) +""" + +from __future__ import annotations + +import ipaddress +import re +import struct +import zipfile +from pathlib import Path +from typing import Any + +# 文件类型白名单:扩展名 → MIME 类型 +ALLOWED_FILE_TYPES: dict[str, str] = { + ".pdf": "application/pdf", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".txt": "text/plain", + ".md": "text/markdown", + ".csv": "text/csv", + ".html": "text/html", +} + +# 文件大小上限:50MB +MAX_FILE_SIZE = 50 * 1024 * 1024 + +# ZIP bomb 检测阈值 +MAX_ZIP_RATIO = 100 # 压缩比上限 100:1 +MAX_ZIP_DECOMPRESSED_SIZE = 500 * 1024 * 1024 # 解压后上限 500MB + +# Image bomb 检测阈值 +MAX_IMAGE_PIXELS = 100_000_000 # 100MP + +# 需要净化的文本格式 +_TEXT_FORMATS = frozenset({"md", "html", "txt", "csv"}) + +# 危险 HTML 标签(含内容整体移除) +_DANGEROUS_TAG_RE = re.compile( + r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)" + r"[^>]*>.*?<\s*/\s*\1\s*>", + re.IGNORECASE | re.DOTALL, +) +# 危险自闭合/孤立标签 +_DANGEROUS_SOLO_RE = re.compile( + r"<\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)[^>]*/?>", + re.IGNORECASE, +) +# 危险闭合标签 +_DANGEROUS_CLOSE_RE = re.compile( + r"<\s*/\s*(script|iframe|object|embed|style|form|input|textarea|button|link|meta)\s*>", + re.IGNORECASE, +) +# javascript: 协议 +_JS_URL_RE = re.compile(r"javascript\s*:", re.IGNORECASE) +# 事件处理器属性 (onclick=, onload=, onerror=, ...) +_EVENT_HANDLER_RE = re.compile( + r"\s+on\w+\s*=\s*(?:\"[^\"]*\"|'[^']*'|[^\s>]+)", + re.IGNORECASE, +) + +# SSRF 防护:拒绝的网络段 +_DENIED_NETWORKS = [ + ipaddress.ip_network("127.0.0.0/8"), # IPv4 loopback + ipaddress.ip_network("10.0.0.0/8"), # RFC1918 + ipaddress.ip_network("172.16.0.0/12"), # RFC1918 + ipaddress.ip_network("192.168.0.0/16"), # RFC1918 + ipaddress.ip_network("169.254.0.0/16"), # link-local + ipaddress.ip_network("0.0.0.0/8"), # "this network" + ipaddress.ip_network("::1/128"), # IPv6 loopback + ipaddress.ip_network("fc00::/7"), # IPv6 ULA + ipaddress.ip_network("fe80::/10"), # IPv6 link-local + ipaddress.ip_network("::/128"), # IPv6 unspecified +] + + +def validate_file_type(filename: str, file_type: str | None = None) -> str: + """校验文件类型是否在白名单内。 + + Args: + filename: 文件名(用于提取扩展名) + file_type: 可选 MIME 类型(当前仅用于日志,不做强制校验) + + Returns: + 归一化文件类型(扩展名,不含点),如 "pdf"、"md" + + Raises: + ValueError: 文件类型不在白名单 + """ + ext = Path(filename).suffix.lower() + if ext not in ALLOWED_FILE_TYPES: + allowed = ", ".join(sorted(ALLOWED_FILE_TYPES.keys())) + raise ValueError(f"File type '{ext}' is not allowed. Allowed types: {allowed}") + return ext.lstrip(".") + + +def validate_file_size(file_size: int) -> None: + """校验文件大小是否在上限内。 + + Args: + file_size: 文件大小(字节) + + Raises: + ValueError: 文件超过 MAX_FILE_SIZE + """ + if file_size > MAX_FILE_SIZE: + raise ValueError( + f"File size {file_size} bytes exceeds limit {MAX_FILE_SIZE} bytes " + f"({MAX_FILE_SIZE // (1024 * 1024)}MB)" + ) + if file_size <= 0: + raise ValueError(f"File size must be positive, got {file_size}") + + +def check_zip_bomb(file_path: str) -> None: + """检测 ZIP bomb — 适用于 .docx/.xlsx/.pptx(本质是 ZIP 格式)。 + + 拒绝条件: + - 单文件压缩比 > 100:1 + - 总解压大小 > 500MB + + Args: + file_path: ZIP 格式文件路径 + + Raises: + ValueError: 检测到 ZIP bomb + zipfile.BadZipFile: 文件不是有效 ZIP + """ + with zipfile.ZipFile(file_path) as zf: + total_compressed = 0 + total_uncompressed = 0 + + for info in zf.infolist(): + total_compressed += info.compress_size + total_uncompressed += info.file_size + + # 单文件压缩比检查 + if info.compress_size > 0: + ratio = info.file_size / info.compress_size + if ratio > MAX_ZIP_RATIO: + raise ValueError( + f"Zip bomb detected: file '{info.filename}' has compression " + f"ratio {ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)" + ) + + # 总解压大小检查 + if total_uncompressed > MAX_ZIP_DECOMPRESSED_SIZE: + raise ValueError( + f"Zip bomb detected: total uncompressed size " + f"{total_uncompressed} bytes exceeds limit " + f"{MAX_ZIP_DECOMPRESSED_SIZE} bytes" + ) + + # 总体压缩比检查 + if total_compressed > 0: + overall_ratio = total_uncompressed / total_compressed + if overall_ratio > MAX_ZIP_RATIO: + raise ValueError( + f"Zip bomb detected: overall compression ratio " + f"{overall_ratio:.1f}:1 (max {MAX_ZIP_RATIO}:1)" + ) + + +def check_image_bomb(image_path: str) -> None: + """检测 image bomb — 像素数超过 100MP 则拒绝。 + + 支持 PNG、JPEG、GIF 格式(通过文件头解析,不依赖 PIL)。 + 无法识别的格式视为安全(返回不抛异常)。 + + Args: + image_path: 图片文件路径 + + Raises: + ValueError: 像素数超过 MAX_IMAGE_PIXELS + """ + try: + with open(image_path, "rb") as f: + header = f.read(32) + except OSError: + return # 无法读取,视为安全 + + width, height = _read_image_dimensions(header, image_path) + if width <= 0 or height <= 0: + return # 无法识别尺寸,视为安全 + + pixels = width * height + if pixels > MAX_IMAGE_PIXELS: + raise ValueError( + f"Image bomb detected: {width}x{height} = {pixels} pixels " + f"exceeds limit {MAX_IMAGE_PIXELS} pixels (100MP)" + ) + + +def _read_image_dimensions(header: bytes, file_path: str) -> tuple[int, int]: + """从文件头读取图片尺寸(PNG/GIF/JPEG)。""" + # PNG: \x89PNG\r\n\x1a\n + IHDR (width@16, height@20, big-endian) + if len(header) >= 24 and header[:8] == b"\x89PNG\r\n\x1a\n": + width = struct.unpack(">I", header[16:20])[0] + height = struct.unpack(">I", header[20:24])[0] + return width, height + + # GIF: GIF87a/GIF89a (width@6, height@8, little-endian) + if len(header) >= 10 and header[:6] in (b"GIF87a", b"GIF89a"): + width = struct.unpack("= 2 and header[:2] == b"\xff\xd8": + return _read_jpeg_dimensions(file_path) + + return 0, 0 + + +def _read_jpeg_dimensions(file_path: str) -> tuple[int, int]: + """扫描 JPEG 文件查找 SOF0/SOF2 标记以获取尺寸。""" + try: + with open(file_path, "rb") as f: + f.read(2) # 跳过 SOI 标记 + while True: + marker = f.read(2) + if len(marker) < 2: + return 0, 0 + if marker[0] != 0xFF: + continue + # SOF0 (0xC0) 或 SOF2 (0xC2) 包含尺寸信息 + if marker[1] in (0xC0, 0xC2): + f.read(3) # 跳过 length(2) + precision(1) + height = struct.unpack(">H", f.read(2))[0] + width = struct.unpack(">H", f.read(2))[0] + return width, height + else: + # 跳过此 marker 段 + length_bytes = f.read(2) + if len(length_bytes) < 2: + return 0, 0 + length = struct.unpack(">H", length_bytes)[0] + f.read(length - 2) + except (OSError, struct.error): + return 0, 0 + + +def is_safe_ip(ip_str: str) -> bool: + """SSRF 防护:检查 IP 是否安全(非 loopback/RFC1918/link-local/ULA)。 + + Args: + ip_str: IP 地址字符串 + + Returns: + True 如果 IP 是安全的公网地址,False 如果在拒绝的网络段内 + """ + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return False # 无效 IP,拒绝 + + for network in _DENIED_NETWORKS: + if ip in network: + return False + return True + + +def sanitize_markdown(content: str) -> str: + """净化 Markdown 内容 — 移除危险 HTML 标签和属性。 + + 移除内容: + - 危险标签(script/iframe/object/embed/style/form/input/textarea/button/link/meta)及其内容 + - javascript: 协议 + - 事件处理器属性(onclick/onload/onerror 等) + + Args: + content: 原始 Markdown/HTML 文本 + + Returns: + 净化后的文本 + """ + # 移除危险标签及其内容(如 ) + result = _DANGEROUS_TAG_RE.sub("", content) + # 移除孤立的危险标签(如 ) + result = _DANGEROUS_CLOSE_RE.sub("", result) + # 移除 javascript: 协议 + result = _JS_URL_RE.sub("", result) + # 移除事件处理器属性 + result = _EVENT_HANDLER_RE.sub("", result) + return result + + +def sanitize_content(content: str, file_type: str) -> str: + """内容净化主入口 — 对文本格式应用 Markdown 净化。 + + Args: + content: 文本内容 + file_type: 文件类型(扩展名,不含点),如 "md"、"html"、"pdf" + + Returns: + 净化后的内容(二进制格式原样返回) + """ + if file_type in _TEXT_FORMATS: + return sanitize_markdown(content) + return content + + +def parse_xml_safe(content: bytes) -> Any: + """安全解析 XML — 禁止 DTD/实体以防止 XXE 攻击。 + + Args: + content: XML 字节内容 + + Returns: + ElementTree Element + + Raises: + xml.etree.ElementTree.ParseError: 解析失败或包含 DTD + """ + import xml.etree.ElementTree as ET + + parser = ET.XMLParser( + forbid_dtd=True, + forbid_entities=True, + forbid_external=True, + ) + return ET.fromstring(content, parser=parser) + + +__all__ = [ + "ALLOWED_FILE_TYPES", + "MAX_FILE_SIZE", + "MAX_IMAGE_PIXELS", + "MAX_ZIP_DECOMPRESSED_SIZE", + "MAX_ZIP_RATIO", + "check_image_bomb", + "check_zip_bomb", + "is_safe_ip", + "parse_xml_safe", + "sanitize_content", + "sanitize_markdown", + "validate_file_size", + "validate_file_type", +] diff --git a/tests/unit/rag_platform/test_document_processor.py b/tests/unit/rag_platform/test_document_processor.py new file mode 100644 index 0000000..68b5f5b --- /dev/null +++ b/tests/unit/rag_platform/test_document_processor.py @@ -0,0 +1,382 @@ +"""U3+U7 测试 — 文档处理管道。 + +测试场景: +1. parse + segment 管道(mock 文件 I/O) +2. preview 返回 chunk 列表 +3. vectorize 调用 embed model + vector store +4. 失败时设置 error 状态 +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.rag_platform.document_processor import DocumentProcessor +from agentkit.rag_platform.models import DocumentStatus +from agentkit.rag_platform.preview import PreviewResult, generate_preview + + +class TestParseAndSegment: + """parse + segment 管道测试。""" + + def test_parse_extracts_text(self, tmp_path): + """parse 从文件提取文本。""" + # 创建测试文件 + file_path = tmp_path / "test.txt" + file_path.write_text("Hello, world!\nThis is a test document.", encoding="utf-8") + + processor = DocumentProcessor() + text = processor.parse(str(file_path), "txt") + + assert "Hello, world!" in text + assert "test document" in text + + def test_parse_applies_sanitization(self, tmp_path): + """parse 对文本格式应用内容净化。""" + file_path = tmp_path / "test.md" + file_path.write_text( + "Hello world", + encoding="utf-8", + ) + + processor = DocumentProcessor() + text = processor.parse(str(file_path), "md") + + assert " world" + result = sanitize_markdown(content) + assert "safe text" + ) + result = sanitize_markdown(content) + assert "" not in result + assert "line1" not in result + assert "line2" not in result + assert "before" in result + assert "after" in result + + +class TestSanitizeContent: + """sanitize_content 主入口测试。""" + + def test_markdown_sanitized(self): + """Markdown 格式应用净化。""" + content = "Hello " + result = sanitize_content(content, "md") + assert "" + result = sanitize_content(content, "txt") + assert "" + result = sanitize_content(content, "pdf") + assert result == content # 原样返回 + + +class TestIsSafeIp: + """SSRF IP 过滤测试。""" + + @pytest.mark.parametrize( + "ip", + [ + "127.0.0.1", + "127.0.1.1", + "10.0.0.1", + "172.16.0.1", + "172.31.255.255", + "192.168.1.1", + "169.254.1.1", + "0.0.0.0", + ], + ) + def test_private_ips_blocked(self, ip: str): + """私有/loopback IP 被拒绝。""" + assert is_safe_ip(ip) is False + + def test_ipv6_loopback_blocked(self): + """IPv6 loopback 被拒绝。""" + assert is_safe_ip("::1") is False + + def test_ipv6_ula_blocked(self): + """IPv6 ULA 被拒绝。""" + assert is_safe_ip("fd00::1") is False + + def test_ipv6_link_local_blocked(self): + """IPv6 link-local 被拒绝。""" + assert is_safe_ip("fe80::1") is False + + @pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"]) + def test_public_ips_allowed(self, ip: str): + """公网 IP 允许。""" + assert is_safe_ip(ip) is True + + def test_invalid_ip_blocked(self): + """无效 IP 被拒绝。""" + assert is_safe_ip("not-an-ip") is False + + def test_empty_ip_blocked(self): + """空字符串被拒绝。""" + assert is_safe_ip("") is False + + +class TestCheckZipBomb: + """ZIP bomb 检测测试。""" + + def test_normal_zip_passes(self, tmp_path): + """正常 ZIP 文件通过。""" + zip_path = tmp_path / "normal.zip" + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("file.txt", b"Hello, world! " * 100) + + # 不应抛出异常 + check_zip_bomb(str(zip_path)) + + def test_high_ratio_rejected(self, tmp_path): + """高压缩比 ZIP 被拒绝(ZIP bomb)。""" + zip_path = tmp_path / "bomb.zip" + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + # 10MB 全零数据 — 压缩比极高 + zf.writestr("bomb.txt", b"\x00" * 10_000_000) + + with pytest.raises(ValueError, match="Zip bomb"): + check_zip_bomb(str(zip_path)) + + def test_large_uncompressed_rejected_via_mock(self, tmp_path): + """解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。""" + zip_path = str(tmp_path / "large.zip") + + # 创建 mock ZipFile 返回超大文件信息 + mock_info = MagicMock() + mock_info.filename = "large.bin" + mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed + mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit) + + mock_zipfile = MagicMock() + mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile) + mock_zipfile.__exit__ = MagicMock(return_value=False) + mock_zipfile.infolist.return_value = [mock_info] + + with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile): + with pytest.raises(ValueError, match="Zip bomb"): + check_zip_bomb(zip_path) + + def test_docx_zip_format_works(self, tmp_path): + """.docx 格式(本质是 ZIP)能被检测。""" + zip_path = tmp_path / "fake.docx" + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + zf.writestr("document.xml", "") + + # 正常 docx 不应触发 zip bomb + check_zip_bomb(str(zip_path)) + + +class TestCheckImageBomb: + """Image bomb 检测测试。""" + + def _make_png(self, path: Path, width: int, height: int) -> None: + """创建一个最小 PNG 文件(仅 IHDR)。""" + png_header = b"\x89PNG\r\n\x1a\n" + ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00" + ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF + ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc) + path.write_bytes(png_header + ihdr_chunk) + + def test_small_png_passes(self, tmp_path): + """小 PNG 图片通过。""" + img_path = tmp_path / "small.png" + self._make_png(img_path, 1, 1) + + # 不应抛出异常 + check_image_bomb(str(img_path)) + + def test_large_png_rejected(self, tmp_path): + """超大 PNG(像素数 > 100MP)被拒绝。""" + img_path = tmp_path / "huge.png" + # 20000x20000 = 400MP > 100MP + self._make_png(img_path, 20000, 20000) + + with pytest.raises(ValueError, match="Image bomb"): + check_image_bomb(str(img_path)) + + def test_unknown_format_passes(self, tmp_path): + """无法识别的格式视为安全。""" + img_path = tmp_path / "unknown.bin" + img_path.write_bytes(b"\x00" * 32) + + # 不应抛出异常 + check_image_bomb(str(img_path)) + + def test_nonexistent_file_passes(self, tmp_path): + """不存在的文件视为安全(不抛异常)。""" + check_image_bomb(str(tmp_path / "nonexistent.png"))