feat(compression): U3 ReAct engine tool result compression and incremental compress

Extend _build_tool_result_message to accept compressor parameter for
tool output compression. Add _should_compress helper for token budget
checking. Add incremental compression within ReAct loop when
conversation exceeds threshold.
This commit is contained in:
chiguyong 2026-06-07 18:19:53 +08:00
parent ea705b979b
commit fcb4fb33f3
2 changed files with 409 additions and 8 deletions

View File

@ -24,7 +24,7 @@ from agentkit.telemetry.metrics import (
)
if TYPE_CHECKING:
from agentkit.core.compressor import ContextCompressor
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
from agentkit.core.trace import TraceRecorder
from agentkit.memory.retriever import MemoryRetriever
@ -311,9 +311,16 @@ class ReActEngine:
)
# Observe: 将工具结果添加到对话历史
tool_msg = self._build_tool_result_message(tc.id, tool_result)
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor):
try:
conversation = await compressor.compress(conversation)
except Exception as e:
logger.warning(f"Incremental compression failed: {e}")
else:
# 检查文本解析模式
parsed_calls = self._parse_text_tool_calls(response.content or "")
@ -362,8 +369,15 @@ class ReActEngine:
)
# 将工具结果添加到对话历史
tool_msg = self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result)
tool_msg = await self._build_tool_result_message(pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"])
conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor):
try:
conversation = await compressor.compress(conversation)
except Exception as e:
logger.warning(f"Incremental compression failed: {e}")
else:
# Final answer: LLM 没有调用工具,返回最终答案
react_step = ReActStep(
@ -596,9 +610,16 @@ class ReActEngine:
data={"tool_name": tc.name, "result": tool_result},
)
tool_msg = self._build_tool_result_message(tc.id, tool_result)
tool_msg = await self._build_tool_result_message(tc.id, tool_result, compressor, tc.name)
conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor):
try:
conversation = await compressor.compress(conversation)
except Exception as e:
logger.warning(f"Incremental compression failed: {e}")
else:
# Check text parsing mode
parsed_calls = self._parse_text_tool_calls(response.content or "")
@ -651,10 +672,17 @@ class ReActEngine:
step=step,
data={"tool_name": pc["name"], "result": tool_result},
)
tool_msg = self._build_tool_result_message(
pc.get("id", f"text_tc_{step}"), tool_result
tool_msg = await self._build_tool_result_message(
pc.get("id", f"text_tc_{step}"), tool_result, compressor, pc["name"]
)
conversation.append(tool_msg)
# Incremental compression: compress conversation if it's getting long
if self._should_compress(conversation, compressor):
try:
conversation = await compressor.compress(conversation)
except Exception as e:
logger.warning(f"Incremental compression failed: {e}")
else:
# Final answer
react_step = ReActStep(
@ -745,12 +773,34 @@ class ReActEngine:
return tool
return None
def _build_tool_result_message(self, tool_call_id: str, result: Any) -> dict:
def _should_compress(self, conversation: list[dict], compressor: "CompressionStrategy | None") -> bool:
"""检查是否需要增量压缩"""
if not compressor:
return False
# Estimate tokens in conversation
total_chars = sum(len(str(m.get("content", ""))) for m in conversation)
estimated_tokens = total_chars // 4
return estimated_tokens > 8000 # Threshold: ~8000 tokens
async def _build_tool_result_message(
self,
tool_call_id: str,
result: Any,
compressor: "CompressionStrategy | None" = None,
tool_name: str | None = None,
) -> dict:
"""构建工具结果消息用于对话历史"""
content = str(result)
if compressor and tool_name:
try:
content = await compressor.compress_tool_result(tool_name, result)
except Exception as e:
logger.warning(f"Tool result compression failed for '{tool_name}': {e}")
content = str(result)
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": str(result),
"content": content,
}
async def _execute_tool(

View File

@ -0,0 +1,351 @@
"""Tests for ReAct engine compression integration (U3)"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.compressor import CompressionStrategy, ContextCompressor
from agentkit.core.react import ReActEngine
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
# ── Helpers ──────────────────────────────────────────
def make_mock_gateway() -> MagicMock:
"""创建一个 mock LLMGateway"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
response = LLMResponse(
content="Final answer",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(return_value=response)
return gateway
def make_mock_gateway_with_tool_call() -> MagicMock:
"""创建一个返回 tool_call 的 mock LLMGateway第二次调用返回最终答案"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
# 第一次调用返回 tool_call第二次返回最终答案
tool_response = LLMResponse(
content="",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
tool_calls=[
ToolCall(id="call_1", name="search", arguments={"query": "test"}),
],
)
final_response = LLMResponse(
content="Final answer after tool",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
gateway.chat = AsyncMock(side_effect=[tool_response, final_response])
return gateway
def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]:
"""生成长消息列表用于测试压缩"""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append({
"role": "user",
"content": "x" * content_length + f" message {i}",
})
messages.append({
"role": "assistant",
"content": "y" * content_length + f" reply {i}",
})
return messages
def make_mock_compressor() -> MagicMock:
"""创建一个 mock CompressionStrategy"""
compressor = MagicMock(spec=CompressionStrategy)
compressor.compress = AsyncMock(return_value=[{"role": "user", "content": "compressed"}])
compressor.compress_tool_result = AsyncMock(return_value="compressed tool result")
compressor.is_available = MagicMock(return_value=True)
return compressor
# ── TestBuildToolResultMessage ────────────────────────
class TestBuildToolResultMessage:
"""_build_tool_result_message 方法测试"""
async def test_no_compressor_returns_original(self):
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message("tc_1", {"key": "value"})
assert result == {
"role": "tool",
"tool_call_id": "tc_1",
"content": "{'key': 'value'}",
}
async def test_with_compressor_calls_compress_tool_result(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
)
compressor.compress_tool_result.assert_called_once_with("search", {"key": "value"})
assert result["content"] == "compressed tool result"
assert result["role"] == "tool"
assert result["tool_call_id"] == "tc_1"
async def test_compressor_failure_falls_back(self):
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(side_effect=RuntimeError("compression error"))
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", {"key": "value"}, compressor=compressor, tool_name="search"
)
# 应该回退到 str(result)
assert result["content"] == "{'key': 'value'}"
assert result["role"] == "tool"
async def test_compressor_receives_tool_name(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
await engine._build_tool_result_message(
"tc_1", "some result", compressor=compressor, tool_name="web_crawl"
)
compressor.compress_tool_result.assert_called_once_with("web_crawl", "some result")
async def test_compressor_without_tool_name_skips_compression(self):
"""compressor 存在但 tool_name 为 None 时不压缩"""
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
result = await engine._build_tool_result_message(
"tc_1", "some result", compressor=compressor, tool_name=None
)
compressor.compress_tool_result.assert_not_called()
assert result["content"] == "some result"
# ── TestShouldCompress ───────────────────────────────
class TestShouldCompress:
"""_should_compress 辅助方法测试"""
def test_no_compressor_returns_false(self):
engine = ReActEngine(llm_gateway=make_mock_gateway())
conversation = [{"role": "user", "content": "x" * 100000}]
assert engine._should_compress(conversation, None) is False
def test_short_conversation_returns_false(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
conversation = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there"},
]
assert engine._should_compress(conversation, compressor) is False
def test_long_conversation_returns_true(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# 8000 tokens * 4 chars/token = 32000 chars needed
conversation = [{"role": "user", "content": "x" * 40000}]
assert engine._should_compress(conversation, compressor) is True
def test_boundary_at_threshold(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# Exactly 8000 tokens = 32000 chars → should NOT trigger (> not >=)
conversation = [{"role": "user", "content": "x" * 32000}]
assert engine._should_compress(conversation, compressor) is False
def test_just_above_threshold(self):
compressor = make_mock_compressor()
engine = ReActEngine(llm_gateway=make_mock_gateway())
# 32001 chars → 32001//4 = 8000 tokens, still not > 8000
# 32004 chars → 32004//4 = 8001 tokens, > 8000
conversation = [{"role": "user", "content": "x" * 32004}]
assert engine._should_compress(conversation, compressor) is True
# ── TestReActLoopCompression ─────────────────────────
class TestReActLoopCompression:
"""ReAct 循环内压缩集成测试"""
async def test_tool_results_compressed_in_loop(self):
"""工具结果在循环中被压缩后拼入 conversation"""
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:search_result")
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 注册一个 mock tool
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="original search result")
result = await engine.execute(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=compressor,
)
# 验证 compress_tool_result 被调用
compressor.compress_tool_result.assert_called_once_with("search", "original search result")
async def test_incremental_compression_triggered(self):
"""长对话触发增量压缩 compressor.compress()"""
compressor = make_mock_compressor()
# 让 compress 返回压缩后的短对话
compressor.compress = AsyncMock(return_value=[
{"role": "system", "content": "Summary"},
{"role": "user", "content": "recent"},
])
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 构造一个很长的初始消息,使 conversation 超过阈值
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
result = await engine.execute(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
)
# 验证增量压缩被触发compress 被调用了至少一次)
# 初始 compress 在 L218-222增量 compress 在工具结果后
assert compressor.compress.call_count >= 1
async def test_incremental_compression_failure_handled(self):
"""compressor.compress() 异常时循环继续"""
compressor = make_mock_compressor()
# 第一次 compress 调用成功(初始压缩),增量压缩时失败
call_count = 0
async def compress_side_effect(messages):
nonlocal call_count
call_count += 1
if call_count > 1:
raise RuntimeError("Incremental compression failed")
return messages
compressor.compress = AsyncMock(side_effect=compress_side_effect)
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
# 构造长消息触发增量压缩
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
# 不应该抛出异常
result = await engine.execute(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
)
# 应该正常返回结果
assert result.output == "Final answer after tool"
async def test_no_compressor_backward_compatible(self):
"""compressor=None 时行为与之前完全一致"""
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="search result data")
result = await engine.execute(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=None,
)
assert result.output == "Final answer after tool"
assert result.status == "success"
assert len(result.trajectory) == 2 # 1 tool_call + 1 final_answer
async def test_execute_stream_with_compressor(self):
"""execute_stream 模式下压缩正常工作"""
compressor = make_mock_compressor()
compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:result")
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="original result")
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Search for test"}],
tools=[mock_tool],
compressor=compressor,
):
events.append(event)
# 验证 compress_tool_result 被调用
compressor.compress_tool_result.assert_called_once_with("search", "original result")
# 验证事件流完整
event_types = [e.event_type for e in events]
assert "thinking" in event_types
assert "tool_call" in event_types
assert "tool_result" in event_types
assert "final_answer" in event_types
async def test_execute_stream_incremental_compression(self):
"""execute_stream 模式下增量压缩触发"""
compressor = make_mock_compressor()
compressor.compress = AsyncMock(return_value=[
{"role": "system", "content": "Summary"},
{"role": "user", "content": "recent"},
])
gateway = make_mock_gateway_with_tool_call()
engine = ReActEngine(llm_gateway=gateway)
long_content = "x" * 40000
mock_tool = MagicMock()
mock_tool.name = "search"
mock_tool.safe_execute = AsyncMock(return_value="result")
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": long_content}],
tools=[mock_tool],
compressor=compressor,
):
events.append(event)
# 验证增量压缩被触发
assert compressor.compress.call_count >= 1
async def test_context_compressor_satisfies_protocol(self):
"""ContextCompressor 满足 CompressionStrategy Protocol"""
compressor = ContextCompressor()
assert isinstance(compressor, CompressionStrategy)
async def test_context_compressor_compress_tool_result_default(self):
"""ContextCompressor.compress_tool_result 默认返回 str(result)"""
compressor = ContextCompressor()
result = await compressor.compress_tool_result("search", {"key": "value"})
assert result == "{'key': 'value'}"