"""RetryPolicy and CircuitBreaker for LLM provider reliability""" import asyncio import logging import time from dataclasses import dataclass, field from enum import Enum from typing import Callable from agentkit.core.exceptions import LLMProviderError logger = logging.getLogger(__name__) @dataclass class RetryConfig: """Retry policy configuration""" max_retries: int = 3 base_delay: float = 1.0 max_delay: float = 30.0 exponential_base: float = 2.0 retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529}) class CircuitState(Enum): """Circuit breaker states""" CLOSED = "closed" OPEN = "open" HALF_OPEN = "half_open" @dataclass class CircuitBreakerConfig: """Circuit breaker configuration""" failure_threshold: int = 5 recovery_timeout: float = 60.0 half_open_max: int = 1 class CircuitOpenError(LLMProviderError): """Raised when the circuit breaker is open""" def __init__(self, provider: str): super().__init__(provider, "Circuit breaker is open") def _is_retryable_error(error: Exception, retryable_status_codes: set[int]) -> bool: """Check if an error is retryable based on its type and status code.""" if isinstance(error, LLMProviderError): message = error.message # Check for HTTP status code pattern in error message for code in retryable_status_codes: if f"HTTP {code}" in message: return True # Connection errors are retryable if "Connection" in message or "connect" in message.lower(): return True return False class RetryPolicy: """Retry with exponential backoff for transient failures""" def __init__(self, config: RetryConfig | None = None): self._config = config or RetryConfig() async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object: """Execute fn with retry on retryable errors.""" last_error: Exception | None = None for attempt in range(self._config.max_retries + 1): try: return await fn(*args, **kwargs) except Exception as e: last_error = e if not _is_retryable_error(e, self._config.retryable_status_codes): raise if attempt >= self._config.max_retries: raise delay = min( self._config.base_delay * (self._config.exponential_base**attempt), self._config.max_delay, ) logger.warning( f"Retry attempt {attempt + 1}/{self._config.max_retries} " f"after {delay:.1f}s: {e}" ) await asyncio.sleep(delay) # Should not reach here, but just in case raise last_error # type: ignore[misc] class CircuitBreaker: """Circuit breaker to prevent cascading failures""" def __init__(self, config: CircuitBreakerConfig | None = None, provider: str = ""): self._config = config or CircuitBreakerConfig() self._provider = provider self._state = CircuitState.CLOSED self._failure_count = 0 self._last_failure_time: float = 0.0 self._half_open_count = 0 @property def state(self) -> CircuitState: """Current circuit state, with automatic OPEN -> HALF_OPEN transition.""" if self._state == CircuitState.OPEN: elapsed = time.monotonic() - self._last_failure_time if elapsed >= self._config.recovery_timeout: self._state = CircuitState.HALF_OPEN self._half_open_count = 0 logger.info(f"Circuit breaker for '{self._provider}' transitioned to HALF_OPEN") return self._state def _on_success(self) -> None: """Handle successful request.""" if self._state == CircuitState.HALF_OPEN: self._state = CircuitState.CLOSED logger.info(f"Circuit breaker for '{self._provider}' transitioned to CLOSED") if self._state == CircuitState.CLOSED: self._failure_count = 0 def _on_failure(self) -> None: """Handle failed request.""" self._failure_count += 1 self._last_failure_time = time.monotonic() if self._state == CircuitState.HALF_OPEN: self._state = CircuitState.OPEN logger.warning(f"Circuit breaker for '{self._provider}' transitioned back to OPEN") elif self._failure_count >= self._config.failure_threshold: self._state = CircuitState.OPEN logger.warning( f"Circuit breaker for '{self._provider}' transitioned to OPEN " f"after {self._failure_count} failures" ) async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object: """Execute fn through the circuit breaker.""" current_state = self.state if current_state == CircuitState.OPEN: raise CircuitOpenError(self._provider) if current_state == CircuitState.HALF_OPEN: if self._half_open_count >= self._config.half_open_max: raise CircuitOpenError(self._provider) self._half_open_count += 1 try: result = await fn(*args, **kwargs) self._on_success() return result except Exception: self._on_failure() raise