""" 限流中间件 — 支持 Redis 和内存双后端。 生产环境推荐使用 Redis 后端(支持多实例共享限流状态), 开发/测试环境使用内存后端。Redis 不可用时自动降级到内存后端。 """ import asyncio import logging import os import time from abc import ABC, abstractmethod from collections import defaultdict from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from app.config import settings logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Backend 抽象接口 # --------------------------------------------------------------------------- class RateLimitBackend(ABC): """限流后端抽象接口。""" @abstractmethod async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool: """检查是否被限流。返回 True 表示被限流。""" @abstractmethod async def reset(self, key: str) -> None: """重置指定 key 的限流状态。""" # --------------------------------------------------------------------------- # 内存后端 # --------------------------------------------------------------------------- class MemoryRateLimitBackend(RateLimitBackend): """基于内存的限流后端(单实例,开发/测试用)。""" def __init__(self) -> None: self._requests: dict[str, list[float]] = defaultdict(list) self._cleanup_task: asyncio.Task | None = None def start_cleanup(self) -> None: """启动后台清理任务,定期清理过期记录。""" if self._cleanup_task is None or self._cleanup_task.done(): self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def _cleanup_loop(self) -> None: """每 60 秒清理一次过期记录。""" while True: try: await asyncio.sleep(60) now = time.time() # 清理所有 key 中超过 3600 秒(1 小时)的记录 expired_keys = [] for key in list(self._requests.keys()): self._requests[key] = [t for t in self._requests[key] if now - t < 3600] if not self._requests[key]: expired_keys.append(key) for key in expired_keys: del self._requests[key] except asyncio.CancelledError: break except Exception as exc: logger.warning("Rate limit cleanup error: %s", exc) async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool: # 清理过期记录 self._requests[key] = [t for t in self._requests[key] if now - t < window_seconds] if len(self._requests[key]) >= max_requests: return True self._requests[key].append(now) return False async def reset(self, key: str) -> None: self._requests.pop(key, None) # --------------------------------------------------------------------------- # Redis 后端 # --------------------------------------------------------------------------- class RedisRateLimitBackend(RateLimitBackend): """基于 Redis 的限流后端(多实例共享,生产用)。 使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。 Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。 """ def __init__(self, redis_url: str) -> None: self._redis_url = redis_url self._redis = None async def _get_redis(self): if self._redis is None: import redis.asyncio as aioredis self._redis = aioredis.from_url( self._redis_url, encoding="utf-8", decode_responses=True, ) return self._redis async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool: try: redis = await self._get_redis() window_start = now - window_seconds member = f"{now}" # Pipeline 保证原子性 async with redis.pipeline(transaction=True) as pipe: # 移除窗口外的旧记录 pipe.zremrangebyscore(key, "-inf", window_start) # 添加当前请求 pipe.zadd(key, {member: now}) # 统计窗口内请求数 pipe.zcard(key) # 设置 key 过期时间(2 倍窗口,防止僵尸 key) pipe.expire(key, window_seconds * 2) results = await pipe.execute() count = results[2] # zcard 结果 return count > max_requests except Exception as exc: logger.warning("Redis rate limit check failed for key=%s: %s, falling back to allow", key, exc) # Redis 异常时放行请求,避免影响正常服务 return False async def reset(self, key: str) -> None: try: redis = await self._get_redis() await redis.delete(key) except Exception as exc: logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc) async def close(self) -> None: if self._redis is not None: await self._redis.aclose() self._redis = None # --------------------------------------------------------------------------- # 辅助函数 # --------------------------------------------------------------------------- def _extract_user_id_from_request(request: Request) -> str | None: """尝试从 Authorization header 解析 user_id(JWT sub)。 解析失败时返回 None,不影响主流程。 """ auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[len("Bearer "):] if not token: return None try: from app.services.auth import verify_token payload = verify_token(token) user_id: str | None = payload.get("sub") return user_id except Exception: return None def _create_backend() -> RateLimitBackend: """根据配置创建限流后端,Redis 不可用时降级到内存后端。""" backend_type = settings.RATE_LIMIT_BACKEND.lower() if backend_type == "redis": redis_url = settings.REDIS_URL if not redis_url: logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend") return MemoryRateLimitBackend() try: backend = RedisRateLimitBackend(redis_url) logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url) return backend except Exception as exc: logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc) return MemoryRateLimitBackend() # 默认使用内存后端 backend = MemoryRateLimitBackend() logger.info("Rate limiter using memory backend") return backend # --------------------------------------------------------------------------- # 中间件 # --------------------------------------------------------------------------- class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) self._backend = _create_backend() # 如果是内存后端,启动后台清理任务 if isinstance(self._backend, MemoryRateLimitBackend): self._backend.start_cleanup() # 限流规则 self.rules = { "auth_strict": { # /api/v1/auth/login, register - 严格限流 5次/分钟/IP "paths": ["/api/v1/auth/login", "/api/v1/auth/register"], "max_requests": 5, "window_seconds": 60, }, "auth": { # /api/v1/auth/ 其余接口 "paths": ["/api/v1/auth/forgot-password", "/api/v1/auth/refresh"], "max_requests": 20, "window_seconds": 60, }, "query_run": { # run-now "paths": ["/run-now"], # 用 endswith 匹配 "max_requests": 10, "window_seconds": 3600, }, "global": { "max_requests": 100, "window_seconds": 60, } } async def dispatch(self, request: Request, call_next): # E2E 测试环境放宽限流 if os.getenv("RATE_LIMIT_DISABLED") == "1": return await call_next(request) client_ip = request.client.host if request.client else "unknown" path = request.url.path now = time.time() # 健康检查不限流 if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"): return await call_next(request) # 尝试从 Authorization header 解析 user_id user_id = _extract_user_id_from_request(request) # 检查严格限流认证接口(login/register:5次/分钟/IP) if any(path == p for p in self.rules["auth_strict"]["paths"]): key = f"auth_strict:{client_ip}" if await self._backend.is_rate_limited(key, now, self.rules["auth_strict"]["max_requests"], self.rules["auth_strict"]["window_seconds"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) # 检查普通认证接口限流 elif any(path == p for p in self.rules["auth"]["paths"]): key = f"auth:{client_ip}" if await self._backend.is_rate_limited(key, now, self.rules["auth"]["max_requests"], self.rules["auth"]["window_seconds"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) # 检查查询执行限流(基于用户ID+IP组合) if path.endswith("/run-now") and request.method == "POST": if user_id: key = f"query_run:{user_id}:{client_ip}" else: key = f"query_run:{client_ip}" if await self._backend.is_rate_limited(key, now, self.rules["query_run"]["max_requests"], self.rules["query_run"]["window_seconds"]): return JSONResponse( status_code=429, content={"detail": "查询执行过于频繁,请稍后再试"} ) # 全局限流(基于用户ID+IP组合,未认证请求按IP限流) if user_id: key = f"global:{user_id}:{client_ip}" else: key = f"global:{client_ip}" if await self._backend.is_rate_limited(key, now, self.rules["global"]["max_requests"], self.rules["global"]["window_seconds"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) return await call_next(request)