feat(agent): Wave 2 medium coupling (G4/G7/G9) (#5)
Deploy to Production / deploy (push) Waiting to run Details
Test / backend-test (push) Waiting to run Details
Test / frontend-unit (push) Waiting to run Details
Test / api-e2e (push) Waiting to run Details
Test / frontend-e2e (push) Waiting to run Details

This commit is contained in:
Fischer 2026-06-30 09:09:33 +08:00
parent 78ed93fc81
commit a2dcde01b8
16 changed files with 2658 additions and 34 deletions

View File

@ -28,6 +28,29 @@ llm:
coding: bailian-coding/qwen3-coder-plus coding: bailian-coding/qwen3-coder-plus
chat: deepseek/deepseek-chat chat: deepseek/deepseek-chat
reasoning: deepseek/deepseek-reasoner reasoning: deepseek/deepseek-reasoner
# G4/U1: Auxiliary model for cost-sensitive tasks (summarization).
# When set, ContextCompressor tries this alias first, falling back to
# the main model on failure or empty content. Commented to preserve
# default behavior — uncomment to enable.
# auxiliary_model: fast
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
# not `git checkout .` which would wipe unrelated changes).
rollback:
default_timeout: 30.0
# G7/U3: Three-tier fallback chain at chat REST send_message.
# main → Recovery (ReflexionEngine retry) → Emergency (rule-based classifier).
# Wired only at chat REST path (KTD5); CLI / ReWOO / Reflexion internal
# ReAct calls bypass the chain (no recursive loop).
fallback_chain:
enabled: true
recovery:
enabled: true
max_retries: 1 # ReflexionEngine max_reflections override
emergency:
enabled: true
session: {backend: memory} session: {backend: memory}
bus: {backend: memory} bus: {backend: memory}
task_store: {backend: memory} task_store: {backend: memory}

View File

@ -0,0 +1,481 @@
---
date: 2026-06-29
type: feat
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
---
# Wave 2 — Auxiliary LLM Routing, Three-Tier Fallback, Atomic Subtask Rollback
## Summary
Wave 2 of the advanced-agent gap optimization brainstorm. Three medium-coupling gaps: G4 routes summary calls through a cheaper auxiliary LLM (falling back to main on failure), G7 introduces a three-tier fallback chain (main → Recovery via existing `ReflexionEngine` → Emergency with structured errors), G9 binds atomic subtask rollback to `PlanPhase` (opt-in via `rollback_command`, coordinated with the existing U7 `PipelineCheckpoint`).
---
## Problem Frame
Wave 1 shipped self-contained quick wins (G1/G2/G3/G8). Wave 2 addresses the medium-coupling gaps that touch multiple layers and resolve two deferred-to-planning decisions surfaced in the brainstorm: G7 Emergency layer rule template shape, and G9 `rollback_command` default behavior. The constraints these gaps must respect come from `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — new fields must preserve existing contracts, dynamic plan mutations must persist immediately, and `PipelineCheckpoint` is in-memory dict + Redis fallback (not a DB row lock).
---
## Requirements
Carried from `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` § Wave 2 (R13-R21):
- R13. `ContextCompressor` summary task routes to auxiliary model (cheap, e.g. Gemini Flash / Doubao lite), not main model.
- R14. `auxiliary_model` is configurable, separated from main `model`.
- R15. Summary quality does not degrade: auxiliary failure falls back to main model (not to simple_summary).
- R16. Main agent failure triggers Recovery layer (reuse `ReflexionEngine` Evaluate→Reflect→Retry).
- R17. Recovery failure triggers Emergency layer (rule-based fallback, structured error + suggestion).
- R18. Fallback chain is configurable (per-layer max retries, enable/disable Recovery/Emergency).
- R19. `PlanPhase` (`experts/plan.py:147`) gains optional `validation_command` and `rollback_command` fields.
- R20. Phase failure auto-executes `rollback_command` when configured (default `git checkout` pattern); coordinated with U7 checkpoint.
- R21. `checkpoint.save` runs only after rollback validation passes (avoid persisting failed state).
Cross-cutting: R26 (configuration via `agentkit.yaml`), R27 (each gap ships a minimal self-check test, per ponytail rule).
---
## Key Technical Decisions
KTD1. **Auxiliary routing is compressor-scoped, not gateway-level.** `LLMGateway` already has a per-model fallback chain (`gateway.py:170-227`), but it semantics are "main fails → try backup model", not "route by task type". G4 adds an `auxiliary_model` parameter to `ContextCompressor.__init__`; `_summarize` tries auxiliary first, falls back to main on failure. This avoids touching the gateway's existing fallback semantics and keeps the change to one file plus config wiring.
KTD2. **Emergency layer adds a parallel `error_struct` field, not replacing `error_message`.** `TaskResult.error_message: str | None` is serialized to API responses (`protocol.py:132-145`) — changing its type breaks frontend contracts. New `error_struct: dict | None` field carries `{error_code, message, suggestions, retryable}`. `error_message` remains the human-readable string (now derived from `error_struct.message` when present). Backward compatibility preserved.
KTD3. **Emergency rule classification by exception type, not by status field.** `TaskTimeoutError``timeout`, `LoopDetectedError``loop_detected`, `LLMProviderError``llm_failure`, `TaskCancelledError``cancelled` (not Emergency-eligible, propagates), generic `Exception``internal_error`. Each rule maps to a fixed message + suggestion list (no LLM-driven suggestions — pure rule-based, per brainstorm).
KTD4. **Recovery layer reuses `ReflexionEngine` with main model, not auxiliary.** `ReflexionEngine.execute()` already supports `evaluate_model`/`reflect_model` parameters (`reflexion.py:94-111`), defaulting to the main model. Recovery triggers on main agent failure; using the main model for evaluate/reflect maximizes diagnostic reliability (the same model that failed is best positioned to reflect on its own failure). Auxiliary model is reserved for G4's cost-sensitive compression path.
KTD5. **Three-tier chain wires at `chat.py:613` (WebSocket main path), not at `ReActEngine.execute()`.** ReAct has 5 call sites (`chat.py:613`, `rewoo.py:1204`, `cli/chat.py:338`, `reflexion.py:216`, `config_driven.py:695`). Wrapping at ReAct would force all 5 sites through Recovery/Emergency, including ReflexionEngine itself (creating a recursive loop). Wrapping at `chat.py:613` covers the primary user-facing path; CLI and ReWOO are out of scope (deferred to follow-up).
KTD6. **Rollback is opt-in via `rollback_command` field; no implicit default.** Brainstorm KTD5 mentioned "default `git checkout`" but auto-executing `git checkout` on every phase failure changes existing behavior (Finding 1: "新字段默认值须保持既有契约"). Resolution: `rollback_command: str | None = None`. When unset, no rollback executes (preserves existing contract). When set, the configured command runs after validation failure, before checkpoint save. `git checkout <files>` is the canonical pattern documented in YAML examples.
KTD7. **Rollback executes via `RollbackExecutor`, not via `ShellTool`.** `ShellTool._is_dangerous` would trigger `confirm_callback` for `git checkout` (not in `_SAFE_COMMAND_PREFIXES`, not in `_DANGEROUS_BINARY_FLAGS`). Orchestrator-internal execution bypasses ShellTool — uses `asyncio.create_subprocess_shell` with `cwd`/`timeout`/`proc.kill()` pattern from `verification_loop.py:67-103`. Audit logging added (not a whitelist concern; terminal_whitelist only governs `/api/v1/terminal/server` route per `terminal_server.py:227`).
KTD8. **G9 checkpoint ordering: validation → rollback → checkpoint save.** Currently `orchestrator.py:265` saves checkpoint unconditionally after phase finalizes (success or failure). R21 requires checkpoint save only after rollback validation passes. New ordering: phase fails → mark FAILED → mark dependents FAILED → execute `validation_command` (if set) → if validation fails, execute `rollback_command` (if set) → save checkpoint (only if rollback validation passed or no rollback configured).
---
## High-Level Technical Design
### Three-tier fallback chain (G7)
```mermaid
stateDiagram-v2
[*] --> Main: chat request
Main --> MainSuccess: success
Main --> Recovery: failure (exception or empty_fallback)
Recovery --> RecoverySuccess: ReflexionEngine retry succeeds
Recovery --> Emergency: max_reflections exhausted
Emergency --> [*]: emit error_struct, terminal
MainSuccess --> [*]: normal response
RecoverySuccess --> [*]: recovered response
```
### G9 phase failure + rollback sequence
```mermaid
sequenceDiagram
participant O as TeamOrchestrator
participant P as PlanPhase
participant RE as RollbackExecutor
participant CP as PipelineCheckpoint
O->>P: execute phase
P-->>O: failure (exception)
O->>O: mark FAILED + dependents FAILED
alt validation_command set
O->>RE: run validation_command
RE-->>O: ValidationResult
alt validation fails AND rollback_command set
O->>RE: run rollback_command
RE-->>O: RollbackResult
alt rollback validation passes
O->>CP: save checkpoint
else rollback validation fails
O->>O: log rollback_failed, skip checkpoint
end
else no rollback_command
O->>CP: save checkpoint (no rollback)
end
else no validation_command
O->>CP: save checkpoint (no validation)
end
```
### Configuration shape (YAML)
```yaml
# G4: Auxiliary LLM for compression
llm:
auxiliary_model: fast # alias resolved via existing model_aliases mechanism
# G7: Three-tier fallback chain
fallback_chain:
enabled: true
recovery:
enabled: true
max_retries: 1 # ReflexionEngine max_reflections override
emergency:
enabled: true
# G9: Rollback configuration (per-phase opt-in via PlanPhase.rollback_command)
rollback:
default_timeout: 30.0 # RollbackExecutor subprocess timeout
```
---
## Implementation Units
### U1. G4 — Auxiliary LLM Routing in ContextCompressor
**Goal:** Route `_summarize` LLM calls through `auxiliary_model` when configured; fall back to main model on auxiliary failure (not to `_simple_summary`).
**Requirements:** R13, R14, R15, R26, R27
**Dependencies:** none (self-contained)
**Files:**
- Modify: `src/agentkit/core/compressor.py` (add `auxiliary_model` param + routing in `_summarize`)
- Modify: `src/agentkit/llm/config.py` (add `auxiliary_model: str | None` field to `LLMConfig`)
- Modify: `src/agentkit/server/config.py` (`_build_llm_config` reads `auxiliary_model` from llm section)
- Modify: `agentkit.yaml` (document `llm.auxiliary_model: fast` in llm section)
- Create: `tests/unit/test_compressor_auxiliary.py`
**Approach:** `ContextCompressor.__init__` gains `auxiliary_model: str | None = None` param (after existing `model` param). `_summarize` (compressor.py:123-158) restructuring:
1. If `auxiliary_model` is set and differs from `model`, try `auxiliary_model` first via `self._llm_gateway.chat(model=self._auxiliary_model, ...)`.
2. **Truthiness check (Finding 4 anti-pattern):** treat empty `response.content` (None or whitespace-only) as failure, not success. On failure, log and fall through to main model.
3. If auxiliary succeeds with non-empty content, return it.
4. If auxiliary fails (exception OR empty content), retry with main `self._model`.
5. Existing `except Exception → _simple_summary` block remains as the final degradation tier.
`LLMConfig.auxiliary_model` reads from `data.get("auxiliary_model")` in `_build_llm_config`. Agentkit.yaml already declares `fast: bailian-coding/qwen-turbo` alias — this is the canonical auxiliary target.
**Execution note:** characterization-first. Capture current `_summarize` behavior (single model, exception → simple_summary) with `auxiliary_model=None` tests, then add `auxiliary_model="fast"` tests for new routing behavior.
**Patterns to follow:**
- `LLMGateway._get_models_to_try` (`gateway.py:170-227`) — existing fallback chain pattern
- Wave 1's `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections, `config.py:239-273`)
**Test scenarios:**
- Happy path: `auxiliary_model="fast"` set, auxiliary returns non-empty content → result is auxiliary content, main model not called
- Empty content fallback: auxiliary returns `content=""` (or `None`) → main model called, main content returned (covers Finding 4 anti-pattern)
- Auxiliary exception: auxiliary raises `LLMProviderError` → main model called, main content returned
- Both fail: auxiliary raises, main raises → `_simple_summary` returned (existing degradation preserved)
- Characterization: `auxiliary_model=None` → behavior matches current code (single model call)
- Config wiring: `LLMConfig.from_dict` reads `auxiliary_model` field from dict; `ServerConfig._build_llm_config` passes it through
- Audit: auxiliary call uses `agent_name="compressor"`, `task_type="summarization"` (preserved for usage tracking)
**Verification:** All test scenarios pass; existing `tests/unit/test_compressor*.py` (if any) still pass with `auxiliary_model=None` default; ruff clean.
---
### U2. G7 — Emergency Layer Rule Template + TaskResult Extension
**Goal:** Add rule-based Emergency classifier with structured error output. No wiring yet — just the infrastructure (classifier + data structure + `fallback.py` extension).
**Requirements:** R17, R18, R26, R27
**Dependencies:** none (foundation for U3)
**Files:**
- Modify: `src/agentkit/core/fallback.py` (add `EmergencyRules` class + `EmergencyError` dataclass, preserve existing 3 constants)
- Modify: `src/agentkit/core/protocol.py` (add `error_struct: dict | None = None` field to `TaskResult`, update `to_dict`)
- Create: `tests/unit/test_emergency_rules.py`
**Approach:**
`EmergencyError` dataclass (in `fallback.py`):
```
@dataclass
class EmergencyError:
error_code: str # "timeout"|"loop_detected"|"llm_failure"|"internal_error"
message: str # human-readable Chinese message (mirrors EMPTY_LLM_RESPONSE style)
suggestions: list[str] # actionable user-facing suggestions
retryable: bool # whether user retry might succeed
original_error: str # str(exc) for traceability
def to_dict(self) -> dict: ...
def to_error_message(self) -> str: ... # formatted "message\n建议1) ... 2) ..."
```
`EmergencyRules` class (rule-based classifier, no LLM):
```
class EmergencyRules:
@staticmethod
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
# Match by exception type (TaskTimeoutError, LoopDetectedError, LLMProviderError, etc.)
# config allows per-rule customization (suggestion overrides, retryable overrides)
```
Rule mapping (initial set, expandable via config):
- `TaskTimeoutError``error_code="timeout"`, `retryable=True`, suggestions: ["稍后重试", "简化任务范围"]
- `LoopDetectedError``error_code="loop_detected"`, `retryable=True`, suggestions: ["拆分任务", "检查工具参数"]
- `LLMProviderError``error_code="llm_failure"`, `retryable=True`, suggestions: ["稍后重试", "切换模型"]
- `TaskCancelledError` → not classified (propagates as-is, Emergency not triggered)
- Generic `Exception``error_code="internal_error"`, `retryable=False`, suggestions: ["联系管理员"]
`TaskResult` extension:
```
@dataclass
class TaskResult:
# ... existing fields ...
error_message: str | None # unchanged
error_struct: dict | None = None # NEW: serialized EmergencyError.to_dict()
```
`to_dict` includes `error_struct` when set; `error_message` continues to hold the human-readable string.
**Patterns to follow:**
- Existing `EMPTY_LLM_RESPONSE` / `MAX_STEPS_REACHED` style in `fallback.py` (Chinese, "建议:..." format)
- `VerificationResult.errors: list[str]` field pattern from `verification_loop.py:18-24`
- `TaskResult.to_dict` pattern at `protocol.py:132-145`
**Test scenarios:**
- Happy path: `classify(TaskTimeoutError(...))` returns `EmergencyError(error_code="timeout", retryable=True)`
- Each exception type maps to correct `error_code`
- `TaskCancelledError` is NOT classified (caller must check before invoking `classify`)
- `to_dict()` produces all 5 fields
- `to_error_message()` formats suggestions as "建议1) ... 2) ..."
- `EmergencyRules.classify(Exception("unknown"))``error_code="internal_error"`, `retryable=False`
- Config override: custom suggestion list for `timeout` rule via config dict
- `TaskResult` with `error_struct` set: `to_dict()` includes both `error_message` and `error_struct`
- `TaskResult` with `error_struct=None` (default): `to_dict()` matches current behavior (backward compat)
**Verification:** All scenarios pass; `fallback.py` existing 3 constants unchanged; `TaskResult.to_dict` for tasks without `error_struct` matches pre-change output byte-for-byte.
---
### U3. G7 — Three-Tier Fallback Chain Wiring
**Goal:** Wire main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) at `chat.py:613`. Composes U2's infrastructure with existing `ReflexionEngine`.
**Requirements:** R16, R18, R26
**Dependencies:** U2 (EmergencyRules + TaskResult.error_struct)
**Files:**
- Modify: `src/agentkit/server/routes/chat.py` (wrap main agent call at L613 with three-tier chain)
- Modify: `src/agentkit/server/config.py` (add `fallback_chain` section to `from_dict`)
- Modify: `agentkit.yaml` (document `fallback_chain:` section)
- Create: `tests/unit/test_fallback_chain.py`
**Approach:**
New helper module or inline function in `chat.py`:
```
async def execute_with_fallback_chain(
agent: ConfigDrivenAgent,
task: Task,
config: dict, # fallback_chain config section
llm_gateway: LLMGateway,
) -> TaskResult:
# Tier 1: Main
try:
result = await agent.execute(task)
if result.status == "success" or result.status == "completed":
return result
# Treat non-success as soft failure → trigger Recovery
raise AgentExecutionError(result.error_message or "main agent did not succeed")
except (TaskTimeoutError, LoopDetectedError, LLMProviderError, AgentExecutionError) as exc:
if not config.get("recovery", {}).get("enabled", True):
return _to_emergency(exc, task)
# Tier 2: Recovery (ReflexionEngine)
try:
reflexion_engine = ReflexionEngine(
llm_gateway=llm_gateway,
max_reflections=config.get("recovery", {}).get("max_retries", 1),
# ... other params from agent's existing reflexion config
)
recovery_result = await reflexion_engine.execute(
messages=task.messages, # rebuild from task
tools=agent.get_tools(),
model=task.model,
agent_name=agent.name,
task_id=task.task_id,
)
if recovery_result.status == "success":
return _recovery_to_task_result(recovery_result, task)
except Exception as recovery_exc:
logger.warning(f"Recovery layer failed: {recovery_exc}")
# Tier 3: Emergency
return _to_emergency(exc, task)
```
`_to_emergency(exc, task)` constructs `EmergencyError` via `EmergencyRules.classify(exc, config)`, then returns `TaskResult(status=FAILED, error_message=emergency.to_error_message(), error_struct=emergency.to_dict(), ...)`.
Wiring at `chat.py:613`: replace direct `agent.execute(task)` call with `execute_with_fallback_chain(...)`. Recovery config from `server_config.fallback_chain` (new `ServerConfig.fallback_chain: dict` field, mirroring Wave 1 pattern).
**Recovery layer scope (KTD5):** Only `chat.py:613` is wrapped. CLI (`cli/chat.py:338`), ReWOO (`rewoo.py:1204`), ReflexionEngine's internal ReAct call, and `config_driven.py:695` are NOT wrapped (would create recursive loop or unwanted coupling). These remain on the direct-execute path. Documented in `## Scope Boundaries`.
**Patterns to follow:**
- `config_driven.py:836` — existing `ReflexionEngine` instantiation pattern (constructor params, execute call)
- Wave 1's `verification` config section in `ServerConfig.from_dict` (`config.py:240-273`)
**Test scenarios:**
- Happy path: main agent succeeds → no Recovery, no Emergency triggered; `error_struct=None`
- Main fails (timeout) → Recovery triggered → Recovery succeeds → `error_struct=None`, output from ReflexionEngine
- Main fails (timeout) → Recovery triggered → Recovery fails (max_reflections exhausted) → Emergency triggered → `error_struct` populated with `error_code="timeout"`, `retryable=True`
- Main fails (LoopDetectedError) → Emergency `error_code="loop_detected"`
- Main fails (LLMProviderError) → Emergency `error_code="llm_failure"`
- Main fails (TaskCancelledError) → propagates as-is (NOT routed to Emergency)
- Main fails (generic Exception) → Emergency `error_code="internal_error"`, `retryable=False`
- Config: `fallback_chain.recovery.enabled=false` → skip Recovery, go directly to Emergency
- Config: `fallback_chain.emergency.enabled=false` → re-raise original exception (no Emergency)
- Integration: full chain on real `ConfigDrivenAgent` with mocked LLM (mock main raises, mock ReflexionEngine succeeds)
**Verification:** All scenarios pass; existing chat WebSocket tests still pass; ruff clean.
---
### U4. G9 — PlanPhase Rollback Fields + RollbackExecutor + TeamOrchestrator Integration
**Goal:** Add `validation_command`/`rollback_command` optional fields to `PlanPhase`; execute rollback on phase failure (opt-in); coordinate with U7 checkpoint ordering per R21.
**Requirements:** R19, R20, R21, R26, R27
**Dependencies:** none (uses existing U7 `PipelineCheckpoint` and `VerificationLoop` patterns)
**Files:**
- Modify: `src/agentkit/experts/plan.py` (`PlanPhase` dataclass + `to_dict`/`from_dict` symmetry)
- Modify: `src/agentkit/experts/orchestrator.py` (insert rollback execution between phase failure and checkpoint save)
- Create: `src/agentkit/orchestrator/rollback.py` (`RollbackExecutor` class)
- Modify: `src/agentkit/server/config.py` (add `rollback` section to `from_dict`)
- Modify: `agentkit.yaml` (document `rollback:` section)
- Create: `tests/unit/test_phase_rollback.py`
**Approach:**
**PlanPhase extension (`plan.py:147-233`):** Add two optional fields with `None` default (preserves existing contract per Finding 1):
```
validation_command: str | None = None
rollback_command: str | None = None
```
Update `to_dict()` to include both fields (only when not None, to keep dict shape minimal). Update `from_dict()` to read both fields.
**`RollbackExecutor` class (new file `orchestrator/rollback.py`):** Mirrors `VerificationLoop` pattern (`verification_loop.py:67-103`):
```
class RollbackExecutor:
def __init__(self, working_dir: str | None = None, timeout: float = 30.0): ...
async def execute(self, command: str) -> RollbackResult:
# asyncio.create_subprocess_shell with cwd/timeout/proc.kill()
# Returns RollbackResult(passed, exit_code, stdout, stderr, command)
async def validate(self, command: str) -> RollbackResult:
# Same as execute but returns passed=False on non-zero exit
```
`RollbackResult` dataclass: `{passed: bool, exit_code: int, stdout: str, stderr: str, command: str}`.
**TeamOrchestrator integration (`orchestrator.py:246-270`):** New ordering per KTD8:
```
# Existing: phase failure detected, mark FAILED + dependents FAILED
# (lines 247-261 unchanged)
# NEW: rollback phase (only if validation_command and rollback_command configured)
should_save_checkpoint = True
if ph.validation_command and ph.rollback_command:
validator = RollbackExecutor(working_dir=self._workspace_root, timeout=...)
validation_result = await validator.validate(ph.validation_command)
if not validation_result.passed:
rollback_result = await validator.execute(ph.rollback_command)
if not rollback_result.passed:
# Rollback failed → don't save checkpoint (R21)
should_save_checkpoint = False
logger.error(f"Rollback failed for phase {ph.id}: {rollback_result.stderr}")
# Emit phase_rollback_failed event
await self._broadcast_event("phase_rollback_failed", {...})
# Existing: checkpoint save (conditional now)
if should_save_checkpoint and self._checkpoint is not None:
try:
await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e:
logger.warning(...)
```
**Events emitted (new):** `phase_rollback_started`, `phase_rollback_completed`, `phase_rollback_failed`. Existing `phase_failed` event unchanged (emitted before rollback).
**Config:** `rollback.default_timeout: float = 30.0` in `ServerConfig.rollback` dict (Wave 1 pattern).
**Security boundary (KTD7):** Rollback subprocess does NOT go through `ShellTool` (avoids `confirm_callback` for `git checkout`). Audit log emitted via `_broadcast_event`. `terminal_whitelist` is not consulted (only governs `/api/v1/terminal/server` route, per `terminal_server.py:227`).
**Execution note:** characterization-first. Test current behavior (`rollback_command=None` → no rollback, checkpoint saved) before adding rollback behavior.
**Patterns to follow:**
- `VerificationLoop` subprocess execution pattern (`verification_loop.py:67-103`) — `asyncio.create_subprocess_shell` + `cwd` + `timeout` + `proc.kill()`
- `PlanPhase.to_dict`/`from_dict` symmetric serialization pattern at `plan.py:186-233`
- `TeamOrchestrator._execute_phase` event broadcasting pattern (`orchestrator.py:253-260`)
- Wave 1's `ServerConfig.from_dict` extension template
**Test scenarios:**
- Characterization: `PlanPhase()` with no `validation_command`/`rollback_command` → `to_dict()` output matches pre-change shape (no new keys); `from_dict({})` produces phase with both fields as None
- Serialization: phase with `rollback_command="git checkout foo.py"``to_dict()` includes key; `from_dict(to_dict())` round-trips
- RollbackExecutor happy path: `execute("git status")` returns `passed=True`, `exit_code=0`
- RollbackExecutor timeout: `execute("sleep 10", timeout=0.1)` returns `passed=False`, exit_code=-1 (or similar)
- RollbackExecutor failure: `execute("false")` returns `passed=False`, `exit_code=1`
- Integration — opt-in default: phase fails, `rollback_command=None` → no RollbackExecutor call, checkpoint saved (existing behavior)
- Integration — rollback configured, validation passes (returns 0): rollback NOT executed, checkpoint saved
- Integration — rollback configured, validation fails, rollback succeeds (returns 0): rollback executed, checkpoint saved
- Integration — rollback configured, validation fails, rollback fails (returns 1): rollback executed, checkpoint NOT saved (R21), `phase_rollback_failed` event emitted
- Integration — real git repo fixture: phase writes file via `git checkout`, rollback `git checkout foo.py` restores file; assert file content matches pre-phase state
**Verification:** All scenarios pass; existing `tests/unit/test_pipeline_state.py` and `tests/unit/test_team_orchestrator*.py` (if any) still pass; ruff clean.
---
## Scope Boundaries
### Deferred to Follow-Up Work
- **Recovery wiring at non-chat call sites** (`cli/chat.py:338`, `rewoo.py:1204`, `config_driven.py:695`). KTD5 limits Wave 2 to the primary chat WebSocket path. Other entry points can adopt the same wrapper in a follow-up.
- **Recovery layer streaming events.** Wave 2's chain returns final `TaskResult`; SSE events for "recovery_started"/"recovery_completed" would require deeper WebSocket protocol changes.
- **Patch-level rollback** (`git apply -R <patch>`). KTD6 + KTD7 scope Wave 2 to `git checkout <files>` pattern. Patch-level requires extending `CheckpointData` schema to track patches — Wave 3 candidate.
- **DB-backed atomic claim for checkpoint.** Finding 3 noted `PipelineCheckpoint` is in-memory dict + Redis fallback, not a PostgreSQL `FOR UPDATE SKIP LOCKED` pattern. Multi-process atomicity is out of scope; single-process `asyncio.Lock` (already present at `orchestrator.py:53`) is sufficient.
### Outside this product's identity (carried from brainstorm)
- Wave 3 (G5/G6) — tree-sitter integration, SOLO four-stage state machine. Strategic direction locked in brainstorm KTD6/KTD7; implementation design deferred to Wave 3 plan.
- Node-level checkpoint (ReAct single-step). Stage-level (U7) satisfies core need.
- DeerFlow-style disk filesystem. Redis is the persistence layer.
- Full LangGraph migration. Self-built architecture stays.
---
## Risks & Dependencies
- **Risk: Auxiliary model availability.** If `auxiliary_model` alias is not configured in `agentkit.yaml`, `LLMGateway` raises `ModelNotFoundError`. Mitigation: `ContextCompressor` catches this in the auxiliary try-block and falls through to main model (R15 fallback path). Documented in YAML comments.
- **Risk: Recovery layer recursive loop.** `ReflexionEngine.execute()` internally calls `ReActEngine.execute()` (5th call site). If Recovery itself fails, it must NOT trigger another Recovery — KTD5 wires only at `chat.py:613`, so ReflexionEngine's internal ReAct call bypasses the chain. Recursive loop is structurally impossible.
- **Risk: `git checkout` destructive scope.** Misconfigured `rollback_command` (e.g., `git checkout .`) could wipe unrelated changes. Mitigation: rollback is opt-in per-phase (KTD6); audit log via `phase_rollback_started`/`phase_rollback_failed` events; documented YAML examples use file-scoped commands (`git checkout <specific_files>`).
- **Risk: `TaskResult.error_struct` field addition breaks pickle/serialization.** `TaskResult` is a dataclass with `to_dict`; new optional field with `None` default is backward-compatible for `to_dict` consumers. Pickle deserialization of pre-change data still works (new field defaults to None).
- **Dependency: Wave 1 merged.** `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections) is established and tested. Wave 2's 3 new sections (auxiliary_model, fallback_chain, rollback) reuse this exact pattern.
- **Dependency: U7 PipelineCheckpoint already shipped.** `orchestrator/checkpoint.py:56` exists with `save(plan_id, phase, plan_status)` API; U4 only adjusts the calling order, not the API.
---
## Sources / Research
- Brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 2 section, R13-R21, KTD4/KTD5)
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (ServerConfig.from_dict extension template established)
- Finding 1 (contract preservation rule): `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — "新字段默认值须保持既有契约" drives KTD6's opt-in default for `rollback_command`
- Finding 2 (retry-storm defense): `docs/solutions/security-issues/portal-platform-security-reliability-fixes.md` — Emergency layer's terminal `error_code` pattern (don't propagate unhandled exceptions)
- Finding 3 (atomic claim pattern): `docs/solutions/architecture-patterns/bitable-companion-service-security-reliability-patterns.md``SKIP LOCKED` not reusable; `PipelineCheckpoint` is in-memory dict + Redis
- Finding 4 (empty-response anti-pattern): `docs/solutions/ui-bugs/tauri-reload-loses-session.md` — auxiliary LLM must check truthy return value, not just absence of exception; drives U1's empty-content fallback
- Code locations verified during planning:
- `src/agentkit/core/compressor.py:35-44,123-158``_summarize` injection point for G4
- `src/agentkit/llm/gateway.py:170-227` — existing per-model fallback chain (semantics differ from G4's task-type routing)
- `src/agentkit/llm/config.py:198-257``LLMConfig` dataclass and `from_dict`
- `src/agentkit/core/reflexion.py:68-111``ReflexionEngine.__init__` and `execute()` signatures (already supports `evaluate_model`/`reflect_model`)
- `src/agentkit/core/fallback.py` (full file, 19 lines) — 3 existing constants; `EmergencyRules` adds alongside
- `src/agentkit/core/protocol.py:118-145``TaskResult` dataclass and `to_dict`
- `src/agentkit/experts/plan.py:147-233``PlanPhase` dataclass, `to_dict`/`from_dict`
- `src/agentkit/experts/orchestrator.py:246-270` — phase failure capture + checkpoint save site (U4 integration point)
- `src/agentkit/orchestrator/checkpoint.py:56``PipelineCheckpoint` (U7, no API change needed)
- `src/agentkit/core/verification_loop.py:67-103` — subprocess execution pattern for `RollbackExecutor`
- `src/agentkit/tools/shell.py:168-174,526-534``_DANGEROUS_BINARY_FLAGS` (git checkout not listed), `_is_dangerous` logic (KTD7 rationale)

View File

@ -41,6 +41,7 @@ class ContextCompressor:
model_context_limit: int = 128_000, model_context_limit: int = 128_000,
headroom_threshold: float = 0.8, headroom_threshold: float = 0.8,
min_tokens: int = 8_000, min_tokens: int = 8_000,
auxiliary_model: str | None = None,
): ):
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
self._max_tokens = max_tokens self._max_tokens = max_tokens
@ -51,6 +52,11 @@ class ContextCompressor:
self._model_context_limit = model_context_limit self._model_context_limit = model_context_limit
self._headroom_threshold = headroom_threshold self._headroom_threshold = headroom_threshold
self._min_tokens = min_tokens self._min_tokens = min_tokens
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
# When set and differs from main model, _summarize tries auxiliary first,
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
self._auxiliary_model = auxiliary_model
def should_compress(self, messages: list[dict]) -> bool: def should_compress(self, messages: list[dict]) -> bool:
"""Check if compression should be triggered based on headroom ratio. """Check if compression should be triggered based on headroom ratio.
@ -92,8 +98,8 @@ class ContextCompressor:
if len(non_system) <= self._keep_recent: if len(non_system) <= self._keep_recent:
return messages # Not enough messages to compress return messages # Not enough messages to compress
old_msgs = non_system[:-self._keep_recent] old_msgs = non_system[: -self._keep_recent]
recent_msgs = non_system[-self._keep_recent:] recent_msgs = non_system[-self._keep_recent :]
# Compress old messages # Compress old messages
summary = await self._summarize(old_msgs) summary = await self._summarize(old_msgs)
@ -101,10 +107,12 @@ class ContextCompressor:
# Build compressed message list # Build compressed message list
compressed = list(system_msgs) compressed = list(system_msgs)
if summary: if summary:
compressed.append({ compressed.append(
"role": "system", {
"content": f"## Conversation Summary\n{summary}", "role": "system",
}) "content": f"## Conversation Summary\n{summary}",
}
)
compressed.extend(recent_msgs) compressed.extend(recent_msgs)
# Recursive check: if still over budget, compress again # Recursive check: if still over budget, compress again
@ -114,22 +122,30 @@ class ContextCompressor:
return self._truncate(compressed) return self._truncate(compressed)
if len(recent_msgs) > 1: if len(recent_msgs) > 1:
# Try keeping fewer recent messages # Try keeping fewer recent messages
return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1) return await self._compress_aggressive(
messages, _compression_depth=_compression_depth + 1
)
# Last resort: truncate # Last resort: truncate
return self._truncate(compressed) return self._truncate(compressed)
return compressed return compressed
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str: async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
"""Summarize a list of messages using LLM""" """Summarize a list of messages using LLM.
G4/U1: When ``auxiliary_model`` is configured and differs from the main
model, try auxiliary first (cost-optimization). On auxiliary failure OR
empty content (Finding 4 anti-pattern "did not throw is not succeeded"),
fall back to main model. Existing ``_simple_summary`` degradation
preserved as the final tier when main model also fails.
"""
if not self._llm_gateway: if not self._llm_gateway:
# No LLM available, do simple truncation # No LLM available, do simple truncation
return self._simple_summary(messages) return self._simple_summary(messages)
# Build summary prompt # Build summary prompt
conversation_text = "\n".join( conversation_text = "\n".join(
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
for m in messages
) )
# Pre-truncate if conversation_text exceeds safe token threshold # Pre-truncate if conversation_text exceeds safe token threshold
@ -145,6 +161,25 @@ class ContextCompressor:
f"{conversation_text}" f"{conversation_text}"
) )
# G4: Try auxiliary model first when configured (cheap route).
if self._auxiliary_model and self._auxiliary_model != self._model:
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._auxiliary_model,
agent_name="compressor",
task_type="summarization",
)
# Finding 4: empty content is a failure, not a success.
if response.content and response.content.strip():
return response.content
logger.info("Auxiliary model returned empty content, falling back to main model")
except Exception as e:
logger.info(
f"Auxiliary model summarization failed, falling back to main model: {e}"
)
# Main model path (or auxiliary fallback).
try: try:
response = await self._llm_gateway.chat( response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
@ -166,7 +201,9 @@ class ContextCompressor:
parts.append(f"[{role}]: {content}...") parts.append(f"[{role}]: {content}...")
return "\n".join(parts) return "\n".join(parts)
async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]: async def _compress_aggressive(
self, messages: list[dict], _compression_depth: int = 0
) -> list[dict]:
"""More aggressive compression when standard compression isn't enough""" """More aggressive compression when standard compression isn't enough"""
system_msgs = [m for m in messages if m.get("role") == "system"] system_msgs = [m for m in messages if m.get("role") == "system"]
non_system = [m for m in messages if m.get("role") != "system"] non_system = [m for m in messages if m.get("role") != "system"]
@ -176,10 +213,12 @@ class ContextCompressor:
summary = await self._summarize(non_system[:-1]) summary = await self._summarize(non_system[:-1])
compressed = list(system_msgs) compressed = list(system_msgs)
if summary: if summary:
compressed.append({ compressed.append(
"role": "system", {
"content": f"## Conversation Summary\n{summary}", "role": "system",
}) "content": f"## Conversation Summary\n{summary}",
}
)
compressed.append(non_system[-1]) compressed.append(non_system[-1])
return compressed return compressed
@ -191,7 +230,7 @@ class ContextCompressor:
for msg in messages: for msg in messages:
content = str(msg.get("content", "")) content = str(msg.get("content", ""))
if len(content) > self._max_tokens * 4: if len(content) > self._max_tokens * 4:
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"} msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
result.append(msg) result.append(msg)
return result return result
@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
if provider == "headroom": if provider == "headroom":
try: try:
from agentkit.core.headroom_compressor import HeadroomCompressor from agentkit.core.headroom_compressor import HeadroomCompressor
compressor = HeadroomCompressor(config) compressor = HeadroomCompressor(config)
if compressor.is_available(): if compressor.is_available():
return compressor return compressor
@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
) )
except ImportError: except ImportError:
logger.warning( logger.warning(
"HeadroomCompressor module not available. " "HeadroomCompressor module not available. Falling back to ContextCompressor."
"Falling back to ContextCompressor."
) )
# Fallback to summary compressor # Fallback to summary compressor
return ContextCompressor( return ContextCompressor(
@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]: def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
"""Render PromptTemplate with caching - returns cached result for same variables""" """Render PromptTemplate with caching - returns cached result for same variables"""
cache_key = hashlib.md5( cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
json.dumps(variables or {}, sort_keys=True).encode()
).hexdigest()
if not hasattr(template, '_render_cache'): if not hasattr(template, "_render_cache"):
template._render_cache = {} template._render_cache = {}
if cache_key in template._render_cache: if cache_key in template._render_cache:
@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic
def clear_cache(template) -> None: def clear_cache(template) -> None:
"""Clear the render cache on a PromptTemplate instance""" """Clear the render cache on a PromptTemplate instance"""
if hasattr(template, '_render_cache'): if hasattr(template, "_render_cache"):
template._render_cache.clear() template._render_cache.clear()

View File

@ -2,8 +2,22 @@
All layers (ReActEngine, Portal, Chat) should use these constants All layers (ReActEngine, Portal, Chat) should use these constants
to ensure consistent user-facing messages. to ensure consistent user-facing messages.
G7/U2: Also hosts ``EmergencyError`` and ``EmergencyRules`` for the
three-tier fallback chain's Emergency layer (rule-based classifier,
no LLM). See ``docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md``.
""" """
from dataclasses import dataclass
from typing import Any
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
# When LLM returns empty content after all fallback models exhausted # When LLM returns empty content after all fallback models exhausted
EMPTY_LLM_RESPONSE = ( EMPTY_LLM_RESPONSE = (
"模型未返回有效内容,已尝试备用模型仍未成功。" "模型未返回有效内容,已尝试备用模型仍未成功。"
@ -16,3 +30,135 @@ MAX_STEPS_REACHED = "已达到最大推理步数,但仍未得到完整结论
# When a shell command succeeds but produces no output # When a shell command succeeds but produces no output
SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]" SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]"
# ── G7/U2: Emergency layer ──────────────────────────────────────
@dataclass
class EmergencyError:
"""Structured error produced by the Emergency layer (rule-classified).
Carries a stable ``error_code`` for programmatic dispatch (frontend
retry UI, telemetry), a human-readable ``message`` mirroring
``EMPTY_LLM_RESPONSE`` style, actionable ``suggestions``, and the
original exception string for traceability.
The ``retryable`` flag distinguishes recoverable user errors
(timeout, loop, LLM hiccup) from internal bugs (retryable=False).
"""
error_code: str # "timeout" | "loop_detected" | "llm_failure" | "internal_error"
message: str # human-readable Chinese message
suggestions: list[str] # actionable user-facing suggestions
retryable: bool # whether a user retry might succeed
original_error: str # str(exc) for traceability
def to_dict(self) -> dict[str, Any]:
return {
"error_code": self.error_code,
"message": self.message,
"suggestions": list(self.suggestions),
"retryable": self.retryable,
"original_error": self.original_error,
}
def to_error_message(self) -> str:
"""Format as a single human-readable string with suggestions.
Mirrors ``EMPTY_LLM_RESPONSE`` style: ``<message>建议1) ... 2) ...``
"""
if not self.suggestions:
return self.message
suggestion_str = "".join(f"{i}) {s}" for i, s in enumerate(self.suggestions, 1))
# Strip trailing "" and prefix with "建议:"
suggestion_str = suggestion_str.rstrip("")
return f"{self.message}建议:{suggestion_str}"
# Default rule set: (exception_type, error_code, message, suggestions, retryable)
# ponytail: ceiling — rule-based, no LLM. Adding a new rule = append a tuple.
# Upgrade path: LLM-driven suggestions would require a separate classifier
# class (out of scope per brainstorm KTD4).
_DEFAULT_RULES: list[tuple[type[Exception], str, str, list[str], bool]] = [
(
TaskTimeoutError,
"timeout",
"任务执行超时。",
["稍后重试", "简化任务范围"],
True,
),
(
LoopDetectedError,
"loop_detected",
"检测到推理循环。",
["拆分任务", "检查工具参数"],
True,
),
(
LLMProviderError,
"llm_failure",
"LLM 服务调用失败。",
["稍后重试", "切换模型"],
True,
),
]
_DEFAULT_ERROR_CODE = "internal_error"
_DEFAULT_MESSAGE = "Agent 执行内部错误。"
_DEFAULT_SUGGESTIONS: list[str] = ["联系管理员"]
_DEFAULT_RETRYABLE = False
class EmergencyRules:
"""Rule-based classifier for the Emergency layer.
Maps exception types to ``EmergencyError`` instances. No LLM, no I/O
pure function of ``(exception, config)``.
Caller responsibility: ``TaskCancelledError`` MUST propagate as-is
(per KTD3); the caller checks ``isinstance(exc, TaskCancelledError)``
before invoking :meth:`classify`. Calling :meth:`classify` with a
``TaskCancelledError`` raises ``ValueError`` to surface the bug.
Config override shape (optional ``config`` arg):
```python
{
"timeout": {"suggestions": ["自定义建议"], "retryable": False},
"llm_failure": {"message": "自定义消息"},
}
```
"""
@staticmethod
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
if isinstance(exc, TaskCancelledError):
# Contract: caller must check before invoking classify.
raise ValueError(
"TaskCancelledError must propagate as-is; caller must check "
"before invoking EmergencyRules.classify"
)
config = config or {}
for exc_type, code, message, suggestions, retryable in _DEFAULT_RULES:
if isinstance(exc, exc_type):
override = config.get(code, {}) if isinstance(config, dict) else {}
return EmergencyError(
error_code=code,
message=override.get("message", message),
suggestions=override.get("suggestions", list(suggestions)),
retryable=override.get("retryable", retryable),
original_error=str(exc),
)
# Generic fallback
override = config.get(_DEFAULT_ERROR_CODE, {}) if isinstance(config, dict) else {}
return EmergencyError(
error_code=_DEFAULT_ERROR_CODE,
message=override.get("message", _DEFAULT_MESSAGE),
suggestions=override.get("suggestions", list(_DEFAULT_SUGGESTIONS)),
retryable=override.get("retryable", _DEFAULT_RETRYABLE),
original_error=str(exc),
)

View File

@ -128,6 +128,10 @@ class TaskResult:
completed_at: datetime completed_at: datetime
metrics: dict | None = None metrics: dict | None = None
trace: Any | None = None trace: Any | None = None
# G7/U2: Emergency layer structured error. None preserves existing contract
# (error_message alone carries the human-readable string). When set,
# carries serialized EmergencyError.to_dict() for programmatic dispatch.
error_struct: dict | None = None
def to_dict(self) -> dict: def to_dict(self) -> dict:
d = { d = {
@ -142,6 +146,8 @@ class TaskResult:
} }
if self.trace is not None: if self.trace is not None:
d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace
if self.error_struct is not None:
d["error_struct"] = self.error_struct
return d return d
@classmethod @classmethod
@ -162,6 +168,7 @@ class TaskResult:
completed_at=completed_at or datetime.now(timezone.utc), completed_at=completed_at or datetime.now(timezone.utc),
metrics=data.get("metrics"), metrics=data.get("metrics"),
trace=data.get("trace"), trace=data.get("trace"),
error_struct=data.get("error_struct"),
) )

