refactor(bitable,tools): replace Any with concrete types + Protocol (U4)

BitableRecord/FormulaResult/SessionState TypeAlias replace dict[str, Any]; _redis/_engine/_session_factory typed as object | None with TYPE_CHECKING Protocol (_RedisLike, _RecalcWorker); Coroutine[Any, Any, Any] retained as legitimate type param.

Baseline 40 : Any occurrences -> 0 across 6 in-scope files (target <=5). Deferred: repository.py/recalc_worker.py/ingestion/* (10 occurrences, separate PR).

ruff clean; 367 passed + 116 skipped (bitable + pipeline_state + tools).
This commit is contained in:
chiguyong 2026-06-30 22:32:30 +08:00
parent be5c4e09f8
commit 1033346913
6 changed files with 256 additions and 133 deletions

View File

@ -18,7 +18,7 @@ import logging
import os
import uuid as _uuid
from datetime import datetime, timezone
from typing import Any
from types import TracebackType
from sqlalchemy import (
Column,
@ -191,7 +191,7 @@ class MetaModel(BitableBase):
# ---------------------------------------------------------------------------
async def _apply_v2_migration(conn: Any) -> None:
async def _apply_v2_migration(conn: object) -> None:
"""V2 migration: create ``bitable_files`` table + add ``file_id`` to tables.
Idempotent safe to call on fresh installs (``create_all`` already made
@ -265,8 +265,8 @@ class BitableDB:
def __init__(self, database_url: str | None = None) -> None:
self._database_url = database_url or _resolve_database_url()
self._engine: Any = None
self._session_factory: Any = None
self._engine: object | None = None
self._session_factory: object | None = None
self._initialized = False
self._init_lock = asyncio.Lock()
@ -275,11 +275,11 @@ class BitableDB:
return self._database_url
@property
def engine(self) -> Any:
def engine(self) -> object | None:
return self._engine
@property
def session_factory(self) -> Any:
def session_factory(self) -> object | None:
return self._session_factory
@property
@ -365,7 +365,12 @@ class BitableDB:
await self.init()
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

View File

@ -12,12 +12,17 @@ based on the calling context — see :mod:`agentkit.bitable.formula.engine`.
from __future__ import annotations
from typing import Any, Callable
from typing import Callable, TypeAlias
# A formula evaluates to a scalar primitive: text, number, or nothing.
# bool is intentionally excluded — comparisons live in the parser layer
# and never reach the function registry.
FormulaResult: TypeAlias = str | int | float | None
# ── Aggregate functions (operate on lists) ────────────────
def _sum(values: list[Any]) -> float | int:
def _sum(values: list[FormulaResult]) -> float | int:
"""Sum of numeric values, ignoring None/empty."""
total = 0
for v in values:
@ -27,7 +32,7 @@ def _sum(values: list[Any]) -> float | int:
return total
def _avg(values: list[Any]) -> float:
def _avg(values: list[FormulaResult]) -> float:
"""Average of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""]
if not nums:
@ -35,12 +40,12 @@ def _avg(values: list[Any]) -> float:
return sum(nums) / len(nums)
def _count(values: list[Any]) -> int:
def _count(values: list[FormulaResult]) -> int:
"""Count of non-empty values."""
return sum(1 for v in values if v is not None and v != "")
def _min(values: list[Any]) -> Any:
def _min(values: list[FormulaResult]) -> FormulaResult:
"""Minimum of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""]
if not nums:
@ -48,7 +53,7 @@ def _min(values: list[Any]) -> Any:
return min(nums)
def _max(values: list[Any]) -> Any:
def _max(values: list[FormulaResult]) -> FormulaResult:
"""Maximum of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""]
if not nums:
@ -59,25 +64,29 @@ def _max(values: list[Any]) -> Any:
# ── Scalar functions ──────────────────────────────────────
def _abs(value: Any) -> Any:
def _abs(value: FormulaResult) -> FormulaResult:
return abs(value)
def _round(value: Any, digits: int = 0) -> float:
def _round(value: FormulaResult, digits: int = 0) -> float:
return round(value, digits)
def _if(condition: Any, true_val: Any, false_val: Any = None) -> Any:
def _if(
condition: FormulaResult,
true_val: FormulaResult,
false_val: FormulaResult = None,
) -> FormulaResult:
return true_val if condition else false_val
def _len(value: Any) -> int:
def _len(value: FormulaResult) -> int:
if value is None:
return 0
return len(str(value))
def _concat(*args: Any) -> str:
def _concat(*args: FormulaResult) -> str:
"""Concatenate all arguments as strings."""
return "".join(str(a) for a in args if a is not None)
@ -87,7 +96,7 @@ def _concat(*args: Any) -> str:
# Functions that aggregate a column (receive a list of all column values)
AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"})
FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = {
FUNCTION_REGISTRY: dict[str, Callable[..., FormulaResult]] = {
"SUM": _sum,
"AVG": _avg,
"COUNT": _count,

View File

@ -25,9 +25,9 @@ from __future__ import annotations
import ast
import re
from typing import Any
from typing import Callable
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY, FormulaResult
# ── Exceptions ────────────────────────────────────────────
@ -184,9 +184,9 @@ def parse_formula(
def evaluate_ast(
tree: ast.Expression,
field_values: dict[str, Any],
functions: dict[str, Any],
) -> Any:
field_values: dict[str, FormulaResult | list[FormulaResult]],
functions: dict[str, Callable[..., FormulaResult]],
) -> FormulaResult:
"""Evaluate a parsed formula AST against field values and functions.
This is NOT ``eval()`` it's a manual AST walker that only processes
@ -204,7 +204,11 @@ def evaluate_ast(
return _eval_node(tree.body, field_values, functions)
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any:
def _eval_node(
node: ast.AST,
fields: dict[str, FormulaResult | list[FormulaResult]],
functions: dict[str, Callable[..., FormulaResult]],
) -> FormulaResult:
"""Recursively evaluate an AST node."""
if isinstance(node, ast.Constant):
return node.value
@ -274,7 +278,7 @@ def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any])
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
def _apply_binop(op: ast.AST, left: FormulaResult, right: FormulaResult) -> FormulaResult:
"""Apply a binary operator."""
if isinstance(op, ast.Add):
# String concat or numeric addition
@ -294,7 +298,7 @@ def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool:
def _apply_compare(op: ast.AST, left: FormulaResult, right: FormulaResult) -> bool:
"""Apply a comparison operator."""
if isinstance(op, ast.Eq):
return left == right

View File

@ -12,7 +12,7 @@ import logging
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, TypeAlias
from agentkit.bitable.db import BitableDB
from agentkit.bitable.models import (
@ -29,13 +29,27 @@ from agentkit.bitable.models import (
)
from agentkit.bitable.repository import BitableRepository
if TYPE_CHECKING:
from typing import Protocol
class _RecalcWorker(Protocol):
"""Structural type for the recalc worker's cache-invalidation surface."""
def invalidate_engine(self, table_id: str) -> None: ...
logger = logging.getLogger(__name__)
# Record values are JSON scalars (text/number/none). Attachment/image fields
# store lists of dicts at runtime, but the common-case scalar shape is captured
# here for annotation clarity; fall back to dict[str, object] where lists occur.
BitableRecord: TypeAlias = dict[str, str | int | float | None]
class FieldDependencyError(Exception):
"""Raised when deleting a field that has dependencies (formula refs, PK, views)."""
def __init__(self, message: str, dependencies: dict[str, Any]) -> None:
def __init__(self, message: str, dependencies: dict[str, object]) -> None:
super().__init__(message)
self.dependencies = dependencies
@ -52,13 +66,13 @@ class BitableService:
def __init__(self, db: BitableDB) -> None:
self._db = db
self._repo = BitableRepository(db)
self._recalc_worker: Any = None # RecalcWorker, set via set_recalc_worker
self._recalc_worker: _RecalcWorker | None = None # set via set_recalc_worker
@property
def repo(self) -> BitableRepository:
return self._repo
def set_recalc_worker(self, worker: Any) -> None:
def set_recalc_worker(self, worker: _RecalcWorker) -> None:
"""Register the long-lived RecalcWorker so field changes can invalidate its engine cache.
Called after both service and worker are constructed (breaks the
@ -95,7 +109,7 @@ class BitableService:
async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]:
return await self._repo.list_files(owner_user_id=owner_user_id)
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
return await self._repo.update_file(file_id, **kwargs)
async def delete_file(self, file_id: str) -> bool:
@ -162,7 +176,7 @@ class BitableService:
async def list_tables(self, owner_user_id: str | None = None) -> list[Table]:
return await self._repo.list_tables(owner_user_id=owner_user_id)
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
"""Update table attrs. Creates PK unique index if primary_key_field_id is set."""
table = await self._repo.update_table(table_id, **kwargs)
if table and kwargs.get("primary_key_field_id"):
@ -179,7 +193,7 @@ class BitableService:
table_id: str,
name: str,
field_type: FieldType,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
owner: FieldOwner = FieldOwner.user,
) -> Field:
"""Create a new field. U2 will add formula validation and DAG updates."""
@ -201,7 +215,7 @@ class BitableService:
async def list_fields(self, table_id: str) -> list[Field]:
return await self._repo.list_fields(table_id)
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
"""Update a field. U2 will add dependency checking."""
field = await self._repo.update_field(field_id, **kwargs)
if field is not None:
@ -220,7 +234,7 @@ class BitableService:
return False
# Check dependencies
deps: dict[str, Any] = {}
deps: dict[str, object] = {}
# 1. Is it a primary key field?
table = await self._repo.get_table(field.table_id)
@ -264,7 +278,7 @@ class BitableService:
async def create_record(
self,
table_id: str,
values: dict[str, Any] | None = None,
values: BitableRecord | None = None,
actor_user_id: str | None = None,
) -> Record:
"""Create a new record. Triggers recalc for affected formula fields.
@ -291,7 +305,7 @@ class BitableService:
return record
async def create_records_batch(
self, table_id: str, records_values: list[dict[str, Any]]
self, table_id: str, records_values: list[BitableRecord]
) -> list[Record]:
"""Batch-create records (P2 #19). Triggers recalc for each record.
@ -319,8 +333,8 @@ class BitableService:
async def list_records_filtered(
self,
table_id: str,
filters: list[dict[str, Any]] | None = None,
sorts: list[dict[str, Any]] | None = None,
filters: list[dict[str, object]] | None = None,
sorts: list[dict[str, object]] | None = None,
cursor: str | None = None,
limit: int = 50,
) -> tuple[list[Record], str | None]:
@ -345,7 +359,7 @@ class BitableService:
table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit
)
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
async def update_record_values(self, record_id: str, values: BitableRecord) -> Record | None:
"""Update a record's values (full replace). Triggers recalc for affected formulas."""
record = await self._repo.update_record_values(record_id, values)
if record is not None:
@ -431,9 +445,9 @@ class BitableService:
async def upsert_records(
self,
table_id: str,
records: list[dict[str, Any]],
records: list[BitableRecord],
primary_key_field_id: str,
) -> dict[str, Any]:
) -> dict[str, int]:
"""Upsert records by primary key using jsonb_set (KTD8).
For each record:
@ -454,12 +468,12 @@ class BitableService:
agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent}
# Partition records into insert vs update lists, collecting PK values.
to_insert: list[dict[str, Any]] = []
to_update: list[tuple[dict[str, Any], str]] = [] # (values, existing_record_id)
to_insert: list[BitableRecord] = []
to_update: list[tuple[BitableRecord, str]] = [] # (values, existing_record_id)
skipped = 0
# Collect all non-None PK values for batch lookup.
pk_values_by_str: dict[str, dict[str, Any]] = {}
pk_values_by_str: dict[str, BitableRecord] = {}
for rec_values in records:
pk_value = rec_values.get(primary_key_field_id)
if pk_value is None:
@ -504,7 +518,7 @@ class BitableService:
table_id: str,
name: str,
view_type: ViewType = ViewType.grid,
config: dict[str, Any] | None = None,
config: dict[str, object] | None = None,
) -> View:
return await self._repo.create_view(
table_id=table_id,
@ -516,7 +530,7 @@ class BitableService:
async def list_views(self, table_id: str) -> list[View]:
return await self._repo.list_views(table_id)
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
return await self._repo.update_view(view_id, **kwargs)
async def get_view(self, view_id: str) -> View | None:

View File

@ -32,14 +32,14 @@ class PipelineStateMemory:
"""In-memory pipeline state storage (testing / fallback)."""
def __init__(self) -> None:
self._executions: dict[str, dict[str, Any]] = {}
self._step_history: dict[str, list[dict[str, Any]]] = {}
self._executions: dict[str, dict[str, object]] = {}
self._step_history: dict[str, list[dict[str, object]]] = {}
async def create_execution(
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
input_data: dict[str, object] | None = None,
tenant_id: str | None = None,
) -> str:
execution_id = str(uuid.uuid4())
@ -67,7 +67,7 @@ class PipelineStateMemory:
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
output: dict[str, object] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
@ -88,7 +88,7 @@ class PipelineStateMemory:
exec_state["error_message"] = error
# Record step history event
step_event: dict[str, Any] = {
step_event: dict[str, object] = {
"id": str(uuid.uuid4()),
"execution_id": execution_id,
"step_name": step_name,
@ -97,14 +97,16 @@ class PipelineStateMemory:
"error_message": error,
"duration_ms": duration_ms,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
"completed_at": datetime.now(timezone.utc).isoformat()
if status in ("completed", "failed")
else None,
}
self._step_history[execution_id].append(step_event)
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
final_output: dict[str, object] | None = None,
) -> None:
exec_state = self._executions.get(execution_id)
if exec_state is None:
@ -130,7 +132,7 @@ class PipelineStateMemory:
exec_state["updated_at"] = now
exec_state["completed_at"] = now
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
return self._executions.get(execution_id)
async def list_executions(
@ -138,17 +140,17 @@ class PipelineStateMemory:
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
results = list(self._executions.values())
if status:
results = [e for e in results if e.get("status") == status]
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
return results[offset : offset + limit]
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return self._step_history.get(execution_id, [])
def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None:
def get_execution_sync(self, execution_id: str) -> dict[str, object] | None:
"""Synchronous accessor for execution state (used by Redis dual-write)."""
return self._executions.get(execution_id)
@ -165,7 +167,7 @@ class PipelineStateRedis:
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
self._redis_url = redis_url
self._redis: Any = None
self._redis: object | None = None
self._fallback = PipelineStateMemory()
self._use_fallback = False
self._fallback_since: float | None = None
@ -181,8 +183,8 @@ class PipelineStateRedis:
return self._redis
async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
) -> object | None:
"""Execute a Redis call, falling back to memory on failure.
After falling back, periodically retries Redis to enable recovery.
@ -192,6 +194,7 @@ class PipelineStateRedis:
# Check if enough time has passed to attempt recovery
if self._fallback_since is not None:
import time as _time
elapsed = _time.monotonic() - self._fallback_since
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
try:
@ -218,6 +221,7 @@ class PipelineStateRedis:
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
self._use_fallback = True
import time as _time
self._fallback_since = _time.monotonic()
self._redis = None
return None
@ -229,7 +233,7 @@ class PipelineStateRedis:
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
input_data: dict[str, object] | None = None,
tenant_id: str | None = None,
) -> str:
# Always write to fallback first for consistency
@ -238,7 +242,7 @@ class PipelineStateRedis:
)
# Try Redis
async def _redis_create(redis: Any) -> None:
async def _redis_create(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id)
score = datetime.now(timezone.utc).timestamp()
pipe = redis.pipeline()
@ -254,13 +258,15 @@ class PipelineStateRedis:
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
output: dict[str, object] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
await self._fallback.update_step(
execution_id, step_name, status, output, error, duration_ms
)
async def _redis_update(redis: Any) -> None:
async def _redis_update(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id)
if state is None:
return
@ -271,11 +277,11 @@ class PipelineStateRedis:
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
final_output: dict[str, object] | None = None,
) -> None:
await self._fallback.complete_execution(execution_id, final_output)
async def _redis_complete(redis: Any) -> None:
async def _redis_complete(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id)
if state is None:
return
@ -291,7 +297,7 @@ class PipelineStateRedis:
) -> None:
await self._fallback.fail_execution(execution_id, step_name, error)
async def _redis_fail(redis: Any) -> None:
async def _redis_fail(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id)
if state is None:
return
@ -299,7 +305,7 @@ class PipelineStateRedis:
await self._safe_redis_call(_redis_fail)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
# Try Redis first
if not self._use_fallback:
try:
@ -318,7 +324,7 @@ class PipelineStateRedis:
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
# Try Redis sorted set for efficient listing
if not self._use_fallback:
try:
@ -341,7 +347,7 @@ class PipelineStateRedis:
return await self._fallback.list_executions(status, limit, offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return await self._fallback.get_step_history(execution_id)
async def health_check(self) -> bool:
@ -364,20 +370,18 @@ class PipelineStatePG:
If session_factory is None, all methods are no-op.
"""
def __init__(self, session_factory: Any = None) -> None:
def __init__(self, session_factory: object | None = None) -> None:
self._session_factory = session_factory
@property
def enabled(self) -> bool:
return self._session_factory is not None
async def persist_execution(self, state: dict[str, Any]) -> None:
async def persist_execution(self, state: dict[str, object]) -> None:
"""Write a completed/failed execution to PostgreSQL."""
if not self.enabled:
return
try:
from sqlalchemy.ext.asyncio import AsyncSession
async with self._session_factory() as session:
model = PipelineExecutionModel(
id=state["id"],
@ -390,18 +394,22 @@ class PipelineStatePG:
final_output=state.get("final_output"),
error_message=state.get("error_message"),
tenant_id=state.get("tenant_id"),
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None,
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None,
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None,
created_at=datetime.fromisoformat(state["created_at"])
if state.get("created_at")
else None,
updated_at=datetime.fromisoformat(state["updated_at"])
if state.get("updated_at")
else None,
completed_at=datetime.fromisoformat(state["completed_at"])
if state.get("completed_at")
else None,
)
await session.merge(model)
await session.commit()
except Exception as exc:
logger.error(f"Failed to persist execution to PG: {exc}")
async def persist_step_history(
self, execution_id: str, steps: list[dict[str, Any]]
) -> None:
async def persist_step_history(self, execution_id: str, steps: list[dict[str, object]]) -> None:
"""Write step history to PostgreSQL."""
if not self.enabled:
return
@ -419,8 +427,12 @@ class PipelineStatePG:
error_message=step.get("error_message"),
duration_ms=step.get("duration_ms"),
retry_attempt=step.get("retry_attempt", 0),
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None,
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None,
started_at=datetime.fromisoformat(step["started_at"])
if step.get("started_at")
else None,
completed_at=datetime.fromisoformat(step["completed_at"])
if step.get("completed_at")
else None,
)
await session.merge(model)
await session.commit()
@ -433,7 +445,7 @@ class PipelineStatePG:
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
"""Query historical executions from PostgreSQL."""
if not self.enabled:
return []
@ -445,9 +457,7 @@ class PipelineStatePG:
PipelineExecutionModel.created_at.desc()
)
if pipeline_name:
stmt = stmt.where(
PipelineExecutionModel.pipeline_name == pipeline_name
)
stmt = stmt.where(PipelineExecutionModel.pipeline_name == pipeline_name)
if status:
stmt = stmt.where(PipelineExecutionModel.status == status)
stmt = stmt.offset(offset).limit(limit)
@ -458,7 +468,7 @@ class PipelineStatePG:
logger.error(f"Failed to query executions from PG: {exc}")
return []
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
if not self.enabled:
return None
@ -479,7 +489,7 @@ class PipelineStatePG:
return None
@staticmethod
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]:
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, object]:
return {
"id": model.id,
"pipeline_name": model.pipeline_name,
@ -509,7 +519,7 @@ class PipelineStateManager:
def __init__(
self,
redis_url: str | None = None,
session_factory: Any = None,
session_factory: object | None = None,
) -> None:
if redis_url:
self._hot = PipelineStateRedis(redis_url=redis_url)
@ -529,7 +539,7 @@ class PipelineStateManager:
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
input_data: dict[str, object] | None = None,
tenant_id: str | None = None,
) -> str:
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
@ -539,7 +549,7 @@ class PipelineStateManager:
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
output: dict[str, object] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
@ -548,7 +558,7 @@ class PipelineStateManager:
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
final_output: dict[str, object] | None = None,
) -> None:
await self._hot.complete_execution(execution_id, final_output)
# Persist to PG
@ -574,7 +584,7 @@ class PipelineStateManager:
if step_history:
await self._cold.persist_step_history(execution_id, step_history)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
# Redis / memory first
state = await self._hot.get_execution(execution_id)
if state is not None:
@ -587,7 +597,7 @@ class PipelineStateManager:
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
) -> list[dict[str, object]]:
# Hot store for recent executions
results = await self._hot.list_executions(status, limit, offset)
if results:
@ -595,7 +605,7 @@ class PipelineStateManager:
# Cold store for historical queries
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return await self._hot.get_step_history(execution_id)
async def health_check(self) -> dict[str, bool]:

View File

@ -15,10 +15,14 @@ import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from typing import TypeAlias
logger = logging.getLogger(__name__)
# Scalar session state values (text/number/flag/none). Screen-state dicts use
# dict[str, object] because they also hold tuple cursor positions.
SessionState: TypeAlias = dict[str, str | int | bool | None]
@dataclass
class ScreenInfo:
@ -37,7 +41,7 @@ class ActionResult:
output: str = ""
screenshot_base64: str = ""
error: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
metadata: dict[str, object] = field(default_factory=dict)
class ComputerUseSession(ABC):
@ -56,7 +60,7 @@ class ComputerUseSession(ABC):
self.session_id = session_id or str(uuid.uuid4())
self.screen = ScreenInfo(width=screen_width, height=screen_height)
self._started = False
self._action_history: list[dict[str, Any]] = []
self._action_history: list[dict[str, object]] = []
@property
def is_started(self) -> bool:
@ -82,7 +86,7 @@ class ComputerUseSession(ABC):
...
@abstractmethod
async def execute_action(self, action: str, **params: Any) -> ActionResult:
async def execute_action(self, action: str, **params: object) -> ActionResult:
"""执行 UI 操作
Args:
@ -94,18 +98,20 @@ class ComputerUseSession(ABC):
"""
...
def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None:
def record_action(self, action: str, params: dict[str, object], result: ActionResult) -> None:
"""记录操作历史"""
self._action_history.append({
self._action_history.append(
{
"timestamp": time.time(),
"action": action,
"params": params,
"success": result.success,
"output": result.output[:200] if result.output else "",
})
}
)
@property
def action_history(self) -> list[dict[str, Any]]:
def action_history(self) -> list[dict[str, object]]:
"""获取操作历史(副本)"""
return list(self._action_history)
@ -134,7 +140,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
screen_width=screen_width,
screen_height=screen_height,
)
self._screen_state: dict[str, Any] = {
self._screen_state: dict[str, object] = {
"focused_element": None,
"cursor_position": (0, 0),
"typed_text": "",
@ -173,7 +179,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
metadata={"screen_state": dict(self._screen_state)},
)
async def execute_action(self, action: str, **params: Any) -> ActionResult:
async def execute_action(self, action: str, **params: object) -> ActionResult:
"""模拟执行 UI 操作"""
if not self._started:
return ActionResult(
@ -186,7 +192,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
self.record_action(action, params, result)
return result
def _simulate_action(self, action: str, **params: Any) -> ActionResult:
def _simulate_action(self, action: str, **params: object) -> ActionResult:
"""模拟具体操作"""
if action == "click":
x = params.get("x", 0)
@ -270,18 +276,78 @@ class LocalComputerUseSession(ComputerUseSession):
screen_width=screen_width,
screen_height=screen_height,
)
self._pyautogui: Any = None
self._pyautogui: object | None = None
# Allowed keys for the `key` action — prevents OS-level shortcut abuse
_ALLOWED_KEYS: set[str] = {
"enter", "return", "tab", "backspace", "delete", "home", "end",
"up", "down", "left", "right", "pageup", "pagedown",
"space", "escape", "insert",
"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12",
"shift", "ctrl", "alt", "command",
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
"enter",
"return",
"tab",
"backspace",
"delete",
"home",
"end",
"up",
"down",
"left",
"right",
"pageup",
"pagedown",
"space",
"escape",
"insert",
"f1",
"f2",
"f3",
"f4",
"f5",
"f6",
"f7",
"f8",
"f9",
"f10",
"f11",
"f12",
"shift",
"ctrl",
"alt",
"command",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
}
_ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"}
_MAX_TEXT_LENGTH: int = 1000
@ -291,6 +357,7 @@ class LocalComputerUseSession(ComputerUseSession):
"""启动本地桌面会话"""
try:
import pyautogui
self._pyautogui = pyautogui
pyautogui.FAILSAFE = True
pyautogui.PAUSE = 0.1
@ -327,7 +394,7 @@ class LocalComputerUseSession(ComputerUseSession):
except Exception as e:
return ActionResult(success=False, action="screenshot", error=str(e))
async def execute_action(self, action: str, **params: Any) -> ActionResult:
async def execute_action(self, action: str, **params: object) -> ActionResult:
"""在本地桌面执行 UI 操作"""
if not self._started:
return ActionResult(success=False, action=action, error="Session not started")
@ -343,7 +410,7 @@ class LocalComputerUseSession(ComputerUseSession):
"""Check if coordinates are within screen bounds."""
return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height
async def _execute_local_action(self, action: str, **params: Any) -> ActionResult:
async def _execute_local_action(self, action: str, **params: object) -> ActionResult:
"""Execute a local UI action with input validation."""
pg = self._pyautogui
@ -351,16 +418,24 @@ class LocalComputerUseSession(ComputerUseSession):
x, y = params.get("x", 0), params.get("y", 0)
button = params.get("button", "left")
if button not in self._ALLOWED_BUTTONS:
return ActionResult(success=False, action="click", error=f"Invalid button: {button}")
return ActionResult(
success=False, action="click", error=f"Invalid button: {button}"
)
if not self._validate_coordinates(x, y):
return ActionResult(success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})")
return ActionResult(
success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})"
)
pg.click(x, y, button=button)
return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})")
if action == "type":
text = params.get("text", "")
if len(text) > self._MAX_TEXT_LENGTH:
return ActionResult(success=False, action="type", error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}")
return ActionResult(
success=False,
action="type",
error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}",
)
pg.write(text)
return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}")
@ -369,16 +444,22 @@ class LocalComputerUseSession(ComputerUseSession):
amount = params.get("amount", 3)
clicks = amount if direction == "down" else -amount
pg.scroll(clicks)
return ActionResult(success=True, action="scroll", output=f"Scrolled {direction} by {amount}")
return ActionResult(
success=True, action="scroll", output=f"Scrolled {direction} by {amount}"
)
if action == "drag":
sx, sy = params.get("start_x", 0), params.get("start_y", 0)
ex, ey = params.get("end_x", 0), params.get("end_y", 0)
if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)):
return ActionResult(success=False, action="drag", error="Drag coordinates out of bounds")
return ActionResult(
success=False, action="drag", error="Drag coordinates out of bounds"
)
pg.moveTo(sx, sy)
pg.dragTo(ex, ey, duration=0.5)
return ActionResult(success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})")
return ActionResult(
success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})"
)
if action == "key":
key_name = params.get("key_name", "")
@ -487,7 +568,7 @@ class DockerComputerUseSession(ComputerUseSession):
screenshot_base64="",
)
async def execute_action(self, action: str, **params: Any) -> ActionResult:
async def execute_action(self, action: str, **params: object) -> ActionResult:
"""在 Docker 虚拟桌面执行操作
Stub: 实际实现需要通过 Anthropic Computer Use API
@ -527,7 +608,7 @@ class ComputerUseSessionManager:
def get_or_create(
self,
session_id: str | None = None,
**kwargs: Any,
**kwargs: object,
) -> ComputerUseSession:
"""获取或创建会话"""
if session_id and session_id in self._sessions: