347 lines
12 KiB
Python
347 lines
12 KiB
Python
"""U3+U7 测试 — 内容净化与上传安全。
|
||
|
||
测试场景:
|
||
1. 文件类型白名单(允许类型通过,.exe/.sh 拒绝)
|
||
2. 文件大小限制(超限拒绝)
|
||
3. Markdown 净化(script 标签移除)
|
||
4. SSRF IP 过滤(私有 IP 拒绝,公网 IP 允许)
|
||
5. ZIP bomb 检测(高压缩比拒绝)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
import zipfile
|
||
import zlib
|
||
from pathlib import Path
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.rag_platform.sanitize import (
|
||
ALLOWED_FILE_TYPES,
|
||
MAX_FILE_SIZE,
|
||
check_image_bomb,
|
||
check_zip_bomb,
|
||
is_safe_ip,
|
||
sanitize_content,
|
||
sanitize_markdown,
|
||
validate_file_size,
|
||
validate_file_type,
|
||
)
|
||
|
||
|
||
class TestValidateFileType:
|
||
"""文件类型白名单测试。"""
|
||
|
||
@pytest.mark.parametrize("filename", ["doc.pdf", "doc.docx", "doc.xlsx"])
|
||
def test_allowed_office_types_pass(self, filename: str):
|
||
"""允许的 Office 格式通过。"""
|
||
assert validate_file_type(filename) == Path(filename).suffix[1:].lower()
|
||
|
||
@pytest.mark.parametrize("filename", ["notes.md", "data.csv", "page.html", "readme.txt"])
|
||
def test_allowed_text_types_pass(self, filename: str):
|
||
"""允许的文本格式通过。"""
|
||
result = validate_file_type(filename)
|
||
assert result == Path(filename).suffix[1:].lower()
|
||
|
||
def test_pdf_returns_pdf(self):
|
||
"""PDF 文件返回 'pdf'。"""
|
||
assert validate_file_type("report.pdf") == "pdf"
|
||
|
||
@pytest.mark.parametrize("filename", ["malware.exe", "script.sh", "shell.bat"])
|
||
def test_dangerous_types_rejected(self, filename: str):
|
||
"""危险文件类型被拒绝。"""
|
||
with pytest.raises(ValueError, match="not allowed"):
|
||
validate_file_type(filename)
|
||
|
||
def test_exe_rejected(self):
|
||
""".exe 文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="not allowed"):
|
||
validate_file_type("program.exe")
|
||
|
||
def test_sh_rejected(self):
|
||
""".sh 文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="not allowed"):
|
||
validate_file_type("script.sh")
|
||
|
||
def test_case_insensitive_extension(self):
|
||
"""扩展名大小写不敏感。"""
|
||
assert validate_file_type("DOC.PDF") == "pdf"
|
||
assert validate_file_type("doc.MD") == "md"
|
||
|
||
def test_no_extension_rejected(self):
|
||
"""无扩展名文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="not allowed"):
|
||
validate_file_type("noextension")
|
||
|
||
def test_all_allowed_types_in_whitelist(self):
|
||
"""白名单包含 8 种类型。"""
|
||
expected = {".pdf", ".docx", ".xlsx", ".pptx", ".txt", ".md", ".csv", ".html"}
|
||
assert set(ALLOWED_FILE_TYPES.keys()) == expected
|
||
|
||
|
||
class TestValidateFileSize:
|
||
"""文件大小限制测试。"""
|
||
|
||
def test_small_file_passes(self):
|
||
"""小文件通过。"""
|
||
validate_file_size(1024)
|
||
validate_file_size(1)
|
||
|
||
def test_exact_limit_passes(self):
|
||
"""刚好等于上限的文件通过。"""
|
||
validate_file_size(MAX_FILE_SIZE)
|
||
|
||
def test_oversized_rejected(self):
|
||
"""超限文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="exceeds limit"):
|
||
validate_file_size(MAX_FILE_SIZE + 1)
|
||
|
||
def test_zero_size_rejected(self):
|
||
"""零字节文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="must be positive"):
|
||
validate_file_size(0)
|
||
|
||
def test_negative_size_rejected(self):
|
||
"""负大小文件被拒绝。"""
|
||
with pytest.raises(ValueError, match="must be positive"):
|
||
validate_file_size(-1)
|
||
|
||
|
||
class TestSanitizeMarkdown:
|
||
"""Markdown 净化测试。"""
|
||
|
||
def test_script_tag_removed(self):
|
||
"""script 标签及其内容被移除。"""
|
||
content = "Hello <script>alert('xss')</script> world"
|
||
result = sanitize_markdown(content)
|
||
assert "<script>" not in result
|
||
assert "alert" not in result
|
||
assert "Hello" in result
|
||
assert "world" in result
|
||
|
||
def test_iframe_tag_removed(self):
|
||
"""iframe 标签被移除。"""
|
||
content = "Text <iframe src='evil.com'></iframe> more"
|
||
result = sanitize_markdown(content)
|
||
assert "<iframe" not in result
|
||
assert "evil.com" not in result
|
||
|
||
def test_object_tag_removed(self):
|
||
"""object 标签被移除。"""
|
||
content = "<object data='evil.swf'></object>"
|
||
result = sanitize_markdown(content)
|
||
assert "<object" not in result
|
||
|
||
def test_embed_tag_removed(self):
|
||
"""embed 标签被移除。"""
|
||
content = "<embed src='evil.swf'>"
|
||
result = sanitize_markdown(content)
|
||
assert "<embed" not in result
|
||
|
||
def test_safe_content_preserved(self):
|
||
"""安全内容保留。"""
|
||
content = "# Title\n\nThis is **bold** and *italic*."
|
||
result = sanitize_markdown(content)
|
||
assert result == content
|
||
|
||
def test_javascript_protocol_removed(self):
|
||
"""javascript: 协议被移除。"""
|
||
content = "<a href='javascript:alert(1)'>click</a>"
|
||
result = sanitize_markdown(content)
|
||
assert "javascript:" not in result.lower()
|
||
|
||
def test_event_handler_removed(self):
|
||
"""事件处理器属性被移除。"""
|
||
content = "<div onclick='alert(1)'>text</div>"
|
||
result = sanitize_markdown(content)
|
||
assert "onclick" not in result
|
||
|
||
def test_multiple_dangerous_tags_removed(self):
|
||
"""多个危险标签同时移除。"""
|
||
content = (
|
||
"<script>bad()</script><iframe src='x'></iframe>safe text<object data='y'></object>"
|
||
)
|
||
result = sanitize_markdown(content)
|
||
assert "<script" not in result
|
||
assert "<iframe" not in result
|
||
assert "<object" not in result
|
||
assert "safe text" in result
|
||
|
||
def test_multiline_script_removed(self):
|
||
"""多行 script 标签被移除。"""
|
||
content = "before\n<script>\nline1\nline2\n</script>\nafter"
|
||
result = sanitize_markdown(content)
|
||
assert "<script>" not in result
|
||
assert "line1" not in result
|
||
assert "line2" not in result
|
||
assert "before" in result
|
||
assert "after" in result
|
||
|
||
|
||
class TestSanitizeContent:
|
||
"""sanitize_content 主入口测试。"""
|
||
|
||
def test_markdown_sanitized(self):
|
||
"""Markdown 格式应用净化。"""
|
||
content = "Hello <script>alert(1)</script>"
|
||
result = sanitize_content(content, "md")
|
||
assert "<script>" not in result
|
||
|
||
def test_html_sanitized(self):
|
||
"""HTML 格式应用净化。"""
|
||
content = "<iframe src='evil'></iframe>"
|
||
result = sanitize_content(content, "html")
|
||
assert "<iframe" not in result
|
||
|
||
def test_text_sanitized(self):
|
||
"""纯文本格式应用净化。"""
|
||
content = "text <script>x</script>"
|
||
result = sanitize_content(content, "txt")
|
||
assert "<script>" not in result
|
||
|
||
def test_binary_format_not_sanitized(self):
|
||
"""二进制格式(pdf)不应用净化。"""
|
||
content = "raw <script>content</script>"
|
||
result = sanitize_content(content, "pdf")
|
||
assert result == content # 原样返回
|
||
|
||
|
||
class TestIsSafeIp:
|
||
"""SSRF IP 过滤测试。"""
|
||
|
||
@pytest.mark.parametrize(
|
||
"ip",
|
||
[
|
||
"127.0.0.1",
|
||
"127.0.1.1",
|
||
"10.0.0.1",
|
||
"172.16.0.1",
|
||
"172.31.255.255",
|
||
"192.168.1.1",
|
||
"169.254.1.1",
|
||
"0.0.0.0",
|
||
],
|
||
)
|
||
def test_private_ips_blocked(self, ip: str):
|
||
"""私有/loopback IP 被拒绝。"""
|
||
assert is_safe_ip(ip) is False
|
||
|
||
def test_ipv6_loopback_blocked(self):
|
||
"""IPv6 loopback 被拒绝。"""
|
||
assert is_safe_ip("::1") is False
|
||
|
||
def test_ipv6_ula_blocked(self):
|
||
"""IPv6 ULA 被拒绝。"""
|
||
assert is_safe_ip("fd00::1") is False
|
||
|
||
def test_ipv6_link_local_blocked(self):
|
||
"""IPv6 link-local 被拒绝。"""
|
||
assert is_safe_ip("fe80::1") is False
|
||
|
||
@pytest.mark.parametrize("ip", ["8.8.8.8", "1.1.1.1", "203.0.113.1"])
|
||
def test_public_ips_allowed(self, ip: str):
|
||
"""公网 IP 允许。"""
|
||
assert is_safe_ip(ip) is True
|
||
|
||
def test_invalid_ip_blocked(self):
|
||
"""无效 IP 被拒绝。"""
|
||
assert is_safe_ip("not-an-ip") is False
|
||
|
||
def test_empty_ip_blocked(self):
|
||
"""空字符串被拒绝。"""
|
||
assert is_safe_ip("") is False
|
||
|
||
|
||
class TestCheckZipBomb:
|
||
"""ZIP bomb 检测测试。"""
|
||
|
||
def test_normal_zip_passes(self, tmp_path):
|
||
"""正常 ZIP 文件通过。"""
|
||
zip_path = tmp_path / "normal.zip"
|
||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
zf.writestr("file.txt", b"Hello, world! " * 100)
|
||
|
||
# 不应抛出异常
|
||
check_zip_bomb(str(zip_path))
|
||
|
||
def test_high_ratio_rejected(self, tmp_path):
|
||
"""高压缩比 ZIP 被拒绝(ZIP bomb)。"""
|
||
zip_path = tmp_path / "bomb.zip"
|
||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
# 10MB 全零数据 — 压缩比极高
|
||
zf.writestr("bomb.txt", b"\x00" * 10_000_000)
|
||
|
||
with pytest.raises(ValueError, match="Zip bomb"):
|
||
check_zip_bomb(str(zip_path))
|
||
|
||
def test_large_uncompressed_rejected_via_mock(self, tmp_path):
|
||
"""解压后总大小超限被拒绝(使用 mock 避免实际写入大文件)。"""
|
||
zip_path = str(tmp_path / "large.zip")
|
||
|
||
# 创建 mock ZipFile 返回超大文件信息
|
||
mock_info = MagicMock()
|
||
mock_info.filename = "large.bin"
|
||
mock_info.compress_size = 10 * 1024 * 1024 # 10MB compressed
|
||
mock_info.file_size = 600 * 1024 * 1024 # 600MB uncompressed (> 500MB limit)
|
||
|
||
mock_zipfile = MagicMock()
|
||
mock_zipfile.__enter__ = MagicMock(return_value=mock_zipfile)
|
||
mock_zipfile.__exit__ = MagicMock(return_value=False)
|
||
mock_zipfile.infolist.return_value = [mock_info]
|
||
|
||
with patch("agentkit.rag_platform.sanitize.zipfile.ZipFile", return_value=mock_zipfile):
|
||
with pytest.raises(ValueError, match="Zip bomb"):
|
||
check_zip_bomb(zip_path)
|
||
|
||
def test_docx_zip_format_works(self, tmp_path):
|
||
""".docx 格式(本质是 ZIP)能被检测。"""
|
||
zip_path = tmp_path / "fake.docx"
|
||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||
zf.writestr("document.xml", "<doc/>")
|
||
|
||
# 正常 docx 不应触发 zip bomb
|
||
check_zip_bomb(str(zip_path))
|
||
|
||
|
||
class TestCheckImageBomb:
|
||
"""Image bomb 检测测试。"""
|
||
|
||
def _make_png(self, path: Path, width: int, height: int) -> None:
|
||
"""创建一个最小 PNG 文件(仅 IHDR)。"""
|
||
png_header = b"\x89PNG\r\n\x1a\n"
|
||
ihdr_data = struct.pack(">II", width, height) + b"\x08\x02\x00\x00\x00"
|
||
ihdr_crc = zlib.crc32(b"IHDR" + ihdr_data) & 0xFFFFFFFF
|
||
ihdr_chunk = struct.pack(">I", 13) + b"IHDR" + ihdr_data + struct.pack(">I", ihdr_crc)
|
||
path.write_bytes(png_header + ihdr_chunk)
|
||
|
||
def test_small_png_passes(self, tmp_path):
|
||
"""小 PNG 图片通过。"""
|
||
img_path = tmp_path / "small.png"
|
||
self._make_png(img_path, 1, 1)
|
||
|
||
# 不应抛出异常
|
||
check_image_bomb(str(img_path))
|
||
|
||
def test_large_png_rejected(self, tmp_path):
|
||
"""超大 PNG(像素数 > 100MP)被拒绝。"""
|
||
img_path = tmp_path / "huge.png"
|
||
# 20000x20000 = 400MP > 100MP
|
||
self._make_png(img_path, 20000, 20000)
|
||
|
||
with pytest.raises(ValueError, match="Image bomb"):
|
||
check_image_bomb(str(img_path))
|
||
|
||
def test_unknown_format_passes(self, tmp_path):
|
||
"""无法识别的格式视为安全。"""
|
||
img_path = tmp_path / "unknown.bin"
|
||
img_path.write_bytes(b"\x00" * 32)
|
||
|
||
# 不应抛出异常
|
||
check_image_bomb(str(img_path))
|
||
|
||
def test_nonexistent_file_passes(self, tmp_path):
|
||
"""不存在的文件视为安全(不抛异常)。"""
|
||
check_image_bomb(str(tmp_path / "nonexistent.png"))
|