"""TaskStore - Task state storage with TTL (InMemory / Redis backends)""" import asyncio import json import logging from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any from agentkit.core.protocol import TaskStatus logger = logging.getLogger(__name__) @dataclass class TaskRecord: """Stored task record with full lifecycle data""" task_id: str agent_name: str skill_name: str | None input_data: dict[str, Any] status: TaskStatus = TaskStatus.PENDING output_data: dict[str, Any] | None = None error_message: str | None = None created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) started_at: datetime | None = None completed_at: datetime | None = None progress: float = 0.0 progress_message: str = "" metadata: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict: return { "task_id": self.task_id, "agent_name": self.agent_name, "skill_name": self.skill_name, "input_data": self.input_data, "status": self.status.value, "output_data": self.output_data, "error_message": self.error_message, "created_at": self.created_at.isoformat(), "started_at": self.started_at.isoformat() if self.started_at else None, "completed_at": self.completed_at.isoformat() if self.completed_at else None, "progress": self.progress, "progress_message": self.progress_message, "metadata": self.metadata, } @classmethod def from_dict(cls, data: dict) -> "TaskRecord": """Reconstruct a TaskRecord from a dict (e.g. deserialized from Redis).""" return cls( task_id=data["task_id"], agent_name=data["agent_name"], skill_name=data.get("skill_name"), input_data=data.get("input_data", {}), status=TaskStatus(data.get("status", "pending")), output_data=data.get("output_data"), error_message=data.get("error_message"), created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc), started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None, completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None, progress=data.get("progress", 0.0), progress_message=data.get("progress_message", ""), metadata=data.get("metadata", {}), ) class InMemoryTaskStore: """In-memory task state storage with automatic TTL cleanup. Stores task records indexed by task_id. Automatically removes completed tasks after a configurable TTL. """ def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000): self._tasks: dict[str, TaskRecord] = {} self._ttl_seconds = ttl_seconds self._max_records = max_records self._cleanup_task: asyncio.Task | None = None @property def backend_type(self) -> str: """Return the backend type identifier.""" return "memory" async def start_cleanup(self) -> None: """Start background cleanup task""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._cleanup_loop()) async def stop_cleanup(self) -> None: """Stop background cleanup task""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass self._cleanup_task = None async def _cleanup_loop(self) -> None: """Periodically remove expired task records""" while True: try: await asyncio.sleep(60) self._cleanup_expired() except asyncio.CancelledError: break except Exception as e: logger.error(f"TaskStore cleanup error: {e}") def _cleanup_expired(self) -> None: """Remove expired records""" expired = [] for task_id, record in self._tasks.items(): if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): if record.completed_at: age = (datetime.now(timezone.utc) - record.completed_at).total_seconds() if age > self._ttl_seconds: expired.append(task_id) for task_id in expired: del self._tasks[task_id] if expired: logger.info(f"TaskStore cleaned up {len(expired)} expired records") def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord: """Create a new task record""" if len(self._tasks) >= self._max_records: # Remove oldest completed task oldest = None for rec in self._tasks.values(): if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)): oldest = rec if oldest: del self._tasks[oldest.task_id] else: raise RuntimeError("TaskStore is full and no completed tasks to evict") record = TaskRecord( task_id=task_id, agent_name=agent_name, skill_name=skill_name, input_data=input_data, ) self._tasks[task_id] = record return record def get(self, task_id: str) -> TaskRecord | None: """Get task record by ID""" return self._tasks.get(task_id) def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord: """Update task status and optional fields""" record = self._tasks.get(task_id) if record is None: raise KeyError(f"Task '{task_id}' not found") record.status = status for key, value in kwargs.items(): if hasattr(record, key): setattr(record, key, value) return record def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]: """List tasks, optionally filtered by status""" tasks = list(self._tasks.values()) if status: tasks = [t for t in tasks if t.status == status] tasks.sort(key=lambda t: t.created_at, reverse=True) return tasks[:limit] def count_by_status(self) -> dict[str, int]: """Return a dict of status value -> count without materializing all records.""" counts: dict[str, int] = {} for record in self._tasks.values(): key = record.status.value counts[key] = counts.get(key, 0) + 1 return counts @property def size(self) -> int: return len(self._tasks) # Backward-compatible alias TaskStore = InMemoryTaskStore class RedisTaskStore: """Redis-backed task state storage with TTL. Stores each task as a JSON string in Redis with key pattern ``agentkit:task:{task_id}``. Redis TTL handles automatic cleanup, so start_cleanup / stop_cleanup are no-ops. """ KEY_PREFIX = "agentkit:task:" def __init__( self, redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 3600, max_records: int = 10000, ): self._redis_url = redis_url self._ttl_seconds = ttl_seconds self._max_records = max_records self._redis: Any = None # redis.asyncio.Redis, lazy init @property def backend_type(self) -> str: """Return the backend type identifier.""" return "redis" async def _get_redis(self): """Lazy-initialise the async Redis client.""" if self._redis is None: import redis.asyncio as aioredis self._redis = aioredis.from_url( self._redis_url, decode_responses=True, ) return self._redis def _key(self, task_id: str) -> str: return f"{self.KEY_PREFIX}{task_id}" # ── lifecycle (no-ops, Redis TTL handles cleanup) ────────── async def start_cleanup(self) -> None: """No-op – Redis TTL handles expiry automatically.""" async def stop_cleanup(self) -> None: """Close the Redis connection pool on shutdown.""" if self._redis is not None: await self._redis.close() self._redis = None # ── CRUD ─────────────────────────────────────────────────── async def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord: """Create a new task record in Redis.""" redis = await self._get_redis() # Enforce max_records by counting existing keys current_size = await self._count_keys(redis) if current_size >= self._max_records: # Try to evict the oldest completed task evicted = await self._evict_oldest_completed(redis) if not evicted: raise RuntimeError("TaskStore is full and no completed tasks to evict") record = TaskRecord( task_id=task_id, agent_name=agent_name, skill_name=skill_name, input_data=input_data, ) await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds) return record async def get(self, task_id: str) -> TaskRecord | None: """Get task record by ID.""" redis = await self._get_redis() raw = await redis.get(self._key(task_id)) if raw is None: return None return TaskRecord.from_dict(json.loads(raw)) # Lua script for atomic read-modify-write _UPDATE_STATUS_SCRIPT = """ local key = KEYS[1] local ttl = tonumber(ARGV[1]) local raw = redis.call('GET', key) if raw == false then return nil end local data = cjson.decode(raw) local n = tonumber(ARGV[2]) for i = 1, n do local k = ARGV[2 + 2 * (i - 1) + 1] local v = ARGV[2 + 2 * (i - 1) + 2] data[k] = v end local encoded = cjson.encode(data) redis.call('SET', key, encoded, 'EX', ttl) return encoded """ async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord: """Update task status and optional fields atomically via Lua script.""" redis = await self._get_redis() key = self._key(task_id) # Build flat list of key-value pairs for the merge fields merge_fields = {"status": status.value} for k, value in kwargs.items(): if k in ("started_at", "completed_at", "output_data", "error_message", "progress", "progress_message", "metadata"): if isinstance(value, datetime): merge_fields[k] = value.isoformat() else: merge_fields[k] = value # Flatten merge_fields into ARGV pairs args = [str(self._ttl_seconds), str(len(merge_fields))] for k, v in merge_fields.items(): args.append(k) args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v)) result = await redis.eval(self._UPDATE_STATUS_SCRIPT, 1, key, *args) if result is None: raise KeyError(f"Task '{task_id}' not found") data = json.loads(result) return TaskRecord.from_dict(data) async def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]: """List tasks, optionally filtered by status, sorted by created_at desc.""" redis = await self._get_redis() tasks: list[TaskRecord] = [] cursor = 0 while True: cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) if keys: values = await redis.mget(keys) for raw in values: if raw is None: continue record = TaskRecord.from_dict(json.loads(raw)) if status is None or record.status == status: tasks.append(record) if cursor == 0: break tasks.sort(key=lambda t: t.created_at, reverse=True) return tasks[:limit] async def count_by_status(self) -> dict[str, int]: """Return a dict of status value -> count using SCAN without materializing all records.""" redis = await self._get_redis() counts: dict[str, int] = {} cursor = 0 while True: cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) if keys: values = await redis.mget(keys) for raw in values: if raw is None: continue record = TaskRecord.from_dict(json.loads(raw)) key = record.status.value counts[key] = counts.get(key, 0) + 1 if cursor == 0: break return counts @property async def size(self) -> int: """Number of task keys currently stored.""" redis = await self._get_redis() return await self._count_keys(redis) # ── helpers ──────────────────────────────────────────────── async def _count_keys(self, redis) -> int: """Count task keys using SCAN (avoid KEYS on large datasets).""" count = 0 cursor = 0 while True: cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) count += len(keys) if cursor == 0: break return count async def _evict_oldest_completed(self, redis) -> bool: """Find and delete the oldest completed/failed/cancelled task. Returns True if a record was evicted, False otherwise. """ tasks: list[TaskRecord] = [] cursor = 0 while True: cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200) if keys: values = await redis.mget(keys) for raw in values: if raw is None: continue record = TaskRecord.from_dict(json.loads(raw)) if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): tasks.append(record) if cursor == 0: break if not tasks: return False # Pick the one with the earliest completed_at oldest = min( (t for t in tasks if t.completed_at is not None), key=lambda t: t.completed_at, # type: ignore[arg-type] default=None, ) if oldest is None: return False await redis.delete(self._key(oldest.task_id)) return True def create_task_store( backend: str = "memory", redis_url: str = "redis://localhost:6379/0", ttl_seconds: int = 3600, max_records: int = 10000, ) -> InMemoryTaskStore | RedisTaskStore: """Factory: create a TaskStore backed by memory or Redis. If ``backend="redis"`` and the Redis connection cannot be established, falls back to :class:`InMemoryTaskStore` with a warning. """ if backend == "redis": try: import redis.asyncio as aioredis # noqa: F401 store = RedisTaskStore( redis_url=redis_url, ttl_seconds=ttl_seconds, max_records=max_records, ) logger.info(f"TaskStore backend: redis ({_sanitize_redis_url(redis_url)})") return store except Exception as exc: logger.warning(f"Failed to initialise RedisTaskStore ({exc}), falling back to InMemoryTaskStore") store = InMemoryTaskStore(ttl_seconds=ttl_seconds, max_records=max_records) logger.info("TaskStore backend: memory") return store def _sanitize_redis_url(url: str) -> str: """Mask the password in a Redis URL for safe logging.""" from urllib.parse import urlparse, urlunparse parsed = urlparse(url) if parsed.password: netloc = f"{parsed.username}:****@{parsed.hostname}" if parsed.port: netloc += f":{parsed.port}" return urlunparse(parsed._replace(netloc=netloc)) return url