211 lines
6.7 KiB
Python
211 lines
6.7 KiB
Python
"""Tests for Pipeline step-level retry with exponential backoff"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
|
|
|
|
|
|
class TestStepRetryPolicy:
|
|
"""StepRetryPolicy construction and defaults"""
|
|
|
|
def test_default_values(self):
|
|
policy = StepRetryPolicy()
|
|
assert policy.max_attempts == 3
|
|
assert policy.base_delay == 1.0
|
|
assert policy.max_delay == 60.0
|
|
assert policy.exponential_base == 2.0
|
|
assert policy.jitter is True
|
|
assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError)
|
|
|
|
def test_custom_values(self):
|
|
policy = StepRetryPolicy(
|
|
max_attempts=5,
|
|
base_delay=0.5,
|
|
max_delay=30.0,
|
|
exponential_base=3.0,
|
|
jitter=False,
|
|
retryable_exceptions=(ValueError,),
|
|
)
|
|
assert policy.max_attempts == 5
|
|
assert policy.base_delay == 0.5
|
|
assert policy.max_delay == 30.0
|
|
assert policy.exponential_base == 3.0
|
|
assert policy.jitter is False
|
|
assert policy.retryable_exceptions == (ValueError,)
|
|
|
|
|
|
class TestCalculateDelay:
|
|
"""StepRetryPolicy.calculate_delay tests"""
|
|
|
|
def test_delay_increases_exponentially(self):
|
|
policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False)
|
|
assert policy.calculate_delay(0) == 1.0
|
|
assert policy.calculate_delay(1) == 2.0
|
|
assert policy.calculate_delay(2) == 4.0
|
|
assert policy.calculate_delay(3) == 8.0
|
|
|
|
def test_delay_respects_max_delay(self):
|
|
policy = StepRetryPolicy(
|
|
base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False
|
|
)
|
|
assert policy.calculate_delay(0) == 1.0
|
|
assert policy.calculate_delay(1) == 2.0
|
|
assert policy.calculate_delay(2) == 4.0
|
|
assert policy.calculate_delay(3) == 8.0
|
|
assert policy.calculate_delay(4) == 10.0 # capped
|
|
assert policy.calculate_delay(10) == 10.0 # still capped
|
|
|
|
def test_jitter_adds_randomness(self):
|
|
policy = StepRetryPolicy(
|
|
base_delay=1.0, exponential_base=2.0, jitter=True
|
|
)
|
|
# With jitter, delay should be >= base delay and <= base_delay * 1.1
|
|
delays = [policy.calculate_delay(0) for _ in range(100)]
|
|
# All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21
|
|
for d in delays:
|
|
assert d >= 1.0
|
|
assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay
|
|
|
|
def test_no_jitter_gives_exact_delay(self):
|
|
policy = StepRetryPolicy(
|
|
base_delay=2.0, exponential_base=3.0, jitter=False
|
|
)
|
|
assert policy.calculate_delay(0) == 2.0
|
|
assert policy.calculate_delay(1) == 6.0
|
|
assert policy.calculate_delay(2) == 18.0
|
|
|
|
|
|
class TestExecuteWithRetry:
|
|
"""execute_with_retry integration tests"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_on_first_attempt(self):
|
|
func = AsyncMock(return_value="ok")
|
|
policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01)
|
|
|
|
result = await execute_with_retry(func, policy, "test_step")
|
|
|
|
assert result == "ok"
|
|
assert func.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_success_after_retry(self):
|
|
call_count = 0
|
|
|
|
async def flaky_func():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
raise ConnectionError("temporary failure")
|
|
return "ok"
|
|
|
|
policy = StepRetryPolicy(
|
|
max_attempts=5,
|
|
base_delay=0.01,
|
|
jitter=False,
|
|
retryable_exceptions=(ConnectionError,),
|
|
)
|
|
|
|
result = await execute_with_retry(flaky_func, policy, "flaky_step")
|
|
|
|
assert result == "ok"
|
|
assert call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_all_attempts_exhausted_raises(self):
|
|
async def always_fails():
|
|
raise ConnectionError("permanent failure")
|
|
|
|
policy = StepRetryPolicy(
|
|
max_attempts=3,
|
|
base_delay=0.01,
|
|
jitter=False,
|
|
retryable_exceptions=(ConnectionError,),
|
|
)
|
|
|
|
with pytest.raises(ConnectionError, match="permanent failure"):
|
|
await execute_with_retry(always_fails, policy, "failing_step")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_retryable_exception_propagates_immediately(self):
|
|
call_count = 0
|
|
|
|
async def raises_value_error():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
raise ValueError("not retryable")
|
|
|
|
policy = StepRetryPolicy(
|
|
max_attempts=3,
|
|
base_delay=0.01,
|
|
jitter=False,
|
|
retryable_exceptions=(ConnectionError, TimeoutError),
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="not retryable"):
|
|
await execute_with_retry(raises_value_error, policy, "bad_step")
|
|
|
|
# Should only be called once — no retries for non-retryable exceptions
|
|
assert call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_none_policy_means_no_retry(self):
|
|
func = AsyncMock(return_value="direct")
|
|
result = await execute_with_retry(func, None, "no_retry_step")
|
|
assert result == "direct"
|
|
assert func.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_none_policy_does_not_catch_exceptions(self):
|
|
async def raises():
|
|
raise RuntimeError("boom")
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
await execute_with_retry(raises, None, "no_retry_step")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_error_is_retryable(self):
|
|
call_count = 0
|
|
|
|
async def timeout_then_ok():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise TimeoutError("timed out")
|
|
return "recovered"
|
|
|
|
policy = StepRetryPolicy(
|
|
max_attempts=3,
|
|
base_delay=0.01,
|
|
jitter=False,
|
|
retryable_exceptions=(TimeoutError,),
|
|
)
|
|
|
|
result = await execute_with_retry(timeout_then_ok, policy, "timeout_step")
|
|
assert result == "recovered"
|
|
assert call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_os_error_is_retryable(self):
|
|
call_count = 0
|
|
|
|
async def oserr_then_ok():
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise OSError("network unreachable")
|
|
return "ok"
|
|
|
|
policy = StepRetryPolicy(
|
|
max_attempts=3,
|
|
base_delay=0.01,
|
|
jitter=False,
|
|
)
|
|
|
|
result = await execute_with_retry(oserr_then_ok, policy, "oserr_step")
|
|
assert result == "ok"
|
|
assert call_count == 2
|