Compare commits
8 Commits
92fc38de7e
...
f50d3485ea
| Author | SHA1 | Date |
|---|---|---|
|
|
f50d3485ea | |
|
|
5a0554b27f | |
|
|
da4eef1349 | |
|
|
6efd5957f6 | |
|
|
be4ac797b2 | |
|
|
58ef1719cb | |
|
|
e3f69f963c | |
|
|
a2dcde01b8 |
|
|
@ -28,6 +28,42 @@ llm:
|
|||
coding: bailian-coding/qwen3-coder-plus
|
||||
chat: deepseek/deepseek-chat
|
||||
reasoning: deepseek/deepseek-reasoner
|
||||
# G4/U1: Auxiliary model for cost-sensitive tasks (summarization).
|
||||
# When set, ContextCompressor tries this alias first, falling back to
|
||||
# the main model on failure or empty content. Commented to preserve
|
||||
# default behavior — uncomment to enable.
|
||||
# auxiliary_model: fast
|
||||
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
|
||||
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
|
||||
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
|
||||
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
|
||||
# not `git checkout .` which would wipe unrelated changes).
|
||||
rollback:
|
||||
default_timeout: 30.0
|
||||
# G7/U3: Three-tier fallback chain at chat REST send_message.
|
||||
# main → Recovery (ReflexionEngine retry) → Emergency (rule-based classifier).
|
||||
# Wired only at chat REST path (KTD5); CLI / ReWOO / Reflexion internal
|
||||
# ReAct calls bypass the chain (no recursive loop).
|
||||
fallback_chain:
|
||||
enabled: true
|
||||
recovery:
|
||||
enabled: true
|
||||
max_retries: 1 # ReflexionEngine max_reflections override
|
||||
emergency:
|
||||
enabled: true
|
||||
# G6/U2: PLAN_EXEC phase policy — SOLO four-stage state machine.
|
||||
# When `enabled: true`, chat WebSocket PLAN_EXEC requests build a PhasePolicy
|
||||
# (Planning → Building → Verification → Delivery) and enforce per-step tool
|
||||
# whitelists (R24). Transitions are LLM-driven via AdvancePhaseTool; set
|
||||
# `auto_advance_after_steps` to auto-advance as a safety net (KTD6).
|
||||
# Commented to preserve default behavior — uncomment to enable.
|
||||
# plan_exec:
|
||||
# enabled: true
|
||||
# auto_advance_after_steps: 5 # optional, default = manual (LLM calls advance_phase)
|
||||
# start_phase: planning # optional, default = planning
|
||||
# whitelist_override: # optional, merges with default (override wins)
|
||||
# planning: [search, read_file, shell]
|
||||
# building: [write_file, shell, read_file]
|
||||
session: {backend: memory}
|
||||
bus: {backend: memory}
|
||||
task_store: {backend: memory}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,481 @@
|
|||
---
|
||||
date: 2026-06-29
|
||||
type: feat
|
||||
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
|
||||
---
|
||||
|
||||
# Wave 2 — Auxiliary LLM Routing, Three-Tier Fallback, Atomic Subtask Rollback
|
||||
|
||||
## Summary
|
||||
|
||||
Wave 2 of the advanced-agent gap optimization brainstorm. Three medium-coupling gaps: G4 routes summary calls through a cheaper auxiliary LLM (falling back to main on failure), G7 introduces a three-tier fallback chain (main → Recovery via existing `ReflexionEngine` → Emergency with structured errors), G9 binds atomic subtask rollback to `PlanPhase` (opt-in via `rollback_command`, coordinated with the existing U7 `PipelineCheckpoint`).
|
||||
|
||||
---
|
||||
|
||||
## Problem Frame
|
||||
|
||||
Wave 1 shipped self-contained quick wins (G1/G2/G3/G8). Wave 2 addresses the medium-coupling gaps that touch multiple layers and resolve two deferred-to-planning decisions surfaced in the brainstorm: G7 Emergency layer rule template shape, and G9 `rollback_command` default behavior. The constraints these gaps must respect come from `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — new fields must preserve existing contracts, dynamic plan mutations must persist immediately, and `PipelineCheckpoint` is in-memory dict + Redis fallback (not a DB row lock).
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
Carried from `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` § Wave 2 (R13-R21):
|
||||
|
||||
- R13. `ContextCompressor` summary task routes to auxiliary model (cheap, e.g. Gemini Flash / Doubao lite), not main model.
|
||||
- R14. `auxiliary_model` is configurable, separated from main `model`.
|
||||
- R15. Summary quality does not degrade: auxiliary failure falls back to main model (not to simple_summary).
|
||||
- R16. Main agent failure triggers Recovery layer (reuse `ReflexionEngine` Evaluate→Reflect→Retry).
|
||||
- R17. Recovery failure triggers Emergency layer (rule-based fallback, structured error + suggestion).
|
||||
- R18. Fallback chain is configurable (per-layer max retries, enable/disable Recovery/Emergency).
|
||||
- R19. `PlanPhase` (`experts/plan.py:147`) gains optional `validation_command` and `rollback_command` fields.
|
||||
- R20. Phase failure auto-executes `rollback_command` when configured (default `git checkout` pattern); coordinated with U7 checkpoint.
|
||||
- R21. `checkpoint.save` runs only after rollback validation passes (avoid persisting failed state).
|
||||
|
||||
Cross-cutting: R26 (configuration via `agentkit.yaml`), R27 (each gap ships a minimal self-check test, per ponytail rule).
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
KTD1. **Auxiliary routing is compressor-scoped, not gateway-level.** `LLMGateway` already has a per-model fallback chain (`gateway.py:170-227`), but it semantics are "main fails → try backup model", not "route by task type". G4 adds an `auxiliary_model` parameter to `ContextCompressor.__init__`; `_summarize` tries auxiliary first, falls back to main on failure. This avoids touching the gateway's existing fallback semantics and keeps the change to one file plus config wiring.
|
||||
|
||||
KTD2. **Emergency layer adds a parallel `error_struct` field, not replacing `error_message`.** `TaskResult.error_message: str | None` is serialized to API responses (`protocol.py:132-145`) — changing its type breaks frontend contracts. New `error_struct: dict | None` field carries `{error_code, message, suggestions, retryable}`. `error_message` remains the human-readable string (now derived from `error_struct.message` when present). Backward compatibility preserved.
|
||||
|
||||
KTD3. **Emergency rule classification by exception type, not by status field.** `TaskTimeoutError` → `timeout`, `LoopDetectedError` → `loop_detected`, `LLMProviderError` → `llm_failure`, `TaskCancelledError` → `cancelled` (not Emergency-eligible, propagates), generic `Exception` → `internal_error`. Each rule maps to a fixed message + suggestion list (no LLM-driven suggestions — pure rule-based, per brainstorm).
|
||||
|
||||
KTD4. **Recovery layer reuses `ReflexionEngine` with main model, not auxiliary.** `ReflexionEngine.execute()` already supports `evaluate_model`/`reflect_model` parameters (`reflexion.py:94-111`), defaulting to the main model. Recovery triggers on main agent failure; using the main model for evaluate/reflect maximizes diagnostic reliability (the same model that failed is best positioned to reflect on its own failure). Auxiliary model is reserved for G4's cost-sensitive compression path.
|
||||
|
||||
KTD5. **Three-tier chain wires at `chat.py:613` (WebSocket main path), not at `ReActEngine.execute()`.** ReAct has 5 call sites (`chat.py:613`, `rewoo.py:1204`, `cli/chat.py:338`, `reflexion.py:216`, `config_driven.py:695`). Wrapping at ReAct would force all 5 sites through Recovery/Emergency, including ReflexionEngine itself (creating a recursive loop). Wrapping at `chat.py:613` covers the primary user-facing path; CLI and ReWOO are out of scope (deferred to follow-up).
|
||||
|
||||
KTD6. **Rollback is opt-in via `rollback_command` field; no implicit default.** Brainstorm KTD5 mentioned "default `git checkout`" but auto-executing `git checkout` on every phase failure changes existing behavior (Finding 1: "新字段默认值须保持既有契约"). Resolution: `rollback_command: str | None = None`. When unset, no rollback executes (preserves existing contract). When set, the configured command runs after validation failure, before checkpoint save. `git checkout <files>` is the canonical pattern documented in YAML examples.
|
||||
|
||||
KTD7. **Rollback executes via `RollbackExecutor`, not via `ShellTool`.** `ShellTool._is_dangerous` would trigger `confirm_callback` for `git checkout` (not in `_SAFE_COMMAND_PREFIXES`, not in `_DANGEROUS_BINARY_FLAGS`). Orchestrator-internal execution bypasses ShellTool — uses `asyncio.create_subprocess_shell` with `cwd`/`timeout`/`proc.kill()` pattern from `verification_loop.py:67-103`. Audit logging added (not a whitelist concern; terminal_whitelist only governs `/api/v1/terminal/server` route per `terminal_server.py:227`).
|
||||
|
||||
KTD8. **G9 checkpoint ordering: validation → rollback → checkpoint save.** Currently `orchestrator.py:265` saves checkpoint unconditionally after phase finalizes (success or failure). R21 requires checkpoint save only after rollback validation passes. New ordering: phase fails → mark FAILED → mark dependents FAILED → execute `validation_command` (if set) → if validation fails, execute `rollback_command` (if set) → save checkpoint (only if rollback validation passed or no rollback configured).
|
||||
|
||||
---
|
||||
|
||||
## High-Level Technical Design
|
||||
|
||||
### Three-tier fallback chain (G7)
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> Main: chat request
|
||||
Main --> MainSuccess: success
|
||||
Main --> Recovery: failure (exception or empty_fallback)
|
||||
Recovery --> RecoverySuccess: ReflexionEngine retry succeeds
|
||||
Recovery --> Emergency: max_reflections exhausted
|
||||
Emergency --> [*]: emit error_struct, terminal
|
||||
MainSuccess --> [*]: normal response
|
||||
RecoverySuccess --> [*]: recovered response
|
||||
```
|
||||
|
||||
### G9 phase failure + rollback sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant O as TeamOrchestrator
|
||||
participant P as PlanPhase
|
||||
participant RE as RollbackExecutor
|
||||
participant CP as PipelineCheckpoint
|
||||
|
||||
O->>P: execute phase
|
||||
P-->>O: failure (exception)
|
||||
O->>O: mark FAILED + dependents FAILED
|
||||
alt validation_command set
|
||||
O->>RE: run validation_command
|
||||
RE-->>O: ValidationResult
|
||||
alt validation fails AND rollback_command set
|
||||
O->>RE: run rollback_command
|
||||
RE-->>O: RollbackResult
|
||||
alt rollback validation passes
|
||||
O->>CP: save checkpoint
|
||||
else rollback validation fails
|
||||
O->>O: log rollback_failed, skip checkpoint
|
||||
end
|
||||
else no rollback_command
|
||||
O->>CP: save checkpoint (no rollback)
|
||||
end
|
||||
else no validation_command
|
||||
O->>CP: save checkpoint (no validation)
|
||||
end
|
||||
```
|
||||
|
||||
### Configuration shape (YAML)
|
||||
|
||||
```yaml
|
||||
# G4: Auxiliary LLM for compression
|
||||
llm:
|
||||
auxiliary_model: fast # alias resolved via existing model_aliases mechanism
|
||||
|
||||
# G7: Three-tier fallback chain
|
||||
fallback_chain:
|
||||
enabled: true
|
||||
recovery:
|
||||
enabled: true
|
||||
max_retries: 1 # ReflexionEngine max_reflections override
|
||||
emergency:
|
||||
enabled: true
|
||||
|
||||
# G9: Rollback configuration (per-phase opt-in via PlanPhase.rollback_command)
|
||||
rollback:
|
||||
default_timeout: 30.0 # RollbackExecutor subprocess timeout
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. G4 — Auxiliary LLM Routing in ContextCompressor
|
||||
|
||||
**Goal:** Route `_summarize` LLM calls through `auxiliary_model` when configured; fall back to main model on auxiliary failure (not to `_simple_summary`).
|
||||
|
||||
**Requirements:** R13, R14, R15, R26, R27
|
||||
|
||||
**Dependencies:** none (self-contained)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/core/compressor.py` (add `auxiliary_model` param + routing in `_summarize`)
|
||||
- Modify: `src/agentkit/llm/config.py` (add `auxiliary_model: str | None` field to `LLMConfig`)
|
||||
- Modify: `src/agentkit/server/config.py` (`_build_llm_config` reads `auxiliary_model` from llm section)
|
||||
- Modify: `agentkit.yaml` (document `llm.auxiliary_model: fast` in llm section)
|
||||
- Create: `tests/unit/test_compressor_auxiliary.py`
|
||||
|
||||
**Approach:** `ContextCompressor.__init__` gains `auxiliary_model: str | None = None` param (after existing `model` param). `_summarize` (compressor.py:123-158) restructuring:
|
||||
|
||||
1. If `auxiliary_model` is set and differs from `model`, try `auxiliary_model` first via `self._llm_gateway.chat(model=self._auxiliary_model, ...)`.
|
||||
2. **Truthiness check (Finding 4 anti-pattern):** treat empty `response.content` (None or whitespace-only) as failure, not success. On failure, log and fall through to main model.
|
||||
3. If auxiliary succeeds with non-empty content, return it.
|
||||
4. If auxiliary fails (exception OR empty content), retry with main `self._model`.
|
||||
5. Existing `except Exception → _simple_summary` block remains as the final degradation tier.
|
||||
|
||||
`LLMConfig.auxiliary_model` reads from `data.get("auxiliary_model")` in `_build_llm_config`. Agentkit.yaml already declares `fast: bailian-coding/qwen-turbo` alias — this is the canonical auxiliary target.
|
||||
|
||||
**Execution note:** characterization-first. Capture current `_summarize` behavior (single model, exception → simple_summary) with `auxiliary_model=None` tests, then add `auxiliary_model="fast"` tests for new routing behavior.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `LLMGateway._get_models_to_try` (`gateway.py:170-227`) — existing fallback chain pattern
|
||||
- Wave 1's `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections, `config.py:239-273`)
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: `auxiliary_model="fast"` set, auxiliary returns non-empty content → result is auxiliary content, main model not called
|
||||
- Empty content fallback: auxiliary returns `content=""` (or `None`) → main model called, main content returned (covers Finding 4 anti-pattern)
|
||||
- Auxiliary exception: auxiliary raises `LLMProviderError` → main model called, main content returned
|
||||
- Both fail: auxiliary raises, main raises → `_simple_summary` returned (existing degradation preserved)
|
||||
- Characterization: `auxiliary_model=None` → behavior matches current code (single model call)
|
||||
- Config wiring: `LLMConfig.from_dict` reads `auxiliary_model` field from dict; `ServerConfig._build_llm_config` passes it through
|
||||
- Audit: auxiliary call uses `agent_name="compressor"`, `task_type="summarization"` (preserved for usage tracking)
|
||||
|
||||
**Verification:** All test scenarios pass; existing `tests/unit/test_compressor*.py` (if any) still pass with `auxiliary_model=None` default; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
### U2. G7 — Emergency Layer Rule Template + TaskResult Extension
|
||||
|
||||
**Goal:** Add rule-based Emergency classifier with structured error output. No wiring yet — just the infrastructure (classifier + data structure + `fallback.py` extension).
|
||||
|
||||
**Requirements:** R17, R18, R26, R27
|
||||
|
||||
**Dependencies:** none (foundation for U3)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/core/fallback.py` (add `EmergencyRules` class + `EmergencyError` dataclass, preserve existing 3 constants)
|
||||
- Modify: `src/agentkit/core/protocol.py` (add `error_struct: dict | None = None` field to `TaskResult`, update `to_dict`)
|
||||
- Create: `tests/unit/test_emergency_rules.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
`EmergencyError` dataclass (in `fallback.py`):
|
||||
```
|
||||
@dataclass
|
||||
class EmergencyError:
|
||||
error_code: str # "timeout"|"loop_detected"|"llm_failure"|"internal_error"
|
||||
message: str # human-readable Chinese message (mirrors EMPTY_LLM_RESPONSE style)
|
||||
suggestions: list[str] # actionable user-facing suggestions
|
||||
retryable: bool # whether user retry might succeed
|
||||
original_error: str # str(exc) for traceability
|
||||
|
||||
def to_dict(self) -> dict: ...
|
||||
def to_error_message(self) -> str: ... # formatted "message\n建议:1) ... 2) ..."
|
||||
```
|
||||
|
||||
`EmergencyRules` class (rule-based classifier, no LLM):
|
||||
```
|
||||
class EmergencyRules:
|
||||
@staticmethod
|
||||
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
|
||||
# Match by exception type (TaskTimeoutError, LoopDetectedError, LLMProviderError, etc.)
|
||||
# config allows per-rule customization (suggestion overrides, retryable overrides)
|
||||
```
|
||||
|
||||
Rule mapping (initial set, expandable via config):
|
||||
- `TaskTimeoutError` → `error_code="timeout"`, `retryable=True`, suggestions: ["稍后重试", "简化任务范围"]
|
||||
- `LoopDetectedError` → `error_code="loop_detected"`, `retryable=True`, suggestions: ["拆分任务", "检查工具参数"]
|
||||
- `LLMProviderError` → `error_code="llm_failure"`, `retryable=True`, suggestions: ["稍后重试", "切换模型"]
|
||||
- `TaskCancelledError` → not classified (propagates as-is, Emergency not triggered)
|
||||
- Generic `Exception` → `error_code="internal_error"`, `retryable=False`, suggestions: ["联系管理员"]
|
||||
|
||||
`TaskResult` extension:
|
||||
```
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
# ... existing fields ...
|
||||
error_message: str | None # unchanged
|
||||
error_struct: dict | None = None # NEW: serialized EmergencyError.to_dict()
|
||||
```
|
||||
|
||||
`to_dict` includes `error_struct` when set; `error_message` continues to hold the human-readable string.
|
||||
|
||||
**Patterns to follow:**
|
||||
- Existing `EMPTY_LLM_RESPONSE` / `MAX_STEPS_REACHED` style in `fallback.py` (Chinese, "建议:..." format)
|
||||
- `VerificationResult.errors: list[str]` field pattern from `verification_loop.py:18-24`
|
||||
- `TaskResult.to_dict` pattern at `protocol.py:132-145`
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: `classify(TaskTimeoutError(...))` returns `EmergencyError(error_code="timeout", retryable=True)`
|
||||
- Each exception type maps to correct `error_code`
|
||||
- `TaskCancelledError` is NOT classified (caller must check before invoking `classify`)
|
||||
- `to_dict()` produces all 5 fields
|
||||
- `to_error_message()` formats suggestions as "建议:1) ... 2) ..."
|
||||
- `EmergencyRules.classify(Exception("unknown"))` → `error_code="internal_error"`, `retryable=False`
|
||||
- Config override: custom suggestion list for `timeout` rule via config dict
|
||||
- `TaskResult` with `error_struct` set: `to_dict()` includes both `error_message` and `error_struct`
|
||||
- `TaskResult` with `error_struct=None` (default): `to_dict()` matches current behavior (backward compat)
|
||||
|
||||
**Verification:** All scenarios pass; `fallback.py` existing 3 constants unchanged; `TaskResult.to_dict` for tasks without `error_struct` matches pre-change output byte-for-byte.
|
||||
|
||||
---
|
||||
|
||||
### U3. G7 — Three-Tier Fallback Chain Wiring
|
||||
|
||||
**Goal:** Wire main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) at `chat.py:613`. Composes U2's infrastructure with existing `ReflexionEngine`.
|
||||
|
||||
**Requirements:** R16, R18, R26
|
||||
|
||||
**Dependencies:** U2 (EmergencyRules + TaskResult.error_struct)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/server/routes/chat.py` (wrap main agent call at L613 with three-tier chain)
|
||||
- Modify: `src/agentkit/server/config.py` (add `fallback_chain` section to `from_dict`)
|
||||
- Modify: `agentkit.yaml` (document `fallback_chain:` section)
|
||||
- Create: `tests/unit/test_fallback_chain.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
New helper module or inline function in `chat.py`:
|
||||
```
|
||||
async def execute_with_fallback_chain(
|
||||
agent: ConfigDrivenAgent,
|
||||
task: Task,
|
||||
config: dict, # fallback_chain config section
|
||||
llm_gateway: LLMGateway,
|
||||
) -> TaskResult:
|
||||
# Tier 1: Main
|
||||
try:
|
||||
result = await agent.execute(task)
|
||||
if result.status == "success" or result.status == "completed":
|
||||
return result
|
||||
# Treat non-success as soft failure → trigger Recovery
|
||||
raise AgentExecutionError(result.error_message or "main agent did not succeed")
|
||||
except (TaskTimeoutError, LoopDetectedError, LLMProviderError, AgentExecutionError) as exc:
|
||||
if not config.get("recovery", {}).get("enabled", True):
|
||||
return _to_emergency(exc, task)
|
||||
|
||||
# Tier 2: Recovery (ReflexionEngine)
|
||||
try:
|
||||
reflexion_engine = ReflexionEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
max_reflections=config.get("recovery", {}).get("max_retries", 1),
|
||||
# ... other params from agent's existing reflexion config
|
||||
)
|
||||
recovery_result = await reflexion_engine.execute(
|
||||
messages=task.messages, # rebuild from task
|
||||
tools=agent.get_tools(),
|
||||
model=task.model,
|
||||
agent_name=agent.name,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
if recovery_result.status == "success":
|
||||
return _recovery_to_task_result(recovery_result, task)
|
||||
except Exception as recovery_exc:
|
||||
logger.warning(f"Recovery layer failed: {recovery_exc}")
|
||||
|
||||
# Tier 3: Emergency
|
||||
return _to_emergency(exc, task)
|
||||
```
|
||||
|
||||
`_to_emergency(exc, task)` constructs `EmergencyError` via `EmergencyRules.classify(exc, config)`, then returns `TaskResult(status=FAILED, error_message=emergency.to_error_message(), error_struct=emergency.to_dict(), ...)`.
|
||||
|
||||
Wiring at `chat.py:613`: replace direct `agent.execute(task)` call with `execute_with_fallback_chain(...)`. Recovery config from `server_config.fallback_chain` (new `ServerConfig.fallback_chain: dict` field, mirroring Wave 1 pattern).
|
||||
|
||||
**Recovery layer scope (KTD5):** Only `chat.py:613` is wrapped. CLI (`cli/chat.py:338`), ReWOO (`rewoo.py:1204`), ReflexionEngine's internal ReAct call, and `config_driven.py:695` are NOT wrapped (would create recursive loop or unwanted coupling). These remain on the direct-execute path. Documented in `## Scope Boundaries`.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `config_driven.py:836` — existing `ReflexionEngine` instantiation pattern (constructor params, execute call)
|
||||
- Wave 1's `verification` config section in `ServerConfig.from_dict` (`config.py:240-273`)
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: main agent succeeds → no Recovery, no Emergency triggered; `error_struct=None`
|
||||
- Main fails (timeout) → Recovery triggered → Recovery succeeds → `error_struct=None`, output from ReflexionEngine
|
||||
- Main fails (timeout) → Recovery triggered → Recovery fails (max_reflections exhausted) → Emergency triggered → `error_struct` populated with `error_code="timeout"`, `retryable=True`
|
||||
- Main fails (LoopDetectedError) → Emergency `error_code="loop_detected"`
|
||||
- Main fails (LLMProviderError) → Emergency `error_code="llm_failure"`
|
||||
- Main fails (TaskCancelledError) → propagates as-is (NOT routed to Emergency)
|
||||
- Main fails (generic Exception) → Emergency `error_code="internal_error"`, `retryable=False`
|
||||
- Config: `fallback_chain.recovery.enabled=false` → skip Recovery, go directly to Emergency
|
||||
- Config: `fallback_chain.emergency.enabled=false` → re-raise original exception (no Emergency)
|
||||
- Integration: full chain on real `ConfigDrivenAgent` with mocked LLM (mock main raises, mock ReflexionEngine succeeds)
|
||||
|
||||
**Verification:** All scenarios pass; existing chat WebSocket tests still pass; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
### U4. G9 — PlanPhase Rollback Fields + RollbackExecutor + TeamOrchestrator Integration
|
||||
|
||||
**Goal:** Add `validation_command`/`rollback_command` optional fields to `PlanPhase`; execute rollback on phase failure (opt-in); coordinate with U7 checkpoint ordering per R21.
|
||||
|
||||
**Requirements:** R19, R20, R21, R26, R27
|
||||
|
||||
**Dependencies:** none (uses existing U7 `PipelineCheckpoint` and `VerificationLoop` patterns)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/experts/plan.py` (`PlanPhase` dataclass + `to_dict`/`from_dict` symmetry)
|
||||
- Modify: `src/agentkit/experts/orchestrator.py` (insert rollback execution between phase failure and checkpoint save)
|
||||
- Create: `src/agentkit/orchestrator/rollback.py` (`RollbackExecutor` class)
|
||||
- Modify: `src/agentkit/server/config.py` (add `rollback` section to `from_dict`)
|
||||
- Modify: `agentkit.yaml` (document `rollback:` section)
|
||||
- Create: `tests/unit/test_phase_rollback.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
**PlanPhase extension (`plan.py:147-233`):** Add two optional fields with `None` default (preserves existing contract per Finding 1):
|
||||
```
|
||||
validation_command: str | None = None
|
||||
rollback_command: str | None = None
|
||||
```
|
||||
Update `to_dict()` to include both fields (only when not None, to keep dict shape minimal). Update `from_dict()` to read both fields.
|
||||
|
||||
**`RollbackExecutor` class (new file `orchestrator/rollback.py`):** Mirrors `VerificationLoop` pattern (`verification_loop.py:67-103`):
|
||||
```
|
||||
class RollbackExecutor:
|
||||
def __init__(self, working_dir: str | None = None, timeout: float = 30.0): ...
|
||||
|
||||
async def execute(self, command: str) -> RollbackResult:
|
||||
# asyncio.create_subprocess_shell with cwd/timeout/proc.kill()
|
||||
# Returns RollbackResult(passed, exit_code, stdout, stderr, command)
|
||||
|
||||
async def validate(self, command: str) -> RollbackResult:
|
||||
# Same as execute but returns passed=False on non-zero exit
|
||||
```
|
||||
|
||||
`RollbackResult` dataclass: `{passed: bool, exit_code: int, stdout: str, stderr: str, command: str}`.
|
||||
|
||||
**TeamOrchestrator integration (`orchestrator.py:246-270`):** New ordering per KTD8:
|
||||
|
||||
```
|
||||
# Existing: phase failure detected, mark FAILED + dependents FAILED
|
||||
# (lines 247-261 unchanged)
|
||||
|
||||
# NEW: rollback phase (only if validation_command and rollback_command configured)
|
||||
should_save_checkpoint = True
|
||||
if ph.validation_command and ph.rollback_command:
|
||||
validator = RollbackExecutor(working_dir=self._workspace_root, timeout=...)
|
||||
validation_result = await validator.validate(ph.validation_command)
|
||||
if not validation_result.passed:
|
||||
rollback_result = await validator.execute(ph.rollback_command)
|
||||
if not rollback_result.passed:
|
||||
# Rollback failed → don't save checkpoint (R21)
|
||||
should_save_checkpoint = False
|
||||
logger.error(f"Rollback failed for phase {ph.id}: {rollback_result.stderr}")
|
||||
# Emit phase_rollback_failed event
|
||||
await self._broadcast_event("phase_rollback_failed", {...})
|
||||
|
||||
# Existing: checkpoint save (conditional now)
|
||||
if should_save_checkpoint and self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||
except Exception as e:
|
||||
logger.warning(...)
|
||||
```
|
||||
|
||||
**Events emitted (new):** `phase_rollback_started`, `phase_rollback_completed`, `phase_rollback_failed`. Existing `phase_failed` event unchanged (emitted before rollback).
|
||||
|
||||
**Config:** `rollback.default_timeout: float = 30.0` in `ServerConfig.rollback` dict (Wave 1 pattern).
|
||||
|
||||
**Security boundary (KTD7):** Rollback subprocess does NOT go through `ShellTool` (avoids `confirm_callback` for `git checkout`). Audit log emitted via `_broadcast_event`. `terminal_whitelist` is not consulted (only governs `/api/v1/terminal/server` route, per `terminal_server.py:227`).
|
||||
|
||||
**Execution note:** characterization-first. Test current behavior (`rollback_command=None` → no rollback, checkpoint saved) before adding rollback behavior.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `VerificationLoop` subprocess execution pattern (`verification_loop.py:67-103`) — `asyncio.create_subprocess_shell` + `cwd` + `timeout` + `proc.kill()`
|
||||
- `PlanPhase.to_dict`/`from_dict` symmetric serialization pattern at `plan.py:186-233`
|
||||
- `TeamOrchestrator._execute_phase` event broadcasting pattern (`orchestrator.py:253-260`)
|
||||
- Wave 1's `ServerConfig.from_dict` extension template
|
||||
|
||||
**Test scenarios:**
|
||||
- Characterization: `PlanPhase()` with no `validation_command`/`rollback_command` → `to_dict()` output matches pre-change shape (no new keys); `from_dict({})` produces phase with both fields as None
|
||||
- Serialization: phase with `rollback_command="git checkout foo.py"` → `to_dict()` includes key; `from_dict(to_dict())` round-trips
|
||||
- RollbackExecutor happy path: `execute("git status")` returns `passed=True`, `exit_code=0`
|
||||
- RollbackExecutor timeout: `execute("sleep 10", timeout=0.1)` returns `passed=False`, exit_code=-1 (or similar)
|
||||
- RollbackExecutor failure: `execute("false")` returns `passed=False`, `exit_code=1`
|
||||
- Integration — opt-in default: phase fails, `rollback_command=None` → no RollbackExecutor call, checkpoint saved (existing behavior)
|
||||
- Integration — rollback configured, validation passes (returns 0): rollback NOT executed, checkpoint saved
|
||||
- Integration — rollback configured, validation fails, rollback succeeds (returns 0): rollback executed, checkpoint saved
|
||||
- Integration — rollback configured, validation fails, rollback fails (returns 1): rollback executed, checkpoint NOT saved (R21), `phase_rollback_failed` event emitted
|
||||
- Integration — real git repo fixture: phase writes file via `git checkout`, rollback `git checkout foo.py` restores file; assert file content matches pre-phase state
|
||||
|
||||
**Verification:** All scenarios pass; existing `tests/unit/test_pipeline_state.py` and `tests/unit/test_team_orchestrator*.py` (if any) still pass; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
|
||||
- **Recovery wiring at non-chat call sites** (`cli/chat.py:338`, `rewoo.py:1204`, `config_driven.py:695`). KTD5 limits Wave 2 to the primary chat WebSocket path. Other entry points can adopt the same wrapper in a follow-up.
|
||||
- **Recovery layer streaming events.** Wave 2's chain returns final `TaskResult`; SSE events for "recovery_started"/"recovery_completed" would require deeper WebSocket protocol changes.
|
||||
- **Patch-level rollback** (`git apply -R <patch>`). KTD6 + KTD7 scope Wave 2 to `git checkout <files>` pattern. Patch-level requires extending `CheckpointData` schema to track patches — Wave 3 candidate.
|
||||
- **DB-backed atomic claim for checkpoint.** Finding 3 noted `PipelineCheckpoint` is in-memory dict + Redis fallback, not a PostgreSQL `FOR UPDATE SKIP LOCKED` pattern. Multi-process atomicity is out of scope; single-process `asyncio.Lock` (already present at `orchestrator.py:53`) is sufficient.
|
||||
|
||||
### Outside this product's identity (carried from brainstorm)
|
||||
|
||||
- Wave 3 (G5/G6) — tree-sitter integration, SOLO four-stage state machine. Strategic direction locked in brainstorm KTD6/KTD7; implementation design deferred to Wave 3 plan.
|
||||
- Node-level checkpoint (ReAct single-step). Stage-level (U7) satisfies core need.
|
||||
- DeerFlow-style disk filesystem. Redis is the persistence layer.
|
||||
- Full LangGraph migration. Self-built architecture stays.
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
- **Risk: Auxiliary model availability.** If `auxiliary_model` alias is not configured in `agentkit.yaml`, `LLMGateway` raises `ModelNotFoundError`. Mitigation: `ContextCompressor` catches this in the auxiliary try-block and falls through to main model (R15 fallback path). Documented in YAML comments.
|
||||
- **Risk: Recovery layer recursive loop.** `ReflexionEngine.execute()` internally calls `ReActEngine.execute()` (5th call site). If Recovery itself fails, it must NOT trigger another Recovery — KTD5 wires only at `chat.py:613`, so ReflexionEngine's internal ReAct call bypasses the chain. Recursive loop is structurally impossible.
|
||||
- **Risk: `git checkout` destructive scope.** Misconfigured `rollback_command` (e.g., `git checkout .`) could wipe unrelated changes. Mitigation: rollback is opt-in per-phase (KTD6); audit log via `phase_rollback_started`/`phase_rollback_failed` events; documented YAML examples use file-scoped commands (`git checkout <specific_files>`).
|
||||
- **Risk: `TaskResult.error_struct` field addition breaks pickle/serialization.** `TaskResult` is a dataclass with `to_dict`; new optional field with `None` default is backward-compatible for `to_dict` consumers. Pickle deserialization of pre-change data still works (new field defaults to None).
|
||||
- **Dependency: Wave 1 merged.** `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections) is established and tested. Wave 2's 3 new sections (auxiliary_model, fallback_chain, rollback) reuse this exact pattern.
|
||||
- **Dependency: U7 PipelineCheckpoint already shipped.** `orchestrator/checkpoint.py:56` exists with `save(plan_id, phase, plan_status)` API; U4 only adjusts the calling order, not the API.
|
||||
|
||||
---
|
||||
|
||||
## Sources / Research
|
||||
|
||||
- Brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 2 section, R13-R21, KTD4/KTD5)
|
||||
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (ServerConfig.from_dict extension template established)
|
||||
- Finding 1 (contract preservation rule): `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — "新字段默认值须保持既有契约" drives KTD6's opt-in default for `rollback_command`
|
||||
- Finding 2 (retry-storm defense): `docs/solutions/security-issues/portal-platform-security-reliability-fixes.md` — Emergency layer's terminal `error_code` pattern (don't propagate unhandled exceptions)
|
||||
- Finding 3 (atomic claim pattern): `docs/solutions/architecture-patterns/bitable-companion-service-security-reliability-patterns.md` — `SKIP LOCKED` not reusable; `PipelineCheckpoint` is in-memory dict + Redis
|
||||
- Finding 4 (empty-response anti-pattern): `docs/solutions/ui-bugs/tauri-reload-loses-session.md` — auxiliary LLM must check truthy return value, not just absence of exception; drives U1's empty-content fallback
|
||||
- Code locations verified during planning:
|
||||
- `src/agentkit/core/compressor.py:35-44,123-158` — `_summarize` injection point for G4
|
||||
- `src/agentkit/llm/gateway.py:170-227` — existing per-model fallback chain (semantics differ from G4's task-type routing)
|
||||
- `src/agentkit/llm/config.py:198-257` — `LLMConfig` dataclass and `from_dict`
|
||||
- `src/agentkit/core/reflexion.py:68-111` — `ReflexionEngine.__init__` and `execute()` signatures (already supports `evaluate_model`/`reflect_model`)
|
||||
- `src/agentkit/core/fallback.py` (full file, 19 lines) — 3 existing constants; `EmergencyRules` adds alongside
|
||||
- `src/agentkit/core/protocol.py:118-145` — `TaskResult` dataclass and `to_dict`
|
||||
- `src/agentkit/experts/plan.py:147-233` — `PlanPhase` dataclass, `to_dict`/`from_dict`
|
||||
- `src/agentkit/experts/orchestrator.py:246-270` — phase failure capture + checkpoint save site (U4 integration point)
|
||||
- `src/agentkit/orchestrator/checkpoint.py:56` — `PipelineCheckpoint` (U7, no API change needed)
|
||||
- `src/agentkit/core/verification_loop.py:67-103` — subprocess execution pattern for `RollbackExecutor`
|
||||
- `src/agentkit/tools/shell.py:168-174,526-534` — `_DANGEROUS_BINARY_FLAGS` (git checkout not listed), `_is_dangerous` logic (KTD7 rationale)
|
||||
|
|
@ -0,0 +1,436 @@
|
|||
---
|
||||
title: "feat: Agent Wave 3 strategic coupling (G5/G6)"
|
||||
date: 2026-06-29
|
||||
type: feat
|
||||
status: draft
|
||||
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
|
||||
execution: code
|
||||
---
|
||||
|
||||
# Wave 3 Strategic Coupling — G5 Function-level Sharding + G6 SOLO Phase Constraints
|
||||
|
||||
## Summary
|
||||
|
||||
Wave 3 of the advanced-agent gap optimization closes two strategic gaps deferred from Waves 1-2:
|
||||
|
||||
- **G5 — Function-level code sharding** (R22, R23): file reading gains an optional `symbol` parameter for symbol/function-granularity slicing, backward compatible with full-file reads.
|
||||
- **G6 — SOLO four-stage state machine** (R24, R25): ReAct loop enforces per-phase tool whitelists (Planning → Building → Verification → Delivery). Extends existing `ExecutionMode.PLAN_EXEC` rather than introducing a new mode.
|
||||
|
||||
Wave 1 (G1/G2/G3/G8 — PR #4 merged) and Wave 2 (G4/G7/G9 — PR #5 open) shipped independently. Wave 3 is the **strategic-risk** wave: it introduces a new tool (G5) and touches ReAct core (G6). Per the brainstorm's KTD6/KTD7 locked decisions, G5 integration approach is decided here (not deferred further), and G6 extends PLAN_EXEC rather than adding a new mode.
|
||||
|
||||
## Problem Frame
|
||||
|
||||
The brainstorm (`docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md`) identified nine gaps across three dimensions. Waves 1-2 closed seven (G1-G4, G7-G9). The remaining two gaps are strategic:
|
||||
|
||||
- **G5 (Long-context cost)**: Large files (e.g. 5000-line modules) blow context budget when the agent only needs one function. Today the agent shells out to `cat file.py` or greps; both pull the whole file into context. A symbol-aware slice would cut context cost 10-50x for typical edits.
|
||||
- **G6 (Unsafe tool sequencing)**: ReAct loop lets the LLM call `write_file` during early exploration before committing to a plan. This wastes tokens on premature edits, causes half-baked refactors, and breaks the "plan-then-build" discipline that production agents (Qoder, Trae Work) enforce via phase state machines.
|
||||
|
||||
KTD6 (brainstorm) defers the G5 integration decision to this plan. KTD7 locks G6 to extend PLAN_EXEC rather than introduce a new mode.
|
||||
|
||||
## Requirements
|
||||
|
||||
Carried forward from the brainstorm, Wave 3 section:
|
||||
|
||||
- **R22**: File reading supports symbol/function granularity sharding.
|
||||
- **R23**: Sharding capability exposed as a tool parameter (`symbol="function_name"`), backward compatible with full-file reads.
|
||||
- **R24**: ReAct loop enforces phase constraints — Planning phase only allows `think`/`search`; Building phase only allows `write_file` (and similar write tools).
|
||||
- **R25**: Phase state is configurable; extends `ExecutionMode.PLAN_EXEC`, does NOT introduce a new mode.
|
||||
|
||||
Cross-cutting (brainstorm):
|
||||
|
||||
- **R26**: All optimizations configurable via `agentkit.yaml` (follow `ServerConfig.from_dict` pattern established in Waves 1-2).
|
||||
- **R27**: Each optimization ships a minimal self-check test (ponytail rule).
|
||||
|
||||
Acceptance examples relevant to Wave 3:
|
||||
|
||||
- **AE5 already covered by Wave 2 (G9)** — not in Wave 3 scope.
|
||||
- No explicit AE for G5/G6 in the brainstorm — the plan below specifies test scenarios as the acceptance contract.
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### KTD1: G5 uses Python `ast` + language-aware regex, NOT tree-sitter
|
||||
|
||||
**Decision**: Implement symbol extraction with the Python stdlib `ast` module for Python files, and a small regex-based extractor for TypeScript/JavaScript/Go/Rust/Java. No new dependency.
|
||||
|
||||
**Rationale**:
|
||||
- `tree-sitter` requires native compilation + per-language grammar files (~30MB installed) — violates the ponytail rule "no new dependency if it can be avoided" and AGENTS.md "禁止使用 any 类型" → prefers minimal stack.
|
||||
- `ast` is stdlib, always available, parses Python accurately.
|
||||
- Regex extractor covers 80% case for TS/JS/Go/Rust/Java (function/class/struct declarations); falls back to "no symbols found → read full file" gracefully.
|
||||
- If a future Wave 4 needs more accurate multi-language parsing, `tree-sitter` can replace the regex layer behind the same `SymbolExtractor` interface.
|
||||
|
||||
**Upgrade path**: replace `RegexSymbolExtractor` with `TreeSitterSymbolExtractor` implementing the same `SymbolExtractor` protocol; no caller changes.
|
||||
|
||||
### KTD2: G5 adds a new `ReadFileTool`, does not extend `ShellTool`
|
||||
|
||||
**Decision**: Add a dedicated `ReadFileTool` in `src/agentkit/tools/file_read.py` with `path` + optional `symbol` + optional `start_line`/`end_line` parameters.
|
||||
|
||||
**Rationale**:
|
||||
- `ShellTool` is for shell execution; grafting symbol extraction onto it muddies the contract.
|
||||
- A dedicated tool gives the LLM a clear schema (`{"path": "...", "symbol": "function_name"}`) and a focused system-prompt description.
|
||||
- Aligns with the existing `_DEFAULT_CORE_TOOLS` list in `core/react.py:148` which already references `read_file` — the name is reserved but the implementation is missing.
|
||||
|
||||
### KTD3: G6 phase state machine lives in ReActEngine, not in the skill config
|
||||
|
||||
**Decision**: Phase state (Planning/Building/Verification/Delivery) is tracked as a mutable field on `ReActEngine` instance. Transitions are driven by LLM-detected phase-completion signals (e.g., the LLM emits `Phase: Building` in its thinking) OR by an explicit `advance_phase` tool.
|
||||
|
||||
**Rationale**:
|
||||
- Skill config declares the policy (which tools per phase, auto-advance vs manual); the engine enforces it per-step. This matches R24 ("ReAct 循环加阶段约束").
|
||||
- Alternative considered: phase state in the agent instance (not engine). Rejected because ReActEngine already owns `max_steps`/`verification_enabled` etc.; phase state belongs with the loop that enforces it.
|
||||
|
||||
### KTD4: PLAN_EXEC mode is wired at chat.py WebSocket path (REST already has fallback chain from Wave 2)
|
||||
|
||||
**Decision**: chat.py:1084 (currently warns "not yet supported, falling back to REACT") will route `ExecutionMode.PLAN_EXEC` to a new `_execute_plan_exec_ws` handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and passes it to `ReActEngine.execute`.
|
||||
|
||||
**Rationale**:
|
||||
- REST `send_message` already uses the Wave 2 three-tier fallback chain; PLAN_EXEC at REST would also need that wrapper. **Out of scope for Wave 3** — only WebSocket path is wired. REST PLAN_EXEC remains "not yet supported" and explicitly raises if invoked.
|
||||
- Single integration point keeps Wave 3 scope bounded; REST wiring is a one-line follow-up once WebSocket path is proven.
|
||||
|
||||
### KTD5: Default phase whitelist matches brainstorm R24
|
||||
|
||||
**Decision**: Default whitelist:
|
||||
- `Planning`: `search`, `tool_search`, `read_file`, `bash` (read-only commands like `git status`, `ls`)
|
||||
- `Building`: `write_file`, `bash` (write commands), `read_file`, `search`
|
||||
- `Verification`: `bash` (test commands), `read_file`, `search`
|
||||
- `Delivery`: all tools (final synthesis)
|
||||
|
||||
**Rationale**:
|
||||
- R24 explicitly names `think`/`search` for Planning and `write_file` for Building.
|
||||
- `bash` is split: read-only in Planning, full in Building. Enforced by adding a `bash_command_filter` callback (regex-based, blocks `rm`/`mv`/`>`/`>>` in Planning/Verification).
|
||||
- `Delivery` allows all tools to support last-mile formatting/cleanup.
|
||||
|
||||
### KTD6: Phase transitions are LLM-driven via `advance_phase` tool (opt-in auto-advance)
|
||||
|
||||
**Decision**: Add an `AdvancePhaseTool` that the LLM can call to transition Planning→Building→Verification→Delivery. Auto-advance (after N steps in current phase) is opt-in via `plan_exec.auto_advance_after_steps`.
|
||||
|
||||
**Rationale**:
|
||||
- LLM-driven transitions match Qoder/Trae Work pattern: LLM declares "planning done" explicitly.
|
||||
- Auto-advance is a safety net for LLMs that forget to call `advance_phase`; default off (ponytail: less code is better).
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### In Scope
|
||||
|
||||
- New `ReadFileTool` with `symbol` parameter (G5, R22/R23).
|
||||
- `SymbolExtractor` protocol + `AstSymbolExtractor` (Python) + `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
|
||||
- `PhasePolicy` dataclass + `PhaseState` enum + per-step tool whitelist enforcement in `ReActEngine.execute` (G6, R24/R25).
|
||||
- `AdvancePhaseTool` for LLM-driven phase transitions.
|
||||
- WebSocket chat path routes `PLAN_EXEC` to new `_execute_plan_exec_ws` handler (KTD4).
|
||||
- `plan_exec` config section in `agentkit.yaml` + `ServerConfig.from_dict` extension (R26).
|
||||
- Tests for each new module (R27).
|
||||
|
||||
### Out of Scope (Deferred to Follow-Up Work)
|
||||
|
||||
- REST `send_message` PLAN_EXEC wiring — once WebSocket path is proven, REST wiring is a follow-up commit.
|
||||
- `tree-sitter` integration for more accurate multi-language parsing (KTD1 upgrade path).
|
||||
- Phase-aware prompt engineering (per-phase system prompt templates) — current plan keeps a single system prompt; phase-specific guidance is a prompt-engineering concern, not a code change.
|
||||
- Phase persistence across session resume (U7 checkpoint already saves plan state; phase state restoration is a separate concern).
|
||||
- Phase rollback on `Building` → `Planning` regression (Wave 2 G9 rollback handles file-level rollback; phase regression is a UX/prompt concern).
|
||||
- Tool-filter UI in the frontend (Wave 3 ships backend-only; frontend surfaces phase via existing event channel if needed in a follow-up).
|
||||
|
||||
### Outside This Product's Identity
|
||||
|
||||
- Replacing the existing ReAct loop with LangGraph (inherited from brainstorm).
|
||||
- Disc-based file system à la DeerFlow (inherited).
|
||||
- Docker sandbox (inherited; only command-level safety via `bash_command_filter`).
|
||||
|
||||
## High-Level Technical Design
|
||||
|
||||
```mermaid
|
||||
flowchart LR
|
||||
subgraph G5[Function Sharding]
|
||||
RF[ReadFileTool] --> SE[SymbolExtractor protocol]
|
||||
SE --> AST[AstSymbolExtractor<br/>Python stdlib ast]
|
||||
SE --> RX[RegexSymbolExtractor<br/>TS/JS/Go/Rust/Java]
|
||||
end
|
||||
|
||||
subgraph G6[Phase State Machine]
|
||||
PP[PhasePolicy config] --> PS[PhaseState enum<br/>Planning/Building/Verify/Delivery]
|
||||
PS --> Filt[Tool filter per step]
|
||||
Filt --> RE[ReActEngine.execute]
|
||||
AP[AdvancePhaseTool] -->|transitions| PS
|
||||
end
|
||||
|
||||
RF -->|file content for symbol| RE
|
||||
RE -->|enforces| Filt
|
||||
```
|
||||
|
||||
The two subsystems compose at the ReAct engine boundary: `ReadFileTool` is one of the tools the LLM can call during any phase (filtered by `PhasePolicy`); `PhaseState` is enforced at the tool-call step before dispatch.
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. SymbolExtractor + ReadFileTool (G5)
|
||||
|
||||
**Goal**: Add `ReadFileTool` with optional `symbol` parameter; implement `SymbolExtractor` protocol with `AstSymbolExtractor` (Python) and `RegexSymbolExtractor` (TS/JS/Go/Rust/Java).
|
||||
|
||||
**Requirements**: R22, R23, R27.
|
||||
|
||||
**Dependencies**: none.
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/tools/file_read.py` (new)
|
||||
- `src/agentkit/tools/symbol_extractor.py` (new)
|
||||
- `src/agentkit/tools/__init__.py` (modify — register `ReadFileTool`)
|
||||
- `tests/unit/test_symbol_extractor.py` (new)
|
||||
- `tests/unit/test_read_file_tool.py` (new)
|
||||
|
||||
**Approach**:
|
||||
- `SymbolExtractor` is a `Protocol` with one method: `extract_symbols(content: str, language: str) -> list[SymbolSpan]`. `SymbolSpan` carries `name`, `kind` (function/class/method/struct), `start_line`, `end_line`.
|
||||
- `AstSymbolExtractor` walks `ast.parse(content)`; for each `FunctionDef`/`AsyncFunctionDef`/`ClassDef` collects name + line range. Uses `ast.get_source_segment` style (line-based, not node-based, to keep the API simple).
|
||||
- `RegexSymbolExtractor` ships patterns for TS/JS (`function X`, `const X = (...) =>`, `class X`), Go (`func X`), Rust (`fn X`, `struct X`, `impl X`), Java (`public ... X(...)`). Falls back to "no symbols" if no pattern matches.
|
||||
- `ReadFileTool.execute(path, symbol=None, start_line=None, end_line=None)`:
|
||||
- `symbol=None` → read full file (backward compat with the existing `_FakeTool` benchmark shape).
|
||||
- `symbol="foo"` → detect language from extension; call `extract_symbols`; return the line range of the first matching symbol; if no match, return an error result with available symbol names listed (so the LLM can retry).
|
||||
- `start_line`/`end_line` overrides symbol; allows manual slicing.
|
||||
- Tool registered as `read_file` (matches the reserved name in `core/react.py:148`).
|
||||
|
||||
**Execution note**: characterization-first — write a test that asserts the tool returns the full file content when `symbol=None` (matches pre-existing benchmark `_FakeTool` shape) before adding symbol-extraction behavior.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/tools/document_tool.py` for tool structure (dataclass, `Tool` base class, `input_schema`).
|
||||
- `src/agentkit/tools/schema_tools.py:SchemaExtractTool` for "extract-from-source" pattern.
|
||||
|
||||
**Test scenarios** (covers R22, R23):
|
||||
- **Happy paths**:
|
||||
- Python file, `symbol="MyClass"` → returns class body only (lines from `class MyClass:` through end of class).
|
||||
- Python file, `symbol="my_func"` → returns function body only.
|
||||
- TypeScript file, `symbol="renderComponent"` → returns arrow/function body.
|
||||
- Go file, `symbol="HandleRequest"` → returns func body.
|
||||
- **Edge cases**:
|
||||
- `symbol=None` → returns full file content (characterization).
|
||||
- `symbol="nonexistent"` → returns error result listing available symbols ("Available symbols: foo, bar, baz").
|
||||
- Unsupported file extension (`.md`, `.txt`) → returns full file with `note: symbol extraction not supported for .md`.
|
||||
- Empty file → returns empty content.
|
||||
- File with nested classes → outer class symbol returns including inner class.
|
||||
- **Error paths**:
|
||||
- Path does not exist → raises `FileNotFoundError` (or returns error result matching other tools' convention).
|
||||
- Path is a directory → returns error result.
|
||||
- Permission denied → returns error result.
|
||||
- **Integration scenarios**:
|
||||
- Symbol extraction + line slicing: `symbol="foo"`, `end_line=50` truncates at line 50 even if symbol extends further.
|
||||
- Round-trip: extract symbol, write back via `ShellTool` `sed` (not in scope for tool — just verify extracted range is well-formed).
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_symbol_extractor.py tests/unit/test_read_file_tool.py -q` passes.
|
||||
- `ruff check src/agentkit/tools/file_read.py src/agentkit/tools/symbol_extractor.py` clean.
|
||||
- `ReadFileTool` appears in `ToolRegistry.list_tools()` after registration.
|
||||
|
||||
---
|
||||
|
||||
### U2. PhasePolicy + PhaseState + ServerConfig (G6 core)
|
||||
|
||||
**Goal**: Add `PhasePolicy` dataclass, `PhaseState` enum, default whitelist config. Extend `ServerConfig.from_dict` with `plan_exec` section. Wire config to `agentkit.yaml`.
|
||||
|
||||
**Requirements**: R25, R26.
|
||||
|
||||
**Dependencies**: none.
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/core/phase.py` (new) — `PhaseState` enum, `PhasePolicy` dataclass, `default_policy()` factory.
|
||||
- `src/agentkit/server/config.py` (modify — add `plan_exec` field + `from_dict` parsing).
|
||||
- `agentkit.yaml` (modify — document `plan_exec:` section).
|
||||
- `tests/unit/test_phase_policy.py` (new).
|
||||
|
||||
**Approach**:
|
||||
- `PhaseState = enum("planning building verification delivery")`.
|
||||
- `PhasePolicy` carries:
|
||||
- `whitelist: dict[PhaseState, set[str]]` — tool names allowed per phase.
|
||||
- `bash_command_filter: dict[PhaseState, re.Pattern | None]` — regex that bash args must NOT match (e.g., `r"\b(rm|mv|>|>>)\b"` in Planning).
|
||||
- `auto_advance_after_steps: int | None` — None = manual (LLM calls `advance_phase`); int = auto-advance after N steps.
|
||||
- `start_phase: PhaseState = PhaseState.PLANNING`.
|
||||
- `default_policy()` returns the KTD5 whitelist above.
|
||||
- `ServerConfig.from_dict` reads `plan_exec` section: `enabled`, `whitelist_override` (dict), `auto_advance_after_steps`.
|
||||
- `agentkit.yaml` gains a commented-out `plan_exec:` block (commented to preserve default behavior — opt-in).
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/core/fallback.py` for dataclass + classmethod factory pattern.
|
||||
- `src/agentkit/server/config.py` `from_dict` extension template (established in Wave 1 for `prompt_cache`/`streaming`/`verification`; Wave 2 added `rollback`/`fallback_chain`; Wave 3 adds `plan_exec`).
|
||||
|
||||
**Test scenarios** (covers R25, R26):
|
||||
- **Happy paths**:
|
||||
- `default_policy()` returns policy with all four phases; Planning whitelist contains `search`, `read_file`; Building contains `write_file`.
|
||||
- `PhasePolicy.is_tool_allowed("search", PhaseState.PLANNING)` returns True.
|
||||
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.PLANNING)` returns False.
|
||||
- `PhasePolicy.is_tool_allowed("write_file", PhaseState.BUILDING)` returns True.
|
||||
- **Edge cases**:
|
||||
- Empty whitelist for a phase → all tools rejected (raises `ValueError` at construction time — fail-fast).
|
||||
- `Delivery` phase whitelist contains `"*"` (wildcard) → all tools allowed.
|
||||
- Custom whitelist override merges with default (override wins on conflict).
|
||||
- **Error paths**:
|
||||
- Invalid phase name in config → `ValueError` with message naming the bad value.
|
||||
- `bash_command_filter` regex compile failure → `ValueError`.
|
||||
- **Config integration**:
|
||||
- `ServerConfig.from_dict({"plan_exec": {"enabled": True, "auto_advance_after_steps": 5}})` populates fields correctly.
|
||||
- `ServerConfig.from_dict({})` → `plan_exec = {}` (default).
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_phase_policy.py -q` passes.
|
||||
- `ruff check src/agentkit/core/phase.py` clean.
|
||||
|
||||
---
|
||||
|
||||
### U3. AdvancePhaseTool + ReActEngine phase enforcement (G6 wiring)
|
||||
|
||||
**Goal**: Add `AdvancePhaseTool`. Wire `PhasePolicy` into `ReActEngine.execute` so each tool-call step checks `is_tool_allowed(tool_name, current_phase)` before dispatch; blocked calls return a structured error to the LLM ("Tool 'write_file' not allowed in Planning phase — call advance_phase first").
|
||||
|
||||
**Requirements**: R24.
|
||||
|
||||
**Dependencies**: U2.
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/tools/advance_phase.py` (new) — `AdvancePhaseTool` calls `react_engine.advance_phase()`.
|
||||
- `src/agentkit/core/react.py` (modify — add `phase_policy` param to `__init__` + `execute`; add `_current_phase` field; add `advance_phase()` method; enforce in `_execute_loop`).
|
||||
- `tests/unit/test_react_phase_enforcement.py` (new).
|
||||
|
||||
**Approach**:
|
||||
- `ReActEngine.__init__` accepts `phase_policy: PhasePolicy | None = None`. None = no enforcement (backward compat — all existing callers unaffected).
|
||||
- `_current_phase: PhaseState | None` initialized from `phase_policy.start_phase` if policy set, else None.
|
||||
- `advance_phase()` advances `_current_phase` to next enum value; raises `ValueError` if already at `DELIVERY`.
|
||||
- In `_execute_loop`, before dispatching a tool call:
|
||||
```python
|
||||
if self._phase_policy is not None and self._current_phase is not None:
|
||||
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||
# Inject structured error into conversation, do NOT dispatch tool.
|
||||
# This counts as a "step" for max_steps purposes.
|
||||
observation = {
|
||||
"error": "phase_violation",
|
||||
"message": f"Tool '{tool_name}' not allowed in {self._current_phase.value} phase",
|
||||
"current_phase": self._current_phase.value,
|
||||
"hint": "Call advance_phase to move to Building phase"
|
||||
}
|
||||
continue # next loop iteration
|
||||
```
|
||||
- Auto-advance: if `phase_policy.auto_advance_after_steps` is set and `_steps_in_phase >= auto_advance_after_steps`, call `advance_phase()` automatically.
|
||||
- `AdvancePhaseTool.execute()` calls the bound engine's `advance_phase()` and returns the new phase name. Registered only when `phase_policy` is not None.
|
||||
|
||||
**Execution note**: characterization-first — test that `ReActEngine` with `phase_policy=None` behaves identically to pre-change (no enforcement, no `advance_phase` tool, no `_current_phase` mutation). Then add enforcement tests.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/core/react.py` `verification_enabled` pattern (feature flag + step-level check).
|
||||
- `src/agentkit/tools/ask_human.py` for tool that interacts with engine state.
|
||||
|
||||
**Test scenarios** (covers R24):
|
||||
- **Characterization (no policy)**:
|
||||
- `ReActEngine(phase_policy=None)` — all tools allowed in all steps; no `advance_phase` tool registered; behavior matches pre-change.
|
||||
- **Happy paths**:
|
||||
- Planning phase: LLM calls `search` → executes; LLM calls `advance_phase` → phase becomes Building.
|
||||
- Building phase: LLM calls `write_file` → executes; LLM calls `advance_phase` → phase becomes Verification.
|
||||
- Verification phase: LLM calls `bash` with `pytest` → executes; LLM calls `advance_phase` → phase becomes Delivery.
|
||||
- Delivery phase: LLM calls any tool → executes (wildcard).
|
||||
- **Edge cases**:
|
||||
- `advance_phase` called at Delivery → returns error "Already at final phase".
|
||||
- Auto-advance after 3 steps in Planning → phase transitions automatically on 4th step.
|
||||
- `bash` command in Planning contains `rm file` → blocked by `bash_command_filter`.
|
||||
- **Error paths**:
|
||||
- LLM calls `write_file` in Planning → tool NOT dispatched; structured error returned to LLM; loop continues.
|
||||
- LLM calls non-existent tool → existing error path (not phase-related).
|
||||
- **Integration scenarios**:
|
||||
- Phase transition emits a `phase_changed` event (use existing `_broadcast_event` pattern from `experts/orchestrator.py`).
|
||||
- `max_steps` reached mid-phase → `ReActResult.status = "max_steps_reached"` (existing path, no change).
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_react_phase_enforcement.py -q` passes.
|
||||
- Existing `tests/unit/test_react_engine.py` still passes (characterization — no policy = no change).
|
||||
- `ruff check src/agentkit/core/react.py src/agentkit/tools/advance_phase.py` clean.
|
||||
|
||||
---
|
||||
|
||||
### U4. Wire PLAN_EXEC at chat.py WebSocket path (G6 chat integration)
|
||||
|
||||
**Goal**: Replace the `chat.py:1084` "not yet supported, falling back to REACT" warning with a real PLAN_EXEC handler that constructs `PhasePolicy` from `ServerConfig.plan_exec` and dispatches to `ReActEngine.execute` with the policy set.
|
||||
|
||||
**Requirements**: R24, R25 (end-to-end wiring).
|
||||
|
||||
**Dependencies**: U2, U3.
|
||||
|
||||
**Files**:
|
||||
- `src/agentkit/server/routes/chat.py` (modify — add `_execute_plan_exec_ws` handler; branch on `ExecutionMode.PLAN_EXEC`).
|
||||
- `tests/unit/test_chat_plan_exec_ws.py` (new).
|
||||
|
||||
**Approach**:
|
||||
- New helper `_execute_plan_exec_ws(websocket, agent, routing, messages, ...)`:
|
||||
1. Read `server_config.plan_exec` (may be `{}` if not configured → use `default_policy()`).
|
||||
2. Build `PhasePolicy` from config (apply overrides).
|
||||
3. Construct `ReActEngine(..., phase_policy=policy)`.
|
||||
4. Register `AdvancePhaseTool` bound to this engine.
|
||||
5. Call `engine.execute_stream(...)` — reuses existing streaming path.
|
||||
6. Emit `phase_changed` events through the WebSocket (frontend can render phase indicator).
|
||||
- chat.py:1084 changes from `if execution_mode not in (REACT, SKILL_REACT): warn + fall back` to:
|
||||
```python
|
||||
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||
await _execute_plan_exec_ws(websocket, agent, routing, ...)
|
||||
return
|
||||
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
||||
# existing warning for REWOO/REFLEXION/TEAM_COLLAB
|
||||
...
|
||||
```
|
||||
- REST `send_message` path: explicitly raise `HTTPException(501, "PLAN_EXEC via REST not yet supported; use WebSocket")` — Wave 3 does NOT wire REST (KTD4).
|
||||
|
||||
**Execution note**: characterization-first — test that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring.
|
||||
|
||||
**Patterns to follow**:
|
||||
- `src/agentkit/server/routes/chat.py` existing WebSocket handler structure (lines 1082-1100).
|
||||
- `src/agentkit/server/_fallback_chain.py` (Wave 2 U3) for "construct engine per-request with config" pattern.
|
||||
|
||||
**Test scenarios** (covers end-to-end):
|
||||
- **Characterization**:
|
||||
- `ExecutionMode.REWOO` via WebSocket → still falls back to REACT with warning (existing behavior unchanged).
|
||||
- `ExecutionMode.REFLEXION` → same.
|
||||
- `ExecutionMode.TEAM_COLLAB` → same.
|
||||
- **Happy paths**:
|
||||
- `ExecutionMode.PLAN_EXEC` via WebSocket → `_execute_plan_exec_ws` invoked; `ReActEngine` constructed with `phase_policy`; `AdvancePhaseTool` registered.
|
||||
- Planning phase: LLM emits `search` tool call → executed; tool result streamed.
|
||||
- LLM emits `advance_phase` → `phase_changed` event sent to WebSocket client; subsequent `write_file` call now allowed.
|
||||
- **Edge cases**:
|
||||
- `plan_exec` config absent → `default_policy()` used; behavior matches KTD5 whitelist.
|
||||
- `plan_exec.enabled=False` → falls back to REACT (opt-out).
|
||||
- Phase violation: LLM calls `write_file` in Planning → structured error returned; loop continues; `phase_violation` event emitted.
|
||||
- **Error paths**:
|
||||
- REST `send_message` with PLAN_EXEC → 501 error.
|
||||
- Phase policy construction fails (bad config) → 500 error with message.
|
||||
- **Integration scenarios**:
|
||||
- Existing fallback chain (Wave 2 U3) NOT applied to PLAN_EXEC — phase policy and fallback chain are mutually exclusive (KTD5 from Wave 2 plan: chain only wraps REACT/SKILL_REACT at REST). Document this in chat.py comment.
|
||||
|
||||
**Verification**:
|
||||
- `python3 -m pytest tests/unit/test_chat_plan_exec_ws.py -q` passes.
|
||||
- `ruff check src/agentkit/server/routes/chat.py` clean.
|
||||
- Manual test: `agentkit chat` with `@skill:plan_exec_demo` skill config → WebSocket stream includes `phase_changed` events.
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
### Risks
|
||||
|
||||
1. **ReAct core modification risk (high)**: U3 modifies `ReActEngine._execute_loop`. Mitigation: characterization-first tests (U3 Execution note); `phase_policy=None` default preserves all existing behavior; full `test_react_engine.py` regression.
|
||||
2. **Symbol extraction accuracy (medium)**: Regex extractor may miss edge cases (decorated functions, nested generics, multi-line signatures). Mitigation: fall back to "no symbols found → read full file" gracefully; never raise on extraction failure.
|
||||
3. **PLAN_EXEC phase deadlock (medium)**: LLM may never call `advance_phase`, leaving the agent stuck in Planning. Mitigation: `auto_advance_after_steps` config (default 5); timeout via existing `max_steps`.
|
||||
4. **Tool name drift (low)**: Phase whitelist references tool names (`write_file`, `search`, etc.) that may be renamed in future. Mitigation: whitelist is config-driven; rename only requires config update.
|
||||
|
||||
### Dependencies
|
||||
|
||||
- Wave 2 PR #5 (`feat/agent-wave2-medium-coupling`) should be merged first — Wave 3 builds on the `ServerConfig.from_dict` extension pattern and the `_fallback_chain.py` integration shape established there. If PR #5 is still open, Wave 3 branches from `feat/agent-wave2-medium-coupling` rather than `main`.
|
||||
- No external library dependencies (KTD1).
|
||||
|
||||
## System-Wide Impact
|
||||
|
||||
- **Agents using PLAN_EXEC mode**: gain phase enforcement. Existing REACT/SKILL_REACT/DIRECT_CHAT agents: zero change (phase_policy defaults to None).
|
||||
- **Tool registry**: gains two new tools (`read_file`, `advance_phase`). Frontend tool list display may need updating to show the new icons — out of scope for Wave 3 (frontend follows up).
|
||||
- **`agentkit.yaml`**: gains `plan_exec:` section (commented by default). Existing configs unaffected.
|
||||
- **WebSocket clients**: gain `phase_changed` event type. Existing clients ignore unknown event types (verified in Wave 2 — `phase_rollback_*` events follow the same pattern).
|
||||
|
||||
## Sources & Research
|
||||
|
||||
- Origin brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 3 section, KTD6/KTD7).
|
||||
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (PR #4 merged).
|
||||
- Wave 2 plan: `docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md` (PR #5 open).
|
||||
- Trae Work architecture research (cited in brainstorm): SOLO four-stage state machine pattern.
|
||||
- Qoder architecture research (cited in brainstorm): Spec→Coding→Verify closed loop.
|
||||
- Codebase: `src/agentkit/core/react.py:148` reserves `read_file`/`write_file` tool names in `_DEFAULT_CORE_TOOLS` — Wave 3 U1 delivers the missing `read_file` implementation.
|
||||
- Codebase: `src/agentkit/server/routes/chat.py:1084` documents that PLAN_EXEC is "not yet supported" — Wave 3 U4 closes this gap.
|
||||
|
||||
## Deferred to Implementation
|
||||
|
||||
- Exact regex patterns for non-Python symbol extraction (U1) — design above gives the shape; implementer finalizes patterns based on real-world test fixtures.
|
||||
- `bash_command_filter` regex precision (U2) — defaults block `rm`/`mv`/`>`/`>>`; implementer may add more based on test scenarios.
|
||||
- `phase_changed` event payload shape (U3/U4) — minimal viable shape: `{"phase": "building", "previous": "planning"}`; frontend rendering concerns are out of scope.
|
||||
- Whether `AdvancePhaseTool` accepts a `target_phase` argument for skipping phases (e.g., Planning → Verification) — default no (sequential only); add if test scenarios reveal a need.
|
||||
|
|
@ -41,6 +41,7 @@ class ContextCompressor:
|
|||
model_context_limit: int = 128_000,
|
||||
headroom_threshold: float = 0.8,
|
||||
min_tokens: int = 8_000,
|
||||
auxiliary_model: str | None = None,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_tokens = max_tokens
|
||||
|
|
@ -51,6 +52,11 @@ class ContextCompressor:
|
|||
self._model_context_limit = model_context_limit
|
||||
self._headroom_threshold = headroom_threshold
|
||||
self._min_tokens = min_tokens
|
||||
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
|
||||
# When set and differs from main model, _summarize tries auxiliary first,
|
||||
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
|
||||
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
|
||||
self._auxiliary_model = auxiliary_model
|
||||
|
||||
def should_compress(self, messages: list[dict]) -> bool:
|
||||
"""Check if compression should be triggered based on headroom ratio.
|
||||
|
|
@ -92,8 +98,8 @@ class ContextCompressor:
|
|||
if len(non_system) <= self._keep_recent:
|
||||
return messages # Not enough messages to compress
|
||||
|
||||
old_msgs = non_system[:-self._keep_recent]
|
||||
recent_msgs = non_system[-self._keep_recent:]
|
||||
old_msgs = non_system[: -self._keep_recent]
|
||||
recent_msgs = non_system[-self._keep_recent :]
|
||||
|
||||
# Compress old messages
|
||||
summary = await self._summarize(old_msgs)
|
||||
|
|
@ -101,10 +107,12 @@ class ContextCompressor:
|
|||
# Build compressed message list
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
compressed.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
}
|
||||
)
|
||||
compressed.extend(recent_msgs)
|
||||
|
||||
# Recursive check: if still over budget, compress again
|
||||
|
|
@ -114,22 +122,30 @@ class ContextCompressor:
|
|||
return self._truncate(compressed)
|
||||
if len(recent_msgs) > 1:
|
||||
# Try keeping fewer recent messages
|
||||
return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1)
|
||||
return await self._compress_aggressive(
|
||||
messages, _compression_depth=_compression_depth + 1
|
||||
)
|
||||
# Last resort: truncate
|
||||
return self._truncate(compressed)
|
||||
|
||||
return compressed
|
||||
|
||||
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
|
||||
"""Summarize a list of messages using LLM"""
|
||||
"""Summarize a list of messages using LLM.
|
||||
|
||||
G4/U1: When ``auxiliary_model`` is configured and differs from the main
|
||||
model, try auxiliary first (cost-optimization). On auxiliary failure OR
|
||||
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
|
||||
fall back to main model. Existing ``_simple_summary`` degradation
|
||||
preserved as the final tier when main model also fails.
|
||||
"""
|
||||
if not self._llm_gateway:
|
||||
# No LLM available, do simple truncation
|
||||
return self._simple_summary(messages)
|
||||
|
||||
# Build summary prompt
|
||||
conversation_text = "\n".join(
|
||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
|
||||
for m in messages
|
||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
|
||||
)
|
||||
|
||||
# Pre-truncate if conversation_text exceeds safe token threshold
|
||||
|
|
@ -145,6 +161,25 @@ class ContextCompressor:
|
|||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
# G4: Try auxiliary model first when configured (cheap route).
|
||||
if self._auxiliary_model and self._auxiliary_model != self._model:
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._auxiliary_model,
|
||||
agent_name="compressor",
|
||||
task_type="summarization",
|
||||
)
|
||||
# Finding 4: empty content is a failure, not a success.
|
||||
if response.content and response.content.strip():
|
||||
return response.content
|
||||
logger.info("Auxiliary model returned empty content, falling back to main model")
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Auxiliary model summarization failed, falling back to main model: {e}"
|
||||
)
|
||||
|
||||
# Main model path (or auxiliary fallback).
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
|
|
@ -166,7 +201,9 @@ class ContextCompressor:
|
|||
parts.append(f"[{role}]: {content}...")
|
||||
return "\n".join(parts)
|
||||
|
||||
async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
|
||||
async def _compress_aggressive(
|
||||
self, messages: list[dict], _compression_depth: int = 0
|
||||
) -> list[dict]:
|
||||
"""More aggressive compression when standard compression isn't enough"""
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
|
@ -176,10 +213,12 @@ class ContextCompressor:
|
|||
summary = await self._summarize(non_system[:-1])
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
compressed.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
}
|
||||
)
|
||||
compressed.append(non_system[-1])
|
||||
return compressed
|
||||
|
||||
|
|
@ -191,7 +230,7 @@ class ContextCompressor:
|
|||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
if len(content) > self._max_tokens * 4:
|
||||
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"}
|
||||
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
|
@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
if provider == "headroom":
|
||||
try:
|
||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||
|
||||
compressor = HeadroomCompressor(config)
|
||||
if compressor.is_available():
|
||||
return compressor
|
||||
|
|
@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"HeadroomCompressor module not available. "
|
||||
"Falling back to ContextCompressor."
|
||||
"HeadroomCompressor module not available. Falling back to ContextCompressor."
|
||||
)
|
||||
# Fallback to summary compressor
|
||||
return ContextCompressor(
|
||||
|
|
@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
|
||||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||||
cache_key = hashlib.md5(
|
||||
json.dumps(variables or {}, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
if not hasattr(template, '_render_cache'):
|
||||
if not hasattr(template, "_render_cache"):
|
||||
template._render_cache = {}
|
||||
|
||||
if cache_key in template._render_cache:
|
||||
|
|
@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic
|
|||
|
||||
def clear_cache(template) -> None:
|
||||
"""Clear the render cache on a PromptTemplate instance"""
|
||||
if hasattr(template, '_render_cache'):
|
||||
if hasattr(template, "_render_cache"):
|
||||
template._render_cache.clear()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,22 @@
|
|||
|
||||
All layers (ReActEngine, Portal, Chat) should use these constants
|
||||
to ensure consistent user-facing messages.
|
||||
|
||||
G7/U2: Also hosts ``EmergencyError`` and ``EmergencyRules`` for the
|
||||
three-tier fallback chain's Emergency layer (rule-based classifier,
|
||||
no LLM). See ``docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md``.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
|
||||
# When LLM returns empty content after all fallback models exhausted
|
||||
EMPTY_LLM_RESPONSE = (
|
||||
"模型未返回有效内容,已尝试备用模型仍未成功。"
|
||||
|
|
@ -16,3 +30,135 @@ MAX_STEPS_REACHED = "已达到最大推理步数,但仍未得到完整结论
|
|||
|
||||
# When a shell command succeeds but produces no output
|
||||
SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]"
|
||||
|
||||
|
||||
# ── G7/U2: Emergency layer ──────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyError:
|
||||
"""Structured error produced by the Emergency layer (rule-classified).
|
||||
|
||||
Carries a stable ``error_code`` for programmatic dispatch (frontend
|
||||
retry UI, telemetry), a human-readable ``message`` mirroring
|
||||
``EMPTY_LLM_RESPONSE`` style, actionable ``suggestions``, and the
|
||||
original exception string for traceability.
|
||||
|
||||
The ``retryable`` flag distinguishes recoverable user errors
|
||||
(timeout, loop, LLM hiccup) from internal bugs (retryable=False).
|
||||
"""
|
||||
|
||||
error_code: str # "timeout" | "loop_detected" | "llm_failure" | "internal_error"
|
||||
message: str # human-readable Chinese message
|
||||
suggestions: list[str] # actionable user-facing suggestions
|
||||
retryable: bool # whether a user retry might succeed
|
||||
original_error: str # str(exc) for traceability
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"error_code": self.error_code,
|
||||
"message": self.message,
|
||||
"suggestions": list(self.suggestions),
|
||||
"retryable": self.retryable,
|
||||
"original_error": self.original_error,
|
||||
}
|
||||
|
||||
def to_error_message(self) -> str:
|
||||
"""Format as a single human-readable string with suggestions.
|
||||
|
||||
Mirrors ``EMPTY_LLM_RESPONSE`` style: ``<message>建议:1) ... 2) ...``
|
||||
"""
|
||||
if not self.suggestions:
|
||||
return self.message
|
||||
suggestion_str = "".join(f"{i}) {s};" for i, s in enumerate(self.suggestions, 1))
|
||||
# Strip trailing ";" and prefix with "建议:"
|
||||
suggestion_str = suggestion_str.rstrip(";")
|
||||
return f"{self.message}建议:{suggestion_str}。"
|
||||
|
||||
|
||||
# Default rule set: (exception_type, error_code, message, suggestions, retryable)
|
||||
# ponytail: ceiling — rule-based, no LLM. Adding a new rule = append a tuple.
|
||||
# Upgrade path: LLM-driven suggestions would require a separate classifier
|
||||
# class (out of scope per brainstorm KTD4).
|
||||
_DEFAULT_RULES: list[tuple[type[Exception], str, str, list[str], bool]] = [
|
||||
(
|
||||
TaskTimeoutError,
|
||||
"timeout",
|
||||
"任务执行超时。",
|
||||
["稍后重试", "简化任务范围"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
LoopDetectedError,
|
||||
"loop_detected",
|
||||
"检测到推理循环。",
|
||||
["拆分任务", "检查工具参数"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
LLMProviderError,
|
||||
"llm_failure",
|
||||
"LLM 服务调用失败。",
|
||||
["稍后重试", "切换模型"],
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
_DEFAULT_ERROR_CODE = "internal_error"
|
||||
_DEFAULT_MESSAGE = "Agent 执行内部错误。"
|
||||
_DEFAULT_SUGGESTIONS: list[str] = ["联系管理员"]
|
||||
_DEFAULT_RETRYABLE = False
|
||||
|
||||
|
||||
class EmergencyRules:
|
||||
"""Rule-based classifier for the Emergency layer.
|
||||
|
||||
Maps exception types to ``EmergencyError`` instances. No LLM, no I/O —
|
||||
pure function of ``(exception, config)``.
|
||||
|
||||
Caller responsibility: ``TaskCancelledError`` MUST propagate as-is
|
||||
(per KTD3); the caller checks ``isinstance(exc, TaskCancelledError)``
|
||||
before invoking :meth:`classify`. Calling :meth:`classify` with a
|
||||
``TaskCancelledError`` raises ``ValueError`` to surface the bug.
|
||||
|
||||
Config override shape (optional ``config`` arg):
|
||||
|
||||
```python
|
||||
{
|
||||
"timeout": {"suggestions": ["自定义建议"], "retryable": False},
|
||||
"llm_failure": {"message": "自定义消息"},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
|
||||
if isinstance(exc, TaskCancelledError):
|
||||
# Contract: caller must check before invoking classify.
|
||||
raise ValueError(
|
||||
"TaskCancelledError must propagate as-is; caller must check "
|
||||
"before invoking EmergencyRules.classify"
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
|
||||
for exc_type, code, message, suggestions, retryable in _DEFAULT_RULES:
|
||||
if isinstance(exc, exc_type):
|
||||
override = config.get(code, {}) if isinstance(config, dict) else {}
|
||||
return EmergencyError(
|
||||
error_code=code,
|
||||
message=override.get("message", message),
|
||||
suggestions=override.get("suggestions", list(suggestions)),
|
||||
retryable=override.get("retryable", retryable),
|
||||
original_error=str(exc),
|
||||
)
|
||||
|
||||
# Generic fallback
|
||||
override = config.get(_DEFAULT_ERROR_CODE, {}) if isinstance(config, dict) else {}
|
||||
return EmergencyError(
|
||||
error_code=_DEFAULT_ERROR_CODE,
|
||||
message=override.get("message", _DEFAULT_MESSAGE),
|
||||
suggestions=override.get("suggestions", list(_DEFAULT_SUGGESTIONS)),
|
||||
retryable=override.get("retryable", _DEFAULT_RETRYABLE),
|
||||
original_error=str(exc),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,206 @@
|
|||
"""Phase state machine for PLAN_EXEC mode (G6, R24/R25).
|
||||
|
||||
Four sequential phases enforce per-step tool whitelists:
|
||||
PLANNING → BUILDING → VERIFICATION → DELIVERY
|
||||
|
||||
KTD3 (Wave 3 plan): state machine lives in ReActEngine, not skill config.
|
||||
KTD5: default whitelist matches brainstorm R24 (Planning: think/search;
|
||||
Building: write_file; etc.).
|
||||
KTD6: transitions are LLM-driven via AdvancePhaseTool; auto-advance is opt-in.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PhaseState(enum.Enum):
|
||||
"""Phases of the SOLO state machine (extends ExecutionMode.PLAN_EXEC)."""
|
||||
|
||||
PLANNING = "planning"
|
||||
BUILDING = "building"
|
||||
VERIFICATION = "verification"
|
||||
DELIVERY = "delivery"
|
||||
|
||||
@classmethod
|
||||
def next_of(cls, current: "PhaseState") -> "PhaseState | None":
|
||||
"""Return the phase after `current`, or None if `current` is the last."""
|
||||
order = [cls.PLANNING, cls.BUILDING, cls.VERIFICATION, cls.DELIVERY]
|
||||
try:
|
||||
idx = order.index(current)
|
||||
except ValueError:
|
||||
return None
|
||||
if idx + 1 >= len(order):
|
||||
return None
|
||||
return order[idx + 1]
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "PhaseState":
|
||||
"""Parse from string (case-insensitive). Raises ValueError on unknown."""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError as e:
|
||||
valid = ", ".join(p.value for p in cls)
|
||||
raise ValueError(f"Invalid phase name {value!r}. Valid: {valid}") from e
|
||||
|
||||
|
||||
# Wildcard token meaning "all tools allowed in this phase".
|
||||
WILDCARD = "*"
|
||||
|
||||
# Default bash command filter for PLANNING and VERIFICATION phases — blocks
|
||||
# commands that mutate the filesystem or execute arbitrary code.
|
||||
# ponytail: regex is intentionally conservative; misses some shell idioms
|
||||
# (e.g., `:>file`, `dd of=file`). Ceiling: a real shell parser would catch
|
||||
# more. Upgrade path = reuse ShellTool._is_dangerous() at enforcement time.
|
||||
# Note: `\b` is a word boundary — works for word commands (rm/mv) but NOT
|
||||
# for `>`/`>>` operators (not word chars). Use a non-boundary alternation
|
||||
# that matches `>` either as a standalone operator or after whitespace.
|
||||
_DEFAULT_BASH_FILTER = re.compile(r"\b(rm|mv|cp|mkdir|rmdir|chmod|chown)\b|(?<!\S)>|>>")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PhasePolicy:
|
||||
"""Per-phase tool whitelist + bash command filter for PLAN_EXEC mode.
|
||||
|
||||
The policy is enforced by ReActEngine._execute_loop before each tool
|
||||
dispatch. A tool not in the current phase's whitelist is rejected with
|
||||
a structured error returned to the LLM (the loop continues — the LLM
|
||||
gets to react to the rejection and either switch tools or call
|
||||
AdvancePhaseTool).
|
||||
|
||||
Wildcard ``"*"`` in a phase's whitelist means "all tools allowed"
|
||||
(used by DELIVERY by default).
|
||||
"""
|
||||
|
||||
whitelist: dict[PhaseState, frozenset[str]]
|
||||
bash_command_filter: dict[PhaseState, re.Pattern | None] = field(default_factory=dict)
|
||||
auto_advance_after_steps: int | None = None # None = manual (LLM calls advance_phase)
|
||||
start_phase: PhaseState = PhaseState.PLANNING
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Fail-fast: empty whitelist for a non-wildcard phase = bug.
|
||||
for phase, tools in self.whitelist.items():
|
||||
if not tools:
|
||||
raise ValueError(
|
||||
f"Phase {phase.value!r} has an empty whitelist — set ['*'] for "
|
||||
f"'all tools allowed' or list specific tool names."
|
||||
)
|
||||
|
||||
def is_tool_allowed(self, tool_name: str, phase: PhaseState) -> bool:
|
||||
"""Return True if `tool_name` is allowed in `phase`."""
|
||||
allowed = self.whitelist.get(phase, frozenset())
|
||||
if WILDCARD in allowed:
|
||||
return True
|
||||
return tool_name in allowed
|
||||
|
||||
def is_bash_command_allowed(self, command: str, phase: PhaseState) -> bool:
|
||||
"""Return True if `command` passes the bash filter for `phase`.
|
||||
|
||||
A None filter = no restriction. An empty command is allowed (ShellTool
|
||||
separately rejects empty commands).
|
||||
"""
|
||||
pattern = self.bash_command_filter.get(phase)
|
||||
if pattern is None:
|
||||
return True
|
||||
return not pattern.search(command)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize for logging/telemetry. Not round-trippable (regex → str)."""
|
||||
return {
|
||||
"whitelist": {phase.value: sorted(tools) for phase, tools in self.whitelist.items()},
|
||||
"bash_command_filter": {
|
||||
phase.value: (p.pattern if p else None)
|
||||
for phase, p in self.bash_command_filter.items()
|
||||
},
|
||||
"auto_advance_after_steps": self.auto_advance_after_steps,
|
||||
"start_phase": self.start_phase.value,
|
||||
}
|
||||
|
||||
|
||||
def default_policy() -> PhasePolicy:
|
||||
"""Return the KTD5 default PhasePolicy.
|
||||
|
||||
Whitelist (R24):
|
||||
- PLANNING: search, tool_search, read_file, shell (read-only)
|
||||
- BUILDING: write_file, shell (full), read_file, search
|
||||
- VERIFICATION: shell (test commands), read_file, search
|
||||
- DELIVERY: all tools (wildcard)
|
||||
|
||||
Bash filter:
|
||||
- PLANNING/VERIFICATION: blocks filesystem-mutating commands
|
||||
(rm/mv/cp/mkdir/chmod/chown/>/>>)
|
||||
- BUILDING/DELIVERY: no filter (full bash)
|
||||
"""
|
||||
return PhasePolicy(
|
||||
whitelist={
|
||||
# Tool name is "shell" (ShellTool default); bash_command_filter
|
||||
# gates on the same name. Using "bash" here would make the filter
|
||||
# dead code and block the LLM from shell access.
|
||||
PhaseState.PLANNING: frozenset({"search", "tool_search", "read_file", "shell"}),
|
||||
PhaseState.BUILDING: frozenset(
|
||||
{"write_file", "shell", "read_file", "search", "tool_search"}
|
||||
),
|
||||
PhaseState.VERIFICATION: frozenset({"shell", "read_file", "search"}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
bash_command_filter={
|
||||
PhaseState.PLANNING: _DEFAULT_BASH_FILTER,
|
||||
PhaseState.VERIFICATION: _DEFAULT_BASH_FILTER,
|
||||
PhaseState.BUILDING: None,
|
||||
PhaseState.DELIVERY: None,
|
||||
},
|
||||
auto_advance_after_steps=None, # manual by default
|
||||
start_phase=PhaseState.PLANNING,
|
||||
)
|
||||
|
||||
|
||||
def policy_from_config(config: dict[str, Any]) -> PhasePolicy | None:
|
||||
"""Build a PhasePolicy from the `plan_exec` config section.
|
||||
|
||||
Returns None if `config` is empty or `enabled` is False (opt-out).
|
||||
|
||||
Config shape:
|
||||
plan_exec:
|
||||
enabled: true # default true if section present
|
||||
auto_advance_after_steps: 5 # optional
|
||||
start_phase: planning # optional, default planning
|
||||
whitelist_override: # optional, merges with default
|
||||
planning: [search, read_file]
|
||||
building: [write_file, bash]
|
||||
"""
|
||||
if not config:
|
||||
return None
|
||||
if config.get("enabled", True) is False:
|
||||
return None
|
||||
|
||||
policy = default_policy()
|
||||
|
||||
# Start phase
|
||||
start_phase_str = config.get("start_phase")
|
||||
if start_phase_str:
|
||||
policy = replace(policy, start_phase=PhaseState.from_string(start_phase_str))
|
||||
|
||||
# Auto-advance override
|
||||
if "auto_advance_after_steps" in config:
|
||||
policy = replace(policy, auto_advance_after_steps=config["auto_advance_after_steps"])
|
||||
|
||||
# Whitelist override — merge with default (override wins on conflict)
|
||||
override = config.get("whitelist_override") or {}
|
||||
if override:
|
||||
new_whitelist = dict(policy.whitelist)
|
||||
for phase_name, tools in override.items():
|
||||
phase = PhaseState.from_string(phase_name)
|
||||
if not isinstance(tools, list):
|
||||
raise ValueError(
|
||||
f"whitelist_override[{phase_name!r}] must be a list, got {type(tools).__name__}"
|
||||
)
|
||||
new_whitelist[phase] = frozenset(str(t) for t in tools)
|
||||
policy = replace(policy, whitelist=new_whitelist)
|
||||
|
||||
return policy
|
||||
|
|
@ -128,6 +128,10 @@ class TaskResult:
|
|||
completed_at: datetime
|
||||
metrics: dict | None = None
|
||||
trace: Any | None = None
|
||||
# G7/U2: Emergency layer structured error. None preserves existing contract
|
||||
# (error_message alone carries the human-readable string). When set,
|
||||
# carries serialized EmergencyError.to_dict() for programmatic dispatch.
|
||||
error_struct: dict | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
|
|
@ -142,6 +146,8 @@ class TaskResult:
|
|||
}
|
||||
if self.trace is not None:
|
||||
d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace
|
||||
if self.error_struct is not None:
|
||||
d["error_struct"] = self.error_struct
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
|
@ -162,6 +168,7 @@ class TaskResult:
|
|||
completed_at=completed_at or datetime.now(timezone.utc),
|
||||
metrics=data.get("metrics"),
|
||||
trace=data.get("trace"),
|
||||
error_struct=data.get("error_struct"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from agentkit.telemetry.metrics import (
|
|||
if TYPE_CHECKING:
|
||||
from agentkit.core.compressor import CompressionStrategy
|
||||
from agentkit.core.middleware import MiddlewareChain
|
||||
from agentkit.core.phase import PhasePolicy, PhaseState
|
||||
from agentkit.core.trace import TraceRecorder
|
||||
from agentkit.memory.retriever import MemoryRetriever
|
||||
|
||||
|
|
@ -168,6 +169,9 @@ class ReActEngine:
|
|||
prompt_cache_enable: bool = True,
|
||||
flush_interval_ms: int = 0,
|
||||
max_reinjections: int = 1,
|
||||
# U3/G6: PLAN_EXEC phase policy (opt-in). None = no enforcement
|
||||
# (backward compat — all existing callers unaffected).
|
||||
phase_policy: "PhasePolicy | None" = None,
|
||||
):
|
||||
if max_steps < 1:
|
||||
raise ValueError(f"max_steps must be >= 1, got {max_steps}")
|
||||
|
|
@ -211,6 +215,15 @@ class ReActEngine:
|
|||
self._loop_corrected: bool = False
|
||||
# U6: Middleware chain (parallel integration, feature flag controlled)
|
||||
self._middleware_chain = middleware_chain
|
||||
# U3/G6: PLAN_EXEC phase state. None = no enforcement (default).
|
||||
# When set, _execute_loop checks each tool call against the current
|
||||
# phase's whitelist before dispatch.
|
||||
self._phase_policy = phase_policy
|
||||
self._current_phase: "PhaseState | None" = (
|
||||
phase_policy.start_phase if phase_policy is not None else None
|
||||
)
|
||||
# Steps taken in the current phase (for auto-advance safety net).
|
||||
self._steps_in_phase: int = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
|
@ -223,6 +236,99 @@ class ReActEngine:
|
|||
# This method exists for API clarity and future stateful extensions.
|
||||
self._loop_window.clear()
|
||||
self._loop_corrected = False
|
||||
# U3/G6: reset phase state to start_phase (if policy set). Each
|
||||
# execute() call begins a fresh PLANNING phase.
|
||||
if self._phase_policy is not None:
|
||||
self._current_phase = self._phase_policy.start_phase
|
||||
self._steps_in_phase = 0
|
||||
|
||||
# ── U3/G6: phase state machine ────────────────────────────────────
|
||||
|
||||
def advance_phase(self) -> "PhaseState | None":
|
||||
"""Advance to the next phase. Returns the new phase, or None if
|
||||
already at DELIVERY (final phase).
|
||||
|
||||
Called by AdvancePhaseTool when the LLM explicitly signals phase
|
||||
completion. Also called by the auto-advance safety net when
|
||||
``steps_in_phase >= auto_advance_after_steps``.
|
||||
|
||||
Returns None if no phase_policy is set (no-op).
|
||||
"""
|
||||
if self._phase_policy is None or self._current_phase is None:
|
||||
return None
|
||||
from agentkit.core.phase import PhaseState
|
||||
|
||||
nxt = PhaseState.next_of(self._current_phase)
|
||||
if nxt is None:
|
||||
# Already at DELIVERY — return None to signal no transition.
|
||||
return None
|
||||
previous = self._current_phase
|
||||
self._current_phase = nxt
|
||||
self._steps_in_phase = 0
|
||||
logger.info(
|
||||
"Phase transition: %s → %s",
|
||||
previous.value,
|
||||
nxt.value,
|
||||
)
|
||||
return nxt
|
||||
|
||||
@property
|
||||
def current_phase(self) -> "PhaseState | None":
|
||||
"""Current phase (None if no phase_policy set)."""
|
||||
return self._current_phase
|
||||
|
||||
def _maybe_auto_advance(self) -> bool:
|
||||
"""Auto-advance phase if step budget exhausted. Returns True if advanced."""
|
||||
if self._phase_policy is None or self._current_phase is None:
|
||||
return False
|
||||
threshold = self._phase_policy.auto_advance_after_steps
|
||||
if threshold is None:
|
||||
return False
|
||||
if self._steps_in_phase >= threshold:
|
||||
self.advance_phase()
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_phase_permission(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return None if tool is allowed; return a structured error dict if blocked.
|
||||
|
||||
The error dict replaces what `_execute_tool` would have returned —
|
||||
the loop continues, so the LLM can react to the rejection (call
|
||||
AdvancePhaseTool or pick a different tool).
|
||||
|
||||
Also applies the bash_command_filter for `bash` tool calls.
|
||||
"""
|
||||
if self._phase_policy is None or self._current_phase is None:
|
||||
return None
|
||||
if not self._phase_policy.is_tool_allowed(tool_name, self._current_phase):
|
||||
return {
|
||||
"error": "phase_violation",
|
||||
"message": (
|
||||
f"Tool {tool_name!r} not allowed in {self._current_phase.value} phase. "
|
||||
f"Call `advance_phase` to move to the next phase."
|
||||
),
|
||||
"current_phase": self._current_phase.value,
|
||||
"tool": tool_name,
|
||||
"is_error": True,
|
||||
}
|
||||
# Bash command filter (only applies to shell tool — registered as "shell").
|
||||
if tool_name == "shell":
|
||||
command = str(arguments.get("command", ""))
|
||||
if not self._phase_policy.is_bash_command_allowed(command, self._current_phase):
|
||||
return {
|
||||
"error": "phase_violation",
|
||||
"message": (
|
||||
f"Bash command blocked in {self._current_phase.value} phase "
|
||||
f"(filesystem-mutating operations not allowed during "
|
||||
f"planning/verification). Command: {command[:100]}"
|
||||
),
|
||||
"current_phase": self._current_phase.value,
|
||||
"tool": tool_name,
|
||||
"is_error": True,
|
||||
}
|
||||
return None
|
||||
|
||||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||
"""检测重复工具调用模式。
|
||||
|
|
@ -498,6 +604,14 @@ class ReActEngine:
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
|
||||
# U3/G6: phase auto-advance safety net.
|
||||
# Incremented per step (LLM call), not per tool_call. When
|
||||
# auto_advance_after_steps is set, advance the phase after
|
||||
# the LLM has been stuck in the same phase for N steps.
|
||||
if self._phase_policy is not None:
|
||||
self._steps_in_phase += 1
|
||||
self._maybe_auto_advance()
|
||||
|
||||
# Think: 调用 LLM
|
||||
llm_start = time.monotonic()
|
||||
response = await self._llm_gateway.chat(
|
||||
|
|
@ -1148,6 +1262,11 @@ class ReActEngine:
|
|||
if cancellation_token is not None:
|
||||
cancellation_token.check()
|
||||
|
||||
# U3/G6: phase auto-advance safety net (mirrors _execute_loop).
|
||||
if self._phase_policy is not None:
|
||||
self._steps_in_phase += 1
|
||||
self._maybe_auto_advance()
|
||||
|
||||
# 超时检查
|
||||
if effective_timeout > 0:
|
||||
elapsed = time.monotonic() - _stream_start
|
||||
|
|
@ -2069,6 +2188,20 @@ class ReActEngine:
|
|||
self, tool_name: str, arguments: dict[str, Any], tools: list[Tool]
|
||||
) -> dict:
|
||||
"""执行工具调用,处理成功和失败情况"""
|
||||
# U3/G6: phase enforcement — check before dispatch. If the tool is
|
||||
# blocked, return a structured error instead of dispatching. The loop
|
||||
# still counts this as a step (the LLM gets to react to the rejection).
|
||||
# `advance_phase` tool bypasses the check (it's the LLM's escape hatch).
|
||||
if tool_name != "advance_phase":
|
||||
block = self._check_phase_permission(tool_name, arguments)
|
||||
if block is not None:
|
||||
logger.info(
|
||||
"Phase violation: tool %r blocked in %s phase",
|
||||
tool_name,
|
||||
self._current_phase.value if self._current_phase else "?",
|
||||
)
|
||||
return block
|
||||
|
||||
tool = self._find_tool(tool_name, tools)
|
||||
if tool is None:
|
||||
error_msg = f"Tool '{tool_name}' not found"
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from typing import Any
|
|||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||
|
||||
from .expert import Expert
|
||||
from .plan import (
|
||||
|
|
@ -72,12 +73,17 @@ class TeamOrchestrator:
|
|||
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
|
||||
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
||||
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
||||
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
|
||||
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
|
||||
DEFAULT_ROLLBACK_TIMEOUT = 30.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
team: ExpertTeam,
|
||||
max_concurrent_phases: int | None = None,
|
||||
checkpoint: Any = None,
|
||||
workspace_root: str | None = None,
|
||||
rollback_timeout: float | None = None,
|
||||
) -> None:
|
||||
self._team = team
|
||||
# Track temporary agent names created for context isolation (KTD3)
|
||||
|
|
@ -93,6 +99,10 @@ class TeamOrchestrator:
|
|||
self._phase_semaphore = asyncio.Semaphore(limit)
|
||||
# U7: Pipeline checkpoint for crash recovery
|
||||
self._checkpoint = checkpoint
|
||||
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
|
||||
# Both default to no-op-friendly values so existing call sites behave identically.
|
||||
self._workspace_root = workspace_root
|
||||
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
|
||||
|
||||
async def execute(self, task: str) -> dict[str, Any]:
|
||||
"""Execute a task in pipeline mode.
|
||||
|
|
@ -262,8 +272,23 @@ class TeamOrchestrator:
|
|||
else:
|
||||
phase_results[ph.id] = result
|
||||
|
||||
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
|
||||
# When phase configures both validation_command and rollback_command:
|
||||
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
|
||||
# 2. if validation fails, run rollback_command
|
||||
# 3. if rollback passes (exit 0), save checkpoint
|
||||
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
|
||||
# When neither command is set, behavior is unchanged (existing save).
|
||||
should_save_checkpoint = True
|
||||
if (
|
||||
ph.validation_command
|
||||
and ph.rollback_command
|
||||
and isinstance(result, (Exception, asyncio.CancelledError))
|
||||
):
|
||||
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
|
||||
|
||||
# U7: Save checkpoint after phase finalizes (success or failure)
|
||||
if self._checkpoint is not None:
|
||||
if should_save_checkpoint and self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||
except Exception as e:
|
||||
|
|
@ -393,9 +418,7 @@ class TeamOrchestrator:
|
|||
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
||||
|
||||
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
|
||||
self._debate_count = sum(
|
||||
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
|
||||
)
|
||||
self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
|
||||
|
||||
logger.info(
|
||||
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
||||
|
|
@ -688,9 +711,9 @@ class TeamOrchestrator:
|
|||
and prev_phase.result
|
||||
):
|
||||
# U4: Resolve offloaded content from workspace
|
||||
collaboration_outputs[contract.from_expert] = (
|
||||
await self._read_dependency_output(prev_phase)
|
||||
)
|
||||
collaboration_outputs[
|
||||
contract.from_expert
|
||||
] = await self._read_dependency_output(prev_phase)
|
||||
break
|
||||
|
||||
# Emit expert_step event
|
||||
|
|
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
|
|||
# Recursively mark their dependents
|
||||
await self._mark_dependents_failed(ph.id, plan, phase_results)
|
||||
|
||||
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
|
||||
"""G9/U4: run validation_command + rollback_command for a failed phase.
|
||||
|
||||
Returns True if checkpoint save should proceed (R21 ordering).
|
||||
- Validation passes → save checkpoint (phase state recoverable)
|
||||
- Validation fails, rollback passes → save checkpoint (rolled back state)
|
||||
- Validation fails, rollback fails → skip checkpoint (broken state)
|
||||
- Subprocess spawn failure or timeout → skip checkpoint
|
||||
"""
|
||||
executor = RollbackExecutor(
|
||||
working_dir=self._workspace_root,
|
||||
timeout=self._rollback_timeout,
|
||||
)
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_started",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"validation_command": ph.validation_command,
|
||||
"rollback_command": ph.rollback_command,
|
||||
},
|
||||
)
|
||||
# ponytail: validate first; if validation passes, rollback is skipped (no need).
|
||||
validation = await executor.validate(ph.validation_command or "")
|
||||
if validation.passed:
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_completed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"rollback_executed": False,
|
||||
"validation_passed": True,
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
rollback = await executor.execute(ph.rollback_command or "")
|
||||
if rollback.passed:
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_completed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"rollback_executed": True,
|
||||
"validation_passed": False,
|
||||
"rollback_stdout": rollback.stdout,
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
logger.error(
|
||||
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
|
||||
)
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_failed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"validation_passed": False,
|
||||
"rollback_exit_code": rollback.exit_code,
|
||||
"rollback_stderr": rollback.stderr,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
async def _synthesize_results(
|
||||
self, lead: Expert, task: str, completed_phases: list[PlanPhase]
|
||||
) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -182,6 +182,11 @@ class PlanPhase:
|
|||
collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
|
||||
rework_count: int = 0
|
||||
review_feedback: str | None = None
|
||||
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
|
||||
# validation_command runs first; if it fails, rollback_command runs.
|
||||
# canonical rollback pattern: `git checkout <specific_files>`.
|
||||
validation_command: str | None = None
|
||||
rollback_command: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""序列化为字典"""
|
||||
|
|
@ -192,7 +197,7 @@ class PlanPhase:
|
|||
result_str = self.result.get("content", str(self.result))
|
||||
else:
|
||||
result_str = str(self.result)
|
||||
return {
|
||||
out: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"assigned_expert": self.assigned_expert,
|
||||
|
|
@ -206,6 +211,12 @@ class PlanPhase:
|
|||
"rework_count": self.rework_count,
|
||||
"review_feedback": self.review_feedback,
|
||||
}
|
||||
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
|
||||
if self.validation_command is not None:
|
||||
out["validation_command"] = self.validation_command
|
||||
if self.rollback_command is not None:
|
||||
out["rollback_command"] = self.rollback_command
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
||||
|
|
@ -230,6 +241,8 @@ class PlanPhase:
|
|||
collaboration_contracts=contracts,
|
||||
rework_count=data.get("rework_count", 0),
|
||||
review_feedback=data.get("review_feedback"),
|
||||
validation_command=data.get("validation_command"),
|
||||
rollback_command=data.get("rollback_command"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -203,6 +203,9 @@ class LLMConfig:
|
|||
model_aliases: dict[str, str] = field(default_factory=dict)
|
||||
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
||||
cache: CacheConfig | None = None
|
||||
# G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization).
|
||||
# Resolved via existing model_aliases mechanism. None = use main model only.
|
||||
auxiliary_model: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LLMConfig":
|
||||
|
|
@ -254,4 +257,5 @@ class LLMConfig:
|
|||
model_aliases=data.get("model_aliases", {}),
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache,
|
||||
auxiliary_model=data.get("auxiliary_model"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
|
||||
|
||||
复用 VerificationLoop 的 asyncio.create_subprocess_shell 模式 (KTD7):
|
||||
绕过 ShellTool,避免 confirm_callback 对 `git checkout` 的拦截。
|
||||
|
||||
设计依据:
|
||||
- KTD6: 回滚是 opt-in 行为,未配置 rollback_command 时不会执行
|
||||
- KTD7: 不走 ShellTool,避免 _is_dangerous 触发 confirm_callback
|
||||
- R21: checkpoint.save 仅在回滚校验通过后调用
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackResult:
|
||||
"""单次 subprocess 执行结果"""
|
||||
|
||||
passed: bool
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
command: str
|
||||
|
||||
|
||||
class RollbackExecutor:
|
||||
"""执行 validation_command / rollback_command 的子进程封装
|
||||
|
||||
与 VerificationLoop 同构,但语义不同:
|
||||
- validate(): 返回 passed=False 表示校验失败,需要触发 rollback
|
||||
- execute(): 返回 passed=False 表示回滚本身失败,需跳过 checkpoint.save (R21)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
working_dir: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self._working_dir = working_dir
|
||||
self._timeout = timeout
|
||||
|
||||
async def _run(self, command: str) -> RollbackResult:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self._working_dir,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
|
||||
return RollbackResult(
|
||||
passed=False,
|
||||
exit_code=-1,
|
||||
stdout="",
|
||||
stderr=f"Failed to spawn command: {e}",
|
||||
command=command,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
await proc.wait()
|
||||
return RollbackResult(
|
||||
passed=False,
|
||||
exit_code=-1,
|
||||
stdout="",
|
||||
stderr=f"Command timed out after {self._timeout}s: {command}",
|
||||
command=command,
|
||||
)
|
||||
|
||||
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
|
||||
return RollbackResult(
|
||||
passed=proc.returncode == 0,
|
||||
exit_code=proc.returncode if proc.returncode is not None else -1,
|
||||
stdout=out_str,
|
||||
stderr=err_str,
|
||||
command=command,
|
||||
)
|
||||
|
||||
async def validate(self, command: str) -> RollbackResult:
|
||||
"""运行 validation_command,passed=False 表示需要触发 rollback"""
|
||||
return await self._run(command)
|
||||
|
||||
async def execute(self, command: str) -> RollbackResult:
|
||||
"""运行 rollback_command,passed=False 表示回滚本身失败 (R21)"""
|
||||
return await self._run(command)
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""G7/U3 — Three-tier fallback chain (main → Recovery → Emergency).
|
||||
|
||||
Wired at chat.py REST send_message endpoint. Composes U2's EmergencyRules
|
||||
with existing ReflexionEngine for the Recovery layer.
|
||||
|
||||
Scope (KTD5): Only the chat REST path is wrapped. CLI / ReWOO / Reflexion
|
||||
internal ReAct calls are NOT wrapped (would create recursive loop).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.fallback import EmergencyError, EmergencyRules
|
||||
from agentkit.core.react import ReActEngine, ReActResult
|
||||
from agentkit.core.reflexion import ReflexionEngine
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ReActResult.status values that indicate soft failure → trigger Recovery.
|
||||
# "success" is the only clean-pass; everything else is fallback-worthy.
|
||||
_SOFT_FAILURE_STATUSES = frozenset({"empty_fallback", "verify_failed", "timeout"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatExecutionResult:
|
||||
"""Wrapper produced by execute_with_fallback_chain.
|
||||
|
||||
Carries a ReActResult-like ``output`` field plus an optional
|
||||
``error_struct`` (set only when Emergency tier fires). The chat
|
||||
handler reads ``.output`` for the assistant reply and ``.error_struct``
|
||||
for the optional structured error payload.
|
||||
"""
|
||||
|
||||
output: str
|
||||
status: str # "success" | "recovered" | "emergency"
|
||||
error_struct: dict[str, Any] | None = None
|
||||
trajectory: list[Any] = field(default_factory=list)
|
||||
total_steps: int = 0
|
||||
total_tokens: int = 0
|
||||
fallback_strategy: str | None = None
|
||||
|
||||
|
||||
def _react_to_chat_result(react: ReActResult) -> ChatExecutionResult:
|
||||
return ChatExecutionResult(
|
||||
output=react.output,
|
||||
status="success",
|
||||
trajectory=react.trajectory,
|
||||
total_steps=react.total_steps,
|
||||
total_tokens=react.total_tokens,
|
||||
fallback_strategy=react.fallback_strategy,
|
||||
)
|
||||
|
||||
|
||||
def _reflexion_to_chat_result(reflexion_result: Any) -> ChatExecutionResult:
|
||||
"""Best-effort conversion from ReflexionResult to ChatExecutionResult."""
|
||||
output = getattr(reflexion_result, "output", None) or getattr(
|
||||
reflexion_result, "final_answer", ""
|
||||
)
|
||||
return ChatExecutionResult(
|
||||
output=output or "",
|
||||
status="recovered",
|
||||
trajectory=getattr(reflexion_result, "trajectory", []) or [],
|
||||
total_steps=getattr(reflexion_result, "total_steps", 0),
|
||||
total_tokens=getattr(reflexion_result, "total_tokens", 0),
|
||||
fallback_strategy="reflexion_recovery",
|
||||
)
|
||||
|
||||
|
||||
def _to_emergency(exc: Exception, config: dict | None) -> ChatExecutionResult:
|
||||
emergency: EmergencyError = EmergencyRules.classify(exc, config)
|
||||
return ChatExecutionResult(
|
||||
output=emergency.to_error_message(),
|
||||
status="emergency",
|
||||
error_struct=emergency.to_dict(),
|
||||
fallback_strategy="emergency",
|
||||
)
|
||||
|
||||
|
||||
async def execute_with_fallback_chain(
|
||||
*,
|
||||
react_engine: ReActEngine,
|
||||
llm_gateway: LLMGateway,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Any] | None,
|
||||
model: str,
|
||||
agent_name: str,
|
||||
system_prompt: str | None,
|
||||
fallback_chain_config: dict | None = None,
|
||||
) -> ChatExecutionResult:
|
||||
"""Three-tier fallback chain: Main → Recovery (ReflexionEngine) → Emergency.
|
||||
|
||||
KTD5: only this entry point wraps the chain. ReflexionEngine's internal
|
||||
ReAct call bypasses the chain (no recursive loop possible).
|
||||
|
||||
Returns ChatExecutionResult with status:
|
||||
- "success": main agent succeeded
|
||||
- "recovered": main failed, ReflexionEngine recovery succeeded
|
||||
- "emergency": main failed, recovery failed/exhausted, Emergency layer fired
|
||||
"""
|
||||
config = fallback_chain_config or {}
|
||||
recovery_cfg = config.get("recovery", {}) if isinstance(config, dict) else {}
|
||||
emergency_cfg = config.get("emergency", {}) if isinstance(config, dict) else {}
|
||||
recovery_enabled = recovery_cfg.get("enabled", True) if isinstance(recovery_cfg, dict) else True
|
||||
emergency_enabled = (
|
||||
emergency_cfg.get("enabled", True) if isinstance(emergency_cfg, dict) else True
|
||||
)
|
||||
max_reflections = recovery_cfg.get("max_retries", 1) if isinstance(recovery_cfg, dict) else 1
|
||||
|
||||
# ── Tier 1: Main ──────────────────────────────────────────────
|
||||
main_exc: Exception | None = None
|
||||
try:
|
||||
result = await react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
if result.status == "success":
|
||||
return _react_to_chat_result(result)
|
||||
# Soft failure (empty_fallback / verify_failed / timeout) → trigger Recovery
|
||||
if result.status in _SOFT_FAILURE_STATUSES:
|
||||
main_exc = AgentSoftFailureError(
|
||||
f"main agent status={result.status}: {result.output[:200]}"
|
||||
)
|
||||
else:
|
||||
# Unknown status — treat as success-like (don't trigger recovery)
|
||||
return _react_to_chat_result(result)
|
||||
except TaskCancelledError:
|
||||
# KTD3: TaskCancelledError propagates as-is, NOT routed to Emergency.
|
||||
raise
|
||||
except (TaskTimeoutError, LoopDetectedError, LLMProviderError) as exc:
|
||||
main_exc = exc
|
||||
except Exception as exc: # noqa: BLE001 - last-resort catch for Emergency routing
|
||||
main_exc = exc
|
||||
|
||||
# ── Tier 2: Recovery (ReflexionEngine) ────────────────────────
|
||||
if recovery_enabled and main_exc is not None:
|
||||
try:
|
||||
reflexion = ReflexionEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
max_reflections=max_reflections,
|
||||
)
|
||||
recovery_result = await reflexion.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
# Recovery succeeds if Reflexion reports success or produces output.
|
||||
recovery_status = getattr(recovery_result, "status", "")
|
||||
if recovery_status == "success" or getattr(recovery_result, "output", None):
|
||||
return _reflexion_to_chat_result(recovery_result)
|
||||
logger.warning(
|
||||
f"Recovery layer did not succeed (status={recovery_status}), "
|
||||
f"falling through to Emergency"
|
||||
)
|
||||
except TaskCancelledError:
|
||||
raise
|
||||
except Exception as recovery_exc: # noqa: BLE001
|
||||
logger.warning(f"Recovery layer raised: {recovery_exc}; falling through to Emergency")
|
||||
|
||||
# ── Tier 3: Emergency ─────────────────────────────────────────
|
||||
if not emergency_enabled:
|
||||
# Re-raise original exception if Emergency disabled.
|
||||
if main_exc is not None:
|
||||
raise main_exc
|
||||
# No exception but no success either — synthesise an emergency-style result.
|
||||
return ChatExecutionResult(
|
||||
output="Agent 未返回有效结果且 Emergency 层已禁用。",
|
||||
status="emergency",
|
||||
fallback_strategy="emergency_disabled",
|
||||
)
|
||||
|
||||
# main_exc may be None if main returned soft-failure status without raising.
|
||||
# Synthesize a generic exception for Emergency classification.
|
||||
exc_for_emergency = main_exc or AgentSoftFailureError("soft failure without exception")
|
||||
return _to_emergency(exc_for_emergency, config)
|
||||
|
||||
|
||||
class AgentSoftFailureError(Exception):
|
||||
"""Internal marker — main agent returned a soft-failure status without raising.
|
||||
|
||||
Used to feed the Emergency classifier when main status was e.g.
|
||||
``empty_fallback`` (no exception raised, but result not usable).
|
||||
Classified as ``internal_error`` by EmergencyRules (generic fallback).
|
||||
"""
|
||||
|
|
@ -119,6 +119,11 @@ class ServerConfig:
|
|||
prompt_cache: dict[str, Any] | None = None,
|
||||
streaming: dict[str, Any] | None = None,
|
||||
verification: dict[str, Any] | None = None,
|
||||
rollback: dict[str, Any] | None = None,
|
||||
fallback_chain: dict[str, Any] | None = None,
|
||||
# G6/U2: PLAN_EXEC phase policy config (opt-in — None = disabled).
|
||||
# Parsed via PhasePolicy.policy_from_config() at chat.py wiring time.
|
||||
plan_exec: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -153,6 +158,16 @@ class ServerConfig:
|
|||
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
||||
# verification_enabled=False 时此配置无效
|
||||
self.verification = verification or {}
|
||||
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
|
||||
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
|
||||
self.rollback = rollback or {}
|
||||
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
|
||||
# controls three-tier chain at chat.py REST send_message (KTD5).
|
||||
self.fallback_chain = fallback_chain or {}
|
||||
# G6/U2: plan_exec phase policy config (opt-in — empty dict = disabled).
|
||||
# Resolved to PhasePolicy via agentkit.core.phase.policy_from_config()
|
||||
# at chat.py WebSocket wiring time (U4).
|
||||
self.plan_exec = plan_exec or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -240,6 +255,12 @@ class ServerConfig:
|
|||
prompt_cache_data = data.get("prompt_cache", {})
|
||||
streaming_data = data.get("streaming", {})
|
||||
verification_data = data.get("verification", {})
|
||||
# G9/U4: rollback 配置 (从 YAML 读取,opt-in)
|
||||
rollback_data = data.get("rollback", {})
|
||||
# G7/U3: fallback_chain 配置 (从 YAML 读取)
|
||||
fallback_chain_data = data.get("fallback_chain", {})
|
||||
# G6/U2: plan_exec phase policy 配置 (从 YAML 读取, opt-in)
|
||||
plan_exec_data = data.get("plan_exec", {})
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
|
|
@ -271,6 +292,9 @@ class ServerConfig:
|
|||
prompt_cache=prompt_cache_data,
|
||||
streaming=streaming_data,
|
||||
verification=verification_data,
|
||||
rollback=rollback_data,
|
||||
fallback_chain=fallback_chain_data,
|
||||
plan_exec=plan_exec_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -320,6 +344,9 @@ class ServerConfig:
|
|||
model_aliases=model_aliases,
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache_config,
|
||||
# G4/U1: auxiliary model alias for cost-sensitive summarization.
|
||||
# Resolved via model_aliases; None = no auxiliary routing.
|
||||
auxiliary_model=data.get("auxiliary_model"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -25,10 +25,13 @@ from fastapi.responses import FileResponse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from agentkit.chat.skill_routing import ExecutionMode
|
||||
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
|
||||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.models import MessageRole, SessionStatus
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -46,6 +49,8 @@ class CreateSessionRequest(BaseModel):
|
|||
class SendMessageRequest(BaseModel):
|
||||
content: str
|
||||
role: str = "user"
|
||||
# Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only).
|
||||
execution_mode: str | None = None
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
|
|
@ -582,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
if session.status == SessionStatus.CLOSED:
|
||||
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
||||
|
||||
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
|
||||
if request.execution_mode == "plan_exec":
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
|
||||
)
|
||||
|
||||
# Append user message
|
||||
await sm.append_message(
|
||||
session_id=session_id,
|
||||
|
|
@ -610,7 +622,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
system_prompt = getattr(agent, "_system_prompt", None) or (
|
||||
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
||||
)
|
||||
result = await react_engine.execute(
|
||||
# G7/U3: Three-tier fallback chain (main → Recovery → Emergency).
|
||||
# Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass.
|
||||
server_config = getattr(req.app.state, "server_config", None)
|
||||
fallback_chain_cfg = (
|
||||
getattr(server_config, "fallback_chain", None) if server_config else None
|
||||
)
|
||||
chat_result = await execute_with_fallback_chain(
|
||||
react_engine=react_engine,
|
||||
llm_gateway=req.app.state.llm_gateway,
|
||||
messages=chat_messages,
|
||||
tools=tools,
|
||||
model=agent.get_model()
|
||||
|
|
@ -618,16 +638,26 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
else getattr(agent, "_llm_model", "default"),
|
||||
agent_name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
fallback_chain_config=fallback_chain_cfg,
|
||||
)
|
||||
|
||||
# Append assistant reply
|
||||
assistant_msg = await sm.append_message(
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=result.output if hasattr(result, "output") else str(result),
|
||||
content=chat_result.output,
|
||||
agent_name=agent.name,
|
||||
)
|
||||
return _message_to_response(assistant_msg)
|
||||
response = _message_to_response(assistant_msg)
|
||||
# Attach structured error payload when Emergency tier fired.
|
||||
if chat_result.error_struct is not None:
|
||||
response_dict = (
|
||||
response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
||||
)
|
||||
response_dict["error_struct"] = chat_result.error_struct
|
||||
response_dict["fallback_status"] = chat_result.status
|
||||
return response_dict
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat execution error for session {session_id}: {e}")
|
||||
|
|
@ -1060,21 +1090,73 @@ async def _handle_chat_message(
|
|||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||
return
|
||||
|
||||
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
||||
# currently fall back to REACT with a warning.
|
||||
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
||||
# U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
|
||||
# KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
|
||||
# fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
|
||||
phase_policy: PhasePolicy | None = None
|
||||
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||
server_config = getattr(websocket.app.state, "server_config", None)
|
||||
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
|
||||
|
||||
if plan_exec_cfg.get("enabled", True) is False:
|
||||
# Explicit opt-out → fall back to REACT.
|
||||
logger.info(
|
||||
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
|
||||
"falling back to REACT for session %s",
|
||||
session_id,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
phase_policy = policy_from_config(plan_exec_cfg)
|
||||
if phase_policy is None:
|
||||
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||
phase_policy = default_policy()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
# Truncate to 200 chars to match nearby error paths and
|
||||
# avoid leaking config internals (see chat.py:1090, 1320).
|
||||
"data": {"message": f"phase policy error: {str(e)[:200]}"},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB
|
||||
# still fall back to REACT with a warning. PLAN_EXEC is handled above.
|
||||
if routing.execution_mode not in (
|
||||
ExecutionMode.REACT,
|
||||
ExecutionMode.SKILL_REACT,
|
||||
ExecutionMode.PLAN_EXEC,
|
||||
):
|
||||
logger.warning(
|
||||
f"Execution mode {routing.execution_mode.value} not yet supported "
|
||||
f"in chat WebSocket, falling back to REACT"
|
||||
)
|
||||
|
||||
# Execute Agent with streaming
|
||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
||||
react_engine = getattr(agent, "_react_engine", None)
|
||||
if react_engine is None:
|
||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
|
||||
# PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
|
||||
# agent's _react_engine — it has no policy).
|
||||
if phase_policy is not None:
|
||||
react_engine = ReActEngine(
|
||||
llm_gateway=websocket.app.state.llm_gateway,
|
||||
phase_policy=phase_policy,
|
||||
)
|
||||
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
|
||||
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
|
||||
routing.tools = list(routing.tools) + [advance_phase_tool]
|
||||
else:
|
||||
react_engine.reset()
|
||||
react_engine = getattr(agent, "_react_engine", None)
|
||||
if react_engine is None:
|
||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||
else:
|
||||
react_engine.reset()
|
||||
|
||||
# Create confirmation handler that sends request to frontend and waits for reply
|
||||
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
|
||||
|
|
@ -1130,6 +1212,9 @@ async def _handle_chat_message(
|
|||
try:
|
||||
final_content = ""
|
||||
token_buffer: list[str] = []
|
||||
# Track phase transitions for phase_changed events (PLAN_EXEC only).
|
||||
# For non-PLAN_EXEC modes, current_phase is always None → no events.
|
||||
prev_phase = react_engine.current_phase
|
||||
async for event in react_engine.execute_stream(
|
||||
messages=chat_messages,
|
||||
tools=routing.tools,
|
||||
|
|
@ -1207,6 +1292,22 @@ async def _handle_chat_message(
|
|||
}
|
||||
)
|
||||
|
||||
# U4/G6: emit phase_changed event when the phase state machine
|
||||
# transitions (PLAN_EXEC only). For non-PLAN_EXEC modes,
|
||||
# current_phase is always None → this branch never fires.
|
||||
curr_phase = react_engine.current_phase
|
||||
if curr_phase != prev_phase:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "phase_changed",
|
||||
"data": {
|
||||
"phase": curr_phase.value if curr_phase else None,
|
||||
"previous": prev_phase.value if prev_phase else None,
|
||||
},
|
||||
}
|
||||
)
|
||||
prev_phase = curr_phase
|
||||
|
||||
# Append assistant reply to session
|
||||
if final_content:
|
||||
await sm.append_message(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from agentkit.tools.memory_tool import MemoryTool
|
|||
from agentkit.tools.web_search import WebSearchTool
|
||||
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
from agentkit.tools.file_read import ReadFileTool
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
|
|
@ -52,4 +54,6 @@ __all__ = [
|
|||
"OutputParser",
|
||||
"ParsedOutput",
|
||||
"ErrorType",
|
||||
"ReadFileTool",
|
||||
"AdvancePhaseTool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,82 @@
|
|||
"""AdvancePhaseTool — LLM-driven phase transition (G6, KTD6).
|
||||
|
||||
Registered alongside other tools when ReActEngine has a phase_policy set.
|
||||
The LLM calls this tool to signal "I'm done planning, move to building".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancePhaseTool(Tool):
|
||||
"""Tool that advances the ReActEngine's current phase.
|
||||
|
||||
KTD6: LLM-driven phase transitions. Auto-advance is opt-in via
|
||||
``plan_exec.auto_advance_after_steps``; this tool is the manual path.
|
||||
|
||||
The tool holds a weak reference to the engine (via bound method
|
||||
``engine.advance_phase``) — registered only when phase_policy is set.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: "ReActEngine",
|
||||
name: str = "advance_phase",
|
||||
description: str | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description
|
||||
or (
|
||||
"Advance the PLAN_EXEC phase state machine to the next phase "
|
||||
"(Planning → Building → Verification → Delivery). Call this "
|
||||
"when you have finished the current phase's work and are ready "
|
||||
"to move on. Returns the new phase name or an error if you "
|
||||
"are already at the final (Delivery) phase."
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
version=version,
|
||||
tags=tags or ["phase", "control"],
|
||||
)
|
||||
self._engine = engine
|
||||
|
||||
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||
# Capture previous phase before transition (engine is single-threaded per request).
|
||||
previous = self._engine.current_phase
|
||||
new_phase = self._engine.advance_phase()
|
||||
if new_phase is None:
|
||||
# Either no policy set, or already at DELIVERY.
|
||||
current = self._engine.current_phase
|
||||
if current is None:
|
||||
return {
|
||||
"is_error": True,
|
||||
"error": "no_phase_policy",
|
||||
"message": "No phase policy is set — advance_phase is a no-op.",
|
||||
}
|
||||
return {
|
||||
"is_error": True,
|
||||
"error": "already_at_final_phase",
|
||||
"message": (f"Already at final phase ({current.value}). Cannot advance further."),
|
||||
"current_phase": current.value,
|
||||
}
|
||||
return {
|
||||
"is_error": False,
|
||||
"previous_phase": previous.value if previous else "",
|
||||
"current_phase": new_phase.value,
|
||||
"message": f"Phase advanced to {new_phase.value}.",
|
||||
}
|
||||
|
|
@ -0,0 +1,262 @@
|
|||
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
|
||||
|
||||
Backward compatible with the pre-existing `_FakeTool` benchmark shape — when
|
||||
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
|
||||
the line range of the first matching symbol via `SymbolExtractor`.
|
||||
|
||||
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the
|
||||
file-reading contract clean and gives the LLM a focused schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.symbol_extractor import (
|
||||
SymbolSpan,
|
||||
extract_symbols_from_file,
|
||||
language_for_extension,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Read a file from the filesystem, optionally sliced to a single symbol.
|
||||
|
||||
Tool name `read_file` matches the reserved entry in
|
||||
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
|
||||
implementation — only `_FakeTool` stubs in `cli/benchmark.py`).
|
||||
|
||||
Backward-compat contract: `symbol=None` returns the full file content,
|
||||
matching the shape `{"path": ...}` that downstream callers (benchmark,
|
||||
phase whitelist) already expect.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "read_file",
|
||||
description: str | None = None,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description
|
||||
or (
|
||||
"Read a file from the filesystem. By default returns the full file "
|
||||
"content. Pass `symbol` (function/class/struct name) to slice to just "
|
||||
"that symbol's line range — saves context when you only need one "
|
||||
"function from a large file. Pass `start_line`/`end_line` for manual "
|
||||
"slicing. If `symbol` is set but not found, returns the available "
|
||||
"symbol names so you can retry."
|
||||
),
|
||||
input_schema=input_schema or self._default_input_schema(),
|
||||
output_schema=output_schema or self._default_output_schema(),
|
||||
version=version,
|
||||
tags=tags or ["io", "file", "read"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (absolute or relative to cwd).",
|
||||
},
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: name of a function/class/struct/method to slice to. "
|
||||
"When set, returns only the line range of the first matching "
|
||||
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
|
||||
),
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
|
||||
"minimum": 1,
|
||||
},
|
||||
"end_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _default_output_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"path": {"type": "string"},
|
||||
"start_line": {"type": "integer"},
|
||||
"end_line": {"type": "integer"},
|
||||
"symbol": {"type": "string"},
|
||||
"available_symbols": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Populated when `symbol` is set but not found.",
|
||||
},
|
||||
"note": {"type": "string"},
|
||||
"is_error": {"type": "boolean"},
|
||||
"error": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||
raw_path = kwargs.get("path")
|
||||
if not raw_path:
|
||||
return self._error("`path` is required")
|
||||
|
||||
path = Path(raw_path)
|
||||
if not path.is_absolute():
|
||||
path = path.resolve()
|
||||
|
||||
symbol = kwargs.get("symbol")
|
||||
start_line = kwargs.get("start_line")
|
||||
end_line = kwargs.get("end_line")
|
||||
|
||||
# Validate/sanitize line overrides.
|
||||
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
|
||||
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
|
||||
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
|
||||
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
|
||||
if start_line is not None and end_line is not None and end_line < start_line:
|
||||
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
|
||||
|
||||
# Filesystem checks.
|
||||
if not path.exists():
|
||||
return self._error(f"File not found: {path}", path=str(path))
|
||||
if path.is_dir():
|
||||
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
|
||||
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
except PermissionError as e:
|
||||
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
|
||||
except OSError as e:
|
||||
return self._error(f"Failed to read {path}: {e}", path=str(path))
|
||||
|
||||
lines = content.splitlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
# Manual slicing takes precedence over symbol (per plan U1 Approach).
|
||||
if start_line is not None or end_line is not None:
|
||||
s = max(1, start_line or 1)
|
||||
e = min(total_lines, end_line or total_lines)
|
||||
sliced = "\n".join(lines[s - 1 : e])
|
||||
return {
|
||||
"content": sliced,
|
||||
"path": str(path),
|
||||
"start_line": s,
|
||||
"end_line": e,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
# Symbol slicing.
|
||||
if symbol:
|
||||
ext = path.suffix.lower()
|
||||
language = language_for_extension(ext)
|
||||
if not language:
|
||||
# Unsupported extension: return full file with note (per plan U1 Edge case).
|
||||
return {
|
||||
"content": content,
|
||||
"path": str(path),
|
||||
"start_line": 1,
|
||||
"end_line": total_lines,
|
||||
"total_lines": total_lines,
|
||||
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
spans, _lang = extract_symbols_from_file(path)
|
||||
# Re-extract using the content we already read so we don't read the file twice.
|
||||
if not spans:
|
||||
# Try extraction from in-memory content (path-based extraction may
|
||||
# have failed silently on OSError; we already read it successfully).
|
||||
from agentkit.tools.symbol_extractor import get_extractor
|
||||
|
||||
extractor = get_extractor(language)
|
||||
if extractor is not None:
|
||||
spans = extractor.extract_symbols(content, language)
|
||||
|
||||
match = _find_symbol(spans, symbol)
|
||||
if match is None:
|
||||
available = sorted({s.name for s in spans})
|
||||
return {
|
||||
"content": "",
|
||||
"path": str(path),
|
||||
"symbol": symbol,
|
||||
"available_symbols": available,
|
||||
"is_error": False,
|
||||
"note": (
|
||||
f"Symbol {symbol!r} not found in {path.name}. "
|
||||
f"Available: {', '.join(available) if available else '(none)'}"
|
||||
),
|
||||
}
|
||||
|
||||
s = match.start_line
|
||||
e = min(match.end_line, total_lines)
|
||||
sliced = "\n".join(lines[s - 1 : e])
|
||||
return {
|
||||
"content": sliced,
|
||||
"path": str(path),
|
||||
"symbol": symbol,
|
||||
"symbol_kind": match.kind,
|
||||
"start_line": s,
|
||||
"end_line": e,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
# Default: full file (characterization baseline — matches _FakeTool shape).
|
||||
return {
|
||||
"content": content,
|
||||
"path": str(path),
|
||||
"start_line": 1,
|
||||
"end_line": total_lines,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _error(
|
||||
message: str, *, path: str | None = None, detail: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {
|
||||
"content": "",
|
||||
"is_error": True,
|
||||
"error": message,
|
||||
}
|
||||
if path is not None:
|
||||
result["path"] = path
|
||||
if detail is not None:
|
||||
result["detail"] = detail
|
||||
return result
|
||||
|
||||
|
||||
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
|
||||
"""Find the first symbol matching `name`. Case-sensitive.
|
||||
|
||||
ponytail: linear scan is fine for typical file symbol counts (<100). The
|
||||
extractor already returns symbols sorted by start_line; first match wins
|
||||
for ambiguous overloads (e.g., Python classes with same name in different
|
||||
modules — not relevant within one file).
|
||||
"""
|
||||
for span in spans:
|
||||
if span.name == name:
|
||||
return span
|
||||
return None
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
|
||||
|
||||
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
|
||||
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
|
||||
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
|
||||
can replace RegexSymbolExtractor behind the same interface.
|
||||
|
||||
ponytail: regex extractor covers ~80% case (top-level function/class/struct
|
||||
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
|
||||
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SymbolSpan:
|
||||
"""A located symbol — name, kind, and 1-based inclusive line range."""
|
||||
|
||||
name: str
|
||||
kind: str # "function" | "class" | "method" | "struct" | "impl"
|
||||
start_line: int # 1-based, inclusive
|
||||
end_line: int # 1-based, inclusive
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SymbolExtractor(Protocol):
|
||||
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
"""Return all symbols found in `content`.
|
||||
|
||||
`language` is the file extension without leading dot (e.g. "py", "ts").
|
||||
Implementations must never raise on extraction failure — return [] on
|
||||
parse errors and let the caller decide the fallback (full-file read).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python — stdlib ast
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AstSymbolExtractor:
|
||||
"""Python symbol extractor using the stdlib `ast` module.
|
||||
|
||||
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
|
||||
functions inside classes. The end_line is the last line of the node's
|
||||
source segment (decorator-inclusive).
|
||||
"""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
if language != "py":
|
||||
return []
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
logger.debug("ast.parse failed: %s", e)
|
||||
return []
|
||||
|
||||
lines = content.splitlines()
|
||||
spans: list[SymbolSpan] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
kind = "method" if _is_method(node) else "function"
|
||||
spans.append(_span_from_node(node, kind, lines))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
spans.append(_span_from_node(node, "class", lines))
|
||||
return spans
|
||||
|
||||
|
||||
def _is_method(node: ast.AST) -> bool:
|
||||
"""A FunctionDef is a method if its parent is a ClassDef.
|
||||
|
||||
`ast.walk` doesn't expose parentage, so we approximate by checking the
|
||||
node's col_offset == 4 (indented inside a class body). ponytail: this
|
||||
misses methods in deeply nested classes — ceiling noted; upgrade path =
|
||||
ast.NodeVisitor with parent tracking.
|
||||
"""
|
||||
return getattr(node, "col_offset", 0) > 0
|
||||
|
||||
|
||||
def _span_from_node(
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||
kind: str,
|
||||
lines: list[str],
|
||||
) -> SymbolSpan:
|
||||
# ast line numbers are 1-based; start at decorator if present (lineno points
|
||||
# to the def/class keyword, decorators are above). Use node.lineno for start
|
||||
# so the returned range matches what the user sees at the def keyword.
|
||||
start = node.lineno
|
||||
# node.end_lineno is the last line of the node body (None on old Pythons).
|
||||
end = node.end_lineno or start
|
||||
# Clamp to actual file length (defensive — ast should not exceed, but
|
||||
# malformed files with no trailing newline can confuse end_lineno).
|
||||
if end > len(lines):
|
||||
end = len(lines)
|
||||
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex extractor — TS/JS/Go/Rust/Java
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Each pattern matches a declaration and captures the symbol name in group 1.
|
||||
# Patterns use re.MULTILINE so ^ matches line starts.
|
||||
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
|
||||
"ts": [
|
||||
(
|
||||
"function",
|
||||
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
|
||||
),
|
||||
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
|
||||
re.MULTILINE,
|
||||
),
|
||||
),
|
||||
],
|
||||
"js": [
|
||||
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
|
||||
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
|
||||
),
|
||||
),
|
||||
],
|
||||
"go": [
|
||||
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
|
||||
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
|
||||
],
|
||||
"rs": [
|
||||
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
|
||||
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
|
||||
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
|
||||
],
|
||||
"java": [
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
|
||||
re.MULTILINE,
|
||||
),
|
||||
),
|
||||
(
|
||||
"class",
|
||||
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class RegexSymbolExtractor:
|
||||
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
|
||||
|
||||
Returns SymbolSpans whose end_line is approximated by the next blank line
|
||||
or next-symbol start (whichever comes first). ponytail: this is an
|
||||
approximation — true block-end requires language-aware brace matching.
|
||||
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
|
||||
tree-sitter.
|
||||
"""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
patterns = _REGEX_PATTERNS.get(language)
|
||||
if not patterns:
|
||||
return []
|
||||
|
||||
lines = content.splitlines()
|
||||
# Collect (line_no, name, kind) tuples first, then compute end_line
|
||||
# as the line before the next symbol starts (or EOF).
|
||||
raw_hits: list[tuple[int, str, str]] = []
|
||||
for kind, pattern in patterns:
|
||||
for m in pattern.finditer(content):
|
||||
# Convert match offset to 1-based line number.
|
||||
line_no = content[: m.start()].count("\n") + 1
|
||||
raw_hits.append((line_no, m.group(1), kind))
|
||||
|
||||
if not raw_hits:
|
||||
return []
|
||||
|
||||
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
|
||||
seen: set[tuple[int, str]] = set()
|
||||
unique: list[tuple[int, str, str]] = []
|
||||
for line_no, name, kind in raw_hits:
|
||||
key = (line_no, name)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique.append((line_no, name, kind))
|
||||
|
||||
unique.sort(key=lambda x: x[0])
|
||||
|
||||
spans: list[SymbolSpan] = []
|
||||
for i, (start_line, name, kind) in enumerate(unique):
|
||||
if i + 1 < len(unique):
|
||||
# End at line before next symbol starts, capped at file length.
|
||||
end_line = unique[i + 1][0] - 1
|
||||
else:
|
||||
end_line = len(lines)
|
||||
if end_line < start_line:
|
||||
end_line = start_line
|
||||
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
|
||||
return spans
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch by file extension
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXTENSION_LANGUAGE: dict[str, str] = {
|
||||
".py": "py",
|
||||
".ts": "ts",
|
||||
".tsx": "ts",
|
||||
".js": "js",
|
||||
".jsx": "js",
|
||||
".mjs": "js",
|
||||
".cjs": "js",
|
||||
".go": "go",
|
||||
".rs": "rs",
|
||||
".java": "java",
|
||||
}
|
||||
|
||||
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
|
||||
_REGEX_EXTRACTOR = RegexSymbolExtractor()
|
||||
|
||||
|
||||
def language_for_extension(ext: str) -> str:
|
||||
"""Return the language key for a file extension (with or without leading dot).
|
||||
|
||||
Returns "" for unsupported extensions.
|
||||
"""
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
|
||||
|
||||
|
||||
def get_extractor(language: str) -> SymbolExtractor | None:
|
||||
"""Return the appropriate extractor for `language`, or None if unsupported."""
|
||||
if language == "py":
|
||||
return _DEFAULT_EXTRACTOR
|
||||
if language in _REGEX_PATTERNS:
|
||||
return _REGEX_EXTRACTOR
|
||||
return None
|
||||
|
||||
|
||||
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
|
||||
"""Read a file and return (symbols, language).
|
||||
|
||||
Returns ([], "") if the extension is unsupported or the file cannot be read.
|
||||
Never raises — callers use this for fallback routing.
|
||||
"""
|
||||
ext = path.suffix.lower()
|
||||
language = language_for_extension(ext)
|
||||
if not language:
|
||||
return [], ""
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
except OSError as e:
|
||||
logger.debug("read failed for %s: %s", path, e)
|
||||
return [], language
|
||||
extractor = get_extractor(language)
|
||||
if extractor is None:
|
||||
return [], language
|
||||
return extractor.extract_symbols(content, language), language
|
||||
|
|
@ -0,0 +1,531 @@
|
|||
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
|
||||
|
||||
Per plan U4 Execution note: characterization-first — verify that existing
|
||||
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
|
||||
(no regression). Then add PLAN_EXEC wiring tests.
|
||||
|
||||
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||
from agentkit.core.phase import PhaseState
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_chat():
|
||||
"""Create a FastAPI app with Chat routes and mocked dependencies."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from agentkit.server.routes.chat import router
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.store import InMemorySessionStore
|
||||
|
||||
app.state.session_manager = SessionManager(store=InMemorySessionStore())
|
||||
app.state.llm_gateway = MagicMock()
|
||||
app.state.agent_pool = MagicMock()
|
||||
app.state.server_config = MagicMock()
|
||||
app.state.server_config.api_key = None
|
||||
app.state.server_config.plan_exec = {}
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app_with_chat):
|
||||
return TestClient(app_with_chat)
|
||||
|
||||
|
||||
def _make_routing(
|
||||
execution_mode: ExecutionMode = ExecutionMode.REACT,
|
||||
tools: list | None = None,
|
||||
) -> SkillRoutingResult:
|
||||
"""Build a minimal SkillRoutingResult for testing."""
|
||||
return SkillRoutingResult(
|
||||
execution_mode=execution_mode,
|
||||
tools=tools or [],
|
||||
clean_content="test message",
|
||||
model="default",
|
||||
agent_name="test-agent",
|
||||
system_prompt=None,
|
||||
skill_name=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_websocket_mock(app) -> MagicMock:
|
||||
"""Build a mock WebSocket with app.state and async send_json."""
|
||||
ws = MagicMock()
|
||||
ws.app = app
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
|
||||
def _make_agent_mock() -> MagicMock:
|
||||
"""Build a mock Agent with _tool_registry and _react_engine."""
|
||||
agent = MagicMock()
|
||||
agent.name = "test-agent"
|
||||
agent._tool_registry = MagicMock()
|
||||
agent._tool_registry.list_tools.return_value = []
|
||||
agent._system_prompt = None
|
||||
# _react_engine is None to force the code path that creates a new engine
|
||||
agent._react_engine = None
|
||||
agent.get_model.return_value = "default"
|
||||
return agent
|
||||
|
||||
|
||||
def _make_session_manager_mock() -> MagicMock:
|
||||
"""Build a mock SessionManager with async methods."""
|
||||
sm = MagicMock()
|
||||
# get_session returns a mock session with agent_name="test-agent"
|
||||
session = MagicMock()
|
||||
session.agent_name = "test-agent"
|
||||
session.status = "active"
|
||||
sm.get_session = AsyncMock(return_value=session)
|
||||
sm.get_chat_messages = AsyncMock(return_value=[])
|
||||
sm.append_message = AsyncMock()
|
||||
return sm
|
||||
|
||||
|
||||
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
|
||||
"""Wire up app.state so _handle_chat_message finds the right routing."""
|
||||
app.state.agent_pool.get_agent.return_value = agent
|
||||
app.state.request_preprocessor = MagicMock()
|
||||
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# REST — PLAN_EXEC raises 501 (KTD4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestPlanExec501:
|
||||
def test_rest_plan_exec_returns_501(self, client):
|
||||
"""REST send_message with execution_mode=plan_exec → 501."""
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
msg_resp = client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello", "execution_mode": "plan_exec"},
|
||||
)
|
||||
assert msg_resp.status_code == 501
|
||||
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
|
||||
|
||||
def test_rest_react_mode_still_works(self, client):
|
||||
"""REST send_message without execution_mode doesn't 501."""
|
||||
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||
session_id = create_resp.json()["session_id"]
|
||||
|
||||
# No execution_mode field → should NOT trigger 501.
|
||||
msg_resp = client.post(
|
||||
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||
json={"content": "Hello"},
|
||||
)
|
||||
assert msg_resp.status_code != 501
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Characterization — REWOO still falls back to REACT (no regression)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
|
||||
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
captured_engine_kwargs: dict = {}
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
captured_engine_kwargs.update(kwargs)
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = None
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
return
|
||||
yield # async generator marker
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="test",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
# REWOO should NOT build a phase_policy
|
||||
assert captured_engine_kwargs.get("phase_policy") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
|
||||
app_with_chat,
|
||||
):
|
||||
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
captured_engine: list = []
|
||||
captured_tools: list = []
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = (
|
||||
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||
)
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
captured_tools.extend(kwargs.get("tools", []))
|
||||
captured_engine.append(self)
|
||||
return
|
||||
yield
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="test",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
assert len(captured_engine) == 1
|
||||
engine = captured_engine[0]
|
||||
assert engine._phase_policy is not None
|
||||
assert engine._current_phase == PhaseState.PLANNING
|
||||
# AdvancePhaseTool was registered in the tools list
|
||||
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
|
||||
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
captured_policy: list = []
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
captured_policy.append(kwargs.get("phase_policy"))
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = (
|
||||
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||
)
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
return
|
||||
yield
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="test",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
assert len(captured_policy) == 1
|
||||
assert captured_policy[0] is not None
|
||||
# Default policy: PLANNING allows search but not write_file
|
||||
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
|
||||
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {"enabled": False}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
captured_engine_kwargs: dict = {}
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
captured_engine_kwargs.update(kwargs)
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = None
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
return
|
||||
yield
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="test",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
# enabled=False → no phase_policy (falls back to REACT)
|
||||
assert captured_engine_kwargs.get("phase_policy") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
|
||||
"""Error: phase policy construction fails → error event sent, returns early."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
engine_constructor_called = []
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
engine_constructor_called.append(kwargs)
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
yield
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="test",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
error_messages = [m for m in sent_messages if m.get("type") == "error"]
|
||||
assert len(error_messages) == 1
|
||||
assert "phase policy error" in error_messages[0]["data"]["message"]
|
||||
# Engine constructor was NOT called (returned early)
|
||||
assert len(engine_constructor_called) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# phase_changed event emission
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
|
||||
"""phase_changed event sent when current_phase changes during execute_stream."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
app_with_chat.state.server_config.plan_exec = {}
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = kwargs.get("phase_policy")
|
||||
self._current_phase = PhaseState.PLANNING
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return self._current_phase
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
yield ReActEvent(
|
||||
event_type="tool_call",
|
||||
step=1,
|
||||
data={"tool": "search", "output": "ok"},
|
||||
)
|
||||
# Simulate phase transition (as if AdvancePhaseTool was called)
|
||||
self._current_phase = PhaseState.BUILDING
|
||||
yield ReActEvent(
|
||||
event_type="final_answer",
|
||||
step=2,
|
||||
data={"output": "done"},
|
||||
)
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="go",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||
assert len(phase_events) == 1
|
||||
assert phase_events[0]["data"]["phase"] == "building"
|
||||
assert phase_events[0]["data"]["previous"] == "planning"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
|
||||
"""Characterization: REACT mode → no phase_changed events."""
|
||||
from agentkit.server.routes import chat as chat_module
|
||||
|
||||
agent = _make_agent_mock()
|
||||
routing = _make_routing(execution_mode=ExecutionMode.REACT)
|
||||
_setup_routing(app_with_chat, routing, agent)
|
||||
|
||||
sm = _make_session_manager_mock()
|
||||
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
|
||||
ws = _make_websocket_mock(app_with_chat)
|
||||
|
||||
class _StubEngine:
|
||||
def __init__(self, **kwargs):
|
||||
self._phase_policy = None
|
||||
self._current_phase = None
|
||||
|
||||
@property
|
||||
def current_phase(self):
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def execute_stream(self, **kwargs):
|
||||
from agentkit.core.react import ReActEvent
|
||||
|
||||
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
|
||||
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||
|
||||
await chat_module._handle_chat_message(
|
||||
websocket=ws,
|
||||
session_id="test-session",
|
||||
content="hi",
|
||||
sm=sm,
|
||||
cancellation_token=MagicMock(),
|
||||
pending_replies={},
|
||||
pending_confirmations=None,
|
||||
)
|
||||
|
||||
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||
assert len(phase_events) == 0
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
|
||||
|
||||
Verifies:
|
||||
- auxiliary_model routes _summarize through the cheaper model first
|
||||
- empty content (Finding 4 anti-pattern) triggers fallback to main model
|
||||
- auxiliary exception triggers fallback to main model
|
||||
- both auxiliary and main failing falls through to _simple_summary
|
||||
- auxiliary_model=None preserves existing single-model behavior (characterization)
|
||||
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
from agentkit.llm.config import LLMConfig
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
|
||||
"""Mock LLMGateway returning a fixed response."""
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||
)
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
|
||||
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
|
||||
|
||||
Each call to gateway.chat(model=X) pops the next response for X from a queue,
|
||||
so repeated calls to the same model can return different values.
|
||||
"""
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
queues = {m: list(rs) for m, rs in responses_by_model.items()}
|
||||
|
||||
async def chat_side_effect(*, messages, model, **kwargs):
|
||||
queue = queues.get(model)
|
||||
if queue is None:
|
||||
raise ValueError(f"unexpected model={model}")
|
||||
if not queue:
|
||||
raise ValueError(f"queue for model={model} exhausted")
|
||||
item = queue.pop(0)
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
return item
|
||||
|
||||
gateway.chat = AsyncMock(side_effect=chat_side_effect)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
|
||||
"""Generate long messages that exceed token budget (triggers compression)."""
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
for i in range(count):
|
||||
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
|
||||
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
|
||||
messages.append({"role": "user", "content": "recent question"})
|
||||
messages.append({"role": "assistant", "content": "recent answer"})
|
||||
return messages
|
||||
|
||||
|
||||
# ── Characterization: auxiliary_model=None preserves existing behavior ──
|
||||
|
||||
|
||||
class TestAuxiliaryNoneCharacterization:
|
||||
"""auxiliary_model=None (default) — single model call, existing behavior."""
|
||||
|
||||
async def test_no_auxiliary_calls_main_once(self):
|
||||
gateway = make_gateway_with_response("main summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
# auxiliary_model omitted → None
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
gateway.chat.assert_awaited_once()
|
||||
# The call used the main model
|
||||
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||
# Summary surfaced in result
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_main_failure_falls_to_simple_summary(self):
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# _simple_summary produces truncated messages with "..."
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "..." in summary_msgs[0]["content"]
|
||||
|
||||
|
||||
# ── New behavior: auxiliary routing ──────────────────
|
||||
|
||||
|
||||
class TestAuxiliaryRouting:
|
||||
"""auxiliary_model set and differs from main → auxiliary tried first."""
|
||||
|
||||
async def test_auxiliary_success_returns_auxiliary_content(self):
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content="aux summary",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="MAIN SHOULD NOT BE USED",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Auxiliary called; main NOT called
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 0
|
||||
# Result contains auxiliary summary
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("aux summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_empty_content_triggers_main_fallback(self):
|
||||
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content="",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Auxiliary called once (returned empty)
|
||||
# Main called once (fallback)
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
# Result contains main summary (not the empty auxiliary)
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_whitespace_content_triggers_main_fallback(self):
|
||||
"""Whitespace-only content also counts as empty (Finding 4)."""
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content=" \n ",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Both auxiliary and main called
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
|
||||
async def test_auxiliary_exception_triggers_main_fallback(self):
|
||||
from agentkit.core.exceptions import LLMProviderError
|
||||
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [LLMProviderError("aux", "provider down")],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Both called; main succeeded
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_both_fail_falls_to_simple_summary(self):
|
||||
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
|
||||
# Note: aggressive compression path may invoke _summarize multiple times.
|
||||
# Queue provides enough responses to handle that without raising queue-exhausted.
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [Exception("aux boom")] * 5,
|
||||
"main": [Exception("main boom")] * 5,
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Both called at least once
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) >= 1
|
||||
assert len(main_calls) >= 1
|
||||
# _simple_summary output has "..." truncation markers
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "..." in summary_msgs[0]["content"]
|
||||
|
||||
async def test_auxiliary_equal_to_main_skipped(self):
|
||||
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
|
||||
gateway = make_gateway_with_response("main summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="main", # same as main
|
||||
)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Only one call (to main); auxiliary block skipped
|
||||
assert gateway.chat.await_count == 1
|
||||
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||
|
||||
async def test_audit_fields_preserved(self):
|
||||
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
|
||||
gateway = make_gateway_with_response("aux summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
# Override the mock to use a single-response gateway where auxiliary succeeds
|
||||
# (the make_gateway_with_response mock returns same response regardless of model)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Single call (auxiliary succeeded) — verify audit fields
|
||||
call_kwargs = gateway.chat.await_args.kwargs
|
||||
assert call_kwargs.get("agent_name") == "compressor"
|
||||
assert call_kwargs.get("task_type") == "summarization"
|
||||
|
||||
|
||||
# ── Config wiring ────────────────────────────────────
|
||||
|
||||
|
||||
class TestConfigWiring:
|
||||
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
|
||||
|
||||
def test_llm_config_from_dict_reads_auxiliary_model(self):
|
||||
cfg = LLMConfig.from_dict(
|
||||
{
|
||||
"providers": {},
|
||||
"model_aliases": {"fast": "p/m"},
|
||||
"auxiliary_model": "fast",
|
||||
}
|
||||
)
|
||||
assert cfg.auxiliary_model == "fast"
|
||||
|
||||
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
|
||||
cfg = LLMConfig.from_dict({"providers": {}})
|
||||
assert cfg.auxiliary_model is None
|
||||
|
||||
def test_llm_config_default_auxiliary_none(self):
|
||||
cfg = LLMConfig()
|
||||
assert cfg.auxiliary_model is None
|
||||
|
||||
def test_server_config_build_llm_config_reads_auxiliary_model(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
llm_data = {
|
||||
"providers": {
|
||||
"p": {
|
||||
"type": "openai",
|
||||
"api_key": "k",
|
||||
"base_url": "http://x",
|
||||
"models": {"m": {"alias": "fast"}},
|
||||
}
|
||||
},
|
||||
"auxiliary_model": "fast",
|
||||
}
|
||||
llm_config = ServerConfig._build_llm_config(llm_data)
|
||||
assert llm_config.auxiliary_model == "fast"
|
||||
# Also verify model_aliases still built correctly
|
||||
assert llm_config.model_aliases.get("fast") == "p/m"
|
||||
|
||||
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
llm_config = ServerConfig._build_llm_config({"providers": {}})
|
||||
assert llm_config.auxiliary_model is None
|
||||
|
|
@ -0,0 +1,318 @@
|
|||
"""G7/U2 — Emergency layer rule template + TaskResult extension.
|
||||
|
||||
Verifies:
|
||||
- EmergencyRules.classify maps each exception type to correct error_code
|
||||
- TaskCancelledError raises ValueError (caller must propagate as-is)
|
||||
- EmergencyError.to_dict produces all 5 fields
|
||||
- EmergencyError.to_error_message formats suggestions as "建议:1) ... 2) ..."
|
||||
- Config overrides apply (suggestions, retryable, message)
|
||||
- TaskResult.error_struct field: default None preserves byte-for-byte
|
||||
to_dict() output (backward compat)
|
||||
- TaskResult round-trip serialization includes error_struct when set
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.fallback import (
|
||||
EMPTY_LLM_RESPONSE,
|
||||
MAX_STEPS_REACHED,
|
||||
SHELL_NO_OUTPUT,
|
||||
EmergencyError,
|
||||
EmergencyRules,
|
||||
)
|
||||
from agentkit.core.protocol import TaskResult
|
||||
|
||||
|
||||
# ── Constants unchanged (contract preservation) ──────
|
||||
|
||||
|
||||
class TestExistingConstantsUnchanged:
|
||||
"""Existing 3 constants preserved byte-for-byte."""
|
||||
|
||||
def test_empty_llm_response_unchanged(self):
|
||||
assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE
|
||||
assert "建议" in EMPTY_LLM_RESPONSE
|
||||
|
||||
def test_max_steps_reached_unchanged(self):
|
||||
assert "已达到最大推理步数" in MAX_STEPS_REACHED
|
||||
|
||||
def test_shell_no_output_unchanged(self):
|
||||
assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]"
|
||||
|
||||
|
||||
# ── EmergencyRules.classify ──────────────────────────
|
||||
|
||||
|
||||
class TestEmergencyRulesClassify:
|
||||
"""classify() maps exception types to EmergencyError."""
|
||||
|
||||
def test_timeout(self):
|
||||
exc = TaskTimeoutError(task_id="t1", timeout_seconds=30)
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
assert "稍后重试" in err.suggestions
|
||||
assert "简化任务范围" in err.suggestions
|
||||
assert err.original_error == str(exc)
|
||||
|
||||
def test_loop_detected(self):
|
||||
exc = LoopDetectedError(tool_name="shell", repetitions=3)
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "loop_detected"
|
||||
assert err.retryable is True
|
||||
assert "拆分任务" in err.suggestions
|
||||
assert "检查工具参数" in err.suggestions
|
||||
|
||||
def test_llm_provider_error(self):
|
||||
exc = LLMProviderError("openai", "rate limited")
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "llm_failure"
|
||||
assert err.retryable is True
|
||||
assert "稍后重试" in err.suggestions
|
||||
assert "切换模型" in err.suggestions
|
||||
|
||||
def test_llm_error_subclass_also_classified(self):
|
||||
"""LLMProviderError is a subclass of LLMError; ensure isinstance check works."""
|
||||
from agentkit.core.exceptions import LLMError
|
||||
|
||||
class CustomLLMError(LLMError):
|
||||
pass
|
||||
|
||||
err = EmergencyRules.classify(CustomLLMError("custom"))
|
||||
# CustomLLMError is NOT a LLMProviderError, falls through to generic
|
||||
assert err.error_code == "internal_error"
|
||||
|
||||
def test_generic_exception_internal_error(self):
|
||||
err = EmergencyRules.classify(Exception("unknown boom"))
|
||||
assert err.error_code == "internal_error"
|
||||
assert err.retryable is False
|
||||
assert "联系管理员" in err.suggestions
|
||||
assert err.original_error == "unknown boom"
|
||||
|
||||
def test_task_cancelled_raises(self):
|
||||
"""TaskCancelledError must propagate; classify() raises ValueError."""
|
||||
exc = TaskCancelledError(task_id="t1")
|
||||
with pytest.raises(ValueError, match="TaskCancelledError"):
|
||||
EmergencyRules.classify(exc)
|
||||
|
||||
def test_subclass_of_timeout_classified(self):
|
||||
"""Subclasses of TaskTimeoutError are classified as timeout."""
|
||||
|
||||
class CustomTimeout(TaskTimeoutError):
|
||||
def __init__(self):
|
||||
super().__init__(task_id="custom", timeout_seconds=10)
|
||||
|
||||
err = EmergencyRules.classify(CustomTimeout())
|
||||
assert err.error_code == "timeout"
|
||||
|
||||
|
||||
# ── EmergencyError serialization ─────────────────────
|
||||
|
||||
|
||||
class TestEmergencyErrorSerialization:
|
||||
"""to_dict / to_error_message on EmergencyError."""
|
||||
|
||||
def test_to_dict_produces_all_five_fields(self):
|
||||
err = EmergencyError(
|
||||
error_code="timeout",
|
||||
message="任务执行超时。",
|
||||
suggestions=["稍后重试", "简化任务范围"],
|
||||
retryable=True,
|
||||
original_error="Task t1 timed out after 30s",
|
||||
)
|
||||
d = err.to_dict()
|
||||
assert set(d.keys()) == {
|
||||
"error_code",
|
||||
"message",
|
||||
"suggestions",
|
||||
"retryable",
|
||||
"original_error",
|
||||
}
|
||||
assert d["error_code"] == "timeout"
|
||||
assert d["message"] == "任务执行超时。"
|
||||
assert d["suggestions"] == ["稍后重试", "简化任务范围"]
|
||||
assert d["retryable"] is True
|
||||
assert d["original_error"] == "Task t1 timed out after 30s"
|
||||
|
||||
def test_to_dict_suggestions_list_is_copy(self):
|
||||
"""to_dict returns a fresh list, not the internal reference."""
|
||||
suggestions = ["a", "b"]
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="m",
|
||||
suggestions=suggestions,
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
d = err.to_dict()
|
||||
assert d["suggestions"] is not suggestions
|
||||
d["suggestions"].append("c")
|
||||
assert err.suggestions == ["a", "b"]
|
||||
|
||||
def test_to_error_message_with_suggestions(self):
|
||||
err = EmergencyError(
|
||||
error_code="timeout",
|
||||
message="任务执行超时。",
|
||||
suggestions=["稍后重试", "简化任务范围"],
|
||||
retryable=True,
|
||||
original_error="err",
|
||||
)
|
||||
msg = err.to_error_message()
|
||||
assert msg.startswith("任务执行超时。建议:")
|
||||
assert "1) 稍后重试" in msg
|
||||
assert "2) 简化任务范围" in msg
|
||||
# Format mirrors EMPTY_LLM_RESPONSE style
|
||||
assert msg.endswith("。")
|
||||
|
||||
def test_to_error_message_no_suggestions(self):
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="just a message",
|
||||
suggestions=[],
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
assert err.to_error_message() == "just a message"
|
||||
|
||||
def test_to_error_message_single_suggestion(self):
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="msg",
|
||||
suggestions=["only one"],
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
msg = err.to_error_message()
|
||||
assert msg == "msg建议:1) only one。"
|
||||
|
||||
|
||||
# ── Config override ──────────────────────────────────
|
||||
|
||||
|
||||
class TestConfigOverride:
|
||||
"""classify() applies per-rule config overrides."""
|
||||
|
||||
def test_override_suggestions(self):
|
||||
exc = TaskTimeoutError(task_id="t", timeout_seconds=1)
|
||||
cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.suggestions == ["自定义建议 A", "自定义建议 B"]
|
||||
assert err.error_code == "timeout"
|
||||
|
||||
def test_override_retryable(self):
|
||||
exc = LLMProviderError("openai", "boom")
|
||||
cfg = {"llm_failure": {"retryable": False}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.retryable is False
|
||||
|
||||
def test_override_message(self):
|
||||
exc = LoopDetectedError(tool_name="x", repetitions=2)
|
||||
cfg = {"loop_detected": {"message": "循环啦!"}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.message == "循环啦!"
|
||||
|
||||
def test_override_internal_error_rule(self):
|
||||
cfg = {"internal_error": {"suggestions": ["联系客服"]}}
|
||||
err = EmergencyRules.classify(Exception("boom"), config=cfg)
|
||||
assert err.error_code == "internal_error"
|
||||
assert err.suggestions == ["联系客服"]
|
||||
|
||||
def test_config_none_uses_defaults(self):
|
||||
err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1))
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
|
||||
def test_config_empty_dict_uses_defaults(self):
|
||||
err = EmergencyRules.classify(
|
||||
TaskTimeoutError(task_id="t", timeout_seconds=1), config={}
|
||||
)
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
|
||||
|
||||
# ── TaskResult.error_struct extension ────────────────
|
||||
|
||||
|
||||
def _make_task_result(
|
||||
error_struct: dict | None = None, error_message: str | None = None
|
||||
) -> TaskResult:
|
||||
now = datetime.now(timezone.utc)
|
||||
return TaskResult(
|
||||
task_id="t1",
|
||||
agent_name="a1",
|
||||
status="completed",
|
||||
output_data={"k": "v"},
|
||||
error_message=error_message,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
metrics={"m": 1},
|
||||
error_struct=error_struct,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskResultErrorStruct:
|
||||
"""TaskResult.error_struct field — backward-compatible extension."""
|
||||
|
||||
def test_default_error_struct_is_none(self):
|
||||
tr = _make_task_result()
|
||||
assert tr.error_struct is None
|
||||
|
||||
def test_to_dict_without_error_struct_preserves_existing_shape(self):
|
||||
"""error_struct=None → to_dict() output has NO error_struct key (byte-for-byte)."""
|
||||
tr = _make_task_result()
|
||||
d = tr.to_dict()
|
||||
assert "error_struct" not in d
|
||||
# Existing keys unchanged
|
||||
assert set(d.keys()) == {
|
||||
"task_id",
|
||||
"agent_name",
|
||||
"status",
|
||||
"output_data",
|
||||
"error_message",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"metrics",
|
||||
}
|
||||
|
||||
def test_to_dict_with_error_struct_includes_key(self):
|
||||
struct = {
|
||||
"error_code": "timeout",
|
||||
"message": "超时",
|
||||
"suggestions": ["重试"],
|
||||
"retryable": True,
|
||||
"original_error": "boom",
|
||||
}
|
||||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||||
d = tr.to_dict()
|
||||
assert d["error_struct"] == struct
|
||||
assert d["error_message"] == "超时建议:1) 重试。"
|
||||
|
||||
def test_from_dict_round_trip_with_error_struct(self):
|
||||
struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"}
|
||||
tr = _make_task_result(error_struct=struct)
|
||||
d = tr.to_dict()
|
||||
restored = TaskResult.from_dict(d)
|
||||
assert restored.error_struct == struct
|
||||
|
||||
def test_from_dict_without_error_struct_defaults_none(self):
|
||||
tr = _make_task_result()
|
||||
d = tr.to_dict()
|
||||
# Simulate legacy data without error_struct key
|
||||
restored = TaskResult.from_dict(d)
|
||||
assert restored.error_struct is None
|
||||
|
||||
def test_error_message_and_error_struct_coexist(self):
|
||||
"""Both fields can be set simultaneously (parallel contract per KTD2)."""
|
||||
struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"}
|
||||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||||
d = tr.to_dict()
|
||||
assert d["error_message"] == "超时建议:1) 重试。"
|
||||
assert d["error_struct"] == struct
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
"""G7/U3 — Three-tier fallback chain wiring tests.
|
||||
|
||||
Verifies Main → Recovery (ReflexionEngine) → Emergency (EmergencyRules)
|
||||
at chat REST path. Mocks ReActEngine + ReflexionEngine + LLMGateway.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.react import ReActResult
|
||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||
|
||||
|
||||
def _make_react_result(status: str = "success", output: str = "ok") -> ReActResult:
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=[],
|
||||
total_steps=1,
|
||||
total_tokens=10,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _make_react_engine(result=None, raises=None):
|
||||
"""Build a fake ReActEngine with .execute returning result or raising."""
|
||||
engine = MagicMock()
|
||||
engine.reset = MagicMock()
|
||||
if raises is not None:
|
||||
engine.execute = AsyncMock(side_effect=raises)
|
||||
else:
|
||||
engine.execute = AsyncMock(return_value=result or _make_react_result())
|
||||
return engine
|
||||
|
||||
|
||||
def _make_llm_gateway():
|
||||
gw = MagicMock()
|
||||
gw.chat = AsyncMock(return_value=MagicMock(content="recovered"))
|
||||
return gw
|
||||
|
||||
|
||||
def _make_reflexion_result(status: str = "success", output: str = "recovered"):
|
||||
"""Synthesize a ReflexionResult-like object."""
|
||||
return MagicMock(
|
||||
status=status,
|
||||
output=output,
|
||||
trajectory=[],
|
||||
total_steps=1,
|
||||
total_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_reflexion(monkeypatch):
|
||||
"""Patch ReflexionEngine used inside the chain to a controllable mock."""
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
instances: list[MagicMock] = []
|
||||
|
||||
class _MockReflexion:
|
||||
def __init__(self, llm_gateway, max_reflections=1, **kwargs):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_reflections = max_reflections
|
||||
self.execute = AsyncMock(return_value=_make_reflexion_result())
|
||||
instances.append(self)
|
||||
|
||||
monkeypatch.setattr(_fallback_chain, "ReflexionEngine", _MockReflexion)
|
||||
return instances
|
||||
|
||||
|
||||
# ─── Tier 1: Main ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMainTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_success_no_recovery_no_emergency(self):
|
||||
engine = _make_react_engine(result=_make_react_result(status="success", output="hello"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert result.output == "hello"
|
||||
assert result.error_struct is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_unknown_status_treated_as_success(self):
|
||||
"""Unknown status (not in soft_failure set) is treated as success-like."""
|
||||
engine = _make_react_engine(result=_make_react_result(status="partial", output="x"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "success"
|
||||
|
||||
|
||||
# ─── Tier 2: Recovery ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRecoveryTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_timeout_triggers_recovery_success(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
assert result.output == "recovered"
|
||||
# ReflexionEngine was instantiated and called
|
||||
assert len(patched_reflexion) == 1
|
||||
patched_reflexion[0].execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_loop_detected_triggers_recovery(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_llm_provider_error_triggers_recovery(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="503"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_soft_failure_status_triggers_recovery(self, patched_reflexion):
|
||||
"""Soft failure (empty_fallback) without exception still triggers Recovery."""
|
||||
engine = _make_react_engine(result=_make_react_result(status="empty_fallback", output=""))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_disabled_skips_to_emergency(self):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_failure_falls_through_to_emergency(self, patched_reflexion):
|
||||
"""Recovery raises → Emergency tier fires with original exception."""
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
# Make ReflexionEngine.execute raise
|
||||
patched_reflexion_instance = MagicMock()
|
||||
patched_reflexion_instance.execute = AsyncMock(
|
||||
side_effect=RuntimeError("reflexion crashed")
|
||||
)
|
||||
# Override the patched class to use our instance
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
|
||||
class _MockReflexionWithExc:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = patched_reflexion_instance.execute
|
||||
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionWithExc
|
||||
try:
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_unsuccessful_status_falls_through(self, patched_reflexion):
|
||||
"""Recovery returns non-success status → Emergency fires."""
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
# Make ReflexionEngine return unsuccessful result with empty output
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
class _MockReflexionNoOutput:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionNoOutput
|
||||
try:
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
|
||||
# ─── Tier 3: Emergency ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmergencyTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_timeout_error_code(self, patched_reflexion):
|
||||
# Make recovery fail (empty result) so Emergency fires
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
class _MockReflexionEmpty:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionEmpty
|
||||
try:
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
assert result.error_struct["retryable"] is True
|
||||
assert "建议" in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_loop_detected_error_code(self):
|
||||
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
|
||||
# Recovery disabled so Emergency fires directly
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "loop_detected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_llm_failure_error_code(self):
|
||||
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="500"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "llm_failure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_internal_error_for_generic_exception(self):
|
||||
engine = _make_react_engine(raises=RuntimeError("unexpected"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "internal_error"
|
||||
assert result.error_struct["retryable"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancelled_propagates_not_routed_to_emergency(self):
|
||||
"""TaskCancelledError must propagate, not be classified by Emergency."""
|
||||
engine = _make_react_engine(raises=TaskCancelledError(task_id="t1"))
|
||||
with pytest.raises(TaskCancelledError):
|
||||
await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_disabled_reraises_original(self):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
with pytest.raises(TaskTimeoutError):
|
||||
await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={
|
||||
"recovery": {"enabled": False},
|
||||
"emergency": {"enabled": False},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ─── Config wiring ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServerConfigFallbackChain:
|
||||
def test_fallback_chain_section_read_from_dict(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"fallback_chain": {
|
||||
"enabled": True,
|
||||
"recovery": {"enabled": False, "max_retries": 3},
|
||||
"emergency": {"enabled": True},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert config.fallback_chain["enabled"] is True
|
||||
assert config.fallback_chain["recovery"] == {"enabled": False, "max_retries": 3}
|
||||
assert config.fallback_chain["emergency"] == {"enabled": True}
|
||||
|
||||
def test_fallback_chain_defaults_empty_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.fallback_chain == {}
|
||||
|
|
@ -0,0 +1,348 @@
|
|||
"""Unit tests for PhasePolicy + PhaseState (G6 core, R24/R25/R26).
|
||||
|
||||
Covers:
|
||||
- PhaseState enum (next_of, from_string)
|
||||
- default_policy() KTD5 whitelist
|
||||
- PhasePolicy.is_tool_allowed / is_bash_command_allowed
|
||||
- policy_from_config parsing (R26 config-driven)
|
||||
- ServerConfig.plan_exec integration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.phase import (
|
||||
WILDCARD,
|
||||
PhasePolicy,
|
||||
PhaseState,
|
||||
default_policy,
|
||||
policy_from_config,
|
||||
)
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PhaseState enum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPhaseState:
|
||||
def test_values(self):
|
||||
assert PhaseState.PLANNING.value == "planning"
|
||||
assert PhaseState.BUILDING.value == "building"
|
||||
assert PhaseState.VERIFICATION.value == "verification"
|
||||
assert PhaseState.DELIVERY.value == "delivery"
|
||||
|
||||
def test_next_of(self):
|
||||
assert PhaseState.next_of(PhaseState.PLANNING) == PhaseState.BUILDING
|
||||
assert PhaseState.next_of(PhaseState.BUILDING) == PhaseState.VERIFICATION
|
||||
assert PhaseState.next_of(PhaseState.VERIFICATION) == PhaseState.DELIVERY
|
||||
assert PhaseState.next_of(PhaseState.DELIVERY) is None
|
||||
|
||||
def test_from_string_case_insensitive(self):
|
||||
assert PhaseState.from_string("planning") == PhaseState.PLANNING
|
||||
assert PhaseState.from_string("PLANNING") == PhaseState.PLANNING
|
||||
assert PhaseState.from_string("Building") == PhaseState.BUILDING
|
||||
|
||||
def test_from_string_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||
PhaseState.from_string("unknown")
|
||||
with pytest.raises(ValueError, match="Valid:"):
|
||||
PhaseState.from_string("exploration")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# default_policy() — KTD5 whitelist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaultPolicy:
|
||||
def test_has_all_four_phases(self):
|
||||
policy = default_policy()
|
||||
assert PhaseState.PLANNING in policy.whitelist
|
||||
assert PhaseState.BUILDING in policy.whitelist
|
||||
assert PhaseState.VERIFICATION in policy.whitelist
|
||||
assert PhaseState.DELIVERY in policy.whitelist
|
||||
|
||||
def test_planning_whitelist_matches_r24(self):
|
||||
policy = default_policy()
|
||||
allowed = policy.whitelist[PhaseState.PLANNING]
|
||||
assert "search" in allowed
|
||||
assert "read_file" in allowed
|
||||
assert "shell" in allowed
|
||||
assert "tool_search" in allowed
|
||||
# Planning must NOT allow write_file.
|
||||
assert "write_file" not in allowed
|
||||
|
||||
def test_building_whitelist_includes_write_file(self):
|
||||
policy = default_policy()
|
||||
allowed = policy.whitelist[PhaseState.BUILDING]
|
||||
assert "write_file" in allowed
|
||||
assert "shell" in allowed
|
||||
assert "read_file" in allowed
|
||||
|
||||
def test_verification_whitelist_excludes_write(self):
|
||||
policy = default_policy()
|
||||
allowed = policy.whitelist[PhaseState.VERIFICATION]
|
||||
assert "shell" in allowed
|
||||
assert "read_file" in allowed
|
||||
assert "write_file" not in allowed
|
||||
|
||||
def test_delivery_wildcard(self):
|
||||
policy = default_policy()
|
||||
allowed = policy.whitelist[PhaseState.DELIVERY]
|
||||
assert WILDCARD in allowed
|
||||
|
||||
def test_start_phase_default_planning(self):
|
||||
assert default_policy().start_phase == PhaseState.PLANNING
|
||||
|
||||
def test_auto_advance_default_none(self):
|
||||
# KTD6: manual by default.
|
||||
assert default_policy().auto_advance_after_steps is None
|
||||
|
||||
def test_bash_filter_blocks_rm_in_planning(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("ls -la", PhaseState.PLANNING) is True
|
||||
assert policy.is_bash_command_allowed("git status", PhaseState.PLANNING) is True
|
||||
assert policy.is_bash_command_allowed("rm -rf /tmp/x", PhaseState.PLANNING) is False
|
||||
assert policy.is_bash_command_allowed("echo x > file.txt", PhaseState.PLANNING) is False
|
||||
|
||||
def test_bash_filter_no_restriction_in_building(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_bash_command_allowed("rm -rf build/", PhaseState.BUILDING) is True
|
||||
assert policy.is_bash_command_allowed("echo x > out.log", PhaseState.BUILDING) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PhasePolicy — is_tool_allowed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolAllowed:
|
||||
def test_planning_allows_search(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
|
||||
|
||||
def test_planning_blocks_write_file(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_tool_allowed("write_file", PhaseState.PLANNING) is False
|
||||
|
||||
def test_building_allows_write_file(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||
|
||||
def test_delivery_wildcard_allows_anything(self):
|
||||
policy = default_policy()
|
||||
assert policy.is_tool_allowed("any_random_tool", PhaseState.DELIVERY) is True
|
||||
assert policy.is_tool_allowed("write_file", PhaseState.DELIVERY) is True
|
||||
|
||||
def test_unknown_phase_returns_false(self):
|
||||
# ponytail: unknown phase → empty whitelist → no tool allowed.
|
||||
# We can't construct an unknown PhaseState (enum), but if a phase
|
||||
# were missing from the whitelist dict, is_tool_allowed should
|
||||
# return False (defensive).
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({"search"}),
|
||||
PhaseState.BUILDING: frozenset({"write_file"}),
|
||||
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
}
|
||||
)
|
||||
# BUILDING is in whitelist, so allowed checks work normally.
|
||||
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||
# Phase missing from whitelist would return False (defensive .get default).
|
||||
# We test this by constructing a minimal policy.
|
||||
minimal = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({WILDCARD}),
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
}
|
||||
)
|
||||
# VERIFICATION is in whitelist — wildcard allows all.
|
||||
assert minimal.is_tool_allowed("anything", PhaseState.VERIFICATION) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PhasePolicy — edge cases & errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPhasePolicyEdgeCases:
|
||||
def test_empty_whitelist_raises(self):
|
||||
# Fail-fast: an empty whitelist for a non-wildcard phase is a bug.
|
||||
with pytest.raises(ValueError, match="empty whitelist"):
|
||||
PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset(), # empty!
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
}
|
||||
)
|
||||
|
||||
def test_wildcard_only_does_not_raise(self):
|
||||
# Wildcard-only whitelist is valid (means "all tools").
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({WILDCARD}),
|
||||
PhaseState.BUILDING: frozenset({WILDCARD}),
|
||||
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
}
|
||||
)
|
||||
assert policy.is_tool_allowed("anything", PhaseState.PLANNING) is True
|
||||
|
||||
def test_to_dict_serializable(self):
|
||||
policy = default_policy()
|
||||
d = policy.to_dict()
|
||||
assert "whitelist" in d
|
||||
assert "planning" in d["whitelist"]
|
||||
assert "delivery" in d["whitelist"]
|
||||
assert d["start_phase"] == "planning"
|
||||
assert d["auto_advance_after_steps"] is None
|
||||
|
||||
def test_custom_bash_filter(self):
|
||||
custom_filter = re.compile(r"\b(pip install|npm install)\b")
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({"shell"}),
|
||||
PhaseState.BUILDING: frozenset({"shell"}),
|
||||
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||
PhaseState.DELIVERY: frozenset({WILDCARD}),
|
||||
},
|
||||
bash_command_filter={PhaseState.BUILDING: custom_filter},
|
||||
)
|
||||
assert policy.is_bash_command_allowed("npm install foo", PhaseState.BUILDING) is False
|
||||
assert policy.is_bash_command_allowed("npm run build", PhaseState.BUILDING) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# policy_from_config — R26 (config-driven)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPolicyFromConfig:
|
||||
def test_empty_config_returns_none(self):
|
||||
assert policy_from_config({}) is None
|
||||
|
||||
def test_enabled_false_returns_none(self):
|
||||
# Opt-out — explicit `enabled: false` disables policy.
|
||||
result = policy_from_config({"enabled": False})
|
||||
assert result is None
|
||||
|
||||
def test_enabled_default_true_when_section_present(self):
|
||||
# When section is present but `enabled` is missing, default is True.
|
||||
result = policy_from_config({"auto_advance_after_steps": 3})
|
||||
assert result is not None
|
||||
assert result.auto_advance_after_steps == 3
|
||||
|
||||
def test_auto_advance_after_steps(self):
|
||||
policy = policy_from_config({"enabled": True, "auto_advance_after_steps": 5})
|
||||
assert policy is not None
|
||||
assert policy.auto_advance_after_steps == 5
|
||||
|
||||
def test_start_phase_custom(self):
|
||||
policy = policy_from_config({"enabled": True, "start_phase": "building"})
|
||||
assert policy is not None
|
||||
assert policy.start_phase == PhaseState.BUILDING
|
||||
|
||||
def test_start_phase_invalid_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||
policy_from_config({"enabled": True, "start_phase": "unknown"})
|
||||
|
||||
def test_whitelist_override_merges_with_default(self):
|
||||
policy = policy_from_config(
|
||||
{
|
||||
"enabled": True,
|
||||
"whitelist_override": {
|
||||
"planning": ["search", "read_file"], # removes shell from default
|
||||
},
|
||||
}
|
||||
)
|
||||
assert policy is not None
|
||||
# Override wins — shell should be removed from planning.
|
||||
assert policy.is_tool_allowed("search", PhaseState.PLANNING) is True
|
||||
assert policy.is_tool_allowed("read_file", PhaseState.PLANNING) is True
|
||||
assert policy.is_tool_allowed("shell", PhaseState.PLANNING) is False
|
||||
# Other phases unchanged.
|
||||
assert policy.is_tool_allowed("write_file", PhaseState.BUILDING) is True
|
||||
|
||||
def test_whitelist_override_invalid_phase_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid phase name"):
|
||||
policy_from_config(
|
||||
{
|
||||
"enabled": True,
|
||||
"whitelist_override": {"unknown_phase": ["tool"]},
|
||||
}
|
||||
)
|
||||
|
||||
def test_whitelist_override_non_list_raises(self):
|
||||
with pytest.raises(ValueError, match="must be a list"):
|
||||
policy_from_config(
|
||||
{
|
||||
"enabled": True,
|
||||
"whitelist_override": {"planning": "not a list"},
|
||||
}
|
||||
)
|
||||
|
||||
def test_to_dict_round_trip_via_default(self):
|
||||
# Sanity: default policy serializes to a dict with expected keys.
|
||||
policy = default_policy()
|
||||
d = policy.to_dict()
|
||||
assert set(d["whitelist"].keys()) == {
|
||||
"planning",
|
||||
"building",
|
||||
"verification",
|
||||
"delivery",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ServerConfig.plan_exec integration (R26)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestServerConfigPlanExec:
|
||||
def test_default_plan_exec_empty(self):
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.plan_exec == {}
|
||||
|
||||
def test_plan_exec_loaded_from_dict(self):
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"plan_exec": {
|
||||
"enabled": True,
|
||||
"auto_advance_after_steps": 5,
|
||||
}
|
||||
}
|
||||
)
|
||||
assert config.plan_exec == {"enabled": True, "auto_advance_after_steps": 5}
|
||||
|
||||
def test_plan_exec_empty_dict_default(self):
|
||||
config = ServerConfig.from_dict({"plan_exec": {}})
|
||||
assert config.plan_exec == {}
|
||||
|
||||
def test_plan_exec_resolved_to_policy(self):
|
||||
# Wire the config dict through policy_from_config to verify integration.
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"plan_exec": {
|
||||
"enabled": True,
|
||||
"auto_advance_after_steps": 3,
|
||||
}
|
||||
}
|
||||
)
|
||||
policy = policy_from_config(config.plan_exec)
|
||||
assert policy is not None
|
||||
assert policy.auto_advance_after_steps == 3
|
||||
|
||||
def test_plan_exec_disabled_via_config(self):
|
||||
config = ServerConfig.from_dict({"plan_exec": {"enabled": False}})
|
||||
policy = policy_from_config(config.plan_exec)
|
||||
assert policy is None
|
||||
|
|
@ -0,0 +1,367 @@
|
|||
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
|
||||
|
||||
Characterization-first: captures pre-change behavior (rollback_command=None →
|
||||
no rollback, checkpoint saved) before asserting new rollback behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||
from agentkit.experts.plan import PlanPhase, TeamPlan
|
||||
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||
|
||||
|
||||
# ─── PlanPhase field characterization ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanPhaseFields:
|
||||
"""PlanPhase serialization for new fields."""
|
||||
|
||||
def test_characterization_no_new_keys_when_unset(self):
|
||||
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
|
||||
ph = PlanPhase(name="执行", assigned_expert="lead")
|
||||
out = ph.to_dict()
|
||||
assert "validation_command" not in out
|
||||
assert "rollback_command" not in out
|
||||
# Pre-change keys remain
|
||||
assert out["name"] == "执行"
|
||||
assert out["assigned_expert"] == "lead"
|
||||
assert out["status"] == "pending"
|
||||
|
||||
def test_characterization_from_dict_empty_yields_none(self):
|
||||
"""from_dict({}) produces both new fields as None."""
|
||||
ph = PlanPhase.from_dict({"name": "x"})
|
||||
assert ph.validation_command is None
|
||||
assert ph.rollback_command is None
|
||||
|
||||
def test_serialization_includes_keys_when_set(self):
|
||||
ph = PlanPhase(
|
||||
name="frontend",
|
||||
validation_command="ruff check src/",
|
||||
rollback_command="git checkout src/app.vue",
|
||||
)
|
||||
out = ph.to_dict()
|
||||
assert out["validation_command"] == "ruff check src/"
|
||||
assert out["rollback_command"] == "git checkout src/app.vue"
|
||||
|
||||
def test_serialization_round_trip(self):
|
||||
ph = PlanPhase(
|
||||
name="backend",
|
||||
validation_command="pytest -x -q",
|
||||
rollback_command="git checkout src/api.py",
|
||||
)
|
||||
restored = PlanPhase.from_dict(ph.to_dict())
|
||||
assert restored.validation_command == "pytest -x -q"
|
||||
assert restored.rollback_command == "git checkout src/api.py"
|
||||
|
||||
def test_only_validation_set_still_emits_validation_key(self):
|
||||
"""Asymmetric case: only validation_command set — only that key appears."""
|
||||
ph = PlanPhase(name="x", validation_command="echo ok")
|
||||
out = ph.to_dict()
|
||||
assert "validation_command" in out
|
||||
assert "rollback_command" not in out
|
||||
|
||||
|
||||
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_workspace(tmp_path):
|
||||
"""Fresh working directory for subprocess execution."""
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
class TestRollbackExecutor:
|
||||
"""RollbackExecutor happy/edge/failure paths."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("true")
|
||||
assert r.passed is True
|
||||
assert r.exit_code == 0
|
||||
assert r.command == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("false")
|
||||
assert r.passed is False
|
||||
assert r.exit_code != 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_timeout(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
|
||||
r = await ex.execute("sleep 5")
|
||||
assert r.passed is False
|
||||
assert r.exit_code == -1
|
||||
assert "timed out" in r.stderr
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_captures_stdout(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("echo hello-rollback")
|
||||
assert r.passed is True
|
||||
assert "hello-rollback" in r.stdout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
|
||||
"""validate() is just execute() with different intent marker."""
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.validate("true")
|
||||
assert r.passed is True
|
||||
r2 = await ex.validate("false")
|
||||
assert r2.passed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
|
||||
"""Bad shell command should surface as failed result, not raise."""
|
||||
# Use a definitely-broken command — shell still spawns, returns non-zero.
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("exit 7")
|
||||
assert r.passed is False
|
||||
assert r.exit_code == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
|
||||
"""Files created in working_dir are visible via cwd-relative commands."""
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
await ex.execute("echo content > testfile.txt")
|
||||
r = await ex.execute("test -f testfile.txt")
|
||||
assert r.passed is True
|
||||
|
||||
|
||||
# ─── Real git rollback integration ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGitRollbackIntegration:
|
||||
"""Real git repo fixture: writes file, rollback restores via git checkout."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_checkout_restores_modified_file(self, tmp_path):
|
||||
# Init repo and commit a baseline file
|
||||
import subprocess
|
||||
|
||||
repo = str(tmp_path)
|
||||
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
|
||||
baseline = "original content\n"
|
||||
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||
f.write(baseline)
|
||||
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
|
||||
|
||||
# Mutate the file
|
||||
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||
f.write("mutated content\n")
|
||||
|
||||
# Run git checkout as rollback_command
|
||||
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
|
||||
r = await ex.execute("git checkout foo.txt")
|
||||
assert r.passed is True
|
||||
|
||||
with open(os.path.join(repo, "foo.txt")) as f:
|
||||
assert f.read() == baseline
|
||||
|
||||
|
||||
# ─── TeamOrchestrator integration ─────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_team_mock():
|
||||
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
|
||||
team = MagicMock()
|
||||
team.team_id = "test-team"
|
||||
team.status.value = "executing"
|
||||
team.lead_expert = None
|
||||
team.active_experts = []
|
||||
return team
|
||||
|
||||
|
||||
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
|
||||
"""Build a TeamOrchestrator with mocked team and checkpoint."""
|
||||
team = team or _make_team_mock()
|
||||
orch = TeamOrchestrator(
|
||||
team=team,
|
||||
checkpoint=checkpoint,
|
||||
workspace_root=workspace_root,
|
||||
rollback_timeout=2.0,
|
||||
)
|
||||
orch._broadcast_event = AsyncMock()
|
||||
return orch
|
||||
|
||||
|
||||
class TestOrchestratorRollbackIntegration:
|
||||
"""TeamOrchestrator phase failure path integration with rollback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_characterization_no_rollback_when_unset(self, tmp_path):
|
||||
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
# Call _run_phase_rollback — but it should NOT be called from
|
||||
# orchestrator path when fields unset. Verify by simulating the
|
||||
# main-loop guard directly.
|
||||
from agentkit.experts.plan import PhaseStatus as PS
|
||||
|
||||
ph.status = PS.FAILED
|
||||
# Simulate the guard condition in orchestrator.py:280-288
|
||||
should_save = True
|
||||
if (
|
||||
ph.validation_command
|
||||
and ph.rollback_command
|
||||
and isinstance(Exception("x"), (Exception,))
|
||||
):
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
# Guard should not fire
|
||||
assert should_save is True
|
||||
# No rollback events broadcast
|
||||
assert orch._broadcast_event.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_passes_no_rollback_executed(self, tmp_path):
|
||||
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="true", # always passes
|
||||
rollback_command="false", # would fail if executed
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is True
|
||||
|
||||
# Events: started + completed (no rollback)
|
||||
events = [
|
||||
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
|
||||
]
|
||||
assert "phase_rollback_started" in events
|
||||
assert "phase_rollback_completed" in events
|
||||
assert "phase_rollback_failed" not in events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_fails_rollback_succeeds(self, tmp_path):
|
||||
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
|
||||
# Create file under git tracking, then mutate it
|
||||
import subprocess
|
||||
|
||||
repo = str(tmp_path)
|
||||
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
|
||||
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||
f.write("base\n")
|
||||
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
|
||||
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||
f.write("mutated\n")
|
||||
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false", # validation fails
|
||||
rollback_command="git checkout f.txt", # rollback succeeds
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is True
|
||||
|
||||
# File restored to baseline
|
||||
with open(os.path.join(repo, "f.txt")) as f:
|
||||
assert f.read() == "base\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
|
||||
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false", # fails
|
||||
rollback_command="false", # also fails
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is False # R21: skip checkpoint
|
||||
|
||||
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||
assert "phase_rollback_started" in events
|
||||
assert "phase_rollback_failed" in events
|
||||
# phase_rollback_completed should NOT be in events
|
||||
assert "phase_rollback_completed" not in events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
|
||||
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
orch._rollback_timeout = 0.1 # short timeout
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false",
|
||||
rollback_command="sleep 5",
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is False
|
||||
|
||||
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||
assert "phase_rollback_failed" in events
|
||||
|
||||
|
||||
# ─── Config wiring ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServerConfigRollback:
|
||||
"""ServerConfig rollback section wiring."""
|
||||
|
||||
def test_rollback_section_read_from_dict(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"rollback": {
|
||||
"default_timeout": 45.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
assert config.rollback == {"default_timeout": 45.0}
|
||||
|
||||
def test_rollback_defaults_empty_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.rollback == {}
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
"""Unit tests for ReActEngine phase enforcement (G6 wiring, R24).
|
||||
|
||||
Per plan U3 Execution note: characterization-first — verify that
|
||||
`ReActEngine(phase_policy=None)` behaves identically to pre-change (no
|
||||
enforcement, no advance_phase tool, no _current_phase mutation). Then add
|
||||
enforcement tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.phase import PhasePolicy, PhaseState, default_policy
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Characterization — phase_policy=None preserves existing behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCharacterizationNoPolicy:
|
||||
"""When phase_policy=None, no enforcement happens and behavior matches
|
||||
pre-Wave-3."""
|
||||
|
||||
def test_init_without_phase_policy(self):
|
||||
# Minimal stub LLM gateway — we're only testing constructor.
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
assert engine._phase_policy is None
|
||||
assert engine._current_phase is None
|
||||
assert engine._steps_in_phase == 0
|
||||
assert engine.current_phase is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_dispatches_without_phase_check(self):
|
||||
"""Tool dispatch proceeds normally when no policy set."""
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
# MagicMock.name is a special attribute used internally by Mock for
|
||||
# repr — setting it post-construction does not make mock.name == "x"
|
||||
# hold. Patch _find_tool directly to bypass the name lookup.
|
||||
fake_tool = MagicMock()
|
||||
fake_tool.safe_execute = AsyncMock(return_value={"output": "ok"})
|
||||
fake_tool.input_schema = None
|
||||
engine._find_tool = lambda name, tools: fake_tool
|
||||
|
||||
result = await engine._execute_tool("any_tool", {"x": 1}, [fake_tool])
|
||||
assert result == {"output": "ok"}
|
||||
fake_tool.safe_execute.assert_awaited_once_with(x=1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advance_phase_returns_none_without_policy(self):
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
assert engine.advance_phase() is None
|
||||
|
||||
def test_reset_does_not_touch_phase_state_when_no_policy(self):
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
engine.reset()
|
||||
assert engine._current_phase is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Initialization with phase_policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPhasePolicyInitialization:
|
||||
def test_phase_policy_set_initializes_current_phase(self):
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=default_policy(),
|
||||
)
|
||||
assert engine._phase_policy is not None
|
||||
assert engine._current_phase == PhaseState.PLANNING
|
||||
assert engine._steps_in_phase == 0
|
||||
|
||||
def test_reset_resets_phase_to_start(self):
|
||||
gateway = MagicMock()
|
||||
engine = ReActEngine(
|
||||
llm_gateway=gateway,
|
||||
phase_policy=default_policy(),
|
||||
)
|
||||
# Manually move phase forward (simulating execute progress).
|
||||
engine.advance_phase() # PLANNING → BUILDING
|
||||
assert engine._current_phase == PhaseState.BUILDING
|
||||
engine._steps_in_phase = 5
|
||||
|
||||
engine.reset()
|
||||
assert engine._current_phase == PhaseState.PLANNING
|
||||
assert engine._steps_in_phase == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# advance_phase() transitions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancePhase:
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
|
||||
def test_planning_to_building(self, engine):
|
||||
new_phase = engine.advance_phase()
|
||||
assert new_phase == PhaseState.BUILDING
|
||||
assert engine.current_phase == PhaseState.BUILDING
|
||||
assert engine._steps_in_phase == 0 # counter reset on transition
|
||||
|
||||
def test_building_to_verification(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
new_phase = engine.advance_phase()
|
||||
assert new_phase == PhaseState.VERIFICATION
|
||||
assert engine.current_phase == PhaseState.VERIFICATION
|
||||
|
||||
def test_verification_to_delivery(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
engine.advance_phase() # → VERIFICATION
|
||||
new_phase = engine.advance_phase()
|
||||
assert new_phase == PhaseState.DELIVERY
|
||||
assert engine.current_phase == PhaseState.DELIVERY
|
||||
|
||||
def test_delivery_returns_none(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
engine.advance_phase() # → VERIFICATION
|
||||
engine.advance_phase() # → DELIVERY
|
||||
result = engine.advance_phase()
|
||||
assert result is None
|
||||
assert engine.current_phase == PhaseState.DELIVERY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _check_phase_permission — whitelist enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPhasePermission:
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
|
||||
def test_search_allowed_in_planning(self, engine):
|
||||
assert engine._check_phase_permission("search", {}) is None
|
||||
|
||||
def test_write_file_blocked_in_planning(self, engine):
|
||||
result = engine._check_phase_permission("write_file", {})
|
||||
assert result is not None
|
||||
assert result["error"] == "phase_violation"
|
||||
assert "write_file" in result["message"]
|
||||
assert result["current_phase"] == "planning"
|
||||
|
||||
def test_write_file_allowed_in_building(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
assert engine._check_phase_permission("write_file", {}) is None
|
||||
|
||||
def test_any_tool_allowed_in_delivery(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
engine.advance_phase() # → VERIFICATION
|
||||
engine.advance_phase() # → DELIVERY
|
||||
assert engine._check_phase_permission("literally_anything", {}) is None
|
||||
|
||||
def test_bash_command_filter_blocks_rm_in_planning(self, engine):
|
||||
result = engine._check_phase_permission("shell", {"command": "rm -rf /tmp"})
|
||||
assert result is not None
|
||||
assert result["error"] == "phase_violation"
|
||||
assert "rm" in result["message"] or "Bash command" in result["message"]
|
||||
|
||||
def test_bash_command_filter_allows_safe_in_planning(self, engine):
|
||||
# `ls` and `git status` are not blocked.
|
||||
assert engine._check_phase_permission("shell", {"command": "ls -la"}) is None
|
||||
assert engine._check_phase_permission("shell", {"command": "git status"}) is None
|
||||
|
||||
def test_bash_command_filter_no_restriction_in_building(self, engine):
|
||||
engine.advance_phase() # → BUILDING
|
||||
# `rm` is allowed in building phase.
|
||||
assert engine._check_phase_permission("shell", {"command": "rm -rf build/"}) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _execute_tool integration — phase enforcement actually blocks dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecuteToolPhaseEnforcement:
|
||||
@pytest.fixture
|
||||
def engine_with_tools(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
# Two fake tools: one allowed in PLANNING (search), one not (write_file).
|
||||
# MagicMock.name can't be set post-construction (special attribute),
|
||||
# so we patch _find_tool with a dict-based lookup.
|
||||
search_tool = MagicMock()
|
||||
search_tool.input_schema = None
|
||||
search_tool.safe_execute = AsyncMock(return_value={"results": []})
|
||||
|
||||
write_tool = MagicMock()
|
||||
write_tool.input_schema = None
|
||||
write_tool.safe_execute = AsyncMock(return_value={"written": True})
|
||||
|
||||
tools_by_name = {"search": search_tool, "write_file": write_tool}
|
||||
engine._find_tool = lambda name, tools: tools_by_name.get(name)
|
||||
|
||||
return engine, [search_tool, write_tool]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_tool_returns_phase_violation_and_skips_dispatch(self, engine_with_tools):
|
||||
engine, tools = engine_with_tools
|
||||
# write_file in PLANNING should be blocked — write_tool.safe_execute
|
||||
# should NEVER be called.
|
||||
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||
assert result["error"] == "phase_violation"
|
||||
assert result["current_phase"] == "planning"
|
||||
write_tool = tools[1]
|
||||
write_tool.safe_execute.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_tool_dispatches_normally(self, engine_with_tools):
|
||||
engine, tools = engine_with_tools
|
||||
result = await engine._execute_tool("search", {"query": "foo"}, tools)
|
||||
assert result == {"results": []}
|
||||
search_tool = tools[0]
|
||||
search_tool.safe_execute.assert_awaited_once_with(query="foo")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_after_advance_phase_blocked_tool_now_dispatches(self, engine_with_tools):
|
||||
engine, tools = engine_with_tools
|
||||
# First: write_file blocked in PLANNING.
|
||||
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||
assert result["error"] == "phase_violation"
|
||||
# Advance to BUILDING.
|
||||
engine.advance_phase()
|
||||
# Now: write_file allowed.
|
||||
result = await engine._execute_tool("write_file", {"path": "/x"}, tools)
|
||||
assert result == {"written": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-advance safety net (KTD6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAutoAdvance:
|
||||
def test_auto_advance_after_threshold(self):
|
||||
# Custom policy with auto-advance after 2 steps.
|
||||
policy = PhasePolicy(
|
||||
whitelist={
|
||||
PhaseState.PLANNING: frozenset({"search"}),
|
||||
PhaseState.BUILDING: frozenset({"write_file"}),
|
||||
PhaseState.VERIFICATION: frozenset({"shell"}),
|
||||
PhaseState.DELIVERY: frozenset({"*"}),
|
||||
},
|
||||
auto_advance_after_steps=2,
|
||||
)
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=policy)
|
||||
assert engine.current_phase == PhaseState.PLANNING
|
||||
|
||||
# Step 1: counter goes to 1, no advance yet.
|
||||
engine._steps_in_phase += 1
|
||||
assert engine._maybe_auto_advance() is False
|
||||
assert engine.current_phase == PhaseState.PLANNING
|
||||
|
||||
# Step 2: counter hits 2, advance triggered.
|
||||
engine._steps_in_phase += 1
|
||||
assert engine._maybe_auto_advance() is True
|
||||
assert engine.current_phase == PhaseState.BUILDING
|
||||
assert engine._steps_in_phase == 0 # reset on advance
|
||||
|
||||
def test_auto_advance_none_default(self):
|
||||
# default_policy has auto_advance_after_steps=None — no auto-advance.
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
engine._steps_in_phase = 100
|
||||
assert engine._maybe_auto_advance() is False
|
||||
assert engine.current_phase == PhaseState.PLANNING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AdvancePhaseTool integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancePhaseTool:
|
||||
@pytest.mark.asyncio
|
||||
async def test_advance_phase_tool_transitions_engine(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
tool = AdvancePhaseTool(engine=engine)
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is False
|
||||
assert result["current_phase"] == "building"
|
||||
assert engine.current_phase == PhaseState.BUILDING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advance_phase_tool_at_delivery_returns_error(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
# Walk through all phases.
|
||||
engine.advance_phase() # PLANNING → BUILDING
|
||||
engine.advance_phase() # BUILDING → VERIFICATION
|
||||
engine.advance_phase() # VERIFICATION → DELIVERY
|
||||
tool = AdvancePhaseTool(engine=engine)
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is True
|
||||
assert result["error"] == "already_at_final_phase"
|
||||
assert result["current_phase"] == "delivery"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advance_phase_tool_without_policy_returns_error(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock()) # no policy
|
||||
tool = AdvancePhaseTool(engine=engine)
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is True
|
||||
assert result["error"] == "no_phase_policy"
|
||||
|
||||
def test_tool_schema_accepts_no_arguments(self):
|
||||
engine = ReActEngine(llm_gateway=MagicMock(), phase_policy=default_policy())
|
||||
tool = AdvancePhaseTool(engine=engine)
|
||||
# input_schema has empty properties + additionalProperties:false —
|
||||
# no arguments expected.
|
||||
assert tool.input_schema["properties"] == {}
|
||||
assert tool.input_schema["additionalProperties"] is False
|
||||
|
||||
def test_tool_bypasses_phase_check(self):
|
||||
"""`advance_phase` is the LLM's escape hatch — must never be blocked."""
|
||||
# _check_phase_permission should NOT block advance_phase even in PLANNING.
|
||||
# The bypass is implemented in _execute_tool by name check.
|
||||
# We verify the bypass indirectly: tool dispatches normally even in
|
||||
# PLANNING (where only search/read_file/bash/tool_search are allowed).
|
||||
# advance_phase is not in the whitelist, but the name-based bypass
|
||||
# in _execute_tool lets it through.
|
||||
# (Direct unit test of the bypass would require mocking _find_tool.)
|
||||
# Sanity: advance_phase is not in any whitelist.
|
||||
for phase, allowed in default_policy().whitelist.items():
|
||||
assert "advance_phase" not in allowed, (
|
||||
f"advance_phase must not be in {phase.value} whitelist"
|
||||
)
|
||||
|
|
@ -0,0 +1,367 @@
|
|||
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
|
||||
|
||||
Per plan U1 Execution note: characterization-first — assert that
|
||||
`symbol=None` returns the full file content (matches pre-existing benchmark
|
||||
`_FakeTool` shape) before adding symbol-extraction behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.file_read import ReadFileTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileToolSchema:
|
||||
def test_name_is_read_file(self):
|
||||
tool = ReadFileTool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_required_path(self):
|
||||
tool = ReadFileTool()
|
||||
assert "path" in tool.input_schema["required"]
|
||||
assert "path" in tool.input_schema["properties"]
|
||||
|
||||
def test_optional_symbol_and_lines(self):
|
||||
tool = ReadFileTool()
|
||||
props = tool.input_schema["properties"]
|
||||
assert "symbol" in props
|
||||
assert "start_line" in props
|
||||
assert "end_line" in props
|
||||
# None of the optional fields should be in `required`.
|
||||
required = set(tool.input_schema["required"])
|
||||
assert required == {"path"}
|
||||
|
||||
def test_additional_properties_false(self):
|
||||
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
|
||||
tool = ReadFileTool()
|
||||
assert tool.input_schema.get("additionalProperties") is False
|
||||
|
||||
def test_tags_contain_io_and_read(self):
|
||||
tool = ReadFileTool()
|
||||
assert "io" in tool.tags
|
||||
assert "read" in tool.tags
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Characterization — symbol=None returns full file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_py_file(tmp_path):
|
||||
path = tmp_path / "sample.py"
|
||||
path.write_text(
|
||||
textwrap.dedent('''
|
||||
"""Sample module."""
|
||||
|
||||
def my_func():
|
||||
return 42
|
||||
|
||||
|
||||
class MyClass:
|
||||
attr = 1
|
||||
|
||||
def method_a(self):
|
||||
return self.attr
|
||||
''').lstrip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ts_file(tmp_path):
|
||||
path = tmp_path / "sample.ts"
|
||||
path.write_text(
|
||||
textwrap.dedent('''
|
||||
export function renderComponent(): JSX.Element {
|
||||
return <div/>;
|
||||
}
|
||||
|
||||
export class BaseService {
|
||||
abstract run(): void;
|
||||
}
|
||||
''').lstrip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
class TestCharacterizationFullFile:
|
||||
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file))
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["path"] == str(sample_py_file)
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == result["total_lines"]
|
||||
assert "def my_func" in result["content"]
|
||||
assert "class MyClass" in result["content"]
|
||||
assert result["content"].startswith('"""Sample module."""')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_file_includes_all_lines(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file))
|
||||
assert result["total_lines"] >= 8
|
||||
assert result["content"].count("\n") >= result["total_lines"] - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Symbol slicing — happy paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSlicing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_function(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "my_func"
|
||||
assert result["symbol_kind"] == "function"
|
||||
assert "def my_func" in result["content"]
|
||||
assert "return 42" in result["content"]
|
||||
# Should NOT include the class below.
|
||||
assert "class MyClass" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_class_includes_method(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "MyClass"
|
||||
assert result["symbol_kind"] == "class"
|
||||
assert "class MyClass" in result["content"]
|
||||
assert "def method_a" in result["content"] # method included
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_method_directly(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "method_a"
|
||||
assert result["symbol_kind"] == "method"
|
||||
assert "def method_a" in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_function(self, sample_ts_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "renderComponent"
|
||||
assert "renderComponent" in result["content"]
|
||||
# Should not include the class below.
|
||||
assert "BaseService" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_class(self, sample_ts_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "BaseService"
|
||||
assert result["symbol_kind"] == "class"
|
||||
assert "BaseService" in result["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Symbol slicing — edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSlicingEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_symbol_not_found_lists_available(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
|
||||
|
||||
assert result["is_error"] is False # soft error, not hard
|
||||
assert result["content"] == ""
|
||||
assert result["symbol"] == "nonexistent"
|
||||
available = result["available_symbols"]
|
||||
assert "my_func" in available
|
||||
assert "MyClass" in available
|
||||
assert "method_a" in available
|
||||
assert "nonexistent" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
|
||||
path = tmp_path / "notes.md"
|
||||
path.write_text("# Hello\nworld\n", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path), symbol="anything")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == "# Hello\nworld\n"
|
||||
assert "symbol extraction not supported" in result["note"]
|
||||
assert ".md" in result["note"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.py"
|
||||
path.write_text("", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path))
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == ""
|
||||
assert result["total_lines"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_with_no_symbols(self, tmp_path):
|
||||
path = tmp_path / "data.py"
|
||||
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path), symbol="PI")
|
||||
|
||||
# PI is not a def/class — extractor finds no symbols; soft error lists available.
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == ""
|
||||
assert result["available_symbols"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileToolErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_required(self):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is True
|
||||
assert "path" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_empty_string(self):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path="")
|
||||
assert result["is_error"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tmp_path):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(tmp_path / "missing.py"))
|
||||
assert result["is_error"] is True
|
||||
assert "not found" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_is_directory(self, tmp_path):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(tmp_path))
|
||||
assert result["is_error"] is True
|
||||
assert "directory" in result["error"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manual line slicing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManualLineSlicing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_and_end_line(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
start_line=3,
|
||||
end_line=5,
|
||||
)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 3
|
||||
assert result["end_line"] == 5
|
||||
# Lines 3-5 of the sample file:
|
||||
# line 3: "def my_func():"
|
||||
# line 4: " return 42"
|
||||
# line 5: "" (blank)
|
||||
assert "def my_func" in result["content"]
|
||||
assert "return 42" in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_line_only_extends_to_eof(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), start_line=8)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 8
|
||||
assert result["end_line"] == result["total_lines"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_line_only_starts_at_one(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), end_line=2)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_start_line_zero(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), start_line=0)
|
||||
assert result["is_error"] is True
|
||||
assert "start_line" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_before_start(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
start_line=5,
|
||||
end_line=3,
|
||||
)
|
||||
assert result["is_error"] is True
|
||||
assert "end_line" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_lines_override_symbol(self, sample_py_file):
|
||||
# Per plan U1 Approach: "start_line/end_line overrides symbol".
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
symbol="my_func",
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
assert result["is_error"] is False
|
||||
# Manual slicing won — symbol field absent.
|
||||
assert "symbol" not in result or result.get("symbol") is None
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration — tool registry discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRegistryDiscovery:
|
||||
def test_instantiable_without_args(self):
|
||||
# Default constructor — matches the convention used by ToolRegistry
|
||||
# to instantiate tools by class.
|
||||
tool = ReadFileTool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_to_dict_serializable(self):
|
||||
tool = ReadFileTool()
|
||||
d = tool.to_dict()
|
||||
assert d["name"] == "read_file"
|
||||
assert "input_schema" in d
|
||||
assert "output_schema" in d
|
||||
assert d["tags"] == ["io", "file", "read"]
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
|
||||
|
||||
Covers R22 (file reading supports symbol/function granularity) and KTD1
|
||||
(Python ast + language-aware regex, no tree-sitter dependency).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.symbol_extractor import (
|
||||
AstSymbolExtractor,
|
||||
RegexSymbolExtractor,
|
||||
SymbolSpan,
|
||||
extract_symbols_from_file,
|
||||
get_extractor,
|
||||
language_for_extension,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# language_for_extension
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLanguageForExtension:
|
||||
def test_python_extensions(self):
|
||||
assert language_for_extension("py") == "py"
|
||||
assert language_for_extension(".py") == "py"
|
||||
assert language_for_extension(".PY") == "py" # case-insensitive
|
||||
|
||||
def test_typescript_javascript(self):
|
||||
assert language_for_extension(".ts") == "ts"
|
||||
assert language_for_extension(".tsx") == "ts"
|
||||
assert language_for_extension(".js") == "js"
|
||||
assert language_for_extension(".jsx") == "js"
|
||||
assert language_for_extension(".mjs") == "js"
|
||||
assert language_for_extension(".cjs") == "js"
|
||||
|
||||
def test_go_rust_java(self):
|
||||
assert language_for_extension(".go") == "go"
|
||||
assert language_for_extension(".rs") == "rs"
|
||||
assert language_for_extension(".java") == "java"
|
||||
|
||||
def test_unsupported_returns_empty(self):
|
||||
assert language_for_extension(".md") == ""
|
||||
assert language_for_extension(".txt") == ""
|
||||
assert language_for_extension("") == ""
|
||||
assert language_for_extension(".unknown") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AstSymbolExtractor — Python
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAstSymbolExtractor:
|
||||
extractor = AstSymbolExtractor()
|
||||
|
||||
def test_unsupported_language_returns_empty(self):
|
||||
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
|
||||
|
||||
def test_syntax_error_returns_empty(self):
|
||||
# Never raises — callers rely on this for fallback routing.
|
||||
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
|
||||
assert result == []
|
||||
|
||||
def test_top_level_function(self):
|
||||
content = "def my_func():\n return 42\n"
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "my_func"
|
||||
assert span.kind == "function"
|
||||
assert span.start_line == 1
|
||||
assert span.end_line == 2
|
||||
|
||||
def test_async_function(self):
|
||||
content = "async def fetch():\n return 1\n"
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
assert len(spans) == 1
|
||||
assert spans[0].name == "fetch"
|
||||
assert spans[0].kind == "function"
|
||||
|
||||
def test_top_level_class(self):
|
||||
content = textwrap.dedent('''
|
||||
class MyClass:
|
||||
"""docstring"""
|
||||
|
||||
def method_a(self):
|
||||
return 1
|
||||
|
||||
async def method_b(self):
|
||||
return 2
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
names = [s.name for s in spans]
|
||||
assert "MyClass" in names
|
||||
assert "method_a" in names
|
||||
assert "method_b" in names
|
||||
|
||||
cls = next(s for s in spans if s.name == "MyClass")
|
||||
assert cls.kind == "class"
|
||||
assert cls.start_line == 1
|
||||
# Class body extends through the last method's end_lineno.
|
||||
assert cls.end_line >= 7
|
||||
|
||||
def test_methods_classified_as_methods(self):
|
||||
content = textwrap.dedent('''
|
||||
class Foo:
|
||||
def bar(self):
|
||||
pass
|
||||
|
||||
def top_level():
|
||||
pass
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
by_name = {s.name: s for s in spans}
|
||||
assert by_name["bar"].kind == "method"
|
||||
assert by_name["top_level"].kind == "function"
|
||||
|
||||
def test_decorated_function(self):
|
||||
content = textwrap.dedent('''
|
||||
@staticmethod
|
||||
def helper():
|
||||
return "hi"
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
# Note: extractor uses node.lineno (def line) — decorators above are
|
||||
# excluded by design (matches user-visible symbol start at `def`).
|
||||
assert any(s.name == "helper" for s in spans)
|
||||
span = next(s for s in spans if s.name == "helper")
|
||||
assert span.start_line == 2 # the `def` line
|
||||
|
||||
def test_nested_function(self):
|
||||
content = textwrap.dedent('''
|
||||
def outer():
|
||||
def inner():
|
||||
return 1
|
||||
return inner()
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
names = {s.name for s in spans}
|
||||
assert "outer" in names
|
||||
assert "inner" in names
|
||||
|
||||
def test_empty_file(self):
|
||||
assert self.extractor.extract_symbols("", "py") == []
|
||||
|
||||
def test_no_symbols_in_docstring_only_file(self):
|
||||
content = '"""just a docstring"""\n'
|
||||
assert self.extractor.extract_symbols(content, "py") == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegexSymbolExtractor:
|
||||
extractor = RegexSymbolExtractor()
|
||||
|
||||
def test_unsupported_language_returns_empty(self):
|
||||
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
|
||||
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
|
||||
|
||||
def test_typescript_function_declaration(self):
|
||||
content = textwrap.dedent('''
|
||||
export function renderComponent(props: Props): JSX.Element {
|
||||
return <div/>;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
|
||||
|
||||
def test_typescript_async_function(self):
|
||||
content = "async function fetchData() {\n return await fetch();\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "fetchData" for s in spans)
|
||||
|
||||
def test_typescript_arrow_function_const(self):
|
||||
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "handleClick" for s in spans)
|
||||
|
||||
def test_typescript_class(self):
|
||||
content = textwrap.dedent('''
|
||||
export abstract class BaseService {
|
||||
abstract run(): void;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
|
||||
|
||||
def test_javascript_function(self):
|
||||
content = "function foo() {\n return 1;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert any(s.name == "foo" for s in spans)
|
||||
|
||||
def test_javascript_arrow_const(self):
|
||||
content = "const bar = () => 42;\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert any(s.name == "bar" for s in spans)
|
||||
|
||||
def test_go_function(self):
|
||||
content = textwrap.dedent('''
|
||||
package main
|
||||
|
||||
func HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
// method
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "go")
|
||||
names = {s.name for s in spans}
|
||||
assert "HandleRequest" in names
|
||||
assert "Start" in names # method receiver pattern
|
||||
|
||||
def test_go_struct(self):
|
||||
content = "type Server struct {\n Addr string\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "go")
|
||||
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
|
||||
|
||||
def test_rust_function(self):
|
||||
content = textwrap.dedent('''
|
||||
pub fn process(input: &str) -> Result<usize, Error> {
|
||||
Ok(input.len())
|
||||
}
|
||||
|
||||
async fn fetch() -> Bytes {
|
||||
unimplemented!()
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
names = {s.name for s in spans}
|
||||
assert "process" in names
|
||||
assert "fetch" in names
|
||||
|
||||
def test_rust_struct(self):
|
||||
content = "pub struct Config {\n pub path: String,\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
|
||||
|
||||
def test_rust_impl(self):
|
||||
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
|
||||
|
||||
def test_java_class(self):
|
||||
content = textwrap.dedent('''
|
||||
package com.example;
|
||||
|
||||
public class UserService {
|
||||
public User findById(long id) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "java")
|
||||
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
|
||||
|
||||
def test_java_method(self):
|
||||
content = "public User findById(long id) {\n return null;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "java")
|
||||
assert any(s.name == "findById" and s.kind == "function" for s in spans)
|
||||
|
||||
def test_end_line_extends_to_next_symbol(self):
|
||||
# First symbol's end_line is the line before the second symbol starts.
|
||||
content = textwrap.dedent('''
|
||||
function first() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
function second() {
|
||||
return 2;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
spans.sort(key=lambda s: s.start_line)
|
||||
first = spans[0]
|
||||
second = spans[1]
|
||||
assert first.name == "first"
|
||||
assert second.name == "second"
|
||||
assert first.end_line == second.start_line - 1
|
||||
|
||||
def test_last_symbol_end_line_is_eof(self):
|
||||
content = "function only() {\n return 1;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert len(spans) == 1
|
||||
assert spans[0].end_line == len(content.splitlines())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_extractor + integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetExtractor:
|
||||
def test_python_returns_ast_extractor(self):
|
||||
ext = get_extractor("py")
|
||||
assert ext is not None
|
||||
assert isinstance(ext, AstSymbolExtractor)
|
||||
|
||||
def test_typescript_returns_regex_extractor(self):
|
||||
ext = get_extractor("ts")
|
||||
assert ext is not None
|
||||
assert isinstance(ext, RegexSymbolExtractor)
|
||||
|
||||
def test_unsupported_returns_none(self):
|
||||
assert get_extractor("md") is None
|
||||
assert get_extractor("") is None
|
||||
assert get_extractor("unknown") is None
|
||||
|
||||
|
||||
class TestExtractSymbolsFromFile:
|
||||
def test_python_file(self, tmp_path):
|
||||
path = tmp_path / "module.py"
|
||||
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
assert lang == "py"
|
||||
assert any(s.name == "hello" for s in spans)
|
||||
|
||||
def test_unsupported_extension(self, tmp_path):
|
||||
path = tmp_path / "notes.md"
|
||||
path.write_text("# Hello\n", encoding="utf-8")
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
assert lang == ""
|
||||
assert spans == []
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
path = tmp_path / "nonexistent.py"
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
# lang is detected from extension even if read fails.
|
||||
assert lang == "py"
|
||||
assert spans == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SymbolSpan dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSpan:
|
||||
def test_frozen_dataclass(self):
|
||||
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
|
||||
assert span.name == "foo"
|
||||
with pytest.raises(Exception):
|
||||
span.name = "bar" # type: ignore[misc] — frozen
|
||||
|
||||
def test_equality(self):
|
||||
a = SymbolSpan("foo", "function", 1, 3)
|
||||
b = SymbolSpan("foo", "function", 1, 3)
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
Loading…
Reference in New Issue