506 lines
20 KiB
Python
506 lines
20 KiB
Python
"""HeadroomCompressor 单元测试
|
||
|
||
所有测试使用 mock headroom 模块,无需安装 headroom-ai。
|
||
"""
|
||
|
||
import time
|
||
from collections import OrderedDict
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.headroom_compressor import (
|
||
HeadroomCompressor,
|
||
_is_code_content,
|
||
_is_json_content,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _make_headroom_compress_mock(return_content="compressed"):
|
||
"""创建 mock headroom.compress 函数,返回带有 messages 属性的结果对象"""
|
||
mock_result = MagicMock()
|
||
mock_result.messages = [{"role": "user", "content": return_content}]
|
||
return mock_result
|
||
|
||
|
||
def _long_json_content():
|
||
"""生成超过 min_length 的 JSON 内容"""
|
||
import json
|
||
items = [{"id": i, "name": f"item_{i}", "description": f"description for item {i}"} for i in range(50)]
|
||
return json.dumps({"items": items})
|
||
|
||
|
||
def _long_code_content():
|
||
"""生成超过 min_length 的代码内容"""
|
||
lines = []
|
||
for i in range(50):
|
||
lines.append(f"def function_{i}():")
|
||
lines.append(f" result = process_data({i})")
|
||
lines.append(f" return result")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _long_text_content():
|
||
"""生成超过 min_length 的纯文本内容"""
|
||
return "This is plain text content. " * 100
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestHeadroomAvailability
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestHeadroomAvailability:
|
||
"""测试 headroom-ai 可用性检测"""
|
||
|
||
def test_is_available_false_when_not_installed(self):
|
||
"""_HEADROOM_AVAILABLE=False 时 is_available() 返回 False"""
|
||
compressor = HeadroomCompressor({})
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", False):
|
||
assert compressor.is_available() is False
|
||
|
||
def test_is_available_true_when_installed(self):
|
||
"""_HEADROOM_AVAILABLE=True 时 is_available() 返回 True"""
|
||
compressor = HeadroomCompressor({})
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True):
|
||
assert compressor.is_available() is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestContentTypeDetection
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestContentTypeDetection:
|
||
"""测试内容类型检测函数"""
|
||
|
||
def test_json_content_detected(self):
|
||
"""有效 JSON 对象被正确检测"""
|
||
assert _is_json_content('{"key": "value"}') is True
|
||
|
||
def test_json_array_detected(self):
|
||
"""有效 JSON 数组被正确检测"""
|
||
assert _is_json_content('[1, 2, 3]') is True
|
||
|
||
def test_non_json_content(self):
|
||
"""普通文本不被识别为 JSON"""
|
||
assert _is_json_content("hello world") is False
|
||
|
||
def test_invalid_json_start(self):
|
||
"""以 { 开头但无效的 JSON 不被识别"""
|
||
assert _is_json_content("{invalid") is False
|
||
|
||
def test_code_content_detected(self):
|
||
"""Python 代码(含 def/class 关键字)被正确检测"""
|
||
code = "def hello():\n pass\n\nclass Foo:\n pass\nimport os\nfrom sys import path"
|
||
assert _is_code_content(code) is True
|
||
|
||
def test_non_code_content(self):
|
||
"""纯文本不被识别为代码"""
|
||
text = "This is just a regular paragraph of text with no code keywords at all."
|
||
assert _is_code_content(text) is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCompressToolResult
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCompressToolResult:
|
||
"""测试 compress_tool_result 方法"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_short_content_not_compressed(self):
|
||
"""短于 min_length 的内容不压缩"""
|
||
compressor = HeadroomCompressor({"min_length": 500})
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True):
|
||
result = await compressor.compress_tool_result("test_tool", "short content")
|
||
assert result == "short content"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_json_content_compressed_with_smart_crusher(self):
|
||
"""JSON 内容使用 smart_crusher 压缩"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher", "code_compressor"],
|
||
})
|
||
json_content = _long_json_content()
|
||
mock_fn = MagicMock(return_value=_make_headroom_compress_mock("compressed json"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress_tool_result("json_tool", json_content)
|
||
assert "compressed json" in result
|
||
assert "<!-- CCR:hash=" in result
|
||
mock_fn.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_code_content_compressed_with_code_compressor(self):
|
||
"""代码内容使用 code_compressor 压缩"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher", "code_compressor"],
|
||
})
|
||
code_content = _long_code_content()
|
||
mock_fn = MagicMock(return_value=_make_headroom_compress_mock("compressed code"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress_tool_result("code_tool", code_content)
|
||
assert "compressed code" in result
|
||
assert "<!-- CCR:hash=" in result
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_text_content_not_compressed(self):
|
||
"""纯文本内容无适用的压缩器,返回原始内容"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher", "code_compressor"],
|
||
})
|
||
text_content = _long_text_content()
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True):
|
||
result = await compressor.compress_tool_result("text_tool", text_content)
|
||
assert result == text_content
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_compression_includes_ccr_hash(self):
|
||
"""压缩后的输出包含 CCR hash 标记"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher"],
|
||
})
|
||
json_content = _long_json_content()
|
||
mock_fn = MagicMock(return_value=_make_headroom_compress_mock("small"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress_tool_result("tool", json_content)
|
||
assert "<!-- CCR:hash=" in result
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_headroom_not_available_returns_original(self):
|
||
"""headroom-ai 不可用时返回原始内容"""
|
||
compressor = HeadroomCompressor({"min_length": 100})
|
||
json_content = _long_json_content()
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", False):
|
||
result = await compressor.compress_tool_result("tool", json_content)
|
||
assert result == json_content
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_compression_failure_falls_back(self):
|
||
"""headroom_compress 抛出异常时回退到原始内容"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher"],
|
||
})
|
||
json_content = _long_json_content()
|
||
mock_fn = MagicMock(side_effect=RuntimeError("headroom error"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress_tool_result("tool", json_content)
|
||
assert result == json_content
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCompress
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCompress:
|
||
"""测试 compress 方法(消息列表压缩)"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_tool_messages_compressed(self):
|
||
"""role=tool 且内容足够长的消息被压缩"""
|
||
compressor = HeadroomCompressor({"min_length": 100})
|
||
long_content = "x" * 500
|
||
messages = [
|
||
{"role": "system", "content": "system prompt"},
|
||
{"role": "tool", "content": long_content, "tool_call_id": "call_1"},
|
||
]
|
||
mock_fn = MagicMock(return_value=_make_headroom_compress_mock("compressed"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress(messages)
|
||
assert len(result) == 2
|
||
assert result[0]["role"] == "system"
|
||
assert result[1]["role"] == "tool"
|
||
assert "compressed" in result[1]["content"]
|
||
assert "<!-- CCR:hash=" in result[1]["content"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_non_tool_messages_preserved(self):
|
||
"""system/user/assistant 消息保持不变"""
|
||
compressor = HeadroomCompressor({"min_length": 100})
|
||
messages = [
|
||
{"role": "system", "content": "system prompt"},
|
||
{"role": "user", "content": "x" * 500},
|
||
{"role": "assistant", "content": "y" * 500},
|
||
]
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True):
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_short_tool_messages_not_compressed(self):
|
||
"""短于 min_length 的 tool 消息不压缩"""
|
||
compressor = HeadroomCompressor({"min_length": 500})
|
||
messages = [
|
||
{"role": "tool", "content": "short", "tool_call_id": "call_1"},
|
||
]
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True):
|
||
result = await compressor.compress(messages)
|
||
assert result[0]["content"] == "short"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_headroom_not_available_returns_original(self):
|
||
"""headroom-ai 不可用时返回原始消息列表"""
|
||
compressor = HeadroomCompressor({"min_length": 100})
|
||
messages = [
|
||
{"role": "tool", "content": "x" * 500, "tool_call_id": "call_1"},
|
||
]
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", False):
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCCRRetrieve
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCCRRetrieve:
|
||
"""测试 CCR 缓存检索"""
|
||
|
||
def test_retrieve_by_hash(self):
|
||
"""通过 hash 检索存储的原始内容"""
|
||
compressor = HeadroomCompressor({})
|
||
original = "this is the original content"
|
||
ccr_hash = compressor._store_ccr(original)
|
||
assert ccr_hash is not None
|
||
|
||
result = compressor.retrieve(ccr_hash=ccr_hash)
|
||
assert result["success"] is True
|
||
assert result["content"] == original
|
||
assert result["ccr_hash"] == ccr_hash
|
||
|
||
def test_retrieve_by_query(self):
|
||
"""通过关键词搜索 CCR 缓存"""
|
||
compressor = HeadroomCompressor({})
|
||
compressor._store_ccr("Python is a programming language")
|
||
compressor._store_ccr("Rust is a systems programming language")
|
||
|
||
result = compressor.retrieve(query="Python")
|
||
assert result["success"] is True
|
||
assert len(result["results"]) >= 1
|
||
assert any("Python" in r["content"] for r in result["results"])
|
||
|
||
def test_retrieve_not_found(self):
|
||
"""无效 hash 返回错误"""
|
||
compressor = HeadroomCompressor({})
|
||
result = compressor.retrieve(ccr_hash="a" * 64) # Full SHA-256 length
|
||
assert result["success"] is False
|
||
assert "error" in result
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ccr_hash_in_compressed_output(self):
|
||
"""compress_tool_result 后内容可通过 CCR hash 检索"""
|
||
compressor = HeadroomCompressor({
|
||
"min_length": 100,
|
||
"compressors": ["smart_crusher"],
|
||
})
|
||
json_content = _long_json_content()
|
||
mock_fn = MagicMock(return_value=_make_headroom_compress_mock("compressed"))
|
||
|
||
with patch("agentkit.core.headroom_compressor._HEADROOM_AVAILABLE", True), \
|
||
patch("agentkit.core.headroom_compressor.headroom_compress", mock_fn):
|
||
result = await compressor.compress_tool_result("tool", json_content)
|
||
# Extract hash from output
|
||
import re
|
||
match = re.search(r"CCR:hash=([a-f0-9]+)", result)
|
||
assert match is not None
|
||
ccr_hash = match.group(1)
|
||
|
||
# Retrieve original via hash
|
||
retrieved = compressor.retrieve(ccr_hash=ccr_hash)
|
||
assert retrieved["success"] is True
|
||
assert retrieved["content"] == json_content
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestHeadroomCompressorConfig
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestHeadroomCompressorConfig:
|
||
"""测试配置项"""
|
||
|
||
def test_default_config(self):
|
||
"""默认配置值"""
|
||
compressor = HeadroomCompressor({})
|
||
assert compressor._compressors == ["smart_crusher", "code_compressor"]
|
||
assert compressor._ccr_ttl == 300
|
||
assert compressor._min_length == 500
|
||
assert compressor._model == "default"
|
||
|
||
def test_custom_config(self):
|
||
config = {
|
||
"compressors": ["smart_crusher"],
|
||
"ccr_ttl": 600,
|
||
"min_length": 1000,
|
||
"model": "gpt-4",
|
||
}
|
||
compressor = HeadroomCompressor(config)
|
||
assert compressor._compressors == ["smart_crusher"]
|
||
assert compressor._ccr_ttl == 600
|
||
assert compressor._min_length == 1000
|
||
assert compressor._model == "gpt-4"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCCRCacheLRU (P0 fix: unbounded growth)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCCRCacheLRU:
|
||
"""测试 CCR 缓存 LRU 淘汰策略"""
|
||
|
||
def test_lru_evicts_oldest_when_full(self):
|
||
"""超过 max_entries 时淘汰最久未访问的条目"""
|
||
compressor = HeadroomCompressor({"max_entries": 3})
|
||
h1 = compressor._store_ccr("content_1")
|
||
h2 = compressor._store_ccr("content_2")
|
||
h3 = compressor._store_ccr("content_3")
|
||
# 第 4 个条目应该触发淘汰第 1 个
|
||
h4 = compressor._store_ccr("content_4")
|
||
assert h1 is not None
|
||
assert h4 is not None
|
||
# h1 应该已被淘汰
|
||
result = compressor.retrieve(ccr_hash=h1)
|
||
assert result["success"] is False
|
||
# h2, h3, h4 应该还在
|
||
assert compressor.retrieve(ccr_hash=h2)["success"] is True
|
||
assert compressor.retrieve(ccr_hash=h3)["success"] is True
|
||
assert compressor.retrieve(ccr_hash=h4)["success"] is True
|
||
|
||
def test_lru_access_renews_entry(self):
|
||
"""retrieve 使条目变为最近访问,不被淘汰"""
|
||
compressor = HeadroomCompressor({"max_entries": 3})
|
||
h1 = compressor._store_ccr("content_1")
|
||
h2 = compressor._store_ccr("content_2")
|
||
h3 = compressor._store_ccr("content_3")
|
||
# 访问 h1,使其变为最近
|
||
compressor.retrieve(ccr_hash=h1)
|
||
# 插入新条目,应该淘汰 h2(最久未访问)
|
||
h4 = compressor._store_ccr("content_4")
|
||
assert compressor.retrieve(ccr_hash=h1)["success"] is True
|
||
assert compressor.retrieve(ccr_hash=h2)["success"] is False
|
||
|
||
def test_default_max_entries(self):
|
||
"""默认 max_entries 为 1000"""
|
||
compressor = HeadroomCompressor({})
|
||
assert compressor._max_entries == 1000
|
||
|
||
def test_custom_max_entries(self):
|
||
"""自定义 max_entries 配置"""
|
||
compressor = HeadroomCompressor({"max_entries": 50})
|
||
assert compressor._max_entries == 50
|
||
|
||
def test_cache_uses_ordered_dict(self):
|
||
"""CCR 缓存使用 OrderedDict"""
|
||
compressor = HeadroomCompressor({})
|
||
assert isinstance(compressor._ccr_cache, OrderedDict)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCCRCacheTTL (P0 fix: TTL enforcement)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCCRCacheTTL:
|
||
"""测试 CCR 缓存 TTL 过期淘汰"""
|
||
|
||
def test_expired_entry_not_retrieved(self):
|
||
"""过期的条目无法被 retrieve"""
|
||
compressor = HeadroomCompressor({"ccr_ttl": 1})
|
||
h = compressor._store_ccr("content")
|
||
time.sleep(1.1)
|
||
result = compressor.retrieve(ccr_hash=h)
|
||
assert result["success"] is False
|
||
|
||
def test_fresh_entry_retrieved(self):
|
||
"""未过期的条目可以正常 retrieve"""
|
||
compressor = HeadroomCompressor({"ccr_ttl": 300})
|
||
h = compressor._store_ccr("content")
|
||
result = compressor.retrieve(ccr_hash=h)
|
||
assert result["success"] is True
|
||
assert result["content"] == "content"
|
||
|
||
def test_ttl_zero_means_no_expiry(self):
|
||
"""ccr_ttl=0 表示永不过期"""
|
||
compressor = HeadroomCompressor({"ccr_ttl": 0})
|
||
h = compressor._store_ccr("content")
|
||
result = compressor.retrieve(ccr_hash=h)
|
||
assert result["success"] is True
|
||
|
||
def test_evict_expired_on_store(self):
|
||
"""_store_ccr 时清理过期条目"""
|
||
compressor = HeadroomCompressor({"ccr_ttl": 1, "max_entries": 100})
|
||
h1 = compressor._store_ccr("old_content")
|
||
time.sleep(1.1)
|
||
# 存储新条目时应触发过期清理
|
||
h2 = compressor._store_ccr("new_content")
|
||
# h1 应该已被清理
|
||
result = compressor.retrieve(ccr_hash=h1)
|
||
assert result["success"] is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# TestCCRCacheCollision (P0 fix: hash collision detection)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCCRCacheCollision:
|
||
"""测试 CCR 缓存哈希碰撞检测"""
|
||
|
||
def test_full_sha256_hash_length(self):
|
||
"""_store_ccr 使用完整 SHA-256(64 字符 hex)"""
|
||
compressor = HeadroomCompressor({})
|
||
h = compressor._store_ccr("some content")
|
||
assert h is not None
|
||
assert len(h) == 64 # Full SHA-256 hex digest
|
||
|
||
def test_same_content_returns_same_hash(self):
|
||
"""相同内容返回相同 hash(幂等)"""
|
||
compressor = HeadroomCompressor({})
|
||
h1 = compressor._store_ccr("identical content")
|
||
h2 = compressor._store_ccr("identical content")
|
||
assert h1 == h2
|
||
|
||
def test_collision_detected_returns_none(self):
|
||
"""碰撞检测:手动注入不同内容到相同 hash 时返回 None"""
|
||
compressor = HeadroomCompressor({})
|
||
# 正常存储
|
||
h1 = compressor._store_ccr("original content")
|
||
assert h1 is not None
|
||
# 手动修改缓存中的内容为不同值(模拟碰撞)
|
||
# 获取内部存储的 key
|
||
import hashlib
|
||
collision_hash = hashlib.sha256("collision content".encode()).hexdigest()
|
||
# 手动注入一个不同内容到同一个 hash
|
||
compressor._ccr_cache[collision_hash] = ("different content", time.time())
|
||
# 尝试存储 "collision content" 到已有不同内容的 hash
|
||
result = compressor._store_ccr("collision content")
|
||
assert result is None
|
||
|
||
def test_no_collision_same_content_overwrite(self):
|
||
"""相同内容重复存储不触发碰撞(幂等更新)"""
|
||
compressor = HeadroomCompressor({})
|
||
h1 = compressor._store_ccr("same content")
|
||
h2 = compressor._store_ccr("same content")
|
||
assert h1 is not None
|
||
assert h2 is not None
|
||
assert h1 == h2
|