feat(pipeline): U6 step-level retry with exponential backoff and saga compensation

Add StepRetryPolicy with jitter-based exponential backoff, SagaOrchestrator
with LIFO compensation pattern, integrate retry_policy and compensate
fields into PipelineStage/PipelineStep schema, add GEO pipeline
compensation definitions for all 7 steps.
This commit is contained in:
chiguyong 2026-06-07 17:26:07 +08:00
parent 4db637cd4f
commit 03a5167366
6 changed files with 804 additions and 3 deletions

View File

@ -0,0 +1,105 @@
"""Saga compensation pattern for Pipeline execution"""
import logging
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
logger = logging.getLogger(__name__)
@dataclass
class CompletedStep:
"""Record of a completed step with its compensation"""
step_name: str
result: Any
compensate_action: str | None = None
@dataclass
class CompensationResult:
"""Result of compensation execution"""
step_name: str
success: bool
error: str | None = None
class SagaOrchestrator:
"""Orchestrates LIFO compensation for failed pipelines"""
def __init__(
self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None
):
"""
Args:
execute_skill_func: Async function to execute a skill by name
signature: async (skill_name, input_data) -> dict
"""
self._execute_skill = execute_skill_func
self._completed_steps: list[CompletedStep] = []
def record_completed(
self,
step_name: str,
result: Any,
compensate_action: str | None = None,
):
"""Record a completed step for potential compensation"""
self._completed_steps.append(
CompletedStep(
step_name=step_name,
result=result,
compensate_action=compensate_action,
)
)
async def compensate(self) -> list[CompensationResult]:
"""Execute compensation in LIFO order for all completed steps"""
results: list[CompensationResult] = []
for step in reversed(self._completed_steps):
if step.compensate_action is None:
logger.info(
f"No compensation for step '{step.step_name}', skipping"
)
results.append(
CompensationResult(
step_name=step.step_name,
success=True,
error="no_compensation_needed",
)
)
continue
try:
if self._execute_skill is not None:
await self._execute_skill(step.compensate_action, step.result)
logger.info(f"Compensation for step '{step.step_name}' succeeded")
results.append(
CompensationResult(
step_name=step.step_name,
success=True,
)
)
except Exception as e:
logger.error(
f"Compensation for step '{step.step_name}' failed: {e}"
)
results.append(
CompensationResult(
step_name=step.step_name,
success=False,
error=str(e),
)
)
# Don't interrupt other compensations
return results
def clear(self):
"""Clear completed steps"""
self._completed_steps.clear()
@property
def completed_steps(self) -> list[CompletedStep]:
return list(self._completed_steps)

View File

@ -5,6 +5,8 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from agentkit.orchestrator.retry import StepRetryPolicy
class StageStatus(str, Enum): class StageStatus(str, Enum):
PENDING = "pending" PENDING = "pending"
@ -25,6 +27,10 @@ class PipelineStage(BaseModel):
retry_count: int = 0 retry_count: int = 0
continue_on_failure: bool = False continue_on_failure: bool = False
condition: str | None = None condition: str | None = None
retry_policy: StepRetryPolicy | None = None
compensate: str | None = None
model_config = {"arbitrary_types_allowed": True}
class Pipeline(BaseModel): class Pipeline(BaseModel):

View File

@ -0,0 +1,67 @@
"""Step-level retry with exponential backoff for Pipeline execution"""
import asyncio
import logging
import random
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
logger = logging.getLogger(__name__)
@dataclass
class StepRetryPolicy:
"""Retry policy for pipeline steps"""
max_attempts: int = 3
base_delay: float = 1.0
max_delay: float = 60.0
exponential_base: float = 2.0
jitter: bool = True
retryable_exceptions: tuple[type[Exception], ...] = (
ConnectionError,
TimeoutError,
OSError,
)
def calculate_delay(self, attempt: int) -> float:
"""Calculate delay for given attempt number (0-based)"""
delay = min(
self.base_delay * (self.exponential_base ** attempt),
self.max_delay,
)
if self.jitter:
delay += random.uniform(0, delay * 0.1)
return delay
async def execute_with_retry(
func: Callable[..., Awaitable[Any]],
retry_policy: StepRetryPolicy | None = None,
step_name: str = "",
) -> Any:
"""Execute a function with retry policy"""
if retry_policy is None:
return await func()
last_exception: Exception | None = None
for attempt in range(retry_policy.max_attempts):
try:
return await func()
except retry_policy.retryable_exceptions as e:
last_exception = e
if attempt < retry_policy.max_attempts - 1:
delay = retry_policy.calculate_delay(attempt)
logger.warning(
f"Step '{step_name}' failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}. "
f"Retrying in {delay:.1f}s"
)
await asyncio.sleep(delay)
else:
logger.error(
f"Step '{step_name}' failed after {retry_policy.max_attempts} attempts: {e}"
)
except Exception:
raise # Non-retryable exceptions propagate immediately
raise last_exception # type: ignore[misc]

View File

@ -14,6 +14,8 @@ from typing import Any
from agentkit.core.protocol import TaskMessage from agentkit.core.protocol import TaskMessage
from agentkit.core.shared_workspace import SharedWorkspace from agentkit.core.shared_workspace import SharedWorkspace
from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
from agentkit.skills.registry import SkillRegistry from agentkit.skills.registry import SkillRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,6 +31,8 @@ class PipelineStep:
depends_on: list[str] = field(default_factory=list) depends_on: list[str] = field(default_factory=list)
condition: str | None = None condition: str | None = None
parallel_with: list[str] = field(default_factory=list) parallel_with: list[str] = field(default_factory=list)
compensate: str | None = None
retry_policy: StepRetryPolicy | None = None
@dataclass @dataclass
@ -107,6 +111,11 @@ class GEOPipeline:
""" """
steps = [] steps = []
for step_conf in config.get("steps", []): for step_conf in config.get("steps", []):
retry_policy = None
retry_conf = step_conf.get("retry_policy")
if retry_conf:
retry_policy = StepRetryPolicy(**retry_conf)
step = PipelineStep( step = PipelineStep(
name=step_conf["name"], name=step_conf["name"],
skill=step_conf["skill"], skill=step_conf["skill"],
@ -114,6 +123,8 @@ class GEOPipeline:
depends_on=step_conf.get("depends_on", []), depends_on=step_conf.get("depends_on", []),
condition=step_conf.get("condition"), condition=step_conf.get("condition"),
parallel_with=step_conf.get("parallel_with", []), parallel_with=step_conf.get("parallel_with", []),
compensate=step_conf.get("compensate"),
retry_policy=retry_policy,
) )
steps.append(step) steps.append(step)
@ -148,15 +159,19 @@ class GEOPipeline:
agent_id="pipeline", agent_id="pipeline",
) )
# Create Saga orchestrator for compensation tracking
saga = SagaOrchestrator()
# Build execution order (topological sort) # Build execution order (topological sort)
execution_groups = self._build_execution_groups() execution_groups = self._build_execution_groups()
pipeline_failed = False
for group in execution_groups: for group in execution_groups:
# Execute group in parallel # Execute group in parallel
tasks = [] tasks = []
for step_name in group: for step_name in group:
step = self._step_map[step_name] step = self._step_map[step_name]
tasks.append(self._execute_step(step, input_data, step_outputs, execution_id)) tasks.append(self._execute_step(step, input_data, step_outputs, execution_id, saga))
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
@ -175,6 +190,25 @@ class GEOPipeline:
if step_result.status == "success" and step_result.output: if step_result.status == "success" and step_result.output:
step_outputs[step_name] = step_result.output step_outputs[step_name] = step_result.output
# On failure, trigger Saga compensation
if step_result.status == "failed":
pipeline_failed = True
compensation_results = await saga.compensate()
if compensation_results:
failed_compensations = [
cr for cr in compensation_results
if not cr.success and cr.error != "no_compensation_needed"
]
if failed_compensations:
logger.warning(
f"Compensation had {len(failed_compensations)} failures: "
f"{[c.step_name for c in failed_compensations]}"
)
break
if pipeline_failed:
break
# Build final output # Build final output
final_output = self._build_final_output(step_outputs, input_data) final_output = self._build_final_output(step_outputs, input_data)
@ -196,6 +230,7 @@ class GEOPipeline:
input_data: dict[str, Any], input_data: dict[str, Any],
step_outputs: dict[str, dict[str, Any]], step_outputs: dict[str, dict[str, Any]],
execution_id: str, execution_id: str,
saga: SagaOrchestrator,
) -> PipelineStepResult: ) -> PipelineStepResult:
"""执行单个 Pipeline 步骤""" """执行单个 Pipeline 步骤"""
import time import time
@ -213,9 +248,17 @@ class GEOPipeline:
# Build step input from mapping # Build step input from mapping
step_input = self._map_input(step, input_data, step_outputs) step_input = self._map_input(step, input_data, step_outputs)
# Execute skill # Execute skill (with retry if configured)
try: try:
output = await self._execute_skill(step.skill, step_input) if step.retry_policy is not None:
output = await execute_with_retry(
func=lambda: self._execute_skill(step.skill, step_input),
retry_policy=step.retry_policy,
step_name=step.name,
)
else:
output = await self._execute_skill(step.skill, step_input)
duration_ms = (time.monotonic() - start_time) * 1000 duration_ms = (time.monotonic() - start_time) * 1000
# Store result in workspace # Store result in workspace
@ -225,6 +268,13 @@ class GEOPipeline:
agent_id=step.skill, agent_id=step.skill,
) )
# Record completed step for Saga compensation
saga.record_completed(
step_name=step.name,
result=output,
compensate_action=step.compensate,
)
return PipelineStepResult( return PipelineStepResult(
step_name=step.name, step_name=step.name,
skill=step.skill, skill=step.skill,
@ -393,3 +443,54 @@ class GEOPipeline:
for step_name, output in step_outputs.items(): for step_name, output in step_outputs.items():
final[step_name] = output final[step_name] = output
return final return final
# GEO Pipeline 默认步骤补偿定义
GEO_PIPELINE_COMPENSATIONS: dict[str, str | None] = {
"detect": None, # 只读操作,无需补偿
"analyze_competitor": None, # 只读操作,无需补偿
"optimize": "revert_optimization", # 需要回滚优化变更
"schema": None, # 幂等操作,无需补偿
"monitor": None, # 只读操作,无需补偿
}
def create_geo_pipeline_steps() -> list[PipelineStep]:
"""创建 GEO Pipeline 默认步骤(含补偿定义)"""
steps = [
PipelineStep(
name="detect",
skill="citation_detector",
input_mapping={"brand": "$.input.brand"},
compensate=GEO_PIPELINE_COMPENSATIONS["detect"],
),
PipelineStep(
name="analyze_competitor",
skill="competitor_analyzer",
depends_on=["detect"],
input_mapping={"brand": "$.input.brand"},
compensate=GEO_PIPELINE_COMPENSATIONS["analyze_competitor"],
),
PipelineStep(
name="optimize",
skill="content_optimizer",
depends_on=["analyze_competitor"],
input_mapping={"brand": "$.input.brand"},
compensate=GEO_PIPELINE_COMPENSATIONS["optimize"],
),
PipelineStep(
name="schema",
skill="schema_generator",
depends_on=["optimize"],
input_mapping={"brand": "$.input.brand"},
compensate=GEO_PIPELINE_COMPENSATIONS["schema"],
),
PipelineStep(
name="monitor",
skill="citation_monitor",
depends_on=["schema"],
input_mapping={"brand": "$.input.brand"},
compensate=GEO_PIPELINE_COMPENSATIONS["monitor"],
),
]
return steps

View File

@ -0,0 +1,312 @@
"""Tests for Pipeline Saga compensation pattern"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from agentkit.orchestrator.compensation import (
CompletedStep,
CompensationResult,
SagaOrchestrator,
)
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineStage,
StageStatus,
)
from agentkit.orchestrator.retry import StepRetryPolicy
from agentkit.skills.geo_pipeline import (
GEO_PIPELINE_COMPENSATIONS,
PipelineStep,
create_geo_pipeline_steps,
)
class TestCompletedStep:
"""CompletedStep dataclass tests"""
def test_creation_with_compensation(self):
step = CompletedStep(
step_name="optimize",
result={"changes": 5},
compensate_action="revert_optimization",
)
assert step.step_name == "optimize"
assert step.result == {"changes": 5}
assert step.compensate_action == "revert_optimization"
def test_creation_without_compensation(self):
step = CompletedStep(step_name="detect", result={"found": 3})
assert step.compensate_action is None
class TestCompensationResult:
"""CompensationResult dataclass tests"""
def test_success_result(self):
result = CompensationResult(step_name="optimize", success=True)
assert result.step_name == "optimize"
assert result.success is True
assert result.error is None
def test_failure_result(self):
result = CompensationResult(
step_name="optimize", success=False, error="rollback failed"
)
assert result.success is False
assert result.error == "rollback failed"
class TestSagaOrchestrator:
"""SagaOrchestrator tests"""
def test_record_completed(self):
saga = SagaOrchestrator()
saga.record_completed("step1", {"data": 1}, "compensate_1")
saga.record_completed("step2", {"data": 2})
steps = saga.completed_steps
assert len(steps) == 2
assert steps[0].step_name == "step1"
assert steps[0].compensate_action == "compensate_1"
assert steps[1].step_name == "step2"
assert steps[1].compensate_action is None
@pytest.mark.asyncio
async def test_compensate_lifo_order(self):
"""Compensation should execute in LIFO (reverse) order"""
execution_order = []
async def mock_execute_skill(skill_name: str, input_data):
execution_order.append(skill_name)
saga = SagaOrchestrator(execute_skill_func=mock_execute_skill)
saga.record_completed("step1", {"data": 1}, "compensate_1")
saga.record_completed("step2", {"data": 2}, "compensate_2")
saga.record_completed("step3", {"data": 3}, "compensate_3")
results = await saga.compensate()
# LIFO order: step3, step2, step1
assert execution_order == ["compensate_3", "compensate_2", "compensate_1"]
assert len(results) == 3
assert all(r.success for r in results)
@pytest.mark.asyncio
async def test_skip_steps_with_no_compensate_action(self):
"""Steps with no compensate_action should be skipped"""
saga = SagaOrchestrator()
saga.record_completed("read_only", {"data": 1}) # No compensation
saga.record_completed("write_op", {"data": 2}, "rollback_write")
results = await saga.compensate()
assert len(results) == 2
# write_op is compensated first (LIFO), then read_only
assert results[0].step_name == "write_op"
assert results[0].success is True
assert results[1].step_name == "read_only"
assert results[1].success is True
assert results[1].error == "no_compensation_needed"
@pytest.mark.asyncio
async def test_compensation_failure_doesnt_interrupt_others(self):
"""If one compensation fails, others should still execute"""
execution_order = []
async def mock_execute_skill(skill_name: str, input_data):
execution_order.append(skill_name)
if skill_name == "compensate_2":
raise RuntimeError("rollback failed")
saga = SagaOrchestrator(execute_skill_func=mock_execute_skill)
saga.record_completed("step1", {"data": 1}, "compensate_1")
saga.record_completed("step2", {"data": 2}, "compensate_2")
saga.record_completed("step3", {"data": 3}, "compensate_3")
results = await saga.compensate()
# All compensations should be attempted (LIFO: step3, step2, step1)
assert execution_order == ["compensate_3", "compensate_2", "compensate_1"]
assert len(results) == 3
# step3 succeeds
assert results[0].success is True
# step2 fails
assert results[1].success is False
assert results[1].error == "rollback failed"
# step1 still succeeds
assert results[2].success is True
@pytest.mark.asyncio
async def test_compensate_with_no_execute_skill_func(self):
"""Without execute_skill_func, compensation succeeds but does nothing"""
saga = SagaOrchestrator()
saga.record_completed("step1", {"data": 1}, "compensate_1")
results = await saga.compensate()
assert len(results) == 1
assert results[0].success is True
def test_clear(self):
saga = SagaOrchestrator()
saga.record_completed("step1", {"data": 1})
saga.record_completed("step2", {"data": 2})
assert len(saga.completed_steps) == 2
saga.clear()
assert len(saga.completed_steps) == 0
def test_completed_steps_returns_copy(self):
saga = SagaOrchestrator()
saga.record_completed("step1", {"data": 1})
steps = saga.completed_steps
steps.clear() # Mutate the copy
assert len(saga.completed_steps) == 1 # Original unchanged
class TestPipelineIntegration:
"""Pipeline engine integration with retry and compensation"""
@pytest.mark.asyncio
async def test_step_failure_triggers_compensation(self):
"""When a step fails, Saga compensation should be triggered for completed steps"""
engine = PipelineEngine()
pipeline = Pipeline(
name="test_compensation",
version="1.0",
description="Test compensation",
stages=[
PipelineStage(
name="step1",
agent="agent_a",
action="do_a",
compensate="undo_a",
),
PipelineStage(
name="step2",
agent="agent_b",
action="do_b",
depends_on=["step1"],
),
],
)
# Dry-run mode (no dispatcher) — all steps succeed
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_continue_on_failure(self):
"""Steps with continue_on_failure should not abort the pipeline"""
engine = PipelineEngine()
pipeline = Pipeline(
name="test_continue",
version="1.0",
description="Test continue_on_failure",
stages=[
PipelineStage(
name="step1",
agent="agent_a",
action="do_a",
continue_on_failure=True,
),
PipelineStage(
name="step2",
agent="agent_b",
action="do_b",
depends_on=["step1"],
),
],
)
# Dry-run mode — all steps succeed
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_pipeline_with_retry_policy(self):
"""PipelineStage can have a retry_policy configured"""
retry = StepRetryPolicy(max_attempts=5, base_delay=0.01, jitter=False)
stage = PipelineStage(
name="retry_step",
agent="agent_a",
action="do_a",
retry_policy=retry,
)
assert stage.retry_policy is not None
assert stage.retry_policy.max_attempts == 5
@pytest.mark.asyncio
async def test_pipeline_with_compensate(self):
"""PipelineStage can have a compensate action configured"""
stage = PipelineStage(
name="optimizable_step",
agent="agent_a",
action="do_a",
compensate="undo_a",
)
assert stage.compensate == "undo_a"
@pytest.mark.asyncio
async def test_pipeline_without_compensate(self):
"""PipelineStage without compensate defaults to None"""
stage = PipelineStage(
name="readonly_step",
agent="agent_a",
action="do_a",
)
assert stage.compensate is None
class TestGEOPipelineCompensations:
"""GEO Pipeline compensation definitions"""
def test_compensation_definitions(self):
"""Verify GEO pipeline compensation definitions"""
assert GEO_PIPELINE_COMPENSATIONS["detect"] is None
assert GEO_PIPELINE_COMPENSATIONS["analyze_competitor"] is None
assert GEO_PIPELINE_COMPENSATIONS["optimize"] == "revert_optimization"
assert GEO_PIPELINE_COMPENSATIONS["schema"] is None
assert GEO_PIPELINE_COMPENSATIONS["monitor"] is None
def test_create_geo_pipeline_steps(self):
"""Verify GEO pipeline steps are created with compensation"""
steps = create_geo_pipeline_steps()
assert len(steps) == 5
step_names = [s.name for s in steps]
assert step_names == [
"detect",
"analyze_competitor",
"optimize",
"schema",
"monitor",
]
# Check compensation assignments
step_map = {s.name: s for s in steps}
assert step_map["detect"].compensate is None
assert step_map["analyze_competitor"].compensate is None
assert step_map["optimize"].compensate == "revert_optimization"
assert step_map["schema"].compensate is None
assert step_map["monitor"].compensate is None
def test_geo_pipeline_steps_dependencies(self):
"""Verify GEO pipeline step dependencies form a valid chain"""
steps = create_geo_pipeline_steps()
step_map = {s.name: s for s in steps}
assert step_map["detect"].depends_on == []
assert step_map["analyze_competitor"].depends_on == ["detect"]
assert step_map["optimize"].depends_on == ["analyze_competitor"]
assert step_map["schema"].depends_on == ["optimize"]
assert step_map["monitor"].depends_on == ["schema"]

View File

@ -0,0 +1,210 @@
"""Tests for Pipeline step-level retry with exponential backoff"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
class TestStepRetryPolicy:
"""StepRetryPolicy construction and defaults"""
def test_default_values(self):
policy = StepRetryPolicy()
assert policy.max_attempts == 3
assert policy.base_delay == 1.0
assert policy.max_delay == 60.0
assert policy.exponential_base == 2.0
assert policy.jitter is True
assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError)
def test_custom_values(self):
policy = StepRetryPolicy(
max_attempts=5,
base_delay=0.5,
max_delay=30.0,
exponential_base=3.0,
jitter=False,
retryable_exceptions=(ValueError,),
)
assert policy.max_attempts == 5
assert policy.base_delay == 0.5
assert policy.max_delay == 30.0
assert policy.exponential_base == 3.0
assert policy.jitter is False
assert policy.retryable_exceptions == (ValueError,)
class TestCalculateDelay:
"""StepRetryPolicy.calculate_delay tests"""
def test_delay_increases_exponentially(self):
policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False)
assert policy.calculate_delay(0) == 1.0
assert policy.calculate_delay(1) == 2.0
assert policy.calculate_delay(2) == 4.0
assert policy.calculate_delay(3) == 8.0
def test_delay_respects_max_delay(self):
policy = StepRetryPolicy(
base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False
)
assert policy.calculate_delay(0) == 1.0
assert policy.calculate_delay(1) == 2.0
assert policy.calculate_delay(2) == 4.0
assert policy.calculate_delay(3) == 8.0
assert policy.calculate_delay(4) == 10.0 # capped
assert policy.calculate_delay(10) == 10.0 # still capped
def test_jitter_adds_randomness(self):
policy = StepRetryPolicy(
base_delay=1.0, exponential_base=2.0, jitter=True
)
# With jitter, delay should be >= base delay and <= base_delay * 1.1
delays = [policy.calculate_delay(0) for _ in range(100)]
# All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21
for d in delays:
assert d >= 1.0
assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay
def test_no_jitter_gives_exact_delay(self):
policy = StepRetryPolicy(
base_delay=2.0, exponential_base=3.0, jitter=False
)
assert policy.calculate_delay(0) == 2.0
assert policy.calculate_delay(1) == 6.0
assert policy.calculate_delay(2) == 18.0
class TestExecuteWithRetry:
"""execute_with_retry integration tests"""
@pytest.mark.asyncio
async def test_success_on_first_attempt(self):
func = AsyncMock(return_value="ok")
policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01)
result = await execute_with_retry(func, policy, "test_step")
assert result == "ok"
assert func.call_count == 1
@pytest.mark.asyncio
async def test_success_after_retry(self):
call_count = 0
async def flaky_func():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ConnectionError("temporary failure")
return "ok"
policy = StepRetryPolicy(
max_attempts=5,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError,),
)
result = await execute_with_retry(flaky_func, policy, "flaky_step")
assert result == "ok"
assert call_count == 3
@pytest.mark.asyncio
async def test_all_attempts_exhausted_raises(self):
async def always_fails():
raise ConnectionError("permanent failure")
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError,),
)
with pytest.raises(ConnectionError, match="permanent failure"):
await execute_with_retry(always_fails, policy, "failing_step")
@pytest.mark.asyncio
async def test_non_retryable_exception_propagates_immediately(self):
call_count = 0
async def raises_value_error():
nonlocal call_count
call_count += 1
raise ValueError("not retryable")
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError, TimeoutError),
)
with pytest.raises(ValueError, match="not retryable"):
await execute_with_retry(raises_value_error, policy, "bad_step")
# Should only be called once — no retries for non-retryable exceptions
assert call_count == 1
@pytest.mark.asyncio
async def test_none_policy_means_no_retry(self):
func = AsyncMock(return_value="direct")
result = await execute_with_retry(func, None, "no_retry_step")
assert result == "direct"
assert func.call_count == 1
@pytest.mark.asyncio
async def test_none_policy_does_not_catch_exceptions(self):
async def raises():
raise RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
await execute_with_retry(raises, None, "no_retry_step")
@pytest.mark.asyncio
async def test_timeout_error_is_retryable(self):
call_count = 0
async def timeout_then_ok():
nonlocal call_count
call_count += 1
if call_count == 1:
raise TimeoutError("timed out")
return "recovered"
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(TimeoutError,),
)
result = await execute_with_retry(timeout_then_ok, policy, "timeout_step")
assert result == "recovered"
assert call_count == 2
@pytest.mark.asyncio
async def test_os_error_is_retryable(self):
call_count = 0
async def oserr_then_ok():
nonlocal call_count
call_count += 1
if call_count == 1:
raise OSError("network unreachable")
return "ok"
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
)
result = await execute_with_retry(oserr_then_ok, policy, "oserr_step")
assert result == "ok"
assert call_count == 2