fischer-agentkit/src/agentkit/core/headroom_compressor.py

257 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""HeadroomCompressor — 基于 headroom-ai 的上下文压缩器
在工具输出拼装到对话历史前进行智能压缩,减少 60-90% token 消耗。
使用 headroom-ai Library 模式集成,支持 SmartCrusher (JSON) 和 CodeCompressor (代码)。
CCR 可逆压缩保证原始数据不丢失。
"""
import hashlib
import json
import logging
import re
import time
from collections import OrderedDict
from typing import Any
from agentkit.core.compressor import CompressionStrategy
logger = logging.getLogger(__name__)
# Optional dependency detection
_HEADROOM_AVAILABLE = False
headroom_compress = None # type: ignore[misc,assignment]
try:
from headroom import compress as headroom_compress
_HEADROOM_AVAILABLE = True
except ImportError:
pass
def _is_json_content(text: str) -> bool:
"""检测文本是否为 JSON 内容"""
text = text.strip()
if text.startswith(("{", "[")):
try:
json.loads(text)
return True
except (json.JSONDecodeError, ValueError):
pass
return False
def _is_code_content(text: str) -> bool:
"""检测文本是否为代码内容"""
# Common code patterns
code_indicators = [
r"^\s*(def |class |import |from |func |fn |pub |package |#include )", # Python/Go/Rust/Java/C
r"^\s*(function |const |let |var |export |import )", # JS/TS
r"```[a-z]", # Code blocks
r"^\s*(if |for |while |try |catch |switch )", # Control flow
]
lines = text.split("\n")
code_line_count = 0
for line in lines[:20]: # Check first 20 lines
for pattern in code_indicators:
if re.search(pattern, line, re.MULTILINE):
code_line_count += 1
break
# If more than 30% of first 20 lines look like code, treat as code
return code_line_count > min(6, len(lines) * 0.3)
class HeadroomCompressor:
"""基于 headroom-ai 的上下文压缩器
支持 SmartCrusher (JSON) 和 CodeCompressor (代码) 两种压缩策略。
CCR 可逆压缩保证原始数据可通过 headroom_retrieve 取回。
配置项:
enabled: bool — 开关
compressors: list[str] — 启用的压缩器 ["smart_crusher", "code_compressor"]
ccr_ttl: int — CCR 缓存 TTL默认 3000 表示永不过期
max_entries: int — CCR 缓存最大条目数,默认 1000
min_length: int — 最小压缩长度(字符),默认 500
model: str — 传给 headroom 的模型名
"""
def __init__(self, config: dict[str, Any]):
self._config = config
self._compressors = config.get("compressors", ["smart_crusher", "code_compressor"])
self._ccr_ttl = config.get("ccr_ttl", 300)
self._max_entries = config.get("max_entries", 1000)
self._min_length = config.get("min_length", 500)
self._model = config.get("model", "default")
# CCR cache: hash -> (content, insert_timestamp) with LRU ordering
self._ccr_cache: OrderedDict[str, tuple[str, float]] = OrderedDict()
def is_available(self) -> bool:
"""检查 headroom-ai 是否已安装"""
return _HEADROOM_AVAILABLE
async def compress(self, messages: list[dict]) -> list[dict]:
"""压缩消息列表中 role=tool 的消息"""
if not _HEADROOM_AVAILABLE:
return messages
compressed = []
for msg in messages:
if msg.get("role") == "tool" and len(str(msg.get("content", ""))) >= self._min_length:
try:
original_content = str(msg.get("content", ""))
# Use headroom compress on the tool message
result = headroom_compress(
[msg],
model=self._model,
)
# result.messages contains the compressed messages
if hasattr(result, "messages") and result.messages:
compressed_msg = result.messages[0]
# Store original in CCR cache
ccr_hash = self._store_ccr(original_content)
# Append CCR hash to compressed content
content = compressed_msg.get("content", original_content)
if ccr_hash:
content += f"\n<!-- CCR:hash={ccr_hash} -->"
compressed.append({**msg, "content": content})
else:
compressed.append(msg)
except Exception as e:
logger.warning(f"Headroom compression failed for tool message: {e}")
compressed.append(msg)
else:
compressed.append(msg)
return compressed
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""压缩单个工具输出结果"""
content = str(result)
if not _HEADROOM_AVAILABLE:
return content
if len(content) < self._min_length:
return content
try:
# Route by content type
content_type = self._detect_content_type(content)
if content_type == "json" and "smart_crusher" in self._compressors:
compressed = self._compress_with_headroom(content, "smart_crusher")
elif content_type == "code" and "code_compressor" in self._compressors:
compressed = self._compress_with_headroom(content, "code_compressor")
else:
# No applicable compressor
return content
if compressed and len(compressed) < len(content):
ccr_hash = self._store_ccr(content)
if ccr_hash:
compressed += f"\n<!-- CCR:hash={ccr_hash} -->"
return compressed
return content
except Exception as e:
logger.warning(f"Tool result compression failed for '{tool_name}': {e}")
return content
def _detect_content_type(self, content: str) -> str:
"""检测内容类型"""
if _is_json_content(content):
return "json"
if _is_code_content(content):
return "code"
return "text"
def _compress_with_headroom(self, content: str, compressor: str) -> str | None:
"""使用 headroom 压缩内容"""
try:
msg = [{"role": "user", "content": content}]
result = headroom_compress(msg, model=self._model)
if hasattr(result, "messages") and result.messages:
return result.messages[0].get("content", content)
return None
except Exception as e:
logger.warning(f"Headroom {compressor} compression failed: {e}")
return None
def _store_ccr(self, original: str) -> str | None:
"""存储原始内容到 CCR 缓存,返回哈希
使用完整 SHA-256 防止碰撞。碰撞时拒绝覆盖并返回 None。
超过 max_entries 时淘汰最久未访问的条目LRU
"""
ccr_hash = hashlib.sha256(original.encode()).hexdigest()
# Collision detection: if hash exists with different content, reject
if ccr_hash in self._ccr_cache:
cached_content, _ = self._ccr_cache[ccr_hash]
if cached_content != original:
logger.warning(
"CCR hash collision detected for hash=%s... "
"Rejecting overwrite to prevent data loss.",
ccr_hash[:16],
)
return None
# Same content: idempotent update (renew timestamp + LRU position)
self._ccr_cache.move_to_end(ccr_hash)
self._ccr_cache[ccr_hash] = (original, time.monotonic())
return ccr_hash
# Evict expired entries before inserting
self._evict_expired()
# LRU eviction: if at capacity, remove oldest entry
while len(self._ccr_cache) >= self._max_entries:
self._ccr_cache.popitem(last=False)
self._ccr_cache[ccr_hash] = (original, time.monotonic())
return ccr_hash
def _evict_expired(self) -> None:
"""清理过期的 CCR 缓存条目"""
if self._ccr_ttl <= 0:
return # TTL=0 means no expiry
now = time.monotonic()
expired_keys = [
k for k, (_, ts) in self._ccr_cache.items()
if now - ts > self._ccr_ttl
]
for k in expired_keys:
del self._ccr_cache[k]
def retrieve(self, ccr_hash: str | None = None, query: str | None = None) -> dict:
"""从 CCR 缓存检索原始数据"""
if ccr_hash and ccr_hash in self._ccr_cache:
content, ts = self._ccr_cache[ccr_hash]
# Check TTL
if self._ccr_ttl > 0:
if time.monotonic() - ts > self._ccr_ttl:
del self._ccr_cache[ccr_hash]
return {
"error": f"CCR hash '{ccr_hash}' expired",
"success": False,
}
# Renew LRU position on access
self._ccr_cache.move_to_end(ccr_hash)
return {
"content": content,
"ccr_hash": ccr_hash,
"success": True,
}
if query:
# Simple keyword search in cached content
results = []
for h, (content, _) in self._ccr_cache.items():
if query.lower() in content.lower():
results.append({"ccr_hash": h, "content": content[:500]})
if results:
return {"results": results, "success": True}
return {
"error": f"CCR hash '{ccr_hash}' not found in cache",
"success": False,
}