""" 基于内存的简易限流中间件(MVP不依赖Redis) """ import time from collections import defaultdict from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse def _extract_user_id_from_request(request: Request) -> str | None: """尝试从 Authorization header 解析 user_id(JWT sub)。 解析失败时返回 None,不影响主流程。 """ auth_header = request.headers.get("authorization", "") if not auth_header.startswith("Bearer "): return None token = auth_header[len("Bearer "):] if not token: return None try: from app.services.auth import verify_token payload = verify_token(token) user_id: str | None = payload.get("sub") return user_id except Exception: return None class RateLimitMiddleware(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) # {key: [(timestamp, ...)]} self._requests = defaultdict(list) # 限流规则 self.rules = { "auth_strict": { # /api/v1/auth/login, register - 严格限流 5次/分钟/IP "paths": ["/api/v1/auth/login", "/api/v1/auth/register"], "max_requests": 5, "window_seconds": 60, }, "auth": { # /api/v1/auth/ 其余接口 "paths": ["/api/v1/auth/forgot-password", "/api/v1/auth/refresh"], "max_requests": 20, "window_seconds": 60, }, "query_run": { # run-now "paths": ["/run-now"], # 用 endswith 匹配 "max_requests": 10, "window_seconds": 3600, }, "global": { "max_requests": 100, "window_seconds": 60, } } async def dispatch(self, request: Request, call_next): client_ip = request.client.host if request.client else "unknown" path = request.url.path now = time.time() # 健康检查不限流 if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"): return await call_next(request) # 尝试从 Authorization header 解析 user_id user_id = _extract_user_id_from_request(request) # 检查严格限流认证接口(login/register:5次/分钟/IP) if any(path == p for p in self.rules["auth_strict"]["paths"]): key = f"auth_strict:{client_ip}" if self._is_rate_limited(key, now, self.rules["auth_strict"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) # 检查普通认证接口限流 elif any(path == p for p in self.rules["auth"]["paths"]): key = f"auth:{client_ip}" if self._is_rate_limited(key, now, self.rules["auth"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) # 检查查询执行限流(基于用户ID+IP组合) if path.endswith("/run-now") and request.method == "POST": if user_id: key = f"query_run:{user_id}:{client_ip}" else: key = f"query_run:{client_ip}" if self._is_rate_limited(key, now, self.rules["query_run"]): return JSONResponse( status_code=429, content={"detail": "查询执行过于频繁,请稍后再试"} ) # 全局限流(基于用户ID+IP组合,未认证请求按IP限流) if user_id: key = f"global:{user_id}:{client_ip}" else: key = f"global:{client_ip}" if self._is_rate_limited(key, now, self.rules["global"]): return JSONResponse( status_code=429, content={"detail": "请求过于频繁,请稍后再试"} ) return await call_next(request) def _is_rate_limited(self, key, now, rule): window = rule["window_seconds"] max_req = rule["max_requests"] # 清理过期记录 self._requests[key] = [t for t in self._requests[key] if now - t < window] if len(self._requests[key]) >= max_req: return True self._requests[key].append(now) return False