fischer-agentkit/tests/unit/rag_platform/test_sanitize.py

347 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U3+U7 测试 — 内容净化与上传安全。
测试场景:
1. 文件类型白名单(允许类型通过,.exe/.sh 拒绝)
2. 文件大小限制(超限拒绝)
3. Markdown 净化script 标签移除)
4. SSRF IP 过滤(私有 IP 拒绝,公网 IP 允许)
5. ZIP bomb 检测(高压缩比拒绝)
"""
from __future__ import annotations
import struct
import zipfile
import zlib
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from agentkit.rag_platform.sanitize import (
ALLOWED_FILE_TYPES,
MAX_FILE_SIZE,
check_image_bomb,
check_zip_bomb,
is_safe_ip,
sanitize_content,
sanitize_markdown,
validate_file_size,
validate_file_type,
)
class TestValidateFileType:
"""文件类型白名单测试。"""
@pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"])
def test_allowed_office_types_pass(self, filename: str):
"""允许的 Office 格式通过。"""
assert validate_file_type(filename) == Path(filename).suffix[1:].lower()
@pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"])
def test_allowed_text_types_pass(self, filename: str):
"""允许的文本格式通过。"""
result = validate_file_type(filename)
assert result == Path(filename).suffix[1:].lower()
def test_pdf_returns_pdf(self):
"""PDF 文件返回 'pdf'"""
assert validate_file_type("report.pdf") == "pdf"
@pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"])
def test_dangerous_types_rejected(self, filename: str):
"""危险文件类型被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type(filename)
def test_exe_rejected(self):
""".exe 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("program.exe")
def test_sh_rejected(self):
""".sh 文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("script.sh")
def test_case_insensitive_extension(self):
"""扩展名大小写不敏感。"""
assert validate_file_type("DOC.PDF") == "pdf"
assert validate_file_type("doc.MD") == "md"
def test_no_extension_rejected(self):
"""无扩展名文件被拒绝。"""
with pytest.raises(ValueError, match="not allowed"):
validate_file_type("noextension")
def test_all_allowed_types_in_whitelist(self):
"""白名单包含 8 种类型。"""
expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"}
assert set(ALLOWED_FILE_TYPES.keys()) == expected
class TestValidateFileSize:
"""文件大小限制测试。"""
def test_small_file_passes(self):
"""小文件通过。"""
validate_file_size(1024)
validate_file_size(1)
def test_exact_limit_passes(self):
"""刚好等于上限的文件通过。"""
validate_file_size(MAX_FILE_SIZE)
def test_oversized_rejected(self):
"""超限文件被拒绝。"""
with pytest.raises(ValueError, match="exceeds limit"):
validate_file_size(MAX_FILE_SIZE + 1)
def test_zero_size_rejected(self):
"""零字节文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(0)
def test_negative_size_rejected(self):
"""负大小文件被拒绝。"""
with pytest.raises(ValueError, match="must be positive"):
validate_file_size(-1)
class TestSanitizeMarkdown:
"""Markdown 净化测试。"""
def test_script_tag_removed(self):
"""script 标签及其内容被移除。"""
content = "Hello <script>alert('xss')</script> world"
result = sanitize_markdown(content)
assert "<script>" not in result
assert "alert" not in result
assert "Hello" in result
assert "world" in result
def test_iframe_tag_removed(self):
"""iframe 标签被移除。"""
content = "Text <iframe src='evil.com'></iframe> more"
result = sanitize_markdown(content)
assert "<iframe" not in result
assert "evil.com" not in result
def test_object_tag_removed(self):
"""object 标签被移除。"""
content = "<object data='evil.swf'></object>"
result = sanitize_markdown(content)
assert "<object" not in result
def test_embed_tag_removed(self):
"""embed 标签被移除。"""
content = "<embed src='evil.swf'>"
result = sanitize_markdown(content)
assert "<embed" not in result
def test_safe_content_preserved(self):
"""安全内容保留。"""
content = "# Title\n\nThis is **bold** and *italic*."
result = sanitize_markdown(content)
assert result == content
def test_javascript_protocol_removed(self):
"""javascript: 协议被移除。"""
content = "<a href='javascript:alert(1)'>click</a>"
result = sanitize_markdown(content)
assert "javascript:" not in result.lower()
def test_event_handler_removed(self):
"""事件处理器属性被移除。"""
content = "<div onclick='alert(1)'>text</div>"
result = sanitize_markdown(content)
assert "onclick" not in result
def test_multiple_dangerous_tags_removed(self):
"""多个危险标签同时移除。"""
content = (
"<script>bad()</script><iframe src='x'></iframe>safe text<object data='y'></object>"
)
result = sanitize_markdown(content)
assert "<script" not in result
assert "<iframe" not in result
assert "<object" not in result
assert "safe text" in result
def test_multiline_script_removed(self):
"""多行 script 标签被移除。"""
content = "before\n<script>\nline1\nline2\n</script>\nafter"
result = sanitize_markdown(content)
assert "<script>" not in result
assert "line1" not in result
assert "line2" not in result
assert "before" in result
assert "after" in result
class TestSanitizeContent:
"""sanitize_content 主入口测试。"""
def test_markdown_sanitized(self):
"""Markdown 格式应用净化。"""
content = "Hello <script>alert(1)</script>"
result = sanitize_content(content, "md")
assert "<script>" not in result
def test_html_sanitized(self):
"""HTML 格式应用净化。"""
content = "<iframe src='evil'></iframe>"
result = sanitize_content(content, "html")
assert "<iframe" not in result
def test_text_sanitized(self):
"""纯文本格式应用净化。"""
content = "text <script>x</script>"
result = sanitize_content(content, "txt")
assert "<script>" not in result
def test_binary_format_not_sanitized(self):
"""二进制格式pdf不应用净化。"""
content = "raw <script>content</script>"
result = sanitize_content(content, "pdf")
assert result == content # 原样返回
class TestIsSafeIp:
"""SSRF IP 过滤测试。"""
@pytest.mark.parametrize(
"ip",
[
"127.0.0.1",
"127.0.1.1",
"10.0.0.1",
"172.16.0.1",
"172.31.255.255",
"192.168.1.1",
"169.254.1.1",
"0.0.0.0",
],
)
def test_private_ips_blocked(self, ip: str):
"""私有/loopback IP 被拒绝。"""
assert is_safe_ip(ip) is False
def test_ipv6_loopback_blocked(self):
"""IPv6 loopback 被拒绝。"""
assert is_safe_ip("::1") is False
def test_ipv6_ula_blocked(self):
"""IPv6 ULA 被拒绝。"""
assert is_safe_ip("fd00::1") is False
def test_ipv6_link_local_blocked(self):
"""IPv6 link-local 被拒绝。"""
assert is_safe_ip("fe80::1") is False
@pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"])
def test_public_ips_allowed(self, ip: str):
"""公网 IP 允许。"""
assert is_safe_ip(ip) is True
def test_invalid_ip_blocked(self):
"""无效 IP 被拒绝。"""
assert is_safe_ip("not-an-ip") is False
def test_empty_ip_blocked(self):
"""空字符串被拒绝。"""
assert is_safe_ip("") is False
class TestCheckZipBomb:
"""ZIP bomb 检测测试。"""
def test_normal_zip_passes(self, tmp_path):
"""正常 ZIP 文件通过。"""
zip_path = tmp_path / "normal.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("file.txt", b"Hello, world! " * 100)
# 不应抛出异常
check_zip_bomb(str(zip_path))
def test_high_ratio_rejected(self, tmp_path):
"""高压缩比 ZIP 被拒绝ZIP bomb"""
zip_path = tmp_path / "bomb.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 10MB 全零数据 — 压缩比极高
zf.writestr("bomb.txt", b"\x00" * 10_000_000)
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(str(zip_path))
def test_large_uncompressed_rejected_via_mock(self, tmp_path):
"""解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。"""
zip_path = str(tmp_path / "large.zip")
# 创建 mock ZipFile 返回超大文件信息
mock_info = MagicMock()
mock_info.filename = "large.bin"
mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed
mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit)
mock_zipfile = MagicMock()
mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile)
mock_zipfile.__exit__ = MagicMock(return_value=False)
mock_zipfile.infolist.return_value = [mock_info]
with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile):
with pytest.raises(ValueError, match="Zip bomb"):
check_zip_bomb(zip_path)
def test_docx_zip_format_works(self, tmp_path):
""".docx 格式(本质是 ZIP能被检测。"""
zip_path = tmp_path / "fake.docx"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr("document.xml", "<doc/>")
# 正常 docx 不应触发 zip bomb
check_zip_bomb(str(zip_path))
class TestCheckImageBomb:
"""Image bomb 检测测试。"""
def _make_png(self, path: Path, width: int, height: int) -> None:
"""创建一个最小 PNG 文件(仅 IHDR"""
png_header = b"\x89PNG\r\n\x1a\n"
ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00"
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
path.write_bytes(png_header + ihdr_chunk)
def test_small_png_passes(self, tmp_path):
"""小 PNG 图片通过。"""
img_path = tmp_path / "small.png"
self._make_png(img_path, 1, 1)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_large_png_rejected(self, tmp_path):
"""超大 PNG像素数 > 100MP被拒绝。"""
img_path = tmp_path / "huge.png"
# 20000x20000 = 400MP > 100MP
self._make_png(img_path, 20000, 20000)
with pytest.raises(ValueError, match="Image bomb"):
check_image_bomb(str(img_path))
def test_unknown_format_passes(self, tmp_path):
"""无法识别的格式视为安全。"""
img_path = tmp_path / "unknown.bin"
img_path.write_bytes(b"\x00" * 32)
# 不应抛出异常
check_image_bomb(str(img_path))
def test_nonexistent_file_passes(self, tmp_path):
"""不存在的文件视为安全(不抛异常)。"""
check_image_bomb(str(tmp_path / "nonexistent.png"))