fischer-agentkit/src/agentkit/core/compressor.py

172 lines
5.9 KiB
Python

"""ContextCompressor - 上下文压缩与 Prompt 缓存
长会话自动压缩历史消息,保持 Token 在预算内;
会话内 Prompt 不重复渲染。
"""
import hashlib
import json
import logging
from typing import Any
logger = logging.getLogger(__name__)
class ContextCompressor:
"""Compress long conversation histories to stay within token budgets"""
def __init__(
self,
llm_gateway: Any = None,
max_tokens: int = 4000,
keep_recent: int = 3,
model: str = "default",
):
self._llm_gateway = llm_gateway
self._max_tokens = max_tokens
self._keep_recent = keep_recent
self._model = model
def estimate_tokens(self, messages: list[dict]) -> int:
"""Estimate total tokens in message list (rough: 4 chars = 1 token)"""
total = 0
for msg in messages:
content = msg.get("content", "")
total += len(str(content)) // 4
return total
async def compress(self, messages: list[dict]) -> list[dict]:
"""Compress messages if they exceed token budget
Strategy:
1. Keep system messages unchanged
2. Keep the most recent N messages unchanged
3. Compress older messages into a summary using LLM
"""
if self.estimate_tokens(messages) <= self._max_tokens:
return messages
# Separate system messages, old messages, and recent messages
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
if len(non_system) <= self._keep_recent:
return messages # Not enough messages to compress
old_msgs = non_system[:-self._keep_recent]
recent_msgs = non_system[-self._keep_recent:]
# Compress old messages
summary = await self._summarize(old_msgs)
# Build compressed message list
compressed = list(system_msgs)
if summary:
compressed.append({
"role": "system",
"content": f"## Conversation Summary\n{summary}",
})
compressed.extend(recent_msgs)
# Recursive check: if still over budget, compress again
if self.estimate_tokens(compressed) > self._max_tokens:
if len(recent_msgs) > 1:
# Try keeping fewer recent messages
return await self._compress_aggressive(messages)
# Last resort: truncate
return self._truncate(compressed)
return compressed
async def _summarize(self, messages: list[dict]) -> str:
"""Summarize a list of messages using LLM"""
if not self._llm_gateway:
# No LLM available, do simple truncation
return self._simple_summary(messages)
# Build summary prompt
conversation_text = "\n".join(
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
for m in messages
)
prompt = (
"Summarize the following conversation history concisely, "
"preserving key facts, decisions, and context. "
"Focus on information that would be needed for continuing the conversation.\n\n"
f"{conversation_text}"
)
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
agent_name="compressor",
task_type="summarization",
)
return response.content
except Exception as e:
logger.warning(f"LLM summarization failed, using simple summary: {e}")
return self._simple_summary(messages)
def _simple_summary(self, messages: list[dict]) -> str:
"""Simple truncation-based summary when LLM is unavailable"""
parts = []
for msg in messages:
role = msg.get("role", "unknown")
content = str(msg.get("content", ""))[:200]
parts.append(f"[{role}]: {content}...")
return "\n".join(parts)
async def _compress_aggressive(self, messages: list[dict]) -> list[dict]:
"""More aggressive compression when standard compression isn't enough"""
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
# Keep only the last message
if non_system:
summary = await self._summarize(non_system[:-1])
compressed = list(system_msgs)
if summary:
compressed.append({
"role": "system",
"content": f"## Conversation Summary\n{summary}",
})
compressed.append(non_system[-1])
return compressed
return messages
def _truncate(self, messages: list[dict]) -> list[dict]:
"""Last resort: truncate long messages"""
result = []
for msg in messages:
content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 2:
msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"}
result.append(msg)
return result
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables"""
cache_key = hashlib.md5(
json.dumps(variables or {}, sort_keys=True).encode()
).hexdigest()
if not hasattr(template, '_render_cache'):
template._render_cache = {}
if cache_key in template._render_cache:
return template._render_cache[cache_key]
result = template.render(variables=variables)
template._render_cache[cache_key] = result
return result
def clear_cache(template) -> None:
"""Clear the render cache on a PromptTemplate instance"""
if hasattr(template, '_render_cache'):
template._render_cache.clear()