View File

@ -30,6 +30,7 @@ from typing import Any
from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.orchestrator.rollback import RollbackExecutor
from .expert import Expert from .expert import Expert
from .plan import ( from .plan import (
@ -72,12 +73,17 @@ class TeamOrchestrator:
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰 DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"}) STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
DEFAULT_ROLLBACK_TIMEOUT = 30.0
def __init__( def __init__(
self, self,
team: ExpertTeam, team: ExpertTeam,
max_concurrent_phases: int | None = None, max_concurrent_phases: int | None = None,
checkpoint: Any = None, checkpoint: Any = None,
workspace_root: str | None = None,
rollback_timeout: float | None = None,
) -> None: ) -> None:
self._team = team self._team = team
# Track temporary agent names created for context isolation (KTD3) # Track temporary agent names created for context isolation (KTD3)
@ -93,6 +99,10 @@ class TeamOrchestrator:
self._phase_semaphore = asyncio.Semaphore(limit) self._phase_semaphore = asyncio.Semaphore(limit)
# U7: Pipeline checkpoint for crash recovery # U7: Pipeline checkpoint for crash recovery
self._checkpoint = checkpoint self._checkpoint = checkpoint
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
# Both default to no-op-friendly values so existing call sites behave identically.
self._workspace_root = workspace_root
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
async def execute(self, task: str) -> dict[str, Any]: async def execute(self, task: str) -> dict[str, Any]:
"""Execute a task in pipeline mode. """Execute a task in pipeline mode.
@ -262,8 +272,23 @@ class TeamOrchestrator:
else: else:
phase_results[ph.id] = result phase_results[ph.id] = result
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
# When phase configures both validation_command and rollback_command:
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
# 2. if validation fails, run rollback_command
# 3. if rollback passes (exit 0), save checkpoint
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
# When neither command is set, behavior is unchanged (existing save).
should_save_checkpoint = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(result, (Exception, asyncio.CancelledError))
):
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
# U7: Save checkpoint after phase finalizes (success or failure) # U7: Save checkpoint after phase finalizes (success or failure)
if self._checkpoint is not None: if should_save_checkpoint and self._checkpoint is not None:
try: try:
await self._checkpoint.save(plan.id, ph, plan.status.value) await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e: except Exception as e:
@ -393,9 +418,7 @@ class TeamOrchestrator:
# PENDING phases remain PENDING — will be executed by _run_pipeline # PENDING phases remain PENDING — will be executed by _run_pipeline
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume # P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
self._debate_count = sum( self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
)
logger.info( logger.info(
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, " f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
@ -688,9 +711,9 @@ class TeamOrchestrator:
and prev_phase.result and prev_phase.result
): ):
# U4: Resolve offloaded content from workspace # U4: Resolve offloaded content from workspace
collaboration_outputs[contract.from_expert] = ( collaboration_outputs[
await self._read_dependency_output(prev_phase) contract.from_expert
) ] = await self._read_dependency_output(prev_phase)
break break
# Emit expert_step event # Emit expert_step event
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
# Recursively mark their dependents # Recursively mark their dependents
await self._mark_dependents_failed(ph.id, plan, phase_results) await self._mark_dependents_failed(ph.id, plan, phase_results)
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
"""G9/U4: run validation_command + rollback_command for a failed phase.
Returns True if checkpoint save should proceed (R21 ordering).
- Validation passes save checkpoint (phase state recoverable)
- Validation fails, rollback passes save checkpoint (rolled back state)
- Validation fails, rollback fails skip checkpoint (broken state)
- Subprocess spawn failure or timeout skip checkpoint
"""
executor = RollbackExecutor(
working_dir=self._workspace_root,
timeout=self._rollback_timeout,
)
await self._broadcast_event(
"phase_rollback_started",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_command": ph.validation_command,
"rollback_command": ph.rollback_command,
},
)
# ponytail: validate first; if validation passes, rollback is skipped (no need).
validation = await executor.validate(ph.validation_command or "")
if validation.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": False,
"validation_passed": True,
},
)
return True
rollback = await executor.execute(ph.rollback_command or "")
if rollback.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": True,
"validation_passed": False,
"rollback_stdout": rollback.stdout,
},
)
return True
logger.error(
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
)
await self._broadcast_event(
"phase_rollback_failed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_passed": False,
"rollback_exit_code": rollback.exit_code,
"rollback_stderr": rollback.stderr,
},
)
return False
async def _synthesize_results( async def _synthesize_results(
self, lead: Expert, task: str, completed_phases: list[PlanPhase] self, lead: Expert, task: str, completed_phases: list[PlanPhase]
) -> dict[str, Any]: ) -> dict[str, Any]:

