fischer-agentkit/tests/unit/quality/test_cascade_state_store.py

275 lines
10 KiB
Python

"""Unit tests for CascadeStateStore (U5 — CascadeStateStore Persistence)."""
import pytest
import time
from unittest.mock import MagicMock, patch
from agentkit.quality.cascade_state_store import (
CascadeStateStore,
InMemoryCascadeStateStore,
RedisCascadeStateStore,
create_cascade_state_store,
)
# ---------------------------------------------------------------------------
# InMemoryCascadeStateStore
# ---------------------------------------------------------------------------
class TestInMemoryCascadeStateStore:
"""Tests for InMemoryCascadeStateStore."""
def test_implements_protocol(self):
store = InMemoryCascadeStateStore()
assert isinstance(store, CascadeStateStore)
def test_increment_interaction(self):
store = InMemoryCascadeStateStore()
assert store.increment_interaction("s1") == 1
assert store.increment_interaction("s1") == 2
assert store.increment_interaction("s1") == 3
def test_increment_different_sessions(self):
store = InMemoryCascadeStateStore()
assert store.increment_interaction("s1") == 1
assert store.increment_interaction("s2") == 1
assert store.increment_interaction("s1") == 2
def test_get_interaction(self):
store = InMemoryCascadeStateStore()
store.increment_interaction("s1")
store.increment_interaction("s1")
assert store.get_interaction("s1") == 2
def test_get_interaction_nonexistent(self):
store = InMemoryCascadeStateStore()
assert store.get_interaction("unknown") == 0
def test_set_and_get_depth(self):
store = InMemoryCascadeStateStore()
store.set_depth("s1", 3)
assert store.get_depth("s1") == 3
def test_get_depth_nonexistent(self):
store = InMemoryCascadeStateStore()
assert store.get_depth("unknown") == 0
def test_reset(self):
store = InMemoryCascadeStateStore()
store.increment_interaction("s1")
store.increment_interaction("s1")
store.set_depth("s1", 2)
store.reset("s1")
assert store.get_interaction("s1") == 0
assert store.get_depth("s1") == 0
def test_reset_nonexistent_no_error(self):
store = InMemoryCascadeStateStore()
store.reset("nonexistent") # Should not raise
def test_ttl_expiry_interaction(self):
store = InMemoryCascadeStateStore(session_ttl=1) # 1 second TTL
store.increment_interaction("s1")
assert store.get_interaction("s1") == 1
time.sleep(1.1)
# After TTL, should return 0 (expired)
assert store.get_interaction("s1") == 0
def test_ttl_expiry_depth(self):
store = InMemoryCascadeStateStore(session_ttl=1)
store.set_depth("s1", 5)
assert store.get_depth("s1") == 5
time.sleep(1.1)
assert store.get_depth("s1") == 0
def test_ttl_cleanup_on_increment(self):
store = InMemoryCascadeStateStore(session_ttl=1)
store.increment_interaction("s1")
store.increment_interaction("s2")
time.sleep(1.1)
# Incrementing s3 should trigger cleanup of s1 and s2
store.increment_interaction("s3")
assert "s1" not in store._interaction_counts
assert "s2" not in store._interaction_counts
assert store._interaction_counts.get("s3") == 1
def test_touch_refreshes_ttl(self):
store = InMemoryCascadeStateStore(session_ttl=2)
store.increment_interaction("s1")
time.sleep(1.0)
# Touch refreshes the timestamp
store.increment_interaction("s1")
time.sleep(1.0)
# Should still be alive (1s < 2s since last touch)
assert store.get_interaction("s1") == 2
def test_default_session_ttl(self):
store = InMemoryCascadeStateStore()
assert store._session_ttl == 86400
def test_custom_session_ttl(self):
store = InMemoryCascadeStateStore(session_ttl=3600)
assert store._session_ttl == 3600
# ---------------------------------------------------------------------------
# RedisCascadeStateStore (mocked)
# ---------------------------------------------------------------------------
class TestRedisCascadeStateStoreMocked:
"""Tests for RedisCascadeStateStore with mocked Redis."""
def test_implements_protocol(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
assert isinstance(store, CascadeStateStore)
def test_degrade_to_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
assert not store._degraded
store._degrade_to_fallback()
assert store._degraded
assert isinstance(store._fallback, InMemoryCascadeStateStore)
def test_increment_degraded_uses_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
result = store.increment_interaction("s1")
assert result == 1
assert store._fallback.get_interaction("s1") == 1
def test_get_interaction_degraded_uses_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
store._fallback.increment_interaction("s1")
assert store.get_interaction("s1") == 1
def test_set_depth_degraded_uses_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
store.set_depth("s1", 3)
assert store._fallback.get_depth("s1") == 3
def test_get_depth_degraded_uses_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
store._fallback.set_depth("s1", 5)
assert store.get_depth("s1") == 5
def test_reset_degraded_uses_fallback(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
store._fallback.increment_interaction("s1")
store.reset("s1")
assert store._fallback.get_interaction("s1") == 0
def test_increment_redis_failure_degrades(self):
store = RedisCascadeStateStore(redis_url="redis://nonexistent:6379")
# Will fail to connect and degrade
result = store.increment_interaction("s1")
assert store._degraded
assert result >= 0 # Should return something (fallback or 0)
def test_custom_session_ttl(self):
store = RedisCascadeStateStore(
redis_url="redis://localhost:6379", session_ttl=3600
)
assert store._session_ttl == 3600
def test_default_session_ttl(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
assert store._session_ttl == 86400
def test_increment_with_mock_redis(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_pipeline.incr.return_value = mock_pipeline
mock_pipeline.expire.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1, True]
mock_redis.pipeline.return_value = mock_pipeline
store._sync_redis = mock_redis
result = store.increment_interaction("s1")
assert result == 1
mock_pipeline.incr.assert_called_once()
mock_pipeline.expire.assert_called_once()
def test_get_interaction_with_mock_redis(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
mock_redis = MagicMock()
mock_redis.get.return_value = "3"
store._sync_redis = mock_redis
result = store.get_interaction("s1")
assert result == 3
def test_get_interaction_redis_returns_none(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
mock_redis = MagicMock()
mock_redis.get.return_value = None
store._sync_redis = mock_redis
result = store.get_interaction("s1")
assert result == 0
def test_set_depth_with_mock_redis(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_pipeline.set.return_value = mock_pipeline
mock_pipeline.expire.return_value = mock_pipeline
mock_pipeline.execute.return_value = [True, True]
mock_redis.pipeline.return_value = mock_pipeline
store._sync_redis = mock_redis
store.set_depth("s1", 3)
mock_pipeline.set.assert_called_once()
mock_pipeline.expire.assert_called_once()
def test_reset_with_mock_redis(self):
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
mock_redis = MagicMock()
mock_pipeline = MagicMock()
mock_pipeline.delete.return_value = mock_pipeline
mock_pipeline.execute.return_value = [1, 1]
mock_redis.pipeline.return_value = mock_pipeline
store._sync_redis = mock_redis
store.reset("s1")
assert mock_pipeline.delete.call_count == 2
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
class TestCreateCascadeStateStore:
def test_memory_backend(self):
store = create_cascade_state_store(backend="memory")
assert isinstance(store, InMemoryCascadeStateStore)
def test_auto_backend_returns_store(self):
store = create_cascade_state_store(backend="auto")
assert isinstance(store, (InMemoryCascadeStateStore, RedisCascadeStateStore))
def test_redis_backend_returns_store(self):
store = create_cascade_state_store(backend="redis")
assert isinstance(store, (InMemoryCascadeStateStore, RedisCascadeStateStore))
def test_session_ttl_passed_to_redis(self):
store = create_cascade_state_store(
backend="redis", session_ttl=7200
)
if isinstance(store, RedisCascadeStateStore):
assert store._session_ttl == 7200
def test_session_ttl_passed_to_memory(self):
store = create_cascade_state_store(
backend="memory", session_ttl=7200
)
assert isinstance(store, InMemoryCascadeStateStore)
assert store._session_ttl == 7200