fischer-agentkit/src/agentkit/bitable/recalc_worker.py

266 lines
10 KiB
Python

"""Async recalc worker for formula fields.
Consumes recalc tasks from the queue, evaluates formulas, and writes results
back to records. Supports crash recovery (resets stale ``calculating`` tasks
on startup) and graceful shutdown.
Lifecycle (managed by app.py lifespan):
worker = RecalcWorker(db, service)
await worker.start() # starts background task + crash recovery
...
await worker.stop() # waits for shutdown
The worker runs as an asyncio task, polling the queue every ``poll_interval``
seconds. Each task is processed in its own transaction.
"""
from __future__ import annotations
import asyncio
import logging
from agentkit.bitable.db import BitableDB
from agentkit.bitable.formula.engine import FormulaEngine
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
from agentkit.bitable.repository import BitableRepository
from agentkit.bitable.service import BitableService
logger = logging.getLogger(__name__)
_DEFAULT_POLL_INTERVAL = 0.5 # seconds between queue polls
_DEFAULT_REAPER_INTERVAL = 300 # 5 minutes
_DEFAULT_STALE_THRESHOLD = 600 # 10 minutes
class RecalcWorker:
"""Background worker that processes formula recalc tasks.
Usage::
worker = RecalcWorker(db, service)
await worker.start()
# ... worker runs in background ...
await worker.stop()
"""
def __init__(
self,
db: BitableDB,
service: BitableService,
poll_interval: float = _DEFAULT_POLL_INTERVAL,
reaper_interval: float = _DEFAULT_REAPER_INTERVAL,
stale_threshold: float = _DEFAULT_STALE_THRESHOLD,
) -> None:
self._db = db
self._service = service
self._repo = BitableRepository(db)
self._poll_interval = poll_interval
self._reaper_interval = reaper_interval
self._stale_threshold = stale_threshold
self._task: asyncio.Task[None] | None = None
self._reaper_task: asyncio.Task[None] | None = None
self._stop_event = asyncio.Event()
# Per-table formula engines (cached, rebuilt when fields change)
self._engines: dict[str, FormulaEngine] = {}
async def start(self) -> None:
"""Start the worker. Performs crash recovery first."""
# Crash recovery: reset stale 'calculating' tasks to 'pending'.
# On startup, all calculating tasks are stale (worker was down).
# Use threshold=0 to reset all calculating tasks immediately.
reset_count = await self._repo.reset_stale_recalc_tasks(stale_threshold=0.0)
if reset_count > 0:
logger.info("RecalcWorker: reset %d stale tasks to pending", reset_count)
self._stop_event.clear()
self._task = asyncio.create_task(self._run(), name="recalc-worker")
self._reaper_task = asyncio.create_task(self._run_reaper(), name="recalc-reaper")
logger.info("RecalcWorker started")
async def stop(self) -> None:
"""Gracefully stop the worker."""
self._stop_event.set()
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
if self._reaper_task is not None:
self._reaper_task.cancel()
try:
await self._reaper_task
except asyncio.CancelledError:
pass
self._reaper_task = None
logger.info("RecalcWorker stopped")
async def _run(self) -> None:
"""Main worker loop — poll queue, process tasks."""
while not self._stop_event.is_set():
try:
# Atomic claim (P1 #6): tasks are marked 'calculating' in the
# same transaction, so concurrent workers never grab the same task.
tasks = await self._repo.claim_recalc_tasks(limit=10)
if not tasks:
await asyncio.sleep(self._poll_interval)
continue
# P1 #7: sort tasks by topological order so formula-to-formula
# dependencies resolve correctly within a batch. Tasks for the
# same record are ordered so that if B depends on A, A is
# processed first (A's result is written back before B reads it).
tasks = await self._sort_by_topological_order(tasks)
for task in tasks:
if self._stop_event.is_set():
break
await self.process_task(task)
except asyncio.CancelledError:
break
except Exception:
logger.exception("RecalcWorker error in main loop")
await asyncio.sleep(self._poll_interval)
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
"""Sort claimed tasks so dependencies are processed first (P1 #7).
Groups tasks by table_id, builds (or reuses) the engine to get the
topological order, and assigns each task a sort key of
``(table_id, record_id, topo_index)``. Tasks for fields not in the
DAG get topo_index = infinity (processed last).
"""
if len(tasks) <= 1:
return tasks
# Build topo index per table: {table_id: {field_id: position}}
topo_index: dict[str, dict[str, int]] = {}
table_ids = {t.table_id for t in tasks}
for tid in table_ids:
engine = await self._get_or_build_engine(tid)
if engine is None:
topo_index[tid] = {}
continue
order = engine.topological_order()
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
def _key(t: RecalcTask) -> tuple[str, str, int]:
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
return (t.table_id, t.record_id, idx)
return sorted(tasks, key=_key)
async def _run_reaper(self) -> None:
"""Reaper loop — reset stale calculating tasks periodically.
Only resets tasks older than ``stale_threshold`` (P1 #10), so active
tasks being processed by a live worker are not interrupted.
"""
while not self._stop_event.is_set():
try:
await asyncio.sleep(self._reaper_interval)
count = await self._repo.reset_stale_recalc_tasks(
stale_threshold=self._stale_threshold
)
if count > 0:
logger.info(
"RecalcWorker reaper: reset %d stale tasks (threshold=%ss)",
count,
self._stale_threshold,
)
except asyncio.CancelledError:
break
except Exception:
logger.exception("RecalcWorker reaper error")
async def process_task(self, task: RecalcTask) -> None:
"""Process a single recalc task: evaluate formula → write result.
The task is expected to already be in ``calculating`` status when
called from the worker loop (atomic claim sets it). When called
synchronously via ``service.process_recalc_task``, this method
marks it calculating first (idempotent — re-marking is harmless).
"""
# Idempotent: mark calculating (no-op if already calculating via claim).
await self._repo.update_recalc_status(task.id, RecalcStatus.calculating)
try:
field = await self._repo.get_field(task.field_id)
if field is None or field.field_type != FieldType.formula:
await self._repo.update_recalc_status(
task.id, RecalcStatus.error, "Field not found or not a formula"
)
return
formula_expr = field.config.get("formula_expr", "")
if not formula_expr:
await self._repo.update_recalc_status(
task.id, RecalcStatus.error, "No formula_expr in field config"
)
return
engine = await self._get_or_build_engine(task.table_id)
if engine is None:
await self._repo.update_recalc_status(
task.id, RecalcStatus.error, "No formula fields in table"
)
return
record = await self._repo.get_record(task.record_id)
if record is None:
await self._repo.update_recalc_status(
task.id, RecalcStatus.error, "Record not found"
)
return
deps = engine.get_dependencies(task.field_id)
column_values: dict[str, list[object]] = {}
for dep_field_id in deps:
column_values[dep_field_id] = await self._repo.get_column_values(
task.table_id, dep_field_id
)
result = engine.evaluate(
task.field_id,
row_values=record.values,
column_values=column_values,
)
await self._repo.set_formula_value(task.record_id, task.field_id, result)
await self._repo.update_recalc_status(task.id, RecalcStatus.done)
except Exception as e:
logger.exception("RecalcWorker: error processing task %s", task.id)
await self._repo.update_recalc_status(task.id, RecalcStatus.error, str(e)[:500])
async def _get_or_build_engine(self, table_id: str) -> FormulaEngine | None:
"""Get or build a FormulaEngine for a table.
Returns None if the table has no formula fields.
"""
if table_id in self._engines:
return self._engines[table_id]
fields = await self._repo.list_fields(table_id)
formula_fields = [f for f in fields if f.field_type == FieldType.formula]
if not formula_fields:
return None
engine = FormulaEngine()
for f in formula_fields:
formula_expr = f.config.get("formula_expr", "")
if formula_expr:
try:
engine.add_formula(f.id, formula_expr)
except Exception:
logger.exception("RecalcWorker: failed to register formula for field %s", f.id)
self._engines[table_id] = engine
return engine
def invalidate_engine(self, table_id: str) -> None:
"""Invalidate the cached formula engine for a table (call when fields change)."""
self._engines.pop(table_id, None)