fischer-agentkit/tests/unit/test_context_compressor.py

435 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for ContextCompressor and PromptTemplate cache"""
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.compressor import ContextCompressor
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.prompts.section import PromptSection
from agentkit.prompts.template import PromptTemplate
# ── Helpers ──────────────────────────────────────────
def make_mock_gateway(summary_content: str = "Summary of conversation") -> MagicMock:
"""创建一个 mock LLMGateway返回摘要响应"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
response = LLMResponse(
content=summary_content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(return_value=response)
return gateway
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
"""生成长消息列表用于测试压缩"""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append({
"role": "user",
"content": "x" * content_length + f" message {i}",
})
messages.append({
"role": "assistant",
"content": "y" * content_length + f" reply {i}",
})
return messages
# ── ContextCompressor Tests ──────────────────────────
class TestEstimateTokens:
"""estimate_tokens 基础测试"""
def test_empty_messages(self):
compressor = ContextCompressor()
assert compressor.estimate_tokens([]) == 0
def test_single_message(self):
compressor = ContextCompressor()
messages = [{"role": "user", "content": "a" * 40}]
# 40 chars / 4 = 10 tokens
assert compressor.estimate_tokens(messages) == 10
def test_multiple_messages(self):
compressor = ContextCompressor()
messages = [
{"role": "user", "content": "a" * 40},
{"role": "assistant", "content": "b" * 80},
]
# 40/4 + 80/4 = 10 + 20 = 30
assert compressor.estimate_tokens(messages) == 30
def test_missing_content_key(self):
compressor = ContextCompressor()
messages = [{"role": "user"}]
assert compressor.estimate_tokens(messages) == 0
class TestNoCompressionWhenUnderBudget:
"""Token 预算内不压缩"""
async def test_short_messages_not_compressed(self):
compressor = ContextCompressor(max_tokens=10000)
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
result = await compressor.compress(messages)
assert result == messages
async def test_exactly_at_budget_not_compressed(self):
# 40 chars = 10 tokens, budget = 10
compressor = ContextCompressor(max_tokens=10)
messages = [{"role": "user", "content": "a" * 40}]
result = await compressor.compress(messages)
assert result == messages
class TestCompressionTriggersWhenOverBudget:
"""超出预算时触发压缩"""
async def test_long_messages_get_compressed(self):
gateway = make_mock_gateway("Compressed summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = make_long_messages(count=5, content_length=500)
result = await compressor.compress(messages)
# 结果应该比原始消息少
assert len(result) < len(messages)
# 应该包含系统消息
system_msgs = [m for m in result if m.get("role") == "system"]
assert len(system_msgs) >= 1
# 应该保留最近的消息
assert result[-1]["role"] != "system"
async def test_compression_preserves_system_messages(self):
gateway = make_mock_gateway("Summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "c" * 2000},
{"role": "assistant", "content": "d" * 2000},
{"role": "user", "content": "Recent question"},
{"role": "assistant", "content": "Recent answer"},
]
result = await compressor.compress(messages)
# 第一个消息应该是原始 system 消息
assert result[0]["content"] == "System prompt"
assert result[0]["role"] == "system"
async def test_compression_keeps_recent_messages(self):
gateway = make_mock_gateway("Summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent question"},
{"role": "assistant", "content": "Recent answer"},
]
result = await compressor.compress(messages)
# 最后两条非系统消息应该是原始的最近消息
non_system = [m for m in result if m.get("role") != "system"]
assert non_system[-2]["content"] == "Recent question"
assert non_system[-1]["content"] == "Recent answer"
class TestSummaryGenerationWithLLM:
"""LLM 摘要生成"""
async def test_llm_summarization_called(self):
gateway = make_mock_gateway("LLM generated summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# LLM 应该被调用
gateway.chat.assert_called_once()
# 摘要应出现在结果中
summary_msgs = [
m for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "LLM generated summary" in summary_msgs[0]["content"]
class TestFallbackToSimpleSummary:
"""LLM 不可用时回退到简单摘要"""
async def test_no_llm_uses_simple_summary(self):
compressor = ContextCompressor(
llm_gateway=None,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 应该有摘要消息(简单截断模式)
summary_msgs = [
m for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
# 简单摘要应包含截断标记
assert "..." in summary_msgs[0]["content"]
async def test_llm_failure_uses_simple_summary(self):
gateway = make_mock_gateway()
gateway.chat = AsyncMock(side_effect=Exception("LLM error"))
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 2000},
{"role": "assistant", "content": "b" * 2000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 应该有摘要消息(回退到简单摘要)
summary_msgs = [
m for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
class TestAggressiveCompression:
"""标准压缩后仍超预算时的激进压缩"""
async def test_aggressive_compression_when_still_over_budget(self):
# 极小的预算,即使压缩后也超
gateway = make_mock_gateway("x" * 5000) # 摘要本身也很长
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=10,
keep_recent=2,
)
messages = [
{"role": "user", "content": "a" * 5000},
{"role": "assistant", "content": "b" * 5000},
{"role": "user", "content": "c" * 5000},
{"role": "assistant", "content": "d" * 5000},
{"role": "user", "content": "Recent"},
{"role": "assistant", "content": "Reply"},
]
result = await compressor.compress(messages)
# 激进压缩应只保留最后一条非系统消息
non_system = [m for m in result if m.get("role") != "system"]
# 激进压缩后最多保留 1 条非系统消息
assert len(non_system) <= 1
class TestTruncation:
"""截断作为最后手段"""
def test_truncate_long_messages(self):
compressor = ContextCompressor(max_tokens=50)
messages = [
{"role": "system", "content": "a" * 500},
{"role": "user", "content": "b" * 500},
]
result = compressor._truncate(messages)
# 长消息应该被截断
for msg in result:
content = msg.get("content", "")
if len(content) > 100 + len("...[truncated]"):
# 只有超长消息才截断
assert content.endswith("...[truncated]")
def test_truncate_preserves_short_messages(self):
compressor = ContextCompressor(max_tokens=50)
messages = [
{"role": "user", "content": "Short message"},
]
result = compressor._truncate(messages)
assert result[0]["content"] == "Short message"
class TestNotEnoughMessagesToCompress:
"""消息数量不足时跳过压缩"""
async def test_fewer_than_keep_recent_messages(self):
compressor = ContextCompressor(
max_tokens=10,
keep_recent=5,
)
messages = [
{"role": "user", "content": "a" * 200},
{"role": "assistant", "content": "b" * 200},
]
# 非系统消息只有 2 条keep_recent=5不压缩
result = await compressor.compress(messages)
assert result == messages
# ── PromptTemplate Cache Tests ───────────────────────
class TestPromptTemplateRenderCached:
"""render_cached() 缓存测试"""
def test_same_variables_returns_cached_result(self):
section = PromptSection(
identity="Bot",
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
result2 = tpl.render_cached(variables={"name": "Alice"})
assert result1 == result2
# 应该是同一个对象(缓存命中)
assert result1 is result2
def test_different_variables_re_renders(self):
section = PromptSection(
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
result2 = tpl.render_cached(variables={"name": "Bob"})
assert result1 != result2
assert "Alice" in result1[0]["content"]
assert "Bob" in result2[0]["content"]
def test_no_variables_cached(self):
section = PromptSection(identity="Bot")
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached()
result2 = tpl.render_cached()
assert result1 is result2
def test_render_cached_matches_render(self):
section = PromptSection(
identity="Bot",
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
cached = tpl.render_cached(variables={"name": "Alice"})
direct = tpl.render(variables={"name": "Alice"})
assert cached == direct
class TestPromptTemplateClearCache:
"""clear_cache() 测试"""
def test_clear_cache_works(self):
section = PromptSection(
context="Hello ${name}",
)
tpl = PromptTemplate(sections=section)
result1 = tpl.render_cached(variables={"name": "Alice"})
tpl.clear_cache()
result2 = tpl.render_cached(variables={"name": "Alice"})
# 清除缓存后应该重新渲染,不再是同一对象
assert result1 == result2
assert result1 is not result2
def test_clear_cache_on_fresh_template(self):
"""对没有缓存的新模板调用 clear_cache 不报错"""
section = PromptSection(identity="Bot")
tpl = PromptTemplate(sections=section)
tpl.clear_cache() # 应该不抛异常
class TestReActEngineWithCompressor:
"""ReActEngine 集成 ContextCompressor 测试"""
async def test_execute_with_compressor(self):
from agentkit.core.compressor import ContextCompressor
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=LLMResponse(
content="Final answer",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
))
compressor = ContextCompressor(max_tokens=10000)
engine = ReActEngine(llm_gateway=gateway)
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
compressor=compressor,
)
assert result.output == "Final answer"
async def test_execute_without_compressor_backward_compatible(self):
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, TokenUsage
gateway = MagicMock()
gateway.chat = AsyncMock(return_value=LLMResponse(
content="Answer",
model="test",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
))
engine = ReActEngine(llm_gateway=gateway)
# 不传 compressor 应该正常工作
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
)
assert result.output == "Answer"