fischer-agentkit/src/agentkit/quality/cascade_state_store.py

257 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Cascade State Store — Persistent cascade detection state with Redis INCR backend.
Provides CascadeStateStore Protocol with InMemoryCascadeStateStore and
RedisCascadeStateStore backends. Enables CascadeDetector state to survive
restarts and work across multiple instances.
Key schema (Redis):
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
agentkit:cascade:depths:{session_id} → SET counter with TTL
"""
import logging
import time
from typing import Any, Protocol, runtime_checkable
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义(降级到 fallback
try:
from redis.exceptions import RedisError as _RedisError
except ImportError:
_RedisError = Exception
logger = logging.getLogger(__name__)
@runtime_checkable
class CascadeStateStore(Protocol):
"""Persistent cascade detection state interface."""
def increment_interaction(self, session_id: str) -> int:
"""Atomically increment interaction count. Returns new count."""
...
def get_interaction(self, session_id: str) -> int:
"""Get current interaction count."""
...
def set_depth(self, session_id: str, depth: int) -> None:
"""Set loop depth for a session."""
...
def get_depth(self, session_id: str) -> int:
"""Get current loop depth."""
...
def reset(self, session_id: str) -> None:
"""Reset all counters for a session."""
...
# ---------------------------------------------------------------------------
# InMemoryCascadeStateStore
# ---------------------------------------------------------------------------
class InMemoryCascadeStateStore:
"""In-memory cascade state store (default, process-local).
Supports optional session TTL to prevent unbounded memory growth.
Expired entries are lazily cleaned up on access.
"""
DEFAULT_SESSION_TTL = 86400 # 24 hours
def __init__(self, session_ttl: int = 86400):
self._session_ttl = session_ttl
self._interaction_counts: dict[str, int] = {}
self._loop_depths: dict[str, int] = {}
self._timestamps: dict[str, float] = {}
def _is_expired(self, session_id: str) -> bool:
ts = self._timestamps.get(session_id)
if ts is None:
return False
return (time.monotonic() - ts) > self._session_ttl
def _cleanup_expired(self) -> None:
"""Lazy cleanup: remove expired sessions."""
expired = [sid for sid in self._timestamps if self._is_expired(sid)]
for sid in expired:
self._interaction_counts.pop(sid, None)
self._loop_depths.pop(sid, None)
self._timestamps.pop(sid, None)
def _touch(self, session_id: str) -> None:
self._timestamps[session_id] = time.monotonic()
def increment_interaction(self, session_id: str) -> int:
self._cleanup_expired()
self._interaction_counts[session_id] = self._interaction_counts.get(session_id, 0) + 1
self._touch(session_id)
return self._interaction_counts[session_id]
def get_interaction(self, session_id: str) -> int:
if self._is_expired(session_id):
self.reset(session_id)
return 0
return self._interaction_counts.get(session_id, 0)
def set_depth(self, session_id: str, depth: int) -> None:
self._touch(session_id)
self._loop_depths[session_id] = depth
def get_depth(self, session_id: str) -> int:
if self._is_expired(session_id):
self.reset(session_id)
return 0
return self._loop_depths.get(session_id, 0)
def reset(self, session_id: str) -> None:
self._interaction_counts.pop(session_id, None)
self._loop_depths.pop(session_id, None)
self._timestamps.pop(session_id, None)
# ---------------------------------------------------------------------------
# RedisCascadeStateStore
# ---------------------------------------------------------------------------
class RedisCascadeStateStore:
"""Redis-backed cascade state store using INCR for atomic increments.
Key schema:
agentkit:cascade:interactions:{session_id} → INCR counter with TTL
agentkit:cascade:depths:{session_id} → SET counter with TTL
"""
INTER_PREFIX = "agentkit:cascade:interactions:"
DEPTH_PREFIX = "agentkit:cascade:depths:"
SESSION_TTL = 86400 # 24 hours — sessions rarely last longer
def __init__(self, redis_url: str = "redis://localhost:6379", session_ttl: int = 86400):
self._redis_url = redis_url
self._session_ttl = session_ttl
self._sync_redis: Any = None
self._fallback: InMemoryCascadeStateStore | None = None
self._degraded = False
def _get_sync_redis(self):
"""Get or create a persistent sync Redis client (connection pool backed)."""
if self._sync_redis is None:
import redis as sync_redis
self._sync_redis = sync_redis.from_url(
self._redis_url, decode_responses=True
)
return self._sync_redis
def _degrade_to_fallback(self) -> None:
if not self._degraded:
self._degraded = True
if self._fallback is None:
self._fallback = InMemoryCascadeStateStore(session_ttl=self._session_ttl)
logger.warning("Redis cascade store unreachable, degraded to in-memory")
def increment_interaction(self, session_id: str) -> int:
if self._degraded and self._fallback is not None:
return self._fallback.increment_interaction(session_id)
try:
r = self._get_sync_redis()
key = f"{self.INTER_PREFIX}{session_id}"
pipe = r.pipeline()
pipe.incr(key)
pipe.expire(key, self._session_ttl)
results = pipe.execute()
return results[0]
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
logger.warning(f"Redis cascade increment failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
return self._fallback.increment_interaction(session_id)
return 0
def get_interaction(self, session_id: str) -> int:
if self._degraded and self._fallback is not None:
return self._fallback.get_interaction(session_id)
try:
r = self._get_sync_redis()
val = r.get(f"{self.INTER_PREFIX}{session_id}")
return int(val) if val is not None else 0
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
logger.warning(f"Redis cascade get failed: {e}")
if self._fallback is not None:
return self._fallback.get_interaction(session_id)
return 0
def set_depth(self, session_id: str, depth: int) -> None:
if self._degraded and self._fallback is not None:
self._fallback.set_depth(session_id, depth)
return
try:
r = self._get_sync_redis()
key = f"{self.DEPTH_PREFIX}{session_id}"
pipe = r.pipeline()
pipe.set(key, depth)
pipe.expire(key, self._session_ttl)
pipe.execute()
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
logger.warning(f"Redis cascade set_depth failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
self._fallback.set_depth(session_id, depth)
def get_depth(self, session_id: str) -> int:
if self._degraded and self._fallback is not None:
return self._fallback.get_depth(session_id)
try:
r = self._get_sync_redis()
val = r.get(f"{self.DEPTH_PREFIX}{session_id}")
return int(val) if val is not None else 0
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
logger.warning(f"Redis cascade get_depth failed: {e}")
if self._fallback is not None:
return self._fallback.get_depth(session_id)
return 0
def reset(self, session_id: str) -> None:
if self._degraded and self._fallback is not None:
self._fallback.reset(session_id)
return
try:
r = self._get_sync_redis()
pipe = r.pipeline()
pipe.delete(f"{self.INTER_PREFIX}{session_id}")
pipe.delete(f"{self.DEPTH_PREFIX}{session_id}")
pipe.execute()
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError, TypeError) as e:
logger.warning(f"Redis cascade reset failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
self._fallback.reset(session_id)
def close(self) -> None:
"""Close the Redis connection pool."""
if self._sync_redis is not None:
self._sync_redis.close()
self._sync_redis = None
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_cascade_state_store(
backend: str = "auto",
redis_url: str = "redis://localhost:6379",
session_ttl: int = 86400,
) -> CascadeStateStore:
"""Create a cascade state store backend."""
if backend in ("auto", "redis"):
try:
import redis # noqa: F401
return RedisCascadeStateStore(redis_url=redis_url, session_ttl=session_ttl)
except ImportError:
logger.warning("redis package not available, falling back to in-memory cascade store")
return InMemoryCascadeStateStore(session_ttl=session_ttl)