fischer-agentkit/tests/unit/test_llm_retry.py

525 lines
18 KiB
Python

"""RetryPolicy and CircuitBreaker tests"""
import asyncio
import time
from unittest.mock import AsyncMock
import pytest
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.retry import (
CircuitBreaker,
CircuitBreakerConfig,
CircuitOpenError,
CircuitState,
RetryConfig,
RetryPolicy,
)
# ---------------------------------------------------------------------------
# RetryPolicy tests
# ---------------------------------------------------------------------------
class TestRetryPolicy:
"""RetryPolicy unit tests"""
async def test_success_on_first_attempt(self):
"""No retry needed when the call succeeds immediately."""
policy = RetryPolicy(RetryConfig(max_retries=3))
fn = AsyncMock(return_value="ok")
result = await policy.execute(fn)
assert result == "ok"
fn.assert_called_once()
async def test_retry_success_on_second_attempt(self):
"""Retryable error on 1st attempt, success on 2nd."""
policy = RetryPolicy(RetryConfig(max_retries=3, base_delay=0.01))
fn = AsyncMock(
side_effect=[
LLMProviderError("openai", "HTTP 429: Rate limit"),
"ok",
]
)
result = await policy.execute(fn)
assert result == "ok"
assert fn.call_count == 2
async def test_retry_exhausted(self):
"""All attempts fail with retryable errors → final error raised."""
policy = RetryPolicy(RetryConfig(max_retries=2, base_delay=0.01))
fn = AsyncMock(
side_effect=LLMProviderError("openai", "HTTP 500: Internal error")
)
with pytest.raises(LLMProviderError) as exc_info:
await policy.execute(fn)
assert "500" in str(exc_info.value)
# max_retries=2 means 3 total attempts (initial + 2 retries)
assert fn.call_count == 3
async def test_non_retryable_error_raises_immediately(self):
"""Non-retryable errors (400, 401, 403) should not be retried."""
policy = RetryPolicy(RetryConfig(max_retries=3, base_delay=0.01))
fn = AsyncMock(
side_effect=LLMProviderError("openai", "HTTP 401: Unauthorized")
)
with pytest.raises(LLMProviderError) as exc_info:
await policy.execute(fn)
assert "401" in str(exc_info.value)
fn.assert_called_once()
async def test_exponential_backoff_timing(self):
"""Verify delays increase exponentially."""
policy = RetryPolicy(
RetryConfig(max_retries=3, base_delay=0.05, exponential_base=2.0)
)
call_times: list[float] = []
async def failing_fn():
call_times.append(time.monotonic())
raise LLMProviderError("openai", "HTTP 429: Rate limit")
with pytest.raises(LLMProviderError):
await policy.execute(failing_fn)
# 4 calls total (initial + 3 retries)
assert len(call_times) == 4
# Check delays: ~0.05s, ~0.1s, ~0.2s between calls
delay1 = call_times[1] - call_times[0]
delay2 = call_times[2] - call_times[1]
delay3 = call_times[3] - call_times[2]
assert delay1 >= 0.04 # ~0.05
assert delay2 >= 0.08 # ~0.10
assert delay3 >= 0.15 # ~0.20
async def test_connection_error_is_retryable(self):
"""Connection errors should be retried."""
policy = RetryPolicy(RetryConfig(max_retries=2, base_delay=0.01))
fn = AsyncMock(
side_effect=[
LLMProviderError("openai", "Connection refused"),
"ok",
]
)
result = await policy.execute(fn)
assert result == "ok"
assert fn.call_count == 2
async def test_custom_retryable_status_codes(self):
"""Custom retryable status codes should be respected."""
config = RetryConfig(
max_retries=1,
base_delay=0.01,
retryable_status_codes={502, 503},
)
policy = RetryPolicy(config)
fn = AsyncMock(
side_effect=LLMProviderError("openai", "HTTP 429: Rate limit")
)
# 429 is NOT in the custom set, so it should not be retried
with pytest.raises(LLMProviderError):
await policy.execute(fn)
fn.assert_called_once()
async def test_no_retry_when_config_is_none(self):
"""RetryPolicy with default config should still work."""
policy = RetryPolicy()
fn = AsyncMock(return_value="ok")
result = await policy.execute(fn)
assert result == "ok"
# ---------------------------------------------------------------------------
# CircuitBreaker tests
# ---------------------------------------------------------------------------
class TestCircuitBreaker:
"""CircuitBreaker unit tests"""
async def test_closed_allows_requests(self):
"""In CLOSED state, requests pass through."""
cb = CircuitBreaker(CircuitBreakerConfig(), provider="test")
fn = AsyncMock(return_value="ok")
result = await cb.execute(fn)
assert result == "ok"
assert cb.state == CircuitState.CLOSED
async def test_closed_to_open_transition(self):
"""After failure_threshold failures, circuit transitions to OPEN."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=3),
provider="test",
)
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
for _ in range(3):
with pytest.raises(LLMProviderError):
await cb.execute(fn)
assert cb.state == CircuitState.OPEN
async def test_open_rejects_requests(self):
"""In OPEN state, requests are rejected with CircuitOpenError."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=1),
provider="test",
)
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
# Trip the circuit
with pytest.raises(LLMProviderError):
await cb.execute(fn)
assert cb.state == CircuitState.OPEN
# Next request should be rejected
with pytest.raises(CircuitOpenError):
await cb.execute(AsyncMock(return_value="ok"))
async def test_open_to_half_open_after_recovery_timeout(self):
"""After recovery_timeout, circuit transitions from OPEN to HALF_OPEN."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
provider="test",
)
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
# Trip the circuit
with pytest.raises(LLMProviderError):
await cb.execute(fn)
assert cb.state == CircuitState.OPEN
# Wait for recovery timeout
await asyncio.sleep(0.06)
# Should now be HALF_OPEN
assert cb.state == CircuitState.HALF_OPEN
async def test_half_open_to_closed_on_success(self):
"""In HALF_OPEN, a successful request transitions to CLOSED."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
provider="test",
)
# Trip the circuit
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
# Wait for recovery
await asyncio.sleep(0.06)
assert cb.state == CircuitState.HALF_OPEN
# Successful request should transition to CLOSED
fn_ok = AsyncMock(return_value="ok")
result = await cb.execute(fn_ok)
assert result == "ok"
assert cb.state == CircuitState.CLOSED
async def test_half_open_to_open_on_failure(self):
"""In HALF_OPEN, a failed request transitions back to OPEN."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
provider="test",
)
# Trip the circuit
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
# Wait for recovery
await asyncio.sleep(0.06)
assert cb.state == CircuitState.HALF_OPEN
# Failed request should transition back to OPEN
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
assert cb.state == CircuitState.OPEN
async def test_half_open_max_limits_requests(self):
"""In HALF_OPEN, only half_open_max requests are allowed per probe cycle."""
cb = CircuitBreaker(
CircuitBreakerConfig(
failure_threshold=1,
recovery_timeout=0.05,
half_open_max=1,
),
provider="test",
)
# Trip the circuit
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
# Wait for recovery
await asyncio.sleep(0.06)
assert cb.state == CircuitState.HALF_OPEN
# First half-open request succeeds → circuit closes
fn_ok = AsyncMock(return_value="ok")
result = await cb.execute(fn_ok)
assert result == "ok"
assert cb.state == CircuitState.CLOSED
# Now trip it again to test half_open_max with a failing probe
cb._failure_count = 0
for _ in range(1): # failure_threshold=1
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
assert cb.state == CircuitState.OPEN
# Wait for recovery again
await asyncio.sleep(0.06)
assert cb.state == CircuitState.HALF_OPEN
# The half_open slot is used by a failing request
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
# Circuit goes back to OPEN, so next request should be rejected
assert cb.state == CircuitState.OPEN
with pytest.raises(CircuitOpenError):
await cb.execute(AsyncMock(return_value="ok"))
async def test_failure_count_resets_on_success(self):
"""Failure count resets when circuit recovers to CLOSED."""
cb = CircuitBreaker(
CircuitBreakerConfig(failure_threshold=2, recovery_timeout=0.05),
provider="test",
)
# Cause 1 failure (not enough to trip)
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
with pytest.raises(LLMProviderError):
await cb.execute(fn_fail)
assert cb.state == CircuitState.CLOSED
assert cb._failure_count == 1
# Successful request resets failure count
fn_ok = AsyncMock(return_value="ok")
await cb.execute(fn_ok)
assert cb._failure_count == 0
# ---------------------------------------------------------------------------
# Integration: Provider with RetryPolicy + CircuitBreaker
# ---------------------------------------------------------------------------
class TestProviderRetryIntegration:
"""Integration tests for providers with retry + circuit breaker"""
async def test_openai_provider_with_retry_succeeds_after_retry(self):
"""OpenAICompatibleProvider with retry config retries on 429."""
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
from agentkit.llm.providers.openai import OpenAICompatibleProvider
retry_config = RetryConfig(max_retries=2, base_delay=0.01)
provider = OpenAICompatibleProvider(
api_key="test-key",
retry_config=retry_config,
)
call_count = 0
async def mock_chat_impl(request):
nonlocal call_count
call_count += 1
if call_count == 1:
raise LLMProviderError("openai", "HTTP 429: Rate limit")
return LLMResponse(
content="retried ok",
model="gpt-4o-mini",
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
)
provider._chat_impl = mock_chat_impl
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o-mini",
)
response = await provider.chat(request)
assert response.content == "retried ok"
assert call_count == 2
async def test_anthropic_provider_with_circuit_breaker(self):
"""AnthropicProvider with circuit breaker rejects when open."""
from agentkit.llm.protocol import LLMRequest
from agentkit.llm.providers.anthropic import AnthropicProvider
cb_config = CircuitBreakerConfig(failure_threshold=1)
provider = AnthropicProvider(
api_key="test-key",
circuit_breaker_config=cb_config,
)
# Make chat_impl fail to trip the circuit
provider._chat_impl = AsyncMock(
side_effect=LLMProviderError("anthropic", "HTTP 500: Error")
)
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="claude-sonnet-4-20250514",
)
# First call fails and trips the circuit
with pytest.raises(LLMProviderError):
await provider.chat(request)
# Second call should be rejected by circuit breaker
with pytest.raises(CircuitOpenError):
await provider.chat(request)
async def test_provider_without_retry_config_works_as_before(self):
"""Provider without retry/circuit_breaker config works normally."""
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
from agentkit.llm.providers.openai import OpenAICompatibleProvider
provider = OpenAICompatibleProvider(api_key="test-key")
# No retry_policy or circuit_breaker
assert provider._retry_policy is None
assert provider._circuit_breaker is None
provider._chat_impl = AsyncMock(
return_value=LLMResponse(
content="no retry",
model="gpt-4o-mini",
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
)
)
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o-mini",
)
response = await provider.chat(request)
assert response.content == "no retry"
async def test_provider_with_both_retry_and_circuit_breaker(self):
"""Provider with both retry and circuit breaker wraps correctly."""
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
from agentkit.llm.providers.openai import OpenAICompatibleProvider
retry_config = RetryConfig(max_retries=2, base_delay=0.01)
cb_config = CircuitBreakerConfig(failure_threshold=5)
provider = OpenAICompatibleProvider(
api_key="test-key",
retry_config=retry_config,
circuit_breaker_config=cb_config,
)
call_count = 0
async def mock_chat_impl(request):
nonlocal call_count
call_count += 1
if call_count <= 2:
raise LLMProviderError("openai", "HTTP 429: Rate limit")
return LLMResponse(
content="success after retry",
model="gpt-4o-mini",
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
)
provider._chat_impl = mock_chat_impl
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o-mini",
)
response = await provider.chat(request)
assert response.content == "success after retry"
assert call_count == 3
# ---------------------------------------------------------------------------
# Config integration tests
# ---------------------------------------------------------------------------
class TestConfigIntegration:
"""Config loading with retry/circuit_breaker sections"""
def test_from_dict_with_retry_and_circuit_breaker(self):
"""YAML config with retry and circuit_breaker sections loads correctly."""
from agentkit.llm.config import LLMConfig
data = {
"providers": {
"openai": {
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1",
"retry": {
"max_retries": 5,
"base_delay": 2.0,
},
"circuit_breaker": {
"failure_threshold": 3,
"recovery_timeout": 30.0,
},
},
},
}
config = LLMConfig.from_dict(data)
provider_conf = config.providers["openai"]
assert provider_conf.retry is not None
assert provider_conf.retry.max_retries == 5
assert provider_conf.retry.base_delay == 2.0
assert provider_conf.circuit_breaker is not None
assert provider_conf.circuit_breaker.failure_threshold == 3
assert provider_conf.circuit_breaker.recovery_timeout == 30.0
def test_from_dict_without_retry_or_circuit_breaker(self):
"""Config without retry/circuit_breaker sections loads with None."""
from agentkit.llm.config import LLMConfig
data = {
"providers": {
"openai": {
"api_key": "sk-test",
"base_url": "https://api.openai.com/v1",
},
},
}
config = LLMConfig.from_dict(data)
provider_conf = config.providers["openai"]
assert provider_conf.retry is None
assert provider_conf.circuit_breaker is None