feat(agent): Wave 2 medium coupling (G4/G7/G9) (#5)
This commit is contained in:
parent
78ed93fc81
commit
a2dcde01b8
|
|
@ -28,6 +28,29 @@ llm:
|
|||
coding: bailian-coding/qwen3-coder-plus
|
||||
chat: deepseek/deepseek-chat
|
||||
reasoning: deepseek/deepseek-reasoner
|
||||
# G4/U1: Auxiliary model for cost-sensitive tasks (summarization).
|
||||
# When set, ContextCompressor tries this alias first, falling back to
|
||||
# the main model on failure or empty content. Commented to preserve
|
||||
# default behavior — uncomment to enable.
|
||||
# auxiliary_model: fast
|
||||
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
|
||||
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
|
||||
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
|
||||
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
|
||||
# not `git checkout .` which would wipe unrelated changes).
|
||||
rollback:
|
||||
default_timeout: 30.0
|
||||
# G7/U3: Three-tier fallback chain at chat REST send_message.
|
||||
# main → Recovery (ReflexionEngine retry) → Emergency (rule-based classifier).
|
||||
# Wired only at chat REST path (KTD5); CLI / ReWOO / Reflexion internal
|
||||
# ReAct calls bypass the chain (no recursive loop).
|
||||
fallback_chain:
|
||||
enabled: true
|
||||
recovery:
|
||||
enabled: true
|
||||
max_retries: 1 # ReflexionEngine max_reflections override
|
||||
emergency:
|
||||
enabled: true
|
||||
session: {backend: memory}
|
||||
bus: {backend: memory}
|
||||
task_store: {backend: memory}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,481 @@
|
|||
---
|
||||
date: 2026-06-29
|
||||
type: feat
|
||||
origin: docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md
|
||||
---
|
||||
|
||||
# Wave 2 — Auxiliary LLM Routing, Three-Tier Fallback, Atomic Subtask Rollback
|
||||
|
||||
## Summary
|
||||
|
||||
Wave 2 of the advanced-agent gap optimization brainstorm. Three medium-coupling gaps: G4 routes summary calls through a cheaper auxiliary LLM (falling back to main on failure), G7 introduces a three-tier fallback chain (main → Recovery via existing `ReflexionEngine` → Emergency with structured errors), G9 binds atomic subtask rollback to `PlanPhase` (opt-in via `rollback_command`, coordinated with the existing U7 `PipelineCheckpoint`).
|
||||
|
||||
---
|
||||
|
||||
## Problem Frame
|
||||
|
||||
Wave 1 shipped self-contained quick wins (G1/G2/G3/G8). Wave 2 addresses the medium-coupling gaps that touch multiple layers and resolve two deferred-to-planning decisions surfaced in the brainstorm: G7 Emergency layer rule template shape, and G9 `rollback_command` default behavior. The constraints these gaps must respect come from `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — new fields must preserve existing contracts, dynamic plan mutations must persist immediately, and `PipelineCheckpoint` is in-memory dict + Redis fallback (not a DB row lock).
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
Carried from `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` § Wave 2 (R13-R21):
|
||||
|
||||
- R13. `ContextCompressor` summary task routes to auxiliary model (cheap, e.g. Gemini Flash / Doubao lite), not main model.
|
||||
- R14. `auxiliary_model` is configurable, separated from main `model`.
|
||||
- R15. Summary quality does not degrade: auxiliary failure falls back to main model (not to simple_summary).
|
||||
- R16. Main agent failure triggers Recovery layer (reuse `ReflexionEngine` Evaluate→Reflect→Retry).
|
||||
- R17. Recovery failure triggers Emergency layer (rule-based fallback, structured error + suggestion).
|
||||
- R18. Fallback chain is configurable (per-layer max retries, enable/disable Recovery/Emergency).
|
||||
- R19. `PlanPhase` (`experts/plan.py:147`) gains optional `validation_command` and `rollback_command` fields.
|
||||
- R20. Phase failure auto-executes `rollback_command` when configured (default `git checkout` pattern); coordinated with U7 checkpoint.
|
||||
- R21. `checkpoint.save` runs only after rollback validation passes (avoid persisting failed state).
|
||||
|
||||
Cross-cutting: R26 (configuration via `agentkit.yaml`), R27 (each gap ships a minimal self-check test, per ponytail rule).
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
KTD1. **Auxiliary routing is compressor-scoped, not gateway-level.** `LLMGateway` already has a per-model fallback chain (`gateway.py:170-227`), but it semantics are "main fails → try backup model", not "route by task type". G4 adds an `auxiliary_model` parameter to `ContextCompressor.__init__`; `_summarize` tries auxiliary first, falls back to main on failure. This avoids touching the gateway's existing fallback semantics and keeps the change to one file plus config wiring.
|
||||
|
||||
KTD2. **Emergency layer adds a parallel `error_struct` field, not replacing `error_message`.** `TaskResult.error_message: str | None` is serialized to API responses (`protocol.py:132-145`) — changing its type breaks frontend contracts. New `error_struct: dict | None` field carries `{error_code, message, suggestions, retryable}`. `error_message` remains the human-readable string (now derived from `error_struct.message` when present). Backward compatibility preserved.
|
||||
|
||||
KTD3. **Emergency rule classification by exception type, not by status field.** `TaskTimeoutError` → `timeout`, `LoopDetectedError` → `loop_detected`, `LLMProviderError` → `llm_failure`, `TaskCancelledError` → `cancelled` (not Emergency-eligible, propagates), generic `Exception` → `internal_error`. Each rule maps to a fixed message + suggestion list (no LLM-driven suggestions — pure rule-based, per brainstorm).
|
||||
|
||||
KTD4. **Recovery layer reuses `ReflexionEngine` with main model, not auxiliary.** `ReflexionEngine.execute()` already supports `evaluate_model`/`reflect_model` parameters (`reflexion.py:94-111`), defaulting to the main model. Recovery triggers on main agent failure; using the main model for evaluate/reflect maximizes diagnostic reliability (the same model that failed is best positioned to reflect on its own failure). Auxiliary model is reserved for G4's cost-sensitive compression path.
|
||||
|
||||
KTD5. **Three-tier chain wires at `chat.py:613` (WebSocket main path), not at `ReActEngine.execute()`.** ReAct has 5 call sites (`chat.py:613`, `rewoo.py:1204`, `cli/chat.py:338`, `reflexion.py:216`, `config_driven.py:695`). Wrapping at ReAct would force all 5 sites through Recovery/Emergency, including ReflexionEngine itself (creating a recursive loop). Wrapping at `chat.py:613` covers the primary user-facing path; CLI and ReWOO are out of scope (deferred to follow-up).
|
||||
|
||||
KTD6. **Rollback is opt-in via `rollback_command` field; no implicit default.** Brainstorm KTD5 mentioned "default `git checkout`" but auto-executing `git checkout` on every phase failure changes existing behavior (Finding 1: "新字段默认值须保持既有契约"). Resolution: `rollback_command: str | None = None`. When unset, no rollback executes (preserves existing contract). When set, the configured command runs after validation failure, before checkpoint save. `git checkout <files>` is the canonical pattern documented in YAML examples.
|
||||
|
||||
KTD7. **Rollback executes via `RollbackExecutor`, not via `ShellTool`.** `ShellTool._is_dangerous` would trigger `confirm_callback` for `git checkout` (not in `_SAFE_COMMAND_PREFIXES`, not in `_DANGEROUS_BINARY_FLAGS`). Orchestrator-internal execution bypasses ShellTool — uses `asyncio.create_subprocess_shell` with `cwd`/`timeout`/`proc.kill()` pattern from `verification_loop.py:67-103`. Audit logging added (not a whitelist concern; terminal_whitelist only governs `/api/v1/terminal/server` route per `terminal_server.py:227`).
|
||||
|
||||
KTD8. **G9 checkpoint ordering: validation → rollback → checkpoint save.** Currently `orchestrator.py:265` saves checkpoint unconditionally after phase finalizes (success or failure). R21 requires checkpoint save only after rollback validation passes. New ordering: phase fails → mark FAILED → mark dependents FAILED → execute `validation_command` (if set) → if validation fails, execute `rollback_command` (if set) → save checkpoint (only if rollback validation passed or no rollback configured).
|
||||
|
||||
---
|
||||
|
||||
## High-Level Technical Design
|
||||
|
||||
### Three-tier fallback chain (G7)
|
||||
|
||||
```mermaid
|
||||
stateDiagram-v2
|
||||
[*] --> Main: chat request
|
||||
Main --> MainSuccess: success
|
||||
Main --> Recovery: failure (exception or empty_fallback)
|
||||
Recovery --> RecoverySuccess: ReflexionEngine retry succeeds
|
||||
Recovery --> Emergency: max_reflections exhausted
|
||||
Emergency --> [*]: emit error_struct, terminal
|
||||
MainSuccess --> [*]: normal response
|
||||
RecoverySuccess --> [*]: recovered response
|
||||
```
|
||||
|
||||
### G9 phase failure + rollback sequence
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant O as TeamOrchestrator
|
||||
participant P as PlanPhase
|
||||
participant RE as RollbackExecutor
|
||||
participant CP as PipelineCheckpoint
|
||||
|
||||
O->>P: execute phase
|
||||
P-->>O: failure (exception)
|
||||
O->>O: mark FAILED + dependents FAILED
|
||||
alt validation_command set
|
||||
O->>RE: run validation_command
|
||||
RE-->>O: ValidationResult
|
||||
alt validation fails AND rollback_command set
|
||||
O->>RE: run rollback_command
|
||||
RE-->>O: RollbackResult
|
||||
alt rollback validation passes
|
||||
O->>CP: save checkpoint
|
||||
else rollback validation fails
|
||||
O->>O: log rollback_failed, skip checkpoint
|
||||
end
|
||||
else no rollback_command
|
||||
O->>CP: save checkpoint (no rollback)
|
||||
end
|
||||
else no validation_command
|
||||
O->>CP: save checkpoint (no validation)
|
||||
end
|
||||
```
|
||||
|
||||
### Configuration shape (YAML)
|
||||
|
||||
```yaml
|
||||
# G4: Auxiliary LLM for compression
|
||||
llm:
|
||||
auxiliary_model: fast # alias resolved via existing model_aliases mechanism
|
||||
|
||||
# G7: Three-tier fallback chain
|
||||
fallback_chain:
|
||||
enabled: true
|
||||
recovery:
|
||||
enabled: true
|
||||
max_retries: 1 # ReflexionEngine max_reflections override
|
||||
emergency:
|
||||
enabled: true
|
||||
|
||||
# G9: Rollback configuration (per-phase opt-in via PlanPhase.rollback_command)
|
||||
rollback:
|
||||
default_timeout: 30.0 # RollbackExecutor subprocess timeout
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Implementation Units
|
||||
|
||||
### U1. G4 — Auxiliary LLM Routing in ContextCompressor
|
||||
|
||||
**Goal:** Route `_summarize` LLM calls through `auxiliary_model` when configured; fall back to main model on auxiliary failure (not to `_simple_summary`).
|
||||
|
||||
**Requirements:** R13, R14, R15, R26, R27
|
||||
|
||||
**Dependencies:** none (self-contained)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/core/compressor.py` (add `auxiliary_model` param + routing in `_summarize`)
|
||||
- Modify: `src/agentkit/llm/config.py` (add `auxiliary_model: str | None` field to `LLMConfig`)
|
||||
- Modify: `src/agentkit/server/config.py` (`_build_llm_config` reads `auxiliary_model` from llm section)
|
||||
- Modify: `agentkit.yaml` (document `llm.auxiliary_model: fast` in llm section)
|
||||
- Create: `tests/unit/test_compressor_auxiliary.py`
|
||||
|
||||
**Approach:** `ContextCompressor.__init__` gains `auxiliary_model: str | None = None` param (after existing `model` param). `_summarize` (compressor.py:123-158) restructuring:
|
||||
|
||||
1. If `auxiliary_model` is set and differs from `model`, try `auxiliary_model` first via `self._llm_gateway.chat(model=self._auxiliary_model, ...)`.
|
||||
2. **Truthiness check (Finding 4 anti-pattern):** treat empty `response.content` (None or whitespace-only) as failure, not success. On failure, log and fall through to main model.
|
||||
3. If auxiliary succeeds with non-empty content, return it.
|
||||
4. If auxiliary fails (exception OR empty content), retry with main `self._model`.
|
||||
5. Existing `except Exception → _simple_summary` block remains as the final degradation tier.
|
||||
|
||||
`LLMConfig.auxiliary_model` reads from `data.get("auxiliary_model")` in `_build_llm_config`. Agentkit.yaml already declares `fast: bailian-coding/qwen-turbo` alias — this is the canonical auxiliary target.
|
||||
|
||||
**Execution note:** characterization-first. Capture current `_summarize` behavior (single model, exception → simple_summary) with `auxiliary_model=None` tests, then add `auxiliary_model="fast"` tests for new routing behavior.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `LLMGateway._get_models_to_try` (`gateway.py:170-227`) — existing fallback chain pattern
|
||||
- Wave 1's `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections, `config.py:239-273`)
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: `auxiliary_model="fast"` set, auxiliary returns non-empty content → result is auxiliary content, main model not called
|
||||
- Empty content fallback: auxiliary returns `content=""` (or `None`) → main model called, main content returned (covers Finding 4 anti-pattern)
|
||||
- Auxiliary exception: auxiliary raises `LLMProviderError` → main model called, main content returned
|
||||
- Both fail: auxiliary raises, main raises → `_simple_summary` returned (existing degradation preserved)
|
||||
- Characterization: `auxiliary_model=None` → behavior matches current code (single model call)
|
||||
- Config wiring: `LLMConfig.from_dict` reads `auxiliary_model` field from dict; `ServerConfig._build_llm_config` passes it through
|
||||
- Audit: auxiliary call uses `agent_name="compressor"`, `task_type="summarization"` (preserved for usage tracking)
|
||||
|
||||
**Verification:** All test scenarios pass; existing `tests/unit/test_compressor*.py` (if any) still pass with `auxiliary_model=None` default; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
### U2. G7 — Emergency Layer Rule Template + TaskResult Extension
|
||||
|
||||
**Goal:** Add rule-based Emergency classifier with structured error output. No wiring yet — just the infrastructure (classifier + data structure + `fallback.py` extension).
|
||||
|
||||
**Requirements:** R17, R18, R26, R27
|
||||
|
||||
**Dependencies:** none (foundation for U3)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/core/fallback.py` (add `EmergencyRules` class + `EmergencyError` dataclass, preserve existing 3 constants)
|
||||
- Modify: `src/agentkit/core/protocol.py` (add `error_struct: dict | None = None` field to `TaskResult`, update `to_dict`)
|
||||
- Create: `tests/unit/test_emergency_rules.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
`EmergencyError` dataclass (in `fallback.py`):
|
||||
```
|
||||
@dataclass
|
||||
class EmergencyError:
|
||||
error_code: str # "timeout"|"loop_detected"|"llm_failure"|"internal_error"
|
||||
message: str # human-readable Chinese message (mirrors EMPTY_LLM_RESPONSE style)
|
||||
suggestions: list[str] # actionable user-facing suggestions
|
||||
retryable: bool # whether user retry might succeed
|
||||
original_error: str # str(exc) for traceability
|
||||
|
||||
def to_dict(self) -> dict: ...
|
||||
def to_error_message(self) -> str: ... # formatted "message\n建议:1) ... 2) ..."
|
||||
```
|
||||
|
||||
`EmergencyRules` class (rule-based classifier, no LLM):
|
||||
```
|
||||
class EmergencyRules:
|
||||
@staticmethod
|
||||
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
|
||||
# Match by exception type (TaskTimeoutError, LoopDetectedError, LLMProviderError, etc.)
|
||||
# config allows per-rule customization (suggestion overrides, retryable overrides)
|
||||
```
|
||||
|
||||
Rule mapping (initial set, expandable via config):
|
||||
- `TaskTimeoutError` → `error_code="timeout"`, `retryable=True`, suggestions: ["稍后重试", "简化任务范围"]
|
||||
- `LoopDetectedError` → `error_code="loop_detected"`, `retryable=True`, suggestions: ["拆分任务", "检查工具参数"]
|
||||
- `LLMProviderError` → `error_code="llm_failure"`, `retryable=True`, suggestions: ["稍后重试", "切换模型"]
|
||||
- `TaskCancelledError` → not classified (propagates as-is, Emergency not triggered)
|
||||
- Generic `Exception` → `error_code="internal_error"`, `retryable=False`, suggestions: ["联系管理员"]
|
||||
|
||||
`TaskResult` extension:
|
||||
```
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
# ... existing fields ...
|
||||
error_message: str | None # unchanged
|
||||
error_struct: dict | None = None # NEW: serialized EmergencyError.to_dict()
|
||||
```
|
||||
|
||||
`to_dict` includes `error_struct` when set; `error_message` continues to hold the human-readable string.
|
||||
|
||||
**Patterns to follow:**
|
||||
- Existing `EMPTY_LLM_RESPONSE` / `MAX_STEPS_REACHED` style in `fallback.py` (Chinese, "建议:..." format)
|
||||
- `VerificationResult.errors: list[str]` field pattern from `verification_loop.py:18-24`
|
||||
- `TaskResult.to_dict` pattern at `protocol.py:132-145`
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: `classify(TaskTimeoutError(...))` returns `EmergencyError(error_code="timeout", retryable=True)`
|
||||
- Each exception type maps to correct `error_code`
|
||||
- `TaskCancelledError` is NOT classified (caller must check before invoking `classify`)
|
||||
- `to_dict()` produces all 5 fields
|
||||
- `to_error_message()` formats suggestions as "建议:1) ... 2) ..."
|
||||
- `EmergencyRules.classify(Exception("unknown"))` → `error_code="internal_error"`, `retryable=False`
|
||||
- Config override: custom suggestion list for `timeout` rule via config dict
|
||||
- `TaskResult` with `error_struct` set: `to_dict()` includes both `error_message` and `error_struct`
|
||||
- `TaskResult` with `error_struct=None` (default): `to_dict()` matches current behavior (backward compat)
|
||||
|
||||
**Verification:** All scenarios pass; `fallback.py` existing 3 constants unchanged; `TaskResult.to_dict` for tasks without `error_struct` matches pre-change output byte-for-byte.
|
||||
|
||||
---
|
||||
|
||||
### U3. G7 — Three-Tier Fallback Chain Wiring
|
||||
|
||||
**Goal:** Wire main → Recovery (ReflexionEngine) → Emergency (EmergencyRules) at `chat.py:613`. Composes U2's infrastructure with existing `ReflexionEngine`.
|
||||
|
||||
**Requirements:** R16, R18, R26
|
||||
|
||||
**Dependencies:** U2 (EmergencyRules + TaskResult.error_struct)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/server/routes/chat.py` (wrap main agent call at L613 with three-tier chain)
|
||||
- Modify: `src/agentkit/server/config.py` (add `fallback_chain` section to `from_dict`)
|
||||
- Modify: `agentkit.yaml` (document `fallback_chain:` section)
|
||||
- Create: `tests/unit/test_fallback_chain.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
New helper module or inline function in `chat.py`:
|
||||
```
|
||||
async def execute_with_fallback_chain(
|
||||
agent: ConfigDrivenAgent,
|
||||
task: Task,
|
||||
config: dict, # fallback_chain config section
|
||||
llm_gateway: LLMGateway,
|
||||
) -> TaskResult:
|
||||
# Tier 1: Main
|
||||
try:
|
||||
result = await agent.execute(task)
|
||||
if result.status == "success" or result.status == "completed":
|
||||
return result
|
||||
# Treat non-success as soft failure → trigger Recovery
|
||||
raise AgentExecutionError(result.error_message or "main agent did not succeed")
|
||||
except (TaskTimeoutError, LoopDetectedError, LLMProviderError, AgentExecutionError) as exc:
|
||||
if not config.get("recovery", {}).get("enabled", True):
|
||||
return _to_emergency(exc, task)
|
||||
|
||||
# Tier 2: Recovery (ReflexionEngine)
|
||||
try:
|
||||
reflexion_engine = ReflexionEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
max_reflections=config.get("recovery", {}).get("max_retries", 1),
|
||||
# ... other params from agent's existing reflexion config
|
||||
)
|
||||
recovery_result = await reflexion_engine.execute(
|
||||
messages=task.messages, # rebuild from task
|
||||
tools=agent.get_tools(),
|
||||
model=task.model,
|
||||
agent_name=agent.name,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
if recovery_result.status == "success":
|
||||
return _recovery_to_task_result(recovery_result, task)
|
||||
except Exception as recovery_exc:
|
||||
logger.warning(f"Recovery layer failed: {recovery_exc}")
|
||||
|
||||
# Tier 3: Emergency
|
||||
return _to_emergency(exc, task)
|
||||
```
|
||||
|
||||
`_to_emergency(exc, task)` constructs `EmergencyError` via `EmergencyRules.classify(exc, config)`, then returns `TaskResult(status=FAILED, error_message=emergency.to_error_message(), error_struct=emergency.to_dict(), ...)`.
|
||||
|
||||
Wiring at `chat.py:613`: replace direct `agent.execute(task)` call with `execute_with_fallback_chain(...)`. Recovery config from `server_config.fallback_chain` (new `ServerConfig.fallback_chain: dict` field, mirroring Wave 1 pattern).
|
||||
|
||||
**Recovery layer scope (KTD5):** Only `chat.py:613` is wrapped. CLI (`cli/chat.py:338`), ReWOO (`rewoo.py:1204`), ReflexionEngine's internal ReAct call, and `config_driven.py:695` are NOT wrapped (would create recursive loop or unwanted coupling). These remain on the direct-execute path. Documented in `## Scope Boundaries`.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `config_driven.py:836` — existing `ReflexionEngine` instantiation pattern (constructor params, execute call)
|
||||
- Wave 1's `verification` config section in `ServerConfig.from_dict` (`config.py:240-273`)
|
||||
|
||||
**Test scenarios:**
|
||||
- Happy path: main agent succeeds → no Recovery, no Emergency triggered; `error_struct=None`
|
||||
- Main fails (timeout) → Recovery triggered → Recovery succeeds → `error_struct=None`, output from ReflexionEngine
|
||||
- Main fails (timeout) → Recovery triggered → Recovery fails (max_reflections exhausted) → Emergency triggered → `error_struct` populated with `error_code="timeout"`, `retryable=True`
|
||||
- Main fails (LoopDetectedError) → Emergency `error_code="loop_detected"`
|
||||
- Main fails (LLMProviderError) → Emergency `error_code="llm_failure"`
|
||||
- Main fails (TaskCancelledError) → propagates as-is (NOT routed to Emergency)
|
||||
- Main fails (generic Exception) → Emergency `error_code="internal_error"`, `retryable=False`
|
||||
- Config: `fallback_chain.recovery.enabled=false` → skip Recovery, go directly to Emergency
|
||||
- Config: `fallback_chain.emergency.enabled=false` → re-raise original exception (no Emergency)
|
||||
- Integration: full chain on real `ConfigDrivenAgent` with mocked LLM (mock main raises, mock ReflexionEngine succeeds)
|
||||
|
||||
**Verification:** All scenarios pass; existing chat WebSocket tests still pass; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
### U4. G9 — PlanPhase Rollback Fields + RollbackExecutor + TeamOrchestrator Integration
|
||||
|
||||
**Goal:** Add `validation_command`/`rollback_command` optional fields to `PlanPhase`; execute rollback on phase failure (opt-in); coordinate with U7 checkpoint ordering per R21.
|
||||
|
||||
**Requirements:** R19, R20, R21, R26, R27
|
||||
|
||||
**Dependencies:** none (uses existing U7 `PipelineCheckpoint` and `VerificationLoop` patterns)
|
||||
|
||||
**Files:**
|
||||
- Modify: `src/agentkit/experts/plan.py` (`PlanPhase` dataclass + `to_dict`/`from_dict` symmetry)
|
||||
- Modify: `src/agentkit/experts/orchestrator.py` (insert rollback execution between phase failure and checkpoint save)
|
||||
- Create: `src/agentkit/orchestrator/rollback.py` (`RollbackExecutor` class)
|
||||
- Modify: `src/agentkit/server/config.py` (add `rollback` section to `from_dict`)
|
||||
- Modify: `agentkit.yaml` (document `rollback:` section)
|
||||
- Create: `tests/unit/test_phase_rollback.py`
|
||||
|
||||
**Approach:**
|
||||
|
||||
**PlanPhase extension (`plan.py:147-233`):** Add two optional fields with `None` default (preserves existing contract per Finding 1):
|
||||
```
|
||||
validation_command: str | None = None
|
||||
rollback_command: str | None = None
|
||||
```
|
||||
Update `to_dict()` to include both fields (only when not None, to keep dict shape minimal). Update `from_dict()` to read both fields.
|
||||
|
||||
**`RollbackExecutor` class (new file `orchestrator/rollback.py`):** Mirrors `VerificationLoop` pattern (`verification_loop.py:67-103`):
|
||||
```
|
||||
class RollbackExecutor:
|
||||
def __init__(self, working_dir: str | None = None, timeout: float = 30.0): ...
|
||||
|
||||
async def execute(self, command: str) -> RollbackResult:
|
||||
# asyncio.create_subprocess_shell with cwd/timeout/proc.kill()
|
||||
# Returns RollbackResult(passed, exit_code, stdout, stderr, command)
|
||||
|
||||
async def validate(self, command: str) -> RollbackResult:
|
||||
# Same as execute but returns passed=False on non-zero exit
|
||||
```
|
||||
|
||||
`RollbackResult` dataclass: `{passed: bool, exit_code: int, stdout: str, stderr: str, command: str}`.
|
||||
|
||||
**TeamOrchestrator integration (`orchestrator.py:246-270`):** New ordering per KTD8:
|
||||
|
||||
```
|
||||
# Existing: phase failure detected, mark FAILED + dependents FAILED
|
||||
# (lines 247-261 unchanged)
|
||||
|
||||
# NEW: rollback phase (only if validation_command and rollback_command configured)
|
||||
should_save_checkpoint = True
|
||||
if ph.validation_command and ph.rollback_command:
|
||||
validator = RollbackExecutor(working_dir=self._workspace_root, timeout=...)
|
||||
validation_result = await validator.validate(ph.validation_command)
|
||||
if not validation_result.passed:
|
||||
rollback_result = await validator.execute(ph.rollback_command)
|
||||
if not rollback_result.passed:
|
||||
# Rollback failed → don't save checkpoint (R21)
|
||||
should_save_checkpoint = False
|
||||
logger.error(f"Rollback failed for phase {ph.id}: {rollback_result.stderr}")
|
||||
# Emit phase_rollback_failed event
|
||||
await self._broadcast_event("phase_rollback_failed", {...})
|
||||
|
||||
# Existing: checkpoint save (conditional now)
|
||||
if should_save_checkpoint and self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||
except Exception as e:
|
||||
logger.warning(...)
|
||||
```
|
||||
|
||||
**Events emitted (new):** `phase_rollback_started`, `phase_rollback_completed`, `phase_rollback_failed`. Existing `phase_failed` event unchanged (emitted before rollback).
|
||||
|
||||
**Config:** `rollback.default_timeout: float = 30.0` in `ServerConfig.rollback` dict (Wave 1 pattern).
|
||||
|
||||
**Security boundary (KTD7):** Rollback subprocess does NOT go through `ShellTool` (avoids `confirm_callback` for `git checkout`). Audit log emitted via `_broadcast_event`. `terminal_whitelist` is not consulted (only governs `/api/v1/terminal/server` route, per `terminal_server.py:227`).
|
||||
|
||||
**Execution note:** characterization-first. Test current behavior (`rollback_command=None` → no rollback, checkpoint saved) before adding rollback behavior.
|
||||
|
||||
**Patterns to follow:**
|
||||
- `VerificationLoop` subprocess execution pattern (`verification_loop.py:67-103`) — `asyncio.create_subprocess_shell` + `cwd` + `timeout` + `proc.kill()`
|
||||
- `PlanPhase.to_dict`/`from_dict` symmetric serialization pattern at `plan.py:186-233`
|
||||
- `TeamOrchestrator._execute_phase` event broadcasting pattern (`orchestrator.py:253-260`)
|
||||
- Wave 1's `ServerConfig.from_dict` extension template
|
||||
|
||||
**Test scenarios:**
|
||||
- Characterization: `PlanPhase()` with no `validation_command`/`rollback_command` → `to_dict()` output matches pre-change shape (no new keys); `from_dict({})` produces phase with both fields as None
|
||||
- Serialization: phase with `rollback_command="git checkout foo.py"` → `to_dict()` includes key; `from_dict(to_dict())` round-trips
|
||||
- RollbackExecutor happy path: `execute("git status")` returns `passed=True`, `exit_code=0`
|
||||
- RollbackExecutor timeout: `execute("sleep 10", timeout=0.1)` returns `passed=False`, exit_code=-1 (or similar)
|
||||
- RollbackExecutor failure: `execute("false")` returns `passed=False`, `exit_code=1`
|
||||
- Integration — opt-in default: phase fails, `rollback_command=None` → no RollbackExecutor call, checkpoint saved (existing behavior)
|
||||
- Integration — rollback configured, validation passes (returns 0): rollback NOT executed, checkpoint saved
|
||||
- Integration — rollback configured, validation fails, rollback succeeds (returns 0): rollback executed, checkpoint saved
|
||||
- Integration — rollback configured, validation fails, rollback fails (returns 1): rollback executed, checkpoint NOT saved (R21), `phase_rollback_failed` event emitted
|
||||
- Integration — real git repo fixture: phase writes file via `git checkout`, rollback `git checkout foo.py` restores file; assert file content matches pre-phase state
|
||||
|
||||
**Verification:** All scenarios pass; existing `tests/unit/test_pipeline_state.py` and `tests/unit/test_team_orchestrator*.py` (if any) still pass; ruff clean.
|
||||
|
||||
---
|
||||
|
||||
## Scope Boundaries
|
||||
|
||||
### Deferred to Follow-Up Work
|
||||
|
||||
- **Recovery wiring at non-chat call sites** (`cli/chat.py:338`, `rewoo.py:1204`, `config_driven.py:695`). KTD5 limits Wave 2 to the primary chat WebSocket path. Other entry points can adopt the same wrapper in a follow-up.
|
||||
- **Recovery layer streaming events.** Wave 2's chain returns final `TaskResult`; SSE events for "recovery_started"/"recovery_completed" would require deeper WebSocket protocol changes.
|
||||
- **Patch-level rollback** (`git apply -R <patch>`). KTD6 + KTD7 scope Wave 2 to `git checkout <files>` pattern. Patch-level requires extending `CheckpointData` schema to track patches — Wave 3 candidate.
|
||||
- **DB-backed atomic claim for checkpoint.** Finding 3 noted `PipelineCheckpoint` is in-memory dict + Redis fallback, not a PostgreSQL `FOR UPDATE SKIP LOCKED` pattern. Multi-process atomicity is out of scope; single-process `asyncio.Lock` (already present at `orchestrator.py:53`) is sufficient.
|
||||
|
||||
### Outside this product's identity (carried from brainstorm)
|
||||
|
||||
- Wave 3 (G5/G6) — tree-sitter integration, SOLO four-stage state machine. Strategic direction locked in brainstorm KTD6/KTD7; implementation design deferred to Wave 3 plan.
|
||||
- Node-level checkpoint (ReAct single-step). Stage-level (U7) satisfies core need.
|
||||
- DeerFlow-style disk filesystem. Redis is the persistence layer.
|
||||
- Full LangGraph migration. Self-built architecture stays.
|
||||
|
||||
---
|
||||
|
||||
## Risks & Dependencies
|
||||
|
||||
- **Risk: Auxiliary model availability.** If `auxiliary_model` alias is not configured in `agentkit.yaml`, `LLMGateway` raises `ModelNotFoundError`. Mitigation: `ContextCompressor` catches this in the auxiliary try-block and falls through to main model (R15 fallback path). Documented in YAML comments.
|
||||
- **Risk: Recovery layer recursive loop.** `ReflexionEngine.execute()` internally calls `ReActEngine.execute()` (5th call site). If Recovery itself fails, it must NOT trigger another Recovery — KTD5 wires only at `chat.py:613`, so ReflexionEngine's internal ReAct call bypasses the chain. Recursive loop is structurally impossible.
|
||||
- **Risk: `git checkout` destructive scope.** Misconfigured `rollback_command` (e.g., `git checkout .`) could wipe unrelated changes. Mitigation: rollback is opt-in per-phase (KTD6); audit log via `phase_rollback_started`/`phase_rollback_failed` events; documented YAML examples use file-scoped commands (`git checkout <specific_files>`).
|
||||
- **Risk: `TaskResult.error_struct` field addition breaks pickle/serialization.** `TaskResult` is a dataclass with `to_dict`; new optional field with `None` default is backward-compatible for `to_dict` consumers. Pickle deserialization of pre-change data still works (new field defaults to None).
|
||||
- **Dependency: Wave 1 merged.** `ServerConfig.from_dict` extension template (prompt_cache/streaming/verification sections) is established and tested. Wave 2's 3 new sections (auxiliary_model, fallback_chain, rollback) reuse this exact pattern.
|
||||
- **Dependency: U7 PipelineCheckpoint already shipped.** `orchestrator/checkpoint.py:56` exists with `save(plan_id, phase, plan_status)` API; U4 only adjusts the calling order, not the API.
|
||||
|
||||
---
|
||||
|
||||
## Sources / Research
|
||||
|
||||
- Brainstorm: `docs/brainstorms/2026-06-29-advanced-agent-gap-optimization-requirements.md` (Wave 2 section, R13-R21, KTD4/KTD5)
|
||||
- Wave 1 plan: `docs/plans/2026-06-29-002-feat-agent-wave1-quick-wins-plan.md` (ServerConfig.from_dict extension template established)
|
||||
- Finding 1 (contract preservation rule): `docs/solutions/logic-errors/long-horizon-reliability-code-review-fixes.md` — "新字段默认值须保持既有契约" drives KTD6's opt-in default for `rollback_command`
|
||||
- Finding 2 (retry-storm defense): `docs/solutions/security-issues/portal-platform-security-reliability-fixes.md` — Emergency layer's terminal `error_code` pattern (don't propagate unhandled exceptions)
|
||||
- Finding 3 (atomic claim pattern): `docs/solutions/architecture-patterns/bitable-companion-service-security-reliability-patterns.md` — `SKIP LOCKED` not reusable; `PipelineCheckpoint` is in-memory dict + Redis
|
||||
- Finding 4 (empty-response anti-pattern): `docs/solutions/ui-bugs/tauri-reload-loses-session.md` — auxiliary LLM must check truthy return value, not just absence of exception; drives U1's empty-content fallback
|
||||
- Code locations verified during planning:
|
||||
- `src/agentkit/core/compressor.py:35-44,123-158` — `_summarize` injection point for G4
|
||||
- `src/agentkit/llm/gateway.py:170-227` — existing per-model fallback chain (semantics differ from G4's task-type routing)
|
||||
- `src/agentkit/llm/config.py:198-257` — `LLMConfig` dataclass and `from_dict`
|
||||
- `src/agentkit/core/reflexion.py:68-111` — `ReflexionEngine.__init__` and `execute()` signatures (already supports `evaluate_model`/`reflect_model`)
|
||||
- `src/agentkit/core/fallback.py` (full file, 19 lines) — 3 existing constants; `EmergencyRules` adds alongside
|
||||
- `src/agentkit/core/protocol.py:118-145` — `TaskResult` dataclass and `to_dict`
|
||||
- `src/agentkit/experts/plan.py:147-233` — `PlanPhase` dataclass, `to_dict`/`from_dict`
|
||||
- `src/agentkit/experts/orchestrator.py:246-270` — phase failure capture + checkpoint save site (U4 integration point)
|
||||
- `src/agentkit/orchestrator/checkpoint.py:56` — `PipelineCheckpoint` (U7, no API change needed)
|
||||
- `src/agentkit/core/verification_loop.py:67-103` — subprocess execution pattern for `RollbackExecutor`
|
||||
- `src/agentkit/tools/shell.py:168-174,526-534` — `_DANGEROUS_BINARY_FLAGS` (git checkout not listed), `_is_dangerous` logic (KTD7 rationale)
|
||||
|
|
@ -41,6 +41,7 @@ class ContextCompressor:
|
|||
model_context_limit: int = 128_000,
|
||||
headroom_threshold: float = 0.8,
|
||||
min_tokens: int = 8_000,
|
||||
auxiliary_model: str | None = None,
|
||||
):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_tokens = max_tokens
|
||||
|
|
@ -51,6 +52,11 @@ class ContextCompressor:
|
|||
self._model_context_limit = model_context_limit
|
||||
self._headroom_threshold = headroom_threshold
|
||||
self._min_tokens = min_tokens
|
||||
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
|
||||
# When set and differs from main model, _summarize tries auxiliary first,
|
||||
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
|
||||
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
|
||||
self._auxiliary_model = auxiliary_model
|
||||
|
||||
def should_compress(self, messages: list[dict]) -> bool:
|
||||
"""Check if compression should be triggered based on headroom ratio.
|
||||
|
|
@ -101,10 +107,12 @@ class ContextCompressor:
|
|||
# Build compressed message list
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
compressed.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
}
|
||||
)
|
||||
compressed.extend(recent_msgs)
|
||||
|
||||
# Recursive check: if still over budget, compress again
|
||||
|
|
@ -114,22 +122,30 @@ class ContextCompressor:
|
|||
return self._truncate(compressed)
|
||||
if len(recent_msgs) > 1:
|
||||
# Try keeping fewer recent messages
|
||||
return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1)
|
||||
return await self._compress_aggressive(
|
||||
messages, _compression_depth=_compression_depth + 1
|
||||
)
|
||||
# Last resort: truncate
|
||||
return self._truncate(compressed)
|
||||
|
||||
return compressed
|
||||
|
||||
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
|
||||
"""Summarize a list of messages using LLM"""
|
||||
"""Summarize a list of messages using LLM.
|
||||
|
||||
G4/U1: When ``auxiliary_model`` is configured and differs from the main
|
||||
model, try auxiliary first (cost-optimization). On auxiliary failure OR
|
||||
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
|
||||
fall back to main model. Existing ``_simple_summary`` degradation
|
||||
preserved as the final tier when main model also fails.
|
||||
"""
|
||||
if not self._llm_gateway:
|
||||
# No LLM available, do simple truncation
|
||||
return self._simple_summary(messages)
|
||||
|
||||
# Build summary prompt
|
||||
conversation_text = "\n".join(
|
||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
|
||||
for m in messages
|
||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
|
||||
)
|
||||
|
||||
# Pre-truncate if conversation_text exceeds safe token threshold
|
||||
|
|
@ -145,6 +161,25 @@ class ContextCompressor:
|
|||
f"{conversation_text}"
|
||||
)
|
||||
|
||||
# G4: Try auxiliary model first when configured (cheap route).
|
||||
if self._auxiliary_model and self._auxiliary_model != self._model:
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model=self._auxiliary_model,
|
||||
agent_name="compressor",
|
||||
task_type="summarization",
|
||||
)
|
||||
# Finding 4: empty content is a failure, not a success.
|
||||
if response.content and response.content.strip():
|
||||
return response.content
|
||||
logger.info("Auxiliary model returned empty content, falling back to main model")
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Auxiliary model summarization failed, falling back to main model: {e}"
|
||||
)
|
||||
|
||||
# Main model path (or auxiliary fallback).
|
||||
try:
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
|
|
@ -166,7 +201,9 @@ class ContextCompressor:
|
|||
parts.append(f"[{role}]: {content}...")
|
||||
return "\n".join(parts)
|
||||
|
||||
async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
|
||||
async def _compress_aggressive(
|
||||
self, messages: list[dict], _compression_depth: int = 0
|
||||
) -> list[dict]:
|
||||
"""More aggressive compression when standard compression isn't enough"""
|
||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||
non_system = [m for m in messages if m.get("role") != "system"]
|
||||
|
|
@ -176,10 +213,12 @@ class ContextCompressor:
|
|||
summary = await self._summarize(non_system[:-1])
|
||||
compressed = list(system_msgs)
|
||||
if summary:
|
||||
compressed.append({
|
||||
compressed.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"## Conversation Summary\n{summary}",
|
||||
})
|
||||
}
|
||||
)
|
||||
compressed.append(non_system[-1])
|
||||
return compressed
|
||||
|
||||
|
|
@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
if provider == "headroom":
|
||||
try:
|
||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||
|
||||
compressor = HeadroomCompressor(config)
|
||||
if compressor.is_available():
|
||||
return compressor
|
||||
|
|
@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"HeadroomCompressor module not available. "
|
||||
"Falling back to ContextCompressor."
|
||||
"HeadroomCompressor module not available. Falling back to ContextCompressor."
|
||||
)
|
||||
# Fallback to summary compressor
|
||||
return ContextCompressor(
|
||||
|
|
@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
|||
|
||||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||||
cache_key = hashlib.md5(
|
||||
json.dumps(variables or {}, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
if not hasattr(template, '_render_cache'):
|
||||
if not hasattr(template, "_render_cache"):
|
||||
template._render_cache = {}
|
||||
|
||||
if cache_key in template._render_cache:
|
||||
|
|
@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic
|
|||
|
||||
def clear_cache(template) -> None:
|
||||
"""Clear the render cache on a PromptTemplate instance"""
|
||||
if hasattr(template, '_render_cache'):
|
||||
if hasattr(template, "_render_cache"):
|
||||
template._render_cache.clear()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,22 @@
|
|||
|
||||
All layers (ReActEngine, Portal, Chat) should use these constants
|
||||
to ensure consistent user-facing messages.
|
||||
|
||||
G7/U2: Also hosts ``EmergencyError`` and ``EmergencyRules`` for the
|
||||
three-tier fallback chain's Emergency layer (rule-based classifier,
|
||||
no LLM). See ``docs/plans/2026-06-29-003-feat-agent-wave2-medium-coupling-plan.md``.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
|
||||
# When LLM returns empty content after all fallback models exhausted
|
||||
EMPTY_LLM_RESPONSE = (
|
||||
"模型未返回有效内容,已尝试备用模型仍未成功。"
|
||||
|
|
@ -16,3 +30,135 @@ MAX_STEPS_REACHED = "已达到最大推理步数,但仍未得到完整结论
|
|||
|
||||
# When a shell command succeeds but produces no output
|
||||
SHELL_NO_OUTPUT = "[命令执行成功,无输出内容]"
|
||||
|
||||
|
||||
# ── G7/U2: Emergency layer ──────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmergencyError:
|
||||
"""Structured error produced by the Emergency layer (rule-classified).
|
||||
|
||||
Carries a stable ``error_code`` for programmatic dispatch (frontend
|
||||
retry UI, telemetry), a human-readable ``message`` mirroring
|
||||
``EMPTY_LLM_RESPONSE`` style, actionable ``suggestions``, and the
|
||||
original exception string for traceability.
|
||||
|
||||
The ``retryable`` flag distinguishes recoverable user errors
|
||||
(timeout, loop, LLM hiccup) from internal bugs (retryable=False).
|
||||
"""
|
||||
|
||||
error_code: str # "timeout" | "loop_detected" | "llm_failure" | "internal_error"
|
||||
message: str # human-readable Chinese message
|
||||
suggestions: list[str] # actionable user-facing suggestions
|
||||
retryable: bool # whether a user retry might succeed
|
||||
original_error: str # str(exc) for traceability
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"error_code": self.error_code,
|
||||
"message": self.message,
|
||||
"suggestions": list(self.suggestions),
|
||||
"retryable": self.retryable,
|
||||
"original_error": self.original_error,
|
||||
}
|
||||
|
||||
def to_error_message(self) -> str:
|
||||
"""Format as a single human-readable string with suggestions.
|
||||
|
||||
Mirrors ``EMPTY_LLM_RESPONSE`` style: ``<message>建议:1) ... 2) ...``
|
||||
"""
|
||||
if not self.suggestions:
|
||||
return self.message
|
||||
suggestion_str = "".join(f"{i}) {s};" for i, s in enumerate(self.suggestions, 1))
|
||||
# Strip trailing ";" and prefix with "建议:"
|
||||
suggestion_str = suggestion_str.rstrip(";")
|
||||
return f"{self.message}建议:{suggestion_str}。"
|
||||
|
||||
|
||||
# Default rule set: (exception_type, error_code, message, suggestions, retryable)
|
||||
# ponytail: ceiling — rule-based, no LLM. Adding a new rule = append a tuple.
|
||||
# Upgrade path: LLM-driven suggestions would require a separate classifier
|
||||
# class (out of scope per brainstorm KTD4).
|
||||
_DEFAULT_RULES: list[tuple[type[Exception], str, str, list[str], bool]] = [
|
||||
(
|
||||
TaskTimeoutError,
|
||||
"timeout",
|
||||
"任务执行超时。",
|
||||
["稍后重试", "简化任务范围"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
LoopDetectedError,
|
||||
"loop_detected",
|
||||
"检测到推理循环。",
|
||||
["拆分任务", "检查工具参数"],
|
||||
True,
|
||||
),
|
||||
(
|
||||
LLMProviderError,
|
||||
"llm_failure",
|
||||
"LLM 服务调用失败。",
|
||||
["稍后重试", "切换模型"],
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
_DEFAULT_ERROR_CODE = "internal_error"
|
||||
_DEFAULT_MESSAGE = "Agent 执行内部错误。"
|
||||
_DEFAULT_SUGGESTIONS: list[str] = ["联系管理员"]
|
||||
_DEFAULT_RETRYABLE = False
|
||||
|
||||
|
||||
class EmergencyRules:
|
||||
"""Rule-based classifier for the Emergency layer.
|
||||
|
||||
Maps exception types to ``EmergencyError`` instances. No LLM, no I/O —
|
||||
pure function of ``(exception, config)``.
|
||||
|
||||
Caller responsibility: ``TaskCancelledError`` MUST propagate as-is
|
||||
(per KTD3); the caller checks ``isinstance(exc, TaskCancelledError)``
|
||||
before invoking :meth:`classify`. Calling :meth:`classify` with a
|
||||
``TaskCancelledError`` raises ``ValueError`` to surface the bug.
|
||||
|
||||
Config override shape (optional ``config`` arg):
|
||||
|
||||
```python
|
||||
{
|
||||
"timeout": {"suggestions": ["自定义建议"], "retryable": False},
|
||||
"llm_failure": {"message": "自定义消息"},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def classify(exc: Exception, config: dict | None = None) -> EmergencyError:
|
||||
if isinstance(exc, TaskCancelledError):
|
||||
# Contract: caller must check before invoking classify.
|
||||
raise ValueError(
|
||||
"TaskCancelledError must propagate as-is; caller must check "
|
||||
"before invoking EmergencyRules.classify"
|
||||
)
|
||||
|
||||
config = config or {}
|
||||
|
||||
for exc_type, code, message, suggestions, retryable in _DEFAULT_RULES:
|
||||
if isinstance(exc, exc_type):
|
||||
override = config.get(code, {}) if isinstance(config, dict) else {}
|
||||
return EmergencyError(
|
||||
error_code=code,
|
||||
message=override.get("message", message),
|
||||
suggestions=override.get("suggestions", list(suggestions)),
|
||||
retryable=override.get("retryable", retryable),
|
||||
original_error=str(exc),
|
||||
)
|
||||
|
||||
# Generic fallback
|
||||
override = config.get(_DEFAULT_ERROR_CODE, {}) if isinstance(config, dict) else {}
|
||||
return EmergencyError(
|
||||
error_code=_DEFAULT_ERROR_CODE,
|
||||
message=override.get("message", _DEFAULT_MESSAGE),
|
||||
suggestions=override.get("suggestions", list(_DEFAULT_SUGGESTIONS)),
|
||||
retryable=override.get("retryable", _DEFAULT_RETRYABLE),
|
||||
original_error=str(exc),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -128,6 +128,10 @@ class TaskResult:
|
|||
completed_at: datetime
|
||||
metrics: dict | None = None
|
||||
trace: Any | None = None
|
||||
# G7/U2: Emergency layer structured error. None preserves existing contract
|
||||
# (error_message alone carries the human-readable string). When set,
|
||||
# carries serialized EmergencyError.to_dict() for programmatic dispatch.
|
||||
error_struct: dict | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
d = {
|
||||
|
|
@ -142,6 +146,8 @@ class TaskResult:
|
|||
}
|
||||
if self.trace is not None:
|
||||
d["trace"] = self.trace.to_dict() if hasattr(self.trace, "to_dict") else self.trace
|
||||
if self.error_struct is not None:
|
||||
d["error_struct"] = self.error_struct
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
|
|
@ -162,6 +168,7 @@ class TaskResult:
|
|||
completed_at=completed_at or datetime.now(timezone.utc),
|
||||
metrics=data.get("metrics"),
|
||||
trace=data.get("trace"),
|
||||
error_struct=data.get("error_struct"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from typing import Any
|
|||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||
|
||||
from .expert import Expert
|
||||
from .plan import (
|
||||
|
|
@ -72,12 +73,17 @@ class TeamOrchestrator:
|
|||
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
|
||||
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
||||
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
||||
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
|
||||
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
|
||||
DEFAULT_ROLLBACK_TIMEOUT = 30.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
team: ExpertTeam,
|
||||
max_concurrent_phases: int | None = None,
|
||||
checkpoint: Any = None,
|
||||
workspace_root: str | None = None,
|
||||
rollback_timeout: float | None = None,
|
||||
) -> None:
|
||||
self._team = team
|
||||
# Track temporary agent names created for context isolation (KTD3)
|
||||
|
|
@ -93,6 +99,10 @@ class TeamOrchestrator:
|
|||
self._phase_semaphore = asyncio.Semaphore(limit)
|
||||
# U7: Pipeline checkpoint for crash recovery
|
||||
self._checkpoint = checkpoint
|
||||
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
|
||||
# Both default to no-op-friendly values so existing call sites behave identically.
|
||||
self._workspace_root = workspace_root
|
||||
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
|
||||
|
||||
async def execute(self, task: str) -> dict[str, Any]:
|
||||
"""Execute a task in pipeline mode.
|
||||
|
|
@ -262,8 +272,23 @@ class TeamOrchestrator:
|
|||
else:
|
||||
phase_results[ph.id] = result
|
||||
|
||||
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
|
||||
# When phase configures both validation_command and rollback_command:
|
||||
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
|
||||
# 2. if validation fails, run rollback_command
|
||||
# 3. if rollback passes (exit 0), save checkpoint
|
||||
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
|
||||
# When neither command is set, behavior is unchanged (existing save).
|
||||
should_save_checkpoint = True
|
||||
if (
|
||||
ph.validation_command
|
||||
and ph.rollback_command
|
||||
and isinstance(result, (Exception, asyncio.CancelledError))
|
||||
):
|
||||
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
|
||||
|
||||
# U7: Save checkpoint after phase finalizes (success or failure)
|
||||
if self._checkpoint is not None:
|
||||
if should_save_checkpoint and self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||
except Exception as e:
|
||||
|
|
@ -393,9 +418,7 @@ class TeamOrchestrator:
|
|||
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
||||
|
||||
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
|
||||
self._debate_count = sum(
|
||||
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
|
||||
)
|
||||
self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
|
||||
|
||||
logger.info(
|
||||
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
||||
|
|
@ -688,9 +711,9 @@ class TeamOrchestrator:
|
|||
and prev_phase.result
|
||||
):
|
||||
# U4: Resolve offloaded content from workspace
|
||||
collaboration_outputs[contract.from_expert] = (
|
||||
await self._read_dependency_output(prev_phase)
|
||||
)
|
||||
collaboration_outputs[
|
||||
contract.from_expert
|
||||
] = await self._read_dependency_output(prev_phase)
|
||||
break
|
||||
|
||||
# Emit expert_step event
|
||||
|
|
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
|
|||
# Recursively mark their dependents
|
||||
await self._mark_dependents_failed(ph.id, plan, phase_results)
|
||||
|
||||
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
|
||||
"""G9/U4: run validation_command + rollback_command for a failed phase.
|
||||
|
||||
Returns True if checkpoint save should proceed (R21 ordering).
|
||||
- Validation passes → save checkpoint (phase state recoverable)
|
||||
- Validation fails, rollback passes → save checkpoint (rolled back state)
|
||||
- Validation fails, rollback fails → skip checkpoint (broken state)
|
||||
- Subprocess spawn failure or timeout → skip checkpoint
|
||||
"""
|
||||
executor = RollbackExecutor(
|
||||
working_dir=self._workspace_root,
|
||||
timeout=self._rollback_timeout,
|
||||
)
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_started",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"validation_command": ph.validation_command,
|
||||
"rollback_command": ph.rollback_command,
|
||||
},
|
||||
)
|
||||
# ponytail: validate first; if validation passes, rollback is skipped (no need).
|
||||
validation = await executor.validate(ph.validation_command or "")
|
||||
if validation.passed:
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_completed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"rollback_executed": False,
|
||||
"validation_passed": True,
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
rollback = await executor.execute(ph.rollback_command or "")
|
||||
if rollback.passed:
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_completed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"rollback_executed": True,
|
||||
"validation_passed": False,
|
||||
"rollback_stdout": rollback.stdout,
|
||||
},
|
||||
)
|
||||
return True
|
||||
|
||||
logger.error(
|
||||
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
|
||||
)
|
||||
await self._broadcast_event(
|
||||
"phase_rollback_failed",
|
||||
{
|
||||
"plan_id": plan.id,
|
||||
"phase_id": ph.id,
|
||||
"phase_name": ph.name,
|
||||
"validation_passed": False,
|
||||
"rollback_exit_code": rollback.exit_code,
|
||||
"rollback_stderr": rollback.stderr,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
async def _synthesize_results(
|
||||
self, lead: Expert, task: str, completed_phases: list[PlanPhase]
|
||||
) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -182,6 +182,11 @@ class PlanPhase:
|
|||
collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
|
||||
rework_count: int = 0
|
||||
review_feedback: str | None = None
|
||||
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
|
||||
# validation_command runs first; if it fails, rollback_command runs.
|
||||
# canonical rollback pattern: `git checkout <specific_files>`.
|
||||
validation_command: str | None = None
|
||||
rollback_command: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""序列化为字典"""
|
||||
|
|
@ -192,7 +197,7 @@ class PlanPhase:
|
|||
result_str = self.result.get("content", str(self.result))
|
||||
else:
|
||||
result_str = str(self.result)
|
||||
return {
|
||||
out: dict[str, Any] = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"assigned_expert": self.assigned_expert,
|
||||
|
|
@ -206,6 +211,12 @@ class PlanPhase:
|
|||
"rework_count": self.rework_count,
|
||||
"review_feedback": self.review_feedback,
|
||||
}
|
||||
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
|
||||
if self.validation_command is not None:
|
||||
out["validation_command"] = self.validation_command
|
||||
if self.rollback_command is not None:
|
||||
out["rollback_command"] = self.rollback_command
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
||||
|
|
@ -230,6 +241,8 @@ class PlanPhase:
|
|||
collaboration_contracts=contracts,
|
||||
rework_count=data.get("rework_count", 0),
|
||||
review_feedback=data.get("review_feedback"),
|
||||
validation_command=data.get("validation_command"),
|
||||
rollback_command=data.get("rollback_command"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -203,6 +203,9 @@ class LLMConfig:
|
|||
model_aliases: dict[str, str] = field(default_factory=dict)
|
||||
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
||||
cache: CacheConfig | None = None
|
||||
# G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization).
|
||||
# Resolved via existing model_aliases mechanism. None = use main model only.
|
||||
auxiliary_model: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "LLMConfig":
|
||||
|
|
@ -254,4 +257,5 @@ class LLMConfig:
|
|||
model_aliases=data.get("model_aliases", {}),
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache,
|
||||
auxiliary_model=data.get("auxiliary_model"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
|
||||
|
||||
复用 VerificationLoop 的 asyncio.create_subprocess_shell 模式 (KTD7):
|
||||
绕过 ShellTool,避免 confirm_callback 对 `git checkout` 的拦截。
|
||||
|
||||
设计依据:
|
||||
- KTD6: 回滚是 opt-in 行为,未配置 rollback_command 时不会执行
|
||||
- KTD7: 不走 ShellTool,避免 _is_dangerous 触发 confirm_callback
|
||||
- R21: checkpoint.save 仅在回滚校验通过后调用
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RollbackResult:
|
||||
"""单次 subprocess 执行结果"""
|
||||
|
||||
passed: bool
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
command: str
|
||||
|
||||
|
||||
class RollbackExecutor:
|
||||
"""执行 validation_command / rollback_command 的子进程封装
|
||||
|
||||
与 VerificationLoop 同构,但语义不同:
|
||||
- validate(): 返回 passed=False 表示校验失败,需要触发 rollback
|
||||
- execute(): 返回 passed=False 表示回滚本身失败,需跳过 checkpoint.save (R21)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
working_dir: str | None = None,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
self._working_dir = working_dir
|
||||
self._timeout = timeout
|
||||
|
||||
async def _run(self, command: str) -> RollbackResult:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self._working_dir,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
|
||||
return RollbackResult(
|
||||
passed=False,
|
||||
exit_code=-1,
|
||||
stdout="",
|
||||
stderr=f"Failed to spawn command: {e}",
|
||||
command=command,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
await proc.wait()
|
||||
return RollbackResult(
|
||||
passed=False,
|
||||
exit_code=-1,
|
||||
stdout="",
|
||||
stderr=f"Command timed out after {self._timeout}s: {command}",
|
||||
command=command,
|
||||
)
|
||||
|
||||
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
|
||||
return RollbackResult(
|
||||
passed=proc.returncode == 0,
|
||||
exit_code=proc.returncode if proc.returncode is not None else -1,
|
||||
stdout=out_str,
|
||||
stderr=err_str,
|
||||
command=command,
|
||||
)
|
||||
|
||||
async def validate(self, command: str) -> RollbackResult:
|
||||
"""运行 validation_command,passed=False 表示需要触发 rollback"""
|
||||
return await self._run(command)
|
||||
|
||||
async def execute(self, command: str) -> RollbackResult:
|
||||
"""运行 rollback_command,passed=False 表示回滚本身失败 (R21)"""
|
||||
return await self._run(command)
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""G7/U3 — Three-tier fallback chain (main → Recovery → Emergency).
|
||||
|
||||
Wired at chat.py REST send_message endpoint. Composes U2's EmergencyRules
|
||||
with existing ReflexionEngine for the Recovery layer.
|
||||
|
||||
Scope (KTD5): Only the chat REST path is wrapped. CLI / ReWOO / Reflexion
|
||||
internal ReAct calls are NOT wrapped (would create recursive loop).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.fallback import EmergencyError, EmergencyRules
|
||||
from agentkit.core.react import ReActEngine, ReActResult
|
||||
from agentkit.core.reflexion import ReflexionEngine
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ReActResult.status values that indicate soft failure → trigger Recovery.
|
||||
# "success" is the only clean-pass; everything else is fallback-worthy.
|
||||
_SOFT_FAILURE_STATUSES = frozenset({"empty_fallback", "verify_failed", "timeout"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatExecutionResult:
|
||||
"""Wrapper produced by execute_with_fallback_chain.
|
||||
|
||||
Carries a ReActResult-like ``output`` field plus an optional
|
||||
``error_struct`` (set only when Emergency tier fires). The chat
|
||||
handler reads ``.output`` for the assistant reply and ``.error_struct``
|
||||
for the optional structured error payload.
|
||||
"""
|
||||
|
||||
output: str
|
||||
status: str # "success" | "recovered" | "emergency"
|
||||
error_struct: dict[str, Any] | None = None
|
||||
trajectory: list[Any] = field(default_factory=list)
|
||||
total_steps: int = 0
|
||||
total_tokens: int = 0
|
||||
fallback_strategy: str | None = None
|
||||
|
||||
|
||||
def _react_to_chat_result(react: ReActResult) -> ChatExecutionResult:
|
||||
return ChatExecutionResult(
|
||||
output=react.output,
|
||||
status="success",
|
||||
trajectory=react.trajectory,
|
||||
total_steps=react.total_steps,
|
||||
total_tokens=react.total_tokens,
|
||||
fallback_strategy=react.fallback_strategy,
|
||||
)
|
||||
|
||||
|
||||
def _reflexion_to_chat_result(reflexion_result: Any) -> ChatExecutionResult:
|
||||
"""Best-effort conversion from ReflexionResult to ChatExecutionResult."""
|
||||
output = getattr(reflexion_result, "output", None) or getattr(
|
||||
reflexion_result, "final_answer", ""
|
||||
)
|
||||
return ChatExecutionResult(
|
||||
output=output or "",
|
||||
status="recovered",
|
||||
trajectory=getattr(reflexion_result, "trajectory", []) or [],
|
||||
total_steps=getattr(reflexion_result, "total_steps", 0),
|
||||
total_tokens=getattr(reflexion_result, "total_tokens", 0),
|
||||
fallback_strategy="reflexion_recovery",
|
||||
)
|
||||
|
||||
|
||||
def _to_emergency(exc: Exception, config: dict | None) -> ChatExecutionResult:
|
||||
emergency: EmergencyError = EmergencyRules.classify(exc, config)
|
||||
return ChatExecutionResult(
|
||||
output=emergency.to_error_message(),
|
||||
status="emergency",
|
||||
error_struct=emergency.to_dict(),
|
||||
fallback_strategy="emergency",
|
||||
)
|
||||
|
||||
|
||||
async def execute_with_fallback_chain(
|
||||
*,
|
||||
react_engine: ReActEngine,
|
||||
llm_gateway: LLMGateway,
|
||||
messages: list[dict[str, str]],
|
||||
tools: list[Any] | None,
|
||||
model: str,
|
||||
agent_name: str,
|
||||
system_prompt: str | None,
|
||||
fallback_chain_config: dict | None = None,
|
||||
) -> ChatExecutionResult:
|
||||
"""Three-tier fallback chain: Main → Recovery (ReflexionEngine) → Emergency.
|
||||
|
||||
KTD5: only this entry point wraps the chain. ReflexionEngine's internal
|
||||
ReAct call bypasses the chain (no recursive loop possible).
|
||||
|
||||
Returns ChatExecutionResult with status:
|
||||
- "success": main agent succeeded
|
||||
- "recovered": main failed, ReflexionEngine recovery succeeded
|
||||
- "emergency": main failed, recovery failed/exhausted, Emergency layer fired
|
||||
"""
|
||||
config = fallback_chain_config or {}
|
||||
recovery_cfg = config.get("recovery", {}) if isinstance(config, dict) else {}
|
||||
emergency_cfg = config.get("emergency", {}) if isinstance(config, dict) else {}
|
||||
recovery_enabled = recovery_cfg.get("enabled", True) if isinstance(recovery_cfg, dict) else True
|
||||
emergency_enabled = (
|
||||
emergency_cfg.get("enabled", True) if isinstance(emergency_cfg, dict) else True
|
||||
)
|
||||
max_reflections = recovery_cfg.get("max_retries", 1) if isinstance(recovery_cfg, dict) else 1
|
||||
|
||||
# ── Tier 1: Main ──────────────────────────────────────────────
|
||||
main_exc: Exception | None = None
|
||||
try:
|
||||
result = await react_engine.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
if result.status == "success":
|
||||
return _react_to_chat_result(result)
|
||||
# Soft failure (empty_fallback / verify_failed / timeout) → trigger Recovery
|
||||
if result.status in _SOFT_FAILURE_STATUSES:
|
||||
main_exc = AgentSoftFailureError(
|
||||
f"main agent status={result.status}: {result.output[:200]}"
|
||||
)
|
||||
else:
|
||||
# Unknown status — treat as success-like (don't trigger recovery)
|
||||
return _react_to_chat_result(result)
|
||||
except TaskCancelledError:
|
||||
# KTD3: TaskCancelledError propagates as-is, NOT routed to Emergency.
|
||||
raise
|
||||
except (TaskTimeoutError, LoopDetectedError, LLMProviderError) as exc:
|
||||
main_exc = exc
|
||||
except Exception as exc: # noqa: BLE001 - last-resort catch for Emergency routing
|
||||
main_exc = exc
|
||||
|
||||
# ── Tier 2: Recovery (ReflexionEngine) ────────────────────────
|
||||
if recovery_enabled and main_exc is not None:
|
||||
try:
|
||||
reflexion = ReflexionEngine(
|
||||
llm_gateway=llm_gateway,
|
||||
max_reflections=max_reflections,
|
||||
)
|
||||
recovery_result = await reflexion.execute(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
agent_name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
# Recovery succeeds if Reflexion reports success or produces output.
|
||||
recovery_status = getattr(recovery_result, "status", "")
|
||||
if recovery_status == "success" or getattr(recovery_result, "output", None):
|
||||
return _reflexion_to_chat_result(recovery_result)
|
||||
logger.warning(
|
||||
f"Recovery layer did not succeed (status={recovery_status}), "
|
||||
f"falling through to Emergency"
|
||||
)
|
||||
except TaskCancelledError:
|
||||
raise
|
||||
except Exception as recovery_exc: # noqa: BLE001
|
||||
logger.warning(f"Recovery layer raised: {recovery_exc}; falling through to Emergency")
|
||||
|
||||
# ── Tier 3: Emergency ─────────────────────────────────────────
|
||||
if not emergency_enabled:
|
||||
# Re-raise original exception if Emergency disabled.
|
||||
if main_exc is not None:
|
||||
raise main_exc
|
||||
# No exception but no success either — synthesise an emergency-style result.
|
||||
return ChatExecutionResult(
|
||||
output="Agent 未返回有效结果且 Emergency 层已禁用。",
|
||||
status="emergency",
|
||||
fallback_strategy="emergency_disabled",
|
||||
)
|
||||
|
||||
# main_exc may be None if main returned soft-failure status without raising.
|
||||
# Synthesize a generic exception for Emergency classification.
|
||||
exc_for_emergency = main_exc or AgentSoftFailureError("soft failure without exception")
|
||||
return _to_emergency(exc_for_emergency, config)
|
||||
|
||||
|
||||
class AgentSoftFailureError(Exception):
|
||||
"""Internal marker — main agent returned a soft-failure status without raising.
|
||||
|
||||
Used to feed the Emergency classifier when main status was e.g.
|
||||
``empty_fallback`` (no exception raised, but result not usable).
|
||||
Classified as ``internal_error`` by EmergencyRules (generic fallback).
|
||||
"""
|
||||
|
|
@ -119,6 +119,8 @@ class ServerConfig:
|
|||
prompt_cache: dict[str, Any] | None = None,
|
||||
streaming: dict[str, Any] | None = None,
|
||||
verification: dict[str, Any] | None = None,
|
||||
rollback: dict[str, Any] | None = None,
|
||||
fallback_chain: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -153,6 +155,12 @@ class ServerConfig:
|
|||
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
||||
# verification_enabled=False 时此配置无效
|
||||
self.verification = verification or {}
|
||||
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
|
||||
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
|
||||
self.rollback = rollback or {}
|
||||
# G7/U3: fallback_chain.{recovery,emergency}.{enabled,max_retries}
|
||||
# controls three-tier chain at chat.py REST send_message (KTD5).
|
||||
self.fallback_chain = fallback_chain or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -240,6 +248,10 @@ class ServerConfig:
|
|||
prompt_cache_data = data.get("prompt_cache", {})
|
||||
streaming_data = data.get("streaming", {})
|
||||
verification_data = data.get("verification", {})
|
||||
# G9/U4: rollback 配置 (从 YAML 读取,opt-in)
|
||||
rollback_data = data.get("rollback", {})
|
||||
# G7/U3: fallback_chain 配置 (从 YAML 读取)
|
||||
fallback_chain_data = data.get("fallback_chain", {})
|
||||
|
||||
return cls(
|
||||
host=server.get("host", "0.0.0.0"),
|
||||
|
|
@ -271,6 +283,8 @@ class ServerConfig:
|
|||
prompt_cache=prompt_cache_data,
|
||||
streaming=streaming_data,
|
||||
verification=verification_data,
|
||||
rollback=rollback_data,
|
||||
fallback_chain=fallback_chain_data,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -320,6 +334,9 @@ class ServerConfig:
|
|||
model_aliases=model_aliases,
|
||||
fallbacks=data.get("fallbacks", {}),
|
||||
cache=cache_config,
|
||||
# G4/U1: auxiliary model alias for cost-sensitive summarization.
|
||||
# Resolved via model_aliases; None = no auxiliary routing.
|
||||
auxiliary_model=data.get("auxiliary_model"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from pydantic import BaseModel
|
|||
from agentkit.chat.skill_routing import ExecutionMode
|
||||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.models import MessageRole, SessionStatus
|
||||
|
||||
|
|
@ -610,7 +611,15 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
system_prompt = getattr(agent, "_system_prompt", None) or (
|
||||
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
||||
)
|
||||
result = await react_engine.execute(
|
||||
# G7/U3: Three-tier fallback chain (main → Recovery → Emergency).
|
||||
# Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass.
|
||||
server_config = getattr(req.app.state, "server_config", None)
|
||||
fallback_chain_cfg = (
|
||||
getattr(server_config, "fallback_chain", None) if server_config else None
|
||||
)
|
||||
chat_result = await execute_with_fallback_chain(
|
||||
react_engine=react_engine,
|
||||
llm_gateway=req.app.state.llm_gateway,
|
||||
messages=chat_messages,
|
||||
tools=tools,
|
||||
model=agent.get_model()
|
||||
|
|
@ -618,16 +627,26 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
|||
else getattr(agent, "_llm_model", "default"),
|
||||
agent_name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
fallback_chain_config=fallback_chain_cfg,
|
||||
)
|
||||
|
||||
# Append assistant reply
|
||||
assistant_msg = await sm.append_message(
|
||||
session_id=session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content=result.output if hasattr(result, "output") else str(result),
|
||||
content=chat_result.output,
|
||||
agent_name=agent.name,
|
||||
)
|
||||
return _message_to_response(assistant_msg)
|
||||
response = _message_to_response(assistant_msg)
|
||||
# Attach structured error payload when Emergency tier fired.
|
||||
if chat_result.error_struct is not None:
|
||||
response_dict = (
|
||||
response.model_dump() if hasattr(response, "model_dump") else dict(response)
|
||||
)
|
||||
response_dict["error_struct"] = chat_result.error_struct
|
||||
response_dict["fallback_status"] = chat_result.status
|
||||
return response_dict
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Chat execution error for session {session_id}: {e}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,400 @@
|
|||
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
|
||||
|
||||
Verifies:
|
||||
- auxiliary_model routes _summarize through the cheaper model first
|
||||
- empty content (Finding 4 anti-pattern) triggers fallback to main model
|
||||
- auxiliary exception triggers fallback to main model
|
||||
- both auxiliary and main failing falls through to _simple_summary
|
||||
- auxiliary_model=None preserves existing single-model behavior (characterization)
|
||||
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
from agentkit.llm.config import LLMConfig
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
|
||||
"""Mock LLMGateway returning a fixed response."""
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
gateway.chat = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content=content,
|
||||
model=model,
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||
)
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
|
||||
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
|
||||
|
||||
Each call to gateway.chat(model=X) pops the next response for X from a queue,
|
||||
so repeated calls to the same model can return different values.
|
||||
"""
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
|
||||
gateway = MagicMock(spec=LLMGateway)
|
||||
queues = {m: list(rs) for m, rs in responses_by_model.items()}
|
||||
|
||||
async def chat_side_effect(*, messages, model, **kwargs):
|
||||
queue = queues.get(model)
|
||||
if queue is None:
|
||||
raise ValueError(f"unexpected model={model}")
|
||||
if not queue:
|
||||
raise ValueError(f"queue for model={model} exhausted")
|
||||
item = queue.pop(0)
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
return item
|
||||
|
||||
gateway.chat = AsyncMock(side_effect=chat_side_effect)
|
||||
return gateway
|
||||
|
||||
|
||||
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
|
||||
"""Generate long messages that exceed token budget (triggers compression)."""
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
for i in range(count):
|
||||
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
|
||||
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
|
||||
messages.append({"role": "user", "content": "recent question"})
|
||||
messages.append({"role": "assistant", "content": "recent answer"})
|
||||
return messages
|
||||
|
||||
|
||||
# ── Characterization: auxiliary_model=None preserves existing behavior ──
|
||||
|
||||
|
||||
class TestAuxiliaryNoneCharacterization:
|
||||
"""auxiliary_model=None (default) — single model call, existing behavior."""
|
||||
|
||||
async def test_no_auxiliary_calls_main_once(self):
|
||||
gateway = make_gateway_with_response("main summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
# auxiliary_model omitted → None
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
gateway.chat.assert_awaited_once()
|
||||
# The call used the main model
|
||||
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||
# Summary surfaced in result
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_main_failure_falls_to_simple_summary(self):
|
||||
gateway = MagicMock()
|
||||
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# _simple_summary produces truncated messages with "..."
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "..." in summary_msgs[0]["content"]
|
||||
|
||||
|
||||
# ── New behavior: auxiliary routing ──────────────────
|
||||
|
||||
|
||||
class TestAuxiliaryRouting:
|
||||
"""auxiliary_model set and differs from main → auxiliary tried first."""
|
||||
|
||||
async def test_auxiliary_success_returns_auxiliary_content(self):
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content="aux summary",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="MAIN SHOULD NOT BE USED",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Auxiliary called; main NOT called
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 0
|
||||
# Result contains auxiliary summary
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("aux summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_empty_content_triggers_main_fallback(self):
|
||||
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content="",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Auxiliary called once (returned empty)
|
||||
# Main called once (fallback)
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
# Result contains main summary (not the empty auxiliary)
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_whitespace_content_triggers_main_fallback(self):
|
||||
"""Whitespace-only content also counts as empty (Finding 4)."""
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [
|
||||
LLMResponse(
|
||||
content=" \n ",
|
||||
model="fast",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||
)
|
||||
],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Both auxiliary and main called
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
|
||||
async def test_auxiliary_exception_triggers_main_fallback(self):
|
||||
from agentkit.core.exceptions import LLMProviderError
|
||||
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [LLMProviderError("aux", "provider down")],
|
||||
"main": [
|
||||
LLMResponse(
|
||||
content="main summary",
|
||||
model="main",
|
||||
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
],
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Both called; main succeeded
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) == 1
|
||||
assert len(main_calls) == 1
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||
|
||||
async def test_both_fail_falls_to_simple_summary(self):
|
||||
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
|
||||
# Note: aggressive compression path may invoke _summarize multiple times.
|
||||
# Queue provides enough responses to handle that without raising queue-exhausted.
|
||||
gateway = make_gateway_side_effect(
|
||||
{
|
||||
"fast": [Exception("aux boom")] * 5,
|
||||
"main": [Exception("main boom")] * 5,
|
||||
}
|
||||
)
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
result = await compressor.compress(make_long_messages())
|
||||
|
||||
# Both called at least once
|
||||
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||
assert len(aux_calls) >= 1
|
||||
assert len(main_calls) >= 1
|
||||
# _simple_summary output has "..." truncation markers
|
||||
summary_msgs = [
|
||||
m
|
||||
for m in result
|
||||
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||
]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "..." in summary_msgs[0]["content"]
|
||||
|
||||
async def test_auxiliary_equal_to_main_skipped(self):
|
||||
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
|
||||
gateway = make_gateway_with_response("main summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="main", # same as main
|
||||
)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Only one call (to main); auxiliary block skipped
|
||||
assert gateway.chat.await_count == 1
|
||||
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||
|
||||
async def test_audit_fields_preserved(self):
|
||||
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
|
||||
gateway = make_gateway_with_response("aux summary")
|
||||
compressor = ContextCompressor(
|
||||
llm_gateway=gateway,
|
||||
max_tokens=100,
|
||||
keep_recent=2,
|
||||
model="main",
|
||||
auxiliary_model="fast",
|
||||
)
|
||||
# Override the mock to use a single-response gateway where auxiliary succeeds
|
||||
# (the make_gateway_with_response mock returns same response regardless of model)
|
||||
await compressor.compress(make_long_messages())
|
||||
|
||||
# Single call (auxiliary succeeded) — verify audit fields
|
||||
call_kwargs = gateway.chat.await_args.kwargs
|
||||
assert call_kwargs.get("agent_name") == "compressor"
|
||||
assert call_kwargs.get("task_type") == "summarization"
|
||||
|
||||
|
||||
# ── Config wiring ────────────────────────────────────
|
||||
|
||||
|
||||
class TestConfigWiring:
|
||||
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
|
||||
|
||||
def test_llm_config_from_dict_reads_auxiliary_model(self):
|
||||
cfg = LLMConfig.from_dict(
|
||||
{
|
||||
"providers": {},
|
||||
"model_aliases": {"fast": "p/m"},
|
||||
"auxiliary_model": "fast",
|
||||
}
|
||||
)
|
||||
assert cfg.auxiliary_model == "fast"
|
||||
|
||||
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
|
||||
cfg = LLMConfig.from_dict({"providers": {}})
|
||||
assert cfg.auxiliary_model is None
|
||||
|
||||
def test_llm_config_default_auxiliary_none(self):
|
||||
cfg = LLMConfig()
|
||||
assert cfg.auxiliary_model is None
|
||||
|
||||
def test_server_config_build_llm_config_reads_auxiliary_model(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
llm_data = {
|
||||
"providers": {
|
||||
"p": {
|
||||
"type": "openai",
|
||||
"api_key": "k",
|
||||
"base_url": "http://x",
|
||||
"models": {"m": {"alias": "fast"}},
|
||||
}
|
||||
},
|
||||
"auxiliary_model": "fast",
|
||||
}
|
||||
llm_config = ServerConfig._build_llm_config(llm_data)
|
||||
assert llm_config.auxiliary_model == "fast"
|
||||
# Also verify model_aliases still built correctly
|
||||
assert llm_config.model_aliases.get("fast") == "p/m"
|
||||
|
||||
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
llm_config = ServerConfig._build_llm_config({"providers": {}})
|
||||
assert llm_config.auxiliary_model is None
|
||||
|
|
@ -0,0 +1,318 @@
|
|||
"""G7/U2 — Emergency layer rule template + TaskResult extension.
|
||||
|
||||
Verifies:
|
||||
- EmergencyRules.classify maps each exception type to correct error_code
|
||||
- TaskCancelledError raises ValueError (caller must propagate as-is)
|
||||
- EmergencyError.to_dict produces all 5 fields
|
||||
- EmergencyError.to_error_message formats suggestions as "建议:1) ... 2) ..."
|
||||
- Config overrides apply (suggestions, retryable, message)
|
||||
- TaskResult.error_struct field: default None preserves byte-for-byte
|
||||
to_dict() output (backward compat)
|
||||
- TaskResult round-trip serialization includes error_struct when set
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.fallback import (
|
||||
EMPTY_LLM_RESPONSE,
|
||||
MAX_STEPS_REACHED,
|
||||
SHELL_NO_OUTPUT,
|
||||
EmergencyError,
|
||||
EmergencyRules,
|
||||
)
|
||||
from agentkit.core.protocol import TaskResult
|
||||
|
||||
|
||||
# ── Constants unchanged (contract preservation) ──────
|
||||
|
||||
|
||||
class TestExistingConstantsUnchanged:
|
||||
"""Existing 3 constants preserved byte-for-byte."""
|
||||
|
||||
def test_empty_llm_response_unchanged(self):
|
||||
assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE
|
||||
assert "建议" in EMPTY_LLM_RESPONSE
|
||||
|
||||
def test_max_steps_reached_unchanged(self):
|
||||
assert "已达到最大推理步数" in MAX_STEPS_REACHED
|
||||
|
||||
def test_shell_no_output_unchanged(self):
|
||||
assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]"
|
||||
|
||||
|
||||
# ── EmergencyRules.classify ──────────────────────────
|
||||
|
||||
|
||||
class TestEmergencyRulesClassify:
|
||||
"""classify() maps exception types to EmergencyError."""
|
||||
|
||||
def test_timeout(self):
|
||||
exc = TaskTimeoutError(task_id="t1", timeout_seconds=30)
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
assert "稍后重试" in err.suggestions
|
||||
assert "简化任务范围" in err.suggestions
|
||||
assert err.original_error == str(exc)
|
||||
|
||||
def test_loop_detected(self):
|
||||
exc = LoopDetectedError(tool_name="shell", repetitions=3)
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "loop_detected"
|
||||
assert err.retryable is True
|
||||
assert "拆分任务" in err.suggestions
|
||||
assert "检查工具参数" in err.suggestions
|
||||
|
||||
def test_llm_provider_error(self):
|
||||
exc = LLMProviderError("openai", "rate limited")
|
||||
err = EmergencyRules.classify(exc)
|
||||
assert err.error_code == "llm_failure"
|
||||
assert err.retryable is True
|
||||
assert "稍后重试" in err.suggestions
|
||||
assert "切换模型" in err.suggestions
|
||||
|
||||
def test_llm_error_subclass_also_classified(self):
|
||||
"""LLMProviderError is a subclass of LLMError; ensure isinstance check works."""
|
||||
from agentkit.core.exceptions import LLMError
|
||||
|
||||
class CustomLLMError(LLMError):
|
||||
pass
|
||||
|
||||
err = EmergencyRules.classify(CustomLLMError("custom"))
|
||||
# CustomLLMError is NOT a LLMProviderError, falls through to generic
|
||||
assert err.error_code == "internal_error"
|
||||
|
||||
def test_generic_exception_internal_error(self):
|
||||
err = EmergencyRules.classify(Exception("unknown boom"))
|
||||
assert err.error_code == "internal_error"
|
||||
assert err.retryable is False
|
||||
assert "联系管理员" in err.suggestions
|
||||
assert err.original_error == "unknown boom"
|
||||
|
||||
def test_task_cancelled_raises(self):
|
||||
"""TaskCancelledError must propagate; classify() raises ValueError."""
|
||||
exc = TaskCancelledError(task_id="t1")
|
||||
with pytest.raises(ValueError, match="TaskCancelledError"):
|
||||
EmergencyRules.classify(exc)
|
||||
|
||||
def test_subclass_of_timeout_classified(self):
|
||||
"""Subclasses of TaskTimeoutError are classified as timeout."""
|
||||
|
||||
class CustomTimeout(TaskTimeoutError):
|
||||
def __init__(self):
|
||||
super().__init__(task_id="custom", timeout_seconds=10)
|
||||
|
||||
err = EmergencyRules.classify(CustomTimeout())
|
||||
assert err.error_code == "timeout"
|
||||
|
||||
|
||||
# ── EmergencyError serialization ─────────────────────
|
||||
|
||||
|
||||
class TestEmergencyErrorSerialization:
|
||||
"""to_dict / to_error_message on EmergencyError."""
|
||||
|
||||
def test_to_dict_produces_all_five_fields(self):
|
||||
err = EmergencyError(
|
||||
error_code="timeout",
|
||||
message="任务执行超时。",
|
||||
suggestions=["稍后重试", "简化任务范围"],
|
||||
retryable=True,
|
||||
original_error="Task t1 timed out after 30s",
|
||||
)
|
||||
d = err.to_dict()
|
||||
assert set(d.keys()) == {
|
||||
"error_code",
|
||||
"message",
|
||||
"suggestions",
|
||||
"retryable",
|
||||
"original_error",
|
||||
}
|
||||
assert d["error_code"] == "timeout"
|
||||
assert d["message"] == "任务执行超时。"
|
||||
assert d["suggestions"] == ["稍后重试", "简化任务范围"]
|
||||
assert d["retryable"] is True
|
||||
assert d["original_error"] == "Task t1 timed out after 30s"
|
||||
|
||||
def test_to_dict_suggestions_list_is_copy(self):
|
||||
"""to_dict returns a fresh list, not the internal reference."""
|
||||
suggestions = ["a", "b"]
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="m",
|
||||
suggestions=suggestions,
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
d = err.to_dict()
|
||||
assert d["suggestions"] is not suggestions
|
||||
d["suggestions"].append("c")
|
||||
assert err.suggestions == ["a", "b"]
|
||||
|
||||
def test_to_error_message_with_suggestions(self):
|
||||
err = EmergencyError(
|
||||
error_code="timeout",
|
||||
message="任务执行超时。",
|
||||
suggestions=["稍后重试", "简化任务范围"],
|
||||
retryable=True,
|
||||
original_error="err",
|
||||
)
|
||||
msg = err.to_error_message()
|
||||
assert msg.startswith("任务执行超时。建议:")
|
||||
assert "1) 稍后重试" in msg
|
||||
assert "2) 简化任务范围" in msg
|
||||
# Format mirrors EMPTY_LLM_RESPONSE style
|
||||
assert msg.endswith("。")
|
||||
|
||||
def test_to_error_message_no_suggestions(self):
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="just a message",
|
||||
suggestions=[],
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
assert err.to_error_message() == "just a message"
|
||||
|
||||
def test_to_error_message_single_suggestion(self):
|
||||
err = EmergencyError(
|
||||
error_code="x",
|
||||
message="msg",
|
||||
suggestions=["only one"],
|
||||
retryable=False,
|
||||
original_error="e",
|
||||
)
|
||||
msg = err.to_error_message()
|
||||
assert msg == "msg建议:1) only one。"
|
||||
|
||||
|
||||
# ── Config override ──────────────────────────────────
|
||||
|
||||
|
||||
class TestConfigOverride:
|
||||
"""classify() applies per-rule config overrides."""
|
||||
|
||||
def test_override_suggestions(self):
|
||||
exc = TaskTimeoutError(task_id="t", timeout_seconds=1)
|
||||
cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.suggestions == ["自定义建议 A", "自定义建议 B"]
|
||||
assert err.error_code == "timeout"
|
||||
|
||||
def test_override_retryable(self):
|
||||
exc = LLMProviderError("openai", "boom")
|
||||
cfg = {"llm_failure": {"retryable": False}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.retryable is False
|
||||
|
||||
def test_override_message(self):
|
||||
exc = LoopDetectedError(tool_name="x", repetitions=2)
|
||||
cfg = {"loop_detected": {"message": "循环啦!"}}
|
||||
err = EmergencyRules.classify(exc, config=cfg)
|
||||
assert err.message == "循环啦!"
|
||||
|
||||
def test_override_internal_error_rule(self):
|
||||
cfg = {"internal_error": {"suggestions": ["联系客服"]}}
|
||||
err = EmergencyRules.classify(Exception("boom"), config=cfg)
|
||||
assert err.error_code == "internal_error"
|
||||
assert err.suggestions == ["联系客服"]
|
||||
|
||||
def test_config_none_uses_defaults(self):
|
||||
err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1))
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
|
||||
def test_config_empty_dict_uses_defaults(self):
|
||||
err = EmergencyRules.classify(
|
||||
TaskTimeoutError(task_id="t", timeout_seconds=1), config={}
|
||||
)
|
||||
assert err.error_code == "timeout"
|
||||
assert err.retryable is True
|
||||
|
||||
|
||||
# ── TaskResult.error_struct extension ────────────────
|
||||
|
||||
|
||||
def _make_task_result(
|
||||
error_struct: dict | None = None, error_message: str | None = None
|
||||
) -> TaskResult:
|
||||
now = datetime.now(timezone.utc)
|
||||
return TaskResult(
|
||||
task_id="t1",
|
||||
agent_name="a1",
|
||||
status="completed",
|
||||
output_data={"k": "v"},
|
||||
error_message=error_message,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
metrics={"m": 1},
|
||||
error_struct=error_struct,
|
||||
)
|
||||
|
||||
|
||||
class TestTaskResultErrorStruct:
|
||||
"""TaskResult.error_struct field — backward-compatible extension."""
|
||||
|
||||
def test_default_error_struct_is_none(self):
|
||||
tr = _make_task_result()
|
||||
assert tr.error_struct is None
|
||||
|
||||
def test_to_dict_without_error_struct_preserves_existing_shape(self):
|
||||
"""error_struct=None → to_dict() output has NO error_struct key (byte-for-byte)."""
|
||||
tr = _make_task_result()
|
||||
d = tr.to_dict()
|
||||
assert "error_struct" not in d
|
||||
# Existing keys unchanged
|
||||
assert set(d.keys()) == {
|
||||
"task_id",
|
||||
"agent_name",
|
||||
"status",
|
||||
"output_data",
|
||||
"error_message",
|
||||
"started_at",
|
||||
"completed_at",
|
||||
"metrics",
|
||||
}
|
||||
|
||||
def test_to_dict_with_error_struct_includes_key(self):
|
||||
struct = {
|
||||
"error_code": "timeout",
|
||||
"message": "超时",
|
||||
"suggestions": ["重试"],
|
||||
"retryable": True,
|
||||
"original_error": "boom",
|
||||
}
|
||||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||||
d = tr.to_dict()
|
||||
assert d["error_struct"] == struct
|
||||
assert d["error_message"] == "超时建议:1) 重试。"
|
||||
|
||||
def test_from_dict_round_trip_with_error_struct(self):
|
||||
struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"}
|
||||
tr = _make_task_result(error_struct=struct)
|
||||
d = tr.to_dict()
|
||||
restored = TaskResult.from_dict(d)
|
||||
assert restored.error_struct == struct
|
||||
|
||||
def test_from_dict_without_error_struct_defaults_none(self):
|
||||
tr = _make_task_result()
|
||||
d = tr.to_dict()
|
||||
# Simulate legacy data without error_struct key
|
||||
restored = TaskResult.from_dict(d)
|
||||
assert restored.error_struct is None
|
||||
|
||||
def test_error_message_and_error_struct_coexist(self):
|
||||
"""Both fields can be set simultaneously (parallel contract per KTD2)."""
|
||||
struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"}
|
||||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||||
d = tr.to_dict()
|
||||
assert d["error_message"] == "超时建议:1) 重试。"
|
||||
assert d["error_struct"] == struct
|
||||
|
|
@ -0,0 +1,404 @@
|
|||
"""G7/U3 — Three-tier fallback chain wiring tests.
|
||||
|
||||
Verifies Main → Recovery (ReflexionEngine) → Emergency (EmergencyRules)
|
||||
at chat REST path. Mocks ReActEngine + ReflexionEngine + LLMGateway.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
LLMProviderError,
|
||||
LoopDetectedError,
|
||||
TaskCancelledError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from agentkit.core.react import ReActResult
|
||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||
|
||||
|
||||
def _make_react_result(status: str = "success", output: str = "ok") -> ReActResult:
|
||||
return ReActResult(
|
||||
output=output,
|
||||
trajectory=[],
|
||||
total_steps=1,
|
||||
total_tokens=10,
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _make_react_engine(result=None, raises=None):
|
||||
"""Build a fake ReActEngine with .execute returning result or raising."""
|
||||
engine = MagicMock()
|
||||
engine.reset = MagicMock()
|
||||
if raises is not None:
|
||||
engine.execute = AsyncMock(side_effect=raises)
|
||||
else:
|
||||
engine.execute = AsyncMock(return_value=result or _make_react_result())
|
||||
return engine
|
||||
|
||||
|
||||
def _make_llm_gateway():
|
||||
gw = MagicMock()
|
||||
gw.chat = AsyncMock(return_value=MagicMock(content="recovered"))
|
||||
return gw
|
||||
|
||||
|
||||
def _make_reflexion_result(status: str = "success", output: str = "recovered"):
|
||||
"""Synthesize a ReflexionResult-like object."""
|
||||
return MagicMock(
|
||||
status=status,
|
||||
output=output,
|
||||
trajectory=[],
|
||||
total_steps=1,
|
||||
total_tokens=5,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patched_reflexion(monkeypatch):
|
||||
"""Patch ReflexionEngine used inside the chain to a controllable mock."""
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
instances: list[MagicMock] = []
|
||||
|
||||
class _MockReflexion:
|
||||
def __init__(self, llm_gateway, max_reflections=1, **kwargs):
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_reflections = max_reflections
|
||||
self.execute = AsyncMock(return_value=_make_reflexion_result())
|
||||
instances.append(self)
|
||||
|
||||
monkeypatch.setattr(_fallback_chain, "ReflexionEngine", _MockReflexion)
|
||||
return instances
|
||||
|
||||
|
||||
# ─── Tier 1: Main ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMainTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_success_no_recovery_no_emergency(self):
|
||||
engine = _make_react_engine(result=_make_react_result(status="success", output="hello"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert result.output == "hello"
|
||||
assert result.error_struct is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_unknown_status_treated_as_success(self):
|
||||
"""Unknown status (not in soft_failure set) is treated as success-like."""
|
||||
engine = _make_react_engine(result=_make_react_result(status="partial", output="x"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "success"
|
||||
|
||||
|
||||
# ─── Tier 2: Recovery ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRecoveryTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_timeout_triggers_recovery_success(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
assert result.output == "recovered"
|
||||
# ReflexionEngine was instantiated and called
|
||||
assert len(patched_reflexion) == 1
|
||||
patched_reflexion[0].execute.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_loop_detected_triggers_recovery(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_llm_provider_error_triggers_recovery(self, patched_reflexion):
|
||||
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="503"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_main_soft_failure_status_triggers_recovery(self, patched_reflexion):
|
||||
"""Soft failure (empty_fallback) without exception still triggers Recovery."""
|
||||
engine = _make_react_engine(result=_make_react_result(status="empty_fallback", output=""))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
assert result.status == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_disabled_skips_to_emergency(self):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_failure_falls_through_to_emergency(self, patched_reflexion):
|
||||
"""Recovery raises → Emergency tier fires with original exception."""
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
# Make ReflexionEngine.execute raise
|
||||
patched_reflexion_instance = MagicMock()
|
||||
patched_reflexion_instance.execute = AsyncMock(
|
||||
side_effect=RuntimeError("reflexion crashed")
|
||||
)
|
||||
# Override the patched class to use our instance
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
|
||||
class _MockReflexionWithExc:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = patched_reflexion_instance.execute
|
||||
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionWithExc
|
||||
try:
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recovery_unsuccessful_status_falls_through(self, patched_reflexion):
|
||||
"""Recovery returns non-success status → Emergency fires."""
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
# Make ReflexionEngine return unsuccessful result with empty output
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
class _MockReflexionNoOutput:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionNoOutput
|
||||
try:
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
|
||||
|
||||
# ─── Tier 3: Emergency ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmergencyTier:
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_timeout_error_code(self, patched_reflexion):
|
||||
# Make recovery fail (empty result) so Emergency fires
|
||||
from agentkit.server import _fallback_chain
|
||||
|
||||
class _MockReflexionEmpty:
|
||||
def __init__(self, **kwargs):
|
||||
self.execute = AsyncMock(return_value=MagicMock(status="failed", output=None))
|
||||
|
||||
original_cls = _fallback_chain.ReflexionEngine
|
||||
_fallback_chain.ReflexionEngine = _MockReflexionEmpty
|
||||
try:
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
)
|
||||
finally:
|
||||
_fallback_chain.ReflexionEngine = original_cls
|
||||
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "timeout"
|
||||
assert result.error_struct["retryable"] is True
|
||||
assert "建议" in result.output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_loop_detected_error_code(self):
|
||||
engine = _make_react_engine(raises=LoopDetectedError(tool_name="search", repetitions=5))
|
||||
# Recovery disabled so Emergency fires directly
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "loop_detected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_llm_failure_error_code(self):
|
||||
engine = _make_react_engine(raises=LLMProviderError(provider="openai", reason="500"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "llm_failure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_internal_error_for_generic_exception(self):
|
||||
engine = _make_react_engine(raises=RuntimeError("unexpected"))
|
||||
result = await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
assert result.status == "emergency"
|
||||
assert result.error_struct["error_code"] == "internal_error"
|
||||
assert result.error_struct["retryable"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancelled_propagates_not_routed_to_emergency(self):
|
||||
"""TaskCancelledError must propagate, not be classified by Emergency."""
|
||||
engine = _make_react_engine(raises=TaskCancelledError(task_id="t1"))
|
||||
with pytest.raises(TaskCancelledError):
|
||||
await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={"recovery": {"enabled": False}},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emergency_disabled_reraises_original(self):
|
||||
engine = _make_react_engine(raises=TaskTimeoutError(task_id="t1", timeout_seconds=10))
|
||||
with pytest.raises(TaskTimeoutError):
|
||||
await execute_with_fallback_chain(
|
||||
react_engine=engine,
|
||||
llm_gateway=_make_llm_gateway(),
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=[],
|
||||
model="default",
|
||||
agent_name="a",
|
||||
system_prompt=None,
|
||||
fallback_chain_config={
|
||||
"recovery": {"enabled": False},
|
||||
"emergency": {"enabled": False},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ─── Config wiring ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServerConfigFallbackChain:
|
||||
def test_fallback_chain_section_read_from_dict(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"fallback_chain": {
|
||||
"enabled": True,
|
||||
"recovery": {"enabled": False, "max_retries": 3},
|
||||
"emergency": {"enabled": True},
|
||||
}
|
||||
}
|
||||
)
|
||||
assert config.fallback_chain["enabled"] is True
|
||||
assert config.fallback_chain["recovery"] == {"enabled": False, "max_retries": 3}
|
||||
assert config.fallback_chain["emergency"] == {"enabled": True}
|
||||
|
||||
def test_fallback_chain_defaults_empty_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.fallback_chain == {}
|
||||
|
|
@ -0,0 +1,367 @@
|
|||
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
|
||||
|
||||
Characterization-first: captures pre-change behavior (rollback_command=None →
|
||||
no rollback, checkpoint saved) before asserting new rollback behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||
from agentkit.experts.plan import PlanPhase, TeamPlan
|
||||
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||
|
||||
|
||||
# ─── PlanPhase field characterization ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanPhaseFields:
|
||||
"""PlanPhase serialization for new fields."""
|
||||
|
||||
def test_characterization_no_new_keys_when_unset(self):
|
||||
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
|
||||
ph = PlanPhase(name="执行", assigned_expert="lead")
|
||||
out = ph.to_dict()
|
||||
assert "validation_command" not in out
|
||||
assert "rollback_command" not in out
|
||||
# Pre-change keys remain
|
||||
assert out["name"] == "执行"
|
||||
assert out["assigned_expert"] == "lead"
|
||||
assert out["status"] == "pending"
|
||||
|
||||
def test_characterization_from_dict_empty_yields_none(self):
|
||||
"""from_dict({}) produces both new fields as None."""
|
||||
ph = PlanPhase.from_dict({"name": "x"})
|
||||
assert ph.validation_command is None
|
||||
assert ph.rollback_command is None
|
||||
|
||||
def test_serialization_includes_keys_when_set(self):
|
||||
ph = PlanPhase(
|
||||
name="frontend",
|
||||
validation_command="ruff check src/",
|
||||
rollback_command="git checkout src/app.vue",
|
||||
)
|
||||
out = ph.to_dict()
|
||||
assert out["validation_command"] == "ruff check src/"
|
||||
assert out["rollback_command"] == "git checkout src/app.vue"
|
||||
|
||||
def test_serialization_round_trip(self):
|
||||
ph = PlanPhase(
|
||||
name="backend",
|
||||
validation_command="pytest -x -q",
|
||||
rollback_command="git checkout src/api.py",
|
||||
)
|
||||
restored = PlanPhase.from_dict(ph.to_dict())
|
||||
assert restored.validation_command == "pytest -x -q"
|
||||
assert restored.rollback_command == "git checkout src/api.py"
|
||||
|
||||
def test_only_validation_set_still_emits_validation_key(self):
|
||||
"""Asymmetric case: only validation_command set — only that key appears."""
|
||||
ph = PlanPhase(name="x", validation_command="echo ok")
|
||||
out = ph.to_dict()
|
||||
assert "validation_command" in out
|
||||
assert "rollback_command" not in out
|
||||
|
||||
|
||||
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_workspace(tmp_path):
|
||||
"""Fresh working directory for subprocess execution."""
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
class TestRollbackExecutor:
|
||||
"""RollbackExecutor happy/edge/failure paths."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("true")
|
||||
assert r.passed is True
|
||||
assert r.exit_code == 0
|
||||
assert r.command == "true"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("false")
|
||||
assert r.passed is False
|
||||
assert r.exit_code != 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_timeout(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
|
||||
r = await ex.execute("sleep 5")
|
||||
assert r.passed is False
|
||||
assert r.exit_code == -1
|
||||
assert "timed out" in r.stderr
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_captures_stdout(self, tmp_workspace):
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("echo hello-rollback")
|
||||
assert r.passed is True
|
||||
assert "hello-rollback" in r.stdout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
|
||||
"""validate() is just execute() with different intent marker."""
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.validate("true")
|
||||
assert r.passed is True
|
||||
r2 = await ex.validate("false")
|
||||
assert r2.passed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
|
||||
"""Bad shell command should surface as failed result, not raise."""
|
||||
# Use a definitely-broken command — shell still spawns, returns non-zero.
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
r = await ex.execute("exit 7")
|
||||
assert r.passed is False
|
||||
assert r.exit_code == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
|
||||
"""Files created in working_dir are visible via cwd-relative commands."""
|
||||
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||
await ex.execute("echo content > testfile.txt")
|
||||
r = await ex.execute("test -f testfile.txt")
|
||||
assert r.passed is True
|
||||
|
||||
|
||||
# ─── Real git rollback integration ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGitRollbackIntegration:
|
||||
"""Real git repo fixture: writes file, rollback restores via git checkout."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_checkout_restores_modified_file(self, tmp_path):
|
||||
# Init repo and commit a baseline file
|
||||
import subprocess
|
||||
|
||||
repo = str(tmp_path)
|
||||
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
|
||||
baseline = "original content\n"
|
||||
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||
f.write(baseline)
|
||||
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
|
||||
|
||||
# Mutate the file
|
||||
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||
f.write("mutated content\n")
|
||||
|
||||
# Run git checkout as rollback_command
|
||||
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
|
||||
r = await ex.execute("git checkout foo.txt")
|
||||
assert r.passed is True
|
||||
|
||||
with open(os.path.join(repo, "foo.txt")) as f:
|
||||
assert f.read() == baseline
|
||||
|
||||
|
||||
# ─── TeamOrchestrator integration ─────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_team_mock():
|
||||
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
|
||||
team = MagicMock()
|
||||
team.team_id = "test-team"
|
||||
team.status.value = "executing"
|
||||
team.lead_expert = None
|
||||
team.active_experts = []
|
||||
return team
|
||||
|
||||
|
||||
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
|
||||
"""Build a TeamOrchestrator with mocked team and checkpoint."""
|
||||
team = team or _make_team_mock()
|
||||
orch = TeamOrchestrator(
|
||||
team=team,
|
||||
checkpoint=checkpoint,
|
||||
workspace_root=workspace_root,
|
||||
rollback_timeout=2.0,
|
||||
)
|
||||
orch._broadcast_event = AsyncMock()
|
||||
return orch
|
||||
|
||||
|
||||
class TestOrchestratorRollbackIntegration:
|
||||
"""TeamOrchestrator phase failure path integration with rollback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_characterization_no_rollback_when_unset(self, tmp_path):
|
||||
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
# Call _run_phase_rollback — but it should NOT be called from
|
||||
# orchestrator path when fields unset. Verify by simulating the
|
||||
# main-loop guard directly.
|
||||
from agentkit.experts.plan import PhaseStatus as PS
|
||||
|
||||
ph.status = PS.FAILED
|
||||
# Simulate the guard condition in orchestrator.py:280-288
|
||||
should_save = True
|
||||
if (
|
||||
ph.validation_command
|
||||
and ph.rollback_command
|
||||
and isinstance(Exception("x"), (Exception,))
|
||||
):
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
# Guard should not fire
|
||||
assert should_save is True
|
||||
# No rollback events broadcast
|
||||
assert orch._broadcast_event.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_passes_no_rollback_executed(self, tmp_path):
|
||||
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="true", # always passes
|
||||
rollback_command="false", # would fail if executed
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is True
|
||||
|
||||
# Events: started + completed (no rollback)
|
||||
events = [
|
||||
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
|
||||
]
|
||||
assert "phase_rollback_started" in events
|
||||
assert "phase_rollback_completed" in events
|
||||
assert "phase_rollback_failed" not in events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_fails_rollback_succeeds(self, tmp_path):
|
||||
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
|
||||
# Create file under git tracking, then mutate it
|
||||
import subprocess
|
||||
|
||||
repo = str(tmp_path)
|
||||
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
|
||||
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||
f.write("base\n")
|
||||
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
|
||||
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
|
||||
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||
f.write("mutated\n")
|
||||
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false", # validation fails
|
||||
rollback_command="git checkout f.txt", # rollback succeeds
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is True
|
||||
|
||||
# File restored to baseline
|
||||
with open(os.path.join(repo, "f.txt")) as f:
|
||||
assert f.read() == "base\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
|
||||
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false", # fails
|
||||
rollback_command="false", # also fails
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is False # R21: skip checkpoint
|
||||
|
||||
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||
assert "phase_rollback_started" in events
|
||||
assert "phase_rollback_failed" in events
|
||||
# phase_rollback_completed should NOT be in events
|
||||
assert "phase_rollback_completed" not in events
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
|
||||
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
|
||||
checkpoint = MagicMock()
|
||||
checkpoint.save = AsyncMock()
|
||||
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||
orch._rollback_timeout = 0.1 # short timeout
|
||||
|
||||
ph = PlanPhase(
|
||||
name="p1",
|
||||
assigned_expert="lead",
|
||||
validation_command="false",
|
||||
rollback_command="sleep 5",
|
||||
)
|
||||
plan = TeamPlan(task="t", lead_expert="lead")
|
||||
plan.phases = [ph]
|
||||
|
||||
should_save = await orch._run_phase_rollback(plan, ph)
|
||||
assert should_save is False
|
||||
|
||||
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||
assert "phase_rollback_failed" in events
|
||||
|
||||
|
||||
# ─── Config wiring ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServerConfigRollback:
|
||||
"""ServerConfig rollback section wiring."""
|
||||
|
||||
def test_rollback_section_read_from_dict(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict(
|
||||
{
|
||||
"rollback": {
|
||||
"default_timeout": 45.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
assert config.rollback == {"default_timeout": 45.0}
|
||||
|
||||
def test_rollback_defaults_empty_when_absent(self):
|
||||
from agentkit.server.config import ServerConfig
|
||||
|
||||
config = ServerConfig.from_dict({})
|
||||
assert config.rollback == {}
|
||||
Loading…
Reference in New Issue