refactor(bitable,tools): replace Any with concrete types + Protocol (U4)

BitableRecord/FormulaResult/SessionState TypeAlias replace dict[str, Any]; _redis/_engine/_session_factory typed as object | None with TYPE_CHECKING Protocol (_RedisLike, _RecalcWorker); Coroutine[Any, Any, Any] retained as legitimate type param.

Baseline 40 : Any occurrences -> 0 across 6 in-scope files (target <=5). Deferred: repository.py/recalc_worker.py/ingestion/* (10 occurrences, separate PR).

ruff clean; 367 passed + 116 skipped (bitable + pipeline_state + tools).
This commit is contained in:
chiguyong 2026-06-30 22:32:30 +08:00
parent be5c4e09f8
commit 1033346913
6 changed files with 256 additions and 133 deletions

View File

@ -18,7 +18,7 @@ import logging
import os import os
import uuid as _uuid import uuid as _uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from types import TracebackType
from sqlalchemy import ( from sqlalchemy import (
Column, Column,
@ -191,7 +191,7 @@ class MetaModel(BitableBase):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _apply_v2_migration(conn: Any) -> None: async def _apply_v2_migration(conn: object) -> None:
"""V2 migration: create ``bitable_files`` table + add ``file_id`` to tables. """V2 migration: create ``bitable_files`` table + add ``file_id`` to tables.
Idempotent safe to call on fresh installs (``create_all`` already made Idempotent safe to call on fresh installs (``create_all`` already made
@ -265,8 +265,8 @@ class BitableDB:
def __init__(self, database_url: str | None = None) -> None: def __init__(self, database_url: str | None = None) -> None:
self._database_url = database_url or _resolve_database_url() self._database_url = database_url or _resolve_database_url()
self._engine: Any = None self._engine: object | None = None
self._session_factory: Any = None self._session_factory: object | None = None
self._initialized = False self._initialized = False
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
@ -275,11 +275,11 @@ class BitableDB:
return self._database_url return self._database_url
@property @property
def engine(self) -> Any: def engine(self) -> object | None:
return self._engine return self._engine
@property @property
def session_factory(self) -> Any: def session_factory(self) -> object | None:
return self._session_factory return self._session_factory
@property @property
@ -365,7 +365,12 @@ class BitableDB:
await self.init() await self.init()
return self return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close() await self.close()

View File

@ -12,12 +12,17 @@ based on the calling context — see :mod:`agentkit.bitable.formula.engine`.
from __future__ import annotations from __future__ import annotations
from typing import Any, Callable from typing import Callable, TypeAlias
# A formula evaluates to a scalar primitive: text, number, or nothing.
# bool is intentionally excluded — comparisons live in the parser layer
# and never reach the function registry.
FormulaResult: TypeAlias = str | int | float | None
# ── Aggregate functions (operate on lists) ──────────────── # ── Aggregate functions (operate on lists) ────────────────
def _sum(values: list[Any]) -> float | int: def _sum(values: list[FormulaResult]) -> float | int:
"""Sum of numeric values, ignoring None/empty.""" """Sum of numeric values, ignoring None/empty."""
total = 0 total = 0
for v in values: for v in values:
@ -27,7 +32,7 @@ def _sum(values: list[Any]) -> float | int:
return total return total
def _avg(values: list[Any]) -> float: def _avg(values: list[FormulaResult]) -> float:
"""Average of numeric values, ignoring None/empty.""" """Average of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""] nums = [v for v in values if v is not None and v != ""]
if not nums: if not nums:
@ -35,12 +40,12 @@ def _avg(values: list[Any]) -> float:
return sum(nums) / len(nums) return sum(nums) / len(nums)
def _count(values: list[Any]) -> int: def _count(values: list[FormulaResult]) -> int:
"""Count of non-empty values.""" """Count of non-empty values."""
return sum(1 for v in values if v is not None and v != "") return sum(1 for v in values if v is not None and v != "")
def _min(values: list[Any]) -> Any: def _min(values: list[FormulaResult]) -> FormulaResult:
"""Minimum of numeric values, ignoring None/empty.""" """Minimum of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""] nums = [v for v in values if v is not None and v != ""]
if not nums: if not nums:
@ -48,7 +53,7 @@ def _min(values: list[Any]) -> Any:
return min(nums) return min(nums)
def _max(values: list[Any]) -> Any: def _max(values: list[FormulaResult]) -> FormulaResult:
"""Maximum of numeric values, ignoring None/empty.""" """Maximum of numeric values, ignoring None/empty."""
nums = [v for v in values if v is not None and v != ""] nums = [v for v in values if v is not None and v != ""]
if not nums: if not nums:
@ -59,25 +64,29 @@ def _max(values: list[Any]) -> Any:
# ── Scalar functions ────────────────────────────────────── # ── Scalar functions ──────────────────────────────────────
def _abs(value: Any) -> Any: def _abs(value: FormulaResult) -> FormulaResult:
return abs(value) return abs(value)
def _round(value: Any, digits: int = 0) -> float: def _round(value: FormulaResult, digits: int = 0) -> float:
return round(value, digits) return round(value, digits)
def _if(condition: Any, true_val: Any, false_val: Any = None) -> Any: def _if(
condition: FormulaResult,
true_val: FormulaResult,
false_val: FormulaResult = None,
) -> FormulaResult:
return true_val if condition else false_val return true_val if condition else false_val
def _len(value: Any) -> int: def _len(value: FormulaResult) -> int:
if value is None: if value is None:
return 0 return 0
return len(str(value)) return len(str(value))
def _concat(*args: Any) -> str: def _concat(*args: FormulaResult) -> str:
"""Concatenate all arguments as strings.""" """Concatenate all arguments as strings."""
return "".join(str(a) for a in args if a is not None) return "".join(str(a) for a in args if a is not None)
@ -87,7 +96,7 @@ def _concat(*args: Any) -> str:
# Functions that aggregate a column (receive a list of all column values) # Functions that aggregate a column (receive a list of all column values)
AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"}) AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"})
FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = { FUNCTION_REGISTRY: dict[str, Callable[..., FormulaResult]] = {
"SUM": _sum, "SUM": _sum,
"AVG": _avg, "AVG": _avg,
"COUNT": _count, "COUNT": _count,

View File

@ -25,9 +25,9 @@ from __future__ import annotations
import ast import ast
import re import re
from typing import Any from typing import Callable
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY from agentkit.bitable.formula.functions import FUNCTION_REGISTRY, FormulaResult
# ── Exceptions ──────────────────────────────────────────── # ── Exceptions ────────────────────────────────────────────
@ -184,9 +184,9 @@ def parse_formula(
def evaluate_ast( def evaluate_ast(
tree: ast.Expression, tree: ast.Expression,
field_values: dict[str, Any], field_values: dict[str, FormulaResult | list[FormulaResult]],
functions: dict[str, Any], functions: dict[str, Callable[..., FormulaResult]],
) -> Any: ) -> FormulaResult:
"""Evaluate a parsed formula AST against field values and functions. """Evaluate a parsed formula AST against field values and functions.
This is NOT ``eval()`` it's a manual AST walker that only processes This is NOT ``eval()`` it's a manual AST walker that only processes
@ -204,7 +204,11 @@ def evaluate_ast(
return _eval_node(tree.body, field_values, functions) return _eval_node(tree.body, field_values, functions)
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any: def _eval_node(
node: ast.AST,
fields: dict[str, FormulaResult | list[FormulaResult]],
functions: dict[str, Callable[..., FormulaResult]],
) -> FormulaResult:
"""Recursively evaluate an AST node.""" """Recursively evaluate an AST node."""
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
return node.value return node.value
@ -274,7 +278,7 @@ def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any])
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}") raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any: def _apply_binop(op: ast.AST, left: FormulaResult, right: FormulaResult) -> FormulaResult:
"""Apply a binary operator.""" """Apply a binary operator."""
if isinstance(op, ast.Add): if isinstance(op, ast.Add):
# String concat or numeric addition # String concat or numeric addition
@ -294,7 +298,7 @@ def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}") raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool: def _apply_compare(op: ast.AST, left: FormulaResult, right: FormulaResult) -> bool:
"""Apply a comparison operator.""" """Apply a comparison operator."""
if isinstance(op, ast.Eq): if isinstance(op, ast.Eq):
return left == right return left == right

View File

@ -12,7 +12,7 @@ import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, TypeAlias
from agentkit.bitable.db import BitableDB from agentkit.bitable.db import BitableDB
from agentkit.bitable.models import ( from agentkit.bitable.models import (
@ -29,13 +29,27 @@ from agentkit.bitable.models import (
) )
from agentkit.bitable.repository import BitableRepository from agentkit.bitable.repository import BitableRepository
if TYPE_CHECKING:
from typing import Protocol
class _RecalcWorker(Protocol):
"""Structural type for the recalc worker's cache-invalidation surface."""
def invalidate_engine(self, table_id: str) -> None: ...
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Record values are JSON scalars (text/number/none). Attachment/image fields
# store lists of dicts at runtime, but the common-case scalar shape is captured
# here for annotation clarity; fall back to dict[str, object] where lists occur.
BitableRecord: TypeAlias = dict[str, str | int | float | None]
class FieldDependencyError(Exception): class FieldDependencyError(Exception):
"""Raised when deleting a field that has dependencies (formula refs, PK, views).""" """Raised when deleting a field that has dependencies (formula refs, PK, views)."""
def __init__(self, message: str, dependencies: dict[str, Any]) -> None: def __init__(self, message: str, dependencies: dict[str, object]) -> None:
super().__init__(message) super().__init__(message)
self.dependencies = dependencies self.dependencies = dependencies
@ -52,13 +66,13 @@ class BitableService:
def __init__(self, db: BitableDB) -> None: def __init__(self, db: BitableDB) -> None:
self._db = db self._db = db
self._repo = BitableRepository(db) self._repo = BitableRepository(db)
self._recalc_worker: Any = None # RecalcWorker, set via set_recalc_worker self._recalc_worker: _RecalcWorker | None = None # set via set_recalc_worker
@property @property
def repo(self) -> BitableRepository: def repo(self) -> BitableRepository:
return self._repo return self._repo
def set_recalc_worker(self, worker: Any) -> None: def set_recalc_worker(self, worker: _RecalcWorker) -> None:
"""Register the long-lived RecalcWorker so field changes can invalidate its engine cache. """Register the long-lived RecalcWorker so field changes can invalidate its engine cache.
Called after both service and worker are constructed (breaks the Called after both service and worker are constructed (breaks the
@ -95,7 +109,7 @@ class BitableService:
async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]: async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]:
return await self._repo.list_files(owner_user_id=owner_user_id) return await self._repo.list_files(owner_user_id=owner_user_id)
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None: async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
return await self._repo.update_file(file_id, **kwargs) return await self._repo.update_file(file_id, **kwargs)
async def delete_file(self, file_id: str) -> bool: async def delete_file(self, file_id: str) -> bool:
@ -162,7 +176,7 @@ class BitableService:
async def list_tables(self, owner_user_id: str | None = None) -> list[Table]: async def list_tables(self, owner_user_id: str | None = None) -> list[Table]:
return await self._repo.list_tables(owner_user_id=owner_user_id) return await self._repo.list_tables(owner_user_id=owner_user_id)
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None: async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
"""Update table attrs. Creates PK unique index if primary_key_field_id is set.""" """Update table attrs. Creates PK unique index if primary_key_field_id is set."""
table = await self._repo.update_table(table_id, **kwargs) table = await self._repo.update_table(table_id, **kwargs)
if table and kwargs.get("primary_key_field_id"): if table and kwargs.get("primary_key_field_id"):
@ -179,7 +193,7 @@ class BitableService:
table_id: str, table_id: str,
name: str, name: str,
field_type: FieldType, field_type: FieldType,
config: dict[str, Any] | None = None, config: dict[str, object] | None = None,
owner: FieldOwner = FieldOwner.user, owner: FieldOwner = FieldOwner.user,
) -> Field: ) -> Field:
"""Create a new field. U2 will add formula validation and DAG updates.""" """Create a new field. U2 will add formula validation and DAG updates."""
@ -201,7 +215,7 @@ class BitableService:
async def list_fields(self, table_id: str) -> list[Field]: async def list_fields(self, table_id: str) -> list[Field]:
return await self._repo.list_fields(table_id) return await self._repo.list_fields(table_id)
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None: async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
"""Update a field. U2 will add dependency checking.""" """Update a field. U2 will add dependency checking."""
field = await self._repo.update_field(field_id, **kwargs) field = await self._repo.update_field(field_id, **kwargs)
if field is not None: if field is not None:
@ -220,7 +234,7 @@ class BitableService:
return False return False
# Check dependencies # Check dependencies
deps: dict[str, Any] = {} deps: dict[str, object] = {}
# 1. Is it a primary key field? # 1. Is it a primary key field?
table = await self._repo.get_table(field.table_id) table = await self._repo.get_table(field.table_id)
@ -264,7 +278,7 @@ class BitableService:
async def create_record( async def create_record(
self, self,
table_id: str, table_id: str,
values: dict[str, Any] | None = None, values: BitableRecord | None = None,
actor_user_id: str | None = None, actor_user_id: str | None = None,
) -> Record: ) -> Record:
"""Create a new record. Triggers recalc for affected formula fields. """Create a new record. Triggers recalc for affected formula fields.
@ -291,7 +305,7 @@ class BitableService:
return record return record
async def create_records_batch( async def create_records_batch(
self, table_id: str, records_values: list[dict[str, Any]] self, table_id: str, records_values: list[BitableRecord]
) -> list[Record]: ) -> list[Record]:
"""Batch-create records (P2 #19). Triggers recalc for each record. """Batch-create records (P2 #19). Triggers recalc for each record.
@ -319,8 +333,8 @@ class BitableService:
async def list_records_filtered( async def list_records_filtered(
self, self,
table_id: str, table_id: str,
filters: list[dict[str, Any]] | None = None, filters: list[dict[str, object]] | None = None,
sorts: list[dict[str, Any]] | None = None, sorts: list[dict[str, object]] | None = None,
cursor: str | None = None, cursor: str | None = None,
limit: int = 50, limit: int = 50,
) -> tuple[list[Record], str | None]: ) -> tuple[list[Record], str | None]:
@ -345,7 +359,7 @@ class BitableService:
table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit
) )
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None: async def update_record_values(self, record_id: str, values: BitableRecord) -> Record | None:
"""Update a record's values (full replace). Triggers recalc for affected formulas.""" """Update a record's values (full replace). Triggers recalc for affected formulas."""
record = await self._repo.update_record_values(record_id, values) record = await self._repo.update_record_values(record_id, values)
if record is not None: if record is not None:
@ -431,9 +445,9 @@ class BitableService:
async def upsert_records( async def upsert_records(
self, self,
table_id: str, table_id: str,
records: list[dict[str, Any]], records: list[BitableRecord],
primary_key_field_id: str, primary_key_field_id: str,
) -> dict[str, Any]: ) -> dict[str, int]:
"""Upsert records by primary key using jsonb_set (KTD8). """Upsert records by primary key using jsonb_set (KTD8).
For each record: For each record:
@ -454,12 +468,12 @@ class BitableService:
agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent} agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent}
# Partition records into insert vs update lists, collecting PK values. # Partition records into insert vs update lists, collecting PK values.
to_insert: list[dict[str, Any]] = [] to_insert: list[BitableRecord] = []
to_update: list[tuple[dict[str, Any], str]] = [] # (values, existing_record_id) to_update: list[tuple[BitableRecord, str]] = [] # (values, existing_record_id)
skipped = 0 skipped = 0
# Collect all non-None PK values for batch lookup. # Collect all non-None PK values for batch lookup.
pk_values_by_str: dict[str, dict[str, Any]] = {} pk_values_by_str: dict[str, BitableRecord] = {}
for rec_values in records: for rec_values in records:
pk_value = rec_values.get(primary_key_field_id) pk_value = rec_values.get(primary_key_field_id)
if pk_value is None: if pk_value is None:
@ -504,7 +518,7 @@ class BitableService:
table_id: str, table_id: str,
name: str, name: str,
view_type: ViewType = ViewType.grid, view_type: ViewType = ViewType.grid,
config: dict[str, Any] | None = None, config: dict[str, object] | None = None,
) -> View: ) -> View:
return await self._repo.create_view( return await self._repo.create_view(
table_id=table_id, table_id=table_id,
@ -516,7 +530,7 @@ class BitableService:
async def list_views(self, table_id: str) -> list[View]: async def list_views(self, table_id: str) -> list[View]:
return await self._repo.list_views(table_id) return await self._repo.list_views(table_id)
async def update_view(self, view_id: str, **kwargs: Any) -> View | None: async def update_view(self, view_id: str, **kwargs: object) -> View | None:
return await self._repo.update_view(view_id, **kwargs) return await self._repo.update_view(view_id, **kwargs)
async def get_view(self, view_id: str) -> View | None: async def get_view(self, view_id: str) -> View | None:

View File

@ -32,14 +32,14 @@ class PipelineStateMemory:
"""In-memory pipeline state storage (testing / fallback).""" """In-memory pipeline state storage (testing / fallback)."""
def __init__(self) -> None: def __init__(self) -> None:
self._executions: dict[str, dict[str, Any]] = {} self._executions: dict[str, dict[str, object]] = {}
self._step_history: dict[str, list[dict[str, Any]]] = {} self._step_history: dict[str, list[dict[str, object]]] = {}
async def create_execution( async def create_execution(
self, self,
pipeline_name: str, pipeline_name: str,
steps: list[str], steps: list[str],
input_data: dict[str, Any] | None = None, input_data: dict[str, object] | None = None,
tenant_id: str | None = None, tenant_id: str | None = None,
) -> str: ) -> str:
execution_id = str(uuid.uuid4()) execution_id = str(uuid.uuid4())
@ -67,7 +67,7 @@ class PipelineStateMemory:
execution_id: str, execution_id: str,
step_name: str, step_name: str,
status: str, status: str,
output: dict[str, Any] | None = None, output: dict[str, object] | None = None,
error: str | None = None, error: str | None = None,
duration_ms: int | None = None, duration_ms: int | None = None,
) -> None: ) -> None:
@ -88,7 +88,7 @@ class PipelineStateMemory:
exec_state["error_message"] = error exec_state["error_message"] = error
# Record step history event # Record step history event
step_event: dict[str, Any] = { step_event: dict[str, object] = {
"id": str(uuid.uuid4()), "id": str(uuid.uuid4()),
"execution_id": execution_id, "execution_id": execution_id,
"step_name": step_name, "step_name": step_name,
@ -97,14 +97,16 @@ class PipelineStateMemory:
"error_message": error, "error_message": error,
"duration_ms": duration_ms, "duration_ms": duration_ms,
"started_at": datetime.now(timezone.utc).isoformat(), "started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None, "completed_at": datetime.now(timezone.utc).isoformat()
if status in ("completed", "failed")
else None,
} }
self._step_history[execution_id].append(step_event) self._step_history[execution_id].append(step_event)
async def complete_execution( async def complete_execution(
self, self,
execution_id: str, execution_id: str,
final_output: dict[str, Any] | None = None, final_output: dict[str, object] | None = None,
) -> None: ) -> None:
exec_state = self._executions.get(execution_id) exec_state = self._executions.get(execution_id)
if exec_state is None: if exec_state is None:
@ -130,7 +132,7 @@ class PipelineStateMemory:
exec_state["updated_at"] = now exec_state["updated_at"] = now
exec_state["completed_at"] = now exec_state["completed_at"] = now
async def get_execution(self, execution_id: str) -> dict[str, Any] | None: async def get_execution(self, execution_id: str) -> dict[str, object] | None:
return self._executions.get(execution_id) return self._executions.get(execution_id)
async def list_executions( async def list_executions(
@ -138,17 +140,17 @@ class PipelineStateMemory:
status: str | None = None, status: str | None = None,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> list[dict[str, Any]]: ) -> list[dict[str, object]]:
results = list(self._executions.values()) results = list(self._executions.values())
if status: if status:
results = [e for e in results if e.get("status") == status] results = [e for e in results if e.get("status") == status]
results.sort(key=lambda e: e.get("created_at", ""), reverse=True) results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
return results[offset : offset + limit] return results[offset : offset + limit]
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return self._step_history.get(execution_id, []) return self._step_history.get(execution_id, [])
def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None: def get_execution_sync(self, execution_id: str) -> dict[str, object] | None:
"""Synchronous accessor for execution state (used by Redis dual-write).""" """Synchronous accessor for execution state (used by Redis dual-write)."""
return self._executions.get(execution_id) return self._executions.get(execution_id)
@ -165,7 +167,7 @@ class PipelineStateRedis:
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None: def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
self._redis_url = redis_url self._redis_url = redis_url
self._redis: Any = None self._redis: object | None = None
self._fallback = PipelineStateMemory() self._fallback = PipelineStateMemory()
self._use_fallback = False self._use_fallback = False
self._fallback_since: float | None = None self._fallback_since: float | None = None
@ -181,8 +183,8 @@ class PipelineStateRedis:
return self._redis return self._redis
async def _safe_redis_call( async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
) -> Any: ) -> object | None:
"""Execute a Redis call, falling back to memory on failure. """Execute a Redis call, falling back to memory on failure.
After falling back, periodically retries Redis to enable recovery. After falling back, periodically retries Redis to enable recovery.
@ -192,6 +194,7 @@ class PipelineStateRedis:
# Check if enough time has passed to attempt recovery # Check if enough time has passed to attempt recovery
if self._fallback_since is not None: if self._fallback_since is not None:
import time as _time import time as _time
elapsed = _time.monotonic() - self._fallback_since elapsed = _time.monotonic() - self._fallback_since
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS: if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
try: try:
@ -218,6 +221,7 @@ class PipelineStateRedis:
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}") logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
self._use_fallback = True self._use_fallback = True
import time as _time import time as _time
self._fallback_since = _time.monotonic() self._fallback_since = _time.monotonic()
self._redis = None self._redis = None
return None return None
@ -229,7 +233,7 @@ class PipelineStateRedis:
self, self,
pipeline_name: str, pipeline_name: str,
steps: list[str], steps: list[str],
input_data: dict[str, Any] | None = None, input_data: dict[str, object] | None = None,
tenant_id: str | None = None, tenant_id: str | None = None,
) -> str: ) -> str:
# Always write to fallback first for consistency # Always write to fallback first for consistency
@ -238,7 +242,7 @@ class PipelineStateRedis:
) )
# Try Redis # Try Redis
async def _redis_create(redis: Any) -> None: async def _redis_create(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id) state = self._fallback.get_execution_sync(execution_id)
score = datetime.now(timezone.utc).timestamp() score = datetime.now(timezone.utc).timestamp()
pipe = redis.pipeline() pipe = redis.pipeline()
@ -254,13 +258,15 @@ class PipelineStateRedis:
execution_id: str, execution_id: str,
step_name: str, step_name: str,
status: str, status: str,
output: dict[str, Any] | None = None, output: dict[str, object] | None = None,
error: str | None = None, error: str | None = None,
duration_ms: int | None = None, duration_ms: int | None = None,
) -> None: ) -> None:
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms) await self._fallback.update_step(
execution_id, step_name, status, output, error, duration_ms
)
async def _redis_update(redis: Any) -> None: async def _redis_update(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id) state = self._fallback.get_execution_sync(execution_id)
if state is None: if state is None:
return return
@ -271,11 +277,11 @@ class PipelineStateRedis:
async def complete_execution( async def complete_execution(
self, self,
execution_id: str, execution_id: str,
final_output: dict[str, Any] | None = None, final_output: dict[str, object] | None = None,
) -> None: ) -> None:
await self._fallback.complete_execution(execution_id, final_output) await self._fallback.complete_execution(execution_id, final_output)
async def _redis_complete(redis: Any) -> None: async def _redis_complete(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id) state = self._fallback.get_execution_sync(execution_id)
if state is None: if state is None:
return return
@ -291,7 +297,7 @@ class PipelineStateRedis:
) -> None: ) -> None:
await self._fallback.fail_execution(execution_id, step_name, error) await self._fallback.fail_execution(execution_id, step_name, error)
async def _redis_fail(redis: Any) -> None: async def _redis_fail(redis: object) -> None:
state = self._fallback.get_execution_sync(execution_id) state = self._fallback.get_execution_sync(execution_id)
if state is None: if state is None:
return return
@ -299,7 +305,7 @@ class PipelineStateRedis:
await self._safe_redis_call(_redis_fail) await self._safe_redis_call(_redis_fail)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None: async def get_execution(self, execution_id: str) -> dict[str, object] | None:
# Try Redis first # Try Redis first
if not self._use_fallback: if not self._use_fallback:
try: try:
@ -318,7 +324,7 @@ class PipelineStateRedis:
status: str | None = None, status: str | None = None,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> list[dict[str, Any]]: ) -> list[dict[str, object]]:
# Try Redis sorted set for efficient listing # Try Redis sorted set for efficient listing
if not self._use_fallback: if not self._use_fallback:
try: try:
@ -341,7 +347,7 @@ class PipelineStateRedis:
return await self._fallback.list_executions(status, limit, offset) return await self._fallback.list_executions(status, limit, offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return await self._fallback.get_step_history(execution_id) return await self._fallback.get_step_history(execution_id)
async def health_check(self) -> bool: async def health_check(self) -> bool:
@ -364,20 +370,18 @@ class PipelineStatePG:
If session_factory is None, all methods are no-op. If session_factory is None, all methods are no-op.
""" """
def __init__(self, session_factory: Any = None) -> None: def __init__(self, session_factory: object | None = None) -> None:
self._session_factory = session_factory self._session_factory = session_factory
@property @property
def enabled(self) -> bool: def enabled(self) -> bool:
return self._session_factory is not None return self._session_factory is not None
async def persist_execution(self, state: dict[str, Any]) -> None: async def persist_execution(self, state: dict[str, object]) -> None:
"""Write a completed/failed execution to PostgreSQL.""" """Write a completed/failed execution to PostgreSQL."""
if not self.enabled: if not self.enabled:
return return
try: try:
from sqlalchemy.ext.asyncio import AsyncSession
async with self._session_factory() as session: async with self._session_factory() as session:
model = PipelineExecutionModel( model = PipelineExecutionModel(
id=state["id"], id=state["id"],
@ -390,18 +394,22 @@ class PipelineStatePG:
final_output=state.get("final_output"), final_output=state.get("final_output"),
error_message=state.get("error_message"), error_message=state.get("error_message"),
tenant_id=state.get("tenant_id"), tenant_id=state.get("tenant_id"),
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None, created_at=datetime.fromisoformat(state["created_at"])
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None, if state.get("created_at")
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None, else None,
updated_at=datetime.fromisoformat(state["updated_at"])
if state.get("updated_at")
else None,
completed_at=datetime.fromisoformat(state["completed_at"])
if state.get("completed_at")
else None,
) )
await session.merge(model) await session.merge(model)
await session.commit() await session.commit()
except Exception as exc: except Exception as exc:
logger.error(f"Failed to persist execution to PG: {exc}") logger.error(f"Failed to persist execution to PG: {exc}")
async def persist_step_history( async def persist_step_history(self, execution_id: str, steps: list[dict[str, object]]) -> None:
self, execution_id: str, steps: list[dict[str, Any]]
) -> None:
"""Write step history to PostgreSQL.""" """Write step history to PostgreSQL."""
if not self.enabled: if not self.enabled:
return return
@ -419,8 +427,12 @@ class PipelineStatePG:
error_message=step.get("error_message"), error_message=step.get("error_message"),
duration_ms=step.get("duration_ms"), duration_ms=step.get("duration_ms"),
retry_attempt=step.get("retry_attempt", 0), retry_attempt=step.get("retry_attempt", 0),
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None, started_at=datetime.fromisoformat(step["started_at"])
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None, if step.get("started_at")
else None,
completed_at=datetime.fromisoformat(step["completed_at"])
if step.get("completed_at")
else None,
) )
await session.merge(model) await session.merge(model)
await session.commit() await session.commit()
@ -433,7 +445,7 @@ class PipelineStatePG:
status: str | None = None, status: str | None = None,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> list[dict[str, Any]]: ) -> list[dict[str, object]]:
"""Query historical executions from PostgreSQL.""" """Query historical executions from PostgreSQL."""
if not self.enabled: if not self.enabled:
return [] return []
@ -445,9 +457,7 @@ class PipelineStatePG:
PipelineExecutionModel.created_at.desc() PipelineExecutionModel.created_at.desc()
) )
if pipeline_name: if pipeline_name:
stmt = stmt.where( stmt = stmt.where(PipelineExecutionModel.pipeline_name == pipeline_name)
PipelineExecutionModel.pipeline_name == pipeline_name
)
if status: if status:
stmt = stmt.where(PipelineExecutionModel.status == status) stmt = stmt.where(PipelineExecutionModel.status == status)
stmt = stmt.offset(offset).limit(limit) stmt = stmt.offset(offset).limit(limit)
@ -458,7 +468,7 @@ class PipelineStatePG:
logger.error(f"Failed to query executions from PG: {exc}") logger.error(f"Failed to query executions from PG: {exc}")
return [] return []
async def get_execution(self, execution_id: str) -> dict[str, Any] | None: async def get_execution(self, execution_id: str) -> dict[str, object] | None:
"""Get a single execution from PostgreSQL (for Redis miss fallback).""" """Get a single execution from PostgreSQL (for Redis miss fallback)."""
if not self.enabled: if not self.enabled:
return None return None
@ -479,7 +489,7 @@ class PipelineStatePG:
return None return None
@staticmethod @staticmethod
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]: def _model_to_dict(model: PipelineExecutionModel) -> dict[str, object]:
return { return {
"id": model.id, "id": model.id,
"pipeline_name": model.pipeline_name, "pipeline_name": model.pipeline_name,
@ -509,7 +519,7 @@ class PipelineStateManager:
def __init__( def __init__(
self, self,
redis_url: str | None = None, redis_url: str | None = None,
session_factory: Any = None, session_factory: object | None = None,
) -> None: ) -> None:
if redis_url: if redis_url:
self._hot = PipelineStateRedis(redis_url=redis_url) self._hot = PipelineStateRedis(redis_url=redis_url)
@ -529,7 +539,7 @@ class PipelineStateManager:
self, self,
pipeline_name: str, pipeline_name: str,
steps: list[str], steps: list[str],
input_data: dict[str, Any] | None = None, input_data: dict[str, object] | None = None,
tenant_id: str | None = None, tenant_id: str | None = None,
) -> str: ) -> str:
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id) return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
@ -539,7 +549,7 @@ class PipelineStateManager:
execution_id: str, execution_id: str,
step_name: str, step_name: str,
status: str, status: str,
output: dict[str, Any] | None = None, output: dict[str, object] | None = None,
error: str | None = None, error: str | None = None,
duration_ms: int | None = None, duration_ms: int | None = None,
) -> None: ) -> None:
@ -548,7 +558,7 @@ class PipelineStateManager:
async def complete_execution( async def complete_execution(
self, self,
execution_id: str, execution_id: str,
final_output: dict[str, Any] | None = None, final_output: dict[str, object] | None = None,
) -> None: ) -> None:
await self._hot.complete_execution(execution_id, final_output) await self._hot.complete_execution(execution_id, final_output)
# Persist to PG # Persist to PG
@ -574,7 +584,7 @@ class PipelineStateManager:
if step_history: if step_history:
await self._cold.persist_step_history(execution_id, step_history) await self._cold.persist_step_history(execution_id, step_history)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None: async def get_execution(self, execution_id: str) -> dict[str, object] | None:
# Redis / memory first # Redis / memory first
state = await self._hot.get_execution(execution_id) state = await self._hot.get_execution(execution_id)
if state is not None: if state is not None:
@ -587,7 +597,7 @@ class PipelineStateManager:
status: str | None = None, status: str | None = None,
limit: int = 50, limit: int = 50,
offset: int = 0, offset: int = 0,
) -> list[dict[str, Any]]: ) -> list[dict[str, object]]:
# Hot store for recent executions # Hot store for recent executions
results = await self._hot.list_executions(status, limit, offset) results = await self._hot.list_executions(status, limit, offset)
if results: if results:
@ -595,7 +605,7 @@ class PipelineStateManager:
# Cold store for historical queries # Cold store for historical queries
return await self._cold.query_executions(status=status, limit=limit, offset=offset) return await self._cold.query_executions(status=status, limit=limit, offset=offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
return await self._hot.get_step_history(execution_id) return await self._hot.get_step_history(execution_id)
async def health_check(self) -> dict[str, bool]: async def health_check(self) -> dict[str, bool]:

View File

@ -15,10 +15,14 @@ import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import TypeAlias
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Scalar session state values (text/number/flag/none). Screen-state dicts use
# dict[str, object] because they also hold tuple cursor positions.
SessionState: TypeAlias = dict[str, str | int | bool | None]
@dataclass @dataclass
class ScreenInfo: class ScreenInfo:
@ -37,7 +41,7 @@ class ActionResult:
output: str = "" output: str = ""
screenshot_base64: str = "" screenshot_base64: str = ""
error: str = "" error: str = ""
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, object] = field(default_factory=dict)
class ComputerUseSession(ABC): class ComputerUseSession(ABC):
@ -56,7 +60,7 @@ class ComputerUseSession(ABC):
self.session_id = session_id or str(uuid.uuid4()) self.session_id = session_id or str(uuid.uuid4())
self.screen = ScreenInfo(width=screen_width, height=screen_height) self.screen = ScreenInfo(width=screen_width, height=screen_height)
self._started = False self._started = False
self._action_history: list[dict[str, Any]] = [] self._action_history: list[dict[str, object]] = []
@property @property
def is_started(self) -> bool: def is_started(self) -> bool:
@ -82,7 +86,7 @@ class ComputerUseSession(ABC):
... ...
@abstractmethod @abstractmethod
async def execute_action(self, action: str, **params: Any) -> ActionResult: async def execute_action(self, action: str, **params: object) -> ActionResult:
"""执行 UI 操作 """执行 UI 操作
Args: Args:
@ -94,18 +98,20 @@ class ComputerUseSession(ABC):
""" """
... ...
def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None: def record_action(self, action: str, params: dict[str, object], result: ActionResult) -> None:
"""记录操作历史""" """记录操作历史"""
self._action_history.append({ self._action_history.append(
"timestamp": time.time(), {
"action": action, "timestamp": time.time(),
"params": params, "action": action,
"success": result.success, "params": params,
"output": result.output[:200] if result.output else "", "success": result.success,
}) "output": result.output[:200] if result.output else "",
}
)
@property @property
def action_history(self) -> list[dict[str, Any]]: def action_history(self) -> list[dict[str, object]]:
"""获取操作历史(副本)""" """获取操作历史(副本)"""
return list(self._action_history) return list(self._action_history)
@ -134,7 +140,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
screen_width=screen_width, screen_width=screen_width,
screen_height=screen_height, screen_height=screen_height,
) )
self._screen_state: dict[str, Any] = { self._screen_state: dict[str, object] = {
"focused_element": None, "focused_element": None,
"cursor_position": (0, 0), "cursor_position": (0, 0),
"typed_text": "", "typed_text": "",
@ -173,7 +179,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
metadata={"screen_state": dict(self._screen_state)}, metadata={"screen_state": dict(self._screen_state)},
) )
async def execute_action(self, action: str, **params: Any) -> ActionResult: async def execute_action(self, action: str, **params: object) -> ActionResult:
"""模拟执行 UI 操作""" """模拟执行 UI 操作"""
if not self._started: if not self._started:
return ActionResult( return ActionResult(
@ -186,7 +192,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
self.record_action(action, params, result) self.record_action(action, params, result)
return result return result
def _simulate_action(self, action: str, **params: Any) -> ActionResult: def _simulate_action(self, action: str, **params: object) -> ActionResult:
"""模拟具体操作""" """模拟具体操作"""
if action == "click": if action == "click":
x = params.get("x", 0) x = params.get("x", 0)
@ -270,18 +276,78 @@ class LocalComputerUseSession(ComputerUseSession):
screen_width=screen_width, screen_width=screen_width,
screen_height=screen_height, screen_height=screen_height,
) )
self._pyautogui: Any = None self._pyautogui: object | None = None
# Allowed keys for the `key` action — prevents OS-level shortcut abuse # Allowed keys for the `key` action — prevents OS-level shortcut abuse
_ALLOWED_KEYS: set[str] = { _ALLOWED_KEYS: set[str] = {
"enter", "return", "tab", "backspace", "delete", "home", "end", "enter",
"up", "down", "left", "right", "pageup", "pagedown", "return",
"space", "escape", "insert", "tab",
"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12", "backspace",
"shift", "ctrl", "alt", "command", "delete",
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "home",
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "end",
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "up",
"down",
"left",
"right",
"pageup",
"pagedown",
"space",
"escape",
"insert",
"f1",
"f2",
"f3",
"f4",
"f5",
"f6",
"f7",
"f8",
"f9",
"f10",
"f11",
"f12",
"shift",
"ctrl",
"alt",
"command",
"a",
"b",
"c",
"d",
"e",
"f",
"g",
"h",
"i",
"j",
"k",
"l",
"m",
"n",
"o",
"p",
"q",
"r",
"s",
"t",
"u",
"v",
"w",
"x",
"y",
"z",
"0",
"1",
"2",
"3",
"4",
"5",
"6",
"7",
"8",
"9",
} }
_ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"} _ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"}
_MAX_TEXT_LENGTH: int = 1000 _MAX_TEXT_LENGTH: int = 1000
@ -291,6 +357,7 @@ class LocalComputerUseSession(ComputerUseSession):
"""启动本地桌面会话""" """启动本地桌面会话"""
try: try:
import pyautogui import pyautogui
self._pyautogui = pyautogui self._pyautogui = pyautogui
pyautogui.FAILSAFE = True pyautogui.FAILSAFE = True
pyautogui.PAUSE = 0.1 pyautogui.PAUSE = 0.1
@ -327,7 +394,7 @@ class LocalComputerUseSession(ComputerUseSession):
except Exception as e: except Exception as e:
return ActionResult(success=False, action="screenshot", error=str(e)) return ActionResult(success=False, action="screenshot", error=str(e))
async def execute_action(self, action: str, **params: Any) -> ActionResult: async def execute_action(self, action: str, **params: object) -> ActionResult:
"""在本地桌面执行 UI 操作""" """在本地桌面执行 UI 操作"""
if not self._started: if not self._started:
return ActionResult(success=False, action=action, error="Session not started") return ActionResult(success=False, action=action, error="Session not started")
@ -343,7 +410,7 @@ class LocalComputerUseSession(ComputerUseSession):
"""Check if coordinates are within screen bounds.""" """Check if coordinates are within screen bounds."""
return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height
async def _execute_local_action(self, action: str, **params: Any) -> ActionResult: async def _execute_local_action(self, action: str, **params: object) -> ActionResult:
"""Execute a local UI action with input validation.""" """Execute a local UI action with input validation."""
pg = self._pyautogui pg = self._pyautogui
@ -351,16 +418,24 @@ class LocalComputerUseSession(ComputerUseSession):
x, y = params.get("x", 0), params.get("y", 0) x, y = params.get("x", 0), params.get("y", 0)
button = params.get("button", "left") button = params.get("button", "left")
if button not in self._ALLOWED_BUTTONS: if button not in self._ALLOWED_BUTTONS:
return ActionResult(success=False, action="click", error=f"Invalid button: {button}") return ActionResult(
success=False, action="click", error=f"Invalid button: {button}"
)
if not self._validate_coordinates(x, y): if not self._validate_coordinates(x, y):
return ActionResult(success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})") return ActionResult(
success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})"
)
pg.click(x, y, button=button) pg.click(x, y, button=button)
return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})") return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})")
if action == "type": if action == "type":
text = params.get("text", "") text = params.get("text", "")
if len(text) > self._MAX_TEXT_LENGTH: if len(text) > self._MAX_TEXT_LENGTH:
return ActionResult(success=False, action="type", error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}") return ActionResult(
success=False,
action="type",
error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}",
)
pg.write(text) pg.write(text)
return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}") return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}")
@ -369,16 +444,22 @@ class LocalComputerUseSession(ComputerUseSession):
amount = params.get("amount", 3) amount = params.get("amount", 3)
clicks = amount if direction == "down" else -amount clicks = amount if direction == "down" else -amount
pg.scroll(clicks) pg.scroll(clicks)
return ActionResult(success=True, action="scroll", output=f"Scrolled {direction} by {amount}") return ActionResult(
success=True, action="scroll", output=f"Scrolled {direction} by {amount}"
)
if action == "drag": if action == "drag":
sx, sy = params.get("start_x", 0), params.get("start_y", 0) sx, sy = params.get("start_x", 0), params.get("start_y", 0)
ex, ey = params.get("end_x", 0), params.get("end_y", 0) ex, ey = params.get("end_x", 0), params.get("end_y", 0)
if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)): if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)):
return ActionResult(success=False, action="drag", error="Drag coordinates out of bounds") return ActionResult(
success=False, action="drag", error="Drag coordinates out of bounds"
)
pg.moveTo(sx, sy) pg.moveTo(sx, sy)
pg.dragTo(ex, ey, duration=0.5) pg.dragTo(ex, ey, duration=0.5)
return ActionResult(success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})") return ActionResult(
success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})"
)
if action == "key": if action == "key":
key_name = params.get("key_name", "") key_name = params.get("key_name", "")
@ -487,7 +568,7 @@ class DockerComputerUseSession(ComputerUseSession):
screenshot_base64="", screenshot_base64="",
) )
async def execute_action(self, action: str, **params: Any) -> ActionResult: async def execute_action(self, action: str, **params: object) -> ActionResult:
"""在 Docker 虚拟桌面执行操作 """在 Docker 虚拟桌面执行操作
Stub: 实际实现需要通过 Anthropic Computer Use API Stub: 实际实现需要通过 Anthropic Computer Use API
@ -527,7 +608,7 @@ class ComputerUseSessionManager:
def get_or_create( def get_or_create(
self, self,
session_id: str | None = None, session_id: str | None = None,
**kwargs: Any, **kwargs: object,
) -> ComputerUseSession: ) -> ComputerUseSession:
"""获取或创建会话""" """获取或创建会话"""
if session_id and session_id in self._sessions: if session_id and session_id in self._sessions: