fischer-agentkit/tests/unit/test_pipeline_retry.py

211 lines
6.7 KiB
Python

"""Tests for Pipeline step-level retry with exponential backoff"""
import asyncio
from unittest.mock import AsyncMock
import pytest
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
class TestStepRetryPolicy:
"""StepRetryPolicy construction and defaults"""
def test_default_values(self):
policy = StepRetryPolicy()
assert policy.max_attempts == 3
assert policy.base_delay == 1.0
assert policy.max_delay == 60.0
assert policy.exponential_base == 2.0
assert policy.jitter is True
assert policy.retryable_exceptions == (ConnectionError, TimeoutError, OSError)
def test_custom_values(self):
policy = StepRetryPolicy(
max_attempts=5,
base_delay=0.5,
max_delay=30.0,
exponential_base=3.0,
jitter=False,
retryable_exceptions=(ValueError,),
)
assert policy.max_attempts == 5
assert policy.base_delay == 0.5
assert policy.max_delay == 30.0
assert policy.exponential_base == 3.0
assert policy.jitter is False
assert policy.retryable_exceptions == (ValueError,)
class TestCalculateDelay:
"""StepRetryPolicy.calculate_delay tests"""
def test_delay_increases_exponentially(self):
policy = StepRetryPolicy(base_delay=1.0, exponential_base=2.0, jitter=False)
assert policy.calculate_delay(0) == 1.0
assert policy.calculate_delay(1) == 2.0
assert policy.calculate_delay(2) == 4.0
assert policy.calculate_delay(3) == 8.0
def test_delay_respects_max_delay(self):
policy = StepRetryPolicy(
base_delay=1.0, exponential_base=2.0, max_delay=10.0, jitter=False
)
assert policy.calculate_delay(0) == 1.0
assert policy.calculate_delay(1) == 2.0
assert policy.calculate_delay(2) == 4.0
assert policy.calculate_delay(3) == 8.0
assert policy.calculate_delay(4) == 10.0 # capped
assert policy.calculate_delay(10) == 10.0 # still capped
def test_jitter_adds_randomness(self):
policy = StepRetryPolicy(
base_delay=1.0, exponential_base=2.0, jitter=True
)
# With jitter, delay should be >= base delay and <= base_delay * 1.1
delays = [policy.calculate_delay(0) for _ in range(100)]
# All delays should be >= 1.0 (base) and < 1.0 * 1.1 * 1.1 = 1.21
for d in delays:
assert d >= 1.0
assert d < 1.0 * 1.1 * 1.1 # jitter adds up to 10% of delay
def test_no_jitter_gives_exact_delay(self):
policy = StepRetryPolicy(
base_delay=2.0, exponential_base=3.0, jitter=False
)
assert policy.calculate_delay(0) == 2.0
assert policy.calculate_delay(1) == 6.0
assert policy.calculate_delay(2) == 18.0
class TestExecuteWithRetry:
"""execute_with_retry integration tests"""
@pytest.mark.asyncio
async def test_success_on_first_attempt(self):
func = AsyncMock(return_value="ok")
policy = StepRetryPolicy(max_attempts=3, jitter=False, base_delay=0.01)
result = await execute_with_retry(func, policy, "test_step")
assert result == "ok"
assert func.call_count == 1
@pytest.mark.asyncio
async def test_success_after_retry(self):
call_count = 0
async def flaky_func():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ConnectionError("temporary failure")
return "ok"
policy = StepRetryPolicy(
max_attempts=5,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError,),
)
result = await execute_with_retry(flaky_func, policy, "flaky_step")
assert result == "ok"
assert call_count == 3
@pytest.mark.asyncio
async def test_all_attempts_exhausted_raises(self):
async def always_fails():
raise ConnectionError("permanent failure")
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError,),
)
with pytest.raises(ConnectionError, match="permanent failure"):
await execute_with_retry(always_fails, policy, "failing_step")
@pytest.mark.asyncio
async def test_non_retryable_exception_propagates_immediately(self):
call_count = 0
async def raises_value_error():
nonlocal call_count
call_count += 1
raise ValueError("not retryable")
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(ConnectionError, TimeoutError),
)
with pytest.raises(ValueError, match="not retryable"):
await execute_with_retry(raises_value_error, policy, "bad_step")
# Should only be called once — no retries for non-retryable exceptions
assert call_count == 1
@pytest.mark.asyncio
async def test_none_policy_means_no_retry(self):
func = AsyncMock(return_value="direct")
result = await execute_with_retry(func, None, "no_retry_step")
assert result == "direct"
assert func.call_count == 1
@pytest.mark.asyncio
async def test_none_policy_does_not_catch_exceptions(self):
async def raises():
raise RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
await execute_with_retry(raises, None, "no_retry_step")
@pytest.mark.asyncio
async def test_timeout_error_is_retryable(self):
call_count = 0
async def timeout_then_ok():
nonlocal call_count
call_count += 1
if call_count == 1:
raise TimeoutError("timed out")
return "recovered"
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
retryable_exceptions=(TimeoutError,),
)
result = await execute_with_retry(timeout_then_ok, policy, "timeout_step")
assert result == "recovered"
assert call_count == 2
@pytest.mark.asyncio
async def test_os_error_is_retryable(self):
call_count = 0
async def oserr_then_ok():
nonlocal call_count
call_count += 1
if call_count == 1:
raise OSError("network unreachable")
return "ok"
policy = StepRetryPolicy(
max_attempts=3,
base_delay=0.01,
jitter=False,
)
result = await execute_with_retry(oserr_then_ok, policy, "oserr_step")
assert result == "ok"
assert call_count == 2