"""Tests for EvolutionStore - evolution event recording and rollback""" import uuid from datetime import datetime, timezone from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.protocol import EvolutionEvent from agentkit.evolution.evolution_store import EvolutionStore # ── Mock helpers ────────────────────────────────────────── def _make_entry( id: uuid.UUID | None = None, agent_name: str = "test_agent", change_type: str = "prompt", before: dict | None = None, after: dict | None = None, metrics: dict | None = None, status: str = "active", created_at: datetime | None = None, ): """Create a mock DB entry object.""" entry = MagicMock() entry.id = id or uuid.uuid4() entry.agent_name = agent_name entry.change_type = change_type entry.before = before or {} entry.after = after or {} entry.metrics = metrics entry.status = status entry.created_at = created_at or datetime.now(timezone.utc) return entry def _make_model(): """Create a mock evolution model class. The model class is used like: Model(id=..., agent_name=..., ...) and also as: Model.id, Model.agent_name, etc. in SQLAlchemy select().where(). """ Model = MagicMock() def _init(*args, **kwargs): instance = MagicMock() instance.id = kwargs.get("id", uuid.uuid4()) instance.agent_name = kwargs.get("agent_name", "test_agent") instance.change_type = kwargs.get("change_type", "prompt") instance.before = kwargs.get("before", {}) instance.after = kwargs.get("after", {}) instance.metrics = kwargs.get("metrics") instance.status = kwargs.get("status", "active") instance.created_at = kwargs.get("created_at", datetime.now(timezone.utc)) return instance Model.side_effect = _init return Model def _make_select_mock(): """Create a mock for sqlalchemy.select that supports .where()/.order_by() chaining.""" stmt = MagicMock() stmt.where.return_value = stmt stmt.order_by.return_value = stmt mock_select = MagicMock(return_value=stmt) return mock_select, stmt class SessionCapture: """Helper that captures the session created by the session factory.""" def __init__(self): self.sessions = [] @property def last(self): return self.sessions[-1] if self.sessions else None def _make_execute_result(scalar_one_or_none_val=None, scalars_all_val=None): """Create a mock SQLAlchemy result object. The result from db.execute() has sync methods (scalar_one_or_none, scalars), so we use MagicMock (not AsyncMock) for the result itself. """ result = MagicMock() result.scalar_one_or_none.return_value = scalar_one_or_none_val mock_scalars = MagicMock() mock_scalars.all.return_value = scalars_all_val or [] result.scalars.return_value = mock_scalars return result def _make_session_factory( capture: SessionCapture | None = None, execute_result=None, commit_side_effect=None, ): """Create a mock async session factory. Returns a callable that works as an async context manager producing a session. """ @asynccontextmanager async def _factory(): session = AsyncMock() session.add = MagicMock() if commit_side_effect: session.commit.side_effect = commit_side_effect else: session.commit = AsyncMock() session.rollback = AsyncMock() session.refresh = AsyncMock() if execute_result is not None: session.execute.return_value = execute_result else: default_result = _make_execute_result() session.execute.return_value = default_result if capture is not None: capture.sessions.append(session) yield session return _factory # ── Fixtures ────────────────────────────────────────────── @pytest.fixture def sample_event(): """A sample EvolutionEvent.""" return EvolutionEvent( agent_name="test_agent", change_type="prompt", before={"prompt": "old prompt"}, after={"prompt": "new prompt"}, metrics={"accuracy": 0.9}, ) # ── record() tests ─────────────────────────────────────── class TestRecord: async def test_record_returns_event_id(self, sample_event): Model = _make_model() capture = SessionCapture() sf = _make_session_factory(capture=capture) store = EvolutionStore(session_factory=sf, evolution_model=Model) event_id = await store.record(sample_event) assert event_id is not None uuid.UUID(event_id) # should be a valid UUID string async def test_record_sets_event_id_on_event(self, sample_event): Model = _make_model() capture = SessionCapture() sf = _make_session_factory(capture=capture) store = EvolutionStore(session_factory=sf, evolution_model=Model) assert sample_event.event_id is None await store.record(sample_event) assert sample_event.event_id is not None async def test_record_creates_model_instance_with_correct_fields(self, sample_event): Model = _make_model() capture = SessionCapture() sf = _make_session_factory(capture=capture) store = EvolutionStore(session_factory=sf, evolution_model=Model) await store.record(sample_event) Model.assert_called_once() call_kwargs = Model.call_args[1] assert call_kwargs["agent_name"] == "test_agent" assert call_kwargs["change_type"] == "prompt" assert call_kwargs["before"] == {"prompt": "old prompt"} assert call_kwargs["after"] == {"prompt": "new prompt"} assert call_kwargs["metrics"] == {"accuracy": 0.9} assert call_kwargs["status"] == "active" async def test_record_calls_db_add_and_commit(self, sample_event): Model = _make_model() capture = SessionCapture() sf = _make_session_factory(capture=capture) store = EvolutionStore(session_factory=sf, evolution_model=Model) await store.record(sample_event) session = capture.last session.add.assert_called() session.commit.assert_called() async def test_record_rollback_on_error(self, sample_event): Model = _make_model() capture = SessionCapture() sf = _make_session_factory(capture=capture, commit_side_effect=RuntimeError("db error")) store = EvolutionStore(session_factory=sf, evolution_model=Model) with pytest.raises(RuntimeError, match="db error"): await store.record(sample_event) session = capture.last session.rollback.assert_called() # ── rollback() tests ────────────────────────────────────── class TestRollback: async def test_rollback_success(self): Model = _make_model() entry_id = uuid.uuid4() mock_entry = _make_entry(id=entry_id, status="active") mock_result = _make_execute_result(scalar_one_or_none_val=mock_entry) capture = SessionCapture() sf = _make_session_factory(capture=capture, execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): result = await store.rollback(str(entry_id)) assert result is True assert mock_entry.status == "rolled_back" capture.last.commit.assert_called() async def test_rollback_not_found(self): Model = _make_model() mock_result = _make_execute_result(scalar_one_or_none_val=None) capture = SessionCapture() sf = _make_session_factory(capture=capture, execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): result = await store.rollback(str(uuid.uuid4())) assert result is False async def test_rollback_returns_false_on_error(self): Model = _make_model() @asynccontextmanager async def bad_sf(): session = AsyncMock() session.execute.side_effect = RuntimeError("connection lost") session.rollback = AsyncMock() yield session store = EvolutionStore(session_factory=bad_sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): result = await store.rollback(str(uuid.uuid4())) assert result is False # ── list_events() tests ────────────────────────────────── class TestListEvents: async def test_list_events_empty(self): Model = _make_model() sf = _make_session_factory() store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events() assert events == [] async def test_list_events_returns_entries(self): Model = _make_model() entry1 = _make_entry(agent_name="agent_a", change_type="prompt") entry2 = _make_entry(agent_name="agent_b", change_type="strategy") mock_result = _make_execute_result(scalars_all_val=[entry1, entry2]) sf = _make_session_factory(execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events() assert len(events) == 2 assert events[0]["agent_name"] == "agent_a" assert events[1]["agent_name"] == "agent_b" async def test_list_events_dict_shape(self): Model = _make_model() entry = _make_entry( agent_name="test_agent", change_type="prompt", before={"old": 1}, after={"new": 2}, metrics={"score": 0.95}, status="active", ) mock_result = _make_execute_result(scalars_all_val=[entry]) sf = _make_session_factory(execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events() e = events[0] assert "id" in e assert e["agent_name"] == "test_agent" assert e["change_type"] == "prompt" assert e["before"] == {"old": 1} assert e["after"] == {"new": 2} assert e["metrics"] == {"score": 0.95} assert e["status"] == "active" assert e["created_at"] is not None async def test_list_events_with_agent_name_filter(self): Model = _make_model() entry = _make_entry(agent_name="target_agent") mock_result = _make_execute_result(scalars_all_val=[entry]) sf = _make_session_factory(execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, mock_stmt = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events(agent_name="target_agent") # Verify .where() was called (chaining) mock_stmt.where.assert_called() assert len(events) == 1 assert events[0]["agent_name"] == "target_agent" async def test_list_events_with_change_type_filter(self): Model = _make_model() entry = _make_entry(change_type="strategy") mock_result = _make_execute_result(scalars_all_val=[entry]) sf = _make_session_factory(execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, mock_stmt = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events(change_type="strategy") mock_stmt.where.assert_called() assert len(events) == 1 assert events[0]["change_type"] == "strategy" async def test_list_events_with_status_filter(self): Model = _make_model() entry = _make_entry(status="rolled_back") mock_result = _make_execute_result(scalars_all_val=[entry]) sf = _make_session_factory(execute_result=mock_result) store = EvolutionStore(session_factory=sf, evolution_model=Model) mock_select, mock_stmt = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events(status="rolled_back") mock_stmt.where.assert_called() assert len(events) == 1 assert events[0]["status"] == "rolled_back" async def test_list_events_returns_empty_on_error(self): Model = _make_model() @asynccontextmanager async def bad_sf(): session = AsyncMock() session.execute.side_effect = RuntimeError("db down") yield session store = EvolutionStore(session_factory=bad_sf, evolution_model=Model) mock_select, _ = _make_select_mock() with patch("sqlalchemy.select", mock_select): events = await store.list_events() assert events == []