View File

@ -182,6 +182,11 @@ class PlanPhase:
collaboration_contracts: list[CollaborationContract] = field(default_factory=list) collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
rework_count: int = 0 rework_count: int = 0
review_feedback: str | None = None review_feedback: str | None = None
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
# validation_command runs first; if it fails, rollback_command runs.
# canonical rollback pattern: `git checkout <specific_files>`.
validation_command: str | None = None
rollback_command: str | None = None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""序列化为字典""" """序列化为字典"""
@ -192,7 +197,7 @@ class PlanPhase:
result_str = self.result.get("content", str(self.result)) result_str = self.result.get("content", str(self.result))
else: else:
result_str = str(self.result) result_str = str(self.result)
return { out: dict[str, Any] = {
"id": self.id, "id": self.id,
"name": self.name, "name": self.name,
"assigned_expert": self.assigned_expert, "assigned_expert": self.assigned_expert,
@ -206,6 +211,12 @@ class PlanPhase:
"rework_count": self.rework_count, "rework_count": self.rework_count,
"review_feedback": self.review_feedback, "review_feedback": self.review_feedback,
} }
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
if self.validation_command is not None:
out["validation_command"] = self.validation_command
if self.rollback_command is not None:
out["rollback_command"] = self.rollback_command
return out
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> PlanPhase: def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
@ -230,6 +241,8 @@ class PlanPhase:
collaboration_contracts=contracts, collaboration_contracts=contracts,
rework_count=data.get("rework_count", 0), rework_count=data.get("rework_count", 0),
review_feedback=data.get("review_feedback"), review_feedback=data.get("review_feedback"),
validation_command=data.get("validation_command"),
rollback_command=data.get("rollback_command"),
) )

View File

@ -203,6 +203,9 @@ class LLMConfig:
model_aliases: dict[str, str] = field(default_factory=dict) model_aliases: dict[str, str] = field(default_factory=dict)
fallbacks: dict[str, list[str]] = field(default_factory=dict) fallbacks: dict[str, list[str]] = field(default_factory=dict)
cache: CacheConfig | None = None cache: CacheConfig | None = None
# G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization).
# Resolved via existing model_aliases mechanism. None = use main model only.
auxiliary_model: str | None = None
@classmethod @classmethod
def from_dict(cls, data: dict) -> "LLMConfig": def from_dict(cls, data: dict) -> "LLMConfig":
@ -254,4 +257,5 @@ class LLMConfig:
model_aliases=data.get("model_aliases", {}), model_aliases=data.get("model_aliases", {}),
fallbacks=data.get("fallbacks", {}), fallbacks=data.get("fallbacks", {}),
cache=cache, cache=cache,
auxiliary_model=data.get("auxiliary_model"),
) )

View File

@ -0,0 +1,97 @@
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
复用 VerificationLoop asyncio.create_subprocess_shell 模式 (KTD7)
绕过 ShellTool避免 confirm_callback `git checkout` 的拦截
设计依据
- KTD6: 回滚是 opt-in 行为未配置 rollback_command 时不会执行
- KTD7: 不走 ShellTool避免 _is_dangerous 触发 confirm_callback
- R21: checkpoint.save 仅在回滚校验通过后调用
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class RollbackResult:
"""单次 subprocess 执行结果"""
passed: bool
exit_code: int
stdout: str
stderr: str
command: str
class RollbackExecutor:
"""执行 validation_command / rollback_command 的子进程封装
VerificationLoop 同构但语义不同
- validate(): 返回 passed=False 表示校验失败需要触发 rollback
- execute(): 返回 passed=False 表示回滚本身失败需跳过 checkpoint.save (R21)
"""
def __init__(
self,
working_dir: str | None = None,
timeout: float = 30.0,
) -> None:
self._working_dir = working_dir
self._timeout = timeout
async def _run(self, command: str) -> RollbackResult:
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self._working_dir,
)
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Failed to spawn command: {e}",
command=command,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Command timed out after {self._timeout}s: {command}",
command=command,
)
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
return RollbackResult(
passed=proc.returncode == 0,
exit_code=proc.returncode if proc.returncode is not None else -1,
stdout=out_str,
stderr=err_str,
command=command,
)
async def validate(self, command: str) -> RollbackResult:
"""运行 validation_commandpassed=False 表示需要触发 rollback"""
return await self._run(command)
async def execute(self, command: str) -> RollbackResult:
"""运行 rollback_commandpassed=False 表示回滚本身失败 (R21)"""
return await self._run(command)

View File

@ -0,0 +1,199 @@
"""G7/U3 — Three-tier fallback chain (main → Recovery → Emergency).
Wired at chat.py REST send_message endpoint. Composes U2's EmergencyRules
with existing ReflexionEngine for the Recovery layer.
Scope (KTD5): Only the chat REST path is wrapped. CLI / ReWOO / Reflexion
internal ReAct calls are NOT wrapped (would create recursive loop).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.fallback import EmergencyError, EmergencyRules
from agentkit.core.react import ReActEngine, ReActResult
from agentkit.core.reflexion import ReflexionEngine
from agentkit.llm.gateway import LLMGateway
logger = logging.getLogger(__name__)
# ReActResult.status values that indicate soft failure → trigger Recovery.
# "success" is the only clean-pass; everything else is fallback-worthy.
_SOFT_FAILURE_STATUSES = frozenset({"empty_fallback", "verify_failed", "timeout"})
@dataclass
class ChatExecutionResult:
"""Wrapper produced by execute_with_fallback_chain.
Carries a ReActResult-like ``output`` field plus an optional
``error_struct`` (set only when Emergency tier fires). The chat
handler reads ``.output`` for the assistant reply and ``.error_struct``
for the optional structured error payload.
"""
output: str
status: str # "success" | "recovered" | "emergency"
error_struct: dict[str, Any] | None = None
trajectory: list[Any] = field(default_factory=list)
total_steps: int = 0
total_tokens: int = 0
fallback_strategy: str | None = None
def _react_to_chat_result(react: ReActResult) -> ChatExecutionResult:
return ChatExecutionResult(
output=react.output,
status="success",
trajectory=react.trajectory,
total_steps=react.total_steps,
total_tokens=react.total_tokens,
fallback_strategy=react.fallback_strategy,
)
def _reflexion_to_chat_result(reflexion_result: Any) -> ChatExecutionResult:
"""Best-effort conversion from ReflexionResult to ChatExecutionResult."""
output = getattr(reflexion_result, "output", None) or getattr(
reflexion_result, "final_answer", ""
)
return ChatExecutionResult(
output=output or "",
status="recovered",
trajectory=getattr(reflexion_result, "trajectory", []) or [],
total_steps=getattr(reflexion_result, "total_steps", 0),
total_tokens=getattr(reflexion_result, "total_tokens", 0),
fallback_strategy="reflexion_recovery",
)
def _to_emergency(exc: Exception, config: dict | None) -> ChatExecutionResult:
emergency: EmergencyError = EmergencyRules.classify(exc, config)
return ChatExecutionResult(
output=emergency.to_error_message(),
status="emergency",
error_struct=emergency.to_dict(),
fallback_strategy="emergency",
)
async def execute_with_fallback_chain(
*,
react_engine: ReActEngine,
llm_gateway: LLMGateway,
messages: list[dict[str, str]],
tools: list[Any] | None,
model: str,
agent_name: str,
system_prompt: str | None,
fallback_chain_config: dict | None = None,
) -> ChatExecutionResult:
"""Three-tier fallback chain: Main → Recovery (ReflexionEngine) → Emergency.
KTD5: only this entry point wraps the chain. ReflexionEngine's internal
ReAct call bypasses the chain (no recursive loop possible).
Returns ChatExecutionResult with status:
- "success": main agent succeeded
- "recovered": main failed, ReflexionEngine recovery succeeded
- "emergency": main failed, recovery failed/exhausted, Emergency layer fired
"""
config = fallback_chain_config or {}
recovery_cfg = config.get("recovery", {}) if isinstance(config, dict) else {}
emergency_cfg = config.get("emergency", {}) if isinstance(config, dict) else {}
recovery_enabled = recovery_cfg.get("enabled", True) if isinstance(recovery_cfg, dict) else True
emergency_enabled = (
emergency_cfg.get("enabled", True) if isinstance(emergency_cfg, dict) else True
)
max_reflections = recovery_cfg.get("max_retries", 1) if isinstance(recovery_cfg, dict) else 1
# ── Tier 1: Main ──────────────────────────────────────────────
main_exc: Exception | None = None
try:
result = await react_engine.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
system_prompt=system_prompt,
)
if result.status == "success":
return _react_to_chat_result(result)
# Soft failure (empty_fallback / verify_failed / timeout) → trigger Recovery
if result.status in _SOFT_FAILURE_STATUSES:
main_exc = AgentSoftFailureError(
f"main agent status={result.status}: {result.output[:200]}"
)
else:
# Unknown status — treat as success-like (don't trigger recovery)
return _react_to_chat_result(result)
except TaskCancelledError:
# KTD3: TaskCancelledError propagates as-is, NOT routed to Emergency.
raise
except (TaskTimeoutError, LoopDetectedError, LLMProviderError) as exc:
main_exc = exc
except Exception as exc: # noqa: BLE001 - last-resort catch for Emergency routing
main_exc = exc
# ── Tier 2: Recovery (ReflexionEngine) ────────────────────────
if recovery_enabled and main_exc is not None:
try:
reflexion = ReflexionEngine(
llm_gateway=llm_gateway,
max_reflections=max_reflections,
)
recovery_result = await reflexion.execute(
messages=messages,
tools=tools,
model=model,
agent_name=agent_name,
system_prompt=system_prompt,
)
# Recovery succeeds if Reflexion reports success or produces output.
recovery_status = getattr(recovery_result, "status", "")
if recovery_status == "success" or getattr(recovery_result, "output", None):
return _reflexion_to_chat_result(recovery_result)
logger.warning(
f"Recovery layer did not succeed (status={recovery_status}), "
f"falling through to Emergency"
)
except TaskCancelledError:
raise
except Exception as recovery_exc: # noqa: BLE001
logger.warning(f"Recovery layer raised: {recovery_exc}; falling through to Emergency")
# ── Tier 3: Emergency ─────────────────────────────────────────
if not emergency_enabled:
# Re-raise original exception if Emergency disabled.
if main_exc is not None:
raise main_exc
# No exception but no success either — synthesise an emergency-style result.
return ChatExecutionResult(
output="Agent 未返回有效结果且 Emergency 层已禁用。",
status="emergency",
fallback_strategy="emergency_disabled",
)
# main_exc may be None if main returned soft-failure status without raising.
# Synthesize a generic exception for Emergency classification.
exc_for_emergency = main_exc or AgentSoftFailureError("soft failure without exception")
return _to_emergency(exc_for_emergency, config)
class AgentSoftFailureError(Exception):
"""Internal marker — main agent returned a soft-failure status without raising.
Used to feed the Emergency classifier when main status was e.g.
``empty_fallback`` (no exception raised, but result not usable).
Classified as ``internal_error`` by EmergencyRules (generic fallback).
"""

