fischer-agentkit/tests/integration/test_rewoo_configurable_fal...

414 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""集成测试 - ReWOOEngine 可配置回退链
测试 fallback_strategies 参数控制下的自定义回退顺序。
仅 mock LLMGateway外部 API使用真实 ReWOOEngine 实例。
"""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.rewoo import ReWOOEngine
from agentkit.core.react import ReActResult, ReActEvent
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.tools.base import Tool
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str = "fake_tool",
description: str = "A fake tool for testing",
result: dict | None = None,
should_fail: bool = False,
):
super().__init__(name=name, description=description)
self._result = result or {"status": "ok"}
self._should_fail = should_fail
self.call_count = 0
async def execute(self, **kwargs) -> dict:
self.call_count += 1
if self._should_fail:
raise RuntimeError(f"Tool '{self.name}' execution failed")
return self._result
def make_response(
content: str = "",
tool_calls: list[ToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
)
def make_plan_response(
steps: list[dict],
reasoning: str = "Plan reasoning",
) -> LLMResponse:
plan_json = json.dumps({
"reasoning": reasoning,
"steps": steps,
})
return make_response(content=plan_json)
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
# ---------------------------------------------------------------------------
# Test 1: Custom fallback ["plan_exec", "react"] — planning 失败plan_exec 失败react 成功
# ---------------------------------------------------------------------------
class TestCustomFallbackPlanExecReact:
"""自定义回退链 ["plan_exec", "react"]"""
@pytest.mark.asyncio
async def test_plan_exec_fails_react_succeeds(self):
"""planning 失败 → plan_exec 失败 → react 成功"""
tool = FakeTool(name="search", result={"results": ["found"]})
# 1st: planning fails (invalid JSON)
# 2nd: plan_exec planning also fails
# 3rd: ReAct succeeds (tool call + final answer)
invalid_plan = make_response(content="Cannot create a plan")
invalid_plan_exec = make_response(content="Also cannot plan")
react_tool_call = make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
)
react_final = make_response(content="ReAct fallback answer")
gateway = make_mock_gateway([
invalid_plan,
invalid_plan_exec,
react_tool_call,
react_final,
])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["plan_exec", "react"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Complex task"}],
tools=[tool],
)
assert isinstance(result, ReActResult)
assert result.output == "ReAct fallback answer"
assert result.fallback_strategy == "react"
assert tool.call_count == 1
@pytest.mark.asyncio
async def test_plan_exec_succeeds(self):
"""planning 失败 → plan_exec 成功"""
tool = FakeTool(name="search", result={"results": ["found"]})
# 1st: planning fails
# 2nd: plan_exec planning succeeds
# 3rd: synthesis
invalid_plan = make_response(content="Cannot create a plan")
plan_exec_response = make_plan_response([
{"step_id": 1, "tool_name": "search", "arguments": {"query": "test"}, "reasoning": "Search"},
])
synthesis = make_response(content="Plan-exec result")
gateway = make_mock_gateway([invalid_plan, plan_exec_response, synthesis])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["plan_exec", "react"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Complex task"}],
tools=[tool],
)
assert result.output == "Plan-exec result"
assert result.fallback_strategy == "plan_exec"
assert tool.call_count == 1
# ---------------------------------------------------------------------------
# Test 2: Custom fallback ["direct"] — planning 失败direct 成功
# ---------------------------------------------------------------------------
class TestCustomFallbackDirect:
"""自定义回退链 ["direct"]"""
@pytest.mark.asyncio
async def test_direct_fallback_succeeds(self):
"""planning 失败 → direct LLM 调用成功"""
invalid_plan = make_response(content="Cannot plan this task")
direct_response = make_response(content="Direct LLM answer")
gateway = make_mock_gateway([invalid_plan, direct_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["direct"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Complex task"}],
)
assert isinstance(result, ReActResult)
assert result.output == "Direct LLM answer"
assert result.fallback_strategy == "direct"
@pytest.mark.asyncio
async def test_direct_fallback_with_system_prompt(self):
"""direct 回退时使用 system_prompt"""
invalid_plan = make_response(content="Cannot plan")
direct_response = make_response(content="Answer with context")
gateway = make_mock_gateway([invalid_plan, direct_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["direct"],
)
result = await engine.execute(
messages=[{"role": "user", "content": "Complex task"}],
system_prompt="You are a helpful assistant",
)
assert result.output == "Answer with context"
assert result.fallback_strategy == "direct"
# ---------------------------------------------------------------------------
# Test 3: Default fallback — 所有策略失败 → RuntimeError
# ---------------------------------------------------------------------------
class TestDefaultFallbackAllFail:
"""默认回退链所有策略都失败时抛出 RuntimeError"""
@pytest.mark.asyncio
async def test_all_fallback_strategies_exhausted(self):
"""所有回退策略耗尽 → RuntimeError"""
# 所有 LLM 调用都返回无效 JSON
# 默认 fallback: simplified_rewoo → react → direct
# simplified_rewoo: 1 invalid plan
# react: 1 response (but we make it fail by having the gateway raise)
# direct: 1 response (but we make it fail too)
# Use a gateway that raises on every call after the first
call_count = 0
class FailingGateway:
def __init__(self):
self.chat = AsyncMock(side_effect=self._fail_after_first)
async def _fail_after_first(self, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call: planning fails (invalid JSON)
return make_response(content="Not a plan")
elif call_count == 2:
# Second call: simplified_rewoo planning also fails
return make_response(content="Still not a plan")
else:
# All subsequent calls fail
raise RuntimeError("LLM service unavailable")
gateway = FailingGateway()
engine = ReWOOEngine(llm_gateway=gateway)
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
await engine.execute(
messages=[{"role": "user", "content": "Impossible task"}],
)
@pytest.mark.asyncio
async def test_default_fallback_strategies_order(self):
"""验证默认回退策略顺序为 simplified_rewoo → react → direct"""
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
)
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]
# ---------------------------------------------------------------------------
# Test 4: Stream mode with custom fallback
# ---------------------------------------------------------------------------
class TestStreamModeCustomFallback:
"""流式模式下的自定义回退链"""
@pytest.mark.asyncio
async def test_stream_direct_fallback(self):
"""流式模式下 planning 失败 → direct 回退"""
invalid_plan = make_response(content="Cannot plan")
direct_response = make_response(content="Stream direct answer")
gateway = make_mock_gateway([invalid_plan, direct_response])
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["direct"],
)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Stream task"}],
):
events.append(event)
# 应该有 planning 事件和 final_answer 事件
event_types = [e.event_type for e in events]
assert "planning" in event_types
assert "final_answer" in event_types
# 最终答案来自 direct 回退
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) >= 1
assert final_events[0].data["output"] == "Stream direct answer"
@pytest.mark.asyncio
async def test_stream_react_fallback(self):
"""流式模式下 planning 失败 → react 回退"""
tool = FakeTool(name="search", result={"results": ["found"]})
invalid_plan = make_response(content="Cannot plan")
# ReAct stream mode needs chat_stream as an async generator
from agentkit.llm.protocol import StreamChunk
async def mock_chat_stream(**kwargs):
# First chunk: tool call
yield StreamChunk(
content="",
model="test-model",
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"query": "test"})],
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
# Second chunk: final answer
yield StreamChunk(
content="Stream react answer",
model="test-model",
tool_calls=[],
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
is_final=True,
)
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=[invalid_plan])
gateway.chat_stream = mock_chat_stream
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["react"],
)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Stream task"}],
tools=[tool],
):
events.append(event)
event_types = [e.event_type for e in events]
assert "planning" in event_types
assert "final_answer" in event_types
@pytest.mark.asyncio
async def test_stream_all_fallbacks_exhausted(self):
"""流式模式下所有回退策略耗尽 → RuntimeError"""
call_count = 0
class FailingStreamGateway:
def __init__(self):
self.chat = AsyncMock(side_effect=self._respond_or_fail)
self.chat_stream = AsyncMock(side_effect=self._stream_fail)
async def _respond_or_fail(self, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return make_response(content="Not a plan")
raise RuntimeError("LLM unavailable")
async def _stream_fail(self, **kwargs):
raise RuntimeError("Stream unavailable")
gateway = FailingStreamGateway()
engine = ReWOOEngine(llm_gateway=gateway)
with pytest.raises(RuntimeError, match="All ReWOO fallback strategies exhausted"):
async for _ in engine.execute_stream(
messages=[{"role": "user", "content": "Impossible task"}],
):
pass
# ---------------------------------------------------------------------------
# Test 5: Invalid strategy validation
# ---------------------------------------------------------------------------
class TestFallbackStrategyValidation:
"""回退策略参数验证"""
@pytest.mark.asyncio
async def test_invalid_strategy_skipped(self):
"""无效策略被跳过,有效策略仍然执行"""
invalid_plan = make_response(content="Cannot plan")
direct_response = make_response(content="Direct answer")
gateway = make_mock_gateway([invalid_plan, direct_response])
# "invalid_strategy" 应被跳过,"direct" 应执行
engine = ReWOOEngine(
llm_gateway=gateway,
fallback_strategies=["invalid_strategy", "direct"],
)
# "invalid_strategy" 被跳过,所以 fallback_strategies 只包含 "direct"
assert engine._fallback_strategies == ["direct"]
result = await engine.execute(
messages=[{"role": "user", "content": "Task"}],
)
assert result.output == "Direct answer"
assert result.fallback_strategy == "direct"
@pytest.mark.asyncio
async def test_all_invalid_strategies_uses_defaults(self):
"""全部无效策略时回退到默认策略"""
engine = ReWOOEngine(
llm_gateway=MagicMock(spec=LLMGateway),
fallback_strategies=["invalid1", "invalid2"],
)
# 全部无效时回退到默认策略
assert engine._fallback_strategies == ["simplified_rewoo", "react", "direct"]