Compare commits

...

8 Commits

Author SHA1 Message Date
chiguyong f50d3485ea fix(review): Wave 3 code review fixes
Test / backend-test (pull_request) Has been cancelled Details
Test / frontend-unit (pull_request) Has been cancelled Details
Test / api-e2e (pull_request) Has been cancelled Details
Test / frontend-e2e (pull_request) Has been cancelled Details
P1: bash/shell tool name mismatch. PhasePolicy whitelist used "bash" but
ShellTool registers as "shell". The bash_command_filter was dead code
(never matched the real tool name). Fixed in phase.py whitelist,
react.py filter check, agentkit.yaml config, and all tests.

P1: AdvancePhaseTool missing import in tools/__init__.py. Was in
__all__ but never imported. Added the import.

P2: chat.py phase policy error message echoed verbatim to WS client.
Truncated to 200 chars to match nearby error paths and avoid leaking
config internals.

P2: policy_from_config rebuilt PhasePolicy 3x via full-field copy.
Replaced with dataclasses.replace() so new PhasePolicy fields are not
silently dropped in future reconstructions.

ce-code-review (mode:agent) step of LFG pipeline.
2026-06-30 09:13:07 +08:00
chiguyong 5a0554b27f refactor(advance_phase): simplify previous_phase capture
Remove over-engineered _previous_value static method that did index
math on a hardcoded phase list. Instead, capture the previous phase
before the transition — clearer intent, fewer lines, same behavior.

ce-simplify-code step of LFG pipeline.
2026-06-30 09:13:06 +08:00
chiguyong da4eef1349 feat(U4): G6 PLAN_EXEC wiring at chat WebSocket path
- PLAN_EXEC branch builds PhasePolicy from ServerConfig.plan_exec
- Empty config → default_policy(); enabled=False → falls back to REACT
- Bad config → error event sent, returns early (no engine constructed)
- ReActEngine created with phase_policy; AdvancePhaseTool registered
- phase_changed events emitted on phase transitions (PLAN_EXEC only)
- REST send_message with execution_mode=plan_exec → HTTP 501 (KTD4)
- REWOO/REFLEXION/TEAM_COLLAB still fall back to REACT (no regression)
- 9 unit tests covering REST 501, characterization, happy path, edge cases, error path, phase_changed events
2026-06-30 09:13:06 +08:00
chiguyong 6efd5957f6 feat(U3): G6 AdvancePhaseTool + ReActEngine phase enforcement
- AdvancePhaseTool calls engine.advance_phase(), returns new phase or error
- ReActEngine.__init__ accepts phase_policy param (None = no enforcement, backward compat)
- _current_phase + _steps_in_phase fields track state machine
- advance_phase() transitions PLANNING → BUILDING → VERIFICATION → DELIVERY
- _check_phase_permission() returns structured error dict if tool blocked
- _execute_tool checks phase before dispatch (advance_phase name bypasses)
- Auto-advance safety net via _maybe_auto_advance() + auto_advance_after_steps
- Phase reset in reset() method
- 27 unit tests covering characterization, permission, transitions, auto-advance, tool integration
2026-06-30 09:13:06 +08:00
chiguyong be4ac797b2 feat(U2): G6 PhaseState + PhasePolicy + ServerConfig.plan_exec
- PhaseState enum (PLANNING/BUILDING/VERIFICATION/DELIVERY) with next_of/from_string
- PhasePolicy dataclass with whitelist + bash_command_filter + auto_advance_after_steps
- default_policy() factory — KTD5 whitelist matching R24 (Planning: search/read_file;
  Building: write_file; Delivery: wildcard)
- bash_command_filter blocks rm/mv/cp/>/>> in PLANNING/VERIFICATION phases
- policy_from_config() parses plan_exec YAML section (R26) with override merge
- ServerConfig.plan_exec field + from_dict parsing (extends Wave 1/2 pattern)
- agentkit.yaml gains commented plan_exec section (opt-in)
- 37 unit tests covering PhaseState, default_policy, is_tool_allowed,
  bash filter, config parsing, and ServerConfig integration
