291 lines
11 KiB
Python
291 lines
11 KiB
Python
"""
|
||
限流中间件 — 支持 Redis 和内存双后端。
|
||
|
||
生产环境推荐使用 Redis 后端(支持多实例共享限流状态),
|
||
开发/测试环境使用内存后端。Redis 不可用时自动降级到内存后端。
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from collections import defaultdict
|
||
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import JSONResponse
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Backend 抽象接口
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class RateLimitBackend(ABC):
|
||
"""限流后端抽象接口。"""
|
||
|
||
@abstractmethod
|
||
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
||
"""检查是否被限流。返回 True 表示被限流。"""
|
||
|
||
@abstractmethod
|
||
async def reset(self, key: str) -> None:
|
||
"""重置指定 key 的限流状态。"""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 内存后端
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class MemoryRateLimitBackend(RateLimitBackend):
|
||
"""基于内存的限流后端(单实例,开发/测试用)。"""
|
||
|
||
def __init__(self) -> None:
|
||
self._requests: dict[str, list[float]] = defaultdict(list)
|
||
self._cleanup_task: asyncio.Task | None = None
|
||
|
||
def start_cleanup(self) -> None:
|
||
"""启动后台清理任务,定期清理过期记录。"""
|
||
if self._cleanup_task is None or self._cleanup_task.done():
|
||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
|
||
async def _cleanup_loop(self) -> None:
|
||
"""每 60 秒清理一次过期记录。"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
now = time.time()
|
||
# 清理所有 key 中超过 3600 秒(1 小时)的记录
|
||
expired_keys = []
|
||
for key in list(self._requests.keys()):
|
||
self._requests[key] = [t for t in self._requests[key] if now - t < 3600]
|
||
if not self._requests[key]:
|
||
expired_keys.append(key)
|
||
for key in expired_keys:
|
||
del self._requests[key]
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as exc:
|
||
logger.warning("Rate limit cleanup error: %s", exc)
|
||
|
||
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
||
# 清理过期记录
|
||
self._requests[key] = [t for t in self._requests[key] if now - t < window_seconds]
|
||
|
||
if len(self._requests[key]) >= max_requests:
|
||
return True
|
||
|
||
self._requests[key].append(now)
|
||
return False
|
||
|
||
async def reset(self, key: str) -> None:
|
||
self._requests.pop(key, None)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Redis 后端
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class RedisRateLimitBackend(RateLimitBackend):
|
||
"""基于 Redis 的限流后端(多实例共享,生产用)。
|
||
|
||
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。
|
||
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。
|
||
"""
|
||
|
||
def __init__(self, redis_url: str) -> None:
|
||
self._redis_url = redis_url
|
||
self._redis = None
|
||
|
||
async def _get_redis(self):
|
||
if self._redis is None:
|
||
import redis.asyncio as aioredis
|
||
self._redis = aioredis.from_url(
|
||
self._redis_url,
|
||
encoding="utf-8",
|
||
decode_responses=True,
|
||
)
|
||
return self._redis
|
||
|
||
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
||
try:
|
||
redis = await self._get_redis()
|
||
window_start = now - window_seconds
|
||
member = f"{now}"
|
||
|
||
# Pipeline 保证原子性
|
||
async with redis.pipeline(transaction=True) as pipe:
|
||
# 移除窗口外的旧记录
|
||
pipe.zremrangebyscore(key, "-inf", window_start)
|
||
# 添加当前请求
|
||
pipe.zadd(key, {member: now})
|
||
# 统计窗口内请求数
|
||
pipe.zcard(key)
|
||
# 设置 key 过期时间(2 倍窗口,防止僵尸 key)
|
||
pipe.expire(key, window_seconds * 2)
|
||
results = await pipe.execute()
|
||
|
||
count = results[2] # zcard 结果
|
||
return count > max_requests
|
||
except Exception as exc:
|
||
logger.warning("Redis rate limit check failed for key=%s: %s, falling back to allow", key, exc)
|
||
# Redis 异常时放行请求,避免影响正常服务
|
||
return False
|
||
|
||
async def reset(self, key: str) -> None:
|
||
try:
|
||
redis = await self._get_redis()
|
||
await redis.delete(key)
|
||
except Exception as exc:
|
||
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
|
||
|
||
async def close(self) -> None:
|
||
if self._redis is not None:
|
||
await self._redis.aclose()
|
||
self._redis = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _extract_user_id_from_request(request: Request) -> str | None:
|
||
"""尝试从 Authorization header 解析 user_id(JWT sub)。
|
||
解析失败时返回 None,不影响主流程。
|
||
"""
|
||
auth_header = request.headers.get("authorization", "")
|
||
if not auth_header.startswith("Bearer "):
|
||
return None
|
||
token = auth_header[len("Bearer "):]
|
||
if not token:
|
||
return None
|
||
try:
|
||
from app.services.auth import verify_token
|
||
payload = verify_token(token)
|
||
user_id: str | None = payload.get("sub")
|
||
return user_id
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _create_backend() -> RateLimitBackend:
|
||
"""根据配置创建限流后端,Redis 不可用时降级到内存后端。"""
|
||
backend_type = settings.RATE_LIMIT_BACKEND.lower()
|
||
|
||
if backend_type == "redis":
|
||
redis_url = settings.REDIS_URL
|
||
if not redis_url:
|
||
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
|
||
return MemoryRateLimitBackend()
|
||
try:
|
||
backend = RedisRateLimitBackend(redis_url)
|
||
logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url)
|
||
return backend
|
||
except Exception as exc:
|
||
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)
|
||
return MemoryRateLimitBackend()
|
||
|
||
# 默认使用内存后端
|
||
backend = MemoryRateLimitBackend()
|
||
logger.info("Rate limiter using memory backend")
|
||
return backend
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 中间件
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
def __init__(self, app):
|
||
super().__init__(app)
|
||
self._backend = _create_backend()
|
||
|
||
# 如果是内存后端,启动后台清理任务
|
||
if isinstance(self._backend, MemoryRateLimitBackend):
|
||
self._backend.start_cleanup()
|
||
|
||
# 限流规则
|
||
self.rules = {
|
||
"auth_strict": { # /api/v1/auth/login, register - 严格限流 5次/分钟/IP
|
||
"paths": ["/api/v1/auth/login", "/api/v1/auth/register"],
|
||
"max_requests": 5,
|
||
"window_seconds": 60,
|
||
},
|
||
"auth": { # /api/v1/auth/ 其余接口
|
||
"paths": ["/api/v1/auth/forgot-password", "/api/v1/auth/refresh"],
|
||
"max_requests": 20,
|
||
"window_seconds": 60,
|
||
},
|
||
"query_run": { # run-now
|
||
"paths": ["/run-now"], # 用 endswith 匹配
|
||
"max_requests": 10,
|
||
"window_seconds": 3600,
|
||
},
|
||
"global": {
|
||
"max_requests": 100,
|
||
"window_seconds": 60,
|
||
}
|
||
}
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
# E2E 测试环境放宽限流
|
||
if os.getenv("RATE_LIMIT_DISABLED") == "1":
|
||
return await call_next(request)
|
||
|
||
client_ip = request.client.host if request.client else "unknown"
|
||
path = request.url.path
|
||
now = time.time()
|
||
|
||
# 健康检查不限流
|
||
if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"):
|
||
return await call_next(request)
|
||
|
||
# 尝试从 Authorization header 解析 user_id
|
||
user_id = _extract_user_id_from_request(request)
|
||
|
||
# 检查严格限流认证接口(login/register:5次/分钟/IP)
|
||
if any(path == p for p in self.rules["auth_strict"]["paths"]):
|
||
key = f"auth_strict:{client_ip}"
|
||
if await self._backend.is_rate_limited(key, now, self.rules["auth_strict"]["max_requests"], self.rules["auth_strict"]["window_seconds"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 检查普通认证接口限流
|
||
elif any(path == p for p in self.rules["auth"]["paths"]):
|
||
key = f"auth:{client_ip}"
|
||
if await self._backend.is_rate_limited(key, now, self.rules["auth"]["max_requests"], self.rules["auth"]["window_seconds"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 检查查询执行限流(基于用户ID+IP组合)
|
||
if path.endswith("/run-now") and request.method == "POST":
|
||
if user_id:
|
||
key = f"query_run:{user_id}:{client_ip}"
|
||
else:
|
||
key = f"query_run:{client_ip}"
|
||
if await self._backend.is_rate_limited(key, now, self.rules["query_run"]["max_requests"], self.rules["query_run"]["window_seconds"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "查询执行过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 全局限流(基于用户ID+IP组合,未认证请求按IP限流)
|
||
if user_id:
|
||
key = f"global:{user_id}:{client_ip}"
|
||
else:
|
||
key = f"global:{client_ip}"
|
||
if await self._backend.is_rate_limited(key, now, self.rules["global"]["max_requests"], self.rules["global"]["window_seconds"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
return await call_next(request)
|