feat(agent): Wave 4 PLAN_EXEC Hardening (U1-U5) #7
17
CONCEPTS.md
17
CONCEPTS.md
|
|
@ -59,3 +59,20 @@ The auth store's cold-start state machine with values `valid` / `invalid` / `err
|
|||
|
||||
### Service Broadcast Callback
|
||||
The convention for pushing backend state changes to the user's open frontend tabs in real time without coupling domain services to the WebSocket transport. A service accepts an optional async callback at construction; after a successful mutation it best-effort invokes the callback with a typed message envelope. Delivery failure is logged but never rolls back the mutation — the persisted state is the source of truth, the broadcast is a latency optimization. The callback is wired at the composition root (app lifespan) to the portal's per-user fan-out primitive, so the service stays layer-pure. The same callback shape is shared by CRUD services, reminder dispatchers, and sync providers, giving all real-time updates a single exit point.
|
||||
|
||||
## PLAN_EXEC
|
||||
|
||||
### Phase State Machine
|
||||
The four-phase lifecycle (`PLANNING → BUILDING → VERIFICATION → DELIVERY`) that constrains which tools an agent may call at each step of a PLAN_EXEC run. Each phase has a tool whitelist (e.g., PLANNING allows `search`/`read_file`; BUILDING allows `write_file`); `WILDCARD = "*"` means all tools allowed (used by DELIVERY by default). Transitions are LLM-driven via `AdvancePhaseTool` — the LLM decides when to advance, there is no implicit timer. `auto_advance_after_steps` opts into a step-count trigger as an alternative to explicit `advance_phase` calls. State lives in `ReActEngine.current_phase`; `PhaseState` is an enum with `from_string()` parser.
|
||||
|
||||
### PhasePolicy
|
||||
The dataclass holding the per-phase whitelist (`dict[PhaseState, frozenset[str]]`) and bash command filter (`dict[PhaseState, Callable[[str], bool] | re.Pattern | None]`). Constructed via `default_policy()` (hardcoded KTD5 defaults) or `policy_from_config(plan_exec_cfg)` (YAML-driven). `policy_from_config` returns `None` for empty/disabled config — signaling "opt out, fall back to REACT" (not "use defaults"). `is_bash_command_allowed` accepts either a `Callable` (returns True if dangerous; `ShellTool._is_dangerous` is the default) or a legacy `re.Pattern` (matches dangerous substrings). The `Callable` path closes the regex ceiling — regex missed `:>file`, `dd of=file`, and unknown binaries.
|
||||
|
||||
### Phase Violation
|
||||
The structured event emitted when `_check_phase_permission` blocks a tool call in PLAN_EXEC mode. Two emission paths: (1) the violation is re-injected into the LLM as a tool result so the loop continues (the LLM can switch tools or call `advance_phase`); (2) a `phase_violation` WS event is forwarded to the client for UI feedback (PhaseIndicator component). Violations carry `current_phase`, `tool`, `violation_kind` (`tool_not_allowed` / `bash_command_blocked`), `message`, and `command_preview`. The engine accumulates violations in `_phase_violations` and drains them via `_drain_phase_violations` after each tool dispatch — drains are the caller's responsibility at three sites in `execute_stream`.
|
||||
|
||||
### AdvancePhaseTool
|
||||
The state-transition tool that moves the engine between phases. Calls `engine.advance_phase()` which transitions `current_phase` to the next enum value (PLANNING→BUILDING→VERIFICATION→DELIVERY; DELIVERY is terminal). Returns `{"previous_phase", "current_phase", "message"}`. Bypasses `_check_phase_permission` (always permitted) — phase advancement is not subject to the whitelist it enforces. Known limitation: the default `_loop_threshold=2` fires on the 2nd identical `advance_phase({})` call because all transitions produce the same argument hash — needs exemption from loop detection.
|
||||
|
||||
### _build_phase_engine
|
||||
The chat.py helper that consolidates PLAN_EXEC engine construction for both WS and REST paths. Returns `(engine, tools_with_advance_phase, error_message)`. Returns `(None, None, None)` when `execution_mode != PLAN_EXEC` or `plan_exec.enabled=False` (both signal "fall back to REACT"). Returns `(None, None, error_message)` on policy construction failure. The helper ensures WS and REST paths share identical PhasePolicy + AdvancePhaseTool wiring — previously the REST path returned 501 because it had no construction logic.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,428 @@
|
|||
---
|
||||
title: "feat: Agent Wave 4 PLAN_EXEC hardening (REST wiring + frontend events + bash filter upgrade + e2e tests)"
|
||||
date: 2026-06-30
|
||||
type: feat
|
||||
status: draft
|
||||
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md (Wave 3 deferred items + Wave 3 code review residual risks)
|
||||
execution: code
|
||||
---
|
||||
|
||||
# Wave 4 PLAN_EXEC Hardening — REST Symmetry + Frontend Visibility + Filter Upgrade + E2E Coverage
|
||||
|
||||
## Summary
|
||||
|
||||
Wave 3 closed G5 (function-level sharding) and G6 (SOLO four-stage state machine) at the **WebSocket path only**. Three concrete gaps remain before PLAN_EXEC is production-ready:
|
||||
|
||||
1. **REST asymmetry** — `POST /api/v1/chat/{session_id}/send_message` with `execution_mode="plan_exec"` returns HTTP 501 (chat.py:590-595). WebSocket has a real handler; REST does not.
|
||||
2. **No frontend visibility** — `phase_changed` events are emitted on the WS socket (chat.py:1295-1309) but `frontend/src/api/types.ts:135` `WsServerMessage` has no `phase_changed`/`phase_violation` branch. `phase_violation` is not emitted to the client at all (only injected back to the LLM as a tool error at react.py:2196-2203). The Vue UI cannot show what phase the agent is in or surface policy violations.
|
||||
3. **Bash filter ceiling** — `core/phase.py:56-60` ponytail comment names the ceiling explicitly: regex misses `:>file`, `dd of=file`. Upgrade path = reuse `ShellTool._is_dangerous()` (shell.py:466) at enforcement time.
|
||||
|
||||
Wave 4 closes all three and ships an E2E integration test that exercises the full PLAN_EXEC path (LLM → phase transition → tool dispatch → WS event). It does **not** migrate to tree-sitter (KTD1 upgrade path remains for Wave 5+) or add phase persistence across session resume (U7 checkpoint scope, separate concern).
|
||||
|
||||
## Problem Frame
|
||||
|
||||
Wave 3 ships PLAN_EXEC behind a feature flag (`agentkit.yaml` `plan_exec.enabled`, default commented = opt-in). To turn it on in production, three conditions must hold:
|
||||
|
||||
- **Symmetric entry points**: callers using REST (CLI `agentkit task submit`, external integrations) and callers using WebSocket (Vue chat UI) must both be able to invoke PLAN_EXEC. Today only WS works.
|
||||
- **Observable behavior**: when the agent transitions phases or rejects a tool call, the user must see it. Today the WS event is emitted but the frontend silently drops it; violations are invisible to the user.
|
||||
- **Hardened safety boundary**: the bash filter is the only thing preventing a Planning-phase LLM from running `rm -rf`. The regex is conservative by the author's own admission (`:>file` slips through). Production use requires the same safety guarantee `ShellTool` already provides.
|
||||
|
||||
Wave 3's "Out of Scope (Deferred to Follow-Up Work)" section explicitly lists "REST `send_message` PLAN_EXEC wiring" and "Tool-filter UI in the frontend" — Wave 4 executes those deferrals. The bash filter upgrade was surfaced by Wave 3 code review (ponytail ceiling) rather than the brainstorm.
|
||||
|
||||
## Requirements
|
||||
|
||||
Carried forward from Wave 3 plan's deferred items + Wave 3 code review residuals:
|
||||
|
||||
- **R28**: REST `POST /api/v1/chat/{session_id}/send_message` with `execution_mode="plan_exec"` invokes the same PLAN_EXEC handler logic as the WebSocket path, returning a `ChatResult` with phase events recorded (mirrors WS event sequence).
|
||||
- **R29**: REST PLAN_EXEC bypasses the Wave 2 fallback chain (mutually exclusive with phase policy — chat.py:1093-1095 documents the constraint).
|
||||
- **R30**: WebSocket emits a `phase_violation` event to the client when `ReActEngine._check_phase_permission` blocks a tool call (currently only returned to the LLM).
|
||||
- **R31**: Frontend `WsServerMessage` union (types.ts:135) includes `phase_changed` and `phase_violation` cases; `handleWsMessage` (chat.ts:800) dispatches both to a phase state slice.
|
||||
- **R32**: A compact phase indicator renders the current phase + violation toasts in the Vue chat view.
|
||||
- **R33**: `PhasePolicy.bash_command_filter` accepts a `Callable[[str], bool]` callback in addition to `re.Pattern`; default policy wires `ShellTool._is_dangerous` so `:>file` and `dd of=file` are blocked.
|
||||
- **R34**: An E2E integration test exercises the full PLAN_EXEC path through a scripted LLM mock: planning → `advance_phase` → building → `write_file` → verification → delivery, asserting WS events and tool dispatches.
|
||||
|
||||
Cross-cutting:
|
||||
|
||||
- **R26** (inherited): all configuration via `agentkit.yaml` `plan_exec` section, parsed by `ServerConfig.from_dict`.
|
||||
- **R27** (inherited): each unit ships a minimal self-check (ponytail rule).
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: Extract `_build_phase_engine` helper shared by WS and REST
|
||||
|
||||
**Decision**: Refactor `chat.py:1093-1153` (the WS PLAN_EXEC engine construction block) into a private `_build_phase_engine(server_config, agent, tools, ...) -> tuple[ReActEngine, list[Tool]]` helper. Both `_execute_plan_exec_ws` and the new `_execute_plan_exec_rest` call it.
|
||||
|
||||
**Rationale**:
|
||||
- The WS block currently inlines policy construction + engine instantiation + AdvancePhaseTool registration. REST needs the same assembly.
|
||||
- Single source of truth for "how to build a phase-enforced engine" prevents drift between WS and REST paths.
|
||||
- Helper is private (underscore prefix) — not a public API; test access goes through the routes.
|
||||
|
||||
### KTD2: REST PLAN_EXEC returns non-streaming `ChatResult`; SSE streaming is deferred
|
||||
|
||||
**Decision**: `_execute_plan_exec_rest()` returns a regular `SendMessageResponse` (matching the existing REST send_message shape). The `phase_changed`/`phase_violation` events are captured into a `phase_events: list` field on the response payload.
|
||||
|
||||
**Rationale**:
|
||||
- Existing REST `send_message` is non-streaming (chat.py:580). Streaming REST (SSE) is a separate concern owned by the `/api/v1/llm/chat` gateway route (llm_gateway.py), not chat.py.
|
||||
- First version ships parity with existing REST shape; SSE streaming for PLAN_EXEC is a follow-up if users request it.
|
||||
- The phase events list lets REST clients render phase progression after-the-fact (CLI `agentkit task submit` shows them in terminal output).
|
||||
|
||||
### KTD3: `phase_violation` event emitted to WS alongside LLM injection
|
||||
|
||||
**Decision**: In `react.py:_execute_loop`, when `_check_phase_permission` blocks a tool call, the existing structured error is injected to the LLM conversation (unchanged), AND a `phase_violation` event is emitted through the engine's event stream. `chat.py` WS handler forwards it to the client.
|
||||
|
||||
**Rationale**:
|
||||
- Wave 3 returns the violation only to the LLM (gives the model a chance to self-correct by calling `advance_phase`). That stays.
|
||||
- Adding the WS event gives the user visibility into "the LLM tried to call `write_file` in Planning, was rejected, and will retry" — without this, the UI shows the LLM thinking silently which looks like a hang.
|
||||
- Event payload: `{"type": "phase_violation", "data": {"tool": "write_file", "phase": "planning", "hint": "call advance_phase"}}`.
|
||||
|
||||
### KTD4: `bash_command_filter` accepts `Callable[[str], bool] | re.Pattern | None`
|
||||
|
||||
**Decision**: Change `PhasePolicy.bash_command_filter` field type from `dict[PhaseState, re.Pattern | None]` to `dict[PhaseState, Callable[[str], bool] | re.Pattern | None]`. `is_bash_command_allowed` detects callable vs pattern at call time. `default_policy()` injects `ShellTool._is_dangerous` as the callable for PLANNING/VERIFICATION.
|
||||
|
||||
**Rationale**:
|
||||
- `ShellTool._is_dangerous` (shell.py:466) is already battle-tested against `_DANGEROUS_BINARIES`, `_DANGEROUS_BINARY_FLAGS`, `_DANGEROUS_ARG_PATTERNS`, shell-chain operators, and pipe operators. Reusing it eliminates the regex ceiling the ponytail comment named.
|
||||
- The `re.Pattern` form stays for backward compat (config-supplied regex patterns still work).
|
||||
- `PhaseState` enum and `PhasePolicy` API stay stable; only the field type widens.
|
||||
|
||||
**Alternative considered**: Move the filter to `ShellTool` itself (gateway-level). Rejected because phase enforcement is per-step in ReActEngine, not per-shell-call — different lifecycle.
|
||||
|
||||
### KTD5: Phase indicator UI is compact, optional, and degrades gracefully
|
||||
|
||||
**Decision**: Add a `PhaseIndicator.vue` component (badge + progress dots for 4 phases + transient toast for `phase_violation`). Mount it in the chat view header only when the current session has `execution_mode="plan_exec"`; otherwise render nothing.
|
||||
|
||||
**Rationale**:
|
||||
- Most chat sessions are REACT/SKILL_REACT — phase indicator is noise for them. Conditionally render only for PLAN_EXEC sessions.
|
||||
- Compact form (badge + dots) avoids competing with the existing `PlanVisualization.vue` (team mode, different concept — don't unify them).
|
||||
- Toast pattern matches existing `useMessage` from `ant-design-vue` used elsewhere in the frontend.
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
|
||||
- Refactor WS PLAN_EXEC engine construction into `_build_phase_engine` shared helper.
|
||||
- New `_execute_plan_exec_rest` for REST send_message; remove the 501 at chat.py:590-595.
|
||||
- Emit `phase_violation` event from `ReActEngine._execute_loop` through the WS handler.
|
||||
- Frontend `WsServerMessage` union extension + `handleWsMessage` cases + new `PhaseIndicator.vue`.
|
||||
- `PhasePolicy.bash_command_filter` type widening + `default_policy()` wiring to `ShellTool._is_dangerous`.
|
||||
- E2E integration test with scripted LLM mock covering full PLAN_EXEC lifecycle.
|
||||
|
||||
### Out of Scope (Deferred to Follow-Up Work)
|
||||
|
||||
- SSE streaming for REST PLAN_EXEC (KTD2 — non-streaming first; SSE follow-up if requested).
|
||||
- `tree-sitter` integration for symbol extraction (Wave 3 KTD1 upgrade path; Wave 5+ candidate).
|
||||
- Phase persistence across session resume (depends on U7 checkpoint deeper changes).
|
||||
- Phase-aware prompt engineering (per-phase system prompt templates — prompt-engineering concern, not code).
|
||||
- Phase rollback on `Building → Planning` regression (UX/prompt concern; Wave 2 G9 file-level rollback already handles file state).
|
||||
- `config_sync.py` exposure of `plan_exec` to frontend (frontend reads phase events from WS, not config — config exposure only needed if the UI wants to render phase whitelists, which is out of scope).
|
||||
- Recovery/Emergency layer integration with PLAN_EXEC (mutually exclusive by design — chat.py:1093-1095 documents this; integrating would require ReflexionEngine to understand phase state, separate Wave).
|
||||
|
||||
### Outside This Product's Identity
|
||||
|
||||
- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm).
|
||||
- Disc-based file system à la DeerFlow (inherited).
|
||||
- Docker sandbox (inherited; only command-level safety via `bash_command_filter`).
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. Bash filter upgrade — reuse `ShellTool._is_dangerous()` (G6 hardening)
|
||||
|
||||
**Goal**: Widen `PhasePolicy.bash_command_filter` to accept `Callable[[str], bool]` callbacks and wire `ShellTool._is_dangerous` as the default filter for PLANNING/VERIFICATION phases. Eliminate the ponytail ceiling at `core/phase.py:56-60`.
|
||||
|
||||
**Requirements**: R33, R27.
|
||||
|
||||
**Dependencies**: none.
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/phase.py` (modify — widen field type; inject `ShellTool._is_dangerous` in `default_policy()`; update `is_bash_command_allowed` to handle callable).
|
||||
- `tests/unit/test_phase_policy.py` (modify — add cases for `:>file`, `dd of=file`, callable vs pattern).
|
||||
|
||||
**Approach**:
|
||||
- Field type: `bash_command_filter: dict[PhaseState, Callable[[str], bool] | re.Pattern | None]`.
|
||||
- `is_bash_command_allowed(command, phase)`:
|
||||
- `filter = self.bash_command_filter.get(phase)`
|
||||
- `if filter is None: return True`
|
||||
- `if callable(filter): return not filter(command)`
|
||||
- `if isinstance(filter, re.Pattern): return not filter.search(command)`
|
||||
- `default_policy()` replaces `_DEFAULT_BASH_FILTER` regex with `ShellTool._is_dangerous` method reference (bound method, callable).
|
||||
- Keep `_DEFAULT_BASH_FILTER` regex as a module constant for tests and config-supplied patterns; `default_policy()` no longer uses it.
|
||||
- Remove the ponytail comment at `core/phase.py:56-60` (ceiling is closed).
|
||||
|
||||
**Execution note**: characterization-first — test that `default_policy().is_bash_command_allowed("rm -rf /", PLANNING)` still returns False (preserves Wave 3 behavior) before adding new edge-case coverage.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/core/phase.py:default_policy()` (Wave 3 — same factory pattern).
|
||||
- `src/agentkit/tools/shell.py:_is_dangerous` (Wave 3 — already the canonical safety check).
|
||||
|
||||
**Test scenarios** (covers R33):
|
||||
- **Characterization (Wave 3 preserved)**:
|
||||
- `default_policy().is_bash_command_allowed("rm -rf /", PLANNING)` → False.
|
||||
- `default_policy().is_bash_command_allowed("ls -la", PLANNING)` → True.
|
||||
- `default_policy().is_bash_command_allowed("git status", PLANNING)` → True.
|
||||
- **Happy paths (new ceiling closed)**:
|
||||
- `:>file` in PLANNING → False (was True before — `ShellTool._is_dangerous` catches redirect-to-empty).
|
||||
- `dd of=/dev/sda` in PLANNING → False (was True before — caught by `_DANGEROUS_BINARIES`).
|
||||
- `echo hello > /tmp/x` in PLANNING → False (was True before — `ShellTool` catches `>` redirect).
|
||||
- **Edge cases**:
|
||||
- `re.Pattern` form still works when supplied via config (`whitelist_override`-adjacent — config-supplied regex pattern is honored).
|
||||
- `callable` form takes precedence over `re.Pattern` when both somehow present (defensive — shouldn't happen).
|
||||
- **Error paths**:
|
||||
- Empty command in PLANNING → True (ShellTool separately rejects empty commands at execution time; filter only gates dangerous patterns).
|
||||
- None filter for BUILDING → True (no restriction).
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes.
|
||||
- `ruff check src/agentkit/core/phase.py` clean.
|
||||
- Ponytail comment at `core/phase.py:56-60` is removed (ceiling closed, not just documented).
|
||||
|
||||
---
|
||||
|
||||
### U2. Emit `phase_violation` WS event from `ReActEngine`
|
||||
|
||||
**Goal**: When `_check_phase_permission` blocks a tool call, emit a `phase_violation` event through the engine's event stream so `chat.py` WS handler can forward it to the client. Today the violation is only injected back to the LLM (react.py:2196-2203), invisible to the user.
|
||||
|
||||
**Requirements**: R30.
|
||||
|
||||
**Dependencies**: none (independent of U1 — violation emission doesn't depend on filter implementation).
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/react.py` (modify — emit event alongside the existing LLM injection).
|
||||
- `src/agentkit/server/routes/chat.py` (modify — forward `phase_violation` events from `execute_stream` to the WS client).
|
||||
- `tests/unit/test_react_phase_enforcement.py` (modify — assert event emission).
|
||||
- `tests/unit/test_chat_plan_exec_ws.py` (modify — assert WS client receives `phase_violation` event).
|
||||
|
||||
**Approach**:
|
||||
- `ReActEngine.execute_stream` already yields events for `tool_call`/`tool_result`/`thinking`/`token`. Add a new event type `phase_violation` yielded before the structured error is injected to the LLM conversation.
|
||||
- Event payload: `{"type": "phase_violation", "data": {"tool": "<tool_name>", "phase": "<current_phase>", "hint": "call advance_phase"}}`.
|
||||
- `chat.py` WS handler (around chat.py:1218 `async for event in react_engine.execute_stream(...)`) adds an `elif event["type"] == "phase_violation":` branch that `websocket.send_json` the event to the client.
|
||||
- Existing LLM-injection path is unchanged — the LLM still gets the structured error to react to.
|
||||
|
||||
**Execution note**: characterization-first — assert that `phase_policy=None` (no enforcement) yields zero `phase_violation` events (preserves Wave 3 behavior) before adding the positive-path test.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/core/react.py` existing event emission (e.g., `tool_call` event emission pattern).
|
||||
- `src/agentkit/server/routes/chat.py:1295-1309` `phase_changed` event forwarding (same shape).
|
||||
|
||||
**Test scenarios** (covers R30):
|
||||
- **Characterization (no policy)**:
|
||||
- `ReActEngine(phase_policy=None)` executing a full loop yields zero `phase_violation` events.
|
||||
- **Happy paths**:
|
||||
- PLANNING phase, LLM calls `write_file` → engine yields `phase_violation` event with `tool="write_file"`, `phase="planning"`, `hint="call advance_phase"`.
|
||||
- WS handler forwards `phase_violation` to client connection (assert `websocket.send_json` called with `{"type": "phase_violation", ...}`).
|
||||
- LLM still receives the structured error in conversation (regression — Wave 3 behavior preserved).
|
||||
- **Edge cases**:
|
||||
- Multiple violations in a row (LLM retries same tool) → multiple `phase_violation` events emitted (one per attempt).
|
||||
- Violation followed by `advance_phase` followed by same tool now allowed → exactly one `phase_violation` event, then a `tool_call` event.
|
||||
- **Error paths**:
|
||||
- Phase policy construction failure → existing 500 error path, no `phase_violation` emitted (engine not constructed).
|
||||
- **Integration scenarios**:
|
||||
- Full WS path: client connects, sends PLAN_EXEC request, LLM mock emits `write_file` in PLANNING → client receives `phase_violation` event before any `tool_call` event.
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_react_phase_enforcement.py tests/unit/test_chat_plan_exec_ws.py -q` passes.
|
||||
- `ruff check src/agentkit/core/react.py src/agentkit/server/routes/chat.py` clean.
|
||||
|
||||
---
|
||||
|
||||
### U3. Refactor `_build_phase_engine` helper + REST PLAN_EXEC wiring
|
||||
|
||||
**Goal**: Extract the WS PLAN_EXEC engine construction (chat.py:1093-1153) into a private `_build_phase_engine(server_config, agent, tools, ...) -> tuple[ReActEngine, list[Tool]]` helper. Add `_execute_plan_exec_rest()` for REST `send_message`; replace the 501 at chat.py:590-595.
|
||||
|
||||
**Requirements**: R28, R29, R26.
|
||||
|
||||
**Dependencies**: U1 (uses the hardened `default_policy()`).
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/server/routes/chat.py` (modify — extract helper; add REST handler; remove 501).
|
||||
- `tests/unit/test_chat_plan_exec_ws.py` (modify — add REST PLAN_EXEC test cases).
|
||||
- `tests/unit/test_chat_rest_plan_exec.py` (new — REST-specific coverage).
|
||||
|
||||
**Approach**:
|
||||
- New private `_build_phase_engine(server_config, agent, tools, system_prompt, model) -> tuple[ReActEngine, list[Tool]]`:
|
||||
1. Read `server_config.plan_exec` (default `{}`).
|
||||
2. If `enabled is False`, return `(None, tools)` (caller falls back to REACT).
|
||||
3. Build `PhasePolicy` via `policy_from_config`; on failure or None, fall back to `default_policy()`.
|
||||
4. Construct `ReActEngine(..., phase_policy=policy)`.
|
||||
5. Register `AdvancePhaseTool` bound to the engine; return `(engine, tools + [advance_phase])`.
|
||||
- WS path: `_execute_plan_exec_ws` calls `_build_phase_engine`; if engine is None, falls back to REACT (existing behavior at chat.py:1101-1107).
|
||||
- REST path: `_execute_plan_exec_rest(request, session_id, ...)`:
|
||||
1. Calls `_build_phase_engine`.
|
||||
2. If engine is None, delegates to `execute_with_fallback_chain` (REST keeps fallback chain for non-PLAN_EXEC).
|
||||
3. Otherwise calls `engine.execute(...)` (non-streaming, single-shot — matches existing REST send_message shape).
|
||||
4. Collects `phase_changed`/`phase_violation` events into a `phase_events: list[dict]` field on the response payload.
|
||||
5. Returns `SendMessageResponse` extended with optional `phase_events` field.
|
||||
- Replace chat.py:590-595 with a branch: if `routing.execution_mode == PLAN_EXEC`, call `_execute_plan_exec_rest`; else continue with existing fallback chain.
|
||||
- `SendMessageResponse` model gains an optional `phase_events: list[dict] | None = None` field (default None keeps backward compat for non-PLAN_EXEC responses).
|
||||
|
||||
**Execution note**: characterization-first — assert that REST send_message with `execution_mode="react"` (or None) still goes through `execute_with_fallback_chain` (Wave 2 behavior unchanged) before adding PLAN_EXEC branch.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/server/routes/chat.py:1093-1153` (existing WS PLAN_EXEC block — the code being extracted).
|
||||
- `src/agentkit/server/_fallback_chain.py:execute_with_fallback_chain` (Wave 2 — REST non-PLAN_EXEC path stays here).
|
||||
|
||||
**Test scenarios** (covers R28, R29):
|
||||
- **Characterization (REST non-PLAN_EXEC preserved)**:
|
||||
- REST `send_message` with `execution_mode="react"` → calls `execute_with_fallback_chain` (Wave 2 path unchanged).
|
||||
- REST `send_message` with `execution_mode=None` → defaults to REACT, fallback chain applies.
|
||||
- **Happy paths**:
|
||||
- REST `send_message` with `execution_mode="plan_exec"` → returns 200 (not 501).
|
||||
- Response includes `phase_events: list` with at least one `phase_changed` entry when the engine transitions.
|
||||
- REST with empty `plan_exec` config → uses `default_policy()` (KTD5 default whitelist).
|
||||
- **Edge cases**:
|
||||
- REST with `plan_exec.enabled=False` → falls back to REACT, response has `phase_events=None`.
|
||||
- REST with bad `plan_exec` config (invalid phase name) → 500 with error message naming the bad value.
|
||||
- REST PLAN_EXEC with phase violation → `phase_events` includes a `phase_violation` entry.
|
||||
- **Error paths**:
|
||||
- REST PLAN_EXEC when session is closed → 400 (existing path, no change).
|
||||
- REST PLAN_EXEC with non-existent session → 404 (existing path).
|
||||
- **Integration scenarios**:
|
||||
- REST PLAN_EXEC bypasses fallback chain: assert `execute_with_fallback_chain` is NOT called when `execution_mode="plan_exec"` (mutual exclusion per R29).
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py tests/unit/test_chat_rest_plan_exec.py -q` passes.
|
||||
- `ruff check src/agentkit/server/routes/chat.py` clean.
|
||||
- The 501 at chat.py:590-595 is removed.
|
||||
|
||||
---
|
||||
|
||||
### U4. Frontend phase event pipeline + `PhaseIndicator.vue`
|
||||
|
||||
**Goal**: Extend `WsServerMessage` union with `phase_changed` and `phase_violation` event types; add `handleWsMessage` cases that update a phase state slice; add a compact `PhaseIndicator.vue` component mounted only for PLAN_EXEC sessions.
|
||||
|
||||
**Requirements**: R31, R32.
|
||||
|
||||
**Dependencies**: U2 (frontend renders `phase_violation` events emitted by backend).
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/server/frontend/src/api/types.ts` (modify — extend `WsServerMessage` union).
|
||||
- `src/agentkit/server/frontend/src/stores/chat.ts` (modify — add `phase` state slice; add cases in `handleWsMessage`).
|
||||
- `src/agentkit/server/frontend/src/components/PhaseIndicator.vue` (new — badge + dots + toast).
|
||||
- `src/agentkit/server/frontend/src/views/AgentChatView.vue` (modify — mount `PhaseIndicator` conditionally).
|
||||
- `src/agentkit/server/frontend/tests/unit/PhaseIndicator.spec.ts` (new — component test).
|
||||
- `src/agentkit/server/frontend/src/api/types.ts` (verify — `PlanExecutionMode` type already covers `"plan_exec"`).
|
||||
|
||||
**Approach**:
|
||||
- `WsServerMessage` union gains two branches: `{ type: "phase_changed"; data: { phase: string; previous: string } }` and `{ type: "phase_violation"; data: { tool: string; phase: string; hint: string } }`.
|
||||
- `chat.ts` Pinia store gains: `currentPhase: Ref<string | null>`, `phaseViolations: Ref<PhaseViolation[]>`, `isPlanExec: ComputedRef<boolean>` (derived from session's `execution_mode`).
|
||||
- `handleWsMessage` adds `case "phase_changed": currentPhase.value = data.phase;` and `case "phase_violation": phaseViolations.value.push(data);` (capped at last 5 to bound memory).
|
||||
- `PhaseIndicator.vue`:
|
||||
- 4 dots representing PLANNING/BUILDING/VERIFICATION/DELIVERY; current phase highlighted.
|
||||
- On `phase_violation`, show an `ant-design-vue` `message.warning(...)` toast with the violation hint.
|
||||
- Renders nothing when `!isPlanExec` (graceful degradation).
|
||||
- `AgentChatView.vue` mounts `<PhaseIndicator />` in the chat header slot, conditional on `chatStore.isPlanExec`.
|
||||
|
||||
**Execution note**: characterization-first — assert that `handleWsMessage` with `data.type="token"` (existing) still updates message content unchanged, before adding new cases.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/server/frontend/src/stores/chat.ts:1325-1391` (existing team event handling — `phase_started`/`phase_completed` cases shape the new cases).
|
||||
- `src/agentkit/server/frontend/src/components/PlanVisualization.vue` (existing team mode component — different domain but same "compact badge + state" pattern).
|
||||
|
||||
**Test scenarios** (covers R31, R32):
|
||||
- **Characterization (existing events preserved)**:
|
||||
- `handleWsMessage({type: "token", data: ...})` still appends to message content.
|
||||
- `handleWsMessage({type: "team_formed", ...})` still routes to team store.
|
||||
- **Happy paths**:
|
||||
- `handleWsMessage({type: "phase_changed", data: {phase: "building", previous: "planning"}})` → `currentPhase.value === "building"`.
|
||||
- `handleWsMessage({type: "phase_violation", data: {tool: "write_file", phase: "planning", hint: "..."}})` → `phaseViolations.value` length increases by 1.
|
||||
- `PhaseIndicator.vue` with `currentPhase="building"` → renders 4 dots with the 2nd highlighted.
|
||||
- **Edge cases**:
|
||||
- `PhaseIndicator.vue` with `isPlanExec=false` → renders nothing (returns `null` or empty `<template>`).
|
||||
- `phaseViolations` capped at 5 entries (6th violation pushes oldest out).
|
||||
- `phase_changed` event with `previous=""` (initial transition) → no error, `currentPhase` updates.
|
||||
- **Integration scenarios**:
|
||||
- Full mount: `<PhaseIndicator />` mounted in `AgentChatView.vue` with `isPlanExec=true` and `currentPhase="planning"` → renders correctly; `message.warning` toast appears when `phase_violation` received.
|
||||
|
||||
**Verification**:
|
||||
- `cd src/agentkit/server/frontend && npm run typecheck` clean.
|
||||
- `npm run test:unit -- PhaseIndicator` passes.
|
||||
- `npm run lint` clean.
|
||||
|
||||
---
|
||||
|
||||
### U5. E2E integration test for full PLAN_EXEC lifecycle
|
||||
|
||||
**Goal**: A single E2E test that exercises the full PLAN_EXEC path through a scripted LLM mock: PLANNING (search) → `advance_phase` → BUILDING (`write_file`) → `advance_phase` → VERIFICATION (`shell` with `pytest`) → `advance_phase` → DELIVERY (final answer). Asserts WS events sequence, phase transitions, tool dispatches, and `phase_violation` rejection when LLM attempts out-of-phase tool.
|
||||
|
||||
**Requirements**: R34, R27.
|
||||
|
||||
**Dependencies**: U1, U2, U3, U4 (all backend pieces must be in place).
|
||||
|
||||
**Files**:
|
||||
- `tests/integration/test_plan_exec_e2e.py` (new).
|
||||
|
||||
**Approach**:
|
||||
- Mock LLM gateway: returns scripted responses in sequence (deterministic, no real API call):
|
||||
1. `search` tool call (PLANNING-allowed) → tool dispatched.
|
||||
2. `advance_phase` tool call → `phase_changed` event emitted.
|
||||
3. `write_file` tool call (BUILDING-allowed) → tool dispatched.
|
||||
4. `advance_phase` tool call → `phase_changed` event emitted.
|
||||
5. `shell` tool call with `pytest` (VERIFICATION-allowed) → tool dispatched.
|
||||
6. `advance_phase` tool call → `phase_changed` event emitted.
|
||||
7. Final answer text.
|
||||
- Negative path: insert an out-of-phase `write_file` call in step 1 (PLANNING) → assert `phase_violation` event emitted, tool NOT dispatched, LLM receives structured error.
|
||||
- Test asserts:
|
||||
- WS event sequence includes exactly 3 `phase_changed` events (planning→building, building→verification, verification→delivery).
|
||||
- Exactly 1 `phase_violation` event (in the negative path).
|
||||
- Tool dispatch count matches allowed tool calls.
|
||||
- Final `final_answer` event received.
|
||||
|
||||
**Execution note**: This is a characterization test for the wired-up system, not a unit test. Mock the LLM gateway at the `LLMGateway` boundary; use real `ReActEngine`, real `PhasePolicy`, real WS handler (or a WS test client).
|
||||
|
||||
**Patterns to follow**:
|
||||
- `tests/unit/test_chat_plan_exec_ws.py` (Wave 3 — same WS test client pattern).
|
||||
- `tests/integration/test_api_coverage.py` (existing — integration test patterns in the repo).
|
||||
|
||||
**Test scenarios** (covers R34):
|
||||
- **Happy path (full lifecycle)**:
|
||||
- Scripted LLM completes all 4 phases in order → 3 `phase_changed` events, 3 `advance_phase` tool calls dispatched, allowed tools dispatched in each phase, `final_answer` event received.
|
||||
- **Negative path (violation then recovery)**:
|
||||
- LLM attempts `write_file` in PLANNING → `phase_violation` event emitted, `write_file` NOT dispatched (assert `write_file.execute` call count is 0 at this point), LLM receives structured error in conversation.
|
||||
- LLM then calls `advance_phase` → transitions to BUILDING, `write_file` now dispatched successfully.
|
||||
- **Edge cases**:
|
||||
- `plan_exec.enabled=False` config → test asserts path falls back to REACT (no phase events emitted).
|
||||
- LLM never calls `advance_phase` and `auto_advance_after_steps=2` → phase auto-advances after 2 steps (asserts safety net).
|
||||
- **Error paths**:
|
||||
- LLM raises (LLM call fails) → existing error event emitted; phase state unchanged.
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/integration/test_plan_exec_e2e.py -v` passes.
|
||||
- Test runs without real LLM API call (mocked).
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
### Risks
|
||||
|
||||
1. **REST non-streaming shape mismatch (medium)**: REST clients expecting SSE for PLAN_EXEC will not get streaming phase events; they get a list after-the-fact. Mitigation: KTD2 documents this as intentional first version; SSE follow-up tracked as deferred.
|
||||
2. **Frontend state slice bloat (low)**: Adding `currentPhase` + `phaseViolations` to the chat store adds reactive state. Mitigation: `phaseViolations` capped at 5; `currentPhase` is a single string. Negligible memory.
|
||||
3. **`ShellTool._is_dangerous` import cycle (low)**: `core/phase.py` importing from `tools/shell.py` could create a cycle if `shell.py` imports from `core/`. Mitigation: verify import direction at implementation time; if cycle, lift `_is_dangerous` to a shared `tools/_safety.py` module (one-function extraction).
|
||||
4. **E2E test flakiness from mock sequencing (low)**: Scripted LLM mock must match exact sequence. Mitigation: index-based mock (call N returns response N) rather than state-based; deterministic.
|
||||
5. **Backward compat for `re.Pattern` config (low)**: Existing config-supplied regex patterns must still work after the type widening. Mitigation: KTD4 preserves `re.Pattern` branch in `is_bash_command_allowed`; characterization test in U1.
|
||||
|
||||
### Dependencies
|
||||
|
||||
- Wave 3 (PR #6 merged) — `PhasePolicy`, `PhaseState`, `default_policy()`, `AdvancePhaseTool`, WS PLAN_EXEC handler all in place.
|
||||
- Wave 2 (PR #5 merged) — `execute_with_fallback_chain` for REST non-PLAN_EXEC path.
|
||||
- No external library dependencies.
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **REST `send_message` callers**: gain PLAN_EXEC support; existing REACT/SKILL_REACT callers unchanged (fallback chain preserved).
|
||||
- **WebSocket clients**: gain `phase_violation` event type; existing event types unchanged. Vue frontend renders the new events; other WS clients (CLI) ignore them silently.
|
||||
- **`agentkit.yaml`**: no new config section (Wave 3 `plan_exec` section is reused; Wave 4 only changes how the policy is constructed internally).
|
||||
- **`PhasePolicy` API**: `bash_command_filter` field type widens (`re.Pattern | None` → `Callable[[str], bool] | re.Pattern | None`); `is_bash_command_allowed` signature unchanged. Backward compatible.
|
||||
- **Frontend chat store**: gains `currentPhase` + `phaseViolations` reactive state; existing state unchanged.
|
||||
|
||||
## Sources & Research
|
||||
|
||||
- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 deferred items + KTD6/KTD7).
|
||||
- Wave 3 plan: `docs/plans/2026-06-29-004-feat-agent-wave3-strategic-plan.md` (Out of Scope section enumerates the deferrals Wave 4 executes).
|
||||
- Wave 3 code review: `/tmp/compound-engineering/ce-code-review/20260630-015548-c44a5245/` (ponytail ceiling at `core/phase.py:56-60` flagged by correctness+reliability reviewers).
|
||||
- Codebase research (2026-06-30): `frontend/src/stores/chat.ts:800`, `frontend/src/api/types.ts:135`, `src/agentkit/server/routes/chat.py:580,1093-1153`, `src/agentkit/tools/shell.py:466`, `src/agentkit/core/react.py:2196-2203`, `src/agentkit/server/_fallback_chain.py:90`.
|
||||
|
||||
## Deferred to Implementation
|
||||
|
||||
- Exact `SendMessageResponse` schema for `phase_events: list[dict]` — design above gives the shape; implementer finalizes field names based on existing response models.
|
||||
- `PhaseIndicator.vue` visual design (dot vs pill vs progress bar) — implementer picks based on existing Ant Design Vue component inventory.
|
||||
- Mock LLM response sequence length in U5 — implementer sizes based on whether the test asserts every step or samples key transitions.
|
||||
- Whether `_build_phase_engine` returns `None` or raises on opt-out (`enabled=False`) — design above returns None and caller falls back; implementer may switch to explicit enum return if cleaner.
|
||||
|
|
@ -15,7 +15,9 @@ import enum
|
|||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from agentkit.tools.shell import ShellTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -53,16 +55,6 @@ class PhaseState(enum.Enum):
|
|||
# Wildcard token meaning "all tools allowed in this phase".
|
||||
WILDCARD = "*"
|
||||
|
||||
# Default bash command filter for PLANNING and VERIFICATION phases — blocks
|
||||
# commands that mutate the filesystem or execute arbitrary code.
|
||||
# ponytail: regex is intentionally conservative; misses some shell idioms
|
||||
# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch
|
||||
# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time.
|
||||
# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT
|
||||
# for `>`/`>>` operators (not word chars). Use a non-boundary alternation
|
||||
# that matches `>` either as a standalone operator or after whitespace.
|
||||
_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?<!\S)>|>>")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PhasePolicy:
|
||||
|
|
@ -76,10 +68,19 @@ class PhasePolicy:
|
|||
|
||||
Wildcard ``"*"`` in a phase's whitelist means "all tools allowed"
|
||||
(used by DELIVERY by default).
|
||||
|
||||
`bash_command_filter` values accept either:
|
||||
- `Callable[[str], bool]`: returns True if the command is dangerous
|
||||
(matches `ShellTool._is_dangerous` semantics); allowed = not dangerous.
|
||||
- `re.Pattern`: pattern matches dangerous substrings; allowed = no match.
|
||||
Kept for backward compat with Wave 3 configs.
|
||||
- `None`: no restriction for this phase.
|
||||
"""
|
||||
|
||||
whitelist: dict[PhaseState, frozenset[str]]
|
||||
bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict)
|
||||
bash_command_filter: dict[
|
||||
PhaseState, Callable[[str], bool] | re.Pattern | None
|
||||
] = field(default_factory=dict)
|
||||
auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase)
|
||||
start_phase: PhaseState = PhaseState.PLANNING
|
||||
|
||||
|
|
@ -103,19 +104,31 @@ class PhasePolicy:
|
|||
"""Return True if `command` passes the bash filter for `phase`.
|
||||
|
||||
A None filter = no restriction. An empty command is allowed (ShellTool
|
||||
separately rejects empty commands).
|
||||
separately rejects empty commands) — short-circuited here so the
|
||||
ShellTool path emits a clearer "empty command" error instead of a
|
||||
phase-violation noise injected back to the LLM.
|
||||
"""
|
||||
pattern = self.bash_command_filter.get(phase)
|
||||
if pattern is None:
|
||||
if not command:
|
||||
return True
|
||||
return not pattern.search(command)
|
||||
filter_value = self.bash_command_filter.get(phase)
|
||||
if filter_value is None:
|
||||
return True
|
||||
if callable(filter_value):
|
||||
# Callable contract: returns True if dangerous.
|
||||
return not filter_value(command)
|
||||
# re.Pattern contract: search() returns a Match if dangerous.
|
||||
return not filter_value.search(command)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize for logging/telemetry. Not round-trippable (regex → str)."""
|
||||
"""Serialize for logging/telemetry. Not round-trippable (regex/callable → str)."""
|
||||
return {
|
||||
"whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()},
|
||||
"bash_command_filter": {
|
||||
phase.value: (p.pattern if p else None)
|
||||
phase.value: (
|
||||
"<callable>"
|
||||
if callable(p)
|
||||
else (p.pattern if p else None)
|
||||
)
|
||||
for phase, p in self.bash_command_filter.items()
|
||||
},
|
||||
"auto_advance_after_steps": self.auto_advance_after_steps,
|
||||
|
|
@ -133,8 +146,10 @@ def default_policy() -> PhasePolicy:
|
|||
- DELIVERY: all tools (wildcard)
|
||||
|
||||
Bash filter:
|
||||
- PLANNING/VERIFICATION: blocks filesystem-mutating commands
|
||||
(rm/mv/cp/mkdir/chmod/chown/>/>>)
|
||||
- PLANNING/VERIFICATION: reuse `ShellTool._is_dangerous` (Wave 4 U1).
|
||||
Closes the regex ceiling — catches `:>file`, `dd of=/dev/sda`, chain
|
||||
operators, and the full danger taxonomy shared with the ShellTool
|
||||
confirmation path.
|
||||
- BUILDING/DELIVERY: no filter (full bash)
|
||||
"""
|
||||
return PhasePolicy(
|
||||
|
|
@ -150,8 +165,8 @@ def default_policy() -> PhasePolicy:
|
|||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
bash_command_filter={
|
||||
PhaseState.PLANNING: _DEFAULT_BASH_FILTER,
|
||||
PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER,
|
||||
PhaseState.PLANNING: ShellTool._is_dangerous,
|
||||
PhaseState.VERIFICATION: ShellTool._is_dangerous,
|
||||
PhaseState.BUILDING: None,
|
||||
PhaseState.DELIVERY: None,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -224,6 +224,12 @@ class ReActEngine:
|
|||
)
|
||||
# Steps taken in the current phase (for auto-advance safety net).
|
||||
self._steps_in_phase: int = 0
|
||||
# Wave 4 U2: phase violation accumulator. _check_phase_permission
|
||||
# appends here when a tool is blocked; execute_stream drains after each
|
||||
# step and yields phase_violation ReActEvents. Non-streaming execute()
|
||||
# simply ignores the accumulator (the error dict returned to the LLM is
|
||||
# the only signal there).
|
||||
self._phase_violations: list[dict[str, Any]] = []
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
|
@ -241,6 +247,8 @@ class ReActEngine:
|
|||
if self._phase_policy is not None:
|
||||
self._current_phase = self._phase_policy.start_phase
|
||||
self._steps_in_phase = 0
|
||||
# Wave 4 U2: clear any pending violations from a prior run.
|
||||
self._phase_violations = []
|
||||
|
||||
# ── U3/G6: phase state machine ────────────────────────────────────
|
||||
|
||||
|
|
@ -299,11 +307,16 @@ class ReActEngine:
|
|||
AdvancePhaseTool or pick a different tool).
|
||||
|
||||
Also applies the bash_command_filter for `bash` tool calls.
|
||||
|
||||
Wave 4 U2: when blocked, the violation is also appended to
|
||||
`self._phase_violations` so `execute_stream` can drain and yield
|
||||
`phase_violation` ReActEvents to the WS layer (alongside the LLM
|
||||
reinjection that the returned dict provides).
|
||||
"""
|
||||
if self._phase_policy is None or self._current_phase is None:
|
||||
return None
|
||||
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||
return {
|
||||
violation = {
|
||||
"error": "phase_violation",
|
||||
"message": (
|
||||
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
|
||||
|
|
@ -312,12 +325,15 @@ class ReActEngine:
|
|||
"current_phase": self._current_phase.value,
|
||||
"tool": tool_name,
|
||||
"is_error": True,
|
||||
"violation_kind": "tool_not_allowed",
|
||||
}
|
||||
self._phase_violations.append(violation)
|
||||
return violation
|
||||
# Bash command filter (only applies to shell tool — registered as "shell").
|
||||
if tool_name == "shell":
|
||||
command = str(arguments.get("command", ""))
|
||||
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
|
||||
return {
|
||||
violation = {
|
||||
"error": "phase_violation",
|
||||
"message": (
|
||||
f"Bash command blocked in {self._current_phase.value} phase "
|
||||
|
|
@ -327,7 +343,11 @@ class ReActEngine:
|
|||
"current_phase": self._current_phase.value,
|
||||
"tool": tool_name,
|
||||
"is_error": True,
|
||||
"violation_kind": "bash_command_blocked",
|
||||
"command_preview": command[:200],
|
||||
}
|
||||
self._phase_violations.append(violation)
|
||||
return violation
|
||||
return None
|
||||
|
||||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||
|
|
@ -351,6 +371,28 @@ class ReActEngine:
|
|||
return hash_to_name.get(h)
|
||||
return None
|
||||
|
||||
def _drain_phase_violations(self, step: int) -> list[ReActEvent]:
|
||||
"""Pop and return ReActEvents for phase violations recorded by
|
||||
``_check_phase_permission`` since the last drain.
|
||||
|
||||
Wave 4 U2: execute_stream calls this after each tool_result yield so
|
||||
the WS layer can surface ``phase_violation`` events to the client
|
||||
(alongside the LLM reinjection that the returned error dict provides).
|
||||
Returns an empty list if no violations are pending.
|
||||
"""
|
||||
if not self._phase_violations:
|
||||
return []
|
||||
events = [
|
||||
ReActEvent(
|
||||
event_type="phase_violation",
|
||||
step=step,
|
||||
data=dict(v),
|
||||
)
|
||||
for v in self._phase_violations
|
||||
]
|
||||
self._phase_violations = []
|
||||
return events
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
|
|
@ -1514,6 +1556,10 @@ class ReActEngine:
|
|||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
# Wave 4 U2: drain phase violations recorded by
|
||||
# _check_phase_permission during this tool call.
|
||||
for _ev in self._drain_phase_violations(step):
|
||||
yield _ev
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, effective_compressor, tc.name
|
||||
)
|
||||
|
|
@ -1652,6 +1698,9 @@ class ReActEngine:
|
|||
step=step,
|
||||
data={"tool_name": tc.name, "result": tool_result},
|
||||
)
|
||||
# Wave 4 U2: drain phase violations.
|
||||
for _ev in self._drain_phase_violations(step):
|
||||
yield _ev
|
||||
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
tc.id, tool_result, effective_compressor, tc.name
|
||||
|
|
@ -1721,6 +1770,9 @@ class ReActEngine:
|
|||
step=step,
|
||||
data={"tool_name": pc["name"], "result": tool_result},
|
||||
)
|
||||
# Wave 4 U2: drain phase violations.
|
||||
for _ev in self._drain_phase_violations(step):
|
||||
yield _ev
|
||||
tool_msg = await self._build_tool_result_message(
|
||||
pc.get("id", f"text_tc_{step}"),
|
||||
tool_result,
|
||||
|
|
|
|||
|
|
@ -146,6 +146,9 @@ export type WsServerMessage =
|
|||
| { type: 'phase_started'; data: { phase_id: string; phase_name: string; assigned_expert: string; depends_on: string[] } }
|
||||
| { type: 'phase_completed'; data: { phase_id: string; phase_name: string; result_summary: string } }
|
||||
| { type: 'phase_failed'; data: { phase_id: string; phase_name: string; error: string } }
|
||||
// PLAN_EXEC (U4) — phase lifecycle events emitted by ReActEngine.
|
||||
| { type: 'phase_changed'; data: { phase: string; previous: string } }
|
||||
| { type: 'phase_violation'; data: { current_phase: string; tool: string; message: string; violation_kind: string; command_preview?: string } }
|
||||
| { type: 'team_synthesis'; data: { content: string } }
|
||||
| { type: 'team_dissolved'; data: { team_id: string } }
|
||||
// Board Meeting 模式事件
|
||||
|
|
|
|||
|
|
@ -0,0 +1,142 @@
|
|||
<template>
|
||||
<div v-if="chatStore.isPlanExec" class="phase-indicator">
|
||||
<span class="phase-indicator__badge">PLAN_EXEC</span>
|
||||
<span class="phase-indicator__label">{{ currentLabel }}</span>
|
||||
<ul class="phase-indicator__dots">
|
||||
<li
|
||||
v-for="p in phases"
|
||||
:key="p.key"
|
||||
class="phase-indicator__dot"
|
||||
:class="{
|
||||
'phase-indicator__dot--active': p.key === activeKey,
|
||||
'phase-indicator__dot--done': p.done,
|
||||
}"
|
||||
:title="p.label"
|
||||
>
|
||||
<span class="phase-indicator__dot-inner" />
|
||||
</li>
|
||||
</ul>
|
||||
<span v-if="chatStore.phaseViolations.length > 0" class="phase-indicator__violations">
|
||||
{{ chatStore.phaseViolations.length }} 违规
|
||||
</span>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
import { useChatStore } from '@/stores/chat'
|
||||
|
||||
const chatStore = useChatStore()
|
||||
|
||||
interface PhaseMeta {
|
||||
key: string
|
||||
label: string
|
||||
done: boolean
|
||||
}
|
||||
|
||||
const phases = computed<PhaseMeta[]>(() => {
|
||||
const order = ['planning', 'building', 'verification', 'delivery']
|
||||
const labels: Record<string, string> = {
|
||||
planning: '规划',
|
||||
building: '构建',
|
||||
verification: '验证',
|
||||
delivery: '交付',
|
||||
}
|
||||
const current = chatStore.currentPhase
|
||||
const currentIndex = current ? order.indexOf(current) : -1
|
||||
return order.map((key, idx) => ({
|
||||
key,
|
||||
label: labels[key] ?? key,
|
||||
done: currentIndex > idx,
|
||||
}))
|
||||
})
|
||||
|
||||
const activeKey = computed(() => chatStore.currentPhase ?? '')
|
||||
|
||||
const currentLabel = computed(() => {
|
||||
const labels: Record<string, string> = {
|
||||
planning: '规划阶段',
|
||||
building: '构建阶段',
|
||||
verification: '验证阶段',
|
||||
delivery: '交付阶段',
|
||||
}
|
||||
return labels[chatStore.currentPhase ?? ''] ?? 'PLAN_EXEC'
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.phase-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 4px 12px;
|
||||
background: var(--color-bg-elevated, #fafafa);
|
||||
border-bottom: 1px solid var(--color-border, #f0f0f0);
|
||||
font-size: 12px;
|
||||
color: var(--color-text-secondary, #888);
|
||||
}
|
||||
|
||||
.phase-indicator__badge {
|
||||
display: inline-block;
|
||||
padding: 1px 6px;
|
||||
background: #722ed1;
|
||||
color: #fff;
|
||||
border-radius: 3px;
|
||||
font-size: 10px;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.phase-indicator__label {
|
||||
color: var(--color-text-primary, #333);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.phase-indicator__dots {
|
||||
list-style: none;
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.phase-indicator__dot {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
border-radius: 50%;
|
||||
background: #d9d9d9;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.phase-indicator__dot-inner {
|
||||
width: 4px;
|
||||
height: 4px;
|
||||
border-radius: 50%;
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.phase-indicator__dot--active {
|
||||
background: #722ed1;
|
||||
box-shadow: 0 0 0 2px rgba(114, 46, 209, 0.2);
|
||||
}
|
||||
|
||||
.phase-indicator__dot--active .phase-indicator__dot-inner {
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.phase-indicator__dot--done {
|
||||
background: #52c41a;
|
||||
}
|
||||
|
||||
.phase-indicator__violations {
|
||||
margin-left: auto;
|
||||
padding: 1px 6px;
|
||||
background: #fff1f0;
|
||||
color: #cf1322;
|
||||
border-radius: 3px;
|
||||
font-size: 11px;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -174,6 +174,27 @@ export const useChatStore = defineStore("chat", () => {
|
|||
() => boardState.value !== null && boardState.value.status === "discussing",
|
||||
);
|
||||
|
||||
// PLAN_EXEC phase state (U4) — tracks current phase + violations for the
|
||||
// PhaseIndicator component. Set when the first phase_* event arrives.
|
||||
// Reset on conversation switch or final_answer.
|
||||
const currentPhase = ref<string | null>(null);
|
||||
const phaseViolations = ref<
|
||||
Array<{
|
||||
phase: string;
|
||||
tool: string;
|
||||
message: string;
|
||||
violation_kind: string;
|
||||
command_preview?: string;
|
||||
ts: number;
|
||||
}>
|
||||
>([]);
|
||||
const isPlanExec = computed(() => currentPhase.value !== null);
|
||||
|
||||
function resetPlanExecState(): void {
|
||||
currentPhase.value = null;
|
||||
phaseViolations.value = [];
|
||||
}
|
||||
|
||||
// Debate state (transient, only active during a debate collaboration)
|
||||
const debateState = ref<{
|
||||
topic: string;
|
||||
|
|
@ -1096,6 +1117,8 @@ export const useChatStore = defineStore("chat", () => {
|
|||
// across multiple interactions. The UI has already transitioned
|
||||
// to showing the final assistant message.
|
||||
clearConvSteps(conversationId);
|
||||
// Reset PLAN_EXEC phase state — the conversation is done.
|
||||
resetPlanExecState();
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -1390,6 +1413,59 @@ export const useChatStore = defineStore("chat", () => {
|
|||
break;
|
||||
}
|
||||
|
||||
// ── PLAN_EXEC (U4) — phase lifecycle events from ReActEngine ────────
|
||||
|
||||
case "phase_changed": {
|
||||
currentPhase.value = data.data.phase;
|
||||
const cid = resolveIncomingConvId();
|
||||
if (cid) {
|
||||
appendStep(
|
||||
{
|
||||
type: "milestone",
|
||||
label: "阶段切换",
|
||||
detail: `${data.data.previous || "—"} → ${data.data.phase}`,
|
||||
status: "success",
|
||||
},
|
||||
cid,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "phase_violation": {
|
||||
// Track current phase from violation data.
|
||||
currentPhase.value = data.data.current_phase;
|
||||
const violation = {
|
||||
phase: data.data.current_phase,
|
||||
tool: data.data.tool,
|
||||
message: data.data.message,
|
||||
violation_kind: data.data.violation_kind,
|
||||
command_preview: data.data.command_preview,
|
||||
ts: Date.now(),
|
||||
};
|
||||
phaseViolations.value = [...phaseViolations.value, violation].slice(-5);
|
||||
// Toast notification via ant-design-vue message.
|
||||
import("ant-design-vue").then(({ message }) => {
|
||||
message.warning(
|
||||
`[${data.data.current_phase}] 工具 ${data.data.tool} 被拦截: ${data.data.message}`,
|
||||
5,
|
||||
);
|
||||
});
|
||||
const cid = resolveIncomingConvId();
|
||||
if (cid) {
|
||||
appendStep(
|
||||
{
|
||||
type: "team_event",
|
||||
label: "阶段违规",
|
||||
detail: `${data.data.current_phase} · ${data.data.tool}`,
|
||||
status: "error",
|
||||
},
|
||||
cid,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// ── Board Meeting 模式事件 ────────────────────────────────────────
|
||||
|
||||
case "board_started": {
|
||||
|
|
@ -1920,6 +1996,10 @@ export const useChatStore = defineStore("chat", () => {
|
|||
boardState,
|
||||
debateState,
|
||||
collaborationState,
|
||||
// PLAN_EXEC (U4)
|
||||
currentPhase,
|
||||
phaseViolations,
|
||||
isPlanExec,
|
||||
// Legacy aliases (derive from current conversation for backward compat).
|
||||
// New code should use `isCurrentLoading` / `currentStreamingSteps` instead.
|
||||
isLoading: isCurrentLoading,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
<template v-else>
|
||||
<ExpertTeamView />
|
||||
<BoardStatusView />
|
||||
<PhaseIndicator />
|
||||
<div class="chat-view__content" ref="messagesContainer">
|
||||
<div class="chat-view__content-inner">
|
||||
<div v-if="chatStore.currentMessages.length === 0" class="chat-view__welcome">
|
||||
|
|
@ -108,6 +109,7 @@ import ChatMessage from '@/components/chat/ChatMessage.vue'
|
|||
import ChatInput from '@/components/chat/ChatInput.vue'
|
||||
import ExpertTeamView from '@/components/chat/ExpertTeamView.vue'
|
||||
import BoardStatusView from '@/components/chat/BoardStatusView.vue'
|
||||
import PhaseIndicator from '@/components/chat/PhaseIndicator.vue'
|
||||
|
||||
const ATypographyText = ATypography.Text
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,78 @@
|
|||
/**
|
||||
* Unit tests for chat store PLAN_EXEC phase state (U4).
|
||||
*
|
||||
* Verifies the phase state slice is exposed with correct initial values
|
||||
* and that `isPlanExec` derives from `currentPhase`. The full event
|
||||
* handling is covered by the E2E test in U5.
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { setActivePinia, createPinia } from 'pinia'
|
||||
|
||||
// Mock the API client so the store doesn't touch the network.
|
||||
vi.mock('@/api/client', () => ({
|
||||
apiClient: {
|
||||
get: vi.fn(),
|
||||
post: vi.fn(),
|
||||
put: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
patch: vi.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock team/documents/calendar stores to avoid pulling their deps.
|
||||
vi.mock('@/stores/team', () => ({
|
||||
useTeamStore: vi.fn(() => null),
|
||||
}))
|
||||
vi.mock('@/stores/documents', () => ({
|
||||
useDocumentsStore: vi.fn(() => null),
|
||||
}))
|
||||
vi.mock('@/stores/calendar', () => ({
|
||||
useCalendarStore: vi.fn(() => null),
|
||||
}))
|
||||
vi.mock('@/api/documents', () => ({
|
||||
isDocumentMeta: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('chat store — PLAN_EXEC phase state (U4)', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
it('exposes currentPhase, phaseViolations, isPlanExec with initial values', async () => {
|
||||
const { useChatStore } = await import('@/stores/chat')
|
||||
const store = useChatStore()
|
||||
expect(store.currentPhase).toBeNull()
|
||||
expect(store.phaseViolations).toEqual([])
|
||||
expect(store.isPlanExec).toBe(false)
|
||||
})
|
||||
|
||||
it('isPlanExec is true when currentPhase is set', async () => {
|
||||
const { useChatStore } = await import('@/stores/chat')
|
||||
const store = useChatStore()
|
||||
store.currentPhase = 'planning'
|
||||
expect(store.isPlanExec).toBe(true)
|
||||
})
|
||||
|
||||
it('phaseViolations array is directly mutable for test fixtures', async () => {
|
||||
// Direct mutation bypasses the capping logic in handleWsMessage;
|
||||
// the cap is enforced inside the case handler, not as a setter.
|
||||
// This test verifies the array is accessible; the cap-at-5 behavior
|
||||
// is exercised through handleWsMessage in the U5 E2E test.
|
||||
const { useChatStore } = await import('@/stores/chat')
|
||||
const store = useChatStore()
|
||||
for (let i = 0; i < 7; i++) {
|
||||
store.phaseViolations = [
|
||||
...store.phaseViolations,
|
||||
{
|
||||
phase: 'planning',
|
||||
tool: `tool_${i}`,
|
||||
message: 'blocked',
|
||||
violation_kind: 'tool_not_allowed',
|
||||
ts: Date.now(),
|
||||
},
|
||||
]
|
||||
}
|
||||
expect(store.phaseViolations.length).toBe(7)
|
||||
})
|
||||
})
|
||||
|
|
@ -6,7 +6,7 @@ import asyncio
|
|||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
|
@ -25,7 +25,7 @@ from fastapi.responses import FileResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agentkit.chat.skill_routing import ExecutionMode
|
||||
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
|
||||
from agentkit.core.phase import default_policy, policy_from_config
|
||||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||
|
|
@ -33,6 +33,11 @@ from agentkit.session.manager import SessionManager
|
|||
from agentkit.session.models import MessageRole, SessionStatus
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
|
@ -534,6 +539,69 @@ def _message_to_response(msg) -> MessageResponse:
|
|||
)
|
||||
|
||||
|
||||
def _build_phase_engine(
|
||||
*,
|
||||
server_config: ServerConfig | None,
|
||||
llm_gateway: LLMGateway,
|
||||
execution_mode: ExecutionMode,
|
||||
base_tools: list[Tool],
|
||||
session_id: str = "",
|
||||
) -> tuple[ReActEngine | None, list[Tool] | None, str | None]:
|
||||
"""Build a PLAN_EXEC engine with PhasePolicy + AdvancePhaseTool.
|
||||
|
||||
Encapsulates the WS path's phase_policy construction so the REST path
|
||||
can reuse it without duplicating config-lookup + policy_from_config +
|
||||
AdvancePhaseTool registration. KTD5: PLAN_EXEC bypasses the fallback
|
||||
chain — callers must NOT route the returned engine through
|
||||
``execute_with_fallback_chain``.
|
||||
|
||||
Args:
|
||||
server_config: ``app.state.server_config`` (or None for tests).
|
||||
llm_gateway: ``app.state.llm_gateway``.
|
||||
execution_mode: routing.execution_mode (WS) or PLAN_EXEC (REST).
|
||||
base_tools: routing.tools (WS) or agent tool list (REST).
|
||||
session_id: included in log lines for traceability only.
|
||||
|
||||
Returns ``(engine, tools_with_advance_phase, error_message)``:
|
||||
- execution_mode != PLAN_EXEC → ``(None, None, None)`` (fall back to REACT).
|
||||
- plan_exec.enabled=False → ``(None, None, None)`` (fall back to REACT).
|
||||
- phase policy construction failed → ``(None, None, error_message)``.
|
||||
- PLAN_EXEC engaged → ``(engine, tools_with_advance_phase, None)``.
|
||||
"""
|
||||
if execution_mode != ExecutionMode.PLAN_EXEC:
|
||||
return (None, None, None)
|
||||
|
||||
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
|
||||
if plan_exec_cfg.get("enabled", True) is False:
|
||||
logger.info(
|
||||
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
|
||||
"falling back to REACT for session %s",
|
||||
session_id,
|
||||
)
|
||||
return (None, None, None)
|
||||
|
||||
try:
|
||||
phase_policy = policy_from_config(plan_exec_cfg)
|
||||
if phase_policy is None:
|
||||
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||
phase_policy = default_policy()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
return (None, None, f"phase policy error: {str(e)[:200]}")
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
phase_policy=phase_policy,
|
||||
)
|
||||
advance_phase_tool = AdvancePhaseTool(engine=engine)
|
||||
tools_with_advance_phase = list(base_tools) + [advance_phase_tool]
|
||||
return (engine, tools_with_advance_phase, None)
|
||||
|
||||
|
||||
# ── REST endpoints ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -587,12 +655,58 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
if session.status == SessionStatus.CLOSED:
|
||||
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
||||
|
||||
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
|
||||
# U3: PLAN_EXEC via REST — non-streaming, bypasses the fallback chain
|
||||
# (KTD5: PLAN_EXEC and execute_with_fallback_chain are mutually exclusive).
|
||||
# When plan_exec is disabled by config, falls through to the REACT path below.
|
||||
if request.execution_mode == "plan_exec":
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
|
||||
# Resolve the Agent early — PLAN_EXEC needs its tool list + system prompt.
|
||||
pool = req.app.state.agent_pool
|
||||
agent = pool.get_agent(session.agent_name)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
|
||||
|
||||
plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine(
|
||||
server_config=getattr(req.app.state, "server_config", None),
|
||||
llm_gateway=req.app.state.llm_gateway,
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=agent._tool_registry.list_tools() if agent._tool_registry else [],
|
||||
session_id=session_id,
|
||||
)
|
||||
if plan_exec_error is not None:
|
||||
raise HTTPException(status_code=500, detail=plan_exec_error)
|
||||
if plan_exec_engine is not None:
|
||||
# PLAN_EXEC engaged — append user msg, execute non-streaming, return.
|
||||
await sm.append_message(
|
||||
session_id=session_id,
|
||||
role=MessageRole.USER,
|
||||
content=request.content,
|
||||
)
|
||||
chat_messages = await sm.get_chat_messages(session_id)
|
||||
system_prompt = getattr(agent, "_system_prompt", None) or (
|
||||
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
||||
)
|
||||
try:
|
||||
plan_exec_result = await plan_exec_engine.execute(
|
||||
messages=chat_messages,
|
||||
tools=plan_exec_tools,
|
||||
model=agent.get_model()
|
||||
if hasattr(agent, "get_model")
|
||||
else getattr(agent, "_llm_model", "default"),
|
||||
agent_name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
assistant_msg = await sm.append_message(
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=plan_exec_result.output,
|
||||
agent_name=agent.name,
|
||||
)
|
||||
return _message_to_response(assistant_msg)
|
||||
# else: plan_exec.enabled=False → fall through to REACT path below.
|
||||
|
||||
# Append user message
|
||||
await sm.append_message(
|
||||
|
|
@ -1090,39 +1204,24 @@ async def _handle_chat_message(
|
|||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||
return
|
||||
|
||||
# U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
|
||||
# U4/G6: PLAN_EXEC — build PhasePolicy from server config.
|
||||
# KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
|
||||
# fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
|
||||
phase_policy: PhasePolicy | None = None
|
||||
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||
server_config = getattr(websocket.app.state, "server_config", None)
|
||||
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
|
||||
|
||||
if plan_exec_cfg.get("enabled", True) is False:
|
||||
# Explicit opt-out → fall back to REACT.
|
||||
logger.info(
|
||||
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
|
||||
"falling back to REACT for session %s",
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
phase_policy = policy_from_config(plan_exec_cfg)
|
||||
if phase_policy is None:
|
||||
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||
phase_policy = default_policy()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
# U3: logic extracted into _build_phase_engine so REST can reuse it.
|
||||
plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine(
|
||||
server_config=getattr(websocket.app.state, "server_config", None),
|
||||
llm_gateway=websocket.app.state.llm_gateway,
|
||||
execution_mode=routing.execution_mode,
|
||||
base_tools=routing.tools,
|
||||
session_id=session_id,
|
||||
)
|
||||
if plan_exec_error is not None:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
# Truncate to 200 chars to match nearby error paths and
|
||||
# avoid leaking config internals (see chat.py:1090, 1320).
|
||||
"data": {"message": f"phase policy error: {str(e)[:200]}"},
|
||||
"data": {"message": plan_exec_error},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
|
@ -1143,14 +1242,9 @@ async def _handle_chat_message(
|
|||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
|
||||
# PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
|
||||
# agent's _react_engine — it has no policy).
|
||||
if phase_policy is not None:
|
||||
react_engine = ReActEngine(
|
||||
llm_gateway=websocket.app.state.llm_gateway,
|
||||
phase_policy=phase_policy,
|
||||
)
|
||||
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
|
||||
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
|
||||
routing.tools = list(routing.tools) + [advance_phase_tool]
|
||||
if plan_exec_engine is not None:
|
||||
react_engine = plan_exec_engine
|
||||
routing.tools = plan_exec_tools
|
||||
else:
|
||||
react_engine = getattr(agent, "_react_engine", None)
|
||||
if react_engine is None:
|
||||
|
|
@ -1280,6 +1374,17 @@ async def _handle_chat_message(
|
|||
"data": event.data,
|
||||
}
|
||||
)
|
||||
elif event.event_type == "phase_violation":
|
||||
# Wave 4 U2: forward phase violations to the client so the
|
||||
# frontend can surface them in the PhaseIndicator UI (alongside
|
||||
# the LLM reinjection that already happens via the tool_result
|
||||
# error dict).
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "phase_violation",
|
||||
"data": event.data,
|
||||
}
|
||||
)
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from typing import Any, Callable, Awaitable
|
|||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
|
||||
from agentkit.tools.terminal_session import TerminalSessionManager
|
||||
from agentkit.tools.pty_session import PTYSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -159,10 +159,22 @@ _SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
|||
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
|
||||
|
||||
# 总是危险的二进制命令(无论参数)
|
||||
_DANGEROUS_BINARIES: frozenset[str] = frozenset({
|
||||
"rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot",
|
||||
"halt", "killall", "chown", "fdisk", "parted",
|
||||
})
|
||||
_DANGEROUS_BINARIES: frozenset[str] = frozenset(
|
||||
{
|
||||
"rm",
|
||||
"rmdir",
|
||||
"mkfs",
|
||||
"dd",
|
||||
"format",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"halt",
|
||||
"killall",
|
||||
"chown",
|
||||
"fdisk",
|
||||
"parted",
|
||||
}
|
||||
)
|
||||
|
||||
# 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合
|
||||
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
|
||||
|
|
@ -183,13 +195,16 @@ _DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
|
|||
re.compile(r"drop\s+database", re.IGNORECASE),
|
||||
re.compile(r"truncate\s+table", re.IGNORECASE),
|
||||
# curl/wget data exfiltration: POST/PUT/upload flags
|
||||
re.compile(r"\bcurl\b.*(-X\s*(POST|PUT|PATCH|DELETE)|--data|--data-binary|--data-raw|--data-urlencode|-d\s|--post\d)", re.IGNORECASE),
|
||||
re.compile(
|
||||
r"\bcurl\b.*(-X\s*(POST|PUT|PATCH|DELETE)|--data|--data-binary|--data-raw|--data-urlencode|-d\s|--post\d)",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
re.compile(r"\bwget\b.*(--post-data|--post-file)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
|
||||
_SHELL_PIPE_OPERATORS = re.compile(r'\|')
|
||||
_SHELL_CHAIN_OPERATORS = re.compile(r'[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n')
|
||||
_SHELL_PIPE_OPERATORS = re.compile(r"\|")
|
||||
_SHELL_CHAIN_OPERATORS = re.compile(r"[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n")
|
||||
|
||||
|
||||
class ShellTool(Tool):
|
||||
|
|
@ -360,6 +375,7 @@ class ShellTool(Tool):
|
|||
# Ensure non-empty output for successful commands (all execution modes)
|
||||
if result.exit_code == 0 and not output.strip():
|
||||
from agentkit.core.fallback import SHELL_NO_OUTPUT
|
||||
|
||||
output = SHELL_NO_OUTPUT
|
||||
|
||||
return {
|
||||
|
|
@ -383,7 +399,6 @@ class ShellTool(Tool):
|
|||
if interactive:
|
||||
return await self._execute_with_pty(command, timeout, working_dir)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
|
|
@ -430,9 +445,7 @@ class ShellTool(Tool):
|
|||
)
|
||||
|
||||
if interactive:
|
||||
return await self._execute_with_pty(
|
||||
command, timeout, session.cwd, session.env
|
||||
)
|
||||
return await self._execute_with_pty(command, timeout, session.cwd, session.env)
|
||||
|
||||
return await session.execute(command, timeout=timeout)
|
||||
|
||||
|
|
@ -463,11 +476,16 @@ class ShellTool(Tool):
|
|||
|
||||
return self._output_parser.parse(output, exit_code)
|
||||
|
||||
def _is_dangerous(self, command: str) -> bool:
|
||||
@staticmethod
|
||||
def _is_dangerous(command: str) -> bool:
|
||||
"""检查命令是否为危险操作
|
||||
|
||||
白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。
|
||||
其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。
|
||||
|
||||
Static so callers without a ShellTool instance (e.g. PhasePolicy) can
|
||||
reuse the same danger classification. Instance calls still work via
|
||||
Python's descriptor protocol.
|
||||
"""
|
||||
command_stripped = command.strip()
|
||||
|
||||
|
|
@ -477,19 +495,20 @@ class ShellTool(Tool):
|
|||
|
||||
# Handle pipe commands: split and check each sub-command
|
||||
if _SHELL_PIPE_OPERATORS.search(command_stripped):
|
||||
parts = command_stripped.split('|')
|
||||
parts = command_stripped.split("|")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
if self._is_single_command_dangerous(part):
|
||||
if ShellTool._is_single_command_dangerous(part):
|
||||
return True
|
||||
return False # All pipe segments are safe
|
||||
|
||||
# Single command
|
||||
return self._is_single_command_dangerous(command_stripped)
|
||||
return ShellTool._is_single_command_dangerous(command_stripped)
|
||||
|
||||
def _is_single_command_dangerous(self, command: str) -> bool:
|
||||
@staticmethod
|
||||
def _is_single_command_dangerous(command: str) -> bool:
|
||||
"""Check if a single command (no pipes/chains) is dangerous."""
|
||||
command_stripped = command.strip()
|
||||
if not command_stripped:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,491 @@
|
|||
"""E2E integration test for PLAN_EXEC lifecycle (Wave 4 U5, R34).
|
||||
|
||||
Exercises the full PLAN_EXEC path through a scripted LLM mock:
|
||||
PLANNING (search) → advance_phase → BUILDING (write_file) →
|
||||
advance_phase → VERIFICATION (shell pytest) → advance_phase →
|
||||
DELIVERY (final answer)
|
||||
|
||||
Also covers the negative path (write_file blocked in PLANNING), an
|
||||
edge case (auto_advance_after_steps safety net), and the error path
|
||||
(LLM raises mid-stream).
|
||||
|
||||
Mock boundary: ``LLMGateway.chat_stream`` — yields scripted ``StreamChunk``
|
||||
objects deterministically. Real ``ReActEngine``, real ``PhasePolicy``
|
||||
(``default_policy()``), real ``AdvancePhaseTool``, real WS handler in
|
||||
``chat._handle_chat_message``. No real LLM API call is made.
|
||||
|
||||
Patterns followed:
|
||||
- ``tests/unit/test_react_token_streaming.py`` — scripted gateway pattern.
|
||||
- ``tests/unit/test_chat_plan_exec_ws.py`` — WS handler test fixture pattern.
|
||||
- ``tests/integration/test_react_loop.py`` — stub-tool + LLMResponse pattern.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.phase import PhaseState, default_policy, policy_from_config
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import StreamChunk, TokenUsage, ToolCall
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
# Integration marker matches the rest of tests/integration/. This test does
|
||||
# NOT require docker (LLM is mocked) — the marker is for filtering only.
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — scripted LLM gateway + stub tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StubTool(Tool):
|
||||
"""A minimal tool that records its invocations and returns a fixed result."""
|
||||
|
||||
def __init__(self, name: str, result: dict[str, Any] | None = None) -> None:
|
||||
super().__init__(name=name, description=f"Stub {name}")
|
||||
self._result = result or {"ok": True, "tool": name}
|
||||
self.call_count: int = 0
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
|
||||
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||
self.call_count += 1
|
||||
self.calls.append(kwargs)
|
||||
return self._result
|
||||
|
||||
|
||||
def _tool_call_chunk(
|
||||
tool_call: ToolCall, *, model: str = "test-model"
|
||||
) -> StreamChunk:
|
||||
"""A final-chunk carrying exactly one tool call (no content).
|
||||
|
||||
Engine reads ``chunk.tool_calls`` with REPLACE semantics (not extend)
|
||||
at react.py:1369 — so a single final chunk must hold the full list.
|
||||
"""
|
||||
return StreamChunk(
|
||||
content="",
|
||||
model=model,
|
||||
tool_calls=[tool_call],
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=5),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
|
||||
def _final_answer_chunk(text: str, *, model: str = "test-model") -> StreamChunk:
|
||||
"""A final-chunk carrying the final answer text."""
|
||||
return StreamChunk(
|
||||
content=text,
|
||||
model=model,
|
||||
usage=TokenUsage(prompt_tokens=20, completion_tokens=30),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
|
||||
def _make_scripted_gateway(script: list[list[StreamChunk]]) -> MagicMock:
|
||||
"""Create a mock LLMGateway whose ``chat_stream`` pops one scripted step.
|
||||
|
||||
Each ``chat_stream()`` invocation yields the next inner list of
|
||||
``StreamChunk`` objects. Raises ``IndexError`` if the script is
|
||||
exhausted (test fixture misconfiguration — fail loud, not silent).
|
||||
"""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
_state = {"idx": 0}
|
||||
|
||||
async def _stream(**kwargs: Any) -> Any:
|
||||
i = _state["idx"]
|
||||
if i >= len(script):
|
||||
raise IndexError(
|
||||
f"Scripted gateway exhausted: call {i + 1} but only "
|
||||
f"{len(script)} steps scripted"
|
||||
)
|
||||
_state["idx"] += 1
|
||||
for chunk in script[i]:
|
||||
yield chunk
|
||||
|
||||
gateway.chat_stream = _stream
|
||||
gateway.get_provider_name_for_model = MagicMock(return_value=None)
|
||||
return gateway
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path — full PLAN_EXEC lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlanExecE2EHappyPath:
|
||||
"""PLANNING → BUILDING → VERIFICATION → DELIVERY via advance_phase."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_emits_expected_events(self) -> None:
|
||||
# 7-step script: search → advance → write_file → advance →
|
||||
# shell(pytest) → advance → final answer
|
||||
script: list[list[StreamChunk]] = [
|
||||
# Step 1 (PLANNING): `search` is in PLANNING whitelist → dispatched.
|
||||
[_tool_call_chunk(ToolCall(id="tc1", name="search", arguments={"query": "docs"}))],
|
||||
# Step 2 (PLANNING): `advance_phase` bypasses phase check → dispatched.
|
||||
[_tool_call_chunk(ToolCall(id="tc2", name="advance_phase", arguments={}))],
|
||||
# Step 3 (BUILDING): `write_file` allowed → dispatched.
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc3", name="write_file", arguments={"path": "/tmp/x", "content": "hi"})
|
||||
)],
|
||||
# Step 4 (BUILDING): advance_phase → transitions to VERIFICATION.
|
||||
[_tool_call_chunk(ToolCall(id="tc4", name="advance_phase", arguments={}))],
|
||||
# Step 5 (VERIFICATION): `shell` with `ls tests/unit/` — read-only,
|
||||
# passes ShellTool._is_dangerous (ls is whitelisted). The plan's
|
||||
# example used `pytest tests/unit/ -q`, but pytest is not in
|
||||
# _SAFE_COMMAND_PREFIXES → flagged dangerous by default. Verifying
|
||||
# the test files exist via `ls` is the realistic VERIFICATION-phase
|
||||
# shell call that the default policy actually allows.
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc5", name="shell", arguments={"command": "ls tests/unit/"})
|
||||
)],
|
||||
# Step 6 (VERIFICATION): advance_phase → transitions to DELIVERY.
|
||||
[_tool_call_chunk(ToolCall(id="tc6", name="advance_phase", arguments={}))],
|
||||
# Step 7 (DELIVERY): final answer text (no tool_calls) → loop exits.
|
||||
[_final_answer_chunk("Delivered: hello world")],
|
||||
]
|
||||
gateway = _make_scripted_gateway(script)
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, phase_policy=default_policy())
|
||||
# ponytail: bump loop threshold so 3 legitimate `advance_phase` calls
|
||||
# (always `{}` args) don't trigger the loop detector. PLAN_EXEC's
|
||||
# 4-phase lifecycle needs 3 transitions; default threshold=2 fires on
|
||||
# the 2nd identical call. This is a known PLAN_EXEC production concern
|
||||
# tracked separately — U5 only validates the lifecycle end-to-end.
|
||||
engine._loop_threshold = 99 # noqa: SLF001
|
||||
|
||||
# Real stub tools + real AdvancePhaseTool (bound to engine).
|
||||
search = _StubTool("search", {"results": ["doc1", "doc2"]})
|
||||
write_file = _StubTool("write_file", {"bytes_written": 2})
|
||||
shell = _StubTool("shell", {"exit_code": 0, "stdout": "all tests passed"})
|
||||
advance = AdvancePhaseTool(engine=engine)
|
||||
tools: list[Tool] = [search, write_file, shell, advance]
|
||||
|
||||
events = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Build and verify hello world"}],
|
||||
tools=tools,
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
# Final answer emitted exactly once.
|
||||
finals = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(finals) == 1
|
||||
assert "Delivered" in finals[0].data["output"]
|
||||
|
||||
# Tool dispatch counts: search=1, write_file=1, shell=1, advance=3.
|
||||
assert search.call_count == 1
|
||||
assert write_file.call_count == 1
|
||||
assert shell.call_count == 1
|
||||
# advance_phase is a real AdvancePhaseTool (not a _StubTool) — count
|
||||
# via tool_call events in the event stream. 3 calls = 3 transitions
|
||||
# (PLANNING → BUILDING → VERIFICATION → DELIVERY).
|
||||
advance_calls = [
|
||||
e for e in events
|
||||
if e.event_type == "tool_call" and e.data.get("tool_name") == "advance_phase"
|
||||
]
|
||||
assert len(advance_calls) == 3
|
||||
|
||||
# No phase_violation events in happy path.
|
||||
violations = [e for e in events if e.event_type == "phase_violation"]
|
||||
assert len(violations) == 0
|
||||
|
||||
# Engine ended at DELIVERY.
|
||||
assert engine.current_phase == PhaseState.DELIVERY
|
||||
|
||||
# tool_call / tool_result event counts match dispatched tools (6 of each).
|
||||
tool_calls = [e for e in events if e.event_type == "tool_call"]
|
||||
tool_results = [e for e in events if e.event_type == "tool_result"]
|
||||
assert len(tool_calls) == 6
|
||||
assert len(tool_results) == 6
|
||||
|
||||
# Ordering: tool_call must precede its matching tool_result.
|
||||
first_tc_idx = next(
|
||||
i for i, e in enumerate(events) if e.event_type == "tool_call"
|
||||
)
|
||||
first_tr_idx = next(
|
||||
i for i, e in enumerate(events) if e.event_type == "tool_result"
|
||||
)
|
||||
assert first_tc_idx < first_tr_idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Negative path — violation then recovery via advance_phase
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlanExecE2ENegativePath:
|
||||
"""Out-of-phase tool blocked → LLM calls advance_phase → tool succeeds."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_violation_then_recovery(self) -> None:
|
||||
script: list[list[StreamChunk]] = [
|
||||
# Step 1 (PLANNING): `write_file` NOT in PLANNING whitelist → blocked.
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc1", name="write_file", arguments={"path": "/x", "content": "y"})
|
||||
)],
|
||||
# Step 2 (PLANNING): LLM reacts to violation by calling advance_phase.
|
||||
[_tool_call_chunk(ToolCall(id="tc2", name="advance_phase", arguments={}))],
|
||||
# Step 3 (BUILDING): `write_file` now allowed → dispatched.
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc3", name="write_file", arguments={"path": "/x", "content": "y"})
|
||||
)],
|
||||
# Step 4: final answer.
|
||||
[_final_answer_chunk("Recovered and built")],
|
||||
]
|
||||
gateway = _make_scripted_gateway(script)
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, phase_policy=default_policy())
|
||||
# See happy path test for the loop threshold rationale.
|
||||
engine._loop_threshold = 99 # noqa: SLF001
|
||||
write_file = _StubTool("write_file", {"bytes_written": 1})
|
||||
advance = AdvancePhaseTool(engine=engine)
|
||||
tools: list[Tool] = [write_file, advance]
|
||||
|
||||
events = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Build x"}],
|
||||
tools=tools,
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
# Exactly one phase_violation event — from step 1.
|
||||
violations = [e for e in events if e.event_type == "phase_violation"]
|
||||
assert len(violations) == 1
|
||||
v = violations[0].data
|
||||
assert v["tool"] == "write_file"
|
||||
assert v["current_phase"] == "planning"
|
||||
assert v["violation_kind"] == "tool_not_allowed"
|
||||
assert "advance_phase" in v["message"]
|
||||
|
||||
# write_file dispatched exactly once (during BUILDING, NOT during PLANNING).
|
||||
assert write_file.call_count == 1
|
||||
|
||||
# Engine ended at BUILDING (advance_phase was called once).
|
||||
assert engine.current_phase == PhaseState.BUILDING
|
||||
|
||||
# Final answer emitted despite the violation.
|
||||
finals = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(finals) == 1
|
||||
assert "Recovered" in finals[0].data["output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases — auto-advance safety net + plan_exec.enabled=False
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlanExecE2EEdgeCases:
|
||||
"""auto_advance_after_steps triggers transition without explicit advance_phase,
|
||||
and policy_from_config(enabled=False) returns None (PLAN_EXEC disabled)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_advance_after_two_steps(self) -> None:
|
||||
"""With auto_advance_after_steps=2, after 2 LLM calls in PLANNING
|
||||
the engine auto-advances to BUILDING — even without an explicit
|
||||
advance_phase tool call."""
|
||||
# Custom policy: auto-advance after 2 steps per phase.
|
||||
policy = default_policy()
|
||||
# ponytail: dataclass(slots=True) — use __setattr__ via object.__setattr__
|
||||
# or rebuild via dataclasses.replace. Replace is the clean path.
|
||||
from dataclasses import replace
|
||||
|
||||
policy = replace(policy, auto_advance_after_steps=2)
|
||||
|
||||
# Script: LLM calls `search` 3 times then final answer.
|
||||
# Expected: step 1 (PLANNING, search), step 2 (PLANNING, search) →
|
||||
# auto-advance fires after step 2 → step 3 (BUILDING, search still
|
||||
# allowed), then final answer.
|
||||
script: list[list[StreamChunk]] = [
|
||||
[_tool_call_chunk(ToolCall(id="tc1", name="search", arguments={"query": "a"}))],
|
||||
[_tool_call_chunk(ToolCall(id="tc2", name="search", arguments={"query": "b"}))],
|
||||
[_tool_call_chunk(ToolCall(id="tc3", name="search", arguments={"query": "c"}))],
|
||||
[_final_answer_chunk("Done after auto-advance")],
|
||||
]
|
||||
gateway = _make_scripted_gateway(script)
|
||||
engine = ReActEngine(llm_gateway=gateway, phase_policy=policy)
|
||||
# See happy path test for the loop threshold rationale.
|
||||
engine._loop_threshold = 99 # noqa: SLF001
|
||||
search = _StubTool("search", {"results": []})
|
||||
tools: list[Tool] = [search]
|
||||
|
||||
events = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "Search stuff"}],
|
||||
tools=tools,
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
# Engine should have transitioned out of PLANNING (auto-advance fired).
|
||||
# Weak assertion: auto_advance_after_steps=2 may fire multiple times
|
||||
# (PLANNING→BUILDING→VERIFICATION), so we only assert it left PLANNING.
|
||||
assert engine.current_phase != PhaseState.PLANNING
|
||||
# All 3 search calls dispatched (search is allowed in both PLANNING and BUILDING).
|
||||
assert search.call_count == 3
|
||||
# Final answer emitted.
|
||||
finals = [e for e in events if e.event_type == "final_answer"]
|
||||
assert len(finals) == 1
|
||||
|
||||
def test_policy_from_config_returns_none_when_disabled(self) -> None:
|
||||
"""Edge: plan_exec.enabled=False → policy_from_config returns None,
|
||||
which causes _build_phase_engine to fall back to REACT (no policy)."""
|
||||
result = policy_from_config({"enabled": False})
|
||||
assert result is None
|
||||
|
||||
def test_policy_from_config_returns_default_when_section_absent(self) -> None:
|
||||
"""Edge: empty plan_exec config → policy_from_config returns None
|
||||
(opt-out), so _build_phase_engine falls back to default_policy()."""
|
||||
result = policy_from_config({})
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error path — LLM raises mid-stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlanExecE2EErrorPath:
|
||||
"""LLM call failure propagates; phase state is left untouched."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_raises_propagates_and_phase_unchanged(self) -> None:
|
||||
"""If chat_stream raises, the exception propagates out of execute_stream
|
||||
and the engine's phase state remains at its starting phase."""
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
|
||||
async def _stream_raises(**kwargs: Any) -> Any:
|
||||
raise RuntimeError("LLM service down")
|
||||
yield # pragma: no cover — async generator marker
|
||||
|
||||
gateway.chat_stream = _stream_raises
|
||||
gateway.get_provider_name_for_model = MagicMock(return_value=None)
|
||||
|
||||
engine = ReActEngine(llm_gateway=gateway, phase_policy=default_policy())
|
||||
search = _StubTool("search")
|
||||
tools: list[Tool] = [search]
|
||||
|
||||
with pytest.raises(RuntimeError, match="LLM service down"):
|
||||
async for _ in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
):
|
||||
pass
|
||||
|
||||
# Phase state unchanged — no transition was triggered.
|
||||
assert engine.current_phase == PhaseState.PLANNING
|
||||
# Tool was never dispatched.
|
||||
assert search.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WS handler integration — phase_changed events emitted to client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPlanExecE2EWSHandler:
|
||||
"""Full WS path: _handle_chat_message emits phase_changed + phase_violation
|
||||
events to the client WebSocket as the engine transitions phases."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ws_handler_emits_phase_changed_and_violation(self) -> None:
|
||||
from fastapi import FastAPI
|
||||
|
||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
from agentkit.server.routes.chat import router
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.store import InMemorySessionStore
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
app.state.session_manager = SessionManager(store=InMemorySessionStore())
|
||||
app.state.agent_pool = MagicMock()
|
||||
app.state.server_config = MagicMock()
|
||||
app.state.server_config.api_key = None
|
||||
app.state.server_config.plan_exec = {}
|
||||
|
||||
# Scripted gateway: write_file in PLANNING (blocked) → advance_phase →
|
||||
# write_file in BUILDING (allowed) → final answer.
|
||||
script: list[list[StreamChunk]] = [
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc1", name="write_file", arguments={"path": "/x", "content": "y"})
|
||||
)],
|
||||
[_tool_call_chunk(ToolCall(id="tc2", name="advance_phase", arguments={}))],
|
||||
[_tool_call_chunk(
|
||||
ToolCall(id="tc3", name="write_file", arguments={"path": "/x", "content": "y"})
|
||||
)],
|
||||
[_final_answer_chunk("Done via WS")],
|
||||
]
|
||||
gateway = _make_scripted_gateway(script)
|
||||
app.state.llm_gateway = gateway
|
||||
|
||||
# Agent mock: returns tools including real AdvancePhaseTool placeholder.
|
||||
# _build_phase_engine appends the real AdvancePhaseTool bound to the
|
||||
# real ReActEngine, so we only need to provide the base tools here.
|
||||
write_file = _StubTool("write_file", {"bytes_written": 1})
|
||||
agent = MagicMock()
|
||||
agent.name = "test-agent"
|
||||
agent._tool_registry = MagicMock()
|
||||
agent._tool_registry.list_tools.return_value = [write_file]
|
||||
agent._system_prompt = None
|
||||
agent._react_engine = None
|
||||
agent.get_model.return_value = "default"
|
||||
app.state.agent_pool.get_agent.return_value = agent
|
||||
|
||||
routing = SkillRoutingResult(
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
tools=[write_file],
|
||||
clean_content="build x",
|
||||
model="default",
|
||||
agent_name="test-agent",
|
||||
system_prompt=None,
|
||||
skill_name=None,
|
||||
)
|
||||
app.state.request_preprocessor = MagicMock()
|
||||
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
|
||||
|
||||
sm = app.state.session_manager
|
||||
# Pre-create the session so get_session succeeds (create_session
|
||||
# generates the session_id internally and returns the Session).
|
||||
session = await sm.create_session(agent_name="test-agent")
|
||||
session_id = session.session_id
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "build x"}])
|
||||
|
||||
ws = MagicMock()
|
||||
ws.app = app
|
||||
ws.send_json = AsyncMock()
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id=session_id,
|
||||
content="build x",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
|
||||
# phase_violation forwarded exactly once (from step 1: write_file in PLANNING).
|
||||
violations = [m for m in sent if m.get("type") == "phase_violation"]
|
||||
assert len(violations) == 1
|
||||
assert violations[0]["data"]["tool"] == "write_file"
|
||||
assert violations[0]["data"]["current_phase"] == "planning"
|
||||
|
||||
# phase_changed forwarded at least once (PLANNING → BUILDING transition).
|
||||
changed = [m for m in sent if m.get("type") == "phase_changed"]
|
||||
assert len(changed) >= 1
|
||||
first_change = changed[0]["data"]
|
||||
assert first_change["phase"] == "building"
|
||||
assert first_change["previous"] == "planning"
|
||||
|
||||
# final_answer emitted.
|
||||
finals = [m for m in sent if m.get("type") == "final_answer"]
|
||||
assert len(finals) == 1
|
||||
assert "Done via WS" in finals[0]["content"]
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
|
||||
"""Unit tests for PLAN_EXEC wiring at chat.py REST + WebSocket paths (G6, U3, U4).
|
||||
|
||||
Per plan U4 Execution note: characterization-first — verify that existing
|
||||
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
|
||||
(no regression). Then add PLAN_EXEC wiring tests.
|
||||
|
||||
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
|
||||
U3: PLAN_EXEC is now wired at both REST and WebSocket paths. REST returns
|
||||
a non-streaming MessageResponse; WS streams phase_violation events alongside
|
||||
the LLM reinjection. KTD5: PLAN_EXEC bypasses the fallback chain.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -109,13 +111,60 @@ def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# REST — PLAN_EXEC raises 501 (KTD4)
|
||||
# REST — PLAN_EXEC wired (U3, replaces former 501 path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestPlanExec501:
|
||||
def test_rest_plan_exec_returns_501(self, client):
|
||||
"""REST send_message with execution_mode=plan_exec → 501."""
|
||||
class TestRestPlanExec:
|
||||
"""U3: REST send_message with execution_mode=plan_exec now executes
|
||||
PLAN_EXEC (non-streaming) instead of raising 501."""
|
||||
|
||||
def test_rest_plan_exec_returns_assistant_message(self, app_with_chat, monkeypatch):
|
||||
"""REST PLAN_EXEC happy path → 200 with assistant message."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
# Patch ReActEngine with a stub whose execute() returns a ReActResult-like.
|
||||
class _StubResult:
|
||||
output = "PLAN_EXEC completed"
|
||||
status = "success"
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = (
|
||||
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||
)
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
return _StubResult()
|
||||
|
||||
monkeypatch.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
# Wire agent_pool with a mock agent that has _tool_registry.
|
||||
agent = _make_agent_mock()
|
||||
app_with_chat.state.agent_pool.get_agent.return_value = agent
|
||||
|
||||
client = TestClient(app_with_chat)
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
msg_resp = client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Build me a hello world", "execution_mode": "plan_exec"},
|
||||
)
|
||||
assert msg_resp.status_code == 200
|
||||
body = msg_resp.json()
|
||||
assert body["content"] == "PLAN_EXEC completed"
|
||||
assert body["role"] == "assistant"
|
||||
|
||||
def test_rest_plan_exec_bad_config_returns_500(self, app_with_chat):
|
||||
"""REST PLAN_EXEC with invalid phase config → 500 with error detail."""
|
||||
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
app_with_chat.state.agent_pool.get_agent.return_value = agent
|
||||
|
||||
client = TestClient(app_with_chat)
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
|
|
@ -123,20 +172,72 @@ class TestRestPlanExec501:
|
|||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello", "execution_mode": "plan_exec"},
|
||||
)
|
||||
assert msg_resp.status_code == 501
|
||||
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
|
||||
assert msg_resp.status_code == 500
|
||||
assert "phase policy error" in msg_resp.json()["detail"]
|
||||
|
||||
def test_rest_react_mode_still_works(self, client):
|
||||
"""REST send_message without execution_mode doesn't 501."""
|
||||
def test_rest_plan_exec_disabled_falls_through_to_react(self, app_with_chat, monkeypatch):
|
||||
"""REST PLAN_EXEC with enabled=False → falls through to REACT path."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {"enabled": False}
|
||||
|
||||
# Track which engine constructor fires.
|
||||
constructed: list = []
|
||||
|
||||
class _StubResult:
|
||||
output = "REACT fallback ok"
|
||||
status = "success"
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
constructed.append(kwargs)
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
return _StubResult()
|
||||
|
||||
monkeypatch.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
# execute_with_fallback_chain also constructs ReflexionEngine internally;
|
||||
# patch it to return a ChatExecutionResult-like directly.
|
||||
from agentkit.server._fallback_chain import ChatExecutionResult
|
||||
|
||||
async def _stub_chain(**kwargs):
|
||||
return ChatExecutionResult(output="REACT fallback ok", status="success")
|
||||
|
||||
monkeypatch.setattr(chat_module, "execute_with_fallback_chain", _stub_chain)
|
||||
|
||||
agent = _make_agent_mock()
|
||||
app_with_chat.state.agent_pool.get_agent.return_value = agent
|
||||
|
||||
client = TestClient(app_with_chat)
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
# No execution_mode field → should NOT trigger 501.
|
||||
msg_resp = client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello", "execution_mode": "plan_exec"},
|
||||
)
|
||||
assert msg_resp.status_code == 200
|
||||
assert msg_resp.json()["content"] == "REACT fallback ok"
|
||||
# No engine should have been constructed with phase_policy — PLAN_EXEC
|
||||
# was disabled and the REACT path doesn't set phase_policy.
|
||||
assert all(kw.get("phase_policy") is None for kw in constructed)
|
||||
|
||||
def test_rest_react_mode_still_works(self, client):
|
||||
"""REST send_message without execution_mode doesn't 500."""
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
# No execution_mode field → should NOT trigger PLAN_EXEC path.
|
||||
# Will likely 500 due to mock llm_gateway, but must NOT be a PLAN_EXEC error.
|
||||
msg_resp = client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello"},
|
||||
)
|
||||
assert msg_resp.status_code != 501
|
||||
# 500 is acceptable (mock gateway), but it must NOT be the PLAN_EXEC error.
|
||||
assert msg_resp.status_code != 501, "REACT fallback should not return 501"
|
||||
if msg_resp.status_code == 500:
|
||||
assert "phase policy error" not in msg_resp.json().get("detail", "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -529,3 +630,265 @@ async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
|
|||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||
assert len(phase_events) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wave 4 U2 — phase_violation event forwarding
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phase_violation_event_forwarded_to_client(app_with_chat):
|
||||
"""When ReActEngine yields a phase_violation ReActEvent, chat.py WS handler
|
||||
must forward it as a `{"type": "phase_violation", "data": ...}` WS message
|
||||
so the frontend PhaseIndicator can react."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = PhaseState.PLANNING
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
# Simulate: tool_call → tool_result (blocked) → phase_violation
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=1,
|
||||
data={"tool_name": "write_file", "arguments": {"path": "/x"}},
|
||||
)
|
||||
yield ReActEvent(
|
||||
event_type="tool_result",
|
||||
step=1,
|
||||
data={
|
||||
"tool_name": "write_file",
|
||||
"result": {"error": "phase_violation", "is_error": True},
|
||||
},
|
||||
)
|
||||
yield ReActEvent(
|
||||
event_type="phase_violation",
|
||||
step=1,
|
||||
data={
|
||||
"error": "phase_violation",
|
||||
"message": "Tool 'write_file' not allowed in planning phase.",
|
||||
"current_phase": "planning",
|
||||
"tool": "write_file",
|
||||
"is_error": True,
|
||||
"violation_kind": "tool_not_allowed",
|
||||
},
|
||||
)
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=2,
|
||||
data={"output": "done"},
|
||||
)
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="go",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"]
|
||||
assert len(violation_messages) == 1
|
||||
v = violation_messages[0]["data"]
|
||||
assert v["error"] == "phase_violation"
|
||||
assert v["tool"] == "write_file"
|
||||
assert v["current_phase"] == "planning"
|
||||
assert v["violation_kind"] == "tool_not_allowed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_phase_violation_event_when_not_plan_exec(app_with_chat):
|
||||
"""Characterization: REACT mode → no phase_violation events forwarded
|
||||
(the engine never yields them without a phase_policy)."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.REACT)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = None
|
||||
self._current_phase = None
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
# REACT mode: no phase_violation events yielded.
|
||||
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="hi",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
violation_messages = [m for m in sent_messages if m.get("type") == "phase_violation"]
|
||||
assert len(violation_messages) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_phase_engine helper (U3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildPhaseEngineHelper:
|
||||
"""Direct unit tests for the _build_phase_engine helper extracted in U3."""
|
||||
|
||||
def test_returns_none_when_not_plan_exec(self):
|
||||
from agentkit.server.routes.chat import _build_phase_engine
|
||||
|
||||
engine, tools, err = _build_phase_engine(
|
||||
server_config=None,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.REACT,
|
||||
base_tools=[],
|
||||
)
|
||||
assert engine is None
|
||||
assert tools is None
|
||||
assert err is None
|
||||
|
||||
def test_returns_none_when_plan_exec_disabled_by_config(self):
|
||||
from agentkit.server.routes.chat import _build_phase_engine
|
||||
|
||||
server_config = MagicMock()
|
||||
server_config.plan_exec = {"enabled": False}
|
||||
|
||||
engine, tools, err = _build_phase_engine(
|
||||
server_config=server_config,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=[],
|
||||
)
|
||||
assert engine is None
|
||||
assert tools is None
|
||||
assert err is None
|
||||
|
||||
def test_returns_none_when_plan_exec_section_absent(self):
|
||||
"""Empty plan_exec config → default_policy() used, engine built."""
|
||||
from agentkit.server.routes.chat import _build_phase_engine
|
||||
|
||||
server_config = MagicMock()
|
||||
server_config.plan_exec = {}
|
||||
|
||||
engine, tools, err = _build_phase_engine(
|
||||
server_config=server_config,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=[],
|
||||
)
|
||||
assert engine is not None
|
||||
assert tools is not None
|
||||
assert err is None
|
||||
# Default policy: PLANNING allows search, blocks write_file
|
||||
assert "search" in engine._phase_policy.whitelist[PhaseState.PLANNING]
|
||||
assert "write_file" not in engine._phase_policy.whitelist[PhaseState.PLANNING]
|
||||
|
||||
def test_returns_error_when_phase_policy_invalid(self):
|
||||
from agentkit.server.routes.chat import _build_phase_engine
|
||||
|
||||
server_config = MagicMock()
|
||||
server_config.plan_exec = {"start_phase": "invalid_phase_name"}
|
||||
|
||||
engine, tools, err = _build_phase_engine(
|
||||
server_config=server_config,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=[],
|
||||
)
|
||||
assert engine is None
|
||||
assert tools is None
|
||||
assert err is not None
|
||||
assert "phase policy error" in err
|
||||
|
||||
def test_appends_advance_phase_tool_to_tools(self):
|
||||
from agentkit.server.routes.chat import _build_phase_engine
|
||||
|
||||
server_config = MagicMock()
|
||||
server_config.plan_exec = {}
|
||||
|
||||
base_tool = MagicMock()
|
||||
engine, tools, err = _build_phase_engine(
|
||||
server_config=server_config,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=[base_tool],
|
||||
)
|
||||
assert err is None
|
||||
assert engine is not None
|
||||
assert tools is not None
|
||||
# base_tool preserved + AdvancePhaseTool appended
|
||||
assert len(tools) == 2
|
||||
assert tools[0] is base_tool
|
||||
assert isinstance(tools[1], AdvancePhaseTool)
|
||||
|
||||
def test_engine_uses_default_policy_when_config_returns_none(self, monkeypatch):
|
||||
"""policy_from_config returning None → default_policy() used."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
def _stub_policy_from_config(cfg):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(chat_module, "policy_from_config", _stub_policy_from_config)
|
||||
|
||||
server_config = MagicMock()
|
||||
server_config.plan_exec = {"enabled": True}
|
||||
|
||||
engine, tools, err = chat_module._build_phase_engine(
|
||||
server_config=server_config,
|
||||
llm_gateway=MagicMock(),
|
||||
execution_mode=ExecutionMode.PLAN_EXEC,
|
||||
base_tools=[],
|
||||
)
|
||||
assert err is None
|
||||
assert engine is not None
|
||||
assert engine._phase_policy is not None
|
||||
# Default policy's start phase is PLANNING
|
||||
assert engine._current_phase == PhaseState.PLANNING
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ Covers:
|
|||
- PhasePolicy.is_tool_allowed / is_bash_command_allowed
|
||||
- policy_from_config parsing (R26 config-driven)
|
||||
- ServerConfig.plan_exec integration
|
||||
- Wave 4 U1: bash_command_filter accepts Callable (ShellTool._is_dangerous reuse)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -22,6 +23,7 @@ from agentkit.core.phase import (
|
|||
policy_from_config,
|
||||
)
|
||||
from agentkit.server.config import ServerConfig
|
||||
from agentkit.tools.shell import ShellTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -115,6 +117,109 @@ class TestDefaultPolicy:
|
|||
assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True
|
||||
assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True
|
||||
|
||||
# --- Wave 4 U1 characterization (Wave 3 behavior preserved) -----------------
|
||||
# default_policy() now wires ShellTool._is_dangerous (a Callable) for
|
||||
# PLANNING/VERIFICATION. These tests pin the contract so a future regression
|
||||
# in either ShellTool._is_dangerous or PhasePolicy dispatch surfaces here.
|
||||
|
||||
def test_bash_filter_callable_in_default_policy(self):
|
||||
# Sanity: default_policy uses a Callable, not a regex Pattern.
|
||||
policy = default_policy()
|
||||
planning_filter = policy.bash_command_filter[PhaseState.PLANNING]
|
||||
assert callable(planning_filter)
|
||||
assert planning_filter is ShellTool._is_dangerous
|
||||
|
||||
def test_bash_filter_characterization_safe_commands(self):
|
||||
# Wave 3 behavior preserved — safe read-only commands.
|
||||
policy = default_policy()
|
||||
for cmd in ("ls -la", "pwd", "git status", "find . -name foo", "cat README.md"):
|
||||
assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is True, cmd
|
||||
|
||||
def test_bash_filter_characterization_dangerous_commands(self):
|
||||
# Wave 3 behavior preserved — commands blocked by the old regex.
|
||||
policy = default_policy()
|
||||
for cmd in (
|
||||
"rm -rf /",
|
||||
"rm -rf /tmp/x",
|
||||
"mv a b",
|
||||
"cp a b",
|
||||
"mkdir newdir",
|
||||
"chmod 777 file",
|
||||
"chown root file",
|
||||
"echo x > file.txt",
|
||||
"echo x >> file.txt",
|
||||
):
|
||||
assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is False, cmd
|
||||
|
||||
# --- Wave 4 U1 ceiling closed (new edge cases the old regex missed) ---------
|
||||
|
||||
def test_bash_filter_closes_regex_ceiling_dd_of(self):
|
||||
# Old regex missed `dd of=/dev/sda` (no word-boundary match for "dd").
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("dd of=/dev/sda", PhaseState.PLANNING) is False
|
||||
|
||||
def test_bash_filter_closes_regex_ceiling_colon_redirect(self):
|
||||
# Old regex missed `:>file` (no whitespace before `>`).
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed(":>file", PhaseState.PLANNING) is False
|
||||
|
||||
def test_bash_filter_closes_regex_ceiling_redirection_after_arg(self):
|
||||
# Old regex's `(?<!\S)>` looked for `>` at start or after whitespace.
|
||||
# `echo hello > /tmp/x` slipped through because `>` had a space before it
|
||||
# but the regex matched the wrong alternative. Verify the new filter
|
||||
# classifies this as dangerous.
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("echo hello > /tmp/x", PhaseState.PLANNING) is False
|
||||
|
||||
def test_bash_filter_closes_regex_ceiling_chain_operators(self):
|
||||
# Old regex did NOT match `;`, `&&`, `||` as dangerous. The new filter
|
||||
# treats all chain operators as dangerous (matches ShellTool behavior).
|
||||
policy = default_policy()
|
||||
for cmd in (
|
||||
"ls; rm -rf /tmp",
|
||||
"ls && rm -rf /tmp",
|
||||
"ls || rm -rf /tmp",
|
||||
"$(rm -rf /tmp)",
|
||||
"`rm -rf /tmp`",
|
||||
):
|
||||
assert policy.is_bash_command_allowed(cmd, PhaseState.PLANNING) is False, cmd
|
||||
|
||||
def test_bash_filter_closes_regex_ceiling_pipe_with_dangerous_segment(self):
|
||||
# Old regex scanned the WHOLE command string, so `echo x | grep y`
|
||||
# would be allowed (no dangerous token) but `rm x | cat` would be
|
||||
# blocked (matches `\brm\b`). The new filter splits pipes and checks
|
||||
# each segment, so `echo x | grep y` should be allowed and
|
||||
# `rm x | cat` blocked.
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("echo x | grep y", PhaseState.PLANNING) is True
|
||||
assert policy.is_bash_command_allowed("rm x | cat", PhaseState.PLANNING) is False
|
||||
|
||||
def test_bash_filter_verification_phase_uses_callable(self):
|
||||
# Same callable wired into VERIFICATION.
|
||||
# Note: `pytest` is NOT in ShellTool._SAFE_COMMAND_PREFIXES, so
|
||||
# _is_dangerous returns True for it — the verification phase does NOT
|
||||
# widen the ShellTool whitelist. Use a known-safe read-only command
|
||||
# for the "allowed" assertion. (Wave 4 U1 reuses ShellTool._is_dangerous
|
||||
# as-is; expanding its safe-whitelist is out of scope.)
|
||||
policy = default_policy()
|
||||
assert policy.bash_command_filter[PhaseState.VERIFICATION] is ShellTool._is_dangerous
|
||||
assert policy.is_bash_command_allowed("rm -rf /", PhaseState.VERIFICATION) is False
|
||||
assert policy.is_bash_command_allowed("ls -la", PhaseState.VERIFICATION) is True
|
||||
assert policy.is_bash_command_allowed("git status", PhaseState.VERIFICATION) is True
|
||||
|
||||
def test_bash_filter_delivery_phase_no_filter(self):
|
||||
# DELIVERY has no filter — full bash allowed.
|
||||
policy = default_policy()
|
||||
assert policy.bash_command_filter[PhaseState.DELIVERY] is None
|
||||
assert policy.is_bash_command_allowed("rm -rf /", PhaseState.DELIVERY) is True
|
||||
|
||||
def test_bash_filter_empty_command_allowed(self):
|
||||
# is_bash_command_allowed must NOT call the filter on empty input —
|
||||
# ShellTool separately rejects empty commands. Empty is "allowed" by
|
||||
# the policy (no rejection injected to the LLM).
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("", PhaseState.PLANNING) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PhasePolicy — is_tool_allowed
|
||||
|
|
@ -207,6 +312,16 @@ class TestPhasePolicyEdgeCases:
|
|||
assert d["start_phase"] == "planning"
|
||||
assert d["auto_advance_after_steps"] is None
|
||||
|
||||
def test_to_dict_serializes_callable_as_marker(self):
|
||||
# Wave 4 U1: default_policy now wires a Callable. to_dict must
|
||||
# surface it as "<callable>" so logs/telemetry stay readable.
|
||||
policy = default_policy()
|
||||
d = policy.to_dict()
|
||||
assert d["bash_command_filter"]["planning"] == "<callable>"
|
||||
assert d["bash_command_filter"]["verification"] == "<callable>"
|
||||
assert d["bash_command_filter"]["building"] is None
|
||||
assert d["bash_command_filter"]["delivery"] is None
|
||||
|
||||
def test_custom_bash_filter(self):
|
||||
custom_filter = re.compile(r"\b(pip install|npm install)\b")
|
||||
policy = PhasePolicy(
|
||||
|
|
@ -221,6 +336,48 @@ class TestPhasePolicyEdgeCases:
|
|||
assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False
|
||||
assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True
|
||||
|
||||
def test_custom_bash_filter_accepts_callable(self):
|
||||
# Wave 4 U1: callable form. The callable returns True for dangerous.
|
||||
def deny_all(_: str) -> bool:
|
||||
return True # everything is "dangerous"
|
||||
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({"shell"}),
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
bash_command_filter={PhaseState.PLANNING: deny_all},
|
||||
)
|
||||
assert policy.is_bash_command_allowed("ls", PhaseState.PLANNING) is False
|
||||
assert policy.is_bash_command_allowed("rm -rf /", PhaseState.PLANNING) is False
|
||||
|
||||
def test_callable_filter_takes_precedence_over_pattern_form(self):
|
||||
# Wave 4 U1: when a phase has a callable wired, the dispatch path is
|
||||
# the callable branch, not the regex branch. Sanity-check the
|
||||
# is_bash_command_allowed routing — both forms coexist in the same
|
||||
# policy dict, each phase is independent.
|
||||
pattern = re.compile(r"\brm\b")
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({"shell"}),
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
bash_command_filter={
|
||||
PhaseState.PLANNING: pattern, # regex
|
||||
PhaseState.BUILDING: ShellTool._is_dangerous, # callable
|
||||
},
|
||||
)
|
||||
# PLANNING uses regex form.
|
||||
assert policy.is_bash_command_allowed("rm x", PhaseState.PLANNING) is False
|
||||
assert policy.is_bash_command_allowed("ls", PhaseState.PLANNING) is True
|
||||
# BUILDING uses callable form.
|
||||
assert policy.is_bash_command_allowed("rm x", PhaseState.BUILDING) is False
|
||||
assert policy.is_bash_command_allowed("ls", PhaseState.BUILDING) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# policy_from_config — R26 (config-driven)
|
||||
|
|
|
|||
|
|
@ -337,3 +337,283 @@ class TestAdvancePhaseTool:
|
|||
assert "advance_phase" not in allowed, (
|
||||
f"advance_phase must not be in {phase.value} whitelist"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wave 4 U2 — phase_violation accumulator + drain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPhaseViolationAccumulator:
|
||||
"""_check_phase_permission records violations; _drain_phase_violations
|
||||
yields them as ReActEvents and clears the accumulator."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return ReActEngine(
|
||||
llm_gateway=MagicMock(),
|
||||
phase_policy=default_policy(),
|
||||
)
|
||||
|
||||
def test_violation_appended_on_tool_block(self, engine):
|
||||
# write_file is blocked in PLANNING.
|
||||
engine._check_phase_permission("write_file", {})
|
||||
assert len(engine._phase_violations) == 1
|
||||
v = engine._phase_violations[0]
|
||||
assert v["error"] == "phase_violation"
|
||||
assert v["tool"] == "write_file"
|
||||
assert v["current_phase"] == "planning"
|
||||
assert v["violation_kind"] == "tool_not_allowed"
|
||||
|
||||
def test_violation_appended_on_bash_block(self, engine):
|
||||
engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
||||
assert len(engine._phase_violations) == 1
|
||||
v = engine._phase_violations[0]
|
||||
assert v["violation_kind"] == "bash_command_blocked"
|
||||
assert v["tool"] == "shell"
|
||||
assert v["command_preview"] == "rm -rf /tmp"
|
||||
|
||||
def test_no_violation_when_allowed(self, engine):
|
||||
# search is allowed in PLANNING.
|
||||
engine._check_phase_permission("search", {})
|
||||
assert engine._phase_violations == []
|
||||
|
||||
def test_no_violation_without_policy(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
||||
engine._check_phase_permission("anything", {})
|
||||
assert engine._phase_violations == []
|
||||
|
||||
def test_drain_returns_events_and_clears(self, engine):
|
||||
# Trigger two violations.
|
||||
engine._check_phase_permission("write_file", {"path": "/a"})
|
||||
engine._check_phase_permission("write_file", {"path": "/b"})
|
||||
assert len(engine._phase_violations) == 2
|
||||
|
||||
events = engine._drain_phase_violations(step=3)
|
||||
assert len(events) == 2
|
||||
assert all(e.event_type == "phase_violation" for e in events)
|
||||
assert all(e.step == 3 for e in events)
|
||||
# Each event data is a copy (caller can't mutate the accumulator).
|
||||
assert events[0].data["tool"] == "write_file"
|
||||
# Accumulator cleared after drain.
|
||||
assert engine._phase_violations == []
|
||||
|
||||
def test_drain_empty_returns_empty(self, engine):
|
||||
assert engine._drain_phase_violations(step=1) == []
|
||||
|
||||
def test_drain_returns_shallow_copy(self, engine):
|
||||
"""Drained event data must not alias the original violation dict —
|
||||
mutating one must not mutate the other."""
|
||||
engine._check_phase_permission("write_file", {})
|
||||
events = engine._drain_phase_violations(step=1)
|
||||
# Mutate the drained event data.
|
||||
events[0].data["tool"] = "MUTATED"
|
||||
# Original accumulator (now empty) is unaffected — but more importantly,
|
||||
# a fresh violation recorded after drain is unaffected.
|
||||
engine._check_phase_permission("write_file", {})
|
||||
new_violations = engine._phase_violations
|
||||
assert new_violations[0]["tool"] == "write_file" # not "MUTATED"
|
||||
|
||||
def test_reset_clears_violations(self, engine):
|
||||
engine._check_phase_permission("write_file", {})
|
||||
assert len(engine._phase_violations) == 1
|
||||
engine.reset()
|
||||
assert engine._phase_violations == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wave 4 U2 — execute_stream yields phase_violation events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecuteStreamPhaseViolationEvents:
|
||||
"""execute_stream must yield phase_violation ReActEvents when a tool is
|
||||
blocked by _check_phase_permission. The events are drained after each
|
||||
tool_result yield."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_yields_phase_violation_on_tool_block(self):
|
||||
"""When the LLM calls a tool blocked by the phase policy, execute_stream
|
||||
yields a tool_call event, a tool_result event (with the error dict),
|
||||
and a phase_violation event."""
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=llm_mock_gateway_with_response(
|
||||
tool_calls=[{"name": "write_file", "arguments": {"path": "/x"}}],
|
||||
content=None,
|
||||
),
|
||||
phase_policy=default_policy(),
|
||||
max_steps=1,
|
||||
)
|
||||
# Patch _find_tool so we don't need real tools registered. write_file
|
||||
# should be blocked by phase_policy before _find_tool is called.
|
||||
engine._find_tool = lambda name, tools: None
|
||||
|
||||
events: list[ReActEvent] = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[],
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
# Expect: thinking → tool_call → tool_result → phase_violation → final_answer
|
||||
event_types = [e.event_type for e in events]
|
||||
assert "tool_call" in event_types
|
||||
assert "tool_result" in event_types
|
||||
assert "phase_violation" in event_types
|
||||
|
||||
# The phase_violation event must come AFTER tool_result.
|
||||
tool_result_idx = next(i for i, e in enumerate(events) if e.event_type == "tool_result")
|
||||
violation_idx = next(i for i, e in enumerate(events) if e.event_type == "phase_violation")
|
||||
assert violation_idx > tool_result_idx
|
||||
|
||||
# Verify event data.
|
||||
violation = events[violation_idx]
|
||||
assert violation.data["error"] == "phase_violation"
|
||||
assert violation.data["tool"] == "write_file"
|
||||
assert violation.data["current_phase"] == "planning"
|
||||
assert violation.data["violation_kind"] == "tool_not_allowed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_yields_phase_violation_on_bash_block(self):
|
||||
"""When the LLM calls shell with a dangerous command in PLANNING,
|
||||
execute_stream yields a phase_violation event with violation_kind
|
||||
= bash_command_blocked."""
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=llm_mock_gateway_with_response(
|
||||
tool_calls=[{"name": "shell", "arguments": {"command": "rm -rf /tmp"}}],
|
||||
content=None,
|
||||
),
|
||||
phase_policy=default_policy(),
|
||||
max_steps=1,
|
||||
)
|
||||
engine._find_tool = lambda name, tools: None
|
||||
|
||||
events: list[ReActEvent] = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[],
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
violation_events = [e for e in events if e.event_type == "phase_violation"]
|
||||
assert len(violation_events) == 1
|
||||
v = violation_events[0].data
|
||||
assert v["violation_kind"] == "bash_command_blocked"
|
||||
assert v["tool"] == "shell"
|
||||
assert "rm -rf /tmp" in v["command_preview"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_no_violation_when_tool_allowed(self):
|
||||
"""When the LLM calls an allowed tool, no phase_violation event is yielded."""
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=llm_mock_gateway_with_response(
|
||||
tool_calls=[{"name": "search", "arguments": {"query": "foo"}}],
|
||||
content=None,
|
||||
),
|
||||
phase_policy=default_policy(),
|
||||
max_steps=1,
|
||||
)
|
||||
# search is allowed in PLANNING; we still need _find_tool to return a
|
||||
# tool object so dispatch proceeds.
|
||||
search_tool = MagicMock()
|
||||
search_tool.input_schema = None
|
||||
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
||||
engine._find_tool = lambda name, tools: search_tool
|
||||
|
||||
events: list[ReActEvent] = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[search_tool],
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
assert not any(e.event_type == "phase_violation" for e in events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_no_violation_without_policy(self):
|
||||
"""Without a phase_policy, no phase_violation events are yielded —
|
||||
characterization of the no-policy path."""
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
engine = ReActEngine(
|
||||
llm_gateway=llm_mock_gateway_with_response(
|
||||
tool_calls=[{"name": "any_tool", "arguments": {}}],
|
||||
content=None,
|
||||
),
|
||||
max_steps=1,
|
||||
)
|
||||
any_tool = MagicMock()
|
||||
any_tool.input_schema = None
|
||||
any_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
||||
engine._find_tool = lambda name, tools: any_tool
|
||||
|
||||
events: list[ReActEvent] = []
|
||||
async for ev in engine.execute_stream(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
tools=[any_tool],
|
||||
):
|
||||
events.append(ev)
|
||||
|
||||
assert not any(e.event_type == "phase_violation" for e in events)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers — minimal LLM gateway mocks for execute_stream tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def llm_mock_gateway():
|
||||
"""Return a MagicMock LLM gateway (sufficient for constructor tests)."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def llm_mock_gateway_with_response(tool_calls: list[dict], content: str | None):
|
||||
"""Return a MagicMock LLM gateway whose chat_stream yields a single chunk
|
||||
containing the given tool_calls (or content for a final-answer response).
|
||||
|
||||
The mock is shaped to match what execute_stream expects from
|
||||
LLMGateway.chat_stream — an async iterable of chunks with attributes
|
||||
`content`, `tool_calls`, `usage`, `model`.
|
||||
"""
|
||||
gateway = MagicMock()
|
||||
|
||||
# Build a fake chunk. execute_stream reads chunk.content, chunk.tool_calls,
|
||||
# chunk.usage, chunk.model. The first three are typically accessed via
|
||||
# attribute access; we make a small dataclass-like object.
|
||||
class _Chunk:
|
||||
def __init__(self, content, tool_calls, usage=None, model="default"):
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls
|
||||
self.usage = usage
|
||||
self.model = model
|
||||
|
||||
# If tool_calls provided, emit a chunk with tool_calls (non-streaming path).
|
||||
# Otherwise, emit a chunk with content (final answer path).
|
||||
if tool_calls:
|
||||
# Convert raw dicts to objects with .name/.arguments/.id attributes
|
||||
# (LLMGateway normally returns tool_call objects).
|
||||
class _TC:
|
||||
def __init__(self, d):
|
||||
self.name = d.get("name", "")
|
||||
self.arguments = d.get("arguments", {})
|
||||
self.id = d.get("id", "tc_test")
|
||||
|
||||
chunks = [_Chunk(content=None, tool_calls=[_TC(tc) for tc in tool_calls])]
|
||||
# Follow with a final-answer chunk so execute_stream's loop exits
|
||||
# cleanly after the tool call.
|
||||
chunks.append(_Chunk(content="done", tool_calls=None))
|
||||
else:
|
||||
chunks = [_Chunk(content=content or "final answer", tool_calls=None)]
|
||||
|
||||
async def _fake_chat_stream(*args, **kwargs):
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
gateway.chat_stream = _fake_chat_stream
|
||||
return gateway
|
||||
|
|
|
|||
Loading…
Reference in New Issue