339 lines
13 KiB
Python
339 lines
13 KiB
Python
"""Tests for SessionManager."""
|
|
|
|
import pytest
|
|
|
|
from agentkit.session.manager import SessionManager
|
|
from agentkit.session.models import MessageRole, SessionStatus
|
|
from agentkit.session.store import InMemorySessionStore
|
|
|
|
|
|
@pytest.fixture
|
|
def manager():
|
|
return SessionManager(store=InMemorySessionStore())
|
|
|
|
|
|
@pytest.fixture
|
|
def async_manager():
|
|
return SessionManager(store=InMemorySessionStore(), async_writes=True)
|
|
|
|
|
|
class TestSessionManagerCreate:
|
|
@pytest.mark.asyncio
|
|
async def test_create_session(self, manager):
|
|
session = await manager.create_session(agent_name="test-agent")
|
|
assert session.session_id is not None
|
|
assert session.agent_name == "test-agent"
|
|
assert session.status == SessionStatus.ACTIVE
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_session_with_metadata(self, manager):
|
|
session = await manager.create_session(
|
|
agent_name="agent1",
|
|
metadata={"user_id": "u1"},
|
|
)
|
|
assert session.metadata == {"user_id": "u1"}
|
|
|
|
|
|
class TestSessionManagerGet:
|
|
@pytest.mark.asyncio
|
|
async def test_get_existing_session(self, manager):
|
|
created = await manager.create_session(agent_name="agent1")
|
|
fetched = await manager.get_session(created.session_id)
|
|
assert fetched is not None
|
|
assert fetched.session_id == created.session_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_nonexistent_session(self, manager):
|
|
result = await manager.get_session("nonexistent")
|
|
assert result is None
|
|
|
|
|
|
class TestSessionManagerLifecycle:
|
|
@pytest.mark.asyncio
|
|
async def test_pause_and_resume(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
paused = await manager.pause_session(session.session_id)
|
|
assert paused.status == SessionStatus.PAUSED
|
|
|
|
resumed = await manager.resume_session(session.session_id)
|
|
assert resumed.status == SessionStatus.ACTIVE
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_session(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
closed = await manager.close_session(session.session_id)
|
|
assert closed.status == SessionStatus.CLOSED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_nonexistent_returns_none(self, manager):
|
|
result = await manager.close_session("nonexistent")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_session(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
deleted = await manager.delete_session(session.session_id)
|
|
assert deleted is True
|
|
assert await manager.get_session(session.session_id) is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_nonexistent_returns_false(self, manager):
|
|
deleted = await manager.delete_session("nonexistent")
|
|
assert deleted is False
|
|
|
|
|
|
class TestSessionManagerMessages:
|
|
@pytest.mark.asyncio
|
|
async def test_append_user_message(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
msg = await manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Hello",
|
|
)
|
|
assert msg.role == MessageRole.USER
|
|
assert msg.content == "Hello"
|
|
assert msg.session_id == session.session_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_append_assistant_message(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
msg = await manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content="Hi there!",
|
|
)
|
|
assert msg.role == MessageRole.ASSISTANT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_messages(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
|
|
|
messages = await manager.get_messages(session.session_id)
|
|
assert len(messages) == 2
|
|
assert messages[0].content == "Hello"
|
|
assert messages[1].content == "Hi!"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_messages_pagination(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
for i in range(10):
|
|
await manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content=f"Message {i}",
|
|
)
|
|
|
|
# Get first 3 messages
|
|
page1 = await manager.get_messages(session.session_id, limit=3, offset=0)
|
|
assert len(page1) == 3
|
|
assert page1[0].content == "Message 0"
|
|
|
|
# Get next 3 messages
|
|
page2 = await manager.get_messages(session.session_id, limit=3, offset=3)
|
|
assert len(page2) == 3
|
|
assert page2[0].content == "Message 3"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_count_messages(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
|
|
|
count = await manager.count_messages(session.session_id)
|
|
assert count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_closed_session_rejects_messages(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
await manager.close_session(session.session_id)
|
|
|
|
with pytest.raises(ValueError, match="closed"):
|
|
await manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Should fail",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nonexistent_session_rejects_messages(self, manager):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
await manager.append_message(
|
|
session_id="nonexistent",
|
|
role=MessageRole.USER,
|
|
content="Should fail",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_chat_messages(self, manager):
|
|
session = await manager.create_session(agent_name="agent1")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
|
|
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
|
|
|
|
chat_msgs = await manager.get_chat_messages(session.session_id)
|
|
assert len(chat_msgs) == 2
|
|
assert chat_msgs[0] == {"role": "user", "content": "Hello"}
|
|
assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"}
|
|
|
|
|
|
class TestSessionManagerList:
|
|
@pytest.mark.asyncio
|
|
async def test_list_sessions(self, manager):
|
|
await manager.create_session(agent_name="agent1")
|
|
await manager.create_session(agent_name="agent2")
|
|
|
|
sessions = await manager.list_sessions()
|
|
assert len(sessions) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_sessions_by_agent(self, manager):
|
|
await manager.create_session(agent_name="agent1")
|
|
await manager.create_session(agent_name="agent2")
|
|
await manager.create_session(agent_name="agent1")
|
|
|
|
sessions = await manager.list_sessions(agent_name="agent1")
|
|
assert len(sessions) == 2
|
|
assert all(s.agent_name == "agent1" for s in sessions)
|
|
|
|
|
|
class TestSessionManagerHealth:
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, manager):
|
|
assert await manager.health_check() is True
|
|
|
|
|
|
class TestAsyncWrites:
|
|
"""Tests for async (non-blocking) write behaviour."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_append_message_returns_immediately(self, async_manager):
|
|
"""append_message returns the Message before it is persisted."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
msg = await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Hello",
|
|
)
|
|
# Message is returned immediately
|
|
assert msg.role == MessageRole.USER
|
|
assert msg.content == "Hello"
|
|
# Give the background writer a moment, then verify persistence
|
|
await async_manager.flush()
|
|
persisted = await async_manager.store.get_messages(session.session_id)
|
|
assert len(persisted) == 1
|
|
assert persisted[0].content == "Hello"
|
|
await async_manager.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_chat_messages_includes_wal_buffered(self, async_manager):
|
|
"""get_chat_messages returns WAL-buffered messages not yet persisted."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
# Append a message — it may still be in the WAL buffer
|
|
await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Buffered",
|
|
)
|
|
# get_chat_messages should include WAL-buffered messages
|
|
chat_msgs = await async_manager.get_chat_messages(session.session_id)
|
|
assert len(chat_msgs) >= 1
|
|
assert any(m["content"] == "Buffered" for m in chat_msgs)
|
|
await async_manager.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_flush_ensures_all_pending_writes(self, async_manager):
|
|
"""flush() waits until all queued writes are persisted."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
for i in range(5):
|
|
await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content=f"Msg {i}",
|
|
)
|
|
await async_manager.flush()
|
|
persisted = await async_manager.store.get_messages(session.session_id)
|
|
assert len(persisted) == 5
|
|
await async_manager.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rapid_appends_are_batched(self, async_manager):
|
|
"""Multiple rapid append_messages are all persisted correctly."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
# Fire off many messages rapidly
|
|
messages = []
|
|
for i in range(20):
|
|
msg = await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content=f"Rapid {i}",
|
|
)
|
|
messages.append(msg)
|
|
await async_manager.flush()
|
|
persisted = await async_manager.store.get_messages(session.session_id)
|
|
assert len(persisted) == 20
|
|
contents = [m.content for m in persisted]
|
|
for i in range(20):
|
|
assert f"Rapid {i}" in contents
|
|
await async_manager.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_close_flushes_pending_writes(self, async_manager):
|
|
"""Closing a session flushes pending writes first."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Before close",
|
|
)
|
|
closed = await async_manager.close_session(session.session_id)
|
|
assert closed.status == SessionStatus.CLOSED
|
|
# Message should be persisted because close_session flushes
|
|
persisted = await async_manager.store.get_messages(session.session_id)
|
|
assert len(persisted) == 1
|
|
assert persisted[0].content == "Before close"
|
|
await async_manager.close()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_manager_close_stops_writer(self, async_manager):
|
|
"""close() flushes and stops the background writer."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Final",
|
|
)
|
|
await async_manager.close()
|
|
# After close, the write queue should be stopped
|
|
assert async_manager._write_queue is None or async_manager._write_queue._worker is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_writes_disabled_by_default(self):
|
|
"""Without async_writes=True, writes are synchronous."""
|
|
mgr = SessionManager(store=InMemorySessionStore())
|
|
assert mgr._write_queue is None
|
|
session = await mgr.create_session(agent_name="agent1")
|
|
await mgr.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="Sync",
|
|
)
|
|
# Should be immediately persisted (no flush needed)
|
|
persisted = await mgr.store.get_messages(session.session_id)
|
|
assert len(persisted) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_messages_includes_wal_buffered(self, async_manager):
|
|
"""get_messages returns WAL-buffered messages not yet persisted."""
|
|
session = await async_manager.create_session(agent_name="agent1")
|
|
await async_manager.append_message(
|
|
session_id=session.session_id,
|
|
role=MessageRole.USER,
|
|
content="WAL msg",
|
|
)
|
|
messages = await async_manager.get_messages(session.session_id)
|
|
assert len(messages) >= 1
|
|
assert any(m.content == "WAL msg" for m in messages)
|
|
await async_manager.close()
|