312 lines
12 KiB
Python
312 lines
12 KiB
Python
"""ContextCompressor - 上下文压缩与 Prompt 缓存
|
||
|
||
长会话自动压缩历史消息,保持 Token 在预算内;
|
||
会话内 Prompt 不重复渲染。
|
||
"""
|
||
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
from typing import Any, Protocol, runtime_checkable
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@runtime_checkable
|
||
class CompressionStrategy(Protocol):
|
||
"""压缩策略协议 — 所有压缩器必须实现此接口"""
|
||
|
||
async def compress(self, messages: list[dict]) -> list[dict]:
|
||
"""压缩消息列表"""
|
||
...
|
||
|
||
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
|
||
"""压缩单个工具输出结果,返回压缩后的字符串"""
|
||
...
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查压缩器是否可用"""
|
||
...
|
||
|
||
|
||
class ContextCompressor:
|
||
"""Compress long conversation histories to stay within token budgets"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_gateway: Any = None,
|
||
max_tokens: int = 4000,
|
||
keep_recent: int = 3,
|
||
model: str = "default",
|
||
model_context_limit: int = 128_000,
|
||
headroom_threshold: float = 0.8,
|
||
min_tokens: int = 8_000,
|
||
auxiliary_model: str | None = None,
|
||
):
|
||
self._llm_gateway = llm_gateway
|
||
self._max_tokens = max_tokens
|
||
self._keep_recent = keep_recent
|
||
self._model = model
|
||
# U3: Headroom-based compression trigger — predict context overflow
|
||
# before the single-request limit is hit.
|
||
self._model_context_limit = model_context_limit
|
||
self._headroom_threshold = headroom_threshold
|
||
self._min_tokens = min_tokens
|
||
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
|
||
# When set and differs from main model, _summarize tries auxiliary first,
|
||
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
|
||
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
|
||
self._auxiliary_model = auxiliary_model
|
||
|
||
def should_compress(self, messages: list[dict]) -> bool:
|
||
"""Check if compression should be triggered based on headroom ratio.
|
||
|
||
Triggers when either:
|
||
1. estimated_tokens / model_context_limit > headroom_threshold (headroom)
|
||
2. estimated_tokens > min_tokens (fixed fallback, preserves old behavior)
|
||
"""
|
||
estimated = self.estimate_tokens(messages)
|
||
if estimated / self._model_context_limit > self._headroom_threshold:
|
||
return True
|
||
if estimated > self._min_tokens:
|
||
return True
|
||
return False
|
||
|
||
def estimate_tokens(self, messages: list[dict]) -> int:
|
||
"""Estimate total tokens in message list (rough: 4 chars = 1 token)"""
|
||
total = 0
|
||
for msg in messages:
|
||
content = msg.get("content", "")
|
||
total += len(str(content)) // 4
|
||
return total
|
||
|
||
async def compress(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
|
||
"""Compress messages if they exceed token budget
|
||
|
||
Strategy:
|
||
1. Keep system messages unchanged
|
||
2. Keep the most recent N messages unchanged
|
||
3. Compress older messages into a summary using LLM
|
||
"""
|
||
if self.estimate_tokens(messages) <= self._max_tokens:
|
||
return messages
|
||
|
||
# Separate system messages, old messages, and recent messages
|
||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||
non_system = [m for m in messages if m.get("role") != "system"]
|
||
|
||
if len(non_system) <= self._keep_recent:
|
||
return messages # Not enough messages to compress
|
||
|
||
old_msgs = non_system[: -self._keep_recent]
|
||
recent_msgs = non_system[-self._keep_recent :]
|
||
|
||
# Compress old messages
|
||
summary = await self._summarize(old_msgs)
|
||
|
||
# Build compressed message list
|
||
compressed = list(system_msgs)
|
||
if summary:
|
||
compressed.append(
|
||
{
|
||
"role": "system",
|
||
"content": f"## Conversation Summary\n{summary}",
|
||
}
|
||
)
|
||
compressed.extend(recent_msgs)
|
||
|
||
# Recursive check: if still over budget, compress again
|
||
if self.estimate_tokens(compressed) > self._max_tokens:
|
||
if _compression_depth >= 1:
|
||
# Depth guard: force truncation instead of infinite recursion
|
||
return self._truncate(compressed)
|
||
if len(recent_msgs) > 1:
|
||
# Try keeping fewer recent messages
|
||
return await self._compress_aggressive(
|
||
messages, _compression_depth=_compression_depth + 1
|
||
)
|
||
# Last resort: truncate
|
||
return self._truncate(compressed)
|
||
|
||
return compressed
|
||
|
||
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
|
||
"""Summarize a list of messages using LLM.
|
||
|
||
G4/U1: When ``auxiliary_model`` is configured and differs from the main
|
||
model, try auxiliary first (cost-optimization). On auxiliary failure OR
|
||
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
|
||
fall back to main model. Existing ``_simple_summary`` degradation
|
||
preserved as the final tier when main model also fails.
|
||
"""
|
||
if not self._llm_gateway:
|
||
# No LLM available, do simple truncation
|
||
return self._simple_summary(messages)
|
||
|
||
# Build summary prompt
|
||
conversation_text = "\n".join(
|
||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
|
||
)
|
||
|
||
# Pre-truncate if conversation_text exceeds safe token threshold
|
||
estimated_tokens = len(conversation_text) // 4
|
||
if estimated_tokens > max_input_tokens:
|
||
max_chars = max_input_tokens * 4
|
||
conversation_text = conversation_text[:max_chars] + "\n...[truncated]"
|
||
|
||
prompt = (
|
||
"Summarize the following conversation history concisely, "
|
||
"preserving key facts, decisions, and context. "
|
||
"Focus on information that would be needed for continuing the conversation.\n\n"
|
||
f"{conversation_text}"
|
||
)
|
||
|
||
# G4: Try auxiliary model first when configured (cheap route).
|
||
if self._auxiliary_model and self._auxiliary_model != self._model:
|
||
try:
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=self._auxiliary_model,
|
||
agent_name="compressor",
|
||
task_type="summarization",
|
||
)
|
||
# Finding 4: empty content is a failure, not a success.
|
||
if response.content and response.content.strip():
|
||
return response.content
|
||
logger.info("Auxiliary model returned empty content, falling back to main model")
|
||
except Exception as e:
|
||
logger.info(
|
||
f"Auxiliary model summarization failed, falling back to main model: {e}"
|
||
)
|
||
|
||
# Main model path (or auxiliary fallback).
|
||
try:
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=self._model,
|
||
agent_name="compressor",
|
||
task_type="summarization",
|
||
)
|
||
return response.content
|
||
except Exception as e:
|
||
logger.warning(f"LLM summarization failed, using simple summary: {e}")
|
||
return self._simple_summary(messages)
|
||
|
||
def _simple_summary(self, messages: list[dict]) -> str:
|
||
"""Simple truncation-based summary when LLM is unavailable"""
|
||
parts = []
|
||
for msg in messages:
|
||
role = msg.get("role", "unknown")
|
||
content = str(msg.get("content", ""))[:200]
|
||
parts.append(f"[{role}]: {content}...")
|
||
return "\n".join(parts)
|
||
|
||
async def _compress_aggressive(
|
||
self, messages: list[dict], _compression_depth: int = 0
|
||
) -> list[dict]:
|
||
"""More aggressive compression when standard compression isn't enough"""
|
||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||
non_system = [m for m in messages if m.get("role") != "system"]
|
||
|
||
# Keep only the last message
|
||
if non_system:
|
||
summary = await self._summarize(non_system[:-1])
|
||
compressed = list(system_msgs)
|
||
if summary:
|
||
compressed.append(
|
||
{
|
||
"role": "system",
|
||
"content": f"## Conversation Summary\n{summary}",
|
||
}
|
||
)
|
||
compressed.append(non_system[-1])
|
||
return compressed
|
||
|
||
return messages
|
||
|
||
def _truncate(self, messages: list[dict]) -> list[dict]:
|
||
"""Last resort: truncate long messages"""
|
||
result = []
|
||
for msg in messages:
|
||
content = str(msg.get("content", ""))
|
||
if len(content) > self._max_tokens * 4:
|
||
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
|
||
result.append(msg)
|
||
return result
|
||
|
||
async def compress_tool_result(self, tool_name: str, result: Any) -> str:
|
||
"""默认实现:不做压缩,直接返回字符串表示"""
|
||
return str(result)
|
||
|
||
def is_available(self) -> bool:
|
||
"""ContextCompressor 始终可用"""
|
||
return True
|
||
|
||
|
||
def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrategy | None:
|
||
"""根据配置创建压缩器实例
|
||
|
||
Args:
|
||
config: 压缩配置字典,支持以下字段:
|
||
- enabled: bool, 是否启用压缩(默认 False)
|
||
- provider: "headroom" | "summary", 压缩提供者
|
||
- max_tokens: int, token 预算(summary 模式)
|
||
- keep_recent: int, 保留最近 N 条消息(summary 模式)
|
||
- 其他 provider 特定配置
|
||
|
||
Returns:
|
||
CompressionStrategy 实例,或 None(未启用时)
|
||
"""
|
||
if not config or not config.get("enabled", False):
|
||
return None
|
||
|
||
provider = config.get("provider", "summary")
|
||
|
||
if provider == "headroom":
|
||
try:
|
||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||
|
||
compressor = HeadroomCompressor(config)
|
||
if compressor.is_available():
|
||
return compressor
|
||
logger.warning(
|
||
"HeadroomCompressor not available (headroom-ai not installed?). "
|
||
"Falling back to ContextCompressor."
|
||
)
|
||
except ImportError:
|
||
logger.warning(
|
||
"HeadroomCompressor module not available. Falling back to ContextCompressor."
|
||
)
|
||
# Fallback to summary compressor
|
||
return ContextCompressor(
|
||
max_tokens=config.get("max_tokens", 4000),
|
||
keep_recent=config.get("keep_recent", 3),
|
||
)
|
||
|
||
# Default: summary-based compression
|
||
return ContextCompressor(
|
||
max_tokens=config.get("max_tokens", 4000),
|
||
keep_recent=config.get("keep_recent", 3),
|
||
)
|
||
|
||
|
||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
|
||
|
||
if not hasattr(template, "_render_cache"):
|
||
template._render_cache = {}
|
||
|
||
if cache_key in template._render_cache:
|
||
return template._render_cache[cache_key]
|
||
|
||
result = template.render(variables=variables)
|
||
template._render_cache[cache_key] = result
|
||
return result
|
||
|
||
|
||
def clear_cache(template) -> None:
|
||
"""Clear the render cache on a PromptTemplate instance"""
|
||
if hasattr(template, "_render_cache"):
|
||
template._render_cache.clear()
|