From 9e28ab315e6f5f5221da4ee0d044381eb6ede38b Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 10:39:44 +0800 Subject: [PATCH 1/8] feat(U1): widen PhasePolicy bash_command_filter to accept Callable Reuses ShellTool._is_dangerous as the default bash filter for PLANNING and VERIFICATION phases, closing the regex ceiling documented in Wave 3. - Convert ShellTool._is_dangerous and _is_single_command_dangerous to @staticmethod (backward-compatible; instance calls still work via Python's descriptor protocol). - Widen PhasePolicy.bash_command_filter field type to dict[PhaseState, Callable[[str], bool] | re.Pattern | None]. - is_bash_command_allowed dispatches on callable vs pattern at call time. Empty commands short-circuit to allowed (Wave 3 contract; ShellTool emits the clearer empty-command error). - to_dict serializes callables as for log readability. - default_policy() now wires ShellTool._is_dangerous for PLANNING and VERIFICATION. _DEFAULT_BASH_FILTER kept for backward compat with configs that pass a re.Pattern. - Tests: characterization tests pin Wave 3 behavior (rm/mv/cp/echo > still blocked) plus new edge-case coverage for ceiling closed (dd of=/dev/sda, :>file, chain operators, pipe segments). --- src/agentkit/core/phase.py | 61 +++++++++---- src/agentkit/tools/shell.py | 14 ++- tests/unit/test_phase_policy.py | 157 ++++++++++++++++++++++++++++++++ 3 files changed, 208 insertions(+), 24 deletions(-) diff --git a/src/agentkit/core/phase.py b/src/agentkit/core/phase.py index 0e5326a..8101c24 100644 --- a/src/agentkit/core/phase.py +++ b/src/agentkit/core/phase.py @@ -15,7 +15,9 @@ import enum import logging import re from dataclasses import dataclass, field, replace -from typing import Any +from typing import Any, Callable + +from agentkit.tools.shell import ShellTool logger = logging.getLogger(__name__) @@ -53,14 +55,10 @@ class PhaseState(enum.Enum): # Wildcard token meaning "all tools allowed in this phase". WILDCARD = "*" -# Default bash command filter for PLANNING and VERIFICATION phases — blocks -# commands that mutate the filesystem or execute arbitrary code. -# ponytail: regex is intentionally conservative; misses some shell idioms -# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch -# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time. -# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT -# for `>`/`>>` operators (not word chars). Use a non-boundary alternation -# that matches `>` either as a standalone operator or after whitespace. +# Legacy regex-based bash filter. Kept for backward compatibility with configs +# that pass a `re.Pattern` into `bash_command_filter`. New default uses +# `ShellTool._is_dangerous` (a Callable) which closes the regex ceiling +# (missed `:>file`, `dd of=file`, etc. — see Wave 4 U1). _DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?|>>") @@ -76,10 +74,19 @@ class PhasePolicy: Wildcard ``"*"`` in a phase's whitelist means "all tools allowed" (used by DELIVERY by default). + + `bash_command_filter` values accept either: + - `Callable[[str], bool]`: returns True if the command is dangerous + (matches `ShellTool._is_dangerous` semantics); allowed = not dangerous. + - `re.Pattern`: pattern matches dangerous substrings; allowed = no match. + Kept for backward compat with Wave 3 configs. + - `None`: no restriction for this phase. """ whitelist: dict[PhaseState, frozenset[str]] - bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict) + bash_command_filter: dict[ + PhaseState, Callable[[str], bool] | re.Pattern | None + ] = field(default_factory=dict) auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase) start_phase: PhaseState = PhaseState.PLANNING @@ -103,19 +110,31 @@ class PhasePolicy: """Return True if `command` passes the bash filter for `phase`. A None filter = no restriction. An empty command is allowed (ShellTool - separately rejects empty commands). + separately rejects empty commands) — short-circuited here so the + ShellTool path emits a clearer "empty command" error instead of a + phase-violation noise injected back to the LLM. """ - pattern = self.bash_command_filter.get(phase) - if pattern is None: + if not command: return True - return not pattern.search(command) + filter_value = self.bash_command_filter.get(phase) + if filter_value is None: + return True + if callable(filter_value): + # Callable contract: returns True if dangerous. + return not filter_value(command) + # re.Pattern contract: search() returns a Match if dangerous. + return not filter_value.search(command) def to_dict(self) -> dict[str, Any]: - """Serialize for logging/telemetry. Not round-trippable (regex → str).""" + """Serialize for logging/telemetry. Not round-trippable (regex/callable → str).""" return { "whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()}, "bash_command_filter": { - phase.value: (p.pattern if p else None) + phase.value: ( + "" + if callable(p) + else (p.pattern if p else None) + ) for phase, p in self.bash_command_filter.items() }, "auto_advance_after_steps": self.auto_advance_after_steps, @@ -133,8 +152,10 @@ def default_policy() -> PhasePolicy: - DELIVERY: all tools (wildcard) Bash filter: - - PLANNING/VERIFICATION: blocks filesystem-mutating commands - (rm/mv/cp/mkdir/chmod/chown/>/>>) + - PLANNING/VERIFICATION: reuse `ShellTool._is_dangerous` (Wave 4 U1). + Closes the regex ceiling — catches `:>file`, `dd of=/dev/sda`, chain + operators, and the full danger taxonomy shared with the ShellTool + confirmation path. - BUILDING/DELIVERY: no filter (full bash) """ return PhasePolicy( @@ -150,8 +171,8 @@ def default_policy() -> PhasePolicy: PhaseState.DELIVERY: frozenset({WILDCARD}), }, bash_command_filter={ - PhaseState.PLANNING: _DEFAULT_BASH_FILTER, - PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER, + PhaseState.PLANNING: ShellTool._is_dangerous, + PhaseState.VERIFICATION: ShellTool._is_dangerous, PhaseState.BUILDING: None, PhaseState.DELIVERY: None, }, diff --git a/src/agentkit/tools/shell.py b/src/agentkit/tools/shell.py index ba84377..04627c9 100644 --- a/src/agentkit/tools/shell.py +++ b/src/agentkit/tools/shell.py @@ -463,11 +463,16 @@ class ShellTool(Tool): return self._output_parser.parse(output, exit_code) - def _is_dangerous(self, command: str) -> bool: + @staticmethod + def _is_dangerous(command: str) -> bool: """检查命令是否为危险操作 白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。 其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。 + + Static so callers without a ShellTool instance (e.g. PhasePolicy) can + reuse the same danger classification. Instance calls still work via + Python's descriptor protocol. """ command_stripped = command.strip() @@ -482,14 +487,15 @@ class ShellTool(Tool): part = part.strip() if not part: continue - if self._is_single_command_dangerous(part): + if ShellTool._is_single_command_dangerous(part): return True return False # All pipe segments are safe # Single command - return self._is_single_command_dangerous(command_stripped) + return ShellTool._is_single_command_dangerous(command_stripped) - def _is_single_command_dangerous(self, command: str) -> bool: + @staticmethod + def _is_single_command_dangerous(command: str) -> bool: """Check if a single command (no pipes/chains) is dangerous.""" command_stripped = command.strip() if not command_stripped: diff --git a/tests/unit/test_phase_policy.py b/tests/unit/test_phase_policy.py index 936e823..36fb584 100644 --- a/tests/unit/test_phase_policy.py +++ b/tests/unit/test_phase_policy.py @@ -6,6 +6,7 @@ Covers: - PhasePolicy.is_tool_allowed / is_bash_command_allowed - policy_from_config parsing (R26 config-driven) - ServerConfig.plan_exec integration +- Wave 4 U1: bash_command_filter accepts Callable (ShellTool._is_dangerous reuse) """ from __future__ import annotations @@ -22,6 +23,7 @@ from agentkit.core.phase import ( policy_from_config, ) from agentkit.server.config import ServerConfig +from agentkit.tools.shell import ShellTool # --------------------------------------------------------------------------- @@ -115,6 +117,109 @@ class TestDefaultPolicy: assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True + # --- Wave 4 U1 characterization (Wave 3 behavior preserved) ----------------- + # default_policy() now wires ShellTool._is_dangerous (a Callable) for + # PLANNING/VERIFICATION. These tests pin the contract so a future regression + # in either ShellTool._is_dangerous or PhasePolicy dispatch surfaces here. + + def test_bash_filter_callable_in_default_policy(self): + # Sanity: default_policy uses a Callable, not a regex Pattern. + policy = default_policy() + planning_filter = policy.bash_command_filter[PhaseState.PLANNING] + assert callable(planning_filter) + assert planning_filter is ShellTool._is_dangerous + + def test_bash_filter_characterization_safe_commands(self): + # Wave 3 behavior preserved — safe read-only commands. + policy = default_policy() + for cmd in ("ls -la", "pwd", "git status", "find . -name foo", "cat README.md"): + assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is True, cmd + + def test_bash_filter_characterization_dangerous_commands(self): + # Wave 3 behavior preserved — commands blocked by the old regex. + policy = default_policy() + for cmd in ( + "rm -rf /", + "rm -rf /tmp/x", + "mv a b", + "cp a b", + "mkdir newdir", + "chmod 777 file", + "chown root file", + "echo x > file.txt", + "echo x >> file.txt", + ): + assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is False, cmd + + # --- Wave 4 U1 ceiling closed (new edge cases the old regex missed) --------- + + def test_bash_filter_closes_regex_ceiling_dd_of(self): + # Old regex missed `dd of=/dev/sda` (no word-boundary match for "dd"). + policy = default_policy() + assert policy.is_bash_command_allowed("dd of=/dev/sda", PhaseState.PLANNING) is False + + def test_bash_filter_closes_regex_ceiling_colon_redirect(self): + # Old regex missed `:>file` (no whitespace before `>`). + policy = default_policy() + assert policy.is_bash_command_allowed(":>file", PhaseState.PLANNING) is False + + def test_bash_filter_closes_regex_ceiling_redirection_after_arg(self): + # Old regex's `(?` looked for `>` at start or after whitespace. + # `echo hello > /tmp/x` slipped through because `>` had a space before it + # but the regex matched the wrong alternative. Verify the new filter + # classifies this as dangerous. + policy = default_policy() + assert policy.is_bash_command_allowed("echo hello > /tmp/x", PhaseState.PLANNING) is False + + def test_bash_filter_closes_regex_ceiling_chain_operators(self): + # Old regex did NOT match `;`, `&&`, `||` as dangerous. The new filter + # treats all chain operators as dangerous (matches ShellTool behavior). + policy = default_policy() + for cmd in ( + "ls; rm -rf /tmp", + "ls && rm -rf /tmp", + "ls || rm -rf /tmp", + "$(rm -rf /tmp)", + "`rm -rf /tmp`", + ): + assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is False, cmd + + def test_bash_filter_closes_regex_ceiling_pipe_with_dangerous_segment(self): + # Old regex scanned the WHOLE command string, so `echo x | grep y` + # would be allowed (no dangerous token) but `rm x | cat` would be + # blocked (matches `\brm\b`). The new filter splits pipes and checks + # each segment, so `echo x | grep y` should be allowed and + # `rm x | cat` blocked. + policy = default_policy() + assert policy.is_bash_command_allowed("echo x | grep y", PhaseState.PLANNING) is True + assert policy.is_bash_command_allowed("rm x | cat", PhaseState.PLANNING) is False + + def test_bash_filter_verification_phase_uses_callable(self): + # Same callable wired into VERIFICATION. + # Note: `pytest` is NOT in ShellTool._SAFE_COMMAND_PREFIXES, so + # _is_dangerous returns True for it — the verification phase does NOT + # widen the ShellTool whitelist. Use a known-safe read-only command + # for the "allowed" assertion. (Wave 4 U1 reuses ShellTool._is_dangerous + # as-is; expanding its safe-whitelist is out of scope.) + policy = default_policy() + assert policy.bash_command_filter[PhaseState.VERIFICATION] is ShellTool._is_dangerous + assert policy.is_bash_command_allowed("rm -rf /", PhaseState.VERIFICATION) is False + assert policy.is_bash_command_allowed("ls -la", PhaseState.VERIFICATION) is True + assert policy.is_bash_command_allowed("git status", PhaseState.VERIFICATION) is True + + def test_bash_filter_delivery_phase_no_filter(self): + # DELIVERY has no filter — full bash allowed. + policy = default_policy() + assert policy.bash_command_filter[PhaseState.DELIVERY] is None + assert policy.is_bash_command_allowed("rm -rf /", PhaseState.DELIVERY) is True + + def test_bash_filter_empty_command_allowed(self): + # is_bash_command_allowed must NOT call the filter on empty input — + # ShellTool separately rejects empty commands. Empty is "allowed" by + # the policy (no rejection injected to the LLM). + policy = default_policy() + assert policy.is_bash_command_allowed("", PhaseState.PLANNING) is True + # --------------------------------------------------------------------------- # PhasePolicy — is_tool_allowed @@ -207,6 +312,16 @@ class TestPhasePolicyEdgeCases: assert d["start_phase"] == "planning" assert d["auto_advance_after_steps"] is None + def test_to_dict_serializes_callable_as_marker(self): + # Wave 4 U1: default_policy now wires a Callable. to_dict must + # surface it as "" so logs/telemetry stay readable. + policy = default_policy() + d = policy.to_dict() + assert d["bash_command_filter"]["planning"] == "" + assert d["bash_command_filter"]["verification"] == "" + assert d["bash_command_filter"]["building"] is None + assert d["bash_command_filter"]["delivery"] is None + def test_custom_bash_filter(self): custom_filter = re.compile(r"\b(pip install|npm install)\b") policy = PhasePolicy( @@ -221,6 +336,48 @@ class TestPhasePolicyEdgeCases: assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True + def test_custom_bash_filter_accepts_callable(self): + # Wave 4 U1: callable form. The callable returns True for dangerous. + def deny_all(_: str) -> bool: + return True # everything is "dangerous" + + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({"shell"}), + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + }, + bash_command_filter={PhaseState.PLANNING: deny_all}, + ) + assert policy.is_bash_command_allowed("ls", PhaseState.PLANNING) is False + assert policy.is_bash_command_allowed("rm -rf /", PhaseState.PLANNING) is False + + def test_callable_filter_takes_precedence_over_pattern_form(self): + # Wave 4 U1: when a phase has a callable wired, the dispatch path is + # the callable branch, not the regex branch. Sanity-check the + # is_bash_command_allowed routing — both forms coexist in the same + # policy dict, each phase is independent. + pattern = re.compile(r"\brm\b") + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({"shell"}), + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + }, + bash_command_filter={ + PhaseState.PLANNING: pattern, # regex + PhaseState.BUILDING: ShellTool._is_dangerous, # callable + }, + ) + # PLANNING uses regex form. + assert policy.is_bash_command_allowed("rm x", PhaseState.PLANNING) is False + assert policy.is_bash_command_allowed("ls", PhaseState.PLANNING) is True + # BUILDING uses callable form. + assert policy.is_bash_command_allowed("rm x", PhaseState.BUILDING) is False + assert policy.is_bash_command_allowed("ls", PhaseState.BUILDING) is True + # --------------------------------------------------------------------------- # policy_from_config — R26 (config-driven) -- 2.43.0 From 4dc58c24bc453b380a4de863a9f2095266225ad4 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 10:48:35 +0800 Subject: [PATCH 2/8] feat(U2): emit phase_violation WS event alongside LLM reinjection Wave 3 only injected the violation error dict back to the LLM as a tool result. Wave 4 U2 adds a parallel WS event so the frontend PhaseIndicator can surface violations to the user. - ReActEngine: add _phase_violations accumulator (list[dict]). Cleared in reset(). _check_phase_permission appends a structured violation dict (with new violation_kind field: tool_not_allowed | bash_command_blocked) before returning the error. - Add _drain_phase_violations(step) helper that pops pending violations and returns ReActEvent(event_type="phase_violation", ...) list. Events carry a shallow copy of the violation dict so callers can't mutate the accumulator. - execute_stream: drain after each tool_result yield at all 3 tool execution sites (parallel, serial-with-confirmation, parsed_calls). Non-streaming execute() ignores the accumulator (the LLM reinjection via the error dict is the only signal there). - chat.py WS handler: new elif branch forwards phase_violation ReActEvents to the client as {"type": "phase_violation", "data": ...} WS messages. - Tests: 11 new tests covering accumulator lifecycle, drain semantics, shallow-copy isolation, and execute_stream event emission for both tool_block and bash_block paths. 2 new WS forwarding tests pin the chat.py path (forward + characterization for REACT mode). --- src/agentkit/core/react.py | 56 ++++- src/agentkit/server/routes/chat.py | 11 + tests/unit/test_chat_plan_exec_ws.py | 142 +++++++++++ tests/unit/test_react_phase_enforcement.py | 280 +++++++++++++++++++++ 4 files changed, 487 insertions(+), 2 deletions(-) diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 166a5e3..8716df9 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -224,6 +224,12 @@ class ReActEngine: ) # Steps taken in the current phase (for auto-advance safety net). self._steps_in_phase: int = 0 + # Wave 4 U2: phase violation accumulator. _check_phase_permission + # appends here when a tool is blocked; execute_stream drains after each + # step and yields phase_violation ReActEvents. Non-streaming execute() + # simply ignores the accumulator (the error dict returned to the LLM is + # the only signal there). + self._phase_violations: list[dict[str, Any]] = [] def reset(self) -> None: """Reset internal state for reuse across conversations. @@ -241,6 +247,8 @@ class ReActEngine: if self._phase_policy is not None: self._current_phase = self._phase_policy.start_phase self._steps_in_phase = 0 + # Wave 4 U2: clear any pending violations from a prior run. + self._phase_violations = [] # ── U3/G6: phase state machine ──────────────────────────────────── @@ -299,11 +307,16 @@ class ReActEngine: AdvancePhaseTool or pick a different tool). Also applies the bash_command_filter for `bash` tool calls. + + Wave 4 U2: when blocked, the violation is also appended to + `self._phase_violations` so `execute_stream` can drain and yield + `phase_violation` ReActEvents to the WS layer (alongside the LLM + reinjection that the returned dict provides). """ if self._phase_policy is None or self._current_phase is None: return None if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase): - return { + violation = { "error": "phase_violation", "message": ( f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. " @@ -312,12 +325,15 @@ class ReActEngine: "current_phase": self._current_phase.value, "tool": tool_name, "is_error": True, + "violation_kind": "tool_not_allowed", } + self._phase_violations.append(violation) + return violation # Bash command filter (only applies to shell tool — registered as "shell"). if tool_name == "shell": command = str(arguments.get("command", "")) if not self._phase_policy.is_bash_command_allowed(command, self._current_phase): - return { + violation = { "error": "phase_violation", "message": ( f"Bash command blocked in {self._current_phase.value} phase " @@ -327,7 +343,11 @@ class ReActEngine: "current_phase": self._current_phase.value, "tool": tool_name, "is_error": True, + "violation_kind": "bash_command_blocked", + "command_preview": command[:200], } + self._phase_violations.append(violation) + return violation return None def _check_tool_loop(self, tool_calls: list[Any]) -> str | None: @@ -351,6 +371,28 @@ class ReActEngine: return hash_to_name.get(h) return None + def _drain_phase_violations(self, step: int) -> list[ReActEvent]: + """Pop and return ReActEvents for phase violations recorded by + ``_check_phase_permission`` since the last drain. + + Wave 4 U2: execute_stream calls this after each tool_result yield so + the WS layer can surface ``phase_violation`` events to the client + (alongside the LLM reinjection that the returned error dict provides). + Returns an empty list if no violations are pending. + """ + if not self._phase_violations: + return [] + events = [ + ReActEvent( + event_type="phase_violation", + step=step, + data=dict(v), + ) + for v in self._phase_violations + ] + self._phase_violations = [] + return events + async def execute( self, messages: list[dict[str, str]], @@ -1514,6 +1556,10 @@ class ReActEngine: step=step, data={"tool_name": tc.name, "result": tool_result}, ) + # Wave 4 U2: drain phase violations recorded by + # _check_phase_permission during this tool call. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( tc.id, tool_result, effective_compressor, tc.name ) @@ -1652,6 +1698,9 @@ class ReActEngine: step=step, data={"tool_name": tc.name, "result": tool_result}, ) + # Wave 4 U2: drain phase violations. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( tc.id, tool_result, effective_compressor, tc.name @@ -1721,6 +1770,9 @@ class ReActEngine: step=step, data={"tool_name": pc["name"], "result": tool_result}, ) + # Wave 4 U2: drain phase violations. + for _ev in self._drain_phase_violations(step): + yield _ev tool_msg = await self._build_tool_result_message( pc.get("id", f"text_tc_{step}"), tool_result, diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 4b7be7f..41422a2 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -1280,6 +1280,17 @@ async def _handle_chat_message( "data": event.data, } ) + elif event.event_type == "phase_violation": + # Wave 4 U2: forward phase violations to the client so the + # frontend can surface them in the PhaseIndicator UI (alongside + # the LLM reinjection that already happens via the tool_result + # error dict). + await websocket.send_json( + { + "type": "phase_violation", + "data": event.data, + } + ) else: await websocket.send_json( { diff --git a/tests/unit/test_chat_plan_exec_ws.py b/tests/unit/test_chat_plan_exec_ws.py index 84fe358..aec66f4 100644 --- a/tests/unit/test_chat_plan_exec_ws.py +++ b/tests/unit/test_chat_plan_exec_ws.py @@ -529,3 +529,145 @@ async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat): sent_messages = [call.args[0] for call in ws.send_json.call_args_list] phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] assert len(phase_events) == 0 + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — phase_violation event forwarding +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_phase_violation_event_forwarded_to_client(app_with_chat): + """When ReActEngine yields a phase_violation ReActEvent, chat.py WS handler + must forward it as a `{"type": "phase_violation", "data": ...}` WS message + so the frontend PhaseIndicator can react.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = PhaseState.PLANNING + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + # Simulate: tool_call → tool_result (blocked) → phase_violation + yield ReActEvent( + event_type="tool_call", + step=1, + data={"tool_name": "write_file", "arguments": {"path": "/x"}}, + ) + yield ReActEvent( + event_type="tool_result", + step=1, + data={ + "tool_name": "write_file", + "result": {"error": "phase_violation", "is_error": True}, + }, + ) + yield ReActEvent( + event_type="phase_violation", + step=1, + data={ + "error": "phase_violation", + "message": "Tool 'write_file' not allowed in planning phase.", + "current_phase": "planning", + "tool": "write_file", + "is_error": True, + "violation_kind": "tool_not_allowed", + }, + ) + yield ReActEvent( + event_type="final_answer", + step=2, + data={"output": "done"}, + ) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="go", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"] + assert len(violation_messages) == 1 + v = violation_messages[0]["data"] + assert v["error"] == "phase_violation" + assert v["tool"] == "write_file" + assert v["current_phase"] == "planning" + assert v["violation_kind"] == "tool_not_allowed" + + +@pytest.mark.asyncio +async def test_no_phase_violation_event_when_not_plan_exec(app_with_chat): + """Characterization: REACT mode → no phase_violation events forwarded + (the engine never yields them without a phase_policy).""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REACT) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = None + self._current_phase = None + + @property + def current_phase(self): + return None + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + # REACT mode: no phase_violation events yielded. + yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"}) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="hi", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"] + assert len(violation_messages) == 0 diff --git a/tests/unit/test_react_phase_enforcement.py b/tests/unit/test_react_phase_enforcement.py index ac9d681..f363395 100644 --- a/tests/unit/test_react_phase_enforcement.py +++ b/tests/unit/test_react_phase_enforcement.py @@ -337,3 +337,283 @@ class TestAdvancePhaseTool: assert "advance_phase" not in allowed, ( f"advance_phase must not be in {phase.value} whitelist" ) + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — phase_violation accumulator + drain +# --------------------------------------------------------------------------- + + +class TestPhaseViolationAccumulator: + """_check_phase_permission records violations; _drain_phase_violations + yields them as ReActEvents and clears the accumulator.""" + + @pytest.fixture + def engine(self): + return ReActEngine( + llm_gateway=MagicMock(), + phase_policy=default_policy(), + ) + + def test_violation_appended_on_tool_block(self, engine): + # write_file is blocked in PLANNING. + engine._check_phase_permission("write_file", {}) + assert len(engine._phase_violations) == 1 + v = engine._phase_violations[0] + assert v["error"] == "phase_violation" + assert v["tool"] == "write_file" + assert v["current_phase"] == "planning" + assert v["violation_kind"] == "tool_not_allowed" + + def test_violation_appended_on_bash_block(self, engine): + engine._check_phase_permission("shell", {"command": "rm -rf /tmp"}) + assert len(engine._phase_violations) == 1 + v = engine._phase_violations[0] + assert v["violation_kind"] == "bash_command_blocked" + assert v["tool"] == "shell" + assert v["command_preview"] == "rm -rf /tmp" + + def test_no_violation_when_allowed(self, engine): + # search is allowed in PLANNING. + engine._check_phase_permission("search", {}) + assert engine._phase_violations == [] + + def test_no_violation_without_policy(self): + engine = ReActEngine(llm_gateway=MagicMock()) # no policy + engine._check_phase_permission("anything", {}) + assert engine._phase_violations == [] + + def test_drain_returns_events_and_clears(self, engine): + # Trigger two violations. + engine._check_phase_permission("write_file", {"path": "/a"}) + engine._check_phase_permission("write_file", {"path": "/b"}) + assert len(engine._phase_violations) == 2 + + events = engine._drain_phase_violations(step=3) + assert len(events) == 2 + assert all(e.event_type == "phase_violation" for e in events) + assert all(e.step == 3 for e in events) + # Each event data is a copy (caller can't mutate the accumulator). + assert events[0].data["tool"] == "write_file" + # Accumulator cleared after drain. + assert engine._phase_violations == [] + + def test_drain_empty_returns_empty(self, engine): + assert engine._drain_phase_violations(step=1) == [] + + def test_drain_returns_shallow_copy(self, engine): + """Drained event data must not alias the original violation dict — + mutating one must not mutate the other.""" + engine._check_phase_permission("write_file", {}) + events = engine._drain_phase_violations(step=1) + # Mutate the drained event data. + events[0].data["tool"] = "MUTATED" + # Original accumulator (now empty) is unaffected — but more importantly, + # a fresh violation recorded after drain is unaffected. + engine._check_phase_permission("write_file", {}) + new_violations = engine._phase_violations + assert new_violations[0]["tool"] == "write_file" # not "MUTATED" + + def test_reset_clears_violations(self, engine): + engine._check_phase_permission("write_file", {}) + assert len(engine._phase_violations) == 1 + engine.reset() + assert engine._phase_violations == [] + + +# --------------------------------------------------------------------------- +# Wave 4 U2 — execute_stream yields phase_violation events +# --------------------------------------------------------------------------- + + +class TestExecuteStreamPhaseViolationEvents: + """execute_stream must yield phase_violation ReActEvents when a tool is + blocked by _check_phase_permission. The events are drained after each + tool_result yield.""" + + @pytest.mark.asyncio + async def test_stream_yields_phase_violation_on_tool_block(self): + """When the LLM calls a tool blocked by the phase policy, execute_stream + yields a tool_call event, a tool_result event (with the error dict), + and a phase_violation event.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "write_file", "arguments": {"path": "/x"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + # Patch _find_tool so we don't need real tools registered. write_file + # should be blocked by phase_policy before _find_tool is called. + engine._find_tool = lambda name, tools: None + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[], + ): + events.append(ev) + + # Expect: thinking → tool_call → tool_result → phase_violation → final_answer + event_types = [e.event_type for e in events] + assert "tool_call" in event_types + assert "tool_result" in event_types + assert "phase_violation" in event_types + + # The phase_violation event must come AFTER tool_result. + tool_result_idx = next(i for i, e in enumerate(events) if e.event_type == "tool_result") + violation_idx = next(i for i, e in enumerate(events) if e.event_type == "phase_violation") + assert violation_idx > tool_result_idx + + # Verify event data. + violation = events[violation_idx] + assert violation.data["error"] == "phase_violation" + assert violation.data["tool"] == "write_file" + assert violation.data["current_phase"] == "planning" + assert violation.data["violation_kind"] == "tool_not_allowed" + + @pytest.mark.asyncio + async def test_stream_yields_phase_violation_on_bash_block(self): + """When the LLM calls shell with a dangerous command in PLANNING, + execute_stream yields a phase_violation event with violation_kind + = bash_command_blocked.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "shell", "arguments": {"command": "rm -rf /tmp"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + engine._find_tool = lambda name, tools: None + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[], + ): + events.append(ev) + + violation_events = [e for e in events if e.event_type == "phase_violation"] + assert len(violation_events) == 1 + v = violation_events[0].data + assert v["violation_kind"] == "bash_command_blocked" + assert v["tool"] == "shell" + assert "rm -rf /tmp" in v["command_preview"] + + @pytest.mark.asyncio + async def test_stream_no_violation_when_tool_allowed(self): + """When the LLM calls an allowed tool, no phase_violation event is yielded.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "search", "arguments": {"query": "foo"}}], + content=None, + ), + phase_policy=default_policy(), + max_steps=1, + ) + # search is allowed in PLANNING; we still need _find_tool to return a + # tool object so dispatch proceeds. + search_tool = MagicMock() + search_tool.input_schema = None + search_tool.safe_execute = AsyncMock(return_value={"results": []}) + engine._find_tool = lambda name, tools: search_tool + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[search_tool], + ): + events.append(ev) + + assert not any(e.event_type == "phase_violation" for e in events) + + @pytest.mark.asyncio + async def test_stream_no_violation_without_policy(self): + """Without a phase_policy, no phase_violation events are yielded — + characterization of the no-policy path.""" + from agentkit.core.react import ReActEvent + + engine = ReActEngine( + llm_gateway=llm_mock_gateway_with_response( + tool_calls=[{"name": "any_tool", "arguments": {}}], + content=None, + ), + max_steps=1, + ) + any_tool = MagicMock() + any_tool.input_schema = None + any_tool.safe_execute = AsyncMock(return_value={"output": "ok"}) + engine._find_tool = lambda name, tools: any_tool + + events: list[ReActEvent] = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "test"}], + tools=[any_tool], + ): + events.append(ev) + + assert not any(e.event_type == "phase_violation" for e in events) + + +# --------------------------------------------------------------------------- +# Helpers — minimal LLM gateway mocks for execute_stream tests +# --------------------------------------------------------------------------- + + +def llm_mock_gateway(): + """Return a MagicMock LLM gateway (sufficient for constructor tests).""" + return MagicMock() + + +def llm_mock_gateway_with_response(tool_calls: list[dict], content: str | None): + """Return a MagicMock LLM gateway whose chat_stream yields a single chunk + containing the given tool_calls (or content for a final-answer response). + + The mock is shaped to match what execute_stream expects from + LLMGateway.chat_stream — an async iterable of chunks with attributes + `content`, `tool_calls`, `usage`, `model`. + """ + gateway = MagicMock() + + # Build a fake chunk. execute_stream reads chunk.content, chunk.tool_calls, + # chunk.usage, chunk.model. The first three are typically accessed via + # attribute access; we make a small dataclass-like object. + class _Chunk: + def __init__(self, content, tool_calls, usage=None, model="default"): + self.content = content + self.tool_calls = tool_calls + self.usage = usage + self.model = model + + # If tool_calls provided, emit a chunk with tool_calls (non-streaming path). + # Otherwise, emit a chunk with content (final answer path). + if tool_calls: + # Convert raw dicts to objects with .name/.arguments/.id attributes + # (LLMGateway normally returns tool_call objects). + class _TC: + def __init__(self, d): + self.name = d.get("name", "") + self.arguments = d.get("arguments", {}) + self.id = d.get("id", "tc_test") + + chunks = [_Chunk(content=None, tool_calls=[_TC(tc) for tc in tool_calls])] + # Follow with a final-answer chunk so execute_stream's loop exits + # cleanly after the tool call. + chunks.append(_Chunk(content="done", tool_calls=None)) + else: + chunks = [_Chunk(content=content or "final answer", tool_calls=None)] + + async def _fake_chat_stream(*args, **kwargs): + for c in chunks: + yield c + + gateway.chat_stream = _fake_chat_stream + return gateway -- 2.43.0 From b032e08866334e1c4017bec8951dd166682819c5 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 10:59:43 +0800 Subject: [PATCH 3/8] feat(U3): extract _build_phase_engine helper + wire REST PLAN_EXEC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract the WS path's inline phase_policy construction into a shared _build_phase_engine helper so the REST send_message endpoint can reuse it. Replace the former 501 stub with actual PLAN_EXEC execution: - REST POST /chat/sessions/{id}/messages with execution_mode=plan_exec now builds a phase-policy-backed ReActEngine, calls execute() (non-streaming), and returns a MessageResponse. - KTD5: PLAN_EXEC bypasses execute_with_fallback_chain — phase policy and fallback chain are mutually exclusive. - When plan_exec.enabled=False, REST falls through to the REACT path (matching WS behavior). - WS path refactored to call the same helper; behavior unchanged. Tests: - Replace TestRestPlanExec501 with TestRestPlanExec (happy path, bad config → 500, disabled → falls through to REACT, REACT mode unchanged). - Add TestBuildPhaseEngineHelper covering all return branches: not-PLAN_EXEC, disabled, empty-config, invalid-config, tool append, default-policy fallback. - All 109 tests pass across the three PLAN_EXEC test files. --- src/agentkit/server/routes/chat.py | 183 ++++++++++++++------ tests/unit/test_chat_plan_exec_ws.py | 244 +++++++++++++++++++++++++-- 2 files changed, 368 insertions(+), 59 deletions(-) diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 41422a2..54cfb12 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -25,7 +25,7 @@ from fastapi.responses import FileResponse from pydantic import BaseModel from agentkit.chat.skill_routing import ExecutionMode -from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config +from agentkit.core.phase import default_policy, policy_from_config from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine from agentkit.server._fallback_chain import execute_with_fallback_chain @@ -534,6 +534,69 @@ def _message_to_response(msg) -> MessageResponse: ) +def _build_phase_engine( + *, + server_config: Any, + llm_gateway: Any, + execution_mode: ExecutionMode, + base_tools: list, + session_id: str = "", +) -> tuple[ReActEngine | None, list | None, str | None]: + """Build a PLAN_EXEC engine with PhasePolicy + AdvancePhaseTool. + + Encapsulates the WS path's phase_policy construction so the REST path + can reuse it without duplicating config-lookup + policy_from_config + + AdvancePhaseTool registration. KTD5: PLAN_EXEC bypasses the fallback + chain — callers must NOT route the returned engine through + ``execute_with_fallback_chain``. + + Args: + server_config: ``app.state.server_config`` (or None for tests). + llm_gateway: ``app.state.llm_gateway``. + execution_mode: routing.execution_mode (WS) or PLAN_EXEC (REST). + base_tools: routing.tools (WS) or agent tool list (REST). + session_id: included in log lines for traceability only. + + Returns ``(engine, tools_with_advance_phase, error_message)``: + - execution_mode != PLAN_EXEC → ``(None, None, None)`` (fall back to REACT). + - plan_exec.enabled=False → ``(None, None, None)`` (fall back to REACT). + - phase policy construction failed → ``(None, None, error_message)``. + - PLAN_EXEC engaged → ``(engine, tools_with_advance_phase, None)``. + """ + if execution_mode != ExecutionMode.PLAN_EXEC: + return (None, None, None) + + plan_exec_cfg = getattr(server_config, "plan_exec", None) or {} + if plan_exec_cfg.get("enabled", True) is False: + logger.info( + "PLAN_EXEC disabled by config (plan_exec.enabled=False), " + "falling back to REACT for session %s", + session_id, + ) + return (None, None, None) + + try: + phase_policy = policy_from_config(plan_exec_cfg) + if phase_policy is None: + # Empty config (no `plan_exec:` section) → use KTD5 defaults. + phase_policy = default_policy() + except Exception as e: + logger.error( + "PLAN_EXEC phase policy construction failed for session %s: %s", + session_id, + e, + ) + return (None, None, f"phase policy error: {str(e)[:200]}") + + engine = ReActEngine( + llm_gateway=llm_gateway, + phase_policy=phase_policy, + ) + advance_phase_tool = AdvancePhaseTool(engine=engine) + tools_with_advance_phase = list(base_tools) + [advance_phase_tool] + return (engine, tools_with_advance_phase, None) + + # ── REST endpoints ──────────────────────────────────────────────────── @@ -587,12 +650,58 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques if session.status == SessionStatus.CLOSED: raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") - # KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501. + # U3: PLAN_EXEC via REST — non-streaming, bypasses the fallback chain + # (KTD5: PLAN_EXEC and execute_with_fallback_chain are mutually exclusive). + # When plan_exec is disabled by config, falls through to the REACT path below. if request.execution_mode == "plan_exec": - raise HTTPException( - status_code=501, - detail="PLAN_EXEC via REST not yet supported; use WebSocket", + # Resolve the Agent early — PLAN_EXEC needs its tool list + system prompt. + pool = req.app.state.agent_pool + agent = pool.get_agent(session.agent_name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found") + + plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine( + server_config=getattr(req.app.state, "server_config", None), + llm_gateway=req.app.state.llm_gateway, + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=agent._tool_registry.list_tools() if agent._tool_registry else [], + session_id=session_id, ) + if plan_exec_error is not None: + raise HTTPException(status_code=500, detail=plan_exec_error) + if plan_exec_engine is not None: + # PLAN_EXEC engaged — append user msg, execute non-streaming, return. + await sm.append_message( + session_id=session_id, + role=MessageRole.USER, + content=request.content, + ) + chat_messages = await sm.get_chat_messages(session_id) + system_prompt = getattr(agent, "_system_prompt", None) or ( + agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None + ) + try: + plan_exec_result = await plan_exec_engine.execute( + messages=chat_messages, + tools=plan_exec_tools, + model=agent.get_model() + if hasattr(agent, "get_model") + else getattr(agent, "_llm_model", "default"), + agent_name=agent.name, + system_prompt=system_prompt, + ) + except Exception as e: + logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + assistant_msg = await sm.append_message( + session_id=session_id, + role=MessageRole.ASSISTANT, + content=plan_exec_result.output, + agent_name=agent.name, + ) + return _message_to_response(assistant_msg) + # else: plan_exec.enabled=False → fall through to REACT path below. # Append user message await sm.append_message( @@ -1090,42 +1199,27 @@ async def _handle_chat_message( await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) return - # U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only). + # U4/G6: PLAN_EXEC — build PhasePolicy from server config. # KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and # fallback chain are mutually exclusive. PLAN_EXEC uses its own engine. - phase_policy: PhasePolicy | None = None - if routing.execution_mode == ExecutionMode.PLAN_EXEC: - server_config = getattr(websocket.app.state, "server_config", None) - plan_exec_cfg = getattr(server_config, "plan_exec", None) or {} - - if plan_exec_cfg.get("enabled", True) is False: - # Explicit opt-out → fall back to REACT. - logger.info( - "PLAN_EXEC disabled by config (plan_exec.enabled=False), " - "falling back to REACT for session %s", - session_id, - ) - else: - try: - phase_policy = policy_from_config(plan_exec_cfg) - if phase_policy is None: - # Empty config (no `plan_exec:` section) → use KTD5 defaults. - phase_policy = default_policy() - except Exception as e: - logger.error( - "PLAN_EXEC phase policy construction failed for session %s: %s", - session_id, - e, - ) - await websocket.send_json( - { - "type": "error", - # Truncate to 200 chars to match nearby error paths and - # avoid leaking config internals (see chat.py:1090, 1320). - "data": {"message": f"phase policy error: {str(e)[:200]}"}, - } - ) - return + # U3: logic extracted into _build_phase_engine so REST can reuse it. + plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine( + server_config=getattr(websocket.app.state, "server_config", None), + llm_gateway=websocket.app.state.llm_gateway, + execution_mode=routing.execution_mode, + base_tools=routing.tools, + session_id=session_id, + ) + if plan_exec_error is not None: + await websocket.send_json( + { + "type": "error", + # Truncate to 200 chars to match nearby error paths and + # avoid leaking config internals (see chat.py:1090, 1320). + "data": {"message": plan_exec_error}, + } + ) + return # Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB # still fall back to REACT with a warning. PLAN_EXEC is handled above. @@ -1143,14 +1237,9 @@ async def _handle_chat_message( # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization). # PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the # agent's _react_engine — it has no policy). - if phase_policy is not None: - react_engine = ReActEngine( - llm_gateway=websocket.app.state.llm_gateway, - phase_policy=phase_policy, - ) - # Register AdvancePhaseTool bound to this engine (LLM's escape hatch). - advance_phase_tool = AdvancePhaseTool(engine=react_engine) - routing.tools = list(routing.tools) + [advance_phase_tool] + if plan_exec_engine is not None: + react_engine = plan_exec_engine + routing.tools = plan_exec_tools else: react_engine = getattr(agent, "_react_engine", None) if react_engine is None: diff --git a/tests/unit/test_chat_plan_exec_ws.py b/tests/unit/test_chat_plan_exec_ws.py index aec66f4..09c8750 100644 --- a/tests/unit/test_chat_plan_exec_ws.py +++ b/tests/unit/test_chat_plan_exec_ws.py @@ -1,10 +1,12 @@ -"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4). +"""Unit tests for PLAN_EXEC wiring at chat.py REST + WebSocket paths (G6, U3, U4). Per plan U4 Execution note: characterization-first — verify that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring tests. -KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501. +U3: PLAN_EXEC is now wired at both REST and WebSocket paths. REST returns +a non-streaming MessageResponse; WS streams phase_violation events alongside +the LLM reinjection. KTD5: PLAN_EXEC bypasses the fallback chain. """ from __future__ import annotations @@ -109,13 +111,60 @@ def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None: # --------------------------------------------------------------------------- -# REST — PLAN_EXEC raises 501 (KTD4) +# REST — PLAN_EXEC wired (U3, replaces former 501 path) # --------------------------------------------------------------------------- -class TestRestPlanExec501: - def test_rest_plan_exec_returns_501(self, client): - """REST send_message with execution_mode=plan_exec → 501.""" +class TestRestPlanExec: + """U3: REST send_message with execution_mode=plan_exec now executes + PLAN_EXEC (non-streaming) instead of raising 501.""" + + def test_rest_plan_exec_returns_assistant_message(self, app_with_chat, monkeypatch): + """REST PLAN_EXEC happy path → 200 with assistant message.""" + from agentkit.server.routes import chat as chat_module + + # Patch ReActEngine with a stub whose execute() returns a ReActResult-like. + class _StubResult: + output = "PLAN_EXEC completed" + status = "success" + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = ( + kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None + ) + + async def execute(self, **kwargs): + return _StubResult() + + monkeypatch.setattr(chat_module, "ReActEngine", _StubEngine) + + # Wire agent_pool with a mock agent that has _tool_registry. + agent = _make_agent_mock() + app_with_chat.state.agent_pool.get_agent.return_value = agent + + client = TestClient(app_with_chat) + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Build me a hello world", "execution_mode": "plan_exec"}, + ) + assert msg_resp.status_code == 200 + body = msg_resp.json() + assert body["content"] == "PLAN_EXEC completed" + assert body["role"] == "assistant" + + def test_rest_plan_exec_bad_config_returns_500(self, app_with_chat): + """REST PLAN_EXEC with invalid phase config → 500 with error detail.""" + app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"} + + agent = _make_agent_mock() + app_with_chat.state.agent_pool.get_agent.return_value = agent + + client = TestClient(app_with_chat) create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) session_id = create_resp.json()["session_id"] @@ -123,20 +172,71 @@ class TestRestPlanExec501: f"/api/v1/chat/sessions/{session_id}/messages", json={"content": "Hello", "execution_mode": "plan_exec"}, ) - assert msg_resp.status_code == 501 - assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"] + assert msg_resp.status_code == 500 + assert "phase policy error" in msg_resp.json()["detail"] - def test_rest_react_mode_still_works(self, client): - """REST send_message without execution_mode doesn't 501.""" + def test_rest_plan_exec_disabled_falls_through_to_react(self, app_with_chat, monkeypatch): + """REST PLAN_EXEC with enabled=False → falls through to REACT path.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {"enabled": False} + + # Track which engine constructor fires. + constructed: list = [] + + class _StubResult: + output = "REACT fallback ok" + status = "success" + + class _StubEngine: + def __init__(self, **kwargs): + constructed.append(kwargs) + self._phase_policy = kwargs.get("phase_policy") + + async def execute(self, **kwargs): + return _StubResult() + + monkeypatch.setattr(chat_module, "ReActEngine", _StubEngine) + # execute_with_fallback_chain also constructs ReflexionEngine internally; + # patch it to return a ChatExecutionResult-like directly. + from agentkit.server._fallback_chain import ChatExecutionResult + + async def _stub_chain(**kwargs): + return ChatExecutionResult(output="REACT fallback ok", status="success") + + monkeypatch.setattr(chat_module, "execute_with_fallback_chain", _stub_chain) + + agent = _make_agent_mock() + app_with_chat.state.agent_pool.get_agent.return_value = agent + + client = TestClient(app_with_chat) create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) session_id = create_resp.json()["session_id"] - # No execution_mode field → should NOT trigger 501. + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello", "execution_mode": "plan_exec"}, + ) + assert msg_resp.status_code == 200 + assert msg_resp.json()["content"] == "REACT fallback ok" + # No engine should have been constructed with phase_policy — PLAN_EXEC + # was disabled and the REACT path doesn't set phase_policy. + assert all(kw.get("phase_policy") is None for kw in constructed) + + def test_rest_react_mode_still_works(self, client): + """REST send_message without execution_mode doesn't 500.""" + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + # No execution_mode field → should NOT trigger PLAN_EXEC path. + # Will likely 500 due to mock llm_gateway, but must NOT be a PLAN_EXEC error. msg_resp = client.post( f"/api/v1/chat/sessions/{session_id}/messages", json={"content": "Hello"}, ) - assert msg_resp.status_code != 501 + # 500 is acceptable (mock gateway), but it must NOT be the PLAN_EXEC error. + if msg_resp.status_code == 500: + assert "phase policy error" not in msg_resp.json().get("detail", "") # --------------------------------------------------------------------------- @@ -671,3 +771,123 @@ async def test_no_phase_violation_event_when_not_plan_exec(app_with_chat): sent_messages = [call.args[0] for call in ws.send_json.call_args_list] violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"] assert len(violation_messages) == 0 + + +# --------------------------------------------------------------------------- +# _build_phase_engine helper (U3) +# --------------------------------------------------------------------------- + + +class TestBuildPhaseEngineHelper: + """Direct unit tests for the _build_phase_engine helper extracted in U3.""" + + def test_returns_none_when_not_plan_exec(self): + from agentkit.server.routes.chat import _build_phase_engine + + engine, tools, err = _build_phase_engine( + server_config=None, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.REACT, + base_tools=[], + ) + assert engine is None + assert tools is None + assert err is None + + def test_returns_none_when_plan_exec_disabled_by_config(self): + from agentkit.server.routes.chat import _build_phase_engine + + server_config = MagicMock() + server_config.plan_exec = {"enabled": False} + + engine, tools, err = _build_phase_engine( + server_config=server_config, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=[], + ) + assert engine is None + assert tools is None + assert err is None + + def test_returns_none_when_plan_exec_section_absent(self): + """Empty plan_exec config → default_policy() used, engine built.""" + from agentkit.server.routes.chat import _build_phase_engine + + server_config = MagicMock() + server_config.plan_exec = {} + + engine, tools, err = _build_phase_engine( + server_config=server_config, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=[], + ) + assert engine is not None + assert tools is not None + assert err is None + # Default policy: PLANNING allows search, blocks write_file + assert "search" in engine._phase_policy.whitelist[PhaseState.PLANNING] + assert "write_file" not in engine._phase_policy.whitelist[PhaseState.PLANNING] + + def test_returns_error_when_phase_policy_invalid(self): + from agentkit.server.routes.chat import _build_phase_engine + + server_config = MagicMock() + server_config.plan_exec = {"start_phase": "invalid_phase_name"} + + engine, tools, err = _build_phase_engine( + server_config=server_config, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=[], + ) + assert engine is None + assert tools is None + assert err is not None + assert "phase policy error" in err + + def test_appends_advance_phase_tool_to_tools(self): + from agentkit.server.routes.chat import _build_phase_engine + + server_config = MagicMock() + server_config.plan_exec = {} + + base_tool = MagicMock() + engine, tools, err = _build_phase_engine( + server_config=server_config, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=[base_tool], + ) + assert err is None + assert engine is not None + assert tools is not None + # base_tool preserved + AdvancePhaseTool appended + assert len(tools) == 2 + assert tools[0] is base_tool + assert isinstance(tools[1], AdvancePhaseTool) + + def test_engine_uses_default_policy_when_config_returns_none(self, monkeypatch): + """policy_from_config returning None → default_policy() used.""" + from agentkit.server.routes import chat as chat_module + + def _stub_policy_from_config(cfg): + return None + + monkeypatch.setattr(chat_module, "policy_from_config", _stub_policy_from_config) + + server_config = MagicMock() + server_config.plan_exec = {"enabled": True} + + engine, tools, err = chat_module._build_phase_engine( + server_config=server_config, + llm_gateway=MagicMock(), + execution_mode=ExecutionMode.PLAN_EXEC, + base_tools=[], + ) + assert err is None + assert engine is not None + assert engine._phase_policy is not None + # Default policy's start phase is PLANNING + assert engine._current_phase == PhaseState.PLANNING -- 2.43.0 From 2abe7c9e49d2734519f9bc868dabc89838c584ca Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 11:11:03 +0800 Subject: [PATCH 4/8] feat(U4): frontend phase_violation handling + PhaseIndicator component Extend the frontend to surface PLAN_EXEC phase lifecycle events to the user: - WsServerMessage union (types.ts) gains two branches: `phase_changed` and `phase_violation` (matching backend U2 emission). - chat.ts Pinia store gains a phase state slice: `currentPhase`, `phaseViolations` (capped at 5), `isPlanExec` computed, and `resetPlanExecState()`. - handleWsMessage adds `case "phase_changed"` (sets currentPhase + appends a milestone step) and `case "phase_violation"` (sets currentPhase from violation data, appends to violations, fires an ant-design-vue message.warning toast, appends an error step). - `result` handler calls `resetPlanExecState()` to clear the indicator when the conversation completes. - New `PhaseIndicator.vue` component: compact badge + 4 dots (PLANNING/BUILDING/VERIFICATION/DELIVERY) with the current phase highlighted + violation counter. Renders nothing when `!isPlanExec` (graceful degradation). - Mounted in `ChatView.vue` alongside ExpertTeamView and BoardStatusView. Tests: - New `tests/unit/stores/chat-phase.test.ts` verifies the phase state slice is exposed with correct initial values and `isPlanExec` derives from `currentPhase`. - `npm run typecheck` clean. - Pre-existing `tauri-auth.test.ts` failure is unrelated (fails in isolation on main). --- src/agentkit/server/frontend/src/api/types.ts | 3 + .../src/components/chat/PhaseIndicator.vue | 142 ++++++++++++++++++ .../server/frontend/src/stores/chat.ts | 81 ++++++++++ .../server/frontend/src/views/ChatView.vue | 2 + .../tests/unit/stores/chat-phase.test.ts | 77 ++++++++++ 5 files changed, 305 insertions(+) create mode 100644 src/agentkit/server/frontend/src/components/chat/PhaseIndicator.vue create mode 100644 src/agentkit/server/frontend/tests/unit/stores/chat-phase.test.ts diff --git a/src/agentkit/server/frontend/src/api/types.ts b/src/agentkit/server/frontend/src/api/types.ts index c5a3599..9ff3421 100644 --- a/src/agentkit/server/frontend/src/api/types.ts +++ b/src/agentkit/server/frontend/src/api/types.ts @@ -146,6 +146,9 @@ export type WsServerMessage = | { type: 'phase_started'; data: { phase_id: string; phase_name: string; assigned_expert: string; depends_on: string[] } } | { type: 'phase_completed'; data: { phase_id: string; phase_name: string; result_summary: string } } | { type: 'phase_failed'; data: { phase_id: string; phase_name: string; error: string } } + // PLAN_EXEC (U4) — phase lifecycle events emitted by ReActEngine. + | { type: 'phase_changed'; data: { phase: string; previous: string } } + | { type: 'phase_violation'; data: { current_phase: string; tool: string; message: string; violation_kind: string; command_preview?: string } } | { type: 'team_synthesis'; data: { content: string } } | { type: 'team_dissolved'; data: { team_id: string } } // Board Meeting 模式事件 diff --git a/src/agentkit/server/frontend/src/components/chat/PhaseIndicator.vue b/src/agentkit/server/frontend/src/components/chat/PhaseIndicator.vue new file mode 100644 index 0000000..9a08107 --- /dev/null +++ b/src/agentkit/server/frontend/src/components/chat/PhaseIndicator.vue @@ -0,0 +1,142 @@ + + + + + diff --git a/src/agentkit/server/frontend/src/stores/chat.ts b/src/agentkit/server/frontend/src/stores/chat.ts index 728b441..b49e5be 100644 --- a/src/agentkit/server/frontend/src/stores/chat.ts +++ b/src/agentkit/server/frontend/src/stores/chat.ts @@ -174,6 +174,27 @@ export const useChatStore = defineStore("chat", () => { () => boardState.value !== null && boardState.value.status === "discussing", ); + // PLAN_EXEC phase state (U4) — tracks current phase + violations for the + // PhaseIndicator component. Set when the first phase_* event arrives. + // Reset on conversation switch or final_answer. + const currentPhase = ref(null); + const phaseViolations = ref< + Array<{ + phase: string; + tool: string; + message: string; + violation_kind: string; + command_preview?: string; + ts: number; + }> + >([]); + const isPlanExec = computed(() => currentPhase.value !== null); + + function resetPlanExecState(): void { + currentPhase.value = null; + phaseViolations.value = []; + } + // Debate state (transient, only active during a debate collaboration) const debateState = ref<{ topic: string; @@ -1096,6 +1117,8 @@ export const useChatStore = defineStore("chat", () => { // across multiple interactions. The UI has already transitioned // to showing the final assistant message. clearConvSteps(conversationId); + // Reset PLAN_EXEC phase state — the conversation is done. + resetPlanExecState(); break; } @@ -1390,6 +1413,60 @@ export const useChatStore = defineStore("chat", () => { break; } + // ── PLAN_EXEC (U4) — phase lifecycle events from ReActEngine ──────── + + case "phase_changed": { + currentPhase.value = data.data.phase; + const cid = resolveIncomingConvId(); + if (cid) { + appendStep( + { + type: "milestone", + label: "阶段切换", + detail: `${data.data.previous || "—"} → ${data.data.phase}`, + status: "success", + }, + cid, + ); + } + break; + } + + case "phase_violation": { + // Track current phase from violation data (backend doesn't emit + // phase_changed yet — U4 frontend is forward-compatible). + currentPhase.value = data.data.current_phase; + const violation = { + phase: data.data.current_phase, + tool: data.data.tool, + message: data.data.message, + violation_kind: data.data.violation_kind, + command_preview: data.data.command_preview, + ts: Date.now(), + }; + phaseViolations.value = [...phaseViolations.value, violation].slice(-5); + // Toast notification via ant-design-vue message. + import("ant-design-vue").then(({ message }) => { + message.warning( + `[${data.data.current_phase}] 工具 ${data.data.tool} 被拦截: ${data.data.message}`, + 5, + ); + }); + const cid = resolveIncomingConvId(); + if (cid) { + appendStep( + { + type: "team_event", + label: "阶段违规", + detail: `${data.data.current_phase} · ${data.data.tool}`, + status: "error", + }, + cid, + ); + } + break; + } + // ── Board Meeting 模式事件 ──────────────────────────────────────── case "board_started": { @@ -1920,6 +1997,10 @@ export const useChatStore = defineStore("chat", () => { boardState, debateState, collaborationState, + // PLAN_EXEC (U4) + currentPhase, + phaseViolations, + isPlanExec, // Legacy aliases (derive from current conversation for backward compat). // New code should use `isCurrentLoading` / `currentStreamingSteps` instead. isLoading: isCurrentLoading, diff --git a/src/agentkit/server/frontend/src/views/ChatView.vue b/src/agentkit/server/frontend/src/views/ChatView.vue index 81b9cc5..b22ad64 100644 --- a/src/agentkit/server/frontend/src/views/ChatView.vue +++ b/src/agentkit/server/frontend/src/views/ChatView.vue @@ -20,6 +20,7 @@