feat(pipeline): U6 step-level retry with exponential backoff and saga compensation
Add StepRetryPolicy with jitter-based exponential backoff, SagaOrchestrator with LIFO compensation pattern, integrate retry_policy and compensate fields into PipelineStage/PipelineStep schema, add GEO pipeline compensation definitions for all 7 steps.
This commit is contained in:
parent
4db637cd4f
commit
03a5167366
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""Saga compensation pattern for Pipeline execution"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompletedStep:
|
||||||
|
"""Record of a completed step with its compensation"""
|
||||||
|
|
||||||
|
step_name: str
|
||||||
|
result: Any
|
||||||
|
compensate_action: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompensationResult:
|
||||||
|
"""Result of compensation execution"""
|
||||||
|
|
||||||
|
step_name: str
|
||||||
|
success: bool
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SagaOrchestrator:
|
||||||
|
"""Orchestrates LIFO compensation for failed pipelines"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
execute_skill_func: Async function to execute a skill by name
|
||||||
|
signature: async (skill_name, input_data) -> dict
|
||||||
|
"""
|
||||||
|
self._execute_skill = execute_skill_func
|
||||||
|
self._completed_steps: list[CompletedStep] = []
|
||||||
|
|
||||||
|
def record_completed(
|
||||||
|
self,
|
||||||
|
step_name: str,
|
||||||
|
result: Any,
|
||||||
|
compensate_action: str | None = None,
|
||||||
|
):
|
||||||
|
"""Record a completed step for potential compensation"""
|
||||||
|
self._completed_steps.append(
|
||||||
|
CompletedStep(
|
||||||
|
step_name=step_name,
|
||||||
|
result=result,
|
||||||
|
compensate_action=compensate_action,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def compensate(self) -> list[CompensationResult]:
|
||||||
|
"""Execute compensation in LIFO order for all completed steps"""
|
||||||
|
results: list[CompensationResult] = []
|
||||||
|
for step in reversed(self._completed_steps):
|
||||||
|
if step.compensate_action is None:
|
||||||
|
logger.info(
|
||||||
|
f"No compensation for step '{step.step_name}', skipping"
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
CompensationResult(
|
||||||
|
step_name=step.step_name,
|
||||||
|
success=True,
|
||||||
|
error="no_compensation_needed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._execute_skill is not None:
|
||||||
|
await self._execute_skill(step.compensate_action, step.result)
|
||||||
|
logger.info(f"Compensation for step '{step.step_name}' succeeded")
|
||||||
|
results.append(
|
||||||
|
CompensationResult(
|
||||||
|
step_name=step.step_name,
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Compensation for step '{step.step_name}' failed: {e}"
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
CompensationResult(
|
||||||
|
step_name=step.step_name,
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Don't interrupt other compensations
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear completed steps"""
|
||||||
|
self._completed_steps.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completed_steps(self) -> list[CompletedStep]:
|
||||||
|
return list(self._completed_steps)
|
||||||
|
|
@ -5,6 +5,8 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
|
||||||
|
|
||||||
class StageStatus(str, Enum):
|
class StageStatus(str, Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
|
|
@ -25,6 +27,10 @@ class PipelineStage(BaseModel):
|
||||||
retry_count: int = 0
|
retry_count: int = 0
|
||||||
continue_on_failure: bool = False
|
continue_on_failure: bool = False
|
||||||
condition: str | None = None
|
condition: str | None = None
|
||||||
|
retry_policy: StepRetryPolicy | None = None
|
||||||
|
compensate: str | None = None
|
||||||
|
|
||||||
|
model_config = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
|
|
||||||
class Pipeline(BaseModel):
|
class Pipeline(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
"""Step-level retry with exponential backoff for Pipeline execution"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StepRetryPolicy:
|
||||||
|
"""Retry policy for pipeline steps"""
|
||||||
|
|
||||||
|
max_attempts: int = 3
|
||||||
|
base_delay: float = 1.0
|
||||||
|
max_delay: float = 60.0
|
||||||
|
exponential_base: float = 2.0
|
||||||
|
jitter: bool = True
|
||||||
|
retryable_exceptions: tuple[type[Exception], ...] = (
|
||||||
|
ConnectionError,
|
||||||
|
TimeoutError,
|
||||||
|
OSError,
|
||||||
|
)
|
||||||
|
|
||||||
|
def calculate_delay(self, attempt: int) -> float:
|
||||||
|
"""Calculate delay for given attempt number (0-based)"""
|
||||||
|
delay = min(
|
||||||
|
self.base_delay * (self.exponential_base ** attempt),
|
||||||
|
self.max_delay,
|
||||||
|
)
|
||||||
|
if self.jitter:
|
||||||
|
delay += random.uniform(0, delay * 0.1)
|
||||||
|
return delay
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_with_retry(
|
||||||
|
func: Callable[..., Awaitable[Any]],
|
||||||
|
retry_policy: StepRetryPolicy | None = None,
|
||||||
|
step_name: str = "",
|
||||||
|
) -> Any:
|
||||||
|
"""Execute a function with retry policy"""
|
||||||
|
if retry_policy is None:
|
||||||
|
return await func()
|
||||||
|
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
for attempt in range(retry_policy.max_attempts):
|
||||||
|
try:
|
||||||
|
return await func()
|
||||||
|
except retry_policy.retryable_exceptions as e:
|
||||||
|
last_exception = e
|
||||||
|
if attempt < retry_policy.max_attempts - 1:
|
||||||
|
delay = retry_policy.calculate_delay(attempt)
|
||||||
|
logger.warning(
|
||||||
|
f"Step '{step_name}' failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}. "
|
||||||
|
f"Retrying in {delay:.1f}s"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Step '{step_name}' failed after {retry_policy.max_attempts} attempts: {e}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise # Non-retryable exceptions propagate immediately
|
||||||
|
|
||||||
|
raise last_exception # type: ignore[misc]
|
||||||
|
|
@ -14,6 +14,8 @@ from typing import Any
|
||||||
|
|
||||||
from agentkit.core.protocol import TaskMessage
|
from agentkit.core.protocol import TaskMessage
|
||||||
from agentkit.core.shared_workspace import SharedWorkspace
|
from agentkit.core.shared_workspace import SharedWorkspace
|
||||||
|
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||||
|
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
|
||||||
from agentkit.skills.registry import SkillRegistry
|
from agentkit.skills.registry import SkillRegistry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -29,6 +31,8 @@ class PipelineStep:
|
||||||
depends_on: list[str] = field(default_factory=list)
|
depends_on: list[str] = field(default_factory=list)
|
||||||
condition: str | None = None
|
condition: str | None = None
|
||||||
parallel_with: list[str] = field(default_factory=list)
|
parallel_with: list[str] = field(default_factory=list)
|
||||||
|
compensate: str | None = None
|
||||||
|
retry_policy: StepRetryPolicy | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -107,6 +111,11 @@ class GEOPipeline:
|
||||||
"""
|
"""
|
||||||
steps = []
|
steps = []
|
||||||
for step_conf in config.get("steps", []):
|
for step_conf in config.get("steps", []):
|
||||||
|
retry_policy = None
|
||||||
|
retry_conf = step_conf.get("retry_policy")
|
||||||
|
if retry_conf:
|
||||||
|
retry_policy = StepRetryPolicy(**retry_conf)
|
||||||
|
|
||||||
step = PipelineStep(
|
step = PipelineStep(
|
||||||
name=step_conf["name"],
|
name=step_conf["name"],
|
||||||
skill=step_conf["skill"],
|
skill=step_conf["skill"],
|
||||||
|
|
@ -114,6 +123,8 @@ class GEOPipeline:
|
||||||
depends_on=step_conf.get("depends_on", []),
|
depends_on=step_conf.get("depends_on", []),
|
||||||
condition=step_conf.get("condition"),
|
condition=step_conf.get("condition"),
|
||||||
parallel_with=step_conf.get("parallel_with", []),
|
parallel_with=step_conf.get("parallel_with", []),
|
||||||
|
compensate=step_conf.get("compensate"),
|
||||||
|
retry_policy=retry_policy,
|
||||||
)
|
)
|
||||||
steps.append(step)
|
steps.append(step)
|
||||||
|
|
||||||
|
|
@ -148,15 +159,19 @@ class GEOPipeline:
|
||||||
agent_id="pipeline",
|
agent_id="pipeline",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create Saga orchestrator for compensation tracking
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
|
||||||
# Build execution order (topological sort)
|
# Build execution order (topological sort)
|
||||||
execution_groups = self._build_execution_groups()
|
execution_groups = self._build_execution_groups()
|
||||||
|
|
||||||
|
pipeline_failed = False
|
||||||
for group in execution_groups:
|
for group in execution_groups:
|
||||||
# Execute group in parallel
|
# Execute group in parallel
|
||||||
tasks = []
|
tasks = []
|
||||||
for step_name in group:
|
for step_name in group:
|
||||||
step = self._step_map[step_name]
|
step = self._step_map[step_name]
|
||||||
tasks.append(self._execute_step(step, input_data, step_outputs, execution_id))
|
tasks.append(self._execute_step(step, input_data, step_outputs, execution_id, saga))
|
||||||
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
@ -175,6 +190,25 @@ class GEOPipeline:
|
||||||
if step_result.status == "success" and step_result.output:
|
if step_result.status == "success" and step_result.output:
|
||||||
step_outputs[step_name] = step_result.output
|
step_outputs[step_name] = step_result.output
|
||||||
|
|
||||||
|
# On failure, trigger Saga compensation
|
||||||
|
if step_result.status == "failed":
|
||||||
|
pipeline_failed = True
|
||||||
|
compensation_results = await saga.compensate()
|
||||||
|
if compensation_results:
|
||||||
|
failed_compensations = [
|
||||||
|
cr for cr in compensation_results
|
||||||
|
if not cr.success and cr.error != "no_compensation_needed"
|
||||||
|
]
|
||||||
|
if failed_compensations:
|
||||||
|
logger.warning(
|
||||||
|
f"Compensation had {len(failed_compensations)} failures: "
|
||||||
|
f"{[c.step_name for c in failed_compensations]}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if pipeline_failed:
|
||||||
|
break
|
||||||
|
|
||||||
# Build final output
|
# Build final output
|
||||||
final_output = self._build_final_output(step_outputs, input_data)
|
final_output = self._build_final_output(step_outputs, input_data)
|
||||||
|
|
||||||
|
|
@ -196,6 +230,7 @@ class GEOPipeline:
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
step_outputs: dict[str, dict[str, Any]],
|
step_outputs: dict[str, dict[str, Any]],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
|
saga: SagaOrchestrator,
|
||||||
) -> PipelineStepResult:
|
) -> PipelineStepResult:
|
||||||
"""执行单个 Pipeline 步骤"""
|
"""执行单个 Pipeline 步骤"""
|
||||||
import time
|
import time
|
||||||
|
|
@ -213,9 +248,17 @@ class GEOPipeline:
|
||||||
# Build step input from mapping
|
# Build step input from mapping
|
||||||
step_input = self._map_input(step, input_data, step_outputs)
|
step_input = self._map_input(step, input_data, step_outputs)
|
||||||
|
|
||||||
# Execute skill
|
# Execute skill (with retry if configured)
|
||||||
try:
|
try:
|
||||||
output = await self._execute_skill(step.skill, step_input)
|
if step.retry_policy is not None:
|
||||||
|
output = await execute_with_retry(
|
||||||
|
func=lambda: self._execute_skill(step.skill, step_input),
|
||||||
|
retry_policy=step.retry_policy,
|
||||||
|
step_name=step.name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = await self._execute_skill(step.skill, step_input)
|
||||||
|
|
||||||
duration_ms = (time.monotonic() - start_time) * 1000
|
duration_ms = (time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
# Store result in workspace
|
# Store result in workspace
|
||||||
|
|
@ -225,6 +268,13 @@ class GEOPipeline:
|
||||||
agent_id=step.skill,
|
agent_id=step.skill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Record completed step for Saga compensation
|
||||||
|
saga.record_completed(
|
||||||
|
step_name=step.name,
|
||||||
|
result=output,
|
||||||
|
compensate_action=step.compensate,
|
||||||
|
)
|
||||||
|
|
||||||
return PipelineStepResult(
|
return PipelineStepResult(
|
||||||
step_name=step.name,
|
step_name=step.name,
|
||||||
skill=step.skill,
|
skill=step.skill,
|
||||||
|
|
@ -393,3 +443,54 @@ class GEOPipeline:
|
||||||
for step_name, output in step_outputs.items():
|
for step_name, output in step_outputs.items():
|
||||||
final[step_name] = output
|
final[step_name] = output
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
|
# GEO Pipeline 默认步骤补偿定义
|
||||||
|
GEO_PIPELINE_COMPENSATIONS: dict[str, str | None] = {
|
||||||
|
"detect": None, # 只读操作,无需补偿
|
||||||
|
"analyze_competitor": None, # 只读操作,无需补偿
|
||||||
|
"optimize": "revert_optimization", # 需要回滚优化变更
|
||||||
|
"schema": None, # 幂等操作,无需补偿
|
||||||
|
"monitor": None, # 只读操作,无需补偿
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_geo_pipeline_steps() -> list[PipelineStep]:
|
||||||
|
"""创建 GEO Pipeline 默认步骤(含补偿定义)"""
|
||||||
|
steps = [
|
||||||
|
PipelineStep(
|
||||||
|
name="detect",
|
||||||
|
skill="citation_detector",
|
||||||
|
input_mapping={"brand": "$.input.brand"},
|
||||||
|
compensate=GEO_PIPELINE_COMPENSATIONS["detect"],
|
||||||
|
),
|
||||||
|
PipelineStep(
|
||||||
|
name="analyze_competitor",
|
||||||
|
skill="competitor_analyzer",
|
||||||
|
depends_on=["detect"],
|
||||||
|
input_mapping={"brand": "$.input.brand"},
|
||||||
|
compensate=GEO_PIPELINE_COMPENSATIONS["analyze_competitor"],
|
||||||
|
),
|
||||||
|
PipelineStep(
|
||||||
|
name="optimize",
|
||||||
|
skill="content_optimizer",
|
||||||
|
depends_on=["analyze_competitor"],
|
||||||
|
input_mapping={"brand": "$.input.brand"},
|
||||||
|
compensate=GEO_PIPELINE_COMPENSATIONS["optimize"],
|
||||||
|
),
|
||||||
|
PipelineStep(
|
||||||
|
name="schema",
|
||||||
|
skill="schema_generator",
|
||||||
|
depends_on=["optimize"],
|
||||||
|
input_mapping={"brand": "$.input.brand"},
|
||||||
|
compensate=GEO_PIPELINE_COMPENSATIONS["schema"],
|
||||||
|
),
|
||||||
|
PipelineStep(
|
||||||
|
name="monitor",
|
||||||
|
skill="citation_monitor",
|
||||||
|
depends_on=["schema"],
|
||||||
|
input_mapping={"brand": "$.input.brand"},
|
||||||
|
compensate=GEO_PIPELINE_COMPENSATIONS["monitor"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return steps
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,312 @@
|
||||||
|
"""Tests for Pipeline Saga compensation pattern"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.orchestrator.compensation import (
|
||||||
|
CompletedStep,
|
||||||
|
CompensationResult,
|
||||||
|
SagaOrchestrator,
|
||||||
|
)
|
||||||
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||||
|
from agentkit.orchestrator.pipeline_schema import (
|
||||||
|
Pipeline,
|
||||||
|
PipelineStage,
|
||||||
|
StageStatus,
|
||||||
|
)
|
||||||
|
from agentkit.orchestrator.retry import StepRetryPolicy
|
||||||
|
from agentkit.skills.geo_pipeline import (
|
||||||
|
GEO_PIPELINE_COMPENSATIONS,
|
||||||
|
PipelineStep,
|
||||||
|
create_geo_pipeline_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompletedStep:
|
||||||
|
"""CompletedStep dataclass tests"""
|
||||||
|
|
||||||
|
def test_creation_with_compensation(self):
|
||||||
|
step = CompletedStep(
|
||||||
|
step_name="optimize",
|
||||||
|
result={"changes": 5},
|
||||||
|
compensate_action="revert_optimization",
|
||||||
|
)
|
||||||
|
assert step.step_name == "optimize"
|
||||||
|
assert step.result == {"changes": 5}
|
||||||
|
assert step.compensate_action == "revert_optimization"
|
||||||
|
|
||||||
|
def test_creation_without_compensation(self):
|
||||||
|
step = CompletedStep(step_name="detect", result={"found": 3})
|
||||||
|
assert step.compensate_action is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompensationResult:
|
||||||
|
"""CompensationResult dataclass tests"""
|
||||||
|
|
||||||
|
def test_success_result(self):
|
||||||
|
result = CompensationResult(step_name="optimize", success=True)
|
||||||
|
assert result.step_name == "optimize"
|
||||||
|
assert result.success is True
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
def test_failure_result(self):
|
||||||
|
result = CompensationResult(
|
||||||
|
step_name="optimize", success=False, error="rollback failed"
|
||||||
|
)
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "rollback failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSagaOrchestrator:
|
||||||
|
"""SagaOrchestrator tests"""
|
||||||
|
|
||||||
|
def test_record_completed(self):
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
saga.record_completed("step1", {"data": 1}, "compensate_1")
|
||||||
|
saga.record_completed("step2", {"data": 2})
|
||||||
|
|
||||||
|
steps = saga.completed_steps
|
||||||
|
assert len(steps) == 2
|
||||||
|
assert steps[0].step_name == "step1"
|
||||||
|
assert steps[0].compensate_action == "compensate_1"
|
||||||
|
assert steps[1].step_name == "step2"
|
||||||
|
assert steps[1].compensate_action is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compensate_lifo_order(self):
|
||||||
|
"""Compensation should execute in LIFO (reverse) order"""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
async def mock_execute_skill(skill_name: str, input_data):
|
||||||
|
execution_order.append(skill_name)
|
||||||
|
|
||||||
|
saga = SagaOrchestrator(execute_skill_func=mock_execute_skill)
|
||||||
|
saga.record_completed("step1", {"data": 1}, "compensate_1")
|
||||||
|
saga.record_completed("step2", {"data": 2}, "compensate_2")
|
||||||
|
saga.record_completed("step3", {"data": 3}, "compensate_3")
|
||||||
|
|
||||||
|
results = await saga.compensate()
|
||||||
|
|
||||||
|
# LIFO order: step3, step2, step1
|
||||||
|
assert execution_order == ["compensate_3", "compensate_2", "compensate_1"]
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(r.success for r in results)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skip_steps_with_no_compensate_action(self):
|
||||||
|
"""Steps with no compensate_action should be skipped"""
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
saga.record_completed("read_only", {"data": 1}) # No compensation
|
||||||
|
saga.record_completed("write_op", {"data": 2}, "rollback_write")
|
||||||
|
|
||||||
|
results = await saga.compensate()
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# write_op is compensated first (LIFO), then read_only
|
||||||
|
assert results[0].step_name == "write_op"
|
||||||
|
assert results[0].success is True
|
||||||
|
assert results[1].step_name == "read_only"
|
||||||
|
assert results[1].success is True
|
||||||
|
assert results[1].error == "no_compensation_needed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compensation_failure_doesnt_interrupt_others(self):
|
||||||
|
"""If one compensation fails, others should still execute"""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
async def mock_execute_skill(skill_name: str, input_data):
|
||||||
|
execution_order.append(skill_name)
|
||||||
|
if skill_name == "compensate_2":
|
||||||
|
raise RuntimeError("rollback failed")
|
||||||
|
|
||||||
|
saga = SagaOrchestrator(execute_skill_func=mock_execute_skill)
|
||||||
|
saga.record_completed("step1", {"data": 1}, "compensate_1")
|
||||||
|
saga.record_completed("step2", {"data": 2}, "compensate_2")
|
||||||
|
saga.record_completed("step3", {"data": 3}, "compensate_3")
|
||||||
|
|
||||||
|
results = await saga.compensate()
|
||||||
|
|
||||||
|
# All compensations should be attempted (LIFO: step3, step2, step1)
|
||||||
|
assert execution_order == ["compensate_3", "compensate_2", "compensate_1"]
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
# step3 succeeds
|
||||||
|
assert results[0].success is True
|
||||||
|
# step2 fails
|
||||||
|
assert results[1].success is False
|
||||||
|
assert results[1].error == "rollback failed"
|
||||||
|
# step1 still succeeds
|
||||||
|
assert results[2].success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compensate_with_no_execute_skill_func(self):
|
||||||
|
"""Without execute_skill_func, compensation succeeds but does nothing"""
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
saga.record_completed("step1", {"data": 1}, "compensate_1")
|
||||||
|
|
||||||
|
results = await saga.compensate()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].success is True
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
saga.record_completed("step1", {"data": 1})
|
||||||
|
saga.record_completed("step2", {"data": 2})
|
||||||
|
assert len(saga.completed_steps) == 2
|
||||||
|
|
||||||
|
saga.clear()
|
||||||
|
assert len(saga.completed_steps) == 0
|
||||||
|
|
||||||
|
def test_completed_steps_returns_copy(self):
|
||||||
|
saga = SagaOrchestrator()
|
||||||
|
saga.record_completed("step1", {"data": 1})
|
||||||
|
|
||||||
|
steps = saga.completed_steps
|
||||||
|
steps.clear() # Mutate the copy
|
||||||
|
|
||||||
|
assert len(saga.completed_steps) == 1 # Original unchanged
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineIntegration:
|
||||||
|
"""Pipeline engine integration with retry and compensation"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_step_failure_triggers_compensation(self):
|
||||||
|
"""When a step fails, Saga compensation should be triggered for completed steps"""
|
||||||
|
engine = PipelineEngine()
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name="test_compensation",
|
||||||
|
version="1.0",
|
||||||
|
description="Test compensation",
|
||||||
|
stages=[
|
||||||
|
PipelineStage(
|
||||||
|
name="step1",
|
||||||
|
agent="agent_a",
|
||||||
|
action="do_a",
|
||||||
|
compensate="undo_a",
|
||||||
|
),
|
||||||
|
PipelineStage(
|
||||||
|
name="step2",
|
||||||
|
agent="agent_b",
|
||||||
|
action="do_b",
|
||||||
|
depends_on=["step1"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dry-run mode (no dispatcher) — all steps succeed
|
||||||
|
result = await engine.execute(pipeline)
|
||||||
|
assert result.status == StageStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_continue_on_failure(self):
|
||||||
|
"""Steps with continue_on_failure should not abort the pipeline"""
|
||||||
|
engine = PipelineEngine()
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
name="test_continue",
|
||||||
|
version="1.0",
|
||||||
|
description="Test continue_on_failure",
|
||||||
|
stages=[
|
||||||
|
PipelineStage(
|
||||||
|
name="step1",
|
||||||
|
agent="agent_a",
|
||||||
|
action="do_a",
|
||||||
|
continue_on_failure=True,
|
||||||
|
),
|
||||||
|
PipelineStage(
|
||||||
|
name="step2",
|
||||||
|
agent="agent_b",
|
||||||
|
action="do_b",
|
||||||
|
depends_on=["step1"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dry-run mode — all steps succeed
|
||||||
|
result = await engine.execute(pipeline)
|
||||||
|
assert result.status == StageStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_with_retry_policy(self):
|
||||||
|
"""PipelineStage can have a retry_policy configured"""
|
||||||
|
retry = StepRetryPolicy(max_attempts=5, base_delay=0.01, jitter=False)
|
||||||
|
|
||||||
|
stage = PipelineStage(
|
||||||
|
name="retry_step",
|
||||||
|
agent="agent_a",
|
||||||
|
action="do_a",
|
||||||
|
retry_policy=retry,
|
||||||
|
)
|
||||||
|
assert stage.retry_policy is not None
|
||||||
|
assert stage.retry_policy.max_attempts == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_with_compensate(self):
|
||||||
|
"""PipelineStage can have a compensate action configured"""
|
||||||
|
stage = PipelineStage(
|
||||||
|
name="optimizable_step",
|
||||||
|
agent="agent_a",
|
||||||
|
action="do_a",
|
||||||
|
compensate="undo_a",
|
||||||
|
)
|
||||||
|
assert stage.compensate == "undo_a"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_without_compensate(self):
|
||||||
|
"""PipelineStage without compensate defaults to None"""
|
||||||
|
stage = PipelineStage(
|
||||||
|
name="readonly_step",
|
||||||
|
agent="agent_a",
|
||||||
|
action="do_a",
|
||||||
|
)
|
||||||
|
assert stage.compensate is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGEOPipelineCompensations:
|
||||||
|
"""GEO Pipeline compensation definitions"""
|
||||||
|
|
||||||
|
def test_compensation_definitions(self):
|
||||||
|
"""Verify GEO pipeline compensation definitions"""
|
||||||
|
assert GEO_PIPELINE_COMPENSATIONS["detect"] is None
|
||||||
|
assert GEO_PIPELINE_COMPENSATIONS["analyze_competitor"] is None
|
||||||
|
assert GEO_PIPELINE_COMPENSATIONS["optimize"] == "revert_optimization"
|
||||||
|
assert GEO_PIPELINE_COMPENSATIONS["schema"] is None
|
||||||
|
assert GEO_PIPELINE_COMPENSATIONS["monitor"] is None
|
||||||
|
|
||||||
|
def test_create_geo_pipeline_steps(self):
|
||||||
|
"""Verify GEO pipeline steps are created with compensation"""
|
||||||
|
steps = create_geo_pipeline_steps()
|
||||||
|
assert len(steps) == 5
|
||||||
|
|
||||||
|
step_names = [s.name for s in steps]
|
||||||
|
assert step_names == [
|
||||||
|
"detect",
|
||||||
|
"analyze_competitor",
|
||||||
|
"optimize",
|
||||||
|
"schema",
|
||||||
|
"monitor",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check compensation assignments
|
||||||
|
step_map = {s.name: s for s in steps}
|
||||||
|
assert step_map["detect"].compensate is None
|
||||||
|
assert step_map["analyze_competitor"].compensate is None
|
||||||
|
assert step_map["optimize"].compensate == "revert_optimization"
|
||||||
|
assert step_map["schema"].compensate is None
|
||||||
|
assert step_map["monitor"].compensate is None
|
||||||
|
|
||||||
|
def test_geo_pipeline_steps_dependencies(self):
|
||||||
|
"""Verify GEO pipeline step dependencies form a valid chain"""
|
||||||
|
steps = create_geo_pipeline_steps()
|
||||||
|
step_map = {s.name: s for s in steps}
|
||||||
|
|
||||||
|
assert step_map["detect"].depends_on == []
|
||||||
|
assert step_map["analyze_competitor"].depends_on == ["detect"]
|
||||||
|
assert step_map["optimize"].depends_on == ["analyze_competitor"]
|
||||||
|
assert step_map["schema"].depends_on == ["optimize"]
|
||||||
|
assert step_map["monitor"].depends_on == ["schema"]
|
||||||
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""Tests for Pipeline step-level retry with exponential backoff"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
|
||||||
|
|
||||||
|
|
||||||
|
class TestStepRetryPolicy:
|
||||||
|
"""StepRetryPolicy construction and defaults"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
policy = StepRetryPolicy()
|
||||||
|
assert policy.max_attempts == 3
|
||||||
|
assert policy.base_delay == 1.0
|
||||||
|
assert policy.max_delay == 60.0
|
||||||
|
assert policy.exponential_base == 2.0
|
||||||
|
assert policy.jitter is True
|
||||||
|
assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError)
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=5,
|
||||||
|
base_delay=0.5,
|
||||||
|
max_delay=30.0,
|
||||||
|
exponential_base=3.0,
|
||||||
|
jitter=False,
|
||||||
|
retryable_exceptions=(ValueError,),
|
||||||
|
)
|
||||||
|
assert policy.max_attempts == 5
|
||||||
|
assert policy.base_delay == 0.5
|
||||||
|
assert policy.max_delay == 30.0
|
||||||
|
assert policy.exponential_base == 3.0
|
||||||
|
assert policy.jitter is False
|
||||||
|
assert policy.retryable_exceptions == (ValueError,)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCalculateDelay:
|
||||||
|
"""StepRetryPolicy.calculate_delay tests"""
|
||||||
|
|
||||||
|
def test_delay_increases_exponentially(self):
|
||||||
|
policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False)
|
||||||
|
assert policy.calculate_delay(0) == 1.0
|
||||||
|
assert policy.calculate_delay(1) == 2.0
|
||||||
|
assert policy.calculate_delay(2) == 4.0
|
||||||
|
assert policy.calculate_delay(3) == 8.0
|
||||||
|
|
||||||
|
def test_delay_respects_max_delay(self):
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False
|
||||||
|
)
|
||||||
|
assert policy.calculate_delay(0) == 1.0
|
||||||
|
assert policy.calculate_delay(1) == 2.0
|
||||||
|
assert policy.calculate_delay(2) == 4.0
|
||||||
|
assert policy.calculate_delay(3) == 8.0
|
||||||
|
assert policy.calculate_delay(4) == 10.0 # capped
|
||||||
|
assert policy.calculate_delay(10) == 10.0 # still capped
|
||||||
|
|
||||||
|
def test_jitter_adds_randomness(self):
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
base_delay=1.0, exponential_base=2.0, jitter=True
|
||||||
|
)
|
||||||
|
# With jitter, delay should be >= base delay and <= base_delay * 1.1
|
||||||
|
delays = [policy.calculate_delay(0) for _ in range(100)]
|
||||||
|
# All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21
|
||||||
|
for d in delays:
|
||||||
|
assert d >= 1.0
|
||||||
|
assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay
|
||||||
|
|
||||||
|
def test_no_jitter_gives_exact_delay(self):
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
base_delay=2.0, exponential_base=3.0, jitter=False
|
||||||
|
)
|
||||||
|
assert policy.calculate_delay(0) == 2.0
|
||||||
|
assert policy.calculate_delay(1) == 6.0
|
||||||
|
assert policy.calculate_delay(2) == 18.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteWithRetry:
|
||||||
|
"""execute_with_retry integration tests"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_on_first_attempt(self):
|
||||||
|
func = AsyncMock(return_value="ok")
|
||||||
|
policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01)
|
||||||
|
|
||||||
|
result = await execute_with_retry(func, policy, "test_step")
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert func.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_after_retry(self):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def flaky_func():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise ConnectionError("temporary failure")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=5,
|
||||||
|
base_delay=0.01,
|
||||||
|
jitter=False,
|
||||||
|
retryable_exceptions=(ConnectionError,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await execute_with_retry(flaky_func, policy, "flaky_step")
|
||||||
|
|
||||||
|
assert result == "ok"
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_attempts_exhausted_raises(self):
|
||||||
|
async def always_fails():
|
||||||
|
raise ConnectionError("permanent failure")
|
||||||
|
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=3,
|
||||||
|
base_delay=0.01,
|
||||||
|
jitter=False,
|
||||||
|
retryable_exceptions=(ConnectionError,),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ConnectionError, match="permanent failure"):
|
||||||
|
await execute_with_retry(always_fails, policy, "failing_step")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_retryable_exception_propagates_immediately(self):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def raises_value_error():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise ValueError("not retryable")
|
||||||
|
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=3,
|
||||||
|
base_delay=0.01,
|
||||||
|
jitter=False,
|
||||||
|
retryable_exceptions=(ConnectionError, TimeoutError),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not retryable"):
|
||||||
|
await execute_with_retry(raises_value_error, policy, "bad_step")
|
||||||
|
|
||||||
|
# Should only be called once — no retries for non-retryable exceptions
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_none_policy_means_no_retry(self):
|
||||||
|
func = AsyncMock(return_value="direct")
|
||||||
|
result = await execute_with_retry(func, None, "no_retry_step")
|
||||||
|
assert result == "direct"
|
||||||
|
assert func.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_none_policy_does_not_catch_exceptions(self):
|
||||||
|
async def raises():
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await execute_with_retry(raises, None, "no_retry_step")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_error_is_retryable(self):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def timeout_then_ok():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise TimeoutError("timed out")
|
||||||
|
return "recovered"
|
||||||
|
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=3,
|
||||||
|
base_delay=0.01,
|
||||||
|
jitter=False,
|
||||||
|
retryable_exceptions=(TimeoutError,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await execute_with_retry(timeout_then_ok, policy, "timeout_step")
|
||||||
|
assert result == "recovered"
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_os_error_is_retryable(self):
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def oserr_then_ok():
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise OSError("network unreachable")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
policy = StepRetryPolicy(
|
||||||
|
max_attempts=3,
|
||||||
|
base_delay=0.01,
|
||||||
|
jitter=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await execute_with_retry(oserr_then_ok, policy, "oserr_step")
|
||||||
|
assert result == "ok"
|
||||||
|
assert call_count == 2
|
||||||
Loading…
Reference in New Issue