"""Unit tests for UsageStore (U4 — UsageStore Persistence).""" from datetime import datetime, timedelta, timezone import pytest from agentkit.llm.protocol import TokenUsage from agentkit.llm.providers.usage_store import ( InMemoryUsageStore, RedisUsageStore, UsageRecord, UsageBucket, UsageSummary, UsageStore, UsageStoreUnavailableError, create_usage_store, ) # --------------------------------------------------------------------------- # InMemoryUsageStore # --------------------------------------------------------------------------- class TestInMemoryUsageStore: """Tests for InMemoryUsageStore.""" def test_implements_protocol(self): store = InMemoryUsageStore() assert isinstance(store, UsageStore) def test_record_single(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) summary = store.get_usage() assert summary.total_tokens == 150 assert summary.total_cost == 0.05 assert len(summary.records) == 1 def test_record_multiple(self): store = InMemoryUsageStore() usage1 = TokenUsage(prompt_tokens=100, completion_tokens=50) usage2 = TokenUsage(prompt_tokens=200, completion_tokens=100) store.record("agent1", "gpt-4", usage1, cost=0.05, latency_ms=200) store.record("agent1", "gpt-4", usage2, cost=0.10, latency_ms=300) summary = store.get_usage() assert summary.total_tokens == 450 assert abs(summary.total_cost - 0.15) < 1e-6 assert len(summary.records) == 2 def test_get_usage_by_agent(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) store.record("agent2", "gpt-4", usage, cost=0.05, latency_ms=200) summary = store.get_usage(agent_name="agent1") assert len(summary.records) == 1 assert summary.records[0].agent_name == "agent1" def test_get_usage_by_time_range(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) # Query with start_time in the future — should return empty future = datetime.now(timezone.utc) + timedelta(hours=1) summary = store.get_usage(start_time=future) assert len(summary.records) == 0 # Query with end_time in the past — should return empty past = datetime.now(timezone.utc) - timedelta(hours=1) summary = store.get_usage(end_time=past) assert len(summary.records) == 0 # Query with wide range — should return the record start = datetime.now(timezone.utc) - timedelta(hours=1) end = datetime.now(timezone.utc) + timedelta(hours=1) summary = store.get_usage(start_time=start, end_time=end) assert len(summary.records) == 1 def test_get_usage_by_model(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) store.record("agent1", "gpt-3.5", usage, cost=0.01, latency_ms=100) summary = store.get_usage() assert "gpt-4" in summary.by_model assert "gpt-3.5" in summary.by_model assert summary.by_model["gpt-4"]["count"] == 1 assert summary.by_model["gpt-3.5"]["count"] == 1 def test_get_usage_empty(self): store = InMemoryUsageStore() summary = store.get_usage() assert summary.total_tokens == 0 assert summary.total_cost == 0.0 assert len(summary.records) == 0 def test_max_records_trimming(self): store = InMemoryUsageStore() store.MAX_RECORDS = 5 usage = TokenUsage(prompt_tokens=1, completion_tokens=1) for i in range(10): store.record(f"agent{i}", "gpt-4", usage, cost=0.01, latency_ms=100) assert len(store._records) == 5 # Should keep the last 5 records assert store._records[0].agent_name == "agent5" def test_usage_record_timestamp(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) rec = store.get_usage().records[0] assert rec.timestamp != "" # Should be parseable as ISO 8601 datetime.fromisoformat(rec.timestamp) def test_record_with_user_and_department(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) rec = store.get_usage().records[0] assert rec.user_id == "u1" assert rec.department_id == "d1" def test_record_defaults_user_department_to_none(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) rec = store.get_usage().records[0] assert rec.user_id is None assert rec.department_id is None def test_get_usage_filters_by_user(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1") store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2") summary = store.get_usage(user_id="u1") assert len(summary.records) == 1 assert summary.records[0].user_id == "u1" def test_get_usage_filters_by_department(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d1") store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d2") summary = store.get_usage(department_id="d1") assert len(summary.records) == 1 assert summary.records[0].department_id == "d1" def test_get_usage_by_user(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2", department_id="d2", ) summary = store.get_usage_by_user("u1") assert len(summary.records) == 1 assert summary.records[0].user_id == "u1" assert summary.total_tokens == 150 def test_get_usage_by_department(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2", department_id="d2", ) summary = store.get_usage_by_department("d1") assert len(summary.records) == 1 assert summary.records[0].department_id == "d1" assert summary.total_tokens == 150 def test_summary_includes_by_user_and_by_department(self): store = InMemoryUsageStore() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) summary = store.get_usage() assert "u1" in summary.by_user assert summary.by_user["u1"]["count"] == 2 assert summary.by_user["u1"]["total_tokens"] == 300 assert "d1" in summary.by_department assert summary.by_department["d1"]["count"] == 2 # --------------------------------------------------------------------------- # UsageRecord / UsageBucket / UsageSummary dataclasses # --------------------------------------------------------------------------- class TestDataclasses: def test_usage_record_auto_timestamp(self): rec = UsageRecord( agent_name="a", model="m", prompt_tokens=1, completion_tokens=1, total_tokens=2, cost=0.01, latency_ms=100, ) assert rec.timestamp != "" def test_usage_record_explicit_timestamp(self): rec = UsageRecord( agent_name="a", model="m", prompt_tokens=1, completion_tokens=1, total_tokens=2, cost=0.01, latency_ms=100, timestamp="2026-01-01T00:00:00+00:00", ) assert rec.timestamp == "2026-01-01T00:00:00+00:00" def test_usage_bucket_defaults(self): bucket = UsageBucket() assert bucket.prompt_tokens == 0 assert bucket.completion_tokens == 0 assert bucket.total_tokens == 0 assert bucket.cost == 0.0 assert bucket.count == 0 def test_usage_summary_defaults(self): summary = UsageSummary() assert summary.total_tokens == 0 assert summary.total_cost == 0.0 assert summary.by_model == {} assert summary.records == [] # --------------------------------------------------------------------------- # RedisUsageStore (mocked) # --------------------------------------------------------------------------- class TestRedisUsageStoreMocked: """Tests for RedisUsageStore with mocked Redis.""" def _make_store(self): store = RedisUsageStore(redis_url="redis://localhost:6379") return store def test_implements_protocol(self): store = self._make_store() assert isinstance(store, UsageStore) def test_degrade_to_fallback(self): store = self._make_store() assert not store._degraded store._degrade_to_fallback() assert store._degraded assert store._fallback is not None def test_record_degraded_uses_fallback(self): store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) # Should be in fallback summary = store._fallback.get_usage() assert len(summary.records) == 1 def test_get_usage_degraded_raises_unavailable(self): """Fail-closed (KTD-1): degraded get_usage raises UsageStoreUnavailableError.""" store = self._make_store() store._degrade_to_fallback() with pytest.raises(UsageStoreUnavailableError): store.get_usage() def test_get_usage_degraded_raises_even_with_fallback_data(self): """Even if fallback has data, get_usage must fail-closed (no quota bypass).""" store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) with pytest.raises(UsageStoreUnavailableError): store.get_usage() def test_get_usage_degraded_with_user_filter_raises(self): """Fail-closed applies even when filtering by user_id.""" store = self._make_store() store._degrade_to_fallback() with pytest.raises(UsageStoreUnavailableError): store.get_usage(user_id="u1") def test_get_usage_by_user_degraded_raises(self): store = self._make_store() store._degrade_to_fallback() with pytest.raises(UsageStoreUnavailableError): store.get_usage_by_user("u1") def test_get_usage_by_department_degraded_raises(self): store = self._make_store() store._degrade_to_fallback() with pytest.raises(UsageStoreUnavailableError): store.get_usage_by_department("d1") def test_today_key_format(self): store = self._make_store() key = store._today_key() # Should be YYYY-MM-DD assert len(key) == 10 assert key[4] == "-" def test_v2_keys_with_user_and_department(self): store = self._make_store() hash_key, list_key = store._v2_keys("2026-06-21", "u1", "d1") assert hash_key == "agentkit:usage:v2:2026-06-21:u1:d1" assert list_key == "agentkit:usage_records:v2:2026-06-21:u1:d1" def test_v2_keys_with_none_user_and_department(self): store = self._make_store() hash_key, list_key = store._v2_keys("2026-06-21", None, None) # None values are normalized to "none" in the key. assert hash_key == "agentkit:usage:v2:2026-06-21:none:none" assert list_key == "agentkit:usage_records:v2:2026-06-21:none:none" def test_record_degraded_with_user_and_department(self): store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store.record( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) # Should be in fallback with user/department attached. summary = store._fallback.get_usage() assert len(summary.records) == 1 assert summary.records[0].user_id == "u1" assert summary.records[0].department_id == "d1" def test_record_async_degraded_uses_fallback(self): """Async record in degraded state uses fallback (recording is allowed).""" import asyncio store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) async def _run(): await store.record_async("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) asyncio.run(_run()) summary = store._fallback.get_usage() assert len(summary.records) == 1 # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- class TestCreateUsageStore: def test_memory_backend(self): store = create_usage_store(backend="memory") assert isinstance(store, InMemoryUsageStore) def test_auto_backend_returns_store(self): store = create_usage_store(backend="auto") assert isinstance(store, (InMemoryUsageStore, RedisUsageStore)) def test_redis_backend_returns_store(self): store = create_usage_store(backend="redis") # May be InMemory if redis package unavailable assert isinstance(store, (InMemoryUsageStore, RedisUsageStore)) # --------------------------------------------------------------------------- # U1: Key construction fix — SCAN patterns for partial scope queries # --------------------------------------------------------------------------- class TestRedisUsageStoreKeyConstruction: """U1: Verify get_usage constructs correct SCAN patterns when only user_id OR only department_id is provided (not both). Previously, ``get_usage(department_id=X)`` constructed key ``...:none:X`` which only matched records with no user — missing all records from actual users in that department. The fix uses SCAN with pattern ``...:*:X`` to aggregate across all users. """ def _make_store_with_mock_redis(self, mock_redis): store = RedisUsageStore(redis_url="redis://localhost:6379") store._sync_redis = mock_redis return store def test_get_usage_by_department_scans_all_users(self): """get_usage(department_id=X) should SCAN ...:*:X (all users).""" from unittest.mock import MagicMock mock_redis = MagicMock() # Single-day range so scan_iter is called once per scope. keys_for_d1 = [ "agentkit:usage_records:v2:2026-06-21:u1:d1", "agentkit:usage_records:v2:2026-06-21:u2:d1", "agentkit:usage_records:v2:2026-06-21:none:d1", ] mock_redis.scan_iter.return_value = iter(keys_for_d1) # lrange returns records for each v2 key + 1 legacy v1 key. record_u1 = ( '{"agent_name":"a","model":"m","prompt_tokens":100,' '"completion_tokens":50,"total_tokens":150,"cost":0.05,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u1","department_id":"d1"}' ) record_u2 = ( '{"agent_name":"a","model":"m","prompt_tokens":200,' '"completion_tokens":100,"total_tokens":300,"cost":0.10,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u2","department_id":"d1"}' ) record_none = ( '{"agent_name":"a","model":"m","prompt_tokens":50,' '"completion_tokens":25,"total_tokens":75,"cost":0.02,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":null,"department_id":"d1"}' ) mock_redis.lrange.side_effect = [ [record_u1], # for key u1:d1 [record_u2], # for key u2:d1 [record_none], # for key none:d1 [], # legacy v1 key ] store = self._make_store_with_mock_redis(mock_redis) day = datetime(2026, 6, 21, tzinfo=timezone.utc) # Use end_time within same day (records are at 10:00 UTC). end = day + timedelta(hours=12) summary = store.get_usage(department_id="d1", start_time=day, end_time=end) # Should aggregate records from all users in d1. assert summary.total_tokens == 525 # 150 + 300 + 75 # Verify SCAN was called with the correct pattern. scan_call = mock_redis.scan_iter.call_args pattern = scan_call.kwargs.get("match") or scan_call.args[0] assert "*:d1" in pattern def test_get_usage_by_user_scans_all_departments(self): """get_usage(user_id=X) should SCAN ...:X:* (all departments).""" from unittest.mock import MagicMock mock_redis = MagicMock() mock_redis.scan_iter.return_value = iter( [ "agentkit:usage_records:v2:2026-06-21:u1:d1", "agentkit:usage_records:v2:2026-06-21:u1:d2", ] ) record_d1 = ( '{"agent_name":"a","model":"m","prompt_tokens":100,' '"completion_tokens":50,"total_tokens":150,"cost":0.05,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u1","department_id":"d1"}' ) record_d2 = ( '{"agent_name":"a","model":"m","prompt_tokens":200,' '"completion_tokens":100,"total_tokens":300,"cost":0.10,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u1","department_id":"d2"}' ) mock_redis.lrange.side_effect = [ [record_d1], [record_d2], [], # legacy v1 key ] store = self._make_store_with_mock_redis(mock_redis) day = datetime(2026, 6, 21, tzinfo=timezone.utc) end = day + timedelta(hours=12) summary = store.get_usage(user_id="u1", start_time=day, end_time=end) assert summary.total_tokens == 450 # 150 + 300 scan_call = mock_redis.scan_iter.call_args pattern = scan_call.kwargs.get("match") or scan_call.args[0] assert "u1:*" in pattern def test_get_usage_both_user_and_dept_uses_direct_key(self): """get_usage(user_id=X, department_id=Y) → direct key lookup (no SCAN).""" from unittest.mock import MagicMock mock_redis = MagicMock() record = ( '{"agent_name":"a","model":"m","prompt_tokens":100,' '"completion_tokens":50,"total_tokens":150,"cost":0.05,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u1","department_id":"d1"}' ) mock_redis.lrange.side_effect = [ [record], # direct v2 key [], # legacy v1 key ] store = self._make_store_with_mock_redis(mock_redis) day = datetime(2026, 6, 21, tzinfo=timezone.utc) end = day + timedelta(hours=12) summary = store.get_usage( user_id="u1", department_id="d1", start_time=day, end_time=end, ) assert summary.total_tokens == 150 # SCAN should NOT be called (direct key lookup instead). mock_redis.scan_iter.assert_not_called() def test_get_usage_no_filter_scans_all(self): """get_usage() with no filters → SCAN ...:* (all records).""" from unittest.mock import MagicMock mock_redis = MagicMock() mock_redis.scan_iter.return_value = iter(["agentkit:usage_records:v2:2026-06-21:u1:d1"]) record = ( '{"agent_name":"a","model":"m","prompt_tokens":100,' '"completion_tokens":50,"total_tokens":150,"cost":0.05,' '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' '"user_id":"u1","department_id":"d1"}' ) mock_redis.lrange.side_effect = [ [record], [], # legacy v1 key ] store = self._make_store_with_mock_redis(mock_redis) day = datetime(2026, 6, 21, tzinfo=timezone.utc) end = day + timedelta(hours=12) summary = store.get_usage(start_time=day, end_time=end) assert summary.total_tokens == 150 mock_redis.scan_iter.assert_called_once() def test_get_usage_empty_redis_returns_empty_summary(self): """No records in Redis → empty UsageSummary (not an error).""" from unittest.mock import MagicMock mock_redis = MagicMock() mock_redis.scan_iter.return_value = iter([]) mock_redis.lrange.return_value = [] store = self._make_store_with_mock_redis(mock_redis) day = datetime(2026, 6, 21, tzinfo=timezone.utc) end = day + timedelta(hours=12) summary = store.get_usage(department_id="d1", start_time=day, end_time=end) assert summary.total_tokens == 0 assert len(summary.records) == 0 # --------------------------------------------------------------------------- # U1: Degradation recovery — health check clears degraded state # --------------------------------------------------------------------------- class TestRedisUsageStoreDegradationRecovery: """U1 (KTD-5): Redis recovery clears degraded state via health check.""" def test_health_check_clears_degraded_on_recovery(self): """When Redis recovers, _degraded is cleared and _fallback discarded.""" import asyncio from unittest.mock import AsyncMock, patch store = RedisUsageStore(redis_url="redis://localhost:6379") store._degrade_to_fallback() assert store._degraded assert store._fallback is not None async def _run(): # Mock the async Redis client's ping to succeed. mock_redis = AsyncMock() mock_redis.ping = AsyncMock(return_value=True) with patch.object(store, "_get_redis", return_value=mock_redis): # Manually trigger one iteration of the health check loop. # We set a very short interval and run one cycle. store.HEALTH_CHECK_INTERVAL = 0.01 task = asyncio.create_task(store._health_check_loop()) await asyncio.sleep(0.05) task.cancel() try: await task except asyncio.CancelledError: pass asyncio.run(_run()) assert not store._degraded assert store._fallback is None def test_health_check_keeps_degraded_on_failure(self): """When Redis is still down, degraded state persists.""" import asyncio from unittest.mock import AsyncMock, patch store = RedisUsageStore(redis_url="redis://localhost:6379") store._degrade_to_fallback() async def _run(): mock_redis = AsyncMock() mock_redis.ping = AsyncMock(side_effect=ConnectionError("still down")) with patch.object(store, "_get_redis", return_value=mock_redis): store.HEALTH_CHECK_INTERVAL = 0.01 task = asyncio.create_task(store._health_check_loop()) await asyncio.sleep(0.05) task.cancel() try: await task except asyncio.CancelledError: pass asyncio.run(_run()) assert store._degraded # Still degraded def test_aclose_cancels_health_check_task(self): """aclose() cancels the health check task.""" import asyncio store = RedisUsageStore(redis_url="redis://localhost:6379") async def _run(): store._degrade_to_fallback() assert store._health_check_task is not None await store.aclose() assert store._health_check_task is None asyncio.run(_run()) # --------------------------------------------------------------------------- # U1: Fail-closed — get_usage raises when Redis query fails # --------------------------------------------------------------------------- class TestRedisUsageStoreFailClosed: """U1 (KTD-1): Redis query failure → UsageStoreUnavailableError (not empty).""" def test_get_usage_raises_on_redis_connection_failure(self): """When Redis connection fails during get_usage, raise (fail-closed).""" from unittest.mock import MagicMock mock_redis = MagicMock() mock_redis.scan_iter.side_effect = ConnectionError("Redis gone") store = RedisUsageStore(redis_url="redis://localhost:6379") store._sync_redis = mock_redis day = datetime(2026, 6, 21, tzinfo=timezone.utc) with pytest.raises(UsageStoreUnavailableError): store.get_usage(department_id="d1", start_time=day, end_time=day + timedelta(days=1)) # Should also degrade for future calls. assert store._degraded def test_get_usage_fail_closed_not_empty_summary(self): """Critical: must NOT return empty summary on Redis failure (that's fail-open).""" from unittest.mock import MagicMock mock_redis = MagicMock() mock_redis.scan_iter.side_effect = ConnectionError("Redis gone") store = RedisUsageStore(redis_url="redis://localhost:6379") store._sync_redis = mock_redis day = datetime(2026, 6, 21, tzinfo=timezone.utc) # Must raise, not return UsageSummary(). with pytest.raises(UsageStoreUnavailableError): store.get_usage(start_time=day, end_time=day + timedelta(days=1)) # --------------------------------------------------------------------------- # U1: Async record — record_async uses redis.asyncio # --------------------------------------------------------------------------- class TestRedisUsageStoreAsyncRecord: """U1 (KTD-6): record_async uses async Redis client, doesn't block event loop.""" def test_record_async_writes_to_redis(self): """record_async should write to Redis via async client.""" import asyncio from unittest.mock import AsyncMock, MagicMock store = RedisUsageStore(redis_url="redis://localhost:6379") mock_pipeline = MagicMock() mock_pipeline.execute = AsyncMock(return_value=[True]) mock_redis = MagicMock() mock_redis.pipeline.return_value = mock_pipeline async def _run(): from unittest.mock import patch with patch.object(store, "_get_redis", return_value=mock_redis): usage = TokenUsage(prompt_tokens=100, completion_tokens=50) await store.record_async( "agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1", department_id="d1", ) asyncio.run(_run()) # Verify pipeline was used. mock_redis.pipeline.assert_called_once() mock_pipeline.hincrbyfloat.assert_called_once() mock_pipeline.rpush.assert_called_once()