geo/backend/app/middleware/rate_limit.py

126 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基于内存的简易限流中间件MVP不依赖Redis
"""
import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
def _extract_user_id_from_request(request: Request) -> str | None:
"""尝试从 Authorization header 解析 user_idJWT sub
解析失败时返回 None不影响主流程。
"""
auth_header = request.headers.get("authorization", "")
if not auth_header.startswith("Bearer "):
return None
token = auth_header[len("Bearer "):]
if not token:
return None
try:
from app.services.auth import verify_token
payload = verify_token(token)
user_id: str | None = payload.get("sub")
return user_id
except Exception:
return None
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
# {key: [(timestamp, ...)]}
self._requests = defaultdict(list)
# 限流规则
self.rules = {
"auth_strict": { # /api/v1/auth/login, register - 严格限流 5次/分钟/IP
"paths": ["/api/v1/auth/login", "/api/v1/auth/register"],
"max_requests": 5,
"window_seconds": 60,
},
"auth": { # /api/v1/auth/ 其余接口
"paths": ["/api/v1/auth/forgot-password", "/api/v1/auth/refresh"],
"max_requests": 20,
"window_seconds": 60,
},
"query_run": { # run-now
"paths": ["/run-now"], # 用 endswith 匹配
"max_requests": 10,
"window_seconds": 3600,
},
"global": {
"max_requests": 100,
"window_seconds": 60,
}
}
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
now = time.time()
# 健康检查不限流
if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"):
return await call_next(request)
# 尝试从 Authorization header 解析 user_id
user_id = _extract_user_id_from_request(request)
# 检查严格限流认证接口login/register5次/分钟/IP
if any(path == p for p in self.rules["auth_strict"]["paths"]):
key = f"auth_strict:{client_ip}"
if self._is_rate_limited(key, now, self.rules["auth_strict"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
# 检查普通认证接口限流
elif any(path == p for p in self.rules["auth"]["paths"]):
key = f"auth:{client_ip}"
if self._is_rate_limited(key, now, self.rules["auth"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
# 检查查询执行限流基于用户ID+IP组合
if path.endswith("/run-now") and request.method == "POST":
if user_id:
key = f"query_run:{user_id}:{client_ip}"
else:
key = f"query_run:{client_ip}"
if self._is_rate_limited(key, now, self.rules["query_run"]):
return JSONResponse(
status_code=429,
content={"detail": "查询执行过于频繁,请稍后再试"}
)
# 全局限流基于用户ID+IP组合未认证请求按IP限流
if user_id:
key = f"global:{user_id}:{client_ip}"
else:
key = f"global:{client_ip}"
if self._is_rate_limited(key, now, self.rules["global"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
return await call_next(request)
def _is_rate_limited(self, key, now, rule):
window = rule["window_seconds"]
max_req = rule["max_requests"]
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
if len(self._requests[key]) >= max_req:
return True
self._requests[key].append(now)
return False