342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""集成测试 - ReActEngine 并行工具执行
|
||
|
||
测试 parallel_tools 配置下的并行/串行/混合执行模式。
|
||
仅 mock LLMGateway(外部 API),使用真实 ReActEngine 实例。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import time
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.react import ReActEngine, ReActResult, ReActEvent
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.tools.base import Tool
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class SlowTool(Tool):
|
||
"""带延迟的 Fake Tool,用于验证并行执行"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str = "slow_tool",
|
||
description: str = "A slow tool for testing parallel execution",
|
||
delay: float = 0.1,
|
||
result: dict | None = None,
|
||
):
|
||
super().__init__(name=name, description=description)
|
||
self._delay = delay
|
||
self._result = result or {"status": "ok"}
|
||
self.call_count = 0
|
||
self.call_times: list[float] = []
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
self.call_count += 1
|
||
self.call_times.append(time.monotonic())
|
||
await asyncio.sleep(self._delay)
|
||
return self._result
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
tool_calls: list[ToolCall] | None = None,
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
tool_calls=tool_calls or [],
|
||
)
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 1: parallel_tools=True — 所有工具并行执行
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParallelToolsTrue:
|
||
"""parallel_tools=True 时所有工具并行执行"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_tools_execute_in_parallel(self):
|
||
"""并行执行应比串行更快"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.15)
|
||
tool_b = SlowTool(name="tool_b", delay=0.15)
|
||
|
||
# LLM 返回 2 个 tool_calls
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Done")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Run both tools"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_a.call_count == 1
|
||
assert tool_b.call_count == 1
|
||
# 并行执行:总时间应接近单个工具延迟,而非两倍
|
||
assert elapsed < 0.35, f"Parallel execution too slow: {elapsed:.2f}s"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_parallel_result_in_trajectory(self):
|
||
"""并行执行的工具结果正确记录在 trajectory 中"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.05, result={"value": "A"})
|
||
tool_b = SlowTool(name="tool_b", delay=0.05, result={"value": "B"})
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Combined result")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Run both"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
|
||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||
assert len(tool_steps) == 2
|
||
names = {s.tool_name for s in tool_steps}
|
||
assert names == {"tool_a", "tool_b"}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 2: parallel_tools="auto" + LLM 标记 _parallelizable=true → 并行
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParallelToolsAuto:
|
||
"""parallel_tools="auto" 时根据 _parallelizable 标记决定并行"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auto_parallelizable_tools_execute_in_parallel(self):
|
||
"""auto 模式下 _parallelizable=true 的工具并行执行"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.15)
|
||
tool_b = SlowTool(name="tool_b", delay=0.15)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1, "_parallelizable": True}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2, "_parallelizable": True}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Auto parallel done")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Run in parallel"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_a.call_count == 1
|
||
assert tool_b.call_count == 1
|
||
# 并行执行应比串行快
|
||
assert elapsed < 0.35, f"Auto parallel execution too slow: {elapsed:.2f}s"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auto_mixed_parallelizable_serial_then_parallel(self):
|
||
"""auto 模式下混合标记:先串行执行非并行工具,再并行执行并行工具"""
|
||
tool_serial = SlowTool(name="tool_serial", delay=0.1)
|
||
tool_para_a = SlowTool(name="tool_para_a", delay=0.1)
|
||
tool_para_b = SlowTool(name="tool_para_b", delay=0.1)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
# 串行工具
|
||
ToolCall(id="tc_1", name="tool_serial", arguments={"x": 1}),
|
||
# 并行工具
|
||
ToolCall(id="tc_2", name="tool_para_a", arguments={"y": 2, "_parallelizable": True}),
|
||
ToolCall(id="tc_3", name="tool_para_b", arguments={"z": 3, "_parallelizable": True}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Mixed result")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Mixed execution"}],
|
||
tools=[tool_serial, tool_para_a, tool_para_b],
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_serial.call_count == 1
|
||
assert tool_para_a.call_count == 1
|
||
assert tool_para_b.call_count == 1
|
||
|
||
# 验证串行工具先执行,并行工具后执行
|
||
# 串行工具的调用时间应早于并行工具
|
||
assert tool_serial.call_times[0] < tool_para_a.call_times[0]
|
||
|
||
# 并行工具的调用时间应接近(几乎同时开始)
|
||
para_diff = abs(tool_para_a.call_times[0] - tool_para_b.call_times[0])
|
||
assert para_diff < 0.05, f"Parallel tools didn't start together: diff={para_diff:.3f}s"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auto_no_parallelizable_tools_execute_serially(self):
|
||
"""auto 模式下没有 _parallelizable 标记时串行执行"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.1)
|
||
tool_b = SlowTool(name="tool_b", delay=0.1)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Serial done")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Run serially"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_a.call_count == 1
|
||
assert tool_b.call_count == 1
|
||
# 串行执行:总时间应接近两倍延迟
|
||
assert elapsed >= 0.15, f"Serial execution unexpectedly fast: {elapsed:.2f}s"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 3: parallel_tools=False — 全部串行(默认行为)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParallelToolsFalse:
|
||
"""parallel_tools=False 时所有工具串行执行"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_tools_execute_serially(self):
|
||
"""默认串行执行"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.1)
|
||
tool_b = SlowTool(name="tool_b", delay=0.1)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Serial result")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools=False)
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Run serially"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_a.call_count == 1
|
||
assert tool_b.call_count == 1
|
||
# 串行执行:总时间应接近两倍延迟
|
||
assert elapsed >= 0.15, f"Serial execution unexpectedly fast: {elapsed:.2f}s"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_is_serial(self):
|
||
"""默认 parallel_tools=False"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.1)
|
||
tool_b = SlowTool(name="tool_b", delay=0.1)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Default serial")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway) # default parallel_tools=False
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Default mode"}],
|
||
tools=[tool_a, tool_b],
|
||
)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert elapsed >= 0.15, f"Default mode should be serial, but was fast: {elapsed:.2f}s"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_tool_call_is_serial(self):
|
||
"""单个工具调用无论 parallel_tools 设置如何都是串行"""
|
||
tool_a = SlowTool(name="tool_a", delay=0.05)
|
||
|
||
tool_call_response = make_response(
|
||
content="",
|
||
tool_calls=[
|
||
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
|
||
],
|
||
)
|
||
final_response = make_response(content="Single tool done")
|
||
|
||
gateway = make_mock_gateway([tool_call_response, final_response])
|
||
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Single tool"}],
|
||
tools=[tool_a],
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert tool_a.call_count == 1
|