182 lines
7.3 KiB
Python
182 lines
7.3 KiB
Python
"""限流中间件测试 — 覆盖内存后端和 Redis 后端降级场景。"""
|
||
import asyncio
|
||
import time
|
||
|
||
import pytest
|
||
|
||
from app.middleware.rate_limit import (
|
||
MemoryRateLimitBackend,
|
||
RateLimitBackend,
|
||
RedisRateLimitBackend,
|
||
_create_backend,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MemoryRateLimitBackend 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestMemoryRateLimitBackend:
|
||
"""内存后端核心逻辑测试。"""
|
||
|
||
def setup_method(self):
|
||
self.backend = MemoryRateLimitBackend()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_under_limit_not_blocked(self):
|
||
"""未达限流阈值时不被限流。"""
|
||
now = time.time()
|
||
for i in range(5):
|
||
result = await self.backend.is_rate_limited("test:key", now + i * 0.001, 5, 60)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_at_limit_blocked(self):
|
||
"""达到限流阈值后被限流。"""
|
||
now = time.time()
|
||
# 先发送 max_requests 个请求
|
||
for i in range(5):
|
||
await self.backend.is_rate_limited("test:key", now + i * 0.001, 5, 60)
|
||
# 第 6 个请求应被限流
|
||
result = await self.backend.is_rate_limited("test:key", now + 0.006, 5, 60)
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_window_expiry_allows_new_requests(self):
|
||
"""窗口过期后允许新请求。"""
|
||
now = time.time()
|
||
# 在窗口开始时发送 5 个请求
|
||
for i in range(5):
|
||
await self.backend.is_rate_limited("test:key", now + i * 0.001, 5, 60)
|
||
# 窗口过期后发送新请求
|
||
new_now = now + 61 # 超过 60 秒窗口
|
||
result = await self.backend.is_rate_limited("test:key", new_now, 5, 60)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_different_keys_independent(self):
|
||
"""不同 key 的限流状态独立。"""
|
||
now = time.time()
|
||
# key1 达到限流
|
||
for i in range(5):
|
||
await self.backend.is_rate_limited("key1", now + i * 0.001, 5, 60)
|
||
result1 = await self.backend.is_rate_limited("key1", now + 0.006, 5, 60)
|
||
assert result1 is True
|
||
# key2 不受限流影响
|
||
result2 = await self.backend.is_rate_limited("key2", now, 5, 60)
|
||
assert result2 is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reset_clears_key(self):
|
||
"""reset 清除指定 key 的限流状态。"""
|
||
now = time.time()
|
||
for i in range(5):
|
||
await self.backend.is_rate_limited("test:key", now + i * 0.001, 5, 60)
|
||
# 限流中
|
||
result = await self.backend.is_rate_limited("test:key", now + 0.006, 5, 60)
|
||
assert result is True
|
||
# 重置后不再限流
|
||
await self.backend.reset("test:key")
|
||
result = await self.backend.is_rate_limited("test:key", now + 0.007, 5, 60)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cleanup_removes_expired_entries(self):
|
||
"""后台清理任务移除过期记录。"""
|
||
now = time.time()
|
||
# 添加一些记录
|
||
await self.backend.is_rate_limited("old_key", now - 3700, 5, 60)
|
||
await self.backend.is_rate_limited("recent_key", now, 5, 60)
|
||
# 模拟清理
|
||
self.backend._requests["old_key"] = [now - 3700]
|
||
self.backend._requests["recent_key"] = [now]
|
||
# 手动触发清理逻辑
|
||
expired_keys = []
|
||
for key in list(self.backend._requests.keys()):
|
||
self.backend._requests[key] = [t for t in self.backend._requests[key] if now - t < 3600]
|
||
if not self.backend._requests[key]:
|
||
expired_keys.append(key)
|
||
for key in expired_keys:
|
||
del self.backend._requests[key]
|
||
assert "old_key" not in self.backend._requests
|
||
assert "recent_key" in self.backend._requests
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RedisRateLimitBackend 测试(无 Redis 连接时的降级行为)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRedisRateLimitBackendFallback:
|
||
"""Redis 后端在连接失败时的降级行为。"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_unavailable_allows_request(self):
|
||
"""Redis 不可用时放行请求(不阻塞服务)。"""
|
||
backend = RedisRateLimitBackend()
|
||
now = time.time()
|
||
# Redis 连接失败时应返回 False(不被限流)
|
||
result = await backend.is_rate_limited("test:key", now, 5, 60)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_reset_silently_fails(self):
|
||
"""Redis reset 失败时静默处理。"""
|
||
backend = RedisRateLimitBackend()
|
||
# 不应抛出异常
|
||
await backend.reset("test:key")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Backend 工厂函数测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestCreateBackend:
|
||
"""_create_backend 工厂函数测试。"""
|
||
|
||
def test_default_creates_memory_backend(self, monkeypatch):
|
||
"""默认配置创建内存后端。"""
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.RATE_LIMIT_BACKEND", "memory")
|
||
backend = _create_backend()
|
||
assert isinstance(backend, MemoryRateLimitBackend)
|
||
|
||
def test_redis_with_empty_url_falls_back_to_memory(self, monkeypatch):
|
||
"""RATE_LIMIT_BACKEND=redis 但 REDIS_URL 为空时降级到内存。"""
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.RATE_LIMIT_BACKEND", "redis")
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.REDIS_URL", "")
|
||
backend = _create_backend()
|
||
assert isinstance(backend, MemoryRateLimitBackend)
|
||
|
||
def test_redis_with_url_creates_redis_backend(self, monkeypatch):
|
||
"""RATE_LIMIT_BACKEND=redis 且 REDIS_URL 非空时创建 Redis 后端。"""
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.RATE_LIMIT_BACKEND", "redis")
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.REDIS_URL", "redis://localhost:6379/0")
|
||
backend = _create_backend()
|
||
assert isinstance(backend, RedisRateLimitBackend)
|
||
|
||
def test_unknown_backend_defaults_to_memory(self, monkeypatch):
|
||
"""未知后端类型默认使用内存。"""
|
||
monkeypatch.setattr("app.middleware.rate_limit.settings.RATE_LIMIT_BACKEND", "unknown")
|
||
backend = _create_backend()
|
||
assert isinstance(backend, MemoryRateLimitBackend)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RateLimitBackend 接口测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRateLimitBackendInterface:
|
||
"""验证 RateLimitBackend 抽象接口。"""
|
||
|
||
def test_cannot_instantiate_abstract_class(self):
|
||
"""不能直接实例化抽象类。"""
|
||
with pytest.raises(TypeError):
|
||
RateLimitBackend()
|
||
|
||
def test_memory_backend_is_subclass(self):
|
||
"""MemoryRateLimitBackend 是 RateLimitBackend 的子类。"""
|
||
assert issubclass(MemoryRateLimitBackend, RateLimitBackend)
|
||
|
||
def test_redis_backend_is_subclass(self):
|
||
"""RedisRateLimitBackend 是 RateLimitBackend 的子类。"""
|
||
assert issubclass(RedisRateLimitBackend, RateLimitBackend)
|