"""Tests for ContextCompressor and PromptTemplate cache""" import inspect import logging from unittest.mock import AsyncMock, MagicMock from agentkit.core.compressor import ContextCompressor, estimate_text_tokens from agentkit.llm.protocol import LLMResponse, TokenUsage from agentkit.prompts.section import PromptSection from agentkit.prompts.template import PromptTemplate # ── Helpers ────────────────────────────────────────── def make_mock_gateway(summary_content: str = "Summary of conversation") -> MagicMock: """创建一个 mock LLMGateway,返回摘要响应""" from agentkit.llm.gateway import LLMGateway gateway = MagicMock(spec=LLMGateway) response = LLMResponse( content=summary_content, model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway.chat = AsyncMock(return_value=response) return gateway def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]: """生成长消息列表用于测试压缩""" messages = [{"role": "system", "content": "You are a helpful assistant."}] for i in range(count): messages.append( { "role": "user", "content": "x" * content_length + f" message {i}", } ) messages.append( { "role": "assistant", "content": "y" * content_length + f" reply {i}", } ) return messages # ── ContextCompressor Tests ────────────────────────── class TestEstimateTokens: """estimate_tokens 基础测试""" def test_empty_messages(self): compressor = ContextCompressor() assert compressor.estimate_tokens([]) == 0 def test_single_message(self): compressor = ContextCompressor() messages = [{"role": "user", "content": "a" * 40}] # 40 chars / 4 = 10 tokens assert compressor.estimate_tokens(messages) == 10 def test_multiple_messages(self): compressor = ContextCompressor() messages = [ {"role": "user", "content": "a" * 40}, {"role": "assistant", "content": "b" * 80}, ] # 40/4 + 80/4 = 10 + 20 = 30 assert compressor.estimate_tokens(messages) == 30 def test_missing_content_key(self): compressor = ContextCompressor() messages = [{"role": "user"}] assert compressor.estimate_tokens(messages) == 0 class TestEstimateTextTokensCJK: """estimate_text_tokens CJK 估算测试 (U1)""" def test_pure_cjk_chinese(self): # 4 CJK chars = 4 tokens (1:1) assert estimate_text_tokens("你好世界") == 4 def test_pure_ascii(self): # 11 chars / 4 = 2.75, floor = 2 assert estimate_text_tokens("hello world") == 2 def test_pure_cjk_japanese_kana(self): # 5 Hiragana chars = 5 tokens (1:1) assert estimate_text_tokens("こんにちは") == 5 def test_pure_cjk_korean_hangul(self): # 5 Hangul chars = 5 tokens (1:1) assert estimate_text_tokens("안녕하세요") == 5 def test_mixed_cjk_and_ascii(self): # "你好" (2 CJK = 2 tokens) + " world" (6 ASCII = 1 token) = 3 assert estimate_text_tokens("你好 world") == 3 def test_empty_string(self): assert estimate_text_tokens("") == 0 def test_estimate_tokens_with_cjk_messages(self): """estimate_tokens() 对 CJK 消息不再低估 4 倍""" compressor = ContextCompressor() messages = [{"role": "user", "content": "你好世界"}] # 4 CJK = 4 tokens assert compressor.estimate_tokens(messages) == 4 def test_estimate_tokens_mixed_messages(self): """estimate_tokens() 对混合消息给出合理估值""" compressor = ContextCompressor() messages = [ {"role": "user", "content": "你好"}, # 2 CJK = 2 {"role": "assistant", "content": "hello"}, # 5 ASCII = 1 ] assert compressor.estimate_tokens(messages) == 3 def test_cjk_not_underestimated_4x(self): """AE1: 100 条 CJK 消息的 estimate_tokens >= 旧实现的 4 倍""" compressor = ContextCompressor() cjk_msg = [{"role": "user", "content": "你好" * 50}] # 100 CJK chars new_estimate = compressor.estimate_tokens(cjk_msg) old_estimate = len("你好" * 50) // 4 # old: len // 4 assert new_estimate >= old_estimate * 4 class TestNoCompressionWhenUnderBudget: """Token 预算内不压缩""" async def test_short_messages_not_compressed(self): compressor = ContextCompressor(max_tokens=10000) messages = [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] result = await compressor.compress(messages) assert result == messages async def test_exactly_at_budget_not_compressed(self): # 40 chars = 10 tokens, budget = 10 compressor = ContextCompressor(max_tokens=10) messages = [{"role": "user", "content": "a" * 40}] result = await compressor.compress(messages) assert result == messages class TestCompressionTriggersWhenOverBudget: """超出预算时触发压缩""" async def test_long_messages_get_compressed(self): gateway = make_mock_gateway("Compressed summary") compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = make_long_messages(count=5, content_length=500) result = await compressor.compress(messages) # 结果应该比原始消息少 assert len(result) < len(messages) # 应该包含系统消息 system_msgs = [m for m in result if m.get("role") == "system"] assert len(system_msgs) >= 1 # 应该保留最近的消息 assert result[-1]["role"] != "system" async def test_compression_preserves_system_messages(self): gateway = make_mock_gateway("Summary") compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = [ {"role": "system", "content": "System prompt"}, {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "c" * 2000}, {"role": "assistant", "content": "d" * 2000}, {"role": "user", "content": "Recent question"}, {"role": "assistant", "content": "Recent answer"}, ] result = await compressor.compress(messages) # 第一个消息应该是原始 system 消息 assert result[0]["content"] == "System prompt" assert result[0]["role"] == "system" async def test_compression_keeps_recent_messages(self): gateway = make_mock_gateway("Summary") compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = [ {"role": "system", "content": "System"}, {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent question"}, {"role": "assistant", "content": "Recent answer"}, ] result = await compressor.compress(messages) # 最后两条非系统消息应该是原始的最近消息 non_system = [m for m in result if m.get("role") != "system"] assert non_system[-2]["content"] == "Recent question" assert non_system[-1]["content"] == "Recent answer" class TestSummaryGenerationWithLLM: """LLM 摘要生成""" async def test_llm_summarization_called(self): gateway = make_mock_gateway("LLM generated summary") compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] result = await compressor.compress(messages) # LLM 应该被调用 gateway.chat.assert_called_once() # 摘要应出现在结果中 summary_msgs = [ m for m in result if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") ] assert len(summary_msgs) == 1 assert "LLM generated summary" in summary_msgs[0]["content"] class TestFallbackToSimpleSummary: """LLM 不可用时回退到简单摘要""" async def test_no_llm_uses_simple_summary(self): compressor = ContextCompressor( llm_gateway=None, max_tokens=100, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] result = await compressor.compress(messages) # 应该有摘要消息(简单截断模式) summary_msgs = [ m for m in result if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") ] assert len(summary_msgs) == 1 # 简单摘要应包含截断标记 assert "..." in summary_msgs[0]["content"] async def test_llm_failure_uses_simple_summary(self): gateway = make_mock_gateway() gateway.chat = AsyncMock(side_effect=Exception("LLM error")) compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] result = await compressor.compress(messages) # 应该有摘要消息(回退到简单摘要) summary_msgs = [ m for m in result if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") ] assert len(summary_msgs) == 1 class TestAggressiveCompression: """标准压缩后仍超预算时的激进压缩""" async def test_aggressive_compression_when_still_over_budget(self): # 极小的预算,即使压缩后也超 gateway = make_mock_gateway("x" * 5000) # 摘要本身也很长 compressor = ContextCompressor( llm_gateway=gateway, max_tokens=10, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 5000}, {"role": "assistant", "content": "b" * 5000}, {"role": "user", "content": "c" * 5000}, {"role": "assistant", "content": "d" * 5000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] result = await compressor.compress(messages) # 激进压缩应只保留最后一条非系统消息 non_system = [m for m in result if m.get("role") != "system"] # 激进压缩后最多保留 1 条非系统消息 assert len(non_system) <= 1 class TestTruncation: """截断作为最后手段""" def test_truncate_long_messages(self): compressor = ContextCompressor(max_tokens=50) messages = [ {"role": "system", "content": "a" * 500}, {"role": "user", "content": "b" * 500}, ] result = compressor._truncate(messages) # 长消息应该被截断 for msg in result: content = msg.get("content", "") if len(content) > 100 + len("...[truncated]"): # 只有超长消息才截断 assert content.endswith("...[truncated]") def test_truncate_preserves_short_messages(self): compressor = ContextCompressor(max_tokens=50) messages = [ {"role": "user", "content": "Short message"}, ] result = compressor._truncate(messages) assert result[0]["content"] == "Short message" class TestCompressLinearFlow: """U3: compress() 线性流程 + 签名变更测试""" def test_compress_signature_no_compression_depth(self): """compress() 不再接受 _compression_depth 参数""" sig = inspect.signature(ContextCompressor.compress) assert "_compression_depth" not in sig.parameters def test_compress_aggressive_signature_no_compression_depth(self): """_compress_aggressive() 不再接受 _compression_depth 参数""" sig = inspect.signature(ContextCompressor._compress_aggressive) assert "_compression_depth" not in sig.parameters async def test_short_messages_not_compressed_linear(self): """短消息不压缩(线性流程验证)""" compressor = ContextCompressor(max_tokens=10000) messages = [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] result = await compressor.compress(messages) assert result == messages async def test_aggressive_receives_original_messages(self): """F-010: _compress_aggressive 接收 original messages, 非 compressed""" # First summary is very long (triggers aggressive), second is short long_summary = LLMResponse( content="x" * 5000, model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) short_summary = LLMResponse( content="short summary", model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway = MagicMock() gateway.chat = AsyncMock(side_effect=[long_summary, short_summary]) compressor = ContextCompressor( llm_gateway=gateway, max_tokens=10, keep_recent=2, ) messages = [ {"role": "user", "content": "ORIGINAL_MARKER_a" * 2000}, {"role": "assistant", "content": "ORIGINAL_MARKER_b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] await compressor.compress(messages) # Second call (aggressive) should receive original message content, # not the first summary ("x" * 5000) assert gateway.chat.call_count == 2 second_call_content = gateway.chat.call_args_list[1].kwargs["messages"][0]["content"] assert "ORIGINAL_MARKER" in second_call_content # First summary content should NOT appear in the aggressive call assert "xxxx" not in second_call_content async def test_truncate_triggered_when_aggressive_insufficient(self): """aggressive 后仍超阈值 → truncate 强制截断""" # Both summaries are very long, forcing truncate as last resort long_summary = LLMResponse( content="z" * 5000, model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway = MagicMock() gateway.chat = AsyncMock(side_effect=[long_summary, long_summary]) compressor = ContextCompressor( llm_gateway=gateway, max_tokens=10, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 5000}, {"role": "assistant", "content": "b" * 5000}, {"role": "user", "content": "c" * 5000}, {"role": "assistant", "content": "d" * 5000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] result = await compressor.compress(messages) # Truncate should have cut message content total_chars = sum(len(str(m.get("content", ""))) for m in result) assert total_chars < sum(len(str(m.get("content", ""))) for m in messages) class TestCompressionLogging: """U3: _log_compression 结构化日志测试""" async def test_log_compression_outputs_structured_info(self, caplog): """_log_compression 输出结构化日志(包含 tokens/ratio/strategy)""" gateway = make_mock_gateway("Summary") compressor = ContextCompressor( llm_gateway=gateway, max_tokens=100, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"): await compressor.compress(messages) # 验证结构化日志包含压缩信息 log_messages = [record.message for record in caplog.records] assert any("context compressed" in msg for msg in log_messages) assert any("strategy: summary" in msg for msg in log_messages) # 日志应包含 token 数量和消息数量 assert any("tokens" in msg for msg in log_messages) assert any("messages:" in msg for msg in log_messages) async def test_no_log_when_not_compressed(self, caplog): """未触发压缩时不输出日志""" compressor = ContextCompressor(max_tokens=10000) messages = [ {"role": "user", "content": "Hello"}, ] with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"): await compressor.compress(messages) log_messages = [record.message for record in caplog.records] assert not any("context compressed" in msg for msg in log_messages) async def test_log_strategy_aggressive(self, caplog): """压缩策略为 aggressive 时日志记录正确""" long_summary = LLMResponse( content="x" * 5000, model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) short_summary = LLMResponse( content="short", model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway = MagicMock() gateway.chat = AsyncMock(side_effect=[long_summary, short_summary]) compressor = ContextCompressor( llm_gateway=gateway, max_tokens=10, keep_recent=2, ) messages = [ {"role": "user", "content": "a" * 2000}, {"role": "assistant", "content": "b" * 2000}, {"role": "user", "content": "Recent"}, {"role": "assistant", "content": "Reply"}, ] with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"): await compressor.compress(messages) log_messages = [record.message for record in caplog.records] assert any("strategy: aggressive" in msg for msg in log_messages) class TestNotEnoughMessagesToCompress: """消息数量不足时跳过压缩""" async def test_fewer_than_keep_recent_messages(self): compressor = ContextCompressor( max_tokens=10, keep_recent=5, ) messages = [ {"role": "user", "content": "a" * 200}, {"role": "assistant", "content": "b" * 200}, ] # 非系统消息只有 2 条,keep_recent=5,不压缩 result = await compressor.compress(messages) assert result == messages # ── PromptTemplate Cache Tests ─────────────────────── class TestPromptTemplateRenderCached: """render_cached() 缓存测试""" def test_same_variables_returns_cached_result(self): section = PromptSection( identity="Bot", context="Hello ${name}", ) tpl = PromptTemplate(sections=section) result1 = tpl.render_cached(variables={"name": "Alice"}) result2 = tpl.render_cached(variables={"name": "Alice"}) assert result1 == result2 # 应该是同一个对象(缓存命中) assert result1 is result2 def test_different_variables_re_renders(self): section = PromptSection( context="Hello ${name}", ) tpl = PromptTemplate(sections=section) result1 = tpl.render_cached(variables={"name": "Alice"}) result2 = tpl.render_cached(variables={"name": "Bob"}) assert result1 != result2 assert "Alice" in result1[0]["content"] assert "Bob" in result2[0]["content"] def test_no_variables_cached(self): section = PromptSection(identity="Bot") tpl = PromptTemplate(sections=section) result1 = tpl.render_cached() result2 = tpl.render_cached() assert result1 is result2 def test_render_cached_matches_render(self): section = PromptSection( identity="Bot", context="Hello ${name}", ) tpl = PromptTemplate(sections=section) cached = tpl.render_cached(variables={"name": "Alice"}) direct = tpl.render(variables={"name": "Alice"}) assert cached == direct class TestPromptTemplateClearCache: """clear_cache() 测试""" def test_clear_cache_works(self): section = PromptSection( context="Hello ${name}", ) tpl = PromptTemplate(sections=section) result1 = tpl.render_cached(variables={"name": "Alice"}) tpl.clear_cache() result2 = tpl.render_cached(variables={"name": "Alice"}) # 清除缓存后应该重新渲染,不再是同一对象 assert result1 == result2 assert result1 is not result2 def test_clear_cache_on_fresh_template(self): """对没有缓存的新模板调用 clear_cache 不报错""" section = PromptSection(identity="Bot") tpl = PromptTemplate(sections=section) tpl.clear_cache() # 应该不抛异常 class TestReActEngineWithCompressor: """ReActEngine 集成 ContextCompressor 测试""" async def test_execute_with_compressor(self): from agentkit.core.compressor import ContextCompressor from agentkit.core.react import ReActEngine from agentkit.llm.protocol import LLMResponse, TokenUsage gateway = MagicMock() gateway.chat = AsyncMock( return_value=LLMResponse( content="Final answer", model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) ) compressor = ContextCompressor(max_tokens=10000) engine = ReActEngine(llm_gateway=gateway) result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], compressor=compressor, ) assert result.output == "Final answer" async def test_execute_without_compressor_backward_compatible(self): from agentkit.core.react import ReActEngine from agentkit.llm.protocol import LLMResponse, TokenUsage gateway = MagicMock() gateway.chat = AsyncMock( return_value=LLMResponse( content="Answer", model="test", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) ) engine = ReActEngine(llm_gateway=gateway) # 不传 compressor 应该正常工作 result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], ) assert result.output == "Answer"