"""Cascade State Store — Persistent cascade detection state with Redis INCR backend. Provides CascadeStateStore Protocol with InMemoryCascadeStateStore and RedisCascadeStateStore backends. Enables CascadeDetector state to survive restarts and work across multiple instances. Key schema (Redis): agentkit:cascade:interactions:{session_id} → INCR counter with TTL agentkit:cascade:depths:{session_id} → SET counter with TTL """ import logging import time from typing import Any, Protocol, runtime_checkable # redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义(降级到 fallback) try: from redis.exceptions import RedisError as _RedisError except ImportError: _RedisError = Exception logger = logging.getLogger(__name__) @runtime_checkable class CascadeStateStore(Protocol): """Persistent cascade detection state interface.""" def increment_interaction(self, session_id: str) -> int: """Atomically increment interaction count. Returns new count.""" ... def get_interaction(self, session_id: str) -> int: """Get current interaction count.""" ... def set_depth(self, session_id: str, depth: int) -> None: """Set loop depth for a session.""" ... def get_depth(self, session_id: str) -> int: """Get current loop depth.""" ... def reset(self, session_id: str) -> None: """Reset all counters for a session.""" ... # --------------------------------------------------------------------------- # InMemoryCascadeStateStore # --------------------------------------------------------------------------- class InMemoryCascadeStateStore: """In-memory cascade state store (default, process-local). Supports optional session TTL to prevent unbounded memory growth. Expired entries are lazily cleaned up on access. """ DEFAULT_SESSION_TTL = 86400 # 24 hours def __init__(self, session_ttl: int = 86400): self._session_ttl = session_ttl self._interaction_counts: dict[str, int] = {} self._loop_depths: dict[str, int] = {} self._timestamps: dict[str, float] = {} def _is_expired(self, session_id: str) -> bool: ts = self._timestamps.get(session_id) if ts is None: return False return (time.monotonic() - ts) > self._session_ttl def _cleanup_expired(self) -> None: """Lazy cleanup: remove expired sessions.""" expired = [sid for sid in self._timestamps if self._is_expired(sid)] for sid in expired: self._interaction_counts.pop(sid, None) self._loop_depths.pop(sid, None) self._timestamps.pop(sid, None) def _touch(self, session_id: str) -> None: self._timestamps[session_id] = time.monotonic() def increment_interaction(self, session_id: str) -> int: self._cleanup_expired() self._interaction_counts[session_id] = self._interaction_counts.get(session_id, 0) + 1 self._touch(session_id) return self._interaction_counts[session_id] def get_interaction(self, session_id: str) -> int: if self._is_expired(session_id): self.reset(session_id) return 0 return self._interaction_counts.get(session_id, 0) def set_depth(self, session_id: str, depth: int) -> None: self._touch(session_id) self._loop_depths[session_id] = depth def get_depth(self, session_id: str) -> int: if self._is_expired(session_id): self.reset(session_id) return 0 return self._loop_depths.get(session_id, 0) def reset(self, session_id: str) -> None: self._interaction_counts.pop(session_id, None) self._loop_depths.pop(session_id, None) self._timestamps.pop(session_id, None) # --------------------------------------------------------------------------- # RedisCascadeStateStore # --------------------------------------------------------------------------- class RedisCascadeStateStore: """Redis-backed cascade state store using INCR for atomic increments. Key schema: agentkit:cascade:interactions:{session_id} → INCR counter with TTL agentkit:cascade:depths:{session_id} → SET counter with TTL """ INTER_PREFIX = "agentkit:cascade:interactions:" DEPTH_PREFIX = "agentkit:cascade:depths:" SESSION_TTL = 86400 # 24 hours — sessions rarely last longer def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400): self._redis_url = redis_url self._session_ttl = session_ttl self._sync_redis: Any = None self._fallback: InMemoryCascadeStateStore | None = None self._degraded = False def _get_sync_redis(self): """Get or create a persistent sync Redis client (connection pool backed).""" if self._sync_redis is None: import redis as sync_redis self._sync_redis = sync_redis.from_url( self._redis_url, decode_responses=True ) return self._sync_redis def _degrade_to_fallback(self) -> None: if not self._degraded: self._degraded = True if self._fallback is None: self._fallback = InMemoryCascadeStateStore(session_ttl=self._session_ttl) logger.warning("Redis cascade store unreachable, degraded to in-memory") def increment_interaction(self, session_id: str) -> int: if self._degraded and self._fallback is not None: return self._fallback.increment_interaction(session_id) try: r = self._get_sync_redis() key = f"{self.INTER_PREFIX}{session_id}" pipe = r.pipeline() pipe.incr(key) pipe.expire(key, self._session_ttl) results = pipe.execute() return results[0] except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e: logger.warning(f"Redis cascade increment failed: {e}") self._degrade_to_fallback() if self._fallback is not None: return self._fallback.increment_interaction(session_id) return 0 def get_interaction(self, session_id: str) -> int: if self._degraded and self._fallback is not None: return self._fallback.get_interaction(session_id) try: r = self._get_sync_redis() val = r.get(f"{self.INTER_PREFIX}{session_id}") return int(val) if val is not None else 0 except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e: logger.warning(f"Redis cascade get failed: {e}") if self._fallback is not None: return self._fallback.get_interaction(session_id) return 0 def set_depth(self, session_id: str, depth: int) -> None: if self._degraded and self._fallback is not None: self._fallback.set_depth(session_id, depth) return try: r = self._get_sync_redis() key = f"{self.DEPTH_PREFIX}{session_id}" pipe = r.pipeline() pipe.set(key, depth) pipe.expire(key, self._session_ttl) pipe.execute() except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e: logger.warning(f"Redis cascade set_depth failed: {e}") self._degrade_to_fallback() if self._fallback is not None: self._fallback.set_depth(session_id, depth) def get_depth(self, session_id: str) -> int: if self._degraded and self._fallback is not None: return self._fallback.get_depth(session_id) try: r = self._get_sync_redis() val = r.get(f"{self.DEPTH_PREFIX}{session_id}") return int(val) if val is not None else 0 except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e: logger.warning(f"Redis cascade get_depth failed: {e}") if self._fallback is not None: return self._fallback.get_depth(session_id) return 0 def reset(self, session_id: str) -> None: if self._degraded and self._fallback is not None: self._fallback.reset(session_id) return try: r = self._get_sync_redis() pipe = r.pipeline() pipe.delete(f"{self.INTER_PREFIX}{session_id}") pipe.delete(f"{self.DEPTH_PREFIX}{session_id}") pipe.execute() except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e: logger.warning(f"Redis cascade reset failed: {e}") self._degrade_to_fallback() if self._fallback is not None: self._fallback.reset(session_id) def close(self) -> None: """Close the Redis connection pool.""" if self._sync_redis is not None: self._sync_redis.close() self._sync_redis = None # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- def create_cascade_state_store( backend: str = "auto", redis_url: str = "redis://localhost:6379", session_ttl: int = 86400, ) -> CascadeStateStore: """Create a cascade state store backend.""" if backend in ("auto", "redis"): try: import redis # noqa: F401 return RedisCascadeStateStore(redis_url=redis_url, session_ttl=session_ttl) except ImportError: logger.warning("redis package not available, falling back to in-memory cascade store") return InMemoryCascadeStateStore(session_ttl=session_ttl)