fischer-agentkit/src/agentkit/core/compressor.py

312 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ContextCompressor - 上下文压缩与 Prompt 缓存
长会话自动压缩历史消息,保持 Token 在预算内;
会话内 Prompt 不重复渲染。
"""
import hashlib
import json
import logging
from typing import Any, Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@runtime_checkable
class CompressionStrategy(Protocol):
"""压缩策略协议 — 所有压缩器必须实现此接口"""
async def compress(self, messages: list[dict]) -> list[dict]:
"""压缩消息列表"""
...
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""压缩单个工具输出结果,返回压缩后的字符串"""
...
def is_available(self) -> bool:
"""检查压缩器是否可用"""
...
class ContextCompressor:
"""Compress long conversation histories to stay within token budgets"""
def __init__(
self,
llm_gateway: Any = None,
max_tokens: int = 4000,
keep_recent: int = 3,
model: str = "default",
model_context_limit: int = 128_000,
headroom_threshold: float = 0.8,
min_tokens: int = 8_000,
auxiliary_model: str | None = None,
):
self._llm_gateway = llm_gateway
self._max_tokens = max_tokens
self._keep_recent = keep_recent
self._model = model
# U3: Headroom-based compression trigger — predict context overflow
# before the single-request limit is hit.
self._model_context_limit = model_context_limit
self._headroom_threshold = headroom_threshold
self._min_tokens = min_tokens
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
# When set and differs from main model, _summarize tries auxiliary first,
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
self._auxiliary_model = auxiliary_model
def should_compress(self, messages: list[dict]) -> bool:
"""Check if compression should be triggered based on headroom ratio.
Triggers when either:
1. estimated_tokens / model_context_limit > headroom_threshold (headroom)
2. estimated_tokens > min_tokens (fixed fallback, preserves old behavior)
"""
estimated = self.estimate_tokens(messages)
if estimated / self._model_context_limit > self._headroom_threshold:
return True
if estimated > self._min_tokens:
return True
return False
def estimate_tokens(self, messages: list[dict]) -> int:
"""Estimate total tokens in message list (rough: 4 chars = 1 token)"""
total = 0
for msg in messages:
content = msg.get("content", "")
total += len(str(content)) // 4
return total
async def compress(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
"""Compress messages if they exceed token budget
Strategy:
1. Keep system messages unchanged
2. Keep the most recent N messages unchanged
3. Compress older messages into a summary using LLM
"""
if self.estimate_tokens(messages) <= self._max_tokens:
return messages
# Separate system messages, old messages, and recent messages
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
if len(non_system) <= self._keep_recent:
return messages # Not enough messages to compress
old_msgs = non_system[: -self._keep_recent]
recent_msgs = non_system[-self._keep_recent :]
# Compress old messages
summary = await self._summarize(old_msgs)
# Build compressed message list
compressed = list(system_msgs)
if summary:
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.extend(recent_msgs)
# Recursive check: if still over budget, compress again
if self.estimate_tokens(compressed) > self._max_tokens:
if _compression_depth >= 1:
# Depth guard: force truncation instead of infinite recursion
return self._truncate(compressed)
if len(recent_msgs) > 1:
# Try keeping fewer recent messages
return await self._compress_aggressive(
messages, _compression_depth=_compression_depth + 1
)
# Last resort: truncate
return self._truncate(compressed)
return compressed
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
"""Summarize a list of messages using LLM.
G4/U1: When ``auxiliary_model`` is configured and differs from the main
model, try auxiliary first (cost-optimization). On auxiliary failure OR
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
fall back to main model. Existing ``_simple_summary`` degradation
preserved as the final tier when main model also fails.
"""
if not self._llm_gateway:
# No LLM available, do simple truncation
return self._simple_summary(messages)
# Build summary prompt
conversation_text = "\n".join(
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
)
# Pre-truncate if conversation_text exceeds safe token threshold
estimated_tokens = len(conversation_text) // 4
if estimated_tokens > max_input_tokens:
max_chars = max_input_tokens * 4
conversation_text = conversation_text[:max_chars] + "\n...[truncated]"
prompt = (
"Summarize the following conversation history concisely, "
"preserving key facts, decisions, and context. "
"Focus on information that would be needed for continuing the conversation.\n\n"
f"{conversation_text}"
)
# G4: Try auxiliary model first when configured (cheap route).
if self._auxiliary_model and self._auxiliary_model != self._model:
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._auxiliary_model,
agent_name="compressor",
task_type="summarization",
)
# Finding 4: empty content is a failure, not a success.
if response.content and response.content.strip():
return response.content
logger.info("Auxiliary model returned empty content, falling back to main model")
except Exception as e:
logger.info(
f"Auxiliary model summarization failed, falling back to main model: {e}"
)
# Main model path (or auxiliary fallback).
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
agent_name="compressor",
task_type="summarization",
)
return response.content
except Exception as e:
logger.warning(f"LLM summarization failed, using simple summary: {e}")
return self._simple_summary(messages)
def _simple_summary(self, messages: list[dict]) -> str:
"""Simple truncation-based summary when LLM is unavailable"""
parts = []
for msg in messages:
role = msg.get("role", "unknown")
content = str(msg.get("content", ""))[:200]
parts.append(f"[{role}]: {content}...")
return "\n".join(parts)
async def _compress_aggressive(
self, messages: list[dict], _compression_depth: int = 0
) -> list[dict]:
"""More aggressive compression when standard compression isn't enough"""
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
# Keep only the last message
if non_system:
summary = await self._summarize(non_system[:-1])
compressed = list(system_msgs)
if summary:
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.append(non_system[-1])
return compressed
return messages
def _truncate(self, messages: list[dict]) -> list[dict]:
"""Last resort: truncate long messages"""
result = []
for msg in messages:
content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 4:
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
result.append(msg)
return result
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
"""默认实现:不做压缩,直接返回字符串表示"""
return str(result)
def is_available(self) -> bool:
"""ContextCompressor 始终可用"""
return True
def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None:
"""根据配置创建压缩器实例
Args:
config: 压缩配置字典,支持以下字段:
- enabled: bool, 是否启用压缩(默认 False
- provider: "headroom" | "summary", 压缩提供者
- max_tokens: int, token 预算summary 模式)
- keep_recent: int, 保留最近 N 条消息summary 模式)
- 其他 provider 特定配置
Returns:
CompressionStrategy 实例,或 None未启用时
"""
if not config or not config.get("enabled", False):
return None
provider = config.get("provider", "summary")
if provider == "headroom":
try:
from agentkit.core.headroom_compressor import HeadroomCompressor
compressor = HeadroomCompressor(config)
if compressor.is_available():
return compressor
logger.warning(
"HeadroomCompressor not available (headroom-ai not installed?). "
"Falling back to ContextCompressor."
)
except ImportError:
logger.warning(
"HeadroomCompressor module not available. Falling back to ContextCompressor."
)
# Fallback to summary compressor
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
# Default: summary-based compression
return ContextCompressor(
max_tokens=config.get("max_tokens", 4000),
keep_recent=config.get("keep_recent", 3),
)
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables"""
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
if not hasattr(template, "_render_cache"):
template._render_cache = {}
if cache_key in template._render_cache:
return template._render_cache[cache_key]
result = template.render(variables=variables)
template._render_cache[cache_key] = result
return result
def clear_cache(template) -> None:
"""Clear the render cache on a PromptTemplate instance"""
if hasattr(template, "_render_cache"):
template._render_cache.clear()