feat(compression): U1 CompressionStrategy Protocol and create_compressor factory

Add runtime-checkable CompressionStrategy Protocol with compress(),
compress_tool_result(), and is_available() methods. Add compress_tool_result
and is_available to existing ContextCompressor. Add create_compressor()
factory function with headroom/summary provider routing and ImportError
fallback.
This commit is contained in:
chiguyong 2026-06-07 18:19:27 +08:00
parent 80a505b1c1
commit 5d3a5f2bf3
2 changed files with 260 additions and 1 deletions

View File

@ -7,11 +7,28 @@
import hashlib import hashlib
import json import json
import logging import logging
from typing import Any from typing import Any, Protocol, runtime_checkable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@runtime_checkable
class CompressionStrategy(Protocol):
"""压缩策略协议 — 所有压缩器必须实现此接口"""
async def compress(self, messages: list[dict]) -> list[dict]:
"""压缩消息列表"""
...
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""压缩单个工具输出结果,返回压缩后的字符串"""
...
def is_available(self) -> bool:
"""检查压缩器是否可用"""
...
class ContextCompressor: class ContextCompressor:
"""Compress long conversation histories to stay within token budgets""" """Compress long conversation histories to stay within token budgets"""
@ -156,6 +173,61 @@ class ContextCompressor:
result.append(msg) result.append(msg)
return result return result
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""默认实现:不做压缩,直接返回字符串表示"""
return str(result)
def is_available(self) -> bool:
"""ContextCompressor 始终可用"""
return True
def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None:
"""根据配置创建压缩器实例
Args:
config: 压缩配置字典支持以下字段
- enabled: bool, 是否启用压缩默认 False
- provider: "headroom" | "summary", 压缩提供者
- max_tokens: int, token 预算summary 模式
- keep_recent: int, 保留最近 N 条消息summary 模式
- 其他 provider 特定配置
Returns:
CompressionStrategy 实例 None未启用时
"""
if not config or not config.get("enabled", False):
return None
provider = config.get("provider", "summary")
if provider == "headroom":
try:
from agentkit.core.headroom_compressor import HeadroomCompressor
compressor = HeadroomCompressor(config)
if compressor.is_available():
return compressor
logger.warning(
"HeadroomCompressor not available (headroom-ai not installed?). "
"Falling back to ContextCompressor."
)
except ImportError:
logger.warning(
"HeadroomCompressor module not available. "
"Falling back to ContextCompressor."
)
# Fallback to summary compressor
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
# Default: summary-based compression
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]: def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables""" """Render PromptTemplate with caching - returns cached result for same variables"""

View File

@ -0,0 +1,187 @@
"""Tests for CompressionStrategy Protocol and create_compressor factory"""
from unittest.mock import MagicMock, patch
import pytest
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
# ── CompressionStrategy Protocol Tests ────────────────
class TestCompressionStrategyProtocol:
"""CompressionStrategy 协议满足性测试"""
def test_context_compressor_satisfies_protocol(self):
"""ContextCompressor 实现了 CompressionStrategy 协议"""
compressor = ContextCompressor()
assert isinstance(compressor, CompressionStrategy)
def test_protocol_requires_compress_method(self):
"""协议要求 compress 方法"""
class MissingCompress:
async def compress_tool_result(self, tool_name: str, result) -> str:
return str(result)
def is_available(self) -> bool:
return True
assert not isinstance(MissingCompress(), CompressionStrategy)
def test_protocol_requires_compress_tool_result_method(self):
"""协议要求 compress_tool_result 方法"""
class MissingCompressToolResult:
async def compress(self, messages: list[dict]) -> list[dict]:
return messages
def is_available(self) -> bool:
return True
assert not isinstance(MissingCompressToolResult(), CompressionStrategy)
def test_protocol_requires_is_available_method(self):
"""协议要求 is_available 方法"""
class MissingIsAvailable:
async def compress(self, messages: list[dict]) -> list[dict]:
return messages
async def compress_tool_result(self, tool_name: str, result) -> str:
return str(result)
assert not isinstance(MissingIsAvailable(), CompressionStrategy)
# ── create_compressor Factory Tests ───────────────────
class TestCreateCompressor:
"""create_compressor 工厂函数测试"""
def test_none_config_returns_none(self):
"""config 为 None 时返回 None"""
assert create_compressor(None) is None
def test_empty_config_returns_none(self):
"""空 config 时返回 None"""
assert create_compressor({}) is None
def test_disabled_config_returns_none(self):
"""enabled=False 时返回 None"""
assert create_compressor({"enabled": False}) is None
def test_summary_provider_returns_context_compressor(self):
"""provider=summary 返回 ContextCompressor"""
compressor = create_compressor({"enabled": True, "provider": "summary"})
assert isinstance(compressor, ContextCompressor)
def test_default_provider_returns_context_compressor(self):
"""不指定 provider 默认返回 ContextCompressor"""
compressor = create_compressor({"enabled": True})
assert isinstance(compressor, ContextCompressor)
def test_headroom_provider_falls_back_when_not_installed(self):
"""provider=headroom 但未安装时回退到 ContextCompressor"""
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert isinstance(compressor, ContextCompressor)
def test_summary_config_passed_to_context_compressor(self):
"""max_tokens 和 keep_recent 传递给 ContextCompressor"""
compressor = create_compressor({
"enabled": True,
"provider": "summary",
"max_tokens": 8000,
"keep_recent": 5,
})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 8000
assert compressor._keep_recent == 5
def test_headroom_fallback_config_passed_to_context_compressor(self):
"""headroom 回退时配置也传递给 ContextCompressor"""
compressor = create_compressor({
"enabled": True,
"provider": "headroom",
"max_tokens": 6000,
"keep_recent": 4,
})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 6000
assert compressor._keep_recent == 4
def test_default_config_values(self):
"""默认 max_tokens=4000, keep_recent=3"""
compressor = create_compressor({"enabled": True})
assert isinstance(compressor, ContextCompressor)
assert compressor._max_tokens == 4000
assert compressor._keep_recent == 3
# ── ContextCompressor New Methods Tests ───────────────
class TestContextCompressorNewMethods:
"""ContextCompressor 新增方法测试"""
async def test_compress_tool_result_default(self):
"""compress_tool_result 默认返回 str(result)"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", {"key": "value"})
assert result == str({"key": "value"})
async def test_compress_tool_result_string_input(self):
"""compress_tool_result 对字符串输入直接返回"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", "hello world")
assert result == "hello world"
async def test_compress_tool_result_numeric_input(self):
"""compress_tool_result 对数字输入返回字符串表示"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("calculator", 42)
assert result == "42"
def test_is_available(self):
"""ContextCompressor 始终可用"""
compressor = ContextCompressor()
assert compressor.is_available() is True
def test_is_available_with_gateway(self):
"""即使有 LLMGatewayContextCompressor 也可用"""
gateway = MagicMock()
compressor = ContextCompressor(llm_gateway=gateway)
assert compressor.is_available() is True
# ── Headroom Import Mock Tests ────────────────────────
class TestHeadroomImportMock:
"""模拟 HeadroomCompressor 导入成功/失败的场景"""
def test_headroom_available_returns_headroom_instance(self):
"""HeadroomCompressor 可用时返回其实例"""
mock_compressor = MagicMock()
mock_compressor.is_available.return_value = True
mock_module = MagicMock()
mock_module.HeadroomCompressor.return_value = mock_compressor
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert compressor is mock_compressor
def test_headroom_not_available_falls_back(self):
"""HeadroomCompressor is_available()=False 时回退到 ContextCompressor"""
mock_compressor = MagicMock()
mock_compressor.is_available.return_value = False
mock_module = MagicMock()
mock_module.HeadroomCompressor.return_value = mock_compressor
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
compressor = create_compressor({"enabled": True, "provider": "headroom"})
assert isinstance(compressor, ContextCompressor)