From a2dcde01b80f61e7bbcef832891014f8c2cdfc34 Mon Sep 17 00:00:00 2001 From: Fischer Date: Tue, 30 Jun 2026 09:09:33 +0800 Subject: [PATCH] feat(agent): Wave 2 medium coupling (G4/G7/G9) (#5) --- agentkit.yaml | 23 + ...3-feat-agent-wave2-medium-coupling-plan.md | 481 ++++++++++++++++++ src/agentkit/core/compressor.py | 83 ++- src/agentkit/core/fallback.py | 146 ++++++ src/agentkit/core/protocol.py | 7 + src/agentkit/experts/orchestrator.py | 106 +++- src/agentkit/experts/plan.py | 15 +- src/agentkit/llm/config.py | 4 + src/agentkit/orchestrator/rollback.py | 97 ++++ src/agentkit/server/_fallback_chain.py | 199 ++++++++ src/agentkit/server/config.py | 17 + src/agentkit/server/routes/chat.py | 25 +- tests/unit/test_compressor_auxiliary.py | 400 +++++++++++++++ tests/unit/test_emergency_rules.py | 318 ++++++++++++ tests/unit/test_fallback_chain.py | 404 +++++++++++++++ tests/unit/test_phase_rollback.py | 367 +++++++++++++ 16 files changed, 2658 insertions(+), 34 deletions(-) create mode 100644 docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md create mode 100644 src/agentkit/orchestrator/rollback.py create mode 100644 src/agentkit/server/_fallback_chain.py create mode 100644 tests/unit/test_compressor_auxiliary.py create mode 100644 tests/unit/test_emergency_rules.py create mode 100644 tests/unit/test_fallback_chain.py create mode 100644 tests/unit/test_phase_rollback.py diff --git a/agentkit.yaml b/agentkit.yaml index 692c566..9588b70 100644 --- a/agentkit.yaml +++ b/agentkit.yaml @@ -28,6 +28,29 @@ llm: coding: bailian-coding/qwen3-coder-plus chat: deepseek/deepseek-chat reasoning: deepseek/deepseek-reasoner + # G4/U1: Auxiliary model for cost-sensitive tasks (summarization). + # When set, ContextCompressor tries this alias first, falling back to + # the main model on failure or empty content. Commented to preserve + # default behavior — uncomment to enable. + # auxiliary_model: fast +# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout +# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase +# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs. +# Canonical rollback pattern: `git checkout ` (file-scoped, +# not `git checkout .` which would wipe unrelated changes). +rollback: + default_timeout: 30.0 +# G7/U3: Three-tier fallback chain at chat REST send_message. +# main → Recovery (ReflexionEngine retry) → Emergency (rule-based classifier). +# Wired only at chat REST path (KTD5); CLI / ReWOO / Reflexion internal +# ReAct calls bypass the chain (no recursive loop). +fallback_chain: + enabled: true + recovery: + enabled: true + max_retries: 1 # ReflexionEngine max_reflections override + emergency: + enabled: true session: {backend: memory} bus: {backend: memory} task_store: {backend: memory} diff --git a/docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md b/docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md new file mode 100644 index 0000000..e0e7bde --- /dev/null +++ b/docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md @@ -0,0 +1,481 @@ +--- +date: 2026-06-29 +type: feat +origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md +--- + +# Wave 2 — Auxiliary LLM Routing, Three-Tier Fallback, Atomic Subtask Rollback + +## Summary + +Wave 2 of the advanced-agent gap optimization brainstorm. Three medium-coupling gaps: G4 routes summary calls through a cheaper auxiliary LLM (falling back to main on failure), G7 introduces a three-tier fallback chain (main → Recovery via existing `ReflexionEngine` → Emergency with structured errors), G9 binds atomic subtask rollback to `PlanPhase` (opt-in via `rollback_command`, coordinated with the existing U7 `PipelineCheckpoint`). + +--- + +## Problem Frame + +Wave 1 shipped self-contained quick wins (G1/G2/G3/G8). Wave 2 addresses the medium-coupling gaps that touch multiple layers and resolve two deferred-to-planning decisions surfaced in the brainstorm: G7 Emergency layer rule template shape, and G9 `rollback_command` default behavior. The constraints these gaps must respect come from `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — new fields must preserve existing contracts, dynamic plan mutations must persist immediately, and `PipelineCheckpoint` is in-memory dict + Redis fallback (not a DB row lock). + +--- + +## Requirements + +Carried from `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` § Wave 2 (R13-R21): + +- R13. `ContextCompressor` summary task routes to auxiliary model (cheap, e.g. Gemini Flash / Doubao lite), not main model. +- R14. `auxiliary_model` is configurable, separated from main `model`. +- R15. Summary quality does not degrade: auxiliary failure falls back to main model (not to simple_summary). +- R16. Main agent failure triggers Recovery layer (reuse `ReflexionEngine` Evaluate→Reflect→Retry). +- R17. Recovery failure triggers Emergency layer (rule-based fallback, structured error + suggestion). +- R18. Fallback chain is configurable (per-layer max retries, enable/disable Recovery/Emergency). +- R19. `PlanPhase` (`experts/plan.py:147`) gains optional `validation_command` and `rollback_command` fields. +- R20. Phase failure auto-executes `rollback_command` when configured (default `git checkout` pattern); coordinated with U7 checkpoint. +- R21. `checkpoint.save` runs only after rollback validation passes (avoid persisting failed state). + +Cross-cutting: R26 (configuration via `agentkit.yaml`), R27 (each gap ships a minimal self-check test, per ponytail rule). + +--- + +## Key Technical Decisions + +KTD1. **Auxiliary routing is compressor-scoped, not gateway-level.** `LLMGateway` already has a per-model fallback chain (`gateway.py:170-227`), but it semantics are "main fails → try backup model", not "route by task type". G4 adds an `auxiliary_model` parameter to `ContextCompressor.__init__`; `_summarize` tries auxiliary first, falls back to main on failure. This avoids touching the gateway's existing fallback semantics and keeps the change to one file plus config wiring. + +KTD2. **Emergency layer adds a parallel `error_struct` field, not replacing `error_message`.** `TaskResult.error_message: str | None` is serialized to API responses (`protocol.py:132-145`) — changing its type breaks frontend contracts. New `error_struct: dict | None` field carries `{error_code, message, suggestions, retryable}`. `error_message` remains the human-readable string (now derived from `error_struct.message` when present). Backward compatibility preserved. + +KTD3. **Emergency rule classification by exception type, not by status field.** `TaskTimeoutError` → `timeout`, `LoopDetectedError` → `loop_detected`, `LLMProviderError` → `llm_failure`, `TaskCancelledError` → `cancelled` (not Emergency-eligible, propagates), generic `Exception` → `internal_error`. Each rule maps to a fixed message + suggestion list (no LLM-driven suggestions — pure rule-based, per brainstorm). + +KTD4. **Recovery layer reuses `ReflexionEngine` with main model, not auxiliary.** `ReflexionEngine.execute()` already supports `evaluate_model`/`reflect_model` parameters (`reflexion.py:94-111`), defaulting to the main model. Recovery triggers on main agent failure; using the main model for evaluate/reflect maximizes diagnostic reliability (the same model that failed is best positioned to reflect on its own failure). Auxiliary model is reserved for G4's cost-sensitive compression path. + +KTD5. **Three-tier chain wires at `chat.py:613` (WebSocket main path), not at `ReActEngine.execute()`.** ReAct has 5 call sites (`chat.py:613`, `rewoo.py:1204`, `cli/chat.py:338`, `reflexion.py:216`, `config_driven.py:695`). Wrapping at ReAct would force all 5 sites through Recovery/Emergency, including ReflexionEngine itself (creating a recursive loop). Wrapping at `chat.py:613` covers the primary user-facing path; CLI and ReWOO are out of scope (deferred to follow-up). + +KTD6. **Rollback is opt-in via `rollback_command` field; no implicit default.** Brainstorm KTD5 mentioned "default `git checkout`" but auto-executing `git checkout` on every phase failure changes existing behavior (Finding 1: "新字段默认值须保持既有契约"). Resolution: `rollback_command: str | None = None`. When unset, no rollback executes (preserves existing contract). When set, the configured command runs after validation failure, before checkpoint save. `git checkout ` is the canonical pattern documented in YAML examples. + +KTD7. **Rollback executes via `RollbackExecutor`, not via `ShellTool`.** `ShellTool._is_dangerous` would trigger `confirm_callback` for `git checkout` (not in `_SAFE_COMMAND_PREFIXES`, not in `_DANGEROUS_BINARY_FLAGS`). Orchestrator-internal execution bypasses ShellTool — uses `asyncio.create_subprocess_shell` with `cwd`/`timeout`/`proc.kill()` pattern from `verification_loop.py:67-103`. Audit logging added (not a whitelist concern; terminal_whitelist only governs `/api/v1/terminal/server` route per `terminal_server.py:227`). + +KTD8. **G9 checkpoint ordering: validation → rollback → checkpoint save.** Currently `orchestrator.py:265` saves checkpoint unconditionally after phase finalizes (success or failure). R21 requires checkpoint save only after rollback validation passes. New ordering: phase fails → mark FAILED → mark dependents FAILED → execute `validation_command` (if set) → if validation fails, execute `rollback_command` (if set) → save checkpoint (only if rollback validation passed or no rollback configured). + +--- + +## High-Level Technical Design + +### Three-tier fallback chain (G7) + +```mermaid +stateDiagram-v2 + [*] --> Main: chat request + Main --> MainSuccess: success + Main --> Recovery: failure (exception or empty_fallback) + Recovery --> RecoverySuccess: ReflexionEngine retry succeeds + Recovery --> Emergency: max_reflections exhausted + Emergency --> [*]: emit error_struct, terminal + MainSuccess --> [*]: normal response + RecoverySuccess --> [*]: recovered response +``` + +### G9 phase failure + rollback sequence + +```mermaid +sequenceDiagram + participant O as TeamOrchestrator + participant P as PlanPhase + participant RE as RollbackExecutor + participant CP as PipelineCheckpoint + + O->>P: execute phase + P-->>O: failure (exception) + O->>O: mark FAILED + dependents FAILED + alt validation_command set + O->>RE: run validation_command + RE-->>O: ValidationResult + alt validation fails AND rollback_command set + O->>RE: run rollback_command + RE-->>O: RollbackResult + alt rollback validation passes + O->>CP: save checkpoint + else rollback validation fails + O->>O: log rollback_failed, skip checkpoint + end + else no rollback_command + O->>CP: save checkpoint (no rollback) + end + else no validation_command + O->>CP: save checkpoint (no validation) + end +``` + +### Configuration shape (YAML) + +```yaml +# G4: Auxiliary LLM for compression +llm: + auxiliary_model: fast # alias resolved via existing model_aliases mechanism + +# G7: Three-tier fallback chain +fallback_chain: + enabled: true + recovery: + enabled: true + max_retries: 1 # ReflexionEngine max_reflections override + emergency: + enabled: true + +# G9: Rollback configuration (per-phase opt-in via PlanPhase.rollback_command) +rollback: + default_timeout: 30.0 # RollbackExecutor subprocess timeout +``` + +--- + +## Implementation Units + +### U1. G4 — Auxiliary LLM Routing in ContextCompressor + +**Goal:** Route `_summarize` LLM calls through `auxiliary_model` when configured; fall back to main model on auxiliary failure (not to `_simple_summary`). + +**Requirements:** R13, R14, R15, R26, R27 + +**Dependencies:** none (self-contained) + +**Files:** +- Modify: `src/agentkit/core/compressor.py` (add `auxiliary_model` param + routing in `_summarize`) +- Modify: `src/agentkit/llm/config.py` (add `auxiliary_model: str | None` field to `LLMConfig`) +- Modify: `src/agentkit/server/config.py` (`_build_llm_config` reads `auxiliary_model` from llm section) +- Modify: `agentkit.yaml` (document `llm.auxiliary_model: fast` in llm section) +- Create: `tests/unit/test_compressor_auxiliary.py` + +**Approach:** `ContextCompressor.__init__` gains `auxiliary_model: str | None = None` param (after existing `model` param). `_summarize` (compressor.py:123-158) restructuring: + +1. If `auxiliary_model` is set and differs from `model`, try `auxiliary_model` first via `self._llm_gateway.chat(model=self._auxiliary_model, ...)`. +2. **Truthiness check (Finding 4 anti-pattern):** treat empty `response.content` (None or whitespace-only) as failure, not success. On failure, log and fall through to main model. +3. If auxiliary succeeds with non-empty content, return it. +4. If auxiliary fails (exception OR empty content), retry with main `self._model`. +5. Existing `except Exception → _simple_summary` block remains as the final degradation tier. + +`LLMConfig.auxiliary_model` reads from `data.get("auxiliary_model")` in `_build_llm_config`. Agentkit.yaml already declares `fast: bailian-coding/qwen-turbo` alias — this is the canonical auxiliary target. + +**Execution note:** characterization-first. Capture current `_summarize` behavior (single model, exception → simple_summary) with `auxiliary_model=None` tests, then add `auxiliary_model="fast"` tests for new routing behavior. + +**Patterns to follow:** +- `LLMGateway._get_models_to_try` (`gateway.py:170-227`) — existing fallback chain pattern +- Wave 1's `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections, `config.py:239-273`) + +**Test scenarios:** +- Happy path: `auxiliary_model="fast"` set, auxiliary returns non-empty content → result is auxiliary content, main model not called +- Empty content fallback: auxiliary returns `content=""` (or `None`) → main model called, main content returned (covers Finding 4 anti-pattern) +- Auxiliary exception: auxiliary raises `LLMProviderError` → main model called, main content returned +- Both fail: auxiliary raises, main raises → `_simple_summary` returned (existing degradation preserved) +- Characterization: `auxiliary_model=None` → behavior matches current code (single model call) +- Config wiring: `LLMConfig.from_dict` reads `auxiliary_model` field from dict; `ServerConfig._build_llm_config` passes it through +- Audit: auxiliary call uses `agent_name="compressor"`, `task_type="summarization"` (preserved for usage tracking) + +**Verification:** All test scenarios pass; existing `tests/unit/test_compressor*.py` (if any) still pass with `auxiliary_model=None` default; ruff clean. + +--- + +### U2. G7 — Emergency Layer Rule Template + TaskResult Extension + +**Goal:** Add rule-based Emergency classifier with structured error output. No wiring yet — just the infrastructure (classifier + data structure + `fallback.py` extension). + +**Requirements:** R17, R18, R26, R27 + +**Dependencies:** none (foundation for U3) + +**Files:** +- Modify: `src/agentkit/core/fallback.py` (add `EmergencyRules` class + `EmergencyError` dataclass, preserve existing 3 constants) +- Modify: `src/agentkit/core/protocol.py` (add `error_struct: dict | None = None` field to `TaskResult`, update `to_dict`) +- Create: `tests/unit/test_emergency_rules.py` + +**Approach:** + +`EmergencyError` dataclass (in `fallback.py`): +``` +@dataclass +class EmergencyError: + error_code: str # "timeout"|"loop_detected"|"llm_failure"|"internal_error" + message: str # human-readable Chinese message (mirrors EMPTY_LLM_RESPONSE style) + suggestions: list[str] # actionable user-facing suggestions + retryable: bool # whether user retry might succeed + original_error: str # str(exc) for traceability + + def to_dict(self) -> dict: ... + def to_error_message(self) -> str: ... # formatted "message\n建议:1) ... 2) ..." +``` + +`EmergencyRules` class (rule-based classifier, no LLM): +``` +class EmergencyRules: + @staticmethod + def classify(exc: Exception, config: dict | None = None) -> EmergencyError: + # Match by exception type (TaskTimeoutError, LoopDetectedError, LLMProviderError, etc.) + # config allows per-rule customization (suggestion overrides, retryable overrides) +``` + +Rule mapping (initial set, expandable via config): +- `TaskTimeoutError` → `error_code="timeout"`, `retryable=True`, suggestions: ["稍后重试", "简化任务范围"] +- `LoopDetectedError` → `error_code="loop_detected"`, `retryable=True`, suggestions: ["拆分任务", "检查工具参数"] +- `LLMProviderError` → `error_code="llm_failure"`, `retryable=True`, suggestions: ["稍后重试", "切换模型"] +- `TaskCancelledError` → not classified (propagates as-is, Emergency not triggered) +- Generic `Exception` → `error_code="internal_error"`, `retryable=False`, suggestions: ["联系管理员"] + +`TaskResult` extension: +``` +@dataclass +class TaskResult: + # ... existing fields ... + error_message: str | None # unchanged + error_struct: dict | None = None # NEW: serialized EmergencyError.to_dict() +``` + +`to_dict` includes `error_struct` when set; `error_message` continues to hold the human-readable string. + +**Patterns to follow:** +- Existing `EMPTY_LLM_RESPONSE` / `MAX_STEPS_REACHED` style in `fallback.py` (Chinese, "建议:..." format) +- `VerificationResult.errors: list[str]` field pattern from `verification_loop.py:18-24` +- `TaskResult.to_dict` pattern at `protocol.py:132-145` + +**Test scenarios:** +- Happy path: `classify(TaskTimeoutError(...))` returns `EmergencyError(error_code="timeout", retryable=True)` +- Each exception type maps to correct `error_code` +- `TaskCancelledError` is NOT classified (caller must check before invoking `classify`) +- `to_dict()` produces all 5 fields +- `to_error_message()` formats suggestions as "建议:1) ... 2) ..." +- `EmergencyRules.classify(Exception("unknown"))` → `error_code="internal_error"`, `retryable=False` +- Config override: custom suggestion list for `timeout` rule via config dict +- `TaskResult` with `error_struct` set: `to_dict()` includes both `error_message` and `error_struct` +- `TaskResult` with `error_struct=None` (default): `to_dict()` matches current behavior (backward compat) + +**Verification:** All scenarios pass; `fallback.py` existing 3 constants unchanged; `TaskResult.to_dict` for tasks without `error_struct` matches pre-change output byte-for-byte. + +--- + +### U3. G7 — Three-Tier Fallback Chain Wiring + +**Goal:** Wire main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) at `chat.py:613`. Composes U2's infrastructure with existing `ReflexionEngine`. + +**Requirements:** R16, R18, R26 + +**Dependencies:** U2 (EmergencyRules + TaskResult.error_struct) + +**Files:** +- Modify: `src/agentkit/server/routes/chat.py` (wrap main agent call at L613 with three-tier chain) +- Modify: `src/agentkit/server/config.py` (add `fallback_chain` section to `from_dict`) +- Modify: `agentkit.yaml` (document `fallback_chain:` section) +- Create: `tests/unit/test_fallback_chain.py` + +**Approach:** + +New helper module or inline function in `chat.py`: +``` +async def execute_with_fallback_chain( + agent: ConfigDrivenAgent, + task: Task, + config: dict, # fallback_chain config section + llm_gateway: LLMGateway, +) -> TaskResult: + # Tier 1: Main + try: + result = await agent.execute(task) + if result.status == "success" or result.status == "completed": + return result + # Treat non-success as soft failure → trigger Recovery + raise AgentExecutionError(result.error_message or "main agent did not succeed") + except (TaskTimeoutError, LoopDetectedError, LLMProviderError, AgentExecutionError) as exc: + if not config.get("recovery", {}).get("enabled", True): + return _to_emergency(exc, task) + + # Tier 2: Recovery (ReflexionEngine) + try: + reflexion_engine = ReflexionEngine( + llm_gateway=llm_gateway, + max_reflections=config.get("recovery", {}).get("max_retries", 1), + # ... other params from agent's existing reflexion config + ) + recovery_result = await reflexion_engine.execute( + messages=task.messages, # rebuild from task + tools=agent.get_tools(), + model=task.model, + agent_name=agent.name, + task_id=task.task_id, + ) + if recovery_result.status == "success": + return _recovery_to_task_result(recovery_result, task) + except Exception as recovery_exc: + logger.warning(f"Recovery layer failed: {recovery_exc}") + + # Tier 3: Emergency + return _to_emergency(exc, task) +``` + +`_to_emergency(exc, task)` constructs `EmergencyError` via `EmergencyRules.classify(exc, config)`, then returns `TaskResult(status=FAILED, error_message=emergency.to_error_message(), error_struct=emergency.to_dict(), ...)`. + +Wiring at `chat.py:613`: replace direct `agent.execute(task)` call with `execute_with_fallback_chain(...)`. Recovery config from `server_config.fallback_chain` (new `ServerConfig.fallback_chain: dict` field, mirroring Wave 1 pattern). + +**Recovery layer scope (KTD5):** Only `chat.py:613` is wrapped. CLI (`cli/chat.py:338`), ReWOO (`rewoo.py:1204`), ReflexionEngine's internal ReAct call, and `config_driven.py:695` are NOT wrapped (would create recursive loop or unwanted coupling). These remain on the direct-execute path. Documented in `## Scope Boundaries`. + +**Patterns to follow:** +- `config_driven.py:836` — existing `ReflexionEngine` instantiation pattern (constructor params, execute call) +- Wave 1's `verification` config section in `ServerConfig.from_dict` (`config.py:240-273`) + +**Test scenarios:** +- Happy path: main agent succeeds → no Recovery, no Emergency triggered; `error_struct=None` +- Main fails (timeout) → Recovery triggered → Recovery succeeds → `error_struct=None`, output from ReflexionEngine +- Main fails (timeout) → Recovery triggered → Recovery fails (max_reflections exhausted) → Emergency triggered → `error_struct` populated with `error_code="timeout"`, `retryable=True` +- Main fails (LoopDetectedError) → Emergency `error_code="loop_detected"` +- Main fails (LLMProviderError) → Emergency `error_code="llm_failure"` +- Main fails (TaskCancelledError) → propagates as-is (NOT routed to Emergency) +- Main fails (generic Exception) → Emergency `error_code="internal_error"`, `retryable=False` +- Config: `fallback_chain.recovery.enabled=false` → skip Recovery, go directly to Emergency +- Config: `fallback_chain.emergency.enabled=false` → re-raise original exception (no Emergency) +- Integration: full chain on real `ConfigDrivenAgent` with mocked LLM (mock main raises, mock ReflexionEngine succeeds) + +**Verification:** All scenarios pass; existing chat WebSocket tests still pass; ruff clean. + +--- + +### U4. G9 — PlanPhase Rollback Fields + RollbackExecutor + TeamOrchestrator Integration + +**Goal:** Add `validation_command`/`rollback_command` optional fields to `PlanPhase`; execute rollback on phase failure (opt-in); coordinate with U7 checkpoint ordering per R21. + +**Requirements:** R19, R20, R21, R26, R27 + +**Dependencies:** none (uses existing U7 `PipelineCheckpoint` and `VerificationLoop` patterns) + +**Files:** +- Modify: `src/agentkit/experts/plan.py` (`PlanPhase` dataclass + `to_dict`/`from_dict` symmetry) +- Modify: `src/agentkit/experts/orchestrator.py` (insert rollback execution between phase failure and checkpoint save) +- Create: `src/agentkit/orchestrator/rollback.py` (`RollbackExecutor` class) +- Modify: `src/agentkit/server/config.py` (add `rollback` section to `from_dict`) +- Modify: `agentkit.yaml` (document `rollback:` section) +- Create: `tests/unit/test_phase_rollback.py` + +**Approach:** + +**PlanPhase extension (`plan.py:147-233`):** Add two optional fields with `None` default (preserves existing contract per Finding 1): +``` +validation_command: str | None = None +rollback_command: str | None = None +``` +Update `to_dict()` to include both fields (only when not None, to keep dict shape minimal). Update `from_dict()` to read both fields. + +**`RollbackExecutor` class (new file `orchestrator/rollback.py`):** Mirrors `VerificationLoop` pattern (`verification_loop.py:67-103`): +``` +class RollbackExecutor: + def __init__(self, working_dir: str | None = None, timeout: float = 30.0): ... + + async def execute(self, command: str) -> RollbackResult: + # asyncio.create_subprocess_shell with cwd/timeout/proc.kill() + # Returns RollbackResult(passed, exit_code, stdout, stderr, command) + + async def validate(self, command: str) -> RollbackResult: + # Same as execute but returns passed=False on non-zero exit +``` + +`RollbackResult` dataclass: `{passed: bool, exit_code: int, stdout: str, stderr: str, command: str}`. + +**TeamOrchestrator integration (`orchestrator.py:246-270`):** New ordering per KTD8: + +``` +# Existing: phase failure detected, mark FAILED + dependents FAILED +# (lines 247-261 unchanged) + +# NEW: rollback phase (only if validation_command and rollback_command configured) +should_save_checkpoint = True +if ph.validation_command and ph.rollback_command: + validator = RollbackExecutor(working_dir=self._workspace_root, timeout=...) + validation_result = await validator.validate(ph.validation_command) + if not validation_result.passed: + rollback_result = await validator.execute(ph.rollback_command) + if not rollback_result.passed: + # Rollback failed → don't save checkpoint (R21) + should_save_checkpoint = False + logger.error(f"Rollback failed for phase {ph.id}: {rollback_result.stderr}") + # Emit phase_rollback_failed event + await self._broadcast_event("phase_rollback_failed", {...}) + +# Existing: checkpoint save (conditional now) +if should_save_checkpoint and self._checkpoint is not None: + try: + await self._checkpoint.save(plan.id, ph, plan.status.value) + except Exception as e: + logger.warning(...) +``` + +**Events emitted (new):** `phase_rollback_started`, `phase_rollback_completed`, `phase_rollback_failed`. Existing `phase_failed` event unchanged (emitted before rollback). + +**Config:** `rollback.default_timeout: float = 30.0` in `ServerConfig.rollback` dict (Wave 1 pattern). + +**Security boundary (KTD7):** Rollback subprocess does NOT go through `ShellTool` (avoids `confirm_callback` for `git checkout`). Audit log emitted via `_broadcast_event`. `terminal_whitelist` is not consulted (only governs `/api/v1/terminal/server` route, per `terminal_server.py:227`). + +**Execution note:** characterization-first. Test current behavior (`rollback_command=None` → no rollback, checkpoint saved) before adding rollback behavior. + +**Patterns to follow:** +- `VerificationLoop` subprocess execution pattern (`verification_loop.py:67-103`) — `asyncio.create_subprocess_shell` + `cwd` + `timeout` + `proc.kill()` +- `PlanPhase.to_dict`/`from_dict` symmetric serialization pattern at `plan.py:186-233` +- `TeamOrchestrator._execute_phase` event broadcasting pattern (`orchestrator.py:253-260`) +- Wave 1's `ServerConfig.from_dict` extension template + +**Test scenarios:** +- Characterization: `PlanPhase()` with no `validation_command`/`rollback_command` → `to_dict()` output matches pre-change shape (no new keys); `from_dict({})` produces phase with both fields as None +- Serialization: phase with `rollback_command="git checkout foo.py"` → `to_dict()` includes key; `from_dict(to_dict())` round-trips +- RollbackExecutor happy path: `execute("git status")` returns `passed=True`, `exit_code=0` +- RollbackExecutor timeout: `execute("sleep 10", timeout=0.1)` returns `passed=False`, exit_code=-1 (or similar) +- RollbackExecutor failure: `execute("false")` returns `passed=False`, `exit_code=1` +- Integration — opt-in default: phase fails, `rollback_command=None` → no RollbackExecutor call, checkpoint saved (existing behavior) +- Integration — rollback configured, validation passes (returns 0): rollback NOT executed, checkpoint saved +- Integration — rollback configured, validation fails, rollback succeeds (returns 0): rollback executed, checkpoint saved +- Integration — rollback configured, validation fails, rollback fails (returns 1): rollback executed, checkpoint NOT saved (R21), `phase_rollback_failed` event emitted +- Integration — real git repo fixture: phase writes file via `git checkout`, rollback `git checkout foo.py` restores file; assert file content matches pre-phase state + +**Verification:** All scenarios pass; existing `tests/unit/test_pipeline_state.py` and `tests/unit/test_team_orchestrator*.py` (if any) still pass; ruff clean. + +--- + +## Scope Boundaries + +### Deferred to Follow-Up Work + +- **Recovery wiring at non-chat call sites** (`cli/chat.py:338`, `rewoo.py:1204`, `config_driven.py:695`). KTD5 limits Wave 2 to the primary chat WebSocket path. Other entry points can adopt the same wrapper in a follow-up. +- **Recovery layer streaming events.** Wave 2's chain returns final `TaskResult`; SSE events for "recovery_started"/"recovery_completed" would require deeper WebSocket protocol changes. +- **Patch-level rollback** (`git apply -R `). KTD6 + KTD7 scope Wave 2 to `git checkout ` pattern. Patch-level requires extending `CheckpointData` schema to track patches — Wave 3 candidate. +- **DB-backed atomic claim for checkpoint.** Finding 3 noted `PipelineCheckpoint` is in-memory dict + Redis fallback, not a PostgreSQL `FOR UPDATE SKIP LOCKED` pattern. Multi-process atomicity is out of scope; single-process `asyncio.Lock` (already present at `orchestrator.py:53`) is sufficient. + +### Outside this product's identity (carried from brainstorm) + +- Wave 3 (G5/G6) — tree-sitter integration, SOLO four-stage state machine. Strategic direction locked in brainstorm KTD6/KTD7; implementation design deferred to Wave 3 plan. +- Node-level checkpoint (ReAct single-step). Stage-level (U7) satisfies core need. +- DeerFlow-style disk filesystem. Redis is the persistence layer. +- Full LangGraph migration. Self-built architecture stays. + +--- + +## Risks & Dependencies + +- **Risk: Auxiliary model availability.** If `auxiliary_model` alias is not configured in `agentkit.yaml`, `LLMGateway` raises `ModelNotFoundError`. Mitigation: `ContextCompressor` catches this in the auxiliary try-block and falls through to main model (R15 fallback path). Documented in YAML comments. +- **Risk: Recovery layer recursive loop.** `ReflexionEngine.execute()` internally calls `ReActEngine.execute()` (5th call site). If Recovery itself fails, it must NOT trigger another Recovery — KTD5 wires only at `chat.py:613`, so ReflexionEngine's internal ReAct call bypasses the chain. Recursive loop is structurally impossible. +- **Risk: `git checkout` destructive scope.** Misconfigured `rollback_command` (e.g., `git checkout .`) could wipe unrelated changes. Mitigation: rollback is opt-in per-phase (KTD6); audit log via `phase_rollback_started`/`phase_rollback_failed` events; documented YAML examples use file-scoped commands (`git checkout `). +- **Risk: `TaskResult.error_struct` field addition breaks pickle/serialization.** `TaskResult` is a dataclass with `to_dict`; new optional field with `None` default is backward-compatible for `to_dict` consumers. Pickle deserialization of pre-change data still works (new field defaults to None). +- **Dependency: Wave 1 merged.** `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections) is established and tested. Wave 2's 3 new sections (auxiliary_model, fallback_chain, rollback) reuse this exact pattern. +- **Dependency: U7 PipelineCheckpoint already shipped.** `orchestrator/checkpoint.py:56` exists with `save(plan_id, phase, plan_status)` API; U4 only adjusts the calling order, not the API. + +--- + +## Sources / Research + +- Brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 2 section, R13-R21, KTD4/KTD5) +- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (ServerConfig.from_dict extension template established) +- Finding 1 (contract preservation rule): `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — "新字段默认值须保持既有契约" drives KTD6's opt-in default for `rollback_command` +- Finding 2 (retry-storm defense): `docs/solutions/security-issues/portal-platform-security-reliability-fixes.md` — Emergency layer's terminal `error_code` pattern (don't propagate unhandled exceptions) +- Finding 3 (atomic claim pattern): `docs/solutions/architecture-patterns/bitable-companion-service-security-reliability-patterns.md` — `SKIP LOCKED` not reusable; `PipelineCheckpoint` is in-memory dict + Redis +- Finding 4 (empty-response anti-pattern): `docs/solutions/ui-bugs/tauri-reload-loses-session.md` — auxiliary LLM must check truthy return value, not just absence of exception; drives U1's empty-content fallback +- Code locations verified during planning: + - `src/agentkit/core/compressor.py:35-44,123-158` — `_summarize` injection point for G4 + - `src/agentkit/llm/gateway.py:170-227` — existing per-model fallback chain (semantics differ from G4's task-type routing) + - `src/agentkit/llm/config.py:198-257` — `LLMConfig` dataclass and `from_dict` + - `src/agentkit/core/reflexion.py:68-111` — `ReflexionEngine.__init__` and `execute()` signatures (already supports `evaluate_model`/`reflect_model`) + - `src/agentkit/core/fallback.py` (full file, 19 lines) — 3 existing constants; `EmergencyRules` adds alongside + - `src/agentkit/core/protocol.py:118-145` — `TaskResult` dataclass and `to_dict` + - `src/agentkit/experts/plan.py:147-233` — `PlanPhase` dataclass, `to_dict`/`from_dict` + - `src/agentkit/experts/orchestrator.py:246-270` — phase failure capture + checkpoint save site (U4 integration point) + - `src/agentkit/orchestrator/checkpoint.py:56` — `PipelineCheckpoint` (U7, no API change needed) + - `src/agentkit/core/verification_loop.py:67-103` — subprocess execution pattern for `RollbackExecutor` + - `src/agentkit/tools/shell.py:168-174,526-534` — `_DANGEROUS_BINARY_FLAGS` (git checkout not listed), `_is_dangerous` logic (KTD7 rationale) diff --git a/src/agentkit/core/compressor.py b/src/agentkit/core/compressor.py index 2a3ecca..e0d2a90 100644 --- a/src/agentkit/core/compressor.py +++ b/src/agentkit/core/compressor.py @@ -41,6 +41,7 @@ class ContextCompressor: model_context_limit: int = 128_000, headroom_threshold: float = 0.8, min_tokens: int = 8_000, + auxiliary_model: str | None = None, ): self._llm_gateway = llm_gateway self._max_tokens = max_tokens @@ -51,6 +52,11 @@ class ContextCompressor: self._model_context_limit = model_context_limit self._headroom_threshold = headroom_threshold self._min_tokens = min_tokens + # G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias). + # When set and differs from main model, _summarize tries auxiliary first, + # falls back to main model on failure OR empty content (Finding 4 anti-pattern). + # ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback. + self._auxiliary_model = auxiliary_model def should_compress(self, messages: list[dict]) -> bool: """Check if compression should be triggered based on headroom ratio. @@ -92,8 +98,8 @@ class ContextCompressor: if len(non_system) <= self._keep_recent: return messages # Not enough messages to compress - old_msgs = non_system[:-self._keep_recent] - recent_msgs = non_system[-self._keep_recent:] + old_msgs = non_system[: -self._keep_recent] + recent_msgs = non_system[-self._keep_recent :] # Compress old messages summary = await self._summarize(old_msgs) @@ -101,10 +107,12 @@ class ContextCompressor: # Build compressed message list compressed = list(system_msgs) if summary: - compressed.append({ - "role": "system", - "content": f"## Conversation Summary\n{summary}", - }) + compressed.append( + { + "role": "system", + "content": f"## Conversation Summary\n{summary}", + } + ) compressed.extend(recent_msgs) # Recursive check: if still over budget, compress again @@ -114,22 +122,30 @@ class ContextCompressor: return self._truncate(compressed) if len(recent_msgs) > 1: # Try keeping fewer recent messages - return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1) + return await self._compress_aggressive( + messages, _compression_depth=_compression_depth + 1 + ) # Last resort: truncate return self._truncate(compressed) return compressed async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str: - """Summarize a list of messages using LLM""" + """Summarize a list of messages using LLM. + + G4/U1: When ``auxiliary_model`` is configured and differs from the main + model, try auxiliary first (cost-optimization). On auxiliary failure OR + empty content (Finding 4 anti-pattern — "did not throw is not succeeded"), + fall back to main model. Existing ``_simple_summary`` degradation + preserved as the final tier when main model also fails. + """ if not self._llm_gateway: # No LLM available, do simple truncation return self._simple_summary(messages) # Build summary prompt conversation_text = "\n".join( - f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" - for m in messages + f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages ) # Pre-truncate if conversation_text exceeds safe token threshold @@ -145,6 +161,25 @@ class ContextCompressor: f"{conversation_text}" ) + # G4: Try auxiliary model first when configured (cheap route). + if self._auxiliary_model and self._auxiliary_model != self._model: + try: + response = await self._llm_gateway.chat( + messages=[{"role": "user", "content": prompt}], + model=self._auxiliary_model, + agent_name="compressor", + task_type="summarization", + ) + # Finding 4: empty content is a failure, not a success. + if response.content and response.content.strip(): + return response.content + logger.info("Auxiliary model returned empty content, falling back to main model") + except Exception as e: + logger.info( + f"Auxiliary model summarization failed, falling back to main model: {e}" + ) + + # Main model path (or auxiliary fallback). try: response = await self._llm_gateway.chat( messages=[{"role": "user", "content": prompt}], @@ -166,7 +201,9 @@ class ContextCompressor: parts.append(f"[{role}]: {content}...") return "\n".join(parts) - async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]: + async def _compress_aggressive( + self, messages: list[dict], _compression_depth: int = 0 + ) -> list[dict]: """More aggressive compression when standard compression isn't enough""" system_msgs = [m for m in messages if m.get("role") == "system"] non_system = [m for m in messages if m.get("role") != "system"] @@ -176,10 +213,12 @@ class ContextCompressor: summary = await self._summarize(non_system[:-1]) compressed = list(system_msgs) if summary: - compressed.append({ - "role": "system", - "content": f"## Conversation Summary\n{summary}", - }) + compressed.append( + { + "role": "system", + "content": f"## Conversation Summary\n{summary}", + } + ) compressed.append(non_system[-1]) return compressed @@ -191,7 +230,7 @@ class ContextCompressor: for msg in messages: content = str(msg.get("content", "")) if len(content) > self._max_tokens * 4: - msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"} + msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"} result.append(msg) return result @@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate if provider == "headroom": try: from agentkit.core.headroom_compressor import HeadroomCompressor + compressor = HeadroomCompressor(config) if compressor.is_available(): return compressor @@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate ) except ImportError: logger.warning( - "HeadroomCompressor module not available. " - "Falling back to ContextCompressor." + "HeadroomCompressor module not available. Falling back to ContextCompressor." ) # Fallback to summary compressor return ContextCompressor( @@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]: """Render PromptTemplate with caching - returns cached result for same variables""" - cache_key = hashlib.md5( - json.dumps(variables or {}, sort_keys=True).encode() - ).hexdigest() + cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest() - if not hasattr(template, '_render_cache'): + if not hasattr(template, "_render_cache"): template._render_cache = {} if cache_key in template._render_cache: @@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic def clear_cache(template) -> None: """Clear the render cache on a PromptTemplate instance""" - if hasattr(template, '_render_cache'): + if hasattr(template, "_render_cache"): template._render_cache.clear() diff --git a/src/agentkit/core/fallback.py b/src/agentkit/core/fallback.py index c3f7455..1c28549 100644 --- a/src/agentkit/core/fallback.py +++ b/src/agentkit/core/fallback.py @@ -2,8 +2,22 @@ All layers (ReActEngine, Portal, Chat) should use these constants to ensure consistent user-facing messages. + +G7/U2: Also hosts ``EmergencyError`` and ``EmergencyRules`` for the +three-tier fallback chain's Emergency layer (rule-based classifier, +no LLM). See ``docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md``. """ +from dataclasses import dataclass +from typing import Any + +from agentkit.core.exceptions import ( + LLMProviderError, + LoopDetectedError, + TaskCancelledError, + TaskTimeoutError, +) + # When LLM returns empty content after all fallback models exhausted EMPTY_LLM_RESPONSE = ( "模型未返回有效内容,已尝试备用模型仍未成功。" @@ -16,3 +30,135 @@ MAX_STEPS_REACHED = "已达到最大推理步数,但仍未得到完整结论 # When a shell command succeeds but produces no output SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]" + + +# ── G7/U2: Emergency layer ────────────────────────────────────── + + +@dataclass +class EmergencyError: + """Structured error produced by the Emergency layer (rule-classified). + + Carries a stable ``error_code`` for programmatic dispatch (frontend + retry UI, telemetry), a human-readable ``message`` mirroring + ``EMPTY_LLM_RESPONSE`` style, actionable ``suggestions``, and the + original exception string for traceability. + + The ``retryable`` flag distinguishes recoverable user errors + (timeout, loop, LLM hiccup) from internal bugs (retryable=False). + """ + + error_code: str # "timeout" | "loop_detected" | "llm_failure" | "internal_error" + message: str # human-readable Chinese message + suggestions: list[str] # actionable user-facing suggestions + retryable: bool # whether a user retry might succeed + original_error: str # str(exc) for traceability + + def to_dict(self) -> dict[str, Any]: + return { + "error_code": self.error_code, + "message": self.message, + "suggestions": list(self.suggestions), + "retryable": self.retryable, + "original_error": self.original_error, + } + + def to_error_message(self) -> str: + """Format as a single human-readable string with suggestions. + + Mirrors ``EMPTY_LLM_RESPONSE`` style: ``建议:1) ... 2) ...`` + """ + if not self.suggestions: + return self.message + suggestion_str = "".join(f"{i}) {s};" for i, s in enumerate(self.suggestions, 1)) + # Strip trailing ";" and prefix with "建议:" + suggestion_str = suggestion_str.rstrip(";") + return f"{self.message}建议:{suggestion_str}。" + + +# Default rule set: (exception_type, error_code, message, suggestions, retryable) +# ponytail: ceiling — rule-based, no LLM. Adding a new rule = append a tuple. +# Upgrade path: LLM-driven suggestions would require a separate classifier +# class (out of scope per brainstorm KTD4). +_DEFAULT_RULES: list[tuple[type[Exception], str, str, list[str], bool]] = [ + ( + TaskTimeoutError, + "timeout", + "任务执行超时。", + ["稍后重试", "简化任务范围"], + True, + ), + ( + LoopDetectedError, + "loop_detected", + "检测到推理循环。", + ["拆分任务", "检查工具参数"], + True, + ), + ( + LLMProviderError, + "llm_failure", + "LLM 服务调用失败。", + ["稍后重试", "切换模型"], + True, + ), +] + +_DEFAULT_ERROR_CODE = "internal_error" +_DEFAULT_MESSAGE = "Agent 执行内部错误。" +_DEFAULT_SUGGESTIONS: list[str] = ["联系管理员"] +_DEFAULT_RETRYABLE = False + + +class EmergencyRules: + """Rule-based classifier for the Emergency layer. + + Maps exception types to ``EmergencyError`` instances. No LLM, no I/O — + pure function of ``(exception, config)``. + + Caller responsibility: ``TaskCancelledError`` MUST propagate as-is + (per KTD3); the caller checks ``isinstance(exc, TaskCancelledError)`` + before invoking :meth:`classify`. Calling :meth:`classify` with a + ``TaskCancelledError`` raises ``ValueError`` to surface the bug. + + Config override shape (optional ``config`` arg): + + ```python + { + "timeout": {"suggestions": ["自定义建议"], "retryable": False}, + "llm_failure": {"message": "自定义消息"}, + } + ``` + """ + + @staticmethod + def classify(exc: Exception, config: dict | None = None) -> EmergencyError: + if isinstance(exc, TaskCancelledError): + # Contract: caller must check before invoking classify. + raise ValueError( + "TaskCancelledError must propagate as-is; caller must check " + "before invoking EmergencyRules.classify" + ) + + config = config or {} + + for exc_type, code, message, suggestions, retryable in _DEFAULT_RULES: + if isinstance(exc, exc_type): + override = config.get(code, {}) if isinstance(config, dict) else {} + return EmergencyError( + error_code=code, + message=override.get("message", message), + suggestions=override.get("suggestions", list(suggestions)), + retryable=override.get("retryable", retryable), + original_error=str(exc), + ) + + # Generic fallback + override = config.get(_DEFAULT_ERROR_CODE, {}) if isinstance(config, dict) else {} + return EmergencyError( + error_code=_DEFAULT_ERROR_CODE, + message=override.get("message", _DEFAULT_MESSAGE), + suggestions=override.get("suggestions", list(_DEFAULT_SUGGESTIONS)), + retryable=override.get("retryable", _DEFAULT_RETRYABLE), + original_error=str(exc), + ) diff --git a/src/agentkit/core/protocol.py b/src/agentkit/core/protocol.py index 99ac0aa..60ed289 100644 --- a/src/agentkit/core/protocol.py +++ b/src/agentkit/core/protocol.py @@ -128,6 +128,10 @@ class TaskResult: completed_at: datetime metrics: dict | None = None trace: Any | None = None + # G7/U2: Emergency layer structured error. None preserves existing contract + # (error_message alone carries the human-readable string). When set, + # carries serialized EmergencyError.to_dict() for programmatic dispatch. + error_struct: dict | None = None def to_dict(self) -> dict: d = { @@ -142,6 +146,8 @@ class TaskResult: } if self.trace is not None: d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace + if self.error_struct is not None: + d["error_struct"] = self.error_struct return d @classmethod @@ -162,6 +168,7 @@ class TaskResult: completed_at=completed_at or datetime.now(timezone.utc), metrics=data.get("metrics"), trace=data.get("trace"), + error_struct=data.get("error_struct"), ) diff --git a/src/agentkit/experts/orchestrator.py b/src/agentkit/experts/orchestrator.py index 990e6a4..4ed80ed 100644 --- a/src/agentkit/experts/orchestrator.py +++ b/src/agentkit/experts/orchestrator.py @@ -30,6 +30,7 @@ from typing import Any from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus from agentkit.llm.gateway import LLMGateway +from agentkit.orchestrator.rollback import RollbackExecutor from .expert import Expert from .plan import ( @@ -72,12 +73,17 @@ class TeamOrchestrator: MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰 STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"}) + # G9/U4: RollbackExecutor default timeout for validation_command / rollback_command. + # Override via constructor `rollback_timeout` from `rollback.default_timeout` config. + DEFAULT_ROLLBACK_TIMEOUT = 30.0 def __init__( self, team: ExpertTeam, max_concurrent_phases: int | None = None, checkpoint: Any = None, + workspace_root: str | None = None, + rollback_timeout: float | None = None, ) -> None: self._team = team # Track temporary agent names created for context isolation (KTD3) @@ -93,6 +99,10 @@ class TeamOrchestrator: self._phase_semaphore = asyncio.Semaphore(limit) # U7: Pipeline checkpoint for crash recovery self._checkpoint = checkpoint + # G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout. + # Both default to no-op-friendly values so existing call sites behave identically. + self._workspace_root = workspace_root + self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT async def execute(self, task: str) -> dict[str, Any]: """Execute a task in pipeline mode. @@ -262,8 +272,23 @@ class TeamOrchestrator: else: phase_results[ph.id] = result + # G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21). + # When phase configures both validation_command and rollback_command: + # 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint + # 2. if validation fails, run rollback_command + # 3. if rollback passes (exit 0), save checkpoint + # 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state) + # When neither command is set, behavior is unchanged (existing save). + should_save_checkpoint = True + if ( + ph.validation_command + and ph.rollback_command + and isinstance(result, (Exception, asyncio.CancelledError)) + ): + should_save_checkpoint = await self._run_phase_rollback(plan, ph) + # U7: Save checkpoint after phase finalizes (success or failure) - if self._checkpoint is not None: + if should_save_checkpoint and self._checkpoint is not None: try: await self._checkpoint.save(plan.id, ph, plan.status.value) except Exception as e: @@ -393,9 +418,7 @@ class TeamOrchestrator: # PENDING phases remain PENDING — will be executed by _run_pipeline # P2 #8: Restore debate count so MAX_DEBATES limit holds after resume - self._debate_count = sum( - 1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE - ) + self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE) logger.info( f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, " @@ -688,9 +711,9 @@ class TeamOrchestrator: and prev_phase.result ): # U4: Resolve offloaded content from workspace - collaboration_outputs[contract.from_expert] = ( - await self._read_dependency_output(prev_phase) - ) + collaboration_outputs[ + contract.from_expert + ] = await self._read_dependency_output(prev_phase) break # Emit expert_step event @@ -1809,6 +1832,75 @@ class TeamOrchestrator: # Recursively mark their dependents await self._mark_dependents_failed(ph.id, plan, phase_results) + async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool: + """G9/U4: run validation_command + rollback_command for a failed phase. + + Returns True if checkpoint save should proceed (R21 ordering). + - Validation passes → save checkpoint (phase state recoverable) + - Validation fails, rollback passes → save checkpoint (rolled back state) + - Validation fails, rollback fails → skip checkpoint (broken state) + - Subprocess spawn failure or timeout → skip checkpoint + """ + executor = RollbackExecutor( + working_dir=self._workspace_root, + timeout=self._rollback_timeout, + ) + await self._broadcast_event( + "phase_rollback_started", + { + "plan_id": plan.id, + "phase_id": ph.id, + "phase_name": ph.name, + "validation_command": ph.validation_command, + "rollback_command": ph.rollback_command, + }, + ) + # ponytail: validate first; if validation passes, rollback is skipped (no need). + validation = await executor.validate(ph.validation_command or "") + if validation.passed: + await self._broadcast_event( + "phase_rollback_completed", + { + "plan_id": plan.id, + "phase_id": ph.id, + "phase_name": ph.name, + "rollback_executed": False, + "validation_passed": True, + }, + ) + return True + + rollback = await executor.execute(ph.rollback_command or "") + if rollback.passed: + await self._broadcast_event( + "phase_rollback_completed", + { + "plan_id": plan.id, + "phase_id": ph.id, + "phase_name": ph.name, + "rollback_executed": True, + "validation_passed": False, + "rollback_stdout": rollback.stdout, + }, + ) + return True + + logger.error( + f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}" + ) + await self._broadcast_event( + "phase_rollback_failed", + { + "plan_id": plan.id, + "phase_id": ph.id, + "phase_name": ph.name, + "validation_passed": False, + "rollback_exit_code": rollback.exit_code, + "rollback_stderr": rollback.stderr, + }, + ) + return False + async def _synthesize_results( self, lead: Expert, task: str, completed_phases: list[PlanPhase] ) -> dict[str, Any]: diff --git a/src/agentkit/experts/plan.py b/src/agentkit/experts/plan.py index a8cd52c..2a95cf0 100644 --- a/src/agentkit/experts/plan.py +++ b/src/agentkit/experts/plan.py @@ -182,6 +182,11 @@ class PlanPhase: collaboration_contracts: list[CollaborationContract] = field(default_factory=list) rework_count: int = 0 review_feedback: str | None = None + # G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6). + # validation_command runs first; if it fails, rollback_command runs. + # canonical rollback pattern: `git checkout `. + validation_command: str | None = None + rollback_command: str | None = None def to_dict(self) -> dict[str, Any]: """序列化为字典""" @@ -192,7 +197,7 @@ class PlanPhase: result_str = self.result.get("content", str(self.result)) else: result_str = str(self.result) - return { + out: dict[str, Any] = { "id": self.id, "name": self.name, "assigned_expert": self.assigned_expert, @@ -206,6 +211,12 @@ class PlanPhase: "rework_count": self.rework_count, "review_feedback": self.review_feedback, } + # G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6). + if self.validation_command is not None: + out["validation_command"] = self.validation_command + if self.rollback_command is not None: + out["rollback_command"] = self.rollback_command + return out @classmethod def from_dict(cls, data: dict[str, Any]) -> PlanPhase: @@ -230,6 +241,8 @@ class PlanPhase: collaboration_contracts=contracts, rework_count=data.get("rework_count", 0), review_feedback=data.get("review_feedback"), + validation_command=data.get("validation_command"), + rollback_command=data.get("rollback_command"), ) diff --git a/src/agentkit/llm/config.py b/src/agentkit/llm/config.py index 6357582..1ee2aa9 100644 --- a/src/agentkit/llm/config.py +++ b/src/agentkit/llm/config.py @@ -203,6 +203,9 @@ class LLMConfig: model_aliases: dict[str, str] = field(default_factory=dict) fallbacks: dict[str, list[str]] = field(default_factory=dict) cache: CacheConfig | None = None + # G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization). + # Resolved via existing model_aliases mechanism. None = use main model only. + auxiliary_model: str | None = None @classmethod def from_dict(cls, data: dict) -> "LLMConfig": @@ -254,4 +257,5 @@ class LLMConfig: model_aliases=data.get("model_aliases", {}), fallbacks=data.get("fallbacks", {}), cache=cache, + auxiliary_model=data.get("auxiliary_model"), ) diff --git a/src/agentkit/orchestrator/rollback.py b/src/agentkit/orchestrator/rollback.py new file mode 100644 index 0000000..15bf21b --- /dev/null +++ b/src/agentkit/orchestrator/rollback.py @@ -0,0 +1,97 @@ +"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4) + +复用 VerificationLoop 的 asyncio.create_subprocess_shell 模式 (KTD7): +绕过 ShellTool,避免 confirm_callback 对 `git checkout` 的拦截。 + +设计依据: +- KTD6: 回滚是 opt-in 行为,未配置 rollback_command 时不会执行 +- KTD7: 不走 ShellTool,避免 _is_dangerous 触发 confirm_callback +- R21: checkpoint.save 仅在回滚校验通过后调用 +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class RollbackResult: + """单次 subprocess 执行结果""" + + passed: bool + exit_code: int + stdout: str + stderr: str + command: str + + +class RollbackExecutor: + """执行 validation_command / rollback_command 的子进程封装 + + 与 VerificationLoop 同构,但语义不同: + - validate(): 返回 passed=False 表示校验失败,需要触发 rollback + - execute(): 返回 passed=False 表示回滚本身失败,需跳过 checkpoint.save (R21) + """ + + def __init__( + self, + working_dir: str | None = None, + timeout: float = 30.0, + ) -> None: + self._working_dir = working_dir + self._timeout = timeout + + async def _run(self, command: str) -> RollbackResult: + try: + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self._working_dir, + ) + except Exception as e: # noqa: BLE001 - subprocess spawn failure surface + return RollbackResult( + passed=False, + exit_code=-1, + stdout="", + stderr=f"Failed to spawn command: {e}", + command=command, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout) + except asyncio.TimeoutError: + try: + proc.kill() + except ProcessLookupError: + pass + await proc.wait() + return RollbackResult( + passed=False, + exit_code=-1, + stdout="", + stderr=f"Command timed out after {self._timeout}s: {command}", + command=command, + ) + + out_str = stdout.decode("utf-8", errors="replace") if stdout else "" + err_str = stderr.decode("utf-8", errors="replace") if stderr else "" + return RollbackResult( + passed=proc.returncode == 0, + exit_code=proc.returncode if proc.returncode is not None else -1, + stdout=out_str, + stderr=err_str, + command=command, + ) + + async def validate(self, command: str) -> RollbackResult: + """运行 validation_command,passed=False 表示需要触发 rollback""" + return await self._run(command) + + async def execute(self, command: str) -> RollbackResult: + """运行 rollback_command,passed=False 表示回滚本身失败 (R21)""" + return await self._run(command) diff --git a/src/agentkit/server/_fallback_chain.py b/src/agentkit/server/_fallback_chain.py new file mode 100644 index 0000000..48d080f --- /dev/null +++ b/src/agentkit/server/_fallback_chain.py @@ -0,0 +1,199 @@ +"""G7/U3 — Three-tier fallback chain (main → Recovery → Emergency). + +Wired at chat.py REST send_message endpoint. Composes U2's EmergencyRules +with existing ReflexionEngine for the Recovery layer. + +Scope (KTD5): Only the chat REST path is wrapped. CLI / ReWOO / Reflexion +internal ReAct calls are NOT wrapped (would create recursive loop). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +from agentkit.core.exceptions import ( + LLMProviderError, + LoopDetectedError, + TaskCancelledError, + TaskTimeoutError, +) +from agentkit.core.fallback import EmergencyError, EmergencyRules +from agentkit.core.react import ReActEngine, ReActResult +from agentkit.core.reflexion import ReflexionEngine +from agentkit.llm.gateway import LLMGateway + +logger = logging.getLogger(__name__) + + +# ReActResult.status values that indicate soft failure → trigger Recovery. +# "success" is the only clean-pass; everything else is fallback-worthy. +_SOFT_FAILURE_STATUSES = frozenset({"empty_fallback", "verify_failed", "timeout"}) + + +@dataclass +class ChatExecutionResult: + """Wrapper produced by execute_with_fallback_chain. + + Carries a ReActResult-like ``output`` field plus an optional + ``error_struct`` (set only when Emergency tier fires). The chat + handler reads ``.output`` for the assistant reply and ``.error_struct`` + for the optional structured error payload. + """ + + output: str + status: str # "success" | "recovered" | "emergency" + error_struct: dict[str, Any] | None = None + trajectory: list[Any] = field(default_factory=list) + total_steps: int = 0 + total_tokens: int = 0 + fallback_strategy: str | None = None + + +def _react_to_chat_result(react: ReActResult) -> ChatExecutionResult: + return ChatExecutionResult( + output=react.output, + status="success", + trajectory=react.trajectory, + total_steps=react.total_steps, + total_tokens=react.total_tokens, + fallback_strategy=react.fallback_strategy, + ) + + +def _reflexion_to_chat_result(reflexion_result: Any) -> ChatExecutionResult: + """Best-effort conversion from ReflexionResult to ChatExecutionResult.""" + output = getattr(reflexion_result, "output", None) or getattr( + reflexion_result, "final_answer", "" + ) + return ChatExecutionResult( + output=output or "", + status="recovered", + trajectory=getattr(reflexion_result, "trajectory", []) or [], + total_steps=getattr(reflexion_result, "total_steps", 0), + total_tokens=getattr(reflexion_result, "total_tokens", 0), + fallback_strategy="reflexion_recovery", + ) + + +def _to_emergency(exc: Exception, config: dict | None) -> ChatExecutionResult: + emergency: EmergencyError = EmergencyRules.classify(exc, config) + return ChatExecutionResult( + output=emergency.to_error_message(), + status="emergency", + error_struct=emergency.to_dict(), + fallback_strategy="emergency", + ) + + +async def execute_with_fallback_chain( + *, + react_engine: ReActEngine, + llm_gateway: LLMGateway, + messages: list[dict[str, str]], + tools: list[Any] | None, + model: str, + agent_name: str, + system_prompt: str | None, + fallback_chain_config: dict | None = None, +) -> ChatExecutionResult: + """Three-tier fallback chain: Main → Recovery (ReflexionEngine) → Emergency. + + KTD5: only this entry point wraps the chain. ReflexionEngine's internal + ReAct call bypasses the chain (no recursive loop possible). + + Returns ChatExecutionResult with status: + - "success": main agent succeeded + - "recovered": main failed, ReflexionEngine recovery succeeded + - "emergency": main failed, recovery failed/exhausted, Emergency layer fired + """ + config = fallback_chain_config or {} + recovery_cfg = config.get("recovery", {}) if isinstance(config, dict) else {} + emergency_cfg = config.get("emergency", {}) if isinstance(config, dict) else {} + recovery_enabled = recovery_cfg.get("enabled", True) if isinstance(recovery_cfg, dict) else True + emergency_enabled = ( + emergency_cfg.get("enabled", True) if isinstance(emergency_cfg, dict) else True + ) + max_reflections = recovery_cfg.get("max_retries", 1) if isinstance(recovery_cfg, dict) else 1 + + # ── Tier 1: Main ────────────────────────────────────────────── + main_exc: Exception | None = None + try: + result = await react_engine.execute( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + system_prompt=system_prompt, + ) + if result.status == "success": + return _react_to_chat_result(result) + # Soft failure (empty_fallback / verify_failed / timeout) → trigger Recovery + if result.status in _SOFT_FAILURE_STATUSES: + main_exc = AgentSoftFailureError( + f"main agent status={result.status}: {result.output[:200]}" + ) + else: + # Unknown status — treat as success-like (don't trigger recovery) + return _react_to_chat_result(result) + except TaskCancelledError: + # KTD3: TaskCancelledError propagates as-is, NOT routed to Emergency. + raise + except (TaskTimeoutError, LoopDetectedError, LLMProviderError) as exc: + main_exc = exc + except Exception as exc: # noqa: BLE001 - last-resort catch for Emergency routing + main_exc = exc + + # ── Tier 2: Recovery (ReflexionEngine) ──────────────────────── + if recovery_enabled and main_exc is not None: + try: + reflexion = ReflexionEngine( + llm_gateway=llm_gateway, + max_reflections=max_reflections, + ) + recovery_result = await reflexion.execute( + messages=messages, + tools=tools, + model=model, + agent_name=agent_name, + system_prompt=system_prompt, + ) + # Recovery succeeds if Reflexion reports success or produces output. + recovery_status = getattr(recovery_result, "status", "") + if recovery_status == "success" or getattr(recovery_result, "output", None): + return _reflexion_to_chat_result(recovery_result) + logger.warning( + f"Recovery layer did not succeed (status={recovery_status}), " + f"falling through to Emergency" + ) + except TaskCancelledError: + raise + except Exception as recovery_exc: # noqa: BLE001 + logger.warning(f"Recovery layer raised: {recovery_exc}; falling through to Emergency") + + # ── Tier 3: Emergency ───────────────────────────────────────── + if not emergency_enabled: + # Re-raise original exception if Emergency disabled. + if main_exc is not None: + raise main_exc + # No exception but no success either — synthesise an emergency-style result. + return ChatExecutionResult( + output="Agent 未返回有效结果且 Emergency 层已禁用。", + status="emergency", + fallback_strategy="emergency_disabled", + ) + + # main_exc may be None if main returned soft-failure status without raising. + # Synthesize a generic exception for Emergency classification. + exc_for_emergency = main_exc or AgentSoftFailureError("soft failure without exception") + return _to_emergency(exc_for_emergency, config) + + +class AgentSoftFailureError(Exception): + """Internal marker — main agent returned a soft-failure status without raising. + + Used to feed the Emergency classifier when main status was e.g. + ``empty_fallback`` (no exception raised, but result not usable). + Classified as ``internal_error`` by EmergencyRules (generic fallback). + """ diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index 1d67cf4..fcfe979 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -119,6 +119,8 @@ class ServerConfig: prompt_cache: dict[str, Any] | None = None, streaming: dict[str, Any] | None = None, verification: dict[str, Any] | None = None, + rollback: dict[str, Any] | None = None, + fallback_chain: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -153,6 +155,12 @@ class ServerConfig: # U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1) # verification_enabled=False 时此配置无效 self.verification = verification or {} + # G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时 + # PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in) + self.rollback = rollback or {} + # G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries} + # controls three-tier chain at chat.py REST send_message (KTD5). + self.fallback_chain = fallback_chain or {} self.on_change = on_change # Config watching state @@ -240,6 +248,10 @@ class ServerConfig: prompt_cache_data = data.get("prompt_cache", {}) streaming_data = data.get("streaming", {}) verification_data = data.get("verification", {}) + # G9/U4: rollback 配置 (从 YAML 读取,opt-in) + rollback_data = data.get("rollback", {}) + # G7/U3: fallback_chain 配置 (从 YAML 读取) + fallback_chain_data = data.get("fallback_chain", {}) return cls( host=server.get("host", "0.0.0.0"), @@ -271,6 +283,8 @@ class ServerConfig: prompt_cache=prompt_cache_data, streaming=streaming_data, verification=verification_data, + rollback=rollback_data, + fallback_chain=fallback_chain_data, ) @staticmethod @@ -320,6 +334,9 @@ class ServerConfig: model_aliases=model_aliases, fallbacks=data.get("fallbacks", {}), cache=cache_config, + # G4/U1: auxiliary model alias for cost-sensitive summarization. + # Resolved via model_aliases; None = no auxiliary routing. + auxiliary_model=data.get("auxiliary_model"), ) @staticmethod diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index d0bc489..6737740 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -27,6 +27,7 @@ from pydantic import BaseModel from agentkit.chat.skill_routing import ExecutionMode from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine +from agentkit.server._fallback_chain import execute_with_fallback_chain from agentkit.session.manager import SessionManager from agentkit.session.models import MessageRole, SessionStatus @@ -610,7 +611,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques system_prompt = getattr(agent, "_system_prompt", None) or ( agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None ) - result = await react_engine.execute( + # G7/U3: Three-tier fallback chain (main → Recovery → Emergency). + # Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass. + server_config = getattr(req.app.state, "server_config", None) + fallback_chain_cfg = ( + getattr(server_config, "fallback_chain", None) if server_config else None + ) + chat_result = await execute_with_fallback_chain( + react_engine=react_engine, + llm_gateway=req.app.state.llm_gateway, messages=chat_messages, tools=tools, model=agent.get_model() @@ -618,16 +627,26 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques else getattr(agent, "_llm_model", "default"), agent_name=agent.name, system_prompt=system_prompt, + fallback_chain_config=fallback_chain_cfg, ) # Append assistant reply assistant_msg = await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, - content=result.output if hasattr(result, "output") else str(result), + content=chat_result.output, agent_name=agent.name, ) - return _message_to_response(assistant_msg) + response = _message_to_response(assistant_msg) + # Attach structured error payload when Emergency tier fired. + if chat_result.error_struct is not None: + response_dict = ( + response.model_dump() if hasattr(response, "model_dump") else dict(response) + ) + response_dict["error_struct"] = chat_result.error_struct + response_dict["fallback_status"] = chat_result.status + return response_dict + return response except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") diff --git a/tests/unit/test_compressor_auxiliary.py b/tests/unit/test_compressor_auxiliary.py new file mode 100644 index 0000000..daf8c70 --- /dev/null +++ b/tests/unit/test_compressor_auxiliary.py @@ -0,0 +1,400 @@ +"""G4/U1 — Auxiliary LLM routing in ContextCompressor. + +Verifies: +- auxiliary_model routes _summarize through the cheaper model first +- empty content (Finding 4 anti-pattern) triggers fallback to main model +- auxiliary exception triggers fallback to main model +- both auxiliary and main failing falls through to _simple_summary +- auxiliary_model=None preserves existing single-model behavior (characterization) +- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config) +""" + +from unittest.mock import AsyncMock, MagicMock + +from agentkit.core.compressor import ContextCompressor +from agentkit.llm.config import LLMConfig +from agentkit.llm.protocol import LLMResponse, TokenUsage + + +# ── Helpers ────────────────────────────────────────── + + +def make_gateway_with_response(content: str, model: str = "test") -> MagicMock: + """Mock LLMGateway returning a fixed response.""" + from agentkit.llm.gateway import LLMGateway + + gateway = MagicMock(spec=LLMGateway) + gateway.chat = AsyncMock( + return_value=LLMResponse( + content=content, + model=model, + usage=TokenUsage(prompt_tokens=10, completion_tokens=10), + ) + ) + return gateway + + +def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock: + """Mock LLMGateway returning different responses (or raising) keyed by model name. + + Each call to gateway.chat(model=X) pops the next response for X from a queue, + so repeated calls to the same model can return different values. + """ + from agentkit.llm.gateway import LLMGateway + + gateway = MagicMock(spec=LLMGateway) + queues = {m: list(rs) for m, rs in responses_by_model.items()} + + async def chat_side_effect(*, messages, model, **kwargs): + queue = queues.get(model) + if queue is None: + raise ValueError(f"unexpected model={model}") + if not queue: + raise ValueError(f"queue for model={model} exhausted") + item = queue.pop(0) + if isinstance(item, Exception): + raise item + return item + + gateway.chat = AsyncMock(side_effect=chat_side_effect) + return gateway + + +def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]: + """Generate long messages that exceed token budget (triggers compression).""" + messages = [{"role": "system", "content": "You are a helpful assistant."}] + for i in range(count): + messages.append({"role": "user", "content": "x" * content_length + f" m{i}"}) + messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"}) + messages.append({"role": "user", "content": "recent question"}) + messages.append({"role": "assistant", "content": "recent answer"}) + return messages + + +# ── Characterization: auxiliary_model=None preserves existing behavior ── + + +class TestAuxiliaryNoneCharacterization: + """auxiliary_model=None (default) — single model call, existing behavior.""" + + async def test_no_auxiliary_calls_main_once(self): + gateway = make_gateway_with_response("main summary") + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + # auxiliary_model omitted → None + ) + result = await compressor.compress(make_long_messages()) + + gateway.chat.assert_awaited_once() + # The call used the main model + assert gateway.chat.await_args.kwargs.get("model") == "main" + # Summary surfaced in result + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert any("main summary" in m["content"] for m in summary_msgs) + + async def test_main_failure_falls_to_simple_summary(self): + gateway = MagicMock() + gateway.chat = AsyncMock(side_effect=Exception("main LLM error")) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + ) + result = await compressor.compress(make_long_messages()) + + # _simple_summary produces truncated messages with "..." + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert len(summary_msgs) == 1 + assert "..." in summary_msgs[0]["content"] + + +# ── New behavior: auxiliary routing ────────────────── + + +class TestAuxiliaryRouting: + """auxiliary_model set and differs from main → auxiliary tried first.""" + + async def test_auxiliary_success_returns_auxiliary_content(self): + gateway = make_gateway_side_effect( + { + "fast": [ + LLMResponse( + content="aux summary", + model="fast", + usage=TokenUsage(prompt_tokens=1, completion_tokens=1), + ) + ], + "main": [ + LLMResponse( + content="MAIN SHOULD NOT BE USED", + model="main", + usage=TokenUsage(prompt_tokens=1, completion_tokens=1), + ) + ], + } + ) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + result = await compressor.compress(make_long_messages()) + + # Auxiliary called; main NOT called + aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"] + main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"] + assert len(aux_calls) == 1 + assert len(main_calls) == 0 + # Result contains auxiliary summary + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert any("aux summary" in m["content"] for m in summary_msgs) + + async def test_empty_content_triggers_main_fallback(self): + """Finding 4 anti-pattern: empty content is a failure, not a success.""" + gateway = make_gateway_side_effect( + { + "fast": [ + LLMResponse( + content="", + model="fast", + usage=TokenUsage(prompt_tokens=1, completion_tokens=0), + ) + ], + "main": [ + LLMResponse( + content="main summary", + model="main", + usage=TokenUsage(prompt_tokens=1, completion_tokens=1), + ) + ], + } + ) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + result = await compressor.compress(make_long_messages()) + + # Auxiliary called once (returned empty) + # Main called once (fallback) + aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"] + main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"] + assert len(aux_calls) == 1 + assert len(main_calls) == 1 + # Result contains main summary (not the empty auxiliary) + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert any("main summary" in m["content"] for m in summary_msgs) + + async def test_whitespace_content_triggers_main_fallback(self): + """Whitespace-only content also counts as empty (Finding 4).""" + gateway = make_gateway_side_effect( + { + "fast": [ + LLMResponse( + content=" \n ", + model="fast", + usage=TokenUsage(prompt_tokens=1, completion_tokens=0), + ) + ], + "main": [ + LLMResponse( + content="main summary", + model="main", + usage=TokenUsage(prompt_tokens=1, completion_tokens=1), + ) + ], + } + ) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + await compressor.compress(make_long_messages()) + + # Both auxiliary and main called + aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"] + main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"] + assert len(aux_calls) == 1 + assert len(main_calls) == 1 + + async def test_auxiliary_exception_triggers_main_fallback(self): + from agentkit.core.exceptions import LLMProviderError + + gateway = make_gateway_side_effect( + { + "fast": [LLMProviderError("aux", "provider down")], + "main": [ + LLMResponse( + content="main summary", + model="main", + usage=TokenUsage(prompt_tokens=1, completion_tokens=1), + ) + ], + } + ) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + result = await compressor.compress(make_long_messages()) + + # Both called; main succeeded + aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"] + main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"] + assert len(aux_calls) == 1 + assert len(main_calls) == 1 + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert any("main summary" in m["content"] for m in summary_msgs) + + async def test_both_fail_falls_to_simple_summary(self): + """Auxiliary raises, main raises → existing _simple_summary degradation.""" + # Note: aggressive compression path may invoke _summarize multiple times. + # Queue provides enough responses to handle that without raising queue-exhausted. + gateway = make_gateway_side_effect( + { + "fast": [Exception("aux boom")] * 5, + "main": [Exception("main boom")] * 5, + } + ) + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + result = await compressor.compress(make_long_messages()) + + # Both called at least once + aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"] + main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"] + assert len(aux_calls) >= 1 + assert len(main_calls) >= 1 + # _simple_summary output has "..." truncation markers + summary_msgs = [ + m + for m in result + if m.get("role") == "system" and "Conversation Summary" in m.get("content", "") + ] + assert len(summary_msgs) == 1 + assert "..." in summary_msgs[0]["content"] + + async def test_auxiliary_equal_to_main_skipped(self): + """auxiliary_model == model → no auxiliary routing (single call to main).""" + gateway = make_gateway_with_response("main summary") + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="main", # same as main + ) + await compressor.compress(make_long_messages()) + + # Only one call (to main); auxiliary block skipped + assert gateway.chat.await_count == 1 + assert gateway.chat.await_args.kwargs.get("model") == "main" + + async def test_audit_fields_preserved(self): + """Auxiliary call uses agent_name='compressor', task_type='summarization'.""" + gateway = make_gateway_with_response("aux summary") + compressor = ContextCompressor( + llm_gateway=gateway, + max_tokens=100, + keep_recent=2, + model="main", + auxiliary_model="fast", + ) + # Override the mock to use a single-response gateway where auxiliary succeeds + # (the make_gateway_with_response mock returns same response regardless of model) + await compressor.compress(make_long_messages()) + + # Single call (auxiliary succeeded) — verify audit fields + call_kwargs = gateway.chat.await_args.kwargs + assert call_kwargs.get("agent_name") == "compressor" + assert call_kwargs.get("task_type") == "summarization" + + +# ── Config wiring ──────────────────────────────────── + + +class TestConfigWiring: + """LLMConfig + ServerConfig read auxiliary_model from dict.""" + + def test_llm_config_from_dict_reads_auxiliary_model(self): + cfg = LLMConfig.from_dict( + { + "providers": {}, + "model_aliases": {"fast": "p/m"}, + "auxiliary_model": "fast", + } + ) + assert cfg.auxiliary_model == "fast" + + def test_llm_config_from_dict_auxiliary_none_when_absent(self): + cfg = LLMConfig.from_dict({"providers": {}}) + assert cfg.auxiliary_model is None + + def test_llm_config_default_auxiliary_none(self): + cfg = LLMConfig() + assert cfg.auxiliary_model is None + + def test_server_config_build_llm_config_reads_auxiliary_model(self): + from agentkit.server.config import ServerConfig + + llm_data = { + "providers": { + "p": { + "type": "openai", + "api_key": "k", + "base_url": "http://x", + "models": {"m": {"alias": "fast"}}, + } + }, + "auxiliary_model": "fast", + } + llm_config = ServerConfig._build_llm_config(llm_data) + assert llm_config.auxiliary_model == "fast" + # Also verify model_aliases still built correctly + assert llm_config.model_aliases.get("fast") == "p/m" + + def test_server_config_build_llm_config_auxiliary_none_when_absent(self): + from agentkit.server.config import ServerConfig + + llm_config = ServerConfig._build_llm_config({"providers": {}}) + assert llm_config.auxiliary_model is None diff --git a/tests/unit/test_emergency_rules.py b/tests/unit/test_emergency_rules.py new file mode 100644 index 0000000..63a0d76 --- /dev/null +++ b/tests/unit/test_emergency_rules.py @@ -0,0 +1,318 @@ +"""G7/U2 — Emergency layer rule template + TaskResult extension. + +Verifies: +- EmergencyRules.classify maps each exception type to correct error_code +- TaskCancelledError raises ValueError (caller must propagate as-is) +- EmergencyError.to_dict produces all 5 fields +- EmergencyError.to_error_message formats suggestions as "建议:1) ... 2) ..." +- Config overrides apply (suggestions, retryable, message) +- TaskResult.error_struct field: default None preserves byte-for-byte + to_dict() output (backward compat) +- TaskResult round-trip serialization includes error_struct when set +""" + +from datetime import datetime, timezone + +import pytest + +from agentkit.core.exceptions import ( + LLMProviderError, + LoopDetectedError, + TaskCancelledError, + TaskTimeoutError, +) +from agentkit.core.fallback import ( + EMPTY_LLM_RESPONSE, + MAX_STEPS_REACHED, + SHELL_NO_OUTPUT, + EmergencyError, + EmergencyRules, +) +from agentkit.core.protocol import TaskResult + + +# ── Constants unchanged (contract preservation) ────── + + +class TestExistingConstantsUnchanged: + """Existing 3 constants preserved byte-for-byte.""" + + def test_empty_llm_response_unchanged(self): + assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE + assert "建议" in EMPTY_LLM_RESPONSE + + def test_max_steps_reached_unchanged(self): + assert "已达到最大推理步数" in MAX_STEPS_REACHED + + def test_shell_no_output_unchanged(self): + assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]" + + +# ── EmergencyRules.classify ────────────────────────── + + +class TestEmergencyRulesClassify: + """classify() maps exception types to EmergencyError.""" + + def test_timeout(self): + exc = TaskTimeoutError(task_id="t1", timeout_seconds=30) + err = EmergencyRules.classify(exc) + assert err.error_code == "timeout" + assert err.retryable is True + assert "稍后重试" in err.suggestions + assert "简化任务范围" in err.suggestions + assert err.original_error == str(exc) + + def test_loop_detected(self): + exc = LoopDetectedError(tool_name="shell", repetitions=3) + err = EmergencyRules.classify(exc) + assert err.error_code == "loop_detected" + assert err.retryable is True + assert "拆分任务" in err.suggestions + assert "检查工具参数" in err.suggestions + + def test_llm_provider_error(self): + exc = LLMProviderError("openai", "rate limited") + err = EmergencyRules.classify(exc) + assert err.error_code == "llm_failure" + assert err.retryable is True + assert "稍后重试" in err.suggestions + assert "切换模型" in err.suggestions + + def test_llm_error_subclass_also_classified(self): + """LLMProviderError is a subclass of LLMError; ensure isinstance check works.""" + from agentkit.core.exceptions import LLMError + + class CustomLLMError(LLMError): + pass + + err = EmergencyRules.classify(CustomLLMError("custom")) + # CustomLLMError is NOT a LLMProviderError, falls through to generic + assert err.error_code == "internal_error" + + def test_generic_exception_internal_error(self): + err = EmergencyRules.classify(Exception("unknown boom")) + assert err.error_code == "internal_error" + assert err.retryable is False + assert "联系管理员" in err.suggestions + assert err.original_error == "unknown boom" + + def test_task_cancelled_raises(self): + """TaskCancelledError must propagate; classify() raises ValueError.""" + exc = TaskCancelledError(task_id="t1") + with pytest.raises(ValueError, match="TaskCancelledError"): + EmergencyRules.classify(exc) + + def test_subclass_of_timeout_classified(self): + """Subclasses of TaskTimeoutError are classified as timeout.""" + + class CustomTimeout(TaskTimeoutError): + def __init__(self): + super().__init__(task_id="custom", timeout_seconds=10) + + err = EmergencyRules.classify(CustomTimeout()) + assert err.error_code == "timeout" + + +# ── EmergencyError serialization ───────────────────── + + +class TestEmergencyErrorSerialization: + """to_dict / to_error_message on EmergencyError.""" + + def test_to_dict_produces_all_five_fields(self): + err = EmergencyError( + error_code="timeout", + message="任务执行超时。", + suggestions=["稍后重试", "简化任务范围"], + retryable=True, + original_error="Task t1 timed out after 30s", + ) + d = err.to_dict() + assert set(d.keys()) == { + "error_code", + "message", + "suggestions", + "retryable", + "original_error", + } + assert d["error_code"] == "timeout" + assert d["message"] == "任务执行超时。" + assert d["suggestions"] == ["稍后重试", "简化任务范围"] + assert d["retryable"] is True + assert d["original_error"] == "Task t1 timed out after 30s" + + def test_to_dict_suggestions_list_is_copy(self): + """to_dict returns a fresh list, not the internal reference.""" + suggestions = ["a", "b"] + err = EmergencyError( + error_code="x", + message="m", + suggestions=suggestions, + retryable=False, + original_error="e", + ) + d = err.to_dict() + assert d["suggestions"] is not suggestions + d["suggestions"].append("c") + assert err.suggestions == ["a", "b"] + + def test_to_error_message_with_suggestions(self): + err = EmergencyError( + error_code="timeout", + message="任务执行超时。", + suggestions=["稍后重试", "简化任务范围"], + retryable=True, + original_error="err", + ) + msg = err.to_error_message() + assert msg.startswith("任务执行超时。建议:") + assert "1) 稍后重试" in msg + assert "2) 简化任务范围" in msg + # Format mirrors EMPTY_LLM_RESPONSE style + assert msg.endswith("。") + + def test_to_error_message_no_suggestions(self): + err = EmergencyError( + error_code="x", + message="just a message", + suggestions=[], + retryable=False, + original_error="e", + ) + assert err.to_error_message() == "just a message" + + def test_to_error_message_single_suggestion(self): + err = EmergencyError( + error_code="x", + message="msg", + suggestions=["only one"], + retryable=False, + original_error="e", + ) + msg = err.to_error_message() + assert msg == "msg建议:1) only one。" + + +# ── Config override ────────────────────────────────── + + +class TestConfigOverride: + """classify() applies per-rule config overrides.""" + + def test_override_suggestions(self): + exc = TaskTimeoutError(task_id="t", timeout_seconds=1) + cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}} + err = EmergencyRules.classify(exc, config=cfg) + assert err.suggestions == ["自定义建议 A", "自定义建议 B"] + assert err.error_code == "timeout" + + def test_override_retryable(self): + exc = LLMProviderError("openai", "boom") + cfg = {"llm_failure": {"retryable": False}} + err = EmergencyRules.classify(exc, config=cfg) + assert err.retryable is False + + def test_override_message(self): + exc = LoopDetectedError(tool_name="x", repetitions=2) + cfg = {"loop_detected": {"message": "循环啦!"}} + err = EmergencyRules.classify(exc, config=cfg) + assert err.message == "循环啦!" + + def test_override_internal_error_rule(self): + cfg = {"internal_error": {"suggestions": ["联系客服"]}} + err = EmergencyRules.classify(Exception("boom"), config=cfg) + assert err.error_code == "internal_error" + assert err.suggestions == ["联系客服"] + + def test_config_none_uses_defaults(self): + err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1)) + assert err.error_code == "timeout" + assert err.retryable is True + + def test_config_empty_dict_uses_defaults(self): + err = EmergencyRules.classify( + TaskTimeoutError(task_id="t", timeout_seconds=1), config={} + ) + assert err.error_code == "timeout" + assert err.retryable is True + + +# ── TaskResult.error_struct extension ──────────────── + + +def _make_task_result( + error_struct: dict | None = None, error_message: str | None = None +) -> TaskResult: + now = datetime.now(timezone.utc) + return TaskResult( + task_id="t1", + agent_name="a1", + status="completed", + output_data={"k": "v"}, + error_message=error_message, + started_at=now, + completed_at=now, + metrics={"m": 1}, + error_struct=error_struct, + ) + + +class TestTaskResultErrorStruct: + """TaskResult.error_struct field — backward-compatible extension.""" + + def test_default_error_struct_is_none(self): + tr = _make_task_result() + assert tr.error_struct is None + + def test_to_dict_without_error_struct_preserves_existing_shape(self): + """error_struct=None → to_dict() output has NO error_struct key (byte-for-byte).""" + tr = _make_task_result() + d = tr.to_dict() + assert "error_struct" not in d + # Existing keys unchanged + assert set(d.keys()) == { + "task_id", + "agent_name", + "status", + "output_data", + "error_message", + "started_at", + "completed_at", + "metrics", + } + + def test_to_dict_with_error_struct_includes_key(self): + struct = { + "error_code": "timeout", + "message": "超时", + "suggestions": ["重试"], + "retryable": True, + "original_error": "boom", + } + tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。") + d = tr.to_dict() + assert d["error_struct"] == struct + assert d["error_message"] == "超时建议:1) 重试。" + + def test_from_dict_round_trip_with_error_struct(self): + struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"} + tr = _make_task_result(error_struct=struct) + d = tr.to_dict() + restored = TaskResult.from_dict(d) + assert restored.error_struct == struct + + def test_from_dict_without_error_struct_defaults_none(self): + tr = _make_task_result() + d = tr.to_dict() + # Simulate legacy data without error_struct key + restored = TaskResult.from_dict(d) + assert restored.error_struct is None + + def test_error_message_and_error_struct_coexist(self): + """Both fields can be set simultaneously (parallel contract per KTD2).""" + struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"} + tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。") + d = tr.to_dict() + assert d["error_message"] == "超时建议:1) 重试。" + assert d["error_struct"] == struct diff --git a/tests/unit/test_fallback_chain.py b/tests/unit/test_fallback_chain.py new file mode 100644 index 0000000..0ada39d --- /dev/null +++ b/tests/unit/test_fallback_chain.py @@ -0,0 +1,404 @@ +"""G7/U3 — Three-tier fallback chain wiring tests. + +Verifies Main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) +at chat REST path. Mocks ReActEngine + ReflexionEngine + LLMGateway. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.core.exceptions import ( + LLMProviderError, + LoopDetectedError, + TaskCancelledError, + TaskTimeoutError, +) +from agentkit.core.react import ReActResult +from agentkit.server._fallback_chain import execute_with_fallback_chain + + +def _make_react_result(status: str = "success", output: str = "ok") -> ReActResult: + return ReActResult( + output=output, + trajectory=[], + total_steps=1, + total_tokens=10, + status=status, + ) + + +def _make_react_engine(result=None, raises=None): + """Build a fake ReActEngine with .execute returning result or raising.""" + engine = MagicMock() + engine.reset = MagicMock() + if raises is not None: + engine.execute = AsyncMock(side_effect=raises) + else: + engine.execute = AsyncMock(return_value=result or _make_react_result()) + return engine + + +def _make_llm_gateway(): + gw = MagicMock() + gw.chat = AsyncMock(return_value=MagicMock(content="recovered")) + return gw + + +def _make_reflexion_result(status: str = "success", output: str = "recovered"): + """Synthesize a ReflexionResult-like object.""" + return MagicMock( + status=status, + output=output, + trajectory=[], + total_steps=1, + total_tokens=5, + ) + + +@pytest.fixture +def patched_reflexion(monkeypatch): + """Patch ReflexionEngine used inside the chain to a controllable mock.""" + from agentkit.server import _fallback_chain + + instances: list[MagicMock] = [] + + class _MockReflexion: + def __init__(self, llm_gateway, max_reflections=1, **kwargs): + self._llm_gateway = llm_gateway + self._max_reflections = max_reflections + self.execute = AsyncMock(return_value=_make_reflexion_result()) + instances.append(self) + + monkeypatch.setattr(_fallback_chain, "ReflexionEngine", _MockReflexion) + return instances + + +# ─── Tier 1: Main ───────────────────────────────────────────────────────── + + +class TestMainTier: + @pytest.mark.asyncio + async def test_main_success_no_recovery_no_emergency(self): + engine = _make_react_engine(result=_make_react_result(status="success", output="hello")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "success" + assert result.output == "hello" + assert result.error_struct is None + + @pytest.mark.asyncio + async def test_main_unknown_status_treated_as_success(self): + """Unknown status (not in soft_failure set) is treated as success-like.""" + engine = _make_react_engine(result=_make_react_result(status="partial", output="x")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "success" + + +# ─── Tier 2: Recovery ────────────────────────────────────────────────────── + + +class TestRecoveryTier: + @pytest.mark.asyncio + async def test_main_timeout_triggers_recovery_success(self, patched_reflexion): + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "recovered" + assert result.output == "recovered" + # ReflexionEngine was instantiated and called + assert len(patched_reflexion) == 1 + patched_reflexion[0].execute.assert_awaited_once() + + @pytest.mark.asyncio + async def test_main_loop_detected_triggers_recovery(self, patched_reflexion): + engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5)) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "recovered" + + @pytest.mark.asyncio + async def test_main_llm_provider_error_triggers_recovery(self, patched_reflexion): + engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="503")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "recovered" + + @pytest.mark.asyncio + async def test_main_soft_failure_status_triggers_recovery(self, patched_reflexion): + """Soft failure (empty_fallback) without exception still triggers Recovery.""" + engine = _make_react_engine(result=_make_react_result(status="empty_fallback", output="")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + assert result.status == "recovered" + + @pytest.mark.asyncio + async def test_recovery_disabled_skips_to_emergency(self): + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={"recovery": {"enabled": False}}, + ) + assert result.status == "emergency" + assert result.error_struct["error_code"] == "timeout" + + @pytest.mark.asyncio + async def test_recovery_failure_falls_through_to_emergency(self, patched_reflexion): + """Recovery raises → Emergency tier fires with original exception.""" + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + # Make ReflexionEngine.execute raise + patched_reflexion_instance = MagicMock() + patched_reflexion_instance.execute = AsyncMock( + side_effect=RuntimeError("reflexion crashed") + ) + # Override the patched class to use our instance + from agentkit.server import _fallback_chain + + original_cls = _fallback_chain.ReflexionEngine + + class _MockReflexionWithExc: + def __init__(self, **kwargs): + self.execute = patched_reflexion_instance.execute + + _fallback_chain.ReflexionEngine = _MockReflexionWithExc + try: + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + finally: + _fallback_chain.ReflexionEngine = original_cls + + assert result.status == "emergency" + assert result.error_struct["error_code"] == "timeout" + + @pytest.mark.asyncio + async def test_recovery_unsuccessful_status_falls_through(self, patched_reflexion): + """Recovery returns non-success status → Emergency fires.""" + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + # Make ReflexionEngine return unsuccessful result with empty output + from agentkit.server import _fallback_chain + + class _MockReflexionNoOutput: + def __init__(self, **kwargs): + self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None)) + + original_cls = _fallback_chain.ReflexionEngine + _fallback_chain.ReflexionEngine = _MockReflexionNoOutput + try: + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + finally: + _fallback_chain.ReflexionEngine = original_cls + + assert result.status == "emergency" + assert result.error_struct["error_code"] == "timeout" + + +# ─── Tier 3: Emergency ──────────────────────────────────────────────────── + + +class TestEmergencyTier: + @pytest.mark.asyncio + async def test_emergency_timeout_error_code(self, patched_reflexion): + # Make recovery fail (empty result) so Emergency fires + from agentkit.server import _fallback_chain + + class _MockReflexionEmpty: + def __init__(self, **kwargs): + self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None)) + + original_cls = _fallback_chain.ReflexionEngine + _fallback_chain.ReflexionEngine = _MockReflexionEmpty + try: + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + ) + finally: + _fallback_chain.ReflexionEngine = original_cls + + assert result.status == "emergency" + assert result.error_struct["error_code"] == "timeout" + assert result.error_struct["retryable"] is True + assert "建议" in result.output + + @pytest.mark.asyncio + async def test_emergency_loop_detected_error_code(self): + engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5)) + # Recovery disabled so Emergency fires directly + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={"recovery": {"enabled": False}}, + ) + assert result.status == "emergency" + assert result.error_struct["error_code"] == "loop_detected" + + @pytest.mark.asyncio + async def test_emergency_llm_failure_error_code(self): + engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="500")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={"recovery": {"enabled": False}}, + ) + assert result.status == "emergency" + assert result.error_struct["error_code"] == "llm_failure" + + @pytest.mark.asyncio + async def test_emergency_internal_error_for_generic_exception(self): + engine = _make_react_engine(raises=RuntimeError("unexpected")) + result = await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={"recovery": {"enabled": False}}, + ) + assert result.status == "emergency" + assert result.error_struct["error_code"] == "internal_error" + assert result.error_struct["retryable"] is False + + @pytest.mark.asyncio + async def test_task_cancelled_propagates_not_routed_to_emergency(self): + """TaskCancelledError must propagate, not be classified by Emergency.""" + engine = _make_react_engine(raises=TaskCancelledError(task_id="t1")) + with pytest.raises(TaskCancelledError): + await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={"recovery": {"enabled": False}}, + ) + + @pytest.mark.asyncio + async def test_emergency_disabled_reraises_original(self): + engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10)) + with pytest.raises(TaskTimeoutError): + await execute_with_fallback_chain( + react_engine=engine, + llm_gateway=_make_llm_gateway(), + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="default", + agent_name="a", + system_prompt=None, + fallback_chain_config={ + "recovery": {"enabled": False}, + "emergency": {"enabled": False}, + }, + ) + + +# ─── Config wiring ──────────────────────────────────────────────────────── + + +class TestServerConfigFallbackChain: + def test_fallback_chain_section_read_from_dict(self): + from agentkit.server.config import ServerConfig + + config = ServerConfig.from_dict( + { + "fallback_chain": { + "enabled": True, + "recovery": {"enabled": False, "max_retries": 3}, + "emergency": {"enabled": True}, + } + } + ) + assert config.fallback_chain["enabled"] is True + assert config.fallback_chain["recovery"] == {"enabled": False, "max_retries": 3} + assert config.fallback_chain["emergency"] == {"enabled": True} + + def test_fallback_chain_defaults_empty_when_absent(self): + from agentkit.server.config import ServerConfig + + config = ServerConfig.from_dict({}) + assert config.fallback_chain == {} diff --git a/tests/unit/test_phase_rollback.py b/tests/unit/test_phase_rollback.py new file mode 100644 index 0000000..4d9b6c6 --- /dev/null +++ b/tests/unit/test_phase_rollback.py @@ -0,0 +1,367 @@ +"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration. + +Characterization-first: captures pre-change behavior (rollback_command=None → +no rollback, checkpoint saved) before asserting new rollback behavior. +""" + +from __future__ import annotations + +import os +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.experts.orchestrator import TeamOrchestrator +from agentkit.experts.plan import PlanPhase, TeamPlan +from agentkit.orchestrator.rollback import RollbackExecutor + + +# ─── PlanPhase field characterization ────────────────────────────────────── + + +class TestPlanPhaseFields: + """PlanPhase serialization for new fields.""" + + def test_characterization_no_new_keys_when_unset(self): + """Default PlanPhase.to_dict() must not include new keys (KTD6 contract).""" + ph = PlanPhase(name="执行", assigned_expert="lead") + out = ph.to_dict() + assert "validation_command" not in out + assert "rollback_command" not in out + # Pre-change keys remain + assert out["name"] == "执行" + assert out["assigned_expert"] == "lead" + assert out["status"] == "pending" + + def test_characterization_from_dict_empty_yields_none(self): + """from_dict({}) produces both new fields as None.""" + ph = PlanPhase.from_dict({"name": "x"}) + assert ph.validation_command is None + assert ph.rollback_command is None + + def test_serialization_includes_keys_when_set(self): + ph = PlanPhase( + name="frontend", + validation_command="ruff check src/", + rollback_command="git checkout src/app.vue", + ) + out = ph.to_dict() + assert out["validation_command"] == "ruff check src/" + assert out["rollback_command"] == "git checkout src/app.vue" + + def test_serialization_round_trip(self): + ph = PlanPhase( + name="backend", + validation_command="pytest -x -q", + rollback_command="git checkout src/api.py", + ) + restored = PlanPhase.from_dict(ph.to_dict()) + assert restored.validation_command == "pytest -x -q" + assert restored.rollback_command == "git checkout src/api.py" + + def test_only_validation_set_still_emits_validation_key(self): + """Asymmetric case: only validation_command set — only that key appears.""" + ph = PlanPhase(name="x", validation_command="echo ok") + out = ph.to_dict() + assert "validation_command" in out + assert "rollback_command" not in out + + +# ─── RollbackExecutor subprocess execution ───────────────────────────────── + + +@pytest.fixture +def tmp_workspace(tmp_path): + """Fresh working directory for subprocess execution.""" + return str(tmp_path) + + +class TestRollbackExecutor: + """RollbackExecutor happy/edge/failure paths.""" + + @pytest.mark.asyncio + async def test_execute_happy_path_zero_exit(self, tmp_workspace): + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + r = await ex.execute("true") + assert r.passed is True + assert r.exit_code == 0 + assert r.command == "true" + + @pytest.mark.asyncio + async def test_execute_failure_nonzero_exit(self, tmp_workspace): + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + r = await ex.execute("false") + assert r.passed is False + assert r.exit_code != 0 + + @pytest.mark.asyncio + async def test_execute_timeout(self, tmp_workspace): + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1) + r = await ex.execute("sleep 5") + assert r.passed is False + assert r.exit_code == -1 + assert "timed out" in r.stderr + + @pytest.mark.asyncio + async def test_execute_captures_stdout(self, tmp_workspace): + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + r = await ex.execute("echo hello-rollback") + assert r.passed is True + assert "hello-rollback" in r.stdout + + @pytest.mark.asyncio + async def test_validate_same_semantics_as_execute(self, tmp_workspace): + """validate() is just execute() with different intent marker.""" + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + r = await ex.validate("true") + assert r.passed is True + r2 = await ex.validate("false") + assert r2.passed is False + + @pytest.mark.asyncio + async def test_spawn_failure_returns_failed_result(self, tmp_workspace): + """Bad shell command should surface as failed result, not raise.""" + # Use a definitely-broken command — shell still spawns, returns non-zero. + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + r = await ex.execute("exit 7") + assert r.passed is False + assert r.exit_code == 7 + + @pytest.mark.asyncio + async def test_cwd_used_for_relative_paths(self, tmp_workspace): + """Files created in working_dir are visible via cwd-relative commands.""" + ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0) + await ex.execute("echo content > testfile.txt") + r = await ex.execute("test -f testfile.txt") + assert r.passed is True + + +# ─── Real git rollback integration ──────────────────────────────────────── + + +class TestGitRollbackIntegration: + """Real git repo fixture: writes file, rollback restores via git checkout.""" + + @pytest.mark.asyncio + async def test_git_checkout_restores_modified_file(self, tmp_path): + # Init repo and commit a baseline file + import subprocess + + repo = str(tmp_path) + subprocess.run(["git", "init", "-q"], cwd=repo, check=True) + subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True) + subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True) + baseline = "original content\n" + with open(os.path.join(repo, "foo.txt"), "w") as f: + f.write(baseline) + subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True) + subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True) + + # Mutate the file + with open(os.path.join(repo, "foo.txt"), "w") as f: + f.write("mutated content\n") + + # Run git checkout as rollback_command + ex = RollbackExecutor(working_dir=repo, timeout=10.0) + r = await ex.execute("git checkout foo.txt") + assert r.passed is True + + with open(os.path.join(repo, "foo.txt")) as f: + assert f.read() == baseline + + +# ─── TeamOrchestrator integration ───────────────────────────────────────── + + +def _make_team_mock(): + """Build a minimal ExpertTeam mock for TeamOrchestrator.""" + team = MagicMock() + team.team_id = "test-team" + team.status.value = "executing" + team.lead_expert = None + team.active_experts = [] + return team + + +def _make_orchestrator(team=None, checkpoint=None, workspace_root=None): + """Build a TeamOrchestrator with mocked team and checkpoint.""" + team = team or _make_team_mock() + orch = TeamOrchestrator( + team=team, + checkpoint=checkpoint, + workspace_root=workspace_root, + rollback_timeout=2.0, + ) + orch._broadcast_event = AsyncMock() + return orch + + +class TestOrchestratorRollbackIntegration: + """TeamOrchestrator phase failure path integration with rollback.""" + + @pytest.mark.asyncio + async def test_characterization_no_rollback_when_unset(self, tmp_path): + """Phase fails, rollback_command=None → checkpoint saved, no rollback events.""" + checkpoint = MagicMock() + checkpoint.save = AsyncMock() + orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path)) + + ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields + plan = TeamPlan(task="t", lead_expert="lead") + plan.phases = [ph] + + # Call _run_phase_rollback — but it should NOT be called from + # orchestrator path when fields unset. Verify by simulating the + # main-loop guard directly. + from agentkit.experts.plan import PhaseStatus as PS + + ph.status = PS.FAILED + # Simulate the guard condition in orchestrator.py:280-288 + should_save = True + if ( + ph.validation_command + and ph.rollback_command + and isinstance(Exception("x"), (Exception,)) + ): + should_save = await orch._run_phase_rollback(plan, ph) + # Guard should not fire + assert should_save is True + # No rollback events broadcast + assert orch._broadcast_event.call_count == 0 + + @pytest.mark.asyncio + async def test_validation_passes_no_rollback_executed(self, tmp_path): + """validation_command returns 0 → rollback NOT executed, checkpoint saved.""" + checkpoint = MagicMock() + checkpoint.save = AsyncMock() + orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path)) + + ph = PlanPhase( + name="p1", + assigned_expert="lead", + validation_command="true", # always passes + rollback_command="false", # would fail if executed + ) + plan = TeamPlan(task="t", lead_expert="lead") + plan.phases = [ph] + + should_save = await orch._run_phase_rollback(plan, ph) + assert should_save is True + + # Events: started + completed (no rollback) + events = [ + c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list + ] + assert "phase_rollback_started" in events + assert "phase_rollback_completed" in events + assert "phase_rollback_failed" not in events + + @pytest.mark.asyncio + async def test_validation_fails_rollback_succeeds(self, tmp_path): + """Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True.""" + # Create file under git tracking, then mutate it + import subprocess + + repo = str(tmp_path) + subprocess.run(["git", "init", "-q"], cwd=repo, check=True) + subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True) + subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True) + with open(os.path.join(repo, "f.txt"), "w") as f: + f.write("base\n") + subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True) + subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True) + with open(os.path.join(repo, "f.txt"), "w") as f: + f.write("mutated\n") + + checkpoint = MagicMock() + checkpoint.save = AsyncMock() + orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo) + + ph = PlanPhase( + name="p1", + assigned_expert="lead", + validation_command="false", # validation fails + rollback_command="git checkout f.txt", # rollback succeeds + ) + plan = TeamPlan(task="t", lead_expert="lead") + plan.phases = [ph] + + should_save = await orch._run_phase_rollback(plan, ph) + assert should_save is True + + # File restored to baseline + with open(os.path.join(repo, "f.txt")) as f: + assert f.read() == "base\n" + + @pytest.mark.asyncio + async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path): + """Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted.""" + checkpoint = MagicMock() + checkpoint.save = AsyncMock() + orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path)) + + ph = PlanPhase( + name="p1", + assigned_expert="lead", + validation_command="false", # fails + rollback_command="false", # also fails + ) + plan = TeamPlan(task="t", lead_expert="lead") + plan.phases = [ph] + + should_save = await orch._run_phase_rollback(plan, ph) + assert should_save is False # R21: skip checkpoint + + events = [c.args[0] for c in orch._broadcast_event.call_args_list] + assert "phase_rollback_started" in events + assert "phase_rollback_failed" in events + # phase_rollback_completed should NOT be in events + assert "phase_rollback_completed" not in events + + @pytest.mark.asyncio + async def test_rollback_timeout_skips_checkpoint(self, tmp_path): + """Rollback command times out → checkpoint NOT saved, failed event emitted.""" + checkpoint = MagicMock() + checkpoint.save = AsyncMock() + orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path)) + orch._rollback_timeout = 0.1 # short timeout + + ph = PlanPhase( + name="p1", + assigned_expert="lead", + validation_command="false", + rollback_command="sleep 5", + ) + plan = TeamPlan(task="t", lead_expert="lead") + plan.phases = [ph] + + should_save = await orch._run_phase_rollback(plan, ph) + assert should_save is False + + events = [c.args[0] for c in orch._broadcast_event.call_args_list] + assert "phase_rollback_failed" in events + + +# ─── Config wiring ──────────────────────────────────────────────────────── + + +class TestServerConfigRollback: + """ServerConfig rollback section wiring.""" + + def test_rollback_section_read_from_dict(self): + from agentkit.server.config import ServerConfig + + config = ServerConfig.from_dict( + { + "rollback": { + "default_timeout": 45.0, + } + } + ) + assert config.rollback == {"default_timeout": 45.0} + + def test_rollback_defaults_empty_when_absent(self): + from agentkit.server.config import ServerConfig + + config = ServerConfig.from_dict({}) + assert config.rollback == {}