feat(session): add Session/Message models and SessionManager with InMemory/Redis stores

This commit is contained in:
chiguyong 2026-06-07 22:43:14 +08:00
parent e4d6efb4bf
commit 493187782c
7 changed files with 1041 additions and 0 deletions

View File

@ -0,0 +1,16 @@
"""Session management - multi-turn conversation support for AgentKit."""
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore, RedisSessionStore, create_session_store
__all__ = [
"Message",
"MessageRole",
"Session",
"SessionStatus",
"SessionManager",
"InMemorySessionStore",
"RedisSessionStore",
"create_session_store",
]

View File

@ -0,0 +1,160 @@
"""SessionManager — high-level API for conversation session management."""
from __future__ import annotations
import logging
from typing import Any
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.store import InMemorySessionStore, SessionStore
logger = logging.getLogger(__name__)
class SessionManager:
"""Manages conversation sessions and their messages.
Provides a high-level API for creating, querying, and updating
sessions, as well as appending and retrieving messages.
"""
def __init__(self, store: SessionStore | None = None):
self._store = store or InMemorySessionStore()
@property
def store(self) -> SessionStore:
return self._store
async def create_session(
self,
agent_name: str,
metadata: dict[str, Any] | None = None,
) -> Session:
"""Create a new conversation session bound to an Agent.
Args:
agent_name: Name of the Agent this session is bound to.
metadata: Optional metadata to attach to the session.
Returns:
The newly created Session.
"""
session = Session(
session_id=Session.new_session_id(),
agent_name=agent_name,
metadata=metadata or {},
)
await self._store.save_session(session)
logger.info(f"Session created: {session.session_id} for agent '{agent_name}'")
return session
async def get_session(self, session_id: str) -> Session | None:
"""Get a session by ID."""
return await self._store.get_session(session_id)
async def pause_session(self, session_id: str) -> Session | None:
"""Pause an active session."""
return await self._store.update_session_status(session_id, SessionStatus.PAUSED)
async def resume_session(self, session_id: str) -> Session | None:
"""Resume a paused session."""
return await self._store.update_session_status(session_id, SessionStatus.ACTIVE)
async def close_session(self, session_id: str) -> Session | None:
"""Close a session. Closed sessions cannot accept new messages."""
return await self._store.update_session_status(session_id, SessionStatus.CLOSED)
async def delete_session(self, session_id: str) -> bool:
"""Delete a session and all its messages."""
return await self._store.delete_session(session_id)
async def list_sessions(
self,
agent_name: str | None = None,
limit: int = 100,
) -> list[Session]:
"""List sessions, optionally filtered by agent name."""
return await self._store.list_sessions(agent_name=agent_name, limit=limit)
async def append_message(
self,
session_id: str,
role: MessageRole,
content: str,
tool_call_id: str | None = None,
agent_name: str | None = None,
metadata: dict[str, Any] | None = None,
) -> Message:
"""Append a message to a session.
Args:
session_id: Target session ID.
role: Message role (user/assistant/tool/system).
content: Message content.
tool_call_id: Optional tool call ID for tool messages.
agent_name: Optional agent name for multi-Agent sessions.
metadata: Optional message metadata.
Returns:
The newly created Message.
Raises:
ValueError: If the session does not exist or is closed.
"""
session = await self._store.get_session(session_id)
if session is None:
raise ValueError(f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise ValueError(f"Session '{session_id}' is closed and cannot accept new messages")
message = Message(
message_id=Session.new_message_id(),
session_id=session_id,
role=role,
content=content,
tool_call_id=tool_call_id,
agent_name=agent_name,
metadata=metadata or {},
)
await self._store.append_message(message)
# Update session's updated_at timestamp
session.updated_at = __import__("datetime").datetime.now(__import__("datetime").timezone.utc)
await self._store.save_session(session)
return message
async def get_messages(
self,
session_id: str,
limit: int | None = None,
offset: int = 0,
) -> list[Message]:
"""Get messages for a session with optional pagination.
Args:
session_id: Target session ID.
limit: Maximum number of messages to return. None for all.
offset: Number of messages to skip from the beginning.
Returns:
List of messages ordered chronologically.
"""
return await self._store.get_messages(session_id, limit=limit, offset=offset)
async def get_chat_messages(self, session_id: str) -> list[dict[str, str]]:
"""Get messages formatted for LLM chat API consumption.
Returns messages as OpenAI-compatible dicts suitable for
passing directly to the ReAct engine or LLM Gateway.
"""
messages = await self._store.get_messages(session_id)
return [m.to_chat_message() for m in messages]
async def count_messages(self, session_id: str) -> int:
"""Count messages in a session."""
return await self._store.count_messages(session_id)
async def health_check(self) -> bool:
"""Check if the underlying store is healthy."""
return await self._store.health_check()

View File

@ -0,0 +1,125 @@
"""Session and Message data models for multi-turn conversations."""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any
class SessionStatus(str, Enum):
"""Session lifecycle states."""
ACTIVE = "active"
PAUSED = "paused"
CLOSED = "closed"
class MessageRole(str, Enum):
"""Message role — mirrors OpenAI chat message roles."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class Message:
"""A single message within a conversation session.
Maps directly to the ``messages`` list consumed by the ReAct engine.
"""
message_id: str
session_id: str
role: MessageRole
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"message_id": self.message_id,
"session_id": self.session_id,
"role": self.role.value,
"content": self.content,
"tool_call_id": self.tool_call_id,
"agent_name": self.agent_name,
"created_at": self.created_at.isoformat(),
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Message:
return cls(
message_id=data["message_id"],
session_id=data["session_id"],
role=MessageRole(data["role"]),
content=data["content"],
tool_call_id=data.get("tool_call_id"),
agent_name=data.get("agent_name"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
metadata=data.get("metadata", {}),
)
def to_chat_message(self) -> dict[str, str]:
"""Convert to OpenAI-compatible chat message dict.
Returns a dict suitable for the ``messages`` parameter of LLM chat APIs.
"""
msg: dict[str, str] = {"role": self.role.value, "content": self.content}
if self.tool_call_id is not None:
msg["tool_call_id"] = self.tool_call_id
return msg
@dataclass
class Session:
"""A conversation session binding a user to an Agent.
Sessions track lifecycle state and accumulate Messages. They are
persisted via :class:`SessionStore` backends.
"""
session_id: str
agent_name: str
status: SessionStatus = SessionStatus.ACTIVE
metadata: dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> dict[str, Any]:
return {
"session_id": self.session_id,
"agent_name": self.agent_name,
"status": self.status.value,
"metadata": self.metadata,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Session:
return cls(
session_id=data["session_id"],
agent_name=data["agent_name"],
status=SessionStatus(data.get("status", "active")),
metadata=data.get("metadata", {}),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
updated_at=datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else datetime.now(timezone.utc),
)
@staticmethod
def new_session_id() -> str:
"""Generate a new session ID."""
return str(uuid.uuid4())
@staticmethod
def new_message_id() -> str:
"""Generate a new message ID."""
return str(uuid.uuid4())

View File

@ -0,0 +1,238 @@
"""Session store backends — InMemory and Redis."""
from __future__ import annotations
import json
import logging
from typing import Any, Protocol, runtime_checkable
from agentkit.session.models import Message, Session, SessionStatus
logger = logging.getLogger(__name__)
@runtime_checkable
class SessionStore(Protocol):
"""Protocol for session persistence backends."""
async def save_session(self, session: Session) -> None: ...
async def get_session(self, session_id: str) -> Session | None: ...
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None: ...
async def delete_session(self, session_id: str) -> bool: ...
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]: ...
async def append_message(self, message: Message) -> None: ...
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]: ...
async def count_messages(self, session_id: str) -> int: ...
async def health_check(self) -> bool: ...
class InMemorySessionStore:
"""In-memory session store for development and testing."""
def __init__(self, max_sessions: int = 10000, max_messages_per_session: int = 50000):
self._sessions: dict[str, Session] = {}
self._messages: dict[str, list[Message]] = {}
self._max_sessions = max_sessions
self._max_messages_per_session = max_messages_per_session
async def save_session(self, session: Session) -> None:
if len(self._sessions) >= self._max_sessions and session.session_id not in self._sessions:
# Evict oldest closed session
closed = [s for s in self._sessions.values() if s.status == SessionStatus.CLOSED]
if closed:
oldest = min(closed, key=lambda s: s.updated_at)
del self._sessions[oldest.session_id]
self._messages.pop(oldest.session_id, None)
else:
raise RuntimeError("SessionStore is full and no closed sessions to evict")
self._sessions[session.session_id] = session
if session.session_id not in self._messages:
self._messages[session.session_id] = []
async def get_session(self, session_id: str) -> Session | None:
return self._sessions.get(session_id)
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
session = self._sessions.get(session_id)
if session is None:
return None
session.status = status
session.updated_at = datetime.now(timezone.utc)
return session
async def delete_session(self, session_id: str) -> bool:
if session_id in self._sessions:
del self._sessions[session_id]
self._messages.pop(session_id, None)
return True
return False
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
sessions = list(self._sessions.values())
if agent_name:
sessions = [s for s in sessions if s.agent_name == agent_name]
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
msgs = self._messages.setdefault(message.session_id, [])
if len(msgs) >= self._max_messages_per_session:
# Remove oldest messages to stay within limit
excess = len(msgs) - self._max_messages_per_session + 1
del msgs[:excess]
msgs.append(message)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
msgs = self._messages.get(session_id, [])
sliced = msgs[offset:]
if limit is not None:
sliced = sliced[:limit]
return sliced
async def count_messages(self, session_id: str) -> int:
return len(self._messages.get(session_id, []))
async def health_check(self) -> bool:
return True
class RedisSessionStore:
"""Redis-backed session store for production use.
Key patterns:
- ``agentkit:session:{session_id}`` session metadata (JSON + TTL)
- ``agentkit:session:{session_id}:messages`` message list (Redis list)
"""
KEY_PREFIX = "agentkit:session:"
MSG_SUFFIX = ":messages"
def __init__(self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 86400):
self._redis_url = redis_url
self._ttl_seconds = ttl_seconds
self._redis: Any = None
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
def _session_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}"
def _messages_key(self, session_id: str) -> str:
return f"{self.KEY_PREFIX}{session_id}{self.MSG_SUFFIX}"
async def save_session(self, session: Session) -> None:
redis = await self._get_redis()
key = self._session_key(session.session_id)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
async def get_session(self, session_id: str) -> Session | None:
redis = await self._get_redis()
raw = await redis.get(self._session_key(session_id))
if raw is None:
return None
return Session.from_dict(json.loads(raw))
async def update_session_status(self, session_id: str, status: SessionStatus) -> Session | None:
redis = await self._get_redis()
key = self._session_key(session_id)
raw = await redis.get(key)
if raw is None:
return None
session = Session.from_dict(json.loads(raw))
session.status = status
session.updated_at = datetime.now(timezone.utc)
await redis.set(key, json.dumps(session.to_dict()), ex=self._ttl_seconds)
return session
async def delete_session(self, session_id: str) -> bool:
redis = await self._get_redis()
keys = [self._session_key(session_id), self._messages_key(session_id)]
deleted = await redis.delete(*keys)
return deleted > 0
async def list_sessions(self, agent_name: str | None = None, limit: int = 100) -> list[Session]:
redis = await self._get_redis()
sessions: list[Session] = []
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
# Filter out message list keys
session_keys = [k for k in keys if not k.endswith(self.MSG_SUFFIX)]
if session_keys:
values = await redis.mget(session_keys)
for raw in values:
if raw is None:
continue
session = Session.from_dict(json.loads(raw))
if agent_name is None or session.agent_name == agent_name:
sessions.append(session)
if cursor == 0:
break
sessions.sort(key=lambda s: s.updated_at, reverse=True)
return sessions[:limit]
async def append_message(self, message: Message) -> None:
redis = await self._get_redis()
key = self._messages_key(message.session_id)
await redis.rpush(key, json.dumps(message.to_dict()))
# Set TTL on message list to match session TTL
await redis.expire(key, self._ttl_seconds)
async def get_messages(self, session_id: str, limit: int | None = None, offset: int = 0) -> list[Message]:
redis = await self._get_redis()
key = self._messages_key(session_id)
# Use LRANGE for offset-based pagination
# Redis list indices: 0-based, -1 = last element
start = offset
if limit is not None:
end = offset + limit - 1
else:
end = -1
raw_list = await redis.lrange(key, start, end)
return [Message.from_dict(json.loads(raw)) for raw in raw_list]
async def count_messages(self, session_id: str) -> int:
redis = await self._get_redis()
return await redis.llen(self._messages_key(session_id))
async def health_check(self) -> bool:
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
# Needed for from_dict deserialization
from datetime import datetime, timezone # noqa: E402
def create_session_store(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
ttl_seconds: int = 86400,
) -> InMemorySessionStore | RedisSessionStore:
"""Factory: create a SessionStore backed by memory or Redis.
Falls back to InMemorySessionStore if Redis is unavailable.
"""
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
store = RedisSessionStore(redis_url=redis_url, ttl_seconds=ttl_seconds)
logger.info(f"SessionStore backend: redis")
return store
except Exception as exc:
logger.warning(f"Failed to initialise RedisSessionStore ({exc}), falling back to InMemorySessionStore")
store = InMemorySessionStore()
logger.info("SessionStore backend: memory")
return store

View File

@ -0,0 +1,199 @@
"""Tests for SessionManager."""
import pytest
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
from agentkit.session.store import InMemorySessionStore
@pytest.fixture
def manager():
return SessionManager(store=InMemorySessionStore())
class TestSessionManagerCreate:
@pytest.mark.asyncio
async def test_create_session(self, manager):
session = await manager.create_session(agent_name="test-agent")
assert session.session_id is not None
assert session.agent_name == "test-agent"
assert session.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_create_session_with_metadata(self, manager):
session = await manager.create_session(
agent_name="agent1",
metadata={"user_id": "u1"},
)
assert session.metadata == {"user_id": "u1"}
class TestSessionManagerGet:
@pytest.mark.asyncio
async def test_get_existing_session(self, manager):
created = await manager.create_session(agent_name="agent1")
fetched = await manager.get_session(created.session_id)
assert fetched is not None
assert fetched.session_id == created.session_id
@pytest.mark.asyncio
async def test_get_nonexistent_session(self, manager):
result = await manager.get_session("nonexistent")
assert result is None
class TestSessionManagerLifecycle:
@pytest.mark.asyncio
async def test_pause_and_resume(self, manager):
session = await manager.create_session(agent_name="agent1")
paused = await manager.pause_session(session.session_id)
assert paused.status == SessionStatus.PAUSED
resumed = await manager.resume_session(session.session_id)
assert resumed.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_close_session(self, manager):
session = await manager.create_session(agent_name="agent1")
closed = await manager.close_session(session.session_id)
assert closed.status == SessionStatus.CLOSED
@pytest.mark.asyncio
async def test_close_nonexistent_returns_none(self, manager):
result = await manager.close_session("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_delete_session(self, manager):
session = await manager.create_session(agent_name="agent1")
deleted = await manager.delete_session(session.session_id)
assert deleted is True
assert await manager.get_session(session.session_id) is None
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_false(self, manager):
deleted = await manager.delete_session("nonexistent")
assert deleted is False
class TestSessionManagerMessages:
@pytest.mark.asyncio
async def test_append_user_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Hello",
)
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
assert msg.session_id == session.session_id
@pytest.mark.asyncio
async def test_append_assistant_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content="Hi there!",
)
assert msg.role == MessageRole.ASSISTANT
@pytest.mark.asyncio
async def test_get_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
messages = await manager.get_messages(session.session_id)
assert len(messages) == 2
assert messages[0].content == "Hello"
assert messages[1].content == "Hi!"
@pytest.mark.asyncio
async def test_get_messages_pagination(self, manager):
session = await manager.create_session(agent_name="agent1")
for i in range(10):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=f"Message {i}",
)
# Get first 3 messages
page1 = await manager.get_messages(session.session_id, limit=3, offset=0)
assert len(page1) == 3
assert page1[0].content == "Message 0"
# Get next 3 messages
page2 = await manager.get_messages(session.session_id, limit=3, offset=3)
assert len(page2) == 3
assert page2[0].content == "Message 3"
@pytest.mark.asyncio
async def test_count_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
count = await manager.count_messages(session.session_id)
assert count == 2
@pytest.mark.asyncio
async def test_closed_session_rejects_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.close_session(session.session_id)
with pytest.raises(ValueError, match="closed"):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_nonexistent_session_rejects_messages(self, manager):
with pytest.raises(ValueError, match="not found"):
await manager.append_message(
session_id="nonexistent",
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_get_chat_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
chat_msgs = await manager.get_chat_messages(session.session_id)
assert len(chat_msgs) == 2
assert chat_msgs[0] == {"role": "user", "content": "Hello"}
assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"}
class TestSessionManagerList:
@pytest.mark.asyncio
async def test_list_sessions(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
sessions = await manager.list_sessions()
assert len(sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_by_agent(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
await manager.create_session(agent_name="agent1")
sessions = await manager.list_sessions(agent_name="agent1")
assert len(sessions) == 2
assert all(s.agent_name == "agent1" for s in sessions)
class TestSessionManagerHealth:
@pytest.mark.asyncio
async def test_health_check(self, manager):
assert await manager.health_check() is True

View File

@ -0,0 +1,146 @@
"""Tests for Session and Message data models."""
import pytest
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
class TestSessionStatus:
def test_status_values(self):
assert SessionStatus.ACTIVE == "active"
assert SessionStatus.PAUSED == "paused"
assert SessionStatus.CLOSED == "closed"
def test_status_from_string(self):
assert SessionStatus("active") == SessionStatus.ACTIVE
assert SessionStatus("paused") == SessionStatus.PAUSED
assert SessionStatus("closed") == SessionStatus.CLOSED
class TestMessageRole:
def test_role_values(self):
assert MessageRole.SYSTEM == "system"
assert MessageRole.USER == "user"
assert MessageRole.ASSISTANT == "assistant"
assert MessageRole.TOOL == "tool"
class TestSession:
def test_create_session(self):
session = Session(session_id="s1", agent_name="test-agent")
assert session.session_id == "s1"
assert session.agent_name == "test-agent"
assert session.status == SessionStatus.ACTIVE
assert session.metadata == {}
assert session.created_at is not None
assert session.updated_at is not None
def test_session_to_dict_and_back(self):
session = Session(
session_id="s1",
agent_name="agent1",
status=SessionStatus.PAUSED,
metadata={"key": "value"},
)
d = session.to_dict()
assert d["session_id"] == "s1"
assert d["agent_name"] == "agent1"
assert d["status"] == "paused"
assert d["metadata"] == {"key": "value"}
restored = Session.from_dict(d)
assert restored.session_id == session.session_id
assert restored.agent_name == session.agent_name
assert restored.status == session.status
assert restored.metadata == session.metadata
def test_new_session_id_is_unique(self):
ids = {Session.new_session_id() for _ in range(100)}
assert len(ids) == 100
def test_new_message_id_is_unique(self):
ids = {Session.new_message_id() for _ in range(100)}
assert len(ids) == 100
class TestMessage:
def test_create_message(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.USER,
content="Hello",
)
assert msg.message_id == "m1"
assert msg.session_id == "s1"
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
assert msg.tool_call_id is None
assert msg.agent_name is None
assert msg.metadata == {}
def test_message_with_tool_call(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.TOOL,
content="result",
tool_call_id="tc1",
agent_name="agent1",
)
assert msg.tool_call_id == "tc1"
assert msg.agent_name == "agent1"
def test_message_to_dict_and_back(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.ASSISTANT,
content="Hi there",
tool_call_id="tc1",
agent_name="agent1",
metadata={"step": 1},
)
d = msg.to_dict()
assert d["message_id"] == "m1"
assert d["role"] == "assistant"
assert d["tool_call_id"] == "tc1"
restored = Message.from_dict(d)
assert restored.message_id == msg.message_id
assert restored.role == msg.role
assert restored.content == msg.content
assert restored.tool_call_id == msg.tool_call_id
assert restored.agent_name == msg.agent_name
assert restored.metadata == msg.metadata
def test_to_chat_message_user(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.USER,
content="Hello",
)
chat_msg = msg.to_chat_message()
assert chat_msg == {"role": "user", "content": "Hello"}
def test_to_chat_message_tool(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.TOOL,
content="result",
tool_call_id="tc1",
)
chat_msg = msg.to_chat_message()
assert chat_msg == {"role": "tool", "content": "result", "tool_call_id": "tc1"}
def test_to_chat_message_no_tool_call_id(self):
msg = Message(
message_id="m1",
session_id="s1",
role=MessageRole.ASSISTANT,
content="Hi",
)
chat_msg = msg.to_chat_message()
assert "tool_call_id" not in chat_msg

View File

@ -0,0 +1,157 @@
"""Tests for InMemorySessionStore."""
import pytest
from agentkit.session.models import Message, MessageRole, Session, SessionStatus
from agentkit.session.store import InMemorySessionStore
@pytest.fixture
def store():
return InMemorySessionStore()
async def _create_session(store, session_id="s1", agent_name="agent1"):
session = Session(session_id=session_id, agent_name=agent_name)
await store.save_session(session)
return session
async def _create_message(store, session_id, role=MessageRole.USER, content="Hello"):
msg = Message(
message_id=Session.new_message_id(),
session_id=session_id,
role=role,
content=content,
)
await store.append_message(msg)
return msg
class TestInMemorySessionStoreCRUD:
@pytest.mark.asyncio
async def test_save_and_get(self, store):
session = await _create_session(store)
fetched = await store.get_session("s1")
assert fetched is not None
assert fetched.session_id == "s1"
@pytest.mark.asyncio
async def test_get_nonexistent(self, store):
result = await store.get_session("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_update_status(self, store):
await _create_session(store)
updated = await store.update_session_status("s1", SessionStatus.PAUSED)
assert updated is not None
assert updated.status == SessionStatus.PAUSED
@pytest.mark.asyncio
async def test_update_status_nonexistent(self, store):
result = await store.update_session_status("nonexistent", SessionStatus.PAUSED)
assert result is None
@pytest.mark.asyncio
async def test_delete(self, store):
await _create_session(store)
assert await store.delete_session("s1") is True
assert await store.get_session("s1") is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, store):
assert await store.delete_session("nonexistent") is False
@pytest.mark.asyncio
async def test_list_sessions(self, store):
await _create_session(store, "s1", "agent1")
await _create_session(store, "s2", "agent2")
sessions = await store.list_sessions()
assert len(sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_by_agent(self, store):
await _create_session(store, "s1", "agent1")
await _create_session(store, "s2", "agent2")
sessions = await store.list_sessions(agent_name="agent1")
assert len(sessions) == 1
assert sessions[0].agent_name == "agent1"
class TestInMemorySessionStoreMessages:
@pytest.mark.asyncio
async def test_append_and_get(self, store):
await _create_session(store)
await _create_message(store, "s1", content="Hello")
await _create_message(store, "s1", content="World")
messages = await store.get_messages("s1")
assert len(messages) == 2
assert messages[0].content == "Hello"
assert messages[1].content == "World"
@pytest.mark.asyncio
async def test_get_messages_pagination(self, store):
await _create_session(store)
for i in range(5):
await _create_message(store, "s1", content=f"Msg {i}")
page = await store.get_messages("s1", limit=2, offset=1)
assert len(page) == 2
assert page[0].content == "Msg 1"
assert page[1].content == "Msg 2"
@pytest.mark.asyncio
async def test_count_messages(self, store):
await _create_session(store)
await _create_message(store, "s1")
await _create_message(store, "s1")
assert await store.count_messages("s1") == 2
@pytest.mark.asyncio
async def test_count_messages_empty_session(self, store):
assert await store.count_messages("nonexistent") == 0
@pytest.mark.asyncio
async def test_get_messages_empty_session(self, store):
messages = await store.get_messages("nonexistent")
assert messages == []
@pytest.mark.asyncio
async def test_delete_session_removes_messages(self, store):
await _create_session(store)
await _create_message(store, "s1")
await store.delete_session("s1")
assert await store.count_messages("s1") == 0
class TestInMemorySessionStoreEviction:
@pytest.mark.asyncio
async def test_evict_closed_session_on_full(self):
store = InMemorySessionStore(max_sessions=2)
s1 = await _create_session(store, "s1")
await _create_session(store, "s2")
# Close s1 so it can be evicted
await store.update_session_status("s1", SessionStatus.CLOSED)
# Creating a third session should evict s1
await _create_session(store, "s3")
assert await store.get_session("s1") is None
assert await store.get_session("s2") is not None
assert await store.get_session("s3") is not None
@pytest.mark.asyncio
async def test_full_no_closed_raises(self):
store = InMemorySessionStore(max_sessions=2)
await _create_session(store, "s1")
await _create_session(store, "s2")
with pytest.raises(RuntimeError, match="full"):
await _create_session(store, "s3")
class TestInMemorySessionStoreHealth:
@pytest.mark.asyncio
async def test_health_check(self, store):
assert await store.health_check() is True