"""Session store backends — InMemory and Redis.""" from __future__ import annotations import json import logging import os from typing import Any, Protocol, runtime_checkable # redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义 try: from redis.exceptions import RedisError as _RedisError except ImportError: _RedisError = Exception from agentkit.session.models import Message, Session, SessionStatus logger = logging.getLogger(__name__) @runtime_checkable class SessionStore(Protocol): """Protocol for session persistence backends.""" async def save_session(self, session: Session) -> None: ... async def get_session(self, session_id: str) -> Session | None: ... async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ... async def delete_session(self, session_id: str) -> bool: ... async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ... async def append_message(self, message: Message) -> None: ... async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ... async def count_messages(self, session_id: str) -> int: ... async def health_check(self) -> bool: ... class InMemorySessionStore: """In-memory session store for development and testing.""" def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000): self._sessions: dict[str, Session] = {} self._messages: dict[str, list[Message]] = {} self._max_sessions = max_sessions self._max_messages_per_session = max_messages_per_session async def save_session(self, session: Session) -> None: if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions: # Evict oldest closed session closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED] if closed: oldest = min(closed, key=lambda s: s.updated_at) del self._sessions[oldest.session_id] self._messages.pop(oldest.session_id, None) else: raise RuntimeError("SessionStore is full and no closed sessions to evict") self._sessions[session.session_id] = session if session.session_id not in self._messages: self._messages[session.session_id] = [] async def get_session(self, session_id: str) -> Session | None: return self._sessions.get(session_id) async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: session = self._sessions.get(session_id) if session is None: return None session.status = status session.updated_at = datetime.now(timezone.utc) return session async def delete_session(self, session_id: str) -> bool: if session_id in self._sessions: del self._sessions[session_id] self._messages.pop(session_id, None) return True return False async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: sessions = list(self._sessions.values()) if agent_name: sessions = [s for s in sessions if s.agent_name == agent_name] sessions.sort(key=lambda s: s.updated_at, reverse=True) return sessions[:limit] async def append_message(self, message: Message) -> None: msgs = self._messages.setdefault(message.session_id, []) if len(msgs) >= self._max_messages_per_session: # Remove oldest messages to stay within limit excess = len(msgs) - self._max_messages_per_session + 1 del msgs[:excess] msgs.append(message) async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: msgs = self._messages.get(session_id, []) sliced = msgs[offset:] if limit is not None: sliced = sliced[:limit] return sliced async def count_messages(self, session_id: str) -> int: return len(self._messages.get(session_id, [])) async def health_check(self) -> bool: return True class RedisSessionStore: """Redis-backed session store for production use. Key patterns: - ``agentkit:session:{session_id}`` — session metadata (JSON + TTL) - ``agentkit:session:{session_id}:messages`` — message list (Redis list) """ KEY_PREFIX = "agentkit:session:" MSG_SUFFIX = ":messages" def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400): self._redis_url = redis_url self._ttl_seconds = ttl_seconds self._redis: Any = None async def _get_redis(self): if self._redis is None: import redis.asyncio as aioredis self._redis = aioredis.from_url(self._redis_url, decode_responses=True) return self._redis def _session_key(self, session_id: str) -> str: return f"{self.KEY_PREFIX}{session_id}" def _messages_key(self, session_id: str) -> str: return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}" async def save_session(self, session: Session) -> None: redis = await self._get_redis() key = self._session_key(session.session_id) await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds) async def get_session(self, session_id: str) -> Session | None: redis = await self._get_redis() raw = await redis.get(self._session_key(session_id)) if raw is None: return None return Session.from_dict(json.loads(raw)) async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: redis = await self._get_redis() key = self._session_key(session_id) raw = await redis.get(key) if raw is None: return None session = Session.from_dict(json.loads(raw)) session.status = status session.updated_at = datetime.now(timezone.utc) await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds) return session async def delete_session(self, session_id: str) -> bool: redis = await self._get_redis() keys = [self._session_key(session_id), self._messages_key(session_id)] deleted = await redis.delete(*keys) return deleted > 0 async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: redis = await self._get_redis() sessions: list[Session] = [] cursor = 0 while True: cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) # Filter out message list keys session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)] if session_keys: values = await redis.mget(session_keys) for raw in values: if raw is None: continue session = Session.from_dict(json.loads(raw)) if agent_name is None or session.agent_name == agent_name: sessions.append(session) if cursor == 0: break sessions.sort(key=lambda s: s.updated_at, reverse=True) return sessions[:limit] async def append_message(self, message: Message) -> None: redis = await self._get_redis() key = self._messages_key(message.session_id) await redis.rpush(key, json.dumps(message.to_dict())) # Set TTL on message list to match session TTL await redis.expire(key, self._ttl_seconds) async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: redis = await self._get_redis() key = self._messages_key(session_id) # Use LRANGE for offset-based pagination # Redis list indices: 0-based, -1 = last element start = offset if limit is not None: end = offset + limit - 1 else: end = -1 raw_list = await redis.lrange(key, start, end) return [Message.from_dict(json.loads(raw)) for raw in raw_list] async def count_messages(self, session_id: str) -> int: redis = await self._get_redis() return await redis.llen(self._messages_key(session_id)) async def health_check(self) -> bool: try: redis = await self._get_redis() return await redis.ping() except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError): return False # Needed for from_dict deserialization from datetime import datetime, timezone # noqa: E402 class FileSessionStore: """File-based session store — persists sessions to ~/.agentkit/sessions/. Each session is stored as a JSON file containing both session metadata and messages. Suitable for single-user GUI mode without Redis. """ def __init__(self, data_dir: str | None = None): if data_dir is None: data_dir = os.path.expanduser("~/.agentkit/sessions") self._data_dir = data_dir os.makedirs(self._data_dir, exist_ok=True) def _session_path(self, session_id: str) -> str: return os.path.join(self._data_dir, f"{session_id}.json") def _read_session_file(self, session_id: str) -> dict | None: path = self._session_path(session_id) if not os.path.exists(path): return None with open(path, encoding="utf-8") as f: return json.load(f) def _write_session_file(self, session_id: str, data: dict) -> None: path = self._session_path(session_id) with open(path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) async def save_session(self, session: Session) -> None: data = self._read_session_file(session.session_id) or {"messages": []} data["session"] = session.to_dict() data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat() self._write_session_file(session.session_id, data) async def get_session(self, session_id: str) -> Session | None: data = self._read_session_file(session_id) if data is None: return None return Session.from_dict(data["session"]) async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: data = self._read_session_file(session_id) if data is None: return None data["session"]["status"] = status.value data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat() self._write_session_file(session_id, data) return Session.from_dict(data["session"]) async def delete_session(self, session_id: str) -> bool: path = self._session_path(session_id) if os.path.exists(path): os.remove(path) return True return False async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: sessions: list[Session] = [] for fname in os.listdir(self._data_dir): if not fname.endswith(".json"): continue path = os.path.join(self._data_dir, fname) try: with open(path, encoding="utf-8") as f: data = json.load(f) session = Session.from_dict(data["session"]) if agent_name is None or session.agent_name == agent_name: sessions.append(session) except (ValueError, KeyError, TypeError, OSError): continue sessions.sort(key=lambda s: s.updated_at, reverse=True) return sessions[:limit] async def append_message(self, message: Message) -> None: data = self._read_session_file(message.session_id) if data is None: data = {"session": {"session_id": message.session_id}, "messages": []} data.setdefault("messages", []).append(message.to_dict()) # Update session timestamp if "session" in data: data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat() self._write_session_file(message.session_id, data) async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: data = self._read_session_file(session_id) if data is None: return [] msgs = data.get("messages", [])[offset:] if limit is not None: msgs = msgs[:limit] return [Message.from_dict(m) for m in msgs] async def count_messages(self, session_id: str) -> int: data = self._read_session_file(session_id) if data is None: return 0 return len(data.get("messages", [])) async def health_check(self) -> bool: return os.path.isdir(self._data_dir) def create_session_store( backend: str = "memory", redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400, data_dir: str | None = None, ) -> InMemorySessionStore | RedisSessionStore | FileSessionStore: """Factory: create a SessionStore backed by memory, file, or Redis. - ``memory``: In-memory (lost on restart) - ``file``: JSON files in ``~/.agentkit/sessions/`` (persistent, no deps) - ``redis``: Redis-backed (production, requires Redis) Falls back to InMemorySessionStore if Redis is unavailable. """ if backend == "file": store = FileSessionStore(data_dir=data_dir) logger.info(f"SessionStore backend: file ({store._data_dir})") return store if backend == "redis": try: import redis.asyncio as aioredis # noqa: F401 store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds) logger.info(f"SessionStore backend: redis") return store except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError) as exc: logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore") store = InMemorySessionStore() logger.info("SessionStore backend: memory") return store