feat(agent): Wave 3 strategic coupling (G5/G6) #6
|
|
@ -51,6 +51,19 @@ fallback_chain:
|
||||||
max_retries: 1 # ReflexionEngine max_reflections override
|
max_retries: 1 # ReflexionEngine max_reflections override
|
||||||
emergency:
|
emergency:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
# G6/U2: PLAN_EXEC phase policy — SOLO four-stage state machine.
|
||||||
|
# When `enabled: true`, chat WebSocket PLAN_EXEC requests build a PhasePolicy
|
||||||
|
# (Planning → Building → Verification → Delivery) and enforce per-step tool
|
||||||
|
# whitelists (R24). Transitions are LLM-driven via AdvancePhaseTool; set
|
||||||
|
# `auto_advance_after_steps` to auto-advance as a safety net (KTD6).
|
||||||
|
# Commented to preserve default behavior — uncomment to enable.
|
||||||
|
# plan_exec:
|
||||||
|
# enabled: true
|
||||||
|
# auto_advance_after_steps: 5 # optional, default = manual (LLM calls advance_phase)
|
||||||
|
# start_phase: planning # optional, default = planning
|
||||||
|
# whitelist_override: # optional, merges with default (override wins)
|
||||||
|
# planning: [search, read_file, shell]
|
||||||
|
# building: [write_file, shell, read_file]
|
||||||
session: {backend: memory}
|
session: {backend: memory}
|
||||||
bus: {backend: memory}
|
bus: {backend: memory}
|
||||||
task_store: {backend: memory}
|
task_store: {backend: memory}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,436 @@
|
||||||
|
---
|
||||||
|
title: "feat: Agent Wave 3 strategic coupling (G5/G6)"
|
||||||
|
date: 2026-06-29
|
||||||
|
type: feat
|
||||||
|
status: draft
|
||||||
|
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
|
||||||
|
execution: code
|
||||||
|
---
|
||||||
|
|
||||||
|
# Wave 3 Strategic Coupling — G5 Function-level Sharding + G6 SOLO Phase Constraints
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
Wave 3 of the advanced-agent gap optimization closes two strategic gaps deferred from Waves 1-2:
|
||||||
|
|
||||||
|
- **G5 — Function-level code sharding** (R22, R23): file reading gains an optional `symbol` parameter for symbol/function-granularity slicing, backward compatible with full-file reads.
|
||||||
|
- **G6 — SOLO four-stage state machine** (R24, R25): ReAct loop enforces per-phase tool whitelists (Planning → Building → Verification → Delivery). Extends existing `ExecutionMode.PLAN_EXEC` rather than introducing a new mode.
|
||||||
|
|
||||||
|
Wave 1 (G1/G2/G3/G8 — PR #4 merged) and Wave 2 (G4/G7/G9 — PR #5 open) shipped independently. Wave 3 is the **strategic-risk** wave: it introduces a new tool (G5) and touches ReAct core (G6). Per the brainstorm's KTD6/KTD7 locked decisions, G5 integration approach is decided here (not deferred further), and G6 extends PLAN_EXEC rather than adding a new mode.
|
||||||
|
|
||||||
|
## Problem Frame
|
||||||
|
|
||||||
|
The brainstorm (`docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md`) identified nine gaps across three dimensions. Waves 1-2 closed seven (G1-G4, G7-G9). The remaining two gaps are strategic:
|
||||||
|
|
||||||
|
- **G5 (Long-context cost)**: Large files (e.g. 5000-line modules) blow context budget when the agent only needs one function. Today the agent shells out to `cat file.py` or greps; both pull the whole file into context. A symbol-aware slice would cut context cost 10-50x for typical edits.
|
||||||
|
- **G6 (Unsafe tool sequencing)**: ReAct loop lets the LLM call `write_file` during early exploration before committing to a plan. This wastes tokens on premature edits, causes half-baked refactors, and breaks the "plan-then-build" discipline that production agents (Qoder, Trae Work) enforce via phase state machines.
|
||||||
|
|
||||||
|
KTD6 (brainstorm) defers the G5 integration decision to this plan. KTD7 locks G6 to extend PLAN_EXEC rather than introduce a new mode.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
Carried forward from the brainstorm, Wave 3 section:
|
||||||
|
|
||||||
|
- **R22**: File reading supports symbol/function granularity sharding.
|
||||||
|
- **R23**: Sharding capability exposed as a tool parameter (`symbol="function_name"`), backward compatible with full-file reads.
|
||||||
|
- **R24**: ReAct loop enforces phase constraints — Planning phase only allows `think`/`search`; Building phase only allows `write_file` (and similar write tools).
|
||||||
|
- **R25**: Phase state is configurable; extends `ExecutionMode.PLAN_EXEC`, does NOT introduce a new mode.
|
||||||
|
|
||||||
|
Cross-cutting (brainstorm):
|
||||||
|
|
||||||
|
- **R26**: All optimizations configurable via `agentkit.yaml` (follow `ServerConfig.from_dict` pattern established in Waves 1-2).
|
||||||
|
- **R27**: Each optimization ships a minimal self-check test (ponytail rule).
|
||||||
|
|
||||||
|
Acceptance examples relevant to Wave 3:
|
||||||
|
|
||||||
|
- **AE5 already covered by Wave 2 (G9)** — not in Wave 3 scope.
|
||||||
|
- No explicit AE for G5/G6 in the brainstorm — the plan below specifies test scenarios as the acceptance contract.
|
||||||
|
|
||||||
|
## Key Technical Decisions
|
||||||
|
|
||||||
|
### KTD1: G5 uses Python `ast` + language-aware regex, NOT tree-sitter
|
||||||
|
|
||||||
|
**Decision**: Implement symbol extraction with the Python stdlib `ast` module for Python files, and a small regex-based extractor for TypeScript/JavaScript/Go/Rust/Java. No new dependency.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- `tree-sitter` requires native compilation + per-language grammar files (~30MB installed) — violates the ponytail rule "no new dependency if it can be avoided" and AGENTS.md "禁止使用 any 类型" → prefers minimal stack.
|
||||||
|
- `ast` is stdlib, always available, parses Python accurately.
|
||||||
|
- Regex extractor covers 80% case for TS/JS/Go/Rust/Java (function/class/struct declarations); falls back to "no symbols found → read full file" gracefully.
|
||||||
|
- If a future Wave 4 needs more accurate multi-language parsing, `tree-sitter` can replace the regex layer behind the same `SymbolExtractor` interface.
|
||||||
|
|
||||||
|
**Upgrade path**: replace `RegexSymbolExtractor` with `TreeSitterSymbolExtractor` implementing the same `SymbolExtractor` protocol; no caller changes.
|
||||||
|
|
||||||
|
### KTD2: G5 adds a new `ReadFileTool`, does not extend `ShellTool`
|
||||||
|
|
||||||
|
**Decision**: Add a dedicated `ReadFileTool` in `src/agentkit/tools/file_read.py` with `path` + optional `symbol` + optional `start_line`/`end_line` parameters.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- `ShellTool` is for shell execution; grafting symbol extraction onto it muddies the contract.
|
||||||
|
- A dedicated tool gives the LLM a clear schema (`{"path": "...", "symbol": "function_name"}`) and a focused system-prompt description.
|
||||||
|
- Aligns with the existing `_DEFAULT_CORE_TOOLS` list in `core/react.py:148` which already references `read_file` — the name is reserved but the implementation is missing.
|
||||||
|
|
||||||
|
### KTD3: G6 phase state machine lives in ReActEngine, not in the skill config
|
||||||
|
|
||||||
|
**Decision**: Phase state (Planning/Building/Verification/Delivery) is tracked as a mutable field on `ReActEngine` instance. Transitions are driven by LLM-detected phase-completion signals (e.g., the LLM emits `Phase: Building` in its thinking) OR by an explicit `advance_phase` tool.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- Skill config declares the policy (which tools per phase, auto-advance vs manual); the engine enforces it per-step. This matches R24 ("ReAct 循环加阶段约束").
|
||||||
|
- Alternative considered: phase state in the agent instance (not engine). Rejected because ReActEngine already owns `max_steps`/`verification_enabled` etc.; phase state belongs with the loop that enforces it.
|
||||||
|
|
||||||
|
### KTD4: PLAN_EXEC mode is wired at chat.py WebSocket path (REST already has fallback chain from Wave 2)
|
||||||
|
|
||||||
|
**Decision**: chat.py:1084 (currently warns "not yet supported, falling back to REACT") will route `ExecutionMode.PLAN_EXEC` to a new `_execute_plan_exec_ws` handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and passes it to `ReActEngine.execute`.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- REST `send_message` already uses the Wave 2 three-tier fallback chain; PLAN_EXEC at REST would also need that wrapper. **Out of scope for Wave 3** — only WebSocket path is wired. REST PLAN_EXEC remains "not yet supported" and explicitly raises if invoked.
|
||||||
|
- Single integration point keeps Wave 3 scope bounded; REST wiring is a one-line follow-up once WebSocket path is proven.
|
||||||
|
|
||||||
|
### KTD5: Default phase whitelist matches brainstorm R24
|
||||||
|
|
||||||
|
**Decision**: Default whitelist:
|
||||||
|
- `Planning`: `search`, `tool_search`, `read_file`, `bash` (read-only commands like `git status`, `ls`)
|
||||||
|
- `Building`: `write_file`, `bash` (write commands), `read_file`, `search`
|
||||||
|
- `Verification`: `bash` (test commands), `read_file`, `search`
|
||||||
|
- `Delivery`: all tools (final synthesis)
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- R24 explicitly names `think`/`search` for Planning and `write_file` for Building.
|
||||||
|
- `bash` is split: read-only in Planning, full in Building. Enforced by adding a `bash_command_filter` callback (regex-based, blocks `rm`/`mv`/`>`/`>>` in Planning/Verification).
|
||||||
|
- `Delivery` allows all tools to support last-mile formatting/cleanup.
|
||||||
|
|
||||||
|
### KTD6: Phase transitions are LLM-driven via `advance_phase` tool (opt-in auto-advance)
|
||||||
|
|
||||||
|
**Decision**: Add an `AdvancePhaseTool` that the LLM can call to transition Planning→Building→Verification→Delivery. Auto-advance (after N steps in current phase) is opt-in via `plan_exec.auto_advance_after_steps`.
|
||||||
|
|
||||||
|
**Rationale**:
|
||||||
|
- LLM-driven transitions match Qoder/Trae Work pattern: LLM declares "planning done" explicitly.
|
||||||
|
- Auto-advance is a safety net for LLMs that forget to call `advance_phase`; default off (ponytail: less code is better).
|
||||||
|
|
||||||
|
## Scope Boundaries
|
||||||
|
|
||||||
|
### In Scope
|
||||||
|
|
||||||
|
- New `ReadFileTool` with `symbol` parameter (G5, R22/R23).
|
||||||
|
- `SymbolExtractor` protocol + `AstSymbolExtractor` (Python) + `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
|
||||||
|
- `PhasePolicy` dataclass + `PhaseState` enum + per-step tool whitelist enforcement in `ReActEngine.execute` (G6, R24/R25).
|
||||||
|
- `AdvancePhaseTool` for LLM-driven phase transitions.
|
||||||
|
- WebSocket chat path routes `PLAN_EXEC` to new `_execute_plan_exec_ws` handler (KTD4).
|
||||||
|
- `plan_exec` config section in `agentkit.yaml` + `ServerConfig.from_dict` extension (R26).
|
||||||
|
- Tests for each new module (R27).
|
||||||
|
|
||||||
|
### Out of Scope (Deferred to Follow-Up Work)
|
||||||
|
|
||||||
|
- REST `send_message` PLAN_EXEC wiring — once WebSocket path is proven, REST wiring is a follow-up commit.
|
||||||
|
- `tree-sitter` integration for more accurate multi-language parsing (KTD1 upgrade path).
|
||||||
|
- Phase-aware prompt engineering (per-phase system prompt templates) — current plan keeps a single system prompt; phase-specific guidance is a prompt-engineering concern, not a code change.
|
||||||
|
- Phase persistence across session resume (U7 checkpoint already saves plan state; phase state restoration is a separate concern).
|
||||||
|
- Phase rollback on `Building` → `Planning` regression (Wave 2 G9 rollback handles file-level rollback; phase regression is a UX/prompt concern).
|
||||||
|
- Tool-filter UI in the frontend (Wave 3 ships backend-only; frontend surfaces phase via existing event channel if needed in a follow-up).
|
||||||
|
|
||||||
|
### Outside This Product's Identity
|
||||||
|
|
||||||
|
- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm).
|
||||||
|
- Disc-based file system à la DeerFlow (inherited).
|
||||||
|
- Docker sandbox (inherited; only command-level safety via `bash_command_filter`).
|
||||||
|
|
||||||
|
## High-Level Technical Design
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart LR
|
||||||
|
subgraph G5[Function Sharding]
|
||||||
|
RF[ReadFileTool] --> SE[SymbolExtractor protocol]
|
||||||
|
SE --> AST[AstSymbolExtractor<br/>Python stdlib ast]
|
||||||
|
SE --> RX[RegexSymbolExtractor<br/>TS/JS/Go/Rust/Java]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph G6[Phase State Machine]
|
||||||
|
PP[PhasePolicy config] --> PS[PhaseState enum<br/>Planning/Building/Verify/Delivery]
|
||||||
|
PS --> Filt[Tool filter per step]
|
||||||
|
Filt --> RE[ReActEngine.execute]
|
||||||
|
AP[AdvancePhaseTool] -->|transitions| PS
|
||||||
|
end
|
||||||
|
|
||||||
|
RF -->|file content for symbol| RE
|
||||||
|
RE -->|enforces| Filt
|
||||||
|
```
|
||||||
|
|
||||||
|
The two subsystems compose at the ReAct engine boundary: `ReadFileTool` is one of the tools the LLM can call during any phase (filtered by `PhasePolicy`); `PhaseState` is enforced at the tool-call step before dispatch.
|
||||||
|
|
||||||
|
## Implementation Units
|
||||||
|
|
||||||
|
### U1. SymbolExtractor + ReadFileTool (G5)
|
||||||
|
|
||||||
|
**Goal**: Add `ReadFileTool` with optional `symbol` parameter; implement `SymbolExtractor` protocol with `AstSymbolExtractor` (Python) and `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
|
||||||
|
|
||||||
|
**Requirements**: R22, R23, R27.
|
||||||
|
|
||||||
|
**Dependencies**: none.
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/tools/file_read.py` (new)
|
||||||
|
- `src/agentkit/tools/symbol_extractor.py` (new)
|
||||||
|
- `src/agentkit/tools/__init__.py` (modify — register `ReadFileTool`)
|
||||||
|
- `tests/unit/test_symbol_extractor.py` (new)
|
||||||
|
- `tests/unit/test_read_file_tool.py` (new)
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- `SymbolExtractor` is a `Protocol` with one method: `extract_symbols(content: str, language: str) -> list[SymbolSpan]`. `SymbolSpan` carries `name`, `kind` (function/class/method/struct), `start_line`, `end_line`.
|
||||||
|
- `AstSymbolExtractor` walks `ast.parse(content)`; for each `FunctionDef`/`AsyncFunctionDef`/`ClassDef` collects name + line range. Uses `ast.get_source_segment` style (line-based, not node-based, to keep the API simple).
|
||||||
|
- `RegexSymbolExtractor` ships patterns for TS/JS (`function X`, `const X = (...) =>`, `class X`), Go (`func X`), Rust (`fn X`, `struct X`, `impl X`), Java (`public ... X(...)`). Falls back to "no symbols" if no pattern matches.
|
||||||
|
- `ReadFileTool.execute(path, symbol=None, start_line=None, end_line=None)`:
|
||||||
|
- `symbol=None` → read full file (backward compat with the existing `_FakeTool` benchmark shape).
|
||||||
|
- `symbol="foo"` → detect language from extension; call `extract_symbols`; return the line range of the first matching symbol; if no match, return an error result with available symbol names listed (so the LLM can retry).
|
||||||
|
- `start_line`/`end_line` overrides symbol; allows manual slicing.
|
||||||
|
- Tool registered as `read_file` (matches the reserved name in `core/react.py:148`).
|
||||||
|
|
||||||
|
**Execution note**: characterization-first — write a test that asserts the tool returns the full file content when `symbol=None` (matches pre-existing benchmark `_FakeTool` shape) before adding symbol-extraction behavior.
|
||||||
|
|
||||||
|
**Patterns to follow**:
|
||||||
|
- `src/agentkit/tools/document_tool.py` for tool structure (dataclass, `Tool` base class, `input_schema`).
|
||||||
|
- `src/agentkit/tools/schema_tools.py:SchemaExtractTool` for "extract-from-source" pattern.
|
||||||
|
|
||||||
|
**Test scenarios** (covers R22, R23):
|
||||||
|
- **Happy paths**:
|
||||||
|
- Python file, `symbol="MyClass"` → returns class body only (lines from `class MyClass:` through end of class).
|
||||||
|
- Python file, `symbol="my_func"` → returns function body only.
|
||||||
|
- TypeScript file, `symbol="renderComponent"` → returns arrow/function body.
|
||||||
|
- Go file, `symbol="HandleRequest"` → returns func body.
|
||||||
|
- **Edge cases**:
|
||||||
|
- `symbol=None` → returns full file content (characterization).
|
||||||
|
- `symbol="nonexistent"` → returns error result listing available symbols ("Available symbols: foo, bar, baz").
|
||||||
|
- Unsupported file extension (`.md`, `.txt`) → returns full file with `note: symbol extraction not supported for .md`.
|
||||||
|
- Empty file → returns empty content.
|
||||||
|
- File with nested classes → outer class symbol returns including inner class.
|
||||||
|
- **Error paths**:
|
||||||
|
- Path does not exist → raises `FileNotFoundError` (or returns error result matching other tools' convention).
|
||||||
|
- Path is a directory → returns error result.
|
||||||
|
- Permission denied → returns error result.
|
||||||
|
- **Integration scenarios**:
|
||||||
|
- Symbol extraction + line slicing: `symbol="foo"`, `end_line=50` truncates at line 50 even if symbol extends further.
|
||||||
|
- Round-trip: extract symbol, write back via `ShellTool` `sed` (not in scope for tool — just verify extracted range is well-formed).
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
- `python3 -m pytest tests/unit/test_symbol_extractor.py tests/unit/test_read_file_tool.py -q` passes.
|
||||||
|
- `ruff check src/agentkit/tools/file_read.py src/agentkit/tools/symbol_extractor.py` clean.
|
||||||
|
- `ReadFileTool` appears in `ToolRegistry.list_tools()` after registration.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U2. PhasePolicy + PhaseState + ServerConfig (G6 core)
|
||||||
|
|
||||||
|
**Goal**: Add `PhasePolicy` dataclass, `PhaseState` enum, default whitelist config. Extend `ServerConfig.from_dict` with `plan_exec` section. Wire config to `agentkit.yaml`.
|
||||||
|
|
||||||
|
**Requirements**: R25, R26.
|
||||||
|
|
||||||
|
**Dependencies**: none.
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/core/phase.py` (new) — `PhaseState` enum, `PhasePolicy` dataclass, `default_policy()` factory.
|
||||||
|
- `src/agentkit/server/config.py` (modify — add `plan_exec` field + `from_dict` parsing).
|
||||||
|
- `agentkit.yaml` (modify — document `plan_exec:` section).
|
||||||
|
- `tests/unit/test_phase_policy.py` (new).
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- `PhaseState = enum("planning building verification delivery")`.
|
||||||
|
- `PhasePolicy` carries:
|
||||||
|
- `whitelist: dict[PhaseState, set[str]]` — tool names allowed per phase.
|
||||||
|
- `bash_command_filter: dict[PhaseState, re.Pattern | None]` — regex that bash args must NOT match (e.g., `r"\b(rm|mv|>|>>)\b"` in Planning).
|
||||||
|
- `auto_advance_after_steps: int | None` — None = manual (LLM calls `advance_phase`); int = auto-advance after N steps.
|
||||||
|
- `start_phase: PhaseState = PhaseState.PLANNING`.
|
||||||
|
- `default_policy()` returns the KTD5 whitelist above.
|
||||||
|
- `ServerConfig.from_dict` reads `plan_exec` section: `enabled`, `whitelist_override` (dict), `auto_advance_after_steps`.
|
||||||
|
- `agentkit.yaml` gains a commented-out `plan_exec:` block (commented to preserve default behavior — opt-in).
|
||||||
|
|
||||||
|
**Patterns to follow**:
|
||||||
|
- `src/agentkit/core/fallback.py` for dataclass + classmethod factory pattern.
|
||||||
|
- `src/agentkit/server/config.py` `from_dict` extension template (established in Wave 1 for `prompt_cache`/`streaming`/`verification`; Wave 2 added `rollback`/`fallback_chain`; Wave 3 adds `plan_exec`).
|
||||||
|
|
||||||
|
**Test scenarios** (covers R25, R26):
|
||||||
|
- **Happy paths**:
|
||||||
|
- `default_policy()` returns policy with all four phases; Planning whitelist contains `search`, `read_file`; Building contains `write_file`.
|
||||||
|
- `PhasePolicy.is_tool_allowed("search", PhaseState.PLANNING)` returns True.
|
||||||
|
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.PLANNING)` returns False.
|
||||||
|
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.BUILDING)` returns True.
|
||||||
|
- **Edge cases**:
|
||||||
|
- Empty whitelist for a phase → all tools rejected (raises `ValueError` at construction time — fail-fast).
|
||||||
|
- `Delivery` phase whitelist contains `"*"` (wildcard) → all tools allowed.
|
||||||
|
- Custom whitelist override merges with default (override wins on conflict).
|
||||||
|
- **Error paths**:
|
||||||
|
- Invalid phase name in config → `ValueError` with message naming the bad value.
|
||||||
|
- `bash_command_filter` regex compile failure → `ValueError`.
|
||||||
|
- **Config integration**:
|
||||||
|
- `ServerConfig.from_dict({"plan_exec": {"enabled": True, "auto_advance_after_steps": 5}})` populates fields correctly.
|
||||||
|
- `ServerConfig.from_dict({})` → `plan_exec = {}` (default).
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes.
|
||||||
|
- `ruff check src/agentkit/core/phase.py` clean.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U3. AdvancePhaseTool + ReActEngine phase enforcement (G6 wiring)
|
||||||
|
|
||||||
|
**Goal**: Add `AdvancePhaseTool`. Wire `PhasePolicy` into `ReActEngine.execute` so each tool-call step checks `is_tool_allowed(tool_name, current_phase)` before dispatch; blocked calls return a structured error to the LLM ("Tool 'write_file' not allowed in Planning phase — call advance_phase first").
|
||||||
|
|
||||||
|
**Requirements**: R24.
|
||||||
|
|
||||||
|
**Dependencies**: U2.
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/tools/advance_phase.py` (new) — `AdvancePhaseTool` calls `react_engine.advance_phase()`.
|
||||||
|
- `src/agentkit/core/react.py` (modify — add `phase_policy` param to `__init__` + `execute`; add `_current_phase` field; add `advance_phase()` method; enforce in `_execute_loop`).
|
||||||
|
- `tests/unit/test_react_phase_enforcement.py` (new).
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- `ReActEngine.__init__` accepts `phase_policy: PhasePolicy | None = None`. None = no enforcement (backward compat — all existing callers unaffected).
|
||||||
|
- `_current_phase: PhaseState | None` initialized from `phase_policy.start_phase` if policy set, else None.
|
||||||
|
- `advance_phase()` advances `_current_phase` to next enum value; raises `ValueError` if already at `DELIVERY`.
|
||||||
|
- In `_execute_loop`, before dispatching a tool call:
|
||||||
|
```python
|
||||||
|
if self._phase_policy is not None and self._current_phase is not None:
|
||||||
|
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||||
|
# Inject structured error into conversation, do NOT dispatch tool.
|
||||||
|
# This counts as a "step" for max_steps purposes.
|
||||||
|
observation = {
|
||||||
|
"error": "phase_violation",
|
||||||
|
"message": f"Tool '{tool_name}' not allowed in {self._current_phase.value} phase",
|
||||||
|
"current_phase": self._current_phase.value,
|
||||||
|
"hint": "Call advance_phase to move to Building phase"
|
||||||
|
}
|
||||||
|
continue # next loop iteration
|
||||||
|
```
|
||||||
|
- Auto-advance: if `phase_policy.auto_advance_after_steps` is set and `_steps_in_phase >= auto_advance_after_steps`, call `advance_phase()` automatically.
|
||||||
|
- `AdvancePhaseTool.execute()` calls the bound engine's `advance_phase()` and returns the new phase name. Registered only when `phase_policy` is not None.
|
||||||
|
|
||||||
|
**Execution note**: characterization-first — test that `ReActEngine` with `phase_policy=None` behaves identically to pre-change (no enforcement, no `advance_phase` tool, no `_current_phase` mutation). Then add enforcement tests.
|
||||||
|
|
||||||
|
**Patterns to follow**:
|
||||||
|
- `src/agentkit/core/react.py` `verification_enabled` pattern (feature flag + step-level check).
|
||||||
|
- `src/agentkit/tools/ask_human.py` for tool that interacts with engine state.
|
||||||
|
|
||||||
|
**Test scenarios** (covers R24):
|
||||||
|
- **Characterization (no policy)**:
|
||||||
|
- `ReActEngine(phase_policy=None)` — all tools allowed in all steps; no `advance_phase` tool registered; behavior matches pre-change.
|
||||||
|
- **Happy paths**:
|
||||||
|
- Planning phase: LLM calls `search` → executes; LLM calls `advance_phase` → phase becomes Building.
|
||||||
|
- Building phase: LLM calls `write_file` → executes; LLM calls `advance_phase` → phase becomes Verification.
|
||||||
|
- Verification phase: LLM calls `bash` with `pytest` → executes; LLM calls `advance_phase` → phase becomes Delivery.
|
||||||
|
- Delivery phase: LLM calls any tool → executes (wildcard).
|
||||||
|
- **Edge cases**:
|
||||||
|
- `advance_phase` called at Delivery → returns error "Already at final phase".
|
||||||
|
- Auto-advance after 3 steps in Planning → phase transitions automatically on 4th step.
|
||||||
|
- `bash` command in Planning contains `rm file` → blocked by `bash_command_filter`.
|
||||||
|
- **Error paths**:
|
||||||
|
- LLM calls `write_file` in Planning → tool NOT dispatched; structured error returned to LLM; loop continues.
|
||||||
|
- LLM calls non-existent tool → existing error path (not phase-related).
|
||||||
|
- **Integration scenarios**:
|
||||||
|
- Phase transition emits a `phase_changed` event (use existing `_broadcast_event` pattern from `experts/orchestrator.py`).
|
||||||
|
- `max_steps` reached mid-phase → `ReActResult.status = "max_steps_reached"` (existing path, no change).
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
- `python3 -m pytest tests/unit/test_react_phase_enforcement.py -q` passes.
|
||||||
|
- Existing `tests/unit/test_react_engine.py` still passes (characterization — no policy = no change).
|
||||||
|
- `ruff check src/agentkit/core/react.py src/agentkit/tools/advance_phase.py` clean.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### U4. Wire PLAN_EXEC at chat.py WebSocket path (G6 chat integration)
|
||||||
|
|
||||||
|
**Goal**: Replace the `chat.py:1084` "not yet supported, falling back to REACT" warning with a real PLAN_EXEC handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and dispatches to `ReActEngine.execute` with the policy set.
|
||||||
|
|
||||||
|
**Requirements**: R24, R25 (end-to-end wiring).
|
||||||
|
|
||||||
|
**Dependencies**: U2, U3.
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `src/agentkit/server/routes/chat.py` (modify — add `_execute_plan_exec_ws` handler; branch on `ExecutionMode.PLAN_EXEC`).
|
||||||
|
- `tests/unit/test_chat_plan_exec_ws.py` (new).
|
||||||
|
|
||||||
|
**Approach**:
|
||||||
|
- New helper `_execute_plan_exec_ws(websocket, agent, routing, messages, ...)`:
|
||||||
|
1. Read `server_config.plan_exec` (may be `{}` if not configured → use `default_policy()`).
|
||||||
|
2. Build `PhasePolicy` from config (apply overrides).
|
||||||
|
3. Construct `ReActEngine(..., phase_policy=policy)`.
|
||||||
|
4. Register `AdvancePhaseTool` bound to this engine.
|
||||||
|
5. Call `engine.execute_stream(...)` — reuses existing streaming path.
|
||||||
|
6. Emit `phase_changed` events through the WebSocket (frontend can render phase indicator).
|
||||||
|
- chat.py:1084 changes from `if execution_mode not in (REACT, SKILL_REACT): warn + fall back` to:
|
||||||
|
```python
|
||||||
|
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||||
|
await _execute_plan_exec_ws(websocket, agent, routing, ...)
|
||||||
|
return
|
||||||
|
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
||||||
|
# existing warning for REWOO/REFLEXION/TEAM_COLLAB
|
||||||
|
...
|
||||||
|
```
|
||||||
|
- REST `send_message` path: explicitly raise `HTTPException(501, "PLAN_EXEC via REST not yet supported; use WebSocket")` — Wave 3 does NOT wire REST (KTD4).
|
||||||
|
|
||||||
|
**Execution note**: characterization-first — test that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring.
|
||||||
|
|
||||||
|
**Patterns to follow**:
|
||||||
|
- `src/agentkit/server/routes/chat.py` existing WebSocket handler structure (lines 1082-1100).
|
||||||
|
- `src/agentkit/server/_fallback_chain.py` (Wave 2 U3) for "construct engine per-request with config" pattern.
|
||||||
|
|
||||||
|
**Test scenarios** (covers end-to-end):
|
||||||
|
- **Characterization**:
|
||||||
|
- `ExecutionMode.REWOO` via WebSocket → still falls back to REACT with warning (existing behavior unchanged).
|
||||||
|
- `ExecutionMode.REFLEXION` → same.
|
||||||
|
- `ExecutionMode.TEAM_COLLAB` → same.
|
||||||
|
- **Happy paths**:
|
||||||
|
- `ExecutionMode.PLAN_EXEC` via WebSocket → `_execute_plan_exec_ws` invoked; `ReActEngine` constructed with `phase_policy`; `AdvancePhaseTool` registered.
|
||||||
|
- Planning phase: LLM emits `search` tool call → executed; tool result streamed.
|
||||||
|
- LLM emits `advance_phase` → `phase_changed` event sent to WebSocket client; subsequent `write_file` call now allowed.
|
||||||
|
- **Edge cases**:
|
||||||
|
- `plan_exec` config absent → `default_policy()` used; behavior matches KTD5 whitelist.
|
||||||
|
- `plan_exec.enabled=False` → falls back to REACT (opt-out).
|
||||||
|
- Phase violation: LLM calls `write_file` in Planning → structured error returned; loop continues; `phase_violation` event emitted.
|
||||||
|
- **Error paths**:
|
||||||
|
- REST `send_message` with PLAN_EXEC → 501 error.
|
||||||
|
- Phase policy construction fails (bad config) → 500 error with message.
|
||||||
|
- **Integration scenarios**:
|
||||||
|
- Existing fallback chain (Wave 2 U3) NOT applied to PLAN_EXEC — phase policy and fallback chain are mutually exclusive (KTD5 from Wave 2 plan: chain only wraps REACT/SKILL_REACT at REST). Document this in chat.py comment.
|
||||||
|
|
||||||
|
**Verification**:
|
||||||
|
- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py -q` passes.
|
||||||
|
- `ruff check src/agentkit/server/routes/chat.py` clean.
|
||||||
|
- Manual test: `agentkit chat` with `@skill:plan_exec_demo` skill config → WebSocket stream includes `phase_changed` events.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risks & Dependencies
|
||||||
|
|
||||||
|
### Risks
|
||||||
|
|
||||||
|
1. **ReAct core modification risk (high)**: U3 modifies `ReActEngine._execute_loop`. Mitigation: characterization-first tests (U3 Execution note); `phase_policy=None` default preserves all existing behavior; full `test_react_engine.py` regression.
|
||||||
|
2. **Symbol extraction accuracy (medium)**: Regex extractor may miss edge cases (decorated functions, nested generics, multi-line signatures). Mitigation: fall back to "no symbols found → read full file" gracefully; never raise on extraction failure.
|
||||||
|
3. **PLAN_EXEC phase deadlock (medium)**: LLM may never call `advance_phase`, leaving the agent stuck in Planning. Mitigation: `auto_advance_after_steps` config (default 5); timeout via existing `max_steps`.
|
||||||
|
4. **Tool name drift (low)**: Phase whitelist references tool names (`write_file`, `search`, etc.) that may be renamed in future. Mitigation: whitelist is config-driven; rename only requires config update.
|
||||||
|
|
||||||
|
### Dependencies
|
||||||
|
|
||||||
|
- Wave 2 PR #5 (`feat/agent-wave2-medium-coupling`) should be merged first — Wave 3 builds on the `ServerConfig.from_dict` extension pattern and the `_fallback_chain.py` integration shape established there. If PR #5 is still open, Wave 3 branches from `feat/agent-wave2-medium-coupling` rather than `main`.
|
||||||
|
- No external library dependencies (KTD1).
|
||||||
|
|
||||||
|
## System-Wide Impact
|
||||||
|
|
||||||
|
- **Agents using PLAN_EXEC mode**: gain phase enforcement. Existing REACT/SKILL_REACT/DIRECT_CHAT agents: zero change (phase_policy defaults to None).
|
||||||
|
- **Tool registry**: gains two new tools (`read_file`, `advance_phase`). Frontend tool list display may need updating to show the new icons — out of scope for Wave 3 (frontend follows up).
|
||||||
|
- **`agentkit.yaml`**: gains `plan_exec:` section (commented by default). Existing configs unaffected.
|
||||||
|
- **WebSocket clients**: gain `phase_changed` event type. Existing clients ignore unknown event types (verified in Wave 2 — `phase_rollback_*` events follow the same pattern).
|
||||||
|
|
||||||
|
## Sources & Research
|
||||||
|
|
||||||
|
- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 section, KTD6/KTD7).
|
||||||
|
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (PR #4 merged).
|
||||||
|
- Wave 2 plan: `docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md` (PR #5 open).
|
||||||
|
- Trae Work architecture research (cited in brainstorm): SOLO four-stage state machine pattern.
|
||||||
|
- Qoder architecture research (cited in brainstorm): Spec→Coding→Verify closed loop.
|
||||||
|
- Codebase: `src/agentkit/core/react.py:148` reserves `read_file`/`write_file` tool names in `_DEFAULT_CORE_TOOLS` — Wave 3 U1 delivers the missing `read_file` implementation.
|
||||||
|
- Codebase: `src/agentkit/server/routes/chat.py:1084` documents that PLAN_EXEC is "not yet supported" — Wave 3 U4 closes this gap.
|
||||||
|
|
||||||
|
## Deferred to Implementation
|
||||||
|
|
||||||
|
- Exact regex patterns for non-Python symbol extraction (U1) — design above gives the shape; implementer finalizes patterns based on real-world test fixtures.
|
||||||
|
- `bash_command_filter` regex precision (U2) — defaults block `rm`/`mv`/`>`/`>>`; implementer may add more based on test scenarios.
|
||||||
|
- `phase_changed` event payload shape (U3/U4) — minimal viable shape: `{"phase": "building", "previous": "planning"}`; frontend rendering concerns are out of scope.
|
||||||
|
- Whether `AdvancePhaseTool` accepts a `target_phase` argument for skipping phases (e.g., Planning → Verification) — default no (sequential only); add if test scenarios reveal a need.
|
||||||
|
|
@ -0,0 +1,206 @@
|
||||||
|
"""Phase state machine for PLAN_EXEC mode (G6, R24/R25).
|
||||||
|
|
||||||
|
Four sequential phases enforce per-step tool whitelists:
|
||||||
|
PLANNING → BUILDING → VERIFICATION → DELIVERY
|
||||||
|
|
||||||
|
KTD3 (Wave 3 plan): state machine lives in ReActEngine, not skill config.
|
||||||
|
KTD5: default whitelist matches brainstorm R24 (Planning: think/search;
|
||||||
|
Building: write_file; etc.).
|
||||||
|
KTD6: transitions are LLM-driven via AdvancePhaseTool; auto-advance is opt-in.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PhaseState(enum.Enum):
|
||||||
|
"""Phases of the SOLO state machine (extends ExecutionMode.PLAN_EXEC)."""
|
||||||
|
|
||||||
|
PLANNING = "planning"
|
||||||
|
BUILDING = "building"
|
||||||
|
VERIFICATION = "verification"
|
||||||
|
DELIVERY = "delivery"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def next_of(cls, current: "PhaseState") -> "PhaseState | None":
|
||||||
|
"""Return the phase after `current`, or None if `current` is the last."""
|
||||||
|
order = [cls.PLANNING, cls.BUILDING, cls.VERIFICATION, cls.DELIVERY]
|
||||||
|
try:
|
||||||
|
idx = order.index(current)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
if idx + 1 >= len(order):
|
||||||
|
return None
|
||||||
|
return order[idx + 1]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str) -> "PhaseState":
|
||||||
|
"""Parse from string (case-insensitive). Raises ValueError on unknown."""
|
||||||
|
try:
|
||||||
|
return cls(value.lower())
|
||||||
|
except ValueError as e:
|
||||||
|
valid = ", ".join(p.value for p in cls)
|
||||||
|
raise ValueError(f"Invalid phase name {value!r}. Valid: {valid}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# Wildcard token meaning "all tools allowed in this phase".
|
||||||
|
WILDCARD = "*"
|
||||||
|
|
||||||
|
# Default bash command filter for PLANNING and VERIFICATION phases — blocks
|
||||||
|
# commands that mutate the filesystem or execute arbitrary code.
|
||||||
|
# ponytail: regex is intentionally conservative; misses some shell idioms
|
||||||
|
# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch
|
||||||
|
# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time.
|
||||||
|
# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT
|
||||||
|
# for `>`/`>>` operators (not word chars). Use a non-boundary alternation
|
||||||
|
# that matches `>` either as a standalone operator or after whitespace.
|
||||||
|
_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?<!\S)>|>>")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class PhasePolicy:
|
||||||
|
"""Per-phase tool whitelist + bash command filter for PLAN_EXEC mode.
|
||||||
|
|
||||||
|
The policy is enforced by ReActEngine._execute_loop before each tool
|
||||||
|
dispatch. A tool not in the current phase's whitelist is rejected with
|
||||||
|
a structured error returned to the LLM (the loop continues — the LLM
|
||||||
|
gets to react to the rejection and either switch tools or call
|
||||||
|
AdvancePhaseTool).
|
||||||
|
|
||||||
|
Wildcard ``"*"`` in a phase's whitelist means "all tools allowed"
|
||||||
|
(used by DELIVERY by default).
|
||||||
|
"""
|
||||||
|
|
||||||
|
whitelist: dict[PhaseState, frozenset[str]]
|
||||||
|
bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict)
|
||||||
|
auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase)
|
||||||
|
start_phase: PhaseState = PhaseState.PLANNING
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
# Fail-fast: empty whitelist for a non-wildcard phase = bug.
|
||||||
|
for phase, tools in self.whitelist.items():
|
||||||
|
if not tools:
|
||||||
|
raise ValueError(
|
||||||
|
f"Phase {phase.value!r} has an empty whitelist — set ['*'] for "
|
||||||
|
f"'all tools allowed' or list specific tool names."
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_tool_allowed(self, tool_name: str, phase: PhaseState) -> bool:
|
||||||
|
"""Return True if `tool_name` is allowed in `phase`."""
|
||||||
|
allowed = self.whitelist.get(phase, frozenset())
|
||||||
|
if WILDCARD in allowed:
|
||||||
|
return True
|
||||||
|
return tool_name in allowed
|
||||||
|
|
||||||
|
def is_bash_command_allowed(self, command: str, phase: PhaseState) -> bool:
|
||||||
|
"""Return True if `command` passes the bash filter for `phase`.
|
||||||
|
|
||||||
|
A None filter = no restriction. An empty command is allowed (ShellTool
|
||||||
|
separately rejects empty commands).
|
||||||
|
"""
|
||||||
|
pattern = self.bash_command_filter.get(phase)
|
||||||
|
if pattern is None:
|
||||||
|
return True
|
||||||
|
return not pattern.search(command)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Serialize for logging/telemetry. Not round-trippable (regex → str)."""
|
||||||
|
return {
|
||||||
|
"whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()},
|
||||||
|
"bash_command_filter": {
|
||||||
|
phase.value: (p.pattern if p else None)
|
||||||
|
for phase, p in self.bash_command_filter.items()
|
||||||
|
},
|
||||||
|
"auto_advance_after_steps": self.auto_advance_after_steps,
|
||||||
|
"start_phase": self.start_phase.value,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def default_policy() -> PhasePolicy:
|
||||||
|
"""Return the KTD5 default PhasePolicy.
|
||||||
|
|
||||||
|
Whitelist (R24):
|
||||||
|
- PLANNING: search, tool_search, read_file, shell (read-only)
|
||||||
|
- BUILDING: write_file, shell (full), read_file, search
|
||||||
|
- VERIFICATION: shell (test commands), read_file, search
|
||||||
|
- DELIVERY: all tools (wildcard)
|
||||||
|
|
||||||
|
Bash filter:
|
||||||
|
- PLANNING/VERIFICATION: blocks filesystem-mutating commands
|
||||||
|
(rm/mv/cp/mkdir/chmod/chown/>/>>)
|
||||||
|
- BUILDING/DELIVERY: no filter (full bash)
|
||||||
|
"""
|
||||||
|
return PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
# Tool name is "shell" (ShellTool default); bash_command_filter
|
||||||
|
# gates on the same name. Using "bash" here would make the filter
|
||||||
|
# dead code and block the LLM from shell access.
|
||||||
|
PhaseState.PLANNING: frozenset({"search", "tool_search", "read_file", "shell"}),
|
||||||
|
PhaseState.BUILDING: frozenset(
|
||||||
|
{"write_file", "shell", "read_file", "search", "tool_search"}
|
||||||
|
),
|
||||||
|
PhaseState.VERIFICATION: frozenset({"shell", "read_file", "search"}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
},
|
||||||
|
bash_command_filter={
|
||||||
|
PhaseState.PLANNING: _DEFAULT_BASH_FILTER,
|
||||||
|
PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER,
|
||||||
|
PhaseState.BUILDING: None,
|
||||||
|
PhaseState.DELIVERY: None,
|
||||||
|
},
|
||||||
|
auto_advance_after_steps=None, # manual by default
|
||||||
|
start_phase=PhaseState.PLANNING,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def policy_from_config(config: dict[str, Any]) -> PhasePolicy | None:
|
||||||
|
"""Build a PhasePolicy from the `plan_exec` config section.
|
||||||
|
|
||||||
|
Returns None if `config` is empty or `enabled` is False (opt-out).
|
||||||
|
|
||||||
|
Config shape:
|
||||||
|
plan_exec:
|
||||||
|
enabled: true # default true if section present
|
||||||
|
auto_advance_after_steps: 5 # optional
|
||||||
|
start_phase: planning # optional, default planning
|
||||||
|
whitelist_override: # optional, merges with default
|
||||||
|
planning: [search, read_file]
|
||||||
|
building: [write_file, bash]
|
||||||
|
"""
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
if config.get("enabled", True) is False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
policy = default_policy()
|
||||||
|
|
||||||
|
# Start phase
|
||||||
|
start_phase_str = config.get("start_phase")
|
||||||
|
if start_phase_str:
|
||||||
|
policy = replace(policy, start_phase=PhaseState.from_string(start_phase_str))
|
||||||
|
|
||||||
|
# Auto-advance override
|
||||||
|
if "auto_advance_after_steps" in config:
|
||||||
|
policy = replace(policy, auto_advance_after_steps=config["auto_advance_after_steps"])
|
||||||
|
|
||||||
|
# Whitelist override — merge with default (override wins on conflict)
|
||||||
|
override = config.get("whitelist_override") or {}
|
||||||
|
if override:
|
||||||
|
new_whitelist = dict(policy.whitelist)
|
||||||
|
for phase_name, tools in override.items():
|
||||||
|
phase = PhaseState.from_string(phase_name)
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"whitelist_override[{phase_name!r}] must be a list, got {type(tools).__name__}"
|
||||||
|
)
|
||||||
|
new_whitelist[phase] = frozenset(str(t) for t in tools)
|
||||||
|
policy = replace(policy, whitelist=new_whitelist)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
@ -28,6 +28,7 @@ from agentkit.telemetry.metrics import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from agentkit.core.compressor import CompressionStrategy
|
from agentkit.core.compressor import CompressionStrategy
|
||||||
from agentkit.core.middleware import MiddlewareChain
|
from agentkit.core.middleware import MiddlewareChain
|
||||||
|
from agentkit.core.phase import PhasePolicy, PhaseState
|
||||||
from agentkit.core.trace import TraceRecorder
|
from agentkit.core.trace import TraceRecorder
|
||||||
from agentkit.memory.retriever import MemoryRetriever
|
from agentkit.memory.retriever import MemoryRetriever
|
||||||
|
|
||||||
|
|
@ -168,6 +169,9 @@ class ReActEngine:
|
||||||
prompt_cache_enable: bool = True,
|
prompt_cache_enable: bool = True,
|
||||||
flush_interval_ms: int = 0,
|
flush_interval_ms: int = 0,
|
||||||
max_reinjections: int = 1,
|
max_reinjections: int = 1,
|
||||||
|
# U3/G6: PLAN_EXEC phase policy (opt-in). None = no enforcement
|
||||||
|
# (backward compat — all existing callers unaffected).
|
||||||
|
phase_policy: "PhasePolicy | None" = None,
|
||||||
):
|
):
|
||||||
if max_steps < 1:
|
if max_steps < 1:
|
||||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||||
|
|
@ -211,6 +215,15 @@ class ReActEngine:
|
||||||
self._loop_corrected: bool = False
|
self._loop_corrected: bool = False
|
||||||
# U6: Middleware chain (parallel integration, feature flag controlled)
|
# U6: Middleware chain (parallel integration, feature flag controlled)
|
||||||
self._middleware_chain = middleware_chain
|
self._middleware_chain = middleware_chain
|
||||||
|
# U3/G6: PLAN_EXEC phase state. None = no enforcement (default).
|
||||||
|
# When set, _execute_loop checks each tool call against the current
|
||||||
|
# phase's whitelist before dispatch.
|
||||||
|
self._phase_policy = phase_policy
|
||||||
|
self._current_phase: "PhaseState | None" = (
|
||||||
|
phase_policy.start_phase if phase_policy is not None else None
|
||||||
|
)
|
||||||
|
# Steps taken in the current phase (for auto-advance safety net).
|
||||||
|
self._steps_in_phase: int = 0
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset internal state for reuse across conversations.
|
"""Reset internal state for reuse across conversations.
|
||||||
|
|
@ -223,6 +236,99 @@ class ReActEngine:
|
||||||
# This method exists for API clarity and future stateful extensions.
|
# This method exists for API clarity and future stateful extensions.
|
||||||
self._loop_window.clear()
|
self._loop_window.clear()
|
||||||
self._loop_corrected = False
|
self._loop_corrected = False
|
||||||
|
# U3/G6: reset phase state to start_phase (if policy set). Each
|
||||||
|
# execute() call begins a fresh PLANNING phase.
|
||||||
|
if self._phase_policy is not None:
|
||||||
|
self._current_phase = self._phase_policy.start_phase
|
||||||
|
self._steps_in_phase = 0
|
||||||
|
|
||||||
|
# ── U3/G6: phase state machine ────────────────────────────────────
|
||||||
|
|
||||||
|
def advance_phase(self) -> "PhaseState | None":
|
||||||
|
"""Advance to the next phase. Returns the new phase, or None if
|
||||||
|
already at DELIVERY (final phase).
|
||||||
|
|
||||||
|
Called by AdvancePhaseTool when the LLM explicitly signals phase
|
||||||
|
completion. Also called by the auto-advance safety net when
|
||||||
|
``steps_in_phase >= auto_advance_after_steps``.
|
||||||
|
|
||||||
|
Returns None if no phase_policy is set (no-op).
|
||||||
|
"""
|
||||||
|
if self._phase_policy is None or self._current_phase is None:
|
||||||
|
return None
|
||||||
|
from agentkit.core.phase import PhaseState
|
||||||
|
|
||||||
|
nxt = PhaseState.next_of(self._current_phase)
|
||||||
|
if nxt is None:
|
||||||
|
# Already at DELIVERY — return None to signal no transition.
|
||||||
|
return None
|
||||||
|
previous = self._current_phase
|
||||||
|
self._current_phase = nxt
|
||||||
|
self._steps_in_phase = 0
|
||||||
|
logger.info(
|
||||||
|
"Phase transition: %s → %s",
|
||||||
|
previous.value,
|
||||||
|
nxt.value,
|
||||||
|
)
|
||||||
|
return nxt
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self) -> "PhaseState | None":
|
||||||
|
"""Current phase (None if no phase_policy set)."""
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def _maybe_auto_advance(self) -> bool:
|
||||||
|
"""Auto-advance phase if step budget exhausted. Returns True if advanced."""
|
||||||
|
if self._phase_policy is None or self._current_phase is None:
|
||||||
|
return False
|
||||||
|
threshold = self._phase_policy.auto_advance_after_steps
|
||||||
|
if threshold is None:
|
||||||
|
return False
|
||||||
|
if self._steps_in_phase >= threshold:
|
||||||
|
self.advance_phase()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_phase_permission(
|
||||||
|
self, tool_name: str, arguments: dict[str, Any]
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return None if tool is allowed; return a structured error dict if blocked.
|
||||||
|
|
||||||
|
The error dict replaces what `_execute_tool` would have returned —
|
||||||
|
the loop continues, so the LLM can react to the rejection (call
|
||||||
|
AdvancePhaseTool or pick a different tool).
|
||||||
|
|
||||||
|
Also applies the bash_command_filter for `bash` tool calls.
|
||||||
|
"""
|
||||||
|
if self._phase_policy is None or self._current_phase is None:
|
||||||
|
return None
|
||||||
|
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||||
|
return {
|
||||||
|
"error": "phase_violation",
|
||||||
|
"message": (
|
||||||
|
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
|
||||||
|
f"Call `advance_phase` to move to the next phase."
|
||||||
|
),
|
||||||
|
"current_phase": self._current_phase.value,
|
||||||
|
"tool": tool_name,
|
||||||
|
"is_error": True,
|
||||||
|
}
|
||||||
|
# Bash command filter (only applies to shell tool — registered as "shell").
|
||||||
|
if tool_name == "shell":
|
||||||
|
command = str(arguments.get("command", ""))
|
||||||
|
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
|
||||||
|
return {
|
||||||
|
"error": "phase_violation",
|
||||||
|
"message": (
|
||||||
|
f"Bash command blocked in {self._current_phase.value} phase "
|
||||||
|
f"(filesystem-mutating operations not allowed during "
|
||||||
|
f"planning/verification). Command: {command[:100]}"
|
||||||
|
),
|
||||||
|
"current_phase": self._current_phase.value,
|
||||||
|
"tool": tool_name,
|
||||||
|
"is_error": True,
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||||
"""检测重复工具调用模式。
|
"""检测重复工具调用模式。
|
||||||
|
|
@ -498,6 +604,14 @@ class ReActEngine:
|
||||||
if cancellation_token is not None:
|
if cancellation_token is not None:
|
||||||
cancellation_token.check()
|
cancellation_token.check()
|
||||||
|
|
||||||
|
# U3/G6: phase auto-advance safety net.
|
||||||
|
# Incremented per step (LLM call), not per tool_call. When
|
||||||
|
# auto_advance_after_steps is set, advance the phase after
|
||||||
|
# the LLM has been stuck in the same phase for N steps.
|
||||||
|
if self._phase_policy is not None:
|
||||||
|
self._steps_in_phase += 1
|
||||||
|
self._maybe_auto_advance()
|
||||||
|
|
||||||
# Think: 调用 LLM
|
# Think: 调用 LLM
|
||||||
llm_start = time.monotonic()
|
llm_start = time.monotonic()
|
||||||
response = await self._llm_gateway.chat(
|
response = await self._llm_gateway.chat(
|
||||||
|
|
@ -1148,6 +1262,11 @@ class ReActEngine:
|
||||||
if cancellation_token is not None:
|
if cancellation_token is not None:
|
||||||
cancellation_token.check()
|
cancellation_token.check()
|
||||||
|
|
||||||
|
# U3/G6: phase auto-advance safety net (mirrors _execute_loop).
|
||||||
|
if self._phase_policy is not None:
|
||||||
|
self._steps_in_phase += 1
|
||||||
|
self._maybe_auto_advance()
|
||||||
|
|
||||||
# 超时检查
|
# 超时检查
|
||||||
if effective_timeout > 0:
|
if effective_timeout > 0:
|
||||||
elapsed = time.monotonic() - _stream_start
|
elapsed = time.monotonic() - _stream_start
|
||||||
|
|
@ -2069,6 +2188,20 @@ class ReActEngine:
|
||||||
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
|
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""执行工具调用,处理成功和失败情况"""
|
"""执行工具调用,处理成功和失败情况"""
|
||||||
|
# U3/G6: phase enforcement — check before dispatch. If the tool is
|
||||||
|
# blocked, return a structured error instead of dispatching. The loop
|
||||||
|
# still counts this as a step (the LLM gets to react to the rejection).
|
||||||
|
# `advance_phase` tool bypasses the check (it's the LLM's escape hatch).
|
||||||
|
if tool_name != "advance_phase":
|
||||||
|
block = self._check_phase_permission(tool_name, arguments)
|
||||||
|
if block is not None:
|
||||||
|
logger.info(
|
||||||
|
"Phase violation: tool %r blocked in %s phase",
|
||||||
|
tool_name,
|
||||||
|
self._current_phase.value if self._current_phase else "?",
|
||||||
|
)
|
||||||
|
return block
|
||||||
|
|
||||||
tool = self._find_tool(tool_name, tools)
|
tool = self._find_tool(tool_name, tools)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
error_msg = f"Tool '{tool_name}' not found"
|
error_msg = f"Tool '{tool_name}' not found"
|
||||||
|
|
|
||||||
|
|
@ -121,6 +121,9 @@ class ServerConfig:
|
||||||
verification: dict[str, Any] | None = None,
|
verification: dict[str, Any] | None = None,
|
||||||
rollback: dict[str, Any] | None = None,
|
rollback: dict[str, Any] | None = None,
|
||||||
fallback_chain: dict[str, Any] | None = None,
|
fallback_chain: dict[str, Any] | None = None,
|
||||||
|
# G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled).
|
||||||
|
# Parsed via PhasePolicy.policy_from_config() at chat.py wiring time.
|
||||||
|
plan_exec: dict[str, Any] | None = None,
|
||||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
@ -161,6 +164,10 @@ class ServerConfig:
|
||||||
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
|
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
|
||||||
# controls three-tier chain at chat.py REST send_message (KTD5).
|
# controls three-tier chain at chat.py REST send_message (KTD5).
|
||||||
self.fallback_chain = fallback_chain or {}
|
self.fallback_chain = fallback_chain or {}
|
||||||
|
# G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled).
|
||||||
|
# Resolved to PhasePolicy via agentkit.core.phase.policy_from_config()
|
||||||
|
# at chat.py WebSocket wiring time (U4).
|
||||||
|
self.plan_exec = plan_exec or {}
|
||||||
self.on_change = on_change
|
self.on_change = on_change
|
||||||
|
|
||||||
# Config watching state
|
# Config watching state
|
||||||
|
|
@ -252,6 +259,8 @@ class ServerConfig:
|
||||||
rollback_data = data.get("rollback", {})
|
rollback_data = data.get("rollback", {})
|
||||||
# G7/U3: fallback_chain 配置 (从 YAML 读取)
|
# G7/U3: fallback_chain 配置 (从 YAML 读取)
|
||||||
fallback_chain_data = data.get("fallback_chain", {})
|
fallback_chain_data = data.get("fallback_chain", {})
|
||||||
|
# G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in)
|
||||||
|
plan_exec_data = data.get("plan_exec", {})
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
host=server.get("host", "0.0.0.0"),
|
host=server.get("host", "0.0.0.0"),
|
||||||
|
|
@ -285,6 +294,7 @@ class ServerConfig:
|
||||||
verification=verification_data,
|
verification=verification_data,
|
||||||
rollback=rollback_data,
|
rollback=rollback_data,
|
||||||
fallback_chain=fallback_chain_data,
|
fallback_chain=fallback_chain_data,
|
||||||
|
plan_exec=plan_exec_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,13 @@ from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentkit.chat.skill_routing import ExecutionMode
|
from agentkit.chat.skill_routing import ExecutionMode
|
||||||
|
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
|
||||||
from agentkit.core.protocol import CancellationToken
|
from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||||
from agentkit.session.manager import SessionManager
|
from agentkit.session.manager import SessionManager
|
||||||
from agentkit.session.models import MessageRole, SessionStatus
|
from agentkit.session.models import MessageRole, SessionStatus
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -47,6 +49,8 @@ class CreateSessionRequest(BaseModel):
|
||||||
class SendMessageRequest(BaseModel):
|
class SendMessageRequest(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
role: str = "user"
|
role: str = "user"
|
||||||
|
# Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only).
|
||||||
|
execution_mode: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
class SessionResponse(BaseModel):
|
||||||
|
|
@ -583,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
||||||
if session.status == SessionStatus.CLOSED:
|
if session.status == SessionStatus.CLOSED:
|
||||||
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
||||||
|
|
||||||
|
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
|
||||||
|
if request.execution_mode == "plan_exec":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=501,
|
||||||
|
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
|
||||||
|
)
|
||||||
|
|
||||||
# Append user message
|
# Append user message
|
||||||
await sm.append_message(
|
await sm.append_message(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
|
@ -1079,21 +1090,73 @@ async def _handle_chat_message(
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
# U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
|
||||||
# currently fall back to REACT with a warning.
|
# KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
|
||||||
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
# fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
|
||||||
|
phase_policy: PhasePolicy | None = None
|
||||||
|
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||||
|
server_config = getattr(websocket.app.state, "server_config", None)
|
||||||
|
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
|
||||||
|
|
||||||
|
if plan_exec_cfg.get("enabled", True) is False:
|
||||||
|
# Explicit opt-out → fall back to REACT.
|
||||||
|
logger.info(
|
||||||
|
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
|
||||||
|
"falling back to REACT for session %s",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
phase_policy = policy_from_config(plan_exec_cfg)
|
||||||
|
if phase_policy is None:
|
||||||
|
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||||
|
phase_policy = default_policy()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||||
|
session_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
# Truncate to 200 chars to match nearby error paths and
|
||||||
|
# avoid leaking config internals (see chat.py:1090, 1320).
|
||||||
|
"data": {"message": f"phase policy error: {str(e)[:200]}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB
|
||||||
|
# still fall back to REACT with a warning. PLAN_EXEC is handled above.
|
||||||
|
if routing.execution_mode not in (
|
||||||
|
ExecutionMode.REACT,
|
||||||
|
ExecutionMode.SKILL_REACT,
|
||||||
|
ExecutionMode.PLAN_EXEC,
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Execution mode {routing.execution_mode.value} not yet supported "
|
f"Execution mode {routing.execution_mode.value} not yet supported "
|
||||||
f"in chat WebSocket, falling back to REACT"
|
f"in chat WebSocket, falling back to REACT"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute Agent with streaming
|
# Execute Agent with streaming
|
||||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
|
||||||
react_engine = getattr(agent, "_react_engine", None)
|
# PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
|
||||||
if react_engine is None:
|
# agent's _react_engine — it has no policy).
|
||||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
if phase_policy is not None:
|
||||||
|
react_engine = ReActEngine(
|
||||||
|
llm_gateway=websocket.app.state.llm_gateway,
|
||||||
|
phase_policy=phase_policy,
|
||||||
|
)
|
||||||
|
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
|
||||||
|
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
|
||||||
|
routing.tools = list(routing.tools) + [advance_phase_tool]
|
||||||
else:
|
else:
|
||||||
react_engine.reset()
|
react_engine = getattr(agent, "_react_engine", None)
|
||||||
|
if react_engine is None:
|
||||||
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||||
|
else:
|
||||||
|
react_engine.reset()
|
||||||
|
|
||||||
# Create confirmation handler that sends request to frontend and waits for reply
|
# Create confirmation handler that sends request to frontend and waits for reply
|
||||||
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
|
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
|
||||||
|
|
@ -1149,6 +1212,9 @@ async def _handle_chat_message(
|
||||||
try:
|
try:
|
||||||
final_content = ""
|
final_content = ""
|
||||||
token_buffer: list[str] = []
|
token_buffer: list[str] = []
|
||||||
|
# Track phase transitions for phase_changed events (PLAN_EXEC only).
|
||||||
|
# For non-PLAN_EXEC modes, current_phase is always None → no events.
|
||||||
|
prev_phase = react_engine.current_phase
|
||||||
async for event in react_engine.execute_stream(
|
async for event in react_engine.execute_stream(
|
||||||
messages=chat_messages,
|
messages=chat_messages,
|
||||||
tools=routing.tools,
|
tools=routing.tools,
|
||||||
|
|
@ -1226,6 +1292,22 @@ async def _handle_chat_message(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# U4/G6: emit phase_changed event when the phase state machine
|
||||||
|
# transitions (PLAN_EXEC only). For non-PLAN_EXEC modes,
|
||||||
|
# current_phase is always None → this branch never fires.
|
||||||
|
curr_phase = react_engine.current_phase
|
||||||
|
if curr_phase != prev_phase:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "phase_changed",
|
||||||
|
"data": {
|
||||||
|
"phase": curr_phase.value if curr_phase else None,
|
||||||
|
"previous": prev_phase.value if prev_phase else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prev_phase = curr_phase
|
||||||
|
|
||||||
# Append assistant reply to session
|
# Append assistant reply to session
|
||||||
if final_content:
|
if final_content:
|
||||||
await sm.append_message(
|
await sm.append_message(
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ from agentkit.tools.memory_tool import MemoryTool
|
||||||
from agentkit.tools.web_search import WebSearchTool
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||||
from agentkit.tools.search import ToolSearchIndex
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
from agentkit.tools.file_read import ReadFileTool
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -52,4 +54,6 @@ __all__ = [
|
||||||
"OutputParser",
|
"OutputParser",
|
||||||
"ParsedOutput",
|
"ParsedOutput",
|
||||||
"ErrorType",
|
"ErrorType",
|
||||||
|
"ReadFileTool",
|
||||||
|
"AdvancePhaseTool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
"""AdvancePhaseTool — LLM-driven phase transition (G6, KTD6).
|
||||||
|
|
||||||
|
Registered alongside other tools when ReActEngine has a phase_policy set.
|
||||||
|
The LLM calls this tool to signal "I'm done planning, move to building".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdvancePhaseTool(Tool):
|
||||||
|
"""Tool that advances the ReActEngine's current phase.
|
||||||
|
|
||||||
|
KTD6: LLM-driven phase transitions. Auto-advance is opt-in via
|
||||||
|
``plan_exec.auto_advance_after_steps``; this tool is the manual path.
|
||||||
|
|
||||||
|
The tool holds a weak reference to the engine (via bound method
|
||||||
|
``engine.advance_phase``) — registered only when phase_policy is set.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: "ReActEngine",
|
||||||
|
name: str = "advance_phase",
|
||||||
|
description: str | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description
|
||||||
|
or (
|
||||||
|
"Advance the PLAN_EXEC phase state machine to the next phase "
|
||||||
|
"(Planning → Building → Verification → Delivery). Call this "
|
||||||
|
"when you have finished the current phase's work and are ready "
|
||||||
|
"to move on. Returns the new phase name or an error if you "
|
||||||
|
"are already at the final (Delivery) phase."
|
||||||
|
),
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"additionalProperties": False,
|
||||||
|
},
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["phase", "control"],
|
||||||
|
)
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||||
|
# Capture previous phase before transition (engine is single-threaded per request).
|
||||||
|
previous = self._engine.current_phase
|
||||||
|
new_phase = self._engine.advance_phase()
|
||||||
|
if new_phase is None:
|
||||||
|
# Either no policy set, or already at DELIVERY.
|
||||||
|
current = self._engine.current_phase
|
||||||
|
if current is None:
|
||||||
|
return {
|
||||||
|
"is_error": True,
|
||||||
|
"error": "no_phase_policy",
|
||||||
|
"message": "No phase policy is set — advance_phase is a no-op.",
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"is_error": True,
|
||||||
|
"error": "already_at_final_phase",
|
||||||
|
"message": (f"Already at final phase ({current.value}). Cannot advance further."),
|
||||||
|
"current_phase": current.value,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"is_error": False,
|
||||||
|
"previous_phase": previous.value if previous else "",
|
||||||
|
"current_phase": new_phase.value,
|
||||||
|
"message": f"Phase advanced to {new_phase.value}.",
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,262 @@
|
||||||
|
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
|
||||||
|
|
||||||
|
Backward compatible with the pre-existing `_FakeTool` benchmark shape — when
|
||||||
|
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
|
||||||
|
the line range of the first matching symbol via `SymbolExtractor`.
|
||||||
|
|
||||||
|
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the
|
||||||
|
file-reading contract clean and gives the LLM a focused schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.symbol_extractor import (
|
||||||
|
SymbolSpan,
|
||||||
|
extract_symbols_from_file,
|
||||||
|
language_for_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(Tool):
|
||||||
|
"""Read a file from the filesystem, optionally sliced to a single symbol.
|
||||||
|
|
||||||
|
Tool name `read_file` matches the reserved entry in
|
||||||
|
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
|
||||||
|
implementation — only `_FakeTool` stubs in `cli/benchmark.py`).
|
||||||
|
|
||||||
|
Backward-compat contract: `symbol=None` returns the full file content,
|
||||||
|
matching the shape `{"path": ...}` that downstream callers (benchmark,
|
||||||
|
phase whitelist) already expect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "read_file",
|
||||||
|
description: str | None = None,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description
|
||||||
|
or (
|
||||||
|
"Read a file from the filesystem. By default returns the full file "
|
||||||
|
"content. Pass `symbol` (function/class/struct name) to slice to just "
|
||||||
|
"that symbol's line range — saves context when you only need one "
|
||||||
|
"function from a large file. Pass `start_line`/`end_line` for manual "
|
||||||
|
"slicing. If `symbol` is set but not found, returns the available "
|
||||||
|
"symbol names so you can retry."
|
||||||
|
),
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["io", "file", "read"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the file to read (absolute or relative to cwd).",
|
||||||
|
},
|
||||||
|
"symbol": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional: name of a function/class/struct/method to slice to. "
|
||||||
|
"When set, returns only the line range of the first matching "
|
||||||
|
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"start_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
|
||||||
|
"minimum": 1,
|
||||||
|
},
|
||||||
|
"end_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
|
||||||
|
"minimum": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string"},
|
||||||
|
"path": {"type": "string"},
|
||||||
|
"start_line": {"type": "integer"},
|
||||||
|
"end_line": {"type": "integer"},
|
||||||
|
"symbol": {"type": "string"},
|
||||||
|
"available_symbols": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Populated when `symbol` is set but not found.",
|
||||||
|
},
|
||||||
|
"note": {"type": "string"},
|
||||||
|
"is_error": {"type": "boolean"},
|
||||||
|
"error": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||||
|
raw_path = kwargs.get("path")
|
||||||
|
if not raw_path:
|
||||||
|
return self._error("`path` is required")
|
||||||
|
|
||||||
|
path = Path(raw_path)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = path.resolve()
|
||||||
|
|
||||||
|
symbol = kwargs.get("symbol")
|
||||||
|
start_line = kwargs.get("start_line")
|
||||||
|
end_line = kwargs.get("end_line")
|
||||||
|
|
||||||
|
# Validate/sanitize line overrides.
|
||||||
|
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
|
||||||
|
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
|
||||||
|
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
|
||||||
|
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
|
||||||
|
if start_line is not None and end_line is not None and end_line < start_line:
|
||||||
|
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
|
||||||
|
|
||||||
|
# Filesystem checks.
|
||||||
|
if not path.exists():
|
||||||
|
return self._error(f"File not found: {path}", path=str(path))
|
||||||
|
if path.is_dir():
|
||||||
|
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
except PermissionError as e:
|
||||||
|
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
|
||||||
|
except OSError as e:
|
||||||
|
return self._error(f"Failed to read {path}: {e}", path=str(path))
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Manual slicing takes precedence over symbol (per plan U1 Approach).
|
||||||
|
if start_line is not None or end_line is not None:
|
||||||
|
s = max(1, start_line or 1)
|
||||||
|
e = min(total_lines, end_line or total_lines)
|
||||||
|
sliced = "\n".join(lines[s - 1 : e])
|
||||||
|
return {
|
||||||
|
"content": sliced,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": s,
|
||||||
|
"end_line": e,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Symbol slicing.
|
||||||
|
if symbol:
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
language = language_for_extension(ext)
|
||||||
|
if not language:
|
||||||
|
# Unsupported extension: return full file with note (per plan U1 Edge case).
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": 1,
|
||||||
|
"end_line": total_lines,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
spans, _lang = extract_symbols_from_file(path)
|
||||||
|
# Re-extract using the content we already read so we don't read the file twice.
|
||||||
|
if not spans:
|
||||||
|
# Try extraction from in-memory content (path-based extraction may
|
||||||
|
# have failed silently on OSError; we already read it successfully).
|
||||||
|
from agentkit.tools.symbol_extractor import get_extractor
|
||||||
|
|
||||||
|
extractor = get_extractor(language)
|
||||||
|
if extractor is not None:
|
||||||
|
spans = extractor.extract_symbols(content, language)
|
||||||
|
|
||||||
|
match = _find_symbol(spans, symbol)
|
||||||
|
if match is None:
|
||||||
|
available = sorted({s.name for s in spans})
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"path": str(path),
|
||||||
|
"symbol": symbol,
|
||||||
|
"available_symbols": available,
|
||||||
|
"is_error": False,
|
||||||
|
"note": (
|
||||||
|
f"Symbol {symbol!r} not found in {path.name}. "
|
||||||
|
f"Available: {', '.join(available) if available else '(none)'}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
s = match.start_line
|
||||||
|
e = min(match.end_line, total_lines)
|
||||||
|
sliced = "\n".join(lines[s - 1 : e])
|
||||||
|
return {
|
||||||
|
"content": sliced,
|
||||||
|
"path": str(path),
|
||||||
|
"symbol": symbol,
|
||||||
|
"symbol_kind": match.kind,
|
||||||
|
"start_line": s,
|
||||||
|
"end_line": e,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default: full file (characterization baseline — matches _FakeTool shape).
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": 1,
|
||||||
|
"end_line": total_lines,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _error(
|
||||||
|
message: str, *, path: str | None = None, detail: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"content": "",
|
||||||
|
"is_error": True,
|
||||||
|
"error": message,
|
||||||
|
}
|
||||||
|
if path is not None:
|
||||||
|
result["path"] = path
|
||||||
|
if detail is not None:
|
||||||
|
result["detail"] = detail
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
|
||||||
|
"""Find the first symbol matching `name`. Case-sensitive.
|
||||||
|
|
||||||
|
ponytail: linear scan is fine for typical file symbol counts (<100). The
|
||||||
|
extractor already returns symbols sorted by start_line; first match wins
|
||||||
|
for ambiguous overloads (e.g., Python classes with same name in different
|
||||||
|
modules — not relevant within one file).
|
||||||
|
"""
|
||||||
|
for span in spans:
|
||||||
|
if span.name == name:
|
||||||
|
return span
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,278 @@
|
||||||
|
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
|
||||||
|
|
||||||
|
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
|
||||||
|
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
|
||||||
|
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
|
||||||
|
can replace RegexSymbolExtractor behind the same interface.
|
||||||
|
|
||||||
|
ponytail: regex extractor covers ~80% case (top-level function/class/struct
|
||||||
|
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
|
||||||
|
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class SymbolSpan:
|
||||||
|
"""A located symbol — name, kind, and 1-based inclusive line range."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
kind: str # "function" | "class" | "method" | "struct" | "impl"
|
||||||
|
start_line: int # 1-based, inclusive
|
||||||
|
end_line: int # 1-based, inclusive
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SymbolExtractor(Protocol):
|
||||||
|
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
"""Return all symbols found in `content`.
|
||||||
|
|
||||||
|
`language` is the file extension without leading dot (e.g. "py", "ts").
|
||||||
|
Implementations must never raise on extraction failure — return [] on
|
||||||
|
parse errors and let the caller decide the fallback (full-file read).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Python — stdlib ast
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AstSymbolExtractor:
|
||||||
|
"""Python symbol extractor using the stdlib `ast` module.
|
||||||
|
|
||||||
|
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
|
||||||
|
functions inside classes. The end_line is the last line of the node's
|
||||||
|
source segment (decorator-inclusive).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
if language != "py":
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
tree = ast.parse(content)
|
||||||
|
except SyntaxError as e:
|
||||||
|
logger.debug("ast.parse failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
spans: list[SymbolSpan] = []
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
kind = "method" if _is_method(node) else "function"
|
||||||
|
spans.append(_span_from_node(node, kind, lines))
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
spans.append(_span_from_node(node, "class", lines))
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
def _is_method(node: ast.AST) -> bool:
|
||||||
|
"""A FunctionDef is a method if its parent is a ClassDef.
|
||||||
|
|
||||||
|
`ast.walk` doesn't expose parentage, so we approximate by checking the
|
||||||
|
node's col_offset == 4 (indented inside a class body). ponytail: this
|
||||||
|
misses methods in deeply nested classes — ceiling noted; upgrade path =
|
||||||
|
ast.NodeVisitor with parent tracking.
|
||||||
|
"""
|
||||||
|
return getattr(node, "col_offset", 0) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _span_from_node(
|
||||||
|
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||||
|
kind: str,
|
||||||
|
lines: list[str],
|
||||||
|
) -> SymbolSpan:
|
||||||
|
# ast line numbers are 1-based; start at decorator if present (lineno points
|
||||||
|
# to the def/class keyword, decorators are above). Use node.lineno for start
|
||||||
|
# so the returned range matches what the user sees at the def keyword.
|
||||||
|
start = node.lineno
|
||||||
|
# node.end_lineno is the last line of the node body (None on old Pythons).
|
||||||
|
end = node.end_lineno or start
|
||||||
|
# Clamp to actual file length (defensive — ast should not exceed, but
|
||||||
|
# malformed files with no trailing newline can confuse end_lineno).
|
||||||
|
if end > len(lines):
|
||||||
|
end = len(lines)
|
||||||
|
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regex extractor — TS/JS/Go/Rust/Java
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Each pattern matches a declaration and captures the symbol name in group 1.
|
||||||
|
# Patterns use re.MULTILINE so ^ matches line starts.
|
||||||
|
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
|
||||||
|
"ts": [
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
|
||||||
|
),
|
||||||
|
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"js": [
|
||||||
|
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
|
||||||
|
],
|
||||||
|
"rs": [
|
||||||
|
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"class",
|
||||||
|
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RegexSymbolExtractor:
|
||||||
|
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
|
||||||
|
|
||||||
|
Returns SymbolSpans whose end_line is approximated by the next blank line
|
||||||
|
or next-symbol start (whichever comes first). ponytail: this is an
|
||||||
|
approximation — true block-end requires language-aware brace matching.
|
||||||
|
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
|
||||||
|
tree-sitter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
patterns = _REGEX_PATTERNS.get(language)
|
||||||
|
if not patterns:
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
# Collect (line_no, name, kind) tuples first, then compute end_line
|
||||||
|
# as the line before the next symbol starts (or EOF).
|
||||||
|
raw_hits: list[tuple[int, str, str]] = []
|
||||||
|
for kind, pattern in patterns:
|
||||||
|
for m in pattern.finditer(content):
|
||||||
|
# Convert match offset to 1-based line number.
|
||||||
|
line_no = content[: m.start()].count("\n") + 1
|
||||||
|
raw_hits.append((line_no, m.group(1), kind))
|
||||||
|
|
||||||
|
if not raw_hits:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
|
||||||
|
seen: set[tuple[int, str]] = set()
|
||||||
|
unique: list[tuple[int, str, str]] = []
|
||||||
|
for line_no, name, kind in raw_hits:
|
||||||
|
key = (line_no, name)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique.append((line_no, name, kind))
|
||||||
|
|
||||||
|
unique.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
spans: list[SymbolSpan] = []
|
||||||
|
for i, (start_line, name, kind) in enumerate(unique):
|
||||||
|
if i + 1 < len(unique):
|
||||||
|
# End at line before next symbol starts, capped at file length.
|
||||||
|
end_line = unique[i + 1][0] - 1
|
||||||
|
else:
|
||||||
|
end_line = len(lines)
|
||||||
|
if end_line < start_line:
|
||||||
|
end_line = start_line
|
||||||
|
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dispatch by file extension
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_EXTENSION_LANGUAGE: dict[str, str] = {
|
||||||
|
".py": "py",
|
||||||
|
".ts": "ts",
|
||||||
|
".tsx": "ts",
|
||||||
|
".js": "js",
|
||||||
|
".jsx": "js",
|
||||||
|
".mjs": "js",
|
||||||
|
".cjs": "js",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rs",
|
||||||
|
".java": "java",
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
|
||||||
|
_REGEX_EXTRACTOR = RegexSymbolExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
def language_for_extension(ext: str) -> str:
|
||||||
|
"""Return the language key for a file extension (with or without leading dot).
|
||||||
|
|
||||||
|
Returns "" for unsupported extensions.
|
||||||
|
"""
|
||||||
|
if not ext.startswith("."):
|
||||||
|
ext = "." + ext
|
||||||
|
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor(language: str) -> SymbolExtractor | None:
|
||||||
|
"""Return the appropriate extractor for `language`, or None if unsupported."""
|
||||||
|
if language == "py":
|
||||||
|
return _DEFAULT_EXTRACTOR
|
||||||
|
if language in _REGEX_PATTERNS:
|
||||||
|
return _REGEX_EXTRACTOR
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
|
||||||
|
"""Read a file and return (symbols, language).
|
||||||
|
|
||||||
|
Returns ([], "") if the extension is unsupported or the file cannot be read.
|
||||||
|
Never raises — callers use this for fallback routing.
|
||||||
|
"""
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
language = language_for_extension(ext)
|
||||||
|
if not language:
|
||||||
|
return [], ""
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
except OSError as e:
|
||||||
|
logger.debug("read failed for %s: %s", path, e)
|
||||||
|
return [], language
|
||||||
|
extractor = get_extractor(language)
|
||||||
|
if extractor is None:
|
||||||
|
return [], language
|
||||||
|
return extractor.extract_symbols(content, language), language
|
||||||
|
|
@ -0,0 +1,531 @@
|
||||||
|
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
|
||||||
|
|
||||||
|
Per plan U4 Execution note: characterization-first — verify that existing
|
||||||
|
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
|
||||||
|
(no regression). Then add PLAN_EXEC wiring tests.
|
||||||
|
|
||||||
|
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
|
from agentkit.core.phase import PhaseState
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_with_chat():
|
||||||
|
"""Create a FastAPI app with Chat routes and mocked dependencies."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from agentkit.server.routes.chat import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
from agentkit.session.manager import SessionManager
|
||||||
|
from agentkit.session.store import InMemorySessionStore
|
||||||
|
|
||||||
|
app.state.session_manager = SessionManager(store=InMemorySessionStore())
|
||||||
|
app.state.llm_gateway = MagicMock()
|
||||||
|
app.state.agent_pool = MagicMock()
|
||||||
|
app.state.server_config = MagicMock()
|
||||||
|
app.state.server_config.api_key = None
|
||||||
|
app.state.server_config.plan_exec = {}
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app_with_chat):
|
||||||
|
return TestClient(app_with_chat)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_routing(
|
||||||
|
execution_mode: ExecutionMode = ExecutionMode.REACT,
|
||||||
|
tools: list | None = None,
|
||||||
|
) -> SkillRoutingResult:
|
||||||
|
"""Build a minimal SkillRoutingResult for testing."""
|
||||||
|
return SkillRoutingResult(
|
||||||
|
execution_mode=execution_mode,
|
||||||
|
tools=tools or [],
|
||||||
|
clean_content="test message",
|
||||||
|
model="default",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
skill_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_websocket_mock(app) -> MagicMock:
|
||||||
|
"""Build a mock WebSocket with app.state and async send_json."""
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.app = app
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _make_agent_mock() -> MagicMock:
|
||||||
|
"""Build a mock Agent with _tool_registry and _react_engine."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = "test-agent"
|
||||||
|
agent._tool_registry = MagicMock()
|
||||||
|
agent._tool_registry.list_tools.return_value = []
|
||||||
|
agent._system_prompt = None
|
||||||
|
# _react_engine is None to force the code path that creates a new engine
|
||||||
|
agent._react_engine = None
|
||||||
|
agent.get_model.return_value = "default"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session_manager_mock() -> MagicMock:
|
||||||
|
"""Build a mock SessionManager with async methods."""
|
||||||
|
sm = MagicMock()
|
||||||
|
# get_session returns a mock session with agent_name="test-agent"
|
||||||
|
session = MagicMock()
|
||||||
|
session.agent_name = "test-agent"
|
||||||
|
session.status = "active"
|
||||||
|
sm.get_session = AsyncMock(return_value=session)
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[])
|
||||||
|
sm.append_message = AsyncMock()
|
||||||
|
return sm
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
|
||||||
|
"""Wire up app.state so _handle_chat_message finds the right routing."""
|
||||||
|
app.state.agent_pool.get_agent.return_value = agent
|
||||||
|
app.state.request_preprocessor = MagicMock()
|
||||||
|
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# REST — PLAN_EXEC raises 501 (KTD4)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRestPlanExec501:
|
||||||
|
def test_rest_plan_exec_returns_501(self, client):
|
||||||
|
"""REST send_message with execution_mode=plan_exec → 501."""
|
||||||
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = client.post(
|
||||||
|
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||||
|
json={"content": "Hello", "execution_mode": "plan_exec"},
|
||||||
|
)
|
||||||
|
assert msg_resp.status_code == 501
|
||||||
|
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
|
||||||
|
|
||||||
|
def test_rest_react_mode_still_works(self, client):
|
||||||
|
"""REST send_message without execution_mode doesn't 501."""
|
||||||
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# No execution_mode field → should NOT trigger 501.
|
||||||
|
msg_resp = client.post(
|
||||||
|
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||||
|
json={"content": "Hello"},
|
||||||
|
)
|
||||||
|
assert msg_resp.status_code != 501
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Characterization — REWOO still falls back to REACT (no regression)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
|
||||||
|
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine_kwargs: dict = {}
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_engine_kwargs.update(kwargs)
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield # async generator marker
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# REWOO should NOT build a phase_policy
|
||||||
|
assert captured_engine_kwargs.get("phase_policy") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
|
||||||
|
app_with_chat,
|
||||||
|
):
|
||||||
|
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine: list = []
|
||||||
|
captured_tools: list = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = (
|
||||||
|
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
captured_tools.extend(kwargs.get("tools", []))
|
||||||
|
captured_engine.append(self)
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_engine) == 1
|
||||||
|
engine = captured_engine[0]
|
||||||
|
assert engine._phase_policy is not None
|
||||||
|
assert engine._current_phase == PhaseState.PLANNING
|
||||||
|
# AdvancePhaseTool was registered in the tools list
|
||||||
|
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
|
||||||
|
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_policy: list = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_policy.append(kwargs.get("phase_policy"))
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = (
|
||||||
|
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_policy) == 1
|
||||||
|
assert captured_policy[0] is not None
|
||||||
|
# Default policy: PLANNING allows search but not write_file
|
||||||
|
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||||
|
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
|
||||||
|
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {"enabled": False}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine_kwargs: dict = {}
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_engine_kwargs.update(kwargs)
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# enabled=False → no phase_policy (falls back to REACT)
|
||||||
|
assert captured_engine_kwargs.get("phase_policy") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
|
||||||
|
"""Error: phase policy construction fails → error event sent, returns early."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
engine_constructor_called = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
engine_constructor_called.append(kwargs)
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
error_messages = [m for m in sent_messages if m.get("type") == "error"]
|
||||||
|
assert len(error_messages) == 1
|
||||||
|
assert "phase policy error" in error_messages[0]["data"]["message"]
|
||||||
|
# Engine constructor was NOT called (returned early)
|
||||||
|
assert len(engine_constructor_called) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# phase_changed event emission
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
|
||||||
|
"""phase_changed event sent when current_phase changes during execute_stream."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = PhaseState.PLANNING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=1,
|
||||||
|
data={"tool": "search", "output": "ok"},
|
||||||
|
)
|
||||||
|
# Simulate phase transition (as if AdvancePhaseTool was called)
|
||||||
|
self._current_phase = PhaseState.BUILDING
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="final_answer",
|
||||||
|
step=2,
|
||||||
|
data={"output": "done"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="go",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||||
|
assert len(phase_events) == 1
|
||||||
|
assert phase_events[0]["data"]["phase"] == "building"
|
||||||
|
assert phase_events[0]["data"]["previous"] == "planning"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
|
||||||
|
"""Characterization: REACT mode → no phase_changed events."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.REACT)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = None
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="hi",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||||
|
assert len(phase_events) == 0
|
||||||
|
|
@ -0,0 +1,348 @@
|
||||||
|
"""Unit tests for PhasePolicy + PhaseState (G6 core, R24/R25/R26).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- PhaseState enum (next_of, from_string)
|
||||||
|
- default_policy() KTD5 whitelist
|
||||||
|
- PhasePolicy.is_tool_allowed / is_bash_command_allowed
|
||||||
|
- policy_from_config parsing (R26 config-driven)
|
||||||
|
- ServerConfig.plan_exec integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.phase import (
|
||||||
|
WILDCARD,
|
||||||
|
PhasePolicy,
|
||||||
|
PhaseState,
|
||||||
|
default_policy,
|
||||||
|
policy_from_config,
|
||||||
|
)
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PhaseState enum
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhaseState:
|
||||||
|
def test_values(self):
|
||||||
|
assert PhaseState.PLANNING.value == "planning"
|
||||||
|
assert PhaseState.BUILDING.value == "building"
|
||||||
|
assert PhaseState.VERIFICATION.value == "verification"
|
||||||
|
assert PhaseState.DELIVERY.value == "delivery"
|
||||||
|
|
||||||
|
def test_next_of(self):
|
||||||
|
assert PhaseState.next_of(PhaseState.PLANNING) == PhaseState.BUILDING
|
||||||
|
assert PhaseState.next_of(PhaseState.BUILDING) == PhaseState.VERIFICATION
|
||||||
|
assert PhaseState.next_of(PhaseState.VERIFICATION) == PhaseState.DELIVERY
|
||||||
|
assert PhaseState.next_of(PhaseState.DELIVERY) is None
|
||||||
|
|
||||||
|
def test_from_string_case_insensitive(self):
|
||||||
|
assert PhaseState.from_string("planning") == PhaseState.PLANNING
|
||||||
|
assert PhaseState.from_string("PLANNING") == PhaseState.PLANNING
|
||||||
|
assert PhaseState.from_string("Building") == PhaseState.BUILDING
|
||||||
|
|
||||||
|
def test_from_string_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||||
|
PhaseState.from_string("unknown")
|
||||||
|
with pytest.raises(ValueError, match="Valid:"):
|
||||||
|
PhaseState.from_string("exploration")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# default_policy() — KTD5 whitelist
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultPolicy:
|
||||||
|
def test_has_all_four_phases(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert PhaseState.PLANNING in policy.whitelist
|
||||||
|
assert PhaseState.BUILDING in policy.whitelist
|
||||||
|
assert PhaseState.VERIFICATION in policy.whitelist
|
||||||
|
assert PhaseState.DELIVERY in policy.whitelist
|
||||||
|
|
||||||
|
def test_planning_whitelist_matches_r24(self):
|
||||||
|
policy = default_policy()
|
||||||
|
allowed = policy.whitelist[PhaseState.PLANNING]
|
||||||
|
assert "search" in allowed
|
||||||
|
assert "read_file" in allowed
|
||||||
|
assert "shell" in allowed
|
||||||
|
assert "tool_search" in allowed
|
||||||
|
# Planning must NOT allow write_file.
|
||||||
|
assert "write_file" not in allowed
|
||||||
|
|
||||||
|
def test_building_whitelist_includes_write_file(self):
|
||||||
|
policy = default_policy()
|
||||||
|
allowed = policy.whitelist[PhaseState.BUILDING]
|
||||||
|
assert "write_file" in allowed
|
||||||
|
assert "shell" in allowed
|
||||||
|
assert "read_file" in allowed
|
||||||
|
|
||||||
|
def test_verification_whitelist_excludes_write(self):
|
||||||
|
policy = default_policy()
|
||||||
|
allowed = policy.whitelist[PhaseState.VERIFICATION]
|
||||||
|
assert "shell" in allowed
|
||||||
|
assert "read_file" in allowed
|
||||||
|
assert "write_file" not in allowed
|
||||||
|
|
||||||
|
def test_delivery_wildcard(self):
|
||||||
|
policy = default_policy()
|
||||||
|
allowed = policy.whitelist[PhaseState.DELIVERY]
|
||||||
|
assert WILDCARD in allowed
|
||||||
|
|
||||||
|
def test_start_phase_default_planning(self):
|
||||||
|
assert default_policy().start_phase == PhaseState.PLANNING
|
||||||
|
|
||||||
|
def test_auto_advance_default_none(self):
|
||||||
|
# KTD6: manual by default.
|
||||||
|
assert default_policy().auto_advance_after_steps is None
|
||||||
|
|
||||||
|
def test_bash_filter_blocks_rm_in_planning(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_bash_command_allowed("ls -la", PhaseState.PLANNING) is True
|
||||||
|
assert policy.is_bash_command_allowed("git status", PhaseState.PLANNING) is True
|
||||||
|
assert policy.is_bash_command_allowed("rm -rf /tmp/x", PhaseState.PLANNING) is False
|
||||||
|
assert policy.is_bash_command_allowed("echo x > file.txt", PhaseState.PLANNING) is False
|
||||||
|
|
||||||
|
def test_bash_filter_no_restriction_in_building(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True
|
||||||
|
assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PhasePolicy — is_tool_allowed
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsToolAllowed:
|
||||||
|
def test_planning_allows_search(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
|
||||||
|
|
||||||
|
def test_planning_blocks_write_file(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_tool_allowed("write_file", PhaseState.PLANNING) is False
|
||||||
|
|
||||||
|
def test_building_allows_write_file(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||||
|
|
||||||
|
def test_delivery_wildcard_allows_anything(self):
|
||||||
|
policy = default_policy()
|
||||||
|
assert policy.is_tool_allowed("any_random_tool", PhaseState.DELIVERY) is True
|
||||||
|
assert policy.is_tool_allowed("write_file", PhaseState.DELIVERY) is True
|
||||||
|
|
||||||
|
def test_unknown_phase_returns_false(self):
|
||||||
|
# ponytail: unknown phase → empty whitelist → no tool allowed.
|
||||||
|
# We can't construct an unknown PhaseState (enum), but if a phase
|
||||||
|
# were missing from the whitelist dict, is_tool_allowed should
|
||||||
|
# return False (defensive).
|
||||||
|
policy = PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset({"search"}),
|
||||||
|
PhaseState.BUILDING: frozenset({"write_file"}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# BUILDING is in whitelist, so allowed checks work normally.
|
||||||
|
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||||
|
# Phase missing from whitelist would return False (defensive .get default).
|
||||||
|
# We test this by constructing a minimal policy.
|
||||||
|
minimal = PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset({WILDCARD}),
|
||||||
|
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# VERIFICATION is in whitelist — wildcard allows all.
|
||||||
|
assert minimal.is_tool_allowed("anything", PhaseState.VERIFICATION) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PhasePolicy — edge cases & errors
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhasePolicyEdgeCases:
|
||||||
|
def test_empty_whitelist_raises(self):
|
||||||
|
# Fail-fast: an empty whitelist for a non-wildcard phase is a bug.
|
||||||
|
with pytest.raises(ValueError, match="empty whitelist"):
|
||||||
|
PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset(), # empty!
|
||||||
|
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wildcard_only_does_not_raise(self):
|
||||||
|
# Wildcard-only whitelist is valid (means "all tools").
|
||||||
|
policy = PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset({WILDCARD}),
|
||||||
|
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert policy.is_tool_allowed("anything", PhaseState.PLANNING) is True
|
||||||
|
|
||||||
|
def test_to_dict_serializable(self):
|
||||||
|
policy = default_policy()
|
||||||
|
d = policy.to_dict()
|
||||||
|
assert "whitelist" in d
|
||||||
|
assert "planning" in d["whitelist"]
|
||||||
|
assert "delivery" in d["whitelist"]
|
||||||
|
assert d["start_phase"] == "planning"
|
||||||
|
assert d["auto_advance_after_steps"] is None
|
||||||
|
|
||||||
|
def test_custom_bash_filter(self):
|
||||||
|
custom_filter = re.compile(r"\b(pip install|npm install)\b")
|
||||||
|
policy = PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset({"shell"}),
|
||||||
|
PhaseState.BUILDING: frozenset({"shell"}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||||
|
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||||
|
},
|
||||||
|
bash_command_filter={PhaseState.BUILDING: custom_filter},
|
||||||
|
)
|
||||||
|
assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False
|
||||||
|
assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# policy_from_config — R26 (config-driven)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolicyFromConfig:
|
||||||
|
def test_empty_config_returns_none(self):
|
||||||
|
assert policy_from_config({}) is None
|
||||||
|
|
||||||
|
def test_enabled_false_returns_none(self):
|
||||||
|
# Opt-out — explicit `enabled: false` disables policy.
|
||||||
|
result = policy_from_config({"enabled": False})
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_enabled_default_true_when_section_present(self):
|
||||||
|
# When section is present but `enabled` is missing, default is True.
|
||||||
|
result = policy_from_config({"auto_advance_after_steps": 3})
|
||||||
|
assert result is not None
|
||||||
|
assert result.auto_advance_after_steps == 3
|
||||||
|
|
||||||
|
def test_auto_advance_after_steps(self):
|
||||||
|
policy = policy_from_config({"enabled": True, "auto_advance_after_steps": 5})
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.auto_advance_after_steps == 5
|
||||||
|
|
||||||
|
def test_start_phase_custom(self):
|
||||||
|
policy = policy_from_config({"enabled": True, "start_phase": "building"})
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.start_phase == PhaseState.BUILDING
|
||||||
|
|
||||||
|
def test_start_phase_invalid_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||||
|
policy_from_config({"enabled": True, "start_phase": "unknown"})
|
||||||
|
|
||||||
|
def test_whitelist_override_merges_with_default(self):
|
||||||
|
policy = policy_from_config(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"whitelist_override": {
|
||||||
|
"planning": ["search", "read_file"], # removes shell from default
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert policy is not None
|
||||||
|
# Override wins — shell should be removed from planning.
|
||||||
|
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
|
||||||
|
assert policy.is_tool_allowed("read_file", PhaseState.PLANNING) is True
|
||||||
|
assert policy.is_tool_allowed("shell", PhaseState.PLANNING) is False
|
||||||
|
# Other phases unchanged.
|
||||||
|
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||||
|
|
||||||
|
def test_whitelist_override_invalid_phase_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||||
|
policy_from_config(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"whitelist_override": {"unknown_phase": ["tool"]},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_whitelist_override_non_list_raises(self):
|
||||||
|
with pytest.raises(ValueError, match="must be a list"):
|
||||||
|
policy_from_config(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"whitelist_override": {"planning": "not a list"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_to_dict_round_trip_via_default(self):
|
||||||
|
# Sanity: default policy serializes to a dict with expected keys.
|
||||||
|
policy = default_policy()
|
||||||
|
d = policy.to_dict()
|
||||||
|
assert set(d["whitelist"].keys()) == {
|
||||||
|
"planning",
|
||||||
|
"building",
|
||||||
|
"verification",
|
||||||
|
"delivery",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ServerConfig.plan_exec integration (R26)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerConfigPlanExec:
|
||||||
|
def test_default_plan_exec_empty(self):
|
||||||
|
config = ServerConfig.from_dict({})
|
||||||
|
assert config.plan_exec == {}
|
||||||
|
|
||||||
|
def test_plan_exec_loaded_from_dict(self):
|
||||||
|
config = ServerConfig.from_dict(
|
||||||
|
{
|
||||||
|
"plan_exec": {
|
||||||
|
"enabled": True,
|
||||||
|
"auto_advance_after_steps": 5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert config.plan_exec == {"enabled": True, "auto_advance_after_steps": 5}
|
||||||
|
|
||||||
|
def test_plan_exec_empty_dict_default(self):
|
||||||
|
config = ServerConfig.from_dict({"plan_exec": {}})
|
||||||
|
assert config.plan_exec == {}
|
||||||
|
|
||||||
|
def test_plan_exec_resolved_to_policy(self):
|
||||||
|
# Wire the config dict through policy_from_config to verify integration.
|
||||||
|
config = ServerConfig.from_dict(
|
||||||
|
{
|
||||||
|
"plan_exec": {
|
||||||
|
"enabled": True,
|
||||||
|
"auto_advance_after_steps": 3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
policy = policy_from_config(config.plan_exec)
|
||||||
|
assert policy is not None
|
||||||
|
assert policy.auto_advance_after_steps == 3
|
||||||
|
|
||||||
|
def test_plan_exec_disabled_via_config(self):
|
||||||
|
config = ServerConfig.from_dict({"plan_exec": {"enabled": False}})
|
||||||
|
policy = policy_from_config(config.plan_exec)
|
||||||
|
assert policy is None
|
||||||
|
|
@ -0,0 +1,339 @@
|
||||||
|
"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24).
|
||||||
|
|
||||||
|
Per plan U3 Execution note: characterization-first — verify that
|
||||||
|
`ReActEngine(phase_policy=None)` behaves identically to pre-change (no
|
||||||
|
enforcement, no advance_phase tool, no _current_phase mutation). Then add
|
||||||
|
enforcement tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.phase import PhasePolicy, PhaseState, default_policy
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Characterization — phase_policy=None preserves existing behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCharacterizationNoPolicy:
|
||||||
|
"""When phase_policy=None, no enforcement happens and behavior matches
|
||||||
|
pre-Wave-3."""
|
||||||
|
|
||||||
|
def test_init_without_phase_policy(self):
|
||||||
|
# Minimal stub LLM gateway — we're only testing constructor.
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
assert engine._phase_policy is None
|
||||||
|
assert engine._current_phase is None
|
||||||
|
assert engine._steps_in_phase == 0
|
||||||
|
assert engine.current_phase is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_tool_dispatches_without_phase_check(self):
|
||||||
|
"""Tool dispatch proceeds normally when no policy set."""
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
# MagicMock.name is a special attribute used internally by Mock for
|
||||||
|
# repr — setting it post-construction does not make mock.name == "x"
|
||||||
|
# hold. Patch _find_tool directly to bypass the name lookup.
|
||||||
|
fake_tool = MagicMock()
|
||||||
|
fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
||||||
|
fake_tool.input_schema = None
|
||||||
|
engine._find_tool = lambda name, tools: fake_tool
|
||||||
|
|
||||||
|
result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool])
|
||||||
|
assert result == {"output": "ok"}
|
||||||
|
fake_tool.safe_execute.assert_awaited_once_with(x=1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_advance_phase_returns_none_without_policy(self):
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
assert engine.advance_phase() is None
|
||||||
|
|
||||||
|
def test_reset_does_not_touch_phase_state_when_no_policy(self):
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
engine.reset()
|
||||||
|
assert engine._current_phase is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Initialization with phase_policy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhasePolicyInitialization:
|
||||||
|
def test_phase_policy_set_initializes_current_phase(self):
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
)
|
||||||
|
assert engine._phase_policy is not None
|
||||||
|
assert engine._current_phase == PhaseState.PLANNING
|
||||||
|
assert engine._steps_in_phase == 0
|
||||||
|
|
||||||
|
def test_reset_resets_phase_to_start(self):
|
||||||
|
gateway = MagicMock()
|
||||||
|
engine = ReActEngine(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
phase_policy=default_policy(),
|
||||||
|
)
|
||||||
|
# Manually move phase forward (simulating execute progress).
|
||||||
|
engine.advance_phase() # PLANNING → BUILDING
|
||||||
|
assert engine._current_phase == PhaseState.BUILDING
|
||||||
|
engine._steps_in_phase = 5
|
||||||
|
|
||||||
|
engine.reset()
|
||||||
|
assert engine._current_phase == PhaseState.PLANNING
|
||||||
|
assert engine._steps_in_phase == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# advance_phase() transitions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdvancePhase:
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self):
|
||||||
|
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
|
||||||
|
def test_planning_to_building(self, engine):
|
||||||
|
new_phase = engine.advance_phase()
|
||||||
|
assert new_phase == PhaseState.BUILDING
|
||||||
|
assert engine.current_phase == PhaseState.BUILDING
|
||||||
|
assert engine._steps_in_phase == 0 # counter reset on transition
|
||||||
|
|
||||||
|
def test_building_to_verification(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
new_phase = engine.advance_phase()
|
||||||
|
assert new_phase == PhaseState.VERIFICATION
|
||||||
|
assert engine.current_phase == PhaseState.VERIFICATION
|
||||||
|
|
||||||
|
def test_verification_to_delivery(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
engine.advance_phase() # → VERIFICATION
|
||||||
|
new_phase = engine.advance_phase()
|
||||||
|
assert new_phase == PhaseState.DELIVERY
|
||||||
|
assert engine.current_phase == PhaseState.DELIVERY
|
||||||
|
|
||||||
|
def test_delivery_returns_none(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
engine.advance_phase() # → VERIFICATION
|
||||||
|
engine.advance_phase() # → DELIVERY
|
||||||
|
result = engine.advance_phase()
|
||||||
|
assert result is None
|
||||||
|
assert engine.current_phase == PhaseState.DELIVERY
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _check_phase_permission — whitelist enforcement
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPhasePermission:
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self):
|
||||||
|
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
|
||||||
|
def test_search_allowed_in_planning(self, engine):
|
||||||
|
assert engine._check_phase_permission("search", {}) is None
|
||||||
|
|
||||||
|
def test_write_file_blocked_in_planning(self, engine):
|
||||||
|
result = engine._check_phase_permission("write_file", {})
|
||||||
|
assert result is not None
|
||||||
|
assert result["error"] == "phase_violation"
|
||||||
|
assert "write_file" in result["message"]
|
||||||
|
assert result["current_phase"] == "planning"
|
||||||
|
|
||||||
|
def test_write_file_allowed_in_building(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
assert engine._check_phase_permission("write_file", {}) is None
|
||||||
|
|
||||||
|
def test_any_tool_allowed_in_delivery(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
engine.advance_phase() # → VERIFICATION
|
||||||
|
engine.advance_phase() # → DELIVERY
|
||||||
|
assert engine._check_phase_permission("literally_anything", {}) is None
|
||||||
|
|
||||||
|
def test_bash_command_filter_blocks_rm_in_planning(self, engine):
|
||||||
|
result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
||||||
|
assert result is not None
|
||||||
|
assert result["error"] == "phase_violation"
|
||||||
|
assert "rm" in result["message"] or "Bash command" in result["message"]
|
||||||
|
|
||||||
|
def test_bash_command_filter_allows_safe_in_planning(self, engine):
|
||||||
|
# `ls` and `git status` are not blocked.
|
||||||
|
assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None
|
||||||
|
assert engine._check_phase_permission("shell", {"command": "git status"}) is None
|
||||||
|
|
||||||
|
def test_bash_command_filter_no_restriction_in_building(self, engine):
|
||||||
|
engine.advance_phase() # → BUILDING
|
||||||
|
# `rm` is allowed in building phase.
|
||||||
|
assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _execute_tool integration — phase enforcement actually blocks dispatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteToolPhaseEnforcement:
|
||||||
|
@pytest.fixture
|
||||||
|
def engine_with_tools(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
# Two fake tools: one allowed in PLANNING (search), one not (write_file).
|
||||||
|
# MagicMock.name can't be set post-construction (special attribute),
|
||||||
|
# so we patch _find_tool with a dict-based lookup.
|
||||||
|
search_tool = MagicMock()
|
||||||
|
search_tool.input_schema = None
|
||||||
|
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
||||||
|
|
||||||
|
write_tool = MagicMock()
|
||||||
|
write_tool.input_schema = None
|
||||||
|
write_tool.safe_execute = AsyncMock(return_value={"written": True})
|
||||||
|
|
||||||
|
tools_by_name = {"search": search_tool, "write_file": write_tool}
|
||||||
|
engine._find_tool = lambda name, tools: tools_by_name.get(name)
|
||||||
|
|
||||||
|
return engine, [search_tool, write_tool]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools):
|
||||||
|
engine, tools = engine_with_tools
|
||||||
|
# write_file in PLANNING should be blocked — write_tool.safe_execute
|
||||||
|
# should NEVER be called.
|
||||||
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||||
|
assert result["error"] == "phase_violation"
|
||||||
|
assert result["current_phase"] == "planning"
|
||||||
|
write_tool = tools[1]
|
||||||
|
write_tool.safe_execute.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allowed_tool_dispatches_normally(self, engine_with_tools):
|
||||||
|
engine, tools = engine_with_tools
|
||||||
|
result = await engine._execute_tool("search", {"query": "foo"}, tools)
|
||||||
|
assert result == {"results": []}
|
||||||
|
search_tool = tools[0]
|
||||||
|
search_tool.safe_execute.assert_awaited_once_with(query="foo")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools):
|
||||||
|
engine, tools = engine_with_tools
|
||||||
|
# First: write_file blocked in PLANNING.
|
||||||
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||||
|
assert result["error"] == "phase_violation"
|
||||||
|
# Advance to BUILDING.
|
||||||
|
engine.advance_phase()
|
||||||
|
# Now: write_file allowed.
|
||||||
|
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||||
|
assert result == {"written": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auto-advance safety net (KTD6)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutoAdvance:
|
||||||
|
def test_auto_advance_after_threshold(self):
|
||||||
|
# Custom policy with auto-advance after 2 steps.
|
||||||
|
policy = PhasePolicy(
|
||||||
|
whitelist={
|
||||||
|
PhaseState.PLANNING: frozenset({"search"}),
|
||||||
|
PhaseState.BUILDING: frozenset({"write_file"}),
|
||||||
|
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||||
|
PhaseState.DELIVERY: frozenset({"*"}),
|
||||||
|
},
|
||||||
|
auto_advance_after_steps=2,
|
||||||
|
)
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy)
|
||||||
|
assert engine.current_phase == PhaseState.PLANNING
|
||||||
|
|
||||||
|
# Step 1: counter goes to 1, no advance yet.
|
||||||
|
engine._steps_in_phase += 1
|
||||||
|
assert engine._maybe_auto_advance() is False
|
||||||
|
assert engine.current_phase == PhaseState.PLANNING
|
||||||
|
|
||||||
|
# Step 2: counter hits 2, advance triggered.
|
||||||
|
engine._steps_in_phase += 1
|
||||||
|
assert engine._maybe_auto_advance() is True
|
||||||
|
assert engine.current_phase == PhaseState.BUILDING
|
||||||
|
assert engine._steps_in_phase == 0 # reset on advance
|
||||||
|
|
||||||
|
def test_auto_advance_none_default(self):
|
||||||
|
# default_policy has auto_advance_after_steps=None — no auto-advance.
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
engine._steps_in_phase = 100
|
||||||
|
assert engine._maybe_auto_advance() is False
|
||||||
|
assert engine.current_phase == PhaseState.PLANNING
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AdvancePhaseTool integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdvancePhaseTool:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_advance_phase_tool_transitions_engine(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
tool = AdvancePhaseTool(engine=engine)
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["current_phase"] == "building"
|
||||||
|
assert engine.current_phase == PhaseState.BUILDING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_advance_phase_tool_at_delivery_returns_error(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
# Walk through all phases.
|
||||||
|
engine.advance_phase() # PLANNING → BUILDING
|
||||||
|
engine.advance_phase() # BUILDING → VERIFICATION
|
||||||
|
engine.advance_phase() # VERIFICATION → DELIVERY
|
||||||
|
tool = AdvancePhaseTool(engine=engine)
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["error"] == "already_at_final_phase"
|
||||||
|
assert result["current_phase"] == "delivery"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_advance_phase_tool_without_policy_returns_error(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
||||||
|
tool = AdvancePhaseTool(engine=engine)
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["error"] == "no_phase_policy"
|
||||||
|
|
||||||
|
def test_tool_schema_accepts_no_arguments(self):
|
||||||
|
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||||
|
tool = AdvancePhaseTool(engine=engine)
|
||||||
|
# input_schema has empty properties + additionalProperties:false —
|
||||||
|
# no arguments expected.
|
||||||
|
assert tool.input_schema["properties"] == {}
|
||||||
|
assert tool.input_schema["additionalProperties"] is False
|
||||||
|
|
||||||
|
def test_tool_bypasses_phase_check(self):
|
||||||
|
"""`advance_phase` is the LLM's escape hatch — must never be blocked."""
|
||||||
|
# _check_phase_permission should NOT block advance_phase even in PLANNING.
|
||||||
|
# The bypass is implemented in _execute_tool by name check.
|
||||||
|
# We verify the bypass indirectly: tool dispatches normally even in
|
||||||
|
# PLANNING (where only search/read_file/bash/tool_search are allowed).
|
||||||
|
# advance_phase is not in the whitelist, but the name-based bypass
|
||||||
|
# in _execute_tool lets it through.
|
||||||
|
# (Direct unit test of the bypass would require mocking _find_tool.)
|
||||||
|
# Sanity: advance_phase is not in any whitelist.
|
||||||
|
for phase, allowed in default_policy().whitelist.items():
|
||||||
|
assert "advance_phase" not in allowed, (
|
||||||
|
f"advance_phase must not be in {phase.value} whitelist"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,367 @@
|
||||||
|
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
|
||||||
|
|
||||||
|
Per plan U1 Execution note: characterization-first — assert that
|
||||||
|
`symbol=None` returns the full file content (matches pre-existing benchmark
|
||||||
|
`_FakeTool` shape) before adding symbol-extraction behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.file_read import ReadFileTool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileToolSchema:
|
||||||
|
def test_name_is_read_file(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.name == "read_file"
|
||||||
|
|
||||||
|
def test_required_path(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "path" in tool.input_schema["required"]
|
||||||
|
assert "path" in tool.input_schema["properties"]
|
||||||
|
|
||||||
|
def test_optional_symbol_and_lines(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
props = tool.input_schema["properties"]
|
||||||
|
assert "symbol" in props
|
||||||
|
assert "start_line" in props
|
||||||
|
assert "end_line" in props
|
||||||
|
# None of the optional fields should be in `required`.
|
||||||
|
required = set(tool.input_schema["required"])
|
||||||
|
assert required == {"path"}
|
||||||
|
|
||||||
|
def test_additional_properties_false(self):
|
||||||
|
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.input_schema.get("additionalProperties") is False
|
||||||
|
|
||||||
|
def test_tags_contain_io_and_read(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "io" in tool.tags
|
||||||
|
assert "read" in tool.tags
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Characterization — symbol=None returns full file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_py_file(tmp_path):
|
||||||
|
path = tmp_path / "sample.py"
|
||||||
|
path.write_text(
|
||||||
|
textwrap.dedent('''
|
||||||
|
"""Sample module."""
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
attr = 1
|
||||||
|
|
||||||
|
def method_a(self):
|
||||||
|
return self.attr
|
||||||
|
''').lstrip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ts_file(tmp_path):
|
||||||
|
path = tmp_path / "sample.ts"
|
||||||
|
path.write_text(
|
||||||
|
textwrap.dedent('''
|
||||||
|
export function renderComponent(): JSX.Element {
|
||||||
|
return <div/>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class BaseService {
|
||||||
|
abstract run(): void;
|
||||||
|
}
|
||||||
|
''').lstrip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class TestCharacterizationFullFile:
|
||||||
|
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file))
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["path"] == str(sample_py_file)
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == result["total_lines"]
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "class MyClass" in result["content"]
|
||||||
|
assert result["content"].startswith('"""Sample module."""')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_file_includes_all_lines(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file))
|
||||||
|
assert result["total_lines"] >= 8
|
||||||
|
assert result["content"].count("\n") >= result["total_lines"] - 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Symbol slicing — happy paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSlicing:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_function(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "my_func"
|
||||||
|
assert result["symbol_kind"] == "function"
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "return 42" in result["content"]
|
||||||
|
# Should NOT include the class below.
|
||||||
|
assert "class MyClass" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_class_includes_method(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "MyClass"
|
||||||
|
assert result["symbol_kind"] == "class"
|
||||||
|
assert "class MyClass" in result["content"]
|
||||||
|
assert "def method_a" in result["content"] # method included
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_method_directly(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "method_a"
|
||||||
|
assert result["symbol_kind"] == "method"
|
||||||
|
assert "def method_a" in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_typescript_function(self, sample_ts_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "renderComponent"
|
||||||
|
assert "renderComponent" in result["content"]
|
||||||
|
# Should not include the class below.
|
||||||
|
assert "BaseService" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_typescript_class(self, sample_ts_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "BaseService"
|
||||||
|
assert result["symbol_kind"] == "class"
|
||||||
|
assert "BaseService" in result["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Symbol slicing — edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSlicingEdgeCases:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_symbol_not_found_lists_available(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
|
||||||
|
|
||||||
|
assert result["is_error"] is False # soft error, not hard
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["symbol"] == "nonexistent"
|
||||||
|
available = result["available_symbols"]
|
||||||
|
assert "my_func" in available
|
||||||
|
assert "MyClass" in available
|
||||||
|
assert "method_a" in available
|
||||||
|
assert "nonexistent" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
|
||||||
|
path = tmp_path / "notes.md"
|
||||||
|
path.write_text("# Hello\nworld\n", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path), symbol="anything")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == "# Hello\nworld\n"
|
||||||
|
assert "symbol extraction not supported" in result["note"]
|
||||||
|
assert ".md" in result["note"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_file(self, tmp_path):
|
||||||
|
path = tmp_path / "empty.py"
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path))
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["total_lines"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_with_no_symbols(self, tmp_path):
|
||||||
|
path = tmp_path / "data.py"
|
||||||
|
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path), symbol="PI")
|
||||||
|
|
||||||
|
# PI is not a def/class — extractor finds no symbols; soft error lists available.
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["available_symbols"] == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileToolErrors:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_required(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "path" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_empty_string(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path="")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_not_found(self, tmp_path):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(tmp_path / "missing.py"))
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "not found" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_is_directory(self, tmp_path):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(tmp_path))
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "directory" in result["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Manual line slicing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestManualLineSlicing:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_and_end_line(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
start_line=3,
|
||||||
|
end_line=5,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 3
|
||||||
|
assert result["end_line"] == 5
|
||||||
|
# Lines 3-5 of the sample file:
|
||||||
|
# line 3: "def my_func():"
|
||||||
|
# line 4: " return 42"
|
||||||
|
# line 5: "" (blank)
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "return 42" in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_line_only_extends_to_eof(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), start_line=8)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 8
|
||||||
|
assert result["end_line"] == result["total_lines"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_line_only_starts_at_one(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), end_line=2)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_start_line_zero(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), start_line=0)
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "start_line" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_before_start(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
start_line=5,
|
||||||
|
end_line=3,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "end_line" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manual_lines_override_symbol(self, sample_py_file):
|
||||||
|
# Per plan U1 Approach: "start_line/end_line overrides symbol".
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
symbol="my_func",
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
# Manual slicing won — symbol field absent.
|
||||||
|
assert "symbol" not in result or result.get("symbol") is None
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration — tool registry discovery
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryDiscovery:
|
||||||
|
def test_instantiable_without_args(self):
|
||||||
|
# Default constructor — matches the convention used by ToolRegistry
|
||||||
|
# to instantiate tools by class.
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.name == "read_file"
|
||||||
|
|
||||||
|
def test_to_dict_serializable(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
d = tool.to_dict()
|
||||||
|
assert d["name"] == "read_file"
|
||||||
|
assert "input_schema" in d
|
||||||
|
assert "output_schema" in d
|
||||||
|
assert d["tags"] == ["io", "file", "read"]
|
||||||
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
|
||||||
|
|
||||||
|
Covers R22 (file reading supports symbol/function granularity) and KTD1
|
||||||
|
(Python ast + language-aware regex, no tree-sitter dependency).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.symbol_extractor import (
|
||||||
|
AstSymbolExtractor,
|
||||||
|
RegexSymbolExtractor,
|
||||||
|
SymbolSpan,
|
||||||
|
extract_symbols_from_file,
|
||||||
|
get_extractor,
|
||||||
|
language_for_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# language_for_extension
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLanguageForExtension:
|
||||||
|
def test_python_extensions(self):
|
||||||
|
assert language_for_extension("py") == "py"
|
||||||
|
assert language_for_extension(".py") == "py"
|
||||||
|
assert language_for_extension(".PY") == "py" # case-insensitive
|
||||||
|
|
||||||
|
def test_typescript_javascript(self):
|
||||||
|
assert language_for_extension(".ts") == "ts"
|
||||||
|
assert language_for_extension(".tsx") == "ts"
|
||||||
|
assert language_for_extension(".js") == "js"
|
||||||
|
assert language_for_extension(".jsx") == "js"
|
||||||
|
assert language_for_extension(".mjs") == "js"
|
||||||
|
assert language_for_extension(".cjs") == "js"
|
||||||
|
|
||||||
|
def test_go_rust_java(self):
|
||||||
|
assert language_for_extension(".go") == "go"
|
||||||
|
assert language_for_extension(".rs") == "rs"
|
||||||
|
assert language_for_extension(".java") == "java"
|
||||||
|
|
||||||
|
def test_unsupported_returns_empty(self):
|
||||||
|
assert language_for_extension(".md") == ""
|
||||||
|
assert language_for_extension(".txt") == ""
|
||||||
|
assert language_for_extension("") == ""
|
||||||
|
assert language_for_extension(".unknown") == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AstSymbolExtractor — Python
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAstSymbolExtractor:
|
||||||
|
extractor = AstSymbolExtractor()
|
||||||
|
|
||||||
|
def test_unsupported_language_returns_empty(self):
|
||||||
|
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
|
||||||
|
|
||||||
|
def test_syntax_error_returns_empty(self):
|
||||||
|
# Never raises — callers rely on this for fallback routing.
|
||||||
|
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_top_level_function(self):
|
||||||
|
content = "def my_func():\n return 42\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
|
assert span.name == "my_func"
|
||||||
|
assert span.kind == "function"
|
||||||
|
assert span.start_line == 1
|
||||||
|
assert span.end_line == 2
|
||||||
|
|
||||||
|
def test_async_function(self):
|
||||||
|
content = "async def fetch():\n return 1\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].name == "fetch"
|
||||||
|
assert spans[0].kind == "function"
|
||||||
|
|
||||||
|
def test_top_level_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
class MyClass:
|
||||||
|
"""docstring"""
|
||||||
|
|
||||||
|
def method_a(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def method_b(self):
|
||||||
|
return 2
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
names = [s.name for s in spans]
|
||||||
|
assert "MyClass" in names
|
||||||
|
assert "method_a" in names
|
||||||
|
assert "method_b" in names
|
||||||
|
|
||||||
|
cls = next(s for s in spans if s.name == "MyClass")
|
||||||
|
assert cls.kind == "class"
|
||||||
|
assert cls.start_line == 1
|
||||||
|
# Class body extends through the last method's end_lineno.
|
||||||
|
assert cls.end_line >= 7
|
||||||
|
|
||||||
|
def test_methods_classified_as_methods(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
class Foo:
|
||||||
|
def bar(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def top_level():
|
||||||
|
pass
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
by_name = {s.name: s for s in spans}
|
||||||
|
assert by_name["bar"].kind == "method"
|
||||||
|
assert by_name["top_level"].kind == "function"
|
||||||
|
|
||||||
|
def test_decorated_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
@staticmethod
|
||||||
|
def helper():
|
||||||
|
return "hi"
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
# Note: extractor uses node.lineno (def line) — decorators above are
|
||||||
|
# excluded by design (matches user-visible symbol start at `def`).
|
||||||
|
assert any(s.name == "helper" for s in spans)
|
||||||
|
span = next(s for s in spans if s.name == "helper")
|
||||||
|
assert span.start_line == 2 # the `def` line
|
||||||
|
|
||||||
|
def test_nested_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
def outer():
|
||||||
|
def inner():
|
||||||
|
return 1
|
||||||
|
return inner()
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "outer" in names
|
||||||
|
assert "inner" in names
|
||||||
|
|
||||||
|
def test_empty_file(self):
|
||||||
|
assert self.extractor.extract_symbols("", "py") == []
|
||||||
|
|
||||||
|
def test_no_symbols_in_docstring_only_file(self):
|
||||||
|
content = '"""just a docstring"""\n'
|
||||||
|
assert self.extractor.extract_symbols(content, "py") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegexSymbolExtractor:
|
||||||
|
extractor = RegexSymbolExtractor()
|
||||||
|
|
||||||
|
def test_unsupported_language_returns_empty(self):
|
||||||
|
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
|
||||||
|
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
|
||||||
|
|
||||||
|
def test_typescript_function_declaration(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
export function renderComponent(props: Props): JSX.Element {
|
||||||
|
return <div/>;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_async_function(self):
|
||||||
|
content = "async function fetchData() {\n return await fetch();\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "fetchData" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_arrow_function_const(self):
|
||||||
|
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "handleClick" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
export abstract class BaseService {
|
||||||
|
abstract run(): void;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
|
||||||
|
|
||||||
|
def test_javascript_function(self):
|
||||||
|
content = "function foo() {\n return 1;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert any(s.name == "foo" for s in spans)
|
||||||
|
|
||||||
|
def test_javascript_arrow_const(self):
|
||||||
|
content = "const bar = () => 42;\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert any(s.name == "bar" for s in spans)
|
||||||
|
|
||||||
|
def test_go_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
package main
|
||||||
|
|
||||||
|
func HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Start() {
|
||||||
|
// method
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "go")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "HandleRequest" in names
|
||||||
|
assert "Start" in names # method receiver pattern
|
||||||
|
|
||||||
|
def test_go_struct(self):
|
||||||
|
content = "type Server struct {\n Addr string\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "go")
|
||||||
|
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
|
||||||
|
|
||||||
|
def test_rust_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
pub fn process(input: &str) -> Result<usize, Error> {
|
||||||
|
Ok(input.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch() -> Bytes {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "process" in names
|
||||||
|
assert "fetch" in names
|
||||||
|
|
||||||
|
def test_rust_struct(self):
|
||||||
|
content = "pub struct Config {\n pub path: String,\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
|
||||||
|
|
||||||
|
def test_rust_impl(self):
|
||||||
|
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
|
||||||
|
|
||||||
|
def test_java_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
package com.example;
|
||||||
|
|
||||||
|
public class UserService {
|
||||||
|
public User findById(long id) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "java")
|
||||||
|
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
|
||||||
|
|
||||||
|
def test_java_method(self):
|
||||||
|
content = "public User findById(long id) {\n return null;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "java")
|
||||||
|
assert any(s.name == "findById" and s.kind == "function" for s in spans)
|
||||||
|
|
||||||
|
def test_end_line_extends_to_next_symbol(self):
|
||||||
|
# First symbol's end_line is the line before the second symbol starts.
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
function first() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
function second() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
spans.sort(key=lambda s: s.start_line)
|
||||||
|
first = spans[0]
|
||||||
|
second = spans[1]
|
||||||
|
assert first.name == "first"
|
||||||
|
assert second.name == "second"
|
||||||
|
assert first.end_line == second.start_line - 1
|
||||||
|
|
||||||
|
def test_last_symbol_end_line_is_eof(self):
|
||||||
|
content = "function only() {\n return 1;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].end_line == len(content.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_extractor + integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExtractor:
|
||||||
|
def test_python_returns_ast_extractor(self):
|
||||||
|
ext = get_extractor("py")
|
||||||
|
assert ext is not None
|
||||||
|
assert isinstance(ext, AstSymbolExtractor)
|
||||||
|
|
||||||
|
def test_typescript_returns_regex_extractor(self):
|
||||||
|
ext = get_extractor("ts")
|
||||||
|
assert ext is not None
|
||||||
|
assert isinstance(ext, RegexSymbolExtractor)
|
||||||
|
|
||||||
|
def test_unsupported_returns_none(self):
|
||||||
|
assert get_extractor("md") is None
|
||||||
|
assert get_extractor("") is None
|
||||||
|
assert get_extractor("unknown") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractSymbolsFromFile:
|
||||||
|
def test_python_file(self, tmp_path):
|
||||||
|
path = tmp_path / "module.py"
|
||||||
|
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
assert lang == "py"
|
||||||
|
assert any(s.name == "hello" for s in spans)
|
||||||
|
|
||||||
|
def test_unsupported_extension(self, tmp_path):
|
||||||
|
path = tmp_path / "notes.md"
|
||||||
|
path.write_text("# Hello\n", encoding="utf-8")
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
assert lang == ""
|
||||||
|
assert spans == []
|
||||||
|
|
||||||
|
def test_missing_file_returns_empty(self, tmp_path):
|
||||||
|
path = tmp_path / "nonexistent.py"
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
# lang is detected from extension even if read fails.
|
||||||
|
assert lang == "py"
|
||||||
|
assert spans == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SymbolSpan dataclass
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSpan:
|
||||||
|
def test_frozen_dataclass(self):
|
||||||
|
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
|
||||||
|
assert span.name == "foo"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
span.name = "bar" # type: ignore[misc] — frozen
|
||||||
|
|
||||||
|
def test_equality(self):
|
||||||
|
a = SymbolSpan("foo", "function", 1, 3)
|
||||||
|
b = SymbolSpan("foo", "function", 1, 3)
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
Loading…
Reference in New Issue