"""U3+U7 测试 — 内容净化与上传安全。 测试场景: 1. 文件类型白名单(允许类型通过,.exe/.sh 拒绝) 2. 文件大小限制(超限拒绝) 3. Markdown 净化(script 标签移除) 4. SSRF IP 过滤(私有 IP 拒绝,公网 IP 允许) 5. ZIP bomb 检测(高压缩比拒绝) """ from __future__ import annotations import struct import zipfile import zlib from pathlib import Path from unittest.mock import MagicMock, patch import pytest from agentkit.rag_platform.sanitize import ( ALLOWED_FILE_TYPES, MAX_FILE_SIZE, check_image_bomb, check_zip_bomb, is_safe_ip, sanitize_content, sanitize_markdown, validate_file_size, validate_file_type, ) class TestValidateFileType: """文件类型白名单测试。""" @pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"]) def test_allowed_office_types_pass(self, filename: str): """允许的 Office 格式通过。""" assert validate_file_type(filename) == Path(filename).suffix[1:].lower() @pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"]) def test_allowed_text_types_pass(self, filename: str): """允许的文本格式通过。""" result = validate_file_type(filename) assert result == Path(filename).suffix[1:].lower() def test_pdf_returns_pdf(self): """PDF 文件返回 'pdf'。""" assert validate_file_type("report.pdf") == "pdf" @pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"]) def test_dangerous_types_rejected(self, filename: str): """危险文件类型被拒绝。""" with pytest.raises(ValueError, match="not allowed"): validate_file_type(filename) def test_exe_rejected(self): """.exe 文件被拒绝。""" with pytest.raises(ValueError, match="not allowed"): validate_file_type("program.exe") def test_sh_rejected(self): """.sh 文件被拒绝。""" with pytest.raises(ValueError, match="not allowed"): validate_file_type("script.sh") def test_case_insensitive_extension(self): """扩展名大小写不敏感。""" assert validate_file_type("DOC.PDF") == "pdf" assert validate_file_type("doc.MD") == "md" def test_no_extension_rejected(self): """无扩展名文件被拒绝。""" with pytest.raises(ValueError, match="not allowed"): validate_file_type("noextension") def test_all_allowed_types_in_whitelist(self): """白名单包含 8 种类型。""" expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"} assert set(ALLOWED_FILE_TYPES.keys()) == expected class TestValidateFileSize: """文件大小限制测试。""" def test_small_file_passes(self): """小文件通过。""" validate_file_size(1024) validate_file_size(1) def test_exact_limit_passes(self): """刚好等于上限的文件通过。""" validate_file_size(MAX_FILE_SIZE) def test_oversized_rejected(self): """超限文件被拒绝。""" with pytest.raises(ValueError, match="exceeds limit"): validate_file_size(MAX_FILE_SIZE + 1) def test_zero_size_rejected(self): """零字节文件被拒绝。""" with pytest.raises(ValueError, match="must be positive"): validate_file_size(0) def test_negative_size_rejected(self): """负大小文件被拒绝。""" with pytest.raises(ValueError, match="must be positive"): validate_file_size(-1) class TestSanitizeMarkdown: """Markdown 净化测试。""" def test_script_tag_removed(self): """script 标签及其内容被移除。""" content = "Hello world" result = sanitize_markdown(content) assert "safe text" ) result = sanitize_markdown(content) assert "" not in result assert "line1" not in result assert "line2" not in result assert "before" in result assert "after" in result class TestSanitizeContent: """sanitize_content 主入口测试。""" def test_markdown_sanitized(self): """Markdown 格式应用净化。""" content = "Hello " result = sanitize_content(content, "md") assert "" result = sanitize_content(content, "txt") assert "" result = sanitize_content(content, "pdf") assert result == content # 原样返回 class TestIsSafeIp: """SSRF IP 过滤测试。""" @pytest.mark.parametrize( "ip", [ "127.0.0.1", "127.0.1.1", "10.0.0.1", "172.16.0.1", "172.31.255.255", "192.168.1.1", "169.254.1.1", "0.0.0.0", ], ) def test_private_ips_blocked(self, ip: str): """私有/loopback IP 被拒绝。""" assert is_safe_ip(ip) is False def test_ipv6_loopback_blocked(self): """IPv6 loopback 被拒绝。""" assert is_safe_ip("::1") is False def test_ipv6_ula_blocked(self): """IPv6 ULA 被拒绝。""" assert is_safe_ip("fd00::1") is False def test_ipv6_link_local_blocked(self): """IPv6 link-local 被拒绝。""" assert is_safe_ip("fe80::1") is False @pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"]) def test_public_ips_allowed(self, ip: str): """公网 IP 允许。""" assert is_safe_ip(ip) is True def test_invalid_ip_blocked(self): """无效 IP 被拒绝。""" assert is_safe_ip("not-an-ip") is False def test_empty_ip_blocked(self): """空字符串被拒绝。""" assert is_safe_ip("") is False class TestCheckZipBomb: """ZIP bomb 检测测试。""" def test_normal_zip_passes(self, tmp_path): """正常 ZIP 文件通过。""" zip_path = tmp_path / "normal.zip" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr("file.txt", b"Hello, world! " * 100) # 不应抛出异常 check_zip_bomb(str(zip_path)) def test_high_ratio_rejected(self, tmp_path): """高压缩比 ZIP 被拒绝(ZIP bomb)。""" zip_path = tmp_path / "bomb.zip" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: # 10MB 全零数据 — 压缩比极高 zf.writestr("bomb.txt", b"\x00" * 10_000_000) with pytest.raises(ValueError, match="Zip bomb"): check_zip_bomb(str(zip_path)) def test_large_uncompressed_rejected_via_mock(self, tmp_path): """解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。""" zip_path = str(tmp_path / "large.zip") # 创建 mock ZipFile 返回超大文件信息 mock_info = MagicMock() mock_info.filename = "large.bin" mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit) mock_zipfile = MagicMock() mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile) mock_zipfile.__exit__ = MagicMock(return_value=False) mock_zipfile.infolist.return_value = [mock_info] with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile): with pytest.raises(ValueError, match="Zip bomb"): check_zip_bomb(zip_path) def test_docx_zip_format_works(self, tmp_path): """.docx 格式(本质是 ZIP)能被检测。""" zip_path = tmp_path / "fake.docx" with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: zf.writestr("document.xml", "") # 正常 docx 不应触发 zip bomb check_zip_bomb(str(zip_path)) class TestCheckImageBomb: """Image bomb 检测测试。""" def _make_png(self, path: Path, width: int, height: int) -> None: """创建一个最小 PNG 文件(仅 IHDR)。""" png_header = b"\x89PNG\r\n\x1a\n" ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00" ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc) path.write_bytes(png_header + ihdr_chunk) def test_small_png_passes(self, tmp_path): """小 PNG 图片通过。""" img_path = tmp_path / "small.png" self._make_png(img_path, 1, 1) # 不应抛出异常 check_image_bomb(str(img_path)) def test_large_png_rejected(self, tmp_path): """超大 PNG(像素数 > 100MP)被拒绝。""" img_path = tmp_path / "huge.png" # 20000x20000 = 400MP > 100MP self._make_png(img_path, 20000, 20000) with pytest.raises(ValueError, match="Image bomb"): check_image_bomb(str(img_path)) def test_unknown_format_passes(self, tmp_path): """无法识别的格式视为安全。""" img_path = tmp_path / "unknown.bin" img_path.write_bytes(b"\x00" * 32) # 不应抛出异常 check_image_bomb(str(img_path)) def test_nonexistent_file_passes(self, tmp_path): """不存在的文件视为安全(不抛异常)。""" check_image_bomb(str(tmp_path / "nonexistent.png"))