refactor(bitable,tools): replace Any with concrete types + Protocol (U4)
BitableRecord/FormulaResult/SessionState TypeAlias replace dict[str, Any]; _redis/_engine/_session_factory typed as object | None with TYPE_CHECKING Protocol (_RedisLike, _RecalcWorker); Coroutine[Any, Any, Any] retained as legitimate type param. Baseline 40 : Any occurrences -> 0 across 6 in-scope files (target <=5). Deferred: repository.py/recalc_worker.py/ingestion/* (10 occurrences, separate PR). ruff clean; 367 passed + 116 skipped (bitable + pipeline_state + tools).
This commit is contained in:
parent
be5c4e09f8
commit
1033346913
|
|
@ -18,7 +18,7 @@ import logging
|
|||
import os
|
||||
import uuid as _uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from types import TracebackType
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
|
|
@ -191,7 +191,7 @@ class MetaModel(BitableBase):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _apply_v2_migration(conn: Any) -> None:
|
||||
async def _apply_v2_migration(conn: object) -> None:
|
||||
"""V2 migration: create ``bitable_files`` table + add ``file_id`` to tables.
|
||||
|
||||
Idempotent — safe to call on fresh installs (``create_all`` already made
|
||||
|
|
@ -265,8 +265,8 @@ class BitableDB:
|
|||
|
||||
def __init__(self, database_url: str | None = None) -> None:
|
||||
self._database_url = database_url or _resolve_database_url()
|
||||
self._engine: Any = None
|
||||
self._session_factory: Any = None
|
||||
self._engine: object | None = None
|
||||
self._session_factory: object | None = None
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
|
|
@ -275,11 +275,11 @@ class BitableDB:
|
|||
return self._database_url
|
||||
|
||||
@property
|
||||
def engine(self) -> Any:
|
||||
def engine(self) -> object | None:
|
||||
return self._engine
|
||||
|
||||
@property
|
||||
def session_factory(self) -> Any:
|
||||
def session_factory(self) -> object | None:
|
||||
return self._session_factory
|
||||
|
||||
@property
|
||||
|
|
@ -365,7 +365,12 @@ class BitableDB:
|
|||
await self.init()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,17 @@ based on the calling context — see :mod:`agentkit.bitable.formula.engine`.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
from typing import Callable, TypeAlias
|
||||
|
||||
# A formula evaluates to a scalar primitive: text, number, or nothing.
|
||||
# bool is intentionally excluded — comparisons live in the parser layer
|
||||
# and never reach the function registry.
|
||||
FormulaResult: TypeAlias = str | int | float | None
|
||||
|
||||
# ── Aggregate functions (operate on lists) ────────────────
|
||||
|
||||
|
||||
def _sum(values: list[Any]) -> float | int:
|
||||
def _sum(values: list[FormulaResult]) -> float | int:
|
||||
"""Sum of numeric values, ignoring None/empty."""
|
||||
total = 0
|
||||
for v in values:
|
||||
|
|
@ -27,7 +32,7 @@ def _sum(values: list[Any]) -> float | int:
|
|||
return total
|
||||
|
||||
|
||||
def _avg(values: list[Any]) -> float:
|
||||
def _avg(values: list[FormulaResult]) -> float:
|
||||
"""Average of numeric values, ignoring None/empty."""
|
||||
nums = [v for v in values if v is not None and v != ""]
|
||||
if not nums:
|
||||
|
|
@ -35,12 +40,12 @@ def _avg(values: list[Any]) -> float:
|
|||
return sum(nums) / len(nums)
|
||||
|
||||
|
||||
def _count(values: list[Any]) -> int:
|
||||
def _count(values: list[FormulaResult]) -> int:
|
||||
"""Count of non-empty values."""
|
||||
return sum(1 for v in values if v is not None and v != "")
|
||||
|
||||
|
||||
def _min(values: list[Any]) -> Any:
|
||||
def _min(values: list[FormulaResult]) -> FormulaResult:
|
||||
"""Minimum of numeric values, ignoring None/empty."""
|
||||
nums = [v for v in values if v is not None and v != ""]
|
||||
if not nums:
|
||||
|
|
@ -48,7 +53,7 @@ def _min(values: list[Any]) -> Any:
|
|||
return min(nums)
|
||||
|
||||
|
||||
def _max(values: list[Any]) -> Any:
|
||||
def _max(values: list[FormulaResult]) -> FormulaResult:
|
||||
"""Maximum of numeric values, ignoring None/empty."""
|
||||
nums = [v for v in values if v is not None and v != ""]
|
||||
if not nums:
|
||||
|
|
@ -59,25 +64,29 @@ def _max(values: list[Any]) -> Any:
|
|||
# ── Scalar functions ──────────────────────────────────────
|
||||
|
||||
|
||||
def _abs(value: Any) -> Any:
|
||||
def _abs(value: FormulaResult) -> FormulaResult:
|
||||
return abs(value)
|
||||
|
||||
|
||||
def _round(value: Any, digits: int = 0) -> float:
|
||||
def _round(value: FormulaResult, digits: int = 0) -> float:
|
||||
return round(value, digits)
|
||||
|
||||
|
||||
def _if(condition: Any, true_val: Any, false_val: Any = None) -> Any:
|
||||
def _if(
|
||||
condition: FormulaResult,
|
||||
true_val: FormulaResult,
|
||||
false_val: FormulaResult = None,
|
||||
) -> FormulaResult:
|
||||
return true_val if condition else false_val
|
||||
|
||||
|
||||
def _len(value: Any) -> int:
|
||||
def _len(value: FormulaResult) -> int:
|
||||
if value is None:
|
||||
return 0
|
||||
return len(str(value))
|
||||
|
||||
|
||||
def _concat(*args: Any) -> str:
|
||||
def _concat(*args: FormulaResult) -> str:
|
||||
"""Concatenate all arguments as strings."""
|
||||
return "".join(str(a) for a in args if a is not None)
|
||||
|
||||
|
|
@ -87,7 +96,7 @@ def _concat(*args: Any) -> str:
|
|||
# Functions that aggregate a column (receive a list of all column values)
|
||||
AGGREGATE_FUNCTIONS: frozenset[str] = frozenset({"SUM", "AVG", "COUNT", "MIN", "MAX"})
|
||||
|
||||
FUNCTION_REGISTRY: dict[str, Callable[..., Any]] = {
|
||||
FUNCTION_REGISTRY: dict[str, Callable[..., FormulaResult]] = {
|
||||
"SUM": _sum,
|
||||
"AVG": _avg,
|
||||
"COUNT": _count,
|
||||
|
|
|
|||
|
|
@ -25,9 +25,9 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
|
||||
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY
|
||||
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY, FormulaResult
|
||||
|
||||
# ── Exceptions ────────────────────────────────────────────
|
||||
|
||||
|
|
@ -184,9 +184,9 @@ def parse_formula(
|
|||
|
||||
def evaluate_ast(
|
||||
tree: ast.Expression,
|
||||
field_values: dict[str, Any],
|
||||
functions: dict[str, Any],
|
||||
) -> Any:
|
||||
field_values: dict[str, FormulaResult | list[FormulaResult]],
|
||||
functions: dict[str, Callable[..., FormulaResult]],
|
||||
) -> FormulaResult:
|
||||
"""Evaluate a parsed formula AST against field values and functions.
|
||||
|
||||
This is NOT ``eval()`` — it's a manual AST walker that only processes
|
||||
|
|
@ -204,7 +204,11 @@ def evaluate_ast(
|
|||
return _eval_node(tree.body, field_values, functions)
|
||||
|
||||
|
||||
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any:
|
||||
def _eval_node(
|
||||
node: ast.AST,
|
||||
fields: dict[str, FormulaResult | list[FormulaResult]],
|
||||
functions: dict[str, Callable[..., FormulaResult]],
|
||||
) -> FormulaResult:
|
||||
"""Recursively evaluate an AST node."""
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
|
|
@ -274,7 +278,7 @@ def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any])
|
|||
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
|
||||
|
||||
|
||||
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
|
||||
def _apply_binop(op: ast.AST, left: FormulaResult, right: FormulaResult) -> FormulaResult:
|
||||
"""Apply a binary operator."""
|
||||
if isinstance(op, ast.Add):
|
||||
# String concat or numeric addition
|
||||
|
|
@ -294,7 +298,7 @@ def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
|
|||
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
|
||||
|
||||
|
||||
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool:
|
||||
def _apply_compare(op: ast.AST, left: FormulaResult, right: FormulaResult) -> bool:
|
||||
"""Apply a comparison operator."""
|
||||
if isinstance(op, ast.Eq):
|
||||
return left == right
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import logging
|
|||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
|
||||
from agentkit.bitable.db import BitableDB
|
||||
from agentkit.bitable.models import (
|
||||
|
|
@ -29,13 +29,27 @@ from agentkit.bitable.models import (
|
|||
)
|
||||
from agentkit.bitable.repository import BitableRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Protocol
|
||||
|
||||
class _RecalcWorker(Protocol):
|
||||
"""Structural type for the recalc worker's cache-invalidation surface."""
|
||||
|
||||
def invalidate_engine(self, table_id: str) -> None: ...
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Record values are JSON scalars (text/number/none). Attachment/image fields
|
||||
# store lists of dicts at runtime, but the common-case scalar shape is captured
|
||||
# here for annotation clarity; fall back to dict[str, object] where lists occur.
|
||||
BitableRecord: TypeAlias = dict[str, str | int | float | None]
|
||||
|
||||
|
||||
class FieldDependencyError(Exception):
|
||||
"""Raised when deleting a field that has dependencies (formula refs, PK, views)."""
|
||||
|
||||
def __init__(self, message: str, dependencies: dict[str, Any]) -> None:
|
||||
def __init__(self, message: str, dependencies: dict[str, object]) -> None:
|
||||
super().__init__(message)
|
||||
self.dependencies = dependencies
|
||||
|
||||
|
|
@ -52,13 +66,13 @@ class BitableService:
|
|||
def __init__(self, db: BitableDB) -> None:
|
||||
self._db = db
|
||||
self._repo = BitableRepository(db)
|
||||
self._recalc_worker: Any = None # RecalcWorker, set via set_recalc_worker
|
||||
self._recalc_worker: _RecalcWorker | None = None # set via set_recalc_worker
|
||||
|
||||
@property
|
||||
def repo(self) -> BitableRepository:
|
||||
return self._repo
|
||||
|
||||
def set_recalc_worker(self, worker: Any) -> None:
|
||||
def set_recalc_worker(self, worker: _RecalcWorker) -> None:
|
||||
"""Register the long-lived RecalcWorker so field changes can invalidate its engine cache.
|
||||
|
||||
Called after both service and worker are constructed (breaks the
|
||||
|
|
@ -95,7 +109,7 @@ class BitableService:
|
|||
async def list_files(self, owner_user_id: str | None = None) -> list[BitableFile]:
|
||||
return await self._repo.list_files(owner_user_id=owner_user_id)
|
||||
|
||||
async def update_file(self, file_id: str, **kwargs: Any) -> BitableFile | None:
|
||||
async def update_file(self, file_id: str, **kwargs: object) -> BitableFile | None:
|
||||
return await self._repo.update_file(file_id, **kwargs)
|
||||
|
||||
async def delete_file(self, file_id: str) -> bool:
|
||||
|
|
@ -162,7 +176,7 @@ class BitableService:
|
|||
async def list_tables(self, owner_user_id: str | None = None) -> list[Table]:
|
||||
return await self._repo.list_tables(owner_user_id=owner_user_id)
|
||||
|
||||
async def update_table(self, table_id: str, **kwargs: Any) -> Table | None:
|
||||
async def update_table(self, table_id: str, **kwargs: object) -> Table | None:
|
||||
"""Update table attrs. Creates PK unique index if primary_key_field_id is set."""
|
||||
table = await self._repo.update_table(table_id, **kwargs)
|
||||
if table and kwargs.get("primary_key_field_id"):
|
||||
|
|
@ -179,7 +193,7 @@ class BitableService:
|
|||
table_id: str,
|
||||
name: str,
|
||||
field_type: FieldType,
|
||||
config: dict[str, Any] | None = None,
|
||||
config: dict[str, object] | None = None,
|
||||
owner: FieldOwner = FieldOwner.user,
|
||||
) -> Field:
|
||||
"""Create a new field. U2 will add formula validation and DAG updates."""
|
||||
|
|
@ -201,7 +215,7 @@ class BitableService:
|
|||
async def list_fields(self, table_id: str) -> list[Field]:
|
||||
return await self._repo.list_fields(table_id)
|
||||
|
||||
async def update_field(self, field_id: str, **kwargs: Any) -> Field | None:
|
||||
async def update_field(self, field_id: str, **kwargs: object) -> Field | None:
|
||||
"""Update a field. U2 will add dependency checking."""
|
||||
field = await self._repo.update_field(field_id, **kwargs)
|
||||
if field is not None:
|
||||
|
|
@ -220,7 +234,7 @@ class BitableService:
|
|||
return False
|
||||
|
||||
# Check dependencies
|
||||
deps: dict[str, Any] = {}
|
||||
deps: dict[str, object] = {}
|
||||
|
||||
# 1. Is it a primary key field?
|
||||
table = await self._repo.get_table(field.table_id)
|
||||
|
|
@ -264,7 +278,7 @@ class BitableService:
|
|||
async def create_record(
|
||||
self,
|
||||
table_id: str,
|
||||
values: dict[str, Any] | None = None,
|
||||
values: BitableRecord | None = None,
|
||||
actor_user_id: str | None = None,
|
||||
) -> Record:
|
||||
"""Create a new record. Triggers recalc for affected formula fields.
|
||||
|
|
@ -291,7 +305,7 @@ class BitableService:
|
|||
return record
|
||||
|
||||
async def create_records_batch(
|
||||
self, table_id: str, records_values: list[dict[str, Any]]
|
||||
self, table_id: str, records_values: list[BitableRecord]
|
||||
) -> list[Record]:
|
||||
"""Batch-create records (P2 #19). Triggers recalc for each record.
|
||||
|
||||
|
|
@ -319,8 +333,8 @@ class BitableService:
|
|||
async def list_records_filtered(
|
||||
self,
|
||||
table_id: str,
|
||||
filters: list[dict[str, Any]] | None = None,
|
||||
sorts: list[dict[str, Any]] | None = None,
|
||||
filters: list[dict[str, object]] | None = None,
|
||||
sorts: list[dict[str, object]] | None = None,
|
||||
cursor: str | None = None,
|
||||
limit: int = 50,
|
||||
) -> tuple[list[Record], str | None]:
|
||||
|
|
@ -345,7 +359,7 @@ class BitableService:
|
|||
table_id, filters=filters, sorts=sorts, cursor=cursor, limit=limit
|
||||
)
|
||||
|
||||
async def update_record_values(self, record_id: str, values: dict[str, Any]) -> Record | None:
|
||||
async def update_record_values(self, record_id: str, values: BitableRecord) -> Record | None:
|
||||
"""Update a record's values (full replace). Triggers recalc for affected formulas."""
|
||||
record = await self._repo.update_record_values(record_id, values)
|
||||
if record is not None:
|
||||
|
|
@ -431,9 +445,9 @@ class BitableService:
|
|||
async def upsert_records(
|
||||
self,
|
||||
table_id: str,
|
||||
records: list[dict[str, Any]],
|
||||
records: list[BitableRecord],
|
||||
primary_key_field_id: str,
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, int]:
|
||||
"""Upsert records by primary key using jsonb_set (KTD8).
|
||||
|
||||
For each record:
|
||||
|
|
@ -454,12 +468,12 @@ class BitableService:
|
|||
agent_field_ids = {f.id for f in fields if f.owner == FieldOwner.agent}
|
||||
|
||||
# Partition records into insert vs update lists, collecting PK values.
|
||||
to_insert: list[dict[str, Any]] = []
|
||||
to_update: list[tuple[dict[str, Any], str]] = [] # (values, existing_record_id)
|
||||
to_insert: list[BitableRecord] = []
|
||||
to_update: list[tuple[BitableRecord, str]] = [] # (values, existing_record_id)
|
||||
skipped = 0
|
||||
|
||||
# Collect all non-None PK values for batch lookup.
|
||||
pk_values_by_str: dict[str, dict[str, Any]] = {}
|
||||
pk_values_by_str: dict[str, BitableRecord] = {}
|
||||
for rec_values in records:
|
||||
pk_value = rec_values.get(primary_key_field_id)
|
||||
if pk_value is None:
|
||||
|
|
@ -504,7 +518,7 @@ class BitableService:
|
|||
table_id: str,
|
||||
name: str,
|
||||
view_type: ViewType = ViewType.grid,
|
||||
config: dict[str, Any] | None = None,
|
||||
config: dict[str, object] | None = None,
|
||||
) -> View:
|
||||
return await self._repo.create_view(
|
||||
table_id=table_id,
|
||||
|
|
@ -516,7 +530,7 @@ class BitableService:
|
|||
async def list_views(self, table_id: str) -> list[View]:
|
||||
return await self._repo.list_views(table_id)
|
||||
|
||||
async def update_view(self, view_id: str, **kwargs: Any) -> View | None:
|
||||
async def update_view(self, view_id: str, **kwargs: object) -> View | None:
|
||||
return await self._repo.update_view(view_id, **kwargs)
|
||||
|
||||
async def get_view(self, view_id: str) -> View | None:
|
||||
|
|
|
|||
|
|
@ -32,14 +32,14 @@ class PipelineStateMemory:
|
|||
"""In-memory pipeline state storage (testing / fallback)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._executions: dict[str, dict[str, Any]] = {}
|
||||
self._step_history: dict[str, list[dict[str, Any]]] = {}
|
||||
self._executions: dict[str, dict[str, object]] = {}
|
||||
self._step_history: dict[str, list[dict[str, object]]] = {}
|
||||
|
||||
async def create_execution(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
input_data: dict[str, object] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
|
@ -67,7 +67,7 @@ class PipelineStateMemory:
|
|||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
output: dict[str, object] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
|
|
@ -88,7 +88,7 @@ class PipelineStateMemory:
|
|||
exec_state["error_message"] = error
|
||||
|
||||
# Record step history event
|
||||
step_event: dict[str, Any] = {
|
||||
step_event: dict[str, object] = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"execution_id": execution_id,
|
||||
"step_name": step_name,
|
||||
|
|
@ -97,14 +97,16 @@ class PipelineStateMemory:
|
|||
"error_message": error,
|
||||
"duration_ms": duration_ms,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
if status in ("completed", "failed")
|
||||
else None,
|
||||
}
|
||||
self._step_history[execution_id].append(step_event)
|
||||
|
||||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
final_output: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
exec_state = self._executions.get(execution_id)
|
||||
if exec_state is None:
|
||||
|
|
@ -130,7 +132,7 @@ class PipelineStateMemory:
|
|||
exec_state["updated_at"] = now
|
||||
exec_state["completed_at"] = now
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||
return self._executions.get(execution_id)
|
||||
|
||||
async def list_executions(
|
||||
|
|
@ -138,17 +140,17 @@ class PipelineStateMemory:
|
|||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[dict[str, object]]:
|
||||
results = list(self._executions.values())
|
||||
if status:
|
||||
results = [e for e in results if e.get("status") == status]
|
||||
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||
return self._step_history.get(execution_id, [])
|
||||
|
||||
def get_execution_sync(self, execution_id: str) -> dict[str, Any] | None:
|
||||
def get_execution_sync(self, execution_id: str) -> dict[str, object] | None:
|
||||
"""Synchronous accessor for execution state (used by Redis dual-write)."""
|
||||
return self._executions.get(execution_id)
|
||||
|
||||
|
|
@ -165,7 +167,7 @@ class PipelineStateRedis:
|
|||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
|
||||
self._redis_url = redis_url
|
||||
self._redis: Any = None
|
||||
self._redis: object | None = None
|
||||
self._fallback = PipelineStateMemory()
|
||||
self._use_fallback = False
|
||||
self._fallback_since: float | None = None
|
||||
|
|
@ -181,8 +183,8 @@ class PipelineStateRedis:
|
|||
return self._redis
|
||||
|
||||
async def _safe_redis_call(
|
||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object, **kwargs: object
|
||||
) -> object | None:
|
||||
"""Execute a Redis call, falling back to memory on failure.
|
||||
|
||||
After falling back, periodically retries Redis to enable recovery.
|
||||
|
|
@ -192,6 +194,7 @@ class PipelineStateRedis:
|
|||
# Check if enough time has passed to attempt recovery
|
||||
if self._fallback_since is not None:
|
||||
import time as _time
|
||||
|
||||
elapsed = _time.monotonic() - self._fallback_since
|
||||
if elapsed >= self._RECOVERY_COOLDOWN_SECONDS:
|
||||
try:
|
||||
|
|
@ -218,6 +221,7 @@ class PipelineStateRedis:
|
|||
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
|
||||
self._use_fallback = True
|
||||
import time as _time
|
||||
|
||||
self._fallback_since = _time.monotonic()
|
||||
self._redis = None
|
||||
return None
|
||||
|
|
@ -229,7 +233,7 @@ class PipelineStateRedis:
|
|||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
input_data: dict[str, object] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
# Always write to fallback first for consistency
|
||||
|
|
@ -238,7 +242,7 @@ class PipelineStateRedis:
|
|||
)
|
||||
|
||||
# Try Redis
|
||||
async def _redis_create(redis: Any) -> None:
|
||||
async def _redis_create(redis: object) -> None:
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
score = datetime.now(timezone.utc).timestamp()
|
||||
pipe = redis.pipeline()
|
||||
|
|
@ -254,13 +258,15 @@ class PipelineStateRedis:
|
|||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
output: dict[str, object] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
|
||||
await self._fallback.update_step(
|
||||
execution_id, step_name, status, output, error, duration_ms
|
||||
)
|
||||
|
||||
async def _redis_update(redis: Any) -> None:
|
||||
async def _redis_update(redis: object) -> None:
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
|
|
@ -271,11 +277,11 @@ class PipelineStateRedis:
|
|||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
final_output: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
await self._fallback.complete_execution(execution_id, final_output)
|
||||
|
||||
async def _redis_complete(redis: Any) -> None:
|
||||
async def _redis_complete(redis: object) -> None:
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
|
|
@ -291,7 +297,7 @@ class PipelineStateRedis:
|
|||
) -> None:
|
||||
await self._fallback.fail_execution(execution_id, step_name, error)
|
||||
|
||||
async def _redis_fail(redis: Any) -> None:
|
||||
async def _redis_fail(redis: object) -> None:
|
||||
state = self._fallback.get_execution_sync(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
|
|
@ -299,7 +305,7 @@ class PipelineStateRedis:
|
|||
|
||||
await self._safe_redis_call(_redis_fail)
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||
# Try Redis first
|
||||
if not self._use_fallback:
|
||||
try:
|
||||
|
|
@ -318,7 +324,7 @@ class PipelineStateRedis:
|
|||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[dict[str, object]]:
|
||||
# Try Redis sorted set for efficient listing
|
||||
if not self._use_fallback:
|
||||
try:
|
||||
|
|
@ -341,7 +347,7 @@ class PipelineStateRedis:
|
|||
|
||||
return await self._fallback.list_executions(status, limit, offset)
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||
return await self._fallback.get_step_history(execution_id)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
|
|
@ -364,20 +370,18 @@ class PipelineStatePG:
|
|||
If session_factory is None, all methods are no-op.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: Any = None) -> None:
|
||||
def __init__(self, session_factory: object | None = None) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._session_factory is not None
|
||||
|
||||
async def persist_execution(self, state: dict[str, Any]) -> None:
|
||||
async def persist_execution(self, state: dict[str, object]) -> None:
|
||||
"""Write a completed/failed execution to PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async with self._session_factory() as session:
|
||||
model = PipelineExecutionModel(
|
||||
id=state["id"],
|
||||
|
|
@ -390,18 +394,22 @@ class PipelineStatePG:
|
|||
final_output=state.get("final_output"),
|
||||
error_message=state.get("error_message"),
|
||||
tenant_id=state.get("tenant_id"),
|
||||
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None,
|
||||
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None,
|
||||
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None,
|
||||
created_at=datetime.fromisoformat(state["created_at"])
|
||||
if state.get("created_at")
|
||||
else None,
|
||||
updated_at=datetime.fromisoformat(state["updated_at"])
|
||||
if state.get("updated_at")
|
||||
else None,
|
||||
completed_at=datetime.fromisoformat(state["completed_at"])
|
||||
if state.get("completed_at")
|
||||
else None,
|
||||
)
|
||||
await session.merge(model)
|
||||
await session.commit()
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to persist execution to PG: {exc}")
|
||||
|
||||
async def persist_step_history(
|
||||
self, execution_id: str, steps: list[dict[str, Any]]
|
||||
) -> None:
|
||||
async def persist_step_history(self, execution_id: str, steps: list[dict[str, object]]) -> None:
|
||||
"""Write step history to PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
|
@ -419,8 +427,12 @@ class PipelineStatePG:
|
|||
error_message=step.get("error_message"),
|
||||
duration_ms=step.get("duration_ms"),
|
||||
retry_attempt=step.get("retry_attempt", 0),
|
||||
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None,
|
||||
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None,
|
||||
started_at=datetime.fromisoformat(step["started_at"])
|
||||
if step.get("started_at")
|
||||
else None,
|
||||
completed_at=datetime.fromisoformat(step["completed_at"])
|
||||
if step.get("completed_at")
|
||||
else None,
|
||||
)
|
||||
await session.merge(model)
|
||||
await session.commit()
|
||||
|
|
@ -433,7 +445,7 @@ class PipelineStatePG:
|
|||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[dict[str, object]]:
|
||||
"""Query historical executions from PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return []
|
||||
|
|
@ -445,9 +457,7 @@ class PipelineStatePG:
|
|||
PipelineExecutionModel.created_at.desc()
|
||||
)
|
||||
if pipeline_name:
|
||||
stmt = stmt.where(
|
||||
PipelineExecutionModel.pipeline_name == pipeline_name
|
||||
)
|
||||
stmt = stmt.where(PipelineExecutionModel.pipeline_name == pipeline_name)
|
||||
if status:
|
||||
stmt = stmt.where(PipelineExecutionModel.status == status)
|
||||
stmt = stmt.offset(offset).limit(limit)
|
||||
|
|
@ -458,7 +468,7 @@ class PipelineStatePG:
|
|||
logger.error(f"Failed to query executions from PG: {exc}")
|
||||
return []
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
|
||||
if not self.enabled:
|
||||
return None
|
||||
|
|
@ -479,7 +489,7 @@ class PipelineStatePG:
|
|||
return None
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]:
|
||||
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, object]:
|
||||
return {
|
||||
"id": model.id,
|
||||
"pipeline_name": model.pipeline_name,
|
||||
|
|
@ -509,7 +519,7 @@ class PipelineStateManager:
|
|||
def __init__(
|
||||
self,
|
||||
redis_url: str | None = None,
|
||||
session_factory: Any = None,
|
||||
session_factory: object | None = None,
|
||||
) -> None:
|
||||
if redis_url:
|
||||
self._hot = PipelineStateRedis(redis_url=redis_url)
|
||||
|
|
@ -529,7 +539,7 @@ class PipelineStateManager:
|
|||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
input_data: dict[str, object] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
|
||||
|
|
@ -539,7 +549,7 @@ class PipelineStateManager:
|
|||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
output: dict[str, object] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
|
|
@ -548,7 +558,7 @@ class PipelineStateManager:
|
|||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
final_output: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
await self._hot.complete_execution(execution_id, final_output)
|
||||
# Persist to PG
|
||||
|
|
@ -574,7 +584,7 @@ class PipelineStateManager:
|
|||
if step_history:
|
||||
await self._cold.persist_step_history(execution_id, step_history)
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
async def get_execution(self, execution_id: str) -> dict[str, object] | None:
|
||||
# Redis / memory first
|
||||
state = await self._hot.get_execution(execution_id)
|
||||
if state is not None:
|
||||
|
|
@ -587,7 +597,7 @@ class PipelineStateManager:
|
|||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[dict[str, object]]:
|
||||
# Hot store for recent executions
|
||||
results = await self._hot.list_executions(status, limit, offset)
|
||||
if results:
|
||||
|
|
@ -595,7 +605,7 @@ class PipelineStateManager:
|
|||
# Cold store for historical queries
|
||||
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, object]]:
|
||||
return await self._hot.get_step_history(execution_id)
|
||||
|
||||
async def health_check(self) -> dict[str, bool]:
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@ import time
|
|||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import TypeAlias
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Scalar session state values (text/number/flag/none). Screen-state dicts use
|
||||
# dict[str, object] because they also hold tuple cursor positions.
|
||||
SessionState: TypeAlias = dict[str, str | int | bool | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenInfo:
|
||||
|
|
@ -37,7 +41,7 @@ class ActionResult:
|
|||
output: str = ""
|
||||
screenshot_base64: str = ""
|
||||
error: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ComputerUseSession(ABC):
|
||||
|
|
@ -56,7 +60,7 @@ class ComputerUseSession(ABC):
|
|||
self.session_id = session_id or str(uuid.uuid4())
|
||||
self.screen = ScreenInfo(width=screen_width, height=screen_height)
|
||||
self._started = False
|
||||
self._action_history: list[dict[str, Any]] = []
|
||||
self._action_history: list[dict[str, object]] = []
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
|
|
@ -82,7 +86,7 @@ class ComputerUseSession(ABC):
|
|||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""执行 UI 操作
|
||||
|
||||
Args:
|
||||
|
|
@ -94,18 +98,20 @@ class ComputerUseSession(ABC):
|
|||
"""
|
||||
...
|
||||
|
||||
def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None:
|
||||
def record_action(self, action: str, params: dict[str, object], result: ActionResult) -> None:
|
||||
"""记录操作历史"""
|
||||
self._action_history.append({
|
||||
self._action_history.append(
|
||||
{
|
||||
"timestamp": time.time(),
|
||||
"action": action,
|
||||
"params": params,
|
||||
"success": result.success,
|
||||
"output": result.output[:200] if result.output else "",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def action_history(self) -> list[dict[str, Any]]:
|
||||
def action_history(self) -> list[dict[str, object]]:
|
||||
"""获取操作历史(副本)"""
|
||||
return list(self._action_history)
|
||||
|
||||
|
|
@ -134,7 +140,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
|||
screen_width=screen_width,
|
||||
screen_height=screen_height,
|
||||
)
|
||||
self._screen_state: dict[str, Any] = {
|
||||
self._screen_state: dict[str, object] = {
|
||||
"focused_element": None,
|
||||
"cursor_position": (0, 0),
|
||||
"typed_text": "",
|
||||
|
|
@ -173,7 +179,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
|||
metadata={"screen_state": dict(self._screen_state)},
|
||||
)
|
||||
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""模拟执行 UI 操作"""
|
||||
if not self._started:
|
||||
return ActionResult(
|
||||
|
|
@ -186,7 +192,7 @@ class InMemoryComputerUseSession(ComputerUseSession):
|
|||
self.record_action(action, params, result)
|
||||
return result
|
||||
|
||||
def _simulate_action(self, action: str, **params: Any) -> ActionResult:
|
||||
def _simulate_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""模拟具体操作"""
|
||||
if action == "click":
|
||||
x = params.get("x", 0)
|
||||
|
|
@ -270,18 +276,78 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
screen_width=screen_width,
|
||||
screen_height=screen_height,
|
||||
)
|
||||
self._pyautogui: Any = None
|
||||
self._pyautogui: object | None = None
|
||||
|
||||
# Allowed keys for the `key` action — prevents OS-level shortcut abuse
|
||||
_ALLOWED_KEYS: set[str] = {
|
||||
"enter", "return", "tab", "backspace", "delete", "home", "end",
|
||||
"up", "down", "left", "right", "pageup", "pagedown",
|
||||
"space", "escape", "insert",
|
||||
"f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8", "f9", "f10", "f11", "f12",
|
||||
"shift", "ctrl", "alt", "command",
|
||||
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m",
|
||||
"n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
|
||||
"0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
|
||||
"enter",
|
||||
"return",
|
||||
"tab",
|
||||
"backspace",
|
||||
"delete",
|
||||
"home",
|
||||
"end",
|
||||
"up",
|
||||
"down",
|
||||
"left",
|
||||
"right",
|
||||
"pageup",
|
||||
"pagedown",
|
||||
"space",
|
||||
"escape",
|
||||
"insert",
|
||||
"f1",
|
||||
"f2",
|
||||
"f3",
|
||||
"f4",
|
||||
"f5",
|
||||
"f6",
|
||||
"f7",
|
||||
"f8",
|
||||
"f9",
|
||||
"f10",
|
||||
"f11",
|
||||
"f12",
|
||||
"shift",
|
||||
"ctrl",
|
||||
"alt",
|
||||
"command",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
"d",
|
||||
"e",
|
||||
"f",
|
||||
"g",
|
||||
"h",
|
||||
"i",
|
||||
"j",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"o",
|
||||
"p",
|
||||
"q",
|
||||
"r",
|
||||
"s",
|
||||
"t",
|
||||
"u",
|
||||
"v",
|
||||
"w",
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"0",
|
||||
"1",
|
||||
"2",
|
||||
"3",
|
||||
"4",
|
||||
"5",
|
||||
"6",
|
||||
"7",
|
||||
"8",
|
||||
"9",
|
||||
}
|
||||
_ALLOWED_BUTTONS: set[str] = {"left", "right", "middle"}
|
||||
_MAX_TEXT_LENGTH: int = 1000
|
||||
|
|
@ -291,6 +357,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
"""启动本地桌面会话"""
|
||||
try:
|
||||
import pyautogui
|
||||
|
||||
self._pyautogui = pyautogui
|
||||
pyautogui.FAILSAFE = True
|
||||
pyautogui.PAUSE = 0.1
|
||||
|
|
@ -327,7 +394,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
except Exception as e:
|
||||
return ActionResult(success=False, action="screenshot", error=str(e))
|
||||
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""在本地桌面执行 UI 操作"""
|
||||
if not self._started:
|
||||
return ActionResult(success=False, action=action, error="Session not started")
|
||||
|
|
@ -343,7 +410,7 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
"""Check if coordinates are within screen bounds."""
|
||||
return 0 <= x <= self.screen_width and 0 <= y <= self.screen_height
|
||||
|
||||
async def _execute_local_action(self, action: str, **params: Any) -> ActionResult:
|
||||
async def _execute_local_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""Execute a local UI action with input validation."""
|
||||
pg = self._pyautogui
|
||||
|
||||
|
|
@ -351,16 +418,24 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
x, y = params.get("x", 0), params.get("y", 0)
|
||||
button = params.get("button", "left")
|
||||
if button not in self._ALLOWED_BUTTONS:
|
||||
return ActionResult(success=False, action="click", error=f"Invalid button: {button}")
|
||||
return ActionResult(
|
||||
success=False, action="click", error=f"Invalid button: {button}"
|
||||
)
|
||||
if not self._validate_coordinates(x, y):
|
||||
return ActionResult(success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})")
|
||||
return ActionResult(
|
||||
success=False, action="click", error=f"Coordinates out of bounds: ({x}, {y})"
|
||||
)
|
||||
pg.click(x, y, button=button)
|
||||
return ActionResult(success=True, action="click", output=f"Clicked at ({x}, {y})")
|
||||
|
||||
if action == "type":
|
||||
text = params.get("text", "")
|
||||
if len(text) > self._MAX_TEXT_LENGTH:
|
||||
return ActionResult(success=False, action="type", error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}")
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action="type",
|
||||
error=f"Text too long: {len(text)} > {self._MAX_TEXT_LENGTH}",
|
||||
)
|
||||
pg.write(text)
|
||||
return ActionResult(success=True, action="type", output=f"Typed: {text[:50]}")
|
||||
|
||||
|
|
@ -369,16 +444,22 @@ class LocalComputerUseSession(ComputerUseSession):
|
|||
amount = params.get("amount", 3)
|
||||
clicks = amount if direction == "down" else -amount
|
||||
pg.scroll(clicks)
|
||||
return ActionResult(success=True, action="scroll", output=f"Scrolled {direction} by {amount}")
|
||||
return ActionResult(
|
||||
success=True, action="scroll", output=f"Scrolled {direction} by {amount}"
|
||||
)
|
||||
|
||||
if action == "drag":
|
||||
sx, sy = params.get("start_x", 0), params.get("start_y", 0)
|
||||
ex, ey = params.get("end_x", 0), params.get("end_y", 0)
|
||||
if not (self._validate_coordinates(sx, sy) and self._validate_coordinates(ex, ey)):
|
||||
return ActionResult(success=False, action="drag", error="Drag coordinates out of bounds")
|
||||
return ActionResult(
|
||||
success=False, action="drag", error="Drag coordinates out of bounds"
|
||||
)
|
||||
pg.moveTo(sx, sy)
|
||||
pg.dragTo(ex, ey, duration=0.5)
|
||||
return ActionResult(success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})")
|
||||
return ActionResult(
|
||||
success=True, action="drag", output=f"Dragged from ({sx},{sy}) to ({ex},{ey})"
|
||||
)
|
||||
|
||||
if action == "key":
|
||||
key_name = params.get("key_name", "")
|
||||
|
|
@ -487,7 +568,7 @@ class DockerComputerUseSession(ComputerUseSession):
|
|||
screenshot_base64="",
|
||||
)
|
||||
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
async def execute_action(self, action: str, **params: object) -> ActionResult:
|
||||
"""在 Docker 虚拟桌面执行操作
|
||||
|
||||
Stub: 实际实现需要通过 Anthropic Computer Use API。
|
||||
|
|
@ -527,7 +608,7 @@ class ComputerUseSessionManager:
|
|||
def get_or_create(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
**kwargs: object,
|
||||
) -> ComputerUseSession:
|
||||
"""获取或创建会话"""
|
||||
if session_id and session_id in self._sessions:
|
||||
|
|
|
|||
Loading…
Reference in New Issue