fischer-agentkit/src/agentkit/core/compressor.py

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ContextCompressor - 上下文压缩与 Prompt 缓存
长会话自动压缩历史消息,保持 Token 在预算内;
会话内 Prompt 不重复渲染。
"""
import hashlib
import json
import logging
from typing import Any, Protocol, runtime_checkable
logger = logging.getLogger(__name__)
def _is_cjk(char: str) -> bool:
"""Check if a character is CJK (1 token ≈ 1 char).
Covers CJK Unified Ideographs, Hiragana, Katakana, and Hangul Syllables.
"""
cp = ord(char)
return (
0x4E00 <= cp <= 0x9FFF # CJK Unified Ideographs
or 0x3040 <= cp <= 0x30FF # Hiragana + Katakana
or 0xAC00 <= cp <= 0xD7AF # Hangul Syllables
)
def estimate_text_tokens(text: str) -> int:
"""Estimate token count: CJK 1:1, other characters 4:1.
CJK characters typically tokenize to ~1 token per character, while
ASCII/Latin text averages ~4 chars per token. Avoids the 4x
underestimation that ``len(text) // 4`` produces for CJK conversations.
ponytail ceiling: pure CJK may still underestimate ~10-20%, but
headroom_threshold=0.8 absorbs this. Upgrade path: litellm.token_counter
or provider-specific tokenizer.
"""
cjk_count = 0
non_cjk_count = 0
for char in text:
if _is_cjk(char):
cjk_count += 1
else:
non_cjk_count += 1
return cjk_count + non_cjk_count // 4
@runtime_checkable
class CompressionStrategy(Protocol):
"""压缩策略协议 — 所有压缩器必须实现此接口"""
async def compress(self, messages: list[dict]) -> list[dict]:
"""压缩消息列表"""
...
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""压缩单个工具输出结果,返回压缩后的字符串"""
...
def is_available(self) -> bool:
"""检查压缩器是否可用"""
...
class ContextCompressor:
"""Compress long conversation histories to stay within token budgets"""
def __init__(
self,
llm_gateway: Any = None,
max_tokens: int = 4000,
keep_recent: int = 3,
model: str = "default",
model_context_limit: int = 128_000,
headroom_threshold: float = 0.8,
min_tokens: int = 8_000,
auxiliary_model: str | None = None,
):
self._llm_gateway = llm_gateway
self._max_tokens = max_tokens
self._keep_recent = keep_recent
self._model = model
# U3: Headroom-based compression trigger — predict context overflow
# before the single-request limit is hit.
self._model_context_limit = model_context_limit
self._headroom_threshold = headroom_threshold
self._min_tokens = min_tokens
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
# When set and differs from main model, _summarize tries auxiliary first,
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
self._auxiliary_model = auxiliary_model
def should_compress(self, messages: list[dict]) -> bool:
"""Check if compression should be triggered based on headroom ratio.
Triggers when either:
1. estimated_tokens / model_context_limit > headroom_threshold (headroom)
2. estimated_tokens > min_tokens (fixed fallback, preserves old behavior)
"""
estimated = self.estimate_tokens(messages)
if estimated / self._model_context_limit > self._headroom_threshold:
return True
if estimated > self._min_tokens:
return True
return False
def estimate_tokens(self, messages: list[dict]) -> int:
"""Estimate total tokens in message list (CJK 1:1, ASCII 4:1)"""
return sum(estimate_text_tokens(str(m.get("content", ""))) for m in messages)
async def compress(self, messages: list[dict]) -> list[dict]:
"""Compress messages if they exceed token budget.
Linear flow: summarize -> aggressive -> truncate.
Each step only fires if the previous didn't bring tokens under budget.
"""
tokens_before = self.estimate_tokens(messages)
if tokens_before <= self._max_tokens:
return messages
# Separate system messages, old messages, and recent messages
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
if len(non_system) <= self._keep_recent:
return messages # Not enough messages to compress
old_msgs = non_system[: -self._keep_recent]
recent_msgs = non_system[-self._keep_recent :]
# Step 1: Summarize old messages
summary = await self._summarize(old_msgs)
compressed = list(system_msgs)
if summary:
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.extend(recent_msgs)
# Step 2: If still over budget, aggressive compress
# F-010: pass original `messages` (not `compressed`) to avoid summary-of-summary
strategy = "summary"
if self.estimate_tokens(compressed) > self._max_tokens:
compressed = await self._compress_aggressive(messages)
strategy = "aggressive"
# Step 3: If still over budget, truncate as last resort
if self.estimate_tokens(compressed) > self._max_tokens:
compressed = self._truncate(compressed)
strategy = "truncate"
# Step 4: Log compression result
tokens_after = self.estimate_tokens(compressed)
self._log_compression(tokens_before, tokens_after, len(messages), len(compressed), strategy)
return compressed
def _log_compression(
self,
tokens_before: int,
tokens_after: int,
msg_count_before: int,
msg_count_after: int,
strategy: str,
) -> None:
"""Log structured compression info (tokens_before/after/ratio/msg_count)."""
ratio = tokens_after / tokens_before if tokens_before > 0 else 0.0
logger.info(
"context compressed: %d -> %d tokens (%.1f%%), messages: %d -> %d, strategy: %s",
tokens_before,
tokens_after,
ratio * 100,
msg_count_before,
msg_count_after,
strategy,
)
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
"""Summarize a list of messages using LLM.
G4/U1: When ``auxiliary_model`` is configured and differs from the main
model, try auxiliary first (cost-optimization). On auxiliary failure OR
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
fall back to main model. Existing ``_simple_summary`` degradation
preserved as the final tier when main model also fails.
"""
if not self._llm_gateway:
# No LLM available, do simple truncation
return self._simple_summary(messages)
# Build summary prompt
conversation_text = "\n".join(
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
)
# Pre-truncate if conversation_text exceeds safe token threshold
estimated_tokens = estimate_text_tokens(conversation_text)
if estimated_tokens > max_input_tokens:
# CJK-aware char limit: max_input_tokens chars is exact for CJK (1:1),
# conservative for ASCII (4:1, truncates to 1/4 budget but safe).
# Review fix #1: old `* 4` allowed 4x token budget for CJK text.
max_chars = max_input_tokens
conversation_text = conversation_text[:max_chars] + "\n...[truncated]"
prompt = (
"Summarize the following conversation history concisely, "
"preserving key facts, decisions, and context. "
"Focus on information that would be needed for continuing the conversation.\n\n"
f"{conversation_text}"
)
# G4: Try auxiliary model first when configured (cheap route).
if self._auxiliary_model and self._auxiliary_model != self._model:
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._auxiliary_model,
agent_name="compressor",
task_type="summarization",
)
# Finding 4: empty content is a failure, not a success.
if response.content and response.content.strip():
return response.content
logger.info("Auxiliary model returned empty content, falling back to main model")
except Exception as e:
logger.info(
f"Auxiliary model summarization failed, falling back to main model: {e}"
)
# Main model path (or auxiliary fallback).
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
agent_name="compressor",
task_type="summarization",
)
return response.content
except Exception as e:
logger.warning(f"LLM summarization failed, using simple summary: {e}")
return self._simple_summary(messages)
def _simple_summary(self, messages: list[dict]) -> str:
"""Simple truncation-based summary when LLM is unavailable"""
parts = []
for msg in messages:
role = msg.get("role", "unknown")
content = str(msg.get("content", ""))[:200]
parts.append(f"[{role}]: {content}...")
return "\n".join(parts)
async def _compress_aggressive(self, messages: list[dict]) -> list[dict]:
"""Aggressive compression: keep only last message + summary of the rest."""
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
# Keep only the last message
if non_system:
summary = await self._summarize(non_system[:-1])
compressed = list(system_msgs)
if summary:
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.append(non_system[-1])
return compressed
return messages
def _truncate(self, messages: list[dict]) -> list[dict]:
"""Last resort: truncate long messages"""
result = []
for msg in messages:
content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 4:
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
result.append(msg)
return result
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""默认实现:不做压缩,直接返回字符串表示"""
return str(result)
def is_available(self) -> bool:
"""ContextCompressor 始终可用"""
return True
def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None:
"""根据配置创建压缩器实例
Args:
config: 压缩配置字典,支持以下字段:
- enabled: bool, 是否启用压缩(默认 False
- provider: "headroom" | "summary", 压缩提供者
- max_tokens: int, token 预算summary 模式)
- keep_recent: int, 保留最近 N 条消息summary 模式)
- 其他 provider 特定配置
Returns:
CompressionStrategy 实例,或 None未启用时
"""
if not config or not config.get("enabled", False):
return None
provider = config.get("provider", "summary")
if provider == "headroom":
try:
from agentkit.core.headroom_compressor import HeadroomCompressor
compressor = HeadroomCompressor(config)
if compressor.is_available():
return compressor
logger.warning(
"HeadroomCompressor not available (headroom-ai not installed?). "
"Falling back to ContextCompressor."
)
except ImportError:
logger.warning(
"HeadroomCompressor module not available. Falling back to ContextCompressor."
)
# Fallback to summary compressor
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
# Default: summary-based compression
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables"""
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
if not hasattr(template, "_render_cache"):
template._render_cache = {}
if cache_key in template._render_cache:
return template._render_cache[cache_key]
result = template.render(variables=variables)
template._render_cache[cache_key] = result
return result
def clear_cache(template) -> None:
"""Clear the render cache on a PromptTemplate instance"""
if hasattr(template, "_render_cache"):
template._render_cache.clear()