172 lines
5.9 KiB
Python
172 lines
5.9 KiB
Python
"""ContextCompressor - 上下文压缩与 Prompt 缓存
|
|
|
|
长会话自动压缩历史消息,保持 Token 在预算内;
|
|
会话内 Prompt 不重复渲染。
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ContextCompressor:
|
|
"""Compress long conversation histories to stay within token budgets"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_gateway: Any = None,
|
|
max_tokens: int = 4000,
|
|
keep_recent: int = 3,
|
|
model: str = "default",
|
|
):
|
|
self._llm_gateway = llm_gateway
|
|
self._max_tokens = max_tokens
|
|
self._keep_recent = keep_recent
|
|
self._model = model
|
|
|
|
def estimate_tokens(self, messages: list[dict]) -> int:
|
|
"""Estimate total tokens in message list (rough: 4 chars = 1 token)"""
|
|
total = 0
|
|
for msg in messages:
|
|
content = msg.get("content", "")
|
|
total += len(str(content)) // 4
|
|
return total
|
|
|
|
async def compress(self, messages: list[dict]) -> list[dict]:
|
|
"""Compress messages if they exceed token budget
|
|
|
|
Strategy:
|
|
1. Keep system messages unchanged
|
|
2. Keep the most recent N messages unchanged
|
|
3. Compress older messages into a summary using LLM
|
|
"""
|
|
if self.estimate_tokens(messages) <= self._max_tokens:
|
|
return messages
|
|
|
|
# Separate system messages, old messages, and recent messages
|
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|
non_system = [m for m in messages if m.get("role") != "system"]
|
|
|
|
if len(non_system) <= self._keep_recent:
|
|
return messages # Not enough messages to compress
|
|
|
|
old_msgs = non_system[:-self._keep_recent]
|
|
recent_msgs = non_system[-self._keep_recent:]
|
|
|
|
# Compress old messages
|
|
summary = await self._summarize(old_msgs)
|
|
|
|
# Build compressed message list
|
|
compressed = list(system_msgs)
|
|
if summary:
|
|
compressed.append({
|
|
"role": "system",
|
|
"content": f"## Conversation Summary\n{summary}",
|
|
})
|
|
compressed.extend(recent_msgs)
|
|
|
|
# Recursive check: if still over budget, compress again
|
|
if self.estimate_tokens(compressed) > self._max_tokens:
|
|
if len(recent_msgs) > 1:
|
|
# Try keeping fewer recent messages
|
|
return await self._compress_aggressive(messages)
|
|
# Last resort: truncate
|
|
return self._truncate(compressed)
|
|
|
|
return compressed
|
|
|
|
async def _summarize(self, messages: list[dict]) -> str:
|
|
"""Summarize a list of messages using LLM"""
|
|
if not self._llm_gateway:
|
|
# No LLM available, do simple truncation
|
|
return self._simple_summary(messages)
|
|
|
|
# Build summary prompt
|
|
conversation_text = "\n".join(
|
|
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
|
|
for m in messages
|
|
)
|
|
|
|
prompt = (
|
|
"Summarize the following conversation history concisely, "
|
|
"preserving key facts, decisions, and context. "
|
|
"Focus on information that would be needed for continuing the conversation.\n\n"
|
|
f"{conversation_text}"
|
|
)
|
|
|
|
try:
|
|
response = await self._llm_gateway.chat(
|
|
messages=[{"role": "user", "content": prompt}],
|
|
model=self._model,
|
|
agent_name="compressor",
|
|
task_type="summarization",
|
|
)
|
|
return response.content
|
|
except Exception as e:
|
|
logger.warning(f"LLM summarization failed, using simple summary: {e}")
|
|
return self._simple_summary(messages)
|
|
|
|
def _simple_summary(self, messages: list[dict]) -> str:
|
|
"""Simple truncation-based summary when LLM is unavailable"""
|
|
parts = []
|
|
for msg in messages:
|
|
role = msg.get("role", "unknown")
|
|
content = str(msg.get("content", ""))[:200]
|
|
parts.append(f"[{role}]: {content}...")
|
|
return "\n".join(parts)
|
|
|
|
async def _compress_aggressive(self, messages: list[dict]) -> list[dict]:
|
|
"""More aggressive compression when standard compression isn't enough"""
|
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
|
non_system = [m for m in messages if m.get("role") != "system"]
|
|
|
|
# Keep only the last message
|
|
if non_system:
|
|
summary = await self._summarize(non_system[:-1])
|
|
compressed = list(system_msgs)
|
|
if summary:
|
|
compressed.append({
|
|
"role": "system",
|
|
"content": f"## Conversation Summary\n{summary}",
|
|
})
|
|
compressed.append(non_system[-1])
|
|
return compressed
|
|
|
|
return messages
|
|
|
|
def _truncate(self, messages: list[dict]) -> list[dict]:
|
|
"""Last resort: truncate long messages"""
|
|
result = []
|
|
for msg in messages:
|
|
content = str(msg.get("content", ""))
|
|
if len(content) > self._max_tokens * 2:
|
|
msg = {**msg, "content": content[:self._max_tokens * 2] + "...[truncated]"}
|
|
result.append(msg)
|
|
return result
|
|
|
|
|
|
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
|
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
|
cache_key = hashlib.md5(
|
|
json.dumps(variables or {}, sort_keys=True).encode()
|
|
).hexdigest()
|
|
|
|
if not hasattr(template, '_render_cache'):
|
|
template._render_cache = {}
|
|
|
|
if cache_key in template._render_cache:
|
|
return template._render_cache[cache_key]
|
|
|
|
result = template.render(variables=variables)
|
|
template._render_cache[cache_key] = result
|
|
return result
|
|
|
|
|
|
def clear_cache(template) -> None:
|
|
"""Clear the render cache on a PromptTemplate instance"""
|
|
if hasattr(template, '_render_cache'):
|
|
template._render_cache.clear()
|