feat(U4): G9 PlanPhase rollback + RollbackExecutor

- PlanPhase 新增 validation_command / rollback_command 可选字段 (KTD6 opt-in)
- to_dict 仅在字段非 None 时输出新键,保持既有 dict shape (KTD6 契约)
- 新增 RollbackExecutor (orchestrator/rollback.py) 复用 VerificationLoop
  subprocess 模式,绕过 ShellTool 避免 confirm_callback 拦截 (KTD7)
- TeamOrchestrator._run_phase_rollback 实现 R21 顺序:
  validation → rollback → checkpoint.save (仅在前者通过时调用)
- ServerConfig.from_dict 读取 rollback.default_timeout
- 20 个测试覆盖 characterization / happy / timeout / git integration / 配置
This commit is contained in:
chiguyong 2026-06-29 22:55:08 +08:00
parent 5b2377469a
commit b1841ce21b
6 changed files with 591 additions and 8 deletions

View File

@ -33,6 +33,13 @@ llm:
# the main model on failure or empty content. Commented to preserve # the main model on failure or empty content. Commented to preserve
# default behavior — uncomment to enable. # default behavior — uncomment to enable.
# auxiliary_model: fast # auxiliary_model: fast
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
# not `git checkout .` which would wipe unrelated changes).
rollback:
default_timeout: 30.0
session: {backend: memory} session: {backend: memory}
bus: {backend: memory} bus: {backend: memory}
task_store: {backend: memory} task_store: {backend: memory}

View File

@ -30,6 +30,7 @@ from typing import Any
from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.config_driven import ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.orchestrator.rollback import RollbackExecutor
from .expert import Expert from .expert import Expert
from .plan import ( from .plan import (
@ -72,12 +73,17 @@ class TeamOrchestrator:
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰 DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"}) STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
DEFAULT_ROLLBACK_TIMEOUT = 30.0
def __init__( def __init__(
self, self,
team: ExpertTeam, team: ExpertTeam,
max_concurrent_phases: int | None = None, max_concurrent_phases: int | None = None,
checkpoint: Any = None, checkpoint: Any = None,
workspace_root: str | None = None,
rollback_timeout: float | None = None,
) -> None: ) -> None:
self._team = team self._team = team
# Track temporary agent names created for context isolation (KTD3) # Track temporary agent names created for context isolation (KTD3)
@ -93,6 +99,10 @@ class TeamOrchestrator:
self._phase_semaphore = asyncio.Semaphore(limit) self._phase_semaphore = asyncio.Semaphore(limit)
# U7: Pipeline checkpoint for crash recovery # U7: Pipeline checkpoint for crash recovery
self._checkpoint = checkpoint self._checkpoint = checkpoint
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
# Both default to no-op-friendly values so existing call sites behave identically.
self._workspace_root = workspace_root
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
async def execute(self, task: str) -> dict[str, Any]: async def execute(self, task: str) -> dict[str, Any]:
"""Execute a task in pipeline mode. """Execute a task in pipeline mode.
@ -262,8 +272,23 @@ class TeamOrchestrator:
else: else:
phase_results[ph.id] = result phase_results[ph.id] = result
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
# When phase configures both validation_command and rollback_command:
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
# 2. if validation fails, run rollback_command
# 3. if rollback passes (exit 0), save checkpoint
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
# When neither command is set, behavior is unchanged (existing save).
should_save_checkpoint = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(result, (Exception, asyncio.CancelledError))
):
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
# U7: Save checkpoint after phase finalizes (success or failure) # U7: Save checkpoint after phase finalizes (success or failure)
if self._checkpoint is not None: if should_save_checkpoint and self._checkpoint is not None:
try: try:
await self._checkpoint.save(plan.id, ph, plan.status.value) await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e: except Exception as e:
@ -393,9 +418,7 @@ class TeamOrchestrator:
# PENDING phases remain PENDING — will be executed by _run_pipeline # PENDING phases remain PENDING — will be executed by _run_pipeline
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume # P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
self._debate_count = sum( self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
)
logger.info( logger.info(
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, " f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
@ -688,9 +711,9 @@ class TeamOrchestrator:
and prev_phase.result and prev_phase.result
): ):
# U4: Resolve offloaded content from workspace # U4: Resolve offloaded content from workspace
collaboration_outputs[contract.from_expert] = ( collaboration_outputs[
await self._read_dependency_output(prev_phase) contract.from_expert
) ] = await self._read_dependency_output(prev_phase)
break break
# Emit expert_step event # Emit expert_step event
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
# Recursively mark their dependents # Recursively mark their dependents
await self._mark_dependents_failed(ph.id, plan, phase_results) await self._mark_dependents_failed(ph.id, plan, phase_results)
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
"""G9/U4: run validation_command + rollback_command for a failed phase.
Returns True if checkpoint save should proceed (R21 ordering).
- Validation passes save checkpoint (phase state recoverable)
- Validation fails, rollback passes save checkpoint (rolled back state)
- Validation fails, rollback fails skip checkpoint (broken state)
- Subprocess spawn failure or timeout skip checkpoint
"""
executor = RollbackExecutor(
working_dir=self._workspace_root,
timeout=self._rollback_timeout,
)
await self._broadcast_event(
"phase_rollback_started",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_command": ph.validation_command,
"rollback_command": ph.rollback_command,
},
)
# ponytail: validate first; if validation passes, rollback is skipped (no need).
validation = await executor.validate(ph.validation_command or "")
if validation.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": False,
"validation_passed": True,
},
)
return True
rollback = await executor.execute(ph.rollback_command or "")
if rollback.passed:
await self._broadcast_event(
"phase_rollback_completed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"rollback_executed": True,
"validation_passed": False,
"rollback_stdout": rollback.stdout,
},
)
return True
logger.error(
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
)
await self._broadcast_event(
"phase_rollback_failed",
{
"plan_id": plan.id,
"phase_id": ph.id,
"phase_name": ph.name,
"validation_passed": False,
"rollback_exit_code": rollback.exit_code,
"rollback_stderr": rollback.stderr,
},
)
return False
async def _synthesize_results( async def _synthesize_results(
self, lead: Expert, task: str, completed_phases: list[PlanPhase] self, lead: Expert, task: str, completed_phases: list[PlanPhase]
) -> dict[str, Any]: ) -> dict[str, Any]:

