From 4dc58c24bc453b380a4de863a9f2095266225ad4 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 10:48:35 +0800 Subject: [PATCH] feat(U2): emit phase_violation WS event alongside LLM reinjection Wave 3 only injected the violation error dict back to the LLM as a tool result. Wave 4 U2 adds a parallel WS event so the frontend PhaseIndicator can surface violations to the user. - ReActEngine: add _phase_violations accumulator (list[dict]). Cleared in reset(). _check_phase_permission appends a structured violation dict (with new violation_kind field: tool_not_allowed | bash_command_blocked) before returning the error. - Add _drain_phase_violations(step) helper that pops pending violations and returns ReActEvent(event_type="phase_violation", ...) list. Events carry a shallow copy of the violation dict so callers can't mutate the accumulator. - execute_stream: drain after each tool_result yield at all 3 tool execution sites (parallel, serial-with-confirmation, parsed_calls). Non-streaming execute() ignores the accumulator (the LLM reinjection via the error dict is the only signal there). - chat.py WS handler: new elif branch forwards phase_violation ReActEvents to the client as {"type": "phase_violation", "data": ...} WS messages. - Tests: 11 new tests covering accumulator lifecycle, drain semantics, shallow-copy isolation, and execute_stream event emission for both tool_block and bash_block paths. 2 new WS forwarding tests pin the chat.py path (forward + characterization for REACT mode). --- src/agentkit/core/react.py | 56 ++++- src/agentkit/server/routes/chat.py | 11 + tests/unit/test_chat_plan_exec_ws.py | 142 +++++++++++ tests/unit/test_react_phase_enforcement.py | 280 +++++++++++++++++++++ 4 files changed, 487 insertions(+), 2 deletions(-) diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 166a5e3..8716df9 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -224,6 +224,12 @@ class ReActEngine: ) # Steps taken in the current phase (for auto-advance safety net). self._steps_in_phase: int = 0 + # Wave 4 U2: phase violation accumulator. _check_phase_permission + # appends here when a tool is blocked; execute_stream drains after each + # step and yields phase_violation ReActEvents. Non-streaming execute() + # simply ignores the accumulator (the error dict returned to the LLM is + # the only signal there). + self._phase_violations: list[dict[str, Any]] = [] def reset(self) -> None: """Reset internal state for reuse across conversations. @@ -241,6 +247,8 @@ class ReActEngine: if self._phase_policy is not None: self._current_phase = self._phase_policy.start_phase self._steps_in_phase = 0 + # Wave 4 U2: clear any pending violations from a prior run. + self._phase_violations = [] # ── U3/G6: phase state machine ──────────────────────────────────── @@ -299,11 +307,16 @@ class ReActEngine: AdvancePhaseTool or pick a different tool). Also applies the bash_command_filter for `bash` tool calls. + + Wave 4 U2: when blocked, the violation is also appended to + `self._phase_violations` so `execute_stream` can drain and yield + `phase_violation` ReActEvents to the WS layer (alongside the LLM + reinjection that the returned dict provides). """ if self._phase_policy is None or self._current_phase is None: return None if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase): - return { + violation = { "error": "phase_violation", "message": ( f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. " @@ -312,12 +325,15 @@ class ReActEngine: "current_phase": self._current_phase.value, "tool": tool_name, "is_error": True, + "violation_kind": "tool_not_allowed", } + self._phase_violations.append(violation) + return violation # Bash command filter (only applies to shell tool — registered as "shell"). if tool_name == "shell": command = str(arguments.get("command", "")) if not self._phase_policy.is_bash_command_allowed(command, self._current_phase): - return { + violation = { "error": "phase_violation", "message": ( f"Bash command blocked in {self._current_phase.value} phase " @@ -327,7 +343,11 @@ class ReActEngine: "current_phase": self._current_phase.value, "tool": tool_name, "is_error": True, + "violation_kind": "bash_command_blocked", + "command_preview": command[:200], } + self._phase_violations.append(violation) + return violation return None def _check_tool_loop(self, tool_calls: list[Any]) -> str | None: @@ -351,6 +371,28 @@ class ReActEngine: return hash_to_name.get(h) return None + def _drain_phase_violations(self, step: int) -> list[ReActEvent]: + """Pop and return ReActEvents for phase violations recorded by + ``_check_phase_permission`` since the last drain. + + Wave 4 U2: execute_stream calls this after each tool_result yield so + the WS layer can surface ``phase_violation`` events to the client + (alongside the LLM reinjection that the returned error dict provides). + Returns an empty list if no violations are pending. + """ + if not self._phase_violations: + return [] + events = [ + ReActEvent( + event_type="phase_violation", + step=step, + data=dict(v), + ) + for v in self._phase_violations + ] + self._phase_violations = [] + return events + async def execute( self, messages: list[dict[str, str]], @@ -1514,6 +1556,10 @@ class ReActEngine: step=step, data={"tool_name": tc.name, "result": tool_result}, ) + # Wave 4 U2: drain phase violations recorded by + # _check_phase_permission during this tool call. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( tc.id, tool_result, effective_compressor, tc.name ) @@ -1652,6 +1698,9 @@ class ReActEngine: step=step, data={"tool_name": tc.name, "result": tool_result}, ) + # Wave 4 U2: drain phase violations. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( tc.id, tool_result, effective_compressor, tc.name @@ -1721,6 +1770,9 @@ class ReActEngine: step=step, data={"tool_name": pc["name"], "result": tool_result}, ) + # Wave 4 U2: drain phase violations. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( pc.get("id", f"text_tc_{step}"), tool_result, diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 4b7be7f..41422a2 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -1280,6 +1280,17 @@ async def _handle_chat_message( "data": event.data, } ) + elif event.event_type == "phase_violation": + # Wave 4 U2: forward phase violations to the client so the + # frontend can surface them in the PhaseIndicator UI (alongside + # the LLM reinjection that already happens via the tool_result + # error dict). + await websocket.send_json( + { + "type": "phase_violation", + "data": event.data, + } + ) else: await websocket.send_json( { diff --git a/tests/unit/test_chat_plan_exec_ws.py b/tests/unit/test_chat_plan_exec_ws.py index 84fe358..aec66f4 100644 --- a/tests/unit/test_chat_plan_exec_ws.py +++ b/tests/unit/test_chat_plan_exec_ws.py @@ -529,3 +529,145 @@ async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat): sent_messages = [call.args[0] for call in ws.send_json.call_args_list] phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] assert len(phase_events) == 0 + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — phase_violation event forwarding +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_phase_violation_event_forwarded_to_client(app_with_chat): + """When ReActEngine yields a phase_violation ReActEvent, chat.py WS handler + must forward it as a `{"type": "phase_violation", "data": ...}` WS message + so the frontend PhaseIndicator can react.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = PhaseState.PLANNING + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + # Simulate: tool_call → tool_result (blocked) → phase_violation + yield ReActEvent( + event_type="tool_call", + step=1, + data={"tool_name": "write_file", "arguments": {"path": "/x"}}, + ) + yield ReActEvent( + event_type="tool_result", + step=1, + data={ + "tool_name": "write_file", + "result": {"error": "phase_violation", "is_error": True}, + }, + ) + yield ReActEvent( + event_type="phase_violation", + step=1, + data={ + "error": "phase_violation", + "message": "Tool 'write_file' not allowed in planning phase.", + "current_phase": "planning", + "tool": "write_file", + "is_error": True, + "violation_kind": "tool_not_allowed", + }, + ) + yield ReActEvent( + event_type="final_answer", + step=2, + data={"output": "done"}, + ) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="go", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"] + assert len(violation_messages) == 1 + v = violation_messages[0]["data"] + assert v["error"] == "phase_violation" + assert v["tool"] == "write_file" + assert v["current_phase"] == "planning" + assert v["violation_kind"] == "tool_not_allowed" + + +@pytest.mark.asyncio +async def test_no_phase_violation_event_when_not_plan_exec(app_with_chat): + """Characterization: REACT mode → no phase_violation events forwarded + (the engine never yields them without a phase_policy).""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REACT) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = None + self._current_phase = None + + @property + def current_phase(self): + return None + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + # REACT mode: no phase_violation events yielded. + yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"}) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="hi", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"] + assert len(violation_messages) == 0 diff --git a/tests/unit/test_react_phase_enforcement.py b/tests/unit/test_react_phase_enforcement.py index ac9d681..f363395 100644 --- a/tests/unit/test_react_phase_enforcement.py +++ b/tests/unit/test_react_phase_enforcement.py @@ -337,3 +337,283 @@ class TestAdvancePhaseTool: assert "advance_phase" not in allowed, ( f"advance_phase must not be in {phase.value} whitelist" ) + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — phase_violation accumulator + drain +# --------------------------------------------------------------------------- + + +class TestPhaseViolationAccumulator: + """_check_phase_permission records violations; _drain_phase_violations + yields them as ReActEvents and clears the accumulator.""" + + @pytest.fixture + def engine(self): + return ReActEngine( + llm_gateway=MagicMock(), + phase_policy=default_policy(), + ) + + def test_violation_appended_on_tool_block(self, engine): + # write_file is blocked in PLANNING. + engine._check_phase_permission("write_file", {}) + assert len(engine._phase_violations) == 1 + v = engine._phase_violations[0] + assert v["error"] == "phase_violation" + assert v["tool"] == "write_file" + assert v["current_phase"] == "planning" + assert v["violation_kind"] == "tool_not_allowed" + + def test_violation_appended_on_bash_block(self, engine): + engine._check_phase_permission("shell", {"command": "rm -rf /tmp"}) + assert len(engine._phase_violations) == 1 + v = engine._phase_violations[0] + assert v["violation_kind"] == "bash_command_blocked" + assert v["tool"] == "shell" + assert v["command_preview"] == "rm -rf /tmp" + + def test_no_violation_when_allowed(self, engine): + # search is allowed in PLANNING. + engine._check_phase_permission("search", {}) + assert engine._phase_violations == [] + + def test_no_violation_without_policy(self): + engine = ReActEngine(llm_gateway=MagicMock()) # no policy + engine._check_phase_permission("anything", {}) + assert engine._phase_violations == [] + + def test_drain_returns_events_and_clears(self, engine): + # Trigger two violations. + engine._check_phase_permission("write_file", {"path": "/a"}) + engine._check_phase_permission("write_file", {"path": "/b"}) + assert len(engine._phase_violations) == 2 + + events = engine._drain_phase_violations(step=3) + assert len(events) == 2 + assert all(e.event_type == "phase_violation" for e in events) + assert all(e.step == 3 for e in events) + # Each event data is a copy (caller can't mutate the accumulator). + assert events[0].data["tool"] == "write_file" + # Accumulator cleared after drain. + assert engine._phase_violations == [] + + def test_drain_empty_returns_empty(self, engine): + assert engine._drain_phase_violations(step=1) == [] + + def test_drain_returns_shallow_copy(self, engine): + """Drained event data must not alias the original violation dict — + mutating one must not mutate the other.""" + engine._check_phase_permission("write_file", {}) + events = engine._drain_phase_violations(step=1) + # Mutate the drained event data. + events[0].data["tool"] = "MUTATED" + # Original accumulator (now empty) is unaffected — but more importantly, + # a fresh violation recorded after drain is unaffected. + engine._check_phase_permission("write_file", {}) + new_violations = engine._phase_violations + assert new_violations[0]["tool"] == "write_file" # not "MUTATED" + + def test_reset_clears_violations(self, engine): + engine._check_phase_permission("write_file", {}) + assert len(engine._phase_violations) == 1 + engine.reset() + assert engine._phase_violations == [] + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — execute_stream yields phase_violation events +# --------------------------------------------------------------------------- + + +class TestExecuteStreamPhaseViolationEvents: + """execute_stream must yield phase_violation ReActEvents when a tool is + blocked by _check_phase_permission. The events are drained after each + tool_result yield.""" + + @pytest.mark.asyncio + async def test_stream_yields_phase_violation_on_tool_block(self): + """When the LLM calls a tool blocked by the phase policy, execute_stream + yields a tool_call event, a tool_result event (with the error dict), + and a phase_violation event.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "write_file", "arguments": {"path": "/x"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + # Patch _find_tool so we don't need real tools registered. write_file + # should be blocked by phase_policy before _find_tool is called. + engine._find_tool = lambda name, tools: None + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[], + ): + events.append(ev) + + # Expect: thinking → tool_call → tool_result → phase_violation → final_answer + event_types = [e.event_type for e in events] + assert "tool_call" in event_types + assert "tool_result" in event_types + assert "phase_violation" in event_types + + # The phase_violation event must come AFTER tool_result. + tool_result_idx = next(i for i, e in enumerate(events) if e.event_type == "tool_result") + violation_idx = next(i for i, e in enumerate(events) if e.event_type == "phase_violation") + assert violation_idx > tool_result_idx + + # Verify event data. + violation = events[violation_idx] + assert violation.data["error"] == "phase_violation" + assert violation.data["tool"] == "write_file" + assert violation.data["current_phase"] == "planning" + assert violation.data["violation_kind"] == "tool_not_allowed" + + @pytest.mark.asyncio + async def test_stream_yields_phase_violation_on_bash_block(self): + """When the LLM calls shell with a dangerous command in PLANNING, + execute_stream yields a phase_violation event with violation_kind + = bash_command_blocked.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "shell", "arguments": {"command": "rm -rf /tmp"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + engine._find_tool = lambda name, tools: None + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[], + ): + events.append(ev) + + violation_events = [e for e in events if e.event_type == "phase_violation"] + assert len(violation_events) == 1 + v = violation_events[0].data + assert v["violation_kind"] == "bash_command_blocked" + assert v["tool"] == "shell" + assert "rm -rf /tmp" in v["command_preview"] + + @pytest.mark.asyncio + async def test_stream_no_violation_when_tool_allowed(self): + """When the LLM calls an allowed tool, no phase_violation event is yielded.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "search", "arguments": {"query": "foo"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + # search is allowed in PLANNING; we still need _find_tool to return a + # tool object so dispatch proceeds. + search_tool = MagicMock() + search_tool.input_schema = None + search_tool.safe_execute = AsyncMock(return_value={"results": []}) + engine._find_tool = lambda name, tools: search_tool + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[search_tool], + ): + events.append(ev) + + assert not any(e.event_type == "phase_violation" for e in events) + + @pytest.mark.asyncio + async def test_stream_no_violation_without_policy(self): + """Without a phase_policy, no phase_violation events are yielded — + characterization of the no-policy path.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "any_tool", "arguments": {}}], + content=None, + ), + max_steps=1, + ) + any_tool = MagicMock() + any_tool.input_schema = None + any_tool.safe_execute = AsyncMock(return_value={"output": "ok"}) + engine._find_tool = lambda name, tools: any_tool + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[any_tool], + ): + events.append(ev) + + assert not any(e.event_type == "phase_violation" for e in events) + + +# --------------------------------------------------------------------------- +# Helpers — minimal LLM gateway mocks for execute_stream tests +# --------------------------------------------------------------------------- + + +def llm_mock_gateway(): + """Return a MagicMock LLM gateway (sufficient for constructor tests).""" + return MagicMock() + + +def llm_mock_gateway_with_response(tool_calls: list[dict], content: str | None): + """Return a MagicMock LLM gateway whose chat_stream yields a single chunk + containing the given tool_calls (or content for a final-answer response). + + The mock is shaped to match what execute_stream expects from + LLMGateway.chat_stream — an async iterable of chunks with attributes + `content`, `tool_calls`, `usage`, `model`. + """ + gateway = MagicMock() + + # Build a fake chunk. execute_stream reads chunk.content, chunk.tool_calls, + # chunk.usage, chunk.model. The first three are typically accessed via + # attribute access; we make a small dataclass-like object. + class _Chunk: + def __init__(self, content, tool_calls, usage=None, model="default"): + self.content = content + self.tool_calls = tool_calls + self.usage = usage + self.model = model + + # If tool_calls provided, emit a chunk with tool_calls (non-streaming path). + # Otherwise, emit a chunk with content (final answer path). + if tool_calls: + # Convert raw dicts to objects with .name/.arguments/.id attributes + # (LLMGateway normally returns tool_call objects). + class _TC: + def __init__(self, d): + self.name = d.get("name", "") + self.arguments = d.get("arguments", {}) + self.id = d.get("id", "tc_test") + + chunks = [_Chunk(content=None, tool_calls=[_TC(tc) for tc in tool_calls])] + # Follow with a final-answer chunk so execute_stream's loop exits + # cleanly after the tool call. + chunks.append(_Chunk(content="done", tool_calls=None)) + else: + chunks = [_Chunk(content=content or "final answer", tool_calls=None)] + + async def _fake_chat_stream(*args, **kwargs): + for c in chunks: + yield c + + gateway.chat_stream = _fake_chat_stream + return gateway