feat: Sentry integration + rate limiter dual backend

- Initialize Sentry SDK in FastAPI (auto-disabled when DSN empty)
- Add sentry_sdk.set_measurement in metrics middleware
- Add sentry-sdk[fastapi] to requirements
- Refactor rate_limit.py: abstract RateLimitBackend + MemoryBackend + RedisBackend
- Redis backend uses sorted set sliding window with pipeline atomicity
- Memory backend adds asyncio cleanup task to prevent memory growth
- Auto-fallback to memory when Redis unavailable
- Add RATE_LIMIT_BACKEND config (default: memory)
This commit is contained in:
chiguyong 2026-06-04 14:04:36 +08:00
parent ee8578c3d7
commit 3737a90471
5 changed files with 213 additions and 26 deletions

View File

@ -13,7 +13,7 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=str(_env_path), extra="ignore")
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres123@db:5432/geo_platform"
REDIS_URL: str = "redis://redis:6379/0"
REDIS_URL: str = ""
# JWT 密钥:必须通过环境变量设置,不提供任何默认值
JWT_SECRET: str
@ -100,6 +100,13 @@ class Settings(BaseSettings):
ATTRIBUTION_WINDOW_DAYS: int = 28
# Sentry Monitoring
SENTRY_DSN: str = ""
ENVIRONMENT: str = "development"
# Rate Limiting
RATE_LIMIT_BACKEND: str = "memory" # "memory" or "redis"
@field_validator("JWT_SECRET")
@classmethod
def validate_jwt_secret(cls, v: str) -> str:

View File

@ -13,6 +13,16 @@ from sqlalchemy import text
from app.logging_config import setup_logging
setup_logging()
# Sentry 初始化DSN 为空时自动禁用)
import sentry_sdk
from app.config import settings as _sentry_settings
if _sentry_settings.SENTRY_DSN:
sentry_sdk.init(
dsn=_sentry_settings.SENTRY_DSN,
traces_sample_rate=0.1,
environment=_sentry_settings.ENVIRONMENT,
)
from fastapi.middleware.cors import CORSMiddleware
from app.api.admin import router as admin_router

View File

@ -28,7 +28,7 @@ class MetricsMiddleware(BaseHTTPMiddleware):
"""记录每个 HTTP 请求的耗时,并:
- 在响应头写入 X-Response-Time
- 对超过阈值的慢请求输出 WARNING 日志携带结构化字段
- 预留 Sentry 集成点TODO 注释标注
- 预留 Sentry 集成点已集成
"""
async def dispatch(self, request: Request, call_next) -> Response:
@ -61,8 +61,12 @@ class MetricsMiddleware(BaseHTTPMiddleware):
else:
logger.debug("Request completed", extra=log_extra)
# TODO: 集成 Sentry 性能监控
# if sentry_sdk: sentry_sdk.set_measurement("response_time_ms", duration_ms)
# Sentry 性能监控
try:
import sentry_sdk
sentry_sdk.set_measurement("response_time_ms", duration_ms)
except Exception:
pass
return response

View File

