657 lines
24 KiB
Python
657 lines
24 KiB
Python
"""Tests for ContextCompressor and PromptTemplate cache"""
|
||
|
||
import inspect
|
||
import logging
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.core.compressor import ContextCompressor, estimate_text_tokens
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
from agentkit.prompts.section import PromptSection
|
||
from agentkit.prompts.template import PromptTemplate
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────
|
||
|
||
|
||
def make_mock_gateway(summary_content: str = "Summary of conversation") -> MagicMock:
|
||
"""创建一个 mock LLMGateway,返回摘要响应"""
|
||
from agentkit.llm.gateway import LLMGateway
|
||
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
response = LLMResponse(
|
||
content=summary_content,
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway.chat = AsyncMock(return_value=response)
|
||
return gateway
|
||
|
||
|
||
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
|
||
"""生成长消息列表用于测试压缩"""
|
||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||
for i in range(count):
|
||
messages.append(
|
||
{
|
||
"role": "user",
|
||
"content": "x" * content_length + f" message {i}",
|
||
}
|
||
)
|
||
messages.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": "y" * content_length + f" reply {i}",
|
||
}
|
||
)
|
||
return messages
|
||
|
||
|
||
# ── ContextCompressor Tests ──────────────────────────
|
||
|
||
|
||
class TestEstimateTokens:
|
||
"""estimate_tokens 基础测试"""
|
||
|
||
def test_empty_messages(self):
|
||
compressor = ContextCompressor()
|
||
assert compressor.estimate_tokens([]) == 0
|
||
|
||
def test_single_message(self):
|
||
compressor = ContextCompressor()
|
||
messages = [{"role": "user", "content": "a" * 40}]
|
||
# 40 chars / 4 = 10 tokens
|
||
assert compressor.estimate_tokens(messages) == 10
|
||
|
||
def test_multiple_messages(self):
|
||
compressor = ContextCompressor()
|
||
messages = [
|
||
{"role": "user", "content": "a" * 40},
|
||
{"role": "assistant", "content": "b" * 80},
|
||
]
|
||
# 40/4 + 80/4 = 10 + 20 = 30
|
||
assert compressor.estimate_tokens(messages) == 30
|
||
|
||
def test_missing_content_key(self):
|
||
compressor = ContextCompressor()
|
||
messages = [{"role": "user"}]
|
||
assert compressor.estimate_tokens(messages) == 0
|
||
|
||
|
||
class TestEstimateTextTokensCJK:
|
||
"""estimate_text_tokens CJK 估算测试 (U1)"""
|
||
|
||
def test_pure_cjk_chinese(self):
|
||
# 4 CJK chars = 4 tokens (1:1)
|
||
assert estimate_text_tokens("你好世界") == 4
|
||
|
||
def test_pure_ascii(self):
|
||
# 11 chars / 4 = 2.75, floor = 2
|
||
assert estimate_text_tokens("hello world") == 2
|
||
|
||
def test_pure_cjk_japanese_kana(self):
|
||
# 5 Hiragana chars = 5 tokens (1:1)
|
||
assert estimate_text_tokens("こんにちは") == 5
|
||
|
||
def test_pure_cjk_korean_hangul(self):
|
||
# 5 Hangul chars = 5 tokens (1:1)
|
||
assert estimate_text_tokens("안녕하세요") == 5
|
||
|
||
def test_mixed_cjk_and_ascii(self):
|
||
# "你好" (2 CJK = 2 tokens) + " world" (6 ASCII = 1 token) = 3
|
||
assert estimate_text_tokens("你好 world") == 3
|
||
|
||
def test_empty_string(self):
|
||
assert estimate_text_tokens("") == 0
|
||
|
||
def test_estimate_tokens_with_cjk_messages(self):
|
||
"""estimate_tokens() 对 CJK 消息不再低估 4 倍"""
|
||
compressor = ContextCompressor()
|
||
messages = [{"role": "user", "content": "你好世界"}] # 4 CJK = 4 tokens
|
||
assert compressor.estimate_tokens(messages) == 4
|
||
|
||
def test_estimate_tokens_mixed_messages(self):
|
||
"""estimate_tokens() 对混合消息给出合理估值"""
|
||
compressor = ContextCompressor()
|
||
messages = [
|
||
{"role": "user", "content": "你好"}, # 2 CJK = 2
|
||
{"role": "assistant", "content": "hello"}, # 5 ASCII = 1
|
||
]
|
||
assert compressor.estimate_tokens(messages) == 3
|
||
|
||
def test_cjk_not_underestimated_4x(self):
|
||
"""AE1: 100 条 CJK 消息的 estimate_tokens >= 旧实现的 4 倍"""
|
||
compressor = ContextCompressor()
|
||
cjk_msg = [{"role": "user", "content": "你好" * 50}] # 100 CJK chars
|
||
new_estimate = compressor.estimate_tokens(cjk_msg)
|
||
old_estimate = len("你好" * 50) // 4 # old: len // 4
|
||
assert new_estimate >= old_estimate * 4
|
||
|
||
|
||
class TestNoCompressionWhenUnderBudget:
|
||
"""Token 预算内不压缩"""
|
||
|
||
async def test_short_messages_not_compressed(self):
|
||
compressor = ContextCompressor(max_tokens=10000)
|
||
messages = [
|
||
{"role": "system", "content": "You are helpful."},
|
||
{"role": "user", "content": "Hello"},
|
||
{"role": "assistant", "content": "Hi there!"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
async def test_exactly_at_budget_not_compressed(self):
|
||
# 40 chars = 10 tokens, budget = 10
|
||
compressor = ContextCompressor(max_tokens=10)
|
||
messages = [{"role": "user", "content": "a" * 40}]
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
|
||
class TestCompressionTriggersWhenOverBudget:
|
||
"""超出预算时触发压缩"""
|
||
|
||
async def test_long_messages_get_compressed(self):
|
||
gateway = make_mock_gateway("Compressed summary")
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = make_long_messages(count=5, content_length=500)
|
||
result = await compressor.compress(messages)
|
||
|
||
# 结果应该比原始消息少
|
||
assert len(result) < len(messages)
|
||
# 应该包含系统消息
|
||
system_msgs = [m for m in result if m.get("role") == "system"]
|
||
assert len(system_msgs) >= 1
|
||
# 应该保留最近的消息
|
||
assert result[-1]["role"] != "system"
|
||
|
||
async def test_compression_preserves_system_messages(self):
|
||
gateway = make_mock_gateway("Summary")
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "system", "content": "System prompt"},
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "c" * 2000},
|
||
{"role": "assistant", "content": "d" * 2000},
|
||
{"role": "user", "content": "Recent question"},
|
||
{"role": "assistant", "content": "Recent answer"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# 第一个消息应该是原始 system 消息
|
||
assert result[0]["content"] == "System prompt"
|
||
assert result[0]["role"] == "system"
|
||
|
||
async def test_compression_keeps_recent_messages(self):
|
||
gateway = make_mock_gateway("Summary")
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "system", "content": "System"},
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent question"},
|
||
{"role": "assistant", "content": "Recent answer"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# 最后两条非系统消息应该是原始的最近消息
|
||
non_system = [m for m in result if m.get("role") != "system"]
|
||
assert non_system[-2]["content"] == "Recent question"
|
||
assert non_system[-1]["content"] == "Recent answer"
|
||
|
||
|
||
class TestSummaryGenerationWithLLM:
|
||
"""LLM 摘要生成"""
|
||
|
||
async def test_llm_summarization_called(self):
|
||
gateway = make_mock_gateway("LLM generated summary")
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# LLM 应该被调用
|
||
gateway.chat.assert_called_once()
|
||
# 摘要应出现在结果中
|
||
summary_msgs = [
|
||
m
|
||
for m in result
|
||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||
]
|
||
assert len(summary_msgs) == 1
|
||
assert "LLM generated summary" in summary_msgs[0]["content"]
|
||
|
||
|
||
class TestFallbackToSimpleSummary:
|
||
"""LLM 不可用时回退到简单摘要"""
|
||
|
||
async def test_no_llm_uses_simple_summary(self):
|
||
compressor = ContextCompressor(
|
||
llm_gateway=None,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# 应该有摘要消息(简单截断模式)
|
||
summary_msgs = [
|
||
m
|
||
for m in result
|
||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||
]
|
||
assert len(summary_msgs) == 1
|
||
# 简单摘要应包含截断标记
|
||
assert "..." in summary_msgs[0]["content"]
|
||
|
||
async def test_llm_failure_uses_simple_summary(self):
|
||
gateway = make_mock_gateway()
|
||
gateway.chat = AsyncMock(side_effect=Exception("LLM error"))
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# 应该有摘要消息(回退到简单摘要)
|
||
summary_msgs = [
|
||
m
|
||
for m in result
|
||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||
]
|
||
assert len(summary_msgs) == 1
|
||
|
||
|
||
class TestAggressiveCompression:
|
||
"""标准压缩后仍超预算时的激进压缩"""
|
||
|
||
async def test_aggressive_compression_when_still_over_budget(self):
|
||
# 极小的预算,即使压缩后也超
|
||
gateway = make_mock_gateway("x" * 5000) # 摘要本身也很长
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=10,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 5000},
|
||
{"role": "assistant", "content": "b" * 5000},
|
||
{"role": "user", "content": "c" * 5000},
|
||
{"role": "assistant", "content": "d" * 5000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
|
||
# 激进压缩应只保留最后一条非系统消息
|
||
non_system = [m for m in result if m.get("role") != "system"]
|
||
# 激进压缩后最多保留 1 条非系统消息
|
||
assert len(non_system) <= 1
|
||
|
||
|
||
class TestTruncation:
|
||
"""截断作为最后手段"""
|
||
|
||
def test_truncate_long_messages(self):
|
||
compressor = ContextCompressor(max_tokens=50)
|
||
messages = [
|
||
{"role": "system", "content": "a" * 500},
|
||
{"role": "user", "content": "b" * 500},
|
||
]
|
||
result = compressor._truncate(messages)
|
||
|
||
# 长消息应该被截断
|
||
for msg in result:
|
||
content = msg.get("content", "")
|
||
if len(content) > 100 + len("...[truncated]"):
|
||
# 只有超长消息才截断
|
||
assert content.endswith("...[truncated]")
|
||
|
||
def test_truncate_preserves_short_messages(self):
|
||
compressor = ContextCompressor(max_tokens=50)
|
||
messages = [
|
||
{"role": "user", "content": "Short message"},
|
||
]
|
||
result = compressor._truncate(messages)
|
||
assert result[0]["content"] == "Short message"
|
||
|
||
|
||
class TestCompressLinearFlow:
|
||
"""U3: compress() 线性流程 + 签名变更测试"""
|
||
|
||
def test_compress_signature_no_compression_depth(self):
|
||
"""compress() 不再接受 _compression_depth 参数"""
|
||
sig = inspect.signature(ContextCompressor.compress)
|
||
assert "_compression_depth" not in sig.parameters
|
||
|
||
def test_compress_aggressive_signature_no_compression_depth(self):
|
||
"""_compress_aggressive() 不再接受 _compression_depth 参数"""
|
||
sig = inspect.signature(ContextCompressor._compress_aggressive)
|
||
assert "_compression_depth" not in sig.parameters
|
||
|
||
async def test_short_messages_not_compressed_linear(self):
|
||
"""短消息不压缩(线性流程验证)"""
|
||
compressor = ContextCompressor(max_tokens=10000)
|
||
messages = [
|
||
{"role": "system", "content": "You are helpful."},
|
||
{"role": "user", "content": "Hello"},
|
||
{"role": "assistant", "content": "Hi there!"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
async def test_aggressive_receives_original_messages(self):
|
||
"""F-010: _compress_aggressive 接收 original messages, 非 compressed"""
|
||
# First summary is very long (triggers aggressive), second is short
|
||
long_summary = LLMResponse(
|
||
content="x" * 5000,
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
short_summary = LLMResponse(
|
||
content="short summary",
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(side_effect=[long_summary, short_summary])
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=10,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "ORIGINAL_MARKER_a" * 2000},
|
||
{"role": "assistant", "content": "ORIGINAL_MARKER_b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
await compressor.compress(messages)
|
||
|
||
# Second call (aggressive) should receive original message content,
|
||
# not the first summary ("x" * 5000)
|
||
assert gateway.chat.call_count == 2
|
||
second_call_content = gateway.chat.call_args_list[1].kwargs["messages"][0]["content"]
|
||
assert "ORIGINAL_MARKER" in second_call_content
|
||
# First summary content should NOT appear in the aggressive call
|
||
assert "xxxx" not in second_call_content
|
||
|
||
async def test_truncate_triggered_when_aggressive_insufficient(self):
|
||
"""aggressive 后仍超阈值 → truncate 强制截断"""
|
||
# Both summaries are very long, forcing truncate as last resort
|
||
long_summary = LLMResponse(
|
||
content="z" * 5000,
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(side_effect=[long_summary, long_summary])
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=10,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 5000},
|
||
{"role": "assistant", "content": "b" * 5000},
|
||
{"role": "user", "content": "c" * 5000},
|
||
{"role": "assistant", "content": "d" * 5000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
result = await compressor.compress(messages)
|
||
# Truncate should have cut message content
|
||
total_chars = sum(len(str(m.get("content", ""))) for m in result)
|
||
assert total_chars < sum(len(str(m.get("content", ""))) for m in messages)
|
||
|
||
|
||
class TestCompressionLogging:
|
||
"""U3: _log_compression 结构化日志测试"""
|
||
|
||
async def test_log_compression_outputs_structured_info(self, caplog):
|
||
"""_log_compression 输出结构化日志(包含 tokens/ratio/strategy)"""
|
||
gateway = make_mock_gateway("Summary")
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=100,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
|
||
await compressor.compress(messages)
|
||
|
||
# 验证结构化日志包含压缩信息
|
||
log_messages = [record.message for record in caplog.records]
|
||
assert any("context compressed" in msg for msg in log_messages)
|
||
assert any("strategy: summary" in msg for msg in log_messages)
|
||
# 日志应包含 token 数量和消息数量
|
||
assert any("tokens" in msg for msg in log_messages)
|
||
assert any("messages:" in msg for msg in log_messages)
|
||
|
||
async def test_no_log_when_not_compressed(self, caplog):
|
||
"""未触发压缩时不输出日志"""
|
||
compressor = ContextCompressor(max_tokens=10000)
|
||
messages = [
|
||
{"role": "user", "content": "Hello"},
|
||
]
|
||
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
|
||
await compressor.compress(messages)
|
||
|
||
log_messages = [record.message for record in caplog.records]
|
||
assert not any("context compressed" in msg for msg in log_messages)
|
||
|
||
async def test_log_strategy_aggressive(self, caplog):
|
||
"""压缩策略为 aggressive 时日志记录正确"""
|
||
long_summary = LLMResponse(
|
||
content="x" * 5000,
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
short_summary = LLMResponse(
|
||
content="short",
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(side_effect=[long_summary, short_summary])
|
||
compressor = ContextCompressor(
|
||
llm_gateway=gateway,
|
||
max_tokens=10,
|
||
keep_recent=2,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 2000},
|
||
{"role": "assistant", "content": "b" * 2000},
|
||
{"role": "user", "content": "Recent"},
|
||
{"role": "assistant", "content": "Reply"},
|
||
]
|
||
with caplog.at_level(logging.INFO, logger="agentkit.core.compressor"):
|
||
await compressor.compress(messages)
|
||
|
||
log_messages = [record.message for record in caplog.records]
|
||
assert any("strategy: aggressive" in msg for msg in log_messages)
|
||
|
||
|
||
class TestNotEnoughMessagesToCompress:
|
||
"""消息数量不足时跳过压缩"""
|
||
|
||
async def test_fewer_than_keep_recent_messages(self):
|
||
compressor = ContextCompressor(
|
||
max_tokens=10,
|
||
keep_recent=5,
|
||
)
|
||
messages = [
|
||
{"role": "user", "content": "a" * 200},
|
||
{"role": "assistant", "content": "b" * 200},
|
||
]
|
||
# 非系统消息只有 2 条,keep_recent=5,不压缩
|
||
result = await compressor.compress(messages)
|
||
assert result == messages
|
||
|
||
|
||
# ── PromptTemplate Cache Tests ───────────────────────
|
||
|
||
|
||
class TestPromptTemplateRenderCached:
|
||
"""render_cached() 缓存测试"""
|
||
|
||
def test_same_variables_returns_cached_result(self):
|
||
section = PromptSection(
|
||
identity="Bot",
|
||
context="Hello ${name}",
|
||
)
|
||
tpl = PromptTemplate(sections=section)
|
||
|
||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||
result2 = tpl.render_cached(variables={"name": "Alice"})
|
||
|
||
assert result1 == result2
|
||
# 应该是同一个对象(缓存命中)
|
||
assert result1 is result2
|
||
|
||
def test_different_variables_re_renders(self):
|
||
section = PromptSection(
|
||
context="Hello ${name}",
|
||
)
|
||
tpl = PromptTemplate(sections=section)
|
||
|
||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||
result2 = tpl.render_cached(variables={"name": "Bob"})
|
||
|
||
assert result1 != result2
|
||
assert "Alice" in result1[0]["content"]
|
||
assert "Bob" in result2[0]["content"]
|
||
|
||
def test_no_variables_cached(self):
|
||
section = PromptSection(identity="Bot")
|
||
tpl = PromptTemplate(sections=section)
|
||
|
||
result1 = tpl.render_cached()
|
||
result2 = tpl.render_cached()
|
||
|
||
assert result1 is result2
|
||
|
||
def test_render_cached_matches_render(self):
|
||
section = PromptSection(
|
||
identity="Bot",
|
||
context="Hello ${name}",
|
||
)
|
||
tpl = PromptTemplate(sections=section)
|
||
|
||
cached = tpl.render_cached(variables={"name": "Alice"})
|
||
direct = tpl.render(variables={"name": "Alice"})
|
||
|
||
assert cached == direct
|
||
|
||
|
||
class TestPromptTemplateClearCache:
|
||
"""clear_cache() 测试"""
|
||
|
||
def test_clear_cache_works(self):
|
||
section = PromptSection(
|
||
context="Hello ${name}",
|
||
)
|
||
tpl = PromptTemplate(sections=section)
|
||
|
||
result1 = tpl.render_cached(variables={"name": "Alice"})
|
||
tpl.clear_cache()
|
||
result2 = tpl.render_cached(variables={"name": "Alice"})
|
||
|
||
# 清除缓存后应该重新渲染,不再是同一对象
|
||
assert result1 == result2
|
||
assert result1 is not result2
|
||
|
||
def test_clear_cache_on_fresh_template(self):
|
||
"""对没有缓存的新模板调用 clear_cache 不报错"""
|
||
section = PromptSection(identity="Bot")
|
||
tpl = PromptTemplate(sections=section)
|
||
tpl.clear_cache() # 应该不抛异常
|
||
|
||
|
||
class TestReActEngineWithCompressor:
|
||
"""ReActEngine 集成 ContextCompressor 测试"""
|
||
|
||
async def test_execute_with_compressor(self):
|
||
from agentkit.core.compressor import ContextCompressor
|
||
from agentkit.core.react import ReActEngine
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content="Final answer",
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
)
|
||
|
||
compressor = ContextCompressor(max_tokens=10000)
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
compressor=compressor,
|
||
)
|
||
|
||
assert result.output == "Final answer"
|
||
|
||
async def test_execute_without_compressor_backward_compatible(self):
|
||
from agentkit.core.react import ReActEngine
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content="Answer",
|
||
model="test",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
)
|
||
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
# 不传 compressor 应该正常工作
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
)
|
||
|
||
assert result.output == "Answer"
|