"""Tests for Pipeline step-level retry with exponential backoff""" import asyncio from unittest.mock import AsyncMock import pytest from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry class TestStepRetryPolicy: """StepRetryPolicy construction and defaults""" def test_default_values(self): policy = StepRetryPolicy() assert policy.max_attempts == 3 assert policy.base_delay == 1.0 assert policy.max_delay == 60.0 assert policy.exponential_base == 2.0 assert policy.jitter is True assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError) def test_custom_values(self): policy = StepRetryPolicy( max_attempts=5, base_delay=0.5, max_delay=30.0, exponential_base=3.0, jitter=False, retryable_exceptions=(ValueError,), ) assert policy.max_attempts == 5 assert policy.base_delay == 0.5 assert policy.max_delay == 30.0 assert policy.exponential_base == 3.0 assert policy.jitter is False assert policy.retryable_exceptions == (ValueError,) class TestCalculateDelay: """StepRetryPolicy.calculate_delay tests""" def test_delay_increases_exponentially(self): policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False) assert policy.calculate_delay(0) == 1.0 assert policy.calculate_delay(1) == 2.0 assert policy.calculate_delay(2) == 4.0 assert policy.calculate_delay(3) == 8.0 def test_delay_respects_max_delay(self): policy = StepRetryPolicy( base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False ) assert policy.calculate_delay(0) == 1.0 assert policy.calculate_delay(1) == 2.0 assert policy.calculate_delay(2) == 4.0 assert policy.calculate_delay(3) == 8.0 assert policy.calculate_delay(4) == 10.0 # capped assert policy.calculate_delay(10) == 10.0 # still capped def test_jitter_adds_randomness(self): policy = StepRetryPolicy( base_delay=1.0, exponential_base=2.0, jitter=True ) # With jitter, delay should be >= base delay and <= base_delay * 1.1 delays = [policy.calculate_delay(0) for _ in range(100)] # All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21 for d in delays: assert d >= 1.0 assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay def test_no_jitter_gives_exact_delay(self): policy = StepRetryPolicy( base_delay=2.0, exponential_base=3.0, jitter=False ) assert policy.calculate_delay(0) == 2.0 assert policy.calculate_delay(1) == 6.0 assert policy.calculate_delay(2) == 18.0 class TestExecuteWithRetry: """execute_with_retry integration tests""" @pytest.mark.asyncio async def test_success_on_first_attempt(self): func = AsyncMock(return_value="ok") policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01) result = await execute_with_retry(func, policy, "test_step") assert result == "ok" assert func.call_count == 1 @pytest.mark.asyncio async def test_success_after_retry(self): call_count = 0 async def flaky_func(): nonlocal call_count call_count += 1 if call_count < 3: raise ConnectionError("temporary failure") return "ok" policy = StepRetryPolicy( max_attempts=5, base_delay=0.01, jitter=False, retryable_exceptions=(ConnectionError,), ) result = await execute_with_retry(flaky_func, policy, "flaky_step") assert result == "ok" assert call_count == 3 @pytest.mark.asyncio async def test_all_attempts_exhausted_raises(self): async def always_fails(): raise ConnectionError("permanent failure") policy = StepRetryPolicy( max_attempts=3, base_delay=0.01, jitter=False, retryable_exceptions=(ConnectionError,), ) with pytest.raises(ConnectionError, match="permanent failure"): await execute_with_retry(always_fails, policy, "failing_step") @pytest.mark.asyncio async def test_non_retryable_exception_propagates_immediately(self): call_count = 0 async def raises_value_error(): nonlocal call_count call_count += 1 raise ValueError("not retryable") policy = StepRetryPolicy( max_attempts=3, base_delay=0.01, jitter=False, retryable_exceptions=(ConnectionError, TimeoutError), ) with pytest.raises(ValueError, match="not retryable"): await execute_with_retry(raises_value_error, policy, "bad_step") # Should only be called once — no retries for non-retryable exceptions assert call_count == 1 @pytest.mark.asyncio async def test_none_policy_means_no_retry(self): func = AsyncMock(return_value="direct") result = await execute_with_retry(func, None, "no_retry_step") assert result == "direct" assert func.call_count == 1 @pytest.mark.asyncio async def test_none_policy_does_not_catch_exceptions(self): async def raises(): raise RuntimeError("boom") with pytest.raises(RuntimeError, match="boom"): await execute_with_retry(raises, None, "no_retry_step") @pytest.mark.asyncio async def test_timeout_error_is_retryable(self): call_count = 0 async def timeout_then_ok(): nonlocal call_count call_count += 1 if call_count == 1: raise TimeoutError("timed out") return "recovered" policy = StepRetryPolicy( max_attempts=3, base_delay=0.01, jitter=False, retryable_exceptions=(TimeoutError,), ) result = await execute_with_retry(timeout_then_ok, policy, "timeout_step") assert result == "recovered" assert call_count == 2 @pytest.mark.asyncio async def test_os_error_is_retryable(self): call_count = 0 async def oserr_then_ok(): nonlocal call_count call_count += 1 if call_count == 1: raise OSError("network unreachable") return "ok" policy = StepRetryPolicy( max_attempts=3, base_delay=0.01, jitter=False, ) result = await execute_with_retry(oserr_then_ok, policy, "oserr_step") assert result == "ok" assert call_count == 2