fischer-agentkit/src/agentkit/tools/computer_use_recorder.py

236 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ComputerUseRecorder - Computer Use 操作录制与回放
记录每次截屏和操作,支持回放和审核。
支持将录制结果持久化到文件,以及从文件加载回放。
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
from agentkit.tools.computer_use_session import ComputerUseSession, ActionResult
logger = logging.getLogger(__name__)
@dataclass
class ActionRecord:
"""操作记录
记录一次 Computer Use 操作的完整信息。
"""
timestamp: float
action: str
params: dict[str, Any] = field(default_factory=dict)
success: bool = False
output: str = ""
error: str = ""
screenshot_path: str = ""
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ActionRecord:
return cls(**data)
class ComputerUseRecorder:
"""Computer Use 操作录制器
记录每次截屏和操作,支持回放和审核。
Usage:
recorder = ComputerUseRecorder()
# 录制
recorder.record("click", {"x": 100, "y": 200}, result)
# 获取记录
records = recorder.get_records()
# 回放
await recorder.replay(session)
# 持久化
recorder.save_recording("recording.json")
"""
def __init__(self, screenshot_dir: str | Path | None = None):
self._records: list[ActionRecord] = []
self._screenshot_dir = Path(screenshot_dir) if screenshot_dir else None
def record(
self,
action: str,
params: dict[str, Any],
result: ActionResult,
screenshot_path: str = "",
) -> ActionRecord:
"""记录一次操作
Args:
action: 操作类型
params: 操作参数
result: 操作结果
screenshot_path: 截图文件路径
Returns:
ActionRecord 记录实例
"""
record = ActionRecord(
timestamp=time.time(),
action=action,
params=dict(params),
success=result.success,
output=result.output[:500] if result.output else "",
error=result.error[:500] if result.error else "",
screenshot_path=screenshot_path,
)
self._records.append(record)
logger.debug(
"Recorded action: %s success=%s",
action,
result.success,
)
return record
def get_records(self) -> list[ActionRecord]:
"""获取所有操作记录(副本)"""
return list(self._records)
def get_records_by_action(self, action: str) -> list[ActionRecord]:
"""按操作类型筛选记录"""
return [r for r in self._records if r.action == action]
def get_failed_records(self) -> list[ActionRecord]:
"""获取所有失败的操作记录"""
return [r for r in self._records if not r.success]
async def replay(self, session: ComputerUseSession) -> list[ActionResult]:
"""回放操作序列
按时间顺序重新执行所有录制的操作。
Args:
session: 目标会话
Returns:
每步操作的 ActionResult 列表
"""
results: list[ActionResult] = []
if not session.is_started:
await session.start()
for record in self._records:
try:
if record.action == "screenshot":
result = await session.screenshot()
else:
result = await session.execute_action(
record.action, **record.params
)
results.append(result)
logger.info(
"Replayed action: %s success=%s",
record.action,
result.success,
)
except Exception as e:
results.append(ActionResult(
success=False,
action=record.action,
error=f"Replay failed: {e}",
))
logger.error(
"Replay failed for action %s: %s",
record.action,
e,
)
return results
def save_recording(self, path: str | Path) -> None:
"""保存录制到文件
Args:
path: 文件路径JSON 格式)
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
data = {
"version": "1.0",
"recorded_at": time.time(),
"total_actions": len(self._records),
"records": [r.to_dict() for r in self._records],
}
with open(path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
logger.info("Recording saved to %s (%d actions)", path, len(self._records))
def load_recording(self, path: str | Path) -> None:
"""从文件加载录制
Args:
path: 文件路径JSON 格式)
Raises:
FileNotFoundError: 文件不存在
ValueError: 文件格式无效
"""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Recording file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if "records" not in data or not isinstance(data["records"], list):
raise ValueError(f"Invalid recording format: {path}")
self._records = [ActionRecord.from_dict(r) for r in data["records"]]
logger.info("Recording loaded from %s (%d actions)", path, len(self._records))
def clear(self) -> None:
"""清空录制记录"""
self._records.clear()
@property
def total_actions(self) -> int:
"""总操作数"""
return len(self._records)
@property
def success_count(self) -> int:
"""成功操作数"""
return sum(1 for r in self._records if r.success)
@property
def failure_count(self) -> int:
"""失败操作数"""
return sum(1 for r in self._records if not r.success)
def summary(self) -> dict[str, Any]:
"""生成录制摘要"""
return {
"total_actions": self.total_actions,
"success_count": self.success_count,
"failure_count": self.failure_count,
"action_types": list({r.action for r in self._records}),
"duration": (
self._records[-1].timestamp - self._records[0].timestamp
if len(self._records) >= 2
else 0.0
),
}