View File

@ -119,6 +119,8 @@ class ServerConfig:
prompt_cache: dict[str, Any] | None = None, prompt_cache: dict[str, Any] | None = None,
streaming: dict[str, Any] | None = None, streaming: dict[str, Any] | None = None,
verification: dict[str, Any] | None = None, verification: dict[str, Any] | None = None,
rollback: dict[str, Any] | None = None,
fallback_chain: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None, on_change: Callable[["ServerConfig"], None] | None = None,
): ):
self.host = host self.host = host
@ -153,6 +155,12 @@ class ServerConfig:
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1) # U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
# verification_enabled=False 时此配置无效 # verification_enabled=False 时此配置无效
self.verification = verification or {} self.verification = verification or {}
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
self.rollback = rollback or {}
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
# controls three-tier chain at chat.py REST send_message (KTD5).
self.fallback_chain = fallback_chain or {}
self.on_change = on_change self.on_change = on_change
# Config watching state # Config watching state
@ -240,6 +248,10 @@ class ServerConfig:
prompt_cache_data = data.get("prompt_cache", {}) prompt_cache_data = data.get("prompt_cache", {})
streaming_data = data.get("streaming", {}) streaming_data = data.get("streaming", {})
verification_data = data.get("verification", {}) verification_data = data.get("verification", {})
# G9/U4: rollback 配置 (从 YAML 读取opt-in)
rollback_data = data.get("rollback", {})
# G7/U3: fallback_chain 配置 (从 YAML 读取)
fallback_chain_data = data.get("fallback_chain", {})
return cls( return cls(
host=server.get("host", "0.0.0.0"), host=server.get("host", "0.0.0.0"),
@ -271,6 +283,8 @@ class ServerConfig:
prompt_cache=prompt_cache_data, prompt_cache=prompt_cache_data,
streaming=streaming_data, streaming=streaming_data,
verification=verification_data, verification=verification_data,
rollback=rollback_data,
fallback_chain=fallback_chain_data,
) )
@staticmethod @staticmethod
@ -320,6 +334,9 @@ class ServerConfig:
model_aliases=model_aliases, model_aliases=model_aliases,
fallbacks=data.get("fallbacks", {}), fallbacks=data.get("fallbacks", {}),
cache=cache_config, cache=cache_config,
# G4/U1: auxiliary model alias for cost-sensitive summarization.
# Resolved via model_aliases; None = no auxiliary routing.
auxiliary_model=data.get("auxiliary_model"),
) )
@staticmethod @staticmethod

