diff --git a/src/agentkit/core/compressor.py b/src/agentkit/core/compressor.py index 0c8fc28..b7818da 100644 --- a/src/agentkit/core/compressor.py +++ b/src/agentkit/core/compressor.py @@ -7,11 +7,28 @@ import hashlib import json import logging -from typing import Any +from typing import Any, Protocol, runtime_checkable logger = logging.getLogger(__name__) +@runtime_checkable +class CompressionStrategy(Protocol): + """压缩策略协议 — 所有压缩器必须实现此接口""" + + async def compress(self, messages: list[dict]) -> list[dict]: + """压缩消息列表""" + ... + + async def compress_tool_result(self, tool_name: str, result: Any) -> str: + """压缩单个工具输出结果,返回压缩后的字符串""" + ... + + def is_available(self) -> bool: + """检查压缩器是否可用""" + ... + + class ContextCompressor: """Compress long conversation histories to stay within token budgets""" @@ -156,6 +173,61 @@ class ContextCompressor: result.append(msg) return result + async def compress_tool_result(self, tool_name: str, result: Any) -> str: + """默认实现:不做压缩,直接返回字符串表示""" + return str(result) + + def is_available(self) -> bool: + """ContextCompressor 始终可用""" + return True + + +def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None: + """根据配置创建压缩器实例 + + Args: + config: 压缩配置字典,支持以下字段: + - enabled: bool, 是否启用压缩(默认 False) + - provider: "headroom" | "summary", 压缩提供者 + - max_tokens: int, token 预算(summary 模式) + - keep_recent: int, 保留最近 N 条消息(summary 模式) + - 其他 provider 特定配置 + + Returns: + CompressionStrategy 实例,或 None(未启用时) + """ + if not config or not config.get("enabled", False): + return None + + provider = config.get("provider", "summary") + + if provider == "headroom": + try: + from agentkit.core.headroom_compressor import HeadroomCompressor + compressor = HeadroomCompressor(config) + if compressor.is_available(): + return compressor + logger.warning( + "HeadroomCompressor not available (headroom-ai not installed?). " + "Falling back to ContextCompressor." + ) + except ImportError: + logger.warning( + "HeadroomCompressor module not available. " + "Falling back to ContextCompressor." + ) + # Fallback to summary compressor + return ContextCompressor( + max_tokens=config.get("max_tokens", 4000), + keep_recent=config.get("keep_recent", 3), + ) + + # Default: summary-based compression + return ContextCompressor( + max_tokens=config.get("max_tokens", 4000), + keep_recent=config.get("keep_recent", 3), + ) + def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]: """Render PromptTemplate with caching - returns cached result for same variables""" diff --git a/tests/unit/test_compression_strategy.py b/tests/unit/test_compression_strategy.py new file mode 100644 index 0000000..58f212d --- /dev/null +++ b/tests/unit/test_compression_strategy.py @@ -0,0 +1,187 @@ +"""Tests for CompressionStrategy Protocol and create_compressor factory""" + +from unittest.mock import MagicMock, patch + +import pytest + +from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor + + +# ── CompressionStrategy Protocol Tests ──────────────── + + +class TestCompressionStrategyProtocol: + """CompressionStrategy 协议满足性测试""" + + def test_context_compressor_satisfies_protocol(self): + """ContextCompressor 实现了 CompressionStrategy 协议""" + compressor = ContextCompressor() + assert isinstance(compressor, CompressionStrategy) + + def test_protocol_requires_compress_method(self): + """协议要求 compress 方法""" + + class MissingCompress: + async def compress_tool_result(self, tool_name: str, result) -> str: + return str(result) + + def is_available(self) -> bool: + return True + + assert not isinstance(MissingCompress(), CompressionStrategy) + + def test_protocol_requires_compress_tool_result_method(self): + """协议要求 compress_tool_result 方法""" + + class MissingCompressToolResult: + async def compress(self, messages: list[dict]) -> list[dict]: + return messages + + def is_available(self) -> bool: + return True + + assert not isinstance(MissingCompressToolResult(), CompressionStrategy) + + def test_protocol_requires_is_available_method(self): + """协议要求 is_available 方法""" + + class MissingIsAvailable: + async def compress(self, messages: list[dict]) -> list[dict]: + return messages + + async def compress_tool_result(self, tool_name: str, result) -> str: + return str(result) + + assert not isinstance(MissingIsAvailable(), CompressionStrategy) + + +# ── create_compressor Factory Tests ─────────────────── + + +class TestCreateCompressor: + """create_compressor 工厂函数测试""" + + def test_none_config_returns_none(self): + """config 为 None 时返回 None""" + assert create_compressor(None) is None + + def test_empty_config_returns_none(self): + """空 config 时返回 None""" + assert create_compressor({}) is None + + def test_disabled_config_returns_none(self): + """enabled=False 时返回 None""" + assert create_compressor({"enabled": False}) is None + + def test_summary_provider_returns_context_compressor(self): + """provider=summary 返回 ContextCompressor""" + compressor = create_compressor({"enabled": True, "provider": "summary"}) + assert isinstance(compressor, ContextCompressor) + + def test_default_provider_returns_context_compressor(self): + """不指定 provider 默认返回 ContextCompressor""" + compressor = create_compressor({"enabled": True}) + assert isinstance(compressor, ContextCompressor) + + def test_headroom_provider_falls_back_when_not_installed(self): + """provider=headroom 但未安装时回退到 ContextCompressor""" + compressor = create_compressor({"enabled": True, "provider": "headroom"}) + assert isinstance(compressor, ContextCompressor) + + def test_summary_config_passed_to_context_compressor(self): + """max_tokens 和 keep_recent 传递给 ContextCompressor""" + compressor = create_compressor({ + "enabled": True, + "provider": "summary", + "max_tokens": 8000, + "keep_recent": 5, + }) + assert isinstance(compressor, ContextCompressor) + assert compressor._max_tokens == 8000 + assert compressor._keep_recent == 5 + + def test_headroom_fallback_config_passed_to_context_compressor(self): + """headroom 回退时配置也传递给 ContextCompressor""" + compressor = create_compressor({ + "enabled": True, + "provider": "headroom", + "max_tokens": 6000, + "keep_recent": 4, + }) + assert isinstance(compressor, ContextCompressor) + assert compressor._max_tokens == 6000 + assert compressor._keep_recent == 4 + + def test_default_config_values(self): + """默认 max_tokens=4000, keep_recent=3""" + compressor = create_compressor({"enabled": True}) + assert isinstance(compressor, ContextCompressor) + assert compressor._max_tokens == 4000 + assert compressor._keep_recent == 3 + + +# ── ContextCompressor New Methods Tests ─────────────── + + +class TestContextCompressorNewMethods: + """ContextCompressor 新增方法测试""" + + async def test_compress_tool_result_default(self): + """compress_tool_result 默认返回 str(result)""" + compressor = ContextCompressor() + result = await compressor.compress_tool_result("search", {"key": "value"}) + assert result == str({"key": "value"}) + + async def test_compress_tool_result_string_input(self): + """compress_tool_result 对字符串输入直接返回""" + compressor = ContextCompressor() + result = await compressor.compress_tool_result("search", "hello world") + assert result == "hello world" + + async def test_compress_tool_result_numeric_input(self): + """compress_tool_result 对数字输入返回字符串表示""" + compressor = ContextCompressor() + result = await compressor.compress_tool_result("calculator", 42) + assert result == "42" + + def test_is_available(self): + """ContextCompressor 始终可用""" + compressor = ContextCompressor() + assert compressor.is_available() is True + + def test_is_available_with_gateway(self): + """即使有 LLMGateway,ContextCompressor 也可用""" + gateway = MagicMock() + compressor = ContextCompressor(llm_gateway=gateway) + assert compressor.is_available() is True + + +# ── Headroom Import Mock Tests ──────────────────────── + + +class TestHeadroomImportMock: + """模拟 HeadroomCompressor 导入成功/失败的场景""" + + def test_headroom_available_returns_headroom_instance(self): + """HeadroomCompressor 可用时返回其实例""" + mock_compressor = MagicMock() + mock_compressor.is_available.return_value = True + + mock_module = MagicMock() + mock_module.HeadroomCompressor.return_value = mock_compressor + + with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}): + compressor = create_compressor({"enabled": True, "provider": "headroom"}) + assert compressor is mock_compressor + + def test_headroom_not_available_falls_back(self): + """HeadroomCompressor is_available()=False 时回退到 ContextCompressor""" + mock_compressor = MagicMock() + mock_compressor.is_available.return_value = False + + mock_module = MagicMock() + mock_module.HeadroomCompressor.return_value = mock_compressor + + with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}): + compressor = create_compressor({"enabled": True, "provider": "headroom"}) + assert isinstance(compressor, ContextCompressor)