275 lines
10 KiB
Python
275 lines
10 KiB
Python
"""Unit tests for CascadeStateStore (U5 — CascadeStateStore Persistence)."""
|
|
|
|
import pytest
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from agentkit.quality.cascade_state_store import (
|
|
CascadeStateStore,
|
|
InMemoryCascadeStateStore,
|
|
RedisCascadeStateStore,
|
|
create_cascade_state_store,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# InMemoryCascadeStateStore
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestInMemoryCascadeStateStore:
|
|
"""Tests for InMemoryCascadeStateStore."""
|
|
|
|
def test_implements_protocol(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert isinstance(store, CascadeStateStore)
|
|
|
|
def test_increment_interaction(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert store.increment_interaction("s1") == 1
|
|
assert store.increment_interaction("s1") == 2
|
|
assert store.increment_interaction("s1") == 3
|
|
|
|
def test_increment_different_sessions(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert store.increment_interaction("s1") == 1
|
|
assert store.increment_interaction("s2") == 1
|
|
assert store.increment_interaction("s1") == 2
|
|
|
|
def test_get_interaction(self):
|
|
store = InMemoryCascadeStateStore()
|
|
store.increment_interaction("s1")
|
|
store.increment_interaction("s1")
|
|
assert store.get_interaction("s1") == 2
|
|
|
|
def test_get_interaction_nonexistent(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert store.get_interaction("unknown") == 0
|
|
|
|
def test_set_and_get_depth(self):
|
|
store = InMemoryCascadeStateStore()
|
|
store.set_depth("s1", 3)
|
|
assert store.get_depth("s1") == 3
|
|
|
|
def test_get_depth_nonexistent(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert store.get_depth("unknown") == 0
|
|
|
|
def test_reset(self):
|
|
store = InMemoryCascadeStateStore()
|
|
store.increment_interaction("s1")
|
|
store.increment_interaction("s1")
|
|
store.set_depth("s1", 2)
|
|
store.reset("s1")
|
|
assert store.get_interaction("s1") == 0
|
|
assert store.get_depth("s1") == 0
|
|
|
|
def test_reset_nonexistent_no_error(self):
|
|
store = InMemoryCascadeStateStore()
|
|
store.reset("nonexistent") # Should not raise
|
|
|
|
def test_ttl_expiry_interaction(self):
|
|
store = InMemoryCascadeStateStore(session_ttl=1) # 1 second TTL
|
|
store.increment_interaction("s1")
|
|
assert store.get_interaction("s1") == 1
|
|
time.sleep(1.1)
|
|
# After TTL, should return 0 (expired)
|
|
assert store.get_interaction("s1") == 0
|
|
|
|
def test_ttl_expiry_depth(self):
|
|
store = InMemoryCascadeStateStore(session_ttl=1)
|
|
store.set_depth("s1", 5)
|
|
assert store.get_depth("s1") == 5
|
|
time.sleep(1.1)
|
|
assert store.get_depth("s1") == 0
|
|
|
|
def test_ttl_cleanup_on_increment(self):
|
|
store = InMemoryCascadeStateStore(session_ttl=1)
|
|
store.increment_interaction("s1")
|
|
store.increment_interaction("s2")
|
|
time.sleep(1.1)
|
|
# Incrementing s3 should trigger cleanup of s1 and s2
|
|
store.increment_interaction("s3")
|
|
assert "s1" not in store._interaction_counts
|
|
assert "s2" not in store._interaction_counts
|
|
assert store._interaction_counts.get("s3") == 1
|
|
|
|
def test_touch_refreshes_ttl(self):
|
|
store = InMemoryCascadeStateStore(session_ttl=2)
|
|
store.increment_interaction("s1")
|
|
time.sleep(1.0)
|
|
# Touch refreshes the timestamp
|
|
store.increment_interaction("s1")
|
|
time.sleep(1.0)
|
|
# Should still be alive (1s < 2s since last touch)
|
|
assert store.get_interaction("s1") == 2
|
|
|
|
def test_default_session_ttl(self):
|
|
store = InMemoryCascadeStateStore()
|
|
assert store._session_ttl == 86400
|
|
|
|
def test_custom_session_ttl(self):
|
|
store = InMemoryCascadeStateStore(session_ttl=3600)
|
|
assert store._session_ttl == 3600
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RedisCascadeStateStore (mocked)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRedisCascadeStateStoreMocked:
|
|
"""Tests for RedisCascadeStateStore with mocked Redis."""
|
|
|
|
def test_implements_protocol(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
assert isinstance(store, CascadeStateStore)
|
|
|
|
def test_degrade_to_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
assert not store._degraded
|
|
store._degrade_to_fallback()
|
|
assert store._degraded
|
|
assert isinstance(store._fallback, InMemoryCascadeStateStore)
|
|
|
|
def test_increment_degraded_uses_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
store._degrade_to_fallback()
|
|
result = store.increment_interaction("s1")
|
|
assert result == 1
|
|
assert store._fallback.get_interaction("s1") == 1
|
|
|
|
def test_get_interaction_degraded_uses_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
store._degrade_to_fallback()
|
|
store._fallback.increment_interaction("s1")
|
|
assert store.get_interaction("s1") == 1
|
|
|
|
def test_set_depth_degraded_uses_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
store._degrade_to_fallback()
|
|
store.set_depth("s1", 3)
|
|
assert store._fallback.get_depth("s1") == 3
|
|
|
|
def test_get_depth_degraded_uses_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
store._degrade_to_fallback()
|
|
store._fallback.set_depth("s1", 5)
|
|
assert store.get_depth("s1") == 5
|
|
|
|
def test_reset_degraded_uses_fallback(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
store._degrade_to_fallback()
|
|
store._fallback.increment_interaction("s1")
|
|
store.reset("s1")
|
|
assert store._fallback.get_interaction("s1") == 0
|
|
|
|
def test_increment_redis_failure_degrades(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://nonexistent:6379")
|
|
# Will fail to connect and degrade
|
|
result = store.increment_interaction("s1")
|
|
assert store._degraded
|
|
assert result >= 0 # Should return something (fallback or 0)
|
|
|
|
def test_custom_session_ttl(self):
|
|
store = RedisCascadeStateStore(
|
|
redis_url="redis://localhost:6379", session_ttl=3600
|
|
)
|
|
assert store._session_ttl == 3600
|
|
|
|
def test_default_session_ttl(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
assert store._session_ttl == 86400
|
|
|
|
def test_increment_with_mock_redis(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
mock_redis = MagicMock()
|
|
mock_pipeline = MagicMock()
|
|
mock_pipeline.incr.return_value = mock_pipeline
|
|
mock_pipeline.expire.return_value = mock_pipeline
|
|
mock_pipeline.execute.return_value = [1, True]
|
|
mock_redis.pipeline.return_value = mock_pipeline
|
|
store._sync_redis = mock_redis
|
|
|
|
result = store.increment_interaction("s1")
|
|
assert result == 1
|
|
mock_pipeline.incr.assert_called_once()
|
|
mock_pipeline.expire.assert_called_once()
|
|
|
|
def test_get_interaction_with_mock_redis(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
mock_redis = MagicMock()
|
|
mock_redis.get.return_value = "3"
|
|
store._sync_redis = mock_redis
|
|
|
|
result = store.get_interaction("s1")
|
|
assert result == 3
|
|
|
|
def test_get_interaction_redis_returns_none(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
mock_redis = MagicMock()
|
|
mock_redis.get.return_value = None
|
|
store._sync_redis = mock_redis
|
|
|
|
result = store.get_interaction("s1")
|
|
assert result == 0
|
|
|
|
def test_set_depth_with_mock_redis(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
mock_redis = MagicMock()
|
|
mock_pipeline = MagicMock()
|
|
mock_pipeline.set.return_value = mock_pipeline
|
|
mock_pipeline.expire.return_value = mock_pipeline
|
|
mock_pipeline.execute.return_value = [True, True]
|
|
mock_redis.pipeline.return_value = mock_pipeline
|
|
store._sync_redis = mock_redis
|
|
|
|
store.set_depth("s1", 3)
|
|
mock_pipeline.set.assert_called_once()
|
|
mock_pipeline.expire.assert_called_once()
|
|
|
|
def test_reset_with_mock_redis(self):
|
|
store = RedisCascadeStateStore(redis_url="redis://localhost:6379")
|
|
mock_redis = MagicMock()
|
|
mock_pipeline = MagicMock()
|
|
mock_pipeline.delete.return_value = mock_pipeline
|
|
mock_pipeline.execute.return_value = [1, 1]
|
|
mock_redis.pipeline.return_value = mock_pipeline
|
|
store._sync_redis = mock_redis
|
|
|
|
store.reset("s1")
|
|
assert mock_pipeline.delete.call_count == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreateCascadeStateStore:
|
|
def test_memory_backend(self):
|
|
store = create_cascade_state_store(backend="memory")
|
|
assert isinstance(store, InMemoryCascadeStateStore)
|
|
|
|
def test_auto_backend_returns_store(self):
|
|
store = create_cascade_state_store(backend="auto")
|
|
assert isinstance(store, (InMemoryCascadeStateStore, RedisCascadeStateStore))
|
|
|
|
def test_redis_backend_returns_store(self):
|
|
store = create_cascade_state_store(backend="redis")
|
|
assert isinstance(store, (InMemoryCascadeStateStore, RedisCascadeStateStore))
|
|
|
|
def test_session_ttl_passed_to_redis(self):
|
|
store = create_cascade_state_store(
|
|
backend="redis", session_ttl=7200
|
|
)
|
|
if isinstance(store, RedisCascadeStateStore):
|
|
assert store._session_ttl == 7200
|
|
|
|
def test_session_ttl_passed_to_memory(self):
|
|
store = create_cascade_state_store(
|
|
backend="memory", session_ttl=7200
|
|
)
|
|
assert isinstance(store, InMemoryCascadeStateStore)
|
|
assert store._session_ttl == 7200
|