View File

@ -27,6 +27,7 @@ from pydantic import BaseModel
from agentkit.chat.skill_routing import ExecutionMode from agentkit.chat.skill_routing import ExecutionMode
from agentkit.core.protocol import CancellationToken from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine from agentkit.core.react import ReActEngine
from agentkit.server._fallback_chain import execute_with_fallback_chain
from agentkit.session.manager import SessionManager from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus from agentkit.session.models import MessageRole, SessionStatus
@ -610,7 +611,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
system_prompt = getattr(agent, "_system_prompt", None) or ( system_prompt = getattr(agent, "_system_prompt", None) or (
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
) )
result = await react_engine.execute( # G7/U3: Three-tier fallback chain (main → Recovery → Emergency).
# Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass.
server_config = getattr(req.app.state, "server_config", None)
fallback_chain_cfg = (
getattr(server_config, "fallback_chain", None) if server_config else None
)
chat_result = await execute_with_fallback_chain(
react_engine=react_engine,
llm_gateway=req.app.state.llm_gateway,
messages=chat_messages, messages=chat_messages,
tools=tools, tools=tools,
model=agent.get_model() model=agent.get_model()
@ -618,16 +627,26 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
else getattr(agent, "_llm_model", "default"), else getattr(agent, "_llm_model", "default"),
agent_name=agent.name, agent_name=agent.name,
system_prompt=system_prompt, system_prompt=system_prompt,
fallback_chain_config=fallback_chain_cfg,
) )
# Append assistant reply # Append assistant reply
assistant_msg = await sm.append_message( assistant_msg = await sm.append_message(
session_id=session_id, session_id=session_id,
role=MessageRole.ASSISTANT, role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result), content=chat_result.output,
agent_name=agent.name, agent_name=agent.name,
) )
return _message_to_response(assistant_msg) response = _message_to_response(assistant_msg)
# Attach structured error payload when Emergency tier fired.
if chat_result.error_struct is not None:
response_dict = (
response.model_dump() if hasattr(response, "model_dump") else dict(response)
)
response_dict["error_struct"] = chat_result.error_struct
response_dict["fallback_status"] = chat_result.status
return response_dict
return response
except Exception as e: except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}") logger.error(f"Chat execution error for session {session_id}: {e}")

