183 lines
7.1 KiB
Python
183 lines
7.1 KiB
Python
"""Unit tests for SqliteConversationStore."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
|
|
|
|
|
@pytest.fixture
|
|
def db_path(tmp_path: Path) -> str:
|
|
"""Return a temporary database path."""
|
|
return str(tmp_path / "test_conversations.db")
|
|
|
|
|
|
@pytest.fixture
|
|
async def store(db_path: str) -> SqliteConversationStore:
|
|
"""Create a SqliteConversationStore with a temporary database."""
|
|
s = SqliteConversationStore(db_path=db_path)
|
|
yield s
|
|
await s._close_db()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Basic CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBasicCRUD:
|
|
async def test_create_conversation(self, store: SqliteConversationStore) -> None:
|
|
conv = await store.get_or_create()
|
|
assert conv.id
|
|
assert conv.created_at is not None
|
|
assert conv.updated_at is not None
|
|
assert conv.messages == []
|
|
|
|
async def test_create_conversation_with_id(self, store: SqliteConversationStore) -> None:
|
|
conv = await store.get_or_create("my-conv-id")
|
|
assert conv.id == "my-conv-id"
|
|
|
|
async def test_get_or_create_returns_existing(self, store: SqliteConversationStore) -> None:
|
|
await store.get_or_create("reuse-id")
|
|
await store.add_message("reuse-id", "user", "hello")
|
|
conv2 = await store.get_or_create("reuse-id")
|
|
assert conv2.id == "reuse-id"
|
|
# In-memory cache should have the message
|
|
assert len(conv2.messages) == 1
|
|
|
|
async def test_add_message(self, store: SqliteConversationStore) -> None:
|
|
await store.get_or_create("msg-test")
|
|
msg = await store.add_message("msg-test", "user", "Hello world")
|
|
assert msg.role == "user"
|
|
assert msg.content == "Hello world"
|
|
assert msg.metadata == {}
|
|
|
|
async def test_add_message_with_metadata(self, store: SqliteConversationStore) -> None:
|
|
await store.get_or_create("meta-test")
|
|
msg = await store.add_message("meta-test", "assistant", "Hi", {"key": "value"})
|
|
assert msg.metadata == {"key": "value"}
|
|
|
|
async def test_add_message_nonexistent_conversation_raises(
|
|
self, store: SqliteConversationStore
|
|
) -> None:
|
|
with pytest.raises(KeyError, match="not found"):
|
|
await store.add_message("nonexistent", "user", "hello")
|
|
|
|
async def test_get_history(self, store: SqliteConversationStore) -> None:
|
|
await store.get_or_create("hist-test")
|
|
await store.add_message("hist-test", "user", "msg1")
|
|
await store.add_message("hist-test", "assistant", "msg2")
|
|
await store.add_message("hist-test", "user", "msg3")
|
|
history = await store.get_history("hist-test")
|
|
assert len(history) == 3
|
|
assert history[0].content == "msg1"
|
|
assert history[1].content == "msg2"
|
|
assert history[2].content == "msg3"
|
|
|
|
async def test_get_history_with_limit(self, store: SqliteConversationStore) -> None:
|
|
await store.get_or_create("limit-test")
|
|
for i in range(10):
|
|
await store.add_message("limit-test", "user", f"msg{i}")
|
|
history = await store.get_history("limit-test", limit=3)
|
|
assert len(history) == 3
|
|
# Should return the last 3 messages
|
|
assert history[0].content == "msg7"
|
|
assert history[1].content == "msg8"
|
|
assert history[2].content == "msg9"
|
|
|
|
async def test_get_history_empty(self, store: SqliteConversationStore) -> None:
|
|
history = await store.get_history("nonexistent")
|
|
assert history == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Persistence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPersistence:
|
|
async def test_data_survives_store_recreation(self, db_path: str) -> None:
|
|
"""Create store, add data, create new store with same DB path, verify data survives."""
|
|
store1 = SqliteConversationStore(db_path=db_path)
|
|
await store1.get_or_create("persist-conv")
|
|
await store1.add_message("persist-conv", "user", "persistent message")
|
|
await store1._close_db()
|
|
|
|
store2 = SqliteConversationStore(db_path=db_path)
|
|
history = await store2.get_history("persist-conv")
|
|
assert len(history) == 1
|
|
assert history[0].content == "persistent message"
|
|
assert history[0].role == "user"
|
|
await store2._close_db()
|
|
|
|
async def test_conversations_survive_recreation(self, db_path: str) -> None:
|
|
store1 = SqliteConversationStore(db_path=db_path)
|
|
await store1.get_or_create("conv-a")
|
|
await store1.get_or_create("conv-b")
|
|
await store1._close_db()
|
|
|
|
store2 = SqliteConversationStore(db_path=db_path)
|
|
convs = await store2.list_conversations()
|
|
ids = {c.id for c in convs}
|
|
assert "conv-a" in ids
|
|
assert "conv-b" in ids
|
|
await store2._close_db()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_conversations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListConversations:
|
|
async def test_list_conversations_ordering(self, store: SqliteConversationStore) -> None:
|
|
"""Most recently updated conversation should appear first."""
|
|
await store.get_or_create("conv-first")
|
|
await store.get_or_create("conv-second")
|
|
# Update conv1 by adding a message
|
|
await store.add_message("conv-first", "user", "update")
|
|
convs = await store.list_conversations()
|
|
assert len(convs) == 2
|
|
# conv1 was updated more recently
|
|
assert convs[0].id == "conv-first"
|
|
|
|
async def test_list_conversations_limit(self, store: SqliteConversationStore) -> None:
|
|
for i in range(5):
|
|
await store.get_or_create(f"conv-{i}")
|
|
convs = await store.list_conversations(limit=3)
|
|
assert len(convs) == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# LRU cache eviction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLRUCache:
|
|
async def test_cache_eviction(self, db_path: str) -> None:
|
|
"""Cache should evict oldest entries when over limit (data still in SQLite)."""
|
|
store = SqliteConversationStore(db_path=db_path, max_conversations=3)
|
|
for i in range(5):
|
|
await store.get_or_create(f"evict-{i}")
|
|
# Cache should have at most 3 entries
|
|
assert len(store._cache) <= 3
|
|
# But all 5 conversations should be in SQLite
|
|
convs = await store.list_conversations(limit=10)
|
|
assert len(convs) == 5
|
|
await store._close_db()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# restore_from_store (no-op)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRestoreFromStore:
|
|
async def test_restore_is_noop(self, store: SqliteConversationStore) -> None:
|
|
"""restore_from_store should be a no-op for SQLite store."""
|
|
# Should not raise
|
|
await store.restore_from_store()
|