414 lines
14 KiB
Python
414 lines
14 KiB
Python
"""集成测试 - ReWOOEngine 可配置回退链
|
||
|
||
测试 fallback_strategies 参数控制下的自定义回退顺序。
|
||
仅 mock LLMGateway(外部 API),使用真实 ReWOOEngine 实例。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.rewoo import ReWOOEngine
|
||
from agentkit.core.react import ReActResult, ReActEvent
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.tools.base import Tool
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class FakeTool(Tool):
|
||
"""用于测试的 Fake Tool"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str = "fake_tool",
|
||
description: str = "A fake tool for testing",
|
||
result: dict | None = None,
|
||
should_fail: bool = False,
|
||
):
|
||
super().__init__(name=name, description=description)
|
||
self._result = result or {"status": "ok"}
|
||
self._should_fail = should_fail
|
||
self.call_count = 0
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
self.call_count += 1
|
||
if self._should_fail:
|
||
raise RuntimeError(f"Tool '{self.name}' execution failed")
|
||
return self._result
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
tool_calls: list[ToolCall] | None = None,
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
tool_calls=tool_calls or [],
|
||
)
|
||
|
||
|
||
def make_plan_response(
|
||
steps: list[dict],
|
||
reasoning: str = "Plan reasoning",
|
||
) -> LLMResponse:
|
||
plan_json = json.dumps({
|
||
"reasoning": reasoning,
|
||
"steps": steps,
|
||
})
|
||
return make_response(content=plan_json)
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 1: Custom fallback ["plan_exec", "react"] — planning 失败,plan_exec 失败,react 成功
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCustomFallbackPlanExecReact:
|
||
"""自定义回退链 ["plan_exec", "react"]"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_plan_exec_fails_react_succeeds(self):
|
||
"""planning 失败 → plan_exec 失败 → react 成功"""
|
||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||
|
||
# 1st: planning fails (invalid JSON)
|
||
# 2nd: plan_exec planning also fails
|
||
# 3rd: ReAct succeeds (tool call + final answer)
|
||
invalid_plan = make_response(content="Cannot create a plan")
|
||
invalid_plan_exec = make_response(content="Also cannot plan")
|
||
react_tool_call = make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
|
||
)
|
||
react_final = make_response(content="ReAct fallback answer")
|
||
|
||
gateway = make_mock_gateway([
|
||
invalid_plan,
|
||
invalid_plan_exec,
|
||
react_tool_call,
|
||
react_final,
|
||
])
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["plan_exec", "react"],
|
||
)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
tools=[tool],
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert result.output == "ReAct fallback answer"
|
||
assert result.fallback_strategy == "react"
|
||
assert tool.call_count == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_plan_exec_succeeds(self):
|
||
"""planning 失败 → plan_exec 成功"""
|
||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||
|
||
# 1st: planning fails
|
||
# 2nd: plan_exec planning succeeds
|
||
# 3rd: synthesis
|
||
invalid_plan = make_response(content="Cannot create a plan")
|
||
plan_exec_response = make_plan_response([
|
||
{"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Search"},
|
||
])
|
||
synthesis = make_response(content="Plan-exec result")
|
||
|
||
gateway = make_mock_gateway([invalid_plan, plan_exec_response, synthesis])
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["plan_exec", "react"],
|
||
)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
tools=[tool],
|
||
)
|
||
|
||
assert result.output == "Plan-exec result"
|
||
assert result.fallback_strategy == "plan_exec"
|
||
assert tool.call_count == 1
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 2: Custom fallback ["direct"] — planning 失败,direct 成功
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCustomFallbackDirect:
|
||
"""自定义回退链 ["direct"]"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_direct_fallback_succeeds(self):
|
||
"""planning 失败 → direct LLM 调用成功"""
|
||
invalid_plan = make_response(content="Cannot plan this task")
|
||
direct_response = make_response(content="Direct LLM answer")
|
||
|
||
gateway = make_mock_gateway([invalid_plan, direct_response])
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["direct"],
|
||
)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
assert result.output == "Direct LLM answer"
|
||
assert result.fallback_strategy == "direct"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_direct_fallback_with_system_prompt(self):
|
||
"""direct 回退时使用 system_prompt"""
|
||
invalid_plan = make_response(content="Cannot plan")
|
||
direct_response = make_response(content="Answer with context")
|
||
|
||
gateway = make_mock_gateway([invalid_plan, direct_response])
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["direct"],
|
||
)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Complex task"}],
|
||
system_prompt="You are a helpful assistant",
|
||
)
|
||
|
||
assert result.output == "Answer with context"
|
||
assert result.fallback_strategy == "direct"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 3: Default fallback — 所有策略失败 → RuntimeError
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestDefaultFallbackAllFail:
|
||
"""默认回退链所有策略都失败时抛出 RuntimeError"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_fallback_strategies_exhausted(self):
|
||
"""所有回退策略耗尽 → RuntimeError"""
|
||
# 所有 LLM 调用都返回无效 JSON
|
||
# 默认 fallback: simplified_rewoo → react → direct
|
||
# simplified_rewoo: 1 invalid plan
|
||
# react: 1 response (but we make it fail by having the gateway raise)
|
||
# direct: 1 response (but we make it fail too)
|
||
|
||
# Use a gateway that raises on every call after the first
|
||
call_count = 0
|
||
|
||
class FailingGateway:
|
||
def __init__(self):
|
||
self.chat = AsyncMock(side_effect=self._fail_after_first)
|
||
|
||
async def _fail_after_first(self, **kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
# First call: planning fails (invalid JSON)
|
||
return make_response(content="Not a plan")
|
||
elif call_count == 2:
|
||
# Second call: simplified_rewoo planning also fails
|
||
return make_response(content="Still not a plan")
|
||
else:
|
||
# All subsequent calls fail
|
||
raise RuntimeError("LLM service unavailable")
|
||
|
||
gateway = FailingGateway()
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
|
||
await engine.execute(
|
||
messages=[{"role": "user", "content": "Impossible task"}],
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_fallback_strategies_order(self):
|
||
"""验证默认回退策略顺序为 simplified_rewoo → react → direct"""
|
||
engine = ReWOOEngine(
|
||
llm_gateway=MagicMock(spec=LLMGateway),
|
||
)
|
||
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 4: Stream mode with custom fallback
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestStreamModeCustomFallback:
|
||
"""流式模式下的自定义回退链"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_stream_direct_fallback(self):
|
||
"""流式模式下 planning 失败 → direct 回退"""
|
||
invalid_plan = make_response(content="Cannot plan")
|
||
direct_response = make_response(content="Stream direct answer")
|
||
|
||
gateway = make_mock_gateway([invalid_plan, direct_response])
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["direct"],
|
||
)
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Stream task"}],
|
||
):
|
||
events.append(event)
|
||
|
||
# 应该有 planning 事件和 final_answer 事件
|
||
event_types = [e.event_type for e in events]
|
||
assert "planning" in event_types
|
||
assert "final_answer" in event_types
|
||
|
||
# 最终答案来自 direct 回退
|
||
final_events = [e for e in events if e.event_type == "final_answer"]
|
||
assert len(final_events) >= 1
|
||
assert final_events[0].data["output"] == "Stream direct answer"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_stream_react_fallback(self):
|
||
"""流式模式下 planning 失败 → react 回退"""
|
||
tool = FakeTool(name="search", result={"results": ["found"]})
|
||
|
||
invalid_plan = make_response(content="Cannot plan")
|
||
|
||
# ReAct stream mode needs chat_stream as an async generator
|
||
from agentkit.llm.protocol import StreamChunk
|
||
|
||
async def mock_chat_stream(**kwargs):
|
||
# First chunk: tool call
|
||
yield StreamChunk(
|
||
content="",
|
||
model="test-model",
|
||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||
)
|
||
# Second chunk: final answer
|
||
yield StreamChunk(
|
||
content="Stream react answer",
|
||
model="test-model",
|
||
tool_calls=[],
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
is_final=True,
|
||
)
|
||
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=[invalid_plan])
|
||
gateway.chat_stream = mock_chat_stream
|
||
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["react"],
|
||
)
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Stream task"}],
|
||
tools=[tool],
|
||
):
|
||
events.append(event)
|
||
|
||
event_types = [e.event_type for e in events]
|
||
assert "planning" in event_types
|
||
assert "final_answer" in event_types
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_stream_all_fallbacks_exhausted(self):
|
||
"""流式模式下所有回退策略耗尽 → RuntimeError"""
|
||
call_count = 0
|
||
|
||
class FailingStreamGateway:
|
||
def __init__(self):
|
||
self.chat = AsyncMock(side_effect=self._respond_or_fail)
|
||
self.chat_stream = AsyncMock(side_effect=self._stream_fail)
|
||
|
||
async def _respond_or_fail(self, **kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
if call_count == 1:
|
||
return make_response(content="Not a plan")
|
||
raise RuntimeError("LLM unavailable")
|
||
|
||
async def _stream_fail(self, **kwargs):
|
||
raise RuntimeError("Stream unavailable")
|
||
|
||
gateway = FailingStreamGateway()
|
||
engine = ReWOOEngine(llm_gateway=gateway)
|
||
|
||
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
|
||
async for _ in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Impossible task"}],
|
||
):
|
||
pass
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 5: Invalid strategy validation
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestFallbackStrategyValidation:
|
||
"""回退策略参数验证"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_invalid_strategy_skipped(self):
|
||
"""无效策略被跳过,有效策略仍然执行"""
|
||
invalid_plan = make_response(content="Cannot plan")
|
||
direct_response = make_response(content="Direct answer")
|
||
|
||
gateway = make_mock_gateway([invalid_plan, direct_response])
|
||
# "invalid_strategy" 应被跳过,"direct" 应执行
|
||
engine = ReWOOEngine(
|
||
llm_gateway=gateway,
|
||
fallback_strategies=["invalid_strategy", "direct"],
|
||
)
|
||
|
||
# "invalid_strategy" 被跳过,所以 fallback_strategies 只包含 "direct"
|
||
assert engine._fallback_strategies == ["direct"]
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Task"}],
|
||
)
|
||
|
||
assert result.output == "Direct answer"
|
||
assert result.fallback_strategy == "direct"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_all_invalid_strategies_uses_defaults(self):
|
||
"""全部无效策略时回退到默认策略"""
|
||
engine = ReWOOEngine(
|
||
llm_gateway=MagicMock(spec=LLMGateway),
|
||
fallback_strategies=["invalid1", "invalid2"],
|
||
)
|
||
|
||
# 全部无效时回退到默认策略
|
||
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
|