feat(U4): G9 PlanPhase rollback + RollbackExecutor
- PlanPhase 新增 validation_command / rollback_command 可选字段 (KTD6 opt-in) - to_dict 仅在字段非 None 时输出新键,保持既有 dict shape (KTD6 契约) - 新增 RollbackExecutor (orchestrator/rollback.py) 复用 VerificationLoop subprocess 模式,绕过 ShellTool 避免 confirm_callback 拦截 (KTD7) - TeamOrchestrator._run_phase_rollback 实现 R21 顺序: validation → rollback → checkpoint.save (仅在前者通过时调用) - ServerConfig.from_dict 读取 rollback.default_timeout - 20 个测试覆盖 characterization / happy / timeout / git integration / 配置
This commit is contained in:
parent
5b2377469a
commit
b1841ce21b
|
|
@ -33,6 +33,13 @@ llm:
|
||||||
# the main model on failure or empty content. Commented to preserve
|
# the main model on failure or empty content. Commented to preserve
|
||||||
# default behavior — uncomment to enable.
|
# default behavior — uncomment to enable.
|
||||||
# auxiliary_model: fast
|
# auxiliary_model: fast
|
||||||
|
# G9/U4: Rollback configuration. Drives RollbackExecutor subprocess timeout
|
||||||
|
# for PlanPhase.validation_command / PlanPhase.rollback_command. Per-phase
|
||||||
|
# opt-in (KTD6) — when PlanPhase.rollback_command is unset, no rollback runs.
|
||||||
|
# Canonical rollback pattern: `git checkout <specific_files>` (file-scoped,
|
||||||
|
# not `git checkout .` which would wipe unrelated changes).
|
||||||
|
rollback:
|
||||||
|
default_timeout: 30.0
|
||||||
session: {backend: memory}
|
session: {backend: memory}
|
||||||
bus: {backend: memory}
|
bus: {backend: memory}
|
||||||
task_store: {backend: memory}
|
task_store: {backend: memory}
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ from typing import Any
|
||||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||||
|
|
||||||
from .expert import Expert
|
from .expert import Expert
|
||||||
from .plan import (
|
from .plan import (
|
||||||
|
|
@ -72,12 +73,17 @@ class TeamOrchestrator:
|
||||||
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
|
MAX_DEBATES = 3 # Hard cap on auto-inserted debate phases per execution
|
||||||
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
||||||
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
||||||
|
# G9/U4: RollbackExecutor default timeout for validation_command / rollback_command.
|
||||||
|
# Override via constructor `rollback_timeout` from `rollback.default_timeout` config.
|
||||||
|
DEFAULT_ROLLBACK_TIMEOUT = 30.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
team: ExpertTeam,
|
team: ExpertTeam,
|
||||||
max_concurrent_phases: int | None = None,
|
max_concurrent_phases: int | None = None,
|
||||||
checkpoint: Any = None,
|
checkpoint: Any = None,
|
||||||
|
workspace_root: str | None = None,
|
||||||
|
rollback_timeout: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._team = team
|
self._team = team
|
||||||
# Track temporary agent names created for context isolation (KTD3)
|
# Track temporary agent names created for context isolation (KTD3)
|
||||||
|
|
@ -93,6 +99,10 @@ class TeamOrchestrator:
|
||||||
self._phase_semaphore = asyncio.Semaphore(limit)
|
self._phase_semaphore = asyncio.Semaphore(limit)
|
||||||
# U7: Pipeline checkpoint for crash recovery
|
# U7: Pipeline checkpoint for crash recovery
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
|
# G9/U4: workspace_root drives RollbackExecutor cwd; rollback_timeout drives its timeout.
|
||||||
|
# Both default to no-op-friendly values so existing call sites behave identically.
|
||||||
|
self._workspace_root = workspace_root
|
||||||
|
self._rollback_timeout = rollback_timeout or self.DEFAULT_ROLLBACK_TIMEOUT
|
||||||
|
|
||||||
async def execute(self, task: str) -> dict[str, Any]:
|
async def execute(self, task: str) -> dict[str, Any]:
|
||||||
"""Execute a task in pipeline mode.
|
"""Execute a task in pipeline mode.
|
||||||
|
|
@ -262,8 +272,23 @@ class TeamOrchestrator:
|
||||||
else:
|
else:
|
||||||
phase_results[ph.id] = result
|
phase_results[ph.id] = result
|
||||||
|
|
||||||
|
# G9/U4: opt-in rollback (KTD6) + checkpoint ordering (R21).
|
||||||
|
# When phase configures both validation_command and rollback_command:
|
||||||
|
# 1. run validation_command — if it passes, treat phase as recoverable, save checkpoint
|
||||||
|
# 2. if validation fails, run rollback_command
|
||||||
|
# 3. if rollback passes (exit 0), save checkpoint
|
||||||
|
# 4. if rollback fails, skip checkpoint (R21 — avoid persisting broken state)
|
||||||
|
# When neither command is set, behavior is unchanged (existing save).
|
||||||
|
should_save_checkpoint = True
|
||||||
|
if (
|
||||||
|
ph.validation_command
|
||||||
|
and ph.rollback_command
|
||||||
|
and isinstance(result, (Exception, asyncio.CancelledError))
|
||||||
|
):
|
||||||
|
should_save_checkpoint = await self._run_phase_rollback(plan, ph)
|
||||||
|
|
||||||
# U7: Save checkpoint after phase finalizes (success or failure)
|
# U7: Save checkpoint after phase finalizes (success or failure)
|
||||||
if self._checkpoint is not None:
|
if should_save_checkpoint and self._checkpoint is not None:
|
||||||
try:
|
try:
|
||||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -393,9 +418,7 @@ class TeamOrchestrator:
|
||||||
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
||||||
|
|
||||||
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
|
# P2 #8: Restore debate count so MAX_DEBATES limit holds after resume
|
||||||
self._debate_count = sum(
|
self._debate_count = sum(1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE)
|
||||||
1 for ph in plan.phases if ph.phase_type == PhaseType.DEBATE
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
||||||
|
|
@ -688,9 +711,9 @@ class TeamOrchestrator:
|
||||||
and prev_phase.result
|
and prev_phase.result
|
||||||
):
|
):
|
||||||
# U4: Resolve offloaded content from workspace
|
# U4: Resolve offloaded content from workspace
|
||||||
collaboration_outputs[contract.from_expert] = (
|
collaboration_outputs[
|
||||||
await self._read_dependency_output(prev_phase)
|
contract.from_expert
|
||||||
)
|
] = await self._read_dependency_output(prev_phase)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Emit expert_step event
|
# Emit expert_step event
|
||||||
|
|
@ -1809,6 +1832,75 @@ class TeamOrchestrator:
|
||||||
# Recursively mark their dependents
|
# Recursively mark their dependents
|
||||||
await self._mark_dependents_failed(ph.id, plan, phase_results)
|
await self._mark_dependents_failed(ph.id, plan, phase_results)
|
||||||
|
|
||||||
|
async def _run_phase_rollback(self, plan: TeamPlan, ph: PlanPhase) -> bool:
|
||||||
|
"""G9/U4: run validation_command + rollback_command for a failed phase.
|
||||||
|
|
||||||
|
Returns True if checkpoint save should proceed (R21 ordering).
|
||||||
|
- Validation passes → save checkpoint (phase state recoverable)
|
||||||
|
- Validation fails, rollback passes → save checkpoint (rolled back state)
|
||||||
|
- Validation fails, rollback fails → skip checkpoint (broken state)
|
||||||
|
- Subprocess spawn failure or timeout → skip checkpoint
|
||||||
|
"""
|
||||||
|
executor = RollbackExecutor(
|
||||||
|
working_dir=self._workspace_root,
|
||||||
|
timeout=self._rollback_timeout,
|
||||||
|
)
|
||||||
|
await self._broadcast_event(
|
||||||
|
"phase_rollback_started",
|
||||||
|
{
|
||||||
|
"plan_id": plan.id,
|
||||||
|
"phase_id": ph.id,
|
||||||
|
"phase_name": ph.name,
|
||||||
|
"validation_command": ph.validation_command,
|
||||||
|
"rollback_command": ph.rollback_command,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# ponytail: validate first; if validation passes, rollback is skipped (no need).
|
||||||
|
validation = await executor.validate(ph.validation_command or "")
|
||||||
|
if validation.passed:
|
||||||
|
await self._broadcast_event(
|
||||||
|
"phase_rollback_completed",
|
||||||
|
{
|
||||||
|
"plan_id": plan.id,
|
||||||
|
"phase_id": ph.id,
|
||||||
|
"phase_name": ph.name,
|
||||||
|
"rollback_executed": False,
|
||||||
|
"validation_passed": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
rollback = await executor.execute(ph.rollback_command or "")
|
||||||
|
if rollback.passed:
|
||||||
|
await self._broadcast_event(
|
||||||
|
"phase_rollback_completed",
|
||||||
|
{
|
||||||
|
"plan_id": plan.id,
|
||||||
|
"phase_id": ph.id,
|
||||||
|
"phase_name": ph.name,
|
||||||
|
"rollback_executed": True,
|
||||||
|
"validation_passed": False,
|
||||||
|
"rollback_stdout": rollback.stdout,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Rollback failed for phase {ph.id} ({ph.name}): exit={rollback.exit_code} stderr={rollback.stderr}"
|
||||||
|
)
|
||||||
|
await self._broadcast_event(
|
||||||
|
"phase_rollback_failed",
|
||||||
|
{
|
||||||
|
"plan_id": plan.id,
|
||||||
|
"phase_id": ph.id,
|
||||||
|
"phase_name": ph.name,
|
||||||
|
"validation_passed": False,
|
||||||
|
"rollback_exit_code": rollback.exit_code,
|
||||||
|
"rollback_stderr": rollback.stderr,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _synthesize_results(
|
async def _synthesize_results(
|
||||||
self, lead: Expert, task: str, completed_phases: list[PlanPhase]
|
self, lead: Expert, task: str, completed_phases: list[PlanPhase]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
|
|
||||||
|
|
@ -182,6 +182,11 @@ class PlanPhase:
|
||||||
collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
|
collaboration_contracts: list[CollaborationContract] = field(default_factory=list)
|
||||||
rework_count: int = 0
|
rework_count: int = 0
|
||||||
review_feedback: str | None = None
|
review_feedback: str | None = None
|
||||||
|
# G9/U4: opt-in rollback fields. When unset, no rollback executes (KTD6).
|
||||||
|
# validation_command runs first; if it fails, rollback_command runs.
|
||||||
|
# canonical rollback pattern: `git checkout <specific_files>`.
|
||||||
|
validation_command: str | None = None
|
||||||
|
rollback_command: str | None = None
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""序列化为字典"""
|
"""序列化为字典"""
|
||||||
|
|
@ -192,7 +197,7 @@ class PlanPhase:
|
||||||
result_str = self.result.get("content", str(self.result))
|
result_str = self.result.get("content", str(self.result))
|
||||||
else:
|
else:
|
||||||
result_str = str(self.result)
|
result_str = str(self.result)
|
||||||
return {
|
out: dict[str, Any] = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"assigned_expert": self.assigned_expert,
|
"assigned_expert": self.assigned_expert,
|
||||||
|
|
@ -206,6 +211,12 @@ class PlanPhase:
|
||||||
"rework_count": self.rework_count,
|
"rework_count": self.rework_count,
|
||||||
"review_feedback": self.review_feedback,
|
"review_feedback": self.review_feedback,
|
||||||
}
|
}
|
||||||
|
# G9/U4: only include new keys when set, to preserve pre-change dict shape (KTD6).
|
||||||
|
if self.validation_command is not None:
|
||||||
|
out["validation_command"] = self.validation_command
|
||||||
|
if self.rollback_command is not None:
|
||||||
|
out["rollback_command"] = self.rollback_command
|
||||||
|
return out
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
||||||
|
|
@ -230,6 +241,8 @@ class PlanPhase:
|
||||||
collaboration_contracts=contracts,
|
collaboration_contracts=contracts,
|
||||||
rework_count=data.get("rework_count", 0),
|
rework_count=data.get("rework_count", 0),
|
||||||
review_feedback=data.get("review_feedback"),
|
review_feedback=data.get("review_feedback"),
|
||||||
|
validation_command=data.get("validation_command"),
|
||||||
|
rollback_command=data.get("rollback_command"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""RollbackExecutor — 阶段失败后执行回滚命令 (G9/U4)
|
||||||
|
|
||||||
|
复用 VerificationLoop 的 asyncio.create_subprocess_shell 模式 (KTD7):
|
||||||
|
绕过 ShellTool,避免 confirm_callback 对 `git checkout` 的拦截。
|
||||||
|
|
||||||
|
设计依据:
|
||||||
|
- KTD6: 回滚是 opt-in 行为,未配置 rollback_command 时不会执行
|
||||||
|
- KTD7: 不走 ShellTool,避免 _is_dangerous 触发 confirm_callback
|
||||||
|
- R21: checkpoint.save 仅在回滚校验通过后调用
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RollbackResult:
|
||||||
|
"""单次 subprocess 执行结果"""
|
||||||
|
|
||||||
|
passed: bool
|
||||||
|
exit_code: int
|
||||||
|
stdout: str
|
||||||
|
stderr: str
|
||||||
|
command: str
|
||||||
|
|
||||||
|
|
||||||
|
class RollbackExecutor:
|
||||||
|
"""执行 validation_command / rollback_command 的子进程封装
|
||||||
|
|
||||||
|
与 VerificationLoop 同构,但语义不同:
|
||||||
|
- validate(): 返回 passed=False 表示校验失败,需要触发 rollback
|
||||||
|
- execute(): 返回 passed=False 表示回滚本身失败,需跳过 checkpoint.save (R21)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
self._working_dir = working_dir
|
||||||
|
self._timeout = timeout
|
||||||
|
|
||||||
|
async def _run(self, command: str) -> RollbackResult:
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=self._working_dir,
|
||||||
|
)
|
||||||
|
except Exception as e: # noqa: BLE001 - subprocess spawn failure surface
|
||||||
|
return RollbackResult(
|
||||||
|
passed=False,
|
||||||
|
exit_code=-1,
|
||||||
|
stdout="",
|
||||||
|
stderr=f"Failed to spawn command: {e}",
|
||||||
|
command=command,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=self._timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
try:
|
||||||
|
proc.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
await proc.wait()
|
||||||
|
return RollbackResult(
|
||||||
|
passed=False,
|
||||||
|
exit_code=-1,
|
||||||
|
stdout="",
|
||||||
|
stderr=f"Command timed out after {self._timeout}s: {command}",
|
||||||
|
command=command,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_str = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||||
|
err_str = stderr.decode("utf-8", errors="replace") if stderr else ""
|
||||||
|
return RollbackResult(
|
||||||
|
passed=proc.returncode == 0,
|
||||||
|
exit_code=proc.returncode if proc.returncode is not None else -1,
|
||||||
|
stdout=out_str,
|
||||||
|
stderr=err_str,
|
||||||
|
command=command,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate(self, command: str) -> RollbackResult:
|
||||||
|
"""运行 validation_command,passed=False 表示需要触发 rollback"""
|
||||||
|
return await self._run(command)
|
||||||
|
|
||||||
|
async def execute(self, command: str) -> RollbackResult:
|
||||||
|
"""运行 rollback_command,passed=False 表示回滚本身失败 (R21)"""
|
||||||
|
return await self._run(command)
|
||||||
|
|
@ -119,6 +119,7 @@ class ServerConfig:
|
||||||
prompt_cache: dict[str, Any] | None = None,
|
prompt_cache: dict[str, Any] | None = None,
|
||||||
streaming: dict[str, Any] | None = None,
|
streaming: dict[str, Any] | None = None,
|
||||||
verification: dict[str, Any] | None = None,
|
verification: dict[str, Any] | None = None,
|
||||||
|
rollback: dict[str, Any] | None = None,
|
||||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
@ -153,6 +154,9 @@ class ServerConfig:
|
||||||
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
# U4/G1: verification.max_reinjections 控制 verify 失败回灌次数(默认 1)
|
||||||
# verification_enabled=False 时此配置无效
|
# verification_enabled=False 时此配置无效
|
||||||
self.verification = verification or {}
|
self.verification = verification or {}
|
||||||
|
# G9/U4: rollback.default_timeout 控制 RollbackExecutor subprocess 超时
|
||||||
|
# PlanPhase.rollback_command 未设置时此配置无效 (KTD6 opt-in)
|
||||||
|
self.rollback = rollback or {}
|
||||||
self.on_change = on_change
|
self.on_change = on_change
|
||||||
|
|
||||||
# Config watching state
|
# Config watching state
|
||||||
|
|
@ -240,6 +244,8 @@ class ServerConfig:
|
||||||
prompt_cache_data = data.get("prompt_cache", {})
|
prompt_cache_data = data.get("prompt_cache", {})
|
||||||
streaming_data = data.get("streaming", {})
|
streaming_data = data.get("streaming", {})
|
||||||
verification_data = data.get("verification", {})
|
verification_data = data.get("verification", {})
|
||||||
|
# G9/U4: rollback 配置 (从 YAML 读取,opt-in)
|
||||||
|
rollback_data = data.get("rollback", {})
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
host=server.get("host", "0.0.0.0"),
|
host=server.get("host", "0.0.0.0"),
|
||||||
|
|
@ -271,6 +277,7 @@ class ServerConfig:
|
||||||
prompt_cache=prompt_cache_data,
|
prompt_cache=prompt_cache_data,
|
||||||
streaming=streaming_data,
|
streaming=streaming_data,
|
||||||
verification=verification_data,
|
verification=verification_data,
|
||||||
|
rollback=rollback_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,367 @@
|
||||||
|
"""G9/U4 — PlanPhase rollback fields + RollbackExecutor + TeamOrchestrator integration.
|
||||||
|
|
||||||
|
Characterization-first: captures pre-change behavior (rollback_command=None →
|
||||||
|
no rollback, checkpoint saved) before asserting new rollback behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||||
|
from agentkit.experts.plan import PlanPhase, TeamPlan
|
||||||
|
from agentkit.orchestrator.rollback import RollbackExecutor
|
||||||
|
|
||||||
|
|
||||||
|
# ─── PlanPhase field characterization ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlanPhaseFields:
|
||||||
|
"""PlanPhase serialization for new fields."""
|
||||||
|
|
||||||
|
def test_characterization_no_new_keys_when_unset(self):
|
||||||
|
"""Default PlanPhase.to_dict() must not include new keys (KTD6 contract)."""
|
||||||
|
ph = PlanPhase(name="执行", assigned_expert="lead")
|
||||||
|
out = ph.to_dict()
|
||||||
|
assert "validation_command" not in out
|
||||||
|
assert "rollback_command" not in out
|
||||||
|
# Pre-change keys remain
|
||||||
|
assert out["name"] == "执行"
|
||||||
|
assert out["assigned_expert"] == "lead"
|
||||||
|
assert out["status"] == "pending"
|
||||||
|
|
||||||
|
def test_characterization_from_dict_empty_yields_none(self):
|
||||||
|
"""from_dict({}) produces both new fields as None."""
|
||||||
|
ph = PlanPhase.from_dict({"name": "x"})
|
||||||
|
assert ph.validation_command is None
|
||||||
|
assert ph.rollback_command is None
|
||||||
|
|
||||||
|
def test_serialization_includes_keys_when_set(self):
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="frontend",
|
||||||
|
validation_command="ruff check src/",
|
||||||
|
rollback_command="git checkout src/app.vue",
|
||||||
|
)
|
||||||
|
out = ph.to_dict()
|
||||||
|
assert out["validation_command"] == "ruff check src/"
|
||||||
|
assert out["rollback_command"] == "git checkout src/app.vue"
|
||||||
|
|
||||||
|
def test_serialization_round_trip(self):
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="backend",
|
||||||
|
validation_command="pytest -x -q",
|
||||||
|
rollback_command="git checkout src/api.py",
|
||||||
|
)
|
||||||
|
restored = PlanPhase.from_dict(ph.to_dict())
|
||||||
|
assert restored.validation_command == "pytest -x -q"
|
||||||
|
assert restored.rollback_command == "git checkout src/api.py"
|
||||||
|
|
||||||
|
def test_only_validation_set_still_emits_validation_key(self):
|
||||||
|
"""Asymmetric case: only validation_command set — only that key appears."""
|
||||||
|
ph = PlanPhase(name="x", validation_command="echo ok")
|
||||||
|
out = ph.to_dict()
|
||||||
|
assert "validation_command" in out
|
||||||
|
assert "rollback_command" not in out
|
||||||
|
|
||||||
|
|
||||||
|
# ─── RollbackExecutor subprocess execution ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path):
|
||||||
|
"""Fresh working directory for subprocess execution."""
|
||||||
|
return str(tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackExecutor:
|
||||||
|
"""RollbackExecutor happy/edge/failure paths."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_happy_path_zero_exit(self, tmp_workspace):
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
r = await ex.execute("true")
|
||||||
|
assert r.passed is True
|
||||||
|
assert r.exit_code == 0
|
||||||
|
assert r.command == "true"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_failure_nonzero_exit(self, tmp_workspace):
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
r = await ex.execute("false")
|
||||||
|
assert r.passed is False
|
||||||
|
assert r.exit_code != 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_timeout(self, tmp_workspace):
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=0.1)
|
||||||
|
r = await ex.execute("sleep 5")
|
||||||
|
assert r.passed is False
|
||||||
|
assert r.exit_code == -1
|
||||||
|
assert "timed out" in r.stderr
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_captures_stdout(self, tmp_workspace):
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
r = await ex.execute("echo hello-rollback")
|
||||||
|
assert r.passed is True
|
||||||
|
assert "hello-rollback" in r.stdout
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_same_semantics_as_execute(self, tmp_workspace):
|
||||||
|
"""validate() is just execute() with different intent marker."""
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
r = await ex.validate("true")
|
||||||
|
assert r.passed is True
|
||||||
|
r2 = await ex.validate("false")
|
||||||
|
assert r2.passed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spawn_failure_returns_failed_result(self, tmp_workspace):
|
||||||
|
"""Bad shell command should surface as failed result, not raise."""
|
||||||
|
# Use a definitely-broken command — shell still spawns, returns non-zero.
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
r = await ex.execute("exit 7")
|
||||||
|
assert r.passed is False
|
||||||
|
assert r.exit_code == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cwd_used_for_relative_paths(self, tmp_workspace):
|
||||||
|
"""Files created in working_dir are visible via cwd-relative commands."""
|
||||||
|
ex = RollbackExecutor(working_dir=tmp_workspace, timeout=5.0)
|
||||||
|
await ex.execute("echo content > testfile.txt")
|
||||||
|
r = await ex.execute("test -f testfile.txt")
|
||||||
|
assert r.passed is True
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Real git rollback integration ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitRollbackIntegration:
|
||||||
|
"""Real git repo fixture: writes file, rollback restores via git checkout."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_git_checkout_restores_modified_file(self, tmp_path):
|
||||||
|
# Init repo and commit a baseline file
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
repo = str(tmp_path)
|
||||||
|
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "config", "user.email", "test@x"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "config", "user.name", "test"], cwd=repo, check=True)
|
||||||
|
baseline = "original content\n"
|
||||||
|
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||||
|
f.write(baseline)
|
||||||
|
subprocess.run(["git", "add", "foo.txt"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "commit", "-q", "-m", "baseline"], cwd=repo, check=True)
|
||||||
|
|
||||||
|
# Mutate the file
|
||||||
|
with open(os.path.join(repo, "foo.txt"), "w") as f:
|
||||||
|
f.write("mutated content\n")
|
||||||
|
|
||||||
|
# Run git checkout as rollback_command
|
||||||
|
ex = RollbackExecutor(working_dir=repo, timeout=10.0)
|
||||||
|
r = await ex.execute("git checkout foo.txt")
|
||||||
|
assert r.passed is True
|
||||||
|
|
||||||
|
with open(os.path.join(repo, "foo.txt")) as f:
|
||||||
|
assert f.read() == baseline
|
||||||
|
|
||||||
|
|
||||||
|
# ─── TeamOrchestrator integration ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_team_mock():
|
||||||
|
"""Build a minimal ExpertTeam mock for TeamOrchestrator."""
|
||||||
|
team = MagicMock()
|
||||||
|
team.team_id = "test-team"
|
||||||
|
team.status.value = "executing"
|
||||||
|
team.lead_expert = None
|
||||||
|
team.active_experts = []
|
||||||
|
return team
|
||||||
|
|
||||||
|
|
||||||
|
def _make_orchestrator(team=None, checkpoint=None, workspace_root=None):
|
||||||
|
"""Build a TeamOrchestrator with mocked team and checkpoint."""
|
||||||
|
team = team or _make_team_mock()
|
||||||
|
orch = TeamOrchestrator(
|
||||||
|
team=team,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
workspace_root=workspace_root,
|
||||||
|
rollback_timeout=2.0,
|
||||||
|
)
|
||||||
|
orch._broadcast_event = AsyncMock()
|
||||||
|
return orch
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestratorRollbackIntegration:
|
||||||
|
"""TeamOrchestrator phase failure path integration with rollback."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_characterization_no_rollback_when_unset(self, tmp_path):
|
||||||
|
"""Phase fails, rollback_command=None → checkpoint saved, no rollback events."""
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
checkpoint.save = AsyncMock()
|
||||||
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||||
|
|
||||||
|
ph = PlanPhase(name="p1", assigned_expert="lead") # no rollback fields
|
||||||
|
plan = TeamPlan(task="t", lead_expert="lead")
|
||||||
|
plan.phases = [ph]
|
||||||
|
|
||||||
|
# Call _run_phase_rollback — but it should NOT be called from
|
||||||
|
# orchestrator path when fields unset. Verify by simulating the
|
||||||
|
# main-loop guard directly.
|
||||||
|
from agentkit.experts.plan import PhaseStatus as PS
|
||||||
|
|
||||||
|
ph.status = PS.FAILED
|
||||||
|
# Simulate the guard condition in orchestrator.py:280-288
|
||||||
|
should_save = True
|
||||||
|
if (
|
||||||
|
ph.validation_command
|
||||||
|
and ph.rollback_command
|
||||||
|
and isinstance(Exception("x"), (Exception,))
|
||||||
|
):
|
||||||
|
should_save = await orch._run_phase_rollback(plan, ph)
|
||||||
|
# Guard should not fire
|
||||||
|
assert should_save is True
|
||||||
|
# No rollback events broadcast
|
||||||
|
assert orch._broadcast_event.call_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_passes_no_rollback_executed(self, tmp_path):
|
||||||
|
"""validation_command returns 0 → rollback NOT executed, checkpoint saved."""
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
checkpoint.save = AsyncMock()
|
||||||
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||||
|
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="p1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
validation_command="true", # always passes
|
||||||
|
rollback_command="false", # would fail if executed
|
||||||
|
)
|
||||||
|
plan = TeamPlan(task="t", lead_expert="lead")
|
||||||
|
plan.phases = [ph]
|
||||||
|
|
||||||
|
should_save = await orch._run_phase_rollback(plan, ph)
|
||||||
|
assert should_save is True
|
||||||
|
|
||||||
|
# Events: started + completed (no rollback)
|
||||||
|
events = [
|
||||||
|
c.kwargs.get("event_type") or c.args[0] for c in orch._broadcast_event.call_args_list
|
||||||
|
]
|
||||||
|
assert "phase_rollback_started" in events
|
||||||
|
assert "phase_rollback_completed" in events
|
||||||
|
assert "phase_rollback_failed" not in events
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_fails_rollback_succeeds(self, tmp_path):
|
||||||
|
"""Validation fails, rollback returns 0 → checkpoint saved, rollback_executed=True."""
|
||||||
|
# Create file under git tracking, then mutate it
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
repo = str(tmp_path)
|
||||||
|
subprocess.run(["git", "init", "-q"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "config", "user.email", "t@x"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "config", "user.name", "t"], cwd=repo, check=True)
|
||||||
|
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||||
|
f.write("base\n")
|
||||||
|
subprocess.run(["git", "add", "f.txt"], cwd=repo, check=True)
|
||||||
|
subprocess.run(["git", "commit", "-q", "-m", "base"], cwd=repo, check=True)
|
||||||
|
with open(os.path.join(repo, "f.txt"), "w") as f:
|
||||||
|
f.write("mutated\n")
|
||||||
|
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
checkpoint.save = AsyncMock()
|
||||||
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=repo)
|
||||||
|
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="p1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
validation_command="false", # validation fails
|
||||||
|
rollback_command="git checkout f.txt", # rollback succeeds
|
||||||
|
)
|
||||||
|
plan = TeamPlan(task="t", lead_expert="lead")
|
||||||
|
plan.phases = [ph]
|
||||||
|
|
||||||
|
should_save = await orch._run_phase_rollback(plan, ph)
|
||||||
|
assert should_save is True
|
||||||
|
|
||||||
|
# File restored to baseline
|
||||||
|
with open(os.path.join(repo, "f.txt")) as f:
|
||||||
|
assert f.read() == "base\n"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_fails_rollback_fails_skips_checkpoint(self, tmp_path):
|
||||||
|
"""Validation fails AND rollback fails → checkpoint NOT saved (R21), event emitted."""
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
checkpoint.save = AsyncMock()
|
||||||
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||||
|
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="p1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
validation_command="false", # fails
|
||||||
|
rollback_command="false", # also fails
|
||||||
|
)
|
||||||
|
plan = TeamPlan(task="t", lead_expert="lead")
|
||||||
|
plan.phases = [ph]
|
||||||
|
|
||||||
|
should_save = await orch._run_phase_rollback(plan, ph)
|
||||||
|
assert should_save is False # R21: skip checkpoint
|
||||||
|
|
||||||
|
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||||
|
assert "phase_rollback_started" in events
|
||||||
|
assert "phase_rollback_failed" in events
|
||||||
|
# phase_rollback_completed should NOT be in events
|
||||||
|
assert "phase_rollback_completed" not in events
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_timeout_skips_checkpoint(self, tmp_path):
|
||||||
|
"""Rollback command times out → checkpoint NOT saved, failed event emitted."""
|
||||||
|
checkpoint = MagicMock()
|
||||||
|
checkpoint.save = AsyncMock()
|
||||||
|
orch = _make_orchestrator(checkpoint=checkpoint, workspace_root=str(tmp_path))
|
||||||
|
orch._rollback_timeout = 0.1 # short timeout
|
||||||
|
|
||||||
|
ph = PlanPhase(
|
||||||
|
name="p1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
validation_command="false",
|
||||||
|
rollback_command="sleep 5",
|
||||||
|
)
|
||||||
|
plan = TeamPlan(task="t", lead_expert="lead")
|
||||||
|
plan.phases = [ph]
|
||||||
|
|
||||||
|
should_save = await orch._run_phase_rollback(plan, ph)
|
||||||
|
assert should_save is False
|
||||||
|
|
||||||
|
events = [c.args[0] for c in orch._broadcast_event.call_args_list]
|
||||||
|
assert "phase_rollback_failed" in events
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Config wiring ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerConfigRollback:
|
||||||
|
"""ServerConfig rollback section wiring."""
|
||||||
|
|
||||||
|
def test_rollback_section_read_from_dict(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict(
|
||||||
|
{
|
||||||
|
"rollback": {
|
||||||
|
"default_timeout": 45.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert config.rollback == {"default_timeout": 45.0}
|
||||||
|
|
||||||
|
def test_rollback_defaults_empty_when_absent(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
config = ServerConfig.from_dict({})
|
||||||
|
assert config.rollback == {}
|
||||||
Loading…
Reference in New Issue