"""RedisTaskStore unit tests - uses mock Redis (no real Redis required)""" import json from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.protocol import TaskStatus from agentkit.server.task_store import ( InMemoryTaskStore, RedisTaskStore, TaskRecord, TaskStore, create_task_store, ) # ═══════════════════════════════════════════════════════════ # Helpers – lightweight fake Redis for unit tests # ═══════════════════════════════════════════════════════════ class FakeRedis: """Minimal in-memory fake that satisfies the RedisTaskStore interface.""" def __init__(self): self._data: dict[str, str] = {} @classmethod def from_url(cls, url, **kwargs): return cls() async def get(self, key): return self._data.get(key) async def set(self, key, value, ex=None, **kwargs): self._data[key] = value async def delete(self, key): self._data.pop(key, None) async def mget(self, keys): return [self._data.get(k) for k in keys] async def scan(self, cursor=0, match=None, count=200): """Simplified SCAN – returns all matching keys in one batch.""" import fnmatch pattern = match or "*" matched = [k for k in self._data if fnmatch.fnmatch(k, pattern)] # cursor=0 means "done" return (0, matched) async def close(self): pass async def eval(self, script, numkeys, *args): """Simulate Redis EVAL for the update_status Lua script.""" # This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore key = args[0] ttl = int(args[1]) n = int(args[2]) raw = self._data.get(key) if raw is None: return None data = json.loads(raw) for i in range(n): k = args[3 + 2 * i] v = args[4 + 2 * i] # Try to parse JSON values (dicts/lists), otherwise keep as string try: data[k] = json.loads(v) except (json.JSONDecodeError, TypeError): data[k] = v encoded = json.dumps(data) self._data[key] = encoded return encoded def _make_redis_store(fake_redis: FakeRedis | None = None) -> RedisTaskStore: """Build a RedisTaskStore with a FakeRedis injected.""" store = RedisTaskStore(redis_url="redis://fake/0") if fake_redis is None: fake_redis = FakeRedis() store._redis = fake_redis return store # ═══════════════════════════════════════════════════════════ # TaskRecord.from_dict round-trip # ═══════════════════════════════════════════════════════════ class TestTaskRecordRoundTrip: """Verify TaskRecord serialisation / deserialisation.""" def test_to_dict_from_dict_round_trip(self): now = datetime.now(timezone.utc) record = TaskRecord( task_id="t1", agent_name="agent_a", skill_name="skill_x", input_data={"query": "hello"}, status=TaskStatus.RUNNING, output_data={"result": "world"}, error_message=None, created_at=now, started_at=now, completed_at=None, progress=0.5, progress_message="Halfway", metadata={"key": "val"}, ) restored = TaskRecord.from_dict(record.to_dict()) assert restored.task_id == record.task_id assert restored.agent_name == record.agent_name assert restored.skill_name == record.skill_name assert restored.input_data == record.input_data assert restored.status == record.status assert restored.output_data == record.output_data assert restored.progress == record.progress assert restored.progress_message == record.progress_message assert restored.metadata == record.metadata def test_from_dict_with_none_fields(self): data = { "task_id": "t2", "agent_name": "b", "skill_name": None, "input_data": {}, "status": "pending", "output_data": None, "error_message": None, "created_at": datetime.now(timezone.utc).isoformat(), "started_at": None, "completed_at": None, "progress": 0.0, "progress_message": "", "metadata": {}, } record = TaskRecord.from_dict(data) assert record.skill_name is None assert record.started_at is None assert record.completed_at is None # ═══════════════════════════════════════════════════════════ # RedisTaskStore – happy path # ═══════════════════════════════════════════════════════════ class TestRedisTaskStoreHappyPath: """Core CRUD operations on RedisTaskStore with mock Redis.""" @pytest.mark.asyncio async def test_create_and_get(self): store = _make_redis_store() record = await store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x") assert record.task_id == "t1" assert record.agent_name == "agent_a" assert record.skill_name == "skill_x" assert record.input_data == {"q": "hello"} assert record.status == TaskStatus.PENDING fetched = await store.get("t1") assert fetched is not None assert fetched.task_id == "t1" assert fetched.agent_name == "agent_a" @pytest.mark.asyncio async def test_update_status_changes_fields(self): store = _make_redis_store() await store.create("t1", "agent_a", {}) now = datetime.now(timezone.utc) updated = await store.update_status( "t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway", ) assert updated.status == TaskStatus.RUNNING assert updated.progress == 0.5 assert updated.progress_message == "Halfway" # Verify persistence fetched = await store.get("t1") assert fetched is not None assert fetched.status == TaskStatus.RUNNING assert fetched.progress == 0.5 @pytest.mark.asyncio async def test_list_tasks_sorted_by_created_at_desc(self): store = _make_redis_store() await store.create("t1", "agent_a", {}) await store.create("t2", "agent_b", {}) tasks = await store.list_tasks() assert len(tasks) == 2 # Most recent first (t2 created after t1) assert tasks[0].task_id == "t2" assert tasks[1].task_id == "t1" @pytest.mark.asyncio async def test_list_tasks_filtered_by_status(self): store = _make_redis_store() await store.create("t1", "agent_a", {}) await store.create("t2", "agent_b", {}) await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) tasks = await store.list_tasks(status=TaskStatus.COMPLETED) assert len(tasks) == 1 assert tasks[0].task_id == "t1" @pytest.mark.asyncio async def test_list_tasks_respects_limit(self): store = _make_redis_store() for i in range(5): await store.create(f"t{i}", "agent_a", {}) tasks = await store.list_tasks(limit=3) assert len(tasks) == 3 @pytest.mark.asyncio async def test_size_returns_count(self): store = _make_redis_store() assert await store.size == 0 await store.create("t1", "agent_a", {}) assert await store.size == 1 await store.create("t2", "agent_b", {}) assert await store.size == 2 @pytest.mark.asyncio async def test_start_cleanup_is_noop(self): store = _make_redis_store() # Should not raise await store.start_cleanup() @pytest.mark.asyncio async def test_stop_cleanup_closes_redis(self): fake = FakeRedis() store = _make_redis_store(fake) await store.stop_cleanup() assert store._redis is None # ═══════════════════════════════════════════════════════════ # RedisTaskStore – error / edge cases # ═══════════════════════════════════════════════════════════ class TestRedisTaskStoreErrors: """Error and edge-case handling.""" @pytest.mark.asyncio async def test_get_nonexistent_returns_none(self): store = _make_redis_store() result = await store.get("nonexistent") assert result is None @pytest.mark.asyncio async def test_update_status_nonexistent_raises_keyerror(self): store = _make_redis_store() with pytest.raises(KeyError, match="not found"): await store.update_status("nonexistent", TaskStatus.RUNNING) @pytest.mark.asyncio async def test_max_records_evicts_oldest_completed(self): fake = FakeRedis() store = _make_redis_store(fake) store._max_records = 2 await store.create("t1", "agent_a", {}) await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc)) await store.create("t2", "agent_b", {}) # t3 should evict t1 (oldest completed) await store.create("t3", "agent_c", {}) assert await store.get("t1") is None assert await store.get("t2") is not None assert await store.get("t3") is not None @pytest.mark.asyncio async def test_max_records_full_no_completed_raises(self): fake = FakeRedis() store = _make_redis_store(fake) store._max_records = 1 await store.create("t1", "agent_a", {}) # All tasks are PENDING, no completed to evict with pytest.raises(RuntimeError, match="full"): await store.create("t2", "agent_b", {}) # ═══════════════════════════════════════════════════════════ # TTL expiry (simulated by removing key from fake Redis) # ═══════════════════════════════════════════════════════════ class TestRedisTaskStoreTTL: """Simulate TTL expiry by manually removing keys from FakeRedis.""" @pytest.mark.asyncio async def test_expired_key_returns_none(self): fake = FakeRedis() store = _make_redis_store(fake) await store.create("t1", "agent_a", {}) # Simulate TTL expiry: remove key from fake Redis fake._data.pop(store._key("t1")) result = await store.get("t1") assert result is None # ═══════════════════════════════════════════════════════════ # create_task_store factory # ═══════════════════════════════════════════════════════════ class TestCreateTaskStore: """Factory function tests.""" def test_default_backend_is_memory(self): store = create_task_store() assert isinstance(store, InMemoryTaskStore) def test_explicit_memory_backend(self): store = create_task_store(backend="memory") assert isinstance(store, InMemoryTaskStore) def test_redis_backend_returns_redis_task_store(self): store = create_task_store(backend="redis", redis_url="redis://localhost:6379/0") assert isinstance(store, RedisTaskStore) def test_redis_unavailable_falls_back_to_memory(self): """If redis.asyncio import fails, factory falls back to InMemoryTaskStore.""" with patch.dict("sys.modules", {"redis.asyncio": None}): # Force import failure with patch("builtins.__import__", side_effect=ImportError("no redis")): store = create_task_store(backend="redis") assert isinstance(store, InMemoryTaskStore) def test_backward_compat_alias(self): """TaskStore is an alias for InMemoryTaskStore.""" assert TaskStore is InMemoryTaskStore