fischer-agentkit/tests/unit/test_reflexion_main_flow.py

654 lines
23 KiB
Python

"""U5/R4: Reflexion in main flow — verify fail -> reflect -> retry tests.
Extends the existing reinjection loop (U4) with LLM-generated reflection
after reinjections exhaust. Mirrors ReflexionEngine._reflect() call shape
but drives it from within ReActEngine's _execute_loop.
Test scenarios:
- AE1 happy path: verify fails -> reflect -> retry passes verify -> completed
- Edge: max_reflections=2 -> 2 retries -> gave_up_after_reflections
- Edge: _reset_loop_detector() between attempts preserves budgets
- Edge: reflect quota 0 -> no retry, return best result (verify_failed)
- Error: reflect LLM call fails -> skip reflection, retry with errors
- Error: all retries fail -> gave_up_after_reflections propagates
- Integration: DIRECT_CHAT/REACT unaffected (max_reflections=0 default)
- Integration: Recovery layer skips gave_up_after_reflections (no double-reflexion)
- Integration: RuleBasedReflector treats gave_up_after_reflections as failure
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.core.react import ReActEngine
from agentkit.core.verification_loop import VerificationResult
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── Helpers (mirrors test_verify_reinjection.py) ──────────────
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
"""Create a mock LLMGateway that returns given responses in order."""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
gateway.get_provider_name_for_model = MagicMock(return_value=None)
return gateway
def make_response(content: str = "") -> LLMResponse:
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
tool_calls=[],
)
def make_verify_result(passed: bool, errors: list[str] | None = None) -> VerificationResult:
return VerificationResult(
passed=passed,
attempts=1,
test_output="$ pytest\nFAILED test_x.py" if not passed else "$ pytest\nOK",
errors=errors or ([] if passed else ["test_x.py::test_failed"]),
)
def make_mock_vloop(verify_results: list[VerificationResult]) -> MagicMock:
"""Create a mock VerificationLoop whose verify() returns given results."""
vloop = MagicMock()
vloop.verify = AsyncMock(side_effect=verify_results)
return vloop
# ── AE1: Happy path — verify fail -> reflect -> retry passes ──
class TestReflexionHappyPath:
"""AE1: verify fails -> reflect -> retry within quota; retry passes verify."""
async def test_verify_fail_reflect_retry_passes(self):
"""verify fail -> reinjections exhausted -> reflect -> retry passes verify."""
# gateway.chat calls: main1, reflect, main2
gateway = make_mock_gateway(
[
make_response("bad answer"),
make_response("reflection: fix the bug"),
make_response("good answer"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=10,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False, errors=["AssertionError"]),
make_verify_result(passed=True),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# 3 chat calls: main1 + reflect + main2
assert gateway.chat.await_count == 3
assert result.output == "good answer"
assert result.status == "success"
assert engine._reflection_count == 1
async def test_reflection_text_injected_into_conversation(self):
"""The reflection text appears in the conversation for the retry call."""
gateway = make_mock_gateway(
[
make_response("bad"),
make_response("you forgot to handle None"),
make_response("good"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=10,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=True),
]
),
):
await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# The 3rd chat call (main2) should have reflection in conversation
third_call = gateway.chat.await_args_list[2]
msgs_sent = third_call.kwargs.get("messages") or third_call[1].get("messages")
reflection_msgs = [
m for m in msgs_sent if "Reflection from Previous Attempt" in m.get("content", "")
]
assert len(reflection_msgs) >= 1
assert "you forgot to handle None" in reflection_msgs[-1]["content"]
# ── Edge: max_reflections=2 -> 2 retries -> gave_up_after_reflections ──
class TestReflexionExhaustion:
"""max_reflections=2: 2 retry attempts, then gave_up_after_reflections."""
async def test_two_reflections_then_gave_up(self):
"""max_reflections=2 -> 2 reflect retries fail -> gave_up_after_reflections."""
# gateway.chat: main1, reflect1, main2, reflect2, main3
gateway = make_mock_gateway(
[
make_response("bad1"),
make_response("reflection1"),
make_response("bad2"),
make_response("reflection2"),
make_response("bad3"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=20,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=False),
make_verify_result(passed=False),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# 5 chat calls: 3 main + 2 reflect
assert gateway.chat.await_count == 5
assert result.status == "gave_up_after_reflections"
assert result.output == "bad3"
assert engine._reflection_count == 2
async def test_reflect_quota_zero_no_retry(self):
"""max_reflections=0 -> no reflection retry, return verify_failed."""
gateway = make_mock_gateway([make_response("bad answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["false"],
max_reinjections=0,
max_reflections=0,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop([make_verify_result(passed=False)]),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
# Only 1 chat call (no reflect)
assert gateway.chat.await_count == 1
assert result.status == "verify_failed"
assert result.output == "bad answer"
assert engine._reflection_count == 0
# ── Edge: _reset_loop_detector preserves budgets ──
class TestResetLoopDetectorPreservesBudgets:
"""_reset_loop_detector() between reflection attempts clears loop window
but preserves budget counters (KTD-9)."""
async def test_loop_detector_reset_budgets_preserved(self):
"""Between reflection retries, loop window is cleared but budget
counters (_verify_count, _reflect_count, _reflection_count) are preserved."""
gateway = make_mock_gateway(
[
make_response("bad1"),
make_response("reflection1"),
make_response("bad2"),
make_response("reflection2"),
make_response("bad3"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=20,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
# Spy on _reset_loop_detector
with patch.object(
engine, "_reset_loop_detector", wraps=engine._reset_loop_detector
) as spy_reset:
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=False),
make_verify_result(passed=False),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# _reset_loop_detector called at least twice (once per reflection)
assert spy_reset.call_count >= 2
# Budget counters preserved (not reset to 0)
assert engine._reflection_count == 2
assert engine._verify_count >= 2 # at least 2 verify attempts
assert result.status == "gave_up_after_reflections"
async def test_loop_window_cleared_between_reflections(self):
"""After _reset_loop_detector, _loop_window is empty."""
gateway = make_mock_gateway(
[
make_response("bad1"),
make_response("reflection1"),
make_response("good"),
]
)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=10,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=True),
]
),
):
await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# After execution, loop_window should be clear (reset was called)
assert len(engine._loop_window) == 0
# ── Error: reflect LLM call fails ──
class TestReflectLLMFailure:
"""Reflect LLM call fails -> skip reflection text, retry with verify errors."""
async def test_reflect_call_fails_retries_with_errors(self):
"""When reflect LLM call raises, skip reflection text, inject verify
errors instead, and still retry."""
# gateway.chat: main1, reflect(raises), main2
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(
side_effect=[
make_response("bad1"),
RuntimeError("reflect LLM unavailable"),
make_response("bad2"),
]
)
gateway.get_provider_name_for_model = MagicMock(return_value=None)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=10,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=1,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False, errors=["err1"]),
make_verify_result(passed=False, errors=["err2"]),
]
),
):
result = await engine.execute(
messages=[{"role": "user", "content": "write code"}],
)
# 3 chat calls: main1 + reflect(fails) + main2
assert gateway.chat.await_count == 3
# _reflection_count incremented even though reflect failed
assert engine._reflection_count == 1
# Since reflect was attempted, status is gave_up_after_reflections
assert result.status == "gave_up_after_reflections"
# The 3rd call (main2) should have verify errors injected (not reflection)
third_call = gateway.chat.await_args_list[2]
msgs_sent = third_call.kwargs.get("messages") or third_call[1].get("messages")
error_msgs = [m for m in msgs_sent if "验证失败" in m.get("content", "")]
assert len(error_msgs) >= 1
# ── Integration: DIRECT_CHAT/REACT unaffected ──
class TestDirectChatUnaffected:
"""max_reflections defaults to 0 — DIRECT_CHAT/REACT unaffected."""
def test_default_max_reflections_is_zero(self):
"""ReActEngine defaults to max_reflections=0 (no reflection)."""
gateway = make_mock_gateway([])
engine = ReActEngine(llm_gateway=gateway)
assert engine._max_reflections == 0
async def test_no_reflection_without_max_reflections(self):
"""Without max_reflections set, verify fail -> verify_failed (not
gave_up_after_reflections)."""
gateway = make_mock_gateway([make_response("bad answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=True,
verification_commands=["false"],
max_reinjections=0,
# max_reflections defaults to 0
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop([make_verify_result(passed=False)]),
):
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
assert gateway.chat.await_count == 1
assert result.status == "verify_failed"
assert engine._reflection_count == 0
async def test_verification_disabled_no_reflection(self):
"""verification_enabled=False -> no verify, no reflect, normal flow."""
gateway = make_mock_gateway([make_response("answer")])
engine = ReActEngine(
llm_gateway=gateway,
max_steps=5,
verification_enabled=False,
max_reflections=2, # even with reflect quota, no verify = no reflect
)
result = await engine.execute(
messages=[{"role": "user", "content": "do something"}],
)
assert gateway.chat.await_count == 1
assert result.status == "success"
assert engine._reflection_count == 0
# ── Integration: Recovery layer — no double-reflexion ──
class TestRecoveryNoDoubleReflexion:
"""Recovery layer (_fallback_chain.py) skips gave_up_after_reflections."""
async def test_gave_up_after_reflections_skips_recovery(self):
"""Main returns gave_up_after_reflections -> Recovery skipped -> Emergency."""
from agentkit.server._fallback_chain import (
execute_with_fallback_chain,
_REFLEXION_EXHAUSTED_STATUSES,
)
# Verify the status is in the exhausted set
assert "gave_up_after_reflections" in _REFLEXION_EXHAUSTED_STATUSES
# Mock main engine returning gave_up_after_reflections
from agentkit.core.react import ReActResult
mock_react_engine = MagicMock()
mock_react_engine.execute = AsyncMock(
return_value=ReActResult(
output="bad output",
trajectory=[],
total_steps=3,
total_tokens=100,
status="gave_up_after_reflections",
)
)
mock_gateway = MagicMock(spec=LLMGateway)
# Mock ReflexionEngine to track if Recovery is called
with patch("agentkit.server._fallback_chain.ReflexionEngine") as mock_reflexion_cls:
result = await execute_with_fallback_chain(
react_engine=mock_react_engine,
llm_gateway=mock_gateway,
messages=[{"role": "user", "content": "test"}],
tools=None,
model="test",
agent_name="test",
system_prompt=None,
)
# Recovery (ReflexionEngine) should NOT be called
assert mock_reflexion_cls.call_count == 0
# Emergency tier should fire
assert result.status == "emergency"
async def test_verify_failed_still_triggers_recovery(self):
"""verify_failed (not gave_up) -> Recovery still triggered (no regression)."""
from agentkit.core.react import ReActResult
from agentkit.server._fallback_chain import execute_with_fallback_chain
mock_react_engine = MagicMock()
mock_react_engine.execute = AsyncMock(
return_value=ReActResult(
output="bad",
trajectory=[],
total_steps=1,
total_tokens=50,
status="verify_failed",
)
)
mock_gateway = MagicMock(spec=LLMGateway)
with patch("agentkit.server._fallback_chain.ReflexionEngine") as mock_reflexion_cls:
mock_recovery_result = MagicMock()
mock_recovery_result.status = "success"
mock_recovery_result.output = "recovered"
mock_reflexion_instance = MagicMock()
mock_reflexion_instance.execute = AsyncMock(return_value=mock_recovery_result)
mock_reflexion_cls.return_value = mock_reflexion_instance
result = await execute_with_fallback_chain(
react_engine=mock_react_engine,
llm_gateway=mock_gateway,
messages=[{"role": "user", "content": "test"}],
tools=None,
model="test",
agent_name="test",
system_prompt=None,
)
# Recovery (ReflexionEngine) SHOULD be called for verify_failed
assert mock_reflexion_cls.call_count == 1
assert result.status == "recovered"
# ── Integration: RuleBasedReflector treats gave_up as failure ──
class TestEvolutionTreatsGaveUpAsFailure:
"""RuleBasedReflector treats gave_up_after_reflections as failure."""
async def test_rule_based_reflector_gave_up_is_failure(self):
"""RuleBasedReflector.outcome == 'failure' for non-COMPLETED status."""
from datetime import datetime, timezone
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.evolution.reflector import RuleBasedReflector
reflector = RuleBasedReflector()
now = datetime.now(timezone.utc)
task = TaskMessage(
task_id="test-1",
agent_name="test",
input_data={"query": "test"},
task_type="test",
priority=1,
callback_url=None,
created_at=now,
)
# gave_up_after_reflections maps to FAILED (not COMPLETED)
result = TaskResult(
task_id="test-1",
agent_name="test",
status=TaskStatus.FAILED,
output_data=None,
error_message="gave_up_after_reflections",
started_at=now,
completed_at=now,
)
reflection = await reflector.reflect(task, result)
assert reflection.outcome == "failure"
assert reflection.quality_score == 0.0
async def test_rule_based_reflector_completed_is_success(self):
"""RuleBasedReflector.outcome == 'success' for COMPLETED status (control)."""
from datetime import datetime, timezone
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.evolution.reflector import RuleBasedReflector
reflector = RuleBasedReflector()
now = datetime.now(timezone.utc)
task = TaskMessage(
task_id="test-2",
agent_name="test",
input_data={"query": "test"},
task_type="test",
priority=1,
callback_url=None,
created_at=now,
)
result = TaskResult(
task_id="test-2",
agent_name="test",
status=TaskStatus.COMPLETED,
output_data={"text": "good"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
)
reflection = await reflector.reflect(task, result)
assert reflection.outcome == "success"
# ── Streaming path ──
class TestReflexionStreamPath:
"""execute_stream mode: verify fail -> reflect -> retry."""
async def test_stream_reflect_retry_passes(self):
"""Stream mode: verify fail -> reflect -> retry passes verify."""
from agentkit.llm.protocol import StreamChunk
def make_stream_chunks(content: str):
async def _stream(**kwargs):
mid = len(content) // 2
yield StreamChunk(content=content[:mid], model="test-model")
yield StreamChunk(content=content[mid:], model="test-model")
return _stream
# For streaming: chat_stream for main calls, chat for reflect call
gateway = MagicMock(spec=LLMGateway)
gateway.chat_stream = MagicMock(
side_effect=[
make_stream_chunks("bad code")(),
make_stream_chunks("fixed code")(),
]
)
# Reflect call uses chat (not chat_stream)
gateway.chat = AsyncMock(return_value=make_response("reflection text"))
gateway.get_provider_name_for_model = MagicMock(return_value=None)
engine = ReActEngine(
llm_gateway=gateway,
max_steps=10,
verification_enabled=True,
verification_commands=["pytest"],
max_reinjections=0,
max_reflections=2,
)
with patch(
"agentkit.core.verification_loop.VerificationLoop",
return_value=make_mock_vloop(
[
make_verify_result(passed=False),
make_verify_result(passed=True),
]
),
):
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "write code"}],
):
events.append(event)
# 2 chat_stream calls (main1 + main2) + 1 chat call (reflect)
assert gateway.chat_stream.call_count == 2
assert gateway.chat.await_count == 1
final_events = [e for e in events if e.event_type == "final_answer"]
assert len(final_events) >= 1
assert "fixed code" in final_events[-1].data.get("output", "")
final_result_events = [e for e in events if e.event_type == "final_result"]
if final_result_events:
assert final_result_events[-1].data["result"].status == "success"