455 lines
18 KiB
Python
455 lines
18 KiB
Python
"""Tests for ReAct engine compression integration (U3)"""
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
|
||
from agentkit.core.react import ReActEngine
|
||
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────
|
||
|
||
|
||
def make_mock_gateway() -> MagicMock:
|
||
"""创建一个 mock LLMGateway"""
|
||
from agentkit.llm.gateway import LLMGateway
|
||
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
response = LLMResponse(
|
||
content="Final answer",
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway.chat = AsyncMock(return_value=response)
|
||
return gateway
|
||
|
||
|
||
def make_mock_gateway_with_tool_call() -> MagicMock:
|
||
"""创建一个返回 tool_call 的 mock LLMGateway,第二次调用返回最终答案
|
||
|
||
同时设置 chat 和 chat_stream,使 execute 和 execute_stream 路径都能正常工作。
|
||
"""
|
||
from agentkit.llm.gateway import LLMGateway
|
||
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
|
||
# 第一次调用返回 tool_call,第二次返回最终答案
|
||
tool_response = LLMResponse(
|
||
content="",
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
tool_calls=[
|
||
ToolCall(id="call_1", name="search", arguments={"query": "test"}),
|
||
],
|
||
)
|
||
final_response = LLMResponse(
|
||
content="Final answer after tool",
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
gateway.chat = AsyncMock(side_effect=[tool_response, final_response])
|
||
|
||
# ponytail: chat_stream yields StreamChunk equivalents of the chat responses
|
||
# so execute_stream (which uses chat_stream) exercises the same tool path.
|
||
tool_chunk = StreamChunk(
|
||
content="",
|
||
model="test-model",
|
||
tool_calls=[ToolCall(id="call_1", name="search", arguments={"query": "test"})],
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
is_final=True,
|
||
)
|
||
final_chunk = StreamChunk(
|
||
content="Final answer after tool",
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
is_final=True,
|
||
)
|
||
|
||
async def _stream(**kwargs):
|
||
# Closure state tracks which response to yield (1st call=tool, 2nd=final)
|
||
_stream._call_count = getattr(_stream, "_call_count", 0) + 1
|
||
if _stream._call_count == 1:
|
||
yield tool_chunk
|
||
else:
|
||
yield final_chunk
|
||
|
||
gateway.chat_stream = _stream
|
||
return gateway
|
||
|
||
|
||
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
|
||
"""生成长消息列表用于测试压缩"""
|
||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||
for i in range(count):
|
||
messages.append({
|
||
"role": "user",
|
||
"content": "x" * content_length + f" message {i}",
|
||
})
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": "y" * content_length + f" reply {i}",
|
||
})
|
||
return messages
|
||
|
||
|
||
def make_mock_compressor() -> MagicMock:
|
||
"""创建一个 mock CompressionStrategy"""
|
||
compressor = MagicMock(spec=CompressionStrategy)
|
||
compressor.compress = AsyncMock(return_value=[{"role": "user", "content": "compressed"}])
|
||
compressor.compress_tool_result = AsyncMock(return_value="compressed tool result")
|
||
compressor.is_available = MagicMock(return_value=True)
|
||
return compressor
|
||
|
||
|
||
# ── TestBuildToolResultMessage ────────────────────────
|
||
|
||
|
||
class TestBuildToolResultMessage:
|
||
"""_build_tool_result_message 方法测试"""
|
||
|
||
async def test_no_compressor_returns_original(self):
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
result = await engine._build_tool_result_message("tc_1", {"key": "value"})
|
||
assert result == {
|
||
"role": "tool",
|
||
"tool_call_id": "tc_1",
|
||
"content": "{'key': 'value'}",
|
||
}
|
||
|
||
async def test_with_compressor_calls_compress_tool_result(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
result = await engine._build_tool_result_message(
|
||
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
|
||
)
|
||
compressor.compress_tool_result.assert_called_once_with("search", {"key": "value"})
|
||
assert result["content"] == "compressed tool result"
|
||
assert result["role"] == "tool"
|
||
assert result["tool_call_id"] == "tc_1"
|
||
|
||
async def test_compressor_failure_falls_back(self):
|
||
compressor = make_mock_compressor()
|
||
compressor.compress_tool_result = AsyncMock(side_effect=RuntimeError("compression error"))
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
result = await engine._build_tool_result_message(
|
||
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
|
||
)
|
||
# 应该回退到 str(result)
|
||
assert result["content"] == "{'key': 'value'}"
|
||
assert result["role"] == "tool"
|
||
|
||
async def test_compressor_receives_tool_name(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
await engine._build_tool_result_message(
|
||
"tc_1", "some result", compressor=compressor, tool_name="web_crawl"
|
||
)
|
||
compressor.compress_tool_result.assert_called_once_with("web_crawl", "some result")
|
||
|
||
async def test_compressor_without_tool_name_skips_compression(self):
|
||
"""compressor 存在但 tool_name 为 None 时不压缩"""
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
result = await engine._build_tool_result_message(
|
||
"tc_1", "some result", compressor=compressor, tool_name=None
|
||
)
|
||
compressor.compress_tool_result.assert_not_called()
|
||
assert result["content"] == "some result"
|
||
|
||
|
||
# ── TestShouldCompress ───────────────────────────────
|
||
|
||
|
||
class TestShouldCompress:
|
||
"""_should_compress 辅助方法测试"""
|
||
|
||
def test_no_compressor_returns_false(self):
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
conversation = [{"role": "user", "content": "x" * 100000}]
|
||
assert engine._should_compress(conversation, None) is False
|
||
|
||
def test_short_conversation_returns_false(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
conversation = [
|
||
{"role": "user", "content": "Hello"},
|
||
{"role": "assistant", "content": "Hi there"},
|
||
]
|
||
assert engine._should_compress(conversation, compressor) is False
|
||
|
||
def test_long_conversation_returns_true(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
# 8000 tokens * 4 chars/token = 32000 chars needed
|
||
conversation = [{"role": "user", "content": "x" * 40000}]
|
||
assert engine._should_compress(conversation, compressor) is True
|
||
|
||
def test_boundary_at_threshold(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
# Exactly 8000 tokens = 32000 chars → should NOT trigger (> not >=)
|
||
conversation = [{"role": "user", "content": "x" * 32000}]
|
||
assert engine._should_compress(conversation, compressor) is False
|
||
|
||
def test_just_above_threshold(self):
|
||
compressor = make_mock_compressor()
|
||
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
||
# 32001 chars → 32001//4 = 8000 tokens, still not > 8000
|
||
# 32004 chars → 32004//4 = 8001 tokens, > 8000
|
||
conversation = [{"role": "user", "content": "x" * 32004}]
|
||
assert engine._should_compress(conversation, compressor) is True
|
||
|
||
|
||
# ── TestReActLoopCompression ─────────────────────────
|
||
|
||
|
||
class TestReActLoopCompression:
|
||
"""ReAct 循环内压缩集成测试"""
|
||
|
||
async def test_tool_results_compressed_in_loop(self):
|
||
"""工具结果在循环中被压缩后拼入 conversation"""
|
||
compressor = make_mock_compressor()
|
||
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:search_result")
|
||
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
# 注册一个 mock tool
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="original search result")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Search for test"}],
|
||
tools=[mock_tool],
|
||
compressor=compressor,
|
||
)
|
||
|
||
# 验证 compress_tool_result 被调用
|
||
compressor.compress_tool_result.assert_called_once_with("search", "original search result")
|
||
|
||
async def test_incremental_compression_triggered(self):
|
||
"""长对话触发增量压缩 compressor.compress()"""
|
||
compressor = make_mock_compressor()
|
||
# 让 compress 返回压缩后的短对话
|
||
compressor.compress = AsyncMock(return_value=[
|
||
{"role": "system", "content": "Summary"},
|
||
{"role": "user", "content": "recent"},
|
||
])
|
||
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
# 构造一个很长的初始消息,使 conversation 超过阈值
|
||
long_content = "x" * 40000
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="result")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": long_content}],
|
||
tools=[mock_tool],
|
||
compressor=compressor,
|
||
)
|
||
|
||
# 验证增量压缩被触发(compress 被调用了至少一次)
|
||
# 初始 compress 在 L218-222,增量 compress 在工具结果后
|
||
assert compressor.compress.call_count >= 1
|
||
|
||
async def test_incremental_compression_failure_handled(self):
|
||
"""compressor.compress() 异常时循环继续"""
|
||
compressor = make_mock_compressor()
|
||
# 第一次 compress 调用成功(初始压缩),增量压缩时失败
|
||
call_count = 0
|
||
|
||
async def compress_side_effect(messages):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count > 1:
|
||
raise RuntimeError("Incremental compression failed")
|
||
return messages
|
||
|
||
compressor.compress = AsyncMock(side_effect=compress_side_effect)
|
||
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
# 构造长消息触发增量压缩
|
||
long_content = "x" * 40000
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="result")
|
||
|
||
# 不应该抛出异常
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": long_content}],
|
||
tools=[mock_tool],
|
||
compressor=compressor,
|
||
)
|
||
|
||
# 应该正常返回结果
|
||
assert result.output == "Final answer after tool"
|
||
|
||
async def test_no_compressor_backward_compatible(self):
|
||
"""compressor=None 时行为与之前完全一致"""
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="search result data")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Search for test"}],
|
||
tools=[mock_tool],
|
||
compressor=None,
|
||
)
|
||
|
||
assert result.output == "Final answer after tool"
|
||
assert result.status == "success"
|
||
assert len(result.trajectory) == 2 # 1 tool_call + 1 final_answer
|
||
|
||
async def test_execute_stream_with_compressor(self):
|
||
"""execute_stream 模式下压缩正常工作"""
|
||
compressor = make_mock_compressor()
|
||
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:result")
|
||
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="original result")
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Search for test"}],
|
||
tools=[mock_tool],
|
||
compressor=compressor,
|
||
):
|
||
events.append(event)
|
||
|
||
# 验证 compress_tool_result 被调用
|
||
compressor.compress_tool_result.assert_called_once_with("search", "original result")
|
||
|
||
# 验证事件流完整
|
||
event_types = [e.event_type for e in events]
|
||
assert "thinking" in event_types
|
||
assert "tool_call" in event_types
|
||
assert "tool_result" in event_types
|
||
assert "final_answer" in event_types
|
||
|
||
async def test_execute_stream_incremental_compression(self):
|
||
"""execute_stream 模式下增量压缩触发"""
|
||
compressor = make_mock_compressor()
|
||
compressor.compress = AsyncMock(return_value=[
|
||
{"role": "system", "content": "Summary"},
|
||
{"role": "user", "content": "recent"},
|
||
])
|
||
|
||
gateway = make_mock_gateway_with_tool_call()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
long_content = "x" * 40000
|
||
mock_tool = MagicMock()
|
||
mock_tool.name = "search"
|
||
mock_tool.safe_execute = AsyncMock(return_value="result")
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": long_content}],
|
||
tools=[mock_tool],
|
||
compressor=compressor,
|
||
):
|
||
events.append(event)
|
||
|
||
# 验证增量压缩被触发
|
||
assert compressor.compress.call_count >= 1
|
||
|
||
async def test_context_compressor_satisfies_protocol(self):
|
||
"""ContextCompressor 满足 CompressionStrategy Protocol"""
|
||
compressor = ContextCompressor()
|
||
assert isinstance(compressor, CompressionStrategy)
|
||
|
||
async def test_context_compressor_compress_tool_result_default(self):
|
||
"""ContextCompressor.compress_tool_result 默认返回 str(result)"""
|
||
compressor = ContextCompressor()
|
||
result = await compressor.compress_tool_result("search", {"key": "value"})
|
||
assert result == "{'key': 'value'}"
|
||
|
||
|
||
# ── TestOTelSpanLifecycle (P0 fix: span leak) ──────────────────
|
||
|
||
|
||
class TestOTelSpanLifecycle:
|
||
"""测试 OTel span 生命周期 — 异常时 span 必须正确关闭"""
|
||
|
||
async def test_span_closed_on_success(self):
|
||
"""正常执行时 span 被正确关闭"""
|
||
gateway = make_mock_gateway()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
mock_span = MagicMock()
|
||
mock_span_cm = MagicMock()
|
||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||
|
||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||
|
||
# __exit__ should have been called
|
||
mock_span_cm.__exit__.assert_called_once()
|
||
|
||
async def test_span_closed_on_exception(self):
|
||
"""LLM 抛出异常时 span 仍被正确关闭"""
|
||
gateway = make_mock_gateway()
|
||
gateway.chat = AsyncMock(side_effect=RuntimeError("LLM error"))
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
mock_span = MagicMock()
|
||
mock_span_cm = MagicMock()
|
||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||
|
||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||
with pytest.raises(RuntimeError, match="LLM error"):
|
||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||
|
||
# __exit__ must have been called even though exception was raised
|
||
mock_span_cm.__exit__.assert_called_once()
|
||
|
||
async def test_span_attributes_set_on_success(self):
|
||
"""正常执行时 span 属性被设置"""
|
||
gateway = make_mock_gateway()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
mock_span = MagicMock()
|
||
mock_span_cm = MagicMock()
|
||
mock_span_cm.__enter__ = MagicMock(return_value=mock_span)
|
||
mock_span_cm.__exit__ = MagicMock(return_value=False)
|
||
|
||
with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \
|
||
patch("agentkit.core.react._OTEL_AVAILABLE", True):
|
||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||
|
||
# Verify span attributes were set
|
||
mock_span.set_attribute.assert_any_call("agent.total_steps", 1)
|
||
mock_span.set_attribute.assert_any_call("agent.total_tokens", 20)
|
||
mock_span.set_attribute.assert_any_call("agent.outcome", "success")
|
||
|
||
async def test_no_span_when_otel_unavailable(self):
|
||
"""_OTEL_AVAILABLE=False 时不创建 span"""
|
||
gateway = make_mock_gateway()
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
with patch("agentkit.core.react._OTEL_AVAILABLE", False), \
|
||
patch("agentkit.core.react.start_span") as mock_start_span:
|
||
await engine.execute(messages=[{"role": "user", "content": "hello"}])
|
||
|
||
# start_span should not be called when OTel is unavailable
|
||
mock_start_span.assert_not_called()
|