diff --git a/src/agentkit/session/__init__.py b/src/agentkit/session/__init__.py new file mode 100644 index 0000000..b51b601 --- /dev/null +++ b/src/agentkit/session/__init__.py @@ -0,0 +1,16 @@ +"""Session management - multi-turn conversation support for AgentKit.""" + +from agentkit.session.models import Message, MessageRole, Session, SessionStatus +from agentkit.session.manager import SessionManager +from agentkit.session.store import InMemorySessionStore, RedisSessionStore, create_session_store + +__all__ = [ + "Message", + "MessageRole", + "Session", + "SessionStatus", + "SessionManager", + "InMemorySessionStore", + "RedisSessionStore", + "create_session_store", +] diff --git a/src/agentkit/session/manager.py b/src/agentkit/session/manager.py new file mode 100644 index 0000000..207a3a7 --- /dev/null +++ b/src/agentkit/session/manager.py @@ -0,0 +1,160 @@ +"""SessionManager — high-level API for conversation session management.""" + +from __future__ import annotations + +import logging +from typing import Any + +from agentkit.session.models import Message, MessageRole, Session, SessionStatus +from agentkit.session.store import InMemorySessionStore, SessionStore + +logger = logging.getLogger(__name__) + + +class SessionManager: + """Manages conversation sessions and their messages. + + Provides a high-level API for creating, querying, and updating + sessions, as well as appending and retrieving messages. + """ + + def __init__(self, store: SessionStore | None = None): + self._store = store or InMemorySessionStore() + + @property + def store(self) -> SessionStore: + return self._store + + async def create_session( + self, + agent_name: str, + metadata: dict[str, Any] | None = None, + ) -> Session: + """Create a new conversation session bound to an Agent. + + Args: + agent_name: Name of the Agent this session is bound to. + metadata: Optional metadata to attach to the session. + + Returns: + The newly created Session. + """ + session = Session( + session_id=Session.new_session_id(), + agent_name=agent_name, + metadata=metadata or {}, + ) + await self._store.save_session(session) + logger.info(f"Session created: {session.session_id} for agent '{agent_name}'") + return session + + async def get_session(self, session_id: str) -> Session | None: + """Get a session by ID.""" + return await self._store.get_session(session_id) + + async def pause_session(self, session_id: str) -> Session | None: + """Pause an active session.""" + return await self._store.update_session_status(session_id, SessionStatus.PAUSED) + + async def resume_session(self, session_id: str) -> Session | None: + """Resume a paused session.""" + return await self._store.update_session_status(session_id, SessionStatus.ACTIVE) + + async def close_session(self, session_id: str) -> Session | None: + """Close a session. Closed sessions cannot accept new messages.""" + return await self._store.update_session_status(session_id, SessionStatus.CLOSED) + + async def delete_session(self, session_id: str) -> bool: + """Delete a session and all its messages.""" + return await self._store.delete_session(session_id) + + async def list_sessions( + self, + agent_name: str | None = None, + limit: int = 100, + ) -> list[Session]: + """List sessions, optionally filtered by agent name.""" + return await self._store.list_sessions(agent_name=agent_name, limit=limit) + + async def append_message( + self, + session_id: str, + role: MessageRole, + content: str, + tool_call_id: str | None = None, + agent_name: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Message: + """Append a message to a session. + + Args: + session_id: Target session ID. + role: Message role (user/assistant/tool/system). + content: Message content. + tool_call_id: Optional tool call ID for tool messages. + agent_name: Optional agent name for multi-Agent sessions. + metadata: Optional message metadata. + + Returns: + The newly created Message. + + Raises: + ValueError: If the session does not exist or is closed. + """ + session = await self._store.get_session(session_id) + if session is None: + raise ValueError(f"Session '{session_id}' not found") + if session.status == SessionStatus.CLOSED: + raise ValueError(f"Session '{session_id}' is closed and cannot accept new messages") + + message = Message( + message_id=Session.new_message_id(), + session_id=session_id, + role=role, + content=content, + tool_call_id=tool_call_id, + agent_name=agent_name, + metadata=metadata or {}, + ) + await self._store.append_message(message) + + # Update session's updated_at timestamp + session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc) + await self._store.save_session(session) + + return message + + async def get_messages( + self, + session_id: str, + limit: int | None = None, + offset: int = 0, + ) -> list[Message]: + """Get messages for a session with optional pagination. + + Args: + session_id: Target session ID. + limit: Maximum number of messages to return. None for all. + offset: Number of messages to skip from the beginning. + + Returns: + List of messages ordered chronologically. + """ + return await self._store.get_messages(session_id, limit=limit, offset=offset) + + async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]: + """Get messages formatted for LLM chat API consumption. + + Returns messages as OpenAI-compatible dicts suitable for + passing directly to the ReAct engine or LLM Gateway. + """ + messages = await self._store.get_messages(session_id) + return [m.to_chat_message() for m in messages] + + async def count_messages(self, session_id: str) -> int: + """Count messages in a session.""" + return await self._store.count_messages(session_id) + + async def health_check(self) -> bool: + """Check if the underlying store is healthy.""" + return await self._store.health_check() diff --git a/src/agentkit/session/models.py b/src/agentkit/session/models.py new file mode 100644 index 0000000..74a32b1 --- /dev/null +++ b/src/agentkit/session/models.py @@ -0,0 +1,125 @@ +"""Session and Message data models for multi-turn conversations.""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any + + +class SessionStatus(str, Enum): + """Session lifecycle states.""" + + ACTIVE = "active" + PAUSED = "paused" + CLOSED = "closed" + + +class MessageRole(str, Enum): + """Message role — mirrors OpenAI chat message roles.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class Message: + """A single message within a conversation session. + + Maps directly to the ``messages`` list consumed by the ReAct engine. + """ + + message_id: str + session_id: str + role: MessageRole + content: str + tool_call_id: str | None = None + agent_name: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + return { + "message_id": self.message_id, + "session_id": self.session_id, + "role": self.role.value, + "content": self.content, + "tool_call_id": self.tool_call_id, + "agent_name": self.agent_name, + "created_at": self.created_at.isoformat(), + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Message: + return cls( + message_id=data["message_id"], + session_id=data["session_id"], + role=MessageRole(data["role"]), + content=data["content"], + tool_call_id=data.get("tool_call_id"), + agent_name=data.get("agent_name"), + created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc), + metadata=data.get("metadata", {}), + ) + + def to_chat_message(self) -> dict[str, str]: + """Convert to OpenAI-compatible chat message dict. + + Returns a dict suitable for the ``messages`` parameter of LLM chat APIs. + """ + msg: dict[str, str] = {"role": self.role.value, "content": self.content} + if self.tool_call_id is not None: + msg["tool_call_id"] = self.tool_call_id + return msg + + +@dataclass +class Session: + """A conversation session binding a user to an Agent. + + Sessions track lifecycle state and accumulate Messages. They are + persisted via :class:`SessionStore` backends. + """ + + session_id: str + agent_name: str + status: SessionStatus = SessionStatus.ACTIVE + metadata: dict[str, Any] = field(default_factory=dict) + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + def to_dict(self) -> dict[str, Any]: + return { + "session_id": self.session_id, + "agent_name": self.agent_name, + "status": self.status.value, + "metadata": self.metadata, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Session: + return cls( + session_id=data["session_id"], + agent_name=data["agent_name"], + status=SessionStatus(data.get("status", "active")), + metadata=data.get("metadata", {}), + created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc), + updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(timezone.utc), + ) + + @staticmethod + def new_session_id() -> str: + """Generate a new session ID.""" + return str(uuid.uuid4()) + + @staticmethod + def new_message_id() -> str: + """Generate a new message ID.""" + return str(uuid.uuid4()) diff --git a/src/agentkit/session/store.py b/src/agentkit/session/store.py new file mode 100644 index 0000000..b16c7f7 --- /dev/null +++ b/src/agentkit/session/store.py @@ -0,0 +1,238 @@ +"""Session store backends — InMemory and Redis.""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Protocol, runtime_checkable + +from agentkit.session.models import Message, Session, SessionStatus + +logger = logging.getLogger(__name__) + + +@runtime_checkable +class SessionStore(Protocol): + """Protocol for session persistence backends.""" + + async def save_session(self, session: Session) -> None: ... + async def get_session(self, session_id: str) -> Session | None: ... + async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ... + async def delete_session(self, session_id: str) -> bool: ... + async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ... + + async def append_message(self, message: Message) -> None: ... + async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ... + async def count_messages(self, session_id: str) -> int: ... + + async def health_check(self) -> bool: ... + + +class InMemorySessionStore: + """In-memory session store for development and testing.""" + + def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000): + self._sessions: dict[str, Session] = {} + self._messages: dict[str, list[Message]] = {} + self._max_sessions = max_sessions + self._max_messages_per_session = max_messages_per_session + + async def save_session(self, session: Session) -> None: + if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions: + # Evict oldest closed session + closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED] + if closed: + oldest = min(closed, key=lambda s: s.updated_at) + del self._sessions[oldest.session_id] + self._messages.pop(oldest.session_id, None) + else: + raise RuntimeError("SessionStore is full and no closed sessions to evict") + self._sessions[session.session_id] = session + if session.session_id not in self._messages: + self._messages[session.session_id] = [] + + async def get_session(self, session_id: str) -> Session | None: + return self._sessions.get(session_id) + + async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: + session = self._sessions.get(session_id) + if session is None: + return None + session.status = status + session.updated_at = datetime.now(timezone.utc) + return session + + async def delete_session(self, session_id: str) -> bool: + if session_id in self._sessions: + del self._sessions[session_id] + self._messages.pop(session_id, None) + return True + return False + + async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: + sessions = list(self._sessions.values()) + if agent_name: + sessions = [s for s in sessions if s.agent_name == agent_name] + sessions.sort(key=lambda s: s.updated_at, reverse=True) + return sessions[:limit] + + async def append_message(self, message: Message) -> None: + msgs = self._messages.setdefault(message.session_id, []) + if len(msgs) >= self._max_messages_per_session: + # Remove oldest messages to stay within limit + excess = len(msgs) - self._max_messages_per_session + 1 + del msgs[:excess] + msgs.append(message) + + async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: + msgs = self._messages.get(session_id, []) + sliced = msgs[offset:] + if limit is not None: + sliced = sliced[:limit] + return sliced + + async def count_messages(self, session_id: str) -> int: + return len(self._messages.get(session_id, [])) + + async def health_check(self) -> bool: + return True + + +class RedisSessionStore: + """Redis-backed session store for production use. + + Key patterns: + - ``agentkit:session:{session_id}`` — session metadata (JSON + TTL) + - ``agentkit:session:{session_id}:messages`` — message list (Redis list) + """ + + KEY_PREFIX = "agentkit:session:" + MSG_SUFFIX = ":messages" + + def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400): + self._redis_url = redis_url + self._ttl_seconds = ttl_seconds + self._redis: Any = None + + async def _get_redis(self): + if self._redis is None: + import redis.asyncio as aioredis + + self._redis = aioredis.from_url(self._redis_url, decode_responses=True) + return self._redis + + def _session_key(self, session_id: str) -> str: + return f"{self.KEY_PREFIX}{session_id}" + + def _messages_key(self, session_id: str) -> str: + return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}" + + async def save_session(self, session: Session) -> None: + redis = await self._get_redis() + key = self._session_key(session.session_id) + await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds) + + async def get_session(self, session_id: str) -> Session | None: + redis = await self._get_redis() + raw = await redis.get(self._session_key(session_id)) + if raw is None: + return None + return Session.from_dict(json.loads(raw)) + + async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: + redis = await self._get_redis() + key = self._session_key(session_id) + raw = await redis.get(key) + if raw is None: + return None + session = Session.from_dict(json.loads(raw)) + session.status = status + session.updated_at = datetime.now(timezone.utc) + await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds) + return session + + async def delete_session(self, session_id: str) -> bool: + redis = await self._get_redis() + keys = [self._session_key(session_id), self._messages_key(session_id)] + deleted = await redis.delete(*keys) + return deleted > 0 + + async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: + redis = await self._get_redis() + sessions: list[Session] = [] + cursor = 0 + while True: + cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) + # Filter out message list keys + session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)] + if session_keys: + values = await redis.mget(session_keys) + for raw in values: + if raw is None: + continue + session = Session.from_dict(json.loads(raw)) + if agent_name is None or session.agent_name == agent_name: + sessions.append(session) + if cursor == 0: + break + sessions.sort(key=lambda s: s.updated_at, reverse=True) + return sessions[:limit] + + async def append_message(self, message: Message) -> None: + redis = await self._get_redis() + key = self._messages_key(message.session_id) + await redis.rpush(key, json.dumps(message.to_dict())) + # Set TTL on message list to match session TTL + await redis.expire(key, self._ttl_seconds) + + async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: + redis = await self._get_redis() + key = self._messages_key(session_id) + # Use LRANGE for offset-based pagination + # Redis list indices: 0-based, -1 = last element + start = offset + if limit is not None: + end = offset + limit - 1 + else: + end = -1 + raw_list = await redis.lrange(key, start, end) + return [Message.from_dict(json.loads(raw)) for raw in raw_list] + + async def count_messages(self, session_id: str) -> int: + redis = await self._get_redis() + return await redis.llen(self._messages_key(session_id)) + + async def health_check(self) -> bool: + try: + redis = await self._get_redis() + return await redis.ping() + except Exception: + return False + + +# Needed for from_dict deserialization +from datetime import datetime, timezone # noqa: E402 + + +def create_session_store( + backend: str = "memory", + redis_url: str = "redis://localhost:6379/0", + ttl_seconds: int = 86400, +) -> InMemorySessionStore | RedisSessionStore: + """Factory: create a SessionStore backed by memory or Redis. + + Falls back to InMemorySessionStore if Redis is unavailable. + """ + if backend == "redis": + try: + import redis.asyncio as aioredis # noqa: F401 + + store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds) + logger.info(f"SessionStore backend: redis") + return store + except Exception as exc: + logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore") + + store = InMemorySessionStore() + logger.info("SessionStore backend: memory") + return store diff --git a/tests/unit/test_session_manager.py b/tests/unit/test_session_manager.py new file mode 100644 index 0000000..d3195a6 --- /dev/null +++ b/tests/unit/test_session_manager.py @@ -0,0 +1,199 @@ +"""Tests for SessionManager.""" + +import pytest + +from agentkit.session.manager import SessionManager +from agentkit.session.models import MessageRole, SessionStatus +from agentkit.session.store import InMemorySessionStore + + +@pytest.fixture +def manager(): + return SessionManager(store=InMemorySessionStore()) + + +class TestSessionManagerCreate: + @pytest.mark.asyncio + async def test_create_session(self, manager): + session = await manager.create_session(agent_name="test-agent") + assert session.session_id is not None + assert session.agent_name == "test-agent" + assert session.status == SessionStatus.ACTIVE + + @pytest.mark.asyncio + async def test_create_session_with_metadata(self, manager): + session = await manager.create_session( + agent_name="agent1", + metadata={"user_id": "u1"}, + ) + assert session.metadata == {"user_id": "u1"} + + +class TestSessionManagerGet: + @pytest.mark.asyncio + async def test_get_existing_session(self, manager): + created = await manager.create_session(agent_name="agent1") + fetched = await manager.get_session(created.session_id) + assert fetched is not None + assert fetched.session_id == created.session_id + + @pytest.mark.asyncio + async def test_get_nonexistent_session(self, manager): + result = await manager.get_session("nonexistent") + assert result is None + + +class TestSessionManagerLifecycle: + @pytest.mark.asyncio + async def test_pause_and_resume(self, manager): + session = await manager.create_session(agent_name="agent1") + paused = await manager.pause_session(session.session_id) + assert paused.status == SessionStatus.PAUSED + + resumed = await manager.resume_session(session.session_id) + assert resumed.status == SessionStatus.ACTIVE + + @pytest.mark.asyncio + async def test_close_session(self, manager): + session = await manager.create_session(agent_name="agent1") + closed = await manager.close_session(session.session_id) + assert closed.status == SessionStatus.CLOSED + + @pytest.mark.asyncio + async def test_close_nonexistent_returns_none(self, manager): + result = await manager.close_session("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_delete_session(self, manager): + session = await manager.create_session(agent_name="agent1") + deleted = await manager.delete_session(session.session_id) + assert deleted is True + assert await manager.get_session(session.session_id) is None + + @pytest.mark.asyncio + async def test_delete_nonexistent_returns_false(self, manager): + deleted = await manager.delete_session("nonexistent") + assert deleted is False + + +class TestSessionManagerMessages: + @pytest.mark.asyncio + async def test_append_user_message(self, manager): + session = await manager.create_session(agent_name="agent1") + msg = await manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Hello", + ) + assert msg.role == MessageRole.USER + assert msg.content == "Hello" + assert msg.session_id == session.session_id + + @pytest.mark.asyncio + async def test_append_assistant_message(self, manager): + session = await manager.create_session(agent_name="agent1") + msg = await manager.append_message( + session_id=session.session_id, + role=MessageRole.ASSISTANT, + content="Hi there!", + ) + assert msg.role == MessageRole.ASSISTANT + + @pytest.mark.asyncio + async def test_get_messages(self, manager): + session = await manager.create_session(agent_name="agent1") + await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello") + await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!") + + messages = await manager.get_messages(session.session_id) + assert len(messages) == 2 + assert messages[0].content == "Hello" + assert messages[1].content == "Hi!" + + @pytest.mark.asyncio + async def test_get_messages_pagination(self, manager): + session = await manager.create_session(agent_name="agent1") + for i in range(10): + await manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content=f"Message {i}", + ) + + # Get first 3 messages + page1 = await manager.get_messages(session.session_id, limit=3, offset=0) + assert len(page1) == 3 + assert page1[0].content == "Message 0" + + # Get next 3 messages + page2 = await manager.get_messages(session.session_id, limit=3, offset=3) + assert len(page2) == 3 + assert page2[0].content == "Message 3" + + @pytest.mark.asyncio + async def test_count_messages(self, manager): + session = await manager.create_session(agent_name="agent1") + await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello") + await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!") + + count = await manager.count_messages(session.session_id) + assert count == 2 + + @pytest.mark.asyncio + async def test_closed_session_rejects_messages(self, manager): + session = await manager.create_session(agent_name="agent1") + await manager.close_session(session.session_id) + + with pytest.raises(ValueError, match="closed"): + await manager.append_message( + session_id=session.session_id, + role=MessageRole.USER, + content="Should fail", + ) + + @pytest.mark.asyncio + async def test_nonexistent_session_rejects_messages(self, manager): + with pytest.raises(ValueError, match="not found"): + await manager.append_message( + session_id="nonexistent", + role=MessageRole.USER, + content="Should fail", + ) + + @pytest.mark.asyncio + async def test_get_chat_messages(self, manager): + session = await manager.create_session(agent_name="agent1") + await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello") + await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!") + + chat_msgs = await manager.get_chat_messages(session.session_id) + assert len(chat_msgs) == 2 + assert chat_msgs[0] == {"role": "user", "content": "Hello"} + assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"} + + +class TestSessionManagerList: + @pytest.mark.asyncio + async def test_list_sessions(self, manager): + await manager.create_session(agent_name="agent1") + await manager.create_session(agent_name="agent2") + + sessions = await manager.list_sessions() + assert len(sessions) == 2 + + @pytest.mark.asyncio + async def test_list_sessions_by_agent(self, manager): + await manager.create_session(agent_name="agent1") + await manager.create_session(agent_name="agent2") + await manager.create_session(agent_name="agent1") + + sessions = await manager.list_sessions(agent_name="agent1") + assert len(sessions) == 2 + assert all(s.agent_name == "agent1" for s in sessions) + + +class TestSessionManagerHealth: + @pytest.mark.asyncio + async def test_health_check(self, manager): + assert await manager.health_check() is True diff --git a/tests/unit/test_session_models.py b/tests/unit/test_session_models.py new file mode 100644 index 0000000..b386566 --- /dev/null +++ b/tests/unit/test_session_models.py @@ -0,0 +1,146 @@ +"""Tests for Session and Message data models.""" + +import pytest + +from agentkit.session.models import Message, MessageRole, Session, SessionStatus + + +class TestSessionStatus: + def test_status_values(self): + assert SessionStatus.ACTIVE == "active" + assert SessionStatus.PAUSED == "paused" + assert SessionStatus.CLOSED == "closed" + + def test_status_from_string(self): + assert SessionStatus("active") == SessionStatus.ACTIVE + assert SessionStatus("paused") == SessionStatus.PAUSED + assert SessionStatus("closed") == SessionStatus.CLOSED + + +class TestMessageRole: + def test_role_values(self): + assert MessageRole.SYSTEM == "system" + assert MessageRole.USER == "user" + assert MessageRole.ASSISTANT == "assistant" + assert MessageRole.TOOL == "tool" + + +class TestSession: + def test_create_session(self): + session = Session(session_id="s1", agent_name="test-agent") + assert session.session_id == "s1" + assert session.agent_name == "test-agent" + assert session.status == SessionStatus.ACTIVE + assert session.metadata == {} + assert session.created_at is not None + assert session.updated_at is not None + + def test_session_to_dict_and_back(self): + session = Session( + session_id="s1", + agent_name="agent1", + status=SessionStatus.PAUSED, + metadata={"key": "value"}, + ) + d = session.to_dict() + assert d["session_id"] == "s1" + assert d["agent_name"] == "agent1" + assert d["status"] == "paused" + assert d["metadata"] == {"key": "value"} + + restored = Session.from_dict(d) + assert restored.session_id == session.session_id + assert restored.agent_name == session.agent_name + assert restored.status == session.status + assert restored.metadata == session.metadata + + def test_new_session_id_is_unique(self): + ids = {Session.new_session_id() for _ in range(100)} + assert len(ids) == 100 + + def test_new_message_id_is_unique(self): + ids = {Session.new_message_id() for _ in range(100)} + assert len(ids) == 100 + + +class TestMessage: + def test_create_message(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.USER, + content="Hello", + ) + assert msg.message_id == "m1" + assert msg.session_id == "s1" + assert msg.role == MessageRole.USER + assert msg.content == "Hello" + assert msg.tool_call_id is None + assert msg.agent_name is None + assert msg.metadata == {} + + def test_message_with_tool_call(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.TOOL, + content="result", + tool_call_id="tc1", + agent_name="agent1", + ) + assert msg.tool_call_id == "tc1" + assert msg.agent_name == "agent1" + + def test_message_to_dict_and_back(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.ASSISTANT, + content="Hi there", + tool_call_id="tc1", + agent_name="agent1", + metadata={"step": 1}, + ) + d = msg.to_dict() + assert d["message_id"] == "m1" + assert d["role"] == "assistant" + assert d["tool_call_id"] == "tc1" + + restored = Message.from_dict(d) + assert restored.message_id == msg.message_id + assert restored.role == msg.role + assert restored.content == msg.content + assert restored.tool_call_id == msg.tool_call_id + assert restored.agent_name == msg.agent_name + assert restored.metadata == msg.metadata + + def test_to_chat_message_user(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.USER, + content="Hello", + ) + chat_msg = msg.to_chat_message() + assert chat_msg == {"role": "user", "content": "Hello"} + + def test_to_chat_message_tool(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.TOOL, + content="result", + tool_call_id="tc1", + ) + chat_msg = msg.to_chat_message() + assert chat_msg == {"role": "tool", "content": "result", "tool_call_id": "tc1"} + + def test_to_chat_message_no_tool_call_id(self): + msg = Message( + message_id="m1", + session_id="s1", + role=MessageRole.ASSISTANT, + content="Hi", + ) + chat_msg = msg.to_chat_message() + assert "tool_call_id" not in chat_msg diff --git a/tests/unit/test_session_store.py b/tests/unit/test_session_store.py new file mode 100644 index 0000000..0d224db --- /dev/null +++ b/tests/unit/test_session_store.py @@ -0,0 +1,157 @@ +"""Tests for InMemorySessionStore.""" + +import pytest + +from agentkit.session.models import Message, MessageRole, Session, SessionStatus +from agentkit.session.store import InMemorySessionStore + + +@pytest.fixture +def store(): + return InMemorySessionStore() + + +async def _create_session(store, session_id="s1", agent_name="agent1"): + session = Session(session_id=session_id, agent_name=agent_name) + await store.save_session(session) + return session + + +async def _create_message(store, session_id, role=MessageRole.USER, content="Hello"): + msg = Message( + message_id=Session.new_message_id(), + session_id=session_id, + role=role, + content=content, + ) + await store.append_message(msg) + return msg + + +class TestInMemorySessionStoreCRUD: + @pytest.mark.asyncio + async def test_save_and_get(self, store): + session = await _create_session(store) + fetched = await store.get_session("s1") + assert fetched is not None + assert fetched.session_id == "s1" + + @pytest.mark.asyncio + async def test_get_nonexistent(self, store): + result = await store.get_session("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_update_status(self, store): + await _create_session(store) + updated = await store.update_session_status("s1", SessionStatus.PAUSED) + assert updated is not None + assert updated.status == SessionStatus.PAUSED + + @pytest.mark.asyncio + async def test_update_status_nonexistent(self, store): + result = await store.update_session_status("nonexistent", SessionStatus.PAUSED) + assert result is None + + @pytest.mark.asyncio + async def test_delete(self, store): + await _create_session(store) + assert await store.delete_session("s1") is True + assert await store.get_session("s1") is None + + @pytest.mark.asyncio + async def test_delete_nonexistent(self, store): + assert await store.delete_session("nonexistent") is False + + @pytest.mark.asyncio + async def test_list_sessions(self, store): + await _create_session(store, "s1", "agent1") + await _create_session(store, "s2", "agent2") + sessions = await store.list_sessions() + assert len(sessions) == 2 + + @pytest.mark.asyncio + async def test_list_sessions_by_agent(self, store): + await _create_session(store, "s1", "agent1") + await _create_session(store, "s2", "agent2") + sessions = await store.list_sessions(agent_name="agent1") + assert len(sessions) == 1 + assert sessions[0].agent_name == "agent1" + + +class TestInMemorySessionStoreMessages: + @pytest.mark.asyncio + async def test_append_and_get(self, store): + await _create_session(store) + await _create_message(store, "s1", content="Hello") + await _create_message(store, "s1", content="World") + + messages = await store.get_messages("s1") + assert len(messages) == 2 + assert messages[0].content == "Hello" + assert messages[1].content == "World" + + @pytest.mark.asyncio + async def test_get_messages_pagination(self, store): + await _create_session(store) + for i in range(5): + await _create_message(store, "s1", content=f"Msg {i}") + + page = await store.get_messages("s1", limit=2, offset=1) + assert len(page) == 2 + assert page[0].content == "Msg 1" + assert page[1].content == "Msg 2" + + @pytest.mark.asyncio + async def test_count_messages(self, store): + await _create_session(store) + await _create_message(store, "s1") + await _create_message(store, "s1") + assert await store.count_messages("s1") == 2 + + @pytest.mark.asyncio + async def test_count_messages_empty_session(self, store): + assert await store.count_messages("nonexistent") == 0 + + @pytest.mark.asyncio + async def test_get_messages_empty_session(self, store): + messages = await store.get_messages("nonexistent") + assert messages == [] + + @pytest.mark.asyncio + async def test_delete_session_removes_messages(self, store): + await _create_session(store) + await _create_message(store, "s1") + await store.delete_session("s1") + assert await store.count_messages("s1") == 0 + + +class TestInMemorySessionStoreEviction: + @pytest.mark.asyncio + async def test_evict_closed_session_on_full(self): + store = InMemorySessionStore(max_sessions=2) + s1 = await _create_session(store, "s1") + await _create_session(store, "s2") + + # Close s1 so it can be evicted + await store.update_session_status("s1", SessionStatus.CLOSED) + + # Creating a third session should evict s1 + await _create_session(store, "s3") + assert await store.get_session("s1") is None + assert await store.get_session("s2") is not None + assert await store.get_session("s3") is not None + + @pytest.mark.asyncio + async def test_full_no_closed_raises(self): + store = InMemorySessionStore(max_sessions=2) + await _create_session(store, "s1") + await _create_session(store, "s2") + with pytest.raises(RuntimeError, match="full"): + await _create_session(store, "s3") + + +class TestInMemorySessionStoreHealth: + @pytest.mark.asyncio + async def test_health_check(self, store): + assert await store.health_check() is True