83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
"""
|
||
基于内存的简易限流中间件(MVP不依赖Redis)
|
||
"""
|
||
import time
|
||
from collections import defaultdict
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
from starlette.responses import JSONResponse
|
||
|
||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||
def __init__(self, app):
|
||
super().__init__(app)
|
||
# {key: [(timestamp, ...)]}
|
||
self._requests = defaultdict(list)
|
||
|
||
# 限流规则
|
||
self.rules = {
|
||
"auth": { # /api/v1/auth/login, register, forgot-password
|
||
"paths": ["/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/forgot-password"],
|
||
"max_requests": 5,
|
||
"window_seconds": 60,
|
||
},
|
||
"query_run": { # run-now
|
||
"paths": ["/run-now"], # 用 endswith 匹配
|
||
"max_requests": 10,
|
||
"window_seconds": 3600,
|
||
},
|
||
"global": {
|
||
"max_requests": 100,
|
||
"window_seconds": 60,
|
||
}
|
||
}
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
client_ip = request.client.host if request.client else "unknown"
|
||
path = request.url.path
|
||
now = time.time()
|
||
|
||
# 健康检查不限流
|
||
if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"):
|
||
return await call_next(request)
|
||
|
||
# 检查认证接口限流
|
||
if any(path == p for p in self.rules["auth"]["paths"]):
|
||
key = f"auth:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["auth"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 检查查询执行限流
|
||
if path.endswith("/run-now") and request.method == "POST":
|
||
key = f"query_run:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["query_run"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "查询执行过于频繁,请稍后再试"}
|
||
)
|
||
|
||
# 全局限流
|
||
key = f"global:{client_ip}"
|
||
if self._is_rate_limited(key, now, self.rules["global"]):
|
||
return JSONResponse(
|
||
status_code=429,
|
||
content={"detail": "请求过于频繁,请稍后再试"}
|
||
)
|
||
|
||
return await call_next(request)
|
||
|
||
def _is_rate_limited(self, key, now, rule):
|
||
window = rule["window_seconds"]
|
||
max_req = rule["max_requests"]
|
||
|
||
# 清理过期记录
|
||
self._requests[key] = [t for t in self._requests[key] if now - t < window]
|
||
|
||
if len(self._requests[key]) >= max_req:
|
||
return True
|
||
|
||
self._requests[key].append(now)
|
||
return False
|