feat(compression): U1 CompressionStrategy Protocol and create_compressor factory
Add runtime-checkable CompressionStrategy Protocol with compress(), compress_tool_result(), and is_available() methods. Add compress_tool_result and is_available to existing ContextCompressor. Add create_compressor() factory function with headroom/summary provider routing and ImportError fallback.
This commit is contained in:
parent
80a505b1c1
commit
5d3a5f2bf3
|
|
@ -7,11 +7,28 @@
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CompressionStrategy(Protocol):
|
||||
"""压缩策略协议 — 所有压缩器必须实现此接口"""
|
||||
|
||||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||||
"""压缩消息列表"""
|
||||
...
|
||||
|
||||
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
|
||||
"""压缩单个工具输出结果,返回压缩后的字符串"""
|
||||
...
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查压缩器是否可用"""
|
||||
...
|
||||
|
||||
|
||||
class ContextCompressor:
|
||||
"""Compress long conversation histories to stay within token budgets"""
|
||||
|
||||
|
|
@ -156,6 +173,61 @@ class ContextCompressor:
|
|||
result.append(msg)
|
||||
return result
|
||||
|
||||
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
|
||||
"""默认实现:不做压缩,直接返回字符串表示"""
|
||||
return str(result)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""ContextCompressor 始终可用"""
|
||||
return True
|
||||
|
||||
|
||||
def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None:
|
||||
"""根据配置创建压缩器实例
|
||||
|
||||
Args:
|
||||
config: 压缩配置字典,支持以下字段:
|
||||
- enabled: bool, 是否启用压缩(默认 False)
|
||||
- provider: "headroom" | "summary", 压缩提供者
|
||||
- max_tokens: int, token 预算(summary 模式)
|
||||
- keep_recent: int, 保留最近 N 条消息(summary 模式)
|
||||
- 其他 provider 特定配置
|
||||
|
||||
Returns:
|
||||
CompressionStrategy 实例,或 None(未启用时)
|
||||
"""
|
||||
if not config or not config.get("enabled", False):
|
||||
return None
|
||||
|
||||
provider = config.get("provider", "summary")
|
||||
|
||||
if provider == "headroom":
|
||||
try:
|
||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||
compressor = HeadroomCompressor(config)
|
||||
if compressor.is_available():
|
||||
return compressor
|
||||
logger.warning(
|
||||
"HeadroomCompressor not available (headroom-ai not installed?). "
|
||||
"Falling back to ContextCompressor."
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"HeadroomCompressor module not available. "
|
||||
"Falling back to ContextCompressor."
|
||||
)
|
||||
# Fallback to summary compressor
|
||||
return ContextCompressor(
|
||||
max_tokens=config.get("max_tokens", 4000),
|
||||
keep_recent=config.get("keep_recent", 3),
|
||||
)
|
||||
|
||||
# Default: summary-based compression
|
||||
return ContextCompressor(
|
||||
max_tokens=config.get("max_tokens", 4000),
|
||||
keep_recent=config.get("keep_recent", 3),
|
||||
)
|
||||
|
||||
|
||||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for CompressionStrategy Protocol and create_compressor factory"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
|
||||
|
||||
|
||||
# ── CompressionStrategy Protocol Tests ────────────────
|
||||
|
||||
|
||||
class TestCompressionStrategyProtocol:
|
||||
"""CompressionStrategy 协议满足性测试"""
|
||||
|
||||
def test_context_compressor_satisfies_protocol(self):
|
||||
"""ContextCompressor 实现了 CompressionStrategy 协议"""
|
||||
compressor = ContextCompressor()
|
||||
assert isinstance(compressor, CompressionStrategy)
|
||||
|
||||
def test_protocol_requires_compress_method(self):
|
||||
"""协议要求 compress 方法"""
|
||||
|
||||
class MissingCompress:
|
||||
async def compress_tool_result(self, tool_name: str, result) -> str:
|
||||
return str(result)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
assert not isinstance(MissingCompress(), CompressionStrategy)
|
||||
|
||||
def test_protocol_requires_compress_tool_result_method(self):
|
||||
"""协议要求 compress_tool_result 方法"""
|
||||
|
||||
class MissingCompressToolResult:
|
||||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||||
return messages
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
assert not isinstance(MissingCompressToolResult(), CompressionStrategy)
|
||||
|
||||
def test_protocol_requires_is_available_method(self):
|
||||
"""协议要求 is_available 方法"""
|
||||
|
||||
class MissingIsAvailable:
|
||||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||||
return messages
|
||||
|
||||
async def compress_tool_result(self, tool_name: str, result) -> str:
|
||||
return str(result)
|
||||
|
||||
assert not isinstance(MissingIsAvailable(), CompressionStrategy)
|
||||
|
||||
|
||||
# ── create_compressor Factory Tests ───────────────────
|
||||
|
||||
|
||||
class TestCreateCompressor:
|
||||
"""create_compressor 工厂函数测试"""
|
||||
|
||||
def test_none_config_returns_none(self):
|
||||
"""config 为 None 时返回 None"""
|
||||
assert create_compressor(None) is None
|
||||
|
||||
def test_empty_config_returns_none(self):
|
||||
"""空 config 时返回 None"""
|
||||
assert create_compressor({}) is None
|
||||
|
||||
def test_disabled_config_returns_none(self):
|
||||
"""enabled=False 时返回 None"""
|
||||
assert create_compressor({"enabled": False}) is None
|
||||
|
||||
def test_summary_provider_returns_context_compressor(self):
|
||||
"""provider=summary 返回 ContextCompressor"""
|
||||
compressor = create_compressor({"enabled": True, "provider": "summary"})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
|
||||
def test_default_provider_returns_context_compressor(self):
|
||||
"""不指定 provider 默认返回 ContextCompressor"""
|
||||
compressor = create_compressor({"enabled": True})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
|
||||
def test_headroom_provider_falls_back_when_not_installed(self):
|
||||
"""provider=headroom 但未安装时回退到 ContextCompressor"""
|
||||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
|
||||
def test_summary_config_passed_to_context_compressor(self):
|
||||
"""max_tokens 和 keep_recent 传递给 ContextCompressor"""
|
||||
compressor = create_compressor({
|
||||
"enabled": True,
|
||||
"provider": "summary",
|
||||
"max_tokens": 8000,
|
||||
"keep_recent": 5,
|
||||
})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
assert compressor._max_tokens == 8000
|
||||
assert compressor._keep_recent == 5
|
||||
|
||||
def test_headroom_fallback_config_passed_to_context_compressor(self):
|
||||
"""headroom 回退时配置也传递给 ContextCompressor"""
|
||||
compressor = create_compressor({
|
||||
"enabled": True,
|
||||
"provider": "headroom",
|
||||
"max_tokens": 6000,
|
||||
"keep_recent": 4,
|
||||
})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
assert compressor._max_tokens == 6000
|
||||
assert compressor._keep_recent == 4
|
||||
|
||||
def test_default_config_values(self):
|
||||
"""默认 max_tokens=4000, keep_recent=3"""
|
||||
compressor = create_compressor({"enabled": True})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
assert compressor._max_tokens == 4000
|
||||
assert compressor._keep_recent == 3
|
||||
|
||||
|
||||
# ── ContextCompressor New Methods Tests ───────────────
|
||||
|
||||
|
||||
class TestContextCompressorNewMethods:
|
||||
"""ContextCompressor 新增方法测试"""
|
||||
|
||||
async def test_compress_tool_result_default(self):
|
||||
"""compress_tool_result 默认返回 str(result)"""
|
||||
compressor = ContextCompressor()
|
||||
result = await compressor.compress_tool_result("search", {"key": "value"})
|
||||
assert result == str({"key": "value"})
|
||||
|
||||
async def test_compress_tool_result_string_input(self):
|
||||
"""compress_tool_result 对字符串输入直接返回"""
|
||||
compressor = ContextCompressor()
|
||||
result = await compressor.compress_tool_result("search", "hello world")
|
||||
assert result == "hello world"
|
||||
|
||||
async def test_compress_tool_result_numeric_input(self):
|
||||
"""compress_tool_result 对数字输入返回字符串表示"""
|
||||
compressor = ContextCompressor()
|
||||
result = await compressor.compress_tool_result("calculator", 42)
|
||||
assert result == "42"
|
||||
|
||||
def test_is_available(self):
|
||||
"""ContextCompressor 始终可用"""
|
||||
compressor = ContextCompressor()
|
||||
assert compressor.is_available() is True
|
||||
|
||||
def test_is_available_with_gateway(self):
|
||||
"""即使有 LLMGateway,ContextCompressor 也可用"""
|
||||
gateway = MagicMock()
|
||||
compressor = ContextCompressor(llm_gateway=gateway)
|
||||
assert compressor.is_available() is True
|
||||
|
||||
|
||||
# ── Headroom Import Mock Tests ────────────────────────
|
||||
|
||||
|
||||
class TestHeadroomImportMock:
|
||||
"""模拟 HeadroomCompressor 导入成功/失败的场景"""
|
||||
|
||||
def test_headroom_available_returns_headroom_instance(self):
|
||||
"""HeadroomCompressor 可用时返回其实例"""
|
||||
mock_compressor = MagicMock()
|
||||
mock_compressor.is_available.return_value = True
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.HeadroomCompressor.return_value = mock_compressor
|
||||
|
||||
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
|
||||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||||
assert compressor is mock_compressor
|
||||
|
||||
def test_headroom_not_available_falls_back(self):
|
||||
"""HeadroomCompressor is_available()=False 时回退到 ContextCompressor"""
|
||||
mock_compressor = MagicMock()
|
||||
mock_compressor.is_available.return_value = False
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.HeadroomCompressor.return_value = mock_compressor
|
||||
|
||||
with patch.dict("sys.modules", {"agentkit.core.headroom_compressor": mock_module}):
|
||||
compressor = create_compressor({"enabled": True, "provider": "headroom"})
|
||||
assert isinstance(compressor, ContextCompressor)
|
||||
Loading…
Reference in New Issue