diff --git a/src/agentkit/bitable/db.py b/src/agentkit/bitable/db.py index a24445d..e5d00ee 100644 --- a/src/agentkit/bitable/db.py +++ b/src/agentkit/bitable/db.py @@ -18,7 +18,7 @@ import logging import os import uuid as _uuid from datetime import datetime, timezone -from typing import Any +from types import TracebackType from sqlalchemy import ( Column, @@ -191,7 +191,7 @@ class MetaModel(BitableBase): # --------------------------------------------------------------------------- -async def _apply_v2_migration(conn: Any) -> None: +async def _apply_v2_migration(conn: object) -> None: """V2 migration: create ``bitable_files`` table + add ``file_id`` to tables. Idempotent — safe to call on fresh installs (``create_all`` already made @@ -265,8 +265,8 @@ class BitableDB: def __init__(self, database_url: str | None = None) -> None: self._database_url = database_url or _resolve_database_url() - self._engine: Any = None - self._session_factory: Any = None + self._engine: object | None = None + self._session_factory: object | None = None self._initialized = False self._init_lock = asyncio.Lock() @@ -275,11 +275,11 @@ class BitableDB: return self._database_url @property - def engine(self) -> Any: + def engine(self) -> object | None: return self._engine @property - def session_factory(self) -> Any: + def session_factory(self) -> object | None: return self._session_factory @property @@ -365,7 +365,12 @@ class BitableDB: await self.init() return self - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: await self.close() diff --git a/src/agentkit/bitable/formula/functions.py b/src/agentkit/bitable/formula/functions.py index b06b435..05009e7 100644 --- a/src/agentkit/bitable/formula/functions.py +++ b/src/agentkit/bitable/formula/functions.py @@ -12,12 +12,17 @@ based on the calling context — see :mod:`agentkit.bitable.formula.engine`. from __future__ import annotations -from typing import Any, Callable +from typing import Callable, TypeAlias + +# A formula evaluates to a scalar primitive: text, number, or nothing. +# bool is intentionally excluded — comparisons live in the parser layer +# and never reach the function registry. +FormulaResult: TypeAlias = str | int | float | None # ── Aggregate functions (operate on lists) ──────────────── -def _sum(values: list[Any]) -> float | int: +def _sum(values: list[FormulaResult]) -> float | int: """Sum of numeric values, ignoring None/empty.""" total = 0 for v in values: @@ -27,7 +32,7 @@ def _sum(values: list[Any]) -> float | int: return total -def _avg(values: list[Any]) -> float: +def _avg(values: list[FormulaResult]) -> float: """Average of numeric values, ignoring None/empty.""" nums = [v for v in values if v is not None and v != ""] if not nums: @@ -35,12 +40,12 @@ def _avg(values: list[Any]) -> float: return sum(nums) / len(nums) -def _count(values: list[Any]) -> int: +def _count(values: list[FormulaResult]) -> int: """Count of non-empty values.""" return sum(1 for v in values if v is not None and v != "") -def _min(values: list[Any]) -> Any: +def _min(values: list[FormulaResult]) -> FormulaResult: """Minimum of numeric values, ignoring None/empty.""" nums = [v for v in values if v is not None and v != ""] if not nums: @@ -48,7 +53,7 @@ def _min(values: list[Any]) -> Any: return min(nums) -def _max(values: list[Any]) -> Any: +def _max(values: list[FormulaResult]) -> FormulaResult: """Maximum of numeric values, ignoring None/empty.""" nums = [v for v in values if v is not None and v != ""] if not nums: @@ -59,25 +64,29 @@ def _max(values: list[Any]) -> Any: # ── Scalar functions ────────────────────────────────────── -def _abs(value: Any) -> Any: +def _abs(value: FormulaResult) -> FormulaResult: return abs(value) -def _round(value: Any, digits: int = 0) -> float: +def _round(value: FormulaResult, digits: int = 0) -> float: return round(value, digits) -def _if(condition: Any, true_val: Any, false_val: Any = None) -> Any: +def _if( + condition: FormulaResult, + true_val: FormulaResult, + false_val: FormulaResult = None, +) -> FormulaResult: return true_val if condition else false_val -def _len(value: Any) -> int: +def _len(value: FormulaResult) -> int: if value is None: return 0 return len(str(value)) -def _concat(*args: Any) -> str: +def _concat(*args: FormulaResult) -> str: """Concatenate all arguments as strings.""" return "".join(str(a) for a in args if a is not None) @@ -87,7 +96,7 @@ def _concat(*args: Any) -> str: # Functions that aggregate a column (receive a list of all column values) AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"}) -FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = { +FUNCTION_REGISTRY: dict[str, Callable[..., FormulaResult]] = { "SUM": _sum, "AVG": _avg, "COUNT": _count, diff --git a/src/agentkit/bitable/formula/parser.py b/src/agentkit/bitable/formula/parser.py index 5f92785..e404e3b 100644 --- a/src/agentkit/bitable/formula/parser.py +++ b/src/agentkit/bitable/formula/parser.py @@ -25,9 +25,9 @@ from __future__ import annotations import ast import re -from typing import Any +from typing import Callable -from agentkit.bitable.formula.functions import FUNCTION_REGISTRY +from agentkit.bitable.formula.functions import FUNCTION_REGISTRY, FormulaResult # ── Exceptions ──────────────────────────────────────────── @@ -184,9 +184,9 @@ def parse_formula( def evaluate_ast( tree: ast.Expression, - field_values: dict[str, Any], - functions: dict[str, Any], -) -> Any: + field_values: dict[str, FormulaResult | list[FormulaResult]], + functions: dict[str, Callable[..., FormulaResult]], +) -> FormulaResult: """Evaluate a parsed formula AST against field values and functions. This is NOT ``eval()`` — it's a manual AST walker that only processes @@ -204,7 +204,11 @@ def evaluate_ast( return _eval_node(tree.body, field_values, functions) -def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any: +def _eval_node( + node: ast.AST, + fields: dict[str, FormulaResult | list[FormulaResult]], + functions: dict[str, Callable[..., FormulaResult]], +) -> FormulaResult: """Recursively evaluate an AST node.""" if isinstance(node, ast.Constant): return node.value @@ -274,7 +278,7 @@ def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}") -def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any: +def _apply_binop(op: ast.AST, left: FormulaResult, right: FormulaResult) -> FormulaResult: """Apply a binary operator.""" if isinstance(op, ast.Add): # String concat or numeric addition @@ -294,7 +298,7 @@ def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any: raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}") -def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool: +def _apply_compare(op: ast.AST, left: FormulaResult, right: FormulaResult) -> bool: """Apply a comparison operator.""" if isinstance(op, ast.Eq): return left == right diff --git a/src/agentkit/bitable/service.py b/src/agentkit/bitable/service.py index a236ea5..84ab40e 100644 --- a/src/agentkit/bitable/service.py +++ b/src/agentkit/bitable/service.py @@ -12,7 +12,7 @@ import logging import os from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, TypeAlias from agentkit.bitable.db import BitableDB from agentkit.bitable.models import ( @@ -29,13 +29,27 @@ from agentkit.bitable.models import ( ) from agentkit.bitable.repository import BitableRepository +if TYPE_CHECKING: + from typing import Protocol + + class _RecalcWorker(Protocol): + """Structural type for the recalc worker's cache-invalidation surface.""" + + def invalidate_engine(self, table_id: str) -> None: ... + + logger = logging.getLogger(__name__) +# Record values are JSON scalars (text/number/none). Attachment/image fields +# store lists of dicts at runtime, but the common-case scalar shape is captured +# here for annotation clarity; fall back to dict[str, object] where lists occur. +BitableRecord: TypeAlias = dict[str, str | int | float | None] + class FieldDependencyError(Exception): """Raised when deleting a field that has dependencies (formula refs, PK, views).""" - def __init__(self, message: str, dependencies: dict[str, Any]) -> None: + def __init__(self, message: str, dependencies: dict[str, object]) -> None: super().__init__(message) self.dependencies = dependencies @@ -52,13 +66,13 @@ class BitableService: def __init__(self, db: BitableDB) -> None: self._db = db self._repo = BitableRepository(db) - self._recalc_worker: Any = None # RecalcWorker, set via set_recalc_worker + self._recalc_worker: _RecalcWorker | None = None # set via set_recalc_worker @property def repo(self) -> BitableRepository: return self._repo - def set_recalc_worker(self, worker: Any) -> None: + def set_recalc_worker(self, worker: _RecalcWorker) -> None: """Register the long-lived RecalcWorker so field changes can invalidate its engine cache. Called after both service and worker are constructed (breaks the @@ -95,7 +109,7 @@ class BitableService: async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]: return await self._repo.list_files(owner_user_id=owner_user_id) - async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None: + async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None: return await self._repo.update_file(file_id, **kwargs) async def delete_file(self, file_id: str) -> bool: @@ -162,7 +176,7 @@ class BitableService: async def list_tables(self, owner_user_id: str | None = None) -> list[Table]: return await self._repo.list_tables(owner_user_id=owner_user_id) - async def update_table(self, table_id: str, **kwargs: Any) -> Table | None: + async def update_table(self, table_id: str, **kwargs: object) -> Table | None: """Update table attrs. Creates PK unique index if primary_key_field_id is set.""" table = await self._repo.update_table(table_id, **kwargs) if table and kwargs.get("primary_key_field_id"): @@ -179,7 +193,7 @@ class BitableService: table_id: str, name: str, field_type: FieldType, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, owner: FieldOwner = FieldOwner.user, ) -> Field: """Create a new field. U2 will add formula validation and DAG updates.""" @@ -201,7 +215,7 @@ class BitableService: async def list_fields(self, table_id: str) -> list[Field]: return await self._repo.list_fields(table_id) - async def update_field(self, field_id: str, **kwargs: Any) -> Field | None: + async def update_field(self, field_id: str, **kwargs: object) -> Field | None: """Update a field. U2 will add dependency checking.""" field = await self._repo.update_field(field_id, **kwargs) if field is not None: @@ -220,7 +234,7 @@ class BitableService: return False # Check dependencies - deps: dict[str, Any] = {} + deps: dict[str, object] = {} # 1. Is it a primary key field? table = await self._repo.get_table(field.table_id) @@ -264,7 +278,7 @@ class BitableService: async def create_record( self, table_id: str, - values: dict[str, Any] | None = None, + values: BitableRecord | None = None, actor_user_id: str | None = None, ) -> Record: """Create a new record. Triggers recalc for affected formula fields. @@ -291,7 +305,7 @@ class BitableService: return record async def create_records_batch( - self, table_id: str, records_values: list[dict[str, Any]] + self, table_id: str, records_values: list[BitableRecord] ) -> list[Record]: """Batch-create records (P2 #19). Triggers recalc for each record. @@ -319,8 +333,8 @@ class BitableService: async def list_records_filtered( self, table_id: str, - filters: list[dict[str, Any]] | None = None, - sorts: list[dict[str, Any]] | None = None, + filters: list[dict[str, object]] | None = None, + sorts: list[dict[str, object]] | None = None, cursor: str | None = None, limit: int = 50, ) -> tuple[list[Record], str | None]: @@ -345,7 +359,7 @@ class BitableService: table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit ) - async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None: + async def update_record_values(self, record_id: str, values: BitableRecord) -> Record | None: """Update a record's values (full replace). Triggers recalc for affected formulas.""" record = await self._repo.update_record_values(record_id, values) if record is not None: @@ -431,9 +445,9 @@ class BitableService: async def upsert_records( self, table_id: str, - records: list[dict[str, Any]], + records: list[BitableRecord], primary_key_field_id: str, - ) -> dict[str, Any]: + ) -> dict[str, int]: """Upsert records by primary key using jsonb_set (KTD8). For each record: @@ -454,12 +468,12 @@ class BitableService: agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent} # Partition records into insert vs update lists, collecting PK values. - to_insert: list[dict[str, Any]] = [] - to_update: list[tuple[dict[str, Any], str]] = [] # (values, existing_record_id) + to_insert: list[BitableRecord] = [] + to_update: list[tuple[BitableRecord, str]] = [] # (values, existing_record_id) skipped = 0 # Collect all non-None PK values for batch lookup. - pk_values_by_str: dict[str, dict[str, Any]] = {} + pk_values_by_str: dict[str, BitableRecord] = {} for rec_values in records: pk_value = rec_values.get(primary_key_field_id) if pk_value is None: @@ -504,7 +518,7 @@ class BitableService: table_id: str, name: str, view_type: ViewType = ViewType.grid, - config: dict[str, Any] | None = None, + config: dict[str, object] | None = None, ) -> View: return await self._repo.create_view( table_id=table_id, @@ -516,7 +530,7 @@ class BitableService: async def list_views(self, table_id: str) -> list[View]: return await self._repo.list_views(table_id) - async def update_view(self, view_id: str, **kwargs: Any) -> View | None: + async def update_view(self, view_id: str, **kwargs: object) -> View | None: return await self._repo.update_view(view_id, **kwargs) async def get_view(self, view_id: str) -> View | None: diff --git a/src/agentkit/orchestrator/pipeline_state.py b/src/agentkit/orchestrator/pipeline_state.py index a176d5a..1acc9c8 100644 --- a/src/agentkit/orchestrator/pipeline_state.py +++ b/src/agentkit/orchestrator/pipeline_state.py @@ -32,14 +32,14 @@ class PipelineStateMemory: """In-memory pipeline state storage (testing / fallback).""" def __init__(self) -> None: - self._executions: dict[str, dict[str, Any]] = {} - self._step_history: dict[str, list[dict[str, Any]]] = {} + self._executions: dict[str, dict[str, object]] = {} + self._step_history: dict[str, list[dict[str, object]]] = {} async def create_execution( self, pipeline_name: str, steps: list[str], - input_data: dict[str, Any] | None = None, + input_data: dict[str, object] | None = None, tenant_id: str | None = None, ) -> str: execution_id = str(uuid.uuid4()) @@ -67,7 +67,7 @@ class PipelineStateMemory: execution_id: str, step_name: str, status: str, - output: dict[str, Any] | None = None, + output: dict[str, object] | None = None, error: str | None = None, duration_ms: int | None = None, ) -> None: @@ -88,7 +88,7 @@ class PipelineStateMemory: exec_state["error_message"] = error # Record step history event - step_event: dict[str, Any] = { + step_event: dict[str, object] = { "id": str(uuid.uuid4()), "execution_id": execution_id, "step_name": step_name, @@ -97,14 +97,16 @@ class PipelineStateMemory: "error_message": error, "duration_ms": duration_ms, "started_at": datetime.now(timezone.utc).isoformat(), - "completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None, + "completed_at": datetime.now(timezone.utc).isoformat() + if status in ("completed", "failed") + else None, } self._step_history[execution_id].append(step_event) async def complete_execution( self, execution_id: str, - final_output: dict[str, Any] | None = None, + final_output: dict[str, object] | None = None, ) -> None: exec_state = self._executions.get(execution_id) if exec_state is None: @@ -130,7 +132,7 @@ class PipelineStateMemory: exec_state["updated_at"] = now exec_state["completed_at"] = now - async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + async def get_execution(self, execution_id: str) -> dict[str, object] | None: return self._executions.get(execution_id) async def list_executions( @@ -138,17 +140,17 @@ class PipelineStateMemory: status: str | None = None, limit: int = 50, offset: int = 0, - ) -> list[dict[str, Any]]: + ) -> list[dict[str, object]]: results = list(self._executions.values()) if status: results = [e for e in results if e.get("status") == status] results.sort(key=lambda e: e.get("created_at", ""), reverse=True) return results[offset : offset + limit] - async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + async def get_step_history(self, execution_id: str) -> list[dict[str, object]]: return self._step_history.get(execution_id, []) - def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None: + def get_execution_sync(self, execution_id: str) -> dict[str, object] | None: """Synchronous accessor for execution state (used by Redis dual-write).""" return self._executions.get(execution_id) @@ -165,7 +167,7 @@ class PipelineStateRedis: def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None: self._redis_url = redis_url - self._redis: Any = None + self._redis: object | None = None self._fallback = PipelineStateMemory() self._use_fallback = False self._fallback_since: float | None = None @@ -181,8 +183,8 @@ class PipelineStateRedis: return self._redis async def _safe_redis_call( - self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any - ) -> Any: + self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object + ) -> object | None: """Execute a Redis call, falling back to memory on failure. After falling back, periodically retries Redis to enable recovery. @@ -192,6 +194,7 @@ class PipelineStateRedis: # Check if enough time has passed to attempt recovery if self._fallback_since is not None: import time as _time + elapsed = _time.monotonic() - self._fallback_since if elapsed >= self._RECOVERY_COOLDOWN_SECONDS: try: @@ -218,6 +221,7 @@ class PipelineStateRedis: logger.warning(f"Redis operation failed, switching to memory fallback: {exc}") self._use_fallback = True import time as _time + self._fallback_since = _time.monotonic() self._redis = None return None @@ -229,7 +233,7 @@ class PipelineStateRedis: self, pipeline_name: str, steps: list[str], - input_data: dict[str, Any] | None = None, + input_data: dict[str, object] | None = None, tenant_id: str | None = None, ) -> str: # Always write to fallback first for consistency @@ -238,7 +242,7 @@ class PipelineStateRedis: ) # Try Redis - async def _redis_create(redis: Any) -> None: + async def _redis_create(redis: object) -> None: state = self._fallback.get_execution_sync(execution_id) score = datetime.now(timezone.utc).timestamp() pipe = redis.pipeline() @@ -254,13 +258,15 @@ class PipelineStateRedis: execution_id: str, step_name: str, status: str, - output: dict[str, Any] | None = None, + output: dict[str, object] | None = None, error: str | None = None, duration_ms: int | None = None, ) -> None: - await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms) + await self._fallback.update_step( + execution_id, step_name, status, output, error, duration_ms + ) - async def _redis_update(redis: Any) -> None: + async def _redis_update(redis: object) -> None: state = self._fallback.get_execution_sync(execution_id) if state is None: return @@ -271,11 +277,11 @@ class PipelineStateRedis: async def complete_execution( self, execution_id: str, - final_output: dict[str, Any] | None = None, + final_output: dict[str, object] | None = None, ) -> None: await self._fallback.complete_execution(execution_id, final_output) - async def _redis_complete(redis: Any) -> None: + async def _redis_complete(redis: object) -> None: state = self._fallback.get_execution_sync(execution_id) if state is None: return @@ -291,7 +297,7 @@ class PipelineStateRedis: ) -> None: await self._fallback.fail_execution(execution_id, step_name, error) - async def _redis_fail(redis: Any) -> None: + async def _redis_fail(redis: object) -> None: state = self._fallback.get_execution_sync(execution_id) if state is None: return @@ -299,7 +305,7 @@ class PipelineStateRedis: await self._safe_redis_call(_redis_fail) - async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + async def get_execution(self, execution_id: str) -> dict[str, object] | None: # Try Redis first if not self._use_fallback: try: @@ -318,7 +324,7 @@ class PipelineStateRedis: status: str | None = None, limit: int = 50, offset: int = 0, - ) -> list[dict[str, Any]]: + ) -> list[dict[str, object]]: # Try Redis sorted set for efficient listing if not self._use_fallback: try: @@ -341,7 +347,7 @@ class PipelineStateRedis: return await self._fallback.list_executions(status, limit, offset) - async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + async def get_step_history(self, execution_id: str) -> list[dict[str, object]]: return await self._fallback.get_step_history(execution_id) async def health_check(self) -> bool: @@ -364,20 +370,18 @@ class PipelineStatePG: If session_factory is None, all methods are no-op. """ - def __init__(self, session_factory: Any = None) -> None: + def __init__(self, session_factory: object | None = None) -> None: self._session_factory = session_factory @property def enabled(self) -> bool: return self._session_factory is not None - async def persist_execution(self, state: dict[str, Any]) -> None: + async def persist_execution(self, state: dict[str, object]) -> None: """Write a completed/failed execution to PostgreSQL.""" if not self.enabled: return try: - from sqlalchemy.ext.asyncio import AsyncSession - async with self._session_factory() as session: model = PipelineExecutionModel( id=state["id"], @@ -390,18 +394,22 @@ class PipelineStatePG: final_output=state.get("final_output"), error_message=state.get("error_message"), tenant_id=state.get("tenant_id"), - created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None, - updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None, - completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None, + created_at=datetime.fromisoformat(state["created_at"]) + if state.get("created_at") + else None, + updated_at=datetime.fromisoformat(state["updated_at"]) + if state.get("updated_at") + else None, + completed_at=datetime.fromisoformat(state["completed_at"]) + if state.get("completed_at") + else None, ) await session.merge(model) await session.commit() except Exception as exc: logger.error(f"Failed to persist execution to PG: {exc}") - async def persist_step_history( - self, execution_id: str, steps: list[dict[str, Any]] - ) -> None: + async def persist_step_history(self, execution_id: str, steps: list[dict[str, object]]) -> None: """Write step history to PostgreSQL.""" if not self.enabled: return @@ -419,8 +427,12 @@ class PipelineStatePG: error_message=step.get("error_message"), duration_ms=step.get("duration_ms"), retry_attempt=step.get("retry_attempt", 0), - started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None, - completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None, + started_at=datetime.fromisoformat(step["started_at"]) + if step.get("started_at") + else None, + completed_at=datetime.fromisoformat(step["completed_at"]) + if step.get("completed_at") + else None, ) await session.merge(model) await session.commit() @@ -433,7 +445,7 @@ class PipelineStatePG: status: str | None = None, limit: int = 50, offset: int = 0, - ) -> list[dict[str, Any]]: + ) -> list[dict[str, object]]: """Query historical executions from PostgreSQL.""" if not self.enabled: return [] @@ -445,9 +457,7 @@ class PipelineStatePG: PipelineExecutionModel.created_at.desc() ) if pipeline_name: - stmt = stmt.where( - PipelineExecutionModel.pipeline_name == pipeline_name - ) + stmt = stmt.where(PipelineExecutionModel.pipeline_name == pipeline_name) if status: stmt = stmt.where(PipelineExecutionModel.status == status) stmt = stmt.offset(offset).limit(limit) @@ -458,7 +468,7 @@ class PipelineStatePG: logger.error(f"Failed to query executions from PG: {exc}") return [] - async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + async def get_execution(self, execution_id: str) -> dict[str, object] | None: """Get a single execution from PostgreSQL (for Redis miss fallback).""" if not self.enabled: return None @@ -479,7 +489,7 @@ class PipelineStatePG: return None @staticmethod - def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]: + def _model_to_dict(model: PipelineExecutionModel) -> dict[str, object]: return { "id": model.id, "pipeline_name": model.pipeline_name, @@ -509,7 +519,7 @@ class PipelineStateManager: def __init__( self, redis_url: str | None = None, - session_factory: Any = None, + session_factory: object | None = None, ) -> None: if redis_url: self._hot = PipelineStateRedis(redis_url=redis_url) @@ -529,7 +539,7 @@ class PipelineStateManager: self, pipeline_name: str, steps: list[str], - input_data: dict[str, Any] | None = None, + input_data: dict[str, object] | None = None, tenant_id: str | None = None, ) -> str: return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id) @@ -539,7 +549,7 @@ class PipelineStateManager: execution_id: str, step_name: str, status: str, - output: dict[str, Any] | None = None, + output: dict[str, object] | None = None, error: str | None = None, duration_ms: int | None = None, ) -> None: @@ -548,7 +558,7 @@ class PipelineStateManager: async def complete_execution( self, execution_id: str, - final_output: dict[str, Any] | None = None, + final_output: dict[str, object] | None = None, ) -> None: await self._hot.complete_execution(execution_id, final_output) # Persist to PG @@ -574,7 +584,7 @@ class PipelineStateManager: if step_history: await self._cold.persist_step_history(execution_id, step_history) - async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + async def get_execution(self, execution_id: str) -> dict[str, object] | None: # Redis / memory first state = await self._hot.get_execution(execution_id) if state is not None: @@ -587,7 +597,7 @@ class PipelineStateManager: status: str | None = None, limit: int = 50, offset: int = 0, - ) -> list[dict[str, Any]]: + ) -> list[dict[str, object]]: # Hot store for recent executions results = await self._hot.list_executions(status, limit, offset) if results: @@ -595,7 +605,7 @@ class PipelineStateManager: # Cold store for historical queries return await self._cold.query_executions(status=status, limit=limit, offset=offset) - async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + async def get_step_history(self, execution_id: str) -> list[dict[str, object]]: return await self._hot.get_step_history(execution_id) async def health_check(self) -> dict[str, bool]: diff --git a/src/agentkit/tools/computer_use_session.py b/src/agentkit/tools/computer_use_session.py index cfad758..492faf8 100644 --- a/src/agentkit/tools/computer_use_session.py +++ b/src/agentkit/tools/computer_use_session.py @@ -15,10 +15,14 @@ import time import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any +from typing import TypeAlias logger = logging.getLogger(__name__) +# Scalar session state values (text/number/flag/none). Screen-state dicts use +# dict[str, object] because they also hold tuple cursor positions. +SessionState: TypeAlias = dict[str, str | int | bool | None] + @dataclass class ScreenInfo: @@ -37,7 +41,7 @@ class ActionResult: output: str = "" screenshot_base64: str = "" error: str = "" - metadata: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, object] = field(default_factory=dict) class ComputerUseSession(ABC): @@ -56,7 +60,7 @@ class ComputerUseSession(ABC): self.session_id = session_id or str(uuid.uuid4()) self.screen = ScreenInfo(width=screen_width, height=screen_height) self._started = False - self._action_history: list[dict[str, Any]] = [] + self._action_history: list[dict[str, object]] = [] @property def is_started(self) -> bool: @@ -82,7 +86,7 @@ class ComputerUseSession(ABC): ... @abstractmethod - async def execute_action(self, action: str, **params: Any) -> ActionResult: + async def execute_action(self, action: str, **params: object) -> ActionResult: """执行 UI 操作 Args: @@ -94,18 +98,20 @@ class ComputerUseSession(ABC): """ ... - def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None: + def record_action(self, action: str, params: dict[str, object], result: ActionResult) -> None: """记录操作历史""" - self._action_history.append({ - "timestamp": time.time(), - "action": action, - "params": params, - "success": result.success, - "output": result.output[:200] if result.output else "", - }) + self._action_history.append( + { + "timestamp": time.time(), + "action": action, + "params": params, + "success": result.success, + "output": result.output[:200] if result.output else "", + } + ) @property - def action_history(self) -> list[dict[str, Any]]: + def action_history(self) -> list[dict[str, object]]: """获取操作历史(副本)""" return list(self._action_history) @@ -134,7 +140,7 @@ class InMemoryComputerUseSession(ComputerUseSession): screen_width=screen_width, screen_height=screen_height, ) - self._screen_state: dict[str, Any] = { + self._screen_state: dict[str, object] = { "focused_element": None, "cursor_position": (0, 0), "typed_text": "", @@ -173,7 +179,7 @@ class InMemoryComputerUseSession(ComputerUseSession): metadata={"screen_state": dict(self._screen_state)}, ) - async def execute_action(self, action: str, **params: Any) -> ActionResult: + async def execute_action(self, action: str, **params: object) -> ActionResult: """模拟执行 UI 操作""" if not self._started: return ActionResult( @@ -186,7 +192,7 @@ class InMemoryComputerUseSession(ComputerUseSession): self.record_action(action, params, result) return result - def _simulate_action(self, action: str, **params: Any) -> ActionResult: + def _simulate_action(self, action: str, **params: object) -> ActionResult: """模拟具体操作""" if action == "click": x = params.get("x", 0) @@ -270,18 +276,78 @@ class LocalComputerUseSession(ComputerUseSession): screen_width=screen_width, screen_height=screen_height, ) - self._pyautogui: Any = None + self._pyautogui: object | None = None # Allowed keys for the `key` action — prevents OS-level shortcut abuse _ALLOWED_KEYS: set[str] = { - "enter", "return", "tab", "backspace", "delete", "home", "end", - "up", "down", "left", "right", "pageup", "pagedown", - "space", "escape", "insert", - "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12", - "shift", "ctrl", "alt", "command", - "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", - "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", - "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", + "enter", + "return", + "tab", + "backspace", + "delete", + "home", + "end", + "up", + "down", + "left", + "right", + "pageup", + "pagedown", + "space", + "escape", + "insert", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "shift", + "ctrl", + "alt", + "command", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", } _ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"} _MAX_TEXT_LENGTH: int = 1000 @@ -291,6 +357,7 @@ class LocalComputerUseSession(ComputerUseSession): """启动本地桌面会话""" try: import pyautogui + self._pyautogui = pyautogui pyautogui.FAILSAFE = True pyautogui.PAUSE = 0.1 @@ -327,7 +394,7 @@ class LocalComputerUseSession(ComputerUseSession): except Exception as e: return ActionResult(success=False, action="screenshot", error=str(e)) - async def execute_action(self, action: str, **params: Any) -> ActionResult: + async def execute_action(self, action: str, **params: object) -> ActionResult: """在本地桌面执行 UI 操作""" if not self._started: return ActionResult(success=False, action=action, error="Session not started") @@ -343,7 +410,7 @@ class LocalComputerUseSession(ComputerUseSession): """Check if coordinates are within screen bounds.""" return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height - async def _execute_local_action(self, action: str, **params: Any) -> ActionResult: + async def _execute_local_action(self, action: str, **params: object) -> ActionResult: """Execute a local UI action with input validation.""" pg = self._pyautogui @@ -351,16 +418,24 @@ class LocalComputerUseSession(ComputerUseSession): x, y = params.get("x", 0), params.get("y", 0) button = params.get("button", "left") if button not in self._ALLOWED_BUTTONS: - return ActionResult(success=False, action="click", error=f"Invalid button: {button}") + return ActionResult( + success=False, action="click", error=f"Invalid button: {button}" + ) if not self._validate_coordinates(x, y): - return ActionResult(success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})") + return ActionResult( + success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})" + ) pg.click(x, y, button=button) return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})") if action == "type": text = params.get("text", "") if len(text) > self._MAX_TEXT_LENGTH: - return ActionResult(success=False, action="type", error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}") + return ActionResult( + success=False, + action="type", + error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}", + ) pg.write(text) return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}") @@ -369,16 +444,22 @@ class LocalComputerUseSession(ComputerUseSession): amount = params.get("amount", 3) clicks = amount if direction == "down" else -amount pg.scroll(clicks) - return ActionResult(success=True, action="scroll", output=f"Scrolled {direction} by {amount}") + return ActionResult( + success=True, action="scroll", output=f"Scrolled {direction} by {amount}" + ) if action == "drag": sx, sy = params.get("start_x", 0), params.get("start_y", 0) ex, ey = params.get("end_x", 0), params.get("end_y", 0) if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)): - return ActionResult(success=False, action="drag", error="Drag coordinates out of bounds") + return ActionResult( + success=False, action="drag", error="Drag coordinates out of bounds" + ) pg.moveTo(sx, sy) pg.dragTo(ex, ey, duration=0.5) - return ActionResult(success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})") + return ActionResult( + success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})" + ) if action == "key": key_name = params.get("key_name", "") @@ -487,7 +568,7 @@ class DockerComputerUseSession(ComputerUseSession): screenshot_base64="", ) - async def execute_action(self, action: str, **params: Any) -> ActionResult: + async def execute_action(self, action: str, **params: object) -> ActionResult: """在 Docker 虚拟桌面执行操作 Stub: 实际实现需要通过 Anthropic Computer Use API。 @@ -527,7 +608,7 @@ class ComputerUseSessionManager: def get_or_create( self, session_id: str | None = None, - **kwargs: Any, + **kwargs: object, ) -> ComputerUseSession: """获取或创建会话""" if session_id and session_id in self._sessions: