fischer-agentkit/tests/unit/test_compression_strategy.py

188 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for CompressionStrategy Protocol and create_compressor factory"""
from unittest.mock import MagicMock, patch
import pytest
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
# ── CompressionStrategy Protocol Tests ────────────────
class TestCompressionStrategyProtocol:
"""CompressionStrategy 协议满足性测试"""
def test_context_compressor_satisfies_protocol(self):
"""ContextCompressor 实现了 CompressionStrategy 协议"""
compressor = ContextCompressor()
assert isinstance(compressor, CompressionStrategy)
def test_protocol_requires_compress_method(self):
"""协议要求 compress 方法"""
class MissingCompress:
async def compress_tool_result(self, tool_name: str, result) -> str:
return str(result)
def is_available(self) -> bool:
return True
assert not isinstance(MissingCompress(), CompressionStrategy)
def test_protocol_requires_compress_tool_result_method(self):
"""协议要求 compress_tool_result 方法"""
class MissingCompressToolResult:
async def compress(self, messages: list[dict]) -> list[dict]:
return messages
def is_available(self) -> bool:
return True
assert not isinstance(MissingCompressToolResult(), CompressionStrategy)
def test_protocol_requires_is_available_method(self):
"""协议要求 is_available 方法"""
class MissingIsAvailable:
async def compress(self, messages: list[dict]) -> list[dict]:
return messages
async def compress_tool_result(self, tool_name: str, result) -> str:
return str(result)
assert not isinstance(MissingIsAvailable(), CompressionStrategy)
# ── create_compressor Factory Tests ───────────────────
class TestCreateCompressor:
"""create_compressor 工厂函数测试"""
def test_none_config_returns_none(self):
"""config 为 None 时返回 None"""
assert create_compressor(None) is None
def test_empty_config_returns_none(self):
"""空 config 时返回 None"""
assert create_compressor({}) is None
def test_disabled_config_returns_none(self):
"""enabled=False 时返回 None"""
assert create_compressor({"enabled": False}) is None
def test_summary_provider_returns_context_compressor(self):
"""provider=summary 返回 ContextCompressor"""
compressor = create_compressor({"enabled": True, "provider": "summary"})
assert isinstance(compressor, ContextCompressor)
def test_default_provider_returns_context_compressor(self):
"""不指定 provider 默认返回 ContextCompressor"""
compressor = create_compressor({"enabled": True})
assert isinstance(compressor, ContextCompressor)
def test_headroom_provider_falls_back_when_not_installed(self):
"""provider=headroom 但未安装时回退到 ContextCompressor"""
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert isinstance(compressor, ContextCompressor)
def test_summary_config_passed_to_context_compressor(self):
"""max_tokens 和 keep_recent 传递给 ContextCompressor"""
compressor = create_compressor({
"enabled": True,
"provider": "summary",
"max_tokens": 8000,
"keep_recent": 5,
})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 8000
assert compressor._keep_recent == 5
def test_headroom_fallback_config_passed_to_context_compressor(self):
"""headroom 回退时配置也传递给 ContextCompressor"""
compressor = create_compressor({
"enabled": True,
"provider": "headroom",
"max_tokens": 6000,
"keep_recent": 4,
})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 6000
assert compressor._keep_recent == 4
def test_default_config_values(self):
"""默认 max_tokens=4000, keep_recent=3"""
compressor = create_compressor({"enabled": True})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 4000
assert compressor._keep_recent == 3
# ── ContextCompressor New Methods Tests ───────────────
class TestContextCompressorNewMethods:
"""ContextCompressor 新增方法测试"""
async def test_compress_tool_result_default(self):
"""compress_tool_result 默认返回 str(result)"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", {"key": "value"})
assert result == str({"key": "value"})
async def test_compress_tool_result_string_input(self):
"""compress_tool_result 对字符串输入直接返回"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", "hello world")
assert result == "hello world"
async def test_compress_tool_result_numeric_input(self):
"""compress_tool_result 对数字输入返回字符串表示"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("calculator", 42)
assert result == "42"
def test_is_available(self):
"""ContextCompressor 始终可用"""
compressor = ContextCompressor()
assert compressor.is_available() is True
def test_is_available_with_gateway(self):
"""即使有 LLMGatewayContextCompressor 也可用"""
gateway = MagicMock()
compressor = ContextCompressor(llm_gateway=gateway)
assert compressor.is_available() is True
# ── Headroom Import Mock Tests ────────────────────────
class TestHeadroomImportMock:
"""模拟 HeadroomCompressor 导入成功/失败的场景"""
def test_headroom_available_returns_headroom_instance(self):
"""HeadroomCompressor 可用时返回其实例"""
mock_compressor = MagicMock()
mock_compressor.is_available.return_value = True
mock_module = MagicMock()
mock_module.HeadroomCompressor.return_value = mock_compressor
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert compressor is mock_compressor
def test_headroom_not_available_falls_back(self):
"""HeadroomCompressor is_available()=False 时回退到 ContextCompressor"""
mock_compressor = MagicMock()
mock_compressor.is_available.return_value = False
mock_module = MagicMock()
mock_module.HeadroomCompressor.return_value = mock_compressor
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert isinstance(compressor, ContextCompressor)