fischer-agentkit/src/agentkit/session/store.py

358 lines
14 KiB
Python

"""Session store backends — InMemory and Redis."""
from __future__ import annotations
import json
import logging
import os
from typing import Any, Protocol, runtime_checkable
# redis 可选依赖;未安装时回退为 Exception 以保留原 catch-all 语义
try:
from redis.exceptions import RedisError as _RedisError
except ImportError:
_RedisError = Exception
from agentkit.session.models import Message, Session, SessionStatus
logger = logging.getLogger(__name__)
@runtime_checkable
class SessionStore(Protocol):
"""Protocol for session persistence backends."""
async def save_session(self, session: Session) -> None: ...
async def get_session(self, session_id: str) -> Session | None: ...
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ...
async def delete_session(self, session_id: str) -> bool: ...
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ...
async def append_message(self, message: Message) -> None: ...
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ...
async def count_messages(self, session_id: str) -> int: ...
async def health_check(self) -> bool: ...
class InMemorySessionStore:
"""In-memory session store for development and testing."""
def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000):
self._sessions: dict[str, Session] = {}
self._messages: dict[str, list[Message]] = {}
self._max_sessions = max_sessions
self._max_messages_per_session = max_messages_per_session
async def save_session(self, session: Session) -> None:
if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions:
# Evict oldest closed session
closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED]
if closed:
oldest = min(closed, key=lambda s: s.updated_at)
del self._sessions[oldest.session_id]
self._messages.pop(oldest.session_id, None)
else:
raise RuntimeError("SessionStore is full and no closed sessions to evict")
self._sessions[session.session_id] = session
if session.session_id not in self._messages:
self._messages[session.session_id] = []
async def get_session(self, session_id: str) -> Session | None:
return self._sessions.get(session_id)
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
session = self._sessions.get(session_id)
if session is None:
return None
session.status = status
session.updated_at = datetime.now(timezone.utc)
return session
async def delete_session(self, session_id: str) -> bool:
if session_id in self._sessions:
del self._sessions[session_id]
self._messages.pop(session_id, None)
return True
return False
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
sessions = list(self._sessions.values())
if agent_name:
sessions = [s for s in sessions if s.agent_name == agent_name]
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
msgs = self._messages.setdefault(message.session_id, [])
if len(msgs) >= self._max_messages_per_session:
# Remove oldest messages to stay within limit
excess = len(msgs) - self._max_messages_per_session + 1
del msgs[:excess]
msgs.append(message)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
msgs = self._messages.get(session_id, [])
sliced = msgs[offset:]
if limit is not None:
sliced = sliced[:limit]
return sliced
async def count_messages(self, session_id: str) -> int:
return len(self._messages.get(session_id, []))
async def health_check(self) -> bool:
return True
class RedisSessionStore:
"""Redis-backed session store for production use.
Key patterns:
- ``agentkit:session:{session_id}`` — session metadata (JSON + TTL)
- ``agentkit:session:{session_id}:messages`` — message list (Redis list)
"""
KEY_PREFIX = "agentkit:session:"
MSG_SUFFIX = ":messages"
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
self._redis_url = redis_url
self._ttl_seconds = ttl_seconds
self._redis: Any = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _session_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}"
def _messages_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}"
async def save_session(self, session: Session) -> None:
redis = await self._get_redis()
key = self._session_key(session.session_id)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
async def get_session(self, session_id: str) -> Session | None:
redis = await self._get_redis()
raw = await redis.get(self._session_key(session_id))
if raw is None:
return None
return Session.from_dict(json.loads(raw))
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
redis = await self._get_redis()
key = self._session_key(session_id)
raw = await redis.get(key)
if raw is None:
return None
session = Session.from_dict(json.loads(raw))
session.status = status
session.updated_at = datetime.now(timezone.utc)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
return session
async def delete_session(self, session_id: str) -> bool:
redis = await self._get_redis()
keys = [self._session_key(session_id), self._messages_key(session_id)]
deleted = await redis.delete(*keys)
return deleted > 0
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
redis = await self._get_redis()
sessions: list[Session] = []
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
# Filter out message list keys
session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)]
if session_keys:
values = await redis.mget(session_keys)
for raw in values:
if raw is None:
continue
session = Session.from_dict(json.loads(raw))
if agent_name is None or session.agent_name == agent_name:
sessions.append(session)
if cursor == 0:
break
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
redis = await self._get_redis()
key = self._messages_key(message.session_id)
await redis.rpush(key, json.dumps(message.to_dict()))
# Set TTL on message list to match session TTL
await redis.expire(key, self._ttl_seconds)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
redis = await self._get_redis()
key = self._messages_key(session_id)
# Use LRANGE for offset-based pagination
# Redis list indices: 0-based, -1 = last element
start = offset
if limit is not None:
end = offset + limit - 1
else:
end = -1
raw_list = await redis.lrange(key, start, end)
return [Message.from_dict(json.loads(raw)) for raw in raw_list]
async def count_messages(self, session_id: str) -> int:
redis = await self._get_redis()
return await redis.llen(self._messages_key(session_id))
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError):
return False
# Needed for from_dict deserialization
from datetime import datetime, timezone # noqa: E402
class FileSessionStore:
"""File-based session store — persists sessions to ~/.agentkit/sessions/.
Each session is stored as a JSON file containing both session metadata
and messages. Suitable for single-user GUI mode without Redis.
"""
def __init__(self, data_dir: str | None = None):
if data_dir is None:
data_dir = os.path.expanduser("~/.agentkit/sessions")
self._data_dir = data_dir
os.makedirs(self._data_dir, exist_ok=True)
def _session_path(self, session_id: str) -> str:
return os.path.join(self._data_dir, f"{session_id}.json")
def _read_session_file(self, session_id: str) -> dict | None:
path = self._session_path(session_id)
if not os.path.exists(path):
return None
with open(path, encoding="utf-8") as f:
return json.load(f)
def _write_session_file(self, session_id: str, data: dict) -> None:
path = self._session_path(session_id)
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
async def save_session(self, session: Session) -> None:
data = self._read_session_file(session.session_id) or {"messages": []}
data["session"] = session.to_dict()
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(session.session_id, data)
async def get_session(self, session_id: str) -> Session | None:
data = self._read_session_file(session_id)
if data is None:
return None
return Session.from_dict(data["session"])
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
data = self._read_session_file(session_id)
if data is None:
return None
data["session"]["status"] = status.value
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(session_id, data)
return Session.from_dict(data["session"])
async def delete_session(self, session_id: str) -> bool:
path = self._session_path(session_id)
if os.path.exists(path):
os.remove(path)
return True
return False
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
sessions: list[Session] = []
for fname in os.listdir(self._data_dir):
if not fname.endswith(".json"):
continue
path = os.path.join(self._data_dir, fname)
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
session = Session.from_dict(data["session"])
if agent_name is None or session.agent_name == agent_name:
sessions.append(session)
except (ValueError, KeyError, TypeError, OSError):
continue
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
data = self._read_session_file(message.session_id)
if data is None:
data = {"session": {"session_id": message.session_id}, "messages": []}
data.setdefault("messages", []).append(message.to_dict())
# Update session timestamp
if "session" in data:
data["session"]["updated_at"] = datetime.now(timezone.utc).isoformat()
self._write_session_file(message.session_id, data)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
data = self._read_session_file(session_id)
if data is None:
return []
msgs = data.get("messages", [])[offset:]
if limit is not None:
msgs = msgs[:limit]
return [Message.from_dict(m) for m in msgs]
async def count_messages(self, session_id: str) -> int:
data = self._read_session_file(session_id)
if data is None:
return 0
return len(data.get("messages", []))
async def health_check(self) -> bool:
return os.path.isdir(self._data_dir)
def create_session_store(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
ttl_seconds: int = 86400,
data_dir: str | None = None,
) -> InMemorySessionStore | RedisSessionStore | FileSessionStore:
"""Factory: create a SessionStore backed by memory, file, or Redis.
- ``memory``: In-memory (lost on restart)
- ``file``: JSON files in ``~/.agentkit/sessions/`` (persistent, no deps)
- ``redis``: Redis-backed (production, requires Redis)
Falls back to InMemorySessionStore if Redis is unavailable.
"""
if backend == "file":
store = FileSessionStore(data_dir=data_dir)
logger.info(f"SessionStore backend: file ({store._data_dir})")
return store
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds)
logger.info(f"SessionStore backend: redis")
return store
except (ImportError, OSError, _RedisError, ValueError, KeyError, RuntimeError) as exc:
logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore")
store = InMemorySessionStore()
logger.info("SessionStore backend: memory")
return store