126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
"""
|
||
基于内存的简易限流中间件(MVP不依赖Redis)
|
||
"""
|
||
import time
|
||
from collections import defaultdict
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import JSONResponse
|
||
|
||
|
||
def _extract_user_id_from_request(request: Request) -> str | None:
|
||
"""尝试从 Authorization header 解析 user_id(JWT sub)。
|
||
解析失败时返回 None,不影响主流程。
|
||
"""
|
||
auth_header = request.headers.get("authorization", "")
|
||
if not auth_header.startswith("Bearer "):
|
||
return None
|
||
token = auth_header[len("Bearer "):]
|
||
if not token:
|
||
return None
|
||
try:
|
||
from app.services.auth import verify_token
|
||
payload = verify_token(token)
|
||
user_id: str | None = payload.get("sub")
|
||
return user_id
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
def __init__(self, app):
|
||
super().__init__(app)
|
||
# {key: [(timestamp, ...)]}
|
||
self._requests = defaultdict(list)
|
||
|
||
# 限流规则
|
||
self.rules = {
|
||
"auth_strict": { # /api/v1/auth/login, register - 严格限流 5次/分钟/IP
|
||
"paths": ["/api/v1/auth/login", "/api/v1/auth/register"],
|
||
"max_requests": 5,
|
||
"window_seconds": 60,
|
||
},
|
||
"auth": { # /api/v1/auth/ 其余接口
|
||
"paths": ["/api/v1/auth/forgot-password", "/api/v1/auth/refresh"],
|
||
"max_requests": 20,
|
||
"window_seconds": 60,
|
||
},
|
||
"query_run": { # run-now
|
||
"paths": ["/run-now"], # 用 endswith 匹配
|
||
"max_requests": 10,
|
||
"window_seconds": 3600,
|
||
},
|
||
"global": {
|
||
"max_requests": 100,
|
||
"window_seconds": 60,
|
||
}
|
||
}
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
client_ip = request.client.host if request.client else "unknown"
|
||
path = request.url.path
|
||
now = time.time()
|
||
|
||
# 健康检查不限流
|
||
if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"):
|
||
return await call_next(request)
|
||
|
||
# 尝试从 Authorization header 解析 user_id
|
||
user_id = _extract_user_id_from_request(request)
|
||
|
||
# 检查严格限流认证接口(login/register:5次/分钟/IP)
|
||
if any(path == p for p in self.rules["auth_strict"]["paths"]):
|
||
key = f"auth_strict:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["auth_strict"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 检查普通认证接口限流
|
||
elif any(path == p for p in self.rules["auth"]["paths"]):
|
||
key = f"auth:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["auth"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 检查查询执行限流(基于用户ID+IP组合)
|
||
if path.endswith("/run-now") and request.method == "POST":
|
||
if user_id:
|
||
key = f"query_run:{user_id}:{client_ip}"
|
||
else:
|
||
key = f"query_run:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["query_run"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "查询执行过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 全局限流(基于用户ID+IP组合,未认证请求按IP限流)
|
||
if user_id:
|
||
key = f"global:{user_id}:{client_ip}"
|
||
else:
|
||
key = f"global:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["global"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
return await call_next(request)
|
||
|
||
def _is_rate_limited(self, key, now, rule):
|
||
window = rule["window_seconds"]
|
||
max_req = rule["max_requests"]
|
||
|
||
# 清理过期记录
|
||
self._requests[key] = [t for t in self._requests[key] if now - t < window]
|
||
|
||
if len(self._requests[key]) >= max_req:
|
||
return True
|
||
|
||
self._requests[key].append(now)
|
||
return False
|