View File

@ -182,6 +182,11 @@ class PlanPhase:
collaboration_contracts: list[CollaborationContract] = field(default_factory=list) collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
rework_count: int = 0 rework_count: int = 0
review_feedback: str | None = None review_feedback: str | None = None
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
# validation_command runs first; if it fails, rollback_command runs.
# canonical rollback pattern: `git checkout <specific_files>`.
validation_command: str | None = None
rollback_command: str | None = None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""序列化为字典""" """序列化为字典"""
@ -192,7 +197,7 @@ class PlanPhase:
result_str = self.result.get("content", str(self.result)) result_str = self.result.get("content", str(self.result))
else: else:
result_str = str(self.result) result_str = str(self.result)
return { out: dict[str, Any] = {
"id": self.id, "id": self.id,
"name": self.name, "name": self.name,
"assigned_expert": self.assigned_expert, "assigned_expert": self.assigned_expert,
@ -206,6 +211,12 @@ class PlanPhase:
"rework_count": self.rework_count, "rework_count": self.rework_count,
"review_feedback": self.review_feedback, "review_feedback": self.review_feedback,
} }
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
if self.validation_command is not None:
out["validation_command"] = self.validation_command
if self.rollback_command is not None:
out["rollback_command"] = self.rollback_command
return out
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> PlanPhase: def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
@ -230,6 +241,8 @@ class PlanPhase:
collaboration_contracts=contracts, collaboration_contracts=contracts,
rework_count=data.get("rework_count", 0), rework_count=data.get("rework_count", 0),
review_feedback=data.get("review_feedback"), review_feedback=data.get("review_feedback"),
validation_command=data.get("validation_command"),
rollback_command=data.get("rollback_command"),
) )

View File

@ -0,0 +1,97 @@
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
复用 VerificationLoop asyncio.create_subprocess_shell 模式 (KTD7)
绕过 ShellTool避免 confirm_callback `git checkout` 的拦截
设计依据
- KTD6: 回滚是 opt-in 行为未配置 rollback_command 时不会执行
- KTD7: 不走 ShellTool避免 _is_dangerous 触发 confirm_callback
- R21: checkpoint.save 仅在回滚校验通过后调用
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class RollbackResult:
"""单次 subprocess 执行结果"""
passed: bool
exit_code: int
stdout: str
stderr: str
command: str
class RollbackExecutor:
"""执行 validation_command / rollback_command 的子进程封装
VerificationLoop 同构但语义不同
- validate(): 返回 passed=False 表示校验失败需要触发 rollback
- execute(): 返回 passed=False 表示回滚本身失败需跳过 checkpoint.save (R21)
"""
def __init__(
self,
working_dir: str | None = None,
timeout: float = 30.0,
) -> None:
self._working_dir = working_dir
self._timeout = timeout
async def _run(self, command: str) -> RollbackResult:
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self._working_dir,
)
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Failed to spawn command: {e}",
command=command,
)
try:
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
return RollbackResult(
passed=False,
exit_code=-1,
stdout="",
stderr=f"Command timed out after {self._timeout}s: {command}",
command=command,
)
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
return RollbackResult(
passed=proc.returncode == 0,
exit_code=proc.returncode if proc.returncode is not None else -1,
stdout=out_str,
stderr=err_str,
command=command,
)
async def validate(self, command: str) -> RollbackResult:
"""运行 validation_commandpassed=False 表示需要触发 rollback"""
return await self._run(command)
async def execute(self, command: str) -> RollbackResult:
"""运行 rollback_commandpassed=False 表示回滚本身失败 (R21)"""
return await self._run(command)

View File

@ -119,6 +119,7 @@ class ServerConfig:
prompt_cache: dict[str, Any] | None = None, prompt_cache: dict[str, Any] | None = None,
streaming: dict[str, Any] | None = None, streaming: dict[str, Any] | None = None,
verification: dict[str, Any] | None = None, verification: dict[str, Any] | None = None,
rollback: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None, on_change: Callable[["ServerConfig"], None] | None = None,
): ):
self.host = host self.host = host
@ -153,6 +154,9 @@ class ServerConfig:
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1) # U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
# verification_enabled=False 时此配置无效 # verification_enabled=False 时此配置无效
self.verification = verification or {} self.verification = verification or {}
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
self.rollback = rollback or {}
self.on_change = on_change self.on_change = on_change
# Config watching state # Config watching state
@ -240,6 +244,8 @@ class ServerConfig:
prompt_cache_data = data.get("prompt_cache", {}) prompt_cache_data = data.get("prompt_cache", {})
streaming_data = data.get("streaming", {}) streaming_data = data.get("streaming", {})
verification_data = data.get("verification", {}) verification_data = data.get("verification", {})
# G9/U4: rollback 配置 (从 YAML 读取opt-in)
rollback_data = data.get("rollback", {})
return cls( return cls(
host=server.get("host", "0.0.0.0"), host=server.get("host", "0.0.0.0"),
@ -271,6 +277,7 @@ class ServerConfig:
prompt_cache=prompt_cache_data, prompt_cache=prompt_cache_data,
streaming=streaming_data, streaming=streaming_data,
verification=verification_data, verification=verification_data,
rollback=rollback_data,
) )
@staticmethod @staticmethod

View File

@ -0,0 +1,367 @@
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
Characterization-first: captures pre-change behavior (rollback_command=None
no rollback, checkpoint saved) before asserting new rollback behavior.
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import PlanPhase, TeamPlan
from agentkit.orchestrator.rollback import RollbackExecutor
# ─── PlanPhase field characterization ──────────────────────────────────────
class TestPlanPhaseFields:
"""PlanPhase serialization for new fields."""
def test_characterization_no_new_keys_when_unset(self):
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
ph = PlanPhase(name="执行", assigned_expert="lead")
out = ph.to_dict()
assert "validation_command" not in out
assert "rollback_command" not in out
# Pre-change keys remain
assert out["name"] == "执行"
assert out["assigned_expert"] == "lead"
assert out["status"] == "pending"
def test_characterization_from_dict_empty_yields_none(self):
"""from_dict({}) produces both new fields as None."""
ph = PlanPhase.from_dict({"name": "x"})
assert ph.validation_command is None
assert ph.rollback_command is None
def test_serialization_includes_keys_when_set(self):
ph = PlanPhase(
name="frontend",
validation_command="ruff check src/",
rollback_command="git checkout src/app.vue",
)
out = ph.to_dict()
assert out["validation_command"] == "ruff check src/"
assert out["rollback_command"] == "git checkout src/app.vue"
def test_serialization_round_trip(self):
ph = PlanPhase(
name="backend",
validation_command="pytest -x -q",
rollback_command="git checkout src/api.py",
)
restored = PlanPhase.from_dict(ph.to_dict())
assert restored.validation_command == "pytest -x -q"
assert restored.rollback_command == "git checkout src/api.py"
def test_only_validation_set_still_emits_validation_key(self):
"""Asymmetric case: only validation_command set — only that key appears."""
ph = PlanPhase(name="x", validation_command="echo ok")
out = ph.to_dict()
assert "validation_command" in out
assert "rollback_command" not in out
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
@pytest.fixture
def tmp_workspace(tmp_path):
"""Fresh working directory for subprocess execution."""
return str(tmp_path)
class TestRollbackExecutor:
"""RollbackExecutor happy/edge/failure paths."""
@pytest.mark.asyncio
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("true")
assert r.passed is True
assert r.exit_code == 0
assert r.command == "true"
@pytest.mark.asyncio
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("false")
assert r.passed is False
assert r.exit_code != 0
@pytest.mark.asyncio
async def test_execute_timeout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
r = await ex.execute("sleep 5")
assert r.passed is False
assert r.exit_code == -1
assert "timed out" in r.stderr
@pytest.mark.asyncio
async def test_execute_captures_stdout(self, tmp_workspace):
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("echo hello-rollback")
assert r.passed is True
assert "hello-rollback" in r.stdout
@pytest.mark.asyncio
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
"""validate() is just execute() with different intent marker."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.validate("true")
assert r.passed is True
r2 = await ex.validate("false")
assert r2.passed is False
@pytest.mark.asyncio
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
"""Bad shell command should surface as failed result, not raise."""
# Use a definitely-broken command — shell still spawns, returns non-zero.
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
r = await ex.execute("exit 7")
assert r.passed is False
assert r.exit_code == 7
@pytest.mark.asyncio
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
"""Files created in working_dir are visible via cwd-relative commands."""
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
await ex.execute("echo content > testfile.txt")
r = await ex.execute("test -f testfile.txt")
assert r.passed is True
# ─── Real git rollback integration ────────────────────────────────────────
class TestGitRollbackIntegration:
"""Real git repo fixture: writes file, rollback restores via git checkout."""
@pytest.mark.asyncio
async def test_git_checkout_restores_modified_file(self, tmp_path):
# Init repo and commit a baseline file
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
baseline = "original content\n"
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write(baseline)
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
# Mutate the file
with open(os.path.join(repo, "foo.txt"), "w") as f:
f.write("mutated content\n")
# Run git checkout as rollback_command
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
r = await ex.execute("git checkout foo.txt")
assert r.passed is True
with open(os.path.join(repo, "foo.txt")) as f:
assert f.read() == baseline
# ─── TeamOrchestrator integration ─────────────────────────────────────────
def _make_team_mock():
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
team = MagicMock()
team.team_id = "test-team"
team.status.value = "executing"
team.lead_expert = None
team.active_experts = []
return team
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
"""Build a TeamOrchestrator with mocked team and checkpoint."""
team = team or _make_team_mock()
orch = TeamOrchestrator(
team=team,
checkpoint=checkpoint,
workspace_root=workspace_root,
rollback_timeout=2.0,
)
orch._broadcast_event = AsyncMock()
return orch
class TestOrchestratorRollbackIntegration:
"""TeamOrchestrator phase failure path integration with rollback."""
@pytest.mark.asyncio
async def test_characterization_no_rollback_when_unset(self, tmp_path):
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
# Call _run_phase_rollback — but it should NOT be called from
# orchestrator path when fields unset. Verify by simulating the
# main-loop guard directly.
from agentkit.experts.plan import PhaseStatus as PS
ph.status = PS.FAILED
# Simulate the guard condition in orchestrator.py:280-288
should_save = True
if (
ph.validation_command
and ph.rollback_command
and isinstance(Exception("x"), (Exception,))
):
should_save = await orch._run_phase_rollback(plan, ph)
# Guard should not fire
assert should_save is True
# No rollback events broadcast
assert orch._broadcast_event.call_count == 0
@pytest.mark.asyncio
async def test_validation_passes_no_rollback_executed(self, tmp_path):
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="true", # always passes
rollback_command="false", # would fail if executed
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# Events: started + completed (no rollback)
events = [
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
]
assert "phase_rollback_started" in events
assert "phase_rollback_completed" in events
assert "phase_rollback_failed" not in events
@pytest.mark.asyncio
async def test_validation_fails_rollback_succeeds(self, tmp_path):
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
# Create file under git tracking, then mutate it
import subprocess
repo = str(tmp_path)
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("base\n")
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
with open(os.path.join(repo, "f.txt"), "w") as f:
f.write("mutated\n")
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # validation fails
rollback_command="git checkout f.txt", # rollback succeeds
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is True
# File restored to baseline
with open(os.path.join(repo, "f.txt")) as f:
assert f.read() == "base\n"
@pytest.mark.asyncio
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false", # fails
rollback_command="false", # also fails
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False # R21: skip checkpoint
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_started" in events
assert "phase_rollback_failed" in events
# phase_rollback_completed should NOT be in events
assert "phase_rollback_completed" not in events
@pytest.mark.asyncio
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
checkpoint = MagicMock()
checkpoint.save = AsyncMock()
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
orch._rollback_timeout = 0.1 # short timeout
ph = PlanPhase(
name="p1",
assigned_expert="lead",
validation_command="false",
rollback_command="sleep 5",
)
plan = TeamPlan(task="t", lead_expert="lead")
plan.phases = [ph]
should_save = await orch._run_phase_rollback(plan, ph)
assert should_save is False
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
assert "phase_rollback_failed" in events
# ─── Config wiring ────────────────────────────────────────────────────────
class TestServerConfigRollback:
"""ServerConfig rollback section wiring."""
def test_rollback_section_read_from_dict(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict(
{
"rollback": {
"default_timeout": 45.0,
}
}
)
assert config.rollback == {"default_timeout": 45.0}
def test_rollback_defaults_empty_when_absent(self):
from agentkit.server.config import ServerConfig
config = ServerConfig.from_dict({})
assert config.rollback == {}