257 lines
9.7 KiB
Python
257 lines
9.7 KiB
Python
"""Cascade State Store — Persistent cascade detection state with Redis INCR backend.
|
||
|
||
Provides CascadeStateStore Protocol with InMemoryCascadeStateStore and
|
||
RedisCascadeStateStore backends. Enables CascadeDetector state to survive
|
||
restarts and work across multiple instances.
|
||
|
||
Key schema (Redis):
|
||
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
|
||
agentkit:cascade:depths:{session_id} → SET counter with TTL
|
||
"""
|
||
|
||
import logging
|
||
import time
|
||
from typing import Any, Protocol, runtime_checkable
|
||
|
||
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义(降级到 fallback)
|
||
try:
|
||
from redis.exceptions import RedisError as _RedisError
|
||
except ImportError:
|
||
_RedisError = Exception
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@runtime_checkable
|
||
class CascadeStateStore(Protocol):
|
||
"""Persistent cascade detection state interface."""
|
||
|
||
def increment_interaction(self, session_id: str) -> int:
|
||
"""Atomically increment interaction count. Returns new count."""
|
||
...
|
||
|
||
def get_interaction(self, session_id: str) -> int:
|
||
"""Get current interaction count."""
|
||
...
|
||
|
||
def set_depth(self, session_id: str, depth: int) -> None:
|
||
"""Set loop depth for a session."""
|
||
...
|
||
|
||
def get_depth(self, session_id: str) -> int:
|
||
"""Get current loop depth."""
|
||
...
|
||
|
||
def reset(self, session_id: str) -> None:
|
||
"""Reset all counters for a session."""
|
||
...
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# InMemoryCascadeStateStore
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class InMemoryCascadeStateStore:
|
||
"""In-memory cascade state store (default, process-local).
|
||
|
||
Supports optional session TTL to prevent unbounded memory growth.
|
||
Expired entries are lazily cleaned up on access.
|
||
"""
|
||
|
||
DEFAULT_SESSION_TTL = 86400 # 24 hours
|
||
|
||
def __init__(self, session_ttl: int = 86400):
|
||
self._session_ttl = session_ttl
|
||
self._interaction_counts: dict[str, int] = {}
|
||
self._loop_depths: dict[str, int] = {}
|
||
self._timestamps: dict[str, float] = {}
|
||
|
||
def _is_expired(self, session_id: str) -> bool:
|
||
ts = self._timestamps.get(session_id)
|
||
if ts is None:
|
||
return False
|
||
return (time.monotonic() - ts) > self._session_ttl
|
||
|
||
def _cleanup_expired(self) -> None:
|
||
"""Lazy cleanup: remove expired sessions."""
|
||
expired = [sid for sid in self._timestamps if self._is_expired(sid)]
|
||
for sid in expired:
|
||
self._interaction_counts.pop(sid, None)
|
||
self._loop_depths.pop(sid, None)
|
||
self._timestamps.pop(sid, None)
|
||
|
||
def _touch(self, session_id: str) -> None:
|
||
self._timestamps[session_id] = time.monotonic()
|
||
|
||
def increment_interaction(self, session_id: str) -> int:
|
||
self._cleanup_expired()
|
||
self._interaction_counts[session_id] = self._interaction_counts.get(session_id, 0) + 1
|
||
self._touch(session_id)
|
||
return self._interaction_counts[session_id]
|
||
|
||
def get_interaction(self, session_id: str) -> int:
|
||
if self._is_expired(session_id):
|
||
self.reset(session_id)
|
||
return 0
|
||
return self._interaction_counts.get(session_id, 0)
|
||
|
||
def set_depth(self, session_id: str, depth: int) -> None:
|
||
self._touch(session_id)
|
||
self._loop_depths[session_id] = depth
|
||
|
||
def get_depth(self, session_id: str) -> int:
|
||
if self._is_expired(session_id):
|
||
self.reset(session_id)
|
||
return 0
|
||
return self._loop_depths.get(session_id, 0)
|
||
|
||
def reset(self, session_id: str) -> None:
|
||
self._interaction_counts.pop(session_id, None)
|
||
self._loop_depths.pop(session_id, None)
|
||
self._timestamps.pop(session_id, None)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RedisCascadeStateStore
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RedisCascadeStateStore:
|
||
"""Redis-backed cascade state store using INCR for atomic increments.
|
||
|
||
Key schema:
|
||
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
|
||
agentkit:cascade:depths:{session_id} → SET counter with TTL
|
||
"""
|
||
|
||
INTER_PREFIX = "agentkit:cascade:interactions:"
|
||
DEPTH_PREFIX = "agentkit:cascade:depths:"
|
||
SESSION_TTL = 86400 # 24 hours — sessions rarely last longer
|
||
|
||
def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400):
|
||
self._redis_url = redis_url
|
||
self._session_ttl = session_ttl
|
||
self._sync_redis: Any = None
|
||
self._fallback: InMemoryCascadeStateStore | None = None
|
||
self._degraded = False
|
||
|
||
def _get_sync_redis(self):
|
||
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
||
if self._sync_redis is None:
|
||
import redis as sync_redis
|
||
self._sync_redis = sync_redis.from_url(
|
||
self._redis_url, decode_responses=True
|
||
)
|
||
return self._sync_redis
|
||
|
||
def _degrade_to_fallback(self) -> None:
|
||
if not self._degraded:
|
||
self._degraded = True
|
||
if self._fallback is None:
|
||
self._fallback = InMemoryCascadeStateStore(session_ttl=self._session_ttl)
|
||
logger.warning("Redis cascade store unreachable, degraded to in-memory")
|
||
|
||
def increment_interaction(self, session_id: str) -> int:
|
||
if self._degraded and self._fallback is not None:
|
||
return self._fallback.increment_interaction(session_id)
|
||
try:
|
||
r = self._get_sync_redis()
|
||
key = f"{self.INTER_PREFIX}{session_id}"
|
||
pipe = r.pipeline()
|
||
pipe.incr(key)
|
||
pipe.expire(key, self._session_ttl)
|
||
results = pipe.execute()
|
||
return results[0]
|
||
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
|
||
logger.warning(f"Redis cascade increment failed: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
return self._fallback.increment_interaction(session_id)
|
||
return 0
|
||
|
||
def get_interaction(self, session_id: str) -> int:
|
||
if self._degraded and self._fallback is not None:
|
||
return self._fallback.get_interaction(session_id)
|
||
try:
|
||
r = self._get_sync_redis()
|
||
val = r.get(f"{self.INTER_PREFIX}{session_id}")
|
||
return int(val) if val is not None else 0
|
||
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
|
||
logger.warning(f"Redis cascade get failed: {e}")
|
||
if self._fallback is not None:
|
||
return self._fallback.get_interaction(session_id)
|
||
return 0
|
||
|
||
def set_depth(self, session_id: str, depth: int) -> None:
|
||
if self._degraded and self._fallback is not None:
|
||
self._fallback.set_depth(session_id, depth)
|
||
return
|
||
try:
|
||
r = self._get_sync_redis()
|
||
key = f"{self.DEPTH_PREFIX}{session_id}"
|
||
pipe = r.pipeline()
|
||
pipe.set(key, depth)
|
||
pipe.expire(key, self._session_ttl)
|
||
pipe.execute()
|
||
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
|
||
logger.warning(f"Redis cascade set_depth failed: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
self._fallback.set_depth(session_id, depth)
|
||
|
||
def get_depth(self, session_id: str) -> int:
|
||
if self._degraded and self._fallback is not None:
|
||
return self._fallback.get_depth(session_id)
|
||
try:
|
||
r = self._get_sync_redis()
|
||
val = r.get(f"{self.DEPTH_PREFIX}{session_id}")
|
||
return int(val) if val is not None else 0
|
||
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
|
||
logger.warning(f"Redis cascade get_depth failed: {e}")
|
||
if self._fallback is not None:
|
||
return self._fallback.get_depth(session_id)
|
||
return 0
|
||
|
||
def reset(self, session_id: str) -> None:
|
||
if self._degraded and self._fallback is not None:
|
||
self._fallback.reset(session_id)
|
||
return
|
||
try:
|
||
r = self._get_sync_redis()
|
||
pipe = r.pipeline()
|
||
pipe.delete(f"{self.INTER_PREFIX}{session_id}")
|
||
pipe.delete(f"{self.DEPTH_PREFIX}{session_id}")
|
||
pipe.execute()
|
||
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
|
||
logger.warning(f"Redis cascade reset failed: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
self._fallback.reset(session_id)
|
||
|
||
def close(self) -> None:
|
||
"""Close the Redis connection pool."""
|
||
if self._sync_redis is not None:
|
||
self._sync_redis.close()
|
||
self._sync_redis = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Factory
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def create_cascade_state_store(
|
||
backend: str = "auto",
|
||
redis_url: str = "redis://localhost:6379",
|
||
session_ttl: int = 86400,
|
||
) -> CascadeStateStore:
|
||
"""Create a cascade state store backend."""
|
||
if backend in ("auto", "redis"):
|
||
try:
|
||
import redis # noqa: F401
|
||
return RedisCascadeStateStore(redis_url=redis_url, session_ttl=session_ttl)
|
||
except ImportError:
|
||
logger.warning("redis package not available, falling back to in-memory cascade store")
|
||
return InMemoryCascadeStateStore(session_ttl=session_ttl)
|