refactor: unify Redis connection pool across all modules
- Create app/core/redis.py with global get_redis() singleton - Replace 4 independent Redis connections: - cache.py: use get_redis() instead of own aioredis.from_url - dispatcher.py: use get_redis() instead of own connection - health_checker.py: use get_redis() instead of per-check connection - rate_limit.py: RedisRateLimitBackend uses get_redis() instead of own connection - Replace main.py readiness endpoint to use get_redis() - Add close_redis() in FastAPI lifespan shutdown - Remove unused aioredis imports from health_checker.py and main.py
This commit is contained in:
parent
bdf351977b
commit
6b90fb5cd6
|
|
@ -34,22 +34,15 @@ class TaskDispatcher:
|
||||||
|
|
||||||
def __init__(self, redis_url: str):
|
def __init__(self, redis_url: str):
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._redis: aioredis.Redis | None = None
|
|
||||||
|
|
||||||
async def _get_redis(self) -> aioredis.Redis:
|
async def _get_redis(self):
|
||||||
"""获取 Redis 连接"""
|
"""获取全局 Redis 连接"""
|
||||||
if self._redis is None:
|
from app.core.redis import get_redis
|
||||||
self._redis = aioredis.from_url(
|
return await get_redis()
|
||||||
self._redis_url,
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
return self._redis
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭 Redis 连接"""
|
"""关闭 Redis 连接(由全局连接池管理,无需手动关闭)"""
|
||||||
if self._redis:
|
pass
|
||||||
await self._redis.close()
|
|
||||||
self._redis = None
|
|
||||||
|
|
||||||
async def dispatch(
|
async def dispatch(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
"""统一 Redis 连接管理。
|
||||||
|
|
||||||
|
提供全局 Redis 连接池单例,所有模块共享同一连接池,
|
||||||
|
避免多处独立创建连接导致资源浪费和连接数失控。
|
||||||
|
|
||||||
|
用法:
|
||||||
|
from app.core.redis import get_redis
|
||||||
|
|
||||||
|
redis = await get_redis()
|
||||||
|
await redis.set("key", "value")
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 全局 Redis 连接池单例
|
||||||
|
_redis: aioredis.Redis | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redis() -> aioredis.Redis:
|
||||||
|
"""获取全局 Redis 连接。
|
||||||
|
|
||||||
|
首次调用时根据 settings.REDIS_URL 创建连接池,
|
||||||
|
后续调用返回同一实例。REDIS_URL 为空时抛出 ValueError。
|
||||||
|
"""
|
||||||
|
global _redis
|
||||||
|
if _redis is None:
|
||||||
|
if not settings.REDIS_URL:
|
||||||
|
raise ValueError("REDIS_URL is not configured")
|
||||||
|
_redis = aioredis.from_url(
|
||||||
|
settings.REDIS_URL,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
)
|
||||||
|
logger.info("Redis connection pool created (url=%s)", _safe_url(settings.REDIS_URL))
|
||||||
|
return _redis
|
||||||
|
|
||||||
|
|
||||||
|
async def close_redis() -> None:
|
||||||
|
"""关闭全局 Redis 连接。在应用 shutdown 时调用。"""
|
||||||
|
global _redis
|
||||||
|
if _redis is not None:
|
||||||
|
await _redis.aclose()
|
||||||
|
_redis = None
|
||||||
|
logger.info("Redis connection pool closed")
|
||||||
|
|
||||||
|
|
||||||
|
def is_redis_configured() -> bool:
|
||||||
|
"""检查 Redis 是否已配置(REDIS_URL 非空)。"""
|
||||||
|
return bool(settings.REDIS_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_url(url: str) -> str:
|
||||||
|
"""隐藏 Redis URL 中的密码部分。"""
|
||||||
|
if "@" in url:
|
||||||
|
# redis://:password@host:port/db → redis://***@host:port/db
|
||||||
|
parts = url.split("@", 1)
|
||||||
|
prefix = parts[0].rsplit(":", 1)[0]
|
||||||
|
return f"{prefix}:***@{parts[1]}"
|
||||||
|
return url
|
||||||
|
|
@ -91,6 +91,10 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# 关闭全局 Redis 连接池
|
||||||
|
from app.core.redis import close_redis
|
||||||
|
await close_redis()
|
||||||
|
|
||||||
await query_scheduler.shutdown()
|
await query_scheduler.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -227,9 +231,6 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
|
||||||
供 Kubernetes readinessProbe / Docker healthcheck 使用。
|
供 Kubernetes readinessProbe / Docker healthcheck 使用。
|
||||||
不需要认证。
|
不需要认证。
|
||||||
"""
|
"""
|
||||||
import redis.asyncio as aioredis # type: ignore
|
|
||||||
from app.config import settings as _settings
|
|
||||||
|
|
||||||
# --- 检查数据库 ---
|
# --- 检查数据库 ---
|
||||||
try:
|
try:
|
||||||
await db.execute(text("SELECT 1"))
|
await db.execute(text("SELECT 1"))
|
||||||
|
|
@ -240,9 +241,9 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
|
||||||
# --- 检查 Redis ---
|
# --- 检查 Redis ---
|
||||||
redis_ok = False
|
redis_ok = False
|
||||||
try:
|
try:
|
||||||
redis_client = aioredis.from_url(_settings.REDIS_URL, socket_connect_timeout=2)
|
from app.core.redis import get_redis
|
||||||
|
redis_client = await get_redis()
|
||||||
await redis_client.ping()
|
await redis_client.ping()
|
||||||
await redis_client.aclose()
|
|
||||||
redis_ok = True
|
redis_ok = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -94,21 +94,12 @@ class RedisRateLimitBackend(RateLimitBackend):
|
||||||
|
|
||||||
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。
|
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。
|
||||||
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。
|
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。
|
||||||
|
使用全局统一 Redis 连接池。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, redis_url: str) -> None:
|
|
||||||
self._redis_url = redis_url
|
|
||||||
self._redis = None
|
|
||||||
|
|
||||||
async def _get_redis(self):
|
async def _get_redis(self):
|
||||||
if self._redis is None:
|
from app.core.redis import get_redis
|
||||||
import redis.asyncio as aioredis
|
return await get_redis()
|
||||||
self._redis = aioredis.from_url(
|
|
||||||
self._redis_url,
|
|
||||||
encoding="utf-8",
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
return self._redis
|
|
||||||
|
|
||||||
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|
@ -142,11 +133,6 @@ class RedisRateLimitBackend(RateLimitBackend):
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
|
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self._redis is not None:
|
|
||||||
await self._redis.aclose()
|
|
||||||
self._redis = None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# 辅助函数
|
# 辅助函数
|
||||||
|
|
@ -176,13 +162,12 @@ def _create_backend() -> RateLimitBackend:
|
||||||
backend_type = settings.RATE_LIMIT_BACKEND.lower()
|
backend_type = settings.RATE_LIMIT_BACKEND.lower()
|
||||||
|
|
||||||
if backend_type == "redis":
|
if backend_type == "redis":
|
||||||
redis_url = settings.REDIS_URL
|
if not settings.REDIS_URL:
|
||||||
if not redis_url:
|
|
||||||
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
|
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
|
||||||
return MemoryRateLimitBackend()
|
return MemoryRateLimitBackend()
|
||||||
try:
|
try:
|
||||||
backend = RedisRateLimitBackend(redis_url)
|
backend = RedisRateLimitBackend()
|
||||||
logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url)
|
logger.info("Rate limiter using Redis backend (shared connection pool)")
|
||||||
return backend
|
return backend
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)
|
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import redis.asyncio as aioredis
|
from app.core.redis import get_redis
|
||||||
|
|
||||||
from app.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -21,25 +19,19 @@ TTL_USER_PROFILE = 600 # 10 分钟
|
||||||
|
|
||||||
|
|
||||||
class CacheService:
|
class CacheService:
|
||||||
"""异步 Redis 缓存服务。"""
|
"""异步 Redis 缓存服务。使用全局统一 Redis 连接池。"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._redis: aioredis.Redis | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def redis(self) -> aioredis.Redis:
|
def redis(self):
|
||||||
if self._redis is None:
|
"""获取全局 Redis 连接(懒加载)。"""
|
||||||
self._redis = aioredis.from_url(
|
from app.core.redis import _redis
|
||||||
settings.REDIS_URL,
|
return _redis
|
||||||
encoding="utf-8",
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
return self._redis
|
|
||||||
|
|
||||||
async def get(self, key: str) -> str | None:
|
async def get(self, key: str) -> str | None:
|
||||||
"""从缓存读取字符串值,不存在或出错时返回 None。"""
|
"""从缓存读取字符串值,不存在或出错时返回 None。"""
|
||||||
try:
|
try:
|
||||||
return await self.redis.get(key)
|
redis = await get_redis()
|
||||||
|
return await redis.get(key)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Cache GET failed for key=%s: %s", key, exc)
|
logger.warning("Cache GET failed for key=%s: %s", key, exc)
|
||||||
return None
|
return None
|
||||||
|
|
@ -57,7 +49,8 @@ class CacheService:
|
||||||
async def set(self, key: str, value: str, expire: int = 300) -> None:
|
async def set(self, key: str, value: str, expire: int = 300) -> None:
|
||||||
"""写入缓存字符串值,expire 单位为秒。"""
|
"""写入缓存字符串值,expire 单位为秒。"""
|
||||||
try:
|
try:
|
||||||
await self.redis.set(key, value, ex=expire)
|
redis = await get_redis()
|
||||||
|
await redis.set(key, value, ex=expire)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Cache SET failed for key=%s: %s", key, exc)
|
logger.warning("Cache SET failed for key=%s: %s", key, exc)
|
||||||
|
|
||||||
|
|
@ -71,29 +64,25 @@ class CacheService:
|
||||||
async def delete(self, key: str) -> None:
|
async def delete(self, key: str) -> None:
|
||||||
"""删除指定缓存键。"""
|
"""删除指定缓存键。"""
|
||||||
try:
|
try:
|
||||||
await self.redis.delete(key)
|
redis = await get_redis()
|
||||||
|
await redis.delete(key)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Cache DELETE failed for key=%s: %s", key, exc)
|
logger.warning("Cache DELETE failed for key=%s: %s", key, exc)
|
||||||
|
|
||||||
async def invalidate_pattern(self, pattern: str) -> int:
|
async def invalidate_pattern(self, pattern: str) -> int:
|
||||||
"""批量删除匹配 pattern 的所有缓存键,返回删除数量。"""
|
"""批量删除匹配 pattern 的所有缓存键,返回删除数量。"""
|
||||||
try:
|
try:
|
||||||
keys = await self.redis.keys(pattern)
|
redis = await get_redis()
|
||||||
|
keys = await redis.keys(pattern)
|
||||||
if keys:
|
if keys:
|
||||||
return await self.redis.delete(*keys)
|
return await redis.delete(*keys)
|
||||||
return 0
|
return 0
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Cache INVALIDATE_PATTERN failed for pattern=%s: %s", pattern, exc)
|
logger.warning("Cache INVALIDATE_PATTERN failed for pattern=%s: %s", pattern, exc)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""关闭 Redis 连接。"""
|
|
||||||
if self._redis is not None:
|
|
||||||
await self._redis.aclose()
|
|
||||||
self._redis = None
|
|
||||||
|
|
||||||
|
# 模块级单例
|
||||||
# 模块级单例(懒加载,应用启动后自动创建连接池)
|
|
||||||
_cache_service: CacheService | None = None
|
_cache_service: CacheService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,12 +51,9 @@ class HealthChecker:
|
||||||
"""检查Redis连接"""
|
"""检查Redis连接"""
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
redis = aioredis.from_url(
|
from app.core.redis import get_redis
|
||||||
self.redis_url,
|
redis = await get_redis()
|
||||||
socket_connect_timeout=2,
|
|
||||||
)
|
|
||||||
await redis.ping()
|
await redis.ping()
|
||||||
await redis.aclose()
|
|
||||||
|
|
||||||
latency = (time.perf_counter() - start) * 1000
|
latency = (time.perf_counter() - start) * 1000
|
||||||
return HealthCheckResult(
|
return HealthCheckResult(
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class TestRedisRateLimitBackendFallback:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_unavailable_allows_request(self):
|
async def test_redis_unavailable_allows_request(self):
|
||||||
"""Redis 不可用时放行请求(不阻塞服务)。"""
|
"""Redis 不可用时放行请求(不阻塞服务)。"""
|
||||||
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
|
backend = RedisRateLimitBackend()
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# Redis 连接失败时应返回 False(不被限流)
|
# Redis 连接失败时应返回 False(不被限流)
|
||||||
result = await backend.is_rate_limited("test:key", now, 5, 60)
|
result = await backend.is_rate_limited("test:key", now, 5, 60)
|
||||||
|
|
@ -121,7 +121,7 @@ class TestRedisRateLimitBackendFallback:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_reset_silently_fails(self):
|
async def test_redis_reset_silently_fails(self):
|
||||||
"""Redis reset 失败时静默处理。"""
|
"""Redis reset 失败时静默处理。"""
|
||||||
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
|
backend = RedisRateLimitBackend()
|
||||||
# 不应抛出异常
|
# 不应抛出异常
|
||||||
await backend.reset("test:key")
|
await backend.reset("test:key")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue