From 6b90fb5cd6a5a5acf87393a24a3f0e3659e8fae0 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 4 Jun 2026 14:21:14 +0800 Subject: [PATCH] refactor: unify Redis connection pool across all modules - Create app/core/redis.py with global get_redis() singleton - Replace 4 independent Redis connections: - cache.py: use get_redis() instead of own aioredis.from_url - dispatcher.py: use get_redis() instead of own connection - health_checker.py: use get_redis() instead of per-check connection - rate_limit.py: RedisRateLimitBackend uses get_redis() instead of own connection - Replace main.py readiness endpoint to use get_redis() - Add close_redis() in FastAPI lifespan shutdown - Remove unused aioredis imports from health_checker.py and main.py --- backend/app/agent_framework/dispatcher.py | 19 ++---- backend/app/core/redis.py | 64 +++++++++++++++++++ backend/app/main.py | 11 ++-- backend/app/middleware/rate_limit.py | 27 ++------ backend/app/services/cache.py | 43 +++++-------- backend/app/services/health_checker.py | 7 +- .../test_infrastructure/test_rate_limit.py | 4 +- 7 files changed, 102 insertions(+), 73 deletions(-) create mode 100644 backend/app/core/redis.py diff --git a/backend/app/agent_framework/dispatcher.py b/backend/app/agent_framework/dispatcher.py index a3bee2f..2b342ac 100644 --- a/backend/app/agent_framework/dispatcher.py +++ b/backend/app/agent_framework/dispatcher.py @@ -34,22 +34,15 @@ class TaskDispatcher: def __init__(self, redis_url: str): self._redis_url = redis_url - self._redis: aioredis.Redis | None = None - async def _get_redis(self) -> aioredis.Redis: - """获取 Redis 连接""" - if self._redis is None: - self._redis = aioredis.from_url( - self._redis_url, - decode_responses=True, - ) - return self._redis + async def _get_redis(self): + """获取全局 Redis 连接""" + from app.core.redis import get_redis + return await get_redis() async def close(self): - """关闭 Redis 连接""" - if self._redis: - await self._redis.close() - self._redis = None + """关闭 Redis 连接(由全局连接池管理,无需手动关闭)""" + pass async def dispatch( self, diff --git a/backend/app/core/redis.py b/backend/app/core/redis.py new file mode 100644 index 0000000..b2c6647 --- /dev/null +++ b/backend/app/core/redis.py @@ -0,0 +1,64 @@ +"""统一 Redis 连接管理。 + +提供全局 Redis 连接池单例,所有模块共享同一连接池, +避免多处独立创建连接导致资源浪费和连接数失控。 + +用法: + from app.core.redis import get_redis + + redis = await get_redis() + await redis.set("key", "value") +""" +import logging + +import redis.asyncio as aioredis + +from app.config import settings + +logger = logging.getLogger(__name__) + +# 全局 Redis 连接池单例 +_redis: aioredis.Redis | None = None + + +async def get_redis() -> aioredis.Redis: + """获取全局 Redis 连接。 + + 首次调用时根据 settings.REDIS_URL 创建连接池, + 后续调用返回同一实例。REDIS_URL 为空时抛出 ValueError。 + """ + global _redis + if _redis is None: + if not settings.REDIS_URL: + raise ValueError("REDIS_URL is not configured") + _redis = aioredis.from_url( + settings.REDIS_URL, + encoding="utf-8", + decode_responses=True, + ) + logger.info("Redis connection pool created (url=%s)", _safe_url(settings.REDIS_URL)) + return _redis + + +async def close_redis() -> None: + """关闭全局 Redis 连接。在应用 shutdown 时调用。""" + global _redis + if _redis is not None: + await _redis.aclose() + _redis = None + logger.info("Redis connection pool closed") + + +def is_redis_configured() -> bool: + """检查 Redis 是否已配置(REDIS_URL 非空)。""" + return bool(settings.REDIS_URL) + + +def _safe_url(url: str) -> str: + """隐藏 Redis URL 中的密码部分。""" + if "@" in url: + # redis://:password@host:port/db → redis://***@host:port/db + parts = url.split("@", 1) + prefix = parts[0].rsplit(":", 1)[0] + return f"{prefix}:***@{parts[1]}" + return url diff --git a/backend/app/main.py b/backend/app/main.py index 3d06603..4540b40 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -91,6 +91,10 @@ async def lifespan(app: FastAPI): yield + # 关闭全局 Redis 连接池 + from app.core.redis import close_redis + await close_redis() + await query_scheduler.shutdown() @@ -227,9 +231,6 @@ async def readiness_check(db: AsyncSession = Depends(get_db)): 供 Kubernetes readinessProbe / Docker healthcheck 使用。 不需要认证。 """ - import redis.asyncio as aioredis # type: ignore - from app.config import settings as _settings - # --- 检查数据库 --- try: await db.execute(text("SELECT 1")) @@ -240,9 +241,9 @@ async def readiness_check(db: AsyncSession = Depends(get_db)): # --- 检查 Redis --- redis_ok = False try: - redis_client = aioredis.from_url(_settings.REDIS_URL, socket_connect_timeout=2) + from app.core.redis import get_redis + redis_client = await get_redis() await redis_client.ping() - await redis_client.aclose() redis_ok = True except Exception: pass diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py index 3f0018b..6faded7 100644 --- a/backend/app/middleware/rate_limit.py +++ b/backend/app/middleware/rate_limit.py @@ -94,21 +94,12 @@ class RedisRateLimitBackend(RateLimitBackend): 使用 Redis Sorted Set + ZRANGEBYSCORE 实现滑动窗口限流。 Pipeline 保证 ZADD + ZRANGEBYSCORE + ZREMRANGEBYSCORE 的原子性。 + 使用全局统一 Redis 连接池。 """ - def __init__(self, redis_url: str) -> None: - self._redis_url = redis_url - self._redis = None - async def _get_redis(self): - if self._redis is None: - import redis.asyncio as aioredis - self._redis = aioredis.from_url( - self._redis_url, - encoding="utf-8", - decode_responses=True, - ) - return self._redis + from app.core.redis import get_redis + return await get_redis() async def is_rate_limited(self, key: str, now: float, max_requests: int, window_seconds: int) -> bool: try: @@ -142,11 +133,6 @@ class RedisRateLimitBackend(RateLimitBackend): except Exception as exc: logger.warning("Redis rate limit reset failed for key=%s: %s", key, exc) - async def close(self) -> None: - if self._redis is not None: - await self._redis.aclose() - self._redis = None - # --------------------------------------------------------------------------- # 辅助函数 @@ -176,13 +162,12 @@ def _create_backend() -> RateLimitBackend: backend_type = settings.RATE_LIMIT_BACKEND.lower() if backend_type == "redis": - redis_url = settings.REDIS_URL - if not redis_url: + if not settings.REDIS_URL: logger.warning("RATE_LIMIT_BACKEND=redis but REDIS_URL is empty, falling back to memory backend") return MemoryRateLimitBackend() try: - backend = RedisRateLimitBackend(redis_url) - logger.info("Rate limiter using Redis backend (url=%s)", redis_url.split("@")[-1] if "@" in redis_url else redis_url) + backend = RedisRateLimitBackend() + logger.info("Rate limiter using Redis backend (shared connection pool)") return backend except Exception as exc: logger.warning("Failed to create Redis rate limit backend: %s, falling back to memory", exc) diff --git a/backend/app/services/cache.py b/backend/app/services/cache.py index aa6b433..6da6de5 100644 --- a/backend/app/services/cache.py +++ b/backend/app/services/cache.py @@ -8,9 +8,7 @@ import json import logging -import redis.asyncio as aioredis - -from app.config import settings +from app.core.redis import get_redis logger = logging.getLogger(__name__) @@ -21,25 +19,19 @@ TTL_USER_PROFILE = 600 # 10 分钟 class CacheService: - """异步 Redis 缓存服务。""" - - def __init__(self) -> None: - self._redis: aioredis.Redis | None = None + """异步 Redis 缓存服务。使用全局统一 Redis 连接池。""" @property - def redis(self) -> aioredis.Redis: - if self._redis is None: - self._redis = aioredis.from_url( - settings.REDIS_URL, - encoding="utf-8", - decode_responses=True, - ) - return self._redis + def redis(self): + """获取全局 Redis 连接(懒加载)。""" + from app.core.redis import _redis + return _redis async def get(self, key: str) -> str | None: """从缓存读取字符串值,不存在或出错时返回 None。""" try: - return await self.redis.get(key) + redis = await get_redis() + return await redis.get(key) except Exception as exc: logger.warning("Cache GET failed for key=%s: %s", key, exc) return None @@ -57,7 +49,8 @@ class CacheService: async def set(self, key: str, value: str, expire: int = 300) -> None: """写入缓存字符串值,expire 单位为秒。""" try: - await self.redis.set(key, value, ex=expire) + redis = await get_redis() + await redis.set(key, value, ex=expire) except Exception as exc: logger.warning("Cache SET failed for key=%s: %s", key, exc) @@ -71,29 +64,25 @@ class CacheService: async def delete(self, key: str) -> None: """删除指定缓存键。""" try: - await self.redis.delete(key) + redis = await get_redis() + await redis.delete(key) except Exception as exc: logger.warning("Cache DELETE failed for key=%s: %s", key, exc) async def invalidate_pattern(self, pattern: str) -> int: """批量删除匹配 pattern 的所有缓存键,返回删除数量。""" try: - keys = await self.redis.keys(pattern) + redis = await get_redis() + keys = await redis.keys(pattern) if keys: - return await self.redis.delete(*keys) + return await redis.delete(*keys) return 0 except Exception as exc: logger.warning("Cache INVALIDATE_PATTERN failed for pattern=%s: %s", pattern, exc) return 0 - async def close(self) -> None: - """关闭 Redis 连接。""" - if self._redis is not None: - await self._redis.aclose() - self._redis = None - -# 模块级单例(懒加载,应用启动后自动创建连接池) +# 模块级单例 _cache_service: CacheService | None = None diff --git a/backend/app/services/health_checker.py b/backend/app/services/health_checker.py index 673e8b5..63dad39 100644 --- a/backend/app/services/health_checker.py +++ b/backend/app/services/health_checker.py @@ -51,12 +51,9 @@ class HealthChecker: """检查Redis连接""" start = time.perf_counter() try: - redis = aioredis.from_url( - self.redis_url, - socket_connect_timeout=2, - ) + from app.core.redis import get_redis + redis = await get_redis() await redis.ping() - await redis.aclose() latency = (time.perf_counter() - start) * 1000 return HealthCheckResult( diff --git a/backend/tests/test_infrastructure/test_rate_limit.py b/backend/tests/test_infrastructure/test_rate_limit.py index ecb8704..0af2d9d 100644 --- a/backend/tests/test_infrastructure/test_rate_limit.py +++ b/backend/tests/test_infrastructure/test_rate_limit.py @@ -112,7 +112,7 @@ class TestRedisRateLimitBackendFallback: @pytest.mark.asyncio async def test_redis_unavailable_allows_request(self): """Redis 不可用时放行请求(不阻塞服务)。""" - backend = RedisRateLimitBackend("redis://nonexistent:6379/0") + backend = RedisRateLimitBackend() now = time.time() # Redis 连接失败时应返回 False(不被限流) result = await backend.is_rate_limited("test:key", now, 5, 60) @@ -121,7 +121,7 @@ class TestRedisRateLimitBackendFallback: @pytest.mark.asyncio async def test_redis_reset_silently_fails(self): """Redis reset 失败时静默处理。""" - backend = RedisRateLimitBackend("redis://nonexistent:6379/0") + backend = RedisRateLimitBackend() # 不应抛出异常 await backend.reset("test:key")