fischer-agentkit/tests/unit/test_unified_evolution_stor...

459 lines
16 KiB
Python

"""Tests for unified EvolutionStoreProtocol compliance
Verifies that all backends implement the full Protocol interface:
- InMemoryEvolutionStore
- PersistentEvolutionStore
- PostgreSQLEvolutionStore (mocked async session)
- EvolutionStore (legacy, with NotImplementedError for skill_version/ab_test)
"""
import os
import tempfile
import pytest
from agentkit.core.protocol import EvolutionEvent
from agentkit.evolution.evolution_store import (
EvolutionStore,
EvolutionStoreProtocol,
InMemoryEvolutionStore,
PersistentEvolutionStore,
create_evolution_store,
)
# ── Fixtures ──────────────────────────────────────────────
@pytest.fixture
def sample_event():
"""A sample EvolutionEvent."""
return EvolutionEvent(
agent_name="test_agent",
change_type="prompt",
before={"prompt": "old prompt"},
after={"prompt": "new prompt"},
metrics={"accuracy": 0.9},
)
@pytest.fixture
def memory_store():
return InMemoryEvolutionStore()
@pytest.fixture
def sqlite_store(tmp_path):
db_path = str(tmp_path / "test_unified.db")
return PersistentEvolutionStore(db_path=db_path)
# ── Protocol compliance tests ─────────────────────────────
class TestProtocolCompliance:
"""Verify all stores implement EvolutionStoreProtocol."""
def test_inmemory_is_protocol(self):
assert isinstance(InMemoryEvolutionStore(), EvolutionStoreProtocol)
def test_persistent_is_protocol(self, tmp_path):
db_path = str(tmp_path / "protocol_check.db")
assert isinstance(PersistentEvolutionStore(db_path=db_path), EvolutionStoreProtocol)
def test_pg_store_is_protocol(self):
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
store = PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
assert isinstance(store, EvolutionStoreProtocol)
def test_legacy_evolution_store_is_protocol(self):
"""Legacy EvolutionStore also satisfies Protocol (has all method signatures)."""
from unittest.mock import AsyncMock, MagicMock
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
assert isinstance(store, EvolutionStoreProtocol)
# ── InMemoryEvolutionStore: full Protocol ─────────────────
class TestInMemoryFullProtocol:
"""InMemoryEvolutionStore implements all Protocol methods."""
async def test_record_and_list_events(self, memory_store, sample_event):
event_id = await memory_store.record(sample_event)
assert event_id is not None
events = await memory_store.list_events()
assert len(events) == 1
assert events[0]["agent_name"] == "test_agent"
assert events[0]["change_type"] == "prompt"
async def test_rollback(self, memory_store, sample_event):
event_id = await memory_store.record(sample_event)
result = await memory_store.rollback(event_id)
assert result is True
events = await memory_store.list_events()
assert events[0]["status"] == "rolled_back"
async def test_rollback_nonexistent(self, memory_store):
result = await memory_store.rollback("nonexistent")
assert result is False
async def test_list_events_with_filters(self, memory_store):
await memory_store.record(
EvolutionEvent(agent_name="a", change_type="prompt", before={}, after={})
)
await memory_store.record(
EvolutionEvent(agent_name="b", change_type="strategy", before={}, after={})
)
events = await memory_store.list_events(agent_name="a")
assert len(events) == 1
assert events[0]["agent_name"] == "a"
async def test_record_and_list_skill_version(self, memory_store):
vid = await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
assert vid is not None
versions = await memory_store.list_skill_versions("search")
assert len(versions) == 1
assert versions[0]["version"] == "v1"
assert versions[0]["content"] == '{"prompt": "v1"}'
async def test_skill_version_with_parent(self, memory_store):
await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
await memory_store.record_skill_version(
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
)
versions = await memory_store.list_skill_versions("search")
assert len(versions) == 2
assert versions[0]["version"] == "v2"
assert versions[0]["parent_version"] == "v1"
async def test_record_and_get_ab_test_result(self, memory_store):
rid = await memory_store.record_ab_test_result("t1", "control", 0.8, 5)
assert rid is not None
results = await memory_store.get_ab_test_results("t1")
assert len(results) == 1
assert results[0]["variant"] == "control"
assert results[0]["score"] == 0.8
assert results[0]["sample_count"] == 5
async def test_ab_test_multiple_variants(self, memory_store):
await memory_store.record_ab_test_result("t1", "control", 0.8, 10)
await memory_store.record_ab_test_result("t1", "experiment", 0.9, 10)
results = await memory_store.get_ab_test_results("t1")
assert len(results) == 2
async def test_list_skill_versions_empty(self, memory_store):
versions = await memory_store.list_skill_versions("nonexistent")
assert versions == []
async def test_get_ab_test_results_empty(self, memory_store):
results = await memory_store.get_ab_test_results("nonexistent")
assert results == []
# ── PersistentEvolutionStore: full Protocol ───────────────
class TestSQLiteFullProtocol:
"""PersistentEvolutionStore implements all Protocol methods."""
async def test_record_and_list_events(self, sqlite_store, sample_event):
event_id = await sqlite_store.record(sample_event)
assert event_id is not None
events = await sqlite_store.list_events()
assert len(events) == 1
assert events[0]["agent_name"] == "test_agent"
async def test_rollback(self, sqlite_store, sample_event):
event_id = await sqlite_store.record(sample_event)
result = await sqlite_store.rollback(event_id)
assert result is True
events = await sqlite_store.list_events()
assert events[0]["status"] == "rolled_back"
async def test_record_and_list_skill_version(self, sqlite_store):
vid = await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
assert vid is not None
versions = await sqlite_store.list_skill_versions("search")
assert len(versions) == 1
assert versions[0]["version"] == "v1"
async def test_skill_version_with_parent(self, sqlite_store):
await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
await sqlite_store.record_skill_version(
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
)
versions = await sqlite_store.list_skill_versions("search")
assert len(versions) == 2
assert versions[0]["version"] == "v2"
assert versions[0]["parent_version"] == "v1"
async def test_record_and_get_ab_test_result(self, sqlite_store):
rid = await sqlite_store.record_ab_test_result("t1", "control", 0.8, 5)
assert rid is not None
results = await sqlite_store.get_ab_test_results("t1")
assert len(results) == 1
assert results[0]["variant"] == "control"
async def test_ab_test_multiple_variants(self, sqlite_store):
await sqlite_store.record_ab_test_result("t1", "control", 0.8, 10)
await sqlite_store.record_ab_test_result("t1", "experiment", 0.9, 10)
results = await sqlite_store.get_ab_test_results("t1")
assert len(results) == 2
async def test_list_skill_versions_empty(self, sqlite_store):
versions = await sqlite_store.list_skill_versions("nonexistent")
assert versions == []
async def test_get_ab_test_results_empty(self, sqlite_store):
results = await sqlite_store.get_ab_test_results("nonexistent")
assert results == []
# ── PostgreSQLEvolutionStore: mocked Protocol ─────────────
class TestPGStoreMocked:
"""Test PostgreSQLEvolutionStore with mocked async session.
Since we can't require a running PostgreSQL in unit tests,
we mock the async session to verify the logic paths.
"""
def _make_pg_store(self):
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
return PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
async def test_record_with_mock_session(self, sample_event):
from unittest.mock import AsyncMock, MagicMock, patch
from contextlib import asynccontextmanager
store = self._make_pg_store()
# Mock the session factory
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@asynccontextmanager
async def mock_sf():
yield mock_session
store._session_factory = mock_sf
store._initialized = True
event_id = await store.record(sample_event)
assert event_id is not None
assert sample_event.event_id == event_id
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
async def test_rollback_with_mock_session(self):
from unittest.mock import AsyncMock, MagicMock, patch
from contextlib import asynccontextmanager
store = self._make_pg_store()
# Create a mock entry
mock_entry = MagicMock()
mock_entry.status = "active"
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_entry
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@asynccontextmanager
async def mock_sf():
yield mock_session
store._session_factory = mock_sf
store._initialized = True
result = await store.rollback("test-event-id")
assert result is True
assert mock_entry.status == "rolled_back"
mock_session.commit.assert_called_once()
async def test_rollback_not_found(self):
from unittest.mock import AsyncMock, MagicMock
from contextlib import asynccontextmanager
store = self._make_pg_store()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.rollback = AsyncMock()
@asynccontextmanager
async def mock_sf():
yield mock_session
store._session_factory = mock_sf
store._initialized = True
result = await store.rollback("nonexistent")
assert result is False
async def test_record_skill_version_with_mock(self):
from unittest.mock import AsyncMock, MagicMock
from contextlib import asynccontextmanager
store = self._make_pg_store()
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@asynccontextmanager
async def mock_sf():
yield mock_session
store._session_factory = mock_sf
store._initialized = True
vid = await store.record_skill_version("search", "v1", '{"prompt": "v1"}')
assert vid is not None
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
async def test_record_ab_test_result_with_mock(self):
from unittest.mock import AsyncMock, MagicMock
from contextlib import asynccontextmanager
store = self._make_pg_store()
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@asynccontextmanager
async def mock_sf():
yield mock_session
store._session_factory = mock_sf
store._initialized = True
rid = await store.record_ab_test_result("t1", "control", 0.8, 5)
assert rid is not None
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# ── Legacy EvolutionStore: NotImplementedError tests ──────
class TestLegacyEvolutionStoreStubs:
"""Legacy EvolutionStore raises NotImplementedError for skill_version/ab_test."""
async def test_record_skill_version_raises(self):
from unittest.mock import MagicMock
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
with pytest.raises(NotImplementedError, match="skill_version"):
await store.record_skill_version("s", "v1", "content")
async def test_list_skill_versions_raises(self):
from unittest.mock import MagicMock
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
with pytest.raises(NotImplementedError, match="skill_version"):
await store.list_skill_versions("s")
async def test_record_ab_test_result_raises(self):
from unittest.mock import MagicMock
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
with pytest.raises(NotImplementedError, match="A/B test"):
await store.record_ab_test_result("t1", "control", 0.8)
async def test_get_ab_test_results_raises(self):
from unittest.mock import MagicMock
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
with pytest.raises(NotImplementedError, match="A/B test"):
await store.get_ab_test_results("t1")
# ── Factory tests ─────────────────────────────────────────
class TestCreateEvolutionStoreExtended:
def test_create_memory_backend(self):
store = create_evolution_store(backend="memory")
assert isinstance(store, InMemoryEvolutionStore)
def test_create_sqlite_backend(self, tmp_path):
db_path = str(tmp_path / "factory_test.db")
store = create_evolution_store(backend="sqlite", db_path=db_path)
assert isinstance(store, PersistentEvolutionStore)
def test_create_default_backend(self):
store = create_evolution_store()
assert isinstance(store, InMemoryEvolutionStore)
def test_create_sql_backend_without_params_falls_back(self):
store = create_evolution_store(backend="sql")
assert isinstance(store, InMemoryEvolutionStore)
def test_create_postgresql_without_url_falls_back(self):
"""PostgreSQL backend without database_url falls back to memory."""
# Clear env var if set
old_val = os.environ.pop("AGENTKIT_DATABASE_URL", None)
try:
store = create_evolution_store(backend="postgresql")
assert isinstance(store, InMemoryEvolutionStore)
finally:
if old_val is not None:
os.environ["AGENTKIT_DATABASE_URL"] = old_val
def test_create_postgresql_with_url(self):
"""PostgreSQL backend with database_url returns PostgreSQLEvolutionStore."""
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
store = create_evolution_store(
backend="postgresql",
database_url="postgresql+asyncpg://user:pass@localhost/db",
)
assert isinstance(store, PostgreSQLEvolutionStore)
def test_create_postgresql_with_env_url(self):
"""PostgreSQL backend reads database_url from environment variable."""
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
old_val = os.environ.get("AGENTKIT_DATABASE_URL")
try:
os.environ["AGENTKIT_DATABASE_URL"] = "postgresql+asyncpg://user:pass@localhost/db"
store = create_evolution_store(backend="postgresql")
assert isinstance(store, PostgreSQLEvolutionStore)
finally:
if old_val is not None:
os.environ["AGENTKIT_DATABASE_URL"] = old_val
else:
os.environ.pop("AGENTKIT_DATABASE_URL", None)