368 lines
14 KiB
Python
368 lines
14 KiB
Python
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
|
|
|
|
Characterization-first: captures pre-change behavior (rollback_command=None →
|
|
no rollback, checkpoint saved) before asserting new rollback behavior.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
|
from agentkit.experts.plan import PlanPhase, TeamPlan
|
|
from agentkit.orchestrator.rollback import RollbackExecutor
|
|
|
|
|
|
# ─── PlanPhase field characterization ──────────────────────────────────────
|
|
|
|
|
|
class TestPlanPhaseFields:
|
|
"""PlanPhase serialization for new fields."""
|
|
|
|
def test_characterization_no_new_keys_when_unset(self):
|
|
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
|
|
ph = PlanPhase(name="执行", assigned_expert="lead")
|
|
out = ph.to_dict()
|
|
assert "validation_command" not in out
|
|
assert "rollback_command" not in out
|
|
# Pre-change keys remain
|
|
assert out["name"] == "执行"
|
|
assert out["assigned_expert"] == "lead"
|
|
assert out["status"] == "pending"
|
|
|
|
def test_characterization_from_dict_empty_yields_none(self):
|
|
"""from_dict({}) produces both new fields as None."""
|
|
ph = PlanPhase.from_dict({"name": "x"})
|
|
assert ph.validation_command is None
|
|
assert ph.rollback_command is None
|
|
|
|
def test_serialization_includes_keys_when_set(self):
|
|
ph = PlanPhase(
|
|
name="frontend",
|
|
validation_command="ruff check src/",
|
|
rollback_command="git checkout src/app.vue",
|
|
)
|
|
out = ph.to_dict()
|
|
assert out["validation_command"] == "ruff check src/"
|
|
assert out["rollback_command"] == "git checkout src/app.vue"
|
|
|
|
def test_serialization_round_trip(self):
|
|
ph = PlanPhase(
|
|
name="backend",
|
|
validation_command="pytest -x -q",
|
|
rollback_command="git checkout src/api.py",
|
|
)
|
|
restored = PlanPhase.from_dict(ph.to_dict())
|
|
assert restored.validation_command == "pytest -x -q"
|
|
assert restored.rollback_command == "git checkout src/api.py"
|
|
|
|
def test_only_validation_set_still_emits_validation_key(self):
|
|
"""Asymmetric case: only validation_command set — only that key appears."""
|
|
ph = PlanPhase(name="x", validation_command="echo ok")
|
|
out = ph.to_dict()
|
|
assert "validation_command" in out
|
|
assert "rollback_command" not in out
|
|
|
|
|
|
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_workspace(tmp_path):
|
|
"""Fresh working directory for subprocess execution."""
|
|
return str(tmp_path)
|
|
|
|
|
|
class TestRollbackExecutor:
|
|
"""RollbackExecutor happy/edge/failure paths."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
r = await ex.execute("true")
|
|
assert r.passed is True
|
|
assert r.exit_code == 0
|
|
assert r.command == "true"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
r = await ex.execute("false")
|
|
assert r.passed is False
|
|
assert r.exit_code != 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_timeout(self, tmp_workspace):
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
|
|
r = await ex.execute("sleep 5")
|
|
assert r.passed is False
|
|
assert r.exit_code == -1
|
|
assert "timed out" in r.stderr
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_captures_stdout(self, tmp_workspace):
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
r = await ex.execute("echo hello-rollback")
|
|
assert r.passed is True
|
|
assert "hello-rollback" in r.stdout
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
|
|
"""validate() is just execute() with different intent marker."""
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
r = await ex.validate("true")
|
|
assert r.passed is True
|
|
r2 = await ex.validate("false")
|
|
assert r2.passed is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
|
|
"""Bad shell command should surface as failed result, not raise."""
|
|
# Use a definitely-broken command — shell still spawns, returns non-zero.
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
r = await ex.execute("exit 7")
|
|
assert r.passed is False
|
|
assert r.exit_code == 7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
|
|
"""Files created in working_dir are visible via cwd-relative commands."""
|
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
|
await ex.execute("echo content > testfile.txt")
|
|
r = await ex.execute("test -f testfile.txt")
|
|
assert r.passed is True
|
|
|
|
|
|
# ─── Real git rollback integration ────────────────────────────────────────
|
|
|
|
|
|
class TestGitRollbackIntegration:
|
|
"""Real git repo fixture: writes file, rollback restores via git checkout."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_git_checkout_restores_modified_file(self, tmp_path):
|
|
# Init repo and commit a baseline file
|
|
import subprocess
|
|
|
|
repo = str(tmp_path)
|
|
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
|
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
|
|
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
|
|
baseline = "original content\n"
|
|
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
|
f.write(baseline)
|
|
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
|
|
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
|
|
|
|
# Mutate the file
|
|
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
|
f.write("mutated content\n")
|
|
|
|
# Run git checkout as rollback_command
|
|
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
|
|
r = await ex.execute("git checkout foo.txt")
|
|
assert r.passed is True
|
|
|
|
with open(os.path.join(repo, "foo.txt")) as f:
|
|
assert f.read() == baseline
|
|
|
|
|
|
# ─── TeamOrchestrator integration ─────────────────────────────────────────
|
|
|
|
|
|
def _make_team_mock():
|
|
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
|
|
team = MagicMock()
|
|
team.team_id = "test-team"
|
|
team.status.value = "executing"
|
|
team.lead_expert = None
|
|
team.active_experts = []
|
|
return team
|
|
|
|
|
|
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
|
|
"""Build a TeamOrchestrator with mocked team and checkpoint."""
|
|
team = team or _make_team_mock()
|
|
orch = TeamOrchestrator(
|
|
team=team,
|
|
checkpoint=checkpoint,
|
|
workspace_root=workspace_root,
|
|
rollback_timeout=2.0,
|
|
)
|
|
orch._broadcast_event = AsyncMock()
|
|
return orch
|
|
|
|
|
|
class TestOrchestratorRollbackIntegration:
|
|
"""TeamOrchestrator phase failure path integration with rollback."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_characterization_no_rollback_when_unset(self, tmp_path):
|
|
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
|
|
checkpoint = MagicMock()
|
|
checkpoint.save = AsyncMock()
|
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
|
|
|
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
|
|
plan = TeamPlan(task="t", lead_expert="lead")
|
|
plan.phases = [ph]
|
|
|
|
# Call _run_phase_rollback — but it should NOT be called from
|
|
# orchestrator path when fields unset. Verify by simulating the
|
|
# main-loop guard directly.
|
|
from agentkit.experts.plan import PhaseStatus as PS
|
|
|
|
ph.status = PS.FAILED
|
|
# Simulate the guard condition in orchestrator.py:280-288
|
|
should_save = True
|
|
if (
|
|
ph.validation_command
|
|
and ph.rollback_command
|
|
and isinstance(Exception("x"), (Exception,))
|
|
):
|
|
should_save = await orch._run_phase_rollback(plan, ph)
|
|
# Guard should not fire
|
|
assert should_save is True
|
|
# No rollback events broadcast
|
|
assert orch._broadcast_event.call_count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validation_passes_no_rollback_executed(self, tmp_path):
|
|
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
|
|
checkpoint = MagicMock()
|
|
checkpoint.save = AsyncMock()
|
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
|
|
|
ph = PlanPhase(
|
|
name="p1",
|
|
assigned_expert="lead",
|
|
validation_command="true", # always passes
|
|
rollback_command="false", # would fail if executed
|
|
)
|
|
plan = TeamPlan(task="t", lead_expert="lead")
|
|
plan.phases = [ph]
|
|
|
|
should_save = await orch._run_phase_rollback(plan, ph)
|
|
assert should_save is True
|
|
|
|
# Events: started + completed (no rollback)
|
|
events = [
|
|
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
|
|
]
|
|
assert "phase_rollback_started" in events
|
|
assert "phase_rollback_completed" in events
|
|
assert "phase_rollback_failed" not in events
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validation_fails_rollback_succeeds(self, tmp_path):
|
|
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
|
|
# Create file under git tracking, then mutate it
|
|
import subprocess
|
|
|
|
repo = str(tmp_path)
|
|
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
|
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
|
|
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
|
|
with open(os.path.join(repo, "f.txt"), "w") as f:
|
|
f.write("base\n")
|
|
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
|
|
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
|
|
with open(os.path.join(repo, "f.txt"), "w") as f:
|
|
f.write("mutated\n")
|
|
|
|
checkpoint = MagicMock()
|
|
checkpoint.save = AsyncMock()
|
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
|
|
|
|
ph = PlanPhase(
|
|
name="p1",
|
|
assigned_expert="lead",
|
|
validation_command="false", # validation fails
|
|
rollback_command="git checkout f.txt", # rollback succeeds
|
|
)
|
|
plan = TeamPlan(task="t", lead_expert="lead")
|
|
plan.phases = [ph]
|
|
|
|
should_save = await orch._run_phase_rollback(plan, ph)
|
|
assert should_save is True
|
|
|
|
# File restored to baseline
|
|
with open(os.path.join(repo, "f.txt")) as f:
|
|
assert f.read() == "base\n"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
|
|
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
|
|
checkpoint = MagicMock()
|
|
checkpoint.save = AsyncMock()
|
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
|
|
|
ph = PlanPhase(
|
|
name="p1",
|
|
assigned_expert="lead",
|
|
validation_command="false", # fails
|
|
rollback_command="false", # also fails
|
|
)
|
|
plan = TeamPlan(task="t", lead_expert="lead")
|
|
plan.phases = [ph]
|
|
|
|
should_save = await orch._run_phase_rollback(plan, ph)
|
|
assert should_save is False # R21: skip checkpoint
|
|
|
|
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
|
assert "phase_rollback_started" in events
|
|
assert "phase_rollback_failed" in events
|
|
# phase_rollback_completed should NOT be in events
|
|
assert "phase_rollback_completed" not in events
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
|
|
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
|
|
checkpoint = MagicMock()
|
|
checkpoint.save = AsyncMock()
|
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
|
orch._rollback_timeout = 0.1 # short timeout
|
|
|
|
ph = PlanPhase(
|
|
name="p1",
|
|
assigned_expert="lead",
|
|
validation_command="false",
|
|
rollback_command="sleep 5",
|
|
)
|
|
plan = TeamPlan(task="t", lead_expert="lead")
|
|
plan.phases = [ph]
|
|
|
|
should_save = await orch._run_phase_rollback(plan, ph)
|
|
assert should_save is False
|
|
|
|
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
|
assert "phase_rollback_failed" in events
|
|
|
|
|
|
# ─── Config wiring ────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestServerConfigRollback:
|
|
"""ServerConfig rollback section wiring."""
|
|
|
|
def test_rollback_section_read_from_dict(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
config = ServerConfig.from_dict(
|
|
{
|
|
"rollback": {
|
|
"default_timeout": 45.0,
|
|
}
|
|
}
|
|
)
|
|
assert config.rollback == {"default_timeout": 45.0}
|
|
|
|
def test_rollback_defaults_empty_when_absent(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
config = ServerConfig.from_dict({})
|
|
assert config.rollback == {}
|