geo/backend/app/middleware/rate_limit.py

83 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基于内存的简易限流中间件MVP不依赖Redis
"""
import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
# {key: [(timestamp, ...)]}
self._requests = defaultdict(list)
# 限流规则
self.rules = {
"auth": { # /api/v1/auth/login, register, forgot-password
"paths": ["/api/v1/auth/login", "/api/v1/auth/register", "/api/v1/auth/forgot-password"],
"max_requests": 5,
"window_seconds": 60,
},
"query_run": { # run-now
"paths": ["/run-now"], # 用 endswith 匹配
"max_requests": 10,
"window_seconds": 3600,
},
"global": {
"max_requests": 100,
"window_seconds": 60,
}
}
async def dispatch(self, request: Request, call_next):
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
now = time.time()
# 健康检查不限流
if path == "/health" or path.startswith("/docs") or path.startswith("/openapi"):
return await call_next(request)
# 检查认证接口限流
if any(path == p for p in self.rules["auth"]["paths"]):
key = f"auth:{client_ip}"
if self._is_rate_limited(key, now, self.rules["auth"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
# 检查查询执行限流
if path.endswith("/run-now") and request.method == "POST":
key = f"query_run:{client_ip}"
if self._is_rate_limited(key, now, self.rules["query_run"]):
return JSONResponse(
status_code=429,
content={"detail": "查询执行过于频繁,请稍后再试"}
)
# 全局限流
key = f"global:{client_ip}"
if self._is_rate_limited(key, now, self.rules["global"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
return await call_next(request)
def _is_rate_limited(self, key, now, rule):
window = rule["window_seconds"]
max_req = rule["max_requests"]
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
if len(self._requests[key]) >= max_req:
return True
self._requests[key].append(now)
return False