refactor: unify Redis connection pool across all modules
- Create app/core/redis.py with global get_redis() singleton - Replace 4 independent Redis connections: - cache.py: use get_redis() instead of own aioredis.from_url - dispatcher.py: use get_redis() instead of own connection - health_checker.py: use get_redis() instead of per-check connection - rate_limit.py: RedisRateLimitBackend uses get_redis() instead of own connection - Replace main.py readiness endpoint to use get_redis() - Add close_redis() in FastAPI lifespan shutdown - Remove unused aioredis imports from health_checker.py and main.py
This commit is contained in:
parent
bdf351977b
commit
6b90fb5cd6
|
|
@ -34,22 +34,15 @@ class TaskDispatcher:
|
|||
|
||||
def __init__(self, redis_url: str):
|
||||
self._redis_url = redis_url
|
||||
self._redis: aioredis.Redis | None = None
|
||||
|
||||
async def _get_redis(self) -> aioredis.Redis:
|
||||
"""获取 Redis 连接"""
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
async def _get_redis(self):
|
||||
"""获取全局 Redis 连接"""
|
||||
from app.core.redis import get_redis
|
||||
return await get_redis()
|
||||
|
||||
async def close(self):
|
||||
"""关闭 Redis 连接"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
"""关闭 Redis 连接(由全局连接池管理,无需手动关闭)"""
|
||||
pass
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
"""统一 Redis 连接管理。
|
||||
|
||||
提供全局 Redis 连接池单例,所有模块共享同一连接池,
|
||||
避免多处独立创建连接导致资源浪费和连接数失控。
|
||||
|
||||
用法:
|
||||
from app.core.redis import get_redis
|
||||
|
||||
redis = await get_redis()
|
||||
await redis.set("key", "value")
|
||||
"""
|
||||
import logging
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局 Redis 连接池单例
|
||||
_redis: aioredis.Redis | None = None
|
||||
|
||||
|
||||
async def get_redis() -> aioredis.Redis:
|
||||
"""获取全局 Redis 连接。
|
||||
|
||||
首次调用时根据 settings.REDIS_URL 创建连接池,
|
||||
后续调用返回同一实例。REDIS_URL 为空时抛出 ValueError。
|
||||
"""
|
||||
global _redis
|
||||
if _redis is None:
|
||||
if not settings.REDIS_URL:
|
||||
raise ValueError("REDIS_URL is not configured")
|
||||
_redis = aioredis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
logger.info("Redis connection pool created (url=%s)", _safe_url(settings.REDIS_URL))
|
||||
return _redis
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
"""关闭全局 Redis 连接。在应用 shutdown 时调用。"""
|
||||
global _redis
|
||||
if _redis is not None:
|
||||
await _redis.aclose()
|
||||
_redis = None
|
||||
logger.info("Redis connection pool closed")
|
||||
|
||||
|
||||
def is_redis_configured() -> bool:
|
||||
"""检查 Redis 是否已配置(REDIS_URL 非空)。"""
|
||||
return bool(settings.REDIS_URL)
|
||||
|
||||
|
||||
def _safe_url(url: str) -> str:
|
||||
"""隐藏 Redis URL 中的密码部分。"""
|
||||
if "@" in url:
|
||||
# redis://:password@host:port/db → redis://***@host:port/db
|
||||
parts = url.split("@", 1)
|
||||
prefix = parts[0].rsplit(":", 1)[0]
|
||||
return f"{prefix}:***@{parts[1]}"
|
||||
return url
|
||||
|
|
@ -91,6 +91,10 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
yield
|
||||
|
||||
# 关闭全局 Redis 连接池
|
||||
from app.core.redis import close_redis
|
||||
await close_redis()
|
||||
|
||||
await query_scheduler.shutdown()
|
||||
|
||||
|
||||
|
|
@ -227,9 +231,6 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
|
|||
供 Kubernetes readinessProbe / Docker healthcheck 使用。
|
||||
不需要认证。
|
||||
"""
|
||||
import redis.asyncio as aioredis # type: ignore
|
||||
from app.config import settings as _settings
|
||||
|
||||
# --- 检查数据库 ---
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
|
|
@ -240,9 +241,9 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
|
|||
# --- 检查 Redis ---
|
||||
redis_ok = False
|
||||
try:
|
||||
redis_client = aioredis.from_url(_settings.REDIS_URL, socket_connect_timeout=2)
|
||||
from app.core.redis import get_redis
|
||||
redis_client = await get_redis()
|
||||
await redis_client.ping()
|
||||
await redis_client.aclose()
|
||||
redis_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -94,21 +94,12 @@ class RedisRateLimitBackend(RateLimitBackend):
|
|||
|
||||
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。
|
||||
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。
|
||||
使用全局统一 Redis 连接池。
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: str) -> None:
|
||||
self._redis_url = redis_url
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
from app.core.redis import get_redis
|
||||
return await get_redis()
|
||||
|
||||
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
|
||||
try:
|
||||
|
|
@ -142,11 +133,6 @@ class RedisRateLimitBackend(RateLimitBackend):
|
|||
except Exception as exc:
|
||||
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._redis is not None:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
|
|
@ -176,13 +162,12 @@ def _create_backend() -> RateLimitBackend:
|
|||
backend_type = settings.RATE_LIMIT_BACKEND.lower()
|
||||
|
||||
if backend_type == "redis":
|
||||
redis_url = settings.REDIS_URL
|
||||
if not redis_url:
|
||||
if not settings.REDIS_URL:
|
||||
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
|
||||
return MemoryRateLimitBackend()
|
||||
try:
|
||||
backend = RedisRateLimitBackend(redis_url)
|
||||
logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url)
|
||||
backend = RedisRateLimitBackend()
|
||||
logger.info("Rate limiter using Redis backend (shared connection pool)")
|
||||
return backend
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@
|
|||
import json
|
||||
import logging
|
||||
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from app.config import settings
|
||||
from app.core.redis import get_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,25 +19,19 @@ TTL_USER_PROFILE = 600 # 10 分钟
|
|||
|
||||
|
||||
class CacheService:
|
||||
"""异步 Redis 缓存服务。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._redis: aioredis.Redis | None = None
|
||||
"""异步 Redis 缓存服务。使用全局统一 Redis 连接池。"""
|
||||
|
||||
@property
|
||||
def redis(self) -> aioredis.Redis:
|
||||
if self._redis is None:
|
||||
self._redis = aioredis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
def redis(self):
|
||||
"""获取全局 Redis 连接(懒加载)。"""
|
||||
from app.core.redis import _redis
|
||||
return _redis
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
"""从缓存读取字符串值,不存在或出错时返回 None。"""
|
||||
try:
|
||||
return await self.redis.get(key)
|
||||
redis = await get_redis()
|
||||
return await redis.get(key)
|
||||
except Exception as exc:
|
||||
logger.warning("Cache GET failed for key=%s: %s", key, exc)
|
||||
return None
|
||||
|
|
@ -57,7 +49,8 @@ class CacheService:
|
|||
async def set(self, key: str, value: str, expire: int = 300) -> None:
|
||||
"""写入缓存字符串值,expire 单位为秒。"""
|
||||
try:
|
||||
await self.redis.set(key, value, ex=expire)
|
||||
redis = await get_redis()
|
||||
await redis.set(key, value, ex=expire)
|
||||
except Exception as exc:
|
||||
logger.warning("Cache SET failed for key=%s: %s", key, exc)
|
||||
|
||||
|
|
@ -71,29 +64,25 @@ class CacheService:
|
|||
async def delete(self, key: str) -> None:
|
||||
"""删除指定缓存键。"""
|
||||
try:
|
||||
await self.redis.delete(key)
|
||||
redis = await get_redis()
|
||||
await redis.delete(key)
|
||||
except Exception as exc:
|
||||
logger.warning("Cache DELETE failed for key=%s: %s", key, exc)
|
||||
|
||||
async def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""批量删除匹配 pattern 的所有缓存键,返回删除数量。"""
|
||||
try:
|
||||
keys = await self.redis.keys(pattern)
|
||||
redis = await get_redis()
|
||||
keys = await redis.keys(pattern)
|
||||
if keys:
|
||||
return await self.redis.delete(*keys)
|
||||
return await redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as exc:
|
||||
logger.warning("Cache INVALIDATE_PATTERN failed for pattern=%s: %s", pattern, exc)
|
||||
return 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 Redis 连接。"""
|
||||
if self._redis is not None:
|
||||
await self._redis.aclose()
|
||||
self._redis = None
|
||||
|
||||
|
||||
# 模块级单例(懒加载,应用启动后自动创建连接池)
|
||||
# 模块级单例
|
||||
_cache_service: CacheService | None = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -51,12 +51,9 @@ class HealthChecker:
|
|||
"""检查Redis连接"""
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
redis = aioredis.from_url(
|
||||
self.redis_url,
|
||||
socket_connect_timeout=2,
|
||||
)
|
||||
from app.core.redis import get_redis
|
||||
redis = await get_redis()
|
||||
await redis.ping()
|
||||
await redis.aclose()
|
||||
|
||||
latency = (time.perf_counter() - start) * 1000
|
||||
return HealthCheckResult(
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class TestRedisRateLimitBackendFallback:
|
|||
@pytest.mark.asyncio
|
||||
async def test_redis_unavailable_allows_request(self):
|
||||
"""Redis 不可用时放行请求(不阻塞服务)。"""
|
||||
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
|
||||
backend = RedisRateLimitBackend()
|
||||
now = time.time()
|
||||
# Redis 连接失败时应返回 False(不被限流)
|
||||
result = await backend.is_rate_limited("test:key", now, 5, 60)
|
||||
|
|
@ -121,7 +121,7 @@ class TestRedisRateLimitBackendFallback:
|
|||
@pytest.mark.asyncio
|
||||
async def test_redis_reset_silently_fails(self):
|
||||
"""Redis reset 失败时静默处理。"""
|
||||
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
|
||||
backend = RedisRateLimitBackend()
|
||||
# 不应抛出异常
|
||||
await backend.reset("test:key")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue