525 lines
18 KiB
Python
525 lines
18 KiB
Python
"""RetryPolicy and CircuitBreaker tests"""
|
|
|
|
import asyncio
|
|
import time
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.exceptions import LLMProviderError
|
|
from agentkit.llm.retry import (
|
|
CircuitBreaker,
|
|
CircuitBreakerConfig,
|
|
CircuitOpenError,
|
|
CircuitState,
|
|
RetryConfig,
|
|
RetryPolicy,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RetryPolicy tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRetryPolicy:
|
|
"""RetryPolicy unit tests"""
|
|
|
|
async def test_success_on_first_attempt(self):
|
|
"""No retry needed when the call succeeds immediately."""
|
|
policy = RetryPolicy(RetryConfig(max_retries=3))
|
|
fn = AsyncMock(return_value="ok")
|
|
|
|
result = await policy.execute(fn)
|
|
|
|
assert result == "ok"
|
|
fn.assert_called_once()
|
|
|
|
async def test_retry_success_on_second_attempt(self):
|
|
"""Retryable error on 1st attempt, success on 2nd."""
|
|
policy = RetryPolicy(RetryConfig(max_retries=3, base_delay=0.01))
|
|
fn = AsyncMock(
|
|
side_effect=[
|
|
LLMProviderError("openai", "HTTP 429: Rate limit"),
|
|
"ok",
|
|
]
|
|
)
|
|
|
|
result = await policy.execute(fn)
|
|
|
|
assert result == "ok"
|
|
assert fn.call_count == 2
|
|
|
|
async def test_retry_exhausted(self):
|
|
"""All attempts fail with retryable errors → final error raised."""
|
|
policy = RetryPolicy(RetryConfig(max_retries=2, base_delay=0.01))
|
|
fn = AsyncMock(
|
|
side_effect=LLMProviderError("openai", "HTTP 500: Internal error")
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await policy.execute(fn)
|
|
|
|
assert "500" in str(exc_info.value)
|
|
# max_retries=2 means 3 total attempts (initial + 2 retries)
|
|
assert fn.call_count == 3
|
|
|
|
async def test_non_retryable_error_raises_immediately(self):
|
|
"""Non-retryable errors (400, 401, 403) should not be retried."""
|
|
policy = RetryPolicy(RetryConfig(max_retries=3, base_delay=0.01))
|
|
fn = AsyncMock(
|
|
side_effect=LLMProviderError("openai", "HTTP 401: Unauthorized")
|
|
)
|
|
|
|
with pytest.raises(LLMProviderError) as exc_info:
|
|
await policy.execute(fn)
|
|
|
|
assert "401" in str(exc_info.value)
|
|
fn.assert_called_once()
|
|
|
|
async def test_exponential_backoff_timing(self):
|
|
"""Verify delays increase exponentially."""
|
|
policy = RetryPolicy(
|
|
RetryConfig(max_retries=3, base_delay=0.05, exponential_base=2.0)
|
|
)
|
|
call_times: list[float] = []
|
|
|
|
async def failing_fn():
|
|
call_times.append(time.monotonic())
|
|
raise LLMProviderError("openai", "HTTP 429: Rate limit")
|
|
|
|
with pytest.raises(LLMProviderError):
|
|
await policy.execute(failing_fn)
|
|
|
|
# 4 calls total (initial + 3 retries)
|
|
assert len(call_times) == 4
|
|
# Check delays: ~0.05s, ~0.1s, ~0.2s between calls
|
|
delay1 = call_times[1] - call_times[0]
|
|
delay2 = call_times[2] - call_times[1]
|
|
delay3 = call_times[3] - call_times[2]
|
|
|
|
assert delay1 >= 0.04 # ~0.05
|
|
assert delay2 >= 0.08 # ~0.10
|
|
assert delay3 >= 0.15 # ~0.20
|
|
|
|
async def test_connection_error_is_retryable(self):
|
|
"""Connection errors should be retried."""
|
|
policy = RetryPolicy(RetryConfig(max_retries=2, base_delay=0.01))
|
|
fn = AsyncMock(
|
|
side_effect=[
|
|
LLMProviderError("openai", "Connection refused"),
|
|
"ok",
|
|
]
|
|
)
|
|
|
|
result = await policy.execute(fn)
|
|
assert result == "ok"
|
|
assert fn.call_count == 2
|
|
|
|
async def test_custom_retryable_status_codes(self):
|
|
"""Custom retryable status codes should be respected."""
|
|
config = RetryConfig(
|
|
max_retries=1,
|
|
base_delay=0.01,
|
|
retryable_status_codes={502, 503},
|
|
)
|
|
policy = RetryPolicy(config)
|
|
fn = AsyncMock(
|
|
side_effect=LLMProviderError("openai", "HTTP 429: Rate limit")
|
|
)
|
|
|
|
# 429 is NOT in the custom set, so it should not be retried
|
|
with pytest.raises(LLMProviderError):
|
|
await policy.execute(fn)
|
|
fn.assert_called_once()
|
|
|
|
async def test_no_retry_when_config_is_none(self):
|
|
"""RetryPolicy with default config should still work."""
|
|
policy = RetryPolicy()
|
|
fn = AsyncMock(return_value="ok")
|
|
|
|
result = await policy.execute(fn)
|
|
assert result == "ok"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CircuitBreaker tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCircuitBreaker:
|
|
"""CircuitBreaker unit tests"""
|
|
|
|
async def test_closed_allows_requests(self):
|
|
"""In CLOSED state, requests pass through."""
|
|
cb = CircuitBreaker(CircuitBreakerConfig(), provider="test")
|
|
fn = AsyncMock(return_value="ok")
|
|
|
|
result = await cb.execute(fn)
|
|
|
|
assert result == "ok"
|
|
assert cb.state == CircuitState.CLOSED
|
|
|
|
async def test_closed_to_open_transition(self):
|
|
"""After failure_threshold failures, circuit transitions to OPEN."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=3),
|
|
provider="test",
|
|
)
|
|
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
|
|
for _ in range(3):
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn)
|
|
|
|
assert cb.state == CircuitState.OPEN
|
|
|
|
async def test_open_rejects_requests(self):
|
|
"""In OPEN state, requests are rejected with CircuitOpenError."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=1),
|
|
provider="test",
|
|
)
|
|
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
|
|
# Trip the circuit
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn)
|
|
|
|
assert cb.state == CircuitState.OPEN
|
|
|
|
# Next request should be rejected
|
|
with pytest.raises(CircuitOpenError):
|
|
await cb.execute(AsyncMock(return_value="ok"))
|
|
|
|
async def test_open_to_half_open_after_recovery_timeout(self):
|
|
"""After recovery_timeout, circuit transitions from OPEN to HALF_OPEN."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
|
|
provider="test",
|
|
)
|
|
fn = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
|
|
# Trip the circuit
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn)
|
|
|
|
assert cb.state == CircuitState.OPEN
|
|
|
|
# Wait for recovery timeout
|
|
await asyncio.sleep(0.06)
|
|
|
|
# Should now be HALF_OPEN
|
|
assert cb.state == CircuitState.HALF_OPEN
|
|
|
|
async def test_half_open_to_closed_on_success(self):
|
|
"""In HALF_OPEN, a successful request transitions to CLOSED."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
|
|
provider="test",
|
|
)
|
|
|
|
# Trip the circuit
|
|
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
# Wait for recovery
|
|
await asyncio.sleep(0.06)
|
|
assert cb.state == CircuitState.HALF_OPEN
|
|
|
|
# Successful request should transition to CLOSED
|
|
fn_ok = AsyncMock(return_value="ok")
|
|
result = await cb.execute(fn_ok)
|
|
|
|
assert result == "ok"
|
|
assert cb.state == CircuitState.CLOSED
|
|
|
|
async def test_half_open_to_open_on_failure(self):
|
|
"""In HALF_OPEN, a failed request transitions back to OPEN."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=1, recovery_timeout=0.05),
|
|
provider="test",
|
|
)
|
|
|
|
# Trip the circuit
|
|
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
# Wait for recovery
|
|
await asyncio.sleep(0.06)
|
|
assert cb.state == CircuitState.HALF_OPEN
|
|
|
|
# Failed request should transition back to OPEN
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
assert cb.state == CircuitState.OPEN
|
|
|
|
async def test_half_open_max_limits_requests(self):
|
|
"""In HALF_OPEN, only half_open_max requests are allowed per probe cycle."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(
|
|
failure_threshold=1,
|
|
recovery_timeout=0.05,
|
|
half_open_max=1,
|
|
),
|
|
provider="test",
|
|
)
|
|
|
|
# Trip the circuit
|
|
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
# Wait for recovery
|
|
await asyncio.sleep(0.06)
|
|
assert cb.state == CircuitState.HALF_OPEN
|
|
|
|
# First half-open request succeeds → circuit closes
|
|
fn_ok = AsyncMock(return_value="ok")
|
|
result = await cb.execute(fn_ok)
|
|
assert result == "ok"
|
|
assert cb.state == CircuitState.CLOSED
|
|
|
|
# Now trip it again to test half_open_max with a failing probe
|
|
cb._failure_count = 0
|
|
for _ in range(1): # failure_threshold=1
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
assert cb.state == CircuitState.OPEN
|
|
|
|
# Wait for recovery again
|
|
await asyncio.sleep(0.06)
|
|
assert cb.state == CircuitState.HALF_OPEN
|
|
|
|
# The half_open slot is used by a failing request
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
# Circuit goes back to OPEN, so next request should be rejected
|
|
assert cb.state == CircuitState.OPEN
|
|
with pytest.raises(CircuitOpenError):
|
|
await cb.execute(AsyncMock(return_value="ok"))
|
|
|
|
async def test_failure_count_resets_on_success(self):
|
|
"""Failure count resets when circuit recovers to CLOSED."""
|
|
cb = CircuitBreaker(
|
|
CircuitBreakerConfig(failure_threshold=2, recovery_timeout=0.05),
|
|
provider="test",
|
|
)
|
|
|
|
# Cause 1 failure (not enough to trip)
|
|
fn_fail = AsyncMock(side_effect=LLMProviderError("test", "HTTP 500: Error"))
|
|
with pytest.raises(LLMProviderError):
|
|
await cb.execute(fn_fail)
|
|
|
|
assert cb.state == CircuitState.CLOSED
|
|
assert cb._failure_count == 1
|
|
|
|
# Successful request resets failure count
|
|
fn_ok = AsyncMock(return_value="ok")
|
|
await cb.execute(fn_ok)
|
|
|
|
assert cb._failure_count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: Provider with RetryPolicy + CircuitBreaker
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestProviderRetryIntegration:
|
|
"""Integration tests for providers with retry + circuit breaker"""
|
|
|
|
async def test_openai_provider_with_retry_succeeds_after_retry(self):
|
|
"""OpenAICompatibleProvider with retry config retries on 429."""
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
|
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
|
|
|
retry_config = RetryConfig(max_retries=2, base_delay=0.01)
|
|
provider = OpenAICompatibleProvider(
|
|
api_key="test-key",
|
|
retry_config=retry_config,
|
|
)
|
|
|
|
call_count = 0
|
|
|
|
async def mock_chat_impl(request):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise LLMProviderError("openai", "HTTP 429: Rate limit")
|
|
return LLMResponse(
|
|
content="retried ok",
|
|
model="gpt-4o-mini",
|
|
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
|
|
provider._chat_impl = mock_chat_impl
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gpt-4o-mini",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.content == "retried ok"
|
|
assert call_count == 2
|
|
|
|
async def test_anthropic_provider_with_circuit_breaker(self):
|
|
"""AnthropicProvider with circuit breaker rejects when open."""
|
|
from agentkit.llm.protocol import LLMRequest
|
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
|
|
|
cb_config = CircuitBreakerConfig(failure_threshold=1)
|
|
provider = AnthropicProvider(
|
|
api_key="test-key",
|
|
circuit_breaker_config=cb_config,
|
|
)
|
|
|
|
# Make chat_impl fail to trip the circuit
|
|
provider._chat_impl = AsyncMock(
|
|
side_effect=LLMProviderError("anthropic", "HTTP 500: Error")
|
|
)
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="claude-sonnet-4-20250514",
|
|
)
|
|
|
|
# First call fails and trips the circuit
|
|
with pytest.raises(LLMProviderError):
|
|
await provider.chat(request)
|
|
|
|
# Second call should be rejected by circuit breaker
|
|
with pytest.raises(CircuitOpenError):
|
|
await provider.chat(request)
|
|
|
|
async def test_provider_without_retry_config_works_as_before(self):
|
|
"""Provider without retry/circuit_breaker config works normally."""
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
|
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
|
|
|
provider = OpenAICompatibleProvider(api_key="test-key")
|
|
|
|
# No retry_policy or circuit_breaker
|
|
assert provider._retry_policy is None
|
|
assert provider._circuit_breaker is None
|
|
|
|
provider._chat_impl = AsyncMock(
|
|
return_value=LLMResponse(
|
|
content="no retry",
|
|
model="gpt-4o-mini",
|
|
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
)
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gpt-4o-mini",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.content == "no retry"
|
|
|
|
async def test_provider_with_both_retry_and_circuit_breaker(self):
|
|
"""Provider with both retry and circuit breaker wraps correctly."""
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
|
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
|
|
|
retry_config = RetryConfig(max_retries=2, base_delay=0.01)
|
|
cb_config = CircuitBreakerConfig(failure_threshold=5)
|
|
|
|
provider = OpenAICompatibleProvider(
|
|
api_key="test-key",
|
|
retry_config=retry_config,
|
|
circuit_breaker_config=cb_config,
|
|
)
|
|
|
|
call_count = 0
|
|
|
|
async def mock_chat_impl(request):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count <= 2:
|
|
raise LLMProviderError("openai", "HTTP 429: Rate limit")
|
|
return LLMResponse(
|
|
content="success after retry",
|
|
model="gpt-4o-mini",
|
|
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
|
|
provider._chat_impl = mock_chat_impl
|
|
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gpt-4o-mini",
|
|
)
|
|
response = await provider.chat(request)
|
|
|
|
assert response.content == "success after retry"
|
|
assert call_count == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config integration tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfigIntegration:
|
|
"""Config loading with retry/circuit_breaker sections"""
|
|
|
|
def test_from_dict_with_retry_and_circuit_breaker(self):
|
|
"""YAML config with retry and circuit_breaker sections loads correctly."""
|
|
from agentkit.llm.config import LLMConfig
|
|
|
|
data = {
|
|
"providers": {
|
|
"openai": {
|
|
"api_key": "sk-test",
|
|
"base_url": "https://api.openai.com/v1",
|
|
"retry": {
|
|
"max_retries": 5,
|
|
"base_delay": 2.0,
|
|
},
|
|
"circuit_breaker": {
|
|
"failure_threshold": 3,
|
|
"recovery_timeout": 30.0,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
config = LLMConfig.from_dict(data)
|
|
provider_conf = config.providers["openai"]
|
|
|
|
assert provider_conf.retry is not None
|
|
assert provider_conf.retry.max_retries == 5
|
|
assert provider_conf.retry.base_delay == 2.0
|
|
|
|
assert provider_conf.circuit_breaker is not None
|
|
assert provider_conf.circuit_breaker.failure_threshold == 3
|
|
assert provider_conf.circuit_breaker.recovery_timeout == 30.0
|
|
|
|
def test_from_dict_without_retry_or_circuit_breaker(self):
|
|
"""Config without retry/circuit_breaker sections loads with None."""
|
|
from agentkit.llm.config import LLMConfig
|
|
|
|
data = {
|
|
"providers": {
|
|
"openai": {
|
|
"api_key": "sk-test",
|
|
"base_url": "https://api.openai.com/v1",
|
|
},
|
|
},
|
|
}
|
|
|
|
config = LLMConfig.from_dict(data)
|
|
provider_conf = config.providers["openai"]
|
|
|
|
assert provider_conf.retry is None
|
|
assert provider_conf.circuit_breaker is None
|