fischer-agentkit/tests/unit/test_context_compressor.py

712 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for ContextCompressor and PromptTemplate cache"""
import inspect
import logging
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.compressor import ContextCompressor, estimate_text_tokens
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.prompts.section import PromptSection
from agentkit.prompts.template import PromptTemplate
# ── Helpers ──────────────────────────────────────────
def make_mock_gateway(summary_content: str = "Summary of conversation") -> MagicMock:
"""创建一个 mock LLMGateway返回摘要响应"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
response = LLMResponse(
content=summary_content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(return_value=response)
return gateway
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
"""生成长消息列表用于测试压缩"""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append(
{
"role": "user",
"content": "x" * content_length + f" message {i}",
}
)
messages.append(
{
"role": "assistant",
"content": "y" * content_length + f" reply {i}",
}
)
return messages
# ── ContextCompressor Tests ──────────────────────────
class TestEstimateTokens:
"""estimate_tokens 基础测试"""
def test_empty_messages(self):
compressor = ContextCompressor()
assert compressor.estimate_tokens([]) == 0
def test_single_message(self):
compressor = ContextCompressor()
messages = [{"role": "user", "content": "a" * 40}]
# 40 chars / 4 = 10 tokens
assert compressor.estimate_tokens(messages) == 10
def test_multiple_messages(self):
compressor = ContextCompressor()
messages = [
{"role": "user", "content": "a" * 40},
{"role": "assistant", "content": "b" * 80},
]
# 40/4 + 80/4 = 10 + 20 = 30
assert compressor.estimate_tokens(messages) == 30
def test_missing_content_key(self):
compressor = ContextCompressor()
messages = [{"role": "user"}]
assert compressor.estimate_tokens(messages) == 0
class TestEstimateTextTokensCJK:
"""estimate_text_tokens CJK 估算测试 (U1)"""
def test_pure_cjk_chinese(self):
# 4 CJK chars = 4 tokens (1:1)
assert estimate_text_tokens("你好世界") == 4
def test_pure_ascii(self):
# 11 chars / 4 = 2.75, floor = 2
assert estimate_text_tokens("hello world") == 2
def test_pure_cjk_japanese_kana(self):
# 5 Hiragana chars = 5 tokens (1:1)
assert estimate_text_tokens("こんにちは") == 5
def test_pure_cjk_korean_hangul(self):
# 5 Hangul chars = 5 tokens (1:1)
assert estimate_text_tokens("안녕하세요") == 5
def test_mixed_cjk_and_ascii(self):
# "你好" (2 CJK = 2 tokens) + " world" (6 ASCII = 1 token) = 3
assert estimate_text_tokens("你好 world") == 3
def test_empty_string(self):
assert estimate_text_tokens("") == 0
def test_estimate_tokens_with_cjk_messages(self):
"""estimate_tokens() 对 CJK 消息不再低估 4 倍"""
compressor = ContextCompressor()
messages = [{"role": "user", "content": "你好世界"}] # 4 CJK = 4 tokens
assert compressor.estimate_tokens(messages) == 4
def test_estimate_tokens_mixed_messages(self):
"""estimate_tokens() 对混合消息给出合理估值"""
compressor = ContextCompressor()
messages = [
{"role": "user", "content": "你好"}, # 2 CJK = 2
{"role": "assistant", "content": "hello"}, # 5 ASCII = 1
]
assert compressor.estimate_tokens(messages) == 3
def test_cjk_not_underestimated_4x(self):
"""AE1: 100 条 CJK 消息的 estimate_tokens >= 旧实现的 4 倍"""
compressor = ContextCompressor()
cjk_msg = [{"role": "user", "content": "你好" * 50}] # 100 CJK chars
new_estimate = compressor.estimate_tokens(cjk_msg)
old_estimate = len("你好" * 50) // 4 # old: len // 4
assert new_estimate >= old_estimate * 4
async def test_summarize_cjk_pre_truncation(self):
"""Review fix #2: _summarize() CJK 文本预截断正确触发
构造 CJK 文本使 estimate_text_tokens > max_input_tokens 但
len(text) < max_input_tokens * 4验证旧 bug* 4 假设允许 4x 超预算)
"""
gateway = make_mock_gateway("Summary result")
compressor = ContextCompressor(llm_gateway=gateway)
# 4000 CJK chars = 4000 tokens (1:1), > max_input_tokens=3200
# But len=4000 < 3200 * 4 = 12800, so old `* 4` limit wouldn't truncate
cjk_content = "" * 4000
messages = [{"role": "user", "content": cjk_content}]
await compressor._summarize(messages, max_input_tokens=3200)
# Verify LLM was called with truncated text (not full 4000 chars)
call_messages = gateway.chat.call_args.kwargs["messages"]
prompt_content = call_messages[0]["content"]
# The conversation_text in the prompt should be truncated to <= 3200 chars
# (plus truncation marker), not the full 4000 chars
assert "...[truncated]" in prompt_content
# Verify the CJK content was actually shortened
assert prompt_content.count("") < 4000
class TestNoCompressionWhenUnderBudget:
"""Token 预算内不压缩"""
async def test_short_messages_not_compressed(self):
compressor = ContextCompressor(max_tokens=10000)
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
result = await compressor.compress(messages)
assert result == messages
async def test_exactly_at_budget_not_compressed(self):
# 40 chars = 10 tokens, budget = 10
compressor = ContextCompressor(max_tokens=10)
messages = [{"role": "user", "content": "a" * 40}]
result = await compressor.compress(messages)
assert result == messages
class TestCompressionTriggersWhenOverBudget:
"""超出预算时触发压缩"""
async def test_long_messages_get_compressed(self):
gateway = make_mock_gateway("Compressed summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = make_long_messages(count=5, content_length=500)
result = await compressor.compress(messages)
# 结果应该比原始消息少
assert len(result) < len(messages)
# 应该包含系统消息
system_msgs = [m for m in result if m.get("role") == "system"]
assert len(system_msgs) >= 1
# 应该保留最近的消息
assert result[-1]["role"] != "system"
async def test_compression_preserves_system_messages(self):
gateway = make_mock_gateway("Summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "c" * 2000},
{"role": "assistant", "content": "d" * 2000},
{"role": "user", "content": "Recent question"},
{"role": "assistant", "content": "Recent answer"},
]
result = await compressor.compress(messages)
# 第一个消息应该是原始 system 消息
assert result[0]["content"] == "System prompt"
assert result[0]["role"] == "system"
async def test_compression_keeps_recent_messages(self):
gateway = make_mock_gateway("Summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent question"},
{"role": "assistant", "content": "Recent answer"},
]
result = await compressor.compress(messages)
# 最后两条非系统消息应该是原始的最近消息
non_system = [m for m in result if m.get("role") != "system"]
assert non_system[-2]["content"] == "Recent question"
assert non_system[-1]["content"] == "Recent answer"
class TestSummaryGenerationWithLLM:
"""LLM 摘要生成"""
async def test_llm_summarization_called(self):
gateway = make_mock_gateway("LLM generated summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# LLM 应该被调用
gateway.chat.assert_called_once()
# 摘要应出现在结果中
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "LLM generated summary" in summary_msgs[0]["content"]
class TestFallbackToSimpleSummary:
"""LLM 不可用时回退到简单摘要"""
async def test_no_llm_uses_simple_summary(self):
compressor = ContextCompressor(
llm_gateway=None,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 应该有摘要消息(简单截断模式)
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
# 简单摘要应包含截断标记
assert "..." in summary_msgs[0]["content"]
async def test_llm_failure_uses_simple_summary(self):
gateway = make_mock_gateway()
gateway.chat = AsyncMock(side_effect=Exception("LLM error"))
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 应该有摘要消息(回退到简单摘要)
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
class TestAggressiveCompression:
"""标准压缩后仍超预算时的激进压缩"""
async def test_aggressive_compression_when_still_over_budget(self):
# 极小的预算,即使压缩后也超
gateway = make_mock_gateway("x" * 5000) # 摘要本身也很长
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=10,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 5000},
{"role": "assistant", "content": "b" * 5000},
{"role": "user", "content": "c" * 5000},
{"role": "assistant", "content": "d" * 5000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 激进压缩应只保留最后一条非系统消息
non_system = [m for m in result if m.get("role") != "system"]
# 激进压缩后最多保留 1 条非系统消息
assert len(non_system) <= 1
class TestTruncation:
"""截断作为最后手段"""
def test_truncate_long_messages(self):
compressor = ContextCompressor(max_tokens=50)
messages = [
{"role": "system", "content": "a" * 500},
{"role": "user", "content": "b" * 500},
]
result = compressor._truncate(messages)
# 长消息应该被截断
for msg in result:
content = msg.get("content", "")
if len(content) > 100 + len("...[truncated]"):
# 只有超长消息才截断
assert content.endswith("...[truncated]")
def test_truncate_preserves_short_messages(self):
compressor = ContextCompressor(max_tokens=50)
messages = [
{"role": "user", "content": "Short message"},
]
result = compressor._truncate(messages)
assert result[0]["content"] == "Short message"
class TestCompressLinearFlow:
"""U3: compress() 线性流程 + 签名变更测试"""
def test_compress_signature_no_compression_depth(self):
"""compress() 不再接受 _compression_depth 参数"""
sig = inspect.signature(ContextCompressor.compress)
assert "_compression_depth" not in sig.parameters
def test_compress_aggressive_signature_no_compression_depth(self):
"""_compress_aggressive() 不再接受 _compression_depth 参数"""
sig = inspect.signature(ContextCompressor._compress_aggressive)
assert "_compression_depth" not in sig.parameters
async def test_short_messages_not_compressed_linear(self):
"""短消息不压缩(线性流程验证)"""
compressor = ContextCompressor(max_tokens=10000)
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
result = await compressor.compress(messages)
assert result == messages
async def test_aggressive_receives_original_messages(self):
"""F-010: _compress_aggressive 接收 original messages, 非 compressed"""
# First summary is very long (triggers aggressive), second is short
long_summary = LLMResponse(
content="x" * 5000,
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
short_summary = LLMResponse(
content="short summary",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=[long_summary, short_summary])
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=10,
keep_recent=2,
)
messages = [
{"role": "user", "content": "ORIGINAL_MARKER_a" * 2000},
{"role": "assistant", "content": "ORIGINAL_MARKER_b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
await compressor.compress(messages)
# Second call (aggressive) should receive original message content,
# not the first summary ("x" * 5000)
assert gateway.chat.call_count == 2
second_call_content = gateway.chat.call_args_list[1].kwargs["messages"][0]["content"]
assert "ORIGINAL_MARKER" in second_call_content
# First summary content should NOT appear in the aggressive call
assert "xxxx" not in second_call_content
async def test_truncate_triggered_when_aggressive_insufficient(self):
"""aggressive 后仍超阈值 → truncate 强制截断"""
# Both summaries are very long, forcing truncate as last resort
long_summary = LLMResponse(
content="z" * 5000,
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=[long_summary, long_summary])
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=10,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 5000},
{"role": "assistant", "content": "b" * 5000},
{"role": "user", "content": "c" * 5000},
{"role": "assistant", "content": "d" * 5000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# Truncate should have cut message content
total_chars = sum(len(str(m.get("content", ""))) for m in result)
assert total_chars < sum(len(str(m.get("content", ""))) for m in messages)
# Review fix #7: verify truncate actually triggered via truncation marker
assert any("...[truncated]" in str(m.get("content", "")) for m in result)
class TestCompressionLogging:
"""U3: _log_compression 结构化日志测试"""
async def test_log_compression_outputs_structured_info(self, caplog):
"""_log_compression 输出结构化日志(包含 tokens/ratio/strategy"""
gateway = make_mock_gateway("Summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
await compressor.compress(messages)
# 验证结构化日志包含压缩信息
log_messages = [record.message for record in caplog.records]
assert any("context compressed" in msg for msg in log_messages)
assert any("strategy: summary" in msg for msg in log_messages)
# 日志应包含 token 数量和消息数量
assert any("tokens" in msg for msg in log_messages)
assert any("messages:" in msg for msg in log_messages)
async def test_no_log_when_not_compressed(self, caplog):
"""未触发压缩时不输出日志"""
compressor = ContextCompressor(max_tokens=10000)
messages = [
{"role": "user", "content": "Hello"},
]
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
await compressor.compress(messages)
log_messages = [record.message for record in caplog.records]
assert not any("context compressed" in msg for msg in log_messages)
async def test_log_strategy_aggressive(self, caplog):
"""压缩策略为 aggressive 时日志记录正确"""
long_summary = LLMResponse(
content="x" * 5000,
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
short_summary = LLMResponse(
content="short",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=[long_summary, short_summary])
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=10,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
await compressor.compress(messages)
log_messages = [record.message for record in caplog.records]
assert any("strategy: aggressive" in msg for msg in log_messages)
class TestNotEnoughMessagesToCompress:
"""消息数量不足时跳过压缩"""
async def test_fewer_than_keep_recent_messages(self):
compressor = ContextCompressor(
max_tokens=10,
keep_recent=5,
)
messages = [
{"role": "user", "content": "a" * 200},
{"role": "assistant", "content": "b" * 200},
]
# 非系统消息只有 2 条keep_recent=5不压缩
result = await compressor.compress(messages)
assert result == messages
# ── PromptTemplate Cache Tests ───────────────────────
class TestPromptTemplateRenderCached:
"""render_cached() 缓存测试"""
def test_same_variables_returns_cached_result(self):
section = PromptSection(
identity="Bot",
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
result2 = tpl.render_cached(variables={"name": "Alice"})
assert result1 == result2
# 应该是同一个对象(缓存命中)
assert result1 is result2
def test_different_variables_re_renders(self):
section = PromptSection(
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
result2 = tpl.render_cached(variables={"name": "Bob"})
assert result1 != result2
assert "Alice" in result1[0]["content"]
assert "Bob" in result2[0]["content"]
def test_no_variables_cached(self):
section = PromptSection(identity="Bot")
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached()
result2 = tpl.render_cached()
assert result1 is result2
def test_render_cached_matches_render(self):
section = PromptSection(
identity="Bot",
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
cached = tpl.render_cached(variables={"name": "Alice"})
direct = tpl.render(variables={"name": "Alice"})
assert cached == direct
class TestPromptTemplateClearCache:
"""clear_cache() 测试"""
def test_clear_cache_works(self):
section = PromptSection(
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
tpl.clear_cache()
result2 = tpl.render_cached(variables={"name": "Alice"})
# 清除缓存后应该重新渲染,不再是同一对象
assert result1 == result2
assert result1 is not result2
def test_clear_cache_on_fresh_template(self):
"""对没有缓存的新模板调用 clear_cache 不报错"""
section = PromptSection(identity="Bot")
tpl = PromptTemplate(sections=section)
tpl.clear_cache() # 应该不抛异常
class TestReActEngineWithCompressor:
"""ReActEngine 集成 ContextCompressor 测试"""
async def test_execute_with_compressor(self):
from agentkit.core.compressor import ContextCompressor
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=LLMResponse(
content="Final answer",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
)
compressor = ContextCompressor(max_tokens=10000)
engine = ReActEngine(llm_gateway=gateway)
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
compressor=compressor,
)
assert result.output == "Final answer"
async def test_execute_without_compressor_backward_compatible(self):
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=LLMResponse(
content="Answer",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
)
engine = ReActEngine(llm_gateway=gateway)
# 不传 compressor 应该正常工作
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
)
assert result.output == "Answer"
async def test_should_compress_cjk_fallback_path(self):
"""Review fix #5: _should_compress() CJK fallback for compressors
without should_compress() method (e.g. HeadroomCompressor)
Verifies R7: react.py fallback uses estimate_text_tokens, so CJK
long conversations correctly trigger compression.
"""
from agentkit.core.react import ReActEngine
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
# Mock compressor WITHOUT should_compress() method
# (simulates HeadroomCompressor which doesn't implement it)
mock_compressor = MagicMock(spec=["is_available", "compress", "compress_tool_result"])
mock_compressor.is_available.return_value = True
# CJK long conversation: 10000 CJK chars = 10000 tokens > 8000 threshold
cjk_conversation = [
{"role": "user", "content": "" * 5000},
{"role": "assistant", "content": "" * 5000},
]
result = engine._should_compress(cjk_conversation, mock_compressor)
assert result is True
# ASCII short conversation should not trigger
ascii_short = [{"role": "user", "content": "Hello"}]
result = engine._should_compress(ascii_short, mock_compressor)
assert result is False