@ -1,12 +1,156 @@
"""
基于内存的简易限流中间件MVP不依赖Redis
限流中间件 支持 Redis 和内存双后端
生产环境推荐使用 Redis 后端支持多实例共享限流状态
开发/测试环境使用内存后端Redis 不可用时自动降级到内存后端
"""
import asyncio
import logging
import os
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from app.config import settings
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Backend 抽象接口
# ---------------------------------------------------------------------------
class RateLimitBackend(ABC):
"""限流后端抽象接口。"""
@abstractmethod
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
"""检查是否被限流。返回 True 表示被限流。"""
@abstractmethod
async def reset(self, key: str) -> None:
"""重置指定 key 的限流状态。"""
# ---------------------------------------------------------------------------
# 内存后端
# ---------------------------------------------------------------------------
class MemoryRateLimitBackend(RateLimitBackend):
"""基于内存的限流后端(单实例,开发/测试用)。"""
def __init__(self) -> None:
self._requests: dict[str, list[float]] = defaultdict(list)
self._cleanup_task: asyncio.Task | None = None
def start_cleanup(self) -> None:
"""启动后台清理任务,定期清理过期记录。"""
if self._cleanup_task is None or self._cleanup_task.done():
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def _cleanup_loop(self) -> None:
"""每 60 秒清理一次过期记录。"""
while True:
try:
await asyncio.sleep(60)
now = time.time()
# 清理所有 key 中超过 3600 秒1 小时)的记录
expired_keys = []
for key in list(self._requests.keys()):
self._requests[key] = [t for t in self._requests[key] if now - t < 3600]
if not self._requests[key]:
expired_keys.append(key)
for key in expired_keys:
del self._requests[key]
except asyncio.CancelledError:
break
except Exception as exc:
logger.warning("Rate limit cleanup error: %s", exc)
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window_seconds]
if len(self._requests[key]) >= max_requests:
return True
self._requests[key].append(now)
return False
async def reset(self, key: str) -> None:
self._requests.pop(key, None)
# ---------------------------------------------------------------------------
# Redis 后端
# ---------------------------------------------------------------------------
class RedisRateLimitBackend(RateLimitBackend):
"""基于 Redis 的限流后端(多实例共享,生产用)。
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性
"""
def __init__(self, redis_url: str) -> None:
self._redis_url = redis_url
self._redis = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(
self._redis_url,
encoding="utf-8",
decode_responses=True,
)
return self._redis
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
try:
redis = await self._get_redis()
window_start = now - window_seconds
member = f"{now}"
# Pipeline 保证原子性
async with redis.pipeline(transaction=True) as pipe:
# 移除窗口外的旧记录
pipe.zremrangebyscore(key, "-inf", window_start)
# 添加当前请求
pipe.zadd(key, {member: now})
# 统计窗口内请求数
pipe.zcard(key)
# 设置 key 过期时间2 倍窗口,防止僵尸 key
pipe.expire(key, window_seconds * 2)
results = await pipe.execute()
count = results[2] # zcard 结果
return count > max_requests
except Exception as exc:
logger.warning("Redis rate limit check failed for key=%s: %s, falling back to allow", key, exc)
# Redis 异常时放行请求,避免影响正常服务
return False
async def reset(self, key: str) -> None:
try:
redis = await self._get_redis()
await redis.delete(key)
except Exception as exc:
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
async def close(self) -> None:
if self._redis is not None:
await self._redis.aclose()
self._redis = None
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _extract_user_id_from_request(request: Request) -> str | None:
"""尝试从 Authorization header 解析 user_idJWT sub
@ -27,11 +171,41 @@ def _extract_user_id_from_request(request: Request) -> str | None:
return None
def _create_backend() -> RateLimitBackend:
"""根据配置创建限流后端Redis 不可用时降级到内存后端。"""
backend_type = settings.RATE_LIMIT_BACKEND.lower()
if backend_type == "redis":
redis_url = settings.REDIS_URL
if not redis_url:
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
return MemoryRateLimitBackend()
try:
backend = RedisRateLimitBackend(redis_url)
logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url)
return backend
except Exception as exc:
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)
return MemoryRateLimitBackend()
# 默认使用内存后端
backend = MemoryRateLimitBackend()
logger.info("Rate limiter using memory backend")
return backend
# ---------------------------------------------------------------------------
# 中间件
# ---------------------------------------------------------------------------
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app):
super().__init__(app)
# {key: [(timestamp, ...)]}
self._requests = defaultdict(list)
self._backend = _create_backend()
# 如果是内存后端,启动后台清理任务
if isinstance(self._backend, MemoryRateLimitBackend):
self._backend.start_cleanup()
# 限流规则
self.rules = {
@ -57,6 +231,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
}
async def dispatch(self, request: Request, call_next):
# E2E 测试环境放宽限流
if os.getenv("RATE_LIMIT_DISABLED") == "1":
return await call_next(request)
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
now = time.time()
@ -71,7 +249,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
# 检查严格限流认证接口login/register5次/分钟/IP
if any(path == p for p in self.rules["auth_strict"]["paths"]):
key = f"auth_strict:{client_ip}"
if self._is_rate_limited(key, now, self.rules["auth_strict"]):
if await self._backend.is_rate_limited(key, now, self.rules["auth_strict"]["max_requests"], self.rules["auth_strict"]["window_seconds"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
@ -80,7 +258,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
# 检查普通认证接口限流
elif any(path == p for p in self.rules["auth"]["paths"]):
key = f"auth:{client_ip}"
if self._is_rate_limited(key, now, self.rules["auth"]):
if await self._backend.is_rate_limited(key, now, self.rules["auth"]["max_requests"], self.rules["auth"]["window_seconds"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
@ -92,7 +270,7 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
key = f"query_run:{user_id}:{client_ip}"
else:
key = f"query_run:{client_ip}"
if self._is_rate_limited(key, now, self.rules["query_run"]):
if await self._backend.is_rate_limited(key, now, self.rules["query_run"]["max_requests"], self.rules["query_run"]["window_seconds"]):
return JSONResponse(
status_code=429,
content={"detail": "查询执行过于频繁,请稍后再试"}
@ -103,23 +281,10 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
key = f"global:{user_id}:{client_ip}"
else:
key = f"global:{client_ip}"
if self._is_rate_limited(key, now, self.rules["global"]):
if await self._backend.is_rate_limited(key, now, self.rules["global"]["max_requests"], self.rules["global"]["window_seconds"]):
return JSONResponse(
status_code=429,
content={"detail": "请求过于频繁,请稍后再试"}
)
return await call_next(request)
def _is_rate_limited(self, key, now, rule):
window = rule["window_seconds"]
max_req = rule["max_requests"]
# 清理过期记录
self._requests[key] = [t for t in self._requests[key] if now - t < window]
if len(self._requests[key]) >= max_req:
return True
self._requests[key].append(now)
return False

View File

@ -45,6 +45,7 @@ fpdf2>=2.7
# 监控
prometheus-client>=0.19.0
sentry-sdk[fastapi]>=2.0.0
# 文档解析
PyMuPDF>=1.23.0