257 lines
9.7 KiB
Python
257 lines
9.7 KiB
Python
"""HeadroomCompressor — 基于 headroom-ai 的上下文压缩器
|
||
|
||
在工具输出拼装到对话历史前进行智能压缩,减少 60-90% token 消耗。
|
||
使用 headroom-ai Library 模式集成,支持 SmartCrusher (JSON) 和 CodeCompressor (代码)。
|
||
CCR 可逆压缩保证原始数据不丢失。
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from collections import OrderedDict
|
||
from typing import Any
|
||
|
||
from agentkit.core.compressor import CompressionStrategy
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Optional dependency detection
|
||
_HEADROOM_AVAILABLE = False
|
||
headroom_compress = None # type: ignore[misc,assignment]
|
||
try:
|
||
from headroom import compress as headroom_compress
|
||
_HEADROOM_AVAILABLE = True
|
||
except ImportError:
|
||
pass
|
||
|
||
|
||
def _is_json_content(text: str) -> bool:
|
||
"""检测文本是否为 JSON 内容"""
|
||
text = text.strip()
|
||
if text.startswith(("{", "[")):
|
||
try:
|
||
json.loads(text)
|
||
return True
|
||
except (json.JSONDecodeError, ValueError):
|
||
pass
|
||
return False
|
||
|
||
|
||
def _is_code_content(text: str) -> bool:
|
||
"""检测文本是否为代码内容"""
|
||
# Common code patterns
|
||
code_indicators = [
|
||
r"^\s*(def |class |import |from |func |fn |pub |package |#include )", # Python/Go/Rust/Java/C
|
||
r"^\s*(function |const |let |var |export |import )", # JS/TS
|
||
r"```[a-z]", # Code blocks
|
||
r"^\s*(if |for |while |try |catch |switch )", # Control flow
|
||
]
|
||
lines = text.split("\n")
|
||
code_line_count = 0
|
||
for line in lines[:20]: # Check first 20 lines
|
||
for pattern in code_indicators:
|
||
if re.search(pattern, line, re.MULTILINE):
|
||
code_line_count += 1
|
||
break
|
||
# If more than 30% of first 20 lines look like code, treat as code
|
||
return code_line_count > min(6, len(lines) * 0.3)
|
||
|
||
|
||
class HeadroomCompressor:
|
||
"""基于 headroom-ai 的上下文压缩器
|
||
|
||
支持 SmartCrusher (JSON) 和 CodeCompressor (代码) 两种压缩策略。
|
||
CCR 可逆压缩保证原始数据可通过 headroom_retrieve 取回。
|
||
|
||
配置项:
|
||
enabled: bool — 开关
|
||
compressors: list[str] — 启用的压缩器 ["smart_crusher", "code_compressor"]
|
||
ccr_ttl: int — CCR 缓存 TTL(秒),默认 300;0 表示永不过期
|
||
max_entries: int — CCR 缓存最大条目数,默认 1000
|
||
min_length: int — 最小压缩长度(字符),默认 500
|
||
model: str — 传给 headroom 的模型名
|
||
"""
|
||
|
||
def __init__(self, config: dict[str, Any]):
|
||
self._config = config
|
||
self._compressors = config.get("compressors", ["smart_crusher", "code_compressor"])
|
||
self._ccr_ttl = config.get("ccr_ttl", 300)
|
||
self._max_entries = config.get("max_entries", 1000)
|
||
self._min_length = config.get("min_length", 500)
|
||
self._model = config.get("model", "default")
|
||
# CCR cache: hash -> (content, insert_timestamp) with LRU ordering
|
||
self._ccr_cache: OrderedDict[str, tuple[str, float]] = OrderedDict()
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查 headroom-ai 是否已安装"""
|
||
return _HEADROOM_AVAILABLE
|
||
|
||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||
"""压缩消息列表中 role=tool 的消息"""
|
||
if not _HEADROOM_AVAILABLE:
|
||
return messages
|
||
|
||
compressed = []
|
||
for msg in messages:
|
||
if msg.get("role") == "tool" and len(str(msg.get("content", ""))) >= self._min_length:
|
||
try:
|
||
original_content = str(msg.get("content", ""))
|
||
# Use headroom compress on the tool message
|
||
result = headroom_compress(
|
||
[msg],
|
||
model=self._model,
|
||
)
|
||
# result.messages contains the compressed messages
|
||
if hasattr(result, "messages") and result.messages:
|
||
compressed_msg = result.messages[0]
|
||
# Store original in CCR cache
|
||
ccr_hash = self._store_ccr(original_content)
|
||
# Append CCR hash to compressed content
|
||
content = compressed_msg.get("content", original_content)
|
||
if ccr_hash:
|
||
content += f"\n<!-- CCR:hash={ccr_hash} -->"
|
||
compressed.append({**msg, "content": content})
|
||
else:
|
||
compressed.append(msg)
|
||
except Exception as e:
|
||
logger.warning(f"Headroom compression failed for tool message: {e}")
|
||
compressed.append(msg)
|
||
else:
|
||
compressed.append(msg)
|
||
|
||
return compressed
|
||
|
||
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
|
||
"""压缩单个工具输出结果"""
|
||
content = str(result)
|
||
|
||
if not _HEADROOM_AVAILABLE:
|
||
return content
|
||
|
||
if len(content) < self._min_length:
|
||
return content
|
||
|
||
try:
|
||
# Route by content type
|
||
content_type = self._detect_content_type(content)
|
||
|
||
if content_type == "json" and "smart_crusher" in self._compressors:
|
||
compressed = self._compress_with_headroom(content, "smart_crusher")
|
||
elif content_type == "code" and "code_compressor" in self._compressors:
|
||
compressed = self._compress_with_headroom(content, "code_compressor")
|
||
else:
|
||
# No applicable compressor
|
||
return content
|
||
|
||
if compressed and len(compressed) < len(content):
|
||
ccr_hash = self._store_ccr(content)
|
||
if ccr_hash:
|
||
compressed += f"\n<!-- CCR:hash={ccr_hash} -->"
|
||
return compressed
|
||
|
||
return content
|
||
except Exception as e:
|
||
logger.warning(f"Tool result compression failed for '{tool_name}': {e}")
|
||
return content
|
||
|
||
def _detect_content_type(self, content: str) -> str:
|
||
"""检测内容类型"""
|
||
if _is_json_content(content):
|
||
return "json"
|
||
if _is_code_content(content):
|
||
return "code"
|
||
return "text"
|
||
|
||
def _compress_with_headroom(self, content: str, compressor: str) -> str | None:
|
||
"""使用 headroom 压缩内容"""
|
||
try:
|
||
msg = [{"role": "user", "content": content}]
|
||
result = headroom_compress(msg, model=self._model)
|
||
if hasattr(result, "messages") and result.messages:
|
||
return result.messages[0].get("content", content)
|
||
return None
|
||
except Exception as e:
|
||
logger.warning(f"Headroom {compressor} compression failed: {e}")
|
||
return None
|
||
|
||
def _store_ccr(self, original: str) -> str | None:
|
||
"""存储原始内容到 CCR 缓存,返回哈希
|
||
|
||
使用完整 SHA-256 防止碰撞。碰撞时拒绝覆盖并返回 None。
|
||
超过 max_entries 时淘汰最久未访问的条目(LRU)。
|
||
"""
|
||
ccr_hash = hashlib.sha256(original.encode()).hexdigest()
|
||
|
||
# Collision detection: if hash exists with different content, reject
|
||
if ccr_hash in self._ccr_cache:
|
||
cached_content, _ = self._ccr_cache[ccr_hash]
|
||
if cached_content != original:
|
||
logger.warning(
|
||
"CCR hash collision detected for hash=%s... "
|
||
"Rejecting overwrite to prevent data loss.",
|
||
ccr_hash[:16],
|
||
)
|
||
return None
|
||
# Same content: idempotent update (renew timestamp + LRU position)
|
||
self._ccr_cache.move_to_end(ccr_hash)
|
||
self._ccr_cache[ccr_hash] = (original, time.monotonic())
|
||
return ccr_hash
|
||
|
||
# Evict expired entries before inserting
|
||
self._evict_expired()
|
||
|
||
# LRU eviction: if at capacity, remove oldest entry
|
||
while len(self._ccr_cache) >= self._max_entries:
|
||
self._ccr_cache.popitem(last=False)
|
||
|
||
self._ccr_cache[ccr_hash] = (original, time.monotonic())
|
||
return ccr_hash
|
||
|
||
def _evict_expired(self) -> None:
|
||
"""清理过期的 CCR 缓存条目"""
|
||
if self._ccr_ttl <= 0:
|
||
return # TTL=0 means no expiry
|
||
now = time.monotonic()
|
||
expired_keys = [
|
||
k for k, (_, ts) in self._ccr_cache.items()
|
||
if now - ts > self._ccr_ttl
|
||
]
|
||
for k in expired_keys:
|
||
del self._ccr_cache[k]
|
||
|
||
def retrieve(self, ccr_hash: str | None = None, query: str | None = None) -> dict:
|
||
"""从 CCR 缓存检索原始数据"""
|
||
if ccr_hash and ccr_hash in self._ccr_cache:
|
||
content, ts = self._ccr_cache[ccr_hash]
|
||
# Check TTL
|
||
if self._ccr_ttl > 0:
|
||
if time.monotonic() - ts > self._ccr_ttl:
|
||
del self._ccr_cache[ccr_hash]
|
||
return {
|
||
"error": f"CCR hash '{ccr_hash}' expired",
|
||
"success": False,
|
||
}
|
||
# Renew LRU position on access
|
||
self._ccr_cache.move_to_end(ccr_hash)
|
||
return {
|
||
"content": content,
|
||
"ccr_hash": ccr_hash,
|
||
"success": True,
|
||
}
|
||
|
||
if query:
|
||
# Simple keyword search in cached content
|
||
results = []
|
||
for h, (content, _) in self._ccr_cache.items():
|
||
if query.lower() in content.lower():
|
||
results.append({"ccr_hash": h, "content": content[:500]})
|
||
if results:
|
||
return {"results": results, "success": True}
|
||
|
||
return {
|
||
"error": f"CCR hash '{ccr_hash}' not found in cache",
|
||
"success": False,
|
||
}
|