refactor: unify Redis connection pool across all modules

- Create app/core/redis.py with global get_redis() singleton
- Replace 4 independent Redis connections:
  - cache.py: use get_redis() instead of own aioredis.from_url
  - dispatcher.py: use get_redis() instead of own connection
  - health_checker.py: use get_redis() instead of per-check connection
  - rate_limit.py: RedisRateLimitBackend uses get_redis() instead of own connection
- Replace main.py readiness endpoint to use get_redis()
- Add close_redis() in FastAPI lifespan shutdown
- Remove unused aioredis imports from health_checker.py and main.py
This commit is contained in:
chiguyong 2026-06-04 14:21:14 +08:00
parent bdf351977b
commit 6b90fb5cd6
7 changed files with 102 additions and 73 deletions

View File

@ -34,22 +34,15 @@ class TaskDispatcher:
def __init__(self, redis_url: str):
self._redis_url = redis_url
self._redis: aioredis.Redis | None = None
async def _get_redis(self) -> aioredis.Redis:
"""获取 Redis 连接"""
if self._redis is None:
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
)
return self._redis
async def _get_redis(self):
"""获取全局 Redis 连接"""
from app.core.redis import get_redis
return await get_redis()
async def close(self):
"""关闭 Redis 连接"""
if self._redis:
await self._redis.close()
self._redis = None
"""关闭 Redis 连接(由全局连接池管理,无需手动关闭)"""
pass
async def dispatch(
self,

64
backend/app/core/redis.py Normal file
View File

@ -0,0 +1,64 @@
"""统一 Redis 连接管理。
提供全局 Redis 连接池单例所有模块共享同一连接池
避免多处独立创建连接导致资源浪费和连接数失控
用法
from app.core.redis import get_redis
redis = await get_redis()
await redis.set("key", "value")
"""
import logging
import redis.asyncio as aioredis
from app.config import settings
logger = logging.getLogger(__name__)
# 全局 Redis 连接池单例
_redis: aioredis.Redis | None = None
async def get_redis() -> aioredis.Redis:
"""获取全局 Redis 连接。
首次调用时根据 settings.REDIS_URL 创建连接池
后续调用返回同一实例REDIS_URL 为空时抛出 ValueError
"""
global _redis
if _redis is None:
if not settings.REDIS_URL:
raise ValueError("REDIS_URL is not configured")
_redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
logger.info("Redis connection pool created (url=%s)", _safe_url(settings.REDIS_URL))
return _redis
async def close_redis() -> None:
"""关闭全局 Redis 连接。在应用 shutdown 时调用。"""
global _redis
if _redis is not None:
await _redis.aclose()
_redis = None
logger.info("Redis connection pool closed")
def is_redis_configured() -> bool:
"""检查 Redis 是否已配置REDIS_URL 非空)。"""
return bool(settings.REDIS_URL)
def _safe_url(url: str) -> str:
"""隐藏 Redis URL 中的密码部分。"""
if "@" in url:
# redis://:password@host:port/db → redis://***@host:port/db
parts = url.split("@", 1)
prefix = parts[0].rsplit(":", 1)[0]
return f"{prefix}:***@{parts[1]}"
return url

View File

@ -91,6 +91,10 @@ async def lifespan(app: FastAPI):
yield
# 关闭全局 Redis 连接池
from app.core.redis import close_redis
await close_redis()
await query_scheduler.shutdown()
@ -227,9 +231,6 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
Kubernetes readinessProbe / Docker healthcheck 使用
不需要认证
"""
import redis.asyncio as aioredis # type: ignore
from app.config import settings as _settings
# --- 检查数据库 ---
try:
await db.execute(text("SELECT 1"))
@ -240,9 +241,9 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
# --- 检查 Redis ---
redis_ok = False
try:
redis_client = aioredis.from_url(_settings.REDIS_URL, socket_connect_timeout=2)
from app.core.redis import get_redis
redis_client = await get_redis()
await redis_client.ping()
await redis_client.aclose()
redis_ok = True
except Exception:
pass

View File

@ -94,21 +94,12 @@ class RedisRateLimitBackend(RateLimitBackend):
使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流
Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性
使用全局统一 Redis 连接池
"""
def __init__(self, redis_url: str) -> None:
self._redis_url = redis_url
self._redis = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(
self._redis_url,
encoding="utf-8",
decode_responses=True,
)
return self._redis
from app.core.redis import get_redis
return await get_redis()
async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool:
try:
@ -142,11 +133,6 @@ class RedisRateLimitBackend(RateLimitBackend):
except Exception as exc:
logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc)
async def close(self) -> None:
if self._redis is not None:
await self._redis.aclose()
self._redis = None
# ---------------------------------------------------------------------------
# 辅助函数
@ -176,13 +162,12 @@ def _create_backend() -> RateLimitBackend:
backend_type = settings.RATE_LIMIT_BACKEND.lower()
if backend_type == "redis":
redis_url = settings.REDIS_URL
if not redis_url:
if not settings.REDIS_URL:
logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend")
return MemoryRateLimitBackend()
try:
backend = RedisRateLimitBackend(redis_url)
logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url)
backend = RedisRateLimitBackend()
logger.info("Rate limiter using Redis backend (shared connection pool)")
return backend
except Exception as exc:
logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc)

