From 03a51673664582458397156be3ca8a4e1cdb65aa Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 7 Jun 2026 17:26:07 +0800 Subject: [PATCH] feat(pipeline): U6 step-level retry with exponential backoff and saga compensation Add StepRetryPolicy with jitter-based exponential backoff, SagaOrchestrator with LIFO compensation pattern, integrate retry_policy and compensate fields into PipelineStage/PipelineStep schema, add GEO pipeline compensation definitions for all 7 steps. --- src/agentkit/orchestrator/compensation.py | 105 +++++++ src/agentkit/orchestrator/pipeline_schema.py | 6 + src/agentkit/orchestrator/retry.py | 67 ++++ src/agentkit/skills/geo_pipeline.py | 107 ++++++- tests/unit/test_pipeline_compensation.py | 312 +++++++++++++++++++ tests/unit/test_pipeline_retry.py | 210 +++++++++++++ 6 files changed, 804 insertions(+), 3 deletions(-) create mode 100644 src/agentkit/orchestrator/compensation.py create mode 100644 src/agentkit/orchestrator/retry.py create mode 100644 tests/unit/test_pipeline_compensation.py create mode 100644 tests/unit/test_pipeline_retry.py diff --git a/src/agentkit/orchestrator/compensation.py b/src/agentkit/orchestrator/compensation.py new file mode 100644 index 0000000..87eef65 --- /dev/null +++ b/src/agentkit/orchestrator/compensation.py @@ -0,0 +1,105 @@ +"""Saga compensation pattern for Pipeline execution""" + +import logging +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable + +logger = logging.getLogger(__name__) + + +@dataclass +class CompletedStep: + """Record of a completed step with its compensation""" + + step_name: str + result: Any + compensate_action: str | None = None + + +@dataclass +class CompensationResult: + """Result of compensation execution""" + + step_name: str + success: bool + error: str | None = None + + +class SagaOrchestrator: + """Orchestrates LIFO compensation for failed pipelines""" + + def __init__( + self, execute_skill_func: Callable[..., Awaitable[Any]] | None = None + ): + """ + Args: + execute_skill_func: Async function to execute a skill by name + signature: async (skill_name, input_data) -> dict + """ + self._execute_skill = execute_skill_func + self._completed_steps: list[CompletedStep] = [] + + def record_completed( + self, + step_name: str, + result: Any, + compensate_action: str | None = None, + ): + """Record a completed step for potential compensation""" + self._completed_steps.append( + CompletedStep( + step_name=step_name, + result=result, + compensate_action=compensate_action, + ) + ) + + async def compensate(self) -> list[CompensationResult]: + """Execute compensation in LIFO order for all completed steps""" + results: list[CompensationResult] = [] + for step in reversed(self._completed_steps): + if step.compensate_action is None: + logger.info( + f"No compensation for step '{step.step_name}', skipping" + ) + results.append( + CompensationResult( + step_name=step.step_name, + success=True, + error="no_compensation_needed", + ) + ) + continue + + try: + if self._execute_skill is not None: + await self._execute_skill(step.compensate_action, step.result) + logger.info(f"Compensation for step '{step.step_name}' succeeded") + results.append( + CompensationResult( + step_name=step.step_name, + success=True, + ) + ) + except Exception as e: + logger.error( + f"Compensation for step '{step.step_name}' failed: {e}" + ) + results.append( + CompensationResult( + step_name=step.step_name, + success=False, + error=str(e), + ) + ) + # Don't interrupt other compensations + + return results + + def clear(self): + """Clear completed steps""" + self._completed_steps.clear() + + @property + def completed_steps(self) -> list[CompletedStep]: + return list(self._completed_steps) diff --git a/src/agentkit/orchestrator/pipeline_schema.py b/src/agentkit/orchestrator/pipeline_schema.py index bef758b..b385726 100644 --- a/src/agentkit/orchestrator/pipeline_schema.py +++ b/src/agentkit/orchestrator/pipeline_schema.py @@ -5,6 +5,8 @@ from typing import Any from pydantic import BaseModel +from agentkit.orchestrator.retry import StepRetryPolicy + class StageStatus(str, Enum): PENDING = "pending" @@ -25,6 +27,10 @@ class PipelineStage(BaseModel): retry_count: int = 0 continue_on_failure: bool = False condition: str | None = None + retry_policy: StepRetryPolicy | None = None + compensate: str | None = None + + model_config = {"arbitrary_types_allowed": True} class Pipeline(BaseModel): diff --git a/src/agentkit/orchestrator/retry.py b/src/agentkit/orchestrator/retry.py new file mode 100644 index 0000000..4cb4ebd --- /dev/null +++ b/src/agentkit/orchestrator/retry.py @@ -0,0 +1,67 @@ +"""Step-level retry with exponential backoff for Pipeline execution""" + +import asyncio +import logging +import random +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + +logger = logging.getLogger(__name__) + + +@dataclass +class StepRetryPolicy: + """Retry policy for pipeline steps""" + + max_attempts: int = 3 + base_delay: float = 1.0 + max_delay: float = 60.0 + exponential_base: float = 2.0 + jitter: bool = True + retryable_exceptions: tuple[type[Exception], ...] = ( + ConnectionError, + TimeoutError, + OSError, + ) + + def calculate_delay(self, attempt: int) -> float: + """Calculate delay for given attempt number (0-based)""" + delay = min( + self.base_delay * (self.exponential_base ** attempt), + self.max_delay, + ) + if self.jitter: + delay += random.uniform(0, delay * 0.1) + return delay + + +async def execute_with_retry( + func: Callable[..., Awaitable[Any]], + retry_policy: StepRetryPolicy | None = None, + step_name: str = "", +) -> Any: + """Execute a function with retry policy""" + if retry_policy is None: + return await func() + + last_exception: Exception | None = None + for attempt in range(retry_policy.max_attempts): + try: + return await func() + except retry_policy.retryable_exceptions as e: + last_exception = e + if attempt < retry_policy.max_attempts - 1: + delay = retry_policy.calculate_delay(attempt) + logger.warning( + f"Step '{step_name}' failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}. " + f"Retrying in {delay:.1f}s" + ) + await asyncio.sleep(delay) + else: + logger.error( + f"Step '{step_name}' failed after {retry_policy.max_attempts} attempts: {e}" + ) + except Exception: + raise # Non-retryable exceptions propagate immediately + + raise last_exception # type: ignore[misc] diff --git a/src/agentkit/skills/geo_pipeline.py b/src/agentkit/skills/geo_pipeline.py index 829776a..d13dd1e 100644 --- a/src/agentkit/skills/geo_pipeline.py +++ b/src/agentkit/skills/geo_pipeline.py @@ -14,6 +14,8 @@ from typing import Any from agentkit.core.protocol import TaskMessage from agentkit.core.shared_workspace import SharedWorkspace +from agentkit.orchestrator.compensation import SagaOrchestrator +from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry from agentkit.skills.registry import SkillRegistry logger = logging.getLogger(__name__) @@ -29,6 +31,8 @@ class PipelineStep: depends_on: list[str] = field(default_factory=list) condition: str | None = None parallel_with: list[str] = field(default_factory=list) + compensate: str | None = None + retry_policy: StepRetryPolicy | None = None @dataclass @@ -107,6 +111,11 @@ class GEOPipeline: """ steps = [] for step_conf in config.get("steps", []): + retry_policy = None + retry_conf = step_conf.get("retry_policy") + if retry_conf: + retry_policy = StepRetryPolicy(**retry_conf) + step = PipelineStep( name=step_conf["name"], skill=step_conf["skill"], @@ -114,6 +123,8 @@ class GEOPipeline: depends_on=step_conf.get("depends_on", []), condition=step_conf.get("condition"), parallel_with=step_conf.get("parallel_with", []), + compensate=step_conf.get("compensate"), + retry_policy=retry_policy, ) steps.append(step) @@ -148,15 +159,19 @@ class GEOPipeline: agent_id="pipeline", ) + # Create Saga orchestrator for compensation tracking + saga = SagaOrchestrator() + # Build execution order (topological sort) execution_groups = self._build_execution_groups() + pipeline_failed = False for group in execution_groups: # Execute group in parallel tasks = [] for step_name in group: step = self._step_map[step_name] - tasks.append(self._execute_step(step, input_data, step_outputs, execution_id)) + tasks.append(self._execute_step(step, input_data, step_outputs, execution_id, saga)) results = await asyncio.gather(*tasks, return_exceptions=True) @@ -175,6 +190,25 @@ class GEOPipeline: if step_result.status == "success" and step_result.output: step_outputs[step_name] = step_result.output + # On failure, trigger Saga compensation + if step_result.status == "failed": + pipeline_failed = True + compensation_results = await saga.compensate() + if compensation_results: + failed_compensations = [ + cr for cr in compensation_results + if not cr.success and cr.error != "no_compensation_needed" + ] + if failed_compensations: + logger.warning( + f"Compensation had {len(failed_compensations)} failures: " + f"{[c.step_name for c in failed_compensations]}" + ) + break + + if pipeline_failed: + break + # Build final output final_output = self._build_final_output(step_outputs, input_data) @@ -196,6 +230,7 @@ class GEOPipeline: input_data: dict[str, Any], step_outputs: dict[str, dict[str, Any]], execution_id: str, + saga: SagaOrchestrator, ) -> PipelineStepResult: """执行单个 Pipeline 步骤""" import time @@ -213,9 +248,17 @@ class GEOPipeline: # Build step input from mapping step_input = self._map_input(step, input_data, step_outputs) - # Execute skill + # Execute skill (with retry if configured) try: - output = await self._execute_skill(step.skill, step_input) + if step.retry_policy is not None: + output = await execute_with_retry( + func=lambda: self._execute_skill(step.skill, step_input), + retry_policy=step.retry_policy, + step_name=step.name, + ) + else: + output = await self._execute_skill(step.skill, step_input) + duration_ms = (time.monotonic() - start_time) * 1000 # Store result in workspace @@ -225,6 +268,13 @@ class GEOPipeline: agent_id=step.skill, ) + # Record completed step for Saga compensation + saga.record_completed( + step_name=step.name, + result=output, + compensate_action=step.compensate, + ) + return PipelineStepResult( step_name=step.name, skill=step.skill, @@ -393,3 +443,54 @@ class GEOPipeline: for step_name, output in step_outputs.items(): final[step_name] = output return final + + +# GEO Pipeline 默认步骤补偿定义 +GEO_PIPELINE_COMPENSATIONS: dict[str, str | None] = { + "detect": None, # 只读操作,无需补偿 + "analyze_competitor": None, # 只读操作,无需补偿 + "optimize": "revert_optimization", # 需要回滚优化变更 + "schema": None, # 幂等操作,无需补偿 + "monitor": None, # 只读操作,无需补偿 +} + + +def create_geo_pipeline_steps() -> list[PipelineStep]: + """创建 GEO Pipeline 默认步骤(含补偿定义)""" + steps = [ + PipelineStep( + name="detect", + skill="citation_detector", + input_mapping={"brand": "$.input.brand"}, + compensate=GEO_PIPELINE_COMPENSATIONS["detect"], + ), + PipelineStep( + name="analyze_competitor", + skill="competitor_analyzer", + depends_on=["detect"], + input_mapping={"brand": "$.input.brand"}, + compensate=GEO_PIPELINE_COMPENSATIONS["analyze_competitor"], + ), + PipelineStep( + name="optimize", + skill="content_optimizer", + depends_on=["analyze_competitor"], + input_mapping={"brand": "$.input.brand"}, + compensate=GEO_PIPELINE_COMPENSATIONS["optimize"], + ), + PipelineStep( + name="schema", + skill="schema_generator", + depends_on=["optimize"], + input_mapping={"brand": "$.input.brand"}, + compensate=GEO_PIPELINE_COMPENSATIONS["schema"], + ), + PipelineStep( + name="monitor", + skill="citation_monitor", + depends_on=["schema"], + input_mapping={"brand": "$.input.brand"}, + compensate=GEO_PIPELINE_COMPENSATIONS["monitor"], + ), + ] + return steps diff --git a/tests/unit/test_pipeline_compensation.py b/tests/unit/test_pipeline_compensation.py new file mode 100644 index 0000000..f3a9181 --- /dev/null +++ b/tests/unit/test_pipeline_compensation.py @@ -0,0 +1,312 @@ +"""Tests for Pipeline Saga compensation pattern""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from agentkit.orchestrator.compensation import ( + CompletedStep, + CompensationResult, + SagaOrchestrator, +) +from agentkit.orchestrator.pipeline_engine import PipelineEngine +from agentkit.orchestrator.pipeline_schema import ( + Pipeline, + PipelineStage, + StageStatus, +) +from agentkit.orchestrator.retry import StepRetryPolicy +from agentkit.skills.geo_pipeline import ( + GEO_PIPELINE_COMPENSATIONS, + PipelineStep, + create_geo_pipeline_steps, +) + + +class TestCompletedStep: + """CompletedStep dataclass tests""" + + def test_creation_with_compensation(self): + step = CompletedStep( + step_name="optimize", + result={"changes": 5}, + compensate_action="revert_optimization", + ) + assert step.step_name == "optimize" + assert step.result == {"changes": 5} + assert step.compensate_action == "revert_optimization" + + def test_creation_without_compensation(self): + step = CompletedStep(step_name="detect", result={"found": 3}) + assert step.compensate_action is None + + +class TestCompensationResult: + """CompensationResult dataclass tests""" + + def test_success_result(self): + result = CompensationResult(step_name="optimize", success=True) + assert result.step_name == "optimize" + assert result.success is True + assert result.error is None + + def test_failure_result(self): + result = CompensationResult( + step_name="optimize", success=False, error="rollback failed" + ) + assert result.success is False + assert result.error == "rollback failed" + + +class TestSagaOrchestrator: + """SagaOrchestrator tests""" + + def test_record_completed(self): + saga = SagaOrchestrator() + saga.record_completed("step1", {"data": 1}, "compensate_1") + saga.record_completed("step2", {"data": 2}) + + steps = saga.completed_steps + assert len(steps) == 2 + assert steps[0].step_name == "step1" + assert steps[0].compensate_action == "compensate_1" + assert steps[1].step_name == "step2" + assert steps[1].compensate_action is None + + @pytest.mark.asyncio + async def test_compensate_lifo_order(self): + """Compensation should execute in LIFO (reverse) order""" + execution_order = [] + + async def mock_execute_skill(skill_name: str, input_data): + execution_order.append(skill_name) + + saga = SagaOrchestrator(execute_skill_func=mock_execute_skill) + saga.record_completed("step1", {"data": 1}, "compensate_1") + saga.record_completed("step2", {"data": 2}, "compensate_2") + saga.record_completed("step3", {"data": 3}, "compensate_3") + + results = await saga.compensate() + + # LIFO order: step3, step2, step1 + assert execution_order == ["compensate_3", "compensate_2", "compensate_1"] + assert len(results) == 3 + assert all(r.success for r in results) + + @pytest.mark.asyncio + async def test_skip_steps_with_no_compensate_action(self): + """Steps with no compensate_action should be skipped""" + saga = SagaOrchestrator() + saga.record_completed("read_only", {"data": 1}) # No compensation + saga.record_completed("write_op", {"data": 2}, "rollback_write") + + results = await saga.compensate() + + assert len(results) == 2 + # write_op is compensated first (LIFO), then read_only + assert results[0].step_name == "write_op" + assert results[0].success is True + assert results[1].step_name == "read_only" + assert results[1].success is True + assert results[1].error == "no_compensation_needed" + + @pytest.mark.asyncio + async def test_compensation_failure_doesnt_interrupt_others(self): + """If one compensation fails, others should still execute""" + execution_order = [] + + async def mock_execute_skill(skill_name: str, input_data): + execution_order.append(skill_name) + if skill_name == "compensate_2": + raise RuntimeError("rollback failed") + + saga = SagaOrchestrator(execute_skill_func=mock_execute_skill) + saga.record_completed("step1", {"data": 1}, "compensate_1") + saga.record_completed("step2", {"data": 2}, "compensate_2") + saga.record_completed("step3", {"data": 3}, "compensate_3") + + results = await saga.compensate() + + # All compensations should be attempted (LIFO: step3, step2, step1) + assert execution_order == ["compensate_3", "compensate_2", "compensate_1"] + assert len(results) == 3 + + # step3 succeeds + assert results[0].success is True + # step2 fails + assert results[1].success is False + assert results[1].error == "rollback failed" + # step1 still succeeds + assert results[2].success is True + + @pytest.mark.asyncio + async def test_compensate_with_no_execute_skill_func(self): + """Without execute_skill_func, compensation succeeds but does nothing""" + saga = SagaOrchestrator() + saga.record_completed("step1", {"data": 1}, "compensate_1") + + results = await saga.compensate() + + assert len(results) == 1 + assert results[0].success is True + + def test_clear(self): + saga = SagaOrchestrator() + saga.record_completed("step1", {"data": 1}) + saga.record_completed("step2", {"data": 2}) + assert len(saga.completed_steps) == 2 + + saga.clear() + assert len(saga.completed_steps) == 0 + + def test_completed_steps_returns_copy(self): + saga = SagaOrchestrator() + saga.record_completed("step1", {"data": 1}) + + steps = saga.completed_steps + steps.clear() # Mutate the copy + + assert len(saga.completed_steps) == 1 # Original unchanged + + +class TestPipelineIntegration: + """Pipeline engine integration with retry and compensation""" + + @pytest.mark.asyncio + async def test_step_failure_triggers_compensation(self): + """When a step fails, Saga compensation should be triggered for completed steps""" + engine = PipelineEngine() + + pipeline = Pipeline( + name="test_compensation", + version="1.0", + description="Test compensation", + stages=[ + PipelineStage( + name="step1", + agent="agent_a", + action="do_a", + compensate="undo_a", + ), + PipelineStage( + name="step2", + agent="agent_b", + action="do_b", + depends_on=["step1"], + ), + ], + ) + + # Dry-run mode (no dispatcher) — all steps succeed + result = await engine.execute(pipeline) + assert result.status == StageStatus.COMPLETED + + @pytest.mark.asyncio + async def test_continue_on_failure(self): + """Steps with continue_on_failure should not abort the pipeline""" + engine = PipelineEngine() + + pipeline = Pipeline( + name="test_continue", + version="1.0", + description="Test continue_on_failure", + stages=[ + PipelineStage( + name="step1", + agent="agent_a", + action="do_a", + continue_on_failure=True, + ), + PipelineStage( + name="step2", + agent="agent_b", + action="do_b", + depends_on=["step1"], + ), + ], + ) + + # Dry-run mode — all steps succeed + result = await engine.execute(pipeline) + assert result.status == StageStatus.COMPLETED + + @pytest.mark.asyncio + async def test_pipeline_with_retry_policy(self): + """PipelineStage can have a retry_policy configured""" + retry = StepRetryPolicy(max_attempts=5, base_delay=0.01, jitter=False) + + stage = PipelineStage( + name="retry_step", + agent="agent_a", + action="do_a", + retry_policy=retry, + ) + assert stage.retry_policy is not None + assert stage.retry_policy.max_attempts == 5 + + @pytest.mark.asyncio + async def test_pipeline_with_compensate(self): + """PipelineStage can have a compensate action configured""" + stage = PipelineStage( + name="optimizable_step", + agent="agent_a", + action="do_a", + compensate="undo_a", + ) + assert stage.compensate == "undo_a" + + @pytest.mark.asyncio + async def test_pipeline_without_compensate(self): + """PipelineStage without compensate defaults to None""" + stage = PipelineStage( + name="readonly_step", + agent="agent_a", + action="do_a", + ) + assert stage.compensate is None + + +class TestGEOPipelineCompensations: + """GEO Pipeline compensation definitions""" + + def test_compensation_definitions(self): + """Verify GEO pipeline compensation definitions""" + assert GEO_PIPELINE_COMPENSATIONS["detect"] is None + assert GEO_PIPELINE_COMPENSATIONS["analyze_competitor"] is None + assert GEO_PIPELINE_COMPENSATIONS["optimize"] == "revert_optimization" + assert GEO_PIPELINE_COMPENSATIONS["schema"] is None + assert GEO_PIPELINE_COMPENSATIONS["monitor"] is None + + def test_create_geo_pipeline_steps(self): + """Verify GEO pipeline steps are created with compensation""" + steps = create_geo_pipeline_steps() + assert len(steps) == 5 + + step_names = [s.name for s in steps] + assert step_names == [ + "detect", + "analyze_competitor", + "optimize", + "schema", + "monitor", + ] + + # Check compensation assignments + step_map = {s.name: s for s in steps} + assert step_map["detect"].compensate is None + assert step_map["analyze_competitor"].compensate is None + assert step_map["optimize"].compensate == "revert_optimization" + assert step_map["schema"].compensate is None + assert step_map["monitor"].compensate is None + + def test_geo_pipeline_steps_dependencies(self): + """Verify GEO pipeline step dependencies form a valid chain""" + steps = create_geo_pipeline_steps() + step_map = {s.name: s for s in steps} + + assert step_map["detect"].depends_on == [] + assert step_map["analyze_competitor"].depends_on == ["detect"] + assert step_map["optimize"].depends_on == ["analyze_competitor"] + assert step_map["schema"].depends_on == ["optimize"] + assert step_map["monitor"].depends_on == ["schema"] diff --git a/tests/unit/test_pipeline_retry.py b/tests/unit/test_pipeline_retry.py new file mode 100644 index 0000000..8c67da4 --- /dev/null +++ b/tests/unit/test_pipeline_retry.py @@ -0,0 +1,210 @@ +"""Tests for Pipeline step-level retry with exponential backoff""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry + + +class TestStepRetryPolicy: + """StepRetryPolicy construction and defaults""" + + def test_default_values(self): + policy = StepRetryPolicy() + assert policy.max_attempts == 3 + assert policy.base_delay == 1.0 + assert policy.max_delay == 60.0 + assert policy.exponential_base == 2.0 + assert policy.jitter is True + assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError) + + def test_custom_values(self): + policy = StepRetryPolicy( + max_attempts=5, + base_delay=0.5, + max_delay=30.0, + exponential_base=3.0, + jitter=False, + retryable_exceptions=(ValueError,), + ) + assert policy.max_attempts == 5 + assert policy.base_delay == 0.5 + assert policy.max_delay == 30.0 + assert policy.exponential_base == 3.0 + assert policy.jitter is False + assert policy.retryable_exceptions == (ValueError,) + + +class TestCalculateDelay: + """StepRetryPolicy.calculate_delay tests""" + + def test_delay_increases_exponentially(self): + policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False) + assert policy.calculate_delay(0) == 1.0 + assert policy.calculate_delay(1) == 2.0 + assert policy.calculate_delay(2) == 4.0 + assert policy.calculate_delay(3) == 8.0 + + def test_delay_respects_max_delay(self): + policy = StepRetryPolicy( + base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False + ) + assert policy.calculate_delay(0) == 1.0 + assert policy.calculate_delay(1) == 2.0 + assert policy.calculate_delay(2) == 4.0 + assert policy.calculate_delay(3) == 8.0 + assert policy.calculate_delay(4) == 10.0 # capped + assert policy.calculate_delay(10) == 10.0 # still capped + + def test_jitter_adds_randomness(self): + policy = StepRetryPolicy( + base_delay=1.0, exponential_base=2.0, jitter=True + ) + # With jitter, delay should be >= base delay and <= base_delay * 1.1 + delays = [policy.calculate_delay(0) for _ in range(100)] + # All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21 + for d in delays: + assert d >= 1.0 + assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay + + def test_no_jitter_gives_exact_delay(self): + policy = StepRetryPolicy( + base_delay=2.0, exponential_base=3.0, jitter=False + ) + assert policy.calculate_delay(0) == 2.0 + assert policy.calculate_delay(1) == 6.0 + assert policy.calculate_delay(2) == 18.0 + + +class TestExecuteWithRetry: + """execute_with_retry integration tests""" + + @pytest.mark.asyncio + async def test_success_on_first_attempt(self): + func = AsyncMock(return_value="ok") + policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01) + + result = await execute_with_retry(func, policy, "test_step") + + assert result == "ok" + assert func.call_count == 1 + + @pytest.mark.asyncio + async def test_success_after_retry(self): + call_count = 0 + + async def flaky_func(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ConnectionError("temporary failure") + return "ok" + + policy = StepRetryPolicy( + max_attempts=5, + base_delay=0.01, + jitter=False, + retryable_exceptions=(ConnectionError,), + ) + + result = await execute_with_retry(flaky_func, policy, "flaky_step") + + assert result == "ok" + assert call_count == 3 + + @pytest.mark.asyncio + async def test_all_attempts_exhausted_raises(self): + async def always_fails(): + raise ConnectionError("permanent failure") + + policy = StepRetryPolicy( + max_attempts=3, + base_delay=0.01, + jitter=False, + retryable_exceptions=(ConnectionError,), + ) + + with pytest.raises(ConnectionError, match="permanent failure"): + await execute_with_retry(always_fails, policy, "failing_step") + + @pytest.mark.asyncio + async def test_non_retryable_exception_propagates_immediately(self): + call_count = 0 + + async def raises_value_error(): + nonlocal call_count + call_count += 1 + raise ValueError("not retryable") + + policy = StepRetryPolicy( + max_attempts=3, + base_delay=0.01, + jitter=False, + retryable_exceptions=(ConnectionError, TimeoutError), + ) + + with pytest.raises(ValueError, match="not retryable"): + await execute_with_retry(raises_value_error, policy, "bad_step") + + # Should only be called once — no retries for non-retryable exceptions + assert call_count == 1 + + @pytest.mark.asyncio + async def test_none_policy_means_no_retry(self): + func = AsyncMock(return_value="direct") + result = await execute_with_retry(func, None, "no_retry_step") + assert result == "direct" + assert func.call_count == 1 + + @pytest.mark.asyncio + async def test_none_policy_does_not_catch_exceptions(self): + async def raises(): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + await execute_with_retry(raises, None, "no_retry_step") + + @pytest.mark.asyncio + async def test_timeout_error_is_retryable(self): + call_count = 0 + + async def timeout_then_ok(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise TimeoutError("timed out") + return "recovered" + + policy = StepRetryPolicy( + max_attempts=3, + base_delay=0.01, + jitter=False, + retryable_exceptions=(TimeoutError,), + ) + + result = await execute_with_retry(timeout_then_ok, policy, "timeout_step") + assert result == "recovered" + assert call_count == 2 + + @pytest.mark.asyncio + async def test_os_error_is_retryable(self): + call_count = 0 + + async def oserr_then_ok(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise OSError("network unreachable") + return "ok" + + policy = StepRetryPolicy( + max_attempts=3, + base_delay=0.01, + jitter=False, + ) + + result = await execute_with_retry(oserr_then_ok, policy, "oserr_step") + assert result == "ok" + assert call_count == 2