feat(U2): emit phase_violation WS event alongside LLM reinjection
Wave 3 only injected the violation error dict back to the LLM as a tool
result. Wave 4 U2 adds a parallel WS event so the frontend PhaseIndicator
can surface violations to the user.
- ReActEngine: add _phase_violations accumulator (list[dict]). Cleared in
reset(). _check_phase_permission appends a structured violation dict
(with new violation_kind field: tool_not_allowed | bash_command_blocked)
before returning the error.
- Add _drain_phase_violations(step) helper that pops pending violations
and returns ReActEvent(event_type="phase_violation", ...) list. Events
carry a shallow copy of the violation dict so callers can't mutate the
accumulator.
- execute_stream: drain after each tool_result yield at all 3 tool
execution sites (parallel, serial-with-confirmation, parsed_calls).
Non-streaming execute() ignores the accumulator (the LLM reinjection
via the error dict is the only signal there).
- chat.py WS handler: new elif branch forwards phase_violation ReActEvents
to the client as {"type": "phase_violation", "data": ...} WS messages.
- Tests: 11 new tests covering accumulator lifecycle, drain semantics,
shallow-copy isolation, and execute_stream event emission for both
tool_block and bash_block paths. 2 new WS forwarding tests pin the
chat.py path (forward + characterization for REACT mode).
This commit is contained in:
parent
9e28ab315e
commit
4dc58c24bc
|
|
@ -224,6 +224,12 @@ class ReActEngine:
|
||||||
)
|
)
|
||||||
# Steps taken in the current phase (for auto-advance safety net).
|
# Steps taken in the current phase (for auto-advance safety net).
|
||||||
self._steps_in_phase: int = 0
|
self._steps_in_phase: int = 0
|
||||||
|
# Wave 4 U2: phase violation accumulator. _check_phase_permission
|
||||||
|
# appends here when a tool is blocked; execute_stream drains after each
|
||||||
|
# step and yields phase_violation ReActEvents. Non-streaming execute()
|
||||||
|
# simply ignores the accumulator (the error dict returned to the LLM is
|
||||||
|
# the only signal there).
|
||||||
|
self._phase_violations: list[dict[str, Any]] = []
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset internal state for reuse across conversations.
|
"""Reset internal state for reuse across conversations.
|
||||||
|
|
@ -241,6 +247,8 @@ class ReActEngine:
|
||||||
if self._phase_policy is not None:
|
if self._phase_policy is not None:
|
||||||
self._current_phase = self._phase_policy.start_phase
|
self._current_phase = self._phase_policy.start_phase
|
||||||
self._steps_in_phase = 0
|
self._steps_in_phase = 0
|
||||||
|
# Wave 4 U2: clear any pending violations from a prior run.
|
||||||
|
self._phase_violations = []
|
||||||
|
|
||||||
# ── U3/G6: phase state machine ────────────────────────────────────
|
# ── U3/G6: phase state machine ────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -299,11 +307,16 @@ class ReActEngine:
|
||||||
AdvancePhaseTool or pick a different tool).
|
AdvancePhaseTool or pick a different tool).
|
||||||
|
|
||||||
Also applies the bash_command_filter for `bash` tool calls.
|
Also applies the bash_command_filter for `bash` tool calls.
|
||||||
|
|
||||||
|
Wave 4 U2: when blocked, the violation is also appended to
|
||||||
|
`self._phase_violations` so `execute_stream` can drain and yield
|
||||||
|
`phase_violation` ReActEvents to the WS layer (alongside the LLM
|
||||||
|
reinjection that the returned dict provides).
|
||||||
"""
|
"""
|
||||||
if self._phase_policy is None or self._current_phase is None:
|
if self._phase_policy is None or self._current_phase is None:
|
||||||
return None
|
return None
|
||||||
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||||
return {
|
violation = {
|
||||||
"error": "phase_violation",
|
"error": "phase_violation",
|
||||||
"message": (
|
"message": (
|
||||||
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
|
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
|
||||||
|
|
@ -312,12 +325,15 @@ class ReActEngine:
|
||||||
"current_phase": self._current_phase.value,
|
"current_phase": self._current_phase.value,
|
||||||
"tool": tool_name,
|
"tool": tool_name,
|
||||||
"is_error": True,
|
"is_error": True,
|
||||||
|
"violation_kind": "tool_not_allowed",
|
||||||
}
|
}
|
||||||
|
self._phase_violations.append(violation)
|
||||||
|
return violation
|
||||||
# Bash command filter (only applies to shell tool — registered as "shell").
|
# Bash command filter (only applies to shell tool — registered as "shell").
|
||||||
if tool_name == "shell":
|
if tool_name == "shell":
|
||||||
command = str(arguments.get("command", ""))
|
command = str(arguments.get("command", ""))
|
||||||
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
|
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
|
||||||
return {
|
violation = {
|
||||||
"error": "phase_violation",
|
"error": "phase_violation",
|
||||||
"message": (
|
"message": (
|
||||||
f"Bash command blocked in {self._current_phase.value} phase "
|
f"Bash command blocked in {self._current_phase.value} phase "
|
||||||
|
|
@ -327,7 +343,11 @@ class ReActEngine:
|
||||||
"current_phase": self._current_phase.value,
|
"current_phase": self._current_phase.value,
|
||||||
"tool": tool_name,
|
"tool": tool_name,
|
||||||
"is_error": True,
|
"is_error": True,
|
||||||
|
"violation_kind": "bash_command_blocked",
|
||||||
|
"command_preview": command[:200],
|
||||||
}
|
}
|
||||||
|
self._phase_violations.append(violation)
|
||||||
|
return violation
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||||
|
|
@ -351,6 +371,28 @@ class ReActEngine:
|
||||||
return hash_to_name.get(h)
|
return hash_to_name.get(h)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _drain_phase_violations(self, step: int) -> list[ReActEvent]:
|
||||||
|
"""Pop and return ReActEvents for phase violations recorded by
|
||||||
|
``_check_phase_permission`` since the last drain.
|
||||||
|
|
||||||
|
Wave 4 U2: execute_stream calls this after each tool_result yield so
|
||||||
|
the WS layer can surface ``phase_violation`` events to the client
|
||||||
|
(alongside the LLM reinjection that the returned error dict provides).
|
||||||
|
Returns an empty list if no violations are pending.
|
||||||
|
"""
|
||||||
|
if not self._phase_violations:
|
||||||
|
return []
|
||||||
|
events = [
|
||||||
|
ReActEvent(
|
||||||
|
event_type="phase_violation",
|
||||||
|
step=step,
|
||||||
|
data=dict(v),
|
||||||
|
)
|
||||||
|
for v in self._phase_violations
|
||||||
|
]
|
||||||
|
self._phase_violations = []
|
||||||
|
return events
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
|
|
@ -1514,6 +1556,10 @@ class ReActEngine:
|
||||||
step=step,
|
step=step,
|
||||||
data={"tool_name": tc.name, "result": tool_result},
|
data={"tool_name": tc.name, "result": tool_result},
|
||||||
)
|
)
|
||||||
|
# Wave 4 U2: drain phase violations recorded by
|
||||||
|
# _check_phase_permission during this tool call.
|
||||||
|
for _ev in self._drain_phase_violations(step):
|
||||||
|
yield _ev
|
||||||
tool_msg = await self._build_tool_result_message(
|
tool_msg = await self._build_tool_result_message(
|
||||||
tc.id, tool_result, effective_compressor, tc.name
|
tc.id, tool_result, effective_compressor, tc.name
|
||||||
)
|
)
|
||||||
|
|
@ -1652,6 +1698,9 @@ class ReActEngine:
|
||||||
step=step,
|
step=step,
|
||||||
data={"tool_name": tc.name, "result": tool_result},
|
data={"tool_name": tc.name, "result": tool_result},
|
||||||
)
|
)
|
||||||
|
# Wave 4 U2: drain phase violations.
|
||||||
|
for _ev in self._drain_phase_violations(step):
|
||||||
|
yield _ev
|
||||||
|
|
||||||
tool_msg = await self._build_tool_result_message(
|
tool_msg = await self._build_tool_result_message(
|
||||||
tc.id, tool_result, effective_compressor, tc.name
|
tc.id, tool_result, effective_compressor, tc.name
|
||||||
|
|
@ -1721,6 +1770,9 @@ class ReActEngine:
|
||||||
step=step,
|
step=step,
|
||||||
data={"tool_name": pc["name"], "result": tool_result},
|
data={"tool_name": pc["name"], "result": tool_result},
|
||||||
)
|
)
|
||||||
|
# Wave 4 U2: drain phase violations.
|
||||||
|
for _ev in self._drain_phase_violations(step):
|
||||||
|
yield _ev
|
||||||
tool_msg = await self._build_tool_result_message(
|
tool_msg = await self._build_tool_result_message(
|
||||||
pc.get("id", f"text_tc_{step}"),
|
pc.get("id", f"text_tc_{step}"),
|
||||||
tool_result,
|
tool_result,
|
||||||
|
|
|
||||||
|
|
@ -1280,6 +1280,17 @@ async def _handle_chat_message(
|
||||||
"data": event.data,
|
"data": event.data,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif event.event_type == "phase_violation":
|
||||||
|
# Wave 4 U2: forward phase violations to the client so the
|
||||||
|
# frontend can surface them in the PhaseIndicator UI (alongside
|
||||||
|
# the LLM reinjection that already happens via the tool_result
|
||||||
|
# error dict).
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "phase_violation",
|
||||||
|
"data": event.data,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -529,3 +529,145 @@ async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
|
||||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||||
assert len(phase_events) == 0
|
assert len(phase_events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wave 4 U2 — phase_violation event forwarding
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_phase_violation_event_forwarded_to_client(app_with_chat):
|
||||||
|
"""When ReActEngine yields a phase_violation ReActEvent, chat.py WS handler
|
||||||
|
must forward it as a `{"type": "phase_violation", "data": ...}` WS message
|
||||||
|
so the frontend PhaseIndicator can react."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = PhaseState.PLANNING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
# Simulate: tool_call → tool_result (blocked) → phase_violation
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=1,
|
||||||
|
data={"tool_name": "write_file", "arguments": {"path": "/x"}},
|
||||||
|
)
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_result",
|
||||||
|
step=1,
|
||||||
|
data={
|
||||||
|
"tool_name": "write_file",
|
||||||
|
"result": {"error": "phase_violation", "is_error": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="phase_violation",
|
||||||
|
step=1,
|
||||||
|
data={
|
||||||
|
"error": "phase_violation",
|
||||||
|
"message": "Tool 'write_file' not allowed in planning phase.",
|
||||||
|
"current_phase": "planning",
|
||||||
|
"tool": "write_file",
|
||||||
|
"is_error": True,
|
||||||
|
"violation_kind": "tool_not_allowed",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="final_answer",
|
||||||
|
step=2,
|
||||||
|
data={"output": "done"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="go",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"]
|
||||||
|
assert len(violation_messages) == 1
|
||||||
|
v = violation_messages[0]["data"]
|
||||||
|
assert v["error"] == "phase_violation"
|
||||||
|
assert v["tool"] == "write_file"
|
||||||
|
assert v["current_phase"] == "planning"
|
||||||
|
assert v["violation_kind"] == "tool_not_allowed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_phase_violation_event_when_not_plan_exec(app_with_chat):
|
||||||
|
"""Characterization: REACT mode → no phase_violation events forwarded
|
||||||
|
(the engine never yields them without a phase_policy)."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.REACT)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = None
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
# REACT mode: no phase_violation events yielded.
|
||||||
|
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="hi",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"]
|
||||||
|
assert len(violation_messages) == 0
|
||||||
|
|
|
||||||
|
|
@ -337,3 +337,283 @@ class TestAdvancePhaseTool:
|
||||||
assert "advance_phase" not in allowed, (
|
assert "advance_phase" not in allowed, (
|
||||||
f"advance_phase must not be in {phase.value} whitelist"
|
f"advance_phase must not be in {phase.value} whitelist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wave 4 U2 — phase_violation accumulator + drain
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhaseViolationAccumulator:
|
||||||
|
"""_check_phase_permission records violations; _drain_phase_violations
|
||||||
|
yields them as ReActEvents and clears the accumulator."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self):
|
||||||
|
return ReActEngine(
|
||||||
|
llm_gateway=MagicMock(),
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_violation_appended_on_tool_block(self, engine):
|
||||||
|
# write_file is blocked in PLANNING.
|
||||||
|
engine._check_phase_permission("write_file", {})
|
||||||
|
assert len(engine._phase_violations) == 1
|
||||||
|
v = engine._phase_violations[0]
|
||||||
|
assert v["error"] == "phase_violation"
|
||||||
|
assert v["tool"] == "write_file"
|
||||||
|
assert v["current_phase"] == "planning"
|
||||||
|
assert v["violation_kind"] == "tool_not_allowed"
|
||||||
|
|
||||||
|
def test_violation_appended_on_bash_block(self, engine):
|
||||||
|
engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
||||||
|
assert len(engine._phase_violations) == 1
|
||||||
|
v = engine._phase_violations[0]
|
||||||
|
assert v["violation_kind"] == "bash_command_blocked"
|
||||||
|
assert v["tool"] == "shell"
|
||||||
|
assert v["command_preview"] == "rm -rf /tmp"
|
||||||
|
|
||||||
|
def test_no_violation_when_allowed(self, engine):
|
||||||
|
# search is allowed in PLANNING.
|
||||||
|
engine._check_phase_permission("search", {})
|
||||||
|
assert engine._phase_violations == []
|
||||||
|
|
||||||
|
def test_no_violation_without_policy(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
||||||
|
engine._check_phase_permission("anything", {})
|
||||||
|
assert engine._phase_violations == []
|
||||||
|
|
||||||
|
def test_drain_returns_events_and_clears(self, engine):
|
||||||
|
# Trigger two violations.
|
||||||
|
engine._check_phase_permission("write_file", {"path": "/a"})
|
||||||
|
engine._check_phase_permission("write_file", {"path": "/b"})
|
||||||
|
assert len(engine._phase_violations) == 2
|
||||||
|
|
||||||
|
events = engine._drain_phase_violations(step=3)
|
||||||
|
assert len(events) == 2
|
||||||
|
assert all(e.event_type == "phase_violation" for e in events)
|
||||||
|
assert all(e.step == 3 for e in events)
|
||||||
|
# Each event data is a copy (caller can't mutate the accumulator).
|
||||||
|
assert events[0].data["tool"] == "write_file"
|
||||||
|
# Accumulator cleared after drain.
|
||||||
|
assert engine._phase_violations == []
|
||||||
|
|
||||||
|
def test_drain_empty_returns_empty(self, engine):
|
||||||
|
assert engine._drain_phase_violations(step=1) == []
|
||||||
|
|
||||||
|
def test_drain_returns_shallow_copy(self, engine):
|
||||||
|
"""Drained event data must not alias the original violation dict —
|
||||||
|
mutating one must not mutate the other."""
|
||||||
|
engine._check_phase_permission("write_file", {})
|
||||||
|
events = engine._drain_phase_violations(step=1)
|
||||||
|
# Mutate the drained event data.
|
||||||
|
events[0].data["tool"] = "MUTATED"
|
||||||
|
# Original accumulator (now empty) is unaffected — but more importantly,
|
||||||
|
# a fresh violation recorded after drain is unaffected.
|
||||||
|
engine._check_phase_permission("write_file", {})
|
||||||
|
new_violations = engine._phase_violations
|
||||||
|
assert new_violations[0]["tool"] == "write_file" # not "MUTATED"
|
||||||
|
|
||||||
|
def test_reset_clears_violations(self, engine):
|
||||||
|
engine._check_phase_permission("write_file", {})
|
||||||
|
assert len(engine._phase_violations) == 1
|
||||||
|
engine.reset()
|
||||||
|
assert engine._phase_violations == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Wave 4 U2 — execute_stream yields phase_violation events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteStreamPhaseViolationEvents:
|
||||||
|
"""execute_stream must yield phase_violation ReActEvents when a tool is
|
||||||
|
blocked by _check_phase_permission. The events are drained after each
|
||||||
|
tool_result yield."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_yields_phase_violation_on_tool_block(self):
|
||||||
|
"""When the LLM calls a tool blocked by the phase policy, execute_stream
|
||||||
|
yields a tool_call event, a tool_result event (with the error dict),
|
||||||
|
and a phase_violation event."""
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=llm_mock_gateway_with_response(
|
||||||
|
tool_calls=[{"name": "write_file", "arguments": {"path": "/x"}}],
|
||||||
|
content=None,
|
||||||
|
),
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
max_steps=1,
|
||||||
|
)
|
||||||
|
# Patch _find_tool so we don't need real tools registered. write_file
|
||||||
|
# should be blocked by phase_policy before _find_tool is called.
|
||||||
|
engine._find_tool = lambda name, tools: None
|
||||||
|
|
||||||
|
events: list[ReActEvent] = []
|
||||||
|
async for ev in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=[],
|
||||||
|
):
|
||||||
|
events.append(ev)
|
||||||
|
|
||||||
|
# Expect: thinking → tool_call → tool_result → phase_violation → final_answer
|
||||||
|
event_types = [e.event_type for e in events]
|
||||||
|
assert "tool_call" in event_types
|
||||||
|
assert "tool_result" in event_types
|
||||||
|
assert "phase_violation" in event_types
|
||||||
|
|
||||||
|
# The phase_violation event must come AFTER tool_result.
|
||||||
|
tool_result_idx = next(i for i, e in enumerate(events) if e.event_type == "tool_result")
|
||||||
|
violation_idx = next(i for i, e in enumerate(events) if e.event_type == "phase_violation")
|
||||||
|
assert violation_idx > tool_result_idx
|
||||||
|
|
||||||
|
# Verify event data.
|
||||||
|
violation = events[violation_idx]
|
||||||
|
assert violation.data["error"] == "phase_violation"
|
||||||
|
assert violation.data["tool"] == "write_file"
|
||||||
|
assert violation.data["current_phase"] == "planning"
|
||||||
|
assert violation.data["violation_kind"] == "tool_not_allowed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_yields_phase_violation_on_bash_block(self):
|
||||||
|
"""When the LLM calls shell with a dangerous command in PLANNING,
|
||||||
|
execute_stream yields a phase_violation event with violation_kind
|
||||||
|
= bash_command_blocked."""
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=llm_mock_gateway_with_response(
|
||||||
|
tool_calls=[{"name": "shell", "arguments": {"command": "rm -rf /tmp"}}],
|
||||||
|
content=None,
|
||||||
|
),
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
max_steps=1,
|
||||||
|
)
|
||||||
|
engine._find_tool = lambda name, tools: None
|
||||||
|
|
||||||
|
events: list[ReActEvent] = []
|
||||||
|
async for ev in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=[],
|
||||||
|
):
|
||||||
|
events.append(ev)
|
||||||
|
|
||||||
|
violation_events = [e for e in events if e.event_type == "phase_violation"]
|
||||||
|
assert len(violation_events) == 1
|
||||||
|
v = violation_events[0].data
|
||||||
|
assert v["violation_kind"] == "bash_command_blocked"
|
||||||
|
assert v["tool"] == "shell"
|
||||||
|
assert "rm -rf /tmp" in v["command_preview"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_no_violation_when_tool_allowed(self):
|
||||||
|
"""When the LLM calls an allowed tool, no phase_violation event is yielded."""
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=llm_mock_gateway_with_response(
|
||||||
|
tool_calls=[{"name": "search", "arguments": {"query": "foo"}}],
|
||||||
|
content=None,
|
||||||
|
),
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
max_steps=1,
|
||||||
|
)
|
||||||
|
# search is allowed in PLANNING; we still need _find_tool to return a
|
||||||
|
# tool object so dispatch proceeds.
|
||||||
|
search_tool = MagicMock()
|
||||||
|
search_tool.input_schema = None
|
||||||
|
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
||||||
|
engine._find_tool = lambda name, tools: search_tool
|
||||||
|
|
||||||
|
events: list[ReActEvent] = []
|
||||||
|
async for ev in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=[search_tool],
|
||||||
|
):
|
||||||
|
events.append(ev)
|
||||||
|
|
||||||
|
assert not any(e.event_type == "phase_violation" for e in events)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_no_violation_without_policy(self):
|
||||||
|
"""Without a phase_policy, no phase_violation events are yielded —
|
||||||
|
characterization of the no-policy path."""
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=llm_mock_gateway_with_response(
|
||||||
|
tool_calls=[{"name": "any_tool", "arguments": {}}],
|
||||||
|
content=None,
|
||||||
|
),
|
||||||
|
max_steps=1,
|
||||||
|
)
|
||||||
|
any_tool = MagicMock()
|
||||||
|
any_tool.input_schema = None
|
||||||
|
any_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
||||||
|
engine._find_tool = lambda name, tools: any_tool
|
||||||
|
|
||||||
|
events: list[ReActEvent] = []
|
||||||
|
async for ev in engine.execute_stream(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
tools=[any_tool],
|
||||||
|
):
|
||||||
|
events.append(ev)
|
||||||
|
|
||||||
|
assert not any(e.event_type == "phase_violation" for e in events)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers — minimal LLM gateway mocks for execute_stream tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def llm_mock_gateway():
|
||||||
|
"""Return a MagicMock LLM gateway (sufficient for constructor tests)."""
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
def llm_mock_gateway_with_response(tool_calls: list[dict], content: str | None):
|
||||||
|
"""Return a MagicMock LLM gateway whose chat_stream yields a single chunk
|
||||||
|
containing the given tool_calls (or content for a final-answer response).
|
||||||
|
|
||||||
|
The mock is shaped to match what execute_stream expects from
|
||||||
|
LLMGateway.chat_stream — an async iterable of chunks with attributes
|
||||||
|
`content`, `tool_calls`, `usage`, `model`.
|
||||||
|
"""
|
||||||
|
gateway = MagicMock()
|
||||||
|
|
||||||
|
# Build a fake chunk. execute_stream reads chunk.content, chunk.tool_calls,
|
||||||
|
# chunk.usage, chunk.model. The first three are typically accessed via
|
||||||
|
# attribute access; we make a small dataclass-like object.
|
||||||
|
class _Chunk:
|
||||||
|
def __init__(self, content, tool_calls, usage=None, model="default"):
|
||||||
|
self.content = content
|
||||||
|
self.tool_calls = tool_calls
|
||||||
|
self.usage = usage
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# If tool_calls provided, emit a chunk with tool_calls (non-streaming path).
|
||||||
|
# Otherwise, emit a chunk with content (final answer path).
|
||||||
|
if tool_calls:
|
||||||
|
# Convert raw dicts to objects with .name/.arguments/.id attributes
|
||||||
|
# (LLMGateway normally returns tool_call objects).
|
||||||
|
class _TC:
|
||||||
|
def __init__(self, d):
|
||||||
|
self.name = d.get("name", "")
|
||||||
|
self.arguments = d.get("arguments", {})
|
||||||
|
self.id = d.get("id", "tc_test")
|
||||||
|
|
||||||
|
chunks = [_Chunk(content=None, tool_calls=[_TC(tc) for tc in tool_calls])]
|
||||||
|
# Follow with a final-answer chunk so execute_stream's loop exits
|
||||||
|
# cleanly after the tool call.
|
||||||
|
chunks.append(_Chunk(content="done", tool_calls=None))
|
||||||
|
else:
|
||||||
|
chunks = [_Chunk(content=content or "final answer", tool_calls=None)]
|
||||||
|
|
||||||
|
async def _fake_chat_stream(*args, **kwargs):
|
||||||
|
for c in chunks:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
gateway.chat_stream = _fake_chat_stream
|
||||||
|
return gateway
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue