459 lines
16 KiB
Python
459 lines
16 KiB
Python
"""Tests for unified EvolutionStoreProtocol compliance
|
|
|
|
Verifies that all backends implement the full Protocol interface:
|
|
- InMemoryEvolutionStore
|
|
- PersistentEvolutionStore
|
|
- PostgreSQLEvolutionStore (mocked async session)
|
|
- EvolutionStore (legacy, with NotImplementedError for skill_version/ab_test)
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.protocol import EvolutionEvent
|
|
from agentkit.evolution.evolution_store import (
|
|
EvolutionStore,
|
|
EvolutionStoreProtocol,
|
|
InMemoryEvolutionStore,
|
|
PersistentEvolutionStore,
|
|
create_evolution_store,
|
|
)
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_event():
|
|
"""A sample EvolutionEvent."""
|
|
return EvolutionEvent(
|
|
agent_name="test_agent",
|
|
change_type="prompt",
|
|
before={"prompt": "old prompt"},
|
|
after={"prompt": "new prompt"},
|
|
metrics={"accuracy": 0.9},
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def memory_store():
|
|
return InMemoryEvolutionStore()
|
|
|
|
|
|
@pytest.fixture
|
|
def sqlite_store(tmp_path):
|
|
db_path = str(tmp_path / "test_unified.db")
|
|
return PersistentEvolutionStore(db_path=db_path)
|
|
|
|
|
|
# ── Protocol compliance tests ─────────────────────────────
|
|
|
|
|
|
class TestProtocolCompliance:
|
|
"""Verify all stores implement EvolutionStoreProtocol."""
|
|
|
|
def test_inmemory_is_protocol(self):
|
|
assert isinstance(InMemoryEvolutionStore(), EvolutionStoreProtocol)
|
|
|
|
def test_persistent_is_protocol(self, tmp_path):
|
|
db_path = str(tmp_path / "protocol_check.db")
|
|
assert isinstance(PersistentEvolutionStore(db_path=db_path), EvolutionStoreProtocol)
|
|
|
|
def test_pg_store_is_protocol(self):
|
|
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
|
|
|
store = PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
|
|
assert isinstance(store, EvolutionStoreProtocol)
|
|
|
|
def test_legacy_evolution_store_is_protocol(self):
|
|
"""Legacy EvolutionStore also satisfies Protocol (has all method signatures)."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
|
assert isinstance(store, EvolutionStoreProtocol)
|
|
|
|
|
|
# ── InMemoryEvolutionStore: full Protocol ─────────────────
|
|
|
|
|
|
class TestInMemoryFullProtocol:
|
|
"""InMemoryEvolutionStore implements all Protocol methods."""
|
|
|
|
async def test_record_and_list_events(self, memory_store, sample_event):
|
|
event_id = await memory_store.record(sample_event)
|
|
assert event_id is not None
|
|
|
|
events = await memory_store.list_events()
|
|
assert len(events) == 1
|
|
assert events[0]["agent_name"] == "test_agent"
|
|
assert events[0]["change_type"] == "prompt"
|
|
|
|
async def test_rollback(self, memory_store, sample_event):
|
|
event_id = await memory_store.record(sample_event)
|
|
result = await memory_store.rollback(event_id)
|
|
assert result is True
|
|
|
|
events = await memory_store.list_events()
|
|
assert events[0]["status"] == "rolled_back"
|
|
|
|
async def test_rollback_nonexistent(self, memory_store):
|
|
result = await memory_store.rollback("nonexistent")
|
|
assert result is False
|
|
|
|
async def test_list_events_with_filters(self, memory_store):
|
|
await memory_store.record(
|
|
EvolutionEvent(agent_name="a", change_type="prompt", before={}, after={})
|
|
)
|
|
await memory_store.record(
|
|
EvolutionEvent(agent_name="b", change_type="strategy", before={}, after={})
|
|
)
|
|
|
|
events = await memory_store.list_events(agent_name="a")
|
|
assert len(events) == 1
|
|
assert events[0]["agent_name"] == "a"
|
|
|
|
async def test_record_and_list_skill_version(self, memory_store):
|
|
vid = await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
|
assert vid is not None
|
|
|
|
versions = await memory_store.list_skill_versions("search")
|
|
assert len(versions) == 1
|
|
assert versions[0]["version"] == "v1"
|
|
assert versions[0]["content"] == '{"prompt": "v1"}'
|
|
|
|
async def test_skill_version_with_parent(self, memory_store):
|
|
await memory_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
|
await memory_store.record_skill_version(
|
|
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
|
|
)
|
|
|
|
versions = await memory_store.list_skill_versions("search")
|
|
assert len(versions) == 2
|
|
assert versions[0]["version"] == "v2"
|
|
assert versions[0]["parent_version"] == "v1"
|
|
|
|
async def test_record_and_get_ab_test_result(self, memory_store):
|
|
rid = await memory_store.record_ab_test_result("t1", "control", 0.8, 5)
|
|
assert rid is not None
|
|
|
|
results = await memory_store.get_ab_test_results("t1")
|
|
assert len(results) == 1
|
|
assert results[0]["variant"] == "control"
|
|
assert results[0]["score"] == 0.8
|
|
assert results[0]["sample_count"] == 5
|
|
|
|
async def test_ab_test_multiple_variants(self, memory_store):
|
|
await memory_store.record_ab_test_result("t1", "control", 0.8, 10)
|
|
await memory_store.record_ab_test_result("t1", "experiment", 0.9, 10)
|
|
|
|
results = await memory_store.get_ab_test_results("t1")
|
|
assert len(results) == 2
|
|
|
|
async def test_list_skill_versions_empty(self, memory_store):
|
|
versions = await memory_store.list_skill_versions("nonexistent")
|
|
assert versions == []
|
|
|
|
async def test_get_ab_test_results_empty(self, memory_store):
|
|
results = await memory_store.get_ab_test_results("nonexistent")
|
|
assert results == []
|
|
|
|
|
|
# ── PersistentEvolutionStore: full Protocol ───────────────
|
|
|
|
|
|
class TestSQLiteFullProtocol:
|
|
"""PersistentEvolutionStore implements all Protocol methods."""
|
|
|
|
async def test_record_and_list_events(self, sqlite_store, sample_event):
|
|
event_id = await sqlite_store.record(sample_event)
|
|
assert event_id is not None
|
|
|
|
events = await sqlite_store.list_events()
|
|
assert len(events) == 1
|
|
assert events[0]["agent_name"] == "test_agent"
|
|
|
|
async def test_rollback(self, sqlite_store, sample_event):
|
|
event_id = await sqlite_store.record(sample_event)
|
|
result = await sqlite_store.rollback(event_id)
|
|
assert result is True
|
|
|
|
events = await sqlite_store.list_events()
|
|
assert events[0]["status"] == "rolled_back"
|
|
|
|
async def test_record_and_list_skill_version(self, sqlite_store):
|
|
vid = await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
|
assert vid is not None
|
|
|
|
versions = await sqlite_store.list_skill_versions("search")
|
|
assert len(versions) == 1
|
|
assert versions[0]["version"] == "v1"
|
|
|
|
async def test_skill_version_with_parent(self, sqlite_store):
|
|
await sqlite_store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
|
await sqlite_store.record_skill_version(
|
|
"search", "v2", '{"prompt": "v2"}', parent_version="v1"
|
|
)
|
|
|
|
versions = await sqlite_store.list_skill_versions("search")
|
|
assert len(versions) == 2
|
|
assert versions[0]["version"] == "v2"
|
|
assert versions[0]["parent_version"] == "v1"
|
|
|
|
async def test_record_and_get_ab_test_result(self, sqlite_store):
|
|
rid = await sqlite_store.record_ab_test_result("t1", "control", 0.8, 5)
|
|
assert rid is not None
|
|
|
|
results = await sqlite_store.get_ab_test_results("t1")
|
|
assert len(results) == 1
|
|
assert results[0]["variant"] == "control"
|
|
|
|
async def test_ab_test_multiple_variants(self, sqlite_store):
|
|
await sqlite_store.record_ab_test_result("t1", "control", 0.8, 10)
|
|
await sqlite_store.record_ab_test_result("t1", "experiment", 0.9, 10)
|
|
|
|
results = await sqlite_store.get_ab_test_results("t1")
|
|
assert len(results) == 2
|
|
|
|
async def test_list_skill_versions_empty(self, sqlite_store):
|
|
versions = await sqlite_store.list_skill_versions("nonexistent")
|
|
assert versions == []
|
|
|
|
async def test_get_ab_test_results_empty(self, sqlite_store):
|
|
results = await sqlite_store.get_ab_test_results("nonexistent")
|
|
assert results == []
|
|
|
|
|
|
# ── PostgreSQLEvolutionStore: mocked Protocol ─────────────
|
|
|
|
|
|
class TestPGStoreMocked:
|
|
"""Test PostgreSQLEvolutionStore with mocked async session.
|
|
|
|
Since we can't require a running PostgreSQL in unit tests,
|
|
we mock the async session to verify the logic paths.
|
|
"""
|
|
|
|
def _make_pg_store(self):
|
|
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
|
|
|
return PostgreSQLEvolutionStore(database_url="postgresql+asyncpg://test:test@localhost/test")
|
|
|
|
async def test_record_with_mock_session(self, sample_event):
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from contextlib import asynccontextmanager
|
|
|
|
store = self._make_pg_store()
|
|
|
|
# Mock the session factory
|
|
mock_session = AsyncMock()
|
|
mock_session.add = MagicMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@asynccontextmanager
|
|
async def mock_sf():
|
|
yield mock_session
|
|
|
|
store._session_factory = mock_sf
|
|
store._initialized = True
|
|
|
|
event_id = await store.record(sample_event)
|
|
assert event_id is not None
|
|
assert sample_event.event_id == event_id
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
async def test_rollback_with_mock_session(self):
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from contextlib import asynccontextmanager
|
|
|
|
store = self._make_pg_store()
|
|
|
|
# Create a mock entry
|
|
mock_entry = MagicMock()
|
|
mock_entry.status = "active"
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = mock_entry
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@asynccontextmanager
|
|
async def mock_sf():
|
|
yield mock_session
|
|
|
|
store._session_factory = mock_sf
|
|
store._initialized = True
|
|
|
|
result = await store.rollback("test-event-id")
|
|
assert result is True
|
|
assert mock_entry.status == "rolled_back"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
async def test_rollback_not_found(self):
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from contextlib import asynccontextmanager
|
|
|
|
store = self._make_pg_store()
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = None
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@asynccontextmanager
|
|
async def mock_sf():
|
|
yield mock_session
|
|
|
|
store._session_factory = mock_sf
|
|
store._initialized = True
|
|
|
|
result = await store.rollback("nonexistent")
|
|
assert result is False
|
|
|
|
async def test_record_skill_version_with_mock(self):
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from contextlib import asynccontextmanager
|
|
|
|
store = self._make_pg_store()
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.add = MagicMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@asynccontextmanager
|
|
async def mock_sf():
|
|
yield mock_session
|
|
|
|
store._session_factory = mock_sf
|
|
store._initialized = True
|
|
|
|
vid = await store.record_skill_version("search", "v1", '{"prompt": "v1"}')
|
|
assert vid is not None
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
async def test_record_ab_test_result_with_mock(self):
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from contextlib import asynccontextmanager
|
|
|
|
store = self._make_pg_store()
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.add = MagicMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@asynccontextmanager
|
|
async def mock_sf():
|
|
yield mock_session
|
|
|
|
store._session_factory = mock_sf
|
|
store._initialized = True
|
|
|
|
rid = await store.record_ab_test_result("t1", "control", 0.8, 5)
|
|
assert rid is not None
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
|
|
# ── Legacy EvolutionStore: NotImplementedError tests ──────
|
|
|
|
|
|
class TestLegacyEvolutionStoreStubs:
|
|
"""Legacy EvolutionStore raises NotImplementedError for skill_version/ab_test."""
|
|
|
|
async def test_record_skill_version_raises(self):
|
|
from unittest.mock import MagicMock
|
|
|
|
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
|
with pytest.raises(NotImplementedError, match="skill_version"):
|
|
await store.record_skill_version("s", "v1", "content")
|
|
|
|
async def test_list_skill_versions_raises(self):
|
|
from unittest.mock import MagicMock
|
|
|
|
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
|
with pytest.raises(NotImplementedError, match="skill_version"):
|
|
await store.list_skill_versions("s")
|
|
|
|
async def test_record_ab_test_result_raises(self):
|
|
from unittest.mock import MagicMock
|
|
|
|
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
|
with pytest.raises(NotImplementedError, match="A/B test"):
|
|
await store.record_ab_test_result("t1", "control", 0.8)
|
|
|
|
async def test_get_ab_test_results_raises(self):
|
|
from unittest.mock import MagicMock
|
|
|
|
store = EvolutionStore(session_factory=MagicMock(), evolution_model=MagicMock())
|
|
with pytest.raises(NotImplementedError, match="A/B test"):
|
|
await store.get_ab_test_results("t1")
|
|
|
|
|
|
# ── Factory tests ─────────────────────────────────────────
|
|
|
|
|
|
class TestCreateEvolutionStoreExtended:
|
|
def test_create_memory_backend(self):
|
|
store = create_evolution_store(backend="memory")
|
|
assert isinstance(store, InMemoryEvolutionStore)
|
|
|
|
def test_create_sqlite_backend(self, tmp_path):
|
|
db_path = str(tmp_path / "factory_test.db")
|
|
store = create_evolution_store(backend="sqlite", db_path=db_path)
|
|
assert isinstance(store, PersistentEvolutionStore)
|
|
|
|
def test_create_default_backend(self):
|
|
store = create_evolution_store()
|
|
assert isinstance(store, InMemoryEvolutionStore)
|
|
|
|
def test_create_sql_backend_without_params_falls_back(self):
|
|
store = create_evolution_store(backend="sql")
|
|
assert isinstance(store, InMemoryEvolutionStore)
|
|
|
|
def test_create_postgresql_without_url_falls_back(self):
|
|
"""PostgreSQL backend without database_url falls back to memory."""
|
|
# Clear env var if set
|
|
old_val = os.environ.pop("AGENTKIT_DATABASE_URL", None)
|
|
try:
|
|
store = create_evolution_store(backend="postgresql")
|
|
assert isinstance(store, InMemoryEvolutionStore)
|
|
finally:
|
|
if old_val is not None:
|
|
os.environ["AGENTKIT_DATABASE_URL"] = old_val
|
|
|
|
def test_create_postgresql_with_url(self):
|
|
"""PostgreSQL backend with database_url returns PostgreSQLEvolutionStore."""
|
|
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
|
|
|
store = create_evolution_store(
|
|
backend="postgresql",
|
|
database_url="postgresql+asyncpg://user:pass@localhost/db",
|
|
)
|
|
assert isinstance(store, PostgreSQLEvolutionStore)
|
|
|
|
def test_create_postgresql_with_env_url(self):
|
|
"""PostgreSQL backend reads database_url from environment variable."""
|
|
from agentkit.evolution.pg_store import PostgreSQLEvolutionStore
|
|
|
|
old_val = os.environ.get("AGENTKIT_DATABASE_URL")
|
|
try:
|
|
os.environ["AGENTKIT_DATABASE_URL"] = "postgresql+asyncpg://user:pass@localhost/db"
|
|
store = create_evolution_store(backend="postgresql")
|
|
assert isinstance(store, PostgreSQLEvolutionStore)
|
|
finally:
|
|
if old_val is not None:
|
|
os.environ["AGENTKIT_DATABASE_URL"] = old_val
|
|
else:
|
|
os.environ.pop("AGENTKIT_DATABASE_URL", None)
|