feat(session): add Session/Message models and SessionManager with InMemory/Redis stores
This commit is contained in:
parent
e4d6efb4bf
commit
493187782c
|
|
@ -0,0 +1,16 @@
|
|||
"""Session management - multi-turn conversation support for AgentKit."""
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.store import InMemorySessionStore, RedisSessionStore, create_session_store
|
||||
|
||||
__all__ = [
|
||||
"Message",
|
||||
"MessageRole",
|
||||
"Session",
|
||||
"SessionStatus",
|
||||
"SessionManager",
|
||||
"InMemorySessionStore",
|
||||
"RedisSessionStore",
|
||||
"create_session_store",
|
||||
]
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
"""SessionManager — high-level API for conversation session management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
from agentkit.session.store import InMemorySessionStore, SessionStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages conversation sessions and their messages.
|
||||
|
||||
Provides a high-level API for creating, querying, and updating
|
||||
sessions, as well as appending and retrieving messages.
|
||||
"""
|
||||
|
||||
def __init__(self, store: SessionStore | None = None):
|
||||
self._store = store or InMemorySessionStore()
|
||||
|
||||
@property
|
||||
def store(self) -> SessionStore:
|
||||
return self._store
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
agent_name: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Session:
|
||||
"""Create a new conversation session bound to an Agent.
|
||||
|
||||
Args:
|
||||
agent_name: Name of the Agent this session is bound to.
|
||||
metadata: Optional metadata to attach to the session.
|
||||
|
||||
Returns:
|
||||
The newly created Session.
|
||||
"""
|
||||
session = Session(
|
||||
session_id=Session.new_session_id(),
|
||||
agent_name=agent_name,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._store.save_session(session)
|
||||
logger.info(f"Session created: {session.session_id} for agent '{agent_name}'")
|
||||
return session
|
||||
|
||||
async def get_session(self, session_id: str) -> Session | None:
|
||||
"""Get a session by ID."""
|
||||
return await self._store.get_session(session_id)
|
||||
|
||||
async def pause_session(self, session_id: str) -> Session | None:
|
||||
"""Pause an active session."""
|
||||
return await self._store.update_session_status(session_id, SessionStatus.PAUSED)
|
||||
|
||||
async def resume_session(self, session_id: str) -> Session | None:
|
||||
"""Resume a paused session."""
|
||||
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
|
||||
|
||||
async def close_session(self, session_id: str) -> Session | None:
|
||||
"""Close a session. Closed sessions cannot accept new messages."""
|
||||
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages."""
|
||||
return await self._store.delete_session(session_id)
|
||||
|
||||
async def list_sessions(
|
||||
self,
|
||||
agent_name: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Session]:
|
||||
"""List sessions, optionally filtered by agent name."""
|
||||
return await self._store.list_sessions(agent_name=agent_name, limit=limit)
|
||||
|
||||
async def append_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: MessageRole,
|
||||
content: str,
|
||||
tool_call_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Message:
|
||||
"""Append a message to a session.
|
||||
|
||||
Args:
|
||||
session_id: Target session ID.
|
||||
role: Message role (user/assistant/tool/system).
|
||||
content: Message content.
|
||||
tool_call_id: Optional tool call ID for tool messages.
|
||||
agent_name: Optional agent name for multi-Agent sessions.
|
||||
metadata: Optional message metadata.
|
||||
|
||||
Returns:
|
||||
The newly created Message.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist or is closed.
|
||||
"""
|
||||
session = await self._store.get_session(session_id)
|
||||
if session is None:
|
||||
raise ValueError(f"Session '{session_id}' not found")
|
||||
if session.status == SessionStatus.CLOSED:
|
||||
raise ValueError(f"Session '{session_id}' is closed and cannot accept new messages")
|
||||
|
||||
message = Message(
|
||||
message_id=Session.new_message_id(),
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
agent_name=agent_name,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._store.append_message(message)
|
||||
|
||||
# Update session's updated_at timestamp
|
||||
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
|
||||
await self._store.save_session(session)
|
||||
|
||||
return message
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
) -> list[Message]:
|
||||
"""Get messages for a session with optional pagination.
|
||||
|
||||
Args:
|
||||
session_id: Target session ID.
|
||||
limit: Maximum number of messages to return. None for all.
|
||||
offset: Number of messages to skip from the beginning.
|
||||
|
||||
Returns:
|
||||
List of messages ordered chronologically.
|
||||
"""
|
||||
return await self._store.get_messages(session_id, limit=limit, offset=offset)
|
||||
|
||||
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
|
||||
"""Get messages formatted for LLM chat API consumption.
|
||||
|
||||
Returns messages as OpenAI-compatible dicts suitable for
|
||||
passing directly to the ReAct engine or LLM Gateway.
|
||||
"""
|
||||
messages = await self._store.get_messages(session_id)
|
||||
return [m.to_chat_message() for m in messages]
|
||||
|
||||
async def count_messages(self, session_id: str) -> int:
|
||||
"""Count messages in a session."""
|
||||
return await self._store.count_messages(session_id)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the underlying store is healthy."""
|
||||
return await self._store.health_check()
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
"""Session and Message data models for multi-turn conversations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SessionStatus(str, Enum):
|
||||
"""Session lifecycle states."""
|
||||
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Message role — mirrors OpenAI chat message roles."""
|
||||
|
||||
SYSTEM = "system"
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""A single message within a conversation session.
|
||||
|
||||
Maps directly to the ``messages`` list consumed by the ReAct engine.
|
||||
"""
|
||||
|
||||
message_id: str
|
||||
session_id: str
|
||||
role: MessageRole
|
||||
content: str
|
||||
tool_call_id: str | None = None
|
||||
agent_name: str | None = None
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"message_id": self.message_id,
|
||||
"session_id": self.session_id,
|
||||
"role": self.role.value,
|
||||
"content": self.content,
|
||||
"tool_call_id": self.tool_call_id,
|
||||
"agent_name": self.agent_name,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||
return cls(
|
||||
message_id=data["message_id"],
|
||||
session_id=data["session_id"],
|
||||
role=MessageRole(data["role"]),
|
||||
content=data["content"],
|
||||
tool_call_id=data.get("tool_call_id"),
|
||||
agent_name=data.get("agent_name"),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def to_chat_message(self) -> dict[str, str]:
|
||||
"""Convert to OpenAI-compatible chat message dict.
|
||||
|
||||
Returns a dict suitable for the ``messages`` parameter of LLM chat APIs.
|
||||
"""
|
||||
msg: dict[str, str] = {"role": self.role.value, "content": self.content}
|
||||
if self.tool_call_id is not None:
|
||||
msg["tool_call_id"] = self.tool_call_id
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""A conversation session binding a user to an Agent.
|
||||
|
||||
Sessions track lifecycle state and accumulate Messages. They are
|
||||
persisted via :class:`SessionStore` backends.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
agent_name: str
|
||||
status: SessionStatus = SessionStatus.ACTIVE
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"agent_name": self.agent_name,
|
||||
"status": self.status.value,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Session:
|
||||
return cls(
|
||||
session_id=data["session_id"],
|
||||
agent_name=data["agent_name"],
|
||||
status=SessionStatus(data.get("status", "active")),
|
||||
metadata=data.get("metadata", {}),
|
||||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def new_session_id() -> str:
|
||||
"""Generate a new session ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@staticmethod
|
||||
def new_message_id() -> str:
|
||||
"""Generate a new message ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
|
@ -0,0 +1,238 @@
|
|||
"""Session store backends — InMemory and Redis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
from agentkit.session.models import Message, Session, SessionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SessionStore(Protocol):
|
||||
"""Protocol for session persistence backends."""
|
||||
|
||||
async def save_session(self, session: Session) -> None: ...
|
||||
async def get_session(self, session_id: str) -> Session | None: ...
|
||||
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ...
|
||||
async def delete_session(self, session_id: str) -> bool: ...
|
||||
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ...
|
||||
|
||||
async def append_message(self, message: Message) -> None: ...
|
||||
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ...
|
||||
async def count_messages(self, session_id: str) -> int: ...
|
||||
|
||||
async def health_check(self) -> bool: ...
|
||||
|
||||
|
||||
class InMemorySessionStore:
|
||||
"""In-memory session store for development and testing."""
|
||||
|
||||
def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000):
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._messages: dict[str, list[Message]] = {}
|
||||
self._max_sessions = max_sessions
|
||||
self._max_messages_per_session = max_messages_per_session
|
||||
|
||||
async def save_session(self, session: Session) -> None:
|
||||
if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions:
|
||||
# Evict oldest closed session
|
||||
closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED]
|
||||
if closed:
|
||||
oldest = min(closed, key=lambda s: s.updated_at)
|
||||
del self._sessions[oldest.session_id]
|
||||
self._messages.pop(oldest.session_id, None)
|
||||
else:
|
||||
raise RuntimeError("SessionStore is full and no closed sessions to evict")
|
||||
self._sessions[session.session_id] = session
|
||||
if session.session_id not in self._messages:
|
||||
self._messages[session.session_id] = []
|
||||
|
||||
async def get_session(self, session_id: str) -> Session | None:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
return None
|
||||
session.status = status
|
||||
session.updated_at = datetime.now(timezone.utc)
|
||||
return session
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
self._messages.pop(session_id, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
|
||||
sessions = list(self._sessions.values())
|
||||
if agent_name:
|
||||
sessions = [s for s in sessions if s.agent_name == agent_name]
|
||||
sessions.sort(key=lambda s: s.updated_at, reverse=True)
|
||||
return sessions[:limit]
|
||||
|
||||
async def append_message(self, message: Message) -> None:
|
||||
msgs = self._messages.setdefault(message.session_id, [])
|
||||
if len(msgs) >= self._max_messages_per_session:
|
||||
# Remove oldest messages to stay within limit
|
||||
excess = len(msgs) - self._max_messages_per_session + 1
|
||||
del msgs[:excess]
|
||||
msgs.append(message)
|
||||
|
||||
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
|
||||
msgs = self._messages.get(session_id, [])
|
||||
sliced = msgs[offset:]
|
||||
if limit is not None:
|
||||
sliced = sliced[:limit]
|
||||
return sliced
|
||||
|
||||
async def count_messages(self, session_id: str) -> int:
|
||||
return len(self._messages.get(session_id, []))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis-backed session store for production use.
|
||||
|
||||
Key patterns:
|
||||
- ``agentkit:session:{session_id}`` — session metadata (JSON + TTL)
|
||||
- ``agentkit:session:{session_id}:messages`` — message list (Redis list)
|
||||
"""
|
||||
|
||||
KEY_PREFIX = "agentkit:session:"
|
||||
MSG_SUFFIX = ":messages"
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
|
||||
self._redis_url = redis_url
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._redis: Any = None
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
|
||||
def _session_key(self, session_id: str) -> str:
|
||||
return f"{self.KEY_PREFIX}{session_id}"
|
||||
|
||||
def _messages_key(self, session_id: str) -> str:
|
||||
return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}"
|
||||
|
||||
async def save_session(self, session: Session) -> None:
|
||||
redis = await self._get_redis()
|
||||
key = self._session_key(session.session_id)
|
||||
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
|
||||
|
||||
async def get_session(self, session_id: str) -> Session | None:
|
||||
redis = await self._get_redis()
|
||||
raw = await redis.get(self._session_key(session_id))
|
||||
if raw is None:
|
||||
return None
|
||||
return Session.from_dict(json.loads(raw))
|
||||
|
||||
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
|
||||
redis = await self._get_redis()
|
||||
key = self._session_key(session_id)
|
||||
raw = await redis.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
session = Session.from_dict(json.loads(raw))
|
||||
session.status = status
|
||||
session.updated_at = datetime.now(timezone.utc)
|
||||
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
|
||||
return session
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
redis = await self._get_redis()
|
||||
keys = [self._session_key(session_id), self._messages_key(session_id)]
|
||||
deleted = await redis.delete(*keys)
|
||||
return deleted > 0
|
||||
|
||||
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
|
||||
redis = await self._get_redis()
|
||||
sessions: list[Session] = []
|
||||
cursor = 0
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||||
# Filter out message list keys
|
||||
session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)]
|
||||
if session_keys:
|
||||
values = await redis.mget(session_keys)
|
||||
for raw in values:
|
||||
if raw is None:
|
||||
continue
|
||||
session = Session.from_dict(json.loads(raw))
|
||||
if agent_name is None or session.agent_name == agent_name:
|
||||
sessions.append(session)
|
||||
if cursor == 0:
|
||||
break
|
||||
sessions.sort(key=lambda s: s.updated_at, reverse=True)
|
||||
return sessions[:limit]
|
||||
|
||||
async def append_message(self, message: Message) -> None:
|
||||
redis = await self._get_redis()
|
||||
key = self._messages_key(message.session_id)
|
||||
await redis.rpush(key, json.dumps(message.to_dict()))
|
||||
# Set TTL on message list to match session TTL
|
||||
await redis.expire(key, self._ttl_seconds)
|
||||
|
||||
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
|
||||
redis = await self._get_redis()
|
||||
key = self._messages_key(session_id)
|
||||
# Use LRANGE for offset-based pagination
|
||||
# Redis list indices: 0-based, -1 = last element
|
||||
start = offset
|
||||
if limit is not None:
|
||||
end = offset + limit - 1
|
||||
else:
|
||||
end = -1
|
||||
raw_list = await redis.lrange(key, start, end)
|
||||
return [Message.from_dict(json.loads(raw)) for raw in raw_list]
|
||||
|
||||
async def count_messages(self, session_id: str) -> int:
|
||||
redis = await self._get_redis()
|
||||
return await redis.llen(self._messages_key(session_id))
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.ping()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Needed for from_dict deserialization
|
||||
from datetime import datetime, timezone # noqa: E402
|
||||
|
||||
|
||||
def create_session_store(
|
||||
backend: str = "memory",
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
ttl_seconds: int = 86400,
|
||||
) -> InMemorySessionStore | RedisSessionStore:
|
||||
"""Factory: create a SessionStore backed by memory or Redis.
|
||||
|
||||
Falls back to InMemorySessionStore if Redis is unavailable.
|
||||
"""
|
||||
if backend == "redis":
|
||||
try:
|
||||
import redis.asyncio as aioredis # noqa: F401
|
||||
|
||||
store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds)
|
||||
logger.info(f"SessionStore backend: redis")
|
||||
return store
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore")
|
||||
|
||||
store = InMemorySessionStore()
|
||||
logger.info("SessionStore backend: memory")
|
||||
return store
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
"""Tests for SessionManager."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.session.manager import SessionManager
|
||||
from agentkit.session.models import MessageRole, SessionStatus
|
||||
from agentkit.session.store import InMemorySessionStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager():
|
||||
return SessionManager(store=InMemorySessionStore())
|
||||
|
||||
|
||||
class TestSessionManagerCreate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self, manager):
|
||||
session = await manager.create_session(agent_name="test-agent")
|
||||
assert session.session_id is not None
|
||||
assert session.agent_name == "test-agent"
|
||||
assert session.status == SessionStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_with_metadata(self, manager):
|
||||
session = await manager.create_session(
|
||||
agent_name="agent1",
|
||||
metadata={"user_id": "u1"},
|
||||
)
|
||||
assert session.metadata == {"user_id": "u1"}
|
||||
|
||||
|
||||
class TestSessionManagerGet:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_session(self, manager):
|
||||
created = await manager.create_session(agent_name="agent1")
|
||||
fetched = await manager.get_session(created.session_id)
|
||||
assert fetched is not None
|
||||
assert fetched.session_id == created.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_session(self, manager):
|
||||
result = await manager.get_session("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSessionManagerLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_and_resume(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
paused = await manager.pause_session(session.session_id)
|
||||
assert paused.status == SessionStatus.PAUSED
|
||||
|
||||
resumed = await manager.resume_session(session.session_id)
|
||||
assert resumed.status == SessionStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_session(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
closed = await manager.close_session(session.session_id)
|
||||
assert closed.status == SessionStatus.CLOSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_nonexistent_returns_none(self, manager):
|
||||
result = await manager.close_session("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
deleted = await manager.delete_session(session.session_id)
|
||||
assert deleted is True
|
||||
assert await manager.get_session(session.session_id) is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_returns_false(self, manager):
|
||||
deleted = await manager.delete_session("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestSessionManagerMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_user_message(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
msg = await manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Hello",
|
||||
)
|
||||
assert msg.role == MessageRole.USER
|
||||
assert msg.content == "Hello"
|
||||
assert msg.session_id == session.session_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_assistant_message(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
msg = await manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="Hi there!",
|
||||
)
|
||||
assert msg.role == MessageRole.ASSISTANT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
||||
|
||||
messages = await manager.get_messages(session.session_id)
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hello"
|
||||
assert messages[1].content == "Hi!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_pagination(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
for i in range(10):
|
||||
await manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content=f"Message {i}",
|
||||
)
|
||||
|
||||
# Get first 3 messages
|
||||
page1 = await manager.get_messages(session.session_id, limit=3, offset=0)
|
||||
assert len(page1) == 3
|
||||
assert page1[0].content == "Message 0"
|
||||
|
||||
# Get next 3 messages
|
||||
page2 = await manager.get_messages(session.session_id, limit=3, offset=3)
|
||||
assert len(page2) == 3
|
||||
assert page2[0].content == "Message 3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_messages(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
||||
|
||||
count = await manager.count_messages(session.session_id)
|
||||
assert count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_session_rejects_messages(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
await manager.close_session(session.session_id)
|
||||
|
||||
with pytest.raises(ValueError, match="closed"):
|
||||
await manager.append_message(
|
||||
session_id=session.session_id,
|
||||
role=MessageRole.USER,
|
||||
content="Should fail",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_session_rejects_messages(self, manager):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
await manager.append_message(
|
||||
session_id="nonexistent",
|
||||
role=MessageRole.USER,
|
||||
content="Should fail",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chat_messages(self, manager):
|
||||
session = await manager.create_session(agent_name="agent1")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
||||
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
||||
|
||||
chat_msgs = await manager.get_chat_messages(session.session_id)
|
||||
assert len(chat_msgs) == 2
|
||||
assert chat_msgs[0] == {"role": "user", "content": "Hello"}
|
||||
assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"}
|
||||
|
||||
|
||||
class TestSessionManagerList:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions(self, manager):
|
||||
await manager.create_session(agent_name="agent1")
|
||||
await manager.create_session(agent_name="agent2")
|
||||
|
||||
sessions = await manager.list_sessions()
|
||||
assert len(sessions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_by_agent(self, manager):
|
||||
await manager.create_session(agent_name="agent1")
|
||||
await manager.create_session(agent_name="agent2")
|
||||
await manager.create_session(agent_name="agent1")
|
||||
|
||||
sessions = await manager.list_sessions(agent_name="agent1")
|
||||
assert len(sessions) == 2
|
||||
assert all(s.agent_name == "agent1" for s in sessions)
|
||||
|
||||
|
||||
class TestSessionManagerHealth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, manager):
|
||||
assert await manager.health_check() is True
|
||||
|
|
@ -0,0 +1,146 @@
|
|||
"""Tests for Session and Message data models."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
|
||||
|
||||
class TestSessionStatus:
|
||||
def test_status_values(self):
|
||||
assert SessionStatus.ACTIVE == "active"
|
||||
assert SessionStatus.PAUSED == "paused"
|
||||
assert SessionStatus.CLOSED == "closed"
|
||||
|
||||
def test_status_from_string(self):
|
||||
assert SessionStatus("active") == SessionStatus.ACTIVE
|
||||
assert SessionStatus("paused") == SessionStatus.PAUSED
|
||||
assert SessionStatus("closed") == SessionStatus.CLOSED
|
||||
|
||||
|
||||
class TestMessageRole:
|
||||
def test_role_values(self):
|
||||
assert MessageRole.SYSTEM == "system"
|
||||
assert MessageRole.USER == "user"
|
||||
assert MessageRole.ASSISTANT == "assistant"
|
||||
assert MessageRole.TOOL == "tool"
|
||||
|
||||
|
||||
class TestSession:
|
||||
def test_create_session(self):
|
||||
session = Session(session_id="s1", agent_name="test-agent")
|
||||
assert session.session_id == "s1"
|
||||
assert session.agent_name == "test-agent"
|
||||
assert session.status == SessionStatus.ACTIVE
|
||||
assert session.metadata == {}
|
||||
assert session.created_at is not None
|
||||
assert session.updated_at is not None
|
||||
|
||||
def test_session_to_dict_and_back(self):
|
||||
session = Session(
|
||||
session_id="s1",
|
||||
agent_name="agent1",
|
||||
status=SessionStatus.PAUSED,
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
d = session.to_dict()
|
||||
assert d["session_id"] == "s1"
|
||||
assert d["agent_name"] == "agent1"
|
||||
assert d["status"] == "paused"
|
||||
assert d["metadata"] == {"key": "value"}
|
||||
|
||||
restored = Session.from_dict(d)
|
||||
assert restored.session_id == session.session_id
|
||||
assert restored.agent_name == session.agent_name
|
||||
assert restored.status == session.status
|
||||
assert restored.metadata == session.metadata
|
||||
|
||||
def test_new_session_id_is_unique(self):
|
||||
ids = {Session.new_session_id() for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
def test_new_message_id_is_unique(self):
|
||||
ids = {Session.new_message_id() for _ in range(100)}
|
||||
assert len(ids) == 100
|
||||
|
||||
|
||||
class TestMessage:
|
||||
def test_create_message(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.USER,
|
||||
content="Hello",
|
||||
)
|
||||
assert msg.message_id == "m1"
|
||||
assert msg.session_id == "s1"
|
||||
assert msg.role == MessageRole.USER
|
||||
assert msg.content == "Hello"
|
||||
assert msg.tool_call_id is None
|
||||
assert msg.agent_name is None
|
||||
assert msg.metadata == {}
|
||||
|
||||
def test_message_with_tool_call(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.TOOL,
|
||||
content="result",
|
||||
tool_call_id="tc1",
|
||||
agent_name="agent1",
|
||||
)
|
||||
assert msg.tool_call_id == "tc1"
|
||||
assert msg.agent_name == "agent1"
|
||||
|
||||
def test_message_to_dict_and_back(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="Hi there",
|
||||
tool_call_id="tc1",
|
||||
agent_name="agent1",
|
||||
metadata={"step": 1},
|
||||
)
|
||||
d = msg.to_dict()
|
||||
assert d["message_id"] == "m1"
|
||||
assert d["role"] == "assistant"
|
||||
assert d["tool_call_id"] == "tc1"
|
||||
|
||||
restored = Message.from_dict(d)
|
||||
assert restored.message_id == msg.message_id
|
||||
assert restored.role == msg.role
|
||||
assert restored.content == msg.content
|
||||
assert restored.tool_call_id == msg.tool_call_id
|
||||
assert restored.agent_name == msg.agent_name
|
||||
assert restored.metadata == msg.metadata
|
||||
|
||||
def test_to_chat_message_user(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.USER,
|
||||
content="Hello",
|
||||
)
|
||||
chat_msg = msg.to_chat_message()
|
||||
assert chat_msg == {"role": "user", "content": "Hello"}
|
||||
|
||||
def test_to_chat_message_tool(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.TOOL,
|
||||
content="result",
|
||||
tool_call_id="tc1",
|
||||
)
|
||||
chat_msg = msg.to_chat_message()
|
||||
assert chat_msg == {"role": "tool", "content": "result", "tool_call_id": "tc1"}
|
||||
|
||||
def test_to_chat_message_no_tool_call_id(self):
|
||||
msg = Message(
|
||||
message_id="m1",
|
||||
session_id="s1",
|
||||
role=MessageRole.ASSISTANT,
|
||||
content="Hi",
|
||||
)
|
||||
chat_msg = msg.to_chat_message()
|
||||
assert "tool_call_id" not in chat_msg
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
"""Tests for InMemorySessionStore."""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
|
||||
from agentkit.session.store import InMemorySessionStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
return InMemorySessionStore()
|
||||
|
||||
|
||||
async def _create_session(store, session_id="s1", agent_name="agent1"):
|
||||
session = Session(session_id=session_id, agent_name=agent_name)
|
||||
await store.save_session(session)
|
||||
return session
|
||||
|
||||
|
||||
async def _create_message(store, session_id, role=MessageRole.USER, content="Hello"):
|
||||
msg = Message(
|
||||
message_id=Session.new_message_id(),
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
)
|
||||
await store.append_message(msg)
|
||||
return msg
|
||||
|
||||
|
||||
class TestInMemorySessionStoreCRUD:
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_and_get(self, store):
|
||||
session = await _create_session(store)
|
||||
fetched = await store.get_session("s1")
|
||||
assert fetched is not None
|
||||
assert fetched.session_id == "s1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent(self, store):
|
||||
result = await store.get_session("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status(self, store):
|
||||
await _create_session(store)
|
||||
updated = await store.update_session_status("s1", SessionStatus.PAUSED)
|
||||
assert updated is not None
|
||||
assert updated.status == SessionStatus.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_nonexistent(self, store):
|
||||
result = await store.update_session_status("nonexistent", SessionStatus.PAUSED)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(self, store):
|
||||
await _create_session(store)
|
||||
assert await store.delete_session("s1") is True
|
||||
assert await store.get_session("s1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent(self, store):
|
||||
assert await store.delete_session("nonexistent") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions(self, store):
|
||||
await _create_session(store, "s1", "agent1")
|
||||
await _create_session(store, "s2", "agent2")
|
||||
sessions = await store.list_sessions()
|
||||
assert len(sessions) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions_by_agent(self, store):
|
||||
await _create_session(store, "s1", "agent1")
|
||||
await _create_session(store, "s2", "agent2")
|
||||
sessions = await store.list_sessions(agent_name="agent1")
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0].agent_name == "agent1"
|
||||
|
||||
|
||||
class TestInMemorySessionStoreMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_append_and_get(self, store):
|
||||
await _create_session(store)
|
||||
await _create_message(store, "s1", content="Hello")
|
||||
await _create_message(store, "s1", content="World")
|
||||
|
||||
messages = await store.get_messages("s1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0].content == "Hello"
|
||||
assert messages[1].content == "World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_pagination(self, store):
|
||||
await _create_session(store)
|
||||
for i in range(5):
|
||||
await _create_message(store, "s1", content=f"Msg {i}")
|
||||
|
||||
page = await store.get_messages("s1", limit=2, offset=1)
|
||||
assert len(page) == 2
|
||||
assert page[0].content == "Msg 1"
|
||||
assert page[1].content == "Msg 2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_messages(self, store):
|
||||
await _create_session(store)
|
||||
await _create_message(store, "s1")
|
||||
await _create_message(store, "s1")
|
||||
assert await store.count_messages("s1") == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_messages_empty_session(self, store):
|
||||
assert await store.count_messages("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_messages_empty_session(self, store):
|
||||
messages = await store.get_messages("nonexistent")
|
||||
assert messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session_removes_messages(self, store):
|
||||
await _create_session(store)
|
||||
await _create_message(store, "s1")
|
||||
await store.delete_session("s1")
|
||||
assert await store.count_messages("s1") == 0
|
||||
|
||||
|
||||
class TestInMemorySessionStoreEviction:
|
||||
@pytest.mark.asyncio
|
||||
async def test_evict_closed_session_on_full(self):
|
||||
store = InMemorySessionStore(max_sessions=2)
|
||||
s1 = await _create_session(store, "s1")
|
||||
await _create_session(store, "s2")
|
||||
|
||||
# Close s1 so it can be evicted
|
||||
await store.update_session_status("s1", SessionStatus.CLOSED)
|
||||
|
||||
# Creating a third session should evict s1
|
||||
await _create_session(store, "s3")
|
||||
assert await store.get_session("s1") is None
|
||||
assert await store.get_session("s2") is not None
|
||||
assert await store.get_session("s3") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_no_closed_raises(self):
|
||||
store = InMemorySessionStore(max_sessions=2)
|
||||
await _create_session(store, "s1")
|
||||
await _create_session(store, "s2")
|
||||
with pytest.raises(RuntimeError, match="full"):
|
||||
await _create_session(store, "s3")
|
||||
|
||||
|
||||
class TestInMemorySessionStoreHealth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, store):
|
||||
assert await store.health_check() is True
|
||||
Loading…
Reference in New Issue