"""U3+U7 测试 — 内容净化与上传安全。
测试场景:
1. 文件类型白名单(允许类型通过,.exe/.sh 拒绝)
2. 文件大小限制(超限拒绝)
3. Markdown 净化(script 标签移除)
4. SSRF IP 过滤(私有 IP 拒绝,公网 IP 允许)
5. ZIP bomb 检测(高压缩比拒绝)
"""
from __future__ import annotations
import struct
import zipfile
import zlib
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from agentkit.rag_platform.sanitize import (
ALLOWED_FILE_TYPES,
MAX_FILE_SIZE,
check_image_bomb,
check_zip_bomb,
is_safe_ip,
sanitize_content,
sanitize_markdown,
validate_file_size,
validate_file_type,
)
class TestValidateFileType:
"""文件类型白名单测试。"""
@pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"])
def test_allowed_office_types_pass(self, filename: str):
"""允许的 Office 格式通过。"""
assert validate_file_type(filename) == Path(filename).suffix[1:].lower()
@pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"])
def test_allowed_text_types_pass(self, filename: str):
"""允许的文本格式通过。"""
result = validate_file_type(filename)
assert result == Path(filename).suffix[1:].lower()
def test_pdf_returns_pdf(self):
"""PDF 文件返回 'pdf'。"""
assert validate_file_type("report.pdf") == "pdf"
@pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"])
def test_dangerous_types_rejected(self, filename: str):
"""危险文件类型被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type(filename)
def test_exe_rejected(self):
""".exe 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("program.exe")
def test_sh_rejected(self):
""".sh 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("script.sh")
def test_case_insensitive_extension(self):
"""扩展名大小写不敏感。"""
assert validate_file_type("DOC.PDF") == "pdf"
assert validate_file_type("doc.MD") == "md"
def test_no_extension_rejected(self):
"""无扩展名文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("noextension")
def test_all_allowed_types_in_whitelist(self):
"""白名单包含 8 种类型。"""
expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"}
assert set(ALLOWED_FILE_TYPES.keys()) == expected
class TestValidateFileSize:
"""文件大小限制测试。"""
def test_small_file_passes(self):
"""小文件通过。"""
validate_file_size(1024)
validate_file_size(1)
def test_exact_limit_passes(self):
"""刚好等于上限的文件通过。"""
validate_file_size(MAX_FILE_SIZE)
def test_oversized_rejected(self):
"""超限文件被拒绝。"""
with pytest.raises(ValueError, match="exceeds limit"):
validate_file_size(MAX_FILE_SIZE + 1)
def test_zero_size_rejected(self):
"""零字节文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(0)
def test_negative_size_rejected(self):
"""负大小文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(-1)
class TestSanitizeMarkdown:
"""Markdown 净化测试。"""
def test_script_tag_removed(self):
"""script 标签及其内容被移除。"""
content = "Hello world"
result = sanitize_markdown(content)
assert "safe text"
)
result = sanitize_markdown(content)
assert ""
result = sanitize_content(content, "md")
assert ""
result = sanitize_content(content, "txt")
assert ""
result = sanitize_content(content, "pdf")
assert result == content # 原样返回
class TestIsSafeIp:
"""SSRF IP 过滤测试。"""
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"127.0.1.1",
"10.0.0.1",
"172.16.0.1",
"172.31.255.255",
"192.168.1.1",
"169.254.1.1",
"0.0.0.0",
],
)
def test_private_ips_blocked(self, ip: str):
"""私有/loopback IP 被拒绝。"""
assert is_safe_ip(ip) is False
def test_ipv6_loopback_blocked(self):
"""IPv6 loopback 被拒绝。"""
assert is_safe_ip("::1") is False
def test_ipv6_ula_blocked(self):
"""IPv6 ULA 被拒绝。"""
assert is_safe_ip("fd00::1") is False
def test_ipv6_link_local_blocked(self):
"""IPv6 link-local 被拒绝。"""
assert is_safe_ip("fe80::1") is False
@pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"])
def test_public_ips_allowed(self, ip: str):
"""公网 IP 允许。"""
assert is_safe_ip(ip) is True
def test_invalid_ip_blocked(self):
"""无效 IP 被拒绝。"""
assert is_safe_ip("not-an-ip") is False
def test_empty_ip_blocked(self):
"""空字符串被拒绝。"""
assert is_safe_ip("") is False
class TestCheckZipBomb:
"""ZIP bomb 检测测试。"""
def test_normal_zip_passes(self, tmp_path):
"""正常 ZIP 文件通过。"""
zip_path = tmp_path / "normal.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("file.txt", b"Hello, world! " * 100)
# 不应抛出异常
check_zip_bomb(str(zip_path))
def test_high_ratio_rejected(self, tmp_path):
"""高压缩比 ZIP 被拒绝(ZIP bomb)。"""
zip_path = tmp_path / "bomb.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 10MB 全零数据 — 压缩比极高
zf.writestr("bomb.txt", b"\x00" * 10_000_000)
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(str(zip_path))
def test_large_uncompressed_rejected_via_mock(self, tmp_path):
"""解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。"""
zip_path = str(tmp_path / "large.zip")
# 创建 mock ZipFile 返回超大文件信息
mock_info = MagicMock()
mock_info.filename = "large.bin"
mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed
mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit)
mock_zipfile = MagicMock()
mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile)
mock_zipfile.__exit__ = MagicMock(return_value=False)
mock_zipfile.infolist.return_value = [mock_info]
with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile):
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(zip_path)
def test_docx_zip_format_works(self, tmp_path):
""".docx 格式(本质是 ZIP)能被检测。"""
zip_path = tmp_path / "fake.docx"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("document.xml", "")
# 正常 docx 不应触发 zip bomb
check_zip_bomb(str(zip_path))
class TestCheckImageBomb:
"""Image bomb 检测测试。"""
def _make_png(self, path: Path, width: int, height: int) -> None:
"""创建一个最小 PNG 文件(仅 IHDR)。"""
png_header = b"\x89PNG\r\n\x1a\n"
ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00"
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
path.write_bytes(png_header + ihdr_chunk)
def test_small_png_passes(self, tmp_path):
"""小 PNG 图片通过。"""
img_path = tmp_path / "small.png"
self._make_png(img_path, 1, 1)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_large_png_rejected(self, tmp_path):
"""超大 PNG(像素数 > 100MP)被拒绝。"""
img_path = tmp_path / "huge.png"
# 20000x20000 = 400MP > 100MP
self._make_png(img_path, 20000, 20000)
with pytest.raises(ValueError, match="Image bomb"):
check_image_bomb(str(img_path))
def test_unknown_format_passes(self, tmp_path):
"""无法识别的格式视为安全。"""
img_path = tmp_path / "unknown.bin"
img_path.write_bytes(b"\x00" * 32)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_nonexistent_file_passes(self, tmp_path):
"""不存在的文件视为安全(不抛异常)。"""
check_image_bomb(str(tmp_path / "nonexistent.png"))