View File

@ -0,0 +1,400 @@
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
Verifies:
- auxiliary_model routes _summarize through the cheaper model first
- empty content (Finding 4 anti-pattern) triggers fallback to main model
- auxiliary exception triggers fallback to main model
- both auxiliary and main failing falls through to _simple_summary
- auxiliary_model=None preserves existing single-model behavior (characterization)
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
"""
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.compressor import ContextCompressor
from agentkit.llm.config import LLMConfig
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── Helpers ──────────────────────────────────────────
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
"""Mock LLMGateway returning a fixed response."""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(
return_value=LLMResponse(
content=content,
model=model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
)
)
return gateway
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
Each call to gateway.chat(model=X) pops the next response for X from a queue,
so repeated calls to the same model can return different values.
"""
from agentkit.llm.gateway import LLMGateway
gateway = MagicMock(spec=LLMGateway)
queues = {m: list(rs) for m, rs in responses_by_model.items()}
async def chat_side_effect(*, messages, model, **kwargs):
queue = queues.get(model)
if queue is None:
raise ValueError(f"unexpected model={model}")
if not queue:
raise ValueError(f"queue for model={model} exhausted")
item = queue.pop(0)
if isinstance(item, Exception):
raise item
return item
gateway.chat = AsyncMock(side_effect=chat_side_effect)
return gateway
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
"""Generate long messages that exceed token budget (triggers compression)."""
messages = [{"role": "system", "content": "You are a helpful assistant."}]
for i in range(count):
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
messages.append({"role": "user", "content": "recent question"})
messages.append({"role": "assistant", "content": "recent answer"})
return messages
# ── Characterization: auxiliary_model=None preserves existing behavior ──
class TestAuxiliaryNoneCharacterization:
"""auxiliary_model=None (default) — single model call, existing behavior."""
async def test_no_auxiliary_calls_main_once(self):
gateway = make_gateway_with_response("main summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
# auxiliary_model omitted → None
)
result = await compressor.compress(make_long_messages())
gateway.chat.assert_awaited_once()
# The call used the main model
assert gateway.chat.await_args.kwargs.get("model") == "main"
# Summary surfaced in result
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_main_failure_falls_to_simple_summary(self):
gateway = MagicMock()
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
)
result = await compressor.compress(make_long_messages())
# _simple_summary produces truncated messages with "..."
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "..." in summary_msgs[0]["content"]
# ── New behavior: auxiliary routing ──────────────────
class TestAuxiliaryRouting:
"""auxiliary_model set and differs from main → auxiliary tried first."""
async def test_auxiliary_success_returns_auxiliary_content(self):
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content="aux summary",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
"main": [
LLMResponse(
content="MAIN SHOULD NOT BE USED",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Auxiliary called; main NOT called
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 0
# Result contains auxiliary summary
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("aux summary" in m["content"] for m in summary_msgs)
async def test_empty_content_triggers_main_fallback(self):
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content="",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
)
],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Auxiliary called once (returned empty)
# Main called once (fallback)
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
# Result contains main summary (not the empty auxiliary)
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_whitespace_content_triggers_main_fallback(self):
"""Whitespace-only content also counts as empty (Finding 4)."""
gateway = make_gateway_side_effect(
{
"fast": [
LLMResponse(
content=" \n ",
model="fast",
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
)
],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
await compressor.compress(make_long_messages())
# Both auxiliary and main called
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
async def test_auxiliary_exception_triggers_main_fallback(self):
from agentkit.core.exceptions import LLMProviderError
gateway = make_gateway_side_effect(
{
"fast": [LLMProviderError("aux", "provider down")],
"main": [
LLMResponse(
content="main summary",
model="main",
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
)
],
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Both called; main succeeded
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) == 1
assert len(main_calls) == 1
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert any("main summary" in m["content"] for m in summary_msgs)
async def test_both_fail_falls_to_simple_summary(self):
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
# Note: aggressive compression path may invoke _summarize multiple times.
# Queue provides enough responses to handle that without raising queue-exhausted.
gateway = make_gateway_side_effect(
{
"fast": [Exception("aux boom")] * 5,
"main": [Exception("main boom")] * 5,
}
)
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
result = await compressor.compress(make_long_messages())
# Both called at least once
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
assert len(aux_calls) >= 1
assert len(main_calls) >= 1
# _simple_summary output has "..." truncation markers
summary_msgs = [
m
for m in result
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
]
assert len(summary_msgs) == 1
assert "..." in summary_msgs[0]["content"]
async def test_auxiliary_equal_to_main_skipped(self):
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
gateway = make_gateway_with_response("main summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="main", # same as main
)
await compressor.compress(make_long_messages())
# Only one call (to main); auxiliary block skipped
assert gateway.chat.await_count == 1
assert gateway.chat.await_args.kwargs.get("model") == "main"
async def test_audit_fields_preserved(self):
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
gateway = make_gateway_with_response("aux summary")
compressor = ContextCompressor(
llm_gateway=gateway,
max_tokens=100,
keep_recent=2,
model="main",
auxiliary_model="fast",
)
# Override the mock to use a single-response gateway where auxiliary succeeds
# (the make_gateway_with_response mock returns same response regardless of model)
await compressor.compress(make_long_messages())
# Single call (auxiliary succeeded) — verify audit fields
call_kwargs = gateway.chat.await_args.kwargs
assert call_kwargs.get("agent_name") == "compressor"
assert call_kwargs.get("task_type") == "summarization"
# ── Config wiring ────────────────────────────────────
class TestConfigWiring:
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
def test_llm_config_from_dict_reads_auxiliary_model(self):
cfg = LLMConfig.from_dict(
{
"providers": {},
"model_aliases": {"fast": "p/m"},
"auxiliary_model": "fast",
}
)
assert cfg.auxiliary_model == "fast"
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
cfg = LLMConfig.from_dict({"providers": {}})
assert cfg.auxiliary_model is None
def test_llm_config_default_auxiliary_none(self):
cfg = LLMConfig()
assert cfg.auxiliary_model is None
def test_server_config_build_llm_config_reads_auxiliary_model(self):
from agentkit.server.config import ServerConfig
llm_data = {
"providers": {
"p": {
"type": "openai",
"api_key": "k",
"base_url": "http://x",
"models": {"m": {"alias": "fast"}},
}
},
"auxiliary_model": "fast",
}
llm_config = ServerConfig._build_llm_config(llm_data)
assert llm_config.auxiliary_model == "fast"
# Also verify model_aliases still built correctly
assert llm_config.model_aliases.get("fast") == "p/m"
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
from agentkit.server.config import ServerConfig
llm_config = ServerConfig._build_llm_config({"providers": {}})
assert llm_config.auxiliary_model is None

View File

@ -0,0 +1,318 @@
"""G7/U2 — Emergency layer rule template + TaskResult extension.
Verifies:
- EmergencyRules.classify maps each exception type to correct error_code
- TaskCancelledError raises ValueError (caller must propagate as-is)
- EmergencyError.to_dict produces all 5 fields
- EmergencyError.to_error_message formats suggestions as "建议1) ... 2) ..."
- Config overrides apply (suggestions, retryable, message)
- TaskResult.error_struct field: default None preserves byte-for-byte
to_dict() output (backward compat)
- TaskResult round-trip serialization includes error_struct when set
"""
from datetime import datetime, timezone
import pytest
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.fallback import (
EMPTY_LLM_RESPONSE,
MAX_STEPS_REACHED,
SHELL_NO_OUTPUT,
EmergencyError,
EmergencyRules,
)
from agentkit.core.protocol import TaskResult
# ── Constants unchanged (contract preservation) ──────
class TestExistingConstantsUnchanged:
"""Existing 3 constants preserved byte-for-byte."""
def test_empty_llm_response_unchanged(self):
assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE
assert "建议" in EMPTY_LLM_RESPONSE
def test_max_steps_reached_unchanged(self):
assert "已达到最大推理步数" in MAX_STEPS_REACHED
def test_shell_no_output_unchanged(self):
assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]"
# ── EmergencyRules.classify ──────────────────────────
class TestEmergencyRulesClassify:
"""classify() maps exception types to EmergencyError."""
def test_timeout(self):
exc = TaskTimeoutError(task_id="t1", timeout_seconds=30)
err = EmergencyRules.classify(exc)
assert err.error_code == "timeout"
assert err.retryable is True
assert "稍后重试" in err.suggestions
assert "简化任务范围" in err.suggestions
assert err.original_error == str(exc)
def test_loop_detected(self):
exc = LoopDetectedError(tool_name="shell", repetitions=3)
err = EmergencyRules.classify(exc)
assert err.error_code == "loop_detected"
assert err.retryable is True
assert "拆分任务" in err.suggestions
assert "检查工具参数" in err.suggestions
def test_llm_provider_error(self):
exc = LLMProviderError("openai", "rate limited")
err = EmergencyRules.classify(exc)
assert err.error_code == "llm_failure"
assert err.retryable is True
assert "稍后重试" in err.suggestions
assert "切换模型" in err.suggestions
def test_llm_error_subclass_also_classified(self):
"""LLMProviderError is a subclass of LLMError; ensure isinstance check works."""
from agentkit.core.exceptions import LLMError
class CustomLLMError(LLMError):
pass
err = EmergencyRules.classify(CustomLLMError("custom"))
# CustomLLMError is NOT a LLMProviderError, falls through to generic
assert err.error_code == "internal_error"
def test_generic_exception_internal_error(self):
err = EmergencyRules.classify(Exception("unknown boom"))
assert err.error_code == "internal_error"
assert err.retryable is False
assert "联系管理员" in err.suggestions
assert err.original_error == "unknown boom"
def test_task_cancelled_raises(self):
"""TaskCancelledError must propagate; classify() raises ValueError."""
exc = TaskCancelledError(task_id="t1")
with pytest.raises(ValueError, match="TaskCancelledError"):
EmergencyRules.classify(exc)
def test_subclass_of_timeout_classified(self):
"""Subclasses of TaskTimeoutError are classified as timeout."""
class CustomTimeout(TaskTimeoutError):
def __init__(self):
super().__init__(task_id="custom", timeout_seconds=10)
err = EmergencyRules.classify(CustomTimeout())
assert err.error_code == "timeout"
# ── EmergencyError serialization ─────────────────────
class TestEmergencyErrorSerialization:
"""to_dict / to_error_message on EmergencyError."""
def test_to_dict_produces_all_five_fields(self):
err = EmergencyError(
error_code="timeout",
message="任务执行超时。",
suggestions=["稍后重试", "简化任务范围"],
retryable=True,
original_error="Task t1 timed out after 30s",
)
d = err.to_dict()
assert set(d.keys()) == {
"error_code",
"message",
"suggestions",
"retryable",
"original_error",
}
assert d["error_code"] == "timeout"
assert d["message"] == "任务执行超时。"
assert d["suggestions"] == ["稍后重试", "简化任务范围"]
assert d["retryable"] is True
assert d["original_error"] == "Task t1 timed out after 30s"
def test_to_dict_suggestions_list_is_copy(self):
"""to_dict returns a fresh list, not the internal reference."""
suggestions = ["a", "b"]
err = EmergencyError(
error_code="x",
message="m",
suggestions=suggestions,
retryable=False,
original_error="e",
)
d = err.to_dict()
assert d["suggestions"] is not suggestions
d["suggestions"].append("c")
assert err.suggestions == ["a", "b"]
def test_to_error_message_with_suggestions(self):
err = EmergencyError(
error_code="timeout",
message="任务执行超时。",
suggestions=["稍后重试", "简化任务范围"],
retryable=True,
original_error="err",
)
msg = err.to_error_message()
assert msg.startswith("任务执行超时。建议:")
assert "1) 稍后重试" in msg
assert "2) 简化任务范围" in msg
# Format mirrors EMPTY_LLM_RESPONSE style
assert msg.endswith("")
def test_to_error_message_no_suggestions(self):
err = EmergencyError(
error_code="x",
message="just a message",
suggestions=[],
retryable=False,
original_error="e",
)
assert err.to_error_message() == "just a message"
def test_to_error_message_single_suggestion(self):
err = EmergencyError(
error_code="x",
message="msg",
suggestions=["only one"],
retryable=False,
original_error="e",
)
msg = err.to_error_message()
assert msg == "msg建议1) only one。"
# ── Config override ──────────────────────────────────
class TestConfigOverride:
"""classify() applies per-rule config overrides."""
def test_override_suggestions(self):
exc = TaskTimeoutError(task_id="t", timeout_seconds=1)
cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.suggestions == ["自定义建议 A", "自定义建议 B"]
assert err.error_code == "timeout"
def test_override_retryable(self):
exc = LLMProviderError("openai", "boom")
cfg = {"llm_failure": {"retryable": False}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.retryable is False
def test_override_message(self):
exc = LoopDetectedError(tool_name="x", repetitions=2)
cfg = {"loop_detected": {"message": "循环啦!"}}
err = EmergencyRules.classify(exc, config=cfg)
assert err.message == "循环啦!"
def test_override_internal_error_rule(self):
cfg = {"internal_error": {"suggestions": ["联系客服"]}}
err = EmergencyRules.classify(Exception("boom"), config=cfg)
assert err.error_code == "internal_error"
assert err.suggestions == ["联系客服"]
def test_config_none_uses_defaults(self):
err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1))
assert err.error_code == "timeout"
assert err.retryable is True
def test_config_empty_dict_uses_defaults(self):
err = EmergencyRules.classify(
TaskTimeoutError(task_id="t", timeout_seconds=1), config={}
)
assert err.error_code == "timeout"
assert err.retryable is True
# ── TaskResult.error_struct extension ────────────────
def _make_task_result(
error_struct: dict | None = None, error_message: str | None = None
) -> TaskResult:
now = datetime.now(timezone.utc)
return TaskResult(
task_id="t1",
agent_name="a1",
status="completed",
output_data={"k": "v"},
error_message=error_message,
started_at=now,
completed_at=now,
metrics={"m": 1},
error_struct=error_struct,
)
class TestTaskResultErrorStruct:
"""TaskResult.error_struct field — backward-compatible extension."""
def test_default_error_struct_is_none(self):
tr = _make_task_result()
assert tr.error_struct is None
def test_to_dict_without_error_struct_preserves_existing_shape(self):
"""error_struct=None → to_dict() output has NO error_struct key (byte-for-byte)."""
tr = _make_task_result()
d = tr.to_dict()
assert "error_struct" not in d
# Existing keys unchanged
assert set(d.keys()) == {
"task_id",
"agent_name",
"status",
"output_data",
"error_message",
"started_at",
"completed_at",
"metrics",
}
def test_to_dict_with_error_struct_includes_key(self):
struct = {
"error_code": "timeout",
"message": "超时",
"suggestions": ["重试"],
"retryable": True,
"original_error": "boom",
}
tr = _make_task_result(error_struct=struct, error_message="超时建议1) 重试。")
d = tr.to_dict()
assert d["error_struct"] == struct
assert d["error_message"] == "超时建议1) 重试。"
def test_from_dict_round_trip_with_error_struct(self):
struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"}
tr = _make_task_result(error_struct=struct)
d = tr.to_dict()
restored = TaskResult.from_dict(d)
assert restored.error_struct == struct
def test_from_dict_without_error_struct_defaults_none(self):
tr = _make_task_result()
d = tr.to_dict()
# Simulate legacy data without error_struct key
restored = TaskResult.from_dict(d)
assert restored.error_struct is None
def test_error_message_and_error_struct_coexist(self):
"""Both fields can be set simultaneously (parallel contract per KTD2)."""
struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"}
tr = _make_task_result(error_struct=struct, error_message="超时建议1) 重试。")
d = tr.to_dict()
assert d["error_message"] == "超时建议1) 重试。"
assert d["error_struct"] == struct

View File

@ -0,0 +1,404 @@
"""G7/U3 — Three-tier fallback chain wiring tests.
Verifies Main Recovery (ReflexionEngine) Emergency (EmergencyRules)
at chat REST path. Mocks ReActEngine + ReflexionEngine + LLMGateway.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.exceptions import (
LLMProviderError,
LoopDetectedError,
TaskCancelledError,
TaskTimeoutError,
)
from agentkit.core.react import ReActResult
from agentkit.server._fallback_chain import execute_with_fallback_chain
def _make_react_result(status: str = "success", output: str = "ok") -> ReActResult:
return ReActResult(
output=output,
trajectory=[],
total_steps=1,
total_tokens=10,
status=status,
)
def _make_react_engine(result=None, raises=None):
"""Build a fake ReActEngine with .execute returning result or raising."""
engine = MagicMock()
engine.reset = MagicMock()
if raises is not None:
engine.execute = AsyncMock(side_effect=raises)
else:
engine.execute = AsyncMock(return_value=result or _make_react_result())
return engine
def _make_llm_gateway():
gw = MagicMock()
gw.chat = AsyncMock(return_value=MagicMock(content="recovered"))
return gw
def _make_reflexion_result(status: str = "success", output: str = "recovered"):
"""Synthesize a ReflexionResult-like object."""
return MagicMock(
status=status,
output=output,
trajectory=[],
total_steps=1,
total_tokens=5,
)
@pytest.fixture
def patched_reflexion(monkeypatch):
"""Patch ReflexionEngine used inside the chain to a controllable mock."""
from agentkit.server import _fallback_chain
instances: list[MagicMock] = []
class _MockReflexion:
def __init__(self, llm_gateway, max_reflections=1, **kwargs):
self._llm_gateway = llm_gateway
self._max_reflections = max_reflections
self.execute = AsyncMock(return_value=_make_reflexion_result())
instances.append(self)
monkeypatch.setattr(_fallback_chain, "ReflexionEngine", _MockReflexion)
return instances
# ─── Tier 1: Main ─────────────────────────────────────────────────────────
class TestMainTier:
@pytest.mark.asyncio
async def test_main_success_no_recovery_no_emergency(self):
engine = _make_react_engine(result=_make_react_result(status="success", output="hello"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "success"
assert result.output == "hello"
assert result.error_struct is None
@pytest.mark.asyncio
async def test_main_unknown_status_treated_as_success(self):
"""Unknown status (not in soft_failure set) is treated as success-like."""
engine = _make_react_engine(result=_make_react_result(status="partial", output="x"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "success"
# ─── Tier 2: Recovery ──────────────────────────────────────────────────────
class TestRecoveryTier:
@pytest.mark.asyncio
async def test_main_timeout_triggers_recovery_success(self, patched_reflexion):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
assert result.output == "recovered"
# ReflexionEngine was instantiated and called
assert len(patched_reflexion) == 1
patched_reflexion[0].execute.assert_awaited_once()
@pytest.mark.asyncio
async def test_main_loop_detected_triggers_recovery(self, patched_reflexion):
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_main_llm_provider_error_triggers_recovery(self, patched_reflexion):
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="503"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_main_soft_failure_status_triggers_recovery(self, patched_reflexion):
"""Soft failure (empty_fallback) without exception still triggers Recovery."""
engine = _make_react_engine(result=_make_react_result(status="empty_fallback", output=""))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
assert result.status == "recovered"
@pytest.mark.asyncio
async def test_recovery_disabled_skips_to_emergency(self):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
@pytest.mark.asyncio
async def test_recovery_failure_falls_through_to_emergency(self, patched_reflexion):
"""Recovery raises → Emergency tier fires with original exception."""
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
# Make ReflexionEngine.execute raise
patched_reflexion_instance = MagicMock()
patched_reflexion_instance.execute = AsyncMock(
side_effect=RuntimeError("reflexion crashed")
)
# Override the patched class to use our instance
from agentkit.server import _fallback_chain
original_cls = _fallback_chain.ReflexionEngine
class _MockReflexionWithExc:
def __init__(self, **kwargs):
self.execute = patched_reflexion_instance.execute
_fallback_chain.ReflexionEngine = _MockReflexionWithExc
try:
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
@pytest.mark.asyncio
async def test_recovery_unsuccessful_status_falls_through(self, patched_reflexion):
"""Recovery returns non-success status → Emergency fires."""
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
# Make ReflexionEngine return unsuccessful result with empty output
from agentkit.server import _fallback_chain
class _MockReflexionNoOutput:
def __init__(self, **kwargs):
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
original_cls = _fallback_chain.ReflexionEngine
_fallback_chain.ReflexionEngine = _MockReflexionNoOutput
try:
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
# ─── Tier 3: Emergency ────────────────────────────────────────────────────
class TestEmergencyTier:
@pytest.mark.asyncio
async def test_emergency_timeout_error_code(self, patched_reflexion):
# Make recovery fail (empty result) so Emergency fires
from agentkit.server import _fallback_chain
class _MockReflexionEmpty:
def __init__(self, **kwargs):
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
original_cls = _fallback_chain.ReflexionEngine
_fallback_chain.ReflexionEngine = _MockReflexionEmpty
try:
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
)
finally:
_fallback_chain.ReflexionEngine = original_cls
assert result.status == "emergency"
assert result.error_struct["error_code"] == "timeout"
assert result.error_struct["retryable"] is True
assert "建议" in result.output
@pytest.mark.asyncio
async def test_emergency_loop_detected_error_code(self):
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
# Recovery disabled so Emergency fires directly
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "loop_detected"
@pytest.mark.asyncio
async def test_emergency_llm_failure_error_code(self):
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="500"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "llm_failure"
@pytest.mark.asyncio
async def test_emergency_internal_error_for_generic_exception(self):
engine = _make_react_engine(raises=RuntimeError("unexpected"))
result = await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
assert result.status == "emergency"
assert result.error_struct["error_code"] == "internal_error"
assert result.error_struct["retryable"] is False
@pytest.mark.asyncio
async def test_task_cancelled_propagates_not_routed_to_emergency(self):
"""TaskCancelledError must propagate, not be classified by Emergency."""
engine = _make_react_engine(raises=TaskCancelledError(task_id="t1"))
with pytest.raises(TaskCancelledError):
await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={"recovery": {"enabled": False}},
)
@pytest.mark.asyncio
async def test_emergency_disabled_reraises_original(self):
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
with pytest.raises(TaskTimeoutError):
await execute_with_fallback_chain(
react_engine=engine,
llm_gateway=_make_llm_gateway(),
messages=[{"role": "user", "content": "hi"}],
tools=[],
model="default",
agent_name="a",
system_prompt=None,
fallback_chain_config={
"recovery": {"enabled": False},
"emergency": {"enabled": False},
},
)
# ─── Config wiring ────────────────────────────────────────────────────────
class TestServerConfigFallbackChain:
def test_fallback_chain_section_read_from_dict(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict(
{
"fallback_chain": {
"enabled": True,
"recovery": {"enabled": False, "max_retries": 3},
"emergency": {"enabled": True},
}
}
)
assert config.fallback_chain["enabled"] is True
assert config.fallback_chain["recovery"] == {"enabled": False, "max_retries": 3}
assert config.fallback_chain["emergency"] == {"enabled": True}
def test_fallback_chain_defaults_empty_when_absent(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.fallback_chain == {}

View File

@ -0,0 +1,367 @@
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
Characterization-first: captures pre-change behavior (rollback_command=None
no rollback, checkpoint saved) before asserting new rollback behavior.
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import PlanPhase, TeamPlan
from agentkit.orchestrator.rollback import RollbackExecutor
# ─── PlanPhase field characterization ──────────────────────────────────────
class TestPlanPhaseFields:
"""PlanPhase serialization for new fields."""
def test_characterization_no_new_keys_when_unset(self):
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
ph = PlanPhase(name="执行", assigned_expert="lead")
out = ph.to_dict()
assert "validation_command" not in out
assert "rollback_command" not in out
# Pre-change keys remain
assert out["name"] == "执行"
assert out["assigned_expert"] == "lead"
assert out["status"] == "pending"
def test_characterization_from_dict_empty_yields_none(self):
"""from_dict({}) produces both new fields as None."""
ph = PlanPhase.from_dict({"name": "x"})
assert ph.validation_command is None
assert ph.rollback_command is None
def test_serialization_includes_keys_when_set(self):
ph = PlanPhase(
name="frontend",
validation_command="ruff check src/",
rollback_command="git checkout src/app.vue",
)
out = ph.to_dict()
assert out["validation_command"] == "ruff check src/"
assert out["rollback_command"] == "git checkout src/app.vue"
def test_serialization_round_trip(self):
ph = PlanPhase(
name="backend",
validation_command="pytest -x -q",
rollback_command="git checkout src/api.py",
)
restored = PlanPhase.from_dict(ph.to_dict())
assert restored.validation_command == "pytest -x -q"
assert restored.rollback_command == "git checkout src/api.py"
def test_only_validation_set_still_emits_validation_key(self):
"""Asymmetric case: only validation_command set — only that key appears."""
ph = PlanPhase(name="x", validation_command="echo ok")
out = ph.to_dict()
assert "validation_command" in out
assert "rollback_command" not in out
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
@pytest.fixture
def tmp_workspace(tmp_path):
"""Fresh working directory for subprocess execution."""
return str(tmp_path)
class TestRollbackExecutor:
"""RollbackExecutor happy/edge/failure paths."""
@pytest.mark.asyncio
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("true")
assert r.passed is True
assert r.exit_code == 0
assert r.command == "true"
@pytest.mark.asyncio
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("false")
assert r.passed is False
assert r.exit_code != 0
@pytest.mark.asyncio
async def test_execute_timeout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
r = await ex.execute("sleep 5")
assert r.passed is False
assert r.exit_code == -1
assert "timed out" in r.stderr
@pytest.mark.asyncio
async def test_execute_captures_stdout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("echo hello-rollback")
assert r.passed is True
assert "hello-rollback" in r.stdout
@pytest.mark.asyncio
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
"""validate() is just execute() with different intent marker."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.validate("true")
assert r.passed is True
r2 = await ex.validate("false")
assert r2.passed is False
@pytest.mark.asyncio
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
"""Bad shell command should surface as failed result, not raise."""
# Use a definitely-broken command — shell still spawns, returns non-zero.
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("exit 7")
assert r.passed is False
assert r.exit_code == 7
@pytest.mark.asyncio
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
"""Files created in working_dir are visible via cwd-relative commands."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
await ex.execute("echo content > testfile.txt")
r = await ex.execute("test -f testfile.txt")
assert r.passed is True
# ─── Real git rollback integration ────────────────────────────────────────
class TestGitRollbackIntegration:
"""Real git repo fixture: writes file, rollback restores via git checkout."""
@pytest.mark.asyncio
async def test_git_checkout_restores_modified_file(self, tmp_path):
# Init repo and commit a baseline file
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
baseline = "original content\n"
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write(baseline)
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
# Mutate the file
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write("mutated content\n")
# Run git checkout as rollback_command
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
r = await ex.execute("git checkout foo.txt")
assert r.passed is True
with open(os.path.join(repo, "foo.txt")) as f:
assert f.read() == baseline
# ─── TeamOrchestrator integration ─────────────────────────────────────────
def _make_team_mock():
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
team = MagicMock()
team.team_id = "test-team"
team.status.value = "executing"
team.lead_expert = None
team.active_experts = []
return team
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
"""Build a TeamOrchestrator with mocked team and checkpoint."""
team = team or _make_team_mock()
orch = TeamOrchestrator(
team=team,
checkpoint=checkpoint,
workspace_root=workspace_root,
rollback_timeout=2.0,
)
orch._broadcast_event = AsyncMock()
return orch
class TestOrchestratorRollbackIntegration:
"""TeamOrchestrator phase failure path integration with rollback."""
@pytest.mark.asyncio
async def test_characterization_no_rollback_when_unset(self, tmp_path):
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
# Call _run_phase_rollback — but it should NOT be called from
# orchestrator path when fields unset. Verify by simulating the
# main-loop guard directly.
from agentkit.experts.plan import PhaseStatus as PS
ph.status = PS.FAILED
# Simulate the guard condition in orchestrator.py:280-288
should_save = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(Exception("x"), (Exception,))
):
should_save = await orch._run_phase_rollback(plan, ph)
# Guard should not fire
assert should_save is True
# No rollback events broadcast
assert orch._broadcast_event.call_count == 0
@pytest.mark.asyncio
async def test_validation_passes_no_rollback_executed(self, tmp_path):
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="true", # always passes
rollback_command="false", # would fail if executed
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# Events: started + completed (no rollback)
events = [
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
]
assert "phase_rollback_started" in events
assert "phase_rollback_completed" in events
assert "phase_rollback_failed" not in events
@pytest.mark.asyncio
async def test_validation_fails_rollback_succeeds(self, tmp_path):
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
# Create file under git tracking, then mutate it
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("base\n")
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("mutated\n")
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # validation fails
rollback_command="git checkout f.txt", # rollback succeeds
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# File restored to baseline
with open(os.path.join(repo, "f.txt")) as f:
assert f.read() == "base\n"
@pytest.mark.asyncio
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # fails
rollback_command="false", # also fails
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False # R21: skip checkpoint
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_started" in events
assert "phase_rollback_failed" in events
# phase_rollback_completed should NOT be in events
assert "phase_rollback_completed" not in events
@pytest.mark.asyncio
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
orch._rollback_timeout = 0.1 # short timeout
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false",
rollback_command="sleep 5",
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_failed" in events
# ─── Config wiring ────────────────────────────────────────────────────────
class TestServerConfigRollback:
"""ServerConfig rollback section wiring."""
def test_rollback_section_read_from_dict(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict(
{
"rollback": {
"default_timeout": 45.0,
}
}
)
assert config.rollback == {"default_timeout": 45.0}
def test_rollback_defaults_empty_when_absent(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.rollback == {}