2026-06-30 09:13:06 +08:00
chiguyong 58ef1719cb feat(U1): G5 SymbolExtractor + ReadFileTool with symbol slicing
- SymbolExtractor protocol + SymbolSpan dataclass
- AstSymbolExtractor for Python (stdlib ast, no tree-sitter dep — KTD1)
- RegexSymbolExtractor for TS/JS/Go/Rust/Java (language-aware regex)
- ReadFileTool with path/symbol/start_line/end_line params
- symbol=None returns full file (characterization baseline matching _FakeTool)
- symbol='foo' returns first matching symbol's line range
- symbol not found returns available_symbols list (soft error)
- Unsupported extension returns full file with note
- Manual start_line/end_line overrides symbol
- 66 unit tests covering R22/R23 + characterization + edge cases
2026-06-30 09:13:06 +08:00
chiguyong e3f69f963c docs(plan): Wave 3 strategic coupling plan (G5/G6) 2026-06-30 09:13:06 +08:00
Fischer a2dcde01b8 feat(agent): Wave 2 medium coupling (G4/G7/G9) (#5)
Deploy to Production / deploy (push) Waiting to run Details
Test / backend-test (push) Waiting to run Details
Test / frontend-unit (push) Waiting to run Details
Test / api-e2e (push) Waiting to run Details
Test / frontend-e2e (push) Waiting to run Details
2026-06-30 09:09:33 +08:00
28 changed files with 6116 additions and 42 deletions

View File

@ -28,6 +28,42 @@ llm:
coding: bailian-coding/qwen3-coder-plus
chat: deepseek/deepseek-chat
reasoning: deepseek/deepseek-reasoner
# G4/U1: Auxiliary model for cost-sensitive tasks (summarization).
# When set, ContextCompressor tries this alias first, falling back to
# the main model on failure or empty content. Commented to preserve
# default behavior — uncomment to enable.
# auxiliary_model: fast
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
# not `git checkout .` which would wipe unrelated changes).
rollback:
default_timeout: 30.0
# G7/U3: Three-tier fallback chain at chat REST send_message.
# main → Recovery (ReflexionEngine retry) → Emergency (rule-based classifier).
# Wired only at chat REST path (KTD5); CLI / ReWOO / Reflexion internal
# ReAct calls bypass the chain (no recursive loop).
fallback_chain:
enabled: true
recovery:
enabled: true
max_retries: 1 # ReflexionEngine max_reflections override
emergency:
enabled: true
# G6/U2: PLAN_EXEC phase policy — SOLO four-stage state machine.
# When `enabled: true`, chat WebSocket PLAN_EXEC requests build a PhasePolicy
# (Planning → Building → Verification → Delivery) and enforce per-step tool
# whitelists (R24). Transitions are LLM-driven via AdvancePhaseTool; set
# `auto_advance_after_steps` to auto-advance as a safety net (KTD6).
# Commented to preserve default behavior — uncomment to enable.
# plan_exec:
# enabled: true
# auto_advance_after_steps: 5 # optional, default = manual (LLM calls advance_phase)
# start_phase: planning # optional, default = planning
# whitelist_override: # optional, merges with default (override wins)
# planning: [search, read_file, shell]
# building: [write_file, shell, read_file]
session: {backend: memory}
bus: {backend: memory}
task_store: {backend: memory}

View File

@ -0,0 +1,481 @@
---
date: 2026-06-29
type: feat
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
---
# Wave 2 — Auxiliary LLM Routing, Three-Tier Fallback, Atomic Subtask Rollback
## Summary
Wave 2 of the advanced-agent gap optimization brainstorm. Three medium-coupling gaps: G4 routes summary calls through a cheaper auxiliary LLM (falling back to main on failure), G7 introduces a three-tier fallback chain (main → Recovery via existing `ReflexionEngine` → Emergency with structured errors), G9 binds atomic subtask rollback to `PlanPhase` (opt-in via `rollback_command`, coordinated with the existing U7 `PipelineCheckpoint`).
---
## Problem Frame
Wave 1 shipped self-contained quick wins (G1/G2/G3/G8). Wave 2 addresses the medium-coupling gaps that touch multiple layers and resolve two deferred-to-planning decisions surfaced in the brainstorm: G7 Emergency layer rule template shape, and G9 `rollback_command` default behavior. The constraints these gaps must respect come from `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — new fields must preserve existing contracts, dynamic plan mutations must persist immediately, and `PipelineCheckpoint` is in-memory dict + Redis fallback (not a DB row lock).
---
## Requirements
Carried from `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` § Wave 2 (R13-R21):
- R13. `ContextCompressor` summary task routes to auxiliary model (cheap, e.g. Gemini Flash / Doubao lite), not main model.
- R14. `auxiliary_model` is configurable, separated from main `model`.
- R15. Summary quality does not degrade: auxiliary failure falls back to main model (not to simple_summary).
- R16. Main agent failure triggers Recovery layer (reuse `ReflexionEngine` Evaluate→Reflect→Retry).
- R17. Recovery failure triggers Emergency layer (rule-based fallback, structured error + suggestion).
- R18. Fallback chain is configurable (per-layer max retries, enable/disable Recovery/Emergency).
- R19. `PlanPhase` (`experts/plan.py:147`) gains optional `validation_command` and `rollback_command` fields.
- R20. Phase failure auto-executes `rollback_command` when configured (default `git checkout` pattern); coordinated with U7 checkpoint.
- R21. `checkpoint.save` runs only after rollback validation passes (avoid persisting failed state).
Cross-cutting: R26 (configuration via `agentkit.yaml`), R27 (each gap ships a minimal self-check test, per ponytail rule).
---
## Key Technical Decisions
KTD1. **Auxiliary routing is compressor-scoped, not gateway-level.** `LLMGateway` already has a per-model fallback chain (`gateway.py:170-227`), but it semantics are "main fails → try backup model", not "route by task type". G4 adds an `auxiliary_model` parameter to `ContextCompressor.__init__`; `_summarize` tries auxiliary first, falls back to main on failure. This avoids touching the gateway's existing fallback semantics and keeps the change to one file plus config wiring.
KTD2. **Emergency layer adds a parallel `error_struct` field, not replacing `error_message`.** `TaskResult.error_message: str | None` is serialized to API responses (`protocol.py:132-145`) — changing its type breaks frontend contracts. New `error_struct: dict | None` field carries `{error_code, message, suggestions, retryable}`. `error_message` remains the human-readable string (now derived from `error_struct.message` when present). Backward compatibility preserved.
KTD3. **Emergency rule classification by exception type, not by status field.** `TaskTimeoutError``timeout`, `LoopDetectedError``loop_detected`, `LLMProviderError``llm_failure`, `TaskCancelledError``cancelled` (not Emergency-eligible, propagates), generic `Exception``internal_error`. Each rule maps to a fixed message + suggestion list (no LLM-driven suggestions — pure rule-based, per brainstorm).
KTD4. **Recovery layer reuses `ReflexionEngine` with main model, not auxiliary.** `ReflexionEngine.execute()` already supports `evaluate_model`/`reflect_model` parameters (`reflexion.py:94-111`), defaulting to the main model. Recovery triggers on main agent failure; using the main model for evaluate/reflect maximizes diagnostic reliability (the same model that failed is best positioned to reflect on its own failure). Auxiliary model is reserved for G4's cost-sensitive compression path.
KTD5. **Three-tier chain wires at `chat.py:613` (WebSocket main path), not at `ReActEngine.execute()`.** ReAct has 5 call sites (`chat.py:613`, `rewoo.py:1204`, `cli/chat.py:338`, `reflexion.py:216`, `config_driven.py:695`). Wrapping at ReAct would force all 5 sites through Recovery/Emergency, including ReflexionEngine itself (creating a recursive loop). Wrapping at `chat.py:613` covers the primary user-facing path; CLI and ReWOO are out of scope (deferred to follow-up).
KTD6. **Rollback is opt-in via `rollback_command` field; no implicit default.** Brainstorm KTD5 mentioned "default `git checkout`" but auto-executing `git checkout` on every phase failure changes existing behavior (Finding 1: "新字段默认值须保持既有契约"). Resolution: `rollback_command: str | None = None`. When unset, no rollback executes (preserves existing contract). When set, the configured command runs after validation failure, before checkpoint save. `git checkout <files>` is the canonical pattern documented in YAML examples.
KTD7. **Rollback executes via `RollbackExecutor`, not via `ShellTool`.** `ShellTool._is_dangerous` would trigger `confirm_callback` for `git checkout` (not in `_SAFE_COMMAND_PREFIXES`, not in `_DANGEROUS_BINARY_FLAGS`). Orchestrator-internal execution bypasses ShellTool — uses `asyncio.create_subprocess_shell` with `cwd`/`timeout`/`proc.kill()` pattern from `verification_loop.py:67-103`. Audit logging added (not a whitelist concern; terminal_whitelist only governs `/api/v1/terminal/server` route per `terminal_server.py:227`).
KTD8. **G9 checkpoint ordering: validation → rollback → checkpoint save.** Currently `orchestrator.py:265` saves checkpoint unconditionally after phase finalizes (success or failure). R21 requires checkpoint save only after rollback validation passes. New ordering: phase fails → mark FAILED → mark dependents FAILED → execute `validation_command` (if set) → if validation fails, execute `rollback_command` (if set) → save checkpoint (only if rollback validation passed or no rollback configured).
---
## High-Level Technical Design
### Three-tier fallback chain (G7)
```mermaid
stateDiagram-v2
[*] --> Main: chat request
Main --> MainSuccess: success
Main --> Recovery: failure (exception or empty_fallback)
Recovery --> RecoverySuccess: ReflexionEngine retry succeeds
Recovery --> Emergency: max_reflections exhausted
Emergency --> [*]: emit error_struct, terminal
MainSuccess --> [*]: normal response
RecoverySuccess --> [*]: recovered response
```
### G9 phase failure + rollback sequence
```mermaid
sequenceDiagram
participant O as TeamOrchestrator
participant P as PlanPhase
participant RE as RollbackExecutor
participant CP as PipelineCheckpoint
O->>P: execute phase
P-->>O: failure (exception)
O->>O: mark FAILED + dependents FAILED
alt validation_command set
O->>RE: run validation_command
RE-->>O: ValidationResult
alt validation fails AND rollback_command set
O->>RE: run rollback_command
RE-->>O: RollbackResult
alt rollback validation passes
O->>CP: save checkpoint
else rollback validation fails
O->>O: log rollback_failed, skip checkpoint
end
else no rollback_command
O->>CP: save checkpoint (no rollback)
end
else no validation_command
O->>CP: save checkpoint (no validation)
end
```
### Configuration shape (YAML)
```yaml
# G4: Auxiliary LLM for compression
llm:
auxiliary_model: fast # alias resolved via existing model_aliases mechanism
# G7: Three-tier fallback chain
fallback_chain:
enabled: true
recovery:
enabled: true
max_retries: 1 # ReflexionEngine max_reflections override
emergency:
enabled: true
# G9: Rollback configuration (per-phase opt-in via PlanPhase.rollback_command)
rollback:
default_timeout: 30.0 # RollbackExecutor subprocess timeout
```
---
## Implementation Units
### U1. G4 — Auxiliary LLM Routing in ContextCompressor
**Goal:** Route `_summarize` LLM calls through `auxiliary_model` when configured; fall back to main model on auxiliary failure (not to `_simple_summary`).
**Requirements:** R13, R14, R15, R26, R27
**Dependencies:** none (self-contained)
**Files:**
- Modify: `src/agentkit/core/compressor.py` (add `auxiliary_model` param + routing in `_summarize`)
- Modify: `src/agentkit/llm/config.py` (add `auxiliary_model: str | None` field to `LLMConfig`)
- Modify: `src/agentkit/server/config.py` (`_build_llm_config` reads `auxiliary_model` from llm section)
- Modify: `agentkit.yaml` (document `llm.auxiliary_model: fast` in llm section)
- Create: `tests/unit/test_compressor_auxiliary.py`
**Approach:** `ContextCompressor.__init__` gains `auxiliary_model: str | None = None` param (after existing `model` param). `_summarize` (compressor.py:123-158) restructuring:
1. If `auxiliary_model` is set and differs from `model`, try `auxiliary_model` first via `self._llm_gateway.chat(model=self._auxiliary_model, ...)`.
2. **Truthiness check (Finding 4 anti-pattern):** treat empty `response.content` (None or whitespace-only) as failure, not success. On failure, log and fall through to main model.
3. If auxiliary succeeds with non-empty content, return it.
4. If auxiliary fails (exception OR empty content), retry with main `self._model`.
5. Existing `except Exception → _simple_summary` block remains as the final degradation tier.
`LLMConfig.auxiliary_model` reads from `data.get("auxiliary_model")` in `_build_llm_config`. Agentkit.yaml already declares `fast: bailian-coding/qwen-turbo` alias — this is the canonical auxiliary target.
**Execution note:** characterization-first. Capture current `_summarize` behavior (single model, exception → simple_summary) with `auxiliary_model=None` tests, then add `auxiliary_model="fast"` tests for new routing behavior.
**Patterns to follow:**
- `LLMGateway._get_models_to_try` (`gateway.py:170-227`) — existing fallback chain pattern
- Wave 1's `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections, `config.py:239-273`)
**Test scenarios:**
- Happy path: `auxiliary_model="fast"` set, auxiliary returns non-empty content → result is auxiliary content, main model not called
- Empty content fallback: auxiliary returns `content=""` (or `None`) → main model called, main content returned (covers Finding 4 anti-pattern)
- Auxiliary exception: auxiliary raises `LLMProviderError` → main model called, main content returned
- Both fail: auxiliary raises, main raises → `_simple_summary` returned (existing degradation preserved)
- Characterization: `auxiliary_model=None` → behavior matches current code (single model call)
- Config wiring: `LLMConfig.from_dict` reads `auxiliary_model` field from dict; `ServerConfig._build_llm_config` passes it through
- Audit: auxiliary call uses `agent_name="compressor"`, `task_type="summarization"` (preserved for usage tracking)
**Verification:** All test scenarios pass; existing `tests/unit/test_compressor*.py` (if any) still pass with `auxiliary_model=None` default; ruff clean.
---
### U2. G7 — Emergency Layer Rule Template + TaskResult Extension
**Goal:** Add rule-based Emergency classifier with structured error output. No wiring yet — just the infrastructure (classifier + data structure + `fallback.py` extension).
**Requirements:** R17, R18, R26, R27
**Dependencies:** none (foundation for U3)
**Files:**
- Modify: `src/agentkit/core/fallback.py` (add `EmergencyRules` class + `EmergencyError` dataclass, preserve existing 3 constants)
- Modify: `src/agentkit/core/protocol.py` (add `error_struct: dict | None = None` field to `TaskResult`, update `to_dict`)
- Create: `tests/unit/test_emergency_rules.py`
**Approach:**
`EmergencyError` dataclass (in `fallback.py`):
```
@dataclass
class EmergencyError:
error_code: str # "timeout"|"loop_detected"|"llm_failure"|"internal_error"
message: str # human-readable Chinese message (mirrors EMPTY_LLM_RESPONSE style)
suggestions: list[str] # actionable user-facing suggestions
retryable: bool # whether user retry might succeed
original_error: str # str(exc) for traceability
def to_dict(self) -> dict: ...
def to_error_message(self) -> str: ... # formatted "message\n建议1) ... 2) ..."
```
`EmergencyRules` class (rule-based classifier, no LLM):
```
class EmergencyRules:
@staticmethod
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
# Match by exception type (TaskTimeoutError, LoopDetectedError, LLMProviderError, etc.)
# config allows per-rule customization (suggestion overrides, retryable overrides)
```
Rule mapping (initial set, expandable via config):
- `TaskTimeoutError``error_code="timeout"`, `retryable=True`, suggestions: ["稍后重试", "简化任务范围"]
- `LoopDetectedError``error_code="loop_detected"`, `retryable=True`, suggestions: ["拆分任务", "检查工具参数"]
- `LLMProviderError``error_code="llm_failure"`, `retryable=True`, suggestions: ["稍后重试", "切换模型"]
- `TaskCancelledError` → not classified (propagates as-is, Emergency not triggered)
- Generic `Exception``error_code="internal_error"`, `retryable=False`, suggestions: ["联系管理员"]
`TaskResult` extension:
```
@dataclass
class TaskResult:
# ... existing fields ...
error_message: str | None # unchanged
error_struct: dict | None = None # NEW: serialized EmergencyError.to_dict()
```
`to_dict` includes `error_struct` when set; `error_message` continues to hold the human-readable string.
**Patterns to follow:**
- Existing `EMPTY_LLM_RESPONSE` / `MAX_STEPS_REACHED` style in `fallback.py` (Chinese, "建议:..." format)
- `VerificationResult.errors: list[str]` field pattern from `verification_loop.py:18-24`
- `TaskResult.to_dict` pattern at `protocol.py:132-145`
**Test scenarios:**
- Happy path: `classify(TaskTimeoutError(...))` returns `EmergencyError(error_code="timeout", retryable=True)`
- Each exception type maps to correct `error_code`
- `TaskCancelledError` is NOT classified (caller must check before invoking `classify`)
- `to_dict()` produces all 5 fields
- `to_error_message()` formats suggestions as "建议1) ... 2) ..."
- `EmergencyRules.classify(Exception("unknown"))``error_code="internal_error"`, `retryable=False`
- Config override: custom suggestion list for `timeout` rule via config dict
- `TaskResult` with `error_struct` set: `to_dict()` includes both `error_message` and `error_struct`
- `TaskResult` with `error_struct=None` (default): `to_dict()` matches current behavior (backward compat)
**Verification:** All scenarios pass; `fallback.py` existing 3 constants unchanged; `TaskResult.to_dict` for tasks without `error_struct` matches pre-change output byte-for-byte.
---
### U3. G7 — Three-Tier Fallback Chain Wiring
**Goal:** Wire main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) at `chat.py:613`. Composes U2's infrastructure with existing `ReflexionEngine`.
**Requirements:** R16, R18, R26
**Dependencies:** U2 (EmergencyRules + TaskResult.error_struct)
**Files:**
- Modify: `src/agentkit/server/routes/chat.py` (wrap main agent call at L613 with three-tier chain)
- Modify: `src/agentkit/server/config.py` (add `fallback_chain` section to `from_dict`)
- Modify: `agentkit.yaml` (document `fallback_chain:` section)
- Create: `tests/unit/test_fallback_chain.py`
**Approach:**
New helper module or inline function in `chat.py`:
```
async def execute_with_fallback_chain(
agent: ConfigDrivenAgent,
task: Task,
config: dict, # fallback_chain config section
llm_gateway: LLMGateway,
) -> TaskResult:
# Tier 1: Main
try:
result = await agent.execute(task)
if result.status == "success" or result.status == "completed":
return result
# Treat non-success as soft failure → trigger Recovery
raise AgentExecutionError(result.error_message or "main agent did not succeed")
except (TaskTimeoutError, LoopDetectedError, LLMProviderError, AgentExecutionError) as exc:
if not config.get("recovery", {}).get("enabled", True):
return _to_emergency(exc, task)
# Tier 2: Recovery (ReflexionEngine)
try:
reflexion_engine = ReflexionEngine(
llm_gateway=llm_gateway,
max_reflections=config.get("recovery", {}).get("max_retries", 1),
# ... other params from agent's existing reflexion config
)
recovery_result = await reflexion_engine.execute(
messages=task.messages, # rebuild from task
tools=agent.get_tools(),
model=task.model,
agent_name=agent.name,
task_id=task.task_id,
)
if recovery_result.status == "success":
return _recovery_to_task_result(recovery_result, task)
except Exception as recovery_exc:
logger.warning(f"Recovery layer failed: {recovery_exc}")
# Tier 3: Emergency
return _to_emergency(exc, task)
```
`_to_emergency(exc, task)` constructs `EmergencyError` via `EmergencyRules.classify(exc, config)`, then returns `TaskResult(status=FAILED, error_message=emergency.to_error_message(), error_struct=emergency.to_dict(), ...)`.
Wiring at `chat.py:613`: replace direct `agent.execute(task)` call with `execute_with_fallback_chain(...)`. Recovery config from `server_config.fallback_chain` (new `ServerConfig.fallback_chain: dict` field, mirroring Wave 1 pattern).
**Recovery layer scope (KTD5):** Only `chat.py:613` is wrapped. CLI (`cli/chat.py:338`), ReWOO (`rewoo.py:1204`), ReflexionEngine's internal ReAct call, and `config_driven.py:695` are NOT wrapped (would create recursive loop or unwanted coupling). These remain on the direct-execute path. Documented in `## Scope Boundaries`.
**Patterns to follow:**
- `config_driven.py:836` — existing `ReflexionEngine` instantiation pattern (constructor params, execute call)
- Wave 1's `verification` config section in `ServerConfig.from_dict` (`config.py:240-273`)
**Test scenarios:**
- Happy path: main agent succeeds → no Recovery, no Emergency triggered; `error_struct=None`
- Main fails (timeout) → Recovery triggered → Recovery succeeds → `error_struct=None`, output from ReflexionEngine
- Main fails (timeout) → Recovery triggered → Recovery fails (max_reflections exhausted) → Emergency triggered → `error_struct` populated with `error_code="timeout"`, `retryable=True`
- Main fails (LoopDetectedError) → Emergency `error_code="loop_detected"`
- Main fails (LLMProviderError) → Emergency `error_code="llm_failure"`
- Main fails (TaskCancelledError) → propagates as-is (NOT routed to Emergency)
- Main fails (generic Exception) → Emergency `error_code="internal_error"`, `retryable=False`
- Config: `fallback_chain.recovery.enabled=false` → skip Recovery, go directly to Emergency
- Config: `fallback_chain.emergency.enabled=false` → re-raise original exception (no Emergency)
- Integration: full chain on real `ConfigDrivenAgent` with mocked LLM (mock main raises, mock ReflexionEngine succeeds)
**Verification:** All scenarios pass; existing chat WebSocket tests still pass; ruff clean.
---
### U4. G9 — PlanPhase Rollback Fields + RollbackExecutor + TeamOrchestrator Integration
**Goal:** Add `validation_command`/`rollback_command` optional fields to `PlanPhase`; execute rollback on phase failure (opt-in); coordinate with U7 checkpoint ordering per R21.
**Requirements:** R19, R20, R21, R26, R27
**Dependencies:** none (uses existing U7 `PipelineCheckpoint` and `VerificationLoop` patterns)
**Files:**
- Modify: `src/agentkit/experts/plan.py` (`PlanPhase` dataclass + `to_dict`/`from_dict` symmetry)
- Modify: `src/agentkit/experts/orchestrator.py` (insert rollback execution between phase failure and checkpoint save)
- Create: `src/agentkit/orchestrator/rollback.py` (`RollbackExecutor` class)
- Modify: `src/agentkit/server/config.py` (add `rollback` section to `from_dict`)
- Modify: `agentkit.yaml` (document `rollback:` section)
- Create: `tests/unit/test_phase_rollback.py`
**Approach:**
**PlanPhase extension (`plan.py:147-233`):** Add two optional fields with `None` default (preserves existing contract per Finding 1):
```
validation_command: str | None = None
rollback_command: str | None = None
```
Update `to_dict()` to include both fields (only when not None, to keep dict shape minimal). Update `from_dict()` to read both fields.
**`RollbackExecutor` class (new file `orchestrator/rollback.py`):** Mirrors `VerificationLoop` pattern (`verification_loop.py:67-103`):
```
class RollbackExecutor:
def __init__(self, working_dir: str | None = None, timeout: float = 30.0): ...
async def execute(self, command: str) -> RollbackResult:
# asyncio.create_subprocess_shell with cwd/timeout/proc.kill()
# Returns RollbackResult(passed, exit_code, stdout, stderr, command)
async def validate(self, command: str) -> RollbackResult:
# Same as execute but returns passed=False on non-zero exit
```
`RollbackResult` dataclass: `{passed: bool, exit_code: int, stdout: str, stderr: str, command: str}`.
**TeamOrchestrator integration (`orchestrator.py:246-270`):** New ordering per KTD8:
```
# Existing: phase failure detected, mark FAILED + dependents FAILED
# (lines 247-261 unchanged)
# NEW: rollback phase (only if validation_command and rollback_command configured)
should_save_checkpoint = True
if ph.validation_command and ph.rollback_command:
validator = RollbackExecutor(working_dir=self._workspace_root, timeout=...)
validation_result = await validator.validate(ph.validation_command)
if not validation_result.passed:
rollback_result = await validator.execute(ph.rollback_command)
if not rollback_result.passed:
# Rollback failed → don't save checkpoint (R21)
should_save_checkpoint = False
logger.error(f"Rollback failed for phase {ph.id}: {rollback_result.stderr}")
# Emit phase_rollback_failed event
await self._broadcast_event("phase_rollback_failed", {...})
# Existing: checkpoint save (conditional now)
if should_save_checkpoint and self._checkpoint is not None:
try:
await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e:
logger.warning(...)
```
**Events emitted (new):** `phase_rollback_started`, `phase_rollback_completed`, `phase_rollback_failed`. Existing `phase_failed` event unchanged (emitted before rollback).
**Config:** `rollback.default_timeout: float = 30.0` in `ServerConfig.rollback` dict (Wave 1 pattern).
**Security boundary (KTD7):** Rollback subprocess does NOT go through `ShellTool` (avoids `confirm_callback` for `git checkout`). Audit log emitted via `_broadcast_event`. `terminal_whitelist` is not consulted (only governs `/api/v1/terminal/server` route, per `terminal_server.py:227`).
**Execution note:** characterization-first. Test current behavior (`rollback_command=None` → no rollback, checkpoint saved) before adding rollback behavior.
**Patterns to follow:**
- `VerificationLoop` subprocess execution pattern (`verification_loop.py:67-103`) — `asyncio.create_subprocess_shell` + `cwd` + `timeout` + `proc.kill()`
- `PlanPhase.to_dict`/`from_dict` symmetric serialization pattern at `plan.py:186-233`
- `TeamOrchestrator._execute_phase` event broadcasting pattern (`orchestrator.py:253-260`)
- Wave 1's `ServerConfig.from_dict` extension template
**Test scenarios:**
- Characterization: `PlanPhase()` with no `validation_command`/`rollback_command` → `to_dict()` output matches pre-change shape (no new keys); `from_dict({})` produces phase with both fields as None
- Serialization: phase with `rollback_command="git checkout foo.py"``to_dict()` includes key; `from_dict(to_dict())` round-trips
- RollbackExecutor happy path: `execute("git status")` returns `passed=True`, `exit_code=0`
- RollbackExecutor timeout: `execute("sleep 10", timeout=0.1)` returns `passed=False`, exit_code=-1 (or similar)
- RollbackExecutor failure: `execute("false")` returns `passed=False`, `exit_code=1`
- Integration — opt-in default: phase fails, `rollback_command=None` → no RollbackExecutor call, checkpoint saved (existing behavior)
- Integration — rollback configured, validation passes (returns 0): rollback NOT executed, checkpoint saved
- Integration — rollback configured, validation fails, rollback succeeds (returns 0): rollback executed, checkpoint saved
- Integration — rollback configured, validation fails, rollback fails (returns 1): rollback executed, checkpoint NOT saved (R21), `phase_rollback_failed` event emitted
- Integration — real git repo fixture: phase writes file via `git checkout`, rollback `git checkout foo.py` restores file; assert file content matches pre-phase state
**Verification:** All scenarios pass; existing `tests/unit/test_pipeline_state.py` and `tests/unit/test_team_orchestrator*.py` (if any) still pass; ruff clean.
---
## Scope Boundaries
### Deferred to Follow-Up Work
- **Recovery wiring at non-chat call sites** (`cli/chat.py:338`, `rewoo.py:1204`, `config_driven.py:695`). KTD5 limits Wave 2 to the primary chat WebSocket path. Other entry points can adopt the same wrapper in a follow-up.
- **Recovery layer streaming events.** Wave 2's chain returns final `TaskResult`; SSE events for "recovery_started"/"recovery_completed" would require deeper WebSocket protocol changes.
- **Patch-level rollback** (`git apply -R <patch>`). KTD6 + KTD7 scope Wave 2 to `git checkout <files>` pattern. Patch-level requires extending `CheckpointData` schema to track patches — Wave 3 candidate.
- **DB-backed atomic claim for checkpoint.** Finding 3 noted `PipelineCheckpoint` is in-memory dict + Redis fallback, not a PostgreSQL `FOR UPDATE SKIP LOCKED` pattern. Multi-process atomicity is out of scope; single-process `asyncio.Lock` (already present at `orchestrator.py:53`) is sufficient.
### Outside this product's identity (carried from brainstorm)
- Wave 3 (G5/G6) — tree-sitter integration, SOLO four-stage state machine. Strategic direction locked in brainstorm KTD6/KTD7; implementation design deferred to Wave 3 plan.
- Node-level checkpoint (ReAct single-step). Stage-level (U7) satisfies core need.
- DeerFlow-style disk filesystem. Redis is the persistence layer.
- Full LangGraph migration. Self-built architecture stays.
---
## Risks & Dependencies
- **Risk: Auxiliary model availability.** If `auxiliary_model` alias is not configured in `agentkit.yaml`, `LLMGateway` raises `ModelNotFoundError`. Mitigation: `ContextCompressor` catches this in the auxiliary try-block and falls through to main model (R15 fallback path). Documented in YAML comments.
- **Risk: Recovery layer recursive loop.** `ReflexionEngine.execute()` internally calls `ReActEngine.execute()` (5th call site). If Recovery itself fails, it must NOT trigger another Recovery — KTD5 wires only at `chat.py:613`, so ReflexionEngine's internal ReAct call bypasses the chain. Recursive loop is structurally impossible.
- **Risk: `git checkout` destructive scope.** Misconfigured `rollback_command` (e.g., `git checkout .`) could wipe unrelated changes. Mitigation: rollback is opt-in per-phase (KTD6); audit log via `phase_rollback_started`/`phase_rollback_failed` events; documented YAML examples use file-scoped commands (`git checkout <specific_files>`).
- **Risk: `TaskResult.error_struct` field addition breaks pickle/serialization.** `TaskResult` is a dataclass with `to_dict`; new optional field with `None` default is backward-compatible for `to_dict` consumers. Pickle deserialization of pre-change data still works (new field defaults to None).
- **Dependency: Wave 1 merged.** `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections) is established and tested. Wave 2's 3 new sections (auxiliary_model, fallback_chain, rollback) reuse this exact pattern.
- **Dependency: U7 PipelineCheckpoint already shipped.** `orchestrator/checkpoint.py:56` exists with `save(plan_id, phase, plan_status)` API; U4 only adjusts the calling order, not the API.
---
## Sources / Research
- Brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 2 section, R13-R21, KTD4/KTD5)
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (ServerConfig.from_dict extension template established)
- Finding 1 (contract preservation rule): `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — "新字段默认值须保持既有契约" drives KTD6's opt-in default for `rollback_command`
- Finding 2 (retry-storm defense): `docs/solutions/security-issues/portal-platform-security-reliability-fixes.md` — Emergency layer's terminal `error_code` pattern (don't propagate unhandled exceptions)
- Finding 3 (atomic claim pattern): `docs/solutions/architecture-patterns/bitable-companion-service-security-reliability-patterns.md``SKIP LOCKED` not reusable; `PipelineCheckpoint` is in-memory dict + Redis
- Finding 4 (empty-response anti-pattern): `docs/solutions/ui-bugs/tauri-reload-loses-session.md` — auxiliary LLM must check truthy return value, not just absence of exception; drives U1's empty-content fallback
- Code locations verified during planning:
- `src/agentkit/core/compressor.py:35-44,123-158``_summarize` injection point for G4
- `src/agentkit/llm/gateway.py:170-227` — existing per-model fallback chain (semantics differ from G4's task-type routing)
- `src/agentkit/llm/config.py:198-257``LLMConfig` dataclass and `from_dict`
- `src/agentkit/core/reflexion.py:68-111``ReflexionEngine.__init__` and `execute()` signatures (already supports `evaluate_model`/`reflect_model`)
- `src/agentkit/core/fallback.py` (full file, 19 lines) — 3 existing constants; `EmergencyRules` adds alongside
- `src/agentkit/core/protocol.py:118-145``TaskResult` dataclass and `to_dict`
- `src/agentkit/experts/plan.py:147-233``PlanPhase` dataclass, `to_dict`/`from_dict`
- `src/agentkit/experts/orchestrator.py:246-270` — phase failure capture + checkpoint save site (U4 integration point)
- `src/agentkit/orchestrator/checkpoint.py:56``PipelineCheckpoint` (U7, no API change needed)
- `src/agentkit/core/verification_loop.py:67-103` — subprocess execution pattern for `RollbackExecutor`
- `src/agentkit/tools/shell.py:168-174,526-534``_DANGEROUS_BINARY_FLAGS` (git checkout not listed), `_is_dangerous` logic (KTD7 rationale)

View File

@ -0,0 +1,436 @@
---
title: "feat: Agent Wave 3 strategic coupling (G5/G6)"
date: 2026-06-29
type: feat
status: draft
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
execution: code
---
# Wave 3 Strategic Coupling — G5 Function-level Sharding + G6 SOLO Phase Constraints
## Summary
Wave 3 of the advanced-agent gap optimization closes two strategic gaps deferred from Waves 1-2:
- **G5 — Function-level code sharding** (R22, R23): file reading gains an optional `symbol` parameter for symbol/function-granularity slicing, backward compatible with full-file reads.
- **G6 — SOLO four-stage state machine** (R24, R25): ReAct loop enforces per-phase tool whitelists (Planning → Building → Verification → Delivery). Extends existing `ExecutionMode.PLAN_EXEC` rather than introducing a new mode.
Wave 1 (G1/G2/G3/G8 — PR #4 merged) and Wave 2 (G4/G7/G9 — PR #5 open) shipped independently. Wave 3 is the **strategic-risk** wave: it introduces a new tool (G5) and touches ReAct core (G6). Per the brainstorm's KTD6/KTD7 locked decisions, G5 integration approach is decided here (not deferred further), and G6 extends PLAN_EXEC rather than adding a new mode.
## Problem Frame
The brainstorm (`docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md`) identified nine gaps across three dimensions. Waves 1-2 closed seven (G1-G4, G7-G9). The remaining two gaps are strategic:
- **G5 (Long-context cost)**: Large files (e.g. 5000-line modules) blow context budget when the agent only needs one function. Today the agent shells out to `cat file.py` or greps; both pull the whole file into context. A symbol-aware slice would cut context cost 10-50x for typical edits.
- **G6 (Unsafe tool sequencing)**: ReAct loop lets the LLM call `write_file` during early exploration before committing to a plan. This wastes tokens on premature edits, causes half-baked refactors, and breaks the "plan-then-build" discipline that production agents (Qoder, Trae Work) enforce via phase state machines.
KTD6 (brainstorm) defers the G5 integration decision to this plan. KTD7 locks G6 to extend PLAN_EXEC rather than introduce a new mode.
## Requirements
Carried forward from the brainstorm, Wave 3 section:
- **R22**: File reading supports symbol/function granularity sharding.
- **R23**: Sharding capability exposed as a tool parameter (`symbol="function_name"`), backward compatible with full-file reads.
- **R24**: ReAct loop enforces phase constraints — Planning phase only allows `think`/`search`; Building phase only allows `write_file` (and similar write tools).
- **R25**: Phase state is configurable; extends `ExecutionMode.PLAN_EXEC`, does NOT introduce a new mode.
Cross-cutting (brainstorm):
- **R26**: All optimizations configurable via `agentkit.yaml` (follow `ServerConfig.from_dict` pattern established in Waves 1-2).
- **R27**: Each optimization ships a minimal self-check test (ponytail rule).
Acceptance examples relevant to Wave 3:
- **AE5 already covered by Wave 2 (G9)** — not in Wave 3 scope.
- No explicit AE for G5/G6 in the brainstorm — the plan below specifies test scenarios as the acceptance contract.
## Key Technical Decisions
### KTD1: G5 uses Python `ast` + language-aware regex, NOT tree-sitter
**Decision**: Implement symbol extraction with the Python stdlib `ast` module for Python files, and a small regex-based extractor for TypeScript/JavaScript/Go/Rust/Java. No new dependency.
**Rationale**:
- `tree-sitter` requires native compilation + per-language grammar files (~30MB installed) — violates the ponytail rule "no new dependency if it can be avoided" and AGENTS.md "禁止使用 any 类型" → prefers minimal stack.
- `ast` is stdlib, always available, parses Python accurately.
- Regex extractor covers 80% case for TS/JS/Go/Rust/Java (function/class/struct declarations); falls back to "no symbols found → read full file" gracefully.
- If a future Wave 4 needs more accurate multi-language parsing, `tree-sitter` can replace the regex layer behind the same `SymbolExtractor` interface.
**Upgrade path**: replace `RegexSymbolExtractor` with `TreeSitterSymbolExtractor` implementing the same `SymbolExtractor` protocol; no caller changes.
### KTD2: G5 adds a new `ReadFileTool`, does not extend `ShellTool`
**Decision**: Add a dedicated `ReadFileTool` in `src/agentkit/tools/file_read.py` with `path` + optional `symbol` + optional `start_line`/`end_line` parameters.
**Rationale**:
- `ShellTool` is for shell execution; grafting symbol extraction onto it muddies the contract.
- A dedicated tool gives the LLM a clear schema (`{"path": "...", "symbol": "function_name"}`) and a focused system-prompt description.
- Aligns with the existing `_DEFAULT_CORE_TOOLS` list in `core/react.py:148` which already references `read_file` — the name is reserved but the implementation is missing.
### KTD3: G6 phase state machine lives in ReActEngine, not in the skill config
**Decision**: Phase state (Planning/Building/Verification/Delivery) is tracked as a mutable field on `ReActEngine` instance. Transitions are driven by LLM-detected phase-completion signals (e.g., the LLM emits `Phase: Building` in its thinking) OR by an explicit `advance_phase` tool.
**Rationale**:
- Skill config declares the policy (which tools per phase, auto-advance vs manual); the engine enforces it per-step. This matches R24 ("ReAct 循环加阶段约束").
- Alternative considered: phase state in the agent instance (not engine). Rejected because ReActEngine already owns `max_steps`/`verification_enabled` etc.; phase state belongs with the loop that enforces it.
### KTD4: PLAN_EXEC mode is wired at chat.py WebSocket path (REST already has fallback chain from Wave 2)
**Decision**: chat.py:1084 (currently warns "not yet supported, falling back to REACT") will route `ExecutionMode.PLAN_EXEC` to a new `_execute_plan_exec_ws` handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and passes it to `ReActEngine.execute`.
**Rationale**:
- REST `send_message` already uses the Wave 2 three-tier fallback chain; PLAN_EXEC at REST would also need that wrapper. **Out of scope for Wave 3** — only WebSocket path is wired. REST PLAN_EXEC remains "not yet supported" and explicitly raises if invoked.
- Single integration point keeps Wave 3 scope bounded; REST wiring is a one-line follow-up once WebSocket path is proven.
### KTD5: Default phase whitelist matches brainstorm R24
**Decision**: Default whitelist:
- `Planning`: `search`, `tool_search`, `read_file`, `bash` (read-only commands like `git status`, `ls`)
- `Building`: `write_file`, `bash` (write commands), `read_file`, `search`
- `Verification`: `bash` (test commands), `read_file`, `search`
- `Delivery`: all tools (final synthesis)
**Rationale**:
- R24 explicitly names `think`/`search` for Planning and `write_file` for Building.
- `bash` is split: read-only in Planning, full in Building. Enforced by adding a `bash_command_filter` callback (regex-based, blocks `rm`/`mv`/`>`/`>>` in Planning/Verification).
- `Delivery` allows all tools to support last-mile formatting/cleanup.
### KTD6: Phase transitions are LLM-driven via `advance_phase` tool (opt-in auto-advance)
**Decision**: Add an `AdvancePhaseTool` that the LLM can call to transition Planning→Building→Verification→Delivery. Auto-advance (after N steps in current phase) is opt-in via `plan_exec.auto_advance_after_steps`.
**Rationale**:
- LLM-driven transitions match Qoder/Trae Work pattern: LLM declares "planning done" explicitly.
- Auto-advance is a safety net for LLMs that forget to call `advance_phase`; default off (ponytail: less code is better).
## Scope Boundaries
### In Scope
- New `ReadFileTool` with `symbol` parameter (G5, R22/R23).
- `SymbolExtractor` protocol + `AstSymbolExtractor` (Python) + `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
- `PhasePolicy` dataclass + `PhaseState` enum + per-step tool whitelist enforcement in `ReActEngine.execute` (G6, R24/R25).
- `AdvancePhaseTool` for LLM-driven phase transitions.
- WebSocket chat path routes `PLAN_EXEC` to new `_execute_plan_exec_ws` handler (KTD4).
- `plan_exec` config section in `agentkit.yaml` + `ServerConfig.from_dict` extension (R26).
- Tests for each new module (R27).
### Out of Scope (Deferred to Follow-Up Work)
- REST `send_message` PLAN_EXEC wiring — once WebSocket path is proven, REST wiring is a follow-up commit.
- `tree-sitter` integration for more accurate multi-language parsing (KTD1 upgrade path).
- Phase-aware prompt engineering (per-phase system prompt templates) — current plan keeps a single system prompt; phase-specific guidance is a prompt-engineering concern, not a code change.
- Phase persistence across session resume (U7 checkpoint already saves plan state; phase state restoration is a separate concern).
- Phase rollback on `Building``Planning` regression (Wave 2 G9 rollback handles file-level rollback; phase regression is a UX/prompt concern).
- Tool-filter UI in the frontend (Wave 3 ships backend-only; frontend surfaces phase via existing event channel if needed in a follow-up).
### Outside This Product's Identity
- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm).
- Disc-based file system à la DeerFlow (inherited).
- Docker sandbox (inherited; only command-level safety via `bash_command_filter`).
## High-Level Technical Design
```mermaid
flowchart LR
subgraph G5[Function Sharding]
RF[ReadFileTool] --> SE[SymbolExtractor protocol]
SE --> AST[AstSymbolExtractor<br/>Python stdlib ast]
SE --> RX[RegexSymbolExtractor<br/>TS/JS/Go/Rust/Java]
end
subgraph G6[Phase State Machine]
PP[PhasePolicy config] --> PS[PhaseState enum<br/>Planning/Building/Verify/Delivery]
PS --> Filt[Tool filter per step]
Filt --> RE[ReActEngine.execute]
AP[AdvancePhaseTool] -->|transitions| PS
end
RF -->|file content for symbol| RE
RE -->|enforces| Filt
```
The two subsystems compose at the ReAct engine boundary: `ReadFileTool` is one of the tools the LLM can call during any phase (filtered by `PhasePolicy`); `PhaseState` is enforced at the tool-call step before dispatch.
## Implementation Units
### U1. SymbolExtractor + ReadFileTool (G5)
**Goal**: Add `ReadFileTool` with optional `symbol` parameter; implement `SymbolExtractor` protocol with `AstSymbolExtractor` (Python) and `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
**Requirements**: R22, R23, R27.
**Dependencies**: none.
**Files**:
- `src/agentkit/tools/file_read.py` (new)
- `src/agentkit/tools/symbol_extractor.py` (new)
- `src/agentkit/tools/__init__.py` (modify — register `ReadFileTool`)
- `tests/unit/test_symbol_extractor.py` (new)
- `tests/unit/test_read_file_tool.py` (new)
**Approach**:
- `SymbolExtractor` is a `Protocol` with one method: `extract_symbols(content: str, language: str) -> list[SymbolSpan]`. `SymbolSpan` carries `name`, `kind` (function/class/method/struct), `start_line`, `end_line`.
- `AstSymbolExtractor` walks `ast.parse(content)`; for each `FunctionDef`/`AsyncFunctionDef`/`ClassDef` collects name + line range. Uses `ast.get_source_segment` style (line-based, not node-based, to keep the API simple).
- `RegexSymbolExtractor` ships patterns for TS/JS (`function X`, `const X = (...) =>`, `class X`), Go (`func X`), Rust (`fn X`, `struct X`, `impl X`), Java (`public ... X(...)`). Falls back to "no symbols" if no pattern matches.
- `ReadFileTool.execute(path, symbol=None, start_line=None, end_line=None)`:
- `symbol=None` → read full file (backward compat with the existing `_FakeTool` benchmark shape).
- `symbol="foo"` → detect language from extension; call `extract_symbols`; return the line range of the first matching symbol; if no match, return an error result with available symbol names listed (so the LLM can retry).
- `start_line`/`end_line` overrides symbol; allows manual slicing.
- Tool registered as `read_file` (matches the reserved name in `core/react.py:148`).
**Execution note**: characterization-first — write a test that asserts the tool returns the full file content when `symbol=None` (matches pre-existing benchmark `_FakeTool` shape) before adding symbol-extraction behavior.
**Patterns to follow**:
- `src/agentkit/tools/document_tool.py` for tool structure (dataclass, `Tool` base class, `input_schema`).
- `src/agentkit/tools/schema_tools.py:SchemaExtractTool` for "extract-from-source" pattern.
**Test scenarios** (covers R22, R23):
- **Happy paths**:
- Python file, `symbol="MyClass"` → returns class body only (lines from `class MyClass:` through end of class).
- Python file, `symbol="my_func"` → returns function body only.
- TypeScript file, `symbol="renderComponent"` → returns arrow/function body.
- Go file, `symbol="HandleRequest"` → returns func body.
- **Edge cases**:
- `symbol=None` → returns full file content (characterization).
- `symbol="nonexistent"` → returns error result listing available symbols ("Available symbols: foo, bar, baz").
- Unsupported file extension (`.md`, `.txt`) → returns full file with `note: symbol extraction not supported for .md`.
- Empty file → returns empty content.
- File with nested classes → outer class symbol returns including inner class.
- **Error paths**:
- Path does not exist → raises `FileNotFoundError` (or returns error result matching other tools' convention).
- Path is a directory → returns error result.
- Permission denied → returns error result.
- **Integration scenarios**:
- Symbol extraction + line slicing: `symbol="foo"`, `end_line=50` truncates at line 50 even if symbol extends further.
- Round-trip: extract symbol, write back via `ShellTool` `sed` (not in scope for tool — just verify extracted range is well-formed).
**Verification**:
- `python3 -m pytest tests/unit/test_symbol_extractor.py tests/unit/test_read_file_tool.py -q` passes.
- `ruff check src/agentkit/tools/file_read.py src/agentkit/tools/symbol_extractor.py` clean.
- `ReadFileTool` appears in `ToolRegistry.list_tools()` after registration.
---
### U2. PhasePolicy + PhaseState + ServerConfig (G6 core)
**Goal**: Add `PhasePolicy` dataclass, `PhaseState` enum, default whitelist config. Extend `ServerConfig.from_dict` with `plan_exec` section. Wire config to `agentkit.yaml`.
**Requirements**: R25, R26.
**Dependencies**: none.
**Files**:
- `src/agentkit/core/phase.py` (new) — `PhaseState` enum, `PhasePolicy` dataclass, `default_policy()` factory.
- `src/agentkit/server/config.py` (modify — add `plan_exec` field + `from_dict` parsing).
- `agentkit.yaml` (modify — document `plan_exec:` section).
- `tests/unit/test_phase_policy.py` (new).
**Approach**:
- `PhaseState = enum("planning building verification delivery")`.
- `PhasePolicy` carries:
- `whitelist: dict[PhaseState, set[str]]` — tool names allowed per phase.
- `bash_command_filter: dict[PhaseState, re.Pattern | None]` — regex that bash args must NOT match (e.g., `r"\b(rm|mv|>|>>)\b"` in Planning).
- `auto_advance_after_steps: int | None` — None = manual (LLM calls `advance_phase`); int = auto-advance after N steps.
- `start_phase: PhaseState = PhaseState.PLANNING`.
- `default_policy()` returns the KTD5 whitelist above.
- `ServerConfig.from_dict` reads `plan_exec` section: `enabled`, `whitelist_override` (dict), `auto_advance_after_steps`.
- `agentkit.yaml` gains a commented-out `plan_exec:` block (commented to preserve default behavior — opt-in).
**Patterns to follow**:
- `src/agentkit/core/fallback.py` for dataclass + classmethod factory pattern.
- `src/agentkit/server/config.py` `from_dict` extension template (established in Wave 1 for `prompt_cache`/`streaming`/`verification`; Wave 2 added `rollback`/`fallback_chain`; Wave 3 adds `plan_exec`).
**Test scenarios** (covers R25, R26):
- **Happy paths**:
- `default_policy()` returns policy with all four phases; Planning whitelist contains `search`, `read_file`; Building contains `write_file`.
- `PhasePolicy.is_tool_allowed("search", PhaseState.PLANNING)` returns True.
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.PLANNING)` returns False.
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.BUILDING)` returns True.
- **Edge cases**:
- Empty whitelist for a phase → all tools rejected (raises `ValueError` at construction time — fail-fast).
- `Delivery` phase whitelist contains `"*"` (wildcard) → all tools allowed.
- Custom whitelist override merges with default (override wins on conflict).
- **Error paths**:
- Invalid phase name in config → `ValueError` with message naming the bad value.
- `bash_command_filter` regex compile failure → `ValueError`.
- **Config integration**:
- `ServerConfig.from_dict({"plan_exec": {"enabled": True, "auto_advance_after_steps": 5}})` populates fields correctly.
- `ServerConfig.from_dict({})``plan_exec = {}` (default).
**Verification**:
- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes.
- `ruff check src/agentkit/core/phase.py` clean.
---
### U3. AdvancePhaseTool + ReActEngine phase enforcement (G6 wiring)
**Goal**: Add `AdvancePhaseTool`. Wire `PhasePolicy` into `ReActEngine.execute` so each tool-call step checks `is_tool_allowed(tool_name, current_phase)` before dispatch; blocked calls return a structured error to the LLM ("Tool 'write_file' not allowed in Planning phase — call advance_phase first").
**Requirements**: R24.
**Dependencies**: U2.
**Files**:
- `src/agentkit/tools/advance_phase.py` (new) — `AdvancePhaseTool` calls `react_engine.advance_phase()`.
- `src/agentkit/core/react.py` (modify — add `phase_policy` param to `__init__` + `execute`; add `_current_phase` field; add `advance_phase()` method; enforce in `_execute_loop`).
- `tests/unit/test_react_phase_enforcement.py` (new).
**Approach**:
- `ReActEngine.__init__` accepts `phase_policy: PhasePolicy | None = None`. None = no enforcement (backward compat — all existing callers unaffected).
- `_current_phase: PhaseState | None` initialized from `phase_policy.start_phase` if policy set, else None.
- `advance_phase()` advances `_current_phase` to next enum value; raises `ValueError` if already at `DELIVERY`.
- In `_execute_loop`, before dispatching a tool call:
```python
if self._phase_policy is not None and self._current_phase is not None:
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
# Inject structured error into conversation, do NOT dispatch tool.
# This counts as a "step" for max_steps purposes.
observation = {
"error": "phase_violation",
"message": f"Tool '{tool_name}' not allowed in {self._current_phase.value} phase",
"current_phase": self._current_phase.value,
"hint": "Call advance_phase to move to Building phase"
}
continue # next loop iteration
```
- Auto-advance: if `phase_policy.auto_advance_after_steps` is set and `_steps_in_phase >= auto_advance_after_steps`, call `advance_phase()` automatically.
- `AdvancePhaseTool.execute()` calls the bound engine's `advance_phase()` and returns the new phase name. Registered only when `phase_policy` is not None.
**Execution note**: characterization-first — test that `ReActEngine` with `phase_policy=None` behaves identically to pre-change (no enforcement, no `advance_phase` tool, no `_current_phase` mutation). Then add enforcement tests.
**Patterns to follow**:
- `src/agentkit/core/react.py` `verification_enabled` pattern (feature flag + step-level check).
- `src/agentkit/tools/ask_human.py` for tool that interacts with engine state.
**Test scenarios** (covers R24):
- **Characterization (no policy)**:
- `ReActEngine(phase_policy=None)` — all tools allowed in all steps; no `advance_phase` tool registered; behavior matches pre-change.
- **Happy paths**:
- Planning phase: LLM calls `search` → executes; LLM calls `advance_phase` → phase becomes Building.
- Building phase: LLM calls `write_file` → executes; LLM calls `advance_phase` → phase becomes Verification.
- Verification phase: LLM calls `bash` with `pytest` → executes; LLM calls `advance_phase` → phase becomes Delivery.
- Delivery phase: LLM calls any tool → executes (wildcard).
- **Edge cases**:
- `advance_phase` called at Delivery → returns error "Already at final phase".
- Auto-advance after 3 steps in Planning → phase transitions automatically on 4th step.
- `bash` command in Planning contains `rm file` → blocked by `bash_command_filter`.
- **Error paths**:
- LLM calls `write_file` in Planning → tool NOT dispatched; structured error returned to LLM; loop continues.
- LLM calls non-existent tool → existing error path (not phase-related).
- **Integration scenarios**:
- Phase transition emits a `phase_changed` event (use existing `_broadcast_event` pattern from `experts/orchestrator.py`).
- `max_steps` reached mid-phase → `ReActResult.status = "max_steps_reached"` (existing path, no change).
**Verification**:
- `python3 -m pytest tests/unit/test_react_phase_enforcement.py -q` passes.
- Existing `tests/unit/test_react_engine.py` still passes (characterization — no policy = no change).
- `ruff check src/agentkit/core/react.py src/agentkit/tools/advance_phase.py` clean.
---
### U4. Wire PLAN_EXEC at chat.py WebSocket path (G6 chat integration)
**Goal**: Replace the `chat.py:1084` "not yet supported, falling back to REACT" warning with a real PLAN_EXEC handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and dispatches to `ReActEngine.execute` with the policy set.
**Requirements**: R24, R25 (end-to-end wiring).
**Dependencies**: U2, U3.
**Files**:
- `src/agentkit/server/routes/chat.py` (modify — add `_execute_plan_exec_ws` handler; branch on `ExecutionMode.PLAN_EXEC`).
- `tests/unit/test_chat_plan_exec_ws.py` (new).
**Approach**:
- New helper `_execute_plan_exec_ws(websocket, agent, routing, messages, ...)`:
1. Read `server_config.plan_exec` (may be `{}` if not configured → use `default_policy()`).
2. Build `PhasePolicy` from config (apply overrides).
3. Construct `ReActEngine(..., phase_policy=policy)`.
4. Register `AdvancePhaseTool` bound to this engine.
5. Call `engine.execute_stream(...)` — reuses existing streaming path.
6. Emit `phase_changed` events through the WebSocket (frontend can render phase indicator).
- chat.py:1084 changes from `if execution_mode not in (REACT, SKILL_REACT): warn + fall back` to:
```python
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
await _execute_plan_exec_ws(websocket, agent, routing, ...)
return
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
# existing warning for REWOO/REFLEXION/TEAM_COLLAB
...
```
- REST `send_message` path: explicitly raise `HTTPException(501, "PLAN_EXEC via REST not yet supported; use WebSocket")` — Wave 3 does NOT wire REST (KTD4).
**Execution note**: characterization-first — test that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring.
**Patterns to follow**:
- `src/agentkit/server/routes/chat.py` existing WebSocket handler structure (lines 1082-1100).
- `src/agentkit/server/_fallback_chain.py` (Wave 2 U3) for "construct engine per-request with config" pattern.
**Test scenarios** (covers end-to-end):
- **Characterization**:
- `ExecutionMode.REWOO` via WebSocket → still falls back to REACT with warning (existing behavior unchanged).
- `ExecutionMode.REFLEXION` → same.
- `ExecutionMode.TEAM_COLLAB` → same.
- **Happy paths**:
- `ExecutionMode.PLAN_EXEC` via WebSocket → `_execute_plan_exec_ws` invoked; `ReActEngine` constructed with `phase_policy`; `AdvancePhaseTool` registered.
- Planning phase: LLM emits `search` tool call → executed; tool result streamed.
- LLM emits `advance_phase``phase_changed` event sent to WebSocket client; subsequent `write_file` call now allowed.
- **Edge cases**:
- `plan_exec` config absent → `default_policy()` used; behavior matches KTD5 whitelist.
- `plan_exec.enabled=False` → falls back to REACT (opt-out).
- Phase violation: LLM calls `write_file` in Planning → structured error returned; loop continues; `phase_violation` event emitted.
- **Error paths**:
- REST `send_message` with PLAN_EXEC → 501 error.
- Phase policy construction fails (bad config) → 500 error with message.
- **Integration scenarios**:
- Existing fallback chain (Wave 2 U3) NOT applied to PLAN_EXEC — phase policy and fallback chain are mutually exclusive (KTD5 from Wave 2 plan: chain only wraps REACT/SKILL_REACT at REST). Document this in chat.py comment.
**Verification**:
- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py -q` passes.
- `ruff check src/agentkit/server/routes/chat.py` clean.
- Manual test: `agentkit chat` with `@skill:plan_exec_demo` skill config → WebSocket stream includes `phase_changed` events.
---
## Risks & Dependencies
### Risks
1. **ReAct core modification risk (high)**: U3 modifies `ReActEngine._execute_loop`. Mitigation: characterization-first tests (U3 Execution note); `phase_policy=None` default preserves all existing behavior; full `test_react_engine.py` regression.
2. **Symbol extraction accuracy (medium)**: Regex extractor may miss edge cases (decorated functions, nested generics, multi-line signatures). Mitigation: fall back to "no symbols found → read full file" gracefully; never raise on extraction failure.
3. **PLAN_EXEC phase deadlock (medium)**: LLM may never call `advance_phase`, leaving the agent stuck in Planning. Mitigation: `auto_advance_after_steps` config (default 5); timeout via existing `max_steps`.
4. **Tool name drift (low)**: Phase whitelist references tool names (`write_file`, `search`, etc.) that may be renamed in future. Mitigation: whitelist is config-driven; rename only requires config update.
### Dependencies
- Wave 2 PR #5 (`feat/agent-wave2-medium-coupling`) should be merged first — Wave 3 builds on the `ServerConfig.from_dict` extension pattern and the `_fallback_chain.py` integration shape established there. If PR #5 is still open, Wave 3 branches from `feat/agent-wave2-medium-coupling` rather than `main`.
- No external library dependencies (KTD1).
## System-Wide Impact
- **Agents using PLAN_EXEC mode**: gain phase enforcement. Existing REACT/SKILL_REACT/DIRECT_CHAT agents: zero change (phase_policy defaults to None).
- **Tool registry**: gains two new tools (`read_file`, `advance_phase`). Frontend tool list display may need updating to show the new icons — out of scope for Wave 3 (frontend follows up).
- **`agentkit.yaml`**: gains `plan_exec:` section (commented by default). Existing configs unaffected.
- **WebSocket clients**: gain `phase_changed` event type. Existing clients ignore unknown event types (verified in Wave 2 — `phase_rollback_*` events follow the same pattern).
## Sources & Research
- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 section, KTD6/KTD7).
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (PR #4 merged).
- Wave 2 plan: `docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md` (PR #5 open).
- Trae Work architecture research (cited in brainstorm): SOLO four-stage state machine pattern.
- Qoder architecture research (cited in brainstorm): Spec→Coding→Verify closed loop.
- Codebase: `src/agentkit/core/react.py:148` reserves `read_file`/`write_file` tool names in `_DEFAULT_CORE_TOOLS` — Wave 3 U1 delivers the missing `read_file` implementation.
- Codebase: `src/agentkit/server/routes/chat.py:1084` documents that PLAN_EXEC is "not yet supported" — Wave 3 U4 closes this gap.
## Deferred to Implementation
- Exact regex patterns for non-Python symbol extraction (U1) — design above gives the shape; implementer finalizes patterns based on real-world test fixtures.
- `bash_command_filter` regex precision (U2) — defaults block `rm`/`mv`/`>`/`>>`; implementer may add more based on test scenarios.
- `phase_changed` event payload shape (U3/U4) — minimal viable shape: `{"phase": "building", "previous": "planning"}`; frontend rendering concerns are out of scope.
- Whether `AdvancePhaseTool` accepts a `target_phase` argument for skipping phases (e.g., Planning → Verification) — default no (sequential only); add if test scenarios reveal a need.

View File

@ -41,6 +41,7 @@ class ContextCompressor:
model_context_limit: int = 128_000,
headroom_threshold: float = 0.8,
min_tokens: int = 8_000,
auxiliary_model: str | None = None,
):
self._llm_gateway = llm_gateway
self._max_tokens = max_tokens
@ -51,6 +52,11 @@ class ContextCompressor:
self._model_context_limit = model_context_limit
self._headroom_threshold = headroom_threshold
self._min_tokens = min_tokens
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
# When set and differs from main model, _summarize tries auxiliary first,
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
self._auxiliary_model = auxiliary_model
def should_compress(self, messages: list[dict]) -> bool:
"""Check if compression should be triggered based on headroom ratio.
@ -92,8 +98,8 @@ class ContextCompressor:
if len(non_system) <= self._keep_recent:
return messages # Not enough messages to compress
old_msgs = non_system[:-self._keep_recent]
recent_msgs = non_system[-self._keep_recent:]
old_msgs = non_system[: -self._keep_recent]
recent_msgs = non_system[-self._keep_recent :]
# Compress old messages
summary = await self._summarize(old_msgs)
@ -101,10 +107,12 @@ class ContextCompressor:
# Build compressed message list
compressed = list(system_msgs)
if summary:
compressed.append({
"role": "system",
"content": f"## Conversation Summary\n{summary}",
})
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.extend(recent_msgs)
# Recursive check: if still over budget, compress again
@ -114,22 +122,30 @@ class ContextCompressor:
return self._truncate(compressed)
if len(recent_msgs) > 1:
# Try keeping fewer recent messages
return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1)
return await self._compress_aggressive(
messages, _compression_depth=_compression_depth + 1
)
# Last resort: truncate
return self._truncate(compressed)
return compressed
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
"""Summarize a list of messages using LLM"""
"""Summarize a list of messages using LLM.
G4/U1: When ``auxiliary_model`` is configured and differs from the main
model, try auxiliary first (cost-optimization). On auxiliary failure OR
empty content (Finding 4 anti-pattern "did not throw is not succeeded"),
fall back to main model. Existing ``_simple_summary`` degradation
preserved as the final tier when main model also fails.
"""
if not self._llm_gateway:
# No LLM available, do simple truncation
return self._simple_summary(messages)
# Build summary prompt
conversation_text = "\n".join(
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
for m in messages
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
)
# Pre-truncate if conversation_text exceeds safe token threshold
@ -145,6 +161,25 @@ class ContextCompressor:
f"{conversation_text}"
)
# G4: Try auxiliary model first when configured (cheap route).
if self._auxiliary_model and self._auxiliary_model != self._model:
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._auxiliary_model,
agent_name="compressor",
task_type="summarization",
)
# Finding 4: empty content is a failure, not a success.
if response.content and response.content.strip():
return response.content
logger.info("Auxiliary model returned empty content, falling back to main model")
except Exception as e:
logger.info(
f"Auxiliary model summarization failed, falling back to main model: {e}"
)
# Main model path (or auxiliary fallback).
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
@ -166,7 +201,9 @@ class ContextCompressor:
parts.append(f"[{role}]: {content}...")
return "\n".join(parts)
async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
async def _compress_aggressive(
self, messages: list[dict], _compression_depth: int = 0
) -> list[dict]:
"""More aggressive compression when standard compression isn't enough"""
system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"]
@ -176,10 +213,12 @@ class ContextCompressor:
summary = await self._summarize(non_system[:-1])
compressed = list(system_msgs)
if summary:
compressed.append({
"role": "system",
"content": f"## Conversation Summary\n{summary}",
})
compressed.append(
{
"role": "system",
"content": f"## Conversation Summary\n{summary}",
}
)
compressed.append(non_system[-1])
return compressed
@ -191,7 +230,7 @@ class ContextCompressor:
for msg in messages:
content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 4:
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"}
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
result.append(msg)
return result
@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
if provider == "headroom":
try:
from agentkit.core.headroom_compressor import HeadroomCompressor
compressor = HeadroomCompressor(config)
if compressor.is_available():
return compressor
@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
)
except ImportError:
logger.warning(
"HeadroomCompressor module not available. "
"Falling back to ContextCompressor."
"HeadroomCompressor module not available. Falling back to ContextCompressor."
)
# Fallback to summary compressor
return ContextCompressor(
@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables"""
cache_key = hashlib.md5(
json.dumps(variables or {}, sort_keys=True).encode()
).hexdigest()
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
if not hasattr(template, '_render_cache'):
if not hasattr(template, "_render_cache"):
template._render_cache = {}
if cache_key in template._render_cache:
@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic
def clear_cache(template) -> None:
"""Clear the render cache on a PromptTemplate instance"""
if hasattr(template, '_render_cache'):
if hasattr(template, "_render_cache"):
template._render_cache.clear()

View File

@ -2,8 +2,22 @@
All layers (ReActEngine, Portal, Chat) should use these constants
to ensure consistent user-facing messages.
G7/U2: Also hosts ``EmergencyError`` and ``EmergencyRules`` for the
three-tier fallback chain's Emergency layer (rule-based classifier,
no LLM). See ``docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md``.
"""
from dataclasses import dataclass
from typing import Any
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
# When LLM returns empty content after all fallback models exhausted
EMPTY_LLM_RESPONSE = (
"模型未返回有效内容,已尝试备用模型仍未成功。"
@ -16,3 +30,135 @@ MAX_STEPS_REACHED = "已达到最大推理步数,但仍未得到完整结论
# When a shell command succeeds but produces no output
SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]"
# ── G7/U2: Emergency layer ──────────────────────────────────────
@dataclass
class EmergencyError:
"""Structured error produced by the Emergency layer (rule-classified).
Carries a stable ``error_code`` for programmatic dispatch (frontend
retry UI, telemetry), a human-readable ``message`` mirroring
``EMPTY_LLM_RESPONSE`` style, actionable ``suggestions``, and the
original exception string for traceability.
The ``retryable`` flag distinguishes recoverable user errors
(timeout, loop, LLM hiccup) from internal bugs (retryable=False).
"""
error_code: str # "timeout" | "loop_detected" | "llm_failure" | "internal_error"
message: str # human-readable Chinese message
suggestions: list[str] # actionable user-facing suggestions
retryable: bool # whether a user retry might succeed
original_error: str # str(exc) for traceability
def to_dict(self) -> dict[str, Any]:
return {
"error_code": self.error_code,
"message": self.message,
"suggestions": list(self.suggestions),
"retryable": self.retryable,
"original_error": self.original_error,
}
def to_error_message(self) -> str:
"""Format as a single human-readable string with suggestions.
Mirrors ``EMPTY_LLM_RESPONSE`` style: ``<message>建议1) ... 2) ...``
"""
if not self.suggestions:
return self.message
suggestion_str = "".join(f"{i}) {s}" for i, s in enumerate(self.suggestions, 1))
# Strip trailing "" and prefix with "建议:"
suggestion_str = suggestion_str.rstrip("")
return f"{self.message}建议:{suggestion_str}"
# Default rule set: (exception_type, error_code, message, suggestions, retryable)
# ponytail: ceiling — rule-based, no LLM. Adding a new rule = append a tuple.
# Upgrade path: LLM-driven suggestions would require a separate classifier
# class (out of scope per brainstorm KTD4).
_DEFAULT_RULES: list[tuple[type[Exception], str, str, list[str], bool]] = [
(
TaskTimeoutError,
"timeout",
"任务执行超时。",
["稍后重试", "简化任务范围"],
True,
),
(
LoopDetectedError,
"loop_detected",
"检测到推理循环。",
["拆分任务", "检查工具参数"],
True,
),
(
LLMProviderError,
"llm_failure",
"LLM 服务调用失败。",
["稍后重试", "切换模型"],
True,
),
]
_DEFAULT_ERROR_CODE = "internal_error"
_DEFAULT_MESSAGE = "Agent 执行内部错误。"
_DEFAULT_SUGGESTIONS: list[str] = ["联系管理员"]
_DEFAULT_RETRYABLE = False
class EmergencyRules:
"""Rule-based classifier for the Emergency layer.
Maps exception types to ``EmergencyError`` instances. No LLM, no I/O
pure function of ``(exception, config)``.
Caller responsibility: ``TaskCancelledError`` MUST propagate as-is
(per KTD3); the caller checks ``isinstance(exc, TaskCancelledError)``
before invoking :meth:`classify`. Calling :meth:`classify` with a
``TaskCancelledError`` raises ``ValueError`` to surface the bug.
Config override shape (optional ``config`` arg):
```python
{
"timeout": {"suggestions": ["自定义建议"], "retryable": False},
"llm_failure": {"message": "自定义消息"},
}
```
"""
@staticmethod
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
if isinstance(exc, TaskCancelledError):
# Contract: caller must check before invoking classify.
raise ValueError(
"TaskCancelledError must propagate as-is; caller must check "
"before invoking EmergencyRules.classify"
)
config = config or {}
for exc_type, code, message, suggestions, retryable in _DEFAULT_RULES:
if isinstance(exc, exc_type):
override = config.get(code, {}) if isinstance(config, dict) else {}
return EmergencyError(
error_code=code,
message=override.get("message", message),
suggestions=override.get("suggestions", list(suggestions)),
retryable=override.get("retryable", retryable),
original_error=str(exc),
)
# Generic fallback
override = config.get(_DEFAULT_ERROR_CODE, {}) if isinstance(config, dict) else {}
return EmergencyError(
error_code=_DEFAULT_ERROR_CODE,
message=override.get("message", _DEFAULT_MESSAGE),
suggestions=override.get("suggestions", list(_DEFAULT_SUGGESTIONS)),
retryable=override.get("retryable", _DEFAULT_RETRYABLE),
original_error=str(exc),
)

206
src/agentkit/core/phase.py Normal file
View File

@ -0,0 +1,206 @@
"""Phase state machine for PLAN_EXEC mode (G6, R24/R25).
Four sequential phases enforce per-step tool whitelists:
PLANNING BUILDING VERIFICATION DELIVERY
KTD3 (Wave 3 plan): state machine lives in ReActEngine, not skill config.
KTD5: default whitelist matches brainstorm R24 (Planning: think/search;
Building: write_file; etc.).
KTD6: transitions are LLM-driven via AdvancePhaseTool; auto-advance is opt-in.
"""
from __future__ import annotations
import enum
import logging
import re
from dataclasses import dataclass, field, replace
from typing import Any
logger = logging.getLogger(__name__)
class PhaseState(enum.Enum):
"""Phases of the SOLO state machine (extends ExecutionMode.PLAN_EXEC)."""
PLANNING = "planning"
BUILDING = "building"
VERIFICATION = "verification"
DELIVERY = "delivery"
@classmethod
def next_of(cls, current: "PhaseState") -> "PhaseState | None":
"""Return the phase after `current`, or None if `current` is the last."""
order = [cls.PLANNING, cls.BUILDING, cls.VERIFICATION, cls.DELIVERY]
try:
idx = order.index(current)
except ValueError:
return None
if idx + 1 >= len(order):
return None
return order[idx + 1]
@classmethod
def from_string(cls, value: str) -> "PhaseState":
"""Parse from string (case-insensitive). Raises ValueError on unknown."""
try:
return cls(value.lower())
except ValueError as e:
valid = ", ".join(p.value for p in cls)
raise ValueError(f"Invalid phase name {value!r}. Valid: {valid}") from e
# Wildcard token meaning "all tools allowed in this phase".
WILDCARD = "*"
# Default bash command filter for PLANNING and VERIFICATION phases — blocks
# commands that mutate the filesystem or execute arbitrary code.
# ponytail: regex is intentionally conservative; misses some shell idioms
# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch
# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time.
# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT
# for `>`/`>>` operators (not word chars). Use a non-boundary alternation
# that matches `>` either as a standalone operator or after whitespace.
_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?<!\S)>|>>")
@dataclass(slots=True)
class PhasePolicy:
"""Per-phase tool whitelist + bash command filter for PLAN_EXEC mode.
The policy is enforced by ReActEngine._execute_loop before each tool
dispatch. A tool not in the current phase's whitelist is rejected with
a structured error returned to the LLM (the loop continues the LLM
gets to react to the rejection and either switch tools or call
AdvancePhaseTool).
Wildcard ``"*"`` in a phase's whitelist means "all tools allowed"
(used by DELIVERY by default).
"""
whitelist: dict[PhaseState, frozenset[str]]
bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict)
auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase)
start_phase: PhaseState = PhaseState.PLANNING
def __post_init__(self) -> None:
# Fail-fast: empty whitelist for a non-wildcard phase = bug.
for phase, tools in self.whitelist.items():
if not tools:
raise ValueError(
f"Phase {phase.value!r} has an empty whitelist — set ['*'] for "
f"'all tools allowed' or list specific tool names."
)
def is_tool_allowed(self, tool_name: str, phase: PhaseState) -> bool:
"""Return True if `tool_name` is allowed in `phase`."""
allowed = self.whitelist.get(phase, frozenset())
if WILDCARD in allowed:
return True
return tool_name in allowed
def is_bash_command_allowed(self, command: str, phase: PhaseState) -> bool:
"""Return True if `command` passes the bash filter for `phase`.
A None filter = no restriction. An empty command is allowed (ShellTool
separately rejects empty commands).
"""
pattern = self.bash_command_filter.get(phase)
if pattern is None:
return True
return not pattern.search(command)
def to_dict(self) -> dict[str, Any]:
"""Serialize for logging/telemetry. Not round-trippable (regex → str)."""
return {
"whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()},
"bash_command_filter": {
phase.value: (p.pattern if p else None)
for phase, p in self.bash_command_filter.items()
},
"auto_advance_after_steps": self.auto_advance_after_steps,
"start_phase": self.start_phase.value,
}
def default_policy() -> PhasePolicy:
"""Return the KTD5 default PhasePolicy.
Whitelist (R24):
- PLANNING: search, tool_search, read_file, shell (read-only)
- BUILDING: write_file, shell (full), read_file, search
- VERIFICATION: shell (test commands), read_file, search
- DELIVERY: all tools (wildcard)
Bash filter:
- PLANNING/VERIFICATION: blocks filesystem-mutating commands
(rm/mv/cp/mkdir/chmod/chown/>/>>)
- BUILDING/DELIVERY: no filter (full bash)
"""
return PhasePolicy(
whitelist={
# Tool name is "shell" (ShellTool default); bash_command_filter
# gates on the same name. Using "bash" here would make the filter
# dead code and block the LLM from shell access.
PhaseState.PLANNING: frozenset({"search", "tool_search", "read_file", "shell"}),
PhaseState.BUILDING: frozenset(
{"write_file", "shell", "read_file", "search", "tool_search"}
),
PhaseState.VERIFICATION: frozenset({"shell", "read_file", "search"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
bash_command_filter={
PhaseState.PLANNING: _DEFAULT_BASH_FILTER,
PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER,
PhaseState.BUILDING: None,
PhaseState.DELIVERY: None,
},
auto_advance_after_steps=None, # manual by default
start_phase=PhaseState.PLANNING,
)
def policy_from_config(config: dict[str, Any]) -> PhasePolicy | None:
"""Build a PhasePolicy from the `plan_exec` config section.
Returns None if `config` is empty or `enabled` is False (opt-out).
Config shape:
plan_exec:
enabled: true # default true if section present
auto_advance_after_steps: 5 # optional
start_phase: planning # optional, default planning
whitelist_override: # optional, merges with default
planning: [search, read_file]
building: [write_file, bash]
"""
if not config:
return None
if config.get("enabled", True) is False:
return None
policy = default_policy()
# Start phase
start_phase_str = config.get("start_phase")
if start_phase_str:
policy = replace(policy, start_phase=PhaseState.from_string(start_phase_str))
# Auto-advance override
if "auto_advance_after_steps" in config:
policy = replace(policy, auto_advance_after_steps=config["auto_advance_after_steps"])
# Whitelist override — merge with default (override wins on conflict)
override = config.get("whitelist_override") or {}
if override:
new_whitelist = dict(policy.whitelist)
for phase_name, tools in override.items():
phase = PhaseState.from_string(phase_name)
if not isinstance(tools, list):
raise ValueError(
f"whitelist_override[{phase_name!r}] must be a list, got {type(tools).__name__}"
)
new_whitelist[phase] = frozenset(str(t) for t in tools)
policy = replace(policy, whitelist=new_whitelist)
return policy

View File

@ -128,6 +128,10 @@ class TaskResult:
completed_at: datetime
metrics: dict | None = None
trace: Any | None = None
# G7/U2: Emergency layer structured error. None preserves existing contract
# (error_message alone carries the human-readable string). When set,
# carries serialized EmergencyError.to_dict() for programmatic dispatch.
error_struct: dict | None = None
def to_dict(self) -> dict:
d = {
@ -142,6 +146,8 @@ class TaskResult:
}
if self.trace is not None:
d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace
if self.error_struct is not None:
d["error_struct"] = self.error_struct
return d
@classmethod
@ -162,6 +168,7 @@ class TaskResult:
completed_at=completed_at or datetime.now(timezone.utc),
metrics=data.get("metrics"),
trace=data.get("trace"),
error_struct=data.get("error_struct"),
)

View File

@ -28,6 +28,7 @@ from agentkit.telemetry.metrics import (
if TYPE_CHECKING:
from agentkit.core.compressor import CompressionStrategy
from agentkit.core.middleware import MiddlewareChain
from agentkit.core.phase import PhasePolicy, PhaseState
from agentkit.core.trace import TraceRecorder
from agentkit.memory.retriever import MemoryRetriever
@ -168,6 +169,9 @@ class ReActEngine:
prompt_cache_enable: bool = True,
flush_interval_ms: int = 0,
max_reinjections: int = 1,
# U3/G6: PLAN_EXEC phase policy (opt-in). None = no enforcement
# (backward compat — all existing callers unaffected).
phase_policy: "PhasePolicy | None" = None,
):
if max_steps < 1:
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
@ -211,6 +215,15 @@ class ReActEngine:
self._loop_corrected: bool = False
# U6: Middleware chain (parallel integration, feature flag controlled)
self._middleware_chain = middleware_chain
# U3/G6: PLAN_EXEC phase state. None = no enforcement (default).
# When set, _execute_loop checks each tool call against the current
# phase's whitelist before dispatch.
self._phase_policy = phase_policy
self._current_phase: "PhaseState | None" = (
phase_policy.start_phase if phase_policy is not None else None
)
# Steps taken in the current phase (for auto-advance safety net).
self._steps_in_phase: int = 0
def reset(self) -> None:
"""Reset internal state for reuse across conversations.
@ -223,6 +236,99 @@ class ReActEngine:
# This method exists for API clarity and future stateful extensions.
self._loop_window.clear()
self._loop_corrected = False
# U3/G6: reset phase state to start_phase (if policy set). Each
# execute() call begins a fresh PLANNING phase.
if self._phase_policy is not None:
self._current_phase = self._phase_policy.start_phase
self._steps_in_phase = 0
# ── U3/G6: phase state machine ────────────────────────────────────
def advance_phase(self) -> "PhaseState | None":
"""Advance to the next phase. Returns the new phase, or None if
already at DELIVERY (final phase).
Called by AdvancePhaseTool when the LLM explicitly signals phase
completion. Also called by the auto-advance safety net when
``steps_in_phase >= auto_advance_after_steps``.
Returns None if no phase_policy is set (no-op).
"""
if self._phase_policy is None or self._current_phase is None:
return None
from agentkit.core.phase import PhaseState
nxt = PhaseState.next_of(self._current_phase)
if nxt is None:
# Already at DELIVERY — return None to signal no transition.
return None
previous = self._current_phase
self._current_phase = nxt
self._steps_in_phase = 0
logger.info(
"Phase transition: %s%s",
previous.value,
nxt.value,
)
return nxt
@property
def current_phase(self) -> "PhaseState | None":
"""Current phase (None if no phase_policy set)."""
return self._current_phase
def _maybe_auto_advance(self) -> bool:
"""Auto-advance phase if step budget exhausted. Returns True if advanced."""
if self._phase_policy is None or self._current_phase is None:
return False
threshold = self._phase_policy.auto_advance_after_steps
if threshold is None:
return False
if self._steps_in_phase >= threshold:
self.advance_phase()
return True
return False
def _check_phase_permission(
self, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any] | None:
"""Return None if tool is allowed; return a structured error dict if blocked.
The error dict replaces what `_execute_tool` would have returned
the loop continues, so the LLM can react to the rejection (call
AdvancePhaseTool or pick a different tool).
Also applies the bash_command_filter for `bash` tool calls.
"""
if self._phase_policy is None or self._current_phase is None:
return None
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
return {
"error": "phase_violation",
"message": (
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
f"Call `advance_phase` to move to the next phase."
),
"current_phase": self._current_phase.value,
"tool": tool_name,
"is_error": True,
}
# Bash command filter (only applies to shell tool — registered as "shell").
if tool_name == "shell":
command = str(arguments.get("command", ""))
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
return {
"error": "phase_violation",
"message": (
f"Bash command blocked in {self._current_phase.value} phase "
f"(filesystem-mutating operations not allowed during "
f"planning/verification). Command: {command[:100]}"
),
"current_phase": self._current_phase.value,
"tool": tool_name,
"is_error": True,
}
return None
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
"""检测重复工具调用模式。
@ -498,6 +604,14 @@ class ReActEngine:
if cancellation_token is not None:
cancellation_token.check()
# U3/G6: phase auto-advance safety net.
# Incremented per step (LLM call), not per tool_call. When
# auto_advance_after_steps is set, advance the phase after
# the LLM has been stuck in the same phase for N steps.
if self._phase_policy is not None:
self._steps_in_phase += 1
self._maybe_auto_advance()
# Think: 调用 LLM
llm_start = time.monotonic()
response = await self._llm_gateway.chat(
@ -1148,6 +1262,11 @@ class ReActEngine:
if cancellation_token is not None:
cancellation_token.check()
# U3/G6: phase auto-advance safety net (mirrors _execute_loop).
if self._phase_policy is not None:
self._steps_in_phase += 1
self._maybe_auto_advance()
# 超时检查
if effective_timeout > 0:
elapsed = time.monotonic() - _stream_start
@ -2069,6 +2188,20 @@ class ReActEngine:
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
) -> dict:
"""执行工具调用,处理成功和失败情况"""
# U3/G6: phase enforcement — check before dispatch. If the tool is
# blocked, return a structured error instead of dispatching. The loop
# still counts this as a step (the LLM gets to react to the rejection).
# `advance_phase` tool bypasses the check (it's the LLM's escape hatch).
if tool_name != "advance_phase":
block = self._check_phase_permission(tool_name, arguments)
if block is not None:
logger.info(
"Phase violation: tool %r blocked in %s phase",
tool_name,
self._current_phase.value if self._current_phase else "?",
)
return block
tool = self._find_tool(tool_name, tools)
if tool is None:
error_msg = f"Tool '{tool_name}' not found"

View File

@ -30,6 +30,7 @@ from typing import Any
from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.llm.gateway import LLMGateway
from agentkit.orchestrator.rollback import RollbackExecutor
from .expert import Expert
from .plan import (
@ -72,12 +73,17 @@ class TeamOrchestrator:
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
DEFAULT_ROLLBACK_TIMEOUT = 30.0
def __init__(
self,
team: ExpertTeam,
max_concurrent_phases: int | None = None,
checkpoint: Any = None,
workspace_root: str | None = None,
rollback_timeout: float | None = None,
) -> None:
self._team = team
# Track temporary agent names created for context isolation (KTD3)
@ -93,6 +99,10 @@ class TeamOrchestrator:
self._phase_semaphore = asyncio.Semaphore(limit)
# U7: Pipeline checkpoint for crash recovery
self._checkpoint = checkpoint
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
# Both default to no-op-friendly values so existing call sites behave identically.
self._workspace_root = workspace_root
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
async def execute(self, task: str) -> dict[str, Any]:
"""Execute a task in pipeline mode.
@ -262,8 +272,23 @@ class TeamOrchestrator:
else:
phase_results[ph.id] = result
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
# When phase configures both validation_command and rollback_command:
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
# 2. if validation fails, run rollback_command
# 3. if rollback passes (exit 0), save checkpoint
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
# When neither command is set, behavior is unchanged (existing save).
should_save_checkpoint = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(result, (Exception, asyncio.CancelledError))
):
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
# U7: Save checkpoint after phase finalizes (success or failure)
if self._checkpoint is not None:
if should_save_checkpoint and self._checkpoint is not None:
try:
await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e:
@ -393,9 +418,7 @@ class TeamOrchestrator:
# PENDING phases remain PENDING — will be executed by _run_pipeline
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
self._debate_count = sum(
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
)
self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
logger.info(
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
@ -688,9 +711,9 @@ class TeamOrchestrator:
and prev_phase.result
):
# U4: Resolve offloaded content from workspace
collaboration_outputs[contract.from_expert] = (
await self._read_dependency_output(prev_phase)
)
collaboration_outputs[
contract.from_expert
] = await self._read_dependency_output(prev_phase)
break
# Emit expert_step event
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
# Recursively mark their dependents
await self._mark_dependents_failed(ph.id, plan, phase_results)
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
"""G9/U4: run validation_command + rollback_command for a failed phase.
Returns True if checkpoint save should proceed (R21 ordering).
- Validation passes save checkpoint (phase state recoverable)
- Validation fails, rollback passes save checkpoint (rolled back state)
- Validation fails, rollback fails skip checkpoint (broken state)
- Subprocess spawn failure or timeout skip checkpoint
"""
executor = RollbackExecutor(
working_dir=self._workspace_root,
timeout=self._rollback_timeout,
)
await self._broadcast_event(
"phase_rollback_started",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_command": ph.validation_command,
"rollback_command": ph.rollback_command,
},
)
# ponytail: validate first; if validation passes, rollback is skipped (no need).
validation = await executor.validate(ph.validation_command or "")
if validation.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": False,
"validation_passed": True,
},
)
return True
rollback = await executor.execute(ph.rollback_command or "")
if rollback.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": True,
"validation_passed": False,
"rollback_stdout": rollback.stdout,
},
)
return True
logger.error(
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
)
await self._broadcast_event(
"phase_rollback_failed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_passed": False,
"rollback_exit_code": rollback.exit_code,
"rollback_stderr": rollback.stderr,
},
)
return False
async def _synthesize_results(
self, lead: Expert, task: str, completed_phases: list[PlanPhase]
) -> dict[str, Any]:

View File

@ -182,6 +182,11 @@ class PlanPhase:
collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
rework_count: int = 0
review_feedback: str | None = None
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
# validation_command runs first; if it fails, rollback_command runs.
# canonical rollback pattern: `git checkout <specific_files>`.
validation_command: str | None = None
rollback_command: str | None = None
def to_dict(self) -> dict[str, Any]:
"""序列化为字典"""
@ -192,7 +197,7 @@ class PlanPhase:
result_str = self.result.get("content", str(self.result))
else:
result_str = str(self.result)
return {
out: dict[str, Any] = {
"id": self.id,
"name": self.name,
"assigned_expert": self.assigned_expert,
@ -206,6 +211,12 @@ class PlanPhase:
"rework_count": self.rework_count,
"review_feedback": self.review_feedback,
}
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
if self.validation_command is not None:
out["validation_command"] = self.validation_command
if self.rollback_command is not None:
out["rollback_command"] = self.rollback_command
return out
@classmethod
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
@ -230,6 +241,8 @@ class PlanPhase:
collaboration_contracts=contracts,
rework_count=data.get("rework_count", 0),
review_feedback=data.get("review_feedback"),
validation_command=data.get("validation_command"),
rollback_command=data.get("rollback_command"),
)

View File

@ -203,6 +203,9 @@ class LLMConfig:
model_aliases: dict[str, str] = field(default_factory=dict)
fallbacks: dict[str, list[str]] = field(default_factory=dict)
cache: CacheConfig | None = None
# G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization).
# Resolved via existing model_aliases mechanism. None = use main model only.
auxiliary_model: str | None = None
@classmethod
def from_dict(cls, data: dict) -> "LLMConfig":
@ -254,4 +257,5 @@ class LLMConfig:
model_aliases=data.get("model_aliases", {}),
fallbacks=data.get("fallbacks", {}),
cache=cache,
auxiliary_model=data.get("auxiliary_model"),
)

View File

@ -0,0 +1,97 @@
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
复用 VerificationLoop asyncio.create_subprocess_shell 模式 (KTD7)
绕过 ShellTool避免 confirm_callback `git checkout` 的拦截
设计依据
- KTD6: 回滚是 opt-in 行为未配置 rollback_command 时不会执行
- KTD7: 不走 ShellTool避免 _is_dangerous 触发 confirm_callback
- R21: checkpoint.save 仅在回滚校验通过后调用
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class RollbackResult:
"""单次 subprocess 执行结果"""
passed: bool
exit_code: int
stdout: str
stderr: str
command: str
class RollbackExecutor:
"""执行 validation_command / rollback_command 的子进程封装
VerificationLoop 同构但语义不同
- validate(): 返回 passed=False 表示校验失败需要触发 rollback
- execute(): 返回 passed=False 表示回滚本身失败需跳过 checkpoint.save (R21)
"""
def __init__(
self,
working_dir: str | None = None,
timeout: float = 30.0,
) -> None:
self._working_dir = working_dir
self._timeout = timeout
async def _run(self, command: str) -> RollbackResult:
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self._working_dir,
)
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Failed to spawn command: {e}",
command=command,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Command timed out after {self._timeout}s: {command}",
command=command,
)
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
return RollbackResult(
passed=proc.returncode == 0,
exit_code=proc.returncode if proc.returncode is not None else -1,
stdout=out_str,
stderr=err_str,
command=command,
)
async def validate(self, command: str) -> RollbackResult:
"""运行 validation_commandpassed=False 表示需要触发 rollback"""
return await self._run(command)
async def execute(self, command: str) -> RollbackResult:
"""运行 rollback_commandpassed=False 表示回滚本身失败 (R21)"""
return await self._run(command)

View File

@ -0,0 +1,199 @@
"""G7/U3 — Three-tier fallback chain (main → Recovery → Emergency).
Wired at chat.py REST send_message endpoint. Composes U2's EmergencyRules
with existing ReflexionEngine for the Recovery layer.
Scope (KTD5): Only the chat REST path is wrapped. CLI / ReWOO / Reflexion
internal ReAct calls are NOT wrapped (would create recursive loop).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.fallback import EmergencyError, EmergencyRules
from agentkit.core.react import ReActEngine, ReActResult
from agentkit.core.reflexion import ReflexionEngine
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
# ReActResult.status values that indicate soft failure → trigger Recovery.
# "success" is the only clean-pass; everything else is fallback-worthy.
_SOFT_FAILURE_STATUSES = frozenset({"empty_fallback", "verify_failed", "timeout"})
@dataclass
class ChatExecutionResult:
"""Wrapper produced by execute_with_fallback_chain.
Carries a ReActResult-like ``output`` field plus an optional
``error_struct`` (set only when Emergency tier fires). The chat
handler reads ``.output`` for the assistant reply and ``.error_struct``
for the optional structured error payload.
"""
output: str
status: str # "success" | "recovered" | "emergency"
error_struct: dict[str, Any] | None = None
trajectory: list[Any] = field(default_factory=list)
total_steps: int = 0
total_tokens: int = 0
fallback_strategy: str | None = None
def _react_to_chat_result(react: ReActResult) -> ChatExecutionResult:
return ChatExecutionResult(
output=react.output,
status="success",
trajectory=react.trajectory,
total_steps=react.total_steps,
total_tokens=react.total_tokens,
fallback_strategy=react.fallback_strategy,
)
def _reflexion_to_chat_result(reflexion_result: Any) -> ChatExecutionResult:
"""Best-effort conversion from ReflexionResult to ChatExecutionResult."""
output = getattr(reflexion_result, "output", None) or getattr(
reflexion_result, "final_answer", ""
)
return ChatExecutionResult(
output=output or "",
status="recovered",
trajectory=getattr(reflexion_result, "trajectory", []) or [],
total_steps=getattr(reflexion_result, "total_steps", 0),
total_tokens=getattr(reflexion_result, "total_tokens", 0),
fallback_strategy="reflexion_recovery",
)
def _to_emergency(exc: Exception, config: dict | None) -> ChatExecutionResult:
emergency: EmergencyError = EmergencyRules.classify(exc, config)
return ChatExecutionResult(
output=emergency.to_error_message(),
status="emergency",
error_struct=emergency.to_dict(),
fallback_strategy="emergency",
)
async def execute_with_fallback_chain(
*,
react_engine: ReActEngine,
llm_gateway: LLMGateway,
messages: list[dict[str, str]],
tools: list[Any] | None,
model: str,
agent_name: str,
system_prompt: str | None,
fallback_chain_config: dict | None = None,
) -> ChatExecutionResult:
"""Three-tier fallback chain: Main → Recovery (ReflexionEngine) → Emergency.
KTD5: only this entry point wraps the chain. ReflexionEngine's internal
ReAct call bypasses the chain (no recursive loop possible).
Returns ChatExecutionResult with status:
- "success": main agent succeeded
- "recovered": main failed, ReflexionEngine recovery succeeded
- "emergency": main failed, recovery failed/exhausted, Emergency layer fired
"""
config = fallback_chain_config or {}
recovery_cfg = config.get("recovery", {}) if isinstance(config, dict) else {}
emergency_cfg = config.get("emergency", {}) if isinstance(config, dict) else {}
recovery_enabled = recovery_cfg.get("enabled", True) if isinstance(recovery_cfg, dict) else True
emergency_enabled = (
emergency_cfg.get("enabled", True) if isinstance(emergency_cfg, dict) else True
)
max_reflections = recovery_cfg.get("max_retries", 1) if isinstance(recovery_cfg, dict) else 1
# ── Tier 1: Main ──────────────────────────────────────────────
main_exc: Exception | None = None
try:
result = await react_engine.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
system_prompt=system_prompt,
)
if result.status == "success":
return _react_to_chat_result(result)
# Soft failure (empty_fallback / verify_failed / timeout) → trigger Recovery
if result.status in _SOFT_FAILURE_STATUSES:
main_exc = AgentSoftFailureError(
f"main agent status={result.status}: {result.output[:200]}"
)
else:
# Unknown status — treat as success-like (don't trigger recovery)
return _react_to_chat_result(result)
except TaskCancelledError:
# KTD3: TaskCancelledError propagates as-is, NOT routed to Emergency.
raise
except (TaskTimeoutError, LoopDetectedError, LLMProviderError) as exc:
main_exc = exc
except Exception as exc: # noqa: BLE001 - last-resort catch for Emergency routing
main_exc = exc
# ── Tier 2: Recovery (ReflexionEngine) ────────────────────────
if recovery_enabled and main_exc is not None:
try:
reflexion = ReflexionEngine(
llm_gateway=llm_gateway,
max_reflections=max_reflections,
)
recovery_result = await reflexion.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
system_prompt=system_prompt,
)
# Recovery succeeds if Reflexion reports success or produces output.
recovery_status = getattr(recovery_result, "status", "")
if recovery_status == "success" or getattr(recovery_result, "output", None):
return _reflexion_to_chat_result(recovery_result)
logger.warning(
f"Recovery layer did not succeed (status={recovery_status}), "
f"falling through to Emergency"
)
except TaskCancelledError:
raise
except Exception as recovery_exc: # noqa: BLE001
logger.warning(f"Recovery layer raised: {recovery_exc}; falling through to Emergency")
# ── Tier 3: Emergency ─────────────────────────────────────────
if not emergency_enabled:
# Re-raise original exception if Emergency disabled.
if main_exc is not None:
raise main_exc
# No exception but no success either — synthesise an emergency-style result.
return ChatExecutionResult(
output="Agent 未返回有效结果且 Emergency 层已禁用。",
status="emergency",
fallback_strategy="emergency_disabled",
)
# main_exc may be None if main returned soft-failure status without raising.
# Synthesize a generic exception for Emergency classification.
exc_for_emergency = main_exc or AgentSoftFailureError("soft failure without exception")
return _to_emergency(exc_for_emergency, config)
class AgentSoftFailureError(Exception):
"""Internal marker — main agent returned a soft-failure status without raising.
Used to feed the Emergency classifier when main status was e.g.
``empty_fallback`` (no exception raised, but result not usable).
Classified as ``internal_error`` by EmergencyRules (generic fallback).
"""

View File

@ -119,6 +119,11 @@ class ServerConfig:
prompt_cache: dict[str, Any] | None = None,
streaming: dict[str, Any] | None = None,
verification: dict[str, Any] | None = None,
rollback: dict[str, Any] | None = None,
fallback_chain: dict[str, Any] | None = None,
# G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled).
# Parsed via PhasePolicy.policy_from_config() at chat.py wiring time.
plan_exec: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None,
):
self.host = host
@ -153,6 +158,16 @@ class ServerConfig:
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
# verification_enabled=False 时此配置无效
self.verification = verification or {}
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
self.rollback = rollback or {}
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
# controls three-tier chain at chat.py REST send_message (KTD5).
self.fallback_chain = fallback_chain or {}
# G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled).
# Resolved to PhasePolicy via agentkit.core.phase.policy_from_config()
# at chat.py WebSocket wiring time (U4).
self.plan_exec = plan_exec or {}
self.on_change = on_change
# Config watching state
@ -240,6 +255,12 @@ class ServerConfig:
prompt_cache_data = data.get("prompt_cache", {})
streaming_data = data.get("streaming", {})
verification_data = data.get("verification", {})
# G9/U4: rollback 配置 (从 YAML 读取opt-in)
rollback_data = data.get("rollback", {})
# G7/U3: fallback_chain 配置 (从 YAML 读取)
fallback_chain_data = data.get("fallback_chain", {})
# G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in)
plan_exec_data = data.get("plan_exec", {})
return cls(
host=server.get("host", "0.0.0.0"),
@ -271,6 +292,9 @@ class ServerConfig:
prompt_cache=prompt_cache_data,
streaming=streaming_data,
verification=verification_data,
rollback=rollback_data,
fallback_chain=fallback_chain_data,
plan_exec=plan_exec_data,
)
@staticmethod
@ -320,6 +344,9 @@ class ServerConfig:
model_aliases=model_aliases,
fallbacks=data.get("fallbacks", {}),
cache=cache_config,
# G4/U1: auxiliary model alias for cost-sensitive summarization.
# Resolved via model_aliases; None = no auxiliary routing.
auxiliary_model=data.get("auxiliary_model"),
)
@staticmethod

View File

@ -25,10 +25,13 @@ from fastapi.responses import FileResponse
from pydantic import BaseModel
from agentkit.chat.skill_routing import ExecutionMode
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine
from agentkit.server._fallback_chain import execute_with_fallback_chain
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
from agentkit.tools.advance_phase import AdvancePhaseTool
logger = logging.getLogger(__name__)
@ -46,6 +49,8 @@ class CreateSessionRequest(BaseModel):
class SendMessageRequest(BaseModel):
content: str
role: str = "user"
# Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only).
execution_mode: str | None = None
class SessionResponse(BaseModel):
@ -582,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
if request.execution_mode == "plan_exec":
raise HTTPException(
status_code=501,
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
)
# Append user message
await sm.append_message(
session_id=session_id,
@ -610,7 +622,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
system_prompt = getattr(agent, "_system_prompt", None) or (
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
)
result = await react_engine.execute(
# G7/U3: Three-tier fallback chain (main → Recovery → Emergency).
# Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass.
server_config = getattr(req.app.state, "server_config", None)
fallback_chain_cfg = (
getattr(server_config, "fallback_chain", None) if server_config else None
)
chat_result = await execute_with_fallback_chain(
react_engine=react_engine,
llm_gateway=req.app.state.llm_gateway,
messages=chat_messages,
tools=tools,
model=agent.get_model()
@ -618,16 +638,26 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
else getattr(agent, "_llm_model", "default"),
agent_name=agent.name,
system_prompt=system_prompt,
fallback_chain_config=fallback_chain_cfg,
)
# Append assistant reply
assistant_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result),
content=chat_result.output,
agent_name=agent.name,
)
return _message_to_response(assistant_msg)
response = _message_to_response(assistant_msg)
# Attach structured error payload when Emergency tier fired.
if chat_result.error_struct is not None:
response_dict = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)
response_dict["error_struct"] = chat_result.error_struct
response_dict["fallback_status"] = chat_result.status
return response_dict
return response
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
@ -1060,21 +1090,73 @@ async def _handle_chat_message(
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
return
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
# currently fall back to REACT with a warning.
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
# U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
# KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
# fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
phase_policy: PhasePolicy | None = None
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
server_config = getattr(websocket.app.state, "server_config", None)
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
if plan_exec_cfg.get("enabled", True) is False:
# Explicit opt-out → fall back to REACT.
logger.info(
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
"falling back to REACT for session %s",
session_id,
)
else:
try:
phase_policy = policy_from_config(plan_exec_cfg)
if phase_policy is None:
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
phase_policy = default_policy()
except Exception as e:
logger.error(
"PLAN_EXEC phase policy construction failed for session %s: %s",
session_id,
e,
)
await websocket.send_json(
{
"type": "error",
# Truncate to 200 chars to match nearby error paths and
# avoid leaking config internals (see chat.py:1090, 1320).
"data": {"message": f"phase policy error: {str(e)[:200]}"},
}
)
return
# Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB
# still fall back to REACT with a warning. PLAN_EXEC is handled above.
if routing.execution_mode not in (
ExecutionMode.REACT,
ExecutionMode.SKILL_REACT,
ExecutionMode.PLAN_EXEC,
):
logger.warning(
f"Execution mode {routing.execution_mode.value} not yet supported "
f"in chat WebSocket, falling back to REACT"
)
# Execute Agent with streaming
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
# PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
# agent's _react_engine — it has no policy).
if phase_policy is not None:
react_engine = ReActEngine(
llm_gateway=websocket.app.state.llm_gateway,
phase_policy=phase_policy,
)
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
routing.tools = list(routing.tools) + [advance_phase_tool]
else:
react_engine.reset()
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
else:
react_engine.reset()
# Create confirmation handler that sends request to frontend and waits for reply
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
@ -1130,6 +1212,9 @@ async def _handle_chat_message(
try:
final_content = ""
token_buffer: list[str] = []
# Track phase transitions for phase_changed events (PLAN_EXEC only).
# For non-PLAN_EXEC modes, current_phase is always None → no events.
prev_phase = react_engine.current_phase
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=routing.tools,
@ -1207,6 +1292,22 @@ async def _handle_chat_message(
}
)
# U4/G6: emit phase_changed event when the phase state machine
# transitions (PLAN_EXEC only). For non-PLAN_EXEC modes,
# current_phase is always None → this branch never fires.
curr_phase = react_engine.current_phase
if curr_phase != prev_phase:
await websocket.send_json(
{
"type": "phase_changed",
"data": {
"phase": curr_phase.value if curr_phase else None,
"previous": prev_phase.value if prev_phase else None,
},
}
)
prev_phase = curr_phase
# Append assistant reply to session
if final_content:
await sm.append_message(

View File

@ -18,6 +18,8 @@ from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
from agentkit.tools.file_read import ReadFileTool
from agentkit.tools.advance_phase import AdvancePhaseTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try:
@ -52,4 +54,6 @@ __all__ = [
"OutputParser",
"ParsedOutput",
"ErrorType",
"ReadFileTool",
"AdvancePhaseTool",
]

View File

@ -0,0 +1,82 @@
"""AdvancePhaseTool — LLM-driven phase transition (G6, KTD6).
Registered alongside other tools when ReActEngine has a phase_policy set.
The LLM calls this tool to signal "I'm done planning, move to building".
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from agentkit.tools.base import Tool
if TYPE_CHECKING:
from agentkit.core.react import ReActEngine
logger = logging.getLogger(__name__)
class AdvancePhaseTool(Tool):
"""Tool that advances the ReActEngine's current phase.
KTD6: LLM-driven phase transitions. Auto-advance is opt-in via
``plan_exec.auto_advance_after_steps``; this tool is the manual path.
The tool holds a weak reference to the engine (via bound method
``engine.advance_phase``) registered only when phase_policy is set.
"""
def __init__(
self,
engine: "ReActEngine",
name: str = "advance_phase",
description: str | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description
or (
"Advance the PLAN_EXEC phase state machine to the next phase "
"(Planning → Building → Verification → Delivery). Call this "
"when you have finished the current phase's work and are ready "
"to move on. Returns the new phase name or an error if you "
"are already at the final (Delivery) phase."
),
input_schema={
"type": "object",
"properties": {},
"additionalProperties": False,
},
version=version,
tags=tags or ["phase", "control"],
)
self._engine = engine
async def execute(self, **kwargs) -> dict[str, Any]:
# Capture previous phase before transition (engine is single-threaded per request).
previous = self._engine.current_phase
new_phase = self._engine.advance_phase()
if new_phase is None:
# Either no policy set, or already at DELIVERY.
current = self._engine.current_phase
if current is None:
return {
"is_error": True,
"error": "no_phase_policy",
"message": "No phase policy is set — advance_phase is a no-op.",
}
return {
"is_error": True,
"error": "already_at_final_phase",
"message": (f"Already at final phase ({current.value}). Cannot advance further."),
"current_phase": current.value,
}
return {
"is_error": False,
"previous_phase": previous.value if previous else "",
"current_phase": new_phase.value,
"message": f"Phase advanced to {new_phase.value}.",
}

View File

@ -0,0 +1,262 @@
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
Backward compatible with the pre-existing `_FakeTool` benchmark shape when
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
the line range of the first matching symbol via `SymbolExtractor`.
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool keeps the
file-reading contract clean and gives the LLM a focused schema.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from agentkit.tools.base import Tool
from agentkit.tools.symbol_extractor import (
SymbolSpan,
extract_symbols_from_file,
language_for_extension,
)
logger = logging.getLogger(__name__)
class ReadFileTool(Tool):
"""Read a file from the filesystem, optionally sliced to a single symbol.
Tool name `read_file` matches the reserved entry in
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
implementation only `_FakeTool` stubs in `cli/benchmark.py`).
Backward-compat contract: `symbol=None` returns the full file content,
matching the shape `{"path": ...}` that downstream callers (benchmark,
phase whitelist) already expect.
"""
def __init__(
self,
name: str = "read_file",
description: str | None = None,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description
or (
"Read a file from the filesystem. By default returns the full file "
"content. Pass `symbol` (function/class/struct name) to slice to just "
"that symbol's line range — saves context when you only need one "
"function from a large file. Pass `start_line`/`end_line` for manual "
"slicing. If `symbol` is set but not found, returns the available "
"symbol names so you can retry."
),
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["io", "file", "read"],
)
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read (absolute or relative to cwd).",
},
"symbol": {
"type": "string",
"description": (
"Optional: name of a function/class/struct/method to slice to. "
"When set, returns only the line range of the first matching "
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
),
},
"start_line": {
"type": "integer",
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
"end_line": {
"type": "integer",
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
},
"required": ["path"],
"additionalProperties": False,
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"content": {"type": "string"},
"path": {"type": "string"},
"start_line": {"type": "integer"},
"end_line": {"type": "integer"},
"symbol": {"type": "string"},
"available_symbols": {
"type": "array",
"items": {"type": "string"},
"description": "Populated when `symbol` is set but not found.",
},
"note": {"type": "string"},
"is_error": {"type": "boolean"},
"error": {"type": "string"},
},
}
async def execute(self, **kwargs) -> dict[str, Any]:
raw_path = kwargs.get("path")
if not raw_path:
return self._error("`path` is required")
path = Path(raw_path)
if not path.is_absolute():
path = path.resolve()
symbol = kwargs.get("symbol")
start_line = kwargs.get("start_line")
end_line = kwargs.get("end_line")
# Validate/sanitize line overrides.
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
if start_line is not None and end_line is not None and end_line < start_line:
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
# Filesystem checks.
if not path.exists():
return self._error(f"File not found: {path}", path=str(path))
if path.is_dir():
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
try:
content = path.read_text(encoding="utf-8", errors="replace")
except PermissionError as e:
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
except OSError as e:
return self._error(f"Failed to read {path}: {e}", path=str(path))
lines = content.splitlines()
total_lines = len(lines)
# Manual slicing takes precedence over symbol (per plan U1 Approach).
if start_line is not None or end_line is not None:
s = max(1, start_line or 1)
e = min(total_lines, end_line or total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Symbol slicing.
if symbol:
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
# Unsupported extension: return full file with note (per plan U1 Edge case).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
"is_error": False,
}
spans, _lang = extract_symbols_from_file(path)
# Re-extract using the content we already read so we don't read the file twice.
if not spans:
# Try extraction from in-memory content (path-based extraction may
# have failed silently on OSError; we already read it successfully).
from agentkit.tools.symbol_extractor import get_extractor
extractor = get_extractor(language)
if extractor is not None:
spans = extractor.extract_symbols(content, language)
match = _find_symbol(spans, symbol)
if match is None:
available = sorted({s.name for s in spans})
return {
"content": "",
"path": str(path),
"symbol": symbol,
"available_symbols": available,
"is_error": False,
"note": (
f"Symbol {symbol!r} not found in {path.name}. "
f"Available: {', '.join(available) if available else '(none)'}"
),
}
s = match.start_line
e = min(match.end_line, total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"symbol": symbol,
"symbol_kind": match.kind,
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Default: full file (characterization baseline — matches _FakeTool shape).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"is_error": False,
}
@staticmethod
def _error(
message: str, *, path: str | None = None, detail: str | None = None
) -> dict[str, Any]:
result: dict[str, Any] = {
"content": "",
"is_error": True,
"error": message,
}
if path is not None:
result["path"] = path
if detail is not None:
result["detail"] = detail
return result
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
"""Find the first symbol matching `name`. Case-sensitive.
ponytail: linear scan is fine for typical file symbol counts (<100). The
extractor already returns symbols sorted by start_line; first match wins
for ambiguous overloads (e.g., Python classes with same name in different
modules not relevant within one file).
"""
for span in spans:
if span.name == name:
return span
return None

View File

@ -0,0 +1,278 @@
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
`SymbolExtractor` protocol is the upgrade seam a future TreeSitterSymbolExtractor
can replace RegexSymbolExtractor behind the same interface.
ponytail: regex extractor covers ~80% case (top-level function/class/struct
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
"""
from __future__ import annotations
import ast
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class SymbolSpan:
"""A located symbol — name, kind, and 1-based inclusive line range."""
name: str
kind: str # "function" | "class" | "method" | "struct" | "impl"
start_line: int # 1-based, inclusive
end_line: int # 1-based, inclusive
@runtime_checkable
class SymbolExtractor(Protocol):
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
"""Return all symbols found in `content`.
`language` is the file extension without leading dot (e.g. "py", "ts").
Implementations must never raise on extraction failure return [] on
parse errors and let the caller decide the fallback (full-file read).
"""
...
# ---------------------------------------------------------------------------
# Python — stdlib ast
# ---------------------------------------------------------------------------
class AstSymbolExtractor:
"""Python symbol extractor using the stdlib `ast` module.
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
functions inside classes. The end_line is the last line of the node's
source segment (decorator-inclusive).
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
if language != "py":
return []
try:
tree = ast.parse(content)
except SyntaxError as e:
logger.debug("ast.parse failed: %s", e)
return []
lines = content.splitlines()
spans: list[SymbolSpan] = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
kind = "method" if _is_method(node) else "function"
spans.append(_span_from_node(node, kind, lines))
elif isinstance(node, ast.ClassDef):
spans.append(_span_from_node(node, "class", lines))
return spans
def _is_method(node: ast.AST) -> bool:
"""A FunctionDef is a method if its parent is a ClassDef.
`ast.walk` doesn't expose parentage, so we approximate by checking the
node's col_offset == 4 (indented inside a class body). ponytail: this
misses methods in deeply nested classes ceiling noted; upgrade path =
ast.NodeVisitor with parent tracking.
"""
return getattr(node, "col_offset", 0) > 0
def _span_from_node(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
kind: str,
lines: list[str],
) -> SymbolSpan:
# ast line numbers are 1-based; start at decorator if present (lineno points
# to the def/class keyword, decorators are above). Use node.lineno for start
# so the returned range matches what the user sees at the def keyword.
start = node.lineno
# node.end_lineno is the last line of the node body (None on old Pythons).
end = node.end_lineno or start
# Clamp to actual file length (defensive — ast should not exceed, but
# malformed files with no trailing newline can confuse end_lineno).
if end > len(lines):
end = len(lines)
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
# ---------------------------------------------------------------------------
# Regex extractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
# Each pattern matches a declaration and captures the symbol name in group 1.
# Patterns use re.MULTILINE so ^ matches line starts.
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
"ts": [
(
"function",
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
),
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
re.MULTILINE,
),
),
],
"js": [
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
),
),
],
"go": [
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
],
"rs": [
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
],
"java": [
(
"function",
re.compile(
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
re.MULTILINE,
),
),
(
"class",
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
),
],
}
class RegexSymbolExtractor:
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
Returns SymbolSpans whose end_line is approximated by the next blank line
or next-symbol start (whichever comes first). ponytail: this is an
approximation true block-end requires language-aware brace matching.
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
tree-sitter.
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
patterns = _REGEX_PATTERNS.get(language)
if not patterns:
return []
lines = content.splitlines()
# Collect (line_no, name, kind) tuples first, then compute end_line
# as the line before the next symbol starts (or EOF).
raw_hits: list[tuple[int, str, str]] = []
for kind, pattern in patterns:
for m in pattern.finditer(content):
# Convert match offset to 1-based line number.
line_no = content[: m.start()].count("\n") + 1
raw_hits.append((line_no, m.group(1), kind))
if not raw_hits:
return []
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
seen: set[tuple[int, str]] = set()
unique: list[tuple[int, str, str]] = []
for line_no, name, kind in raw_hits:
key = (line_no, name)
if key in seen:
continue
seen.add(key)
unique.append((line_no, name, kind))
unique.sort(key=lambda x: x[0])
spans: list[SymbolSpan] = []
for i, (start_line, name, kind) in enumerate(unique):
if i + 1 < len(unique):
# End at line before next symbol starts, capped at file length.
end_line = unique[i + 1][0] - 1
else:
end_line = len(lines)
if end_line < start_line:
end_line = start_line
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
return spans
# ---------------------------------------------------------------------------
# Dispatch by file extension
# ---------------------------------------------------------------------------
_EXTENSION_LANGUAGE: dict[str, str] = {
".py": "py",
".ts": "ts",
".tsx": "ts",
".js": "js",
".jsx": "js",
".mjs": "js",
".cjs": "js",
".go": "go",
".rs": "rs",
".java": "java",
}
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
_REGEX_EXTRACTOR = RegexSymbolExtractor()
def language_for_extension(ext: str) -> str:
"""Return the language key for a file extension (with or without leading dot).
Returns "" for unsupported extensions.
"""
if not ext.startswith("."):
ext = "." + ext
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
def get_extractor(language: str) -> SymbolExtractor | None:
"""Return the appropriate extractor for `language`, or None if unsupported."""
if language == "py":
return _DEFAULT_EXTRACTOR
if language in _REGEX_PATTERNS:
return _REGEX_EXTRACTOR
return None
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
"""Read a file and return (symbols, language).
Returns ([], "") if the extension is unsupported or the file cannot be read.
Never raises callers use this for fallback routing.
"""
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
return [], ""
try:
content = path.read_text(encoding="utf-8", errors="replace")
except OSError as e:
logger.debug("read failed for %s: %s", path, e)
return [], language
extractor = get_extractor(language)
if extractor is None:
return [], language
return extractor.extract_symbols(content, language), language

View File

@ -0,0 +1,531 @@
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
Per plan U4 Execution note: characterization-first verify that existing
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
(no regression). Then add PLAN_EXEC wiring tests.
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
from agentkit.core.phase import PhaseState
from agentkit.tools.advance_phase import AdvancePhaseTool
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def app_with_chat():
"""Create a FastAPI app with Chat routes and mocked dependencies."""
from fastapi import FastAPI
from agentkit.server.routes.chat import router
app = FastAPI()
app.include_router(router, prefix="/api/v1")
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
app.state.session_manager = SessionManager(store=InMemorySessionStore())
app.state.llm_gateway = MagicMock()
app.state.agent_pool = MagicMock()
app.state.server_config = MagicMock()
app.state.server_config.api_key = None
app.state.server_config.plan_exec = {}
return app
@pytest.fixture
def client(app_with_chat):
return TestClient(app_with_chat)
def _make_routing(
execution_mode: ExecutionMode = ExecutionMode.REACT,
tools: list | None = None,
) -> SkillRoutingResult:
"""Build a minimal SkillRoutingResult for testing."""
return SkillRoutingResult(
execution_mode=execution_mode,
tools=tools or [],
clean_content="test message",
model="default",
agent_name="test-agent",
system_prompt=None,
skill_name=None,
)
def _make_websocket_mock(app) -> MagicMock:
"""Build a mock WebSocket with app.state and async send_json."""
ws = MagicMock()
ws.app = app
ws.send_json = AsyncMock()
return ws
def _make_agent_mock() -> MagicMock:
"""Build a mock Agent with _tool_registry and _react_engine."""
agent = MagicMock()
agent.name = "test-agent"
agent._tool_registry = MagicMock()
agent._tool_registry.list_tools.return_value = []
agent._system_prompt = None
# _react_engine is None to force the code path that creates a new engine
agent._react_engine = None
agent.get_model.return_value = "default"
return agent
def _make_session_manager_mock() -> MagicMock:
"""Build a mock SessionManager with async methods."""
sm = MagicMock()
# get_session returns a mock session with agent_name="test-agent"
session = MagicMock()
session.agent_name = "test-agent"
session.status = "active"
sm.get_session = AsyncMock(return_value=session)
sm.get_chat_messages = AsyncMock(return_value=[])
sm.append_message = AsyncMock()
return sm
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
"""Wire up app.state so _handle_chat_message finds the right routing."""
app.state.agent_pool.get_agent.return_value = agent
app.state.request_preprocessor = MagicMock()
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
# ---------------------------------------------------------------------------
# REST — PLAN_EXEC raises 501 (KTD4)
# ---------------------------------------------------------------------------
class TestRestPlanExec501:
def test_rest_plan_exec_returns_501(self, client):
"""REST send_message with execution_mode=plan_exec → 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello", "execution_mode": "plan_exec"},
)
assert msg_resp.status_code == 501
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
def test_rest_react_mode_still_works(self, client):
"""REST send_message without execution_mode doesn't 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
# No execution_mode field → should NOT trigger 501.
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello"},
)
assert msg_resp.status_code != 501
# ---------------------------------------------------------------------------
# Characterization — REWOO still falls back to REACT (no regression)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield # async generator marker
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# REWOO should NOT build a phase_policy
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
app_with_chat,
):
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
ws = _make_websocket_mock(app_with_chat)
captured_engine: list = []
captured_tools: list = []
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
captured_tools.extend(kwargs.get("tools", []))
captured_engine.append(self)
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_engine) == 1
engine = captured_engine[0]
assert engine._phase_policy is not None
assert engine._current_phase == PhaseState.PLANNING
# AdvancePhaseTool was registered in the tools list
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_policy: list = []
class _StubEngine:
def __init__(self, **kwargs):
captured_policy.append(kwargs.get("phase_policy"))
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_policy) == 1
assert captured_policy[0] is not None
# Default policy: PLANNING allows search but not write_file
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
@pytest.mark.asyncio
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"enabled": False}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# enabled=False → no phase_policy (falls back to REACT)
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Error path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
"""Error: phase policy construction fails → error event sent, returns early."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
engine_constructor_called = []
class _StubEngine:
def __init__(self, **kwargs):
engine_constructor_called.append(kwargs)
async def execute_stream(self, **kwargs):
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
error_messages = [m for m in sent_messages if m.get("type") == "error"]
assert len(error_messages) == 1
assert "phase policy error" in error_messages[0]["data"]["message"]
# Engine constructor was NOT called (returned early)
assert len(engine_constructor_called) == 0
# ---------------------------------------------------------------------------
# phase_changed event emission
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
"""phase_changed event sent when current_phase changes during execute_stream."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = PhaseState.PLANNING
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(
event_type="tool_call",
step=1,
data={"tool": "search", "output": "ok"},
)
# Simulate phase transition (as if AdvancePhaseTool was called)
self._current_phase = PhaseState.BUILDING
yield ReActEvent(
event_type="final_answer",
step=2,
data={"output": "done"},
)
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="go",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 1
assert phase_events[0]["data"]["phase"] == "building"
assert phase_events[0]["data"]["previous"] == "planning"
@pytest.mark.asyncio
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
"""Characterization: REACT mode → no phase_changed events."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REACT)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = None
self._current_phase = None
@property
def current_phase(self):
return None
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="hi",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 0

View File

@ -0,0 +1,400 @@
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
Verifies:
- auxiliary_model routes _summarize through the cheaper model first
- empty content (Finding 4 anti-pattern) triggers fallback to main model
- auxiliary exception triggers fallback to main model
- both auxiliary and main failing falls through to _simple_summary
- auxiliary_model=None preserves existing single-model behavior (characterization)
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
"""
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.compressor import ContextCompressor
from agentkit.llm.config import LLMConfig
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── Helpers ──────────────────────────────────────────
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
"""Mock LLMGateway returning a fixed response."""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(
return_value=LLMResponse(
content=content,
model=model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
)
return gateway
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
Each call to gateway.chat(model=X) pops the next response for X from a queue,
so repeated calls to the same model can return different values.
"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
queues = {m: list(rs) for m, rs in responses_by_model.items()}
async def chat_side_effect(*, messages, model, **kwargs):
queue = queues.get(model)
if queue is None:
raise ValueError(f"unexpected model={model}")
if not queue:
raise ValueError(f"queue for model={model} exhausted")
item = queue.pop(0)
if isinstance(item, Exception):
raise item
return item
gateway.chat = AsyncMock(side_effect=chat_side_effect)
return gateway
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
"""Generate long messages that exceed token budget (triggers compression)."""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
messages.append({"role": "user", "content": "recent question"})
messages.append({"role": "assistant", "content": "recent answer"})
return messages
# ── Characterization: auxiliary_model=None preserves existing behavior ──
class TestAuxiliaryNoneCharacterization:
"""auxiliary_model=None (default) — single model call, existing behavior."""
async def test_no_auxiliary_calls_main_once(self):
gateway = make_gateway_with_response("main summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
# auxiliary_model omitted → None
)
result = await compressor.compress(make_long_messages())
gateway.chat.assert_awaited_once()
# The call used the main model
assert gateway.chat.await_args.kwargs.get("model") == "main"
# Summary surfaced in result
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_main_failure_falls_to_simple_summary(self):
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
)
result = await compressor.compress(make_long_messages())
# _simple_summary produces truncated messages with "..."
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "..." in summary_msgs[0]["content"]
# ── New behavior: auxiliary routing ──────────────────
class TestAuxiliaryRouting:
"""auxiliary_model set and differs from main → auxiliary tried first."""
async def test_auxiliary_success_returns_auxiliary_content(self):
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content="aux summary",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
"main": [
LLMResponse(
content="MAIN SHOULD NOT BE USED",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Auxiliary called; main NOT called
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 0
# Result contains auxiliary summary
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("aux summary" in m["content"] for m in summary_msgs)
async def test_empty_content_triggers_main_fallback(self):
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content="",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
)
],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Auxiliary called once (returned empty)
# Main called once (fallback)
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
# Result contains main summary (not the empty auxiliary)
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_whitespace_content_triggers_main_fallback(self):
"""Whitespace-only content also counts as empty (Finding 4)."""
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content=" \n ",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
)
],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
await compressor.compress(make_long_messages())
# Both auxiliary and main called
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
async def test_auxiliary_exception_triggers_main_fallback(self):
from agentkit.core.exceptions import LLMProviderError
gateway = make_gateway_side_effect(
{
"fast": [LLMProviderError("aux", "provider down")],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Both called; main succeeded
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_both_fail_falls_to_simple_summary(self):
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
# Note: aggressive compression path may invoke _summarize multiple times.
# Queue provides enough responses to handle that without raising queue-exhausted.
gateway = make_gateway_side_effect(
{
"fast": [Exception("aux boom")] * 5,
"main": [Exception("main boom")] * 5,
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Both called at least once
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) >= 1
assert len(main_calls) >= 1
# _simple_summary output has "..." truncation markers
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "..." in summary_msgs[0]["content"]
async def test_auxiliary_equal_to_main_skipped(self):
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
gateway = make_gateway_with_response("main summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="main", # same as main
)
await compressor.compress(make_long_messages())
# Only one call (to main); auxiliary block skipped
assert gateway.chat.await_count == 1
assert gateway.chat.await_args.kwargs.get("model") == "main"
async def test_audit_fields_preserved(self):
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
gateway = make_gateway_with_response("aux summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
# Override the mock to use a single-response gateway where auxiliary succeeds
# (the make_gateway_with_response mock returns same response regardless of model)
await compressor.compress(make_long_messages())
# Single call (auxiliary succeeded) — verify audit fields
call_kwargs = gateway.chat.await_args.kwargs
assert call_kwargs.get("agent_name") == "compressor"
assert call_kwargs.get("task_type") == "summarization"
# ── Config wiring ────────────────────────────────────
class TestConfigWiring:
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
def test_llm_config_from_dict_reads_auxiliary_model(self):
cfg = LLMConfig.from_dict(
{
"providers": {},
"model_aliases": {"fast": "p/m"},
"auxiliary_model": "fast",
}
)
assert cfg.auxiliary_model == "fast"
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
cfg = LLMConfig.from_dict({"providers": {}})
assert cfg.auxiliary_model is None
def test_llm_config_default_auxiliary_none(self):
cfg = LLMConfig()
assert cfg.auxiliary_model is None
def test_server_config_build_llm_config_reads_auxiliary_model(self):
from agentkit.server.config import ServerConfig
llm_data = {
"providers": {
"p": {
"type": "openai",
"api_key": "k",
"base_url": "http://x",
"models": {"m": {"alias": "fast"}},
}
},
"auxiliary_model": "fast",
}
llm_config = ServerConfig._build_llm_config(llm_data)
assert llm_config.auxiliary_model == "fast"
# Also verify model_aliases still built correctly
assert llm_config.model_aliases.get("fast") == "p/m"
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
from agentkit.server.config import ServerConfig
llm_config = ServerConfig._build_llm_config({"providers": {}})
assert llm_config.auxiliary_model is None

View File

@ -0,0 +1,318 @@
"""G7/U2 — Emergency layer rule template + TaskResult extension.
Verifies:
- EmergencyRules.classify maps each exception type to correct error_code
- TaskCancelledError raises ValueError (caller must propagate as-is)
- EmergencyError.to_dict produces all 5 fields
- EmergencyError.to_error_message formats suggestions as "建议1) ... 2) ..."
- Config overrides apply (suggestions, retryable, message)
- TaskResult.error_struct field: default None preserves byte-for-byte
to_dict() output (backward compat)
- TaskResult round-trip serialization includes error_struct when set
"""
from datetime import datetime, timezone
import pytest
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.fallback import (
EMPTY_LLM_RESPONSE,
MAX_STEPS_REACHED,
SHELL_NO_OUTPUT,
EmergencyError,
EmergencyRules,
)
from agentkit.core.protocol import TaskResult
# ── Constants unchanged (contract preservation) ──────
class TestExistingConstantsUnchanged:
"""Existing 3 constants preserved byte-for-byte."""
def test_empty_llm_response_unchanged(self):
assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE
assert "建议" in EMPTY_LLM_RESPONSE
def test_max_steps_reached_unchanged(self):
assert "已达到最大推理步数" in MAX_STEPS_REACHED
def test_shell_no_output_unchanged(self):
assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]"
# ── EmergencyRules.classify ──────────────────────────
class TestEmergencyRulesClassify:
"""classify() maps exception types to EmergencyError."""
def test_timeout(self):
exc = TaskTimeoutError(task_id="t1", timeout_seconds=30)
err = EmergencyRules.classify(exc)
assert err.error_code == "timeout"
assert err.retryable is True
assert "稍后重试" in err.suggestions
assert "简化任务范围" in err.suggestions
assert err.original_error == str(exc)
def test_loop_detected(self):
exc = LoopDetectedError(tool_name="shell", repetitions=3)
err = EmergencyRules.classify(exc)
assert err.error_code == "loop_detected"
assert err.retryable is True
assert "拆分任务" in err.suggestions
assert "检查工具参数" in err.suggestions
def test_llm_provider_error(self):
exc = LLMProviderError("openai", "rate limited")
err = EmergencyRules.classify(exc)
assert err.error_code == "llm_failure"
assert err.retryable is True
assert "稍后重试" in err.suggestions
assert "切换模型" in err.suggestions
def test_llm_error_subclass_also_classified(self):
"""LLMProviderError is a subclass of LLMError; ensure isinstance check works."""
from agentkit.core.exceptions import LLMError
class CustomLLMError(LLMError):
pass
err = EmergencyRules.classify(CustomLLMError("custom"))
# CustomLLMError is NOT a LLMProviderError, falls through to generic
assert err.error_code == "internal_error"
def test_generic_exception_internal_error(self):
err = EmergencyRules.classify(Exception("unknown boom"))
assert err.error_code == "internal_error"
assert err.retryable is False
assert "联系管理员" in err.suggestions
assert err.original_error == "unknown boom"
def test_task_cancelled_raises(self):
"""TaskCancelledError must propagate; classify() raises ValueError."""
exc = TaskCancelledError(task_id="t1")
with pytest.raises(ValueError, match="TaskCancelledError"):
EmergencyRules.classify(exc)
def test_subclass_of_timeout_classified(self):
"""Subclasses of TaskTimeoutError are classified as timeout."""
class CustomTimeout(TaskTimeoutError):
def __init__(self):
super().__init__(task_id="custom", timeout_seconds=10)
err = EmergencyRules.classify(CustomTimeout())
assert err.error_code == "timeout"
# ── EmergencyError serialization ─────────────────────
class TestEmergencyErrorSerialization:
"""to_dict / to_error_message on EmergencyError."""
def test_to_dict_produces_all_five_fields(self):
err = EmergencyError(
error_code="timeout",
message="任务执行超时。",
suggestions=["稍后重试", "简化任务范围"],
retryable=True,
original_error="Task t1 timed out after 30s",
)
d = err.to_dict()
assert set(d.keys()) == {
"error_code",
"message",
"suggestions",
"retryable",
"original_error",
}
assert d["error_code"] == "timeout"
assert d["message"] == "任务执行超时。"
assert d["suggestions"] == ["稍后重试", "简化任务范围"]
assert d["retryable"] is True
assert d["original_error"] == "Task t1 timed out after 30s"
def test_to_dict_suggestions_list_is_copy(self):
"""to_dict returns a fresh list, not the internal reference."""
suggestions = ["a", "b"]
err = EmergencyError(
error_code="x",
message="m",
suggestions=suggestions,
retryable=False,
original_error="e",
)
d = err.to_dict()
assert d["suggestions"] is not suggestions
d["suggestions"].append("c")
assert err.suggestions == ["a", "b"]
def test_to_error_message_with_suggestions(self):
err = EmergencyError(
error_code="timeout",
message="任务执行超时。",
suggestions=["稍后重试", "简化任务范围"],
retryable=True,
original_error="err",
)
msg = err.to_error_message()
assert msg.startswith("任务执行超时。建议:")
assert "1) 稍后重试" in msg
assert "2) 简化任务范围" in msg
# Format mirrors EMPTY_LLM_RESPONSE style
assert msg.endswith("")
def test_to_error_message_no_suggestions(self):
err = EmergencyError(
error_code="x",
message="just a message",
suggestions=[],
retryable=False,
original_error="e",
)
assert err.to_error_message() == "just a message"
def test_to_error_message_single_suggestion(self):
err = EmergencyError(
error_code="x",
message="msg",
suggestions=["only one"],
retryable=False,
original_error="e",
)
msg = err.to_error_message()
assert msg == "msg建议1) only one。"
# ── Config override ──────────────────────────────────
class TestConfigOverride:
"""classify() applies per-rule config overrides."""
def test_override_suggestions(self):
exc = TaskTimeoutError(task_id="t", timeout_seconds=1)
cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.suggestions == ["自定义建议 A", "自定义建议 B"]
assert err.error_code == "timeout"
def test_override_retryable(self):
exc = LLMProviderError("openai", "boom")
cfg = {"llm_failure": {"retryable": False}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.retryable is False
def test_override_message(self):
exc = LoopDetectedError(tool_name="x", repetitions=2)
cfg = {"loop_detected": {"message": "循环啦!"}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.message == "循环啦!"
def test_override_internal_error_rule(self):
cfg = {"internal_error": {"suggestions": ["联系客服"]}}
err = EmergencyRules.classify(Exception("boom"), config=cfg)
assert err.error_code == "internal_error"
assert err.suggestions == ["联系客服"]
def test_config_none_uses_defaults(self):
err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1))
assert err.error_code == "timeout"
assert err.retryable is True
def test_config_empty_dict_uses_defaults(self):
err = EmergencyRules.classify(
TaskTimeoutError(task_id="t", timeout_seconds=1), config={}
)
assert err.error_code == "timeout"
assert err.retryable is True
# ── TaskResult.error_struct extension ────────────────
def _make_task_result(
error_struct: dict | None = None, error_message: str | None = None
) -> TaskResult:
now = datetime.now(timezone.utc)
return TaskResult(
task_id="t1",
agent_name="a1",
status="completed",
output_data={"k": "v"},
error_message=error_message,
started_at=now,
completed_at=now,
metrics={"m": 1},
error_struct=error_struct,
)
class TestTaskResultErrorStruct:
"""TaskResult.error_struct field — backward-compatible extension."""
def test_default_error_struct_is_none(self):
tr = _make_task_result()
assert tr.error_struct is None
def test_to_dict_without_error_struct_preserves_existing_shape(self):
"""error_struct=None → to_dict() output has NO error_struct key (byte-for-byte)."""
tr = _make_task_result()
d = tr.to_dict()
assert "error_struct" not in d
# Existing keys unchanged
assert set(d.keys()) == {
"task_id",
"agent_name",
"status",
"output_data",
"error_message",
"started_at",
"completed_at",
"metrics",
}
def test_to_dict_with_error_struct_includes_key(self):
struct = {
"error_code": "timeout",
"message": "超时",
"suggestions": ["重试"],
"retryable": True,
"original_error": "boom",
}
tr = _make_task_result(error_struct=struct, error_message="超时建议1) 重试。")
d = tr.to_dict()
assert d["error_struct"] == struct
assert d["error_message"] == "超时建议1) 重试。"
def test_from_dict_round_trip_with_error_struct(self):
struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"}
tr = _make_task_result(error_struct=struct)
d = tr.to_dict()
restored = TaskResult.from_dict(d)
assert restored.error_struct == struct
def test_from_dict_without_error_struct_defaults_none(self):
tr = _make_task_result()
d = tr.to_dict()
# Simulate legacy data without error_struct key
restored = TaskResult.from_dict(d)
assert restored.error_struct is None
def test_error_message_and_error_struct_coexist(self):
"""Both fields can be set simultaneously (parallel contract per KTD2)."""
struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"}
tr = _make_task_result(error_struct=struct, error_message="超时建议1) 重试。")
d = tr.to_dict()
assert d["error_message"] == "超时建议1) 重试。"
assert d["error_struct"] == struct

View File

@ -0,0 +1,404 @@
"""G7/U3 — Three-tier fallback chain wiring tests.
Verifies Main Recovery (ReflexionEngine) Emergency (EmergencyRules)
at chat REST path. Mocks ReActEngine + ReflexionEngine + LLMGateway.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.react import ReActResult
from agentkit.server._fallback_chain import execute_with_fallback_chain
def _make_react_result(status: str = "success", output: str = "ok") -> ReActResult:
return ReActResult(
output=output,
trajectory=[],
total_steps=1,
total_tokens=10,
status=status,
)
def _make_react_engine(result=None, raises=None):
"""Build a fake ReActEngine with .execute returning result or raising."""
engine = MagicMock()
engine.reset = MagicMock()
if raises is not None:
engine.execute = AsyncMock(side_effect=raises)
else:
engine.execute = AsyncMock(return_value=result or _make_react_result())
return engine
def _make_llm_gateway():
gw = MagicMock()
gw.chat = AsyncMock(return_value=MagicMock(content="recovered"))
return gw
def _make_reflexion_result(status: str = "success", output: str = "recovered"):
"""Synthesize a ReflexionResult-like object."""
return MagicMock(
status=status,
output=output,
trajectory=[],
total_steps=1,
total_tokens=5,
)
@pytest.fixture
def patched_reflexion(monkeypatch):
"""Patch ReflexionEngine used inside the chain to a controllable mock."""
from agentkit.server import _fallback_chain
instances: list[MagicMock] = []
class _MockReflexion:
def __init__(self, llm_gateway, max_reflections=1, **kwargs):
self._llm_gateway = llm_gateway
self._max_reflections = max_reflections
self.execute = AsyncMock(return_value=_make_reflexion_result())
instances.append(self)
monkeypatch.setattr(_fallback_chain, "ReflexionEngine", _MockReflexion)
return instances
# ─── Tier 1: Main ─────────────────────────────────────────────────────────
class TestMainTier:
@pytest.mark.asyncio
async def test_main_success_no_recovery_no_emergency(self):
engine = _make_react_engine(result=_make_react_result(status="success", output="hello"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "success"
assert result.output == "hello"
assert result.error_struct is None
@pytest.mark.asyncio
async def test_main_unknown_status_treated_as_success(self):
"""Unknown status (not in soft_failure set) is treated as success-like."""
engine = _make_react_engine(result=_make_react_result(status="partial", output="x"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "success"
# ─── Tier 2: Recovery ──────────────────────────────────────────────────────
class TestRecoveryTier:
@pytest.mark.asyncio
async def test_main_timeout_triggers_recovery_success(self, patched_reflexion):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
assert result.output == "recovered"
# ReflexionEngine was instantiated and called
assert len(patched_reflexion) == 1
patched_reflexion[0].execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_main_loop_detected_triggers_recovery(self, patched_reflexion):
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_main_llm_provider_error_triggers_recovery(self, patched_reflexion):
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="503"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_main_soft_failure_status_triggers_recovery(self, patched_reflexion):
"""Soft failure (empty_fallback) without exception still triggers Recovery."""
engine = _make_react_engine(result=_make_react_result(status="empty_fallback", output=""))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_recovery_disabled_skips_to_emergency(self):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
@pytest.mark.asyncio
async def test_recovery_failure_falls_through_to_emergency(self, patched_reflexion):
"""Recovery raises → Emergency tier fires with original exception."""
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
# Make ReflexionEngine.execute raise
patched_reflexion_instance = MagicMock()
patched_reflexion_instance.execute = AsyncMock(
side_effect=RuntimeError("reflexion crashed")
)
# Override the patched class to use our instance
from agentkit.server import _fallback_chain
original_cls = _fallback_chain.ReflexionEngine
class _MockReflexionWithExc:
def __init__(self, **kwargs):
self.execute = patched_reflexion_instance.execute
_fallback_chain.ReflexionEngine = _MockReflexionWithExc
try:
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
@pytest.mark.asyncio
async def test_recovery_unsuccessful_status_falls_through(self, patched_reflexion):
"""Recovery returns non-success status → Emergency fires."""
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
# Make ReflexionEngine return unsuccessful result with empty output
from agentkit.server import _fallback_chain
class _MockReflexionNoOutput:
def __init__(self, **kwargs):
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
original_cls = _fallback_chain.ReflexionEngine
_fallback_chain.ReflexionEngine = _MockReflexionNoOutput
try:
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
# ─── Tier 3: Emergency ────────────────────────────────────────────────────
class TestEmergencyTier:
@pytest.mark.asyncio
async def test_emergency_timeout_error_code(self, patched_reflexion):
# Make recovery fail (empty result) so Emergency fires
from agentkit.server import _fallback_chain
class _MockReflexionEmpty:
def __init__(self, **kwargs):
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
original_cls = _fallback_chain.ReflexionEngine
_fallback_chain.ReflexionEngine = _MockReflexionEmpty
try:
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
assert result.error_struct["retryable"] is True
assert "建议" in result.output
@pytest.mark.asyncio
async def test_emergency_loop_detected_error_code(self):
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
# Recovery disabled so Emergency fires directly
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "loop_detected"
@pytest.mark.asyncio
async def test_emergency_llm_failure_error_code(self):
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="500"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "llm_failure"
@pytest.mark.asyncio
async def test_emergency_internal_error_for_generic_exception(self):
engine = _make_react_engine(raises=RuntimeError("unexpected"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "internal_error"
assert result.error_struct["retryable"] is False
@pytest.mark.asyncio
async def test_task_cancelled_propagates_not_routed_to_emergency(self):
"""TaskCancelledError must propagate, not be classified by Emergency."""
engine = _make_react_engine(raises=TaskCancelledError(task_id="t1"))
with pytest.raises(TaskCancelledError):
await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
@pytest.mark.asyncio
async def test_emergency_disabled_reraises_original(self):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
with pytest.raises(TaskTimeoutError):
await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={
"recovery": {"enabled": False},
"emergency": {"enabled": False},
},
)
# ─── Config wiring ────────────────────────────────────────────────────────
class TestServerConfigFallbackChain:
def test_fallback_chain_section_read_from_dict(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict(
{
"fallback_chain": {
"enabled": True,
"recovery": {"enabled": False, "max_retries": 3},
"emergency": {"enabled": True},
}
}
)
assert config.fallback_chain["enabled"] is True
assert config.fallback_chain["recovery"] == {"enabled": False, "max_retries": 3}
assert config.fallback_chain["emergency"] == {"enabled": True}
def test_fallback_chain_defaults_empty_when_absent(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.fallback_chain == {}

View File

@ -0,0 +1,348 @@
"""Unit tests for PhasePolicy + PhaseState (G6 core, R24/R25/R26).
Covers:
- PhaseState enum (next_of, from_string)
- default_policy() KTD5 whitelist
- PhasePolicy.is_tool_allowed / is_bash_command_allowed
- policy_from_config parsing (R26 config-driven)
- ServerConfig.plan_exec integration
"""
from __future__ import annotations
import re
import pytest
from agentkit.core.phase import (
WILDCARD,
PhasePolicy,
PhaseState,
default_policy,
policy_from_config,
)
from agentkit.server.config import ServerConfig
# ---------------------------------------------------------------------------
# PhaseState enum
# ---------------------------------------------------------------------------
class TestPhaseState:
def test_values(self):
assert PhaseState.PLANNING.value == "planning"
assert PhaseState.BUILDING.value == "building"
assert PhaseState.VERIFICATION.value == "verification"
assert PhaseState.DELIVERY.value == "delivery"
def test_next_of(self):
assert PhaseState.next_of(PhaseState.PLANNING) == PhaseState.BUILDING
assert PhaseState.next_of(PhaseState.BUILDING) == PhaseState.VERIFICATION
assert PhaseState.next_of(PhaseState.VERIFICATION) == PhaseState.DELIVERY
assert PhaseState.next_of(PhaseState.DELIVERY) is None
def test_from_string_case_insensitive(self):
assert PhaseState.from_string("planning") == PhaseState.PLANNING
assert PhaseState.from_string("PLANNING") == PhaseState.PLANNING
assert PhaseState.from_string("Building") == PhaseState.BUILDING
def test_from_string_invalid_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
PhaseState.from_string("unknown")
with pytest.raises(ValueError, match="Valid:"):
PhaseState.from_string("exploration")
# ---------------------------------------------------------------------------
# default_policy() — KTD5 whitelist
# ---------------------------------------------------------------------------
class TestDefaultPolicy:
def test_has_all_four_phases(self):
policy = default_policy()
assert PhaseState.PLANNING in policy.whitelist
assert PhaseState.BUILDING in policy.whitelist
assert PhaseState.VERIFICATION in policy.whitelist
assert PhaseState.DELIVERY in policy.whitelist
def test_planning_whitelist_matches_r24(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.PLANNING]
assert "search" in allowed
assert "read_file" in allowed
assert "shell" in allowed
assert "tool_search" in allowed
# Planning must NOT allow write_file.
assert "write_file" not in allowed
def test_building_whitelist_includes_write_file(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.BUILDING]
assert "write_file" in allowed
assert "shell" in allowed
assert "read_file" in allowed
def test_verification_whitelist_excludes_write(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.VERIFICATION]
assert "shell" in allowed
assert "read_file" in allowed
assert "write_file" not in allowed
def test_delivery_wildcard(self):
policy = default_policy()
allowed = policy.whitelist[PhaseState.DELIVERY]
assert WILDCARD in allowed
def test_start_phase_default_planning(self):
assert default_policy().start_phase == PhaseState.PLANNING
def test_auto_advance_default_none(self):
# KTD6: manual by default.
assert default_policy().auto_advance_after_steps is None
def test_bash_filter_blocks_rm_in_planning(self):
policy = default_policy()
assert policy.is_bash_command_allowed("ls -la", PhaseState.PLANNING) is True
assert policy.is_bash_command_allowed("git status", PhaseState.PLANNING) is True
assert policy.is_bash_command_allowed("rm -rf /tmp/x", PhaseState.PLANNING) is False
assert policy.is_bash_command_allowed("echo x > file.txt", PhaseState.PLANNING) is False
def test_bash_filter_no_restriction_in_building(self):
policy = default_policy()
assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True
assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True
# ---------------------------------------------------------------------------
# PhasePolicy — is_tool_allowed
# ---------------------------------------------------------------------------
class TestIsToolAllowed:
def test_planning_allows_search(self):
policy = default_policy()
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
def test_planning_blocks_write_file(self):
policy = default_policy()
assert policy.is_tool_allowed("write_file", PhaseState.PLANNING) is False
def test_building_allows_write_file(self):
policy = default_policy()
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
def test_delivery_wildcard_allows_anything(self):
policy = default_policy()
assert policy.is_tool_allowed("any_random_tool", PhaseState.DELIVERY) is True
assert policy.is_tool_allowed("write_file", PhaseState.DELIVERY) is True
def test_unknown_phase_returns_false(self):
# ponytail: unknown phase → empty whitelist → no tool allowed.
# We can't construct an unknown PhaseState (enum), but if a phase
# were missing from the whitelist dict, is_tool_allowed should
# return False (defensive).
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"search"}),
PhaseState.BUILDING: frozenset({"write_file"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
# BUILDING is in whitelist, so allowed checks work normally.
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
# Phase missing from whitelist would return False (defensive .get default).
# We test this by constructing a minimal policy.
minimal = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
# VERIFICATION is in whitelist — wildcard allows all.
assert minimal.is_tool_allowed("anything", PhaseState.VERIFICATION) is True
# ---------------------------------------------------------------------------
# PhasePolicy — edge cases & errors
# ---------------------------------------------------------------------------
class TestPhasePolicyEdgeCases:
def test_empty_whitelist_raises(self):
# Fail-fast: an empty whitelist for a non-wildcard phase is a bug.
with pytest.raises(ValueError, match="empty whitelist"):
PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset(), # empty!
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
def test_wildcard_only_does_not_raise(self):
# Wildcard-only whitelist is valid (means "all tools").
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({WILDCARD}),
PhaseState.BUILDING: frozenset({WILDCARD}),
PhaseState.VERIFICATION: frozenset({WILDCARD}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
}
)
assert policy.is_tool_allowed("anything", PhaseState.PLANNING) is True
def test_to_dict_serializable(self):
policy = default_policy()
d = policy.to_dict()
assert "whitelist" in d
assert "planning" in d["whitelist"]
assert "delivery" in d["whitelist"]
assert d["start_phase"] == "planning"
assert d["auto_advance_after_steps"] is None
def test_custom_bash_filter(self):
custom_filter = re.compile(r"\b(pip install|npm install)\b")
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"shell"}),
PhaseState.BUILDING: frozenset({"shell"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({WILDCARD}),
},
bash_command_filter={PhaseState.BUILDING: custom_filter},
)
assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False
assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True
# ---------------------------------------------------------------------------
# policy_from_config — R26 (config-driven)
# ---------------------------------------------------------------------------
class TestPolicyFromConfig:
def test_empty_config_returns_none(self):
assert policy_from_config({}) is None
def test_enabled_false_returns_none(self):
# Opt-out — explicit `enabled: false` disables policy.
result = policy_from_config({"enabled": False})
assert result is None
def test_enabled_default_true_when_section_present(self):
# When section is present but `enabled` is missing, default is True.
result = policy_from_config({"auto_advance_after_steps": 3})
assert result is not None
assert result.auto_advance_after_steps == 3
def test_auto_advance_after_steps(self):
policy = policy_from_config({"enabled": True, "auto_advance_after_steps": 5})
assert policy is not None
assert policy.auto_advance_after_steps == 5
def test_start_phase_custom(self):
policy = policy_from_config({"enabled": True, "start_phase": "building"})
assert policy is not None
assert policy.start_phase == PhaseState.BUILDING
def test_start_phase_invalid_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
policy_from_config({"enabled": True, "start_phase": "unknown"})
def test_whitelist_override_merges_with_default(self):
policy = policy_from_config(
{
"enabled": True,
"whitelist_override": {
"planning": ["search", "read_file"], # removes shell from default
},
}
)
assert policy is not None
# Override wins — shell should be removed from planning.
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
assert policy.is_tool_allowed("read_file", PhaseState.PLANNING) is True
assert policy.is_tool_allowed("shell", PhaseState.PLANNING) is False
# Other phases unchanged.
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
def test_whitelist_override_invalid_phase_raises(self):
with pytest.raises(ValueError, match="Invalid phase name"):
policy_from_config(
{
"enabled": True,
"whitelist_override": {"unknown_phase": ["tool"]},
}
)
def test_whitelist_override_non_list_raises(self):
with pytest.raises(ValueError, match="must be a list"):
policy_from_config(
{
"enabled": True,
"whitelist_override": {"planning": "not a list"},
}
)
def test_to_dict_round_trip_via_default(self):
# Sanity: default policy serializes to a dict with expected keys.
policy = default_policy()
d = policy.to_dict()
assert set(d["whitelist"].keys()) == {
"planning",
"building",
"verification",
"delivery",
}
# ---------------------------------------------------------------------------
# ServerConfig.plan_exec integration (R26)
# ---------------------------------------------------------------------------
class TestServerConfigPlanExec:
def test_default_plan_exec_empty(self):
config = ServerConfig.from_dict({})
assert config.plan_exec == {}
def test_plan_exec_loaded_from_dict(self):
config = ServerConfig.from_dict(
{
"plan_exec": {
"enabled": True,
"auto_advance_after_steps": 5,
}
}
)
assert config.plan_exec == {"enabled": True, "auto_advance_after_steps": 5}
def test_plan_exec_empty_dict_default(self):
config = ServerConfig.from_dict({"plan_exec": {}})
assert config.plan_exec == {}
def test_plan_exec_resolved_to_policy(self):
# Wire the config dict through policy_from_config to verify integration.
config = ServerConfig.from_dict(
{
"plan_exec": {
"enabled": True,
"auto_advance_after_steps": 3,
}
}
)
policy = policy_from_config(config.plan_exec)
assert policy is not None
assert policy.auto_advance_after_steps == 3
def test_plan_exec_disabled_via_config(self):
config = ServerConfig.from_dict({"plan_exec": {"enabled": False}})
policy = policy_from_config(config.plan_exec)
assert policy is None

View File

@ -0,0 +1,367 @@
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
Characterization-first: captures pre-change behavior (rollback_command=None
no rollback, checkpoint saved) before asserting new rollback behavior.
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import PlanPhase, TeamPlan
from agentkit.orchestrator.rollback import RollbackExecutor
# ─── PlanPhase field characterization ──────────────────────────────────────
class TestPlanPhaseFields:
"""PlanPhase serialization for new fields."""
def test_characterization_no_new_keys_when_unset(self):
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
ph = PlanPhase(name="执行", assigned_expert="lead")
out = ph.to_dict()
assert "validation_command" not in out
assert "rollback_command" not in out
# Pre-change keys remain
assert out["name"] == "执行"
assert out["assigned_expert"] == "lead"
assert out["status"] == "pending"
def test_characterization_from_dict_empty_yields_none(self):
"""from_dict({}) produces both new fields as None."""
ph = PlanPhase.from_dict({"name": "x"})
assert ph.validation_command is None
assert ph.rollback_command is None
def test_serialization_includes_keys_when_set(self):
ph = PlanPhase(
name="frontend",
validation_command="ruff check src/",
rollback_command="git checkout src/app.vue",
)
out = ph.to_dict()
assert out["validation_command"] == "ruff check src/"
assert out["rollback_command"] == "git checkout src/app.vue"
def test_serialization_round_trip(self):
ph = PlanPhase(
name="backend",
validation_command="pytest -x -q",
rollback_command="git checkout src/api.py",
)
restored = PlanPhase.from_dict(ph.to_dict())
assert restored.validation_command == "pytest -x -q"
assert restored.rollback_command == "git checkout src/api.py"
def test_only_validation_set_still_emits_validation_key(self):
"""Asymmetric case: only validation_command set — only that key appears."""
ph = PlanPhase(name="x", validation_command="echo ok")
out = ph.to_dict()
assert "validation_command" in out
assert "rollback_command" not in out
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
@pytest.fixture
def tmp_workspace(tmp_path):
"""Fresh working directory for subprocess execution."""
return str(tmp_path)
class TestRollbackExecutor:
"""RollbackExecutor happy/edge/failure paths."""
@pytest.mark.asyncio
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("true")
assert r.passed is True
assert r.exit_code == 0
assert r.command == "true"
@pytest.mark.asyncio
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("false")
assert r.passed is False
assert r.exit_code != 0
@pytest.mark.asyncio
async def test_execute_timeout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
r = await ex.execute("sleep 5")
assert r.passed is False
assert r.exit_code == -1
assert "timed out" in r.stderr
@pytest.mark.asyncio
async def test_execute_captures_stdout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("echo hello-rollback")
assert r.passed is True
assert "hello-rollback" in r.stdout
@pytest.mark.asyncio
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
"""validate() is just execute() with different intent marker."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.validate("true")
assert r.passed is True
r2 = await ex.validate("false")
assert r2.passed is False
@pytest.mark.asyncio
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
"""Bad shell command should surface as failed result, not raise."""
# Use a definitely-broken command — shell still spawns, returns non-zero.
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("exit 7")
assert r.passed is False
assert r.exit_code == 7
@pytest.mark.asyncio
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
"""Files created in working_dir are visible via cwd-relative commands."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
await ex.execute("echo content > testfile.txt")
r = await ex.execute("test -f testfile.txt")
assert r.passed is True
# ─── Real git rollback integration ────────────────────────────────────────
class TestGitRollbackIntegration:
"""Real git repo fixture: writes file, rollback restores via git checkout."""
@pytest.mark.asyncio
async def test_git_checkout_restores_modified_file(self, tmp_path):
# Init repo and commit a baseline file
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
baseline = "original content\n"
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write(baseline)
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
# Mutate the file
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write("mutated content\n")
# Run git checkout as rollback_command
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
r = await ex.execute("git checkout foo.txt")
assert r.passed is True
with open(os.path.join(repo, "foo.txt")) as f:
assert f.read() == baseline
# ─── TeamOrchestrator integration ─────────────────────────────────────────
def _make_team_mock():
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
team = MagicMock()
team.team_id = "test-team"
team.status.value = "executing"
team.lead_expert = None
team.active_experts = []
return team
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
"""Build a TeamOrchestrator with mocked team and checkpoint."""
team = team or _make_team_mock()
orch = TeamOrchestrator(
team=team,
checkpoint=checkpoint,
workspace_root=workspace_root,
rollback_timeout=2.0,
)
orch._broadcast_event = AsyncMock()
return orch
class TestOrchestratorRollbackIntegration:
"""TeamOrchestrator phase failure path integration with rollback."""
@pytest.mark.asyncio
async def test_characterization_no_rollback_when_unset(self, tmp_path):
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
# Call _run_phase_rollback — but it should NOT be called from
# orchestrator path when fields unset. Verify by simulating the
# main-loop guard directly.
from agentkit.experts.plan import PhaseStatus as PS
ph.status = PS.FAILED
# Simulate the guard condition in orchestrator.py:280-288
should_save = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(Exception("x"), (Exception,))
):
should_save = await orch._run_phase_rollback(plan, ph)
# Guard should not fire
assert should_save is True
# No rollback events broadcast
assert orch._broadcast_event.call_count == 0
@pytest.mark.asyncio
async def test_validation_passes_no_rollback_executed(self, tmp_path):
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="true", # always passes
rollback_command="false", # would fail if executed
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# Events: started + completed (no rollback)
events = [
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
]
assert "phase_rollback_started" in events
assert "phase_rollback_completed" in events
assert "phase_rollback_failed" not in events
@pytest.mark.asyncio
async def test_validation_fails_rollback_succeeds(self, tmp_path):
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
# Create file under git tracking, then mutate it
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("base\n")
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("mutated\n")
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # validation fails
rollback_command="git checkout f.txt", # rollback succeeds
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# File restored to baseline
with open(os.path.join(repo, "f.txt")) as f:
assert f.read() == "base\n"
@pytest.mark.asyncio
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # fails
rollback_command="false", # also fails
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False # R21: skip checkpoint
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_started" in events
assert "phase_rollback_failed" in events
# phase_rollback_completed should NOT be in events
assert "phase_rollback_completed" not in events
@pytest.mark.asyncio
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
orch._rollback_timeout = 0.1 # short timeout
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false",
rollback_command="sleep 5",
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_failed" in events
# ─── Config wiring ────────────────────────────────────────────────────────
class TestServerConfigRollback:
"""ServerConfig rollback section wiring."""
def test_rollback_section_read_from_dict(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict(
{
"rollback": {
"default_timeout": 45.0,
}
}
)
assert config.rollback == {"default_timeout": 45.0}
def test_rollback_defaults_empty_when_absent(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.rollback == {}

View File

@ -0,0 +1,339 @@
"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24).
Per plan U3 Execution note: characterization-first verify that
`ReActEngine(phase_policy=None)` behaves identically to pre-change (no
enforcement, no advance_phase tool, no _current_phase mutation). Then add
enforcement tests.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.phase import PhasePolicy, PhaseState, default_policy
from agentkit.core.react import ReActEngine
from agentkit.tools.advance_phase import AdvancePhaseTool
# ---------------------------------------------------------------------------
# Characterization — phase_policy=None preserves existing behavior
# ---------------------------------------------------------------------------
class TestCharacterizationNoPolicy:
"""When phase_policy=None, no enforcement happens and behavior matches
pre-Wave-3."""
def test_init_without_phase_policy(self):
# Minimal stub LLM gateway — we're only testing constructor.
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
assert engine._phase_policy is None
assert engine._current_phase is None
assert engine._steps_in_phase == 0
assert engine.current_phase is None
@pytest.mark.asyncio
async def test_execute_tool_dispatches_without_phase_check(self):
"""Tool dispatch proceeds normally when no policy set."""
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
# MagicMock.name is a special attribute used internally by Mock for
# repr — setting it post-construction does not make mock.name == "x"
# hold. Patch _find_tool directly to bypass the name lookup.
fake_tool = MagicMock()
fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
fake_tool.input_schema = None
engine._find_tool = lambda name, tools: fake_tool
result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool])
assert result == {"output": "ok"}
fake_tool.safe_execute.assert_awaited_once_with(x=1)
@pytest.mark.asyncio
async def test_advance_phase_returns_none_without_policy(self):
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
assert engine.advance_phase() is None
def test_reset_does_not_touch_phase_state_when_no_policy(self):
gateway = MagicMock()
engine = ReActEngine(llm_gateway=gateway)
engine.reset()
assert engine._current_phase is None
# ---------------------------------------------------------------------------
# Initialization with phase_policy
# ---------------------------------------------------------------------------
class TestPhasePolicyInitialization:
def test_phase_policy_set_initializes_current_phase(self):
gateway = MagicMock()
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=default_policy(),
)
assert engine._phase_policy is not None
assert engine._current_phase == PhaseState.PLANNING
assert engine._steps_in_phase == 0
def test_reset_resets_phase_to_start(self):
gateway = MagicMock()
engine = ReActEngine(
llm_gateway=gateway,
phase_policy=default_policy(),
)
# Manually move phase forward (simulating execute progress).
engine.advance_phase() # PLANNING → BUILDING
assert engine._current_phase == PhaseState.BUILDING
engine._steps_in_phase = 5
engine.reset()
assert engine._current_phase == PhaseState.PLANNING
assert engine._steps_in_phase == 0
# ---------------------------------------------------------------------------
# advance_phase() transitions
# ---------------------------------------------------------------------------
class TestAdvancePhase:
@pytest.fixture
def engine(self):
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
def test_planning_to_building(self, engine):
new_phase = engine.advance_phase()
assert new_phase == PhaseState.BUILDING
assert engine.current_phase == PhaseState.BUILDING
assert engine._steps_in_phase == 0 # counter reset on transition
def test_building_to_verification(self, engine):
engine.advance_phase() # → BUILDING
new_phase = engine.advance_phase()
assert new_phase == PhaseState.VERIFICATION
assert engine.current_phase == PhaseState.VERIFICATION
def test_verification_to_delivery(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
new_phase = engine.advance_phase()
assert new_phase == PhaseState.DELIVERY
assert engine.current_phase == PhaseState.DELIVERY
def test_delivery_returns_none(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
engine.advance_phase() # → DELIVERY
result = engine.advance_phase()
assert result is None
assert engine.current_phase == PhaseState.DELIVERY
# ---------------------------------------------------------------------------
# _check_phase_permission — whitelist enforcement
# ---------------------------------------------------------------------------
class TestPhasePermission:
@pytest.fixture
def engine(self):
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
def test_search_allowed_in_planning(self, engine):
assert engine._check_phase_permission("search", {}) is None
def test_write_file_blocked_in_planning(self, engine):
result = engine._check_phase_permission("write_file", {})
assert result is not None
assert result["error"] == "phase_violation"
assert "write_file" in result["message"]
assert result["current_phase"] == "planning"
def test_write_file_allowed_in_building(self, engine):
engine.advance_phase() # → BUILDING
assert engine._check_phase_permission("write_file", {}) is None
def test_any_tool_allowed_in_delivery(self, engine):
engine.advance_phase() # → BUILDING
engine.advance_phase() # → VERIFICATION
engine.advance_phase() # → DELIVERY
assert engine._check_phase_permission("literally_anything", {}) is None
def test_bash_command_filter_blocks_rm_in_planning(self, engine):
result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
assert result is not None
assert result["error"] == "phase_violation"
assert "rm" in result["message"] or "Bash command" in result["message"]
def test_bash_command_filter_allows_safe_in_planning(self, engine):
# `ls` and `git status` are not blocked.
assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None
assert engine._check_phase_permission("shell", {"command": "git status"}) is None
def test_bash_command_filter_no_restriction_in_building(self, engine):
engine.advance_phase() # → BUILDING
# `rm` is allowed in building phase.
assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None
# ---------------------------------------------------------------------------
# _execute_tool integration — phase enforcement actually blocks dispatch
# ---------------------------------------------------------------------------
class TestExecuteToolPhaseEnforcement:
@pytest.fixture
def engine_with_tools(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
# Two fake tools: one allowed in PLANNING (search), one not (write_file).
# MagicMock.name can't be set post-construction (special attribute),
# so we patch _find_tool with a dict-based lookup.
search_tool = MagicMock()
search_tool.input_schema = None
search_tool.safe_execute = AsyncMock(return_value={"results": []})
write_tool = MagicMock()
write_tool.input_schema = None
write_tool.safe_execute = AsyncMock(return_value={"written": True})
tools_by_name = {"search": search_tool, "write_file": write_tool}
engine._find_tool = lambda name, tools: tools_by_name.get(name)
return engine, [search_tool, write_tool]
@pytest.mark.asyncio
async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools):
engine, tools = engine_with_tools
# write_file in PLANNING should be blocked — write_tool.safe_execute
# should NEVER be called.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result["error"] == "phase_violation"
assert result["current_phase"] == "planning"
write_tool = tools[1]
write_tool.safe_execute.assert_not_called()
@pytest.mark.asyncio
async def test_allowed_tool_dispatches_normally(self, engine_with_tools):
engine, tools = engine_with_tools
result = await engine._execute_tool("search", {"query": "foo"}, tools)
assert result == {"results": []}
search_tool = tools[0]
search_tool.safe_execute.assert_awaited_once_with(query="foo")
@pytest.mark.asyncio
async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools):
engine, tools = engine_with_tools
# First: write_file blocked in PLANNING.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result["error"] == "phase_violation"
# Advance to BUILDING.
engine.advance_phase()
# Now: write_file allowed.
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
assert result == {"written": True}
# ---------------------------------------------------------------------------
# Auto-advance safety net (KTD6)
# ---------------------------------------------------------------------------
class TestAutoAdvance:
def test_auto_advance_after_threshold(self):
# Custom policy with auto-advance after 2 steps.
policy = PhasePolicy(
whitelist={
PhaseState.PLANNING: frozenset({"search"}),
PhaseState.BUILDING: frozenset({"write_file"}),
PhaseState.VERIFICATION: frozenset({"shell"}),
PhaseState.DELIVERY: frozenset({"*"}),
},
auto_advance_after_steps=2,
)
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy)
assert engine.current_phase == PhaseState.PLANNING
# Step 1: counter goes to 1, no advance yet.
engine._steps_in_phase += 1
assert engine._maybe_auto_advance() is False
assert engine.current_phase == PhaseState.PLANNING
# Step 2: counter hits 2, advance triggered.
engine._steps_in_phase += 1
assert engine._maybe_auto_advance() is True
assert engine.current_phase == PhaseState.BUILDING
assert engine._steps_in_phase == 0 # reset on advance
def test_auto_advance_none_default(self):
# default_policy has auto_advance_after_steps=None — no auto-advance.
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
engine._steps_in_phase = 100
assert engine._maybe_auto_advance() is False
assert engine.current_phase == PhaseState.PLANNING
# ---------------------------------------------------------------------------
# AdvancePhaseTool integration
# ---------------------------------------------------------------------------
class TestAdvancePhaseTool:
@pytest.mark.asyncio
async def test_advance_phase_tool_transitions_engine(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is False
assert result["current_phase"] == "building"
assert engine.current_phase == PhaseState.BUILDING
@pytest.mark.asyncio
async def test_advance_phase_tool_at_delivery_returns_error(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
# Walk through all phases.
engine.advance_phase() # PLANNING → BUILDING
engine.advance_phase() # BUILDING → VERIFICATION
engine.advance_phase() # VERIFICATION → DELIVERY
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is True
assert result["error"] == "already_at_final_phase"
assert result["current_phase"] == "delivery"
@pytest.mark.asyncio
async def test_advance_phase_tool_without_policy_returns_error(self):
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
tool = AdvancePhaseTool(engine=engine)
result = await tool.execute()
assert result["is_error"] is True
assert result["error"] == "no_phase_policy"
def test_tool_schema_accepts_no_arguments(self):
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
tool = AdvancePhaseTool(engine=engine)
# input_schema has empty properties + additionalProperties:false —
# no arguments expected.
assert tool.input_schema["properties"] == {}
assert tool.input_schema["additionalProperties"] is False
def test_tool_bypasses_phase_check(self):
"""`advance_phase` is the LLM's escape hatch — must never be blocked."""
# _check_phase_permission should NOT block advance_phase even in PLANNING.
# The bypass is implemented in _execute_tool by name check.
# We verify the bypass indirectly: tool dispatches normally even in
# PLANNING (where only search/read_file/bash/tool_search are allowed).
# advance_phase is not in the whitelist, but the name-based bypass
# in _execute_tool lets it through.
# (Direct unit test of the bypass would require mocking _find_tool.)
# Sanity: advance_phase is not in any whitelist.
for phase, allowed in default_policy().whitelist.items():
assert "advance_phase" not in allowed, (
f"advance_phase must not be in {phase.value} whitelist"
)

View File

@ -0,0 +1,367 @@
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
Per plan U1 Execution note: characterization-first assert that
`symbol=None` returns the full file content (matches pre-existing benchmark
`_FakeTool` shape) before adding symbol-extraction behavior.
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.file_read import ReadFileTool
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
class TestReadFileToolSchema:
def test_name_is_read_file(self):
tool = ReadFileTool()
assert tool.name == "read_file"
def test_required_path(self):
tool = ReadFileTool()
assert "path" in tool.input_schema["required"]
assert "path" in tool.input_schema["properties"]
def test_optional_symbol_and_lines(self):
tool = ReadFileTool()
props = tool.input_schema["properties"]
assert "symbol" in props
assert "start_line" in props
assert "end_line" in props
# None of the optional fields should be in `required`.
required = set(tool.input_schema["required"])
assert required == {"path"}
def test_additional_properties_false(self):
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
tool = ReadFileTool()
assert tool.input_schema.get("additionalProperties") is False
def test_tags_contain_io_and_read(self):
tool = ReadFileTool()
assert "io" in tool.tags
assert "read" in tool.tags
# ---------------------------------------------------------------------------
# Characterization — symbol=None returns full file
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_py_file(tmp_path):
path = tmp_path / "sample.py"
path.write_text(
textwrap.dedent('''
"""Sample module."""
def my_func():
return 42
class MyClass:
attr = 1
def method_a(self):
return self.attr
''').lstrip(),
encoding="utf-8",
)
return path
@pytest.fixture
def sample_ts_file(tmp_path):
path = tmp_path / "sample.ts"
path.write_text(
textwrap.dedent('''
export function renderComponent(): JSX.Element {
return <div/>;
}
export class BaseService {
abstract run(): void;
}
''').lstrip(),
encoding="utf-8",
)
return path
class TestCharacterizationFullFile:
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
@pytest.mark.asyncio
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["is_error"] is False
assert result["path"] == str(sample_py_file)
assert result["start_line"] == 1
assert result["end_line"] == result["total_lines"]
assert "def my_func" in result["content"]
assert "class MyClass" in result["content"]
assert result["content"].startswith('"""Sample module."""')
@pytest.mark.asyncio
async def test_full_file_includes_all_lines(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["total_lines"] >= 8
assert result["content"].count("\n") >= result["total_lines"] - 1
# ---------------------------------------------------------------------------
# Symbol slicing — happy paths
# ---------------------------------------------------------------------------
class TestSymbolSlicing:
@pytest.mark.asyncio
async def test_python_function(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
assert result["is_error"] is False
assert result["symbol"] == "my_func"
assert result["symbol_kind"] == "function"
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
# Should NOT include the class below.
assert "class MyClass" not in result["content"]
@pytest.mark.asyncio
async def test_python_class_includes_method(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
assert result["is_error"] is False
assert result["symbol"] == "MyClass"
assert result["symbol_kind"] == "class"
assert "class MyClass" in result["content"]
assert "def method_a" in result["content"] # method included
@pytest.mark.asyncio
async def test_python_method_directly(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
assert result["is_error"] is False
assert result["symbol"] == "method_a"
assert result["symbol_kind"] == "method"
assert "def method_a" in result["content"]
@pytest.mark.asyncio
async def test_typescript_function(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
assert result["is_error"] is False
assert result["symbol"] == "renderComponent"
assert "renderComponent" in result["content"]
# Should not include the class below.
assert "BaseService" not in result["content"]
@pytest.mark.asyncio
async def test_typescript_class(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
assert result["is_error"] is False
assert result["symbol"] == "BaseService"
assert result["symbol_kind"] == "class"
assert "BaseService" in result["content"]
# ---------------------------------------------------------------------------
# Symbol slicing — edge cases
# ---------------------------------------------------------------------------
class TestSymbolSlicingEdgeCases:
@pytest.mark.asyncio
async def test_symbol_not_found_lists_available(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
assert result["is_error"] is False # soft error, not hard
assert result["content"] == ""
assert result["symbol"] == "nonexistent"
available = result["available_symbols"]
assert "my_func" in available
assert "MyClass" in available
assert "method_a" in available
assert "nonexistent" not in result["content"]
@pytest.mark.asyncio
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\nworld\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="anything")
assert result["is_error"] is False
assert result["content"] == "# Hello\nworld\n"
assert "symbol extraction not supported" in result["note"]
assert ".md" in result["note"]
@pytest.mark.asyncio
async def test_empty_file(self, tmp_path):
path = tmp_path / "empty.py"
path.write_text("", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path))
assert result["is_error"] is False
assert result["content"] == ""
assert result["total_lines"] == 0
@pytest.mark.asyncio
async def test_file_with_no_symbols(self, tmp_path):
path = tmp_path / "data.py"
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="PI")
# PI is not a def/class — extractor finds no symbols; soft error lists available.
assert result["is_error"] is False
assert result["content"] == ""
assert result["available_symbols"] == []
# ---------------------------------------------------------------------------
# Error paths
# ---------------------------------------------------------------------------
class TestReadFileToolErrors:
@pytest.mark.asyncio
async def test_path_required(self):
tool = ReadFileTool()
result = await tool.execute()
assert result["is_error"] is True
assert "path" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_empty_string(self):
tool = ReadFileTool()
result = await tool.execute(path="")
assert result["is_error"] is True
@pytest.mark.asyncio
async def test_file_not_found(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path / "missing.py"))
assert result["is_error"] is True
assert "not found" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_is_directory(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path))
assert result["is_error"] is True
assert "directory" in result["error"].lower()
# ---------------------------------------------------------------------------
# Manual line slicing
# ---------------------------------------------------------------------------
class TestManualLineSlicing:
@pytest.mark.asyncio
async def test_start_and_end_line(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=3,
end_line=5,
)
assert result["is_error"] is False
assert result["start_line"] == 3
assert result["end_line"] == 5
# Lines 3-5 of the sample file:
# line 3: "def my_func():"
# line 4: " return 42"
# line 5: "" (blank)
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
@pytest.mark.asyncio
async def test_start_line_only_extends_to_eof(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=8)
assert result["is_error"] is False
assert result["start_line"] == 8
assert result["end_line"] == result["total_lines"]
@pytest.mark.asyncio
async def test_end_line_only_starts_at_one(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), end_line=2)
assert result["is_error"] is False
assert result["start_line"] == 1
assert result["end_line"] == 2
@pytest.mark.asyncio
async def test_invalid_start_line_zero(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=0)
assert result["is_error"] is True
assert "start_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_end_before_start(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=5,
end_line=3,
)
assert result["is_error"] is True
assert "end_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_manual_lines_override_symbol(self, sample_py_file):
# Per plan U1 Approach: "start_line/end_line overrides symbol".
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
symbol="my_func",
start_line=1,
end_line=1,
)
assert result["is_error"] is False
# Manual slicing won — symbol field absent.
assert "symbol" not in result or result.get("symbol") is None
assert result["start_line"] == 1
assert result["end_line"] == 1
# ---------------------------------------------------------------------------
# Integration — tool registry discovery
# ---------------------------------------------------------------------------
class TestToolRegistryDiscovery:
def test_instantiable_without_args(self):
# Default constructor — matches the convention used by ToolRegistry
# to instantiate tools by class.
tool = ReadFileTool()
assert tool.name == "read_file"
def test_to_dict_serializable(self):
tool = ReadFileTool()
d = tool.to_dict()
assert d["name"] == "read_file"
assert "input_schema" in d
assert "output_schema" in d
assert d["tags"] == ["io", "file", "read"]

View File

@ -0,0 +1,359 @@
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
Covers R22 (file reading supports symbol/function granularity) and KTD1
(Python ast + language-aware regex, no tree-sitter dependency).
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.symbol_extractor import (
AstSymbolExtractor,
RegexSymbolExtractor,
SymbolSpan,
extract_symbols_from_file,
get_extractor,
language_for_extension,
)
# ---------------------------------------------------------------------------
# language_for_extension
# ---------------------------------------------------------------------------
class TestLanguageForExtension:
def test_python_extensions(self):
assert language_for_extension("py") == "py"
assert language_for_extension(".py") == "py"
assert language_for_extension(".PY") == "py" # case-insensitive
def test_typescript_javascript(self):
assert language_for_extension(".ts") == "ts"
assert language_for_extension(".tsx") == "ts"
assert language_for_extension(".js") == "js"
assert language_for_extension(".jsx") == "js"
assert language_for_extension(".mjs") == "js"
assert language_for_extension(".cjs") == "js"
def test_go_rust_java(self):
assert language_for_extension(".go") == "go"
assert language_for_extension(".rs") == "rs"
assert language_for_extension(".java") == "java"
def test_unsupported_returns_empty(self):
assert language_for_extension(".md") == ""
assert language_for_extension(".txt") == ""
assert language_for_extension("") == ""
assert language_for_extension(".unknown") == ""
# ---------------------------------------------------------------------------
# AstSymbolExtractor — Python
# ---------------------------------------------------------------------------
class TestAstSymbolExtractor:
extractor = AstSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
def test_syntax_error_returns_empty(self):
# Never raises — callers rely on this for fallback routing.
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
assert result == []
def test_top_level_function(self):
content = "def my_func():\n return 42\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
span = spans[0]
assert span.name == "my_func"
assert span.kind == "function"
assert span.start_line == 1
assert span.end_line == 2
def test_async_function(self):
content = "async def fetch():\n return 1\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
assert spans[0].name == "fetch"
assert spans[0].kind == "function"
def test_top_level_class(self):
content = textwrap.dedent('''
class MyClass:
"""docstring"""
def method_a(self):
return 1
async def method_b(self):
return 2
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = [s.name for s in spans]
assert "MyClass" in names
assert "method_a" in names
assert "method_b" in names
cls = next(s for s in spans if s.name == "MyClass")
assert cls.kind == "class"
assert cls.start_line == 1
# Class body extends through the last method's end_lineno.
assert cls.end_line >= 7
def test_methods_classified_as_methods(self):
content = textwrap.dedent('''
class Foo:
def bar(self):
pass
def top_level():
pass
''').strip()
spans = self.extractor.extract_symbols(content, "py")
by_name = {s.name: s for s in spans}
assert by_name["bar"].kind == "method"
assert by_name["top_level"].kind == "function"
def test_decorated_function(self):
content = textwrap.dedent('''
@staticmethod
def helper():
return "hi"
''').strip()
spans = self.extractor.extract_symbols(content, "py")
# Note: extractor uses node.lineno (def line) — decorators above are
# excluded by design (matches user-visible symbol start at `def`).
assert any(s.name == "helper" for s in spans)
span = next(s for s in spans if s.name == "helper")
assert span.start_line == 2 # the `def` line
def test_nested_function(self):
content = textwrap.dedent('''
def outer():
def inner():
return 1
return inner()
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = {s.name for s in spans}
assert "outer" in names
assert "inner" in names
def test_empty_file(self):
assert self.extractor.extract_symbols("", "py") == []
def test_no_symbols_in_docstring_only_file(self):
content = '"""just a docstring"""\n'
assert self.extractor.extract_symbols(content, "py") == []
# ---------------------------------------------------------------------------
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
class TestRegexSymbolExtractor:
extractor = RegexSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
def test_typescript_function_declaration(self):
content = textwrap.dedent('''
export function renderComponent(props: Props): JSX.Element {
return <div/>;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
def test_typescript_async_function(self):
content = "async function fetchData() {\n return await fetch();\n}\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "fetchData" for s in spans)
def test_typescript_arrow_function_const(self):
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "handleClick" for s in spans)
def test_typescript_class(self):
content = textwrap.dedent('''
export abstract class BaseService {
abstract run(): void;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
def test_javascript_function(self):
content = "function foo() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "foo" for s in spans)
def test_javascript_arrow_const(self):
content = "const bar = () => 42;\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "bar" for s in spans)
def test_go_function(self):
content = textwrap.dedent('''
package main
func HandleRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}
func (s *Server) Start() {
// method
}
''').strip()
spans = self.extractor.extract_symbols(content, "go")
names = {s.name for s in spans}
assert "HandleRequest" in names
assert "Start" in names # method receiver pattern
def test_go_struct(self):
content = "type Server struct {\n Addr string\n}\n"
spans = self.extractor.extract_symbols(content, "go")
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
def test_rust_function(self):
content = textwrap.dedent('''
pub fn process(input: &str) -> Result<usize, Error> {
Ok(input.len())
}
async fn fetch() -> Bytes {
unimplemented!()
}
''').strip()
spans = self.extractor.extract_symbols(content, "rs")
names = {s.name for s in spans}
assert "process" in names
assert "fetch" in names
def test_rust_struct(self):
content = "pub struct Config {\n pub path: String,\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
def test_rust_impl(self):
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
def test_java_class(self):
content = textwrap.dedent('''
package com.example;
public class UserService {
public User findById(long id) {
return null;
}
}
''').strip()
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
def test_java_method(self):
content = "public User findById(long id) {\n return null;\n}\n"
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "findById" and s.kind == "function" for s in spans)
def test_end_line_extends_to_next_symbol(self):
# First symbol's end_line is the line before the second symbol starts.
content = textwrap.dedent('''
function first() {
return 1;
}
function second() {
return 2;
}
''').strip()
spans = self.extractor.extract_symbols(content, "js")
spans.sort(key=lambda s: s.start_line)
first = spans[0]
second = spans[1]
assert first.name == "first"
assert second.name == "second"
assert first.end_line == second.start_line - 1
def test_last_symbol_end_line_is_eof(self):
content = "function only() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert len(spans) == 1
assert spans[0].end_line == len(content.splitlines())
# ---------------------------------------------------------------------------
# get_extractor + integration
# ---------------------------------------------------------------------------
class TestGetExtractor:
def test_python_returns_ast_extractor(self):
ext = get_extractor("py")
assert ext is not None
assert isinstance(ext, AstSymbolExtractor)
def test_typescript_returns_regex_extractor(self):
ext = get_extractor("ts")
assert ext is not None
assert isinstance(ext, RegexSymbolExtractor)
def test_unsupported_returns_none(self):
assert get_extractor("md") is None
assert get_extractor("") is None
assert get_extractor("unknown") is None
class TestExtractSymbolsFromFile:
def test_python_file(self, tmp_path):
path = tmp_path / "module.py"
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == "py"
assert any(s.name == "hello" for s in spans)
def test_unsupported_extension(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == ""
assert spans == []
def test_missing_file_returns_empty(self, tmp_path):
path = tmp_path / "nonexistent.py"
spans, lang = extract_symbols_from_file(path)
# lang is detected from extension even if read fails.
assert lang == "py"
assert spans == []
# ---------------------------------------------------------------------------
# SymbolSpan dataclass
# ---------------------------------------------------------------------------
class TestSymbolSpan:
def test_frozen_dataclass(self):
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
assert span.name == "foo"
with pytest.raises(Exception):
span.name = "bar" # type: ignore[misc] — frozen
def test_equality(self):
a = SymbolSpan("foo", "function", 1, 3)
b = SymbolSpan("foo", "function", 1, 3)
assert a == b
assert hash(a) == hash(b)