feat(agent): Wave 3 strategic coupling (G5/G6) (#6)
Deploy to Production / deploy (push) Waiting to run Details
Test / backend-test (push) Waiting to run Details
Test / frontend-unit (push) Waiting to run Details
Test / api-e2e (push) Waiting to run Details
Test / frontend-e2e (push) Waiting to run Details

This commit is contained in:
Fischer 2026-06-30 09:17:19 +08:00
parent a2dcde01b8
commit 2b8a7d8909
15 changed files with 3458 additions and 8 deletions

View File

@ -51,6 +51,19 @@ fallback_chain:
max_retries: 1 # ReflexionEngine max_reflections override max_retries: 1 # ReflexionEngine max_reflections override
emergency: emergency:
enabled: true enabled: true
# G6/U2: PLAN_EXEC phase policy — SOLO four-stage state machine.
# When `enabled: true`, chat WebSocket PLAN_EXEC requests build a PhasePolicy
# (Planning → Building → Verification → Delivery) and enforce per-step tool
# whitelists (R24). Transitions are LLM-driven via AdvancePhaseTool; set
# `auto_advance_after_steps` to auto-advance as a safety net (KTD6).
# Commented to preserve default behavior — uncomment to enable.
# plan_exec:
# enabled: true
# auto_advance_after_steps: 5 # optional, default = manual (LLM calls advance_phase)
# start_phase: planning # optional, default = planning
# whitelist_override: # optional, merges with default (override wins)
# planning: [search, read_file, shell]
# building: [write_file, shell, read_file]
session: {backend: memory} session: {backend: memory}
bus: {backend: memory} bus: {backend: memory}
task_store: {backend: memory} task_store: {backend: memory}

View File

@ -0,0 +1,436 @@
---
title: "feat: Agent Wave 3 strategic coupling (G5/G6)"
date: 2026-06-29
type: feat
status: draft
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
execution: code
---
# Wave 3 Strategic Coupling — G5 Function-level Sharding + G6 SOLO Phase Constraints
## Summary
Wave 3 of the advanced-agent gap optimization closes two strategic gaps deferred from Waves 1-2:
- **G5 — Function-level code sharding** (R22, R23): file reading gains an optional `symbol` parameter for symbol/function-granularity slicing, backward compatible with full-file reads.
- **G6 — SOLO four-stage state machine** (R24, R25): ReAct loop enforces per-phase tool whitelists (Planning → Building → Verification → Delivery). Extends existing `ExecutionMode.PLAN_EXEC` rather than introducing a new mode.
Wave 1 (G1/G2/G3/G8 — PR #4 merged) and Wave 2 (G4/G7/G9 — PR #5 open) shipped independently. Wave 3 is the **strategic-risk** wave: it introduces a new tool (G5) and touches ReAct core (G6). Per the brainstorm's KTD6/KTD7 locked decisions, G5 integration approach is decided here (not deferred further), and G6 extends PLAN_EXEC rather than adding a new mode.
## Problem Frame
The brainstorm (`docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md`) identified nine gaps across three dimensions. Waves 1-2 closed seven (G1-G4, G7-G9). The remaining two gaps are strategic:
- **G5 (Long-context cost)**: Large files (e.g. 5000-line modules) blow context budget when the agent only needs one function. Today the agent shells out to `cat file.py` or greps; both pull the whole file into context. A symbol-aware slice would cut context cost 10-50x for typical edits.
- **G6 (Unsafe tool sequencing)**: ReAct loop lets the LLM call `write_file` during early exploration before committing to a plan. This wastes tokens on premature edits, causes half-baked refactors, and breaks the "plan-then-build" discipline that production agents (Qoder, Trae Work) enforce via phase state machines.
KTD6 (brainstorm) defers the G5 integration decision to this plan. KTD7 locks G6 to extend PLAN_EXEC rather than introduce a new mode.
## Requirements
Carried forward from the brainstorm, Wave 3 section:
- **R22**: File reading supports symbol/function granularity sharding.
- **R23**: Sharding capability exposed as a tool parameter (`symbol="function_name"`), backward compatible with full-file reads.
- **R24**: ReAct loop enforces phase constraints — Planning phase only allows `think`/`search`; Building phase only allows `write_file` (and similar write tools).
- **R25**: Phase state is configurable; extends `ExecutionMode.PLAN_EXEC`, does NOT introduce a new mode.
Cross-cutting (brainstorm):
- **R26**: All optimizations configurable via `agentkit.yaml` (follow `ServerConfig.from_dict` pattern established in Waves 1-2).
- **R27**: Each optimization ships a minimal self-check test (ponytail rule).
Acceptance examples relevant to Wave 3:
- **AE5 already covered by Wave 2 (G9)** — not in Wave 3 scope.
- No explicit AE for G5/G6 in the brainstorm — the plan below specifies test scenarios as the acceptance contract.
## Key Technical Decisions
### KTD1: G5 uses Python `ast` + language-aware regex, NOT tree-sitter
**Decision**: Implement symbol extraction with the Python stdlib `ast` module for Python files, and a small regex-based extractor for TypeScript/JavaScript/Go/Rust/Java. No new dependency.
**Rationale**:
- `tree-sitter` requires native compilation + per-language grammar files (~30MB installed) — violates the ponytail rule "no new dependency if it can be avoided" and AGENTS.md "禁止使用 any 类型" → prefers minimal stack.
- `ast` is stdlib, always available, parses Python accurately.
- Regex extractor covers 80% case for TS/JS/Go/Rust/Java (function/class/struct declarations); falls back to "no symbols found → read full file" gracefully.
- If a future Wave 4 needs more accurate multi-language parsing, `tree-sitter` can replace the regex layer behind the same `SymbolExtractor` interface.
**Upgrade path**: replace `RegexSymbolExtractor` with `TreeSitterSymbolExtractor` implementing the same `SymbolExtractor` protocol; no caller changes.
### KTD2: G5 adds a new `ReadFileTool`, does not extend `ShellTool`
**Decision**: Add a dedicated `ReadFileTool` in `src/agentkit/tools/file_read.py` with `path` + optional `symbol` + optional `start_line`/`end_line` parameters.
**Rationale**:
- `ShellTool` is for shell execution; grafting symbol extraction onto it muddies the contract.
- A dedicated tool gives the LLM a clear schema (`{"path": "...", "symbol": "function_name"}`) and a focused system-prompt description.
- Aligns with the existing `_DEFAULT_CORE_TOOLS` list in `core/react.py:148` which already references `read_file` — the name is reserved but the implementation is missing.
### KTD3: G6 phase state machine lives in ReActEngine, not in the skill config
**Decision**: Phase state (Planning/Building/Verification/Delivery) is tracked as a mutable field on `ReActEngine` instance. Transitions are driven by LLM-detected phase-completion signals (e.g., the LLM emits `Phase: Building` in its thinking) OR by an explicit `advance_phase` tool.
**Rationale**:
- Skill config declares the policy (which tools per phase, auto-advance vs manual); the engine enforces it per-step. This matches R24 ("ReAct 循环加阶段约束").
- Alternative considered: phase state in the agent instance (not engine). Rejected because ReActEngine already owns `max_steps`/`verification_enabled` etc.; phase state belongs with the loop that enforces it.
### KTD4: PLAN_EXEC mode is wired at chat.py WebSocket path (REST already has fallback chain from Wave 2)
**Decision**: chat.py:1084 (currently warns "not yet supported, falling back to REACT") will route `ExecutionMode.PLAN_EXEC` to a new `_execute_plan_exec_ws` handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and passes it to `ReActEngine.execute`.
**Rationale**:
- REST `send_message` already uses the Wave 2 three-tier fallback chain; PLAN_EXEC at REST would also need that wrapper. **Out of scope for Wave 3** — only WebSocket path is wired. REST PLAN_EXEC remains "not yet supported" and explicitly raises if invoked.
- Single integration point keeps Wave 3 scope bounded; REST wiring is a one-line follow-up once WebSocket path is proven.
### KTD5: Default phase whitelist matches brainstorm R24
**Decision**: Default whitelist:
- `Planning`: `search`, `tool_search`, `read_file`, `bash` (read-only commands like `git status`, `ls`)
- `Building`: `write_file`, `bash` (write commands), `read_file`, `search`
- `Verification`: `bash` (test commands), `read_file`, `search`
- `Delivery`: all tools (final synthesis)
**Rationale**:
- R24 explicitly names `think`/`search` for Planning and `write_file` for Building.
- `bash` is split: read-only in Planning, full in Building. Enforced by adding a `bash_command_filter` callback (regex-based, blocks `rm`/`mv`/`>`/`>>` in Planning/Verification).
- `Delivery` allows all tools to support last-mile formatting/cleanup.
### KTD6: Phase transitions are LLM-driven via `advance_phase` tool (opt-in auto-advance)
**Decision**: Add an `AdvancePhaseTool` that the LLM can call to transition Planning→Building→Verification→Delivery. Auto-advance (after N steps in current phase) is opt-in via `plan_exec.auto_advance_after_steps`.
**Rationale**:
- LLM-driven transitions match Qoder/Trae Work pattern: LLM declares "planning done" explicitly.
- Auto-advance is a safety net for LLMs that forget to call `advance_phase`; default off (ponytail: less code is better).
## Scope Boundaries
### In Scope
- New `ReadFileTool` with `symbol` parameter (G5, R22/R23).
- `SymbolExtractor` protocol + `AstSymbolExtractor` (Python) + `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
- `PhasePolicy` dataclass + `PhaseState` enum + per-step tool whitelist enforcement in `ReActEngine.execute` (G6, R24/R25).
- `AdvancePhaseTool` for LLM-driven phase transitions.
- WebSocket chat path routes `PLAN_EXEC` to new `_execute_plan_exec_ws` handler (KTD4).
- `plan_exec` config section in `agentkit.yaml` + `ServerConfig.from_dict` extension (R26).
- Tests for each new module (R27).
### Out of Scope (Deferred to Follow-Up Work)
- REST `send_message` PLAN_EXEC wiring — once WebSocket path is proven, REST wiring is a follow-up commit.
- `tree-sitter` integration for more accurate multi-language parsing (KTD1 upgrade path).
- Phase-aware prompt engineering (per-phase system prompt templates) — current plan keeps a single system prompt; phase-specific guidance is a prompt-engineering concern, not a code change.
- Phase persistence across session resume (U7 checkpoint already saves plan state; phase state restoration is a separate concern).
- Phase rollback on `Building``Planning` regression (Wave 2 G9 rollback handles file-level rollback; phase regression is a UX/prompt concern).
- Tool-filter UI in the frontend (Wave 3 ships backend-only; frontend surfaces phase via existing event channel if needed in a follow-up).
### Outside This Product's Identity
- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm).
- Disc-based file system à la DeerFlow (inherited).
- Docker sandbox (inherited; only command-level safety via `bash_command_filter`).
## High-Level Technical Design
```mermaid
flowchart LR
subgraph G5[Function Sharding]
RF[ReadFileTool] --> SE[SymbolExtractor protocol]
SE --> AST[AstSymbolExtractor<br/>Python stdlib ast]
SE --> RX[RegexSymbolExtractor<br/>TS/JS/Go/Rust/Java]
end
subgraph G6[Phase State Machine]
PP[PhasePolicy config] --> PS[PhaseState enum<br/>Planning/Building/Verify/Delivery]
PS --> Filt[Tool filter per step]
Filt --> RE[ReActEngine.execute]
AP[AdvancePhaseTool] -->|transitions| PS
end
RF -->|file content for symbol| RE
RE -->|enforces| Filt
```
The two subsystems compose at the ReAct engine boundary: `ReadFileTool` is one of the tools the LLM can call during any phase (filtered by `PhasePolicy`); `PhaseState` is enforced at the tool-call step before dispatch.
## Implementation Units
### U1. SymbolExtractor + ReadFileTool (G5)
**Goal**: Add `ReadFileTool` with optional `symbol` parameter; implement `SymbolExtractor` protocol with `AstSymbolExtractor` (Python) and `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
**Requirements**: R22, R23, R27.
**Dependencies**: none.
**Files**:
- `src/agentkit/tools/file_read.py` (new)
- `src/agentkit/tools/symbol_extractor.py` (new)
- `src/agentkit/tools/__init__.py` (modify — register `ReadFileTool`)
- `tests/unit/test_symbol_extractor.py` (new)
- `tests/unit/test_read_file_tool.py` (new)
**Approach**:
- `SymbolExtractor` is a `Protocol` with one method: `extract_symbols(content: str, language: str) -> list[SymbolSpan]`. `SymbolSpan` carries `name`, `kind` (function/class/method/struct), `start_line`, `end_line`.
- `AstSymbolExtractor` walks `ast.parse(content)`; for each `FunctionDef`/`AsyncFunctionDef`/`ClassDef` collects name + line range. Uses `ast.get_source_segment` style (line-based, not node-based, to keep the API simple).
- `RegexSymbolExtractor` ships patterns for TS/JS (`function X`, `const X = (...) =>`, `class X`), Go (`func X`), Rust (`fn X`, `struct X`, `impl X`), Java (`public ... X(...)`). Falls back to "no symbols" if no pattern matches.
- `ReadFileTool.execute(path, symbol=None, start_line=None, end_line=None)`:
- `symbol=None` → read full file (backward compat with the existing `_FakeTool` benchmark shape).
- `symbol="foo"` → detect language from extension; call `extract_symbols`; return the line range of the first matching symbol; if no match, return an error result with available symbol names listed (so the LLM can retry).
- `start_line`/`end_line` overrides symbol; allows manual slicing.
- Tool registered as `read_file` (matches the reserved name in `core/react.py:148`).
**Execution note**: characterization-first — write a test that asserts the tool returns the full file content when `symbol=None` (matches pre-existing benchmark `_FakeTool` shape) before adding symbol-extraction behavior.
**Patterns to follow**:
- `src/agentkit/tools/document_tool.py` for tool structure (dataclass, `Tool` base class, `input_schema`).
- `src/agentkit/tools/schema_tools.py:SchemaExtractTool` for "extract-from-source" pattern.
**Test scenarios** (covers R22, R23):
- **Happy paths**:
- Python file, `symbol="MyClass"` → returns class body only (lines from `class MyClass:` through end of class).
- Python file, `symbol="my_func"` → returns function body only.
- TypeScript file, `symbol="renderComponent"` → returns arrow/function body.
- Go file, `symbol="HandleRequest"` → returns func body.
- **Edge cases**:
- `symbol=None` → returns full file content (characterization).
- `symbol="nonexistent"` → returns error result listing available symbols ("Available symbols: foo, bar, baz").
- Unsupported file extension (`.md`, `.txt`) → returns full file with `note: symbol extraction not supported for .md`.
- Empty file → returns empty content.
- File with nested classes → outer class symbol returns including inner class.
- **Error paths**:
- Path does not exist → raises `FileNotFoundError` (or returns error result matching other tools' convention).
- Path is a directory → returns error result.
- Permission denied → returns error result.
- **Integration scenarios**:
- Symbol extraction + line slicing: `symbol="foo"`, `end_line=50` truncates at line 50 even if symbol extends further.
- Round-trip: extract symbol, write back via `ShellTool` `sed` (not in scope for tool — just verify extracted range is well-formed).
**Verification**:
- `python3 -m pytest tests/unit/test_symbol_extractor.py tests/unit/test_read_file_tool.py -q` passes.
- `ruff check src/agentkit/tools/file_read.py src/agentkit/tools/symbol_extractor.py` clean.
- `ReadFileTool` appears in `ToolRegistry.list_tools()` after registration.
---
### U2. PhasePolicy + PhaseState + ServerConfig (G6 core)
**Goal**: Add `PhasePolicy` dataclass, `PhaseState` enum, default whitelist config. Extend `ServerConfig.from_dict` with `plan_exec` section. Wire config to `agentkit.yaml`.
**Requirements**: R25, R26.
**Dependencies**: none.
**Files**:
- `src/agentkit/core/phase.py` (new) — `PhaseState` enum, `PhasePolicy` dataclass, `default_policy()` factory.
- `src/agentkit/server/config.py` (modify — add `plan_exec` field + `from_dict` parsing).
- `agentkit.yaml` (modify — document `plan_exec:` section).
- `tests/unit/test_phase_policy.py` (new).
**Approach**:
- `PhaseState = enum("planning building verification delivery")`.
- `PhasePolicy` carries:
- `whitelist: dict[PhaseState, set[str]]` — tool names allowed per phase.
- `bash_command_filter: dict[PhaseState, re.Pattern | None]` — regex that bash args must NOT match (e.g., `r"\b(rm|mv|>|>>)\b"` in Planning).
- `auto_advance_after_steps: int | None` — None = manual (LLM calls `advance_phase`); int = auto-advance after N steps.
- `start_phase: PhaseState = PhaseState.PLANNING`.
- `default_policy()` returns the KTD5 whitelist above.
- `ServerConfig.from_dict` reads `plan_exec` section: `enabled`, `whitelist_override` (dict), `auto_advance_after_steps`.
- `agentkit.yaml` gains a commented-out `plan_exec:` block (commented to preserve default behavior — opt-in).
**Patterns to follow**:
- `src/agentkit/core/fallback.py` for dataclass + classmethod factory pattern.
- `src/agentkit/server/config.py` `from_dict` extension template (established in Wave 1 for `prompt_cache`/`streaming`/`verification`; Wave 2 added `rollback`/`fallback_chain`; Wave 3 adds `plan_exec`).
**Test scenarios** (covers R25, R26):
- **Happy paths**:
- `default_policy()` returns policy with all four phases; Planning whitelist contains `search`, `read_file`; Building contains `write_file`.
- `PhasePolicy.is_tool_allowed("search", PhaseState.PLANNING)` returns True.
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.PLANNING)` returns False.
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.BUILDING)` returns True.
- **Edge cases**:
- Empty whitelist for a phase → all tools rejected (raises `ValueError` at construction time — fail-fast).
- `Delivery` phase whitelist contains `"*"` (wildcard) → all tools allowed.
- Custom whitelist override merges with default (override wins on conflict).
- **Error paths**:
- Invalid phase name in config → `ValueError` with message naming the bad value.
- `bash_command_filter` regex compile failure → `ValueError`.
- **Config integration**:
- `ServerConfig.from_dict({"plan_exec": {"enabled": True, "auto_advance_after_steps": 5}})` populates fields correctly.
- `ServerConfig.from_dict({})``plan_exec = {}` (default).
**Verification**:
- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes.
- `ruff check src/agentkit/core/phase.py` clean.
---
### U3. AdvancePhaseTool + ReActEngine phase enforcement (G6 wiring)
**Goal**: Add `AdvancePhaseTool`. Wire `PhasePolicy` into `ReActEngine.execute` so each tool-call step checks `is_tool_allowed(tool_name, current_phase)` before dispatch; blocked calls return a structured error to the LLM ("Tool 'write_file' not allowed in Planning phase — call advance_phase first").
**Requirements**: R24.
**Dependencies**: U2.
**Files**:
- `src/agentkit/tools/advance_phase.py` (new) — `AdvancePhaseTool` calls `react_engine.advance_phase()`.
- `src/agentkit/core/react.py` (modify — add `phase_policy` param to `__init__` + `execute`; add `_current_phase` field; add `advance_phase()` method; enforce in `_execute_loop`).
- `tests/unit/test_react_phase_enforcement.py` (new).
**Approach**:
- `ReActEngine.__init__` accepts `phase_policy: PhasePolicy | None = None`. None = no enforcement (backward compat — all existing callers unaffected).
- `_current_phase: PhaseState | None` initialized from `phase_policy.start_phase` if policy set, else None.
- `advance_phase()` advances `_current_phase` to next enum value; raises `ValueError` if already at `DELIVERY`.
- In `_execute_loop`, before dispatching a tool call:
```python
if self._phase_policy is not None and self._current_phase is not None:
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
# Inject structured error into conversation, do NOT dispatch tool.
# This counts as a "step" for max_steps purposes.
observation = {
"error": "phase_violation",
"message": f"Tool '{tool_name}' not allowed in {self._current_phase.value} phase",
"current_phase": self._current_phase.value,
"hint": "Call advance_phase to move to Building phase"
}
continue # next loop iteration
```
- Auto-advance: if `phase_policy.auto_advance_after_steps` is set and `_steps_in_phase >= auto_advance_after_steps`, call `advance_phase()` automatically.
- `AdvancePhaseTool.execute()` calls the bound engine's `advance_phase()` and returns the new phase name. Registered only when `phase_policy` is not None.
**Execution note**: characterization-first — test that `ReActEngine` with `phase_policy=None` behaves identically to pre-change (no enforcement, no `advance_phase` tool, no `_current_phase` mutation). Then add enforcement tests.
**Patterns to follow**:
- `src/agentkit/core/react.py` `verification_enabled` pattern (feature flag + step-level check).
- `src/agentkit/tools/ask_human.py` for tool that interacts with engine state.
**Test scenarios** (covers R24):
- **Characterization (no policy)**:
- `ReActEngine(phase_policy=None)` — all tools allowed in all steps; no `advance_phase` tool registered; behavior matches pre-change.
- **Happy paths**:
- Planning phase: LLM calls `search` → executes; LLM calls `advance_phase` → phase becomes Building.
- Building phase: LLM calls `write_file` → executes; LLM calls `advance_phase` → phase becomes Verification.
- Verification phase: LLM calls `bash` with `pytest` → executes; LLM calls `advance_phase` → phase becomes Delivery.
- Delivery phase: LLM calls any tool → executes (wildcard).
- **Edge cases**:
- `advance_phase` called at Delivery → returns error "Already at final phase".
- Auto-advance after 3 steps in Planning → phase transitions automatically on 4th step.
- `bash` command in Planning contains `rm file` → blocked by `bash_command_filter`.
- **Error paths**:
- LLM calls `write_file` in Planning → tool NOT dispatched; structured error returned to LLM; loop continues.
- LLM calls non-existent tool → existing error path (not phase-related).
- **Integration scenarios**:
- Phase transition emits a `phase_changed` event (use existing `_broadcast_event` pattern from `experts/orchestrator.py`).
- `max_steps` reached mid-phase → `ReActResult.status = "max_steps_reached"` (existing path, no change).
**Verification**:
- `python3 -m pytest tests/unit/test_react_phase_enforcement.py -q` passes.
- Existing `tests/unit/test_react_engine.py` still passes (characterization — no policy = no change).
- `ruff check src/agentkit/core/react.py src/agentkit/tools/advance_phase.py` clean.
---
### U4. Wire PLAN_EXEC at chat.py WebSocket path (G6 chat integration)
**Goal**: Replace the `chat.py:1084` "not yet supported, falling back to REACT" warning with a real PLAN_EXEC handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and dispatches to `ReActEngine.execute` with the policy set.
**Requirements**: R24, R25 (end-to-end wiring).
**Dependencies**: U2, U3.
**Files**:
- `src/agentkit/server/routes/chat.py` (modify — add `_execute_plan_exec_ws` handler; branch on `ExecutionMode.PLAN_EXEC`).
- `tests/unit/test_chat_plan_exec_ws.py` (new).
**Approach**:
- New helper `_execute_plan_exec_ws(websocket, agent, routing, messages, ...)`:
1. Read `server_config.plan_exec` (may be `{}` if not configured → use `default_policy()`).
2. Build `PhasePolicy` from config (apply overrides).
3. Construct `ReActEngine(..., phase_policy=policy)`.
4. Register `AdvancePhaseTool` bound to this engine.
5. Call `engine.execute_stream(...)` — reuses existing streaming path.
6. Emit `phase_changed` events through the WebSocket (frontend can render phase indicator).
- chat.py:1084 changes from `if execution_mode not in (REACT, SKILL_REACT): warn + fall back` to:
```python
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
await _execute_plan_exec_ws(websocket, agent, routing, ...)
return
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
# existing warning for REWOO/REFLEXION/TEAM_COLLAB
...
```
- REST `send_message` path: explicitly raise `HTTPException(501, "PLAN_EXEC via REST not yet supported; use WebSocket")` — Wave 3 does NOT wire REST (KTD4).
**Execution note**: characterization-first — test that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring.
**Patterns to follow**:
- `src/agentkit/server/routes/chat.py` existing WebSocket handler structure (lines 1082-1100).
- `src/agentkit/server/_fallback_chain.py` (Wave 2 U3) for "construct engine per-request with config" pattern.
**Test scenarios** (covers end-to-end):
- **Characterization**:
- `ExecutionMode.REWOO` via WebSocket → still falls back to REACT with warning (existing behavior unchanged).
- `ExecutionMode.REFLEXION` → same.
- `ExecutionMode.TEAM_COLLAB` → same.
- **Happy paths**:
- `ExecutionMode.PLAN_EXEC` via WebSocket → `_execute_plan_exec_ws` invoked; `ReActEngine` constructed with `phase_policy`; `AdvancePhaseTool` registered.
- Planning phase: LLM emits `search` tool call → executed; tool result streamed.
- LLM emits `advance_phase``phase_changed` event sent to WebSocket client; subsequent `write_file` call now allowed.
- **Edge cases**:
- `plan_exec` config absent → `default_policy()` used; behavior matches KTD5 whitelist.
- `plan_exec.enabled=False` → falls back to REACT (opt-out).
- Phase violation: LLM calls `write_file` in Planning → structured error returned; loop continues; `phase_violation` event emitted.
- **Error paths**:
- REST `send_message` with PLAN_EXEC → 501 error.
- Phase policy construction fails (bad config) → 500 error with message.
- **Integration scenarios**:
- Existing fallback chain (Wave 2 U3) NOT applied to PLAN_EXEC — phase policy and fallback chain are mutually exclusive (KTD5 from Wave 2 plan: chain only wraps REACT/SKILL_REACT at REST). Document this in chat.py comment.
**Verification**:
- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py -q` passes.
- `ruff check src/agentkit/server/routes/chat.py` clean.
- Manual test: `agentkit chat` with `@skill:plan_exec_demo` skill config → WebSocket stream includes `phase_changed` events.
---
## Risks & Dependencies
### Risks
1. **ReAct core modification risk (high)**: U3 modifies `ReActEngine._execute_loop`. Mitigation: characterization-first tests (U3 Execution note); `phase_policy=None` default preserves all existing behavior; full `test_react_engine.py` regression.
2. **Symbol extraction accuracy (medium)**: Regex extractor may miss edge cases (decorated functions, nested generics, multi-line signatures). Mitigation: fall back to "no symbols found → read full file" gracefully; never raise on extraction failure.
3. **PLAN_EXEC phase deadlock (medium)**: LLM may never call `advance_phase`, leaving the agent stuck in Planning. Mitigation: `auto_advance_after_steps` config (default 5); timeout via existing `max_steps`.
4. **Tool name drift (low)**: Phase whitelist references tool names (`write_file`, `search`, etc.) that may be renamed in future. Mitigation: whitelist is config-driven; rename only requires config update.
### Dependencies
- Wave 2 PR #5 (`feat/agent-wave2-medium-coupling`) should be merged first — Wave 3 builds on the `ServerConfig.from_dict` extension pattern and the `_fallback_chain.py` integration shape established there. If PR #5 is still open, Wave 3 branches from `feat/agent-wave2-medium-coupling` rather than `main`.
- No external library dependencies (KTD1).
## System-Wide Impact
- **Agents using PLAN_EXEC mode**: gain phase enforcement. Existing REACT/SKILL_REACT/DIRECT_CHAT agents: zero change (phase_policy defaults to None).
- **Tool registry**: gains two new tools (`read_file`, `advance_phase`). Frontend tool list display may need updating to show the new icons — out of scope for Wave 3 (frontend follows up).
- **`agentkit.yaml`**: gains `plan_exec:` section (commented by default). Existing configs unaffected.
- **WebSocket clients**: gain `phase_changed` event type. Existing clients ignore unknown event types (verified in Wave 2 — `phase_rollback_*` events follow the same pattern).
## Sources & Research
- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 section, KTD6/KTD7).
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (PR #4 merged).
- Wave 2 plan: `docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md` (PR #5 open).
- Trae Work architecture research (cited in brainstorm): SOLO four-stage state machine pattern.
- Qoder architecture research (cited in brainstorm): Spec→Coding→Verify closed loop.
- Codebase: `src/agentkit/core/react.py:148` reserves `read_file`/`write_file` tool names in `_DEFAULT_CORE_TOOLS` — Wave 3 U1 delivers the missing `read_file` implementation.
- Codebase: `src/agentkit/server/routes/chat.py:1084` documents that PLAN_EXEC is "not yet supported" — Wave 3 U4 closes this gap.
## Deferred to Implementation
- Exact regex patterns for non-Python symbol extraction (U1) — design above gives the shape; implementer finalizes patterns based on real-world test fixtures.
- `bash_command_filter` regex precision (U2) — defaults block `rm`/`mv`/`>`/`>>`; implementer may add more based on test scenarios.
- `phase_changed` event payload shape (U3/U4) — minimal viable shape: `{"phase": "building", "previous": "planning"}`; frontend rendering concerns are out of scope.
- Whether `AdvancePhaseTool` accepts a `target_phase` argument for skipping phases (e.g., Planning → Verification) — default no (sequential only); add if test scenarios reveal a need.

206
src/agentkit/core/phase.py Normal file
View File

@ -0,0 +1,206 @@
"""Phase state machine for PLAN_EXEC mode (G6, R24/R25).
Four sequential phases enforce per-step tool whitelists:
PLANNING BUILDING VERIFICATION DELIVERY
KTD3 (Wave 3 plan): state machine lives in ReActEngine, not skill config.
KTD5: default whitelist matches brainstorm R24 (Planning: think/search;
Building: write_file; etc.).
KTD6: transitions are LLM-driven via AdvancePhaseTool; auto-advance is opt-in.
"""
from __future__ import annotations
import enum
import logging
import re
from dataclasses import dataclass, field, replace
from typing import Any
logger = logging.getLogger(__name__)
class PhaseState(enum.Enum):
"""Phases of the SOLO state machine (extends ExecutionMode.PLAN_EXEC)."""
PLANNING = "planning"
BUILDING = "building"
VERIFICATION = "verification"
DELIVERY = "delivery"
@classmethod
def next_of(cls, current: "PhaseState") -> "PhaseState | None":
"""Return the phase after `current`, or None if `current` is the last."""
order = [cls.PLANNING, cls.BUILDING, cls.VERIFICATION, cls.DELIVERY]
try:
idx = order.index(current)
except ValueError:
return None
if idx + 1 >= len(order):
return None
return order[idx + 1]
@classmethod
def from_string(cls, value: str) -> "PhaseState":
"""Parse from string (case-insensitive). Raises ValueError on unknown."""
try:
return cls(value.lower())
except ValueError as e:
valid = ", ".join(p.value for p in cls)
raise ValueError(f"Invalid phase name {value!r}. Valid: {valid}") from e
# Wildcard token meaning "all tools allowed in this phase".
WILDCARD = "*"
# Default bash command filter for PLANNING and VERIFICATION phases — blocks
# commands that mutate the filesystem or execute arbitrary code.
# ponytail: regex is intentionally conservative; misses some shell idioms
# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch
# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time.
# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT
# for `>`/`>>` operators (not word chars). Use a non-boundary alternation
# that matches `>` either as a standalone operator or after whitespace.
_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?<!\S)>|>>")
@dataclass(slots=True)
class PhasePolicy:
"""Per-phase tool whitelist + bash command filter for PLAN_EXEC mode.
The policy is enforced by ReActEngine._execute_loop before each tool
dispatch. A tool not in the current phase's whitelist is rejected with
a structured error returned to the LLM (the loop continues the LLM
gets to react to the rejection and either switch tools or call
AdvancePhaseTool).
Wildcard ``"*"`` in a phase's whitelist means "all tools allowed"
(used by DELIVERY by default).
"""
whitelist: dict[PhaseState, frozenset[str]]
bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict)
auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase)
start_phase: PhaseState = PhaseState.PLANNING
def __post_init__(self) -> None:
# Fail-fast: empty whitelist for a non-wildcard phase = bug.
for phase, tools in self.whitelist.items():
if not tools:
raise ValueError(
f"Phase {phase.value!r} has an empty whitelist — set ['*'] for "
f"'all tools allowed' or list specific tool names."
)
def is_tool_allowed(self, tool_name: str, phase: PhaseState) -> bool:
"""Return True if `tool_name` is allowed in `phase`."""
allowed = self.whitelist.get(phase, frozenset())
if WILDCARD in allowed:
return True
return tool_name in allowed
def is_bash_command_allowed(self, command: str, phase: PhaseState) -> bool:
"""Return True if `command` passes the bash filter for `phase`.
A None filter = no restriction. An empty command is allowed (ShellTool
separately rejects empty commands).
"""
pattern = self.bash_command_filter.get(phase)
if pattern is None:
return True
return not pattern.search(command)
def to_dict(self) -> dict[str, Any]:
"""Serialize for logging/telemetry. Not round-trippable (regex → str)."""
return {
"whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()},
"bash_command_filter": {
phase.value: (p.pattern if p else None)
for phase, p in self.bash_command_filter.items()
},
"auto_advance_after_steps": self.auto_advance_after_steps,
"start_phase": self.start_phase.value,
}
def default_policy() -> PhasePolicy:
"""Return the KTD5 default PhasePolicy.
Whitelist (R24):
- PLANNING: search, tool_search, read_file, shell (read-only)
- BUILDING: write_file, shell (full), read_file, search
- VERIFICATION: shell (test commands), read_file, search
- DELIVERY: all tools (wildcard)
Bash filter:
- PLANNING/VERIFICATION: blocks filesystem-mutating commands
(rm/mv/cp/mkdir/chmod/chown/>/>>)
- BUILDING/DELIVERY: no filter (full bash)
"""
return PhasePolicy(
whitelist={
# Tool name is "shell" (ShellTool default); bash_command_filter
# gates on the same name. Using "bash" here would make the filter
# dead code and block the LLM from shell access.
PhaseState.PLANNING: frozenset({"search", "tool_search", "read_file", "shell"}),
PhaseState.BUILDING: frozenset(
{"write_file", "shell", "read_file", "search", "tool_search"}
),
PhaseState.VERIFICATION: frozenset({"shell", "read_file", "search"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
bash_command_filter={
PhaseState.PLANNING: _DEFAULT_BASH_FILTER,
PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER,
PhaseState.BUILDING: None,
PhaseState.DELIVERY: None,
},
auto_advance_after_steps=None, # manual by default
start_phase=PhaseState.PLANNING,
)
def policy_from_config(config: dict[str, Any]) -> PhasePolicy | None:
"""Build a PhasePolicy from the `plan_exec` config section.
Returns None if `config` is empty or `enabled` is False (opt-out).
Config shape:
plan_exec:
enabled: true # default true if section present
auto_advance_after_steps: 5 # optional
start_phase: planning # optional, default planning
whitelist_override: # optional, merges with default
planning: [search, read_file]
building: [write_file, bash]
"""
if not config:
return None
if config.get("enabled", True) is False:
return None
policy = default_policy()
# Start phase
start_phase_str = config.get("start_phase")
if start_phase_str:
policy = replace(policy, start_phase=PhaseState.from_string(start_phase_str))
# Auto-advance override
if "auto_advance_after_steps" in config:
policy = replace(policy, auto_advance_after_steps=config["auto_advance_after_steps"])
# Whitelist override — merge with default (override wins on conflict)
override = config.get("whitelist_override") or {}
if override:
new_whitelist = dict(policy.whitelist)
for phase_name, tools in override.items():
phase = PhaseState.from_string(phase_name)
if not isinstance(tools, list):
raise ValueError(
f"whitelist_override[{phase_name!r}] must be a list, got {type(tools).__name__}"
)
new_whitelist[phase] = frozenset(str(t) for t in tools)
policy = replace(policy, whitelist=new_whitelist)
return policy

View File

@ -28,6 +28,7 @@ from agentkit.telemetry.metrics import (
if TYPE_CHECKING: if TYPE_CHECKING:
from agentkit.core.compressor import CompressionStrategy from agentkit.core.compressor import CompressionStrategy
from agentkit.core.middleware import MiddlewareChain from agentkit.core.middleware import MiddlewareChain
from agentkit.core.phase import PhasePolicy, PhaseState
from agentkit.core.trace import TraceRecorder from agentkit.core.trace import TraceRecorder
from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.retriever import MemoryRetriever
@ -168,6 +169,9 @@ class ReActEngine:
prompt_cache_enable: bool = True, prompt_cache_enable: bool = True,
flush_interval_ms: int = 0, flush_interval_ms: int = 0,
max_reinjections: int = 1, max_reinjections: int = 1,
# U3/G6: PLAN_EXEC phase policy (opt-in). None = no enforcement
# (backward compat — all existing callers unaffected).
phase_policy: "PhasePolicy | None" = None,
): ):
if max_steps < 1: if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}") raise ValueError(f"max_steps must be >= 1, got {max_steps}")
@ -211,6 +215,15 @@ class ReActEngine:
self._loop_corrected: bool = False self._loop_corrected: bool = False
# U6: Middleware chain (parallel integration, feature flag controlled) # U6: Middleware chain (parallel integration, feature flag controlled)
self._middleware_chain = middleware_chain self._middleware_chain = middleware_chain
# U3/G6: PLAN_EXEC phase state. None = no enforcement (default).
# When set, _execute_loop checks each tool call against the current
# phase's whitelist before dispatch.
self._phase_policy = phase_policy
self._current_phase: "PhaseState | None" = (
phase_policy.start_phase if phase_policy is not None else None
)
# Steps taken in the current phase (for auto-advance safety net).
self._steps_in_phase: int = 0
def reset(self) -> None: def reset(self) -> None:
"""Reset internal state for reuse across conversations. """Reset internal state for reuse across conversations.
@ -223,6 +236,99 @@ class ReActEngine:
# This method exists for API clarity and future stateful extensions. # This method exists for API clarity and future stateful extensions.
self._loop_window.clear() self._loop_window.clear()
self._loop_corrected = False self._loop_corrected = False
# U3/G6: reset phase state to start_phase (if policy set). Each
# execute() call begins a fresh PLANNING phase.
if self._phase_policy is not None:
self._current_phase = self._phase_policy.start_phase
self._steps_in_phase = 0
# ── U3/G6: phase state machine ────────────────────────────────────
def advance_phase(self) -> "PhaseState | None":
"""Advance to the next phase. Returns the new phase, or None if
already at DELIVERY (final phase).
Called by AdvancePhaseTool when the LLM explicitly signals phase
completion. Also called by the auto-advance safety net when
``steps_in_phase >= auto_advance_after_steps``.
Returns None if no phase_policy is set (no-op).
"""
if self._phase_policy is None or self._current_phase is None:
return None
from agentkit.core.phase import PhaseState
nxt = PhaseState.next_of(self._current_phase)
if nxt is None:
# Already at DELIVERY — return None to signal no transition.
return None
previous = self._current_phase
self._current_phase = nxt
self._steps_in_phase = 0
logger.info(
"Phase transition: %s%s",
previous.value,
nxt.value,
)
return nxt
@property
def current_phase(self) -> "PhaseState | None":
"""Current phase (None if no phase_policy set)."""
return self._current_phase
def _maybe_auto_advance(self) -> bool:
"""Auto-advance phase if step budget exhausted. Returns True if advanced."""
if self._phase_policy is None or self._current_phase is None:
return False
threshold = self._phase_policy.auto_advance_after_steps
if threshold is None:
return False
if self._steps_in_phase >= threshold:
self.advance_phase()
return True
return False
def _check_phase_permission(
self, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any] | None:
"""Return None if tool is allowed; return a structured error dict if blocked.
The error dict replaces what `_execute_tool` would have returned
the loop continues, so the LLM can react to the rejection (call
AdvancePhaseTool or pick a different tool).
Also applies the bash_command_filter for `bash` tool calls.
"""
if self._phase_policy is None or self._current_phase is None:
return None
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
return {
"error": "phase_violation",
"message": (
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
f"Call `advance_phase` to move to the next phase."
),
"current_phase": self._current_phase.value,
"tool": tool_name,
"is_error": True,
}
# Bash command filter (only applies to shell tool — registered as "shell").
if tool_name == "shell":
command = str(arguments.get("command", ""))
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
return {
"error": "phase_violation",
"message": (
f"Bash command blocked in {self._current_phase.value} phase "
f"(filesystem-mutating operations not allowed during "
f"planning/verification). Command: {command[:100]}"
),
"current_phase": self._current_phase.value,
"tool": tool_name,
"is_error": True,
}
return None
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None: def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
"""检测重复工具调用模式。 """检测重复工具调用模式。
@ -498,6 +604,14 @@ class ReActEngine:
if cancellation_token is not None: if cancellation_token is not None:
cancellation_token.check() cancellation_token.check()
# U3/G6: phase auto-advance safety net.
# Incremented per step (LLM call), not per tool_call. When
# auto_advance_after_steps is set, advance the phase after
# the LLM has been stuck in the same phase for N steps.
if self._phase_policy is not None:
self._steps_in_phase += 1
self._maybe_auto_advance()
# Think: 调用 LLM # Think: 调用 LLM
llm_start = time.monotonic() llm_start = time.monotonic()
response = await self._llm_gateway.chat( response = await self._llm_gateway.chat(
@ -1148,6 +1262,11 @@ class ReActEngine:
if cancellation_token is not None: if cancellation_token is not None:
cancellation_token.check() cancellation_token.check()
# U3/G6: phase auto-advance safety net (mirrors _execute_loop).
if self._phase_policy is not None:
self._steps_in_phase += 1
self._maybe_auto_advance()
# 超时检查 # 超时检查
if effective_timeout > 0: if effective_timeout > 0:
elapsed = time.monotonic() - _stream_start elapsed = time.monotonic() - _stream_start
@ -2069,6 +2188,20 @@ class ReActEngine:
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool] self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
) -> dict: ) -> dict:
"""执行工具调用,处理成功和失败情况""" """执行工具调用,处理成功和失败情况"""
# U3/G6: phase enforcement — check before dispatch. If the tool is
# blocked, return a structured error instead of dispatching. The loop
# still counts this as a step (the LLM gets to react to the rejection).
# `advance_phase` tool bypasses the check (it's the LLM's escape hatch).
if tool_name != "advance_phase":
block = self._check_phase_permission(tool_name, arguments)
if block is not None:
logger.info(
"Phase violation: tool %r blocked in %s phase",
tool_name,
self._current_phase.value if self._current_phase else "?",
)
return block
tool = self._find_tool(tool_name, tools) tool = self._find_tool(tool_name, tools)
if tool is None: if tool is None:
error_msg = f"Tool '{tool_name}' not found" error_msg = f"Tool '{tool_name}' not found"

View File

@ -121,6 +121,9 @@ class ServerConfig:
verification: dict[str, Any] | None = None, verification: dict[str, Any] | None = None,
rollback: dict[str, Any] | None = None, rollback: dict[str, Any] | None = None,
fallback_chain: dict[str, Any] | None = None, fallback_chain: dict[str, Any] | None = None,
# G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled).
# Parsed via PhasePolicy.policy_from_config() at chat.py wiring time.
plan_exec: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None, on_change: Callable[["ServerConfig"], None] | None = None,
): ):
self.host = host self.host = host
@ -161,6 +164,10 @@ class ServerConfig:
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries} # G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
# controls three-tier chain at chat.py REST send_message (KTD5). # controls three-tier chain at chat.py REST send_message (KTD5).
self.fallback_chain = fallback_chain or {} self.fallback_chain = fallback_chain or {}
# G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled).
# Resolved to PhasePolicy via agentkit.core.phase.policy_from_config()
# at chat.py WebSocket wiring time (U4).
self.plan_exec = plan_exec or {}
self.on_change = on_change self.on_change = on_change
# Config watching state # Config watching state
@ -252,6 +259,8 @@ class ServerConfig:
rollback_data = data.get("rollback", {}) rollback_data = data.get("rollback", {})
# G7/U3: fallback_chain 配置 (从 YAML 读取) # G7/U3: fallback_chain 配置 (从 YAML 读取)
fallback_chain_data = data.get("fallback_chain", {}) fallback_chain_data = data.get("fallback_chain", {})
# G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in)
plan_exec_data = data.get("plan_exec", {})
return cls( return cls(
host=server.get("host", "0.0.0.0"), host=server.get("host", "0.0.0.0"),
@ -285,6 +294,7 @@ class ServerConfig:
verification=verification_data, verification=verification_data,
rollback=rollback_data, rollback=rollback_data,
fallback_chain=fallback_chain_data, fallback_chain=fallback_chain_data,
plan_exec=plan_exec_data,
) )
@staticmethod @staticmethod

View File

@ -25,11 +25,13 @@ from fastapi.responses import FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from agentkit.chat.skill_routing import ExecutionMode from agentkit.chat.skill_routing import ExecutionMode
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
from agentkit.core.protocol import CancellationToken from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine from agentkit.core.react import ReActEngine
from agentkit.server._fallback_chain import execute_with_fallback_chain from agentkit.server._fallback_chain import execute_with_fallback_chain
from agentkit.session.manager import SessionManager from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus from agentkit.session.models import MessageRole, SessionStatus
from agentkit.tools.advance_phase import AdvancePhaseTool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,6 +49,8 @@ class CreateSessionRequest(BaseModel):
class SendMessageRequest(BaseModel): class SendMessageRequest(BaseModel):
content: str content: str
role: str = "user" role: str = "user"
# Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only).
execution_mode: str | None = None
class SessionResponse(BaseModel): class SessionResponse(BaseModel):
@ -583,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
if session.status == SessionStatus.CLOSED: if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
if request.execution_mode == "plan_exec":
raise HTTPException(
status_code=501,
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
)
# Append user message # Append user message
await sm.append_message( await sm.append_message(
session_id=session_id, session_id=session_id,
@ -1079,21 +1090,73 @@ async def _handle_chat_message(
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
return return
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB # U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
# currently fall back to REACT with a warning. # KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT): # fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
phase_policy: PhasePolicy | None = None
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
server_config = getattr(websocket.app.state, "server_config", None)
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
if plan_exec_cfg.get("enabled", True) is False:
# Explicit opt-out → fall back to REACT.
logger.info(
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
"falling back to REACT for session %s",
session_id,
)
else:
try:
phase_policy = policy_from_config(plan_exec_cfg)
if phase_policy is None:
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
phase_policy = default_policy()
except Exception as e:
logger.error(
"PLAN_EXEC phase policy construction failed for session %s: %s",
session_id,
e,
)
await websocket.send_json(
{
"type": "error",
# Truncate to 200 chars to match nearby error paths and
# avoid leaking config internals (see chat.py:1090, 1320).
"data": {"message": f"phase policy error: {str(e)[:200]}"},
}
)
return
# Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB
# still fall back to REACT with a warning. PLAN_EXEC is handled above.
if routing.execution_mode not in (
ExecutionMode.REACT,
ExecutionMode.SKILL_REACT,
ExecutionMode.PLAN_EXEC,
):
logger.warning( logger.warning(
f"Execution mode {routing.execution_mode.value} not yet supported " f"Execution mode {routing.execution_mode.value} not yet supported "
f"in chat WebSocket, falling back to REACT" f"in chat WebSocket, falling back to REACT"
) )
# Execute Agent with streaming # Execute Agent with streaming
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
react_engine = getattr(agent, "_react_engine", None) # PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
if react_engine is None: # agent's _react_engine — it has no policy).
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) if phase_policy is not None:
react_engine = ReActEngine(
llm_gateway=websocket.app.state.llm_gateway,
phase_policy=phase_policy,
)
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
routing.tools = list(routing.tools) + [advance_phase_tool]
else: else:
react_engine.reset() react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
else:
react_engine.reset()
# Create confirmation handler that sends request to frontend and waits for reply # Create confirmation handler that sends request to frontend and waits for reply
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy # Use the same dict object — do NOT use `or {}` because an empty dict is falsy
@ -1149,6 +1212,9 @@ async def _handle_chat_message(
try: try:
final_content = "" final_content = ""
token_buffer: list[str] = [] token_buffer: list[str] = []
# Track phase transitions for phase_changed events (PLAN_EXEC only).
# For non-PLAN_EXEC modes, current_phase is always None → no events.
prev_phase = react_engine.current_phase
async for event in react_engine.execute_stream( async for event in react_engine.execute_stream(
messages=chat_messages, messages=chat_messages,
tools=routing.tools, tools=routing.tools,
@ -1226,6 +1292,22 @@ async def _handle_chat_message(
} }
) )
# U4/G6: emit phase_changed event when the phase state machine
# transitions (PLAN_EXEC only). For non-PLAN_EXEC modes,
# current_phase is always None → this branch never fires.
curr_phase = react_engine.current_phase
if curr_phase != prev_phase:
await websocket.send_json(
{
"type": "phase_changed",
"data": {
"phase": curr_phase.value if curr_phase else None,
"previous": prev_phase.value if prev_phase else None,
},
}
)
prev_phase = curr_phase
# Append assistant reply to session # Append assistant reply to session
if final_content: if final_content:
await sm.append_message( await sm.append_message(

View File

@ -18,6 +18,8 @@ from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.web_search import WebSearchTool from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
from agentkit.tools.search import ToolSearchIndex from agentkit.tools.search import ToolSearchIndex
from agentkit.tools.file_read import ReadFileTool
from agentkit.tools.advance_phase import AdvancePhaseTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -52,4 +54,6 @@ __all__ = [
"OutputParser", "OutputParser",
"ParsedOutput", "ParsedOutput",
"ErrorType", "ErrorType",
"ReadFileTool",
"AdvancePhaseTool",
] ]

View File

@ -0,0 +1,82 @@
"""AdvancePhaseTool — LLM-driven phase transition (G6, KTD6).
Registered alongside other tools when ReActEngine has a phase_policy set.
The LLM calls this tool to signal "I'm done planning, move to building".
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from agentkit.tools.base import Tool
if TYPE_CHECKING:
from agentkit.core.react import ReActEngine
logger = logging.getLogger(__name__)
class AdvancePhaseTool(Tool):
"""Tool that advances the ReActEngine's current phase.
KTD6: LLM-driven phase transitions. Auto-advance is opt-in via
``plan_exec.auto_advance_after_steps``; this tool is the manual path.
The tool holds a weak reference to the engine (via bound method
``engine.advance_phase``) registered only when phase_policy is set.
"""
def __init__(
self,
engine: "ReActEngine",
name: str = "advance_phase",
description: str | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description
or (
"Advance the PLAN_EXEC phase state machine to the next phase "
"(Planning → Building → Verification → Delivery). Call this "
"when you have finished the current phase's work and are ready "
"to move on. Returns the new phase name or an error if you "
"are already at the final (Delivery) phase."
),
input_schema={
"type": "object",
"properties": {},
"additionalProperties": False,
},
version=version,
tags=tags or ["phase", "control"],
)
self._engine = engine
async def execute(self, **kwargs) -> dict[str, Any]:
# Capture previous phase before transition (engine is single-threaded per request).
previous = self._engine.current_phase
new_phase = self._engine.advance_phase()
if new_phase is None:
# Either no policy set, or already at DELIVERY.
current = self._engine.current_phase
if current is None:
return {
"is_error": True,
"error": "no_phase_policy",
"message": "No phase policy is set — advance_phase is a no-op.",
}
return {
"is_error": True,
"error": "already_at_final_phase",
"message": (f"Already at final phase ({current.value}). Cannot advance further."),
"current_phase": current.value,
}
return {
"is_error": False,
"previous_phase": previous.value if previous else "",
"current_phase": new_phase.value,
"message": f"Phase advanced to {new_phase.value}.",
}

View File

@ -0,0 +1,262 @@
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
Backward compatible with the pre-existing `_FakeTool` benchmark shape when
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
the line range of the first matching symbol via `SymbolExtractor`.
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool keeps the
file-reading contract clean and gives the LLM a focused schema.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from agentkit.tools.base import Tool
from agentkit.tools.symbol_extractor import (
SymbolSpan,
extract_symbols_from_file,
language_for_extension,
)
logger = logging.getLogger(__name__)
class ReadFileTool(Tool):
"""Read a file from the filesystem, optionally sliced to a single symbol.
Tool name `read_file` matches the reserved entry in
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
implementation only `_FakeTool` stubs in `cli/benchmark.py`).
Backward-compat contract: `symbol=None` returns the full file content,
matching the shape `{"path": ...}` that downstream callers (benchmark,
phase whitelist) already expect.
"""
def __init__(
self,
name: str = "read_file",
description: str | None = None,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description
or (
"Read a file from the filesystem. By default returns the full file "
"content. Pass `symbol` (function/class/struct name) to slice to just "
"that symbol's line range — saves context when you only need one "
"function from a large file. Pass `start_line`/`end_line` for manual "
"slicing. If `symbol` is set but not found, returns the available "
"symbol names so you can retry."
),
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["io", "file", "read"],
)
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read (absolute or relative to cwd).",
},
"symbol": {
"type": "string",
"description": (
"Optional: name of a function/class/struct/method to slice to. "
"When set, returns only the line range of the first matching "
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
),
},
"start_line": {
"type": "integer",
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
"end_line": {
"type": "integer",
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
},
"required": ["path"],
"additionalProperties": False,
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"content": {"type": "string"},
"path": {"type": "string"},
"start_line": {"type": "integer"},
"end_line": {"type": "integer"},
"symbol": {"type": "string"},
"available_symbols": {
"type": "array",
"items": {"type": "string"},
"description": "Populated when `symbol` is set but not found.",
},
"note": {"type": "string"},
"is_error": {"type": "boolean"},
"error": {"type": "string"},
},
}
async def execute(self, **kwargs) -> dict[str, Any]:
raw_path = kwargs.get("path")
if not raw_path:
return self._error("`path` is required")
path = Path(raw_path)
if not path.is_absolute():
path = path.resolve()
symbol = kwargs.get("symbol")
start_line = kwargs.get("start_line")
end_line = kwargs.get("end_line")
# Validate/sanitize line overrides.
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
if start_line is not None and end_line is not None and end_line < start_line:
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
# Filesystem checks.
if not path.exists():
return self._error(f"File not found: {path}", path=str(path))
if path.is_dir():
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
try:
content = path.read_text(encoding="utf-8", errors="replace")
except PermissionError as e:
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
except OSError as e:
return self._error(f"Failed to read {path}: {e}", path=str(path))
lines = content.splitlines()
total_lines = len(lines)
# Manual slicing takes precedence over symbol (per plan U1 Approach).
if start_line is not None or end_line is not None:
s = max(1, start_line or 1)
e = min(total_lines, end_line or total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Symbol slicing.
if symbol:
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
# Unsupported extension: return full file with note (per plan U1 Edge case).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
"is_error": False,
}
spans, _lang = extract_symbols_from_file(path)
# Re-extract using the content we already read so we don't read the file twice.
if not spans:
# Try extraction from in-memory content (path-based extraction may
# have failed silently on OSError; we already read it successfully).
from agentkit.tools.symbol_extractor import get_extractor
extractor = get_extractor(language)
if extractor is not None:
spans = extractor.extract_symbols(content, language)
match = _find_symbol(spans, symbol)
if match is None:
available = sorted({s.name for s in spans})
return {
"content": "",
"path": str(path),
"symbol": symbol,
"available_symbols": available,
"is_error": False,
"note": (
f"Symbol {symbol!r} not found in {path.name}. "
f"Available: {', '.join(available) if available else '(none)'}"
),
}
s = match.start_line
e = min(match.end_line, total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"symbol": symbol,
"symbol_kind": match.kind,
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Default: full file (characterization baseline — matches _FakeTool shape).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"is_error": False,
}
@staticmethod
def _error(
message: str, *, path: str | None = None, detail: str | None = None
) -> dict[str, Any]:
result: dict[str, Any] = {
"content": "",
"is_error": True,
"error": message,
}
if path is not None:
result["path"] = path
if detail is not None:
result["detail"] = detail
return result
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
"""Find the first symbol matching `name`. Case-sensitive.
ponytail: linear scan is fine for typical file symbol counts (<100). The
extractor already returns symbols sorted by start_line; first match wins
for ambiguous overloads (e.g., Python classes with same name in different
modules not relevant within one file).
"""
for span in spans:
if span.name == name:
return span
return None

View File

@ -0,0 +1,278 @@
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
`SymbolExtractor` protocol is the upgrade seam a future TreeSitterSymbolExtractor
can replace RegexSymbolExtractor behind the same interface.
ponytail: regex extractor covers ~80% case (top-level function/class/struct
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
"""
from __future__ import annotations
import ast
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class SymbolSpan:
"""A located symbol — name, kind, and 1-based inclusive line range."""
name: str
kind: str # "function" | "class" | "method" | "struct" | "impl"
start_line: int # 1-based, inclusive
end_line: int # 1-based, inclusive
@runtime_checkable
class SymbolExtractor(Protocol):
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
"""Return all symbols found in `content`.
`language` is the file extension without leading dot (e.g. "py", "ts").
Implementations must never raise on extraction failure return [] on
parse errors and let the caller decide the fallback (full-file read).
"""
...
# ---------------------------------------------------------------------------
# Python — stdlib ast
# ---------------------------------------------------------------------------
class AstSymbolExtractor:
"""Python symbol extractor using the stdlib `ast` module.
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
functions inside classes. The end_line is the last line of the node's
source segment (decorator-inclusive).
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
if language != "py":
return []
try:
tree = ast.parse(content)
except SyntaxError as e:
logger.debug("ast.parse failed: %s", e)
return []
lines = content.splitlines()
spans: list[SymbolSpan] = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
kind = "method" if _is_method(node) else "function"
spans.append(_span_from_node(node, kind, lines))
elif isinstance(node, ast.ClassDef):
spans.append(_span_from_node(node, "class", lines))
return spans
def _is_method(node: ast.AST) -> bool:
"""A FunctionDef is a method if its parent is a ClassDef.
`ast.walk` doesn't expose parentage, so we approximate by checking the
node's col_offset == 4 (indented inside a class body). ponytail: this
misses methods in deeply nested classes ceiling noted; upgrade path =
ast.NodeVisitor with parent tracking.
"""
return getattr(node, "col_offset", 0) > 0
def _span_from_node(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
kind: str,
lines: list[str],
) -> SymbolSpan:
# ast line numbers are 1-based; start at decorator if present (lineno points
# to the def/class keyword, decorators are above). Use node.lineno for start
# so the returned range matches what the user sees at the def keyword.
start = node.lineno
# node.end_lineno is the last line of the node body (None on old Pythons).
end = node.end_lineno or start
# Clamp to actual file length (defensive — ast should not exceed, but
# malformed files with no trailing newline can confuse end_lineno).
if end > len(lines):
end = len(lines)
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
# ---------------------------------------------------------------------------
# Regex extractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
# Each pattern matches a declaration and captures the symbol name in group 1.
# Patterns use re.MULTILINE so ^ matches line starts.
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
"ts": [
(
"function",
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
),
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
re.MULTILINE,
),
),
],
"js": [
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
),
),
],
"go": [
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
],
"rs": [
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
],
"java": [
(
"function",
re.compile(
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
re.MULTILINE,
),
),
(
"class",
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
),
],
}
class RegexSymbolExtractor:
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
Returns SymbolSpans whose end_line is approximated by the next blank line
or next-symbol start (whichever comes first). ponytail: this is an
approximation true block-end requires language-aware brace matching.
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
tree-sitter.
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
patterns = _REGEX_PATTERNS.get(language)
if not patterns:
return []
lines = content.splitlines()
# Collect (line_no, name, kind) tuples first, then compute end_line
# as the line before the next symbol starts (or EOF).
raw_hits: list[tuple[int, str, str]] = []
for kind, pattern in patterns:
for m in pattern.finditer(content):
# Convert match offset to 1-based line number.
line_no = content[: m.start()].count("\n") + 1
raw_hits.append((line_no, m.group(1), kind))
if not raw_hits:
return []
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
seen: set[tuple[int, str]] = set()
unique: list[tuple[int, str, str]] = []
for line_no, name, kind in raw_hits:
key = (line_no, name)
if key in seen:
continue
seen.add(key)
unique.append((line_no, name, kind))
unique.sort(key=lambda x: x[0])
spans: list[SymbolSpan] = []
for i, (start_line, name, kind) in enumerate(unique):
if i + 1 < len(unique):
# End at line before next symbol starts, capped at file length.
end_line = unique[i + 1][0] - 1
else:
end_line = len(lines)
if end_line < start_line:
end_line = start_line
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
return spans
# ---------------------------------------------------------------------------
# Dispatch by file extension
# ---------------------------------------------------------------------------
_EXTENSION_LANGUAGE: dict[str, str] = {
".py": "py",
".ts": "ts",
".tsx": "ts",
".js": "js",
".jsx": "js",
".mjs": "js",
".cjs": "js",
".go": "go",
".rs": "rs",
".java": "java",
}
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
_REGEX_EXTRACTOR = RegexSymbolExtractor()
def language_for_extension(ext: str) -> str:
"""Return the language key for a file extension (with or without leading dot).
Returns "" for unsupported extensions.
"""
if not ext.startswith("."):
ext = "." + ext
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
def get_extractor(language: str) -> SymbolExtractor | None:
"""Return the appropriate extractor for `language`, or None if unsupported."""
if language == "py":
return _DEFAULT_EXTRACTOR
if language in _REGEX_PATTERNS:
return _REGEX_EXTRACTOR
return None
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
"""Read a file and return (symbols, language).
Returns ([], "") if the extension is unsupported or the file cannot be read.
Never raises callers use this for fallback routing.
"""
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
return [], ""
try:
content = path.read_text(encoding="utf-8", errors="replace")
except OSError as e:
logger.debug("read failed for %s: %s", path, e)
return [], language
extractor = get_extractor(language)
if extractor is None:
return [], language
return extractor.extract_symbols(content, language), language

View File

@ -0,0 +1,531 @@
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
Per plan U4 Execution note: characterization-first verify that existing
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
(no regression). Then add PLAN_EXEC wiring tests.
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
from agentkit.core.phase import PhaseState
from agentkit.tools.advance_phase import AdvancePhaseTool
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def app_with_chat():
"""Create a FastAPI app with Chat routes and mocked dependencies."""
from fastapi import FastAPI
from agentkit.server.routes.chat import router
app = FastAPI()
app.include_router(router, prefix="/api/v1")
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
app.state.session_manager = SessionManager(store=InMemorySessionStore())
app.state.llm_gateway = MagicMock()
app.state.agent_pool = MagicMock()
app.state.server_config = MagicMock()
app.state.server_config.api_key = None
app.state.server_config.plan_exec = {}
return app
@pytest.fixture
def client(app_with_chat):
return TestClient(app_with_chat)
def _make_routing(
execution_mode: ExecutionMode = ExecutionMode.REACT,
tools: list | None = None,
) -> SkillRoutingResult:
"""Build a minimal SkillRoutingResult for testing."""
return SkillRoutingResult(
execution_mode=execution_mode,
tools=tools or [],
clean_content="test message",
model="default",
agent_name="test-agent",
system_prompt=None,
skill_name=None,
)
def _make_websocket_mock(app) -> MagicMock:
"""Build a mock WebSocket with app.state and async send_json."""
ws = MagicMock()
ws.app = app
ws.send_json = AsyncMock()
return ws
def _make_agent_mock() -> MagicMock:
"""Build a mock Agent with _tool_registry and _react_engine."""
agent = MagicMock()
agent.name = "test-agent"
agent._tool_registry = MagicMock()
agent._tool_registry.list_tools.return_value = []
agent._system_prompt = None
# _react_engine is None to force the code path that creates a new engine
agent._react_engine = None
agent.get_model.return_value = "default"
return agent
def _make_session_manager_mock() -> MagicMock:
"""Build a mock SessionManager with async methods."""
sm = MagicMock()
# get_session returns a mock session with agent_name="test-agent"
session = MagicMock()
session.agent_name = "test-agent"
session.status = "active"
sm.get_session = AsyncMock(return_value=session)
sm.get_chat_messages = AsyncMock(return_value=[])
sm.append_message = AsyncMock()
return sm
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
"""Wire up app.state so _handle_chat_message finds the right routing."""
app.state.agent_pool.get_agent.return_value = agent
app.state.request_preprocessor = MagicMock()
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
# ---------------------------------------------------------------------------
# REST — PLAN_EXEC raises 501 (KTD4)
# ---------------------------------------------------------------------------
class TestRestPlanExec501:
def test_rest_plan_exec_returns_501(self, client):
"""REST send_message with execution_mode=plan_exec → 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello", "execution_mode": "plan_exec"},
)
assert msg_resp.status_code == 501
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
def test_rest_react_mode_still_works(self, client):
"""REST send_message without execution_mode doesn't 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
# No execution_mode field → should NOT trigger 501.
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello"},
)
assert msg_resp.status_code != 501
# ---------------------------------------------------------------------------
# Characterization — REWOO still falls back to REACT (no regression)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield # async generator marker
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# REWOO should NOT build a phase_policy
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
app_with_chat,
):
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
ws = _make_websocket_mock(app_with_chat)
captured_engine: list = []
captured_tools: list = []
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
captured_tools.extend(kwargs.get("tools", []))
captured_engine.append(self)
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_engine) == 1
engine = captured_engine[0]
assert engine._phase_policy is not None
assert engine._current_phase == PhaseState.PLANNING
# AdvancePhaseTool was registered in the tools list
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_policy: list = []
class _StubEngine:
def __init__(self, **kwargs):
captured_policy.append(kwargs.get("phase_policy"))
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_policy) == 1
assert captured_policy[0] is not None
# Default policy: PLANNING allows search but not write_file
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
@pytest.mark.asyncio
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"enabled": False}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# enabled=False → no phase_policy (falls back to REACT)
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Error path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
"""Error: phase policy construction fails → error event sent, returns early."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
engine_constructor_called = []
class _StubEngine:
def __init__(self, **kwargs):
engine_constructor_called.append(kwargs)
async def execute_stream(self, **kwargs):
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
error_messages = [m for m in sent_messages if m.get("type") == "error"]
assert len(error_messages) == 1
assert "phase policy error" in error_messages[0]["data"]["message"]
# Engine constructor was NOT called (returned early)
assert len(engine_constructor_called) == 0
# ---------------------------------------------------------------------------
# phase_changed event emission
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
"""phase_changed event sent when current_phase changes during execute_stream."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = PhaseState.PLANNING
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(
event_type="tool_call",
step=1,
data={"tool": "search", "output": "ok"},
)
# Simulate phase transition (as if AdvancePhaseTool was called)
self._current_phase = PhaseState.BUILDING
yield ReActEvent(
event_type="final_answer",
step=2,
data={"output": "done"},
)
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="go",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 1
assert phase_events[0]["data"]["phase"] == "building"
assert phase_events[0]["data"]["previous"] == "planning"
@pytest.mark.asyncio
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
"""Characterization: REACT mode → no phase_changed events."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REACT)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = None
self._current_phase = None
@property
def current_phase(self):
return None
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="hi",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 0

View File

@ -0,0 +1,348 @@
"""Unit tests for PhasePolicy + PhaseState (G6 core, R24/R25/R26).
Covers:
- PhaseState enum (next_of, from_string)
- default_policy() KTD5 whitelist
- PhasePolicy.is_tool_allowed / is_bash_command_allowed
- policy_from_config parsing (R26 config-driven)
- ServerConfig.plan_exec integration
"""
from __future__ import annotations
import re
import pytest
from agentkit.core.phase import (
WILDCARD,
PhasePolicy,
PhaseState,
default_policy,
policy_from_config,
)
from agentkit.server.config import ServerConfig
# ---------------------------------------------------------------------------
# PhaseState enum
# ---------------------------------------------------------------------------
class TestPhaseState:
def test_values(self):
assert PhaseState.PLANNING.value == "planning"
assert PhaseState.BUILDING.value == "building"
assert PhaseState.VERIFICATION.value == "verification"
assert PhaseState.DELIVERY.value == "delivery"
def test_next_of(self):
assert PhaseState.next_of(PhaseState.PLANNING) == PhaseState.BUILDING
assert PhaseState.next_of(PhaseState.BUILDING) == PhaseState.VERIFICATION
assert PhaseState.next_of(PhaseState.VERIFICATION) == PhaseState.DELIVERY
assert PhaseState.next_of(PhaseState.DELIVERY) is None
def test_from_string_case_insensitive(self):
assert PhaseState.from_string("planning") == PhaseState.PLANNING
assert PhaseState.from_string("PLANNING") == PhaseState.PLANNING
assert PhaseState.from_string("Building") == PhaseState.BUILDING
def test_from_string_invalid_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
PhaseState.from_string("unknown")
with pytest.raises(ValueError, match="Valid:"):
PhaseState.from_string("exploration")
# ---------------------------------------------------------------------------
# default_policy() — KTD5 whitelist
# ---------------------------------------------------------------------------
class TestDefaultPolicy:
def test_has_all_four_phases(self):
policy = default_policy()
assert PhaseState.PLANNING in policy.whitelist
assert PhaseState.BUILDING in policy.whitelist
assert PhaseState.VERIFICATION in policy.whitelist
assert PhaseState.DELIVERY in policy.whitelist
def test_planning_whitelist_matches_r24(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.PLANNING]
assert "search" in allowed
assert "read_file" in allowed
assert "shell" in allowed
assert "tool_search" in allowed
# Planning must NOT allow write_file.
assert "write_file" not in allowed
def test_building_whitelist_includes_write_file(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.BUILDING]
assert "write_file" in allowed
assert "shell" in allowed
assert "read_file" in allowed
def test_verification_whitelist_excludes_write(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.VERIFICATION]
assert "shell" in allowed
assert "read_file" in allowed
assert "write_file" not in allowed
def test_delivery_wildcard(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.DELIVERY]
assert WILDCARD in allowed
def test_start_phase_default_planning(self):
assert default_policy().start_phase == PhaseState.PLANNING
def test_auto_advance_default_none(self):
# KTD6: manual by default.
assert default_policy().auto_advance_after_steps is None
def test_bash_filter_blocks_rm_in_planning(self):
policy = default_policy()
assert policy.is_bash_command_allowed("ls -la", PhaseState.PLANNING) is True
assert policy.is_bash_command_allowed("git status", PhaseState.PLANNING) is True
assert policy.is_bash_command_allowed("rm -rf /tmp/x", PhaseState.PLANNING) is False
assert policy.is_bash_command_allowed("echo x > file.txt", PhaseState.PLANNING) is False
def test_bash_filter_no_restriction_in_building(self):
policy = default_policy()
assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True
assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True
# ---------------------------------------------------------------------------
# PhasePolicy — is_tool_allowed
# ---------------------------------------------------------------------------
class TestIsToolAllowed:
def test_planning_allows_search(self):
policy = default_policy()
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
def test_planning_blocks_write_file(self):
policy = default_policy()
assert policy.is_tool_allowed("write_file", PhaseState.PLANNING) is False
def test_building_allows_write_file(self):
policy = default_policy()
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
def test_delivery_wildcard_allows_anything(self):
policy = default_policy()
assert policy.is_tool_allowed("any_random_tool", PhaseState.DELIVERY) is True
assert policy.is_tool_allowed("write_file", PhaseState.DELIVERY) is True
def test_unknown_phase_returns_false(self):
# ponytail: unknown phase → empty whitelist → no tool allowed.
# We can't construct an unknown PhaseState (enum), but if a phase
# were missing from the whitelist dict, is_tool_allowed should
# return False (defensive).
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"search"}),
PhaseState.BUILDING: frozenset({"write_file"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
# BUILDING is in whitelist, so allowed checks work normally.
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
# Phase missing from whitelist would return False (defensive .get default).
# We test this by constructing a minimal policy.
minimal = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
# VERIFICATION is in whitelist — wildcard allows all.
assert minimal.is_tool_allowed("anything", PhaseState.VERIFICATION) is True
# ---------------------------------------------------------------------------
# PhasePolicy — edge cases & errors
# ---------------------------------------------------------------------------
class TestPhasePolicyEdgeCases:
def test_empty_whitelist_raises(self):
# Fail-fast: an empty whitelist for a non-wildcard phase is a bug.
with pytest.raises(ValueError, match="empty whitelist"):
PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset(), # empty!
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
def test_wildcard_only_does_not_raise(self):
# Wildcard-only whitelist is valid (means "all tools").
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
assert policy.is_tool_allowed("anything", PhaseState.PLANNING) is True
def test_to_dict_serializable(self):
policy = default_policy()
d = policy.to_dict()
assert "whitelist" in d
assert "planning" in d["whitelist"]
assert "delivery" in d["whitelist"]
assert d["start_phase"] == "planning"
assert d["auto_advance_after_steps"] is None
def test_custom_bash_filter(self):
custom_filter = re.compile(r"\b(pip install|npm install)\b")
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"shell"}),
PhaseState.BUILDING: frozenset({"shell"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
bash_command_filter={PhaseState.BUILDING: custom_filter},
)
assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False
assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True
# ---------------------------------------------------------------------------
# policy_from_config — R26 (config-driven)
# ---------------------------------------------------------------------------
class TestPolicyFromConfig:
def test_empty_config_returns_none(self):
assert policy_from_config({}) is None
def test_enabled_false_returns_none(self):
# Opt-out — explicit `enabled: false` disables policy.
result = policy_from_config({"enabled": False})
assert result is None
def test_enabled_default_true_when_section_present(self):
# When section is present but `enabled` is missing, default is True.
result = policy_from_config({"auto_advance_after_steps": 3})
assert result is not None
assert result.auto_advance_after_steps == 3
def test_auto_advance_after_steps(self):
policy = policy_from_config({"enabled": True, "auto_advance_after_steps": 5})
assert policy is not None
assert policy.auto_advance_after_steps == 5
def test_start_phase_custom(self):
policy = policy_from_config({"enabled": True, "start_phase": "building"})
assert policy is not None
assert policy.start_phase == PhaseState.BUILDING
def test_start_phase_invalid_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
policy_from_config({"enabled": True, "start_phase": "unknown"})
def test_whitelist_override_merges_with_default(self):
policy = policy_from_config(
{
"enabled": True,
"whitelist_override": {
"planning": ["search", "read_file"], # removes shell from default
},
}
)
assert policy is not None
# Override wins — shell should be removed from planning.
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
assert policy.is_tool_allowed("read_file", PhaseState.PLANNING) is True
assert policy.is_tool_allowed("shell", PhaseState.PLANNING) is False
# Other phases unchanged.
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
def test_whitelist_override_invalid_phase_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
policy_from_config(
{
"enabled": True,
"whitelist_override": {"unknown_phase": ["tool"]},
}
)
def test_whitelist_override_non_list_raises(self):
with pytest.raises(ValueError, match="must be a list"):
policy_from_config(
{
"enabled": True,
"whitelist_override": {"planning": "not a list"},
}
)
def test_to_dict_round_trip_via_default(self):
# Sanity: default policy serializes to a dict with expected keys.
policy = default_policy()
d = policy.to_dict()
assert set(d["whitelist"].keys()) == {
"planning",
"building",
"verification",
"delivery",
}
# ---------------------------------------------------------------------------
# ServerConfig.plan_exec integration (R26)
# ---------------------------------------------------------------------------
class TestServerConfigPlanExec:
def test_default_plan_exec_empty(self):
config = ServerConfig.from_dict({})
assert config.plan_exec == {}
def test_plan_exec_loaded_from_dict(self):
config = ServerConfig.from_dict(
{
"plan_exec": {
"enabled": True,
"auto_advance_after_steps": 5,
}
}
)
assert config.plan_exec == {"enabled": True, "auto_advance_after_steps": 5}
def test_plan_exec_empty_dict_default(self):
config = ServerConfig.from_dict({"plan_exec": {}})
assert config.plan_exec == {}
def test_plan_exec_resolved_to_policy(self):
# Wire the config dict through policy_from_config to verify integration.
config = ServerConfig.from_dict(
{
"plan_exec": {
"enabled": True,
"auto_advance_after_steps": 3,
}
}
)
policy = policy_from_config(config.plan_exec)
assert policy is not None
assert policy.auto_advance_after_steps == 3
def test_plan_exec_disabled_via_config(self):
config = ServerConfig.from_dict({"plan_exec": {"enabled": False}})
policy = policy_from_config(config.plan_exec)
assert policy is None

View File

@ -0,0 +1,339 @@
"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24).
Per plan U3 Execution note: characterization-first verify that
`ReActEngine(phase_policy=None)` behaves identically to pre-change (no
enforcement, no advance_phase tool, no _current_phase mutation). Then add
enforcement tests.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.phase import PhasePolicy, PhaseState, default_policy
from agentkit.core.react import ReActEngine
from agentkit.tools.advance_phase import AdvancePhaseTool
# ---------------------------------------------------------------------------
# Characterization — phase_policy=None preserves existing behavior
# ---------------------------------------------------------------------------
class TestCharacterizationNoPolicy:
"""When phase_policy=None, no enforcement happens and behavior matches
pre-Wave-3."""
def test_init_without_phase_policy(self):
# Minimal stub LLM gateway — we're only testing constructor.
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
assert engine._phase_policy is None
assert engine._current_phase is None
assert engine._steps_in_phase == 0
assert engine.current_phase is None
@pytest.mark.asyncio
async def test_execute_tool_dispatches_without_phase_check(self):
"""Tool dispatch proceeds normally when no policy set."""
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
# MagicMock.name is a special attribute used internally by Mock for
# repr — setting it post-construction does not make mock.name == "x"
# hold. Patch _find_tool directly to bypass the name lookup.
fake_tool = MagicMock()
fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
fake_tool.input_schema = None
engine._find_tool = lambda name, tools: fake_tool
result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool])
assert result == {"output": "ok"}
fake_tool.safe_execute.assert_awaited_once_with(x=1)
@pytest.mark.asyncio
async def test_advance_phase_returns_none_without_policy(self):
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
assert engine.advance_phase() is None
def test_reset_does_not_touch_phase_state_when_no_policy(self):
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
engine.reset()
assert engine._current_phase is None
# ---------------------------------------------------------------------------
# Initialization with phase_policy
# ---------------------------------------------------------------------------
class TestPhasePolicyInitialization:
def test_phase_policy_set_initializes_current_phase(self):
gateway = MagicMock()
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=default_policy(),
)
assert engine._phase_policy is not None
assert engine._current_phase == PhaseState.PLANNING
assert engine._steps_in_phase == 0
def test_reset_resets_phase_to_start(self):
gateway = MagicMock()
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=default_policy(),
)
# Manually move phase forward (simulating execute progress).
engine.advance_phase() # PLANNING → BUILDING
assert engine._current_phase == PhaseState.BUILDING
engine._steps_in_phase = 5
engine.reset()
assert engine._current_phase == PhaseState.PLANNING
assert engine._steps_in_phase == 0
# ---------------------------------------------------------------------------
# advance_phase() transitions
# ---------------------------------------------------------------------------
class TestAdvancePhase:
@pytest.fixture
def engine(self):
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
def test_planning_to_building(self, engine):
new_phase = engine.advance_phase()
assert new_phase == PhaseState.BUILDING
assert engine.current_phase == PhaseState.BUILDING
assert engine._steps_in_phase == 0 # counter reset on transition
def test_building_to_verification(self, engine):
engine.advance_phase() # → BUILDING
new_phase = engine.advance_phase()
assert new_phase == PhaseState.VERIFICATION
assert engine.current_phase == PhaseState.VERIFICATION
def test_verification_to_delivery(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
new_phase = engine.advance_phase()
assert new_phase == PhaseState.DELIVERY
assert engine.current_phase == PhaseState.DELIVERY
def test_delivery_returns_none(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
engine.advance_phase() # → DELIVERY
result = engine.advance_phase()
assert result is None
assert engine.current_phase == PhaseState.DELIVERY
# ---------------------------------------------------------------------------
# _check_phase_permission — whitelist enforcement
# ---------------------------------------------------------------------------
class TestPhasePermission:
@pytest.fixture
def engine(self):
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
def test_search_allowed_in_planning(self, engine):
assert engine._check_phase_permission("search", {}) is None
def test_write_file_blocked_in_planning(self, engine):
result = engine._check_phase_permission("write_file", {})
assert result is not None
assert result["error"] == "phase_violation"
assert "write_file" in result["message"]
assert result["current_phase"] == "planning"
def test_write_file_allowed_in_building(self, engine):
engine.advance_phase() # → BUILDING
assert engine._check_phase_permission("write_file", {}) is None
def test_any_tool_allowed_in_delivery(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
engine.advance_phase() # → DELIVERY
assert engine._check_phase_permission("literally_anything", {}) is None
def test_bash_command_filter_blocks_rm_in_planning(self, engine):
result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
assert result is not None
assert result["error"] == "phase_violation"
assert "rm" in result["message"] or "Bash command" in result["message"]
def test_bash_command_filter_allows_safe_in_planning(self, engine):
# `ls` and `git status` are not blocked.
assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None
assert engine._check_phase_permission("shell", {"command": "git status"}) is None
def test_bash_command_filter_no_restriction_in_building(self, engine):
engine.advance_phase() # → BUILDING
# `rm` is allowed in building phase.
assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None
# ---------------------------------------------------------------------------
# _execute_tool integration — phase enforcement actually blocks dispatch
# ---------------------------------------------------------------------------
class TestExecuteToolPhaseEnforcement:
@pytest.fixture
def engine_with_tools(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
# Two fake tools: one allowed in PLANNING (search), one not (write_file).
# MagicMock.name can't be set post-construction (special attribute),
# so we patch _find_tool with a dict-based lookup.
search_tool = MagicMock()
search_tool.input_schema = None
search_tool.safe_execute = AsyncMock(return_value={"results": []})
write_tool = MagicMock()
write_tool.input_schema = None
write_tool.safe_execute = AsyncMock(return_value={"written": True})
tools_by_name = {"search": search_tool, "write_file": write_tool}
engine._find_tool = lambda name, tools: tools_by_name.get(name)
return engine, [search_tool, write_tool]
@pytest.mark.asyncio
async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools):
engine, tools = engine_with_tools
# write_file in PLANNING should be blocked — write_tool.safe_execute
# should NEVER be called.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result["error"] == "phase_violation"
assert result["current_phase"] == "planning"
write_tool = tools[1]
write_tool.safe_execute.assert_not_called()
@pytest.mark.asyncio
async def test_allowed_tool_dispatches_normally(self, engine_with_tools):
engine, tools = engine_with_tools
result = await engine._execute_tool("search", {"query": "foo"}, tools)
assert result == {"results": []}
search_tool = tools[0]
search_tool.safe_execute.assert_awaited_once_with(query="foo")
@pytest.mark.asyncio
async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools):
engine, tools = engine_with_tools
# First: write_file blocked in PLANNING.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result["error"] == "phase_violation"
# Advance to BUILDING.
engine.advance_phase()
# Now: write_file allowed.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result == {"written": True}
# ---------------------------------------------------------------------------
# Auto-advance safety net (KTD6)
# ---------------------------------------------------------------------------
class TestAutoAdvance:
def test_auto_advance_after_threshold(self):
# Custom policy with auto-advance after 2 steps.
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"search"}),
PhaseState.BUILDING: frozenset({"write_file"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({"*"}),
},
auto_advance_after_steps=2,
)
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy)
assert engine.current_phase == PhaseState.PLANNING
# Step 1: counter goes to 1, no advance yet.
engine._steps_in_phase += 1
assert engine._maybe_auto_advance() is False
assert engine.current_phase == PhaseState.PLANNING
# Step 2: counter hits 2, advance triggered.
engine._steps_in_phase += 1
assert engine._maybe_auto_advance() is True
assert engine.current_phase == PhaseState.BUILDING
assert engine._steps_in_phase == 0 # reset on advance
def test_auto_advance_none_default(self):
# default_policy has auto_advance_after_steps=None — no auto-advance.
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
engine._steps_in_phase = 100
assert engine._maybe_auto_advance() is False
assert engine.current_phase == PhaseState.PLANNING
# ---------------------------------------------------------------------------
# AdvancePhaseTool integration
# ---------------------------------------------------------------------------
class TestAdvancePhaseTool:
@pytest.mark.asyncio
async def test_advance_phase_tool_transitions_engine(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is False
assert result["current_phase"] == "building"
assert engine.current_phase == PhaseState.BUILDING
@pytest.mark.asyncio
async def test_advance_phase_tool_at_delivery_returns_error(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
# Walk through all phases.
engine.advance_phase() # PLANNING → BUILDING
engine.advance_phase() # BUILDING → VERIFICATION
engine.advance_phase() # VERIFICATION → DELIVERY
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is True
assert result["error"] == "already_at_final_phase"
assert result["current_phase"] == "delivery"
@pytest.mark.asyncio
async def test_advance_phase_tool_without_policy_returns_error(self):
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is True
assert result["error"] == "no_phase_policy"
def test_tool_schema_accepts_no_arguments(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
tool = AdvancePhaseTool(engine=engine)
# input_schema has empty properties + additionalProperties:false —
# no arguments expected.
assert tool.input_schema["properties"] == {}
assert tool.input_schema["additionalProperties"] is False
def test_tool_bypasses_phase_check(self):
"""`advance_phase` is the LLM's escape hatch — must never be blocked."""
# _check_phase_permission should NOT block advance_phase even in PLANNING.
# The bypass is implemented in _execute_tool by name check.
# We verify the bypass indirectly: tool dispatches normally even in
# PLANNING (where only search/read_file/bash/tool_search are allowed).
# advance_phase is not in the whitelist, but the name-based bypass
# in _execute_tool lets it through.
# (Direct unit test of the bypass would require mocking _find_tool.)
# Sanity: advance_phase is not in any whitelist.
for phase, allowed in default_policy().whitelist.items():
assert "advance_phase" not in allowed, (
f"advance_phase must not be in {phase.value} whitelist"
)

View File

@ -0,0 +1,367 @@
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
Per plan U1 Execution note: characterization-first assert that
`symbol=None` returns the full file content (matches pre-existing benchmark
`_FakeTool` shape) before adding symbol-extraction behavior.
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.file_read import ReadFileTool
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
class TestReadFileToolSchema:
def test_name_is_read_file(self):
tool = ReadFileTool()
assert tool.name == "read_file"
def test_required_path(self):
tool = ReadFileTool()
assert "path" in tool.input_schema["required"]
assert "path" in tool.input_schema["properties"]
def test_optional_symbol_and_lines(self):
tool = ReadFileTool()
props = tool.input_schema["properties"]
assert "symbol" in props
assert "start_line" in props
assert "end_line" in props
# None of the optional fields should be in `required`.
required = set(tool.input_schema["required"])
assert required == {"path"}
def test_additional_properties_false(self):
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
tool = ReadFileTool()
assert tool.input_schema.get("additionalProperties") is False
def test_tags_contain_io_and_read(self):
tool = ReadFileTool()
assert "io" in tool.tags
assert "read" in tool.tags
# ---------------------------------------------------------------------------
# Characterization — symbol=None returns full file
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_py_file(tmp_path):
path = tmp_path / "sample.py"
path.write_text(
textwrap.dedent('''
"""Sample module."""
def my_func():
return 42
class MyClass:
attr = 1
def method_a(self):
return self.attr
''').lstrip(),
encoding="utf-8",
)
return path
@pytest.fixture
def sample_ts_file(tmp_path):
path = tmp_path / "sample.ts"
path.write_text(
textwrap.dedent('''
export function renderComponent(): JSX.Element {
return <div/>;
}
export class BaseService {
abstract run(): void;
}
''').lstrip(),
encoding="utf-8",
)
return path
class TestCharacterizationFullFile:
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
@pytest.mark.asyncio
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["is_error"] is False
assert result["path"] == str(sample_py_file)
assert result["start_line"] == 1
assert result["end_line"] == result["total_lines"]
assert "def my_func" in result["content"]
assert "class MyClass" in result["content"]
assert result["content"].startswith('"""Sample module."""')
@pytest.mark.asyncio
async def test_full_file_includes_all_lines(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["total_lines"] >= 8
assert result["content"].count("\n") >= result["total_lines"] - 1
# ---------------------------------------------------------------------------
# Symbol slicing — happy paths
# ---------------------------------------------------------------------------
class TestSymbolSlicing:
@pytest.mark.asyncio
async def test_python_function(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
assert result["is_error"] is False
assert result["symbol"] == "my_func"
assert result["symbol_kind"] == "function"
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
# Should NOT include the class below.
assert "class MyClass" not in result["content"]
@pytest.mark.asyncio
async def test_python_class_includes_method(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
assert result["is_error"] is False
assert result["symbol"] == "MyClass"
assert result["symbol_kind"] == "class"
assert "class MyClass" in result["content"]
assert "def method_a" in result["content"] # method included
@pytest.mark.asyncio
async def test_python_method_directly(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
assert result["is_error"] is False
assert result["symbol"] == "method_a"
assert result["symbol_kind"] == "method"
assert "def method_a" in result["content"]
@pytest.mark.asyncio
async def test_typescript_function(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
assert result["is_error"] is False
assert result["symbol"] == "renderComponent"
assert "renderComponent" in result["content"]
# Should not include the class below.
assert "BaseService" not in result["content"]
@pytest.mark.asyncio
async def test_typescript_class(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
assert result["is_error"] is False
assert result["symbol"] == "BaseService"
assert result["symbol_kind"] == "class"
assert "BaseService" in result["content"]
# ---------------------------------------------------------------------------
# Symbol slicing — edge cases
# ---------------------------------------------------------------------------
class TestSymbolSlicingEdgeCases:
@pytest.mark.asyncio
async def test_symbol_not_found_lists_available(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
assert result["is_error"] is False # soft error, not hard
assert result["content"] == ""
assert result["symbol"] == "nonexistent"
available = result["available_symbols"]
assert "my_func" in available
assert "MyClass" in available
assert "method_a" in available
assert "nonexistent" not in result["content"]
@pytest.mark.asyncio
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\nworld\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="anything")
assert result["is_error"] is False
assert result["content"] == "# Hello\nworld\n"
assert "symbol extraction not supported" in result["note"]
assert ".md" in result["note"]
@pytest.mark.asyncio
async def test_empty_file(self, tmp_path):
path = tmp_path / "empty.py"
path.write_text("", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path))
assert result["is_error"] is False
assert result["content"] == ""
assert result["total_lines"] == 0
@pytest.mark.asyncio
async def test_file_with_no_symbols(self, tmp_path):
path = tmp_path / "data.py"
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="PI")
# PI is not a def/class — extractor finds no symbols; soft error lists available.
assert result["is_error"] is False
assert result["content"] == ""
assert result["available_symbols"] == []
# ---------------------------------------------------------------------------
# Error paths
# ---------------------------------------------------------------------------
class TestReadFileToolErrors:
@pytest.mark.asyncio
async def test_path_required(self):
tool = ReadFileTool()
result = await tool.execute()
assert result["is_error"] is True
assert "path" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_empty_string(self):
tool = ReadFileTool()
result = await tool.execute(path="")
assert result["is_error"] is True
@pytest.mark.asyncio
async def test_file_not_found(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path / "missing.py"))
assert result["is_error"] is True
assert "not found" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_is_directory(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path))
assert result["is_error"] is True
assert "directory" in result["error"].lower()
# ---------------------------------------------------------------------------
# Manual line slicing
# ---------------------------------------------------------------------------
class TestManualLineSlicing:
@pytest.mark.asyncio
async def test_start_and_end_line(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=3,
end_line=5,
)
assert result["is_error"] is False
assert result["start_line"] == 3
assert result["end_line"] == 5
# Lines 3-5 of the sample file:
# line 3: "def my_func():"
# line 4: " return 42"
# line 5: "" (blank)
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
@pytest.mark.asyncio
async def test_start_line_only_extends_to_eof(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=8)
assert result["is_error"] is False
assert result["start_line"] == 8
assert result["end_line"] == result["total_lines"]
@pytest.mark.asyncio
async def test_end_line_only_starts_at_one(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), end_line=2)
assert result["is_error"] is False
assert result["start_line"] == 1
assert result["end_line"] == 2
@pytest.mark.asyncio
async def test_invalid_start_line_zero(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=0)
assert result["is_error"] is True
assert "start_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_end_before_start(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=5,
end_line=3,
)
assert result["is_error"] is True
assert "end_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_manual_lines_override_symbol(self, sample_py_file):
# Per plan U1 Approach: "start_line/end_line overrides symbol".
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
symbol="my_func",
start_line=1,
end_line=1,
)
assert result["is_error"] is False
# Manual slicing won — symbol field absent.
assert "symbol" not in result or result.get("symbol") is None
assert result["start_line"] == 1
assert result["end_line"] == 1
# ---------------------------------------------------------------------------
# Integration — tool registry discovery
# ---------------------------------------------------------------------------
class TestToolRegistryDiscovery:
def test_instantiable_without_args(self):
# Default constructor — matches the convention used by ToolRegistry
# to instantiate tools by class.
tool = ReadFileTool()
assert tool.name == "read_file"
def test_to_dict_serializable(self):
tool = ReadFileTool()
d = tool.to_dict()
assert d["name"] == "read_file"
assert "input_schema" in d
assert "output_schema" in d
assert d["tags"] == ["io", "file", "read"]

View File

@ -0,0 +1,359 @@
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
Covers R22 (file reading supports symbol/function granularity) and KTD1
(Python ast + language-aware regex, no tree-sitter dependency).
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.symbol_extractor import (
AstSymbolExtractor,
RegexSymbolExtractor,
SymbolSpan,
extract_symbols_from_file,
get_extractor,
language_for_extension,
)
# ---------------------------------------------------------------------------
# language_for_extension
# ---------------------------------------------------------------------------
class TestLanguageForExtension:
def test_python_extensions(self):
assert language_for_extension("py") == "py"
assert language_for_extension(".py") == "py"
assert language_for_extension(".PY") == "py" # case-insensitive
def test_typescript_javascript(self):
assert language_for_extension(".ts") == "ts"
assert language_for_extension(".tsx") == "ts"
assert language_for_extension(".js") == "js"
assert language_for_extension(".jsx") == "js"
assert language_for_extension(".mjs") == "js"
assert language_for_extension(".cjs") == "js"
def test_go_rust_java(self):
assert language_for_extension(".go") == "go"
assert language_for_extension(".rs") == "rs"
assert language_for_extension(".java") == "java"
def test_unsupported_returns_empty(self):
assert language_for_extension(".md") == ""
assert language_for_extension(".txt") == ""
assert language_for_extension("") == ""
assert language_for_extension(".unknown") == ""
# ---------------------------------------------------------------------------
# AstSymbolExtractor — Python
# ---------------------------------------------------------------------------
class TestAstSymbolExtractor:
extractor = AstSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
def test_syntax_error_returns_empty(self):
# Never raises — callers rely on this for fallback routing.
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
assert result == []
def test_top_level_function(self):
content = "def my_func():\n return 42\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
span = spans[0]
assert span.name == "my_func"
assert span.kind == "function"
assert span.start_line == 1
assert span.end_line == 2
def test_async_function(self):
content = "async def fetch():\n return 1\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
assert spans[0].name == "fetch"
assert spans[0].kind == "function"
def test_top_level_class(self):
content = textwrap.dedent('''
class MyClass:
"""docstring"""
def method_a(self):
return 1
async def method_b(self):
return 2
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = [s.name for s in spans]
assert "MyClass" in names
assert "method_a" in names
assert "method_b" in names
cls = next(s for s in spans if s.name == "MyClass")
assert cls.kind == "class"
assert cls.start_line == 1
# Class body extends through the last method's end_lineno.
assert cls.end_line >= 7
def test_methods_classified_as_methods(self):
content = textwrap.dedent('''
class Foo:
def bar(self):
pass
def top_level():
pass
''').strip()
spans = self.extractor.extract_symbols(content, "py")
by_name = {s.name: s for s in spans}
assert by_name["bar"].kind == "method"
assert by_name["top_level"].kind == "function"
def test_decorated_function(self):
content = textwrap.dedent('''
@staticmethod
def helper():
return "hi"
''').strip()
spans = self.extractor.extract_symbols(content, "py")
# Note: extractor uses node.lineno (def line) — decorators above are
# excluded by design (matches user-visible symbol start at `def`).
assert any(s.name == "helper" for s in spans)
span = next(s for s in spans if s.name == "helper")
assert span.start_line == 2 # the `def` line
def test_nested_function(self):
content = textwrap.dedent('''
def outer():
def inner():
return 1
return inner()
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = {s.name for s in spans}
assert "outer" in names
assert "inner" in names
def test_empty_file(self):
assert self.extractor.extract_symbols("", "py") == []
def test_no_symbols_in_docstring_only_file(self):
content = '"""just a docstring"""\n'
assert self.extractor.extract_symbols(content, "py") == []
# ---------------------------------------------------------------------------
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
class TestRegexSymbolExtractor:
extractor = RegexSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
def test_typescript_function_declaration(self):
content = textwrap.dedent('''
export function renderComponent(props: Props): JSX.Element {
return <div/>;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
def test_typescript_async_function(self):
content = "async function fetchData() {\n return await fetch();\n}\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "fetchData" for s in spans)
def test_typescript_arrow_function_const(self):
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "handleClick" for s in spans)
def test_typescript_class(self):
content = textwrap.dedent('''
export abstract class BaseService {
abstract run(): void;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
def test_javascript_function(self):
content = "function foo() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "foo" for s in spans)
def test_javascript_arrow_const(self):
content = "const bar = () => 42;\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "bar" for s in spans)
def test_go_function(self):
content = textwrap.dedent('''
package main
func HandleRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}
func (s *Server) Start() {
// method
}
''').strip()
spans = self.extractor.extract_symbols(content, "go")
names = {s.name for s in spans}
assert "HandleRequest" in names
assert "Start" in names # method receiver pattern
def test_go_struct(self):
content = "type Server struct {\n Addr string\n}\n"
spans = self.extractor.extract_symbols(content, "go")
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
def test_rust_function(self):
content = textwrap.dedent('''
pub fn process(input: &str) -> Result<usize, Error> {
Ok(input.len())
}
async fn fetch() -> Bytes {
unimplemented!()
}
''').strip()
spans = self.extractor.extract_symbols(content, "rs")
names = {s.name for s in spans}
assert "process" in names
assert "fetch" in names
def test_rust_struct(self):
content = "pub struct Config {\n pub path: String,\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
def test_rust_impl(self):
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
def test_java_class(self):
content = textwrap.dedent('''
package com.example;
public class UserService {
public User findById(long id) {
return null;
}
}
''').strip()
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
def test_java_method(self):
content = "public User findById(long id) {\n return null;\n}\n"
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "findById" and s.kind == "function" for s in spans)
def test_end_line_extends_to_next_symbol(self):
# First symbol's end_line is the line before the second symbol starts.
content = textwrap.dedent('''
function first() {
return 1;
}
function second() {
return 2;
}
''').strip()
spans = self.extractor.extract_symbols(content, "js")
spans.sort(key=lambda s: s.start_line)
first = spans[0]
second = spans[1]
assert first.name == "first"
assert second.name == "second"
assert first.end_line == second.start_line - 1
def test_last_symbol_end_line_is_eof(self):
content = "function only() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert len(spans) == 1
assert spans[0].end_line == len(content.splitlines())
# ---------------------------------------------------------------------------
# get_extractor + integration
# ---------------------------------------------------------------------------
class TestGetExtractor:
def test_python_returns_ast_extractor(self):
ext = get_extractor("py")
assert ext is not None
assert isinstance(ext, AstSymbolExtractor)
def test_typescript_returns_regex_extractor(self):
ext = get_extractor("ts")
assert ext is not None
assert isinstance(ext, RegexSymbolExtractor)
def test_unsupported_returns_none(self):
assert get_extractor("md") is None
assert get_extractor("") is None
assert get_extractor("unknown") is None
class TestExtractSymbolsFromFile:
def test_python_file(self, tmp_path):
path = tmp_path / "module.py"
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == "py"
assert any(s.name == "hello" for s in spans)
def test_unsupported_extension(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == ""
assert spans == []
def test_missing_file_returns_empty(self, tmp_path):
path = tmp_path / "nonexistent.py"
spans, lang = extract_symbols_from_file(path)
# lang is detected from extension even if read fails.
assert lang == "py"
assert spans == []
# ---------------------------------------------------------------------------
# SymbolSpan dataclass
# ---------------------------------------------------------------------------
class TestSymbolSpan:
def test_frozen_dataclass(self):
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
assert span.name == "foo"
with pytest.raises(Exception):
span.name = "bar" # type: ignore[misc] — frozen
def test_equality(self):
a = SymbolSpan("foo", "function", 1, 3)
b = SymbolSpan("foo", "function", 1, 3)
assert a == b
assert hash(a) == hash(b)