620 lines
26 KiB
Python
620 lines
26 KiB
Python
"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24).
|
|
|
|
Per plan U3 Execution note: characterization-first — verify that
|
|
`ReActEngine(phase_policy=None)` behaves identically to pre-change (no
|
|
enforcement, no advance_phase tool, no _current_phase mutation). Then add
|
|
enforcement tests.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.phase import PhasePolicy, PhaseState, default_policy
|
|
from agentkit.core.react import ReActEngine
|
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Characterization — phase_policy=None preserves existing behavior
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCharacterizationNoPolicy:
|
|
"""When phase_policy=None, no enforcement happens and behavior matches
|
|
pre-Wave-3."""
|
|
|
|
def test_init_without_phase_policy(self):
|
|
# Minimal stub LLM gateway — we're only testing constructor.
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
assert engine._phase_policy is None
|
|
assert engine._current_phase is None
|
|
assert engine._steps_in_phase == 0
|
|
assert engine.current_phase is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_tool_dispatches_without_phase_check(self):
|
|
"""Tool dispatch proceeds normally when no policy set."""
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
|
|
# MagicMock.name is a special attribute used internally by Mock for
|
|
# repr — setting it post-construction does not make mock.name == "x"
|
|
# hold. Patch _find_tool directly to bypass the name lookup.
|
|
fake_tool = MagicMock()
|
|
fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
|
fake_tool.input_schema = None
|
|
engine._find_tool = lambda name, tools: fake_tool
|
|
|
|
result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool])
|
|
assert result == {"output": "ok"}
|
|
fake_tool.safe_execute.assert_awaited_once_with(x=1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_advance_phase_returns_none_without_policy(self):
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
assert engine.advance_phase() is None
|
|
|
|
def test_reset_does_not_touch_phase_state_when_no_policy(self):
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(llm_gateway=gateway)
|
|
engine.reset()
|
|
assert engine._current_phase is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Initialization with phase_policy
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPhasePolicyInitialization:
|
|
def test_phase_policy_set_initializes_current_phase(self):
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=default_policy(),
|
|
)
|
|
assert engine._phase_policy is not None
|
|
assert engine._current_phase == PhaseState.PLANNING
|
|
assert engine._steps_in_phase == 0
|
|
|
|
def test_reset_resets_phase_to_start(self):
|
|
gateway = MagicMock()
|
|
engine = ReActEngine(
|
|
llm_gateway=gateway,
|
|
phase_policy=default_policy(),
|
|
)
|
|
# Manually move phase forward (simulating execute progress).
|
|
engine.advance_phase() # PLANNING → BUILDING
|
|
assert engine._current_phase == PhaseState.BUILDING
|
|
engine._steps_in_phase = 5
|
|
|
|
engine.reset()
|
|
assert engine._current_phase == PhaseState.PLANNING
|
|
assert engine._steps_in_phase == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# advance_phase() transitions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAdvancePhase:
|
|
@pytest.fixture
|
|
def engine(self):
|
|
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
|
|
def test_planning_to_building(self, engine):
|
|
new_phase = engine.advance_phase()
|
|
assert new_phase == PhaseState.BUILDING
|
|
assert engine.current_phase == PhaseState.BUILDING
|
|
assert engine._steps_in_phase == 0 # counter reset on transition
|
|
|
|
def test_building_to_verification(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
new_phase = engine.advance_phase()
|
|
assert new_phase == PhaseState.VERIFICATION
|
|
assert engine.current_phase == PhaseState.VERIFICATION
|
|
|
|
def test_verification_to_delivery(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
engine.advance_phase() # → VERIFICATION
|
|
new_phase = engine.advance_phase()
|
|
assert new_phase == PhaseState.DELIVERY
|
|
assert engine.current_phase == PhaseState.DELIVERY
|
|
|
|
def test_delivery_returns_none(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
engine.advance_phase() # → VERIFICATION
|
|
engine.advance_phase() # → DELIVERY
|
|
result = engine.advance_phase()
|
|
assert result is None
|
|
assert engine.current_phase == PhaseState.DELIVERY
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _check_phase_permission — whitelist enforcement
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPhasePermission:
|
|
@pytest.fixture
|
|
def engine(self):
|
|
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
|
|
def test_search_allowed_in_planning(self, engine):
|
|
assert engine._check_phase_permission("search", {}) is None
|
|
|
|
def test_write_file_blocked_in_planning(self, engine):
|
|
result = engine._check_phase_permission("write_file", {})
|
|
assert result is not None
|
|
assert result["error"] == "phase_violation"
|
|
assert "write_file" in result["message"]
|
|
assert result["current_phase"] == "planning"
|
|
|
|
def test_write_file_allowed_in_building(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
assert engine._check_phase_permission("write_file", {}) is None
|
|
|
|
def test_any_tool_allowed_in_delivery(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
engine.advance_phase() # → VERIFICATION
|
|
engine.advance_phase() # → DELIVERY
|
|
assert engine._check_phase_permission("literally_anything", {}) is None
|
|
|
|
def test_bash_command_filter_blocks_rm_in_planning(self, engine):
|
|
result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
|
assert result is not None
|
|
assert result["error"] == "phase_violation"
|
|
assert "rm" in result["message"] or "Bash command" in result["message"]
|
|
|
|
def test_bash_command_filter_allows_safe_in_planning(self, engine):
|
|
# `ls` and `git status` are not blocked.
|
|
assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None
|
|
assert engine._check_phase_permission("shell", {"command": "git status"}) is None
|
|
|
|
def test_bash_command_filter_no_restriction_in_building(self, engine):
|
|
engine.advance_phase() # → BUILDING
|
|
# `rm` is allowed in building phase.
|
|
assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _execute_tool integration — phase enforcement actually blocks dispatch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExecuteToolPhaseEnforcement:
|
|
@pytest.fixture
|
|
def engine_with_tools(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
# Two fake tools: one allowed in PLANNING (search), one not (write_file).
|
|
# MagicMock.name can't be set post-construction (special attribute),
|
|
# so we patch _find_tool with a dict-based lookup.
|
|
search_tool = MagicMock()
|
|
search_tool.input_schema = None
|
|
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
|
|
|
write_tool = MagicMock()
|
|
write_tool.input_schema = None
|
|
write_tool.safe_execute = AsyncMock(return_value={"written": True})
|
|
|
|
tools_by_name = {"search": search_tool, "write_file": write_tool}
|
|
engine._find_tool = lambda name, tools: tools_by_name.get(name)
|
|
|
|
return engine, [search_tool, write_tool]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools):
|
|
engine, tools = engine_with_tools
|
|
# write_file in PLANNING should be blocked — write_tool.safe_execute
|
|
# should NEVER be called.
|
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
|
assert result["error"] == "phase_violation"
|
|
assert result["current_phase"] == "planning"
|
|
write_tool = tools[1]
|
|
write_tool.safe_execute.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allowed_tool_dispatches_normally(self, engine_with_tools):
|
|
engine, tools = engine_with_tools
|
|
result = await engine._execute_tool("search", {"query": "foo"}, tools)
|
|
assert result == {"results": []}
|
|
search_tool = tools[0]
|
|
search_tool.safe_execute.assert_awaited_once_with(query="foo")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools):
|
|
engine, tools = engine_with_tools
|
|
# First: write_file blocked in PLANNING.
|
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
|
assert result["error"] == "phase_violation"
|
|
# Advance to BUILDING.
|
|
engine.advance_phase()
|
|
# Now: write_file allowed.
|
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
|
assert result == {"written": True}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auto-advance safety net (KTD6)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAutoAdvance:
|
|
def test_auto_advance_after_threshold(self):
|
|
# Custom policy with auto-advance after 2 steps.
|
|
policy = PhasePolicy(
|
|
whitelist={
|
|
PhaseState.PLANNING: frozenset({"search"}),
|
|
PhaseState.BUILDING: frozenset({"write_file"}),
|
|
PhaseState.VERIFICATION: frozenset({"shell"}),
|
|
PhaseState.DELIVERY: frozenset({"*"}),
|
|
},
|
|
auto_advance_after_steps=2,
|
|
)
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy)
|
|
assert engine.current_phase == PhaseState.PLANNING
|
|
|
|
# Step 1: counter goes to 1, no advance yet.
|
|
engine._steps_in_phase += 1
|
|
assert engine._maybe_auto_advance() is False
|
|
assert engine.current_phase == PhaseState.PLANNING
|
|
|
|
# Step 2: counter hits 2, advance triggered.
|
|
engine._steps_in_phase += 1
|
|
assert engine._maybe_auto_advance() is True
|
|
assert engine.current_phase == PhaseState.BUILDING
|
|
assert engine._steps_in_phase == 0 # reset on advance
|
|
|
|
def test_auto_advance_none_default(self):
|
|
# default_policy has auto_advance_after_steps=None — no auto-advance.
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
engine._steps_in_phase = 100
|
|
assert engine._maybe_auto_advance() is False
|
|
assert engine.current_phase == PhaseState.PLANNING
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AdvancePhaseTool integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAdvancePhaseTool:
|
|
@pytest.mark.asyncio
|
|
async def test_advance_phase_tool_transitions_engine(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
tool = AdvancePhaseTool(engine=engine)
|
|
result = await tool.execute()
|
|
assert result["is_error"] is False
|
|
assert result["current_phase"] == "building"
|
|
assert engine.current_phase == PhaseState.BUILDING
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_advance_phase_tool_at_delivery_returns_error(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
# Walk through all phases.
|
|
engine.advance_phase() # PLANNING → BUILDING
|
|
engine.advance_phase() # BUILDING → VERIFICATION
|
|
engine.advance_phase() # VERIFICATION → DELIVERY
|
|
tool = AdvancePhaseTool(engine=engine)
|
|
result = await tool.execute()
|
|
assert result["is_error"] is True
|
|
assert result["error"] == "already_at_final_phase"
|
|
assert result["current_phase"] == "delivery"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_advance_phase_tool_without_policy_returns_error(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
|
tool = AdvancePhaseTool(engine=engine)
|
|
result = await tool.execute()
|
|
assert result["is_error"] is True
|
|
assert result["error"] == "no_phase_policy"
|
|
|
|
def test_tool_schema_accepts_no_arguments(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
|
tool = AdvancePhaseTool(engine=engine)
|
|
# input_schema has empty properties + additionalProperties:false —
|
|
# no arguments expected.
|
|
assert tool.input_schema["properties"] == {}
|
|
assert tool.input_schema["additionalProperties"] is False
|
|
|
|
def test_tool_bypasses_phase_check(self):
|
|
"""`advance_phase` is the LLM's escape hatch — must never be blocked."""
|
|
# _check_phase_permission should NOT block advance_phase even in PLANNING.
|
|
# The bypass is implemented in _execute_tool by name check.
|
|
# We verify the bypass indirectly: tool dispatches normally even in
|
|
# PLANNING (where only search/read_file/bash/tool_search are allowed).
|
|
# advance_phase is not in the whitelist, but the name-based bypass
|
|
# in _execute_tool lets it through.
|
|
# (Direct unit test of the bypass would require mocking _find_tool.)
|
|
# Sanity: advance_phase is not in any whitelist.
|
|
for phase, allowed in default_policy().whitelist.items():
|
|
assert "advance_phase" not in allowed, (
|
|
f"advance_phase must not be in {phase.value} whitelist"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Wave 4 U2 — phase_violation accumulator + drain
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPhaseViolationAccumulator:
|
|
"""_check_phase_permission records violations; _drain_phase_violations
|
|
yields them as ReActEvents and clears the accumulator."""
|
|
|
|
@pytest.fixture
|
|
def engine(self):
|
|
return ReActEngine(
|
|
llm_gateway=MagicMock(),
|
|
phase_policy=default_policy(),
|
|
)
|
|
|
|
def test_violation_appended_on_tool_block(self, engine):
|
|
# write_file is blocked in PLANNING.
|
|
engine._check_phase_permission("write_file", {})
|
|
assert len(engine._phase_violations) == 1
|
|
v = engine._phase_violations[0]
|
|
assert v["error"] == "phase_violation"
|
|
assert v["tool"] == "write_file"
|
|
assert v["current_phase"] == "planning"
|
|
assert v["violation_kind"] == "tool_not_allowed"
|
|
|
|
def test_violation_appended_on_bash_block(self, engine):
|
|
engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
|
assert len(engine._phase_violations) == 1
|
|
v = engine._phase_violations[0]
|
|
assert v["violation_kind"] == "bash_command_blocked"
|
|
assert v["tool"] == "shell"
|
|
assert v["command_preview"] == "rm -rf /tmp"
|
|
|
|
def test_no_violation_when_allowed(self, engine):
|
|
# search is allowed in PLANNING.
|
|
engine._check_phase_permission("search", {})
|
|
assert engine._phase_violations == []
|
|
|
|
def test_no_violation_without_policy(self):
|
|
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
|
engine._check_phase_permission("anything", {})
|
|
assert engine._phase_violations == []
|
|
|
|
def test_drain_returns_events_and_clears(self, engine):
|
|
# Trigger two violations.
|
|
engine._check_phase_permission("write_file", {"path": "/a"})
|
|
engine._check_phase_permission("write_file", {"path": "/b"})
|
|
assert len(engine._phase_violations) == 2
|
|
|
|
events = engine._drain_phase_violations(step=3)
|
|
assert len(events) == 2
|
|
assert all(e.event_type == "phase_violation" for e in events)
|
|
assert all(e.step == 3 for e in events)
|
|
# Each event data is a copy (caller can't mutate the accumulator).
|
|
assert events[0].data["tool"] == "write_file"
|
|
# Accumulator cleared after drain.
|
|
assert engine._phase_violations == []
|
|
|
|
def test_drain_empty_returns_empty(self, engine):
|
|
assert engine._drain_phase_violations(step=1) == []
|
|
|
|
def test_drain_returns_shallow_copy(self, engine):
|
|
"""Drained event data must not alias the original violation dict —
|
|
mutating one must not mutate the other."""
|
|
engine._check_phase_permission("write_file", {})
|
|
events = engine._drain_phase_violations(step=1)
|
|
# Mutate the drained event data.
|
|
events[0].data["tool"] = "MUTATED"
|
|
# Original accumulator (now empty) is unaffected — but more importantly,
|
|
# a fresh violation recorded after drain is unaffected.
|
|
engine._check_phase_permission("write_file", {})
|
|
new_violations = engine._phase_violations
|
|
assert new_violations[0]["tool"] == "write_file" # not "MUTATED"
|
|
|
|
def test_reset_clears_violations(self, engine):
|
|
engine._check_phase_permission("write_file", {})
|
|
assert len(engine._phase_violations) == 1
|
|
engine.reset()
|
|
assert engine._phase_violations == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Wave 4 U2 — execute_stream yields phase_violation events
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestExecuteStreamPhaseViolationEvents:
|
|
"""execute_stream must yield phase_violation ReActEvents when a tool is
|
|
blocked by _check_phase_permission. The events are drained after each
|
|
tool_result yield."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_yields_phase_violation_on_tool_block(self):
|
|
"""When the LLM calls a tool blocked by the phase policy, execute_stream
|
|
yields a tool_call event, a tool_result event (with the error dict),
|
|
and a phase_violation event."""
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
engine = ReActEngine(
|
|
llm_gateway=llm_mock_gateway_with_response(
|
|
tool_calls=[{"name": "write_file", "arguments": {"path": "/x"}}],
|
|
content=None,
|
|
),
|
|
phase_policy=default_policy(),
|
|
max_steps=1,
|
|
)
|
|
# Patch _find_tool so we don't need real tools registered. write_file
|
|
# should be blocked by phase_policy before _find_tool is called.
|
|
engine._find_tool = lambda name, tools: None
|
|
|
|
events: list[ReActEvent] = []
|
|
async for ev in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
tools=[],
|
|
):
|
|
events.append(ev)
|
|
|
|
# Expect: thinking → tool_call → tool_result → phase_violation → final_answer
|
|
event_types = [e.event_type for e in events]
|
|
assert "tool_call" in event_types
|
|
assert "tool_result" in event_types
|
|
assert "phase_violation" in event_types
|
|
|
|
# The phase_violation event must come AFTER tool_result.
|
|
tool_result_idx = next(i for i, e in enumerate(events) if e.event_type == "tool_result")
|
|
violation_idx = next(i for i, e in enumerate(events) if e.event_type == "phase_violation")
|
|
assert violation_idx > tool_result_idx
|
|
|
|
# Verify event data.
|
|
violation = events[violation_idx]
|
|
assert violation.data["error"] == "phase_violation"
|
|
assert violation.data["tool"] == "write_file"
|
|
assert violation.data["current_phase"] == "planning"
|
|
assert violation.data["violation_kind"] == "tool_not_allowed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_yields_phase_violation_on_bash_block(self):
|
|
"""When the LLM calls shell with a dangerous command in PLANNING,
|
|
execute_stream yields a phase_violation event with violation_kind
|
|
= bash_command_blocked."""
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
engine = ReActEngine(
|
|
llm_gateway=llm_mock_gateway_with_response(
|
|
tool_calls=[{"name": "shell", "arguments": {"command": "rm -rf /tmp"}}],
|
|
content=None,
|
|
),
|
|
phase_policy=default_policy(),
|
|
max_steps=1,
|
|
)
|
|
engine._find_tool = lambda name, tools: None
|
|
|
|
events: list[ReActEvent] = []
|
|
async for ev in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
tools=[],
|
|
):
|
|
events.append(ev)
|
|
|
|
violation_events = [e for e in events if e.event_type == "phase_violation"]
|
|
assert len(violation_events) == 1
|
|
v = violation_events[0].data
|
|
assert v["violation_kind"] == "bash_command_blocked"
|
|
assert v["tool"] == "shell"
|
|
assert "rm -rf /tmp" in v["command_preview"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_no_violation_when_tool_allowed(self):
|
|
"""When the LLM calls an allowed tool, no phase_violation event is yielded."""
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
engine = ReActEngine(
|
|
llm_gateway=llm_mock_gateway_with_response(
|
|
tool_calls=[{"name": "search", "arguments": {"query": "foo"}}],
|
|
content=None,
|
|
),
|
|
phase_policy=default_policy(),
|
|
max_steps=1,
|
|
)
|
|
# search is allowed in PLANNING; we still need _find_tool to return a
|
|
# tool object so dispatch proceeds.
|
|
search_tool = MagicMock()
|
|
search_tool.input_schema = None
|
|
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
|
engine._find_tool = lambda name, tools: search_tool
|
|
|
|
events: list[ReActEvent] = []
|
|
async for ev in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
tools=[search_tool],
|
|
):
|
|
events.append(ev)
|
|
|
|
assert not any(e.event_type == "phase_violation" for e in events)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_no_violation_without_policy(self):
|
|
"""Without a phase_policy, no phase_violation events are yielded —
|
|
characterization of the no-policy path."""
|
|
from agentkit.core.react import ReActEvent
|
|
|
|
engine = ReActEngine(
|
|
llm_gateway=llm_mock_gateway_with_response(
|
|
tool_calls=[{"name": "any_tool", "arguments": {}}],
|
|
content=None,
|
|
),
|
|
max_steps=1,
|
|
)
|
|
any_tool = MagicMock()
|
|
any_tool.input_schema = None
|
|
any_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
|
engine._find_tool = lambda name, tools: any_tool
|
|
|
|
events: list[ReActEvent] = []
|
|
async for ev in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "test"}],
|
|
tools=[any_tool],
|
|
):
|
|
events.append(ev)
|
|
|
|
assert not any(e.event_type == "phase_violation" for e in events)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers — minimal LLM gateway mocks for execute_stream tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def llm_mock_gateway():
|
|
"""Return a MagicMock LLM gateway (sufficient for constructor tests)."""
|
|
return MagicMock()
|
|
|
|
|
|
def llm_mock_gateway_with_response(tool_calls: list[dict], content: str | None):
|
|
"""Return a MagicMock LLM gateway whose chat_stream yields a single chunk
|
|
containing the given tool_calls (or content for a final-answer response).
|
|
|
|
The mock is shaped to match what execute_stream expects from
|
|
LLMGateway.chat_stream — an async iterable of chunks with attributes
|
|
`content`, `tool_calls`, `usage`, `model`.
|
|
"""
|
|
gateway = MagicMock()
|
|
|
|
# Build a fake chunk. execute_stream reads chunk.content, chunk.tool_calls,
|
|
# chunk.usage, chunk.model. The first three are typically accessed via
|
|
# attribute access; we make a small dataclass-like object.
|
|
class _Chunk:
|
|
def __init__(self, content, tool_calls, usage=None, model="default"):
|
|
self.content = content
|
|
self.tool_calls = tool_calls
|
|
self.usage = usage
|
|
self.model = model
|
|
|
|
# If tool_calls provided, emit a chunk with tool_calls (non-streaming path).
|
|
# Otherwise, emit a chunk with content (final answer path).
|
|
if tool_calls:
|
|
# Convert raw dicts to objects with .name/.arguments/.id attributes
|
|
# (LLMGateway normally returns tool_call objects).
|
|
class _TC:
|
|
def __init__(self, d):
|
|
self.name = d.get("name", "")
|
|
self.arguments = d.get("arguments", {})
|
|
self.id = d.get("id", "tc_test")
|
|
|
|
chunks = [_Chunk(content=None, tool_calls=[_TC(tc) for tc in tool_calls])]
|
|
# Follow with a final-answer chunk so execute_stream's loop exits
|
|
# cleanly after the tool call.
|
|
chunks.append(_Chunk(content="done", tool_calls=None))
|
|
else:
|
|
chunks = [_Chunk(content=content or "final answer", tool_calls=None)]
|
|
|
|
async def _fake_chat_stream(*args, **kwargs):
|
|
for c in chunks:
|
|
yield c
|
|
|
|
gateway.chat_stream = _fake_chat_stream
|
|
return gateway
|