188 lines
7.1 KiB
Python
188 lines
7.1 KiB
Python
"""Tests for CompressionStrategy Protocol and create_compressor factory"""
|
||
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
|
||
|
||
|
||
# ── CompressionStrategy Protocol Tests ────────────────
|
||
|
||
|
||
class TestCompressionStrategyProtocol:
|
||
"""CompressionStrategy 协议满足性测试"""
|
||
|
||
def test_context_compressor_satisfies_protocol(self):
|
||
"""ContextCompressor 实现了 CompressionStrategy 协议"""
|
||
compressor = ContextCompressor()
|
||
assert isinstance(compressor, CompressionStrategy)
|
||
|
||
def test_protocol_requires_compress_method(self):
|
||
"""协议要求 compress 方法"""
|
||
|
||
class MissingCompress:
|
||
async def compress_tool_result(self, tool_name: str, result) -> str:
|
||
return str(result)
|
||
|
||
def is_available(self) -> bool:
|
||
return True
|
||
|
||
assert not isinstance(MissingCompress(), CompressionStrategy)
|
||
|
||
def test_protocol_requires_compress_tool_result_method(self):
|
||
"""协议要求 compress_tool_result 方法"""
|
||
|
||
class MissingCompressToolResult:
|
||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||
return messages
|
||
|
||
def is_available(self) -> bool:
|
||
return True
|
||
|
||
assert not isinstance(MissingCompressToolResult(), CompressionStrategy)
|
||
|
||
def test_protocol_requires_is_available_method(self):
|
||
"""协议要求 is_available 方法"""
|
||
|
||
class MissingIsAvailable:
|
||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||
return messages
|
||
|
||
async def compress_tool_result(self, tool_name: str, result) -> str:
|
||
return str(result)
|
||
|
||
assert not isinstance(MissingIsAvailable(), CompressionStrategy)
|
||
|
||
|
||
# ── create_compressor Factory Tests ───────────────────
|
||
|
||
|
||
class TestCreateCompressor:
|
||
"""create_compressor 工厂函数测试"""
|
||
|
||
def test_none_config_returns_none(self):
|
||
"""config 为 None 时返回 None"""
|
||
assert create_compressor(None) is None
|
||
|
||
def test_empty_config_returns_none(self):
|
||
"""空 config 时返回 None"""
|
||
assert create_compressor({}) is None
|
||
|
||
def test_disabled_config_returns_none(self):
|
||
"""enabled=False 时返回 None"""
|
||
assert create_compressor({"enabled": False}) is None
|
||
|
||
def test_summary_provider_returns_context_compressor(self):
|
||
"""provider=summary 返回 ContextCompressor"""
|
||
compressor = create_compressor({"enabled": True, "provider": "summary"})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
|
||
def test_default_provider_returns_context_compressor(self):
|
||
"""不指定 provider 默认返回 ContextCompressor"""
|
||
compressor = create_compressor({"enabled": True})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
|
||
def test_headroom_provider_falls_back_when_not_installed(self):
|
||
"""provider=headroom 但未安装时回退到 ContextCompressor"""
|
||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
|
||
def test_summary_config_passed_to_context_compressor(self):
|
||
"""max_tokens 和 keep_recent 传递给 ContextCompressor"""
|
||
compressor = create_compressor({
|
||
"enabled": True,
|
||
"provider": "summary",
|
||
"max_tokens": 8000,
|
||
"keep_recent": 5,
|
||
})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
assert compressor._max_tokens == 8000
|
||
assert compressor._keep_recent == 5
|
||
|
||
def test_headroom_fallback_config_passed_to_context_compressor(self):
|
||
"""headroom 回退时配置也传递给 ContextCompressor"""
|
||
compressor = create_compressor({
|
||
"enabled": True,
|
||
"provider": "headroom",
|
||
"max_tokens": 6000,
|
||
"keep_recent": 4,
|
||
})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
assert compressor._max_tokens == 6000
|
||
assert compressor._keep_recent == 4
|
||
|
||
def test_default_config_values(self):
|
||
"""默认 max_tokens=4000, keep_recent=3"""
|
||
compressor = create_compressor({"enabled": True})
|
||
assert isinstance(compressor, ContextCompressor)
|
||
assert compressor._max_tokens == 4000
|
||
assert compressor._keep_recent == 3
|
||
|
||
|
||
# ── ContextCompressor New Methods Tests ───────────────
|
||
|
||
|
||
class TestContextCompressorNewMethods:
|
||
"""ContextCompressor 新增方法测试"""
|
||
|
||
async def test_compress_tool_result_default(self):
|
||
"""compress_tool_result 默认返回 str(result)"""
|
||
compressor = ContextCompressor()
|
||
result = await compressor.compress_tool_result("search", {"key": "value"})
|
||
assert result == str({"key": "value"})
|
||
|
||
async def test_compress_tool_result_string_input(self):
|
||
"""compress_tool_result 对字符串输入直接返回"""
|
||
compressor = ContextCompressor()
|
||
result = await compressor.compress_tool_result("search", "hello world")
|
||
assert result == "hello world"
|
||
|
||
async def test_compress_tool_result_numeric_input(self):
|
||
"""compress_tool_result 对数字输入返回字符串表示"""
|
||
compressor = ContextCompressor()
|
||
result = await compressor.compress_tool_result("calculator", 42)
|
||
assert result == "42"
|
||
|
||
def test_is_available(self):
|
||
"""ContextCompressor 始终可用"""
|
||
compressor = ContextCompressor()
|
||
assert compressor.is_available() is True
|
||
|
||
def test_is_available_with_gateway(self):
|
||
"""即使有 LLMGateway,ContextCompressor 也可用"""
|
||
gateway = MagicMock()
|
||
compressor = ContextCompressor(llm_gateway=gateway)
|
||
assert compressor.is_available() is True
|
||
|
||
|
||
# ── Headroom Import Mock Tests ────────────────────────
|
||
|
||
|
||
class TestHeadroomImportMock:
|
||
"""模拟 HeadroomCompressor 导入成功/失败的场景"""
|
||
|
||
def test_headroom_available_returns_headroom_instance(self):
|
||
"""HeadroomCompressor 可用时返回其实例"""
|
||
mock_compressor = MagicMock()
|
||
mock_compressor.is_available.return_value = True
|
||
|
||
mock_module = MagicMock()
|
||
mock_module.HeadroomCompressor.return_value = mock_compressor
|
||
|
||
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
|
||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||
assert compressor is mock_compressor
|
||
|
||
def test_headroom_not_available_falls_back(self):
|
||
"""HeadroomCompressor is_available()=False 时回退到 ContextCompressor"""
|
||
mock_compressor = MagicMock()
|
||
mock_compressor.is_available.return_value = False
|
||
|
||
mock_module = MagicMock()
|
||
mock_module.HeadroomCompressor.return_value = mock_compressor
|
||
|
||
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
|
||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||
assert isinstance(compressor, ContextCompressor)
|