452 lines
16 KiB
Python
452 lines
16 KiB
Python
"""TaskStore - Task state storage with TTL (InMemory / Redis backends)"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
from agentkit.core.protocol import TaskStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class TaskRecord:
|
||
"""Stored task record with full lifecycle data"""
|
||
task_id: str
|
||
agent_name: str
|
||
skill_name: str | None
|
||
input_data: dict[str, Any]
|
||
status: TaskStatus = TaskStatus.PENDING
|
||
output_data: dict[str, Any] | None = None
|
||
error_message: str | None = None
|
||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||
started_at: datetime | None = None
|
||
completed_at: datetime | None = None
|
||
progress: float = 0.0
|
||
progress_message: str = ""
|
||
metadata: dict[str, Any] = field(default_factory=dict)
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"task_id": self.task_id,
|
||
"agent_name": self.agent_name,
|
||
"skill_name": self.skill_name,
|
||
"input_data": self.input_data,
|
||
"status": self.status.value,
|
||
"output_data": self.output_data,
|
||
"error_message": self.error_message,
|
||
"created_at": self.created_at.isoformat(),
|
||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||
"progress": self.progress,
|
||
"progress_message": self.progress_message,
|
||
"metadata": self.metadata,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "TaskRecord":
|
||
"""Reconstruct a TaskRecord from a dict (e.g. deserialized from Redis)."""
|
||
return cls(
|
||
task_id=data["task_id"],
|
||
agent_name=data["agent_name"],
|
||
skill_name=data.get("skill_name"),
|
||
input_data=data.get("input_data", {}),
|
||
status=TaskStatus(data.get("status", "pending")),
|
||
output_data=data.get("output_data"),
|
||
error_message=data.get("error_message"),
|
||
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
|
||
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
|
||
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
|
||
progress=data.get("progress", 0.0),
|
||
progress_message=data.get("progress_message", ""),
|
||
metadata=data.get("metadata", {}),
|
||
)
|
||
|
||
|
||
class InMemoryTaskStore:
|
||
"""In-memory task state storage with automatic TTL cleanup.
|
||
|
||
Stores task records indexed by task_id. Automatically removes
|
||
completed tasks after a configurable TTL.
|
||
"""
|
||
|
||
def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000):
|
||
self._tasks: dict[str, TaskRecord] = {}
|
||
self._ttl_seconds = ttl_seconds
|
||
self._max_records = max_records
|
||
self._cleanup_task: asyncio.Task | None = None
|
||
|
||
@property
|
||
def backend_type(self) -> str:
|
||
"""Return the backend type identifier."""
|
||
return "memory"
|
||
|
||
async def start_cleanup(self) -> None:
|
||
"""Start background cleanup task"""
|
||
if self._cleanup_task is None:
|
||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||
|
||
async def stop_cleanup(self) -> None:
|
||
"""Stop background cleanup task"""
|
||
if self._cleanup_task:
|
||
self._cleanup_task.cancel()
|
||
try:
|
||
await self._cleanup_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._cleanup_task = None
|
||
|
||
async def _cleanup_loop(self) -> None:
|
||
"""Periodically remove expired task records"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(60)
|
||
self._cleanup_expired()
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"TaskStore cleanup error: {e}")
|
||
|
||
def _cleanup_expired(self) -> None:
|
||
"""Remove expired records"""
|
||
expired = []
|
||
for task_id, record in self._tasks.items():
|
||
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||
if record.completed_at:
|
||
age = (datetime.now(timezone.utc) - record.completed_at).total_seconds()
|
||
if age > self._ttl_seconds:
|
||
expired.append(task_id)
|
||
for task_id in expired:
|
||
del self._tasks[task_id]
|
||
if expired:
|
||
logger.info(f"TaskStore cleaned up {len(expired)} expired records")
|
||
|
||
def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
|
||
"""Create a new task record"""
|
||
if len(self._tasks) >= self._max_records:
|
||
# Remove oldest completed task
|
||
oldest = None
|
||
for rec in self._tasks.values():
|
||
if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||
if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)):
|
||
oldest = rec
|
||
if oldest:
|
||
del self._tasks[oldest.task_id]
|
||
else:
|
||
raise RuntimeError("TaskStore is full and no completed tasks to evict")
|
||
|
||
record = TaskRecord(
|
||
task_id=task_id,
|
||
agent_name=agent_name,
|
||
skill_name=skill_name,
|
||
input_data=input_data,
|
||
)
|
||
self._tasks[task_id] = record
|
||
return record
|
||
|
||
def get(self, task_id: str) -> TaskRecord | None:
|
||
"""Get task record by ID"""
|
||
return self._tasks.get(task_id)
|
||
|
||
def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
|
||
"""Update task status and optional fields"""
|
||
record = self._tasks.get(task_id)
|
||
if record is None:
|
||
raise KeyError(f"Task '{task_id}' not found")
|
||
record.status = status
|
||
for key, value in kwargs.items():
|
||
if hasattr(record, key):
|
||
setattr(record, key, value)
|
||
return record
|
||
|
||
def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
|
||
"""List tasks, optionally filtered by status"""
|
||
tasks = list(self._tasks.values())
|
||
if status:
|
||
tasks = [t for t in tasks if t.status == status]
|
||
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
||
return tasks[:limit]
|
||
|
||
def count_by_status(self) -> dict[str, int]:
|
||
"""Return a dict of status value -> count without materializing all records."""
|
||
counts: dict[str, int] = {}
|
||
for record in self._tasks.values():
|
||
key = record.status.value
|
||
counts[key] = counts.get(key, 0) + 1
|
||
return counts
|
||
|
||
@property
|
||
def size(self) -> int:
|
||
return len(self._tasks)
|
||
|
||
|
||
# Backward-compatible alias
|
||
TaskStore = InMemoryTaskStore
|
||
|
||
|
||
class RedisTaskStore:
|
||
"""Redis-backed task state storage with TTL.
|
||
|
||
Stores each task as a JSON string in Redis with key pattern
|
||
``agentkit:task:{task_id}``. Redis TTL handles automatic cleanup,
|
||
so start_cleanup / stop_cleanup are no-ops.
|
||
"""
|
||
|
||
KEY_PREFIX = "agentkit:task:"
|
||
|
||
def __init__(
|
||
self,
|
||
redis_url: str = "redis://localhost:6379/0",
|
||
ttl_seconds: int = 3600,
|
||
max_records: int = 10000,
|
||
):
|
||
self._redis_url = redis_url
|
||
self._ttl_seconds = ttl_seconds
|
||
self._max_records = max_records
|
||
self._redis: Any = None # redis.asyncio.Redis, lazy init
|
||
|
||
@property
|
||
def backend_type(self) -> str:
|
||
"""Return the backend type identifier."""
|
||
return "redis"
|
||
|
||
async def _get_redis(self):
|
||
"""Lazy-initialise the async Redis client."""
|
||
if self._redis is None:
|
||
import redis.asyncio as aioredis
|
||
|
||
self._redis = aioredis.from_url(
|
||
self._redis_url,
|
||
decode_responses=True,
|
||
)
|
||
return self._redis
|
||
|
||
def _key(self, task_id: str) -> str:
|
||
return f"{self.KEY_PREFIX}{task_id}"
|
||
|
||
# ── lifecycle (no-ops, Redis TTL handles cleanup) ──────────
|
||
|
||
async def start_cleanup(self) -> None:
|
||
"""No-op – Redis TTL handles expiry automatically."""
|
||
|
||
async def stop_cleanup(self) -> None:
|
||
"""Close the Redis connection pool on shutdown."""
|
||
if self._redis is not None:
|
||
await self._redis.close()
|
||
self._redis = None
|
||
|
||
# ── CRUD ───────────────────────────────────────────────────
|
||
|
||
async def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
|
||
"""Create a new task record in Redis."""
|
||
redis = await self._get_redis()
|
||
|
||
# Enforce max_records by counting existing keys
|
||
current_size = await self._count_keys(redis)
|
||
if current_size >= self._max_records:
|
||
# Try to evict the oldest completed task
|
||
evicted = await self._evict_oldest_completed(redis)
|
||
if not evicted:
|
||
raise RuntimeError("TaskStore is full and no completed tasks to evict")
|
||
|
||
record = TaskRecord(
|
||
task_id=task_id,
|
||
agent_name=agent_name,
|
||
skill_name=skill_name,
|
||
input_data=input_data,
|
||
)
|
||
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
|
||
return record
|
||
|
||
async def get(self, task_id: str) -> TaskRecord | None:
|
||
"""Get task record by ID."""
|
||
redis = await self._get_redis()
|
||
raw = await redis.get(self._key(task_id))
|
||
if raw is None:
|
||
return None
|
||
return TaskRecord.from_dict(json.loads(raw))
|
||
|
||
# Lua script for atomic read-modify-write
|
||
_UPDATE_STATUS_SCRIPT = """
|
||
local key = KEYS[1]
|
||
local ttl = tonumber(ARGV[1])
|
||
local raw = redis.call('GET', key)
|
||
if raw == false then
|
||
return nil
|
||
end
|
||
local data = cjson.decode(raw)
|
||
local n = tonumber(ARGV[2])
|
||
for i = 1, n do
|
||
local k = ARGV[2 + 2 * (i - 1) + 1]
|
||
local v = ARGV[2 + 2 * (i - 1) + 2]
|
||
data[k] = v
|
||
end
|
||
local encoded = cjson.encode(data)
|
||
redis.call('SET', key, encoded, 'EX', ttl)
|
||
return encoded
|
||
"""
|
||
|
||
async def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
|
||
"""Update task status and optional fields atomically via Lua script."""
|
||
redis = await self._get_redis()
|
||
key = self._key(task_id)
|
||
|
||
# Build flat list of key-value pairs for the merge fields
|
||
merge_fields = {"status": status.value}
|
||
for k, value in kwargs.items():
|
||
if k in ("started_at", "completed_at", "output_data", "error_message", "progress", "progress_message", "metadata"):
|
||
if isinstance(value, datetime):
|
||
merge_fields[k] = value.isoformat()
|
||
else:
|
||
merge_fields[k] = value
|
||
|
||
# Flatten merge_fields into ARGV pairs
|
||
args = [str(self._ttl_seconds), str(len(merge_fields))]
|
||
for k, v in merge_fields.items():
|
||
args.append(k)
|
||
args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v))
|
||
|
||
result = await redis.eval(self._UPDATE_STATUS_SCRIPT, 1, key, *args)
|
||
if result is None:
|
||
raise KeyError(f"Task '{task_id}' not found")
|
||
data = json.loads(result)
|
||
return TaskRecord.from_dict(data)
|
||
|
||
async def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
|
||
"""List tasks, optionally filtered by status, sorted by created_at desc."""
|
||
redis = await self._get_redis()
|
||
tasks: list[TaskRecord] = []
|
||
cursor = 0
|
||
while True:
|
||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||
if keys:
|
||
values = await redis.mget(keys)
|
||
for raw in values:
|
||
if raw is None:
|
||
continue
|
||
record = TaskRecord.from_dict(json.loads(raw))
|
||
if status is None or record.status == status:
|
||
tasks.append(record)
|
||
if cursor == 0:
|
||
break
|
||
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
||
return tasks[:limit]
|
||
|
||
async def count_by_status(self) -> dict[str, int]:
|
||
"""Return a dict of status value -> count using SCAN without materializing all records."""
|
||
redis = await self._get_redis()
|
||
counts: dict[str, int] = {}
|
||
cursor = 0
|
||
while True:
|
||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||
if keys:
|
||
values = await redis.mget(keys)
|
||
for raw in values:
|
||
if raw is None:
|
||
continue
|
||
record = TaskRecord.from_dict(json.loads(raw))
|
||
key = record.status.value
|
||
counts[key] = counts.get(key, 0) + 1
|
||
if cursor == 0:
|
||
break
|
||
return counts
|
||
|
||
@property
|
||
async def size(self) -> int:
|
||
"""Number of task keys currently stored."""
|
||
redis = await self._get_redis()
|
||
return await self._count_keys(redis)
|
||
|
||
# ── helpers ────────────────────────────────────────────────
|
||
|
||
async def _count_keys(self, redis) -> int:
|
||
"""Count task keys using SCAN (avoid KEYS on large datasets)."""
|
||
count = 0
|
||
cursor = 0
|
||
while True:
|
||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||
count += len(keys)
|
||
if cursor == 0:
|
||
break
|
||
return count
|
||
|
||
async def _evict_oldest_completed(self, redis) -> bool:
|
||
"""Find and delete the oldest completed/failed/cancelled task.
|
||
Returns True if a record was evicted, False otherwise.
|
||
"""
|
||
tasks: list[TaskRecord] = []
|
||
cursor = 0
|
||
while True:
|
||
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
|
||
if keys:
|
||
values = await redis.mget(keys)
|
||
for raw in values:
|
||
if raw is None:
|
||
continue
|
||
record = TaskRecord.from_dict(json.loads(raw))
|
||
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
|
||
tasks.append(record)
|
||
if cursor == 0:
|
||
break
|
||
|
||
if not tasks:
|
||
return False
|
||
|
||
# Pick the one with the earliest completed_at
|
||
oldest = min(
|
||
(t for t in tasks if t.completed_at is not None),
|
||
key=lambda t: t.completed_at, # type: ignore[arg-type]
|
||
default=None,
|
||
)
|
||
if oldest is None:
|
||
return False
|
||
|
||
await redis.delete(self._key(oldest.task_id))
|
||
return True
|
||
|
||
|
||
def create_task_store(
|
||
backend: str = "memory",
|
||
redis_url: str = "redis://localhost:6379/0",
|
||
ttl_seconds: int = 3600,
|
||
max_records: int = 10000,
|
||
) -> InMemoryTaskStore | RedisTaskStore:
|
||
"""Factory: create a TaskStore backed by memory or Redis.
|
||
|
||
If ``backend="redis"`` and the Redis connection cannot be established,
|
||
falls back to :class:`InMemoryTaskStore` with a warning.
|
||
"""
|
||
if backend == "redis":
|
||
try:
|
||
import redis.asyncio as aioredis # noqa: F401
|
||
|
||
store = RedisTaskStore(
|
||
redis_url=redis_url,
|
||
ttl_seconds=ttl_seconds,
|
||
max_records=max_records,
|
||
)
|
||
logger.info(f"TaskStore backend: redis ({_sanitize_redis_url(redis_url)})")
|
||
return store
|
||
except Exception as exc:
|
||
logger.warning(f"Failed to initialise RedisTaskStore ({exc}), falling back to InMemoryTaskStore")
|
||
|
||
store = InMemoryTaskStore(ttl_seconds=ttl_seconds, max_records=max_records)
|
||
logger.info("TaskStore backend: memory")
|
||
return store
|
||
|
||
|
||
def _sanitize_redis_url(url: str) -> str:
|
||
"""Mask the password in a Redis URL for safe logging."""
|
||
from urllib.parse import urlparse, urlunparse
|
||
|
||
parsed = urlparse(url)
|
||
if parsed.password:
|
||
netloc = f"{parsed.username}:****@{parsed.hostname}"
|
||
if parsed.port:
|
||
netloc += f":{parsed.port}"
|
||
return urlunparse(parsed._replace(netloc=netloc))
|
||
return url
|