fischer-agentkit/tests/unit/test_task_store_redis.py

374 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RedisTaskStore unit tests - uses mock Redis (no real Redis required)"""
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.protocol import TaskStatus
from agentkit.server.task_store import (
InMemoryTaskStore,
RedisTaskStore,
TaskRecord,
TaskStore,
create_task_store,
)
# ═══════════════════════════════════════════════════════════
# Helpers lightweight fake Redis for unit tests
# ═══════════════════════════════════════════════════════════
class FakeRedis:
"""Minimal in-memory fake that satisfies the RedisTaskStore interface."""
def __init__(self):
self._data: dict[str, str] = {}
self._zsets: dict[str, dict[str, float]] = {}
@classmethod
def from_url(cls, url, **kwargs):
return cls()
async def get(self, key):
return self._data.get(key)
async def set(self, key, value, ex=None, **kwargs):
self._data[key] = value
async def delete(self, key):
self._data.pop(key, None)
async def mget(self, keys):
return [self._data.get(k) for k in keys]
async def scan(self, cursor=0, match=None, count=200):
"""Simplified SCAN returns all matching keys in one batch."""
import fnmatch
pattern = match or "*"
matched = [k for k in self._data if fnmatch.fnmatch(k, pattern)]
# cursor=0 means "done"
return (0, matched)
async def close(self):
pass
async def ping(self):
return True
# ── Sorted-set operations ──────────────────────────────
async def zadd(self, name, mapping):
zs = self._zsets.setdefault(name, {})
added = 0
for member, score in mapping.items():
if member not in zs:
added += 1
zs[member] = score
return added
async def zcard(self, name):
return len(self._zsets.get(name, {}))
async def zrange(self, name, start, end):
zs = self._zsets.get(name, {})
# Sort by score, then by member for deterministic order
sorted_members = sorted(zs.keys(), key=lambda m: (zs[m], m))
if end == -1:
return sorted_members[start:]
return sorted_members[start : end + 1]
async def zrem(self, name, *members):
zs = self._zsets.get(name, {})
removed = 0
for m in members:
if m in zs:
del zs[m]
removed += 1
return removed
async def eval(self, script, numkeys, *args):
"""Simulate Redis EVAL for the update_status Lua script."""
# This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore
key = args[0]
reset_ttl = args[1]
ttl = int(args[2])
n = int(args[3])
raw = self._data.get(key)
if raw is None:
return None
data = json.loads(raw)
for i in range(n):
k = args[4 + 2 * i]
v = args[5 + 2 * i]
# Try to parse JSON values (dicts/lists), otherwise keep as string
try:
data[k] = json.loads(v)
except (json.JSONDecodeError, TypeError):
data[k] = v
encoded = json.dumps(data)
self._data[key] = encoded
return encoded
def _make_redis_store(fake_redis: FakeRedis | None = None) -> RedisTaskStore:
"""Build a RedisTaskStore with a FakeRedis injected."""
store = RedisTaskStore(redis_url="redis://fake/0")
if fake_redis is None:
fake_redis = FakeRedis()
store._redis = fake_redis
return store
# ═══════════════════════════════════════════════════════════
# TaskRecord.from_dict round-trip
# ═══════════════════════════════════════════════════════════
class TestTaskRecordRoundTrip:
"""Verify TaskRecord serialisation / deserialisation."""
def test_to_dict_from_dict_round_trip(self):
now = datetime.now(timezone.utc)
record = TaskRecord(
task_id="t1",
agent_name="agent_a",
skill_name="skill_x",
input_data={"query": "hello"},
status=TaskStatus.RUNNING,
output_data={"result": "world"},
error_message=None,
created_at=now,
started_at=now,
completed_at=None,
progress=0.5,
progress_message="Halfway",
metadata={"key": "val"},
)
restored = TaskRecord.from_dict(record.to_dict())
assert restored.task_id == record.task_id
assert restored.agent_name == record.agent_name
assert restored.skill_name == record.skill_name
assert restored.input_data == record.input_data
assert restored.status == record.status
assert restored.output_data == record.output_data
assert restored.progress == record.progress
assert restored.progress_message == record.progress_message
assert restored.metadata == record.metadata
def test_from_dict_with_none_fields(self):
data = {
"task_id": "t2",
"agent_name": "b",
"skill_name": None,
"input_data": {},
"status": "pending",
"output_data": None,
"error_message": None,
"created_at": datetime.now(timezone.utc).isoformat(),
"started_at": None,
"completed_at": None,
"progress": 0.0,
"progress_message": "",
"metadata": {},
}
record = TaskRecord.from_dict(data)
assert record.skill_name is None
assert record.started_at is None
assert record.completed_at is None
# ═══════════════════════════════════════════════════════════
# RedisTaskStore happy path
# ═══════════════════════════════════════════════════════════
class TestRedisTaskStoreHappyPath:
"""Core CRUD operations on RedisTaskStore with mock Redis."""
@pytest.mark.asyncio
async def test_create_and_get(self):
store = _make_redis_store()
record = await store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
assert record.task_id == "t1"
assert record.agent_name == "agent_a"
assert record.skill_name == "skill_x"
assert record.input_data == {"q": "hello"}
assert record.status == TaskStatus.PENDING
fetched = await store.get("t1")
assert fetched is not None
assert fetched.task_id == "t1"
assert fetched.agent_name == "agent_a"
@pytest.mark.asyncio
async def test_update_status_changes_fields(self):
store = _make_redis_store()
await store.create("t1", "agent_a", {})
now = datetime.now(timezone.utc)
updated = await store.update_status(
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway",
)
assert updated.status == TaskStatus.RUNNING
assert updated.progress == 0.5
assert updated.progress_message == "Halfway"
# Verify persistence
fetched = await store.get("t1")
assert fetched is not None
assert fetched.status == TaskStatus.RUNNING
assert fetched.progress == 0.5
@pytest.mark.asyncio
async def test_list_tasks_sorted_by_created_at_desc(self):
store = _make_redis_store()
await store.create("t1", "agent_a", {})
await store.create("t2", "agent_b", {})
tasks = await store.list_tasks()
assert len(tasks) == 2
# Most recent first (t2 created after t1)
assert tasks[0].task_id == "t2"
assert tasks[1].task_id == "t1"
@pytest.mark.asyncio
async def test_list_tasks_filtered_by_status(self):
store = _make_redis_store()
await store.create("t1", "agent_a", {})
await store.create("t2", "agent_b", {})
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
tasks = await store.list_tasks(status=TaskStatus.COMPLETED)
assert len(tasks) == 1
assert tasks[0].task_id == "t1"
@pytest.mark.asyncio
async def test_list_tasks_respects_limit(self):
store = _make_redis_store()
for i in range(5):
await store.create(f"t{i}", "agent_a", {})
tasks = await store.list_tasks(limit=3)
assert len(tasks) == 3
@pytest.mark.asyncio
async def test_size_returns_count(self):
store = _make_redis_store()
assert await store.size == 0
await store.create("t1", "agent_a", {})
assert await store.size == 1
await store.create("t2", "agent_b", {})
assert await store.size == 2
@pytest.mark.asyncio
async def test_start_cleanup_is_noop(self):
store = _make_redis_store()
# Should not raise
await store.start_cleanup()
@pytest.mark.asyncio
async def test_stop_cleanup_closes_redis(self):
fake = FakeRedis()
store = _make_redis_store(fake)
await store.stop_cleanup()
assert store._redis is None
# ═══════════════════════════════════════════════════════════
# RedisTaskStore error / edge cases
# ═══════════════════════════════════════════════════════════
class TestRedisTaskStoreErrors:
"""Error and edge-case handling."""
@pytest.mark.asyncio
async def test_get_nonexistent_returns_none(self):
store = _make_redis_store()
result = await store.get("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_update_status_nonexistent_raises_keyerror(self):
store = _make_redis_store()
with pytest.raises(KeyError, match="not found"):
await store.update_status("nonexistent", TaskStatus.RUNNING)
@pytest.mark.asyncio
async def test_max_records_evicts_oldest_completed(self):
fake = FakeRedis()
store = _make_redis_store(fake)
store._max_records = 2
await store.create("t1", "agent_a", {})
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
await store.create("t2", "agent_b", {})
# t3 should evict t1 (oldest completed)
await store.create("t3", "agent_c", {})
assert await store.get("t1") is None
assert await store.get("t2") is not None
assert await store.get("t3") is not None
@pytest.mark.asyncio
async def test_max_records_full_no_completed_raises(self):
fake = FakeRedis()
store = _make_redis_store(fake)
store._max_records = 1
await store.create("t1", "agent_a", {})
# All tasks are PENDING, no completed to evict
with pytest.raises(RuntimeError, match="full"):
await store.create("t2", "agent_b", {})
# ═══════════════════════════════════════════════════════════
# TTL expiry (simulated by removing key from fake Redis)
# ═══════════════════════════════════════════════════════════
class TestRedisTaskStoreTTL:
"""Simulate TTL expiry by manually removing keys from FakeRedis."""
@pytest.mark.asyncio
async def test_expired_key_returns_none(self):
fake = FakeRedis()
store = _make_redis_store(fake)
await store.create("t1", "agent_a", {})
# Simulate TTL expiry: remove key from fake Redis
fake._data.pop(store._key("t1"))
result = await store.get("t1")
assert result is None
# ═══════════════════════════════════════════════════════════
# create_task_store factory
# ═══════════════════════════════════════════════════════════
class TestCreateTaskStore:
"""Factory function tests."""
def test_default_backend_is_memory(self):
store = create_task_store()
assert isinstance(store, InMemoryTaskStore)
def test_explicit_memory_backend(self):
store = create_task_store(backend="memory")
assert isinstance(store, InMemoryTaskStore)
def test_redis_backend_returns_redis_task_store(self):
store = create_task_store(backend="redis", redis_url="redis://localhost:6379/0")
assert isinstance(store, RedisTaskStore)
def test_redis_unavailable_falls_back_to_memory(self):
"""If redis.asyncio import fails, factory falls back to InMemoryTaskStore."""
with patch.dict("sys.modules", {"redis.asyncio": None}):
# Force import failure
with patch("builtins.__import__", side_effect=ImportError("no redis")):
store = create_task_store(backend="redis")
assert isinstance(store, InMemoryTaskStore)
def test_backward_compat_alias(self):
"""TaskStore is an alias for InMemoryTaskStore."""
assert TaskStore is InMemoryTaskStore