fischer-agentkit/src/agentkit/llm/retry.py

162 lines
5.4 KiB
Python

"""RetryPolicy and CircuitBreaker for LLM provider reliability"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from agentkit.core.exceptions import LLMProviderError
logger = logging.getLogger(__name__)
@dataclass
class RetryConfig:
"""Retry policy configuration"""
max_retries: int = 3
base_delay: float = 1.0
max_delay: float = 30.0
exponential_base: float = 2.0
retryable_status_codes: set[int] = field(default_factory=lambda: {429, 500, 502, 503, 529})
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
@dataclass
class CircuitBreakerConfig:
"""Circuit breaker configuration"""
failure_threshold: int = 5
recovery_timeout: float = 60.0
half_open_max: int = 1
class CircuitOpenError(LLMProviderError):
"""Raised when the circuit breaker is open"""
def __init__(self, provider: str):
super().__init__(provider, "Circuit breaker is open")
def _is_retryable_error(error: Exception, retryable_status_codes: set[int]) -> bool:
"""Check if an error is retryable based on its type and status code."""
if isinstance(error, LLMProviderError):
message = error.message
# Check for HTTP status code pattern in error message
for code in retryable_status_codes:
if f"HTTP {code}" in message:
return True
# Connection errors are retryable
if "Connection" in message or "connect" in message.lower():
return True
return False
class RetryPolicy:
"""Retry with exponential backoff for transient failures"""
def __init__(self, config: RetryConfig | None = None):
self._config = config or RetryConfig()
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn with retry on retryable errors."""
last_error: Exception | None = None
for attempt in range(self._config.max_retries + 1):
try:
return await fn(*args, **kwargs)
except Exception as e:
last_error = e
if not _is_retryable_error(e, self._config.retryable_status_codes):
raise
if attempt >= self._config.max_retries:
raise
delay = min(
self._config.base_delay * (self._config.exponential_base**attempt),
self._config.max_delay,
)
logger.warning(
f"Retry attempt {attempt + 1}/{self._config.max_retries} "
f"after {delay:.1f}s: {e}"
)
await asyncio.sleep(delay)
# Should not reach here, but just in case
raise last_error # type: ignore[misc]
class CircuitBreaker:
"""Circuit breaker to prevent cascading failures"""
def __init__(self, config: CircuitBreakerConfig | None = None, provider: str = ""):
self._config = config or CircuitBreakerConfig()
self._provider = provider
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: float = 0.0
self._half_open_count = 0
@property
def state(self) -> CircuitState:
"""Current circuit state, with automatic OPEN -> HALF_OPEN transition."""
if self._state == CircuitState.OPEN:
elapsed = time.monotonic() - self._last_failure_time
if elapsed >= self._config.recovery_timeout:
self._state = CircuitState.HALF_OPEN
self._half_open_count = 0
logger.info(f"Circuit breaker for '{self._provider}' transitioned to HALF_OPEN")
return self._state
def _on_success(self) -> None:
"""Handle successful request."""
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.CLOSED
logger.info(f"Circuit breaker for '{self._provider}' transitioned to CLOSED")
if self._state == CircuitState.CLOSED:
self._failure_count = 0
def _on_failure(self) -> None:
"""Handle failed request."""
self._failure_count += 1
self._last_failure_time = time.monotonic()
if self._state == CircuitState.HALF_OPEN:
self._state = CircuitState.OPEN
logger.warning(f"Circuit breaker for '{self._provider}' transitioned back to OPEN")
elif self._failure_count >= self._config.failure_threshold:
self._state = CircuitState.OPEN
logger.warning(
f"Circuit breaker for '{self._provider}' transitioned to OPEN "
f"after {self._failure_count} failures"
)
async def execute(self, fn: Callable, *args: object, **kwargs: object) -> object:
"""Execute fn through the circuit breaker."""
current_state = self.state
if current_state == CircuitState.OPEN:
raise CircuitOpenError(self._provider)
if current_state == CircuitState.HALF_OPEN:
if self._half_open_count >= self._config.half_open_max:
raise CircuitOpenError(self._provider)
self._half_open_count += 1
try:
result = await fn(*args, **kwargs)
self._on_success()
return result
except Exception:
self._on_failure()
raise