fischer-agentkit/src/agentkit/server/task_store.py

526 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""TaskStore - Task state storage with TTL (InMemory / Redis backends)"""
import asyncio
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
from agentkit.core.protocol import TaskStatus
logger = logging.getLogger(__name__)
@dataclass
class TaskRecord:
"""Stored task record with full lifecycle data"""
task_id: str
agent_name: str
skill_name: str | None
input_data: dict[str, Any]
status: TaskStatus = TaskStatus.PENDING
output_data: dict[str, Any] | None = None
error_message: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
started_at: datetime | None = None
completed_at: datetime | None = None
progress: float = 0.0
progress_message: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"task_id": self.task_id,
"agent_name": self.agent_name,
"skill_name": self.skill_name,
"input_data": self.input_data,
"status": self.status.value,
"output_data": self.output_data,
"error_message": self.error_message,
"created_at": self.created_at.isoformat(),
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"progress": self.progress,
"progress_message": self.progress_message,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: dict) -> "TaskRecord":
"""Reconstruct a TaskRecord from a dict (e.g. deserialized from Redis)."""
return cls(
task_id=data["task_id"],
agent_name=data["agent_name"],
skill_name=data.get("skill_name"),
input_data=data.get("input_data", {}),
status=TaskStatus(data.get("status", "pending")),
output_data=data.get("output_data"),
error_message=data.get("error_message"),
created_at=datetime.fromisoformat(data["created_at"]) if data.get("created_at") else datetime.now(timezone.utc),
started_at=datetime.fromisoformat(data["started_at"]) if data.get("started_at") else None,
completed_at=datetime.fromisoformat(data["completed_at"]) if data.get("completed_at") else None,
progress=data.get("progress", 0.0),
progress_message=data.get("progress_message", ""),
metadata=data.get("metadata", {}),
)
class InMemoryTaskStore:
"""In-memory task state storage with automatic TTL cleanup.
Stores task records indexed by task_id. Automatically removes
completed tasks after a configurable TTL.
"""
def __init__(self, ttl_seconds: int = 3600, max_records: int = 10000):
self._tasks: dict[str, TaskRecord] = {}
self._ttl_seconds = ttl_seconds
self._max_records = max_records
self._cleanup_task: asyncio.Task | None = None
@property
def backend_type(self) -> str:
"""Return the backend type identifier."""
return "memory"
async def start_cleanup(self) -> None:
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
async def stop_cleanup(self) -> None:
"""Stop background cleanup task"""
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
self._cleanup_task = None
async def _cleanup_loop(self) -> None:
"""Periodically remove expired task records"""
while True:
try:
await asyncio.sleep(60)
self._cleanup_expired()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"TaskStore cleanup error: {e}")
def _cleanup_expired(self) -> None:
"""Remove expired records"""
expired = []
for task_id, record in self._tasks.items():
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
if record.completed_at:
age = (datetime.now(timezone.utc) - record.completed_at).total_seconds()
if age > self._ttl_seconds:
expired.append(task_id)
for task_id in expired:
del self._tasks[task_id]
if expired:
logger.info(f"TaskStore cleaned up {len(expired)} expired records")
def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
"""Create a new task record"""
if len(self._tasks) >= self._max_records:
# Remove oldest completed task
oldest = None
for rec in self._tasks.values():
if rec.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
if oldest is None or (rec.completed_at and (oldest.completed_at is None or rec.completed_at < oldest.completed_at)):
oldest = rec
if oldest:
del self._tasks[oldest.task_id]
else:
raise RuntimeError("TaskStore is full and no completed tasks to evict")
record = TaskRecord(
task_id=task_id,
agent_name=agent_name,
skill_name=skill_name,
input_data=input_data,
)
self._tasks[task_id] = record
return record
def get(self, task_id: str) -> TaskRecord | None:
"""Get task record by ID"""
return self._tasks.get(task_id)
def update_status(self, task_id: str, status: TaskStatus, **kwargs) -> TaskRecord:
"""Update task status and optional fields"""
record = self._tasks.get(task_id)
if record is None:
raise KeyError(f"Task '{task_id}' not found")
record.status = status
for key, value in kwargs.items():
if hasattr(record, key):
setattr(record, key, value)
return record
def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
"""List tasks, optionally filtered by status"""
tasks = list(self._tasks.values())
if status:
tasks = [t for t in tasks if t.status == status]
tasks.sort(key=lambda t: t.created_at, reverse=True)
return tasks[:limit]
def count_by_status(self) -> dict[str, int]:
"""Return a dict of status value -> count without materializing all records."""
counts: dict[str, int] = {}
for record in self._tasks.values():
key = record.status.value
counts[key] = counts.get(key, 0) + 1
return counts
@property
def size(self) -> int:
return len(self._tasks)
async def health_check(self) -> bool:
"""Verify the store is operational. Always returns True for in-memory backend."""
return True
# Backward-compatible alias
TaskStore = InMemoryTaskStore
class RedisTaskStore:
"""Redis-backed task state storage with TTL.
Stores each task as a JSON string in Redis with key pattern
``agentkit:task:{task_id}``. Redis TTL handles automatic cleanup,
so start_cleanup / stop_cleanup are no-ops.
"""
KEY_PREFIX = "agentkit:task:"
ZSET_KEY = "agentkit:tasks:by_time"
def __init__(
self,
redis_url: str = "redis://localhost:6379/0",
ttl_seconds: int = 3600,
max_records: int = 10000,
):
self._redis_url = redis_url
self._ttl_seconds = ttl_seconds
self._max_records = max_records
self._redis: Any = None # redis.asyncio.Redis, lazy init
@property
def backend_type(self) -> str:
"""Return the backend type identifier."""
return "redis"
async def _get_redis(self):
"""Lazy-initialise the async Redis client."""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
)
return self._redis
def _key(self, task_id: str) -> str:
return f"{self.KEY_PREFIX}{task_id}"
# ── lifecycle (no-ops, Redis TTL handles cleanup) ──────────
async def start_cleanup(self) -> None:
"""No-op Redis TTL handles expiry automatically."""
async def stop_cleanup(self) -> None:
"""Close the Redis connection pool on shutdown."""
if self._redis is not None:
await self._redis.close()
self._redis = None
# ── CRUD ───────────────────────────────────────────────────
async def create(self, task_id: str, agent_name: str, input_data: dict, skill_name: str | None = None) -> TaskRecord:
"""Create a new task record in Redis."""
redis = await self._get_redis()
# Enforce max_records by counting existing keys
current_size = await self._count_keys(redis)
if current_size >= self._max_records:
# Try to evict the oldest completed task
evicted = await self._evict_oldest_completed(redis)
if not evicted:
raise RuntimeError("TaskStore is full and no completed tasks to evict")
record = TaskRecord(
task_id=task_id,
agent_name=agent_name,
skill_name=skill_name,
input_data=input_data,
)
score = record.created_at.timestamp()
await redis.set(self._key(task_id), json.dumps(record.to_dict()), ex=self._ttl_seconds)
await redis.zadd(self.ZSET_KEY, {task_id: score})
return record
async def get(self, task_id: str) -> TaskRecord | None:
"""Get task record by ID."""
redis = await self._get_redis()
raw = await redis.get(self._key(task_id))
if raw is None:
return None
return TaskRecord.from_dict(json.loads(raw))
# Lua script for atomic read-modify-write
# ARGV[1] = "1" to reset TTL (apply ex=ttl_seconds), "0" to keep existing TTL (KEEPTTL)
# ARGV[2] = ttl_seconds (only used when ARGV[1] == "1")
# ARGV[3] = number of merge fields
# ARGV[4..] = key/value pairs
_UPDATE_STATUS_SCRIPT = """
local reset_ttl = ARGV[1]
local ttl = tonumber(ARGV[2])
local n = tonumber(ARGV[3])
local key = KEYS[1]
local raw = redis.call('GET', key)
if raw == false then
return nil
end
local data = cjson.decode(raw)
for i = 1, n do
local k = ARGV[3 + 2 * (i - 1) + 1]
local v = ARGV[3 + 2 * (i - 1) + 2]
data[k] = v
end
local encoded = cjson.encode(data)
if reset_ttl == "1" then
redis.call('SET', key, encoded, 'EX', ttl)
else
redis.call('SET', key, encoded, 'KEEPTTL')
end
return encoded
"""
async def update_status(self, task_id: str, status: TaskStatus, reset_ttl: bool = False, **kwargs) -> TaskRecord:
"""Update task status and optional fields atomically via Lua script.
Args:
task_id: Task identifier.
status: New task status.
reset_ttl: If True, reset the Redis TTL to ``ttl_seconds``. Defaults to
False so that frequent status updates on a long-running task do not
extend its lifetime indefinitely.
**kwargs: Optional fields to update (started_at, completed_at, etc.).
"""
redis = await self._get_redis()
key = self._key(task_id)
# Build flat list of key-value pairs for the merge fields
merge_fields = {"status": status.value}
for k, value in kwargs.items():
if k in ("started_at", "completed_at", "output_data", "error_message", "progress", "progress_message", "metadata"):
if isinstance(value, datetime):
merge_fields[k] = value.isoformat()
else:
merge_fields[k] = value
# Flatten merge_fields into ARGV pairs
args = ["1" if reset_ttl else "0", str(self._ttl_seconds), str(len(merge_fields))]
for k, v in merge_fields.items():
args.append(k)
args.append(json.dumps(v) if isinstance(v, (dict, list)) else str(v))
result = await redis.eval(self._UPDATE_STATUS_SCRIPT, 1, key, *args)
if result is None:
raise KeyError(f"Task '{task_id}' not found")
data = json.loads(result)
return TaskRecord.from_dict(data)
async def list_tasks(self, status: TaskStatus | None = None, limit: int = 100) -> list[TaskRecord]:
"""List tasks, optionally filtered by status, sorted by created_at desc."""
redis = await self._get_redis()
tasks: list[TaskRecord] = []
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
if keys:
values = await redis.mget(keys)
for raw in values:
if raw is None:
continue
record = TaskRecord.from_dict(json.loads(raw))
if status is None or record.status == status:
tasks.append(record)
if cursor == 0:
break
tasks.sort(key=lambda t: t.created_at, reverse=True)
return tasks[:limit]
async def count_by_status(self) -> dict[str, int]:
"""Return a dict of status value -> count using SCAN without materializing all records."""
redis = await self._get_redis()
counts: dict[str, int] = {}
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
if keys:
values = await redis.mget(keys)
for raw in values:
if raw is None:
continue
record = TaskRecord.from_dict(json.loads(raw))
key = record.status.value
counts[key] = counts.get(key, 0) + 1
if cursor == 0:
break
return counts
@property
async def size(self) -> int:
"""Number of task keys currently stored."""
redis = await self._get_redis()
return await self._count_keys(redis)
async def health_check(self) -> bool:
"""Verify Redis connectivity by sending a PING command."""
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
# ── helpers ────────────────────────────────────────────────
async def _count_keys(self, redis) -> int:
"""Count task keys. Uses ZCARD on the sorted set for O(1) when
available, falls back to SCAN otherwise."""
try:
count = await redis.zcard(self.ZSET_KEY)
if count > 0:
return count
except Exception:
pass
# Fallback: full SCAN
count = 0
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
count += len(keys)
if cursor == 0:
break
return count
async def _evict_oldest_completed(self, redis) -> bool:
"""Find and delete the oldest completed/failed/cancelled task.
Uses ZRANGE on the sorted set for O(log N) when available,
falls back to full SCAN otherwise.
Returns True if a record was evicted, False otherwise.
"""
# Try ZSET-based eviction first
try:
member_count = await redis.zcard(self.ZSET_KEY)
if member_count > 0:
# Iterate from oldest (lowest score) to find a completed task
task_ids = await redis.zrange(self.ZSET_KEY, 0, -1)
for tid in task_ids:
raw = await redis.get(self._key(tid))
if raw is None:
# Stale ZSET entry clean up
await redis.zrem(self.ZSET_KEY, tid)
continue
record = TaskRecord.from_dict(json.loads(raw))
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED) and record.completed_at is not None:
await redis.delete(self._key(tid))
await redis.zrem(self.ZSET_KEY, tid)
return True
return False
except Exception:
pass
# Fallback: full SCAN
tasks: list[TaskRecord] = []
cursor = 0
while True:
cursor, keys = await redis.scan(cursor, match=f"{self.KEY_PREFIX}*", count=200)
if keys:
values = await redis.mget(keys)
for raw in values:
if raw is None:
continue
record = TaskRecord.from_dict(json.loads(raw))
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED):
tasks.append(record)
if cursor == 0:
break
if not tasks:
return False
# Pick the one with the earliest completed_at
oldest = min(
(t for t in tasks if t.completed_at is not None),
key=lambda t: t.completed_at, # type: ignore[arg-type]
default=None,
)
if oldest is None:
return False
await redis.delete(self._key(oldest.task_id))
try:
await redis.zrem(self.ZSET_KEY, oldest.task_id)
except Exception:
pass
return True
def create_task_store(
backend: str = "memory",
redis_url: str = "redis://localhost:6379/0",
ttl_seconds: int = 3600,
max_records: int = 10000,
) -> InMemoryTaskStore | RedisTaskStore:
"""Factory: create a TaskStore backed by memory or Redis.
If ``backend="redis"`` and the Redis connection cannot be established,
falls back to :class:`InMemoryTaskStore` with a warning.
Note:
This factory only validates that the ``redis`` package is importable.
Runtime connectivity should be verified via ``await store.health_check()``
during application startup.
"""
if backend == "redis":
try:
import redis.asyncio as aioredis # noqa: F401
store = RedisTaskStore(
redis_url=redis_url,
ttl_seconds=ttl_seconds,
max_records=max_records,
)
logger.info(f"TaskStore backend: redis ({_sanitize_redis_url(redis_url)})")
return store
except Exception as exc:
logger.warning(f"Failed to initialise RedisTaskStore ({exc}), falling back to InMemoryTaskStore")
store = InMemoryTaskStore(ttl_seconds=ttl_seconds, max_records=max_records)
logger.info("TaskStore backend: memory")
return store
def _sanitize_redis_url(url: str) -> str:
"""Mask the password in a Redis URL for safe logging."""
from urllib.parse import urlparse, urlunparse
parsed = urlparse(url)
if parsed.password:
netloc = f"{parsed.username}:****@{parsed.hostname}"
if parsed.port:
netloc += f":{parsed.port}"
return urlunparse(parsed._replace(netloc=netloc))
return url