374 lines
14 KiB
Python
374 lines
14 KiB
Python
"""RedisTaskStore unit tests - uses mock Redis (no real Redis required)"""
|
||
|
||
import json
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.protocol import TaskStatus
|
||
from agentkit.server.task_store import (
|
||
InMemoryTaskStore,
|
||
RedisTaskStore,
|
||
TaskRecord,
|
||
TaskStore,
|
||
create_task_store,
|
||
)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# Helpers – lightweight fake Redis for unit tests
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class FakeRedis:
|
||
"""Minimal in-memory fake that satisfies the RedisTaskStore interface."""
|
||
|
||
def __init__(self):
|
||
self._data: dict[str, str] = {}
|
||
self._zsets: dict[str, dict[str, float]] = {}
|
||
|
||
@classmethod
|
||
def from_url(cls, url, **kwargs):
|
||
return cls()
|
||
|
||
async def get(self, key):
|
||
return self._data.get(key)
|
||
|
||
async def set(self, key, value, ex=None, **kwargs):
|
||
self._data[key] = value
|
||
|
||
async def delete(self, key):
|
||
self._data.pop(key, None)
|
||
|
||
async def mget(self, keys):
|
||
return [self._data.get(k) for k in keys]
|
||
|
||
async def scan(self, cursor=0, match=None, count=200):
|
||
"""Simplified SCAN – returns all matching keys in one batch."""
|
||
import fnmatch
|
||
|
||
pattern = match or "*"
|
||
matched = [k for k in self._data if fnmatch.fnmatch(k, pattern)]
|
||
# cursor=0 means "done"
|
||
return (0, matched)
|
||
|
||
async def close(self):
|
||
pass
|
||
|
||
async def ping(self):
|
||
return True
|
||
|
||
# ── Sorted-set operations ──────────────────────────────
|
||
|
||
async def zadd(self, name, mapping):
|
||
zs = self._zsets.setdefault(name, {})
|
||
added = 0
|
||
for member, score in mapping.items():
|
||
if member not in zs:
|
||
added += 1
|
||
zs[member] = score
|
||
return added
|
||
|
||
async def zcard(self, name):
|
||
return len(self._zsets.get(name, {}))
|
||
|
||
async def zrange(self, name, start, end):
|
||
zs = self._zsets.get(name, {})
|
||
# Sort by score, then by member for deterministic order
|
||
sorted_members = sorted(zs.keys(), key=lambda m: (zs[m], m))
|
||
if end == -1:
|
||
return sorted_members[start:]
|
||
return sorted_members[start : end + 1]
|
||
|
||
async def zrem(self, name, *members):
|
||
zs = self._zsets.get(name, {})
|
||
removed = 0
|
||
for m in members:
|
||
if m in zs:
|
||
del zs[m]
|
||
removed += 1
|
||
return removed
|
||
|
||
async def eval(self, script, numkeys, *args):
|
||
"""Simulate Redis EVAL for the update_status Lua script."""
|
||
# This implements the same logic as _UPDATE_STATUS_SCRIPT in RedisTaskStore
|
||
key = args[0]
|
||
reset_ttl = args[1]
|
||
ttl = int(args[2])
|
||
n = int(args[3])
|
||
raw = self._data.get(key)
|
||
if raw is None:
|
||
return None
|
||
data = json.loads(raw)
|
||
for i in range(n):
|
||
k = args[4 + 2 * i]
|
||
v = args[5 + 2 * i]
|
||
# Try to parse JSON values (dicts/lists), otherwise keep as string
|
||
try:
|
||
data[k] = json.loads(v)
|
||
except (json.JSONDecodeError, TypeError):
|
||
data[k] = v
|
||
encoded = json.dumps(data)
|
||
self._data[key] = encoded
|
||
return encoded
|
||
|
||
|
||
def _make_redis_store(fake_redis: FakeRedis | None = None) -> RedisTaskStore:
|
||
"""Build a RedisTaskStore with a FakeRedis injected."""
|
||
store = RedisTaskStore(redis_url="redis://fake/0")
|
||
if fake_redis is None:
|
||
fake_redis = FakeRedis()
|
||
store._redis = fake_redis
|
||
return store
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# TaskRecord.from_dict round-trip
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestTaskRecordRoundTrip:
|
||
"""Verify TaskRecord serialisation / deserialisation."""
|
||
|
||
def test_to_dict_from_dict_round_trip(self):
|
||
now = datetime.now(timezone.utc)
|
||
record = TaskRecord(
|
||
task_id="t1",
|
||
agent_name="agent_a",
|
||
skill_name="skill_x",
|
||
input_data={"query": "hello"},
|
||
status=TaskStatus.RUNNING,
|
||
output_data={"result": "world"},
|
||
error_message=None,
|
||
created_at=now,
|
||
started_at=now,
|
||
completed_at=None,
|
||
progress=0.5,
|
||
progress_message="Halfway",
|
||
metadata={"key": "val"},
|
||
)
|
||
restored = TaskRecord.from_dict(record.to_dict())
|
||
assert restored.task_id == record.task_id
|
||
assert restored.agent_name == record.agent_name
|
||
assert restored.skill_name == record.skill_name
|
||
assert restored.input_data == record.input_data
|
||
assert restored.status == record.status
|
||
assert restored.output_data == record.output_data
|
||
assert restored.progress == record.progress
|
||
assert restored.progress_message == record.progress_message
|
||
assert restored.metadata == record.metadata
|
||
|
||
def test_from_dict_with_none_fields(self):
|
||
data = {
|
||
"task_id": "t2",
|
||
"agent_name": "b",
|
||
"skill_name": None,
|
||
"input_data": {},
|
||
"status": "pending",
|
||
"output_data": None,
|
||
"error_message": None,
|
||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||
"started_at": None,
|
||
"completed_at": None,
|
||
"progress": 0.0,
|
||
"progress_message": "",
|
||
"metadata": {},
|
||
}
|
||
record = TaskRecord.from_dict(data)
|
||
assert record.skill_name is None
|
||
assert record.started_at is None
|
||
assert record.completed_at is None
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# RedisTaskStore – happy path
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestRedisTaskStoreHappyPath:
|
||
"""Core CRUD operations on RedisTaskStore with mock Redis."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_and_get(self):
|
||
store = _make_redis_store()
|
||
record = await store.create("t1", "agent_a", {"q": "hello"}, skill_name="skill_x")
|
||
assert record.task_id == "t1"
|
||
assert record.agent_name == "agent_a"
|
||
assert record.skill_name == "skill_x"
|
||
assert record.input_data == {"q": "hello"}
|
||
assert record.status == TaskStatus.PENDING
|
||
|
||
fetched = await store.get("t1")
|
||
assert fetched is not None
|
||
assert fetched.task_id == "t1"
|
||
assert fetched.agent_name == "agent_a"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_status_changes_fields(self):
|
||
store = _make_redis_store()
|
||
await store.create("t1", "agent_a", {})
|
||
now = datetime.now(timezone.utc)
|
||
updated = await store.update_status(
|
||
"t1", TaskStatus.RUNNING, started_at=now, progress=0.5, progress_message="Halfway",
|
||
)
|
||
assert updated.status == TaskStatus.RUNNING
|
||
assert updated.progress == 0.5
|
||
assert updated.progress_message == "Halfway"
|
||
|
||
# Verify persistence
|
||
fetched = await store.get("t1")
|
||
assert fetched is not None
|
||
assert fetched.status == TaskStatus.RUNNING
|
||
assert fetched.progress == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_tasks_sorted_by_created_at_desc(self):
|
||
store = _make_redis_store()
|
||
await store.create("t1", "agent_a", {})
|
||
await store.create("t2", "agent_b", {})
|
||
tasks = await store.list_tasks()
|
||
assert len(tasks) == 2
|
||
# Most recent first (t2 created after t1)
|
||
assert tasks[0].task_id == "t2"
|
||
assert tasks[1].task_id == "t1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_tasks_filtered_by_status(self):
|
||
store = _make_redis_store()
|
||
await store.create("t1", "agent_a", {})
|
||
await store.create("t2", "agent_b", {})
|
||
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||
tasks = await store.list_tasks(status=TaskStatus.COMPLETED)
|
||
assert len(tasks) == 1
|
||
assert tasks[0].task_id == "t1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_tasks_respects_limit(self):
|
||
store = _make_redis_store()
|
||
for i in range(5):
|
||
await store.create(f"t{i}", "agent_a", {})
|
||
tasks = await store.list_tasks(limit=3)
|
||
assert len(tasks) == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_size_returns_count(self):
|
||
store = _make_redis_store()
|
||
assert await store.size == 0
|
||
await store.create("t1", "agent_a", {})
|
||
assert await store.size == 1
|
||
await store.create("t2", "agent_b", {})
|
||
assert await store.size == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_start_cleanup_is_noop(self):
|
||
store = _make_redis_store()
|
||
# Should not raise
|
||
await store.start_cleanup()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_stop_cleanup_closes_redis(self):
|
||
fake = FakeRedis()
|
||
store = _make_redis_store(fake)
|
||
await store.stop_cleanup()
|
||
assert store._redis is None
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# RedisTaskStore – error / edge cases
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestRedisTaskStoreErrors:
|
||
"""Error and edge-case handling."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_nonexistent_returns_none(self):
|
||
store = _make_redis_store()
|
||
result = await store.get("nonexistent")
|
||
assert result is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_status_nonexistent_raises_keyerror(self):
|
||
store = _make_redis_store()
|
||
with pytest.raises(KeyError, match="not found"):
|
||
await store.update_status("nonexistent", TaskStatus.RUNNING)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_max_records_evicts_oldest_completed(self):
|
||
fake = FakeRedis()
|
||
store = _make_redis_store(fake)
|
||
store._max_records = 2
|
||
|
||
await store.create("t1", "agent_a", {})
|
||
await store.update_status("t1", TaskStatus.COMPLETED, completed_at=datetime.now(timezone.utc))
|
||
await store.create("t2", "agent_b", {})
|
||
# t3 should evict t1 (oldest completed)
|
||
await store.create("t3", "agent_c", {})
|
||
assert await store.get("t1") is None
|
||
assert await store.get("t2") is not None
|
||
assert await store.get("t3") is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_max_records_full_no_completed_raises(self):
|
||
fake = FakeRedis()
|
||
store = _make_redis_store(fake)
|
||
store._max_records = 1
|
||
|
||
await store.create("t1", "agent_a", {})
|
||
# All tasks are PENDING, no completed to evict
|
||
with pytest.raises(RuntimeError, match="full"):
|
||
await store.create("t2", "agent_b", {})
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# TTL expiry (simulated by removing key from fake Redis)
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestRedisTaskStoreTTL:
|
||
"""Simulate TTL expiry by manually removing keys from FakeRedis."""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_expired_key_returns_none(self):
|
||
fake = FakeRedis()
|
||
store = _make_redis_store(fake)
|
||
await store.create("t1", "agent_a", {})
|
||
# Simulate TTL expiry: remove key from fake Redis
|
||
fake._data.pop(store._key("t1"))
|
||
result = await store.get("t1")
|
||
assert result is None
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════
|
||
# create_task_store factory
|
||
# ═══════════════════════════════════════════════════════════
|
||
|
||
|
||
class TestCreateTaskStore:
|
||
"""Factory function tests."""
|
||
|
||
def test_default_backend_is_memory(self):
|
||
store = create_task_store()
|
||
assert isinstance(store, InMemoryTaskStore)
|
||
|
||
def test_explicit_memory_backend(self):
|
||
store = create_task_store(backend="memory")
|
||
assert isinstance(store, InMemoryTaskStore)
|
||
|
||
def test_redis_backend_returns_redis_task_store(self):
|
||
store = create_task_store(backend="redis", redis_url="redis://localhost:6379/0")
|
||
assert isinstance(store, RedisTaskStore)
|
||
|
||
def test_redis_unavailable_falls_back_to_memory(self):
|
||
"""If redis.asyncio import fails, factory falls back to InMemoryTaskStore."""
|
||
with patch.dict("sys.modules", {"redis.asyncio": None}):
|
||
# Force import failure
|
||
with patch("builtins.__import__", side_effect=ImportError("no redis")):
|
||
store = create_task_store(backend="redis")
|
||
assert isinstance(store, InMemoryTaskStore)
|
||
|
||
def test_backward_compat_alias(self):
|
||
"""TaskStore is an alias for InMemoryTaskStore."""
|
||
assert TaskStore is InMemoryTaskStore
|