fischer-agentkit/tests/unit/test_react_compression.py

455 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for ReAct engine compression integration (U3)"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
# ── Helpers ──────────────────────────────────────────
def make_mock_gateway() -> MagicMock:
"""创建一个 mock LLMGateway"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
response = LLMResponse(
content="Final answer",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(return_value=response)
return gateway
def make_mock_gateway_with_tool_call() -> MagicMock:
"""创建一个返回 tool_call 的 mock LLMGateway第二次调用返回最终答案
同时设置 chat 和 chat_stream使 execute 和 execute_stream 路径都能正常工作。
"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
# 第一次调用返回 tool_call第二次返回最终答案
tool_response = LLMResponse(
content="",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
tool_calls=[
ToolCall(id="call_1", name="search", arguments={"query": "test"}),
],
)
final_response = LLMResponse(
content="Final answer after tool",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(side_effect=[tool_response, final_response])
# ponytail: chat_stream yields StreamChunk equivalents of the chat responses
# so execute_stream (which uses chat_stream) exercises the same tool path.
tool_chunk = StreamChunk(
content="",
model="test-model",
tool_calls=[ToolCall(id="call_1", name="search", arguments={"query": "test"})],
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
is_final=True,
)
final_chunk = StreamChunk(
content="Final answer after tool",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
is_final=True,
)
async def _stream(**kwargs):
# Closure state tracks which response to yield (1st call=tool, 2nd=final)
_stream._call_count = getattr(_stream, "_call_count", 0) + 1
if _stream._call_count == 1:
yield tool_chunk
else:
yield final_chunk
gateway.chat_stream = _stream
return gateway
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
"""生成长消息列表用于测试压缩"""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append({
"role": "user",
"content": "x" * content_length + f" message {i}",
})
messages.append({
"role": "assistant",
"content": "y" * content_length + f" reply {i}",
})
return messages
def make_mock_compressor() -> MagicMock:
"""创建一个 mock CompressionStrategy"""
compressor = MagicMock(spec=CompressionStrategy)
compressor.compress = AsyncMock(return_value=[{"role": "user", "content": "compressed"}])
compressor.compress_tool_result = AsyncMock(return_value="compressed tool result")
compressor.is_available = MagicMock(return_value=True)
return compressor
# ── TestBuildToolResultMessage ────────────────────────
class TestBuildToolResultMessage:
"""_build_tool_result_message 方法测试"""
async def test_no_compressor_returns_original(self):
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message("tc_1", {"key": "value"})
assert result == {
"role": "tool",
"tool_call_id": "tc_1",
"content": "{'key': 'value'}",
}
async def test_with_compressor_calls_compress_tool_result(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
)
compressor.compress_tool_result.assert_called_once_with("search", {"key": "value"})
assert result["content"] == "compressed tool result"
assert result["role"] == "tool"
assert result["tool_call_id"] == "tc_1"
async def test_compressor_failure_falls_back(self):
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(side_effect=RuntimeError("compression error"))
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
)
# 应该回退到 str(result)
assert result["content"] == "{'key': 'value'}"
assert result["role"] == "tool"
async def test_compressor_receives_tool_name(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
await engine._build_tool_result_message(
"tc_1", "some result", compressor=compressor, tool_name="web_crawl"
)
compressor.compress_tool_result.assert_called_once_with("web_crawl", "some result")
async def test_compressor_without_tool_name_skips_compression(self):
"""compressor 存在但 tool_name 为 None 时不压缩"""
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", "some result", compressor=compressor, tool_name=None
)
compressor.compress_tool_result.assert_not_called()
assert result["content"] == "some result"
# ── TestShouldCompress ───────────────────────────────
class TestShouldCompress:
"""_should_compress 辅助方法测试"""
def test_no_compressor_returns_false(self):
engine = ReActEngine(llm_gateway=make_mock_gateway())
conversation = [{"role": "user", "content": "x" * 100000}]
assert engine._should_compress(conversation, None) is False
def test_short_conversation_returns_false(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
assert engine._should_compress(conversation, compressor) is False
def test_long_conversation_returns_true(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# 8000 tokens * 4 chars/token = 32000 chars needed
conversation = [{"role": "user", "content": "x" * 40000}]
assert engine._should_compress(conversation, compressor) is True
def test_boundary_at_threshold(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# Exactly 8000 tokens = 32000 chars → should NOT trigger (> not >=)
conversation = [{"role": "user", "content": "x" * 32000}]
assert engine._should_compress(conversation, compressor) is False
def test_just_above_threshold(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# 32001 chars → 32001//4 = 8000 tokens, still not > 8000
# 32004 chars → 32004//4 = 8001 tokens, > 8000
conversation = [{"role": "user", "content": "x" * 32004}]
assert engine._should_compress(conversation, compressor) is True
# ── TestReActLoopCompression ─────────────────────────
class TestReActLoopCompression:
"""ReAct 循环内压缩集成测试"""
async def test_tool_results_compressed_in_loop(self):
"""工具结果在循环中被压缩后拼入 conversation"""
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:search_result")
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 注册一个 mock tool
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="original search result")
result = await engine.execute(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=compressor,
)
# 验证 compress_tool_result 被调用
compressor.compress_tool_result.assert_called_once_with("search", "original search result")
async def test_incremental_compression_triggered(self):
"""长对话触发增量压缩 compressor.compress()"""
compressor = make_mock_compressor()
# 让 compress 返回压缩后的短对话
compressor.compress = AsyncMock(return_value=[
{"role": "system", "content": "Summary"},
{"role": "user", "content": "recent"},
])
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 构造一个很长的初始消息,使 conversation 超过阈值
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
result = await engine.execute(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
)
# 验证增量压缩被触发compress 被调用了至少一次)
# 初始 compress 在 L218-222增量 compress 在工具结果后
assert compressor.compress.call_count >= 1
async def test_incremental_compression_failure_handled(self):
"""compressor.compress() 异常时循环继续"""
compressor = make_mock_compressor()
# 第一次 compress 调用成功(初始压缩),增量压缩时失败
call_count = 0
async def compress_side_effect(messages):
nonlocal call_count
call_count += 1
if call_count > 1:
raise RuntimeError("Incremental compression failed")
return messages
compressor.compress = AsyncMock(side_effect=compress_side_effect)
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 构造长消息触发增量压缩
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
# 不应该抛出异常
result = await engine.execute(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
)
# 应该正常返回结果
assert result.output == "Final answer after tool"
async def test_no_compressor_backward_compatible(self):
"""compressor=None 时行为与之前完全一致"""
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="search result data")
result = await engine.execute(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=None,
)
assert result.output == "Final answer after tool"
assert result.status == "success"
assert len(result.trajectory) == 2 # 1 tool_call + 1 final_answer
async def test_execute_stream_with_compressor(self):
"""execute_stream 模式下压缩正常工作"""
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:result")
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="original result")
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=compressor,
):
events.append(event)
# 验证 compress_tool_result 被调用
compressor.compress_tool_result.assert_called_once_with("search", "original result")
# 验证事件流完整
event_types = [e.event_type for e in events]
assert "thinking" in event_types
assert "tool_call" in event_types
assert "tool_result" in event_types
assert "final_answer" in event_types
async def test_execute_stream_incremental_compression(self):
"""execute_stream 模式下增量压缩触发"""
compressor = make_mock_compressor()
compressor.compress = AsyncMock(return_value=[
{"role": "system", "content": "Summary"},
{"role": "user", "content": "recent"},
])
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
):
events.append(event)
# 验证增量压缩被触发
assert compressor.compress.call_count >= 1
async def test_context_compressor_satisfies_protocol(self):
"""ContextCompressor 满足 CompressionStrategy Protocol"""
compressor = ContextCompressor()
assert isinstance(compressor, CompressionStrategy)
async def test_context_compressor_compress_tool_result_default(self):
"""ContextCompressor.compress_tool_result 默认返回 str(result)"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", {"key": "value"})
assert result == "{'key': 'value'}"
# ── TestOTelSpanLifecycle (P0 fix: span leak) ──────────────────
class TestOTelSpanLifecycle:
"""测试 OTel span 生命周期 — 异常时 span 必须正确关闭"""
async def test_span_closed_on_success(self):
"""正常执行时 span 被正确关闭"""
gateway = make_mock_gateway()
engine = ReActEngine(llm_gateway=gateway)
mock_span = MagicMock()
mock_span_cm = MagicMock()
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
mock_span_cm.__exit__ = MagicMock(return_value=False)
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
patch("agentkit.core.react._OTEL_AVAILABLE", True):
await engine.execute(messages=[{"role": "user", "content": "hello"}])
# __exit__ should have been called
mock_span_cm.__exit__.assert_called_once()
async def test_span_closed_on_exception(self):
"""LLM 抛出异常时 span 仍被正确关闭"""
gateway = make_mock_gateway()
gateway.chat = AsyncMock(side_effect=RuntimeError("LLM error"))
engine = ReActEngine(llm_gateway=gateway)
mock_span = MagicMock()
mock_span_cm = MagicMock()
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
mock_span_cm.__exit__ = MagicMock(return_value=False)
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
patch("agentkit.core.react._OTEL_AVAILABLE", True):
with pytest.raises(RuntimeError, match="LLM error"):
await engine.execute(messages=[{"role": "user", "content": "hello"}])
# __exit__ must have been called even though exception was raised
mock_span_cm.__exit__.assert_called_once()
async def test_span_attributes_set_on_success(self):
"""正常执行时 span 属性被设置"""
gateway = make_mock_gateway()
engine = ReActEngine(llm_gateway=gateway)
mock_span = MagicMock()
mock_span_cm = MagicMock()
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
mock_span_cm.__exit__ = MagicMock(return_value=False)
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
patch("agentkit.core.react._OTEL_AVAILABLE", True):
await engine.execute(messages=[{"role": "user", "content": "hello"}])
# Verify span attributes were set
mock_span.set_attribute.assert_any_call("agent.total_steps", 1)
mock_span.set_attribute.assert_any_call("agent.total_tokens", 20)
mock_span.set_attribute.assert_any_call("agent.outcome", "success")
async def test_no_span_when_otel_unavailable(self):
"""_OTEL_AVAILABLE=False 时不创建 span"""
gateway = make_mock_gateway()
engine = ReActEngine(llm_gateway=gateway)
with patch("agentkit.core.react._OTEL_AVAILABLE", False), \
patch("agentkit.core.react.start_span") as mock_start_span:
await engine.execute(messages=[{"role": "user", "content": "hello"}])
# start_span should not be called when OTel is unavailable
mock_start_span.assert_not_called()