View File

@ -8,9 +8,7 @@
import json
import logging
import redis.asyncio as aioredis
from app.config import settings
from app.core.redis import get_redis
logger = logging.getLogger(__name__)
@ -21,25 +19,19 @@ TTL_USER_PROFILE = 600 # 10 分钟
class CacheService:
"""异步 Redis 缓存服务。"""
def __init__(self) -> None:
self._redis: aioredis.Redis | None = None
"""异步 Redis 缓存服务。使用全局统一 Redis 连接池。"""
@property
def redis(self) -> aioredis.Redis:
if self._redis is None:
self._redis = aioredis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)
return self._redis
def redis(self):
"""获取全局 Redis 连接(懒加载)。"""
from app.core.redis import _redis
return _redis
async def get(self, key: str) -> str | None:
"""从缓存读取字符串值,不存在或出错时返回 None。"""
try:
return await self.redis.get(key)
redis = await get_redis()
return await redis.get(key)
except Exception as exc:
logger.warning("Cache GET failed for key=%s: %s", key, exc)
return None
@ -57,7 +49,8 @@ class CacheService:
async def set(self, key: str, value: str, expire: int = 300) -> None:
"""写入缓存字符串值expire 单位为秒。"""
try:
await self.redis.set(key, value, ex=expire)
redis = await get_redis()
await redis.set(key, value, ex=expire)
except Exception as exc:
logger.warning("Cache SET failed for key=%s: %s", key, exc)
@ -71,29 +64,25 @@ class CacheService:
async def delete(self, key: str) -> None:
"""删除指定缓存键。"""
try:
await self.redis.delete(key)
redis = await get_redis()
await redis.delete(key)
except Exception as exc:
logger.warning("Cache DELETE failed for key=%s: %s", key, exc)
async def invalidate_pattern(self, pattern: str) -> int:
"""批量删除匹配 pattern 的所有缓存键,返回删除数量。"""
try:
keys = await self.redis.keys(pattern)
redis = await get_redis()
keys = await redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return await redis.delete(*keys)
return 0
except Exception as exc:
logger.warning("Cache INVALIDATE_PATTERN failed for pattern=%s: %s", pattern, exc)
return 0
async def close(self) -> None:
"""关闭 Redis 连接。"""
if self._redis is not None:
await self._redis.aclose()
self._redis = None
# 模块级单例(懒加载,应用启动后自动创建连接池)
# 模块级单例
_cache_service: CacheService | None = None

View File

@ -51,12 +51,9 @@ class HealthChecker:
"""检查Redis连接"""
start = time.perf_counter()
try:
redis = aioredis.from_url(
self.redis_url,
socket_connect_timeout=2,
)
from app.core.redis import get_redis
redis = await get_redis()
await redis.ping()
await redis.aclose()
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(

View File

@ -112,7 +112,7 @@ class TestRedisRateLimitBackendFallback:
@pytest.mark.asyncio
async def test_redis_unavailable_allows_request(self):
"""Redis 不可用时放行请求(不阻塞服务)。"""
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
backend = RedisRateLimitBackend()
now = time.time()
# Redis 连接失败时应返回 False不被限流
result = await backend.is_rate_limited("test:key", now, 5, 60)
@ -121,7 +121,7 @@ class TestRedisRateLimitBackendFallback:
@pytest.mark.asyncio
async def test_redis_reset_silently_fails(self):
"""Redis reset 失败时静默处理。"""
backend = RedisRateLimitBackend("redis://nonexistent:6379/0")
backend = RedisRateLimitBackend()
# 不应抛出异常
await backend.reset("test:key")