294 lines
11 KiB
Python
294 lines
11 KiB
Python
"""集成测试 - ReWOO 渐进式回退链
|
||
|
||
测试 ReWOOEngine 的 planning → simplified_rewoo → react → direct 回退策略。
|
||
仅 mock LLMGateway(外部 API),使用真实 ReWOOEngine 实例。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.rewoo import ReWOOEngine, ReWOOStep
|
||
from agentkit.core.react import ReActResult, ReActStep
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.tools.base import Tool
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class FakeTool(Tool):
|
||
"""用于测试的 Fake Tool"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str = "fake_tool",
|
||
description: str = "A fake tool for testing",
|
||
result: dict | None = None,
|
||
should_fail: bool = False,
|
||
):
|
||
super().__init__(
|
||
name=name,
|
||
description=description,
|
||
)
|
||
self._result = result or {"status": "ok"}
|
||
self._should_fail = should_fail
|
||
self.call_count = 0
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
self.call_count += 1
|
||
if self._should_fail:
|
||
raise RuntimeError(f"Tool '{self.name}' execution failed")
|
||
return self._result
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
tool_calls: list[ToolCall] | None = None,
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
tool_calls=tool_calls or [],
|
||
)
|
||
|
||
|
||
def make_plan_response(
|
||
steps: list[dict],
|
||
reasoning: str = "Plan reasoning",
|
||
) -> LLMResponse:
|
||
plan_json = json.dumps({
|
||
"reasoning": reasoning,
|
||
"steps": steps,
|
||
})
|
||
return make_response(content=plan_json)
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 1: Planning succeeds → no fallback
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestReWOOPlanningSucceeds:
|
||
"""规划成功,不使用回退"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_planning_succeeds_no_fallback(self):
|
||
tool = FakeTool(name="calculator", result={"value": 42})
|
||
|
||
plan_response = make_plan_response([
|
||
{"step_id": 1, "tool_name": "calculator", "arguments": {"expr": "6*7"}, "reasoning": "Calculate"},
|
||
])
|
||
synthesis_response = make_response(content="The result is 42")
|
||
|
||
gateway = make_mock_gateway([plan_response, synthesis_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Calculate 6*7"}],
|
||
tools=[tool],
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert result.output == "The result is 42"
|
||
assert result.fallback_strategy is None
|
||
assert result.status == "success"
|
||
# 1 tool_call + 1 final_answer = 2 steps
|
||
assert result.total_steps == 2
|
||
assert tool.call_count == 1
|
||
|
||
# Verify ReWOOStep has plan_step_id
|
||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||
assert len(tool_steps) == 1
|
||
assert isinstance(tool_steps[0], ReWOOStep)
|
||
assert tool_steps[0].plan_step_id == 1
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 2: Planning fails → simplified planning succeeds → fallback_strategy="simplified_rewoo"
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestReWOOSimplifiedFallback:
|
||
"""规划失败,简化规划成功"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_planning_fails_simplified_succeeds(self):
|
||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||
|
||
# First plan fails (invalid JSON), second (simplified) succeeds
|
||
invalid_plan_response = make_response(content="I cannot create a plan for this task.")
|
||
simplified_plan_response = make_plan_response([
|
||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Simplified search"},
|
||
])
|
||
synthesis_response = make_response(content="Simplified result")
|
||
|
||
gateway = make_mock_gateway([invalid_plan_response, simplified_plan_response, synthesis_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
tools=[tool],
|
||
)
|
||
|
||
assert result.output == "Simplified result"
|
||
assert result.fallback_strategy == "simplified_rewoo"
|
||
assert tool.call_count == 1
|
||
|
||
# Verify trajectory still has proper structure
|
||
tool_steps = [s for s in result.trajectory if s.action == "tool_call"]
|
||
assert len(tool_steps) == 1
|
||
assert tool_steps[0].tool_name == "search"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 3: All planning fails → ReAct fallback → fallback_strategy="react"
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestReWOOReActFallback:
|
||
"""所有规划失败,回退到 ReAct"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_planning_and_simplified_fail_react_succeeds(self):
|
||
# Both plan attempts fail (invalid JSON), ReAct succeeds
|
||
invalid_plan1 = make_response(content="Not a plan")
|
||
invalid_plan2 = make_response(content="Still not a plan")
|
||
react_response = make_response(content="ReAct fallback answer")
|
||
|
||
gateway = make_mock_gateway([invalid_plan1, invalid_plan2, react_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
)
|
||
|
||
assert result.output == "ReAct fallback answer"
|
||
assert result.fallback_strategy == "react"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_react_fallback_with_tool_calls(self):
|
||
"""ReAct 回退时带工具调用"""
|
||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||
|
||
invalid_plan1 = make_response(content="Cannot plan")
|
||
invalid_plan2 = make_response(content="Still cannot plan")
|
||
react_tool_response = make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
|
||
)
|
||
react_final_response = make_response(content="ReAct answer with tool")
|
||
|
||
gateway = make_mock_gateway([
|
||
invalid_plan1,
|
||
invalid_plan2,
|
||
react_tool_response,
|
||
react_final_response,
|
||
])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Search task"}],
|
||
tools=[tool],
|
||
)
|
||
|
||
assert result.output == "ReAct answer with tool"
|
||
assert result.fallback_strategy == "react"
|
||
assert tool.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_malformed_json_triggers_react_fallback(self):
|
||
"""格式错误的 JSON 触发 ReAct 回退"""
|
||
malformed_response = make_response(content='{"reasoning": "plan", "steps": [invalid json')
|
||
simplified_fail_response = make_response(content="Also not a plan")
|
||
react_response = make_response(content="ReAct answer")
|
||
|
||
gateway = make_mock_gateway([malformed_response, simplified_fail_response, react_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Task"}],
|
||
)
|
||
|
||
assert result.output == "ReAct answer"
|
||
assert result.fallback_strategy == "react"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_missing_steps_key_triggers_react_fallback(self):
|
||
"""缺少 steps 键触发 ReAct 回退"""
|
||
no_steps_response = make_response(content='{"reasoning": "no steps here"}')
|
||
simplified_fail_response = make_response(content="Also no steps")
|
||
react_response = make_response(content="ReAct fallback")
|
||
|
||
gateway = make_mock_gateway([no_steps_response, simplified_fail_response, react_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Task"}],
|
||
)
|
||
|
||
assert result.output == "ReAct fallback"
|
||
assert result.fallback_strategy == "react"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test: Multi-step plan with fallback chain integration
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestReWOOMultiStepWithFallbackChain:
|
||
"""多步计划与回退链的集成测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_three_step_plan_no_fallback(self):
|
||
"""三步计划成功,无回退"""
|
||
search_tool = FakeTool(name="search", result={"results": ["Python is great"]})
|
||
calc_tool = FakeTool(name="calculator", result={"value": 100})
|
||
weather_tool = FakeTool(name="weather", result={"temp": 25, "city": "Shanghai"})
|
||
|
||
plan_response = make_plan_response([
|
||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "Python"}, "reasoning": "Search first"},
|
||
{"step_id": 2, "tool_name": "calculator", "arguments": {"expr": "10*10"}, "reasoning": "Calculate"},
|
||
{"step_id": 3, "tool_name": "weather", "arguments": {"city": "Shanghai"}, "reasoning": "Check weather"},
|
||
])
|
||
synthesis_response = make_response(content="Based on search, calculation (100), and weather (25°C)")
|
||
|
||
gateway = make_mock_gateway([plan_response, synthesis_response])
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Search, calculate and check weather"}],
|
||
tools=[search_tool, calc_tool, weather_tool],
|
||
)
|
||
|
||
assert result.fallback_strategy is None
|
||
assert result.total_steps == 4 # 3 tool_calls + 1 final_answer
|
||
assert search_tool.call_count == 1
|
||
assert calc_tool.call_count == 1
|
||
assert weather_tool.call_count == 1
|
||
assert "100" in result.output
|
||
assert "25" in result.output
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fallback_strategies_constant(self):
|
||
"""验证 FALLBACK_STRATEGIES 常量"""
|
||
assert ReWOOEngine.FALLBACK_STRATEGIES == ["simplified_rewoo", "react", "direct"]
|