401 lines
14 KiB
Python
401 lines
14 KiB
Python
"""Tests for EvolutionStore - evolution event recording and rollback"""
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from contextlib import asynccontextmanager
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.protocol import EvolutionEvent
|
|
from agentkit.evolution.evolution_store import EvolutionStore
|
|
|
|
|
|
# ── Mock helpers ──────────────────────────────────────────
|
|
|
|
|
|
def _make_entry(
|
|
id: uuid.UUID | None = None,
|
|
agent_name: str = "test_agent",
|
|
change_type: str = "prompt",
|
|
before: dict | None = None,
|
|
after: dict | None = None,
|
|
metrics: dict | None = None,
|
|
status: str = "active",
|
|
created_at: datetime | None = None,
|
|
):
|
|
"""Create a mock DB entry object."""
|
|
entry = MagicMock()
|
|
entry.id = id or uuid.uuid4()
|
|
entry.agent_name = agent_name
|
|
entry.change_type = change_type
|
|
entry.before = before or {}
|
|
entry.after = after or {}
|
|
entry.metrics = metrics
|
|
entry.status = status
|
|
entry.created_at = created_at or datetime.now(timezone.utc)
|
|
return entry
|
|
|
|
|
|
def _make_model():
|
|
"""Create a mock evolution model class.
|
|
|
|
The model class is used like: Model(id=..., agent_name=..., ...)
|
|
and also as: Model.id, Model.agent_name, etc. in SQLAlchemy select().where().
|
|
"""
|
|
Model = MagicMock()
|
|
|
|
def _init(*args, **kwargs):
|
|
instance = MagicMock()
|
|
instance.id = kwargs.get("id", uuid.uuid4())
|
|
instance.agent_name = kwargs.get("agent_name", "test_agent")
|
|
instance.change_type = kwargs.get("change_type", "prompt")
|
|
instance.before = kwargs.get("before", {})
|
|
instance.after = kwargs.get("after", {})
|
|
instance.metrics = kwargs.get("metrics")
|
|
instance.status = kwargs.get("status", "active")
|
|
instance.created_at = kwargs.get("created_at", datetime.now(timezone.utc))
|
|
return instance
|
|
|
|
Model.side_effect = _init
|
|
return Model
|
|
|
|
|
|
def _make_select_mock():
|
|
"""Create a mock for sqlalchemy.select that supports .where()/.order_by() chaining."""
|
|
stmt = MagicMock()
|
|
stmt.where.return_value = stmt
|
|
stmt.order_by.return_value = stmt
|
|
mock_select = MagicMock(return_value=stmt)
|
|
return mock_select, stmt
|
|
|
|
|
|
class SessionCapture:
|
|
"""Helper that captures the session created by the session factory."""
|
|
|
|
def __init__(self):
|
|
self.sessions = []
|
|
|
|
@property
|
|
def last(self):
|
|
return self.sessions[-1] if self.sessions else None
|
|
|
|
|
|
def _make_execute_result(scalar_one_or_none_val=None, scalars_all_val=None):
|
|
"""Create a mock SQLAlchemy result object.
|
|
|
|
The result from db.execute() has sync methods (scalar_one_or_none, scalars),
|
|
so we use MagicMock (not AsyncMock) for the result itself.
|
|
"""
|
|
result = MagicMock()
|
|
result.scalar_one_or_none.return_value = scalar_one_or_none_val
|
|
mock_scalars = MagicMock()
|
|
mock_scalars.all.return_value = scalars_all_val or []
|
|
result.scalars.return_value = mock_scalars
|
|
return result
|
|
|
|
|
|
def _make_session_factory(
|
|
capture: SessionCapture | None = None,
|
|
execute_result=None,
|
|
commit_side_effect=None,
|
|
):
|
|
"""Create a mock async session factory.
|
|
|
|
Returns a callable that works as an async context manager producing a session.
|
|
"""
|
|
|
|
@asynccontextmanager
|
|
async def _factory():
|
|
session = AsyncMock()
|
|
session.add = MagicMock()
|
|
if commit_side_effect:
|
|
session.commit.side_effect = commit_side_effect
|
|
else:
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.refresh = AsyncMock()
|
|
|
|
if execute_result is not None:
|
|
session.execute.return_value = execute_result
|
|
else:
|
|
default_result = _make_execute_result()
|
|
session.execute.return_value = default_result
|
|
|
|
if capture is not None:
|
|
capture.sessions.append(session)
|
|
yield session
|
|
|
|
return _factory
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_event():
|
|
"""A sample EvolutionEvent."""
|
|
return EvolutionEvent(
|
|
agent_name="test_agent",
|
|
change_type="prompt",
|
|
before={"prompt": "old prompt"},
|
|
after={"prompt": "new prompt"},
|
|
metrics={"accuracy": 0.9},
|
|
)
|
|
|
|
|
|
# ── record() tests ───────────────────────────────────────
|
|
|
|
|
|
class TestRecord:
|
|
async def test_record_returns_event_id(self, sample_event):
|
|
Model = _make_model()
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
event_id = await store.record(sample_event)
|
|
assert event_id is not None
|
|
uuid.UUID(event_id) # should be a valid UUID string
|
|
|
|
async def test_record_sets_event_id_on_event(self, sample_event):
|
|
Model = _make_model()
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
assert sample_event.event_id is None
|
|
await store.record(sample_event)
|
|
assert sample_event.event_id is not None
|
|
|
|
async def test_record_creates_model_instance_with_correct_fields(self, sample_event):
|
|
Model = _make_model()
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
await store.record(sample_event)
|
|
|
|
Model.assert_called_once()
|
|
call_kwargs = Model.call_args[1]
|
|
assert call_kwargs["agent_name"] == "test_agent"
|
|
assert call_kwargs["change_type"] == "prompt"
|
|
assert call_kwargs["before"] == {"prompt": "old prompt"}
|
|
assert call_kwargs["after"] == {"prompt": "new prompt"}
|
|
assert call_kwargs["metrics"] == {"accuracy": 0.9}
|
|
assert call_kwargs["status"] == "active"
|
|
|
|
async def test_record_calls_db_add_and_commit(self, sample_event):
|
|
Model = _make_model()
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
await store.record(sample_event)
|
|
|
|
session = capture.last
|
|
session.add.assert_called()
|
|
session.commit.assert_called()
|
|
|
|
async def test_record_rollback_on_error(self, sample_event):
|
|
Model = _make_model()
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture, commit_side_effect=RuntimeError("db error"))
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
with pytest.raises(RuntimeError, match="db error"):
|
|
await store.record(sample_event)
|
|
|
|
session = capture.last
|
|
session.rollback.assert_called()
|
|
|
|
|
|
# ── rollback() tests ──────────────────────────────────────
|
|
|
|
|
|
class TestRollback:
|
|
async def test_rollback_success(self):
|
|
Model = _make_model()
|
|
entry_id = uuid.uuid4()
|
|
|
|
mock_entry = _make_entry(id=entry_id, status="active")
|
|
mock_result = _make_execute_result(scalar_one_or_none_val=mock_entry)
|
|
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture, execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
result = await store.rollback(str(entry_id))
|
|
|
|
assert result is True
|
|
assert mock_entry.status == "rolled_back"
|
|
capture.last.commit.assert_called()
|
|
|
|
async def test_rollback_not_found(self):
|
|
Model = _make_model()
|
|
|
|
mock_result = _make_execute_result(scalar_one_or_none_val=None)
|
|
|
|
capture = SessionCapture()
|
|
sf = _make_session_factory(capture=capture, execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
result = await store.rollback(str(uuid.uuid4()))
|
|
|
|
assert result is False
|
|
|
|
async def test_rollback_returns_false_on_error(self):
|
|
Model = _make_model()
|
|
|
|
@asynccontextmanager
|
|
async def bad_sf():
|
|
session = AsyncMock()
|
|
session.execute.side_effect = RuntimeError("connection lost")
|
|
session.rollback = AsyncMock()
|
|
yield session
|
|
|
|
store = EvolutionStore(session_factory=bad_sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
result = await store.rollback(str(uuid.uuid4()))
|
|
|
|
assert result is False
|
|
|
|
|
|
# ── list_events() tests ──────────────────────────────────
|
|
|
|
|
|
class TestListEvents:
|
|
async def test_list_events_empty(self):
|
|
Model = _make_model()
|
|
sf = _make_session_factory()
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events()
|
|
|
|
assert events == []
|
|
|
|
async def test_list_events_returns_entries(self):
|
|
Model = _make_model()
|
|
entry1 = _make_entry(agent_name="agent_a", change_type="prompt")
|
|
entry2 = _make_entry(agent_name="agent_b", change_type="strategy")
|
|
|
|
mock_result = _make_execute_result(scalars_all_val=[entry1, entry2])
|
|
|
|
sf = _make_session_factory(execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events()
|
|
|
|
assert len(events) == 2
|
|
assert events[0]["agent_name"] == "agent_a"
|
|
assert events[1]["agent_name"] == "agent_b"
|
|
|
|
async def test_list_events_dict_shape(self):
|
|
Model = _make_model()
|
|
entry = _make_entry(
|
|
agent_name="test_agent",
|
|
change_type="prompt",
|
|
before={"old": 1},
|
|
after={"new": 2},
|
|
metrics={"score": 0.95},
|
|
status="active",
|
|
)
|
|
|
|
mock_result = _make_execute_result(scalars_all_val=[entry])
|
|
|
|
sf = _make_session_factory(execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events()
|
|
|
|
e = events[0]
|
|
assert "id" in e
|
|
assert e["agent_name"] == "test_agent"
|
|
assert e["change_type"] == "prompt"
|
|
assert e["before"] == {"old": 1}
|
|
assert e["after"] == {"new": 2}
|
|
assert e["metrics"] == {"score": 0.95}
|
|
assert e["status"] == "active"
|
|
assert e["created_at"] is not None
|
|
|
|
async def test_list_events_with_agent_name_filter(self):
|
|
Model = _make_model()
|
|
entry = _make_entry(agent_name="target_agent")
|
|
|
|
mock_result = _make_execute_result(scalars_all_val=[entry])
|
|
|
|
sf = _make_session_factory(execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, mock_stmt = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events(agent_name="target_agent")
|
|
|
|
# Verify .where() was called (chaining)
|
|
mock_stmt.where.assert_called()
|
|
assert len(events) == 1
|
|
assert events[0]["agent_name"] == "target_agent"
|
|
|
|
async def test_list_events_with_change_type_filter(self):
|
|
Model = _make_model()
|
|
entry = _make_entry(change_type="strategy")
|
|
|
|
mock_result = _make_execute_result(scalars_all_val=[entry])
|
|
|
|
sf = _make_session_factory(execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, mock_stmt = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events(change_type="strategy")
|
|
|
|
mock_stmt.where.assert_called()
|
|
assert len(events) == 1
|
|
assert events[0]["change_type"] == "strategy"
|
|
|
|
async def test_list_events_with_status_filter(self):
|
|
Model = _make_model()
|
|
entry = _make_entry(status="rolled_back")
|
|
|
|
mock_result = _make_execute_result(scalars_all_val=[entry])
|
|
|
|
sf = _make_session_factory(execute_result=mock_result)
|
|
store = EvolutionStore(session_factory=sf, evolution_model=Model)
|
|
|
|
mock_select, mock_stmt = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events(status="rolled_back")
|
|
|
|
mock_stmt.where.assert_called()
|
|
assert len(events) == 1
|
|
assert events[0]["status"] == "rolled_back"
|
|
|
|
async def test_list_events_returns_empty_on_error(self):
|
|
Model = _make_model()
|
|
|
|
@asynccontextmanager
|
|
async def bad_sf():
|
|
session = AsyncMock()
|
|
session.execute.side_effect = RuntimeError("db down")
|
|
yield session
|
|
|
|
store = EvolutionStore(session_factory=bad_sf, evolution_model=Model)
|
|
|
|
mock_select, _ = _make_select_mock()
|
|
with patch("sqlalchemy.select", mock_select):
|
|
events = await store.list_events()
|
|
|
|
assert events == []
|