feat(phase4): implement Computer Use integration (U12)
- ComputerUseTool: Anthropic API + fallback chain (API→Session→ShellTool→AskHuman) - ComputerUseSession: Docker sandbox + InMemory test session - ComputerUseRecorder: action recording, replay, and persistence 89 new tests passing. Degradation chain verified.
This commit is contained in:
parent
c99aee1423
commit
901e4d9d0a
|
|
@ -0,0 +1,484 @@
|
|||
"""ComputerUseTool - Anthropic Computer Use API 集成
|
||||
|
||||
封装 Anthropic Computer Use API 调用,支持截屏识别、UI 操作、降级策略。
|
||||
操作类型:screenshot / click / type / scroll / drag / key / wait
|
||||
降级链:ComputerUseTool 失败 → 检查是否有 API/CLI 替代 → ShellTool → AskHumanTool
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.computer_use_session import (
|
||||
ComputerUseSession,
|
||||
InMemoryComputerUseSession,
|
||||
ComputerUseSessionManager,
|
||||
ActionResult,
|
||||
)
|
||||
from agentkit.tools.computer_use_recorder import ComputerUseRecorder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 支持的操作类型
|
||||
_VALID_ACTIONS = ("screenshot", "click", "type", "scroll", "drag", "key", "wait")
|
||||
|
||||
# Anthropic Computer Use API 端点
|
||||
_ANTHROPIC_COMPUTER_USE_URL = "https://api.anthropic.com/v1/messages"
|
||||
|
||||
# 降级建议映射:操作 → 可能的 Shell 替代
|
||||
_FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = {
|
||||
"click": ["考虑使用 xdotool 或 API 替代点击操作"],
|
||||
"type": ["考虑使用 xdotool type 或 API 替代输入操作"],
|
||||
"scroll": ["考虑使用 xdotool scroll 或键盘 PageDown/PageUp"],
|
||||
"drag": ["考虑使用 xdotool mousemove 替代拖拽操作"],
|
||||
"key": ["考虑使用 xdotool key 或 API 替代键盘操作"],
|
||||
"screenshot": ["考虑使用 scrot/screenshot 命令或 API 替代截屏"],
|
||||
}
|
||||
|
||||
|
||||
class ComputerUseTool(Tool):
|
||||
"""Computer Use 工具
|
||||
|
||||
封装 Anthropic Computer Use API 调用,支持截屏识别和 UI 操作。
|
||||
支持降级链:API 失败 → Shell 替代建议 → AskHuman。
|
||||
|
||||
Usage:
|
||||
tool = ComputerUseTool(api_key="sk-...")
|
||||
result = await tool.execute(action="screenshot")
|
||||
result = await tool.execute(action="click", x=100, y=200)
|
||||
result = await tool.execute(action="type", text="hello")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "computer_use",
|
||||
description: str = "Anthropic Computer Use API 集成,支持截屏识别和 UI 操作",
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
api_key: str | None = None,
|
||||
model: str = "claude-sonnet-4-20250514",
|
||||
api_base_url: str = _ANTHROPIC_COMPUTER_USE_URL,
|
||||
session_factory: type[ComputerUseSession] | None = None,
|
||||
recorder: ComputerUseRecorder | None = None,
|
||||
fallback_callback: Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]] | None = None,
|
||||
max_retries: int = 1,
|
||||
request_timeout: float = 30.0,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema or self._default_input_schema(),
|
||||
output_schema=output_schema or self._default_output_schema(),
|
||||
version=version,
|
||||
tags=tags or ["computer-use", "ui-automation", "anthropic"],
|
||||
)
|
||||
self._api_key = api_key
|
||||
self._model = model
|
||||
self._api_base_url = api_base_url
|
||||
self._session_manager = ComputerUseSessionManager(
|
||||
session_factory=session_factory or InMemoryComputerUseSession,
|
||||
)
|
||||
self._recorder = recorder or ComputerUseRecorder()
|
||||
self._fallback_callback = fallback_callback
|
||||
self._max_retries = max_retries
|
||||
self._request_timeout = request_timeout
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": list(_VALID_ACTIONS),
|
||||
"description": "操作类型:screenshot/click/type/scroll/drag/key/wait",
|
||||
},
|
||||
"x": {
|
||||
"type": "integer",
|
||||
"description": "X 坐标(click/drag 操作需要)",
|
||||
},
|
||||
"y": {
|
||||
"type": "integer",
|
||||
"description": "Y 坐标(click/drag 操作需要)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "输入文本(type 操作需要)",
|
||||
},
|
||||
"key_name": {
|
||||
"type": "string",
|
||||
"description": "按键名称(key 操作需要),如 Enter、Tab、Escape",
|
||||
},
|
||||
"direction": {
|
||||
"type": "string",
|
||||
"enum": ["up", "down", "left", "right"],
|
||||
"description": "滚动方向(scroll 操作需要)",
|
||||
},
|
||||
"amount": {
|
||||
"type": "integer",
|
||||
"description": "滚动量(scroll 操作),默认 3",
|
||||
"default": 3,
|
||||
},
|
||||
"start_x": {
|
||||
"type": "integer",
|
||||
"description": "拖拽起始 X 坐标(drag 操作需要)",
|
||||
},
|
||||
"start_y": {
|
||||
"type": "integer",
|
||||
"description": "拖拽起始 Y 坐标(drag 操作需要)",
|
||||
},
|
||||
"end_x": {
|
||||
"type": "integer",
|
||||
"description": "拖拽结束 X 坐标(drag 操作需要)",
|
||||
},
|
||||
"end_y": {
|
||||
"type": "integer",
|
||||
"description": "拖拽结束 Y 坐标(drag 操作需要)",
|
||||
},
|
||||
"duration": {
|
||||
"type": "number",
|
||||
"description": "等待时长(wait 操作),默认 1.0 秒",
|
||||
"default": 1.0,
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "会话 ID,指定后在对应会话中执行操作",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _default_output_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"success": {"type": "boolean", "description": "操作是否成功"},
|
||||
"action": {"type": "string", "description": "执行的操作类型"},
|
||||
"output": {"type": "string", "description": "操作结果描述"},
|
||||
"screenshot_base64": {"type": "string", "description": "截图 base64 编码"},
|
||||
"error": {"type": "string", "description": "错误信息"},
|
||||
"fallback_suggestion": {
|
||||
"type": "string",
|
||||
"description": "降级建议(操作失败时提供)",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "会话 ID"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""执行 Computer Use 操作
|
||||
|
||||
Args:
|
||||
action: 操作类型(必需)
|
||||
x, y: 坐标(click 操作)
|
||||
text: 输入文本(type 操作)
|
||||
key_name: 按键名称(key 操作)
|
||||
direction, amount: 滚动方向和量(scroll 操作)
|
||||
start_x, start_y, end_x, end_y: 拖拽坐标(drag 操作)
|
||||
duration: 等待时长(wait 操作)
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
包含 success, action, output 等字段的字典
|
||||
"""
|
||||
action = kwargs.get("action")
|
||||
if not action:
|
||||
return self._error_result(
|
||||
action="unknown",
|
||||
error="action 参数是必需的",
|
||||
fallback="指定操作类型:screenshot/click/type/scroll/drag/key/wait",
|
||||
)
|
||||
|
||||
if action not in _VALID_ACTIONS:
|
||||
return self._error_result(
|
||||
action=action,
|
||||
error=f"无效的操作类型: {action},支持: {', '.join(_VALID_ACTIONS)}",
|
||||
)
|
||||
|
||||
# 参数验证
|
||||
validation_error = self._validate_params(action, kwargs)
|
||||
if validation_error:
|
||||
return self._error_result(
|
||||
action=action,
|
||||
error=validation_error,
|
||||
)
|
||||
|
||||
session_id = kwargs.get("session_id")
|
||||
|
||||
# 获取或创建会话
|
||||
session = self._session_manager.get_or_create(session_id=session_id)
|
||||
if not session.is_started:
|
||||
await session.start()
|
||||
|
||||
# 提取操作参数(排除 action 和 session_id)
|
||||
params = {k: v for k, v in kwargs.items() if k not in ("action", "session_id") and v is not None}
|
||||
|
||||
# 尝试通过 API 执行
|
||||
result = await self._execute_with_fallback(session, action, params)
|
||||
|
||||
# 录制操作
|
||||
self._recorder.record(action, params, result)
|
||||
|
||||
return self._format_result(result, session.session_id)
|
||||
|
||||
async def _execute_with_fallback(
|
||||
self,
|
||||
session: ComputerUseSession,
|
||||
action: str,
|
||||
params: dict[str, Any],
|
||||
) -> ActionResult:
|
||||
"""带降级链的操作执行
|
||||
|
||||
降级链:Anthropic API → Session 本地执行 → Shell 替代建议 → AskHuman
|
||||
"""
|
||||
# 1. 尝试 Anthropic Computer Use API
|
||||
if self._api_key:
|
||||
try:
|
||||
result = await self._call_anthropic_api(session, action, params)
|
||||
if result.success:
|
||||
return result
|
||||
logger.warning(
|
||||
"Anthropic API returned failure for %s: %s",
|
||||
action,
|
||||
result.error,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Anthropic API call failed for %s: %s", action, e)
|
||||
|
||||
# 2. 降级到 Session 本地执行
|
||||
try:
|
||||
if action == "screenshot":
|
||||
result = await session.screenshot()
|
||||
else:
|
||||
result = await session.execute_action(action, **params)
|
||||
if result.success:
|
||||
result.metadata["execution_mode"] = "session_local"
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning("Session local execution failed for %s: %s", action, e)
|
||||
|
||||
# 3. 尝试自定义降级回调
|
||||
if self._fallback_callback:
|
||||
try:
|
||||
fallback_result = await self._fallback_callback(action, params)
|
||||
if fallback_result.get("success"):
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=action,
|
||||
output=fallback_result.get("output", ""),
|
||||
metadata={"execution_mode": "fallback_callback"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Fallback callback failed for %s: %s", action, e)
|
||||
|
||||
# 4. 返回降级建议
|
||||
suggestions = _FALLBACK_SHELL_SUGGESTIONS.get(action, [])
|
||||
fallback_msg = "; ".join(suggestions) if suggestions else "考虑使用 ShellTool 或 AskHumanTool 替代"
|
||||
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error=f"Computer Use 操作失败,降级建议: {fallback_msg}",
|
||||
metadata={
|
||||
"execution_mode": "fallback_suggestion",
|
||||
"fallback_suggestions": suggestions,
|
||||
},
|
||||
)
|
||||
|
||||
async def _call_anthropic_api(
|
||||
self,
|
||||
session: ComputerUseSession,
|
||||
action: str,
|
||||
params: dict[str, Any],
|
||||
) -> ActionResult:
|
||||
"""调用 Anthropic Computer Use API
|
||||
|
||||
通过 Anthropic Messages API 发送 computer_use_tool 请求。
|
||||
"""
|
||||
# 构造 computer_use 工具调用参数
|
||||
tool_input: dict[str, Any] = {"action": action}
|
||||
if action == "click":
|
||||
tool_input["coordinate"] = [params.get("x", 0), params.get("y", 0)]
|
||||
elif action == "type":
|
||||
tool_input["text"] = params.get("text", "")
|
||||
elif action == "scroll":
|
||||
tool_input["direction"] = params.get("direction", "down")
|
||||
tool_input["amount"] = params.get("amount", 3)
|
||||
elif action == "drag":
|
||||
tool_input["start_coordinate"] = [
|
||||
params.get("start_x", 0),
|
||||
params.get("start_y", 0),
|
||||
]
|
||||
tool_input["coordinate"] = [
|
||||
params.get("end_x", 0),
|
||||
params.get("end_y", 0),
|
||||
]
|
||||
elif action == "key":
|
||||
tool_input["key"] = params.get("key_name", "")
|
||||
elif action == "wait":
|
||||
tool_input["duration"] = params.get("duration", 1.0)
|
||||
elif action == "screenshot":
|
||||
pass # screenshot 不需要额外参数
|
||||
|
||||
# 获取当前截图作为上下文
|
||||
screenshot_result = await session.screenshot()
|
||||
screenshot_b64 = screenshot_result.screenshot_base64
|
||||
|
||||
# 构造 API 请求
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if screenshot_b64:
|
||||
content_blocks.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_b64,
|
||||
},
|
||||
})
|
||||
|
||||
content_blocks.append({
|
||||
"type": "text",
|
||||
"text": f"Execute computer use action: {action}",
|
||||
})
|
||||
|
||||
request_body = {
|
||||
"model": self._model,
|
||||
"max_tokens": 1024,
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_20250124",
|
||||
"name": "computer",
|
||||
"display_width_px": session.screen.width,
|
||||
"display_height_px": session.screen.height,
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": content_blocks,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
headers = {
|
||||
"x-api-key": self._api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._request_timeout) as client:
|
||||
response = await client.post(
|
||||
self._api_base_url,
|
||||
json=request_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = response.text[:500]
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error=f"Anthropic API error {response.status_code}: {error_detail}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
# 解析 API 响应中的 tool_use 内容
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "tool_use" and block.get("name") == "computer":
|
||||
tool_input_resp = block.get("input", {})
|
||||
resp_action = tool_input_resp.get("action", action)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=resp_action,
|
||||
output=f"API executed: {resp_action}",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
|
||||
# API 没有返回 tool_use,可能是纯文本响应
|
||||
text_output = ""
|
||||
for block in data.get("content", []):
|
||||
if block.get("type") == "text":
|
||||
text_output += block.get("text", "")
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action=action,
|
||||
output=text_output[:500] if text_output else "API call completed",
|
||||
metadata={"api_response": data},
|
||||
)
|
||||
|
||||
def _validate_params(self, action: str, kwargs: dict[str, Any]) -> str | None:
|
||||
"""验证操作参数
|
||||
|
||||
Returns:
|
||||
错误信息,None 表示验证通过
|
||||
"""
|
||||
if action == "click":
|
||||
if "x" not in kwargs or "y" not in kwargs:
|
||||
return "click 操作需要 x 和 y 参数"
|
||||
elif action == "type":
|
||||
if not kwargs.get("text"):
|
||||
return "type 操作需要 text 参数"
|
||||
elif action == "key":
|
||||
if not kwargs.get("key_name"):
|
||||
return "key 操作需要 key_name 参数"
|
||||
elif action == "drag":
|
||||
required = ("start_x", "start_y", "end_x", "end_y")
|
||||
missing = [r for r in required if r not in kwargs]
|
||||
if missing:
|
||||
return f"drag 操作需要 {', '.join(missing)} 参数"
|
||||
return None
|
||||
|
||||
def _format_result(self, result: ActionResult, session_id: str) -> dict[str, Any]:
|
||||
"""格式化操作结果"""
|
||||
formatted: dict[str, Any] = {
|
||||
"success": result.success,
|
||||
"action": result.action,
|
||||
"output": result.output,
|
||||
"session_id": session_id,
|
||||
}
|
||||
if result.screenshot_base64:
|
||||
formatted["screenshot_base64"] = result.screenshot_base64
|
||||
if result.error:
|
||||
formatted["error"] = result.error
|
||||
if not result.success and result.metadata.get("fallback_suggestions"):
|
||||
formatted["fallback_suggestion"] = "; ".join(
|
||||
result.metadata["fallback_suggestions"]
|
||||
)
|
||||
return formatted
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
action: str,
|
||||
error: str,
|
||||
fallback: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""构造错误结果"""
|
||||
result: dict[str, Any] = {
|
||||
"success": False,
|
||||
"action": action,
|
||||
"error": error,
|
||||
"session_id": None,
|
||||
}
|
||||
if fallback:
|
||||
result["fallback_suggestion"] = fallback
|
||||
return result
|
||||
|
||||
@property
|
||||
def session_manager(self) -> ComputerUseSessionManager:
|
||||
"""获取会话管理器"""
|
||||
return self._session_manager
|
||||
|
||||
@property
|
||||
def recorder(self) -> ComputerUseRecorder:
|
||||
"""获取录制器"""
|
||||
return self._recorder
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
"""ComputerUseRecorder - Computer Use 操作录制与回放
|
||||
|
||||
记录每次截屏和操作,支持回放和审核。
|
||||
支持将录制结果持久化到文件,以及从文件加载回放。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.computer_use_session import ComputerUseSession, ActionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionRecord:
|
||||
"""操作记录
|
||||
|
||||
记录一次 Computer Use 操作的完整信息。
|
||||
"""
|
||||
|
||||
timestamp: float
|
||||
action: str
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = False
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
screenshot_path: str = ""
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> ActionRecord:
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ComputerUseRecorder:
|
||||
"""Computer Use 操作录制器
|
||||
|
||||
记录每次截屏和操作,支持回放和审核。
|
||||
|
||||
Usage:
|
||||
recorder = ComputerUseRecorder()
|
||||
|
||||
# 录制
|
||||
recorder.record("click", {"x": 100, "y": 200}, result)
|
||||
|
||||
# 获取记录
|
||||
records = recorder.get_records()
|
||||
|
||||
# 回放
|
||||
await recorder.replay(session)
|
||||
|
||||
# 持久化
|
||||
recorder.save_recording("recording.json")
|
||||
"""
|
||||
|
||||
def __init__(self, screenshot_dir: str | Path | None = None):
|
||||
self._records: list[ActionRecord] = []
|
||||
self._screenshot_dir = Path(screenshot_dir) if screenshot_dir else None
|
||||
|
||||
def record(
|
||||
self,
|
||||
action: str,
|
||||
params: dict[str, Any],
|
||||
result: ActionResult,
|
||||
screenshot_path: str = "",
|
||||
) -> ActionRecord:
|
||||
"""记录一次操作
|
||||
|
||||
Args:
|
||||
action: 操作类型
|
||||
params: 操作参数
|
||||
result: 操作结果
|
||||
screenshot_path: 截图文件路径
|
||||
|
||||
Returns:
|
||||
ActionRecord 记录实例
|
||||
"""
|
||||
record = ActionRecord(
|
||||
timestamp=time.time(),
|
||||
action=action,
|
||||
params=dict(params),
|
||||
success=result.success,
|
||||
output=result.output[:500] if result.output else "",
|
||||
error=result.error[:500] if result.error else "",
|
||||
screenshot_path=screenshot_path,
|
||||
)
|
||||
self._records.append(record)
|
||||
logger.debug(
|
||||
"Recorded action: %s success=%s",
|
||||
action,
|
||||
result.success,
|
||||
)
|
||||
return record
|
||||
|
||||
def get_records(self) -> list[ActionRecord]:
|
||||
"""获取所有操作记录(副本)"""
|
||||
return list(self._records)
|
||||
|
||||
def get_records_by_action(self, action: str) -> list[ActionRecord]:
|
||||
"""按操作类型筛选记录"""
|
||||
return [r for r in self._records if r.action == action]
|
||||
|
||||
def get_failed_records(self) -> list[ActionRecord]:
|
||||
"""获取所有失败的操作记录"""
|
||||
return [r for r in self._records if not r.success]
|
||||
|
||||
async def replay(self, session: ComputerUseSession) -> list[ActionResult]:
|
||||
"""回放操作序列
|
||||
|
||||
按时间顺序重新执行所有录制的操作。
|
||||
|
||||
Args:
|
||||
session: 目标会话
|
||||
|
||||
Returns:
|
||||
每步操作的 ActionResult 列表
|
||||
"""
|
||||
results: list[ActionResult] = []
|
||||
|
||||
if not session.is_started:
|
||||
await session.start()
|
||||
|
||||
for record in self._records:
|
||||
try:
|
||||
if record.action == "screenshot":
|
||||
result = await session.screenshot()
|
||||
else:
|
||||
result = await session.execute_action(
|
||||
record.action, **record.params
|
||||
)
|
||||
results.append(result)
|
||||
logger.info(
|
||||
"Replayed action: %s success=%s",
|
||||
record.action,
|
||||
result.success,
|
||||
)
|
||||
except Exception as e:
|
||||
results.append(ActionResult(
|
||||
success=False,
|
||||
action=record.action,
|
||||
error=f"Replay failed: {e}",
|
||||
))
|
||||
logger.error(
|
||||
"Replay failed for action %s: %s",
|
||||
record.action,
|
||||
e,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def save_recording(self, path: str | Path) -> None:
|
||||
"""保存录制到文件
|
||||
|
||||
Args:
|
||||
path: 文件路径(JSON 格式)
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"recorded_at": time.time(),
|
||||
"total_actions": len(self._records),
|
||||
"records": [r.to_dict() for r in self._records],
|
||||
}
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info("Recording saved to %s (%d actions)", path, len(self._records))
|
||||
|
||||
def load_recording(self, path: str | Path) -> None:
|
||||
"""从文件加载录制
|
||||
|
||||
Args:
|
||||
path: 文件路径(JSON 格式)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 文件不存在
|
||||
ValueError: 文件格式无效
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Recording file not found: {path}")
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "records" not in data or not isinstance(data["records"], list):
|
||||
raise ValueError(f"Invalid recording format: {path}")
|
||||
|
||||
self._records = [ActionRecord.from_dict(r) for r in data["records"]]
|
||||
logger.info("Recording loaded from %s (%d actions)", path, len(self._records))
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空录制记录"""
|
||||
self._records.clear()
|
||||
|
||||
@property
|
||||
def total_actions(self) -> int:
|
||||
"""总操作数"""
|
||||
return len(self._records)
|
||||
|
||||
@property
|
||||
def success_count(self) -> int:
|
||||
"""成功操作数"""
|
||||
return sum(1 for r in self._records if r.success)
|
||||
|
||||
@property
|
||||
def failure_count(self) -> int:
|
||||
"""失败操作数"""
|
||||
return sum(1 for r in self._records if not r.success)
|
||||
|
||||
def summary(self) -> dict[str, Any]:
|
||||
"""生成录制摘要"""
|
||||
return {
|
||||
"total_actions": self.total_actions,
|
||||
"success_count": self.success_count,
|
||||
"failure_count": self.failure_count,
|
||||
"action_types": list({r.action for r in self._records}),
|
||||
"duration": (
|
||||
self._records[-1].timestamp - self._records[0].timestamp
|
||||
if len(self._records) >= 2
|
||||
else 0.0
|
||||
),
|
||||
}
|
||||
|
|
@ -0,0 +1,417 @@
|
|||
"""ComputerUseSession - 虚拟桌面会话管理
|
||||
|
||||
管理虚拟桌面会话(Docker 沙箱),维护操作上下文。
|
||||
提供 InMemoryComputerUseSession 用于测试,DockerComputerUseSession 用于生产。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScreenInfo:
|
||||
"""屏幕信息"""
|
||||
|
||||
width: int = 1280
|
||||
height: int = 720
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionResult:
|
||||
"""操作执行结果"""
|
||||
|
||||
success: bool
|
||||
action: str
|
||||
output: str = ""
|
||||
screenshot_base64: str = ""
|
||||
error: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ComputerUseSession(ABC):
|
||||
"""虚拟桌面会话抽象基类
|
||||
|
||||
管理虚拟桌面会话,维护操作上下文。
|
||||
子类实现具体的会话管理(Docker 沙箱、内存模拟等)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
screen_width: int = 1280,
|
||||
screen_height: int = 720,
|
||||
):
|
||||
self.session_id = session_id or str(uuid.uuid4())
|
||||
self.screen = ScreenInfo(width=screen_width, height=screen_height)
|
||||
self._started = False
|
||||
self._action_history: list[dict[str, Any]] = []
|
||||
|
||||
@property
|
||||
def is_started(self) -> bool:
|
||||
return self._started
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""启动虚拟桌面会话"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""停止虚拟桌面会话"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def screenshot(self) -> ActionResult:
|
||||
"""截取当前屏幕
|
||||
|
||||
Returns:
|
||||
ActionResult 包含 screenshot_base64 和屏幕内容描述
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
"""执行 UI 操作
|
||||
|
||||
Args:
|
||||
action: 操作类型 (click/type/scroll/drag/key/wait)
|
||||
**params: 操作参数
|
||||
|
||||
Returns:
|
||||
ActionResult 包含操作结果
|
||||
"""
|
||||
...
|
||||
|
||||
def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None:
|
||||
"""记录操作历史"""
|
||||
self._action_history.append({
|
||||
"timestamp": time.time(),
|
||||
"action": action,
|
||||
"params": params,
|
||||
"success": result.success,
|
||||
"output": result.output[:200] if result.output else "",
|
||||
})
|
||||
|
||||
@property
|
||||
def action_history(self) -> list[dict[str, Any]]:
|
||||
"""获取操作历史(副本)"""
|
||||
return list(self._action_history)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "started" if self._started else "stopped"
|
||||
return (
|
||||
f"<{type(self).__name__} id={self.session_id[:8]} "
|
||||
f"screen={self.screen.width}x{self.screen.height} {status}>"
|
||||
)
|
||||
|
||||
|
||||
class InMemoryComputerUseSession(ComputerUseSession):
|
||||
"""内存模拟会话,用于测试
|
||||
|
||||
不连接真实虚拟桌面,模拟截屏和操作结果。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
screen_width: int = 1280,
|
||||
screen_height: int = 720,
|
||||
):
|
||||
super().__init__(
|
||||
session_id=session_id,
|
||||
screen_width=screen_width,
|
||||
screen_height=screen_height,
|
||||
)
|
||||
self._screen_state: dict[str, Any] = {
|
||||
"focused_element": None,
|
||||
"cursor_position": (0, 0),
|
||||
"typed_text": "",
|
||||
}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动内存会话"""
|
||||
self._started = True
|
||||
logger.info("InMemory session started: %s", self.session_id[:8])
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止内存会话"""
|
||||
self._started = False
|
||||
logger.info("InMemory session stopped: %s", self.session_id[:8])
|
||||
|
||||
async def screenshot(self) -> ActionResult:
|
||||
"""模拟截屏,返回空截图和当前状态描述"""
|
||||
if not self._started:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action="screenshot",
|
||||
error="Session not started",
|
||||
)
|
||||
|
||||
state_desc = (
|
||||
f"Screen {self.screen.width}x{self.screen.height}, "
|
||||
f"cursor at {self._screen_state['cursor_position']}, "
|
||||
f"focused: {self._screen_state['focused_element']}, "
|
||||
f"typed: '{self._screen_state['typed_text']}'"
|
||||
)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="screenshot",
|
||||
output=state_desc,
|
||||
screenshot_base64="",
|
||||
metadata={"screen_state": dict(self._screen_state)},
|
||||
)
|
||||
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
"""模拟执行 UI 操作"""
|
||||
if not self._started:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error="Session not started",
|
||||
)
|
||||
|
||||
result = self._simulate_action(action, **params)
|
||||
self.record_action(action, params, result)
|
||||
return result
|
||||
|
||||
def _simulate_action(self, action: str, **params: Any) -> ActionResult:
|
||||
"""模拟具体操作"""
|
||||
if action == "click":
|
||||
x = params.get("x", 0)
|
||||
y = params.get("y", 0)
|
||||
self._screen_state["cursor_position"] = (x, y)
|
||||
self._screen_state["focused_element"] = f"element@({x},{y})"
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="click",
|
||||
output=f"Clicked at ({x}, {y})",
|
||||
)
|
||||
|
||||
if action == "type":
|
||||
text = params.get("text", "")
|
||||
self._screen_state["typed_text"] += text
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="type",
|
||||
output=f"Typed: {text}",
|
||||
)
|
||||
|
||||
if action == "scroll":
|
||||
direction = params.get("direction", "down")
|
||||
amount = params.get("amount", 3)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="scroll",
|
||||
output=f"Scrolled {direction} by {amount}",
|
||||
)
|
||||
|
||||
if action == "drag":
|
||||
start_x = params.get("start_x", 0)
|
||||
start_y = params.get("start_y", 0)
|
||||
end_x = params.get("end_x", 0)
|
||||
end_y = params.get("end_y", 0)
|
||||
self._screen_state["cursor_position"] = (end_x, end_y)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="drag",
|
||||
output=f"Dragged from ({start_x},{start_y}) to ({end_x},{end_y})",
|
||||
)
|
||||
|
||||
if action == "key":
|
||||
key_name = params.get("key_name", "")
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="key",
|
||||
output=f"Pressed key: {key_name}",
|
||||
)
|
||||
|
||||
if action == "wait":
|
||||
duration = params.get("duration", 1.0)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="wait",
|
||||
output=f"Waited {duration}s",
|
||||
)
|
||||
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error=f"Unknown action: {action}",
|
||||
)
|
||||
|
||||
|
||||
class DockerComputerUseSession(ComputerUseSession):
|
||||
"""Docker 沙箱虚拟桌面会话
|
||||
|
||||
通过 Docker 容器运行虚拟桌面,连接 Anthropic Computer Use API。
|
||||
当前为 stub 实现,可后续接入真实 Docker 环境。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
screen_width: int = 1280,
|
||||
screen_height: int = 720,
|
||||
container_image: str = "anthropic/computer-use-demo:latest",
|
||||
docker_url: str = "unix:///var/run/docker.sock",
|
||||
):
|
||||
super().__init__(
|
||||
session_id=session_id,
|
||||
screen_width=screen_width,
|
||||
screen_height=screen_height,
|
||||
)
|
||||
self._container_image = container_image
|
||||
self._docker_url = docker_url
|
||||
self._container_id: str | None = None
|
||||
|
||||
@property
|
||||
def container_id(self) -> str | None:
|
||||
return self._container_id
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 Docker 容器
|
||||
|
||||
Stub: 实际实现需要通过 Docker API 创建容器。
|
||||
"""
|
||||
logger.info(
|
||||
"Docker session start requested: %s (image=%s)",
|
||||
self.session_id[:8],
|
||||
self._container_image,
|
||||
)
|
||||
# TODO: 实际 Docker 容器创建逻辑
|
||||
# async with docker.DockerClient(base_url=self._docker_url) as client:
|
||||
# container = await client.containers.run(
|
||||
# self._container_image,
|
||||
# detach=True,
|
||||
# environment={"SCREEN_WIDTH": self.screen.width, ...},
|
||||
# )
|
||||
# self._container_id = container.id
|
||||
self._started = True
|
||||
self._container_id = f"stub-{self.session_id[:8]}"
|
||||
logger.info("Docker session started (stub): %s", self.session_id[:8])
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止 Docker 容器"""
|
||||
if self._container_id:
|
||||
logger.info(
|
||||
"Docker session stop requested: %s (container=%s)",
|
||||
self.session_id[:8],
|
||||
self._container_id[:12],
|
||||
)
|
||||
# TODO: 实际 Docker 容器停止逻辑
|
||||
# async with docker.DockerClient(base_url=self._docker_url) as client:
|
||||
# container = await client.containers.get(self._container_id)
|
||||
# await container.stop()
|
||||
# await container.remove()
|
||||
self._started = False
|
||||
self._container_id = None
|
||||
logger.info("Docker session stopped: %s", self.session_id[:8])
|
||||
|
||||
async def screenshot(self) -> ActionResult:
|
||||
"""截取 Docker 虚拟桌面屏幕
|
||||
|
||||
Stub: 实际实现需要通过 VNC/HTTP 获取截图。
|
||||
"""
|
||||
if not self._started:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action="screenshot",
|
||||
error="Session not started",
|
||||
)
|
||||
|
||||
# TODO: 实际截屏逻辑
|
||||
return ActionResult(
|
||||
success=True,
|
||||
action="screenshot",
|
||||
output=f"Screenshot from container {self._container_id}",
|
||||
screenshot_base64="",
|
||||
)
|
||||
|
||||
async def execute_action(self, action: str, **params: Any) -> ActionResult:
|
||||
"""在 Docker 虚拟桌面执行操作
|
||||
|
||||
Stub: 实际实现需要通过 Anthropic Computer Use API。
|
||||
"""
|
||||
if not self._started:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
action=action,
|
||||
error="Session not started",
|
||||
)
|
||||
|
||||
# TODO: 实际操作执行逻辑
|
||||
result = ActionResult(
|
||||
success=True,
|
||||
action=action,
|
||||
output=f"Executed {action} in container {self._container_id}",
|
||||
)
|
||||
self.record_action(action, params, result)
|
||||
return result
|
||||
|
||||
|
||||
class ComputerUseSessionManager:
|
||||
"""Computer Use 会话管理器
|
||||
|
||||
管理多个 ComputerUseSession 实例的生命周期。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sessions: int = 10,
|
||||
session_factory: type[ComputerUseSession] | None = None,
|
||||
):
|
||||
self._sessions: dict[str, ComputerUseSession] = {}
|
||||
self._max_sessions = max_sessions
|
||||
self._session_factory = session_factory or InMemoryComputerUseSession
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ComputerUseSession:
|
||||
"""获取或创建会话"""
|
||||
if session_id and session_id in self._sessions:
|
||||
return self._sessions[session_id]
|
||||
|
||||
session = self._session_factory(session_id=session_id, **kwargs)
|
||||
self._sessions[session.session_id] = session
|
||||
|
||||
# 超过上限时移除最旧会话
|
||||
if len(self._sessions) > self._max_sessions:
|
||||
oldest_id = next(iter(self._sessions))
|
||||
self.remove(oldest_id)
|
||||
|
||||
return session
|
||||
|
||||
def get(self, session_id: str) -> ComputerUseSession | None:
|
||||
"""获取指定会话"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def remove(self, session_id: str) -> None:
|
||||
"""移除会话"""
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session and session.is_started:
|
||||
logger.warning("Removing started session: %s", session_id[:8])
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话 ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
def has_session(self, session_id: str) -> bool:
|
||||
"""检查会话是否存在"""
|
||||
return session_id in self._sessions
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""关闭所有会话"""
|
||||
for session in self._sessions.values():
|
||||
if session.is_started:
|
||||
await session.stop()
|
||||
self._sessions.clear()
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue