162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
"""RetryPolicy and CircuitBreaker for LLM provider reliability"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Callable
|
|
|
|
from agentkit.core.exceptions import LLMProviderError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RetryConfig:
|
|
"""Retry policy configuration"""
|
|
|
|
max_retries: int = 3
|
|
base_delay: float = 1.0
|
|
max_delay: float = 30.0
|
|
exponential_base: float = 2.0
|
|
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
|
|
|
|
|
|
class CircuitState(Enum):
|
|
"""Circuit breaker states"""
|
|
|
|
CLOSED = "closed"
|
|
OPEN = "open"
|
|
HALF_OPEN = "half_open"
|
|
|
|
|
|
@dataclass
|
|
class CircuitBreakerConfig:
|
|
"""Circuit breaker configuration"""
|
|
|
|
failure_threshold: int = 5
|
|
recovery_timeout: float = 60.0
|
|
half_open_max: int = 1
|
|
|
|
|
|
class CircuitOpenError(LLMProviderError):
|
|
"""Raised when the circuit breaker is open"""
|
|
|
|
def __init__(self, provider: str):
|
|
super().__init__(provider, "Circuit breaker is open")
|
|
|
|
|
|
def _is_retryable_error(error: Exception, retryable_status_codes: set[int]) -> bool:
|
|
"""Check if an error is retryable based on its type and status code."""
|
|
if isinstance(error, LLMProviderError):
|
|
message = error.message
|
|
# Check for HTTP status code pattern in error message
|
|
for code in retryable_status_codes:
|
|
if f"HTTP {code}" in message:
|
|
return True
|
|
# Connection errors are retryable
|
|
if "Connection" in message or "connect" in message.lower():
|
|
return True
|
|
return False
|
|
|
|
|
|
class RetryPolicy:
|
|
"""Retry with exponential backoff for transient failures"""
|
|
|
|
def __init__(self, config: RetryConfig | None = None):
|
|
self._config = config or RetryConfig()
|
|
|
|
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
|
"""Execute fn with retry on retryable errors."""
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(self._config.max_retries + 1):
|
|
try:
|
|
return await fn(*args, **kwargs)
|
|
except Exception as e:
|
|
last_error = e
|
|
if not _is_retryable_error(e, self._config.retryable_status_codes):
|
|
raise
|
|
if attempt >= self._config.max_retries:
|
|
raise
|
|
|
|
delay = min(
|
|
self._config.base_delay * (self._config.exponential_base**attempt),
|
|
self._config.max_delay,
|
|
)
|
|
logger.warning(
|
|
f"Retry attempt {attempt + 1}/{self._config.max_retries} "
|
|
f"after {delay:.1f}s: {e}"
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# Should not reach here, but just in case
|
|
raise last_error # type: ignore[misc]
|
|
|
|
|
|
class CircuitBreaker:
|
|
"""Circuit breaker to prevent cascading failures"""
|
|
|
|
def __init__(self, config: CircuitBreakerConfig | None = None, provider: str = ""):
|
|
self._config = config or CircuitBreakerConfig()
|
|
self._provider = provider
|
|
self._state = CircuitState.CLOSED
|
|
self._failure_count = 0
|
|
self._last_failure_time: float = 0.0
|
|
self._half_open_count = 0
|
|
|
|
@property
|
|
def state(self) -> CircuitState:
|
|
"""Current circuit state, with automatic OPEN -> HALF_OPEN transition."""
|
|
if self._state == CircuitState.OPEN:
|
|
elapsed = time.monotonic() - self._last_failure_time
|
|
if elapsed >= self._config.recovery_timeout:
|
|
self._state = CircuitState.HALF_OPEN
|
|
self._half_open_count = 0
|
|
logger.info(f"Circuit breaker for '{self._provider}' transitioned to HALF_OPEN")
|
|
return self._state
|
|
|
|
def _on_success(self) -> None:
|
|
"""Handle successful request."""
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
self._state = CircuitState.CLOSED
|
|
logger.info(f"Circuit breaker for '{self._provider}' transitioned to CLOSED")
|
|
if self._state == CircuitState.CLOSED:
|
|
self._failure_count = 0
|
|
|
|
def _on_failure(self) -> None:
|
|
"""Handle failed request."""
|
|
self._failure_count += 1
|
|
self._last_failure_time = time.monotonic()
|
|
|
|
if self._state == CircuitState.HALF_OPEN:
|
|
self._state = CircuitState.OPEN
|
|
logger.warning(f"Circuit breaker for '{self._provider}' transitioned back to OPEN")
|
|
elif self._failure_count >= self._config.failure_threshold:
|
|
self._state = CircuitState.OPEN
|
|
logger.warning(
|
|
f"Circuit breaker for '{self._provider}' transitioned to OPEN "
|
|
f"after {self._failure_count} failures"
|
|
)
|
|
|
|
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
|
|
"""Execute fn through the circuit breaker."""
|
|
current_state = self.state
|
|
|
|
if current_state == CircuitState.OPEN:
|
|
raise CircuitOpenError(self._provider)
|
|
|
|
if current_state == CircuitState.HALF_OPEN:
|
|
if self._half_open_count >= self._config.half_open_max:
|
|
raise CircuitOpenError(self._provider)
|
|
self._half_open_count += 1
|
|
|
|
try:
|
|
result = await fn(*args, **kwargs)
|
|
self._on_success()
|
|
return result
|
|
except Exception:
|
|
self._on_failure()
|
|
raise
|