358 lines
14 KiB
Python
358 lines
14 KiB
Python
"""Session store backends — InMemory and Redis."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Protocol, runtime_checkable
|
|
|
|
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义
|
|
try:
|
|
from redis.exceptions import RedisError as _RedisError
|
|
except ImportError:
|
|
_RedisError = Exception
|
|
|
|
from agentkit.session.models import Message, Session, SessionStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@runtime_checkable
|
|
class SessionStore(Protocol):
|
|
"""Protocol for session persistence backends."""
|
|
|
|
async def save_session(self, session: Session) -> None: ...
|
|
async def get_session(self, session_id: str) -> Session | None: ...
|
|
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ...
|
|
async def delete_session(self, session_id: str) -> bool: ...
|
|
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ...
|
|
|
|
async def append_message(self, message: Message) -> None: ...
|
|
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ...
|
|
async def count_messages(self, session_id: str) -> int: ...
|
|
|
|
async def health_check(self) -> bool: ...
|
|
|
|
|
|
class InMemorySessionStore:
|
|
"""In-memory session store for development and testing."""
|
|
|
|
def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000):
|
|
self._sessions: dict[str, Session] = {}
|
|
self._messages: dict[str, list[Message]] = {}
|
|
self._max_sessions = max_sessions
|
|
self._max_messages_per_session = max_messages_per_session
|
|
|
|
async def save_session(self, session: Session) -> None:
|
|
if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions:
|
|
# Evict oldest closed session
|
|
closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED]
|
|
if closed:
|
|
oldest = min(closed, key=lambda s: s.updated_at)
|
|
del self._sessions[oldest.session_id]
|
|
self._messages.pop(oldest.session_id, None)
|
|
else:
|
|
raise RuntimeError("SessionStore is full and no closed sessions to evict")
|
|
self._sessions[session.session_id] = session
|
|
if session.session_id not in self._messages:
|
|
self._messages[session.session_id] = []
|
|
|
|
async def get_session(self, session_id: str) -> Session | None:
|
|
return self._sessions.get(session_id)
|
|
|
|
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
|
|
session = self._sessions.get(session_id)
|
|
if session is None:
|
|
return None
|
|
session.status = status
|
|
session.updated_at = datetime.now(timezone.utc)
|
|
return session
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
if session_id in self._sessions:
|
|
del self._sessions[session_id]
|
|
self._messages.pop(session_id, None)
|
|
return True
|
|
return False
|
|
|
|
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
|
|
sessions = list(self._sessions.values())
|
|
if agent_name:
|
|
sessions = [s for s in sessions if s.agent_name == agent_name]
|
|
sessions.sort(key=lambda s: s.updated_at, reverse=True)
|
|
return sessions[:limit]
|
|
|
|
async def append_message(self, message: Message) -> None:
|
|
msgs = self._messages.setdefault(message.session_id, [])
|
|
if len(msgs) >= self._max_messages_per_session:
|
|
# Remove oldest messages to stay within limit
|
|
excess = len(msgs) - self._max_messages_per_session + 1
|
|
del msgs[:excess]
|
|
msgs.append(message)
|
|
|
|
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
|
|
msgs = self._messages.get(session_id, [])
|
|
sliced = msgs[offset:]
|
|
if limit is not None:
|
|
sliced = sliced[:limit]
|
|
return sliced
|
|
|
|
async def count_messages(self, session_id: str) -> int:
|
|
return len(self._messages.get(session_id, []))
|
|
|
|
async def health_check(self) -> bool:
|
|
return True
|
|
|
|
|
|
class RedisSessionStore:
|
|
"""Redis-backed session store for production use.
|
|
|
|
Key patterns:
|
|
- ``agentkit:session:{session_id}`` — session metadata (JSON + TTL)
|
|
- ``agentkit:session:{session_id}:messages`` — message list (Redis list)
|
|
"""
|
|
|
|
KEY_PREFIX = "agentkit:session:"
|
|
MSG_SUFFIX = ":messages"
|
|
|
|
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
|
|
self._redis_url = redis_url
|
|
self._ttl_seconds = ttl_seconds
|
|
self._redis: Any = None
|
|
|
|
async def _get_redis(self):
|
|
if self._redis is None:
|
|
import redis.asyncio as aioredis
|
|
|
|
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
|
return self._redis
|
|
|
|
def _session_key(self, session_id: str) -> str:
|
|
return f"{self.KEY_PREFIX}{session_id}"
|
|
|
|
def _messages_key(self, session_id: str) -> str:
|
|
return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}"
|
|
|
|
async def save_session(self, session: Session) -> None:
|
|
redis = await self._get_redis()
|
|
key = self._session_key(session.session_id)
|
|
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
|
|
|
|
async def get_session(self, session_id: str) -> Session | None:
|
|
redis = await self._get_redis()
|
|
raw = await redis.get(self._session_key(session_id))
|
|
if raw is None:
|
|
return None
|
|
return Session.from_dict(json.loads(raw))
|
|
|
|
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
|
|
redis = await self._get_redis()
|
|
key = self._session_key(session_id)
|
|
raw = await redis.get(key)
|
|
if raw is None:
|
|
return None
|
|
session = Session.from_dict(json.loads(raw))
|
|
session.status = status
|
|
session.updated_at = datetime.now(timezone.utc)
|
|
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
|
|
return session
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
redis = await self._get_redis()
|
|
keys = [self._session_key(session_id), self._messages_key(session_id)]
|
|
deleted = await redis.delete(*keys)
|
|
return deleted > 0
|
|
|
|
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
|
|
redis = await self._get_redis()
|
|
sessions: list[Session] = []
|
|
cursor = 0
|
|
while True:
|
|
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
|
# Filter out message list keys
|
|
session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)]
|
|
if session_keys:
|
|
values = await redis.mget(session_keys)
|
|
for raw in values:
|
|
if raw is None:
|
|
continue
|
|
session = Session.from_dict(json.loads(raw))
|
|
if agent_name is None or session.agent_name == agent_name:
|
|
sessions.append(session)
|
|
if cursor == 0:
|
|
break
|
|
sessions.sort(key=lambda s: s.updated_at, reverse=True)
|
|
return sessions[:limit]
|
|
|
|
async def append_message(self, message: Message) -> None:
|
|
redis = await self._get_redis()
|
|
key = self._messages_key(message.session_id)
|
|
await redis.rpush(key, json.dumps(message.to_dict()))
|
|
# Set TTL on message list to match session TTL
|
|
await redis.expire(key, self._ttl_seconds)
|
|
|
|
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
|
|
redis = await self._get_redis()
|
|
key = self._messages_key(session_id)
|
|
# Use LRANGE for offset-based pagination
|
|
# Redis list indices: 0-based, -1 = last element
|
|
start = offset
|
|
if limit is not None:
|
|
end = offset + limit - 1
|
|
else:
|
|
end = -1
|
|
raw_list = await redis.lrange(key, start, end)
|
|
return [Message.from_dict(json.loads(raw)) for raw in raw_list]
|
|
|
|
async def count_messages(self, session_id: str) -> int:
|
|
redis = await self._get_redis()
|
|
return await redis.llen(self._messages_key(session_id))
|
|
|
|
async def health_check(self) -> bool:
|
|
try:
|
|
redis = await self._get_redis()
|
|
return await redis.ping()
|
|
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError):
|
|
return False
|
|
|
|
|
|
# Needed for from_dict deserialization
|
|
from datetime import datetime, timezone # noqa: E402
|
|
|
|
|
|
class FileSessionStore:
|
|
"""File-based session store — persists sessions to ~/.agentkit/sessions/.
|
|
|
|
Each session is stored as a JSON file containing both session metadata
|
|
and messages. Suitable for single-user GUI mode without Redis.
|
|
"""
|
|
|
|
def __init__(self, data_dir: str | None = None):
|
|
if data_dir is None:
|
|
data_dir = os.path.expanduser("~/.agentkit/sessions")
|
|
self._data_dir = data_dir
|
|
os.makedirs(self._data_dir, exist_ok=True)
|
|
|
|
def _session_path(self, session_id: str) -> str:
|
|
return os.path.join(self._data_dir, f"{session_id}.json")
|
|
|
|
def _read_session_file(self, session_id: str) -> dict | None:
|
|
path = self._session_path(session_id)
|
|
if not os.path.exists(path):
|
|
return None
|
|
with open(path, encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
def _write_session_file(self, session_id: str, data: dict) -> None:
|
|
path = self._session_path(session_id)
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
|
async def save_session(self, session: Session) -> None:
|
|
data = self._read_session_file(session.session_id) or {"messages": []}
|
|
data["session"] = session.to_dict()
|
|
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
self._write_session_file(session.session_id, data)
|
|
|
|
async def get_session(self, session_id: str) -> Session | None:
|
|
data = self._read_session_file(session_id)
|
|
if data is None:
|
|
return None
|
|
return Session.from_dict(data["session"])
|
|
|
|
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
|
|
data = self._read_session_file(session_id)
|
|
if data is None:
|
|
return None
|
|
data["session"]["status"] = status.value
|
|
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
self._write_session_file(session_id, data)
|
|
return Session.from_dict(data["session"])
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
path = self._session_path(session_id)
|
|
if os.path.exists(path):
|
|
os.remove(path)
|
|
return True
|
|
return False
|
|
|
|
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
|
|
sessions: list[Session] = []
|
|
for fname in os.listdir(self._data_dir):
|
|
if not fname.endswith(".json"):
|
|
continue
|
|
path = os.path.join(self._data_dir, fname)
|
|
try:
|
|
with open(path, encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
session = Session.from_dict(data["session"])
|
|
if agent_name is None or session.agent_name == agent_name:
|
|
sessions.append(session)
|
|
except (ValueError, KeyError, TypeError, OSError):
|
|
continue
|
|
sessions.sort(key=lambda s: s.updated_at, reverse=True)
|
|
return sessions[:limit]
|
|
|
|
async def append_message(self, message: Message) -> None:
|
|
data = self._read_session_file(message.session_id)
|
|
if data is None:
|
|
data = {"session": {"session_id": message.session_id}, "messages": []}
|
|
data.setdefault("messages", []).append(message.to_dict())
|
|
# Update session timestamp
|
|
if "session" in data:
|
|
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
self._write_session_file(message.session_id, data)
|
|
|
|
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
|
|
data = self._read_session_file(session_id)
|
|
if data is None:
|
|
return []
|
|
msgs = data.get("messages", [])[offset:]
|
|
if limit is not None:
|
|
msgs = msgs[:limit]
|
|
return [Message.from_dict(m) for m in msgs]
|
|
|
|
async def count_messages(self, session_id: str) -> int:
|
|
data = self._read_session_file(session_id)
|
|
if data is None:
|
|
return 0
|
|
return len(data.get("messages", []))
|
|
|
|
async def health_check(self) -> bool:
|
|
return os.path.isdir(self._data_dir)
|
|
|
|
|
|
def create_session_store(
|
|
backend: str = "memory",
|
|
redis_url: str = "redis://localhost:6379/0",
|
|
ttl_seconds: int = 86400,
|
|
data_dir: str | None = None,
|
|
) -> InMemorySessionStore | RedisSessionStore | FileSessionStore:
|
|
"""Factory: create a SessionStore backed by memory, file, or Redis.
|
|
|
|
- ``memory``: In-memory (lost on restart)
|
|
- ``file``: JSON files in ``~/.agentkit/sessions/`` (persistent, no deps)
|
|
- ``redis``: Redis-backed (production, requires Redis)
|
|
|
|
Falls back to InMemorySessionStore if Redis is unavailable.
|
|
"""
|
|
if backend == "file":
|
|
store = FileSessionStore(data_dir=data_dir)
|
|
logger.info(f"SessionStore backend: file ({store._data_dir})")
|
|
return store
|
|
|
|
if backend == "redis":
|
|
try:
|
|
import redis.asyncio as aioredis # noqa: F401
|
|
|
|
store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds)
|
|
logger.info(f"SessionStore backend: redis")
|
|
return store
|
|
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError) as exc:
|
|
logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore")
|
|
|
|
store = InMemorySessionStore()
|
|
logger.info("SessionStore backend: memory")
|
|
return store
|