refactor(bitable,tools): replace Any with concrete types + Protocol (U4)
BitableRecord/FormulaResult/SessionState TypeAlias replace dict[str, Any]; _redis/_engine/_session_factory typed as object | None with TYPE_CHECKING Protocol (_RedisLike, _RecalcWorker); Coroutine[Any, Any, Any] retained as legitimate type param. Baseline 40 : Any occurrences -> 0 across 6 in-scope files (target <=5). Deferred: repository.py/recalc_worker.py/ingestion/* (10 occurrences, separate PR). ruff clean; 367 passed + 116 skipped (bitable + pipeline_state + tools).
This commit is contained in:
parent
be5c4e09f8
commit
1033346913
|
|
@ -18,7 +18,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from types import TracebackType
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column,
|
Column,
|
||||||
|
|
@ -191,7 +191,7 @@ class MetaModel(BitableBase):
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _apply_v2_migration(conn: Any) -> None:
|
async def _apply_v2_migration(conn: object) -> None:
|
||||||
"""V2 migration: create ``bitable_files`` table + add ``file_id`` to tables.
|
"""V2 migration: create ``bitable_files`` table + add ``file_id`` to tables.
|
||||||
|
|
||||||
Idempotent — safe to call on fresh installs (``create_all`` already made
|
Idempotent — safe to call on fresh installs (``create_all`` already made
|
||||||
|
|
@ -265,8 +265,8 @@ class BitableDB:
|
||||||
|
|
||||||
def __init__(self, database_url: str | None = None) -> None:
|
def __init__(self, database_url: str | None = None) -> None:
|
||||||
self._database_url = database_url or _resolve_database_url()
|
self._database_url = database_url or _resolve_database_url()
|
||||||
self._engine: Any = None
|
self._engine: object | None = None
|
||||||
self._session_factory: Any = None
|
self._session_factory: object | None = None
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
@ -275,11 +275,11 @@ class BitableDB:
|
||||||
return self._database_url
|
return self._database_url
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def engine(self) -> Any:
|
def engine(self) -> object | None:
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session_factory(self) -> Any:
|
def session_factory(self) -> object | None:
|
||||||
return self._session_factory
|
return self._session_factory
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -365,7 +365,12 @@ class BitableDB:
|
||||||
await self.init()
|
await self.init()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,17 @@ based on the calling context — see :mod:`agentkit.bitable.formula.engine`.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable
|
from typing import Callable, TypeAlias
|
||||||
|
|
||||||
|
# A formula evaluates to a scalar primitive: text, number, or nothing.
|
||||||
|
# bool is intentionally excluded — comparisons live in the parser layer
|
||||||
|
# and never reach the function registry.
|
||||||
|
FormulaResult: TypeAlias = str | int | float | None
|
||||||
|
|
||||||
# ── Aggregate functions (operate on lists) ────────────────
|
# ── Aggregate functions (operate on lists) ────────────────
|
||||||
|
|
||||||
|
|
||||||
def _sum(values: list[Any]) -> float | int:
|
def _sum(values: list[FormulaResult]) -> float | int:
|
||||||
"""Sum of numeric values, ignoring None/empty."""
|
"""Sum of numeric values, ignoring None/empty."""
|
||||||
total = 0
|
total = 0
|
||||||
for v in values:
|
for v in values:
|
||||||
|
|
@ -27,7 +32,7 @@ def _sum(values: list[Any]) -> float | int:
|
||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
def _avg(values: list[Any]) -> float:
|
def _avg(values: list[FormulaResult]) -> float:
|
||||||
"""Average of numeric values, ignoring None/empty."""
|
"""Average of numeric values, ignoring None/empty."""
|
||||||
nums = [v for v in values if v is not None and v != ""]
|
nums = [v for v in values if v is not None and v != ""]
|
||||||
if not nums:
|
if not nums:
|
||||||
|
|
@ -35,12 +40,12 @@ def _avg(values: list[Any]) -> float:
|
||||||
return sum(nums) / len(nums)
|
return sum(nums) / len(nums)
|
||||||
|
|
||||||
|
|
||||||
def _count(values: list[Any]) -> int:
|
def _count(values: list[FormulaResult]) -> int:
|
||||||
"""Count of non-empty values."""
|
"""Count of non-empty values."""
|
||||||
return sum(1 for v in values if v is not None and v != "")
|
return sum(1 for v in values if v is not None and v != "")
|
||||||
|
|
||||||
|
|
||||||
def _min(values: list[Any]) -> Any:
|
def _min(values: list[FormulaResult]) -> FormulaResult:
|
||||||
"""Minimum of numeric values, ignoring None/empty."""
|
"""Minimum of numeric values, ignoring None/empty."""
|
||||||
nums = [v for v in values if v is not None and v != ""]
|
nums = [v for v in values if v is not None and v != ""]
|
||||||
if not nums:
|
if not nums:
|
||||||
|
|
@ -48,7 +53,7 @@ def _min(values: list[Any]) -> Any:
|
||||||
return min(nums)
|
return min(nums)
|
||||||
|
|
||||||
|
|
||||||
def _max(values: list[Any]) -> Any:
|
def _max(values: list[FormulaResult]) -> FormulaResult:
|
||||||
"""Maximum of numeric values, ignoring None/empty."""
|
"""Maximum of numeric values, ignoring None/empty."""
|
||||||
nums = [v for v in values if v is not None and v != ""]
|
nums = [v for v in values if v is not None and v != ""]
|
||||||
if not nums:
|
if not nums:
|
||||||
|
|
@ -59,25 +64,29 @@ def _max(values: list[Any]) -> Any:
|
||||||
# ── Scalar functions ──────────────────────────────────────
|
# ── Scalar functions ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _abs(value: Any) -> Any:
|
def _abs(value: FormulaResult) -> FormulaResult:
|
||||||
return abs(value)
|
return abs(value)
|
||||||
|
|
||||||
|
|
||||||
def _round(value: Any, digits: int = 0) -> float:
|
def _round(value: FormulaResult, digits: int = 0) -> float:
|
||||||
return round(value, digits)
|
return round(value, digits)
|
||||||
|
|
||||||
|
|
||||||
def _if(condition: Any, true_val: Any, false_val: Any = None) -> Any:
|
def _if(
|
||||||
|
condition: FormulaResult,
|
||||||
|
true_val: FormulaResult,
|
||||||
|
false_val: FormulaResult = None,
|
||||||
|
) -> FormulaResult:
|
||||||
return true_val if condition else false_val
|
return true_val if condition else false_val
|
||||||
|
|
||||||
|
|
||||||
def _len(value: Any) -> int:
|
def _len(value: FormulaResult) -> int:
|
||||||
if value is None:
|
if value is None:
|
||||||
return 0
|
return 0
|
||||||
return len(str(value))
|
return len(str(value))
|
||||||
|
|
||||||
|
|
||||||
def _concat(*args: Any) -> str:
|
def _concat(*args: FormulaResult) -> str:
|
||||||
"""Concatenate all arguments as strings."""
|
"""Concatenate all arguments as strings."""
|
||||||
return "".join(str(a) for a in args if a is not None)
|
return "".join(str(a) for a in args if a is not None)
|
||||||
|
|
||||||
|
|
@ -87,7 +96,7 @@ def _concat(*args: Any) -> str:
|
||||||
# Functions that aggregate a column (receive a list of all column values)
|
# Functions that aggregate a column (receive a list of all column values)
|
||||||
AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"})
|
AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"})
|
||||||
|
|
||||||
FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = {
|
FUNCTION_REGISTRY: dict[str, Callable[..., FormulaResult]] = {
|
||||||
"SUM": _sum,
|
"SUM": _sum,
|
||||||
"AVG": _avg,
|
"AVG": _avg,
|
||||||
"COUNT": _count,
|
"COUNT": _count,
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Callable
|
||||||
|
|
||||||
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY
|
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY, FormulaResult
|
||||||
|
|
||||||
# ── Exceptions ────────────────────────────────────────────
|
# ── Exceptions ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -184,9 +184,9 @@ def parse_formula(
|
||||||
|
|
||||||
def evaluate_ast(
|
def evaluate_ast(
|
||||||
tree: ast.Expression,
|
tree: ast.Expression,
|
||||||
field_values: dict[str, Any],
|
field_values: dict[str, FormulaResult | list[FormulaResult]],
|
||||||
functions: dict[str, Any],
|
functions: dict[str, Callable[..., FormulaResult]],
|
||||||
) -> Any:
|
) -> FormulaResult:
|
||||||
"""Evaluate a parsed formula AST against field values and functions.
|
"""Evaluate a parsed formula AST against field values and functions.
|
||||||
|
|
||||||
This is NOT ``eval()`` — it's a manual AST walker that only processes
|
This is NOT ``eval()`` — it's a manual AST walker that only processes
|
||||||
|
|
@ -204,7 +204,11 @@ def evaluate_ast(
|
||||||
return _eval_node(tree.body, field_values, functions)
|
return _eval_node(tree.body, field_values, functions)
|
||||||
|
|
||||||
|
|
||||||
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any:
|
def _eval_node(
|
||||||
|
node: ast.AST,
|
||||||
|
fields: dict[str, FormulaResult | list[FormulaResult]],
|
||||||
|
functions: dict[str, Callable[..., FormulaResult]],
|
||||||
|
) -> FormulaResult:
|
||||||
"""Recursively evaluate an AST node."""
|
"""Recursively evaluate an AST node."""
|
||||||
if isinstance(node, ast.Constant):
|
if isinstance(node, ast.Constant):
|
||||||
return node.value
|
return node.value
|
||||||
|
|
@ -274,7 +278,7 @@ def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any])
|
||||||
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
|
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
|
||||||
|
|
||||||
|
|
||||||
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
|
def _apply_binop(op: ast.AST, left: FormulaResult, right: FormulaResult) -> FormulaResult:
|
||||||
"""Apply a binary operator."""
|
"""Apply a binary operator."""
|
||||||
if isinstance(op, ast.Add):
|
if isinstance(op, ast.Add):
|
||||||
# String concat or numeric addition
|
# String concat or numeric addition
|
||||||
|
|
@ -294,7 +298,7 @@ def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
|
||||||
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
|
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
|
||||||
|
|
||||||
|
|
||||||
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool:
|
def _apply_compare(op: ast.AST, left: FormulaResult, right: FormulaResult) -> bool:
|
||||||
"""Apply a comparison operator."""
|
"""Apply a comparison operator."""
|
||||||
if isinstance(op, ast.Eq):
|
if isinstance(op, ast.Eq):
|
||||||
return left == right
|
return left == right
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
|
||||||
from agentkit.bitable.db import BitableDB
|
from agentkit.bitable.db import BitableDB
|
||||||
from agentkit.bitable.models import (
|
from agentkit.bitable.models import (
|
||||||
|
|
@ -29,13 +29,27 @@ from agentkit.bitable.models import (
|
||||||
)
|
)
|
||||||
from agentkit.bitable.repository import BitableRepository
|
from agentkit.bitable.repository import BitableRepository
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class _RecalcWorker(Protocol):
|
||||||
|
"""Structural type for the recalc worker's cache-invalidation surface."""
|
||||||
|
|
||||||
|
def invalidate_engine(self, table_id: str) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Record values are JSON scalars (text/number/none). Attachment/image fields
|
||||||
|
# store lists of dicts at runtime, but the common-case scalar shape is captured
|
||||||
|
# here for annotation clarity; fall back to dict[str, object] where lists occur.
|
||||||
|
BitableRecord: TypeAlias = dict[str, str | int | float | None]
|
||||||
|
|
||||||
|
|
||||||
class FieldDependencyError(Exception):
|
class FieldDependencyError(Exception):
|
||||||
"""Raised when deleting a field that has dependencies (formula refs, PK, views)."""
|
"""Raised when deleting a field that has dependencies (formula refs, PK, views)."""
|
||||||
|
|
||||||
def __init__(self, message: str, dependencies: dict[str, Any]) -> None:
|
def __init__(self, message: str, dependencies: dict[str, object]) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.dependencies = dependencies
|
self.dependencies = dependencies
|
||||||
|
|
||||||
|
|
@ -52,13 +66,13 @@ class BitableService:
|
||||||
def __init__(self, db: BitableDB) -> None:
|
def __init__(self, db: BitableDB) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
self._repo = BitableRepository(db)
|
self._repo = BitableRepository(db)
|
||||||
self._recalc_worker: Any = None # RecalcWorker, set via set_recalc_worker
|
self._recalc_worker: _RecalcWorker | None = None # set via set_recalc_worker
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo(self) -> BitableRepository:
|
def repo(self) -> BitableRepository:
|
||||||
return self._repo
|
return self._repo
|
||||||
|
|
||||||
def set_recalc_worker(self, worker: Any) -> None:
|
def set_recalc_worker(self, worker: _RecalcWorker) -> None:
|
||||||
"""Register the long-lived RecalcWorker so field changes can invalidate its engine cache.
|
"""Register the long-lived RecalcWorker so field changes can invalidate its engine cache.
|
||||||
|
|
||||||
Called after both service and worker are constructed (breaks the
|
Called after both service and worker are constructed (breaks the
|
||||||
|
|
@ -95,7 +109,7 @@ class BitableService:
|
||||||
async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]:
|
async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]:
|
||||||
return await self._repo.list_files(owner_user_id=owner_user_id)
|
return await self._repo.list_files(owner_user_id=owner_user_id)
|
||||||
|
|
||||||
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
|
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
|
||||||
return await self._repo.update_file(file_id, **kwargs)
|
return await self._repo.update_file(file_id, **kwargs)
|
||||||
|
|
||||||
async def delete_file(self, file_id: str) -> bool:
|
async def delete_file(self, file_id: str) -> bool:
|
||||||
|
|
@ -162,7 +176,7 @@ class BitableService:
|
||||||
async def list_tables(self, owner_user_id: str | None = None) -> list[Table]:
|
async def list_tables(self, owner_user_id: str | None = None) -> list[Table]:
|
||||||
return await self._repo.list_tables(owner_user_id=owner_user_id)
|
return await self._repo.list_tables(owner_user_id=owner_user_id)
|
||||||
|
|
||||||
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
|
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
|
||||||
"""Update table attrs. Creates PK unique index if primary_key_field_id is set."""
|
"""Update table attrs. Creates PK unique index if primary_key_field_id is set."""
|
||||||
table = await self._repo.update_table(table_id, **kwargs)
|
table = await self._repo.update_table(table_id, **kwargs)
|
||||||
if table and kwargs.get("primary_key_field_id"):
|
if table and kwargs.get("primary_key_field_id"):
|
||||||
|
|
@ -179,7 +193,7 @@ class BitableService:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
field_type: FieldType,
|
field_type: FieldType,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
owner: FieldOwner = FieldOwner.user,
|
owner: FieldOwner = FieldOwner.user,
|
||||||
) -> Field:
|
) -> Field:
|
||||||
"""Create a new field. U2 will add formula validation and DAG updates."""
|
"""Create a new field. U2 will add formula validation and DAG updates."""
|
||||||
|
|
@ -201,7 +215,7 @@ class BitableService:
|
||||||
async def list_fields(self, table_id: str) -> list[Field]:
|
async def list_fields(self, table_id: str) -> list[Field]:
|
||||||
return await self._repo.list_fields(table_id)
|
return await self._repo.list_fields(table_id)
|
||||||
|
|
||||||
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
|
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
|
||||||
"""Update a field. U2 will add dependency checking."""
|
"""Update a field. U2 will add dependency checking."""
|
||||||
field = await self._repo.update_field(field_id, **kwargs)
|
field = await self._repo.update_field(field_id, **kwargs)
|
||||||
if field is not None:
|
if field is not None:
|
||||||
|
|
@ -220,7 +234,7 @@ class BitableService:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check dependencies
|
# Check dependencies
|
||||||
deps: dict[str, Any] = {}
|
deps: dict[str, object] = {}
|
||||||
|
|
||||||
# 1. Is it a primary key field?
|
# 1. Is it a primary key field?
|
||||||
table = await self._repo.get_table(field.table_id)
|
table = await self._repo.get_table(field.table_id)
|
||||||
|
|
@ -264,7 +278,7 @@ class BitableService:
|
||||||
async def create_record(
|
async def create_record(
|
||||||
self,
|
self,
|
||||||
table_id: str,
|
table_id: str,
|
||||||
values: dict[str, Any] | None = None,
|
values: BitableRecord | None = None,
|
||||||
actor_user_id: str | None = None,
|
actor_user_id: str | None = None,
|
||||||
) -> Record:
|
) -> Record:
|
||||||
"""Create a new record. Triggers recalc for affected formula fields.
|
"""Create a new record. Triggers recalc for affected formula fields.
|
||||||
|
|
@ -291,7 +305,7 @@ class BitableService:
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def create_records_batch(
|
async def create_records_batch(
|
||||||
self, table_id: str, records_values: list[dict[str, Any]]
|
self, table_id: str, records_values: list[BitableRecord]
|
||||||
) -> list[Record]:
|
) -> list[Record]:
|
||||||
"""Batch-create records (P2 #19). Triggers recalc for each record.
|
"""Batch-create records (P2 #19). Triggers recalc for each record.
|
||||||
|
|
||||||
|
|
@ -319,8 +333,8 @@ class BitableService:
|
||||||
async def list_records_filtered(
|
async def list_records_filtered(
|
||||||
self,
|
self,
|
||||||
table_id: str,
|
table_id: str,
|
||||||
filters: list[dict[str, Any]] | None = None,
|
filters: list[dict[str, object]] | None = None,
|
||||||
sorts: list[dict[str, Any]] | None = None,
|
sorts: list[dict[str, object]] | None = None,
|
||||||
cursor: str | None = None,
|
cursor: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
) -> tuple[list[Record], str | None]:
|
) -> tuple[list[Record], str | None]:
|
||||||
|
|
@ -345,7 +359,7 @@ class BitableService:
|
||||||
table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit
|
table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
|
async def update_record_values(self, record_id: str, values: BitableRecord) -> Record | None:
|
||||||
"""Update a record's values (full replace). Triggers recalc for affected formulas."""
|
"""Update a record's values (full replace). Triggers recalc for affected formulas."""
|
||||||
record = await self._repo.update_record_values(record_id, values)
|
record = await self._repo.update_record_values(record_id, values)
|
||||||
if record is not None:
|
if record is not None:
|
||||||
|
|
@ -431,9 +445,9 @@ class BitableService:
|
||||||
async def upsert_records(
|
async def upsert_records(
|
||||||
self,
|
self,
|
||||||
table_id: str,
|
table_id: str,
|
||||||
records: list[dict[str, Any]],
|
records: list[BitableRecord],
|
||||||
primary_key_field_id: str,
|
primary_key_field_id: str,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, int]:
|
||||||
"""Upsert records by primary key using jsonb_set (KTD8).
|
"""Upsert records by primary key using jsonb_set (KTD8).
|
||||||
|
|
||||||
For each record:
|
For each record:
|
||||||
|
|
@ -454,12 +468,12 @@ class BitableService:
|
||||||
agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent}
|
agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent}
|
||||||
|
|
||||||
# Partition records into insert vs update lists, collecting PK values.
|
# Partition records into insert vs update lists, collecting PK values.
|
||||||
to_insert: list[dict[str, Any]] = []
|
to_insert: list[BitableRecord] = []
|
||||||
to_update: list[tuple[dict[str, Any], str]] = [] # (values, existing_record_id)
|
to_update: list[tuple[BitableRecord, str]] = [] # (values, existing_record_id)
|
||||||
skipped = 0
|
skipped = 0
|
||||||
|
|
||||||
# Collect all non-None PK values for batch lookup.
|
# Collect all non-None PK values for batch lookup.
|
||||||
pk_values_by_str: dict[str, dict[str, Any]] = {}
|
pk_values_by_str: dict[str, BitableRecord] = {}
|
||||||
for rec_values in records:
|
for rec_values in records:
|
||||||
pk_value = rec_values.get(primary_key_field_id)
|
pk_value = rec_values.get(primary_key_field_id)
|
||||||
if pk_value is None:
|
if pk_value is None:
|
||||||
|
|
@ -504,7 +518,7 @@ class BitableService:
|
||||||
table_id: str,
|
table_id: str,
|
||||||
name: str,
|
name: str,
|
||||||
view_type: ViewType = ViewType.grid,
|
view_type: ViewType = ViewType.grid,
|
||||||
config: dict[str, Any] | None = None,
|
config: dict[str, object] | None = None,
|
||||||
) -> View:
|
) -> View:
|
||||||
return await self._repo.create_view(
|
return await self._repo.create_view(
|
||||||
table_id=table_id,
|
table_id=table_id,
|
||||||
|
|
@ -516,7 +530,7 @@ class BitableService:
|
||||||
async def list_views(self, table_id: str) -> list[View]:
|
async def list_views(self, table_id: str) -> list[View]:
|
||||||
return await self._repo.list_views(table_id)
|
return await self._repo.list_views(table_id)
|
||||||
|
|
||||||
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
|
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
|
||||||
return await self._repo.update_view(view_id, **kwargs)
|
return await self._repo.update_view(view_id, **kwargs)
|
||||||
|
|
||||||
async def get_view(self, view_id: str) -> View | None:
|
async def get_view(self, view_id: str) -> View | None:
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,14 @@ class PipelineStateMemory:
|
||||||
"""In-memory pipeline state storage (testing / fallback)."""
|
"""In-memory pipeline state storage (testing / fallback)."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._executions: dict[str, dict[str, Any]] = {}
|
self._executions: dict[str, dict[str, object]] = {}
|
||||||
self._step_history: dict[str, list[dict[str, Any]]] = {}
|
self._step_history: dict[str, list[dict[str, object]]] = {}
|
||||||
|
|
||||||
async def create_execution(
|
async def create_execution(
|
||||||
self,
|
self,
|
||||||
pipeline_name: str,
|
pipeline_name: str,
|
||||||
steps: list[str],
|
steps: list[str],
|
||||||
input_data: dict[str, Any] | None = None,
|
input_data: dict[str, object] | None = None,
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
execution_id = str(uuid.uuid4())
|
execution_id = str(uuid.uuid4())
|
||||||
|
|
@ -67,7 +67,7 @@ class PipelineStateMemory:
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
step_name: str,
|
step_name: str,
|
||||||
status: str,
|
status: str,
|
||||||
output: dict[str, Any] | None = None,
|
output: dict[str, object] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
duration_ms: int | None = None,
|
duration_ms: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -88,7 +88,7 @@ class PipelineStateMemory:
|
||||||
exec_state["error_message"] = error
|
exec_state["error_message"] = error
|
||||||
|
|
||||||
# Record step history event
|
# Record step history event
|
||||||
step_event: dict[str, Any] = {
|
step_event: dict[str, object] = {
|
||||||
"id": str(uuid.uuid4()),
|
"id": str(uuid.uuid4()),
|
||||||
"execution_id": execution_id,
|
"execution_id": execution_id,
|
||||||
"step_name": step_name,
|
"step_name": step_name,
|
||||||
|
|
@ -97,14 +97,16 @@ class PipelineStateMemory:
|
||||||
"error_message": error,
|
"error_message": error,
|
||||||
"duration_ms": duration_ms,
|
"duration_ms": duration_ms,
|
||||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||||
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
|
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||||
|
if status in ("completed", "failed")
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
self._step_history[execution_id].append(step_event)
|
self._step_history[execution_id].append(step_event)
|
||||||
|
|
||||||
async def complete_execution(
|
async def complete_execution(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
final_output: dict[str, Any] | None = None,
|
final_output: dict[str, object] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
exec_state = self._executions.get(execution_id)
|
exec_state = self._executions.get(execution_id)
|
||||||
if exec_state is None:
|
if exec_state is None:
|
||||||
|
|
@ -130,7 +132,7 @@ class PipelineStateMemory:
|
||||||
exec_state["updated_at"] = now
|
exec_state["updated_at"] = now
|
||||||
exec_state["completed_at"] = now
|
exec_state["completed_at"] = now
|
||||||
|
|
||||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||||
return self._executions.get(execution_id)
|
return self._executions.get(execution_id)
|
||||||
|
|
||||||
async def list_executions(
|
async def list_executions(
|
||||||
|
|
@ -138,17 +140,17 @@ class PipelineStateMemory:
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
results = list(self._executions.values())
|
results = list(self._executions.values())
|
||||||
if status:
|
if status:
|
||||||
results = [e for e in results if e.get("status") == status]
|
results = [e for e in results if e.get("status") == status]
|
||||||
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
|
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
|
||||||
return results[offset : offset + limit]
|
return results[offset : offset + limit]
|
||||||
|
|
||||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||||
return self._step_history.get(execution_id, [])
|
return self._step_history.get(execution_id, [])
|
||||||
|
|
||||||
def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None:
|
def get_execution_sync(self, execution_id: str) -> dict[str, object] | None:
|
||||||
"""Synchronous accessor for execution state (used by Redis dual-write)."""
|
"""Synchronous accessor for execution state (used by Redis dual-write)."""
|
||||||
return self._executions.get(execution_id)
|
return self._executions.get(execution_id)
|
||||||
|
|
||||||
|
|
@ -165,7 +167,7 @@ class PipelineStateRedis:
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
|
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
self._redis: Any = None
|
self._redis: object | None = None
|
||||||
self._fallback = PipelineStateMemory()
|
self._fallback = PipelineStateMemory()
|
||||||
self._use_fallback = False
|
self._use_fallback = False
|
||||||
self._fallback_since: float | None = None
|
self._fallback_since: float | None = None
|
||||||
|
|
@ -181,8 +183,8 @@ class PipelineStateRedis:
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
async def _safe_redis_call(
|
async def _safe_redis_call(
|
||||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
|
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
|
||||||
) -> Any:
|
) -> object | None:
|
||||||
"""Execute a Redis call, falling back to memory on failure.
|
"""Execute a Redis call, falling back to memory on failure.
|
||||||
|
|
||||||
After falling back, periodically retries Redis to enable recovery.
|
After falling back, periodically retries Redis to enable recovery.
|
||||||
|
|
@ -192,6 +194,7 @@ class PipelineStateRedis:
|
||||||
# Check if enough time has passed to attempt recovery
|
# Check if enough time has passed to attempt recovery
|
||||||
if self._fallback_since is not None:
|
if self._fallback_since is not None:
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
elapsed = _time.monotonic() - self._fallback_since
|
elapsed = _time.monotonic() - self._fallback_since
|
||||||
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
|
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
|
||||||
try:
|
try:
|
||||||
|
|
@ -218,6 +221,7 @@ class PipelineStateRedis:
|
||||||
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
|
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
|
||||||
self._use_fallback = True
|
self._use_fallback = True
|
||||||
import time as _time
|
import time as _time
|
||||||
|
|
||||||
self._fallback_since = _time.monotonic()
|
self._fallback_since = _time.monotonic()
|
||||||
self._redis = None
|
self._redis = None
|
||||||
return None
|
return None
|
||||||
|
|
@ -229,7 +233,7 @@ class PipelineStateRedis:
|
||||||
self,
|
self,
|
||||||
pipeline_name: str,
|
pipeline_name: str,
|
||||||
steps: list[str],
|
steps: list[str],
|
||||||
input_data: dict[str, Any] | None = None,
|
input_data: dict[str, object] | None = None,
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Always write to fallback first for consistency
|
# Always write to fallback first for consistency
|
||||||
|
|
@ -238,7 +242,7 @@ class PipelineStateRedis:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try Redis
|
# Try Redis
|
||||||
async def _redis_create(redis: Any) -> None:
|
async def _redis_create(redis: object) -> None:
|
||||||
state = self._fallback.get_execution_sync(execution_id)
|
state = self._fallback.get_execution_sync(execution_id)
|
||||||
score = datetime.now(timezone.utc).timestamp()
|
score = datetime.now(timezone.utc).timestamp()
|
||||||
pipe = redis.pipeline()
|
pipe = redis.pipeline()
|
||||||
|
|
@ -254,13 +258,15 @@ class PipelineStateRedis:
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
step_name: str,
|
step_name: str,
|
||||||
status: str,
|
status: str,
|
||||||
output: dict[str, Any] | None = None,
|
output: dict[str, object] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
duration_ms: int | None = None,
|
duration_ms: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
|
await self._fallback.update_step(
|
||||||
|
execution_id, step_name, status, output, error, duration_ms
|
||||||
|
)
|
||||||
|
|
||||||
async def _redis_update(redis: Any) -> None:
|
async def _redis_update(redis: object) -> None:
|
||||||
state = self._fallback.get_execution_sync(execution_id)
|
state = self._fallback.get_execution_sync(execution_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return
|
return
|
||||||
|
|
@ -271,11 +277,11 @@ class PipelineStateRedis:
|
||||||
async def complete_execution(
|
async def complete_execution(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
final_output: dict[str, Any] | None = None,
|
final_output: dict[str, object] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self._fallback.complete_execution(execution_id, final_output)
|
await self._fallback.complete_execution(execution_id, final_output)
|
||||||
|
|
||||||
async def _redis_complete(redis: Any) -> None:
|
async def _redis_complete(redis: object) -> None:
|
||||||
state = self._fallback.get_execution_sync(execution_id)
|
state = self._fallback.get_execution_sync(execution_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return
|
return
|
||||||
|
|
@ -291,7 +297,7 @@ class PipelineStateRedis:
|
||||||
) -> None:
|
) -> None:
|
||||||
await self._fallback.fail_execution(execution_id, step_name, error)
|
await self._fallback.fail_execution(execution_id, step_name, error)
|
||||||
|
|
||||||
async def _redis_fail(redis: Any) -> None:
|
async def _redis_fail(redis: object) -> None:
|
||||||
state = self._fallback.get_execution_sync(execution_id)
|
state = self._fallback.get_execution_sync(execution_id)
|
||||||
if state is None:
|
if state is None:
|
||||||
return
|
return
|
||||||
|
|
@ -299,7 +305,7 @@ class PipelineStateRedis:
|
||||||
|
|
||||||
await self._safe_redis_call(_redis_fail)
|
await self._safe_redis_call(_redis_fail)
|
||||||
|
|
||||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||||
# Try Redis first
|
# Try Redis first
|
||||||
if not self._use_fallback:
|
if not self._use_fallback:
|
||||||
try:
|
try:
|
||||||
|
|
@ -318,7 +324,7 @@ class PipelineStateRedis:
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
# Try Redis sorted set for efficient listing
|
# Try Redis sorted set for efficient listing
|
||||||
if not self._use_fallback:
|
if not self._use_fallback:
|
||||||
try:
|
try:
|
||||||
|
|
@ -341,7 +347,7 @@ class PipelineStateRedis:
|
||||||
|
|
||||||
return await self._fallback.list_executions(status, limit, offset)
|
return await self._fallback.list_executions(status, limit, offset)
|
||||||
|
|
||||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||||
return await self._fallback.get_step_history(execution_id)
|
return await self._fallback.get_step_history(execution_id)
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
|
|
@ -364,20 +370,18 @@ class PipelineStatePG:
|
||||||
If session_factory is None, all methods are no-op.
|
If session_factory is None, all methods are no-op.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, session_factory: Any = None) -> None:
|
def __init__(self, session_factory: object | None = None) -> None:
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
return self._session_factory is not None
|
return self._session_factory is not None
|
||||||
|
|
||||||
async def persist_execution(self, state: dict[str, Any]) -> None:
|
async def persist_execution(self, state: dict[str, object]) -> None:
|
||||||
"""Write a completed/failed execution to PostgreSQL."""
|
"""Write a completed/failed execution to PostgreSQL."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
async with self._session_factory() as session:
|
async with self._session_factory() as session:
|
||||||
model = PipelineExecutionModel(
|
model = PipelineExecutionModel(
|
||||||
id=state["id"],
|
id=state["id"],
|
||||||
|
|
@ -390,18 +394,22 @@ class PipelineStatePG:
|
||||||
final_output=state.get("final_output"),
|
final_output=state.get("final_output"),
|
||||||
error_message=state.get("error_message"),
|
error_message=state.get("error_message"),
|
||||||
tenant_id=state.get("tenant_id"),
|
tenant_id=state.get("tenant_id"),
|
||||||
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None,
|
created_at=datetime.fromisoformat(state["created_at"])
|
||||||
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None,
|
if state.get("created_at")
|
||||||
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None,
|
else None,
|
||||||
|
updated_at=datetime.fromisoformat(state["updated_at"])
|
||||||
|
if state.get("updated_at")
|
||||||
|
else None,
|
||||||
|
completed_at=datetime.fromisoformat(state["completed_at"])
|
||||||
|
if state.get("completed_at")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
await session.merge(model)
|
await session.merge(model)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Failed to persist execution to PG: {exc}")
|
logger.error(f"Failed to persist execution to PG: {exc}")
|
||||||
|
|
||||||
async def persist_step_history(
|
async def persist_step_history(self, execution_id: str, steps: list[dict[str, object]]) -> None:
|
||||||
self, execution_id: str, steps: list[dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
"""Write step history to PostgreSQL."""
|
"""Write step history to PostgreSQL."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return
|
return
|
||||||
|
|
@ -419,8 +427,12 @@ class PipelineStatePG:
|
||||||
error_message=step.get("error_message"),
|
error_message=step.get("error_message"),
|
||||||
duration_ms=step.get("duration_ms"),
|
duration_ms=step.get("duration_ms"),
|
||||||
retry_attempt=step.get("retry_attempt", 0),
|
retry_attempt=step.get("retry_attempt", 0),
|
||||||
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None,
|
started_at=datetime.fromisoformat(step["started_at"])
|
||||||
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None,
|
if step.get("started_at")
|
||||||
|
else None,
|
||||||
|
completed_at=datetime.fromisoformat(step["completed_at"])
|
||||||
|
if step.get("completed_at")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
await session.merge(model)
|
await session.merge(model)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
@ -433,7 +445,7 @@ class PipelineStatePG:
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
"""Query historical executions from PostgreSQL."""
|
"""Query historical executions from PostgreSQL."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return []
|
return []
|
||||||
|
|
@ -445,9 +457,7 @@ class PipelineStatePG:
|
||||||
PipelineExecutionModel.created_at.desc()
|
PipelineExecutionModel.created_at.desc()
|
||||||
)
|
)
|
||||||
if pipeline_name:
|
if pipeline_name:
|
||||||
stmt = stmt.where(
|
stmt = stmt.where(PipelineExecutionModel.pipeline_name == pipeline_name)
|
||||||
PipelineExecutionModel.pipeline_name == pipeline_name
|
|
||||||
)
|
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(PipelineExecutionModel.status == status)
|
stmt = stmt.where(PipelineExecutionModel.status == status)
|
||||||
stmt = stmt.offset(offset).limit(limit)
|
stmt = stmt.offset(offset).limit(limit)
|
||||||
|
|
@ -458,7 +468,7 @@ class PipelineStatePG:
|
||||||
logger.error(f"Failed to query executions from PG: {exc}")
|
logger.error(f"Failed to query executions from PG: {exc}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||||
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
|
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return None
|
return None
|
||||||
|
|
@ -479,7 +489,7 @@ class PipelineStatePG:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]:
|
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
"id": model.id,
|
"id": model.id,
|
||||||
"pipeline_name": model.pipeline_name,
|
"pipeline_name": model.pipeline_name,
|
||||||
|
|
@ -509,7 +519,7 @@ class PipelineStateManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
redis_url: str | None = None,
|
redis_url: str | None = None,
|
||||||
session_factory: Any = None,
|
session_factory: object | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if redis_url:
|
if redis_url:
|
||||||
self._hot = PipelineStateRedis(redis_url=redis_url)
|
self._hot = PipelineStateRedis(redis_url=redis_url)
|
||||||
|
|
@ -529,7 +539,7 @@ class PipelineStateManager:
|
||||||
self,
|
self,
|
||||||
pipeline_name: str,
|
pipeline_name: str,
|
||||||
steps: list[str],
|
steps: list[str],
|
||||||
input_data: dict[str, Any] | None = None,
|
input_data: dict[str, object] | None = None,
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
|
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
|
||||||
|
|
@ -539,7 +549,7 @@ class PipelineStateManager:
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
step_name: str,
|
step_name: str,
|
||||||
status: str,
|
status: str,
|
||||||
output: dict[str, Any] | None = None,
|
output: dict[str, object] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
duration_ms: int | None = None,
|
duration_ms: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -548,7 +558,7 @@ class PipelineStateManager:
|
||||||
async def complete_execution(
|
async def complete_execution(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
final_output: dict[str, Any] | None = None,
|
final_output: dict[str, object] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self._hot.complete_execution(execution_id, final_output)
|
await self._hot.complete_execution(execution_id, final_output)
|
||||||
# Persist to PG
|
# Persist to PG
|
||||||
|
|
@ -574,7 +584,7 @@ class PipelineStateManager:
|
||||||
if step_history:
|
if step_history:
|
||||||
await self._cold.persist_step_history(execution_id, step_history)
|
await self._cold.persist_step_history(execution_id, step_history)
|
||||||
|
|
||||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||||
# Redis / memory first
|
# Redis / memory first
|
||||||
state = await self._hot.get_execution(execution_id)
|
state = await self._hot.get_execution(execution_id)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
|
|
@ -587,7 +597,7 @@ class PipelineStateManager:
|
||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, object]]:
|
||||||
# Hot store for recent executions
|
# Hot store for recent executions
|
||||||
results = await self._hot.list_executions(status, limit, offset)
|
results = await self._hot.list_executions(status, limit, offset)
|
||||||
if results:
|
if results:
|
||||||
|
|
@ -595,7 +605,7 @@ class PipelineStateManager:
|
||||||
# Cold store for historical queries
|
# Cold store for historical queries
|
||||||
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
|
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
|
||||||
|
|
||||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||||
return await self._hot.get_step_history(execution_id)
|
return await self._hot.get_step_history(execution_id)
|
||||||
|
|
||||||
async def health_check(self) -> dict[str, bool]:
|
async def health_check(self) -> dict[str, bool]:
|
||||||
|
|
|
||||||
|
|
@ -15,10 +15,14 @@ import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import TypeAlias
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Scalar session state values (text/number/flag/none). Screen-state dicts use
|
||||||
|
# dict[str, object] because they also hold tuple cursor positions.
|
||||||
|
SessionState: TypeAlias = dict[str, str | int | bool | None]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScreenInfo:
|
class ScreenInfo:
|
||||||
|
|
@ -37,7 +41,7 @@ class ActionResult:
|
||||||
output: str = ""
|
output: str = ""
|
||||||
screenshot_base64: str = ""
|
screenshot_base64: str = ""
|
||||||
error: str = ""
|
error: str = ""
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, object] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ComputerUseSession(ABC):
|
class ComputerUseSession(ABC):
|
||||||
|
|
@ -56,7 +60,7 @@ class ComputerUseSession(ABC):
|
||||||
self.session_id = session_id or str(uuid.uuid4())
|
self.session_id = session_id or str(uuid.uuid4())
|
||||||
self.screen = ScreenInfo(width=screen_width, height=screen_height)
|
self.screen = ScreenInfo(width=screen_width, height=screen_height)
|
||||||
self._started = False
|
self._started = False
|
||||||
self._action_history: list[dict[str, Any]] = []
|
self._action_history: list[dict[str, object]] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_started(self) -> bool:
|
def is_started(self) -> bool:
|
||||||
|
|
@ -82,7 +86,7 @@ class ComputerUseSession(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""执行 UI 操作
|
"""执行 UI 操作
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -94,18 +98,20 @@ class ComputerUseSession(ABC):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None:
|
def record_action(self, action: str, params: dict[str, object], result: ActionResult) -> None:
|
||||||
"""记录操作历史"""
|
"""记录操作历史"""
|
||||||
self._action_history.append({
|
self._action_history.append(
|
||||||
"timestamp": time.time(),
|
{
|
||||||
"action": action,
|
"timestamp": time.time(),
|
||||||
"params": params,
|
"action": action,
|
||||||
"success": result.success,
|
"params": params,
|
||||||
"output": result.output[:200] if result.output else "",
|
"success": result.success,
|
||||||
})
|
"output": result.output[:200] if result.output else "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_history(self) -> list[dict[str, Any]]:
|
def action_history(self) -> list[dict[str, object]]:
|
||||||
"""获取操作历史(副本)"""
|
"""获取操作历史(副本)"""
|
||||||
return list(self._action_history)
|
return list(self._action_history)
|
||||||
|
|
||||||
|
|
@ -134,7 +140,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
||||||
screen_width=screen_width,
|
screen_width=screen_width,
|
||||||
screen_height=screen_height,
|
screen_height=screen_height,
|
||||||
)
|
)
|
||||||
self._screen_state: dict[str, Any] = {
|
self._screen_state: dict[str, object] = {
|
||||||
"focused_element": None,
|
"focused_element": None,
|
||||||
"cursor_position": (0, 0),
|
"cursor_position": (0, 0),
|
||||||
"typed_text": "",
|
"typed_text": "",
|
||||||
|
|
@ -173,7 +179,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
||||||
metadata={"screen_state": dict(self._screen_state)},
|
metadata={"screen_state": dict(self._screen_state)},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""模拟执行 UI 操作"""
|
"""模拟执行 UI 操作"""
|
||||||
if not self._started:
|
if not self._started:
|
||||||
return ActionResult(
|
return ActionResult(
|
||||||
|
|
@ -186,7 +192,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
||||||
self.record_action(action, params, result)
|
self.record_action(action, params, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _simulate_action(self, action: str, **params: Any) -> ActionResult:
|
def _simulate_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""模拟具体操作"""
|
"""模拟具体操作"""
|
||||||
if action == "click":
|
if action == "click":
|
||||||
x = params.get("x", 0)
|
x = params.get("x", 0)
|
||||||
|
|
@ -270,18 +276,78 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
screen_width=screen_width,
|
screen_width=screen_width,
|
||||||
screen_height=screen_height,
|
screen_height=screen_height,
|
||||||
)
|
)
|
||||||
self._pyautogui: Any = None
|
self._pyautogui: object | None = None
|
||||||
|
|
||||||
# Allowed keys for the `key` action — prevents OS-level shortcut abuse
|
# Allowed keys for the `key` action — prevents OS-level shortcut abuse
|
||||||
_ALLOWED_KEYS: set[str] = {
|
_ALLOWED_KEYS: set[str] = {
|
||||||
"enter", "return", "tab", "backspace", "delete", "home", "end",
|
"enter",
|
||||||
"up", "down", "left", "right", "pageup", "pagedown",
|
"return",
|
||||||
"space", "escape", "insert",
|
"tab",
|
||||||
"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12",
|
"backspace",
|
||||||
"shift", "ctrl", "alt", "command",
|
"delete",
|
||||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
|
"home",
|
||||||
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
|
"end",
|
||||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
"up",
|
||||||
|
"down",
|
||||||
|
"left",
|
||||||
|
"right",
|
||||||
|
"pageup",
|
||||||
|
"pagedown",
|
||||||
|
"space",
|
||||||
|
"escape",
|
||||||
|
"insert",
|
||||||
|
"f1",
|
||||||
|
"f2",
|
||||||
|
"f3",
|
||||||
|
"f4",
|
||||||
|
"f5",
|
||||||
|
"f6",
|
||||||
|
"f7",
|
||||||
|
"f8",
|
||||||
|
"f9",
|
||||||
|
"f10",
|
||||||
|
"f11",
|
||||||
|
"f12",
|
||||||
|
"shift",
|
||||||
|
"ctrl",
|
||||||
|
"alt",
|
||||||
|
"command",
|
||||||
|
"a",
|
||||||
|
"b",
|
||||||
|
"c",
|
||||||
|
"d",
|
||||||
|
"e",
|
||||||
|
"f",
|
||||||
|
"g",
|
||||||
|
"h",
|
||||||
|
"i",
|
||||||
|
"j",
|
||||||
|
"k",
|
||||||
|
"l",
|
||||||
|
"m",
|
||||||
|
"n",
|
||||||
|
"o",
|
||||||
|
"p",
|
||||||
|
"q",
|
||||||
|
"r",
|
||||||
|
"s",
|
||||||
|
"t",
|
||||||
|
"u",
|
||||||
|
"v",
|
||||||
|
"w",
|
||||||
|
"x",
|
||||||
|
"y",
|
||||||
|
"z",
|
||||||
|
"0",
|
||||||
|
"1",
|
||||||
|
"2",
|
||||||
|
"3",
|
||||||
|
"4",
|
||||||
|
"5",
|
||||||
|
"6",
|
||||||
|
"7",
|
||||||
|
"8",
|
||||||
|
"9",
|
||||||
}
|
}
|
||||||
_ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"}
|
_ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"}
|
||||||
_MAX_TEXT_LENGTH: int = 1000
|
_MAX_TEXT_LENGTH: int = 1000
|
||||||
|
|
@ -291,6 +357,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
"""启动本地桌面会话"""
|
"""启动本地桌面会话"""
|
||||||
try:
|
try:
|
||||||
import pyautogui
|
import pyautogui
|
||||||
|
|
||||||
self._pyautogui = pyautogui
|
self._pyautogui = pyautogui
|
||||||
pyautogui.FAILSAFE = True
|
pyautogui.FAILSAFE = True
|
||||||
pyautogui.PAUSE = 0.1
|
pyautogui.PAUSE = 0.1
|
||||||
|
|
@ -327,7 +394,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return ActionResult(success=False, action="screenshot", error=str(e))
|
return ActionResult(success=False, action="screenshot", error=str(e))
|
||||||
|
|
||||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""在本地桌面执行 UI 操作"""
|
"""在本地桌面执行 UI 操作"""
|
||||||
if not self._started:
|
if not self._started:
|
||||||
return ActionResult(success=False, action=action, error="Session not started")
|
return ActionResult(success=False, action=action, error="Session not started")
|
||||||
|
|
@ -343,7 +410,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
"""Check if coordinates are within screen bounds."""
|
"""Check if coordinates are within screen bounds."""
|
||||||
return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height
|
return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height
|
||||||
|
|
||||||
async def _execute_local_action(self, action: str, **params: Any) -> ActionResult:
|
async def _execute_local_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""Execute a local UI action with input validation."""
|
"""Execute a local UI action with input validation."""
|
||||||
pg = self._pyautogui
|
pg = self._pyautogui
|
||||||
|
|
||||||
|
|
@ -351,16 +418,24 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
x, y = params.get("x", 0), params.get("y", 0)
|
x, y = params.get("x", 0), params.get("y", 0)
|
||||||
button = params.get("button", "left")
|
button = params.get("button", "left")
|
||||||
if button not in self._ALLOWED_BUTTONS:
|
if button not in self._ALLOWED_BUTTONS:
|
||||||
return ActionResult(success=False, action="click", error=f"Invalid button: {button}")
|
return ActionResult(
|
||||||
|
success=False, action="click", error=f"Invalid button: {button}"
|
||||||
|
)
|
||||||
if not self._validate_coordinates(x, y):
|
if not self._validate_coordinates(x, y):
|
||||||
return ActionResult(success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})")
|
return ActionResult(
|
||||||
|
success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})"
|
||||||
|
)
|
||||||
pg.click(x, y, button=button)
|
pg.click(x, y, button=button)
|
||||||
return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})")
|
return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})")
|
||||||
|
|
||||||
if action == "type":
|
if action == "type":
|
||||||
text = params.get("text", "")
|
text = params.get("text", "")
|
||||||
if len(text) > self._MAX_TEXT_LENGTH:
|
if len(text) > self._MAX_TEXT_LENGTH:
|
||||||
return ActionResult(success=False, action="type", error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}")
|
return ActionResult(
|
||||||
|
success=False,
|
||||||
|
action="type",
|
||||||
|
error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}",
|
||||||
|
)
|
||||||
pg.write(text)
|
pg.write(text)
|
||||||
return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}")
|
return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}")
|
||||||
|
|
||||||
|
|
@ -369,16 +444,22 @@ class LocalComputerUseSession(ComputerUseSession):
|
||||||
amount = params.get("amount", 3)
|
amount = params.get("amount", 3)
|
||||||
clicks = amount if direction == "down" else -amount
|
clicks = amount if direction == "down" else -amount
|
||||||
pg.scroll(clicks)
|
pg.scroll(clicks)
|
||||||
return ActionResult(success=True, action="scroll", output=f"Scrolled {direction} by {amount}")
|
return ActionResult(
|
||||||
|
success=True, action="scroll", output=f"Scrolled {direction} by {amount}"
|
||||||
|
)
|
||||||
|
|
||||||
if action == "drag":
|
if action == "drag":
|
||||||
sx, sy = params.get("start_x", 0), params.get("start_y", 0)
|
sx, sy = params.get("start_x", 0), params.get("start_y", 0)
|
||||||
ex, ey = params.get("end_x", 0), params.get("end_y", 0)
|
ex, ey = params.get("end_x", 0), params.get("end_y", 0)
|
||||||
if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)):
|
if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)):
|
||||||
return ActionResult(success=False, action="drag", error="Drag coordinates out of bounds")
|
return ActionResult(
|
||||||
|
success=False, action="drag", error="Drag coordinates out of bounds"
|
||||||
|
)
|
||||||
pg.moveTo(sx, sy)
|
pg.moveTo(sx, sy)
|
||||||
pg.dragTo(ex, ey, duration=0.5)
|
pg.dragTo(ex, ey, duration=0.5)
|
||||||
return ActionResult(success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})")
|
return ActionResult(
|
||||||
|
success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})"
|
||||||
|
)
|
||||||
|
|
||||||
if action == "key":
|
if action == "key":
|
||||||
key_name = params.get("key_name", "")
|
key_name = params.get("key_name", "")
|
||||||
|
|
@ -487,7 +568,7 @@ class DockerComputerUseSession(ComputerUseSession):
|
||||||
screenshot_base64="",
|
screenshot_base64="",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||||
"""在 Docker 虚拟桌面执行操作
|
"""在 Docker 虚拟桌面执行操作
|
||||||
|
|
||||||
Stub: 实际实现需要通过 Anthropic Computer Use API。
|
Stub: 实际实现需要通过 Anthropic Computer Use API。
|
||||||
|
|
@ -527,7 +608,7 @@ class ComputerUseSessionManager:
|
||||||
def get_or_create(
|
def get_or_create(
|
||||||
self,
|
self,
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: object,
|
||||||
) -> ComputerUseSession:
|
) -> ComputerUseSession:
|
||||||
"""获取或创建会话"""
|
"""获取或创建会话"""
|
||||||
if session_id and session_id in self._sessions:
|
if session_id and session_id in self._sessions:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue