fischer-agentkit/tests/integration/test_parallel_tools.py

342 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""集成测试 - ReActEngine 并行工具执行
测试 parallel_tools 配置下的并行/串行/混合执行模式。
仅 mock LLMGateway外部 API使用真实 ReActEngine 实例。
"""
from __future__ import annotations
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.react import ReActEngine, ReActResult, ReActEvent
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.tools.base import Tool
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class SlowTool(Tool):
"""带延迟的 Fake Tool用于验证并行执行"""
def __init__(
self,
name: str = "slow_tool",
description: str = "A slow tool for testing parallel execution",
delay: float = 0.1,
result: dict | None = None,
):
super().__init__(name=name, description=description)
self._delay = delay
self._result = result or {"status": "ok"}
self.call_count = 0
self.call_times: list[float] = []
async def execute(self, **kwargs) -> dict:
self.call_count += 1
self.call_times.append(time.monotonic())
await asyncio.sleep(self._delay)
return self._result
def make_response(
content: str = "",
tool_calls: list[ToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
)
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
# ---------------------------------------------------------------------------
# Test 1: parallel_tools=True — 所有工具并行执行
# ---------------------------------------------------------------------------
class TestParallelToolsTrue:
"""parallel_tools=True 时所有工具并行执行"""
@pytest.mark.asyncio
async def test_all_tools_execute_in_parallel(self):
"""并行执行应比串行更快"""
tool_a = SlowTool(name="tool_a", delay=0.15)
tool_b = SlowTool(name="tool_b", delay=0.15)
# LLM 返回 2 个 tool_calls
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
],
)
final_response = make_response(content="Done")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Run both tools"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1
assert tool_b.call_count == 1
# 并行执行:总时间应接近单个工具延迟,而非两倍
assert elapsed < 0.35, f"Parallel execution too slow: {elapsed:.2f}s"
@pytest.mark.asyncio
async def test_parallel_result_in_trajectory(self):
"""并行执行的工具结果正确记录在 trajectory 中"""
tool_a = SlowTool(name="tool_a", delay=0.05, result={"value": "A"})
tool_b = SlowTool(name="tool_b", delay=0.05, result={"value": "B"})
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
],
)
final_response = make_response(content="Combined result")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
result = await engine.execute(
messages=[{"role": "user", "content": "Run both"}],
tools=[tool_a, tool_b],
)
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
assert len(tool_steps) == 2
names = {s.tool_name for s in tool_steps}
assert names == {"tool_a", "tool_b"}
# ---------------------------------------------------------------------------
# Test 2: parallel_tools="auto" + LLM 标记 _parallelizable=true → 并行
# ---------------------------------------------------------------------------
class TestParallelToolsAuto:
"""parallel_tools="auto" 时根据 _parallelizable 标记决定并行"""
@pytest.mark.asyncio
async def test_auto_parallelizable_tools_execute_in_parallel(self):
"""auto 模式下 _parallelizable=true 的工具并行执行"""
tool_a = SlowTool(name="tool_a", delay=0.15)
tool_b = SlowTool(name="tool_b", delay=0.15)
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1, "_parallelizable": True}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2, "_parallelizable": True}),
],
)
final_response = make_response(content="Auto parallel done")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Run in parallel"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1
assert tool_b.call_count == 1
# 并行执行应比串行快
assert elapsed < 0.35, f"Auto parallel execution too slow: {elapsed:.2f}s"
@pytest.mark.asyncio
async def test_auto_mixed_parallelizable_serial_then_parallel(self):
"""auto 模式下混合标记:先串行执行非并行工具,再并行执行并行工具"""
tool_serial = SlowTool(name="tool_serial", delay=0.1)
tool_para_a = SlowTool(name="tool_para_a", delay=0.1)
tool_para_b = SlowTool(name="tool_para_b", delay=0.1)
tool_call_response = make_response(
content="",
tool_calls=[
# 串行工具
ToolCall(id="tc_1", name="tool_serial", arguments={"x": 1}),
# 并行工具
ToolCall(id="tc_2", name="tool_para_a", arguments={"y": 2, "_parallelizable": True}),
ToolCall(id="tc_3", name="tool_para_b", arguments={"z": 3, "_parallelizable": True}),
],
)
final_response = make_response(content="Mixed result")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
result = await engine.execute(
messages=[{"role": "user", "content": "Mixed execution"}],
tools=[tool_serial, tool_para_a, tool_para_b],
)
assert isinstance(result, ReActResult)
assert tool_serial.call_count == 1
assert tool_para_a.call_count == 1
assert tool_para_b.call_count == 1
# 验证串行工具先执行,并行工具后执行
# 串行工具的调用时间应早于并行工具
assert tool_serial.call_times[0] < tool_para_a.call_times[0]
# 并行工具的调用时间应接近(几乎同时开始)
para_diff = abs(tool_para_a.call_times[0] - tool_para_b.call_times[0])
assert para_diff < 0.05, f"Parallel tools didn't start together: diff={para_diff:.3f}s"
@pytest.mark.asyncio
async def test_auto_no_parallelizable_tools_execute_serially(self):
"""auto 模式下没有 _parallelizable 标记时串行执行"""
tool_a = SlowTool(name="tool_a", delay=0.1)
tool_b = SlowTool(name="tool_b", delay=0.1)
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
],
)
final_response = make_response(content="Serial done")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools="auto")
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Run serially"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1
assert tool_b.call_count == 1
# 串行执行:总时间应接近两倍延迟
assert elapsed >= 0.15, f"Serial execution unexpectedly fast: {elapsed:.2f}s"
# ---------------------------------------------------------------------------
# Test 3: parallel_tools=False — 全部串行(默认行为)
# ---------------------------------------------------------------------------
class TestParallelToolsFalse:
"""parallel_tools=False 时所有工具串行执行"""
@pytest.mark.asyncio
async def test_all_tools_execute_serially(self):
"""默认串行执行"""
tool_a = SlowTool(name="tool_a", delay=0.1)
tool_b = SlowTool(name="tool_b", delay=0.1)
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
],
)
final_response = make_response(content="Serial result")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools=False)
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Run serially"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1
assert tool_b.call_count == 1
# 串行执行:总时间应接近两倍延迟
assert elapsed >= 0.15, f"Serial execution unexpectedly fast: {elapsed:.2f}s"
@pytest.mark.asyncio
async def test_default_is_serial(self):
"""默认 parallel_tools=False"""
tool_a = SlowTool(name="tool_a", delay=0.1)
tool_b = SlowTool(name="tool_b", delay=0.1)
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
ToolCall(id="tc_2", name="tool_b", arguments={"y": 2}),
],
)
final_response = make_response(content="Default serial")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway) # default parallel_tools=False
start = time.monotonic()
result = await engine.execute(
messages=[{"role": "user", "content": "Default mode"}],
tools=[tool_a, tool_b],
)
elapsed = time.monotonic() - start
assert isinstance(result, ReActResult)
assert elapsed >= 0.15, f"Default mode should be serial, but was fast: {elapsed:.2f}s"
@pytest.mark.asyncio
async def test_single_tool_call_is_serial(self):
"""单个工具调用无论 parallel_tools 设置如何都是串行"""
tool_a = SlowTool(name="tool_a", delay=0.05)
tool_call_response = make_response(
content="",
tool_calls=[
ToolCall(id="tc_1", name="tool_a", arguments={"x": 1}),
],
)
final_response = make_response(content="Single tool done")
gateway = make_mock_gateway([tool_call_response, final_response])
engine = ReActEngine(llm_gateway=gateway, parallel_tools=True)
result = await engine.execute(
messages=[{"role": "user", "content": "Single tool"}],
tools=[tool_a],
)
assert isinstance(result, ReActResult)
assert tool_a.call_count == 1