fischer-agentkit/tests/unit/test_session_manager.py

339 lines
13 KiB
Python

"""Tests for SessionManager."""
import pytest
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
from agentkit.session.store import InMemorySessionStore
@pytest.fixture
def manager():
return SessionManager(store=InMemorySessionStore())
@pytest.fixture
def async_manager():
return SessionManager(store=InMemorySessionStore(), async_writes=True)
class TestSessionManagerCreate:
@pytest.mark.asyncio
async def test_create_session(self, manager):
session = await manager.create_session(agent_name="test-agent")
assert session.session_id is not None
assert session.agent_name == "test-agent"
assert session.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_create_session_with_metadata(self, manager):
session = await manager.create_session(
agent_name="agent1",
metadata={"user_id": "u1"},
)
assert session.metadata == {"user_id": "u1"}
class TestSessionManagerGet:
@pytest.mark.asyncio
async def test_get_existing_session(self, manager):
created = await manager.create_session(agent_name="agent1")
fetched = await manager.get_session(created.session_id)
assert fetched is not None
assert fetched.session_id == created.session_id
@pytest.mark.asyncio
async def test_get_nonexistent_session(self, manager):
result = await manager.get_session("nonexistent")
assert result is None
class TestSessionManagerLifecycle:
@pytest.mark.asyncio
async def test_pause_and_resume(self, manager):
session = await manager.create_session(agent_name="agent1")
paused = await manager.pause_session(session.session_id)
assert paused.status == SessionStatus.PAUSED
resumed = await manager.resume_session(session.session_id)
assert resumed.status == SessionStatus.ACTIVE
@pytest.mark.asyncio
async def test_close_session(self, manager):
session = await manager.create_session(agent_name="agent1")
closed = await manager.close_session(session.session_id)
assert closed.status == SessionStatus.CLOSED
@pytest.mark.asyncio
async def test_close_nonexistent_returns_none(self, manager):
result = await manager.close_session("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_delete_session(self, manager):
session = await manager.create_session(agent_name="agent1")
deleted = await manager.delete_session(session.session_id)
assert deleted is True
assert await manager.get_session(session.session_id) is None
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_false(self, manager):
deleted = await manager.delete_session("nonexistent")
assert deleted is False
class TestSessionManagerMessages:
@pytest.mark.asyncio
async def test_append_user_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Hello",
)
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
assert msg.session_id == session.session_id
@pytest.mark.asyncio
async def test_append_assistant_message(self, manager):
session = await manager.create_session(agent_name="agent1")
msg = await manager.append_message(
session_id=session.session_id,
role=MessageRole.ASSISTANT,
content="Hi there!",
)
assert msg.role == MessageRole.ASSISTANT
@pytest.mark.asyncio
async def test_get_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
messages = await manager.get_messages(session.session_id)
assert len(messages) == 2
assert messages[0].content == "Hello"
assert messages[1].content == "Hi!"
@pytest.mark.asyncio
async def test_get_messages_pagination(self, manager):
session = await manager.create_session(agent_name="agent1")
for i in range(10):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=f"Message {i}",
)
# Get first 3 messages
page1 = await manager.get_messages(session.session_id, limit=3, offset=0)
assert len(page1) == 3
assert page1[0].content == "Message 0"
# Get next 3 messages
page2 = await manager.get_messages(session.session_id, limit=3, offset=3)
assert len(page2) == 3
assert page2[0].content == "Message 3"
@pytest.mark.asyncio
async def test_count_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
count = await manager.count_messages(session.session_id)
assert count == 2
@pytest.mark.asyncio
async def test_closed_session_rejects_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.close_session(session.session_id)
with pytest.raises(ValueError, match="closed"):
await manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_nonexistent_session_rejects_messages(self, manager):
with pytest.raises(ValueError, match="not found"):
await manager.append_message(
session_id="nonexistent",
role=MessageRole.USER,
content="Should fail",
)
@pytest.mark.asyncio
async def test_get_chat_messages(self, manager):
session = await manager.create_session(agent_name="agent1")
await manager.append_message(session_id=session.session_id, role=MessageRole.USER, content="Hello")
await manager.append_message(session_id=session.session_id, role=MessageRole.ASSISTANT, content="Hi!")
chat_msgs = await manager.get_chat_messages(session.session_id)
assert len(chat_msgs) == 2
assert chat_msgs[0] == {"role": "user", "content": "Hello"}
assert chat_msgs[1] == {"role": "assistant", "content": "Hi!"}
class TestSessionManagerList:
@pytest.mark.asyncio
async def test_list_sessions(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
sessions = await manager.list_sessions()
assert len(sessions) == 2
@pytest.mark.asyncio
async def test_list_sessions_by_agent(self, manager):
await manager.create_session(agent_name="agent1")
await manager.create_session(agent_name="agent2")
await manager.create_session(agent_name="agent1")
sessions = await manager.list_sessions(agent_name="agent1")
assert len(sessions) == 2
assert all(s.agent_name == "agent1" for s in sessions)
class TestSessionManagerHealth:
@pytest.mark.asyncio
async def test_health_check(self, manager):
assert await manager.health_check() is True
class TestAsyncWrites:
"""Tests for async (non-blocking) write behaviour."""
@pytest.mark.asyncio
async def test_append_message_returns_immediately(self, async_manager):
"""append_message returns the Message before it is persisted."""
session = await async_manager.create_session(agent_name="agent1")
msg = await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Hello",
)
# Message is returned immediately
assert msg.role == MessageRole.USER
assert msg.content == "Hello"
# Give the background writer a moment, then verify persistence
await async_manager.flush()
persisted = await async_manager.store.get_messages(session.session_id)
assert len(persisted) == 1
assert persisted[0].content == "Hello"
await async_manager.close()
@pytest.mark.asyncio
async def test_get_chat_messages_includes_wal_buffered(self, async_manager):
"""get_chat_messages returns WAL-buffered messages not yet persisted."""
session = await async_manager.create_session(agent_name="agent1")
# Append a message — it may still be in the WAL buffer
await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Buffered",
)
# get_chat_messages should include WAL-buffered messages
chat_msgs = await async_manager.get_chat_messages(session.session_id)
assert len(chat_msgs) >= 1
assert any(m["content"] == "Buffered" for m in chat_msgs)
await async_manager.close()
@pytest.mark.asyncio
async def test_flush_ensures_all_pending_writes(self, async_manager):
"""flush() waits until all queued writes are persisted."""
session = await async_manager.create_session(agent_name="agent1")
for i in range(5):
await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=f"Msg {i}",
)
await async_manager.flush()
persisted = await async_manager.store.get_messages(session.session_id)
assert len(persisted) == 5
await async_manager.close()
@pytest.mark.asyncio
async def test_rapid_appends_are_batched(self, async_manager):
"""Multiple rapid append_messages are all persisted correctly."""
session = await async_manager.create_session(agent_name="agent1")
# Fire off many messages rapidly
messages = []
for i in range(20):
msg = await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content=f"Rapid {i}",
)
messages.append(msg)
await async_manager.flush()
persisted = await async_manager.store.get_messages(session.session_id)
assert len(persisted) == 20
contents = [m.content for m in persisted]
for i in range(20):
assert f"Rapid {i}" in contents
await async_manager.close()
@pytest.mark.asyncio
async def test_session_close_flushes_pending_writes(self, async_manager):
"""Closing a session flushes pending writes first."""
session = await async_manager.create_session(agent_name="agent1")
await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Before close",
)
closed = await async_manager.close_session(session.session_id)
assert closed.status == SessionStatus.CLOSED
# Message should be persisted because close_session flushes
persisted = await async_manager.store.get_messages(session.session_id)
assert len(persisted) == 1
assert persisted[0].content == "Before close"
await async_manager.close()
@pytest.mark.asyncio
async def test_manager_close_stops_writer(self, async_manager):
"""close() flushes and stops the background writer."""
session = await async_manager.create_session(agent_name="agent1")
await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Final",
)
await async_manager.close()
# After close, the write queue should be stopped
assert async_manager._write_queue is None or async_manager._write_queue._worker is None
@pytest.mark.asyncio
async def test_async_writes_disabled_by_default(self):
"""Without async_writes=True, writes are synchronous."""
mgr = SessionManager(store=InMemorySessionStore())
assert mgr._write_queue is None
session = await mgr.create_session(agent_name="agent1")
await mgr.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="Sync",
)
# Should be immediately persisted (no flush needed)
persisted = await mgr.store.get_messages(session.session_id)
assert len(persisted) == 1
@pytest.mark.asyncio
async def test_get_messages_includes_wal_buffered(self, async_manager):
"""get_messages returns WAL-buffered messages not yet persisted."""
session = await async_manager.create_session(agent_name="agent1")
await async_manager.append_message(
session_id=session.session_id,
role=MessageRole.USER,
content="WAL msg",
)
messages = await async_manager.get_messages(session.session_id)
assert len(messages) >= 1
assert any(m.content == "WAL msg" for m in messages)
await async_manager.close()