"""Tests for ReAct engine compression integration (U3)""" from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.compressor import CompressionStrategy, ContextCompressor from agentkit.core.react import ReActEngine from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall # ── Helpers ────────────────────────────────────────── def make_mock_gateway() -> MagicMock: """创建一个 mock LLMGateway""" from agentkit.llm.gateway import LLMGateway gateway = MagicMock(spec=LLMGateway) response = LLMResponse( content="Final answer", model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway.chat = AsyncMock(return_value=response) return gateway def make_mock_gateway_with_tool_call() -> MagicMock: """创建一个返回 tool_call 的 mock LLMGateway,第二次调用返回最终答案 同时设置 chat 和 chat_stream,使 execute 和 execute_stream 路径都能正常工作。 """ from agentkit.llm.gateway import LLMGateway gateway = MagicMock(spec=LLMGateway) # 第一次调用返回 tool_call,第二次返回最终答案 tool_response = LLMResponse( content="", model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), tool_calls=[ ToolCall(id="call_1", name="search", arguments={"query": "test"}), ], ) final_response = LLMResponse( content="Final answer after tool", model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), ) gateway.chat = AsyncMock(side_effect=[tool_response, final_response]) # ponytail: chat_stream yields StreamChunk equivalents of the chat responses # so execute_stream (which uses chat_stream) exercises the same tool path. tool_chunk = StreamChunk( content="", model="test-model", tool_calls=[ToolCall(id="call_1", name="search", arguments={"query": "test"})], usage=TokenUsage(prompt_tokens=10, completion_tokens=10), is_final=True, ) final_chunk = StreamChunk( content="Final answer after tool", model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=10), is_final=True, ) async def _stream(**kwargs): # Closure state tracks which response to yield (1st call=tool, 2nd=final) _stream._call_count = getattr(_stream, "_call_count", 0) + 1 if _stream._call_count == 1: yield tool_chunk else: yield final_chunk gateway.chat_stream = _stream return gateway def make_long_messages(count: int = 10, content_length: int = 2000) -> list[dict]: """生成长消息列表用于测试压缩""" messages = [{"role": "system", "content": "You are a helpful assistant."}] for i in range(count): messages.append({ "role": "user", "content": "x" * content_length + f" message {i}", }) messages.append({ "role": "assistant", "content": "y" * content_length + f" reply {i}", }) return messages def make_mock_compressor() -> MagicMock: """创建一个 mock CompressionStrategy""" compressor = MagicMock(spec=CompressionStrategy) compressor.compress = AsyncMock(return_value=[{"role": "user", "content": "compressed"}]) compressor.compress_tool_result = AsyncMock(return_value="compressed tool result") compressor.is_available = MagicMock(return_value=True) return compressor # ── TestBuildToolResultMessage ──────────────────────── class TestBuildToolResultMessage: """_build_tool_result_message 方法测试""" async def test_no_compressor_returns_original(self): engine = ReActEngine(llm_gateway=make_mock_gateway()) result = await engine._build_tool_result_message("tc_1", {"key": "value"}) assert result == { "role": "tool", "tool_call_id": "tc_1", "content": "{'key': 'value'}", } async def test_with_compressor_calls_compress_tool_result(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) result = await engine._build_tool_result_message( "tc_1", {"key": "value"}, compressor=compressor, tool_name="search" ) compressor.compress_tool_result.assert_called_once_with("search", {"key": "value"}) assert result["content"] == "compressed tool result" assert result["role"] == "tool" assert result["tool_call_id"] == "tc_1" async def test_compressor_failure_falls_back(self): compressor = make_mock_compressor() compressor.compress_tool_result = AsyncMock(side_effect=RuntimeError("compression error")) engine = ReActEngine(llm_gateway=make_mock_gateway()) result = await engine._build_tool_result_message( "tc_1", {"key": "value"}, compressor=compressor, tool_name="search" ) # 应该回退到 str(result) assert result["content"] == "{'key': 'value'}" assert result["role"] == "tool" async def test_compressor_receives_tool_name(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) await engine._build_tool_result_message( "tc_1", "some result", compressor=compressor, tool_name="web_crawl" ) compressor.compress_tool_result.assert_called_once_with("web_crawl", "some result") async def test_compressor_without_tool_name_skips_compression(self): """compressor 存在但 tool_name 为 None 时不压缩""" compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) result = await engine._build_tool_result_message( "tc_1", "some result", compressor=compressor, tool_name=None ) compressor.compress_tool_result.assert_not_called() assert result["content"] == "some result" # ── TestShouldCompress ─────────────────────────────── class TestShouldCompress: """_should_compress 辅助方法测试""" def test_no_compressor_returns_false(self): engine = ReActEngine(llm_gateway=make_mock_gateway()) conversation = [{"role": "user", "content": "x" * 100000}] assert engine._should_compress(conversation, None) is False def test_short_conversation_returns_false(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) conversation = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there"}, ] assert engine._should_compress(conversation, compressor) is False def test_long_conversation_returns_true(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) # 8000 tokens * 4 chars/token = 32000 chars needed conversation = [{"role": "user", "content": "x" * 40000}] assert engine._should_compress(conversation, compressor) is True def test_boundary_at_threshold(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) # Exactly 8000 tokens = 32000 chars → should NOT trigger (> not >=) conversation = [{"role": "user", "content": "x" * 32000}] assert engine._should_compress(conversation, compressor) is False def test_just_above_threshold(self): compressor = make_mock_compressor() engine = ReActEngine(llm_gateway=make_mock_gateway()) # 32001 chars → 32001//4 = 8000 tokens, still not > 8000 # 32004 chars → 32004//4 = 8001 tokens, > 8000 conversation = [{"role": "user", "content": "x" * 32004}] assert engine._should_compress(conversation, compressor) is True # ── TestReActLoopCompression ───────────────────────── class TestReActLoopCompression: """ReAct 循环内压缩集成测试""" async def test_tool_results_compressed_in_loop(self): """工具结果在循环中被压缩后拼入 conversation""" compressor = make_mock_compressor() compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:search_result") gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) # 注册一个 mock tool mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="original search result") result = await engine.execute( messages=[{"role": "user", "content": "Search for test"}], tools=[mock_tool], compressor=compressor, ) # 验证 compress_tool_result 被调用 compressor.compress_tool_result.assert_called_once_with("search", "original search result") async def test_incremental_compression_triggered(self): """长对话触发增量压缩 compressor.compress()""" compressor = make_mock_compressor() # 让 compress 返回压缩后的短对话 compressor.compress = AsyncMock(return_value=[ {"role": "system", "content": "Summary"}, {"role": "user", "content": "recent"}, ]) gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) # 构造一个很长的初始消息,使 conversation 超过阈值 long_content = "x" * 40000 mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="result") result = await engine.execute( messages=[{"role": "user", "content": long_content}], tools=[mock_tool], compressor=compressor, ) # 验证增量压缩被触发(compress 被调用了至少一次) # 初始 compress 在 L218-222,增量 compress 在工具结果后 assert compressor.compress.call_count >= 1 async def test_incremental_compression_failure_handled(self): """compressor.compress() 异常时循环继续""" compressor = make_mock_compressor() # 第一次 compress 调用成功(初始压缩),增量压缩时失败 call_count = 0 async def compress_side_effect(messages): nonlocal call_count call_count += 1 if call_count > 1: raise RuntimeError("Incremental compression failed") return messages compressor.compress = AsyncMock(side_effect=compress_side_effect) gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) # 构造长消息触发增量压缩 long_content = "x" * 40000 mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="result") # 不应该抛出异常 result = await engine.execute( messages=[{"role": "user", "content": long_content}], tools=[mock_tool], compressor=compressor, ) # 应该正常返回结果 assert result.output == "Final answer after tool" async def test_no_compressor_backward_compatible(self): """compressor=None 时行为与之前完全一致""" gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="search result data") result = await engine.execute( messages=[{"role": "user", "content": "Search for test"}], tools=[mock_tool], compressor=None, ) assert result.output == "Final answer after tool" assert result.status == "success" assert len(result.trajectory) == 2 # 1 tool_call + 1 final_answer async def test_execute_stream_with_compressor(self): """execute_stream 模式下压缩正常工作""" compressor = make_mock_compressor() compressor.compress_tool_result = AsyncMock(return_value="COMPRESSED:result") gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="original result") events = [] async for event in engine.execute_stream( messages=[{"role": "user", "content": "Search for test"}], tools=[mock_tool], compressor=compressor, ): events.append(event) # 验证 compress_tool_result 被调用 compressor.compress_tool_result.assert_called_once_with("search", "original result") # 验证事件流完整 event_types = [e.event_type for e in events] assert "thinking" in event_types assert "tool_call" in event_types assert "tool_result" in event_types assert "final_answer" in event_types async def test_execute_stream_incremental_compression(self): """execute_stream 模式下增量压缩触发""" compressor = make_mock_compressor() compressor.compress = AsyncMock(return_value=[ {"role": "system", "content": "Summary"}, {"role": "user", "content": "recent"}, ]) gateway = make_mock_gateway_with_tool_call() engine = ReActEngine(llm_gateway=gateway) long_content = "x" * 40000 mock_tool = MagicMock() mock_tool.name = "search" mock_tool.safe_execute = AsyncMock(return_value="result") events = [] async for event in engine.execute_stream( messages=[{"role": "user", "content": long_content}], tools=[mock_tool], compressor=compressor, ): events.append(event) # 验证增量压缩被触发 assert compressor.compress.call_count >= 1 async def test_context_compressor_satisfies_protocol(self): """ContextCompressor 满足 CompressionStrategy Protocol""" compressor = ContextCompressor() assert isinstance(compressor, CompressionStrategy) async def test_context_compressor_compress_tool_result_default(self): """ContextCompressor.compress_tool_result 默认返回 str(result)""" compressor = ContextCompressor() result = await compressor.compress_tool_result("search", {"key": "value"}) assert result == "{'key': 'value'}" # ── TestOTelSpanLifecycle (P0 fix: span leak) ────────────────── class TestOTelSpanLifecycle: """测试 OTel span 生命周期 — 异常时 span 必须正确关闭""" async def test_span_closed_on_success(self): """正常执行时 span 被正确关闭""" gateway = make_mock_gateway() engine = ReActEngine(llm_gateway=gateway) mock_span = MagicMock() mock_span_cm = MagicMock() mock_span_cm.__enter__ = MagicMock(return_value=mock_span) mock_span_cm.__exit__ = MagicMock(return_value=False) with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \ patch("agentkit.core.react._OTEL_AVAILABLE", True): await engine.execute(messages=[{"role": "user", "content": "hello"}]) # __exit__ should have been called mock_span_cm.__exit__.assert_called_once() async def test_span_closed_on_exception(self): """LLM 抛出异常时 span 仍被正确关闭""" gateway = make_mock_gateway() gateway.chat = AsyncMock(side_effect=RuntimeError("LLM error")) engine = ReActEngine(llm_gateway=gateway) mock_span = MagicMock() mock_span_cm = MagicMock() mock_span_cm.__enter__ = MagicMock(return_value=mock_span) mock_span_cm.__exit__ = MagicMock(return_value=False) with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \ patch("agentkit.core.react._OTEL_AVAILABLE", True): with pytest.raises(RuntimeError, match="LLM error"): await engine.execute(messages=[{"role": "user", "content": "hello"}]) # __exit__ must have been called even though exception was raised mock_span_cm.__exit__.assert_called_once() async def test_span_attributes_set_on_success(self): """正常执行时 span 属性被设置""" gateway = make_mock_gateway() engine = ReActEngine(llm_gateway=gateway) mock_span = MagicMock() mock_span_cm = MagicMock() mock_span_cm.__enter__ = MagicMock(return_value=mock_span) mock_span_cm.__exit__ = MagicMock(return_value=False) with patch("agentkit.core.react.start_span", return_value=mock_span_cm), \ patch("agentkit.core.react._OTEL_AVAILABLE", True): await engine.execute(messages=[{"role": "user", "content": "hello"}]) # Verify span attributes were set mock_span.set_attribute.assert_any_call("agent.total_steps", 1) mock_span.set_attribute.assert_any_call("agent.total_tokens", 20) mock_span.set_attribute.assert_any_call("agent.outcome", "success") async def test_no_span_when_otel_unavailable(self): """_OTEL_AVAILABLE=False 时不创建 span""" gateway = make_mock_gateway() engine = ReActEngine(llm_gateway=gateway) with patch("agentkit.core.react._OTEL_AVAILABLE", False), \ patch("agentkit.core.react.start_span") as mock_start_span: await engine.execute(messages=[{"role": "user", "content": "hello"}]) # start_span should not be called when OTel is unavailable mock_start_span.assert_not_called()