266 lines
10 KiB
Python
266 lines
10 KiB
Python
"""Async recalc worker for formula fields.
|
|
|
|
Consumes recalc tasks from the queue, evaluates formulas, and writes results
|
|
back to records. Supports crash recovery (resets stale ``calculating`` tasks
|
|
on startup) and graceful shutdown.
|
|
|
|
Lifecycle (managed by app.py lifespan):
|
|
worker = RecalcWorker(db, service)
|
|
await worker.start() # starts background task + crash recovery
|
|
...
|
|
await worker.stop() # waits for shutdown
|
|
|
|
The worker runs as an asyncio task, polling the queue every ``poll_interval``
|
|
seconds. Each task is processed in its own transaction.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
|
|
from agentkit.bitable.db import BitableDB
|
|
from agentkit.bitable.formula.engine import FormulaEngine
|
|
from agentkit.bitable.models import FieldType, RecalcStatus, RecalcTask
|
|
from agentkit.bitable.repository import BitableRepository
|
|
from agentkit.bitable.service import BitableService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_POLL_INTERVAL = 0.5 # seconds between queue polls
|
|
_DEFAULT_REAPER_INTERVAL = 300 # 5 minutes
|
|
_DEFAULT_STALE_THRESHOLD = 600 # 10 minutes
|
|
|
|
|
|
class RecalcWorker:
|
|
"""Background worker that processes formula recalc tasks.
|
|
|
|
Usage::
|
|
|
|
worker = RecalcWorker(db, service)
|
|
await worker.start()
|
|
# ... worker runs in background ...
|
|
await worker.stop()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db: BitableDB,
|
|
service: BitableService,
|
|
poll_interval: float = _DEFAULT_POLL_INTERVAL,
|
|
reaper_interval: float = _DEFAULT_REAPER_INTERVAL,
|
|
stale_threshold: float = _DEFAULT_STALE_THRESHOLD,
|
|
) -> None:
|
|
self._db = db
|
|
self._service = service
|
|
self._repo = BitableRepository(db)
|
|
self._poll_interval = poll_interval
|
|
self._reaper_interval = reaper_interval
|
|
self._stale_threshold = stale_threshold
|
|
self._task: asyncio.Task[None] | None = None
|
|
self._reaper_task: asyncio.Task[None] | None = None
|
|
self._stop_event = asyncio.Event()
|
|
# Per-table formula engines (cached, rebuilt when fields change)
|
|
self._engines: dict[str, FormulaEngine] = {}
|
|
|
|
async def start(self) -> None:
|
|
"""Start the worker. Performs crash recovery first."""
|
|
# Crash recovery: reset stale 'calculating' tasks to 'pending'.
|
|
# On startup, all calculating tasks are stale (worker was down).
|
|
# Use threshold=0 to reset all calculating tasks immediately.
|
|
reset_count = await self._repo.reset_stale_recalc_tasks(stale_threshold=0.0)
|
|
if reset_count > 0:
|
|
logger.info("RecalcWorker: reset %d stale tasks to pending", reset_count)
|
|
|
|
self._stop_event.clear()
|
|
self._task = asyncio.create_task(self._run(), name="recalc-worker")
|
|
self._reaper_task = asyncio.create_task(self._run_reaper(), name="recalc-reaper")
|
|
logger.info("RecalcWorker started")
|
|
|
|
async def stop(self) -> None:
|
|
"""Gracefully stop the worker."""
|
|
self._stop_event.set()
|
|
if self._task is not None:
|
|
self._task.cancel()
|
|
try:
|
|
await self._task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._task = None
|
|
if self._reaper_task is not None:
|
|
self._reaper_task.cancel()
|
|
try:
|
|
await self._reaper_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._reaper_task = None
|
|
logger.info("RecalcWorker stopped")
|
|
|
|
async def _run(self) -> None:
|
|
"""Main worker loop — poll queue, process tasks."""
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
# Atomic claim (P1 #6): tasks are marked 'calculating' in the
|
|
# same transaction, so concurrent workers never grab the same task.
|
|
tasks = await self._repo.claim_recalc_tasks(limit=10)
|
|
if not tasks:
|
|
await asyncio.sleep(self._poll_interval)
|
|
continue
|
|
|
|
# P1 #7: sort tasks by topological order so formula-to-formula
|
|
# dependencies resolve correctly within a batch. Tasks for the
|
|
# same record are ordered so that if B depends on A, A is
|
|
# processed first (A's result is written back before B reads it).
|
|
tasks = await self._sort_by_topological_order(tasks)
|
|
|
|
for task in tasks:
|
|
if self._stop_event.is_set():
|
|
break
|
|
await self.process_task(task)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception:
|
|
logger.exception("RecalcWorker error in main loop")
|
|
await asyncio.sleep(self._poll_interval)
|
|
|
|
async def _sort_by_topological_order(self, tasks: list[RecalcTask]) -> list[RecalcTask]:
|
|
"""Sort claimed tasks so dependencies are processed first (P1 #7).
|
|
|
|
Groups tasks by table_id, builds (or reuses) the engine to get the
|
|
topological order, and assigns each task a sort key of
|
|
``(table_id, record_id, topo_index)``. Tasks for fields not in the
|
|
DAG get topo_index = infinity (processed last).
|
|
"""
|
|
if len(tasks) <= 1:
|
|
return tasks
|
|
|
|
# Build topo index per table: {table_id: {field_id: position}}
|
|
topo_index: dict[str, dict[str, int]] = {}
|
|
table_ids = {t.table_id for t in tasks}
|
|
for tid in table_ids:
|
|
engine = await self._get_or_build_engine(tid)
|
|
if engine is None:
|
|
topo_index[tid] = {}
|
|
continue
|
|
order = engine.topological_order()
|
|
topo_index[tid] = {fid: i for i, fid in enumerate(order)}
|
|
|
|
def _key(t: RecalcTask) -> tuple[str, str, int]:
|
|
idx = topo_index.get(t.table_id, {}).get(t.field_id, 1 << 30)
|
|
return (t.table_id, t.record_id, idx)
|
|
|
|
return sorted(tasks, key=_key)
|
|
|
|
async def _run_reaper(self) -> None:
|
|
"""Reaper loop — reset stale calculating tasks periodically.
|
|
|
|
Only resets tasks older than ``stale_threshold`` (P1 #10), so active
|
|
tasks being processed by a live worker are not interrupted.
|
|
"""
|
|
while not self._stop_event.is_set():
|
|
try:
|
|
await asyncio.sleep(self._reaper_interval)
|
|
count = await self._repo.reset_stale_recalc_tasks(
|
|
stale_threshold=self._stale_threshold
|
|
)
|
|
if count > 0:
|
|
logger.info(
|
|
"RecalcWorker reaper: reset %d stale tasks (threshold=%ss)",
|
|
count,
|
|
self._stale_threshold,
|
|
)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception:
|
|
logger.exception("RecalcWorker reaper error")
|
|
|
|
async def process_task(self, task: RecalcTask) -> None:
|
|
"""Process a single recalc task: evaluate formula → write result.
|
|
|
|
The task is expected to already be in ``calculating`` status when
|
|
called from the worker loop (atomic claim sets it). When called
|
|
synchronously via ``service.process_recalc_task``, this method
|
|
marks it calculating first (idempotent — re-marking is harmless).
|
|
"""
|
|
# Idempotent: mark calculating (no-op if already calculating via claim).
|
|
await self._repo.update_recalc_status(task.id, RecalcStatus.calculating)
|
|
|
|
try:
|
|
field = await self._repo.get_field(task.field_id)
|
|
if field is None or field.field_type != FieldType.formula:
|
|
await self._repo.update_recalc_status(
|
|
task.id, RecalcStatus.error, "Field not found or not a formula"
|
|
)
|
|
return
|
|
|
|
formula_expr = field.config.get("formula_expr", "")
|
|
if not formula_expr:
|
|
await self._repo.update_recalc_status(
|
|
task.id, RecalcStatus.error, "No formula_expr in field config"
|
|
)
|
|
return
|
|
|
|
engine = await self._get_or_build_engine(task.table_id)
|
|
if engine is None:
|
|
await self._repo.update_recalc_status(
|
|
task.id, RecalcStatus.error, "No formula fields in table"
|
|
)
|
|
return
|
|
|
|
record = await self._repo.get_record(task.record_id)
|
|
if record is None:
|
|
await self._repo.update_recalc_status(
|
|
task.id, RecalcStatus.error, "Record not found"
|
|
)
|
|
return
|
|
|
|
deps = engine.get_dependencies(task.field_id)
|
|
column_values: dict[str, list[object]] = {}
|
|
for dep_field_id in deps:
|
|
column_values[dep_field_id] = await self._repo.get_column_values(
|
|
task.table_id, dep_field_id
|
|
)
|
|
|
|
result = engine.evaluate(
|
|
task.field_id,
|
|
row_values=record.values,
|
|
column_values=column_values,
|
|
)
|
|
|
|
await self._repo.set_formula_value(task.record_id, task.field_id, result)
|
|
await self._repo.update_recalc_status(task.id, RecalcStatus.done)
|
|
|
|
except Exception as e:
|
|
logger.exception("RecalcWorker: error processing task %s", task.id)
|
|
await self._repo.update_recalc_status(task.id, RecalcStatus.error, str(e)[:500])
|
|
|
|
async def _get_or_build_engine(self, table_id: str) -> FormulaEngine | None:
|
|
"""Get or build a FormulaEngine for a table.
|
|
|
|
Returns None if the table has no formula fields.
|
|
"""
|
|
if table_id in self._engines:
|
|
return self._engines[table_id]
|
|
|
|
fields = await self._repo.list_fields(table_id)
|
|
formula_fields = [f for f in fields if f.field_type == FieldType.formula]
|
|
|
|
if not formula_fields:
|
|
return None
|
|
|
|
engine = FormulaEngine()
|
|
for f in formula_fields:
|
|
formula_expr = f.config.get("formula_expr", "")
|
|
if formula_expr:
|
|
try:
|
|
engine.add_formula(f.id, formula_expr)
|
|
except Exception:
|
|
logger.exception("RecalcWorker: failed to register formula for field %s", f.id)
|
|
|
|
self._engines[table_id] = engine
|
|
return engine
|
|
|
|
def invalidate_engine(self, table_id: str) -> None:
|
|
"""Invalidate the cached formula engine for a table (call when fields change)."""
|
|
self._engines.pop(table_id, None)
|