68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
"""Step-level retry with exponential backoff for Pipeline execution"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
from dataclasses import dataclass
|
|
from typing import Awaitable, Callable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class StepRetryPolicy:
|
|
"""Retry policy for pipeline steps"""
|
|
|
|
max_attempts: int = 3
|
|
base_delay: float = 1.0
|
|
max_delay: float = 60.0
|
|
exponential_base: float = 2.0
|
|
jitter: bool = True
|
|
retryable_exceptions: tuple[type[Exception], ...] = (
|
|
ConnectionError,
|
|
TimeoutError,
|
|
OSError,
|
|
)
|
|
|
|
def calculate_delay(self, attempt: int) -> float:
|
|
"""Calculate delay for given attempt number (0-based)"""
|
|
delay = min(
|
|
self.base_delay * (self.exponential_base**attempt),
|
|
self.max_delay,
|
|
)
|
|
if self.jitter:
|
|
delay += random.uniform(0, delay * 0.1)
|
|
return delay
|
|
|
|
|
|
async def execute_with_retry(
|
|
func: Callable[..., Awaitable[object]],
|
|
retry_policy: StepRetryPolicy | None = None,
|
|
step_name: str = "",
|
|
) -> object:
|
|
"""Execute a function with retry policy"""
|
|
if retry_policy is None:
|
|
return await func()
|
|
|
|
last_exception: Exception | None = None
|
|
for attempt in range(retry_policy.max_attempts):
|
|
try:
|
|
return await func()
|
|
except retry_policy.retryable_exceptions as e:
|
|
last_exception = e
|
|
if attempt < retry_policy.max_attempts - 1:
|
|
delay = retry_policy.calculate_delay(attempt)
|
|
logger.warning(
|
|
f"Step '{step_name}' failed (attempt {attempt + 1}/{retry_policy.max_attempts}): {e}. "
|
|
f"Retrying in {delay:.1f}s"
|
|
)
|
|
await asyncio.sleep(delay)
|
|
else:
|
|
logger.error(
|
|
f"Step '{step_name}' failed after {retry_policy.max_attempts} attempts: {e}"
|
|
)
|
|
except Exception:
|
|
raise # Non-retryable exceptions propagate immediately
|
|
|
|
raise last_exception # type: ignore[misc]
|