diff --git a/agentkit.yaml b/agentkit.yaml index 9588b70..a0b2f43 100644 --- a/agentkit.yaml +++ b/agentkit.yaml @@ -51,6 +51,19 @@ fallback_chain: max_retries: 1 # ReflexionEngine max_reflections override emergency: enabled: true +# G6/U2: PLAN_EXEC phase policy — SOLO four-stage state machine. +# When `enabled: true`, chat WebSocket PLAN_EXEC requests build a PhasePolicy +# (Planning → Building → Verification → Delivery) and enforce per-step tool +# whitelists (R24). Transitions are LLM-driven via AdvancePhaseTool; set +# `auto_advance_after_steps` to auto-advance as a safety net (KTD6). +# Commented to preserve default behavior — uncomment to enable. +# plan_exec: +# enabled: true +# auto_advance_after_steps: 5 # optional, default = manual (LLM calls advance_phase) +# start_phase: planning # optional, default = planning +# whitelist_override: # optional, merges with default (override wins) +# planning: [search, read_file, shell] +# building: [write_file, shell, read_file] session: {backend: memory} bus: {backend: memory} task_store: {backend: memory} diff --git a/docs/plans/2026-06-29-004-feat-agent-wave3-strategic-plan.md b/docs/plans/2026-06-29-004-feat-agent-wave3-strategic-plan.md new file mode 100644 index 0000000..cef1c54 --- /dev/null +++ b/docs/plans/2026-06-29-004-feat-agent-wave3-strategic-plan.md @@ -0,0 +1,436 @@ +--- +title: "feat: Agent Wave 3 strategic coupling (G5/G6)" +date: 2026-06-29 +type: feat +status: draft +origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md +execution: code +--- + +# Wave 3 Strategic Coupling — G5 Function-level Sharding + G6 SOLO Phase Constraints + +## Summary + +Wave 3 of the advanced-agent gap optimization closes two strategic gaps deferred from Waves 1-2: + +- **G5 — Function-level code sharding** (R22, R23): file reading gains an optional `symbol` parameter for symbol/function-granularity slicing, backward compatible with full-file reads. +- **G6 — SOLO four-stage state machine** (R24, R25): ReAct loop enforces per-phase tool whitelists (Planning → Building → Verification → Delivery). Extends existing `ExecutionMode.PLAN_EXEC` rather than introducing a new mode. + +Wave 1 (G1/G2/G3/G8 — PR #4 merged) and Wave 2 (G4/G7/G9 — PR #5 open) shipped independently. Wave 3 is the **strategic-risk** wave: it introduces a new tool (G5) and touches ReAct core (G6). Per the brainstorm's KTD6/KTD7 locked decisions, G5 integration approach is decided here (not deferred further), and G6 extends PLAN_EXEC rather than adding a new mode. + +## Problem Frame + +The brainstorm (`docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md`) identified nine gaps across three dimensions. Waves 1-2 closed seven (G1-G4, G7-G9). The remaining two gaps are strategic: + +- **G5 (Long-context cost)**: Large files (e.g. 5000-line modules) blow context budget when the agent only needs one function. Today the agent shells out to `cat file.py` or greps; both pull the whole file into context. A symbol-aware slice would cut context cost 10-50x for typical edits. +- **G6 (Unsafe tool sequencing)**: ReAct loop lets the LLM call `write_file` during early exploration before committing to a plan. This wastes tokens on premature edits, causes half-baked refactors, and breaks the "plan-then-build" discipline that production agents (Qoder, Trae Work) enforce via phase state machines. + +KTD6 (brainstorm) defers the G5 integration decision to this plan. KTD7 locks G6 to extend PLAN_EXEC rather than introduce a new mode. + +## Requirements + +Carried forward from the brainstorm, Wave 3 section: + +- **R22**: File reading supports symbol/function granularity sharding. +- **R23**: Sharding capability exposed as a tool parameter (`symbol="function_name"`), backward compatible with full-file reads. +- **R24**: ReAct loop enforces phase constraints — Planning phase only allows `think`/`search`; Building phase only allows `write_file` (and similar write tools). +- **R25**: Phase state is configurable; extends `ExecutionMode.PLAN_EXEC`, does NOT introduce a new mode. + +Cross-cutting (brainstorm): + +- **R26**: All optimizations configurable via `agentkit.yaml` (follow `ServerConfig.from_dict` pattern established in Waves 1-2). +- **R27**: Each optimization ships a minimal self-check test (ponytail rule). + +Acceptance examples relevant to Wave 3: + +- **AE5 already covered by Wave 2 (G9)** — not in Wave 3 scope. +- No explicit AE for G5/G6 in the brainstorm — the plan below specifies test scenarios as the acceptance contract. + +## Key Technical Decisions + +### KTD1: G5 uses Python `ast` + language-aware regex, NOT tree-sitter + +**Decision**: Implement symbol extraction with the Python stdlib `ast` module for Python files, and a small regex-based extractor for TypeScript/JavaScript/Go/Rust/Java. No new dependency. + +**Rationale**: +- `tree-sitter` requires native compilation + per-language grammar files (~30MB installed) — violates the ponytail rule "no new dependency if it can be avoided" and AGENTS.md "禁止使用 any 类型" → prefers minimal stack. +- `ast` is stdlib, always available, parses Python accurately. +- Regex extractor covers 80% case for TS/JS/Go/Rust/Java (function/class/struct declarations); falls back to "no symbols found → read full file" gracefully. +- If a future Wave 4 needs more accurate multi-language parsing, `tree-sitter` can replace the regex layer behind the same `SymbolExtractor` interface. + +**Upgrade path**: replace `RegexSymbolExtractor` with `TreeSitterSymbolExtractor` implementing the same `SymbolExtractor` protocol; no caller changes. + +### KTD2: G5 adds a new `ReadFileTool`, does not extend `ShellTool` + +**Decision**: Add a dedicated `ReadFileTool` in `src/agentkit/tools/file_read.py` with `path` + optional `symbol` + optional `start_line`/`end_line` parameters. + +**Rationale**: +- `ShellTool` is for shell execution; grafting symbol extraction onto it muddies the contract. +- A dedicated tool gives the LLM a clear schema (`{"path": "...", "symbol": "function_name"}`) and a focused system-prompt description. +- Aligns with the existing `_DEFAULT_CORE_TOOLS` list in `core/react.py:148` which already references `read_file` — the name is reserved but the implementation is missing. + +### KTD3: G6 phase state machine lives in ReActEngine, not in the skill config + +**Decision**: Phase state (Planning/Building/Verification/Delivery) is tracked as a mutable field on `ReActEngine` instance. Transitions are driven by LLM-detected phase-completion signals (e.g., the LLM emits `Phase: Building` in its thinking) OR by an explicit `advance_phase` tool. + +**Rationale**: +- Skill config declares the policy (which tools per phase, auto-advance vs manual); the engine enforces it per-step. This matches R24 ("ReAct 循环加阶段约束"). +- Alternative considered: phase state in the agent instance (not engine). Rejected because ReActEngine already owns `max_steps`/`verification_enabled` etc.; phase state belongs with the loop that enforces it. + +### KTD4: PLAN_EXEC mode is wired at chat.py WebSocket path (REST already has fallback chain from Wave 2) + +**Decision**: chat.py:1084 (currently warns "not yet supported, falling back to REACT") will route `ExecutionMode.PLAN_EXEC` to a new `_execute_plan_exec_ws` handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and passes it to `ReActEngine.execute`. + +**Rationale**: +- REST `send_message` already uses the Wave 2 three-tier fallback chain; PLAN_EXEC at REST would also need that wrapper. **Out of scope for Wave 3** — only WebSocket path is wired. REST PLAN_EXEC remains "not yet supported" and explicitly raises if invoked. +- Single integration point keeps Wave 3 scope bounded; REST wiring is a one-line follow-up once WebSocket path is proven. + +### KTD5: Default phase whitelist matches brainstorm R24 + +**Decision**: Default whitelist: +- `Planning`: `search`, `tool_search`, `read_file`, `bash` (read-only commands like `git status`, `ls`) +- `Building`: `write_file`, `bash` (write commands), `read_file`, `search` +- `Verification`: `bash` (test commands), `read_file`, `search` +- `Delivery`: all tools (final synthesis) + +**Rationale**: +- R24 explicitly names `think`/`search` for Planning and `write_file` for Building. +- `bash` is split: read-only in Planning, full in Building. Enforced by adding a `bash_command_filter` callback (regex-based, blocks `rm`/`mv`/`>`/`>>` in Planning/Verification). +- `Delivery` allows all tools to support last-mile formatting/cleanup. + +### KTD6: Phase transitions are LLM-driven via `advance_phase` tool (opt-in auto-advance) + +**Decision**: Add an `AdvancePhaseTool` that the LLM can call to transition Planning→Building→Verification→Delivery. Auto-advance (after N steps in current phase) is opt-in via `plan_exec.auto_advance_after_steps`. + +**Rationale**: +- LLM-driven transitions match Qoder/Trae Work pattern: LLM declares "planning done" explicitly. +- Auto-advance is a safety net for LLMs that forget to call `advance_phase`; default off (ponytail: less code is better). + +## Scope Boundaries + +### In Scope + +- New `ReadFileTool` with `symbol` parameter (G5, R22/R23). +- `SymbolExtractor` protocol + `AstSymbolExtractor` (Python) + `RegexSymbolExtractor` (TS/JS/Go/Rust/Java). +- `PhasePolicy` dataclass + `PhaseState` enum + per-step tool whitelist enforcement in `ReActEngine.execute` (G6, R24/R25). +- `AdvancePhaseTool` for LLM-driven phase transitions. +- WebSocket chat path routes `PLAN_EXEC` to new `_execute_plan_exec_ws` handler (KTD4). +- `plan_exec` config section in `agentkit.yaml` + `ServerConfig.from_dict` extension (R26). +- Tests for each new module (R27). + +### Out of Scope (Deferred to Follow-Up Work) + +- REST `send_message` PLAN_EXEC wiring — once WebSocket path is proven, REST wiring is a follow-up commit. +- `tree-sitter` integration for more accurate multi-language parsing (KTD1 upgrade path). +- Phase-aware prompt engineering (per-phase system prompt templates) — current plan keeps a single system prompt; phase-specific guidance is a prompt-engineering concern, not a code change. +- Phase persistence across session resume (U7 checkpoint already saves plan state; phase state restoration is a separate concern). +- Phase rollback on `Building` → `Planning` regression (Wave 2 G9 rollback handles file-level rollback; phase regression is a UX/prompt concern). +- Tool-filter UI in the frontend (Wave 3 ships backend-only; frontend surfaces phase via existing event channel if needed in a follow-up). + +### Outside This Product's Identity + +- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm). +- Disc-based file system à la DeerFlow (inherited). +- Docker sandbox (inherited; only command-level safety via `bash_command_filter`). + +## High-Level Technical Design + +```mermaid +flowchart LR + subgraph G5[Function Sharding] + RF[ReadFileTool] --> SE[SymbolExtractor protocol] + SE --> AST[AstSymbolExtractor
Python stdlib ast] + SE --> RX[RegexSymbolExtractor
TS/JS/Go/Rust/Java] + end + + subgraph G6[Phase State Machine] + PP[PhasePolicy config] --> PS[PhaseState enum
Planning/Building/Verify/Delivery] + PS --> Filt[Tool filter per step] + Filt --> RE[ReActEngine.execute] + AP[AdvancePhaseTool] -->|transitions| PS + end + + RF -->|file content for symbol| RE + RE -->|enforces| Filt +``` + +The two subsystems compose at the ReAct engine boundary: `ReadFileTool` is one of the tools the LLM can call during any phase (filtered by `PhasePolicy`); `PhaseState` is enforced at the tool-call step before dispatch. + +## Implementation Units + +### U1. SymbolExtractor + ReadFileTool (G5) + +**Goal**: Add `ReadFileTool` with optional `symbol` parameter; implement `SymbolExtractor` protocol with `AstSymbolExtractor` (Python) and `RegexSymbolExtractor` (TS/JS/Go/Rust/Java). + +**Requirements**: R22, R23, R27. + +**Dependencies**: none. + +**Files**: +- `src/agentkit/tools/file_read.py` (new) +- `src/agentkit/tools/symbol_extractor.py` (new) +- `src/agentkit/tools/__init__.py` (modify — register `ReadFileTool`) +- `tests/unit/test_symbol_extractor.py` (new) +- `tests/unit/test_read_file_tool.py` (new) + +**Approach**: +- `SymbolExtractor` is a `Protocol` with one method: `extract_symbols(content: str, language: str) -> list[SymbolSpan]`. `SymbolSpan` carries `name`, `kind` (function/class/method/struct), `start_line`, `end_line`. +- `AstSymbolExtractor` walks `ast.parse(content)`; for each `FunctionDef`/`AsyncFunctionDef`/`ClassDef` collects name + line range. Uses `ast.get_source_segment` style (line-based, not node-based, to keep the API simple). +- `RegexSymbolExtractor` ships patterns for TS/JS (`function X`, `const X = (...) =>`, `class X`), Go (`func X`), Rust (`fn X`, `struct X`, `impl X`), Java (`public ... X(...)`). Falls back to "no symbols" if no pattern matches. +- `ReadFileTool.execute(path, symbol=None, start_line=None, end_line=None)`: + - `symbol=None` → read full file (backward compat with the existing `_FakeTool` benchmark shape). + - `symbol="foo"` → detect language from extension; call `extract_symbols`; return the line range of the first matching symbol; if no match, return an error result with available symbol names listed (so the LLM can retry). + - `start_line`/`end_line` overrides symbol; allows manual slicing. +- Tool registered as `read_file` (matches the reserved name in `core/react.py:148`). + +**Execution note**: characterization-first — write a test that asserts the tool returns the full file content when `symbol=None` (matches pre-existing benchmark `_FakeTool` shape) before adding symbol-extraction behavior. + +**Patterns to follow**: +- `src/agentkit/tools/document_tool.py` for tool structure (dataclass, `Tool` base class, `input_schema`). +- `src/agentkit/tools/schema_tools.py:SchemaExtractTool` for "extract-from-source" pattern. + +**Test scenarios** (covers R22, R23): +- **Happy paths**: + - Python file, `symbol="MyClass"` → returns class body only (lines from `class MyClass:` through end of class). + - Python file, `symbol="my_func"` → returns function body only. + - TypeScript file, `symbol="renderComponent"` → returns arrow/function body. + - Go file, `symbol="HandleRequest"` → returns func body. +- **Edge cases**: + - `symbol=None` → returns full file content (characterization). + - `symbol="nonexistent"` → returns error result listing available symbols ("Available symbols: foo, bar, baz"). + - Unsupported file extension (`.md`, `.txt`) → returns full file with `note: symbol extraction not supported for .md`. + - Empty file → returns empty content. + - File with nested classes → outer class symbol returns including inner class. +- **Error paths**: + - Path does not exist → raises `FileNotFoundError` (or returns error result matching other tools' convention). + - Path is a directory → returns error result. + - Permission denied → returns error result. +- **Integration scenarios**: + - Symbol extraction + line slicing: `symbol="foo"`, `end_line=50` truncates at line 50 even if symbol extends further. + - Round-trip: extract symbol, write back via `ShellTool` `sed` (not in scope for tool — just verify extracted range is well-formed). + +**Verification**: +- `python3 -m pytest tests/unit/test_symbol_extractor.py tests/unit/test_read_file_tool.py -q` passes. +- `ruff check src/agentkit/tools/file_read.py src/agentkit/tools/symbol_extractor.py` clean. +- `ReadFileTool` appears in `ToolRegistry.list_tools()` after registration. + +--- + +### U2. PhasePolicy + PhaseState + ServerConfig (G6 core) + +**Goal**: Add `PhasePolicy` dataclass, `PhaseState` enum, default whitelist config. Extend `ServerConfig.from_dict` with `plan_exec` section. Wire config to `agentkit.yaml`. + +**Requirements**: R25, R26. + +**Dependencies**: none. + +**Files**: +- `src/agentkit/core/phase.py` (new) — `PhaseState` enum, `PhasePolicy` dataclass, `default_policy()` factory. +- `src/agentkit/server/config.py` (modify — add `plan_exec` field + `from_dict` parsing). +- `agentkit.yaml` (modify — document `plan_exec:` section). +- `tests/unit/test_phase_policy.py` (new). + +**Approach**: +- `PhaseState = enum("planning building verification delivery")`. +- `PhasePolicy` carries: + - `whitelist: dict[PhaseState, set[str]]` — tool names allowed per phase. + - `bash_command_filter: dict[PhaseState, re.Pattern | None]` — regex that bash args must NOT match (e.g., `r"\b(rm|mv|>|>>)\b"` in Planning). + - `auto_advance_after_steps: int | None` — None = manual (LLM calls `advance_phase`); int = auto-advance after N steps. + - `start_phase: PhaseState = PhaseState.PLANNING`. +- `default_policy()` returns the KTD5 whitelist above. +- `ServerConfig.from_dict` reads `plan_exec` section: `enabled`, `whitelist_override` (dict), `auto_advance_after_steps`. +- `agentkit.yaml` gains a commented-out `plan_exec:` block (commented to preserve default behavior — opt-in). + +**Patterns to follow**: +- `src/agentkit/core/fallback.py` for dataclass + classmethod factory pattern. +- `src/agentkit/server/config.py` `from_dict` extension template (established in Wave 1 for `prompt_cache`/`streaming`/`verification`; Wave 2 added `rollback`/`fallback_chain`; Wave 3 adds `plan_exec`). + +**Test scenarios** (covers R25, R26): +- **Happy paths**: + - `default_policy()` returns policy with all four phases; Planning whitelist contains `search`, `read_file`; Building contains `write_file`. + - `PhasePolicy.is_tool_allowed("search", PhaseState.PLANNING)` returns True. + - `PhasePolicy.is_tool_allowed("write_file", PhaseState.PLANNING)` returns False. + - `PhasePolicy.is_tool_allowed("write_file", PhaseState.BUILDING)` returns True. +- **Edge cases**: + - Empty whitelist for a phase → all tools rejected (raises `ValueError` at construction time — fail-fast). + - `Delivery` phase whitelist contains `"*"` (wildcard) → all tools allowed. + - Custom whitelist override merges with default (override wins on conflict). +- **Error paths**: + - Invalid phase name in config → `ValueError` with message naming the bad value. + - `bash_command_filter` regex compile failure → `ValueError`. +- **Config integration**: + - `ServerConfig.from_dict({"plan_exec": {"enabled": True, "auto_advance_after_steps": 5}})` populates fields correctly. + - `ServerConfig.from_dict({})` → `plan_exec = {}` (default). + +**Verification**: +- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes. +- `ruff check src/agentkit/core/phase.py` clean. + +--- + +### U3. AdvancePhaseTool + ReActEngine phase enforcement (G6 wiring) + +**Goal**: Add `AdvancePhaseTool`. Wire `PhasePolicy` into `ReActEngine.execute` so each tool-call step checks `is_tool_allowed(tool_name, current_phase)` before dispatch; blocked calls return a structured error to the LLM ("Tool 'write_file' not allowed in Planning phase — call advance_phase first"). + +**Requirements**: R24. + +**Dependencies**: U2. + +**Files**: +- `src/agentkit/tools/advance_phase.py` (new) — `AdvancePhaseTool` calls `react_engine.advance_phase()`. +- `src/agentkit/core/react.py` (modify — add `phase_policy` param to `__init__` + `execute`; add `_current_phase` field; add `advance_phase()` method; enforce in `_execute_loop`). +- `tests/unit/test_react_phase_enforcement.py` (new). + +**Approach**: +- `ReActEngine.__init__` accepts `phase_policy: PhasePolicy | None = None`. None = no enforcement (backward compat — all existing callers unaffected). +- `_current_phase: PhaseState | None` initialized from `phase_policy.start_phase` if policy set, else None. +- `advance_phase()` advances `_current_phase` to next enum value; raises `ValueError` if already at `DELIVERY`. +- In `_execute_loop`, before dispatching a tool call: + ```python + if self._phase_policy is not None and self._current_phase is not None: + if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase): + # Inject structured error into conversation, do NOT dispatch tool. + # This counts as a "step" for max_steps purposes. + observation = { + "error": "phase_violation", + "message": f"Tool '{tool_name}' not allowed in {self._current_phase.value} phase", + "current_phase": self._current_phase.value, + "hint": "Call advance_phase to move to Building phase" + } + continue # next loop iteration + ``` +- Auto-advance: if `phase_policy.auto_advance_after_steps` is set and `_steps_in_phase >= auto_advance_after_steps`, call `advance_phase()` automatically. +- `AdvancePhaseTool.execute()` calls the bound engine's `advance_phase()` and returns the new phase name. Registered only when `phase_policy` is not None. + +**Execution note**: characterization-first — test that `ReActEngine` with `phase_policy=None` behaves identically to pre-change (no enforcement, no `advance_phase` tool, no `_current_phase` mutation). Then add enforcement tests. + +**Patterns to follow**: +- `src/agentkit/core/react.py` `verification_enabled` pattern (feature flag + step-level check). +- `src/agentkit/tools/ask_human.py` for tool that interacts with engine state. + +**Test scenarios** (covers R24): +- **Characterization (no policy)**: + - `ReActEngine(phase_policy=None)` — all tools allowed in all steps; no `advance_phase` tool registered; behavior matches pre-change. +- **Happy paths**: + - Planning phase: LLM calls `search` → executes; LLM calls `advance_phase` → phase becomes Building. + - Building phase: LLM calls `write_file` → executes; LLM calls `advance_phase` → phase becomes Verification. + - Verification phase: LLM calls `bash` with `pytest` → executes; LLM calls `advance_phase` → phase becomes Delivery. + - Delivery phase: LLM calls any tool → executes (wildcard). +- **Edge cases**: + - `advance_phase` called at Delivery → returns error "Already at final phase". + - Auto-advance after 3 steps in Planning → phase transitions automatically on 4th step. + - `bash` command in Planning contains `rm file` → blocked by `bash_command_filter`. +- **Error paths**: + - LLM calls `write_file` in Planning → tool NOT dispatched; structured error returned to LLM; loop continues. + - LLM calls non-existent tool → existing error path (not phase-related). +- **Integration scenarios**: + - Phase transition emits a `phase_changed` event (use existing `_broadcast_event` pattern from `experts/orchestrator.py`). + - `max_steps` reached mid-phase → `ReActResult.status = "max_steps_reached"` (existing path, no change). + +**Verification**: +- `python3 -m pytest tests/unit/test_react_phase_enforcement.py -q` passes. +- Existing `tests/unit/test_react_engine.py` still passes (characterization — no policy = no change). +- `ruff check src/agentkit/core/react.py src/agentkit/tools/advance_phase.py` clean. + +--- + +### U4. Wire PLAN_EXEC at chat.py WebSocket path (G6 chat integration) + +**Goal**: Replace the `chat.py:1084` "not yet supported, falling back to REACT" warning with a real PLAN_EXEC handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and dispatches to `ReActEngine.execute` with the policy set. + +**Requirements**: R24, R25 (end-to-end wiring). + +**Dependencies**: U2, U3. + +**Files**: +- `src/agentkit/server/routes/chat.py` (modify — add `_execute_plan_exec_ws` handler; branch on `ExecutionMode.PLAN_EXEC`). +- `tests/unit/test_chat_plan_exec_ws.py` (new). + +**Approach**: +- New helper `_execute_plan_exec_ws(websocket, agent, routing, messages, ...)`: + 1. Read `server_config.plan_exec` (may be `{}` if not configured → use `default_policy()`). + 2. Build `PhasePolicy` from config (apply overrides). + 3. Construct `ReActEngine(..., phase_policy=policy)`. + 4. Register `AdvancePhaseTool` bound to this engine. + 5. Call `engine.execute_stream(...)` — reuses existing streaming path. + 6. Emit `phase_changed` events through the WebSocket (frontend can render phase indicator). +- chat.py:1084 changes from `if execution_mode not in (REACT, SKILL_REACT): warn + fall back` to: + ```python + if routing.execution_mode == ExecutionMode.PLAN_EXEC: + await _execute_plan_exec_ws(websocket, agent, routing, ...) + return + if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT): + # existing warning for REWOO/REFLEXION/TEAM_COLLAB + ... + ``` +- REST `send_message` path: explicitly raise `HTTPException(501, "PLAN_EXEC via REST not yet supported; use WebSocket")` — Wave 3 does NOT wire REST (KTD4). + +**Execution note**: characterization-first — test that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring. + +**Patterns to follow**: +- `src/agentkit/server/routes/chat.py` existing WebSocket handler structure (lines 1082-1100). +- `src/agentkit/server/_fallback_chain.py` (Wave 2 U3) for "construct engine per-request with config" pattern. + +**Test scenarios** (covers end-to-end): +- **Characterization**: + - `ExecutionMode.REWOO` via WebSocket → still falls back to REACT with warning (existing behavior unchanged). + - `ExecutionMode.REFLEXION` → same. + - `ExecutionMode.TEAM_COLLAB` → same. +- **Happy paths**: + - `ExecutionMode.PLAN_EXEC` via WebSocket → `_execute_plan_exec_ws` invoked; `ReActEngine` constructed with `phase_policy`; `AdvancePhaseTool` registered. + - Planning phase: LLM emits `search` tool call → executed; tool result streamed. + - LLM emits `advance_phase` → `phase_changed` event sent to WebSocket client; subsequent `write_file` call now allowed. +- **Edge cases**: + - `plan_exec` config absent → `default_policy()` used; behavior matches KTD5 whitelist. + - `plan_exec.enabled=False` → falls back to REACT (opt-out). + - Phase violation: LLM calls `write_file` in Planning → structured error returned; loop continues; `phase_violation` event emitted. +- **Error paths**: + - REST `send_message` with PLAN_EXEC → 501 error. + - Phase policy construction fails (bad config) → 500 error with message. +- **Integration scenarios**: + - Existing fallback chain (Wave 2 U3) NOT applied to PLAN_EXEC — phase policy and fallback chain are mutually exclusive (KTD5 from Wave 2 plan: chain only wraps REACT/SKILL_REACT at REST). Document this in chat.py comment. + +**Verification**: +- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py -q` passes. +- `ruff check src/agentkit/server/routes/chat.py` clean. +- Manual test: `agentkit chat` with `@skill:plan_exec_demo` skill config → WebSocket stream includes `phase_changed` events. + +--- + +## Risks & Dependencies + +### Risks + +1. **ReAct core modification risk (high)**: U3 modifies `ReActEngine._execute_loop`. Mitigation: characterization-first tests (U3 Execution note); `phase_policy=None` default preserves all existing behavior; full `test_react_engine.py` regression. +2. **Symbol extraction accuracy (medium)**: Regex extractor may miss edge cases (decorated functions, nested generics, multi-line signatures). Mitigation: fall back to "no symbols found → read full file" gracefully; never raise on extraction failure. +3. **PLAN_EXEC phase deadlock (medium)**: LLM may never call `advance_phase`, leaving the agent stuck in Planning. Mitigation: `auto_advance_after_steps` config (default 5); timeout via existing `max_steps`. +4. **Tool name drift (low)**: Phase whitelist references tool names (`write_file`, `search`, etc.) that may be renamed in future. Mitigation: whitelist is config-driven; rename only requires config update. + +### Dependencies + +- Wave 2 PR #5 (`feat/agent-wave2-medium-coupling`) should be merged first — Wave 3 builds on the `ServerConfig.from_dict` extension pattern and the `_fallback_chain.py` integration shape established there. If PR #5 is still open, Wave 3 branches from `feat/agent-wave2-medium-coupling` rather than `main`. +- No external library dependencies (KTD1). + +## System-Wide Impact + +- **Agents using PLAN_EXEC mode**: gain phase enforcement. Existing REACT/SKILL_REACT/DIRECT_CHAT agents: zero change (phase_policy defaults to None). +- **Tool registry**: gains two new tools (`read_file`, `advance_phase`). Frontend tool list display may need updating to show the new icons — out of scope for Wave 3 (frontend follows up). +- **`agentkit.yaml`**: gains `plan_exec:` section (commented by default). Existing configs unaffected. +- **WebSocket clients**: gain `phase_changed` event type. Existing clients ignore unknown event types (verified in Wave 2 — `phase_rollback_*` events follow the same pattern). + +## Sources & Research + +- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 section, KTD6/KTD7). +- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (PR #4 merged). +- Wave 2 plan: `docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md` (PR #5 open). +- Trae Work architecture research (cited in brainstorm): SOLO four-stage state machine pattern. +- Qoder architecture research (cited in brainstorm): Spec→Coding→Verify closed loop. +- Codebase: `src/agentkit/core/react.py:148` reserves `read_file`/`write_file` tool names in `_DEFAULT_CORE_TOOLS` — Wave 3 U1 delivers the missing `read_file` implementation. +- Codebase: `src/agentkit/server/routes/chat.py:1084` documents that PLAN_EXEC is "not yet supported" — Wave 3 U4 closes this gap. + +## Deferred to Implementation + +- Exact regex patterns for non-Python symbol extraction (U1) — design above gives the shape; implementer finalizes patterns based on real-world test fixtures. +- `bash_command_filter` regex precision (U2) — defaults block `rm`/`mv`/`>`/`>>`; implementer may add more based on test scenarios. +- `phase_changed` event payload shape (U3/U4) — minimal viable shape: `{"phase": "building", "previous": "planning"}`; frontend rendering concerns are out of scope. +- Whether `AdvancePhaseTool` accepts a `target_phase` argument for skipping phases (e.g., Planning → Verification) — default no (sequential only); add if test scenarios reveal a need. diff --git a/src/agentkit/core/phase.py b/src/agentkit/core/phase.py new file mode 100644 index 0000000..0e5326a --- /dev/null +++ b/src/agentkit/core/phase.py @@ -0,0 +1,206 @@ +"""Phase state machine for PLAN_EXEC mode (G6, R24/R25). + +Four sequential phases enforce per-step tool whitelists: + PLANNING → BUILDING → VERIFICATION → DELIVERY + +KTD3 (Wave 3 plan): state machine lives in ReActEngine, not skill config. +KTD5: default whitelist matches brainstorm R24 (Planning: think/search; + Building: write_file; etc.). +KTD6: transitions are LLM-driven via AdvancePhaseTool; auto-advance is opt-in. +""" + +from __future__ import annotations + +import enum +import logging +import re +from dataclasses import dataclass, field, replace +from typing import Any + +logger = logging.getLogger(__name__) + + +class PhaseState(enum.Enum): + """Phases of the SOLO state machine (extends ExecutionMode.PLAN_EXEC).""" + + PLANNING = "planning" + BUILDING = "building" + VERIFICATION = "verification" + DELIVERY = "delivery" + + @classmethod + def next_of(cls, current: "PhaseState") -> "PhaseState | None": + """Return the phase after `current`, or None if `current` is the last.""" + order = [cls.PLANNING, cls.BUILDING, cls.VERIFICATION, cls.DELIVERY] + try: + idx = order.index(current) + except ValueError: + return None + if idx + 1 >= len(order): + return None + return order[idx + 1] + + @classmethod + def from_string(cls, value: str) -> "PhaseState": + """Parse from string (case-insensitive). Raises ValueError on unknown.""" + try: + return cls(value.lower()) + except ValueError as e: + valid = ", ".join(p.value for p in cls) + raise ValueError(f"Invalid phase name {value!r}. Valid: {valid}") from e + + +# Wildcard token meaning "all tools allowed in this phase". +WILDCARD = "*" + +# Default bash command filter for PLANNING and VERIFICATION phases — blocks +# commands that mutate the filesystem or execute arbitrary code. +# ponytail: regex is intentionally conservative; misses some shell idioms +# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch +# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time. +# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT +# for `>`/`>>` operators (not word chars). Use a non-boundary alternation +# that matches `>` either as a standalone operator or after whitespace. +_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?|>>") + + +@dataclass(slots=True) +class PhasePolicy: + """Per-phase tool whitelist + bash command filter for PLAN_EXEC mode. + + The policy is enforced by ReActEngine._execute_loop before each tool + dispatch. A tool not in the current phase's whitelist is rejected with + a structured error returned to the LLM (the loop continues — the LLM + gets to react to the rejection and either switch tools or call + AdvancePhaseTool). + + Wildcard ``"*"`` in a phase's whitelist means "all tools allowed" + (used by DELIVERY by default). + """ + + whitelist: dict[PhaseState, frozenset[str]] + bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict) + auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase) + start_phase: PhaseState = PhaseState.PLANNING + + def __post_init__(self) -> None: + # Fail-fast: empty whitelist for a non-wildcard phase = bug. + for phase, tools in self.whitelist.items(): + if not tools: + raise ValueError( + f"Phase {phase.value!r} has an empty whitelist — set ['*'] for " + f"'all tools allowed' or list specific tool names." + ) + + def is_tool_allowed(self, tool_name: str, phase: PhaseState) -> bool: + """Return True if `tool_name` is allowed in `phase`.""" + allowed = self.whitelist.get(phase, frozenset()) + if WILDCARD in allowed: + return True + return tool_name in allowed + + def is_bash_command_allowed(self, command: str, phase: PhaseState) -> bool: + """Return True if `command` passes the bash filter for `phase`. + + A None filter = no restriction. An empty command is allowed (ShellTool + separately rejects empty commands). + """ + pattern = self.bash_command_filter.get(phase) + if pattern is None: + return True + return not pattern.search(command) + + def to_dict(self) -> dict[str, Any]: + """Serialize for logging/telemetry. Not round-trippable (regex → str).""" + return { + "whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()}, + "bash_command_filter": { + phase.value: (p.pattern if p else None) + for phase, p in self.bash_command_filter.items() + }, + "auto_advance_after_steps": self.auto_advance_after_steps, + "start_phase": self.start_phase.value, + } + + +def default_policy() -> PhasePolicy: + """Return the KTD5 default PhasePolicy. + + Whitelist (R24): + - PLANNING: search, tool_search, read_file, shell (read-only) + - BUILDING: write_file, shell (full), read_file, search + - VERIFICATION: shell (test commands), read_file, search + - DELIVERY: all tools (wildcard) + + Bash filter: + - PLANNING/VERIFICATION: blocks filesystem-mutating commands + (rm/mv/cp/mkdir/chmod/chown/>/>>) + - BUILDING/DELIVERY: no filter (full bash) + """ + return PhasePolicy( + whitelist={ + # Tool name is "shell" (ShellTool default); bash_command_filter + # gates on the same name. Using "bash" here would make the filter + # dead code and block the LLM from shell access. + PhaseState.PLANNING: frozenset({"search", "tool_search", "read_file", "shell"}), + PhaseState.BUILDING: frozenset( + {"write_file", "shell", "read_file", "search", "tool_search"} + ), + PhaseState.VERIFICATION: frozenset({"shell", "read_file", "search"}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + }, + bash_command_filter={ + PhaseState.PLANNING: _DEFAULT_BASH_FILTER, + PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER, + PhaseState.BUILDING: None, + PhaseState.DELIVERY: None, + }, + auto_advance_after_steps=None, # manual by default + start_phase=PhaseState.PLANNING, + ) + + +def policy_from_config(config: dict[str, Any]) -> PhasePolicy | None: + """Build a PhasePolicy from the `plan_exec` config section. + + Returns None if `config` is empty or `enabled` is False (opt-out). + + Config shape: + plan_exec: + enabled: true # default true if section present + auto_advance_after_steps: 5 # optional + start_phase: planning # optional, default planning + whitelist_override: # optional, merges with default + planning: [search, read_file] + building: [write_file, bash] + """ + if not config: + return None + if config.get("enabled", True) is False: + return None + + policy = default_policy() + + # Start phase + start_phase_str = config.get("start_phase") + if start_phase_str: + policy = replace(policy, start_phase=PhaseState.from_string(start_phase_str)) + + # Auto-advance override + if "auto_advance_after_steps" in config: + policy = replace(policy, auto_advance_after_steps=config["auto_advance_after_steps"]) + + # Whitelist override — merge with default (override wins on conflict) + override = config.get("whitelist_override") or {} + if override: + new_whitelist = dict(policy.whitelist) + for phase_name, tools in override.items(): + phase = PhaseState.from_string(phase_name) + if not isinstance(tools, list): + raise ValueError( + f"whitelist_override[{phase_name!r}] must be a list, got {type(tools).__name__}" + ) + new_whitelist[phase] = frozenset(str(t) for t in tools) + policy = replace(policy, whitelist=new_whitelist) + + return policy diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 588bb75..166a5e3 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -28,6 +28,7 @@ from agentkit.telemetry.metrics import ( if TYPE_CHECKING: from agentkit.core.compressor import CompressionStrategy from agentkit.core.middleware import MiddlewareChain + from agentkit.core.phase import PhasePolicy, PhaseState from agentkit.core.trace import TraceRecorder from agentkit.memory.retriever import MemoryRetriever @@ -168,6 +169,9 @@ class ReActEngine: prompt_cache_enable: bool = True, flush_interval_ms: int = 0, max_reinjections: int = 1, + # U3/G6: PLAN_EXEC phase policy (opt-in). None = no enforcement + # (backward compat — all existing callers unaffected). + phase_policy: "PhasePolicy | None" = None, ): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") @@ -211,6 +215,15 @@ class ReActEngine: self._loop_corrected: bool = False # U6: Middleware chain (parallel integration, feature flag controlled) self._middleware_chain = middleware_chain + # U3/G6: PLAN_EXEC phase state. None = no enforcement (default). + # When set, _execute_loop checks each tool call against the current + # phase's whitelist before dispatch. + self._phase_policy = phase_policy + self._current_phase: "PhaseState | None" = ( + phase_policy.start_phase if phase_policy is not None else None + ) + # Steps taken in the current phase (for auto-advance safety net). + self._steps_in_phase: int = 0 def reset(self) -> None: """Reset internal state for reuse across conversations. @@ -223,6 +236,99 @@ class ReActEngine: # This method exists for API clarity and future stateful extensions. self._loop_window.clear() self._loop_corrected = False + # U3/G6: reset phase state to start_phase (if policy set). Each + # execute() call begins a fresh PLANNING phase. + if self._phase_policy is not None: + self._current_phase = self._phase_policy.start_phase + self._steps_in_phase = 0 + + # ── U3/G6: phase state machine ──────────────────────────────────── + + def advance_phase(self) -> "PhaseState | None": + """Advance to the next phase. Returns the new phase, or None if + already at DELIVERY (final phase). + + Called by AdvancePhaseTool when the LLM explicitly signals phase + completion. Also called by the auto-advance safety net when + ``steps_in_phase >= auto_advance_after_steps``. + + Returns None if no phase_policy is set (no-op). + """ + if self._phase_policy is None or self._current_phase is None: + return None + from agentkit.core.phase import PhaseState + + nxt = PhaseState.next_of(self._current_phase) + if nxt is None: + # Already at DELIVERY — return None to signal no transition. + return None + previous = self._current_phase + self._current_phase = nxt + self._steps_in_phase = 0 + logger.info( + "Phase transition: %s → %s", + previous.value, + nxt.value, + ) + return nxt + + @property + def current_phase(self) -> "PhaseState | None": + """Current phase (None if no phase_policy set).""" + return self._current_phase + + def _maybe_auto_advance(self) -> bool: + """Auto-advance phase if step budget exhausted. Returns True if advanced.""" + if self._phase_policy is None or self._current_phase is None: + return False + threshold = self._phase_policy.auto_advance_after_steps + if threshold is None: + return False + if self._steps_in_phase >= threshold: + self.advance_phase() + return True + return False + + def _check_phase_permission( + self, tool_name: str, arguments: dict[str, Any] + ) -> dict[str, Any] | None: + """Return None if tool is allowed; return a structured error dict if blocked. + + The error dict replaces what `_execute_tool` would have returned — + the loop continues, so the LLM can react to the rejection (call + AdvancePhaseTool or pick a different tool). + + Also applies the bash_command_filter for `bash` tool calls. + """ + if self._phase_policy is None or self._current_phase is None: + return None + if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase): + return { + "error": "phase_violation", + "message": ( + f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. " + f"Call `advance_phase` to move to the next phase." + ), + "current_phase": self._current_phase.value, + "tool": tool_name, + "is_error": True, + } + # Bash command filter (only applies to shell tool — registered as "shell"). + if tool_name == "shell": + command = str(arguments.get("command", "")) + if not self._phase_policy.is_bash_command_allowed(command, self._current_phase): + return { + "error": "phase_violation", + "message": ( + f"Bash command blocked in {self._current_phase.value} phase " + f"(filesystem-mutating operations not allowed during " + f"planning/verification). Command: {command[:100]}" + ), + "current_phase": self._current_phase.value, + "tool": tool_name, + "is_error": True, + } + return None def _check_tool_loop(self, tool_calls: list[Any]) -> str | None: """检测重复工具调用模式。 @@ -498,6 +604,14 @@ class ReActEngine: if cancellation_token is not None: cancellation_token.check() + # U3/G6: phase auto-advance safety net. + # Incremented per step (LLM call), not per tool_call. When + # auto_advance_after_steps is set, advance the phase after + # the LLM has been stuck in the same phase for N steps. + if self._phase_policy is not None: + self._steps_in_phase += 1 + self._maybe_auto_advance() + # Think: 调用 LLM llm_start = time.monotonic() response = await self._llm_gateway.chat( @@ -1148,6 +1262,11 @@ class ReActEngine: if cancellation_token is not None: cancellation_token.check() + # U3/G6: phase auto-advance safety net (mirrors _execute_loop). + if self._phase_policy is not None: + self._steps_in_phase += 1 + self._maybe_auto_advance() + # 超时检查 if effective_timeout > 0: elapsed = time.monotonic() - _stream_start @@ -2069,6 +2188,20 @@ class ReActEngine: self, tool_name: str, arguments: dict[str, Any], tools: list[Tool] ) -> dict: """执行工具调用,处理成功和失败情况""" + # U3/G6: phase enforcement — check before dispatch. If the tool is + # blocked, return a structured error instead of dispatching. The loop + # still counts this as a step (the LLM gets to react to the rejection). + # `advance_phase` tool bypasses the check (it's the LLM's escape hatch). + if tool_name != "advance_phase": + block = self._check_phase_permission(tool_name, arguments) + if block is not None: + logger.info( + "Phase violation: tool %r blocked in %s phase", + tool_name, + self._current_phase.value if self._current_phase else "?", + ) + return block + tool = self._find_tool(tool_name, tools) if tool is None: error_msg = f"Tool '{tool_name}' not found" diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index fcfe979..da2ee38 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -121,6 +121,9 @@ class ServerConfig: verification: dict[str, Any] | None = None, rollback: dict[str, Any] | None = None, fallback_chain: dict[str, Any] | None = None, + # G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled). + # Parsed via PhasePolicy.policy_from_config() at chat.py wiring time. + plan_exec: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -161,6 +164,10 @@ class ServerConfig: # G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries} # controls three-tier chain at chat.py REST send_message (KTD5). self.fallback_chain = fallback_chain or {} + # G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled). + # Resolved to PhasePolicy via agentkit.core.phase.policy_from_config() + # at chat.py WebSocket wiring time (U4). + self.plan_exec = plan_exec or {} self.on_change = on_change # Config watching state @@ -252,6 +259,8 @@ class ServerConfig: rollback_data = data.get("rollback", {}) # G7/U3: fallback_chain 配置 (从 YAML 读取) fallback_chain_data = data.get("fallback_chain", {}) + # G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in) + plan_exec_data = data.get("plan_exec", {}) return cls( host=server.get("host", "0.0.0.0"), @@ -285,6 +294,7 @@ class ServerConfig: verification=verification_data, rollback=rollback_data, fallback_chain=fallback_chain_data, + plan_exec=plan_exec_data, ) @staticmethod diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 6737740..4b7be7f 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -25,11 +25,13 @@ from fastapi.responses import FileResponse from pydantic import BaseModel from agentkit.chat.skill_routing import ExecutionMode +from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine from agentkit.server._fallback_chain import execute_with_fallback_chain from agentkit.session.manager import SessionManager from agentkit.session.models import MessageRole, SessionStatus +from agentkit.tools.advance_phase import AdvancePhaseTool logger = logging.getLogger(__name__) @@ -47,6 +49,8 @@ class CreateSessionRequest(BaseModel): class SendMessageRequest(BaseModel): content: str role: str = "user" + # Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only). + execution_mode: str | None = None class SessionResponse(BaseModel): @@ -583,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques if session.status == SessionStatus.CLOSED: raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") + # KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501. + if request.execution_mode == "plan_exec": + raise HTTPException( + status_code=501, + detail="PLAN_EXEC via REST not yet supported; use WebSocket", + ) + # Append user message await sm.append_message( session_id=session_id, @@ -1079,21 +1090,73 @@ async def _handle_chat_message( await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) return - # Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB - # currently fall back to REACT with a warning. - if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT): + # U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only). + # KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and + # fallback chain are mutually exclusive. PLAN_EXEC uses its own engine. + phase_policy: PhasePolicy | None = None + if routing.execution_mode == ExecutionMode.PLAN_EXEC: + server_config = getattr(websocket.app.state, "server_config", None) + plan_exec_cfg = getattr(server_config, "plan_exec", None) or {} + + if plan_exec_cfg.get("enabled", True) is False: + # Explicit opt-out → fall back to REACT. + logger.info( + "PLAN_EXEC disabled by config (plan_exec.enabled=False), " + "falling back to REACT for session %s", + session_id, + ) + else: + try: + phase_policy = policy_from_config(plan_exec_cfg) + if phase_policy is None: + # Empty config (no `plan_exec:` section) → use KTD5 defaults. + phase_policy = default_policy() + except Exception as e: + logger.error( + "PLAN_EXEC phase policy construction failed for session %s: %s", + session_id, + e, + ) + await websocket.send_json( + { + "type": "error", + # Truncate to 200 chars to match nearby error paths and + # avoid leaking config internals (see chat.py:1090, 1320). + "data": {"message": f"phase policy error: {str(e)[:200]}"}, + } + ) + return + + # Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB + # still fall back to REACT with a warning. PLAN_EXEC is handled above. + if routing.execution_mode not in ( + ExecutionMode.REACT, + ExecutionMode.SKILL_REACT, + ExecutionMode.PLAN_EXEC, + ): logger.warning( f"Execution mode {routing.execution_mode.value} not yet supported " f"in chat WebSocket, falling back to REACT" ) # Execute Agent with streaming - # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) - react_engine = getattr(agent, "_react_engine", None) - if react_engine is None: - react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization). + # PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the + # agent's _react_engine — it has no policy). + if phase_policy is not None: + react_engine = ReActEngine( + llm_gateway=websocket.app.state.llm_gateway, + phase_policy=phase_policy, + ) + # Register AdvancePhaseTool bound to this engine (LLM's escape hatch). + advance_phase_tool = AdvancePhaseTool(engine=react_engine) + routing.tools = list(routing.tools) + [advance_phase_tool] else: - react_engine.reset() + react_engine = getattr(agent, "_react_engine", None) + if react_engine is None: + react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + else: + react_engine.reset() # Create confirmation handler that sends request to frontend and waits for reply # Use the same dict object — do NOT use `or {}` because an empty dict is falsy @@ -1149,6 +1212,9 @@ async def _handle_chat_message( try: final_content = "" token_buffer: list[str] = [] + # Track phase transitions for phase_changed events (PLAN_EXEC only). + # For non-PLAN_EXEC modes, current_phase is always None → no events. + prev_phase = react_engine.current_phase async for event in react_engine.execute_stream( messages=chat_messages, tools=routing.tools, @@ -1226,6 +1292,22 @@ async def _handle_chat_message( } ) + # U4/G6: emit phase_changed event when the phase state machine + # transitions (PLAN_EXEC only). For non-PLAN_EXEC modes, + # current_phase is always None → this branch never fires. + curr_phase = react_engine.current_phase + if curr_phase != prev_phase: + await websocket.send_json( + { + "type": "phase_changed", + "data": { + "phase": curr_phase.value if curr_phase else None, + "previous": prev_phase.value if prev_phase else None, + }, + } + ) + prev_phase = curr_phase + # Append assistant reply to session if final_content: await sm.append_message( diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 315b63e..ea54dcf 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -18,6 +18,8 @@ from agentkit.tools.memory_tool import MemoryTool from agentkit.tools.web_search import WebSearchTool from agentkit.tools.builtin import RunTestsTool, ToolSearchTool from agentkit.tools.search import ToolSearchIndex +from agentkit.tools.file_read import ReadFileTool +from agentkit.tools.advance_phase import AdvancePhaseTool # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -52,4 +54,6 @@ __all__ = [ "OutputParser", "ParsedOutput", "ErrorType", + "ReadFileTool", + "AdvancePhaseTool", ] diff --git a/src/agentkit/tools/advance_phase.py b/src/agentkit/tools/advance_phase.py new file mode 100644 index 0000000..b750d3f --- /dev/null +++ b/src/agentkit/tools/advance_phase.py @@ -0,0 +1,82 @@ +"""AdvancePhaseTool — LLM-driven phase transition (G6, KTD6). + +Registered alongside other tools when ReActEngine has a phase_policy set. +The LLM calls this tool to signal "I'm done planning, move to building". +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from agentkit.tools.base import Tool + +if TYPE_CHECKING: + from agentkit.core.react import ReActEngine + +logger = logging.getLogger(__name__) + + +class AdvancePhaseTool(Tool): + """Tool that advances the ReActEngine's current phase. + + KTD6: LLM-driven phase transitions. Auto-advance is opt-in via + ``plan_exec.auto_advance_after_steps``; this tool is the manual path. + + The tool holds a weak reference to the engine (via bound method + ``engine.advance_phase``) — registered only when phase_policy is set. + """ + + def __init__( + self, + engine: "ReActEngine", + name: str = "advance_phase", + description: str | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description + or ( + "Advance the PLAN_EXEC phase state machine to the next phase " + "(Planning → Building → Verification → Delivery). Call this " + "when you have finished the current phase's work and are ready " + "to move on. Returns the new phase name or an error if you " + "are already at the final (Delivery) phase." + ), + input_schema={ + "type": "object", + "properties": {}, + "additionalProperties": False, + }, + version=version, + tags=tags or ["phase", "control"], + ) + self._engine = engine + + async def execute(self, **kwargs) -> dict[str, Any]: + # Capture previous phase before transition (engine is single-threaded per request). + previous = self._engine.current_phase + new_phase = self._engine.advance_phase() + if new_phase is None: + # Either no policy set, or already at DELIVERY. + current = self._engine.current_phase + if current is None: + return { + "is_error": True, + "error": "no_phase_policy", + "message": "No phase policy is set — advance_phase is a no-op.", + } + return { + "is_error": True, + "error": "already_at_final_phase", + "message": (f"Already at final phase ({current.value}). Cannot advance further."), + "current_phase": current.value, + } + return { + "is_error": False, + "previous_phase": previous.value if previous else "", + "current_phase": new_phase.value, + "message": f"Phase advanced to {new_phase.value}.", + } diff --git a/src/agentkit/tools/file_read.py b/src/agentkit/tools/file_read.py new file mode 100644 index 0000000..1f70c96 --- /dev/null +++ b/src/agentkit/tools/file_read.py @@ -0,0 +1,262 @@ +"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23). + +Backward compatible with the pre-existing `_FakeTool` benchmark shape — when +`symbol=None`, returns the full file content. When `symbol="foo"`, returns +the line range of the first matching symbol via `SymbolExtractor`. + +KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the +file-reading contract clean and gives the LLM a focused schema. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from agentkit.tools.base import Tool +from agentkit.tools.symbol_extractor import ( + SymbolSpan, + extract_symbols_from_file, + language_for_extension, +) + +logger = logging.getLogger(__name__) + + +class ReadFileTool(Tool): + """Read a file from the filesystem, optionally sliced to a single symbol. + + Tool name `read_file` matches the reserved entry in + `core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real + implementation — only `_FakeTool` stubs in `cli/benchmark.py`). + + Backward-compat contract: `symbol=None` returns the full file content, + matching the shape `{"path": ...}` that downstream callers (benchmark, + phase whitelist) already expect. + """ + + def __init__( + self, + name: str = "read_file", + description: str | None = None, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description + or ( + "Read a file from the filesystem. By default returns the full file " + "content. Pass `symbol` (function/class/struct name) to slice to just " + "that symbol's line range — saves context when you only need one " + "function from a large file. Pass `start_line`/`end_line` for manual " + "slicing. If `symbol` is set but not found, returns the available " + "symbol names so you can retry." + ), + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["io", "file", "read"], + ) + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to read (absolute or relative to cwd).", + }, + "symbol": { + "type": "string", + "description": ( + "Optional: name of a function/class/struct/method to slice to. " + "When set, returns only the line range of the first matching " + "symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java." + ), + }, + "start_line": { + "type": "integer", + "description": "Optional 1-based start line for manual slicing. Overrides `symbol`.", + "minimum": 1, + }, + "end_line": { + "type": "integer", + "description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.", + "minimum": 1, + }, + }, + "required": ["path"], + "additionalProperties": False, + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "content": {"type": "string"}, + "path": {"type": "string"}, + "start_line": {"type": "integer"}, + "end_line": {"type": "integer"}, + "symbol": {"type": "string"}, + "available_symbols": { + "type": "array", + "items": {"type": "string"}, + "description": "Populated when `symbol` is set but not found.", + }, + "note": {"type": "string"}, + "is_error": {"type": "boolean"}, + "error": {"type": "string"}, + }, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + raw_path = kwargs.get("path") + if not raw_path: + return self._error("`path` is required") + + path = Path(raw_path) + if not path.is_absolute(): + path = path.resolve() + + symbol = kwargs.get("symbol") + start_line = kwargs.get("start_line") + end_line = kwargs.get("end_line") + + # Validate/sanitize line overrides. + if start_line is not None and (not isinstance(start_line, int) or start_line < 1): + return self._error(f"`start_line` must be a positive integer, got {start_line!r}") + if end_line is not None and (not isinstance(end_line, int) or end_line < 1): + return self._error(f"`end_line` must be a positive integer, got {end_line!r}") + if start_line is not None and end_line is not None and end_line < start_line: + return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})") + + # Filesystem checks. + if not path.exists(): + return self._error(f"File not found: {path}", path=str(path)) + if path.is_dir(): + return self._error(f"Path is a directory, not a file: {path}", path=str(path)) + + try: + content = path.read_text(encoding="utf-8", errors="replace") + except PermissionError as e: + return self._error(f"Permission denied: {path}", path=str(path), detail=str(e)) + except OSError as e: + return self._error(f"Failed to read {path}: {e}", path=str(path)) + + lines = content.splitlines() + total_lines = len(lines) + + # Manual slicing takes precedence over symbol (per plan U1 Approach). + if start_line is not None or end_line is not None: + s = max(1, start_line or 1) + e = min(total_lines, end_line or total_lines) + sliced = "\n".join(lines[s - 1 : e]) + return { + "content": sliced, + "path": str(path), + "start_line": s, + "end_line": e, + "total_lines": total_lines, + "is_error": False, + } + + # Symbol slicing. + if symbol: + ext = path.suffix.lower() + language = language_for_extension(ext) + if not language: + # Unsupported extension: return full file with note (per plan U1 Edge case). + return { + "content": content, + "path": str(path), + "start_line": 1, + "end_line": total_lines, + "total_lines": total_lines, + "note": f"symbol extraction not supported for {ext or 'unknown extension'}", + "is_error": False, + } + + spans, _lang = extract_symbols_from_file(path) + # Re-extract using the content we already read so we don't read the file twice. + if not spans: + # Try extraction from in-memory content (path-based extraction may + # have failed silently on OSError; we already read it successfully). + from agentkit.tools.symbol_extractor import get_extractor + + extractor = get_extractor(language) + if extractor is not None: + spans = extractor.extract_symbols(content, language) + + match = _find_symbol(spans, symbol) + if match is None: + available = sorted({s.name for s in spans}) + return { + "content": "", + "path": str(path), + "symbol": symbol, + "available_symbols": available, + "is_error": False, + "note": ( + f"Symbol {symbol!r} not found in {path.name}. " + f"Available: {', '.join(available) if available else '(none)'}" + ), + } + + s = match.start_line + e = min(match.end_line, total_lines) + sliced = "\n".join(lines[s - 1 : e]) + return { + "content": sliced, + "path": str(path), + "symbol": symbol, + "symbol_kind": match.kind, + "start_line": s, + "end_line": e, + "total_lines": total_lines, + "is_error": False, + } + + # Default: full file (characterization baseline — matches _FakeTool shape). + return { + "content": content, + "path": str(path), + "start_line": 1, + "end_line": total_lines, + "total_lines": total_lines, + "is_error": False, + } + + @staticmethod + def _error( + message: str, *, path: str | None = None, detail: str | None = None + ) -> dict[str, Any]: + result: dict[str, Any] = { + "content": "", + "is_error": True, + "error": message, + } + if path is not None: + result["path"] = path + if detail is not None: + result["detail"] = detail + return result + + +def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None: + """Find the first symbol matching `name`. Case-sensitive. + + ponytail: linear scan is fine for typical file symbol counts (<100). The + extractor already returns symbols sorted by start_line; first match wins + for ambiguous overloads (e.g., Python classes with same name in different + modules — not relevant within one file). + """ + for span in spans: + if span.name == name: + return span + return None diff --git a/src/agentkit/tools/symbol_extractor.py b/src/agentkit/tools/symbol_extractor.py new file mode 100644 index 0000000..93f64d3 --- /dev/null +++ b/src/agentkit/tools/symbol_extractor.py @@ -0,0 +1,278 @@ +"""Symbol extraction — locate code symbols (functions/classes/structs) by name. + +KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex +for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The +`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor +can replace RegexSymbolExtractor behind the same interface. + +ponytail: regex extractor covers ~80% case (top-level function/class/struct +declarations). Ceiling: misses nested signatures inside JSX/TSX generics, +multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter. +""" + +from __future__ import annotations + +import ast +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class SymbolSpan: + """A located symbol — name, kind, and 1-based inclusive line range.""" + + name: str + kind: str # "function" | "class" | "method" | "struct" | "impl" + start_line: int # 1-based, inclusive + end_line: int # 1-based, inclusive + + +@runtime_checkable +class SymbolExtractor(Protocol): + """Protocol for symbol extractors — runtime_checkable for isinstance/issubclass.""" + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + """Return all symbols found in `content`. + + `language` is the file extension without leading dot (e.g. "py", "ts"). + Implementations must never raise on extraction failure — return [] on + parse errors and let the caller decide the fallback (full-file read). + """ + ... + + +# --------------------------------------------------------------------------- +# Python — stdlib ast +# --------------------------------------------------------------------------- + + +class AstSymbolExtractor: + """Python symbol extractor using the stdlib `ast` module. + + Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested + functions inside classes. The end_line is the last line of the node's + source segment (decorator-inclusive). + """ + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + if language != "py": + return [] + try: + tree = ast.parse(content) + except SyntaxError as e: + logger.debug("ast.parse failed: %s", e) + return [] + + lines = content.splitlines() + spans: list[SymbolSpan] = [] + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + kind = "method" if _is_method(node) else "function" + spans.append(_span_from_node(node, kind, lines)) + elif isinstance(node, ast.ClassDef): + spans.append(_span_from_node(node, "class", lines)) + return spans + + +def _is_method(node: ast.AST) -> bool: + """A FunctionDef is a method if its parent is a ClassDef. + + `ast.walk` doesn't expose parentage, so we approximate by checking the + node's col_offset == 4 (indented inside a class body). ponytail: this + misses methods in deeply nested classes — ceiling noted; upgrade path = + ast.NodeVisitor with parent tracking. + """ + return getattr(node, "col_offset", 0) > 0 + + +def _span_from_node( + node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef, + kind: str, + lines: list[str], +) -> SymbolSpan: + # ast line numbers are 1-based; start at decorator if present (lineno points + # to the def/class keyword, decorators are above). Use node.lineno for start + # so the returned range matches what the user sees at the def keyword. + start = node.lineno + # node.end_lineno is the last line of the node body (None on old Pythons). + end = node.end_lineno or start + # Clamp to actual file length (defensive — ast should not exceed, but + # malformed files with no trailing newline can confuse end_lineno). + if end > len(lines): + end = len(lines) + return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end) + + +# --------------------------------------------------------------------------- +# Regex extractor — TS/JS/Go/Rust/Java +# --------------------------------------------------------------------------- + +# Each pattern matches a declaration and captures the symbol name in group 1. +# Patterns use re.MULTILINE so ^ matches line starts. +_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = { + "ts": [ + ( + "function", + re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE), + ), + ("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)), + ( + "function", + re.compile( + r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", + re.MULTILINE, + ), + ), + ], + "js": [ + ("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)), + ("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)), + ( + "function", + re.compile( + r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE + ), + ), + ], + "go": [ + ("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)), + ("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)), + ], + "rs": [ + ("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)), + ("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)), + ("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)), + ], + "java": [ + ( + "function", + re.compile( + r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{", + re.MULTILINE, + ), + ), + ( + "class", + re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE), + ), + ], +} + + +class RegexSymbolExtractor: + """Language-aware regex symbol extractor for TS/JS/Go/Rust/Java. + + Returns SymbolSpans whose end_line is approximated by the next blank line + or next-symbol start (whichever comes first). ponytail: this is an + approximation — true block-end requires language-aware brace matching. + Ceiling: deeply nested blocks may over-extend the range. Upgrade path = + tree-sitter. + """ + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + patterns = _REGEX_PATTERNS.get(language) + if not patterns: + return [] + + lines = content.splitlines() + # Collect (line_no, name, kind) tuples first, then compute end_line + # as the line before the next symbol starts (or EOF). + raw_hits: list[tuple[int, str, str]] = [] + for kind, pattern in patterns: + for m in pattern.finditer(content): + # Convert match offset to 1-based line number. + line_no = content[: m.start()].count("\n") + 1 + raw_hits.append((line_no, m.group(1), kind)) + + if not raw_hits: + return [] + + # Deduplicate: same (line_no, name) may appear for overlapping patterns. + seen: set[tuple[int, str]] = set() + unique: list[tuple[int, str, str]] = [] + for line_no, name, kind in raw_hits: + key = (line_no, name) + if key in seen: + continue + seen.add(key) + unique.append((line_no, name, kind)) + + unique.sort(key=lambda x: x[0]) + + spans: list[SymbolSpan] = [] + for i, (start_line, name, kind) in enumerate(unique): + if i + 1 < len(unique): + # End at line before next symbol starts, capped at file length. + end_line = unique[i + 1][0] - 1 + else: + end_line = len(lines) + if end_line < start_line: + end_line = start_line + spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line)) + return spans + + +# --------------------------------------------------------------------------- +# Dispatch by file extension +# --------------------------------------------------------------------------- + +_EXTENSION_LANGUAGE: dict[str, str] = { + ".py": "py", + ".ts": "ts", + ".tsx": "ts", + ".js": "js", + ".jsx": "js", + ".mjs": "js", + ".cjs": "js", + ".go": "go", + ".rs": "rs", + ".java": "java", +} + +_DEFAULT_EXTRACTOR = AstSymbolExtractor() +_REGEX_EXTRACTOR = RegexSymbolExtractor() + + +def language_for_extension(ext: str) -> str: + """Return the language key for a file extension (with or without leading dot). + + Returns "" for unsupported extensions. + """ + if not ext.startswith("."): + ext = "." + ext + return _EXTENSION_LANGUAGE.get(ext.lower(), "") + + +def get_extractor(language: str) -> SymbolExtractor | None: + """Return the appropriate extractor for `language`, or None if unsupported.""" + if language == "py": + return _DEFAULT_EXTRACTOR + if language in _REGEX_PATTERNS: + return _REGEX_EXTRACTOR + return None + + +def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]: + """Read a file and return (symbols, language). + + Returns ([], "") if the extension is unsupported or the file cannot be read. + Never raises — callers use this for fallback routing. + """ + ext = path.suffix.lower() + language = language_for_extension(ext) + if not language: + return [], "" + try: + content = path.read_text(encoding="utf-8", errors="replace") + except OSError as e: + logger.debug("read failed for %s: %s", path, e) + return [], language + extractor = get_extractor(language) + if extractor is None: + return [], language + return extractor.extract_symbols(content, language), language diff --git a/tests/unit/test_chat_plan_exec_ws.py b/tests/unit/test_chat_plan_exec_ws.py new file mode 100644 index 0000000..84fe358 --- /dev/null +++ b/tests/unit/test_chat_plan_exec_ws.py @@ -0,0 +1,531 @@ +"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4). + +Per plan U4 Execution note: characterization-first — verify that existing +REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning +(no regression). Then add PLAN_EXEC wiring tests. + +KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult +from agentkit.core.phase import PhaseState +from agentkit.tools.advance_phase import AdvancePhaseTool + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def app_with_chat(): + """Create a FastAPI app with Chat routes and mocked dependencies.""" + from fastapi import FastAPI + + from agentkit.server.routes.chat import router + + app = FastAPI() + app.include_router(router, prefix="/api/v1") + + from agentkit.session.manager import SessionManager + from agentkit.session.store import InMemorySessionStore + + app.state.session_manager = SessionManager(store=InMemorySessionStore()) + app.state.llm_gateway = MagicMock() + app.state.agent_pool = MagicMock() + app.state.server_config = MagicMock() + app.state.server_config.api_key = None + app.state.server_config.plan_exec = {} + return app + + +@pytest.fixture +def client(app_with_chat): + return TestClient(app_with_chat) + + +def _make_routing( + execution_mode: ExecutionMode = ExecutionMode.REACT, + tools: list | None = None, +) -> SkillRoutingResult: + """Build a minimal SkillRoutingResult for testing.""" + return SkillRoutingResult( + execution_mode=execution_mode, + tools=tools or [], + clean_content="test message", + model="default", + agent_name="test-agent", + system_prompt=None, + skill_name=None, + ) + + +def _make_websocket_mock(app) -> MagicMock: + """Build a mock WebSocket with app.state and async send_json.""" + ws = MagicMock() + ws.app = app + ws.send_json = AsyncMock() + return ws + + +def _make_agent_mock() -> MagicMock: + """Build a mock Agent with _tool_registry and _react_engine.""" + agent = MagicMock() + agent.name = "test-agent" + agent._tool_registry = MagicMock() + agent._tool_registry.list_tools.return_value = [] + agent._system_prompt = None + # _react_engine is None to force the code path that creates a new engine + agent._react_engine = None + agent.get_model.return_value = "default" + return agent + + +def _make_session_manager_mock() -> MagicMock: + """Build a mock SessionManager with async methods.""" + sm = MagicMock() + # get_session returns a mock session with agent_name="test-agent" + session = MagicMock() + session.agent_name = "test-agent" + session.status = "active" + sm.get_session = AsyncMock(return_value=session) + sm.get_chat_messages = AsyncMock(return_value=[]) + sm.append_message = AsyncMock() + return sm + + +def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None: + """Wire up app.state so _handle_chat_message finds the right routing.""" + app.state.agent_pool.get_agent.return_value = agent + app.state.request_preprocessor = MagicMock() + app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing) + + +# --------------------------------------------------------------------------- +# REST — PLAN_EXEC raises 501 (KTD4) +# --------------------------------------------------------------------------- + + +class TestRestPlanExec501: + def test_rest_plan_exec_returns_501(self, client): + """REST send_message with execution_mode=plan_exec → 501.""" + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello", "execution_mode": "plan_exec"}, + ) + assert msg_resp.status_code == 501 + assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"] + + def test_rest_react_mode_still_works(self, client): + """REST send_message without execution_mode doesn't 501.""" + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + # No execution_mode field → should NOT trigger 501. + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello"}, + ) + assert msg_resp.status_code != 501 + + +# --------------------------------------------------------------------------- +# Characterization — REWOO still falls back to REACT (no regression) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat): + """Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT).""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REWOO) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_engine_kwargs: dict = {} + + class _StubEngine: + def __init__(self, **kwargs): + captured_engine_kwargs.update(kwargs) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = None + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield # async generator marker + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + # REWOO should NOT build a phase_policy + assert captured_engine_kwargs.get("phase_policy") is None + + +# --------------------------------------------------------------------------- +# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool( + app_with_chat, +): + """PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered.""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}]) + ws = _make_websocket_mock(app_with_chat) + + captured_engine: list = [] + captured_tools: list = [] + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = ( + kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None + ) + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + captured_tools.extend(kwargs.get("tools", [])) + captured_engine.append(self) + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + assert len(captured_engine) == 1 + engine = captured_engine[0] + assert engine._phase_policy is not None + assert engine._current_phase == PhaseState.PLANNING + # AdvancePhaseTool was registered in the tools list + assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_empty_config_uses_default_policy(app_with_chat): + """Edge: plan_exec config absent (empty dict) → default_policy() used.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_policy: list = [] + + class _StubEngine: + def __init__(self, **kwargs): + captured_policy.append(kwargs.get("phase_policy")) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = ( + kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None + ) + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + assert len(captured_policy) == 1 + assert captured_policy[0] is not None + # Default policy: PLANNING allows search but not write_file + assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING] + assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING] + + +@pytest.mark.asyncio +async def test_plan_exec_disabled_falls_back_to_react(app_with_chat): + """Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy).""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {"enabled": False} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_engine_kwargs: dict = {} + + class _StubEngine: + def __init__(self, **kwargs): + captured_engine_kwargs.update(kwargs) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = None + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + # enabled=False → no phase_policy (falls back to REACT) + assert captured_engine_kwargs.get("phase_policy") is None + + +# --------------------------------------------------------------------------- +# Error path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat): + """Error: phase policy construction fails → error event sent, returns early.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + engine_constructor_called = [] + + class _StubEngine: + def __init__(self, **kwargs): + engine_constructor_called.append(kwargs) + + async def execute_stream(self, **kwargs): + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + error_messages = [m for m in sent_messages if m.get("type") == "error"] + assert len(error_messages) == 1 + assert "phase policy error" in error_messages[0]["data"]["message"] + # Engine constructor was NOT called (returned early) + assert len(engine_constructor_called) == 0 + + +# --------------------------------------------------------------------------- +# phase_changed event emission +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_phase_changed_event_emitted_on_transition(app_with_chat): + """phase_changed event sent when current_phase changes during execute_stream.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = PhaseState.PLANNING + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + yield ReActEvent( + event_type="tool_call", + step=1, + data={"tool": "search", "output": "ok"}, + ) + # Simulate phase transition (as if AdvancePhaseTool was called) + self._current_phase = PhaseState.BUILDING + yield ReActEvent( + event_type="final_answer", + step=2, + data={"output": "done"}, + ) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="go", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] + assert len(phase_events) == 1 + assert phase_events[0]["data"]["phase"] == "building" + assert phase_events[0]["data"]["previous"] == "planning" + + +@pytest.mark.asyncio +async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat): + """Characterization: REACT mode → no phase_changed events.""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REACT) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = None + self._current_phase = None + + @property + def current_phase(self): + return None + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"}) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="hi", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] + assert len(phase_events) == 0 diff --git a/tests/unit/test_phase_policy.py b/tests/unit/test_phase_policy.py new file mode 100644 index 0000000..936e823 --- /dev/null +++ b/tests/unit/test_phase_policy.py @@ -0,0 +1,348 @@ +"""Unit tests for PhasePolicy + PhaseState (G6 core, R24/R25/R26). + +Covers: +- PhaseState enum (next_of, from_string) +- default_policy() KTD5 whitelist +- PhasePolicy.is_tool_allowed / is_bash_command_allowed +- policy_from_config parsing (R26 config-driven) +- ServerConfig.plan_exec integration +""" + +from __future__ import annotations + +import re + +import pytest + +from agentkit.core.phase import ( + WILDCARD, + PhasePolicy, + PhaseState, + default_policy, + policy_from_config, +) +from agentkit.server.config import ServerConfig + + +# --------------------------------------------------------------------------- +# PhaseState enum +# --------------------------------------------------------------------------- + + +class TestPhaseState: + def test_values(self): + assert PhaseState.PLANNING.value == "planning" + assert PhaseState.BUILDING.value == "building" + assert PhaseState.VERIFICATION.value == "verification" + assert PhaseState.DELIVERY.value == "delivery" + + def test_next_of(self): + assert PhaseState.next_of(PhaseState.PLANNING) == PhaseState.BUILDING + assert PhaseState.next_of(PhaseState.BUILDING) == PhaseState.VERIFICATION + assert PhaseState.next_of(PhaseState.VERIFICATION) == PhaseState.DELIVERY + assert PhaseState.next_of(PhaseState.DELIVERY) is None + + def test_from_string_case_insensitive(self): + assert PhaseState.from_string("planning") == PhaseState.PLANNING + assert PhaseState.from_string("PLANNING") == PhaseState.PLANNING + assert PhaseState.from_string("Building") == PhaseState.BUILDING + + def test_from_string_invalid_raises(self): + with pytest.raises(ValueError, match="Invalid phase name"): + PhaseState.from_string("unknown") + with pytest.raises(ValueError, match="Valid:"): + PhaseState.from_string("exploration") + + +# --------------------------------------------------------------------------- +# default_policy() — KTD5 whitelist +# --------------------------------------------------------------------------- + + +class TestDefaultPolicy: + def test_has_all_four_phases(self): + policy = default_policy() + assert PhaseState.PLANNING in policy.whitelist + assert PhaseState.BUILDING in policy.whitelist + assert PhaseState.VERIFICATION in policy.whitelist + assert PhaseState.DELIVERY in policy.whitelist + + def test_planning_whitelist_matches_r24(self): + policy = default_policy() + allowed = policy.whitelist[PhaseState.PLANNING] + assert "search" in allowed + assert "read_file" in allowed + assert "shell" in allowed + assert "tool_search" in allowed + # Planning must NOT allow write_file. + assert "write_file" not in allowed + + def test_building_whitelist_includes_write_file(self): + policy = default_policy() + allowed = policy.whitelist[PhaseState.BUILDING] + assert "write_file" in allowed + assert "shell" in allowed + assert "read_file" in allowed + + def test_verification_whitelist_excludes_write(self): + policy = default_policy() + allowed = policy.whitelist[PhaseState.VERIFICATION] + assert "shell" in allowed + assert "read_file" in allowed + assert "write_file" not in allowed + + def test_delivery_wildcard(self): + policy = default_policy() + allowed = policy.whitelist[PhaseState.DELIVERY] + assert WILDCARD in allowed + + def test_start_phase_default_planning(self): + assert default_policy().start_phase == PhaseState.PLANNING + + def test_auto_advance_default_none(self): + # KTD6: manual by default. + assert default_policy().auto_advance_after_steps is None + + def test_bash_filter_blocks_rm_in_planning(self): + policy = default_policy() + assert policy.is_bash_command_allowed("ls -la", PhaseState.PLANNING) is True + assert policy.is_bash_command_allowed("git status", PhaseState.PLANNING) is True + assert policy.is_bash_command_allowed("rm -rf /tmp/x", PhaseState.PLANNING) is False + assert policy.is_bash_command_allowed("echo x > file.txt", PhaseState.PLANNING) is False + + def test_bash_filter_no_restriction_in_building(self): + policy = default_policy() + assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True + assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True + + +# --------------------------------------------------------------------------- +# PhasePolicy — is_tool_allowed +# --------------------------------------------------------------------------- + + +class TestIsToolAllowed: + def test_planning_allows_search(self): + policy = default_policy() + assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True + + def test_planning_blocks_write_file(self): + policy = default_policy() + assert policy.is_tool_allowed("write_file", PhaseState.PLANNING) is False + + def test_building_allows_write_file(self): + policy = default_policy() + assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True + + def test_delivery_wildcard_allows_anything(self): + policy = default_policy() + assert policy.is_tool_allowed("any_random_tool", PhaseState.DELIVERY) is True + assert policy.is_tool_allowed("write_file", PhaseState.DELIVERY) is True + + def test_unknown_phase_returns_false(self): + # ponytail: unknown phase → empty whitelist → no tool allowed. + # We can't construct an unknown PhaseState (enum), but if a phase + # were missing from the whitelist dict, is_tool_allowed should + # return False (defensive). + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({"search"}), + PhaseState.BUILDING: frozenset({"write_file"}), + PhaseState.VERIFICATION: frozenset({"shell"}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + } + ) + # BUILDING is in whitelist, so allowed checks work normally. + assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True + # Phase missing from whitelist would return False (defensive .get default). + # We test this by constructing a minimal policy. + minimal = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({WILDCARD}), + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + } + ) + # VERIFICATION is in whitelist — wildcard allows all. + assert minimal.is_tool_allowed("anything", PhaseState.VERIFICATION) is True + + +# --------------------------------------------------------------------------- +# PhasePolicy — edge cases & errors +# --------------------------------------------------------------------------- + + +class TestPhasePolicyEdgeCases: + def test_empty_whitelist_raises(self): + # Fail-fast: an empty whitelist for a non-wildcard phase is a bug. + with pytest.raises(ValueError, match="empty whitelist"): + PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset(), # empty! + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + } + ) + + def test_wildcard_only_does_not_raise(self): + # Wildcard-only whitelist is valid (means "all tools"). + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({WILDCARD}), + PhaseState.BUILDING: frozenset({WILDCARD}), + PhaseState.VERIFICATION: frozenset({WILDCARD}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + } + ) + assert policy.is_tool_allowed("anything", PhaseState.PLANNING) is True + + def test_to_dict_serializable(self): + policy = default_policy() + d = policy.to_dict() + assert "whitelist" in d + assert "planning" in d["whitelist"] + assert "delivery" in d["whitelist"] + assert d["start_phase"] == "planning" + assert d["auto_advance_after_steps"] is None + + def test_custom_bash_filter(self): + custom_filter = re.compile(r"\b(pip install|npm install)\b") + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({"shell"}), + PhaseState.BUILDING: frozenset({"shell"}), + PhaseState.VERIFICATION: frozenset({"shell"}), + PhaseState.DELIVERY: frozenset({WILDCARD}), + }, + bash_command_filter={PhaseState.BUILDING: custom_filter}, + ) + assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False + assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True + + +# --------------------------------------------------------------------------- +# policy_from_config — R26 (config-driven) +# --------------------------------------------------------------------------- + + +class TestPolicyFromConfig: + def test_empty_config_returns_none(self): + assert policy_from_config({}) is None + + def test_enabled_false_returns_none(self): + # Opt-out — explicit `enabled: false` disables policy. + result = policy_from_config({"enabled": False}) + assert result is None + + def test_enabled_default_true_when_section_present(self): + # When section is present but `enabled` is missing, default is True. + result = policy_from_config({"auto_advance_after_steps": 3}) + assert result is not None + assert result.auto_advance_after_steps == 3 + + def test_auto_advance_after_steps(self): + policy = policy_from_config({"enabled": True, "auto_advance_after_steps": 5}) + assert policy is not None + assert policy.auto_advance_after_steps == 5 + + def test_start_phase_custom(self): + policy = policy_from_config({"enabled": True, "start_phase": "building"}) + assert policy is not None + assert policy.start_phase == PhaseState.BUILDING + + def test_start_phase_invalid_raises(self): + with pytest.raises(ValueError, match="Invalid phase name"): + policy_from_config({"enabled": True, "start_phase": "unknown"}) + + def test_whitelist_override_merges_with_default(self): + policy = policy_from_config( + { + "enabled": True, + "whitelist_override": { + "planning": ["search", "read_file"], # removes shell from default + }, + } + ) + assert policy is not None + # Override wins — shell should be removed from planning. + assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True + assert policy.is_tool_allowed("read_file", PhaseState.PLANNING) is True + assert policy.is_tool_allowed("shell", PhaseState.PLANNING) is False + # Other phases unchanged. + assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True + + def test_whitelist_override_invalid_phase_raises(self): + with pytest.raises(ValueError, match="Invalid phase name"): + policy_from_config( + { + "enabled": True, + "whitelist_override": {"unknown_phase": ["tool"]}, + } + ) + + def test_whitelist_override_non_list_raises(self): + with pytest.raises(ValueError, match="must be a list"): + policy_from_config( + { + "enabled": True, + "whitelist_override": {"planning": "not a list"}, + } + ) + + def test_to_dict_round_trip_via_default(self): + # Sanity: default policy serializes to a dict with expected keys. + policy = default_policy() + d = policy.to_dict() + assert set(d["whitelist"].keys()) == { + "planning", + "building", + "verification", + "delivery", + } + + +# --------------------------------------------------------------------------- +# ServerConfig.plan_exec integration (R26) +# --------------------------------------------------------------------------- + + +class TestServerConfigPlanExec: + def test_default_plan_exec_empty(self): + config = ServerConfig.from_dict({}) + assert config.plan_exec == {} + + def test_plan_exec_loaded_from_dict(self): + config = ServerConfig.from_dict( + { + "plan_exec": { + "enabled": True, + "auto_advance_after_steps": 5, + } + } + ) + assert config.plan_exec == {"enabled": True, "auto_advance_after_steps": 5} + + def test_plan_exec_empty_dict_default(self): + config = ServerConfig.from_dict({"plan_exec": {}}) + assert config.plan_exec == {} + + def test_plan_exec_resolved_to_policy(self): + # Wire the config dict through policy_from_config to verify integration. + config = ServerConfig.from_dict( + { + "plan_exec": { + "enabled": True, + "auto_advance_after_steps": 3, + } + } + ) + policy = policy_from_config(config.plan_exec) + assert policy is not None + assert policy.auto_advance_after_steps == 3 + + def test_plan_exec_disabled_via_config(self): + config = ServerConfig.from_dict({"plan_exec": {"enabled": False}}) + policy = policy_from_config(config.plan_exec) + assert policy is None diff --git a/tests/unit/test_react_phase_enforcement.py b/tests/unit/test_react_phase_enforcement.py new file mode 100644 index 0000000..ac9d681 --- /dev/null +++ b/tests/unit/test_react_phase_enforcement.py @@ -0,0 +1,339 @@ +"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24). + +Per plan U3 Execution note: characterization-first — verify that +`ReActEngine(phase_policy=None)` behaves identically to pre-change (no +enforcement, no advance_phase tool, no _current_phase mutation). Then add +enforcement tests. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.core.phase import PhasePolicy, PhaseState, default_policy +from agentkit.core.react import ReActEngine +from agentkit.tools.advance_phase import AdvancePhaseTool + + +# --------------------------------------------------------------------------- +# Characterization — phase_policy=None preserves existing behavior +# --------------------------------------------------------------------------- + + +class TestCharacterizationNoPolicy: + """When phase_policy=None, no enforcement happens and behavior matches + pre-Wave-3.""" + + def test_init_without_phase_policy(self): + # Minimal stub LLM gateway — we're only testing constructor. + gateway = MagicMock() + engine = ReActEngine(llm_gateway=gateway) + assert engine._phase_policy is None + assert engine._current_phase is None + assert engine._steps_in_phase == 0 + assert engine.current_phase is None + + @pytest.mark.asyncio + async def test_execute_tool_dispatches_without_phase_check(self): + """Tool dispatch proceeds normally when no policy set.""" + gateway = MagicMock() + engine = ReActEngine(llm_gateway=gateway) + + # MagicMock.name is a special attribute used internally by Mock for + # repr — setting it post-construction does not make mock.name == "x" + # hold. Patch _find_tool directly to bypass the name lookup. + fake_tool = MagicMock() + fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"}) + fake_tool.input_schema = None + engine._find_tool = lambda name, tools: fake_tool + + result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool]) + assert result == {"output": "ok"} + fake_tool.safe_execute.assert_awaited_once_with(x=1) + + @pytest.mark.asyncio + async def test_advance_phase_returns_none_without_policy(self): + gateway = MagicMock() + engine = ReActEngine(llm_gateway=gateway) + assert engine.advance_phase() is None + + def test_reset_does_not_touch_phase_state_when_no_policy(self): + gateway = MagicMock() + engine = ReActEngine(llm_gateway=gateway) + engine.reset() + assert engine._current_phase is None + + +# --------------------------------------------------------------------------- +# Initialization with phase_policy +# --------------------------------------------------------------------------- + + +class TestPhasePolicyInitialization: + def test_phase_policy_set_initializes_current_phase(self): + gateway = MagicMock() + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=default_policy(), + ) + assert engine._phase_policy is not None + assert engine._current_phase == PhaseState.PLANNING + assert engine._steps_in_phase == 0 + + def test_reset_resets_phase_to_start(self): + gateway = MagicMock() + engine = ReActEngine( + llm_gateway=gateway, + phase_policy=default_policy(), + ) + # Manually move phase forward (simulating execute progress). + engine.advance_phase() # PLANNING → BUILDING + assert engine._current_phase == PhaseState.BUILDING + engine._steps_in_phase = 5 + + engine.reset() + assert engine._current_phase == PhaseState.PLANNING + assert engine._steps_in_phase == 0 + + +# --------------------------------------------------------------------------- +# advance_phase() transitions +# --------------------------------------------------------------------------- + + +class TestAdvancePhase: + @pytest.fixture + def engine(self): + return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + + def test_planning_to_building(self, engine): + new_phase = engine.advance_phase() + assert new_phase == PhaseState.BUILDING + assert engine.current_phase == PhaseState.BUILDING + assert engine._steps_in_phase == 0 # counter reset on transition + + def test_building_to_verification(self, engine): + engine.advance_phase() # → BUILDING + new_phase = engine.advance_phase() + assert new_phase == PhaseState.VERIFICATION + assert engine.current_phase == PhaseState.VERIFICATION + + def test_verification_to_delivery(self, engine): + engine.advance_phase() # → BUILDING + engine.advance_phase() # → VERIFICATION + new_phase = engine.advance_phase() + assert new_phase == PhaseState.DELIVERY + assert engine.current_phase == PhaseState.DELIVERY + + def test_delivery_returns_none(self, engine): + engine.advance_phase() # → BUILDING + engine.advance_phase() # → VERIFICATION + engine.advance_phase() # → DELIVERY + result = engine.advance_phase() + assert result is None + assert engine.current_phase == PhaseState.DELIVERY + + +# --------------------------------------------------------------------------- +# _check_phase_permission — whitelist enforcement +# --------------------------------------------------------------------------- + + +class TestPhasePermission: + @pytest.fixture + def engine(self): + return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + + def test_search_allowed_in_planning(self, engine): + assert engine._check_phase_permission("search", {}) is None + + def test_write_file_blocked_in_planning(self, engine): + result = engine._check_phase_permission("write_file", {}) + assert result is not None + assert result["error"] == "phase_violation" + assert "write_file" in result["message"] + assert result["current_phase"] == "planning" + + def test_write_file_allowed_in_building(self, engine): + engine.advance_phase() # → BUILDING + assert engine._check_phase_permission("write_file", {}) is None + + def test_any_tool_allowed_in_delivery(self, engine): + engine.advance_phase() # → BUILDING + engine.advance_phase() # → VERIFICATION + engine.advance_phase() # → DELIVERY + assert engine._check_phase_permission("literally_anything", {}) is None + + def test_bash_command_filter_blocks_rm_in_planning(self, engine): + result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"}) + assert result is not None + assert result["error"] == "phase_violation" + assert "rm" in result["message"] or "Bash command" in result["message"] + + def test_bash_command_filter_allows_safe_in_planning(self, engine): + # `ls` and `git status` are not blocked. + assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None + assert engine._check_phase_permission("shell", {"command": "git status"}) is None + + def test_bash_command_filter_no_restriction_in_building(self, engine): + engine.advance_phase() # → BUILDING + # `rm` is allowed in building phase. + assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None + + +# --------------------------------------------------------------------------- +# _execute_tool integration — phase enforcement actually blocks dispatch +# --------------------------------------------------------------------------- + + +class TestExecuteToolPhaseEnforcement: + @pytest.fixture + def engine_with_tools(self): + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + # Two fake tools: one allowed in PLANNING (search), one not (write_file). + # MagicMock.name can't be set post-construction (special attribute), + # so we patch _find_tool with a dict-based lookup. + search_tool = MagicMock() + search_tool.input_schema = None + search_tool.safe_execute = AsyncMock(return_value={"results": []}) + + write_tool = MagicMock() + write_tool.input_schema = None + write_tool.safe_execute = AsyncMock(return_value={"written": True}) + + tools_by_name = {"search": search_tool, "write_file": write_tool} + engine._find_tool = lambda name, tools: tools_by_name.get(name) + + return engine, [search_tool, write_tool] + + @pytest.mark.asyncio + async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools): + engine, tools = engine_with_tools + # write_file in PLANNING should be blocked — write_tool.safe_execute + # should NEVER be called. + result = await engine._execute_tool("write_file", {"path": "/x"}, tools) + assert result["error"] == "phase_violation" + assert result["current_phase"] == "planning" + write_tool = tools[1] + write_tool.safe_execute.assert_not_called() + + @pytest.mark.asyncio + async def test_allowed_tool_dispatches_normally(self, engine_with_tools): + engine, tools = engine_with_tools + result = await engine._execute_tool("search", {"query": "foo"}, tools) + assert result == {"results": []} + search_tool = tools[0] + search_tool.safe_execute.assert_awaited_once_with(query="foo") + + @pytest.mark.asyncio + async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools): + engine, tools = engine_with_tools + # First: write_file blocked in PLANNING. + result = await engine._execute_tool("write_file", {"path": "/x"}, tools) + assert result["error"] == "phase_violation" + # Advance to BUILDING. + engine.advance_phase() + # Now: write_file allowed. + result = await engine._execute_tool("write_file", {"path": "/x"}, tools) + assert result == {"written": True} + + +# --------------------------------------------------------------------------- +# Auto-advance safety net (KTD6) +# --------------------------------------------------------------------------- + + +class TestAutoAdvance: + def test_auto_advance_after_threshold(self): + # Custom policy with auto-advance after 2 steps. + policy = PhasePolicy( + whitelist={ + PhaseState.PLANNING: frozenset({"search"}), + PhaseState.BUILDING: frozenset({"write_file"}), + PhaseState.VERIFICATION: frozenset({"shell"}), + PhaseState.DELIVERY: frozenset({"*"}), + }, + auto_advance_after_steps=2, + ) + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy) + assert engine.current_phase == PhaseState.PLANNING + + # Step 1: counter goes to 1, no advance yet. + engine._steps_in_phase += 1 + assert engine._maybe_auto_advance() is False + assert engine.current_phase == PhaseState.PLANNING + + # Step 2: counter hits 2, advance triggered. + engine._steps_in_phase += 1 + assert engine._maybe_auto_advance() is True + assert engine.current_phase == PhaseState.BUILDING + assert engine._steps_in_phase == 0 # reset on advance + + def test_auto_advance_none_default(self): + # default_policy has auto_advance_after_steps=None — no auto-advance. + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + engine._steps_in_phase = 100 + assert engine._maybe_auto_advance() is False + assert engine.current_phase == PhaseState.PLANNING + + +# --------------------------------------------------------------------------- +# AdvancePhaseTool integration +# --------------------------------------------------------------------------- + + +class TestAdvancePhaseTool: + @pytest.mark.asyncio + async def test_advance_phase_tool_transitions_engine(self): + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + tool = AdvancePhaseTool(engine=engine) + result = await tool.execute() + assert result["is_error"] is False + assert result["current_phase"] == "building" + assert engine.current_phase == PhaseState.BUILDING + + @pytest.mark.asyncio + async def test_advance_phase_tool_at_delivery_returns_error(self): + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + # Walk through all phases. + engine.advance_phase() # PLANNING → BUILDING + engine.advance_phase() # BUILDING → VERIFICATION + engine.advance_phase() # VERIFICATION → DELIVERY + tool = AdvancePhaseTool(engine=engine) + result = await tool.execute() + assert result["is_error"] is True + assert result["error"] == "already_at_final_phase" + assert result["current_phase"] == "delivery" + + @pytest.mark.asyncio + async def test_advance_phase_tool_without_policy_returns_error(self): + engine = ReActEngine(llm_gateway=MagicMock()) # no policy + tool = AdvancePhaseTool(engine=engine) + result = await tool.execute() + assert result["is_error"] is True + assert result["error"] == "no_phase_policy" + + def test_tool_schema_accepts_no_arguments(self): + engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy()) + tool = AdvancePhaseTool(engine=engine) + # input_schema has empty properties + additionalProperties:false — + # no arguments expected. + assert tool.input_schema["properties"] == {} + assert tool.input_schema["additionalProperties"] is False + + def test_tool_bypasses_phase_check(self): + """`advance_phase` is the LLM's escape hatch — must never be blocked.""" + # _check_phase_permission should NOT block advance_phase even in PLANNING. + # The bypass is implemented in _execute_tool by name check. + # We verify the bypass indirectly: tool dispatches normally even in + # PLANNING (where only search/read_file/bash/tool_search are allowed). + # advance_phase is not in the whitelist, but the name-based bypass + # in _execute_tool lets it through. + # (Direct unit test of the bypass would require mocking _find_tool.) + # Sanity: advance_phase is not in any whitelist. + for phase, allowed in default_policy().whitelist.items(): + assert "advance_phase" not in allowed, ( + f"advance_phase must not be in {phase.value} whitelist" + ) diff --git a/tests/unit/test_read_file_tool.py b/tests/unit/test_read_file_tool.py new file mode 100644 index 0000000..87363e4 --- /dev/null +++ b/tests/unit/test_read_file_tool.py @@ -0,0 +1,367 @@ +"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline. + +Per plan U1 Execution note: characterization-first — assert that +`symbol=None` returns the full file content (matches pre-existing benchmark +`_FakeTool` shape) before adding symbol-extraction behavior. +""" + +from __future__ import annotations + +import textwrap + +import pytest + +from agentkit.tools.file_read import ReadFileTool + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + + +class TestReadFileToolSchema: + def test_name_is_read_file(self): + tool = ReadFileTool() + assert tool.name == "read_file" + + def test_required_path(self): + tool = ReadFileTool() + assert "path" in tool.input_schema["required"] + assert "path" in tool.input_schema["properties"] + + def test_optional_symbol_and_lines(self): + tool = ReadFileTool() + props = tool.input_schema["properties"] + assert "symbol" in props + assert "start_line" in props + assert "end_line" in props + # None of the optional fields should be in `required`. + required = set(tool.input_schema["required"]) + assert required == {"path"} + + def test_additional_properties_false(self): + # LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern). + tool = ReadFileTool() + assert tool.input_schema.get("additionalProperties") is False + + def test_tags_contain_io_and_read(self): + tool = ReadFileTool() + assert "io" in tool.tags + assert "read" in tool.tags + + +# --------------------------------------------------------------------------- +# Characterization — symbol=None returns full file +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_py_file(tmp_path): + path = tmp_path / "sample.py" + path.write_text( + textwrap.dedent(''' + """Sample module.""" + + def my_func(): + return 42 + + + class MyClass: + attr = 1 + + def method_a(self): + return self.attr + ''').lstrip(), + encoding="utf-8", + ) + return path + + +@pytest.fixture +def sample_ts_file(tmp_path): + path = tmp_path / "sample.ts" + path.write_text( + textwrap.dedent(''' + export function renderComponent(): JSX.Element { + return
; + } + + export class BaseService { + abstract run(): void; + } + ''').lstrip(), + encoding="utf-8", + ) + return path + + +class TestCharacterizationFullFile: + """symbol=None returns the whole file (matches _FakeTool baseline).""" + + @pytest.mark.asyncio + async def test_full_file_returned_when_symbol_none(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file)) + + assert result["is_error"] is False + assert result["path"] == str(sample_py_file) + assert result["start_line"] == 1 + assert result["end_line"] == result["total_lines"] + assert "def my_func" in result["content"] + assert "class MyClass" in result["content"] + assert result["content"].startswith('"""Sample module."""') + + @pytest.mark.asyncio + async def test_full_file_includes_all_lines(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file)) + assert result["total_lines"] >= 8 + assert result["content"].count("\n") >= result["total_lines"] - 1 + + +# --------------------------------------------------------------------------- +# Symbol slicing — happy paths +# --------------------------------------------------------------------------- + + +class TestSymbolSlicing: + @pytest.mark.asyncio + async def test_python_function(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="my_func") + + assert result["is_error"] is False + assert result["symbol"] == "my_func" + assert result["symbol_kind"] == "function" + assert "def my_func" in result["content"] + assert "return 42" in result["content"] + # Should NOT include the class below. + assert "class MyClass" not in result["content"] + + @pytest.mark.asyncio + async def test_python_class_includes_method(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="MyClass") + + assert result["is_error"] is False + assert result["symbol"] == "MyClass" + assert result["symbol_kind"] == "class" + assert "class MyClass" in result["content"] + assert "def method_a" in result["content"] # method included + + @pytest.mark.asyncio + async def test_python_method_directly(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="method_a") + + assert result["is_error"] is False + assert result["symbol"] == "method_a" + assert result["symbol_kind"] == "method" + assert "def method_a" in result["content"] + + @pytest.mark.asyncio + async def test_typescript_function(self, sample_ts_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent") + + assert result["is_error"] is False + assert result["symbol"] == "renderComponent" + assert "renderComponent" in result["content"] + # Should not include the class below. + assert "BaseService" not in result["content"] + + @pytest.mark.asyncio + async def test_typescript_class(self, sample_ts_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_ts_file), symbol="BaseService") + + assert result["is_error"] is False + assert result["symbol"] == "BaseService" + assert result["symbol_kind"] == "class" + assert "BaseService" in result["content"] + + +# --------------------------------------------------------------------------- +# Symbol slicing — edge cases +# --------------------------------------------------------------------------- + + +class TestSymbolSlicingEdgeCases: + @pytest.mark.asyncio + async def test_symbol_not_found_lists_available(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="nonexistent") + + assert result["is_error"] is False # soft error, not hard + assert result["content"] == "" + assert result["symbol"] == "nonexistent" + available = result["available_symbols"] + assert "my_func" in available + assert "MyClass" in available + assert "method_a" in available + assert "nonexistent" not in result["content"] + + @pytest.mark.asyncio + async def test_unsupported_extension_returns_full_with_note(self, tmp_path): + path = tmp_path / "notes.md" + path.write_text("# Hello\nworld\n", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path), symbol="anything") + + assert result["is_error"] is False + assert result["content"] == "# Hello\nworld\n" + assert "symbol extraction not supported" in result["note"] + assert ".md" in result["note"] + + @pytest.mark.asyncio + async def test_empty_file(self, tmp_path): + path = tmp_path / "empty.py" + path.write_text("", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path)) + + assert result["is_error"] is False + assert result["content"] == "" + assert result["total_lines"] == 0 + + @pytest.mark.asyncio + async def test_file_with_no_symbols(self, tmp_path): + path = tmp_path / "data.py" + path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path), symbol="PI") + + # PI is not a def/class — extractor finds no symbols; soft error lists available. + assert result["is_error"] is False + assert result["content"] == "" + assert result["available_symbols"] == [] + + +# --------------------------------------------------------------------------- +# Error paths +# --------------------------------------------------------------------------- + + +class TestReadFileToolErrors: + @pytest.mark.asyncio + async def test_path_required(self): + tool = ReadFileTool() + result = await tool.execute() + assert result["is_error"] is True + assert "path" in result["error"].lower() + + @pytest.mark.asyncio + async def test_path_empty_string(self): + tool = ReadFileTool() + result = await tool.execute(path="") + assert result["is_error"] is True + + @pytest.mark.asyncio + async def test_file_not_found(self, tmp_path): + tool = ReadFileTool() + result = await tool.execute(path=str(tmp_path / "missing.py")) + assert result["is_error"] is True + assert "not found" in result["error"].lower() + + @pytest.mark.asyncio + async def test_path_is_directory(self, tmp_path): + tool = ReadFileTool() + result = await tool.execute(path=str(tmp_path)) + assert result["is_error"] is True + assert "directory" in result["error"].lower() + + +# --------------------------------------------------------------------------- +# Manual line slicing +# --------------------------------------------------------------------------- + + +class TestManualLineSlicing: + @pytest.mark.asyncio + async def test_start_and_end_line(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + start_line=3, + end_line=5, + ) + assert result["is_error"] is False + assert result["start_line"] == 3 + assert result["end_line"] == 5 + # Lines 3-5 of the sample file: + # line 3: "def my_func():" + # line 4: " return 42" + # line 5: "" (blank) + assert "def my_func" in result["content"] + assert "return 42" in result["content"] + + @pytest.mark.asyncio + async def test_start_line_only_extends_to_eof(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), start_line=8) + assert result["is_error"] is False + assert result["start_line"] == 8 + assert result["end_line"] == result["total_lines"] + + @pytest.mark.asyncio + async def test_end_line_only_starts_at_one(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), end_line=2) + assert result["is_error"] is False + assert result["start_line"] == 1 + assert result["end_line"] == 2 + + @pytest.mark.asyncio + async def test_invalid_start_line_zero(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), start_line=0) + assert result["is_error"] is True + assert "start_line" in result["error"].lower() + + @pytest.mark.asyncio + async def test_end_before_start(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + start_line=5, + end_line=3, + ) + assert result["is_error"] is True + assert "end_line" in result["error"].lower() + + @pytest.mark.asyncio + async def test_manual_lines_override_symbol(self, sample_py_file): + # Per plan U1 Approach: "start_line/end_line overrides symbol". + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + symbol="my_func", + start_line=1, + end_line=1, + ) + assert result["is_error"] is False + # Manual slicing won — symbol field absent. + assert "symbol" not in result or result.get("symbol") is None + assert result["start_line"] == 1 + assert result["end_line"] == 1 + + +# --------------------------------------------------------------------------- +# Integration — tool registry discovery +# --------------------------------------------------------------------------- + + +class TestToolRegistryDiscovery: + def test_instantiable_without_args(self): + # Default constructor — matches the convention used by ToolRegistry + # to instantiate tools by class. + tool = ReadFileTool() + assert tool.name == "read_file" + + def test_to_dict_serializable(self): + tool = ReadFileTool() + d = tool.to_dict() + assert d["name"] == "read_file" + assert "input_schema" in d + assert "output_schema" in d + assert d["tags"] == ["io", "file", "read"] diff --git a/tests/unit/test_symbol_extractor.py b/tests/unit/test_symbol_extractor.py new file mode 100644 index 0000000..d7b444f --- /dev/null +++ b/tests/unit/test_symbol_extractor.py @@ -0,0 +1,359 @@ +"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor. + +Covers R22 (file reading supports symbol/function granularity) and KTD1 +(Python ast + language-aware regex, no tree-sitter dependency). +""" + +from __future__ import annotations + +import textwrap + +import pytest + +from agentkit.tools.symbol_extractor import ( + AstSymbolExtractor, + RegexSymbolExtractor, + SymbolSpan, + extract_symbols_from_file, + get_extractor, + language_for_extension, +) + + +# --------------------------------------------------------------------------- +# language_for_extension +# --------------------------------------------------------------------------- + + +class TestLanguageForExtension: + def test_python_extensions(self): + assert language_for_extension("py") == "py" + assert language_for_extension(".py") == "py" + assert language_for_extension(".PY") == "py" # case-insensitive + + def test_typescript_javascript(self): + assert language_for_extension(".ts") == "ts" + assert language_for_extension(".tsx") == "ts" + assert language_for_extension(".js") == "js" + assert language_for_extension(".jsx") == "js" + assert language_for_extension(".mjs") == "js" + assert language_for_extension(".cjs") == "js" + + def test_go_rust_java(self): + assert language_for_extension(".go") == "go" + assert language_for_extension(".rs") == "rs" + assert language_for_extension(".java") == "java" + + def test_unsupported_returns_empty(self): + assert language_for_extension(".md") == "" + assert language_for_extension(".txt") == "" + assert language_for_extension("") == "" + assert language_for_extension(".unknown") == "" + + +# --------------------------------------------------------------------------- +# AstSymbolExtractor — Python +# --------------------------------------------------------------------------- + + +class TestAstSymbolExtractor: + extractor = AstSymbolExtractor() + + def test_unsupported_language_returns_empty(self): + assert self.extractor.extract_symbols("function foo() {}", "ts") == [] + + def test_syntax_error_returns_empty(self): + # Never raises — callers rely on this for fallback routing. + result = self.extractor.extract_symbols("def broken(:\n pass", "py") + assert result == [] + + def test_top_level_function(self): + content = "def my_func():\n return 42\n" + spans = self.extractor.extract_symbols(content, "py") + assert len(spans) == 1 + span = spans[0] + assert span.name == "my_func" + assert span.kind == "function" + assert span.start_line == 1 + assert span.end_line == 2 + + def test_async_function(self): + content = "async def fetch():\n return 1\n" + spans = self.extractor.extract_symbols(content, "py") + assert len(spans) == 1 + assert spans[0].name == "fetch" + assert spans[0].kind == "function" + + def test_top_level_class(self): + content = textwrap.dedent(''' + class MyClass: + """docstring""" + + def method_a(self): + return 1 + + async def method_b(self): + return 2 + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + names = [s.name for s in spans] + assert "MyClass" in names + assert "method_a" in names + assert "method_b" in names + + cls = next(s for s in spans if s.name == "MyClass") + assert cls.kind == "class" + assert cls.start_line == 1 + # Class body extends through the last method's end_lineno. + assert cls.end_line >= 7 + + def test_methods_classified_as_methods(self): + content = textwrap.dedent(''' + class Foo: + def bar(self): + pass + + def top_level(): + pass + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + by_name = {s.name: s for s in spans} + assert by_name["bar"].kind == "method" + assert by_name["top_level"].kind == "function" + + def test_decorated_function(self): + content = textwrap.dedent(''' + @staticmethod + def helper(): + return "hi" + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + # Note: extractor uses node.lineno (def line) — decorators above are + # excluded by design (matches user-visible symbol start at `def`). + assert any(s.name == "helper" for s in spans) + span = next(s for s in spans if s.name == "helper") + assert span.start_line == 2 # the `def` line + + def test_nested_function(self): + content = textwrap.dedent(''' + def outer(): + def inner(): + return 1 + return inner() + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + names = {s.name for s in spans} + assert "outer" in names + assert "inner" in names + + def test_empty_file(self): + assert self.extractor.extract_symbols("", "py") == [] + + def test_no_symbols_in_docstring_only_file(self): + content = '"""just a docstring"""\n' + assert self.extractor.extract_symbols(content, "py") == [] + + +# --------------------------------------------------------------------------- +# RegexSymbolExtractor — TS/JS/Go/Rust/Java +# --------------------------------------------------------------------------- + + +class TestRegexSymbolExtractor: + extractor = RegexSymbolExtractor() + + def test_unsupported_language_returns_empty(self): + assert self.extractor.extract_symbols("def foo(): pass", "py") == [] + assert self.extractor.extract_symbols("function foo() {}", "rb") == [] + + def test_typescript_function_declaration(self): + content = textwrap.dedent(''' + export function renderComponent(props: Props): JSX.Element { + return
; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "renderComponent" and s.kind == "function" for s in spans) + + def test_typescript_async_function(self): + content = "async function fetchData() {\n return await fetch();\n}\n" + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "fetchData" for s in spans) + + def test_typescript_arrow_function_const(self): + content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n" + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "handleClick" for s in spans) + + def test_typescript_class(self): + content = textwrap.dedent(''' + export abstract class BaseService { + abstract run(): void; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "BaseService" and s.kind == "class" for s in spans) + + def test_javascript_function(self): + content = "function foo() {\n return 1;\n}\n" + spans = self.extractor.extract_symbols(content, "js") + assert any(s.name == "foo" for s in spans) + + def test_javascript_arrow_const(self): + content = "const bar = () => 42;\n" + spans = self.extractor.extract_symbols(content, "js") + assert any(s.name == "bar" for s in spans) + + def test_go_function(self): + content = textwrap.dedent(''' + package main + + func HandleRequest(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + } + + func (s *Server) Start() { + // method + } + ''').strip() + spans = self.extractor.extract_symbols(content, "go") + names = {s.name for s in spans} + assert "HandleRequest" in names + assert "Start" in names # method receiver pattern + + def test_go_struct(self): + content = "type Server struct {\n Addr string\n}\n" + spans = self.extractor.extract_symbols(content, "go") + assert any(s.name == "Server" and s.kind == "struct" for s in spans) + + def test_rust_function(self): + content = textwrap.dedent(''' + pub fn process(input: &str) -> Result { + Ok(input.len()) + } + + async fn fetch() -> Bytes { + unimplemented!() + } + ''').strip() + spans = self.extractor.extract_symbols(content, "rs") + names = {s.name for s in spans} + assert "process" in names + assert "fetch" in names + + def test_rust_struct(self): + content = "pub struct Config {\n pub path: String,\n}\n" + spans = self.extractor.extract_symbols(content, "rs") + assert any(s.name == "Config" and s.kind == "struct" for s in spans) + + def test_rust_impl(self): + content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n" + spans = self.extractor.extract_symbols(content, "rs") + assert any(s.name == "Config" and s.kind == "impl" for s in spans) + + def test_java_class(self): + content = textwrap.dedent(''' + package com.example; + + public class UserService { + public User findById(long id) { + return null; + } + } + ''').strip() + spans = self.extractor.extract_symbols(content, "java") + assert any(s.name == "UserService" and s.kind == "class" for s in spans) + + def test_java_method(self): + content = "public User findById(long id) {\n return null;\n}\n" + spans = self.extractor.extract_symbols(content, "java") + assert any(s.name == "findById" and s.kind == "function" for s in spans) + + def test_end_line_extends_to_next_symbol(self): + # First symbol's end_line is the line before the second symbol starts. + content = textwrap.dedent(''' + function first() { + return 1; + } + + function second() { + return 2; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "js") + spans.sort(key=lambda s: s.start_line) + first = spans[0] + second = spans[1] + assert first.name == "first" + assert second.name == "second" + assert first.end_line == second.start_line - 1 + + def test_last_symbol_end_line_is_eof(self): + content = "function only() {\n return 1;\n}\n" + spans = self.extractor.extract_symbols(content, "js") + assert len(spans) == 1 + assert spans[0].end_line == len(content.splitlines()) + + +# --------------------------------------------------------------------------- +# get_extractor + integration +# --------------------------------------------------------------------------- + + +class TestGetExtractor: + def test_python_returns_ast_extractor(self): + ext = get_extractor("py") + assert ext is not None + assert isinstance(ext, AstSymbolExtractor) + + def test_typescript_returns_regex_extractor(self): + ext = get_extractor("ts") + assert ext is not None + assert isinstance(ext, RegexSymbolExtractor) + + def test_unsupported_returns_none(self): + assert get_extractor("md") is None + assert get_extractor("") is None + assert get_extractor("unknown") is None + + +class TestExtractSymbolsFromFile: + def test_python_file(self, tmp_path): + path = tmp_path / "module.py" + path.write_text("def hello():\n return 'world'\n", encoding="utf-8") + spans, lang = extract_symbols_from_file(path) + assert lang == "py" + assert any(s.name == "hello" for s in spans) + + def test_unsupported_extension(self, tmp_path): + path = tmp_path / "notes.md" + path.write_text("# Hello\n", encoding="utf-8") + spans, lang = extract_symbols_from_file(path) + assert lang == "" + assert spans == [] + + def test_missing_file_returns_empty(self, tmp_path): + path = tmp_path / "nonexistent.py" + spans, lang = extract_symbols_from_file(path) + # lang is detected from extension even if read fails. + assert lang == "py" + assert spans == [] + + +# --------------------------------------------------------------------------- +# SymbolSpan dataclass +# --------------------------------------------------------------------------- + + +class TestSymbolSpan: + def test_frozen_dataclass(self): + span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3) + assert span.name == "foo" + with pytest.raises(Exception): + span.name = "bar" # type: ignore[misc] — frozen + + def test_equality(self): + a = SymbolSpan("foo", "function", 1, 3) + b = SymbolSpan("foo", "function", 1, 3) + assert a == b + assert hash(a) == hash(b)