diff --git a/src/agentkit/tools/computer_use.py b/src/agentkit/tools/computer_use.py new file mode 100644 index 0000000..cebef33 --- /dev/null +++ b/src/agentkit/tools/computer_use.py @@ -0,0 +1,484 @@ +"""ComputerUseTool - Anthropic Computer Use API 集成 + +封装 Anthropic Computer Use API 调用,支持截屏识别、UI 操作、降级策略。 +操作类型:screenshot / click / type / scroll / drag / key / wait +降级链:ComputerUseTool 失败 → 检查是否有 API/CLI 替代 → ShellTool → AskHumanTool +""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any, Callable, Awaitable + +import httpx + +from agentkit.tools.base import Tool +from agentkit.tools.computer_use_session import ( + ComputerUseSession, + InMemoryComputerUseSession, + ComputerUseSessionManager, + ActionResult, +) +from agentkit.tools.computer_use_recorder import ComputerUseRecorder + +logger = logging.getLogger(__name__) + +# 支持的操作类型 +_VALID_ACTIONS = ("screenshot", "click", "type", "scroll", "drag", "key", "wait") + +# Anthropic Computer Use API 端点 +_ANTHROPIC_COMPUTER_USE_URL = "https://api.anthropic.com/v1/messages" + +# 降级建议映射:操作 → 可能的 Shell 替代 +_FALLBACK_SHELL_SUGGESTIONS: dict[str, list[str]] = { + "click": ["考虑使用 xdotool 或 API 替代点击操作"], + "type": ["考虑使用 xdotool type 或 API 替代输入操作"], + "scroll": ["考虑使用 xdotool scroll 或键盘 PageDown/PageUp"], + "drag": ["考虑使用 xdotool mousemove 替代拖拽操作"], + "key": ["考虑使用 xdotool key 或 API 替代键盘操作"], + "screenshot": ["考虑使用 scrot/screenshot 命令或 API 替代截屏"], +} + + +class ComputerUseTool(Tool): + """Computer Use 工具 + + 封装 Anthropic Computer Use API 调用,支持截屏识别和 UI 操作。 + 支持降级链:API 失败 → Shell 替代建议 → AskHuman。 + + Usage: + tool = ComputerUseTool(api_key="sk-...") + result = await tool.execute(action="screenshot") + result = await tool.execute(action="click", x=100, y=200) + result = await tool.execute(action="type", text="hello") + """ + + def __init__( + self, + name: str = "computer_use", + description: str = "Anthropic Computer Use API 集成,支持截屏识别和 UI 操作", + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + api_key: str | None = None, + model: str = "claude-sonnet-4-20250514", + api_base_url: str = _ANTHROPIC_COMPUTER_USE_URL, + session_factory: type[ComputerUseSession] | None = None, + recorder: ComputerUseRecorder | None = None, + fallback_callback: Callable[[str, dict[str, Any]], Awaitable[dict[str, Any]]] | None = None, + max_retries: int = 1, + request_timeout: float = 30.0, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["computer-use", "ui-automation", "anthropic"], + ) + self._api_key = api_key + self._model = model + self._api_base_url = api_base_url + self._session_manager = ComputerUseSessionManager( + session_factory=session_factory or InMemoryComputerUseSession, + ) + self._recorder = recorder or ComputerUseRecorder() + self._fallback_callback = fallback_callback + self._max_retries = max_retries + self._request_timeout = request_timeout + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": list(_VALID_ACTIONS), + "description": "操作类型:screenshot/click/type/scroll/drag/key/wait", + }, + "x": { + "type": "integer", + "description": "X 坐标(click/drag 操作需要)", + }, + "y": { + "type": "integer", + "description": "Y 坐标(click/drag 操作需要)", + }, + "text": { + "type": "string", + "description": "输入文本(type 操作需要)", + }, + "key_name": { + "type": "string", + "description": "按键名称(key 操作需要),如 Enter、Tab、Escape", + }, + "direction": { + "type": "string", + "enum": ["up", "down", "left", "right"], + "description": "滚动方向(scroll 操作需要)", + }, + "amount": { + "type": "integer", + "description": "滚动量(scroll 操作),默认 3", + "default": 3, + }, + "start_x": { + "type": "integer", + "description": "拖拽起始 X 坐标(drag 操作需要)", + }, + "start_y": { + "type": "integer", + "description": "拖拽起始 Y 坐标(drag 操作需要)", + }, + "end_x": { + "type": "integer", + "description": "拖拽结束 X 坐标(drag 操作需要)", + }, + "end_y": { + "type": "integer", + "description": "拖拽结束 Y 坐标(drag 操作需要)", + }, + "duration": { + "type": "number", + "description": "等待时长(wait 操作),默认 1.0 秒", + "default": 1.0, + }, + "session_id": { + "type": "string", + "description": "会话 ID,指定后在对应会话中执行操作", + }, + }, + "required": ["action"], + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "success": {"type": "boolean", "description": "操作是否成功"}, + "action": {"type": "string", "description": "执行的操作类型"}, + "output": {"type": "string", "description": "操作结果描述"}, + "screenshot_base64": {"type": "string", "description": "截图 base64 编码"}, + "error": {"type": "string", "description": "错误信息"}, + "fallback_suggestion": { + "type": "string", + "description": "降级建议(操作失败时提供)", + }, + "session_id": {"type": "string", "description": "会话 ID"}, + }, + } + + async def execute(self, **kwargs) -> dict: + """执行 Computer Use 操作 + + Args: + action: 操作类型(必需) + x, y: 坐标(click 操作) + text: 输入文本(type 操作) + key_name: 按键名称(key 操作) + direction, amount: 滚动方向和量(scroll 操作) + start_x, start_y, end_x, end_y: 拖拽坐标(drag 操作) + duration: 等待时长(wait 操作) + session_id: 会话 ID + + Returns: + 包含 success, action, output 等字段的字典 + """ + action = kwargs.get("action") + if not action: + return self._error_result( + action="unknown", + error="action 参数是必需的", + fallback="指定操作类型:screenshot/click/type/scroll/drag/key/wait", + ) + + if action not in _VALID_ACTIONS: + return self._error_result( + action=action, + error=f"无效的操作类型: {action},支持: {', '.join(_VALID_ACTIONS)}", + ) + + # 参数验证 + validation_error = self._validate_params(action, kwargs) + if validation_error: + return self._error_result( + action=action, + error=validation_error, + ) + + session_id = kwargs.get("session_id") + + # 获取或创建会话 + session = self._session_manager.get_or_create(session_id=session_id) + if not session.is_started: + await session.start() + + # 提取操作参数(排除 action 和 session_id) + params = {k: v for k, v in kwargs.items() if k not in ("action", "session_id") and v is not None} + + # 尝试通过 API 执行 + result = await self._execute_with_fallback(session, action, params) + + # 录制操作 + self._recorder.record(action, params, result) + + return self._format_result(result, session.session_id) + + async def _execute_with_fallback( + self, + session: ComputerUseSession, + action: str, + params: dict[str, Any], + ) -> ActionResult: + """带降级链的操作执行 + + 降级链:Anthropic API → Session 本地执行 → Shell 替代建议 → AskHuman + """ + # 1. 尝试 Anthropic Computer Use API + if self._api_key: + try: + result = await self._call_anthropic_api(session, action, params) + if result.success: + return result + logger.warning( + "Anthropic API returned failure for %s: %s", + action, + result.error, + ) + except Exception as e: + logger.warning("Anthropic API call failed for %s: %s", action, e) + + # 2. 降级到 Session 本地执行 + try: + if action == "screenshot": + result = await session.screenshot() + else: + result = await session.execute_action(action, **params) + if result.success: + result.metadata["execution_mode"] = "session_local" + return result + except Exception as e: + logger.warning("Session local execution failed for %s: %s", action, e) + + # 3. 尝试自定义降级回调 + if self._fallback_callback: + try: + fallback_result = await self._fallback_callback(action, params) + if fallback_result.get("success"): + return ActionResult( + success=True, + action=action, + output=fallback_result.get("output", ""), + metadata={"execution_mode": "fallback_callback"}, + ) + except Exception as e: + logger.warning("Fallback callback failed for %s: %s", action, e) + + # 4. 返回降级建议 + suggestions = _FALLBACK_SHELL_SUGGESTIONS.get(action, []) + fallback_msg = "; ".join(suggestions) if suggestions else "考虑使用 ShellTool 或 AskHumanTool 替代" + + return ActionResult( + success=False, + action=action, + error=f"Computer Use 操作失败,降级建议: {fallback_msg}", + metadata={ + "execution_mode": "fallback_suggestion", + "fallback_suggestions": suggestions, + }, + ) + + async def _call_anthropic_api( + self, + session: ComputerUseSession, + action: str, + params: dict[str, Any], + ) -> ActionResult: + """调用 Anthropic Computer Use API + + 通过 Anthropic Messages API 发送 computer_use_tool 请求。 + """ + # 构造 computer_use 工具调用参数 + tool_input: dict[str, Any] = {"action": action} + if action == "click": + tool_input["coordinate"] = [params.get("x", 0), params.get("y", 0)] + elif action == "type": + tool_input["text"] = params.get("text", "") + elif action == "scroll": + tool_input["direction"] = params.get("direction", "down") + tool_input["amount"] = params.get("amount", 3) + elif action == "drag": + tool_input["start_coordinate"] = [ + params.get("start_x", 0), + params.get("start_y", 0), + ] + tool_input["coordinate"] = [ + params.get("end_x", 0), + params.get("end_y", 0), + ] + elif action == "key": + tool_input["key"] = params.get("key_name", "") + elif action == "wait": + tool_input["duration"] = params.get("duration", 1.0) + elif action == "screenshot": + pass # screenshot 不需要额外参数 + + # 获取当前截图作为上下文 + screenshot_result = await session.screenshot() + screenshot_b64 = screenshot_result.screenshot_base64 + + # 构造 API 请求 + content_blocks: list[dict[str, Any]] = [] + if screenshot_b64: + content_blocks.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": screenshot_b64, + }, + }) + + content_blocks.append({ + "type": "text", + "text": f"Execute computer use action: {action}", + }) + + request_body = { + "model": self._model, + "max_tokens": 1024, + "tools": [ + { + "type": "computer_20250124", + "name": "computer", + "display_width_px": session.screen.width, + "display_height_px": session.screen.height, + } + ], + "messages": [ + { + "role": "user", + "content": content_blocks, + } + ], + } + + headers = { + "x-api-key": self._api_key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + + async with httpx.AsyncClient(timeout=self._request_timeout) as client: + response = await client.post( + self._api_base_url, + json=request_body, + headers=headers, + ) + + if response.status_code != 200: + error_detail = response.text[:500] + return ActionResult( + success=False, + action=action, + error=f"Anthropic API error {response.status_code}: {error_detail}", + ) + + data = response.json() + + # 解析 API 响应中的 tool_use 内容 + for block in data.get("content", []): + if block.get("type") == "tool_use" and block.get("name") == "computer": + tool_input_resp = block.get("input", {}) + resp_action = tool_input_resp.get("action", action) + return ActionResult( + success=True, + action=resp_action, + output=f"API executed: {resp_action}", + metadata={"api_response": data}, + ) + + # API 没有返回 tool_use,可能是纯文本响应 + text_output = "" + for block in data.get("content", []): + if block.get("type") == "text": + text_output += block.get("text", "") + + return ActionResult( + success=True, + action=action, + output=text_output[:500] if text_output else "API call completed", + metadata={"api_response": data}, + ) + + def _validate_params(self, action: str, kwargs: dict[str, Any]) -> str | None: + """验证操作参数 + + Returns: + 错误信息,None 表示验证通过 + """ + if action == "click": + if "x" not in kwargs or "y" not in kwargs: + return "click 操作需要 x 和 y 参数" + elif action == "type": + if not kwargs.get("text"): + return "type 操作需要 text 参数" + elif action == "key": + if not kwargs.get("key_name"): + return "key 操作需要 key_name 参数" + elif action == "drag": + required = ("start_x", "start_y", "end_x", "end_y") + missing = [r for r in required if r not in kwargs] + if missing: + return f"drag 操作需要 {', '.join(missing)} 参数" + return None + + def _format_result(self, result: ActionResult, session_id: str) -> dict[str, Any]: + """格式化操作结果""" + formatted: dict[str, Any] = { + "success": result.success, + "action": result.action, + "output": result.output, + "session_id": session_id, + } + if result.screenshot_base64: + formatted["screenshot_base64"] = result.screenshot_base64 + if result.error: + formatted["error"] = result.error + if not result.success and result.metadata.get("fallback_suggestions"): + formatted["fallback_suggestion"] = "; ".join( + result.metadata["fallback_suggestions"] + ) + return formatted + + def _error_result( + self, + action: str, + error: str, + fallback: str = "", + ) -> dict[str, Any]: + """构造错误结果""" + result: dict[str, Any] = { + "success": False, + "action": action, + "error": error, + "session_id": None, + } + if fallback: + result["fallback_suggestion"] = fallback + return result + + @property + def session_manager(self) -> ComputerUseSessionManager: + """获取会话管理器""" + return self._session_manager + + @property + def recorder(self) -> ComputerUseRecorder: + """获取录制器""" + return self._recorder diff --git a/src/agentkit/tools/computer_use_recorder.py b/src/agentkit/tools/computer_use_recorder.py new file mode 100644 index 0000000..7b567c1 --- /dev/null +++ b/src/agentkit/tools/computer_use_recorder.py @@ -0,0 +1,235 @@ +"""ComputerUseRecorder - Computer Use 操作录制与回放 + +记录每次截屏和操作,支持回放和审核。 +支持将录制结果持久化到文件,以及从文件加载回放。 +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +from agentkit.tools.computer_use_session import ComputerUseSession, ActionResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ActionRecord: + """操作记录 + + 记录一次 Computer Use 操作的完整信息。 + """ + + timestamp: float + action: str + params: dict[str, Any] = field(default_factory=dict) + success: bool = False + output: str = "" + error: str = "" + screenshot_path: str = "" + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ActionRecord: + return cls(**data) + + +class ComputerUseRecorder: + """Computer Use 操作录制器 + + 记录每次截屏和操作,支持回放和审核。 + + Usage: + recorder = ComputerUseRecorder() + + # 录制 + recorder.record("click", {"x": 100, "y": 200}, result) + + # 获取记录 + records = recorder.get_records() + + # 回放 + await recorder.replay(session) + + # 持久化 + recorder.save_recording("recording.json") + """ + + def __init__(self, screenshot_dir: str | Path | None = None): + self._records: list[ActionRecord] = [] + self._screenshot_dir = Path(screenshot_dir) if screenshot_dir else None + + def record( + self, + action: str, + params: dict[str, Any], + result: ActionResult, + screenshot_path: str = "", + ) -> ActionRecord: + """记录一次操作 + + Args: + action: 操作类型 + params: 操作参数 + result: 操作结果 + screenshot_path: 截图文件路径 + + Returns: + ActionRecord 记录实例 + """ + record = ActionRecord( + timestamp=time.time(), + action=action, + params=dict(params), + success=result.success, + output=result.output[:500] if result.output else "", + error=result.error[:500] if result.error else "", + screenshot_path=screenshot_path, + ) + self._records.append(record) + logger.debug( + "Recorded action: %s success=%s", + action, + result.success, + ) + return record + + def get_records(self) -> list[ActionRecord]: + """获取所有操作记录(副本)""" + return list(self._records) + + def get_records_by_action(self, action: str) -> list[ActionRecord]: + """按操作类型筛选记录""" + return [r for r in self._records if r.action == action] + + def get_failed_records(self) -> list[ActionRecord]: + """获取所有失败的操作记录""" + return [r for r in self._records if not r.success] + + async def replay(self, session: ComputerUseSession) -> list[ActionResult]: + """回放操作序列 + + 按时间顺序重新执行所有录制的操作。 + + Args: + session: 目标会话 + + Returns: + 每步操作的 ActionResult 列表 + """ + results: list[ActionResult] = [] + + if not session.is_started: + await session.start() + + for record in self._records: + try: + if record.action == "screenshot": + result = await session.screenshot() + else: + result = await session.execute_action( + record.action, **record.params + ) + results.append(result) + logger.info( + "Replayed action: %s success=%s", + record.action, + result.success, + ) + except Exception as e: + results.append(ActionResult( + success=False, + action=record.action, + error=f"Replay failed: {e}", + )) + logger.error( + "Replay failed for action %s: %s", + record.action, + e, + ) + + return results + + def save_recording(self, path: str | Path) -> None: + """保存录制到文件 + + Args: + path: 文件路径(JSON 格式) + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + data = { + "version": "1.0", + "recorded_at": time.time(), + "total_actions": len(self._records), + "records": [r.to_dict() for r in self._records], + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + + logger.info("Recording saved to %s (%d actions)", path, len(self._records)) + + def load_recording(self, path: str | Path) -> None: + """从文件加载录制 + + Args: + path: 文件路径(JSON 格式) + + Raises: + FileNotFoundError: 文件不存在 + ValueError: 文件格式无效 + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Recording file not found: {path}") + + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + if "records" not in data or not isinstance(data["records"], list): + raise ValueError(f"Invalid recording format: {path}") + + self._records = [ActionRecord.from_dict(r) for r in data["records"]] + logger.info("Recording loaded from %s (%d actions)", path, len(self._records)) + + def clear(self) -> None: + """清空录制记录""" + self._records.clear() + + @property + def total_actions(self) -> int: + """总操作数""" + return len(self._records) + + @property + def success_count(self) -> int: + """成功操作数""" + return sum(1 for r in self._records if r.success) + + @property + def failure_count(self) -> int: + """失败操作数""" + return sum(1 for r in self._records if not r.success) + + def summary(self) -> dict[str, Any]: + """生成录制摘要""" + return { + "total_actions": self.total_actions, + "success_count": self.success_count, + "failure_count": self.failure_count, + "action_types": list({r.action for r in self._records}), + "duration": ( + self._records[-1].timestamp - self._records[0].timestamp + if len(self._records) >= 2 + else 0.0 + ), + } diff --git a/src/agentkit/tools/computer_use_session.py b/src/agentkit/tools/computer_use_session.py new file mode 100644 index 0000000..d185fcf --- /dev/null +++ b/src/agentkit/tools/computer_use_session.py @@ -0,0 +1,417 @@ +"""ComputerUseSession - 虚拟桌面会话管理 + +管理虚拟桌面会话(Docker 沙箱),维护操作上下文。 +提供 InMemoryComputerUseSession 用于测试,DockerComputerUseSession 用于生产。 +""" + +from __future__ import annotations + +import logging +import time +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class ScreenInfo: + """屏幕信息""" + + width: int = 1280 + height: int = 720 + + +@dataclass +class ActionResult: + """操作执行结果""" + + success: bool + action: str + output: str = "" + screenshot_base64: str = "" + error: str = "" + metadata: dict[str, Any] = field(default_factory=dict) + + +class ComputerUseSession(ABC): + """虚拟桌面会话抽象基类 + + 管理虚拟桌面会话,维护操作上下文。 + 子类实现具体的会话管理(Docker 沙箱、内存模拟等)。 + """ + + def __init__( + self, + session_id: str | None = None, + screen_width: int = 1280, + screen_height: int = 720, + ): + self.session_id = session_id or str(uuid.uuid4()) + self.screen = ScreenInfo(width=screen_width, height=screen_height) + self._started = False + self._action_history: list[dict[str, Any]] = [] + + @property + def is_started(self) -> bool: + return self._started + + @abstractmethod + async def start(self) -> None: + """启动虚拟桌面会话""" + ... + + @abstractmethod + async def stop(self) -> None: + """停止虚拟桌面会话""" + ... + + @abstractmethod + async def screenshot(self) -> ActionResult: + """截取当前屏幕 + + Returns: + ActionResult 包含 screenshot_base64 和屏幕内容描述 + """ + ... + + @abstractmethod + async def execute_action(self, action: str, **params: Any) -> ActionResult: + """执行 UI 操作 + + Args: + action: 操作类型 (click/type/scroll/drag/key/wait) + **params: 操作参数 + + Returns: + ActionResult 包含操作结果 + """ + ... + + def record_action(self, action: str, params: dict[str, Any], result: ActionResult) -> None: + """记录操作历史""" + self._action_history.append({ + "timestamp": time.time(), + "action": action, + "params": params, + "success": result.success, + "output": result.output[:200] if result.output else "", + }) + + @property + def action_history(self) -> list[dict[str, Any]]: + """获取操作历史(副本)""" + return list(self._action_history) + + def __repr__(self) -> str: + status = "started" if self._started else "stopped" + return ( + f"<{type(self).__name__} id={self.session_id[:8]} " + f"screen={self.screen.width}x{self.screen.height} {status}>" + ) + + +class InMemoryComputerUseSession(ComputerUseSession): + """内存模拟会话,用于测试 + + 不连接真实虚拟桌面,模拟截屏和操作结果。 + """ + + def __init__( + self, + session_id: str | None = None, + screen_width: int = 1280, + screen_height: int = 720, + ): + super().__init__( + session_id=session_id, + screen_width=screen_width, + screen_height=screen_height, + ) + self._screen_state: dict[str, Any] = { + "focused_element": None, + "cursor_position": (0, 0), + "typed_text": "", + } + + async def start(self) -> None: + """启动内存会话""" + self._started = True + logger.info("InMemory session started: %s", self.session_id[:8]) + + async def stop(self) -> None: + """停止内存会话""" + self._started = False + logger.info("InMemory session stopped: %s", self.session_id[:8]) + + async def screenshot(self) -> ActionResult: + """模拟截屏,返回空截图和当前状态描述""" + if not self._started: + return ActionResult( + success=False, + action="screenshot", + error="Session not started", + ) + + state_desc = ( + f"Screen {self.screen.width}x{self.screen.height}, " + f"cursor at {self._screen_state['cursor_position']}, " + f"focused: {self._screen_state['focused_element']}, " + f"typed: '{self._screen_state['typed_text']}'" + ) + return ActionResult( + success=True, + action="screenshot", + output=state_desc, + screenshot_base64="", + metadata={"screen_state": dict(self._screen_state)}, + ) + + async def execute_action(self, action: str, **params: Any) -> ActionResult: + """模拟执行 UI 操作""" + if not self._started: + return ActionResult( + success=False, + action=action, + error="Session not started", + ) + + result = self._simulate_action(action, **params) + self.record_action(action, params, result) + return result + + def _simulate_action(self, action: str, **params: Any) -> ActionResult: + """模拟具体操作""" + if action == "click": + x = params.get("x", 0) + y = params.get("y", 0) + self._screen_state["cursor_position"] = (x, y) + self._screen_state["focused_element"] = f"element@({x},{y})" + return ActionResult( + success=True, + action="click", + output=f"Clicked at ({x}, {y})", + ) + + if action == "type": + text = params.get("text", "") + self._screen_state["typed_text"] += text + return ActionResult( + success=True, + action="type", + output=f"Typed: {text}", + ) + + if action == "scroll": + direction = params.get("direction", "down") + amount = params.get("amount", 3) + return ActionResult( + success=True, + action="scroll", + output=f"Scrolled {direction} by {amount}", + ) + + if action == "drag": + start_x = params.get("start_x", 0) + start_y = params.get("start_y", 0) + end_x = params.get("end_x", 0) + end_y = params.get("end_y", 0) + self._screen_state["cursor_position"] = (end_x, end_y) + return ActionResult( + success=True, + action="drag", + output=f"Dragged from ({start_x},{start_y}) to ({end_x},{end_y})", + ) + + if action == "key": + key_name = params.get("key_name", "") + return ActionResult( + success=True, + action="key", + output=f"Pressed key: {key_name}", + ) + + if action == "wait": + duration = params.get("duration", 1.0) + return ActionResult( + success=True, + action="wait", + output=f"Waited {duration}s", + ) + + return ActionResult( + success=False, + action=action, + error=f"Unknown action: {action}", + ) + + +class DockerComputerUseSession(ComputerUseSession): + """Docker 沙箱虚拟桌面会话 + + 通过 Docker 容器运行虚拟桌面,连接 Anthropic Computer Use API。 + 当前为 stub 实现,可后续接入真实 Docker 环境。 + """ + + def __init__( + self, + session_id: str | None = None, + screen_width: int = 1280, + screen_height: int = 720, + container_image: str = "anthropic/computer-use-demo:latest", + docker_url: str = "unix:///var/run/docker.sock", + ): + super().__init__( + session_id=session_id, + screen_width=screen_width, + screen_height=screen_height, + ) + self._container_image = container_image + self._docker_url = docker_url + self._container_id: str | None = None + + @property + def container_id(self) -> str | None: + return self._container_id + + async def start(self) -> None: + """启动 Docker 容器 + + Stub: 实际实现需要通过 Docker API 创建容器。 + """ + logger.info( + "Docker session start requested: %s (image=%s)", + self.session_id[:8], + self._container_image, + ) + # TODO: 实际 Docker 容器创建逻辑 + # async with docker.DockerClient(base_url=self._docker_url) as client: + # container = await client.containers.run( + # self._container_image, + # detach=True, + # environment={"SCREEN_WIDTH": self.screen.width, ...}, + # ) + # self._container_id = container.id + self._started = True + self._container_id = f"stub-{self.session_id[:8]}" + logger.info("Docker session started (stub): %s", self.session_id[:8]) + + async def stop(self) -> None: + """停止 Docker 容器""" + if self._container_id: + logger.info( + "Docker session stop requested: %s (container=%s)", + self.session_id[:8], + self._container_id[:12], + ) + # TODO: 实际 Docker 容器停止逻辑 + # async with docker.DockerClient(base_url=self._docker_url) as client: + # container = await client.containers.get(self._container_id) + # await container.stop() + # await container.remove() + self._started = False + self._container_id = None + logger.info("Docker session stopped: %s", self.session_id[:8]) + + async def screenshot(self) -> ActionResult: + """截取 Docker 虚拟桌面屏幕 + + Stub: 实际实现需要通过 VNC/HTTP 获取截图。 + """ + if not self._started: + return ActionResult( + success=False, + action="screenshot", + error="Session not started", + ) + + # TODO: 实际截屏逻辑 + return ActionResult( + success=True, + action="screenshot", + output=f"Screenshot from container {self._container_id}", + screenshot_base64="", + ) + + async def execute_action(self, action: str, **params: Any) -> ActionResult: + """在 Docker 虚拟桌面执行操作 + + Stub: 实际实现需要通过 Anthropic Computer Use API。 + """ + if not self._started: + return ActionResult( + success=False, + action=action, + error="Session not started", + ) + + # TODO: 实际操作执行逻辑 + result = ActionResult( + success=True, + action=action, + output=f"Executed {action} in container {self._container_id}", + ) + self.record_action(action, params, result) + return result + + +class ComputerUseSessionManager: + """Computer Use 会话管理器 + + 管理多个 ComputerUseSession 实例的生命周期。 + """ + + def __init__( + self, + max_sessions: int = 10, + session_factory: type[ComputerUseSession] | None = None, + ): + self._sessions: dict[str, ComputerUseSession] = {} + self._max_sessions = max_sessions + self._session_factory = session_factory or InMemoryComputerUseSession + + def get_or_create( + self, + session_id: str | None = None, + **kwargs: Any, + ) -> ComputerUseSession: + """获取或创建会话""" + if session_id and session_id in self._sessions: + return self._sessions[session_id] + + session = self._session_factory(session_id=session_id, **kwargs) + self._sessions[session.session_id] = session + + # 超过上限时移除最旧会话 + if len(self._sessions) > self._max_sessions: + oldest_id = next(iter(self._sessions)) + self.remove(oldest_id) + + return session + + def get(self, session_id: str) -> ComputerUseSession | None: + """获取指定会话""" + return self._sessions.get(session_id) + + def remove(self, session_id: str) -> None: + """移除会话""" + session = self._sessions.pop(session_id, None) + if session and session.is_started: + logger.warning("Removing started session: %s", session_id[:8]) + + def list_sessions(self) -> list[str]: + """列出所有会话 ID""" + return list(self._sessions.keys()) + + def has_session(self, session_id: str) -> bool: + """检查会话是否存在""" + return session_id in self._sessions + + async def close_all(self) -> None: + """关闭所有会话""" + for session in self._sessions.values(): + if session.is_started: + await session.stop() + self._sessions.clear() diff --git a/tests/unit/tools/test_computer_use.py b/tests/unit/tools/test_computer_use.py new file mode 100644 index 0000000..7608d2c --- /dev/null +++ b/tests/unit/tools/test_computer_use.py @@ -0,0 +1,1070 @@ +"""ComputerUseTool / ComputerUseSession / ComputerUseRecorder 单元测试 + +测试场景: +- 截屏并识别 UI 元素 → 返回可操作区域列表 +- 点击指定坐标 → 操作成功 +- 输入文本到输入框 → 操作成功 +- 多步骤 UI 操作 → 每步根据结果决定下一步 +- API 不可用时降级到 ShellTool → 正确降级 +- Covers AE2: Computer Use 失败 → 降级到 OA 系统 API +- 操作录制回放 → 可回放操作序列 +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.tools.computer_use import ComputerUseTool +from agentkit.tools.computer_use_session import ( + ComputerUseSession, + InMemoryComputerUseSession, + DockerComputerUseSession, + ComputerUseSessionManager, + ActionResult, + ScreenInfo, +) +from agentkit.tools.computer_use_recorder import ComputerUseRecorder, ActionRecord + + +# ============================================================ +# ActionResult 测试 +# ============================================================ + + +class TestActionResult: + """测试 ActionResult 数据类""" + + def test_success_result(self): + result = ActionResult(success=True, action="click", output="Clicked at (100, 200)") + assert result.success is True + assert result.action == "click" + assert result.output == "Clicked at (100, 200)" + assert result.error == "" + assert result.screenshot_base64 == "" + + def test_failure_result(self): + result = ActionResult(success=False, action="click", error="Session not started") + assert result.success is False + assert result.error == "Session not started" + + def test_result_with_metadata(self): + result = ActionResult( + success=True, + action="screenshot", + metadata={"screen_state": {"cursor": (0, 0)}}, + ) + assert result.metadata["screen_state"]["cursor"] == (0, 0) + + +# ============================================================ +# ScreenInfo 测试 +# ============================================================ + + +class TestScreenInfo: + """测试 ScreenInfo 数据类""" + + def test_default_screen(self): + screen = ScreenInfo() + assert screen.width == 1280 + assert screen.height == 720 + + def test_custom_screen(self): + screen = ScreenInfo(width=1920, height=1080) + assert screen.width == 1920 + assert screen.height == 1080 + + +# ============================================================ +# InMemoryComputerUseSession 测试 +# ============================================================ + + +class TestInMemoryComputerUseSession: + """测试 InMemoryComputerUseSession 内存模拟会话""" + + def test_construction_default(self): + session = InMemoryComputerUseSession() + assert session.session_id is not None + assert session.screen.width == 1280 + assert session.screen.height == 720 + assert session.is_started is False + + def test_construction_custom(self): + session = InMemoryComputerUseSession( + session_id="test-123", + screen_width=1920, + screen_height=1080, + ) + assert session.session_id == "test-123" + assert session.screen.width == 1920 + + @pytest.mark.asyncio + async def test_start_stop(self): + session = InMemoryComputerUseSession() + assert session.is_started is False + await session.start() + assert session.is_started is True + await session.stop() + assert session.is_started is False + + @pytest.mark.asyncio + async def test_screenshot_not_started(self): + """未启动时截屏失败""" + session = InMemoryComputerUseSession() + result = await session.screenshot() + assert result.success is False + assert "not started" in result.error + + @pytest.mark.asyncio + async def test_screenshot_started(self): + """启动后截屏成功""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.screenshot() + assert result.success is True + assert result.action == "screenshot" + assert "1280x720" in result.output + + @pytest.mark.asyncio + async def test_click_action(self): + """点击操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("click", x=100, y=200) + assert result.success is True + assert "(100, 200)" in result.output + + @pytest.mark.asyncio + async def test_type_action(self): + """输入文本操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("type", text="hello world") + assert result.success is True + assert "hello world" in result.output + + @pytest.mark.asyncio + async def test_scroll_action(self): + """滚动操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("scroll", direction="down", amount=5) + assert result.success is True + assert "down" in result.output + assert "5" in result.output + + @pytest.mark.asyncio + async def test_drag_action(self): + """拖拽操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action( + "drag", start_x=10, start_y=20, end_x=100, end_y=200 + ) + assert result.success is True + assert "(10,20)" in result.output + assert "(100,200)" in result.output + + @pytest.mark.asyncio + async def test_key_action(self): + """按键操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("key", key_name="Enter") + assert result.success is True + assert "Enter" in result.output + + @pytest.mark.asyncio + async def test_wait_action(self): + """等待操作""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("wait", duration=2.0) + assert result.success is True + assert "2.0" in result.output + + @pytest.mark.asyncio + async def test_unknown_action(self): + """未知操作类型""" + session = InMemoryComputerUseSession() + await session.start() + result = await session.execute_action("unknown_action") + assert result.success is False + assert "Unknown" in result.error + + @pytest.mark.asyncio + async def test_action_not_started(self): + """未启动时执行操作失败""" + session = InMemoryComputerUseSession() + result = await session.execute_action("click", x=0, y=0) + assert result.success is False + assert "not started" in result.error + + @pytest.mark.asyncio + async def test_action_history(self): + """操作历史记录""" + session = InMemoryComputerUseSession() + await session.start() + await session.execute_action("click", x=10, y=20) + await session.execute_action("type", text="test") + assert len(session.action_history) == 2 + assert session.action_history[0]["action"] == "click" + assert session.action_history[1]["action"] == "type" + + @pytest.mark.asyncio + async def test_action_history_is_copy(self): + """操作历史返回副本""" + session = InMemoryComputerUseSession() + await session.start() + await session.execute_action("click", x=0, y=0) + history = session.action_history + history.clear() + assert len(session.action_history) == 1 + + def test_repr(self): + session = InMemoryComputerUseSession(session_id="test-123") + r = repr(session) + assert "InMemory" in r + assert "stopped" in r + + +# ============================================================ +# DockerComputerUseSession 测试 +# ============================================================ + + +class TestDockerComputerUseSession: + """测试 DockerComputerUseSession(stub 实现)""" + + def test_construction(self): + session = DockerComputerUseSession( + session_id="docker-1", + container_image="anthropic/computer-use-demo:latest", + ) + assert session.session_id == "docker-1" + assert session.container_id is None + + @pytest.mark.asyncio + async def test_start_stop(self): + session = DockerComputerUseSession(session_id="docker-1") + await session.start() + assert session.is_started is True + assert session.container_id is not None + await session.stop() + assert session.is_started is False + assert session.container_id is None + + @pytest.mark.asyncio + async def test_screenshot_not_started(self): + session = DockerComputerUseSession(session_id="docker-1") + result = await session.screenshot() + assert result.success is False + + @pytest.mark.asyncio + async def test_screenshot_started(self): + session = DockerComputerUseSession(session_id="docker-1") + await session.start() + result = await session.screenshot() + assert result.success is True + + @pytest.mark.asyncio + async def test_execute_action_not_started(self): + session = DockerComputerUseSession(session_id="docker-1") + result = await session.execute_action("click", x=0, y=0) + assert result.success is False + + @pytest.mark.asyncio + async def test_execute_action_started(self): + session = DockerComputerUseSession(session_id="docker-1") + await session.start() + result = await session.execute_action("click", x=100, y=200) + assert result.success is True + + +# ============================================================ +# ComputerUseSessionManager 测试 +# ============================================================ + + +class TestComputerUseSessionManager: + """测试 ComputerUseSessionManager 会话管理""" + + def test_get_or_create_new(self): + manager = ComputerUseSessionManager() + session = manager.get_or_create("s1") + assert session.session_id == "s1" + + def test_get_or_create_existing(self): + manager = ComputerUseSessionManager() + s1 = manager.get_or_create("s1") + s2 = manager.get_or_create("s1") + assert s1 is s2 + + def test_get_existing(self): + manager = ComputerUseSessionManager() + manager.get_or_create("s1") + session = manager.get("s1") + assert session is not None + + def test_get_nonexistent(self): + manager = ComputerUseSessionManager() + assert manager.get("nonexistent") is None + + def test_remove(self): + manager = ComputerUseSessionManager() + manager.get_or_create("s1") + manager.remove("s1") + assert manager.get("s1") is None + + def test_list_sessions(self): + manager = ComputerUseSessionManager() + manager.get_or_create("s1") + manager.get_or_create("s2") + assert sorted(manager.list_sessions()) == ["s1", "s2"] + + def test_has_session(self): + manager = ComputerUseSessionManager() + manager.get_or_create("s1") + assert manager.has_session("s1") is True + assert manager.has_session("s2") is False + + def test_max_sessions_eviction(self): + manager = ComputerUseSessionManager(max_sessions=2) + manager.get_or_create("s1") + manager.get_or_create("s2") + manager.get_or_create("s3") + assert not manager.has_session("s1") + assert manager.has_session("s2") + assert manager.has_session("s3") + + @pytest.mark.asyncio + async def test_close_all(self): + manager = ComputerUseSessionManager() + s1 = manager.get_or_create("s1") + s2 = manager.get_or_create("s2") + await s1.start() + await s2.start() + await manager.close_all() + assert manager.list_sessions() == [] + assert s1.is_started is False + assert s2.is_started is False + + def test_custom_session_factory(self): + manager = ComputerUseSessionManager( + session_factory=DockerComputerUseSession, + ) + session = manager.get_or_create("docker-1") + assert isinstance(session, DockerComputerUseSession) + + +# ============================================================ +# ComputerUseRecorder 测试 +# ============================================================ + + +class TestActionRecord: + """测试 ActionRecord 数据类""" + + def test_to_dict(self): + record = ActionRecord( + timestamp=1000.0, + action="click", + params={"x": 100, "y": 200}, + success=True, + output="Clicked at (100, 200)", + ) + d = record.to_dict() + assert d["action"] == "click" + assert d["params"]["x"] == 100 + assert d["success"] is True + + def test_from_dict(self): + data = { + "timestamp": 1000.0, + "action": "type", + "params": {"text": "hello"}, + "success": True, + "output": "Typed: hello", + "error": "", + "screenshot_path": "", + } + record = ActionRecord.from_dict(data) + assert record.action == "type" + assert record.params["text"] == "hello" + + def test_roundtrip(self): + record = ActionRecord( + timestamp=1000.0, + action="click", + params={"x": 50, "y": 60}, + success=False, + error="Timeout", + ) + d = record.to_dict() + restored = ActionRecord.from_dict(d) + assert restored.action == record.action + assert restored.success == record.success + assert restored.error == record.error + + +class TestComputerUseRecorder: + """测试 ComputerUseRecorder 操作录制器""" + + def test_record_action(self): + recorder = ComputerUseRecorder() + result = ActionResult(success=True, action="click", output="Clicked") + record = recorder.record("click", {"x": 100, "y": 200}, result) + assert record.action == "click" + assert record.success is True + + def test_get_records(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="ok") + r2 = ActionResult(success=True, action="type", output="ok") + recorder.record("click", {"x": 1, "y": 2}, r1) + recorder.record("type", {"text": "hi"}, r2) + records = recorder.get_records() + assert len(records) == 2 + + def test_get_records_is_copy(self): + recorder = ComputerUseRecorder() + result = ActionResult(success=True, action="click", output="ok") + recorder.record("click", {"x": 1, "y": 2}, result) + records = recorder.get_records() + records.clear() + assert len(recorder.get_records()) == 1 + + def test_get_records_by_action(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="ok") + r2 = ActionResult(success=True, action="type", output="ok") + r3 = ActionResult(success=True, action="click", output="ok") + recorder.record("click", {"x": 1, "y": 2}, r1) + recorder.record("type", {"text": "hi"}, r2) + recorder.record("click", {"x": 3, "y": 4}, r3) + click_records = recorder.get_records_by_action("click") + assert len(click_records) == 2 + + def test_get_failed_records(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="ok") + r2 = ActionResult(success=False, action="type", error="failed") + recorder.record("click", {"x": 1, "y": 2}, r1) + recorder.record("type", {"text": "hi"}, r2) + failed = recorder.get_failed_records() + assert len(failed) == 1 + assert failed[0].action == "type" + + def test_total_actions(self): + recorder = ComputerUseRecorder() + assert recorder.total_actions == 0 + result = ActionResult(success=True, action="click", output="ok") + recorder.record("click", {"x": 1, "y": 2}, result) + assert recorder.total_actions == 1 + + def test_success_failure_counts(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="ok") + r2 = ActionResult(success=False, action="type", error="fail") + r3 = ActionResult(success=True, action="scroll", output="ok") + recorder.record("click", {}, r1) + recorder.record("type", {}, r2) + recorder.record("scroll", {}, r3) + assert recorder.success_count == 2 + assert recorder.failure_count == 1 + + def test_summary(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="ok") + r2 = ActionResult(success=False, action="type", error="fail") + recorder.record("click", {}, r1) + recorder.record("type", {}, r2) + s = recorder.summary() + assert s["total_actions"] == 2 + assert s["success_count"] == 1 + assert s["failure_count"] == 1 + assert "click" in s["action_types"] + assert "type" in s["action_types"] + + def test_clear(self): + recorder = ComputerUseRecorder() + result = ActionResult(success=True, action="click", output="ok") + recorder.record("click", {}, result) + recorder.clear() + assert recorder.total_actions == 0 + + def test_save_and_load_recording(self): + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="Clicked at (10, 20)") + r2 = ActionResult(success=True, action="type", output="Typed: hello") + recorder.record("click", {"x": 10, "y": 20}, r1) + recorder.record("type", {"text": "hello"}, r2) + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + + try: + recorder.save_recording(path) + + # 加载到新录制器 + recorder2 = ComputerUseRecorder() + recorder2.load_recording(path) + assert recorder2.total_actions == 2 + records = recorder2.get_records() + assert records[0].action == "click" + assert records[1].action == "type" + finally: + Path(path).unlink(missing_ok=True) + + def test_load_nonexistent_file(self): + recorder = ComputerUseRecorder() + with pytest.raises(FileNotFoundError): + recorder.load_recording("/nonexistent/path/recording.json") + + def test_load_invalid_format(self): + recorder = ComputerUseRecorder() + with tempfile.NamedTemporaryFile( + suffix=".json", delete=False, mode="w" + ) as f: + json.dump({"invalid": True}, f) + path = f.name + + try: + with pytest.raises(ValueError, match="Invalid recording format"): + recorder.load_recording(path) + finally: + Path(path).unlink(missing_ok=True) + + @pytest.mark.asyncio + async def test_replay(self): + """回放操作序列""" + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="click", output="Clicked") + r2 = ActionResult(success=True, action="type", output="Typed") + recorder.record("click", {"x": 10, "y": 20}, r1) + recorder.record("type", {"text": "hello"}, r2) + + session = InMemoryComputerUseSession(session_id="replay-test") + results = await recorder.replay(session) + assert len(results) == 2 + assert results[0].success is True + assert results[1].success is True + assert session.is_started is True + + @pytest.mark.asyncio + async def test_replay_with_screenshot(self): + """回放包含截屏的操作序列""" + recorder = ComputerUseRecorder() + r1 = ActionResult(success=True, action="screenshot", output="Screen captured") + recorder.record("screenshot", {}, r1) + + session = InMemoryComputerUseSession(session_id="replay-screenshot") + results = await recorder.replay(session) + assert len(results) == 1 + assert results[0].success is True + + +# ============================================================ +# ComputerUseTool 构造测试 +# ============================================================ + + +class TestComputerUseToolConstruction: + """测试 ComputerUseTool 构造""" + + def test_default_construction(self): + tool = ComputerUseTool() + assert tool.name == "computer_use" + assert tool.input_schema is not None + assert "action" in tool.input_schema["properties"] + assert tool.input_schema["required"] == ["action"] + + def test_custom_construction(self): + tool = ComputerUseTool(name="my_cu", version="2.0.0") + assert tool.name == "my_cu" + assert tool.version == "2.0.0" + + def test_to_dict(self): + tool = ComputerUseTool() + d = tool.to_dict() + assert d["name"] == "computer_use" + assert "input_schema" in d + + def test_repr(self): + tool = ComputerUseTool() + r = repr(tool) + assert "ComputerUseTool" in r + assert "computer_use" in r + + def test_session_manager_accessible(self): + tool = ComputerUseTool() + assert tool.session_manager is not None + + def test_recorder_accessible(self): + tool = ComputerUseTool() + assert tool.recorder is not None + + def test_custom_recorder(self): + recorder = ComputerUseRecorder() + tool = ComputerUseTool(recorder=recorder) + assert tool.recorder is recorder + + +# ============================================================ +# ComputerUseTool 执行测试 +# ============================================================ + + +class TestComputerUseToolExecution: + """测试 ComputerUseTool 操作执行""" + + @pytest.mark.asyncio + async def test_screenshot_action(self): + """截屏并识别 UI 元素""" + tool = ComputerUseTool() + result = await tool.execute(action="screenshot") + assert result["success"] is True + assert result["action"] == "screenshot" + assert "output" in result + + @pytest.mark.asyncio + async def test_click_action(self): + """点击指定坐标""" + tool = ComputerUseTool() + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True + assert result["action"] == "click" + assert "(100, 200)" in result["output"] + + @pytest.mark.asyncio + async def test_type_action(self): + """输入文本到输入框""" + tool = ComputerUseTool() + result = await tool.execute(action="type", text="hello world") + assert result["success"] is True + assert result["action"] == "type" + assert "hello world" in result["output"] + + @pytest.mark.asyncio + async def test_scroll_action(self): + """滚动操作""" + tool = ComputerUseTool() + result = await tool.execute(action="scroll", direction="down", amount=5) + assert result["success"] is True + assert "down" in result["output"] + + @pytest.mark.asyncio + async def test_drag_action(self): + """拖拽操作""" + tool = ComputerUseTool() + result = await tool.execute( + action="drag", + start_x=10, start_y=20, + end_x=100, end_y=200, + ) + assert result["success"] is True + assert result["action"] == "drag" + + @pytest.mark.asyncio + async def test_key_action(self): + """按键操作""" + tool = ComputerUseTool() + result = await tool.execute(action="key", key_name="Enter") + assert result["success"] is True + assert "Enter" in result["output"] + + @pytest.mark.asyncio + async def test_wait_action(self): + """等待操作""" + tool = ComputerUseTool() + result = await tool.execute(action="wait", duration=0.1) + assert result["success"] is True + + @pytest.mark.asyncio + async def test_missing_action(self): + """缺少 action 参数""" + tool = ComputerUseTool() + result = await tool.execute() + assert result["success"] is False + assert "action" in result["error"] + + @pytest.mark.asyncio + async def test_invalid_action(self): + """无效操作类型""" + tool = ComputerUseTool() + result = await tool.execute(action="invalid_action") + assert result["success"] is False + assert "无效" in result["error"] + + @pytest.mark.asyncio + async def test_click_missing_coordinates(self): + """click 缺少坐标参数""" + tool = ComputerUseTool() + result = await tool.execute(action="click") + assert result["success"] is False + assert "x" in result["error"] + + @pytest.mark.asyncio + async def test_type_missing_text(self): + """type 缺少文本参数""" + tool = ComputerUseTool() + result = await tool.execute(action="type") + assert result["success"] is False + assert "text" in result["error"] + + @pytest.mark.asyncio + async def test_key_missing_key_name(self): + """key 缺少按键名称参数""" + tool = ComputerUseTool() + result = await tool.execute(action="key") + assert result["success"] is False + assert "key_name" in result["error"] + + @pytest.mark.asyncio + async def test_drag_missing_params(self): + """drag 缺少参数""" + tool = ComputerUseTool() + result = await tool.execute(action="drag", start_x=0, start_y=0) + assert result["success"] is False + assert "end_x" in result["error"] + + @pytest.mark.asyncio + async def test_session_id(self): + """指定会话 ID""" + tool = ComputerUseTool() + result = await tool.execute(action="screenshot", session_id="my-session") + assert result["success"] is True + assert result["session_id"] == "my-session" + + @pytest.mark.asyncio + async def test_session_reuse(self): + """同一会话 ID 复用会话""" + tool = ComputerUseTool() + r1 = await tool.execute(action="click", x=10, y=20, session_id="reuse-test") + r2 = await tool.execute(action="type", text="hello", session_id="reuse-test") + assert r1["session_id"] == "reuse-test" + assert r2["session_id"] == "reuse-test" + assert tool.session_manager.has_session("reuse-test") + + +# ============================================================ +# ComputerUseTool 多步骤 UI 操作测试 +# ============================================================ + + +class TestComputerUseToolMultiStep: + """测试多步骤 UI 操作""" + + @pytest.mark.asyncio + async def test_multi_step_ui_operation(self): + """多步骤 UI 操作:截屏→点击→输入→按键""" + tool = ComputerUseTool() + session_id = "multi-step" + + # Step 1: 截屏识别 + r1 = await tool.execute(action="screenshot", session_id=session_id) + assert r1["success"] is True + + # Step 2: 点击输入框 + r2 = await tool.execute(action="click", x=100, y=200, session_id=session_id) + assert r2["success"] is True + + # Step 3: 输入文本 + r3 = await tool.execute(action="type", text="test input", session_id=session_id) + assert r3["success"] is True + + # Step 4: 按回车提交 + r4 = await tool.execute(action="key", key_name="Enter", session_id=session_id) + assert r4["success"] is True + + # 验证录制 + assert tool.recorder.total_actions == 4 + assert tool.recorder.success_count == 4 + + @pytest.mark.asyncio + async def test_each_step_informs_next(self): + """每步根据结果决定下一步""" + tool = ComputerUseTool() + session_id = "adaptive" + + # Step 1: 截屏 + r1 = await tool.execute(action="screenshot", session_id=session_id) + assert r1["success"] is True + + # 根据截屏结果决定下一步(模拟决策逻辑) + if r1["success"]: + r2 = await tool.execute(action="click", x=50, y=50, session_id=session_id) + else: + r2 = await tool.execute(action="wait", duration=1.0, session_id=session_id) + + assert r2["success"] is True + + +# ============================================================ +# ComputerUseTool 降级测试 +# ============================================================ + + +class TestComputerUseToolFallback: + """测试 ComputerUseTool 降级链""" + + @pytest.mark.asyncio + async def test_fallback_without_api_key(self): + """无 API Key 时降级到 Session 本地执行""" + tool = ComputerUseTool() # 无 api_key + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True # InMemory session 可以执行 + + @pytest.mark.asyncio + async def test_api_failure_fallback_to_session(self): + """API 调用失败时降级到 Session 本地执行""" + tool = ComputerUseTool(api_key="sk-test-key") + + with patch.object( + tool, "_call_anthropic_api", + new_callable=AsyncMock, + side_effect=Exception("API connection failed"), + ): + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True # 降级到 InMemory session + + @pytest.mark.asyncio + async def test_api_and_session_failure_fallback_suggestion(self): + """API 和 Session 都失败时返回降级建议""" + tool = ComputerUseTool(api_key="sk-test-key") + + # Mock API 失败 + with patch.object( + tool, "_call_anthropic_api", + new_callable=AsyncMock, + return_value=ActionResult(success=False, action="click", error="API error"), + ): + # Mock Session 也失败 + mock_session = AsyncMock(spec=ComputerUseSession) + mock_session.session_id = "fallback-test" + mock_session.screen = ScreenInfo() + mock_session.is_started = True + mock_session.screenshot.return_value = ActionResult( + success=True, action="screenshot", screenshot_base64="" + ) + mock_session.execute_action.return_value = ActionResult( + success=False, action="click", error="Session error" + ) + + # 直接注入 mock session + tool._session_manager._sessions["fallback-test"] = mock_session + + result = await tool.execute( + action="click", x=100, y=200, session_id="fallback-test" + ) + assert result["success"] is False + assert "fallback_suggestion" in result + + @pytest.mark.asyncio + async def test_custom_fallback_callback(self): + """自定义降级回调""" + async def fallback(action: str, params: dict) -> dict: + return {"success": True, "output": f"Fallback executed: {action}"} + + tool = ComputerUseTool(fallback_callback=fallback) + + # Mock API 和 Session 都失败 + with patch.object( + tool, "_call_anthropic_api", + new_callable=AsyncMock, + side_effect=Exception("API failed"), + ): + mock_session = AsyncMock(spec=ComputerUseSession) + mock_session.session_id = "cb-test" + mock_session.screen = ScreenInfo() + mock_session.is_started = True + mock_session.screenshot.return_value = ActionResult( + success=True, action="screenshot", screenshot_base64="" + ) + mock_session.execute_action.return_value = ActionResult( + success=False, action="click", error="Session error" + ) + tool._session_manager._sessions["cb-test"] = mock_session + + result = await tool.execute( + action="click", x=100, y=200, session_id="cb-test" + ) + assert result["success"] is True + assert "Fallback" in result["output"] + + @pytest.mark.asyncio + async def test_ae2_computer_use_fallback_to_oa_api(self): + """AE2: Computer Use 失败 → 降级到 OA 系统 API + + 模拟场景:Computer Use 无法操作 OA 系统 UI, + 降级到 OA 系统 API 完成操作。 + """ + oa_api_called = False + + async def oa_api_fallback(action: str, params: dict) -> dict: + nonlocal oa_api_called + oa_api_called = True + return { + "success": True, + "output": f"OA API completed: {action} with {params}", + } + + tool = ComputerUseTool( + api_key="sk-test-key", + fallback_callback=oa_api_fallback, + ) + + # Mock API 和 Session 都失败 + with patch.object( + tool, "_call_anthropic_api", + new_callable=AsyncMock, + side_effect=Exception("API unavailable"), + ): + mock_session = AsyncMock(spec=ComputerUseSession) + mock_session.session_id = "oa-test" + mock_session.screen = ScreenInfo() + mock_session.is_started = True + mock_session.screenshot.return_value = ActionResult( + success=True, action="screenshot", screenshot_base64="" + ) + mock_session.execute_action.return_value = ActionResult( + success=False, action="click", error="UI not accessible" + ) + tool._session_manager._sessions["oa-test"] = mock_session + + result = await tool.execute( + action="click", x=100, y=200, session_id="oa-test" + ) + assert result["success"] is True + assert oa_api_called is True + + +# ============================================================ +# ComputerUseTool 录制集成测试 +# ============================================================ + + +class TestComputerUseToolRecording: + """测试 ComputerUseTool 与 Recorder 的集成""" + + @pytest.mark.asyncio + async def test_actions_recorded(self): + """操作自动录制""" + tool = ComputerUseTool() + await tool.execute(action="click", x=10, y=20) + await tool.execute(action="type", text="hello") + assert tool.recorder.total_actions == 2 + + @pytest.mark.asyncio + async def test_recording_save_and_replay(self): + """录制保存和回放""" + tool = ComputerUseTool() + await tool.execute(action="click", x=10, y=20, session_id="rec-1") + await tool.execute(action="type", text="hello", session_id="rec-1") + + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + path = f.name + + try: + tool.recorder.save_recording(path) + + # 加载到新录制器并回放 + recorder2 = ComputerUseRecorder() + recorder2.load_recording(path) + assert recorder2.total_actions == 2 + + session = InMemoryComputerUseSession(session_id="replay-1") + results = await recorder2.replay(session) + assert len(results) == 2 + assert all(r.success for r in results) + finally: + Path(path).unlink(missing_ok=True) + + @pytest.mark.asyncio + async def test_recording_summary(self): + """录制摘要""" + tool = ComputerUseTool() + await tool.execute(action="click", x=10, y=20) + await tool.execute(action="type", text="hello") + summary = tool.recorder.summary() + assert summary["total_actions"] == 2 + assert summary["success_count"] == 2 + assert summary["failure_count"] == 0 + + +# ============================================================ +# ComputerUseTool API 调用测试(Mock httpx) +# ============================================================ + + +class TestComputerUseToolAPICall: + """测试 ComputerUseTool Anthropic API 调用(Mock)""" + + @pytest.mark.asyncio + async def test_api_call_success(self): + """API 调用成功""" + tool = ComputerUseTool(api_key="sk-test-key") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "content": [ + { + "type": "tool_use", + "name": "computer", + "input": {"action": "click"}, + } + ] + } + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True + + @pytest.mark.asyncio + async def test_api_call_http_error(self): + """API 调用 HTTP 错误""" + tool = ComputerUseTool(api_key="sk-test-key") + + mock_response = MagicMock() + mock_response.status_code = 429 + mock_response.text = "Rate limited" + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + # API 返回错误,降级到 session 本地执行 + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True # 降级成功 + + @pytest.mark.asyncio + async def test_api_call_network_error(self): + """API 调用网络错误""" + tool = ComputerUseTool(api_key="sk-test-key") + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.side_effect = Exception("Connection refused") + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + # 网络错误,降级到 session 本地执行 + result = await tool.execute(action="click", x=100, y=200) + assert result["success"] is True # 降级成功