fischer-agentkit/src/agentkit/orchestrator/checkpoint.py

261 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7)
在 TeamOrchestrator 阶段完成后保存 checkpoint 到 Redis或内存降级
崩溃后可通过 resume(plan_id) 从最后完成阶段恢复。
复用 PipelineStateRedis 的 _safe_redis_call 模式:
Redis 失败时降级到内存 dict不阻断执行。
键命名agentkit:pipeline:checkpoint:{plan_id}:{phase_id}
TTL7 天(与 PipelineStateRedis._TTL_SECONDS 一致)
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
logger = logging.getLogger(__name__)
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
@dataclass
class CheckpointData:
"""单个阶段的 checkpoint 数据。"""
plan_id: str
phase_id: str
phase_name: str
phase_status: str
phase_result: dict[str, Any] | None = None
plan_status: str = ""
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
return cls(
plan_id=data.get("plan_id", ""),
phase_id=data.get("phase_id", ""),
phase_name=data.get("phase_name", ""),
phase_status=data.get("phase_status", ""),
phase_result=data.get("phase_result"),
plan_status=data.get("plan_status", ""),
saved_at=data.get("saved_at", ""),
)
class PipelineCheckpoint:
"""阶段级检查点存储 — Redis 优先,内存降级。
Usage::
checkpoint = PipelineCheckpoint(redis_client=redis)
await checkpoint.save(plan.id, phase, plan.status.value)
last = await checkpoint.load(plan.id)
if last:
# resume from last completed phase
"""
def __init__(
self,
redis_client: Any = None,
prefix: str = _KEY_PREFIX,
ttl_seconds: int = _TTL_SECONDS,
) -> None:
self._redis = redis_client
self._prefix = prefix
self._ttl = ttl_seconds
# 内存降级存储plan_id → {phase_id → CheckpointData}
# P1 #6: 改用 dict keyed by phase_id避免重复 append
self._memory: dict[str, dict[str, CheckpointData]] = {}
# 内存降级存储plan_id → (plan_dict, saved_timestamp)
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
def _is_expired(self, saved_at: str) -> bool:
"""检查 checkpoint 是否已过期(内存模式 TTL"""
if not saved_at:
return False
try:
saved_time = datetime.fromisoformat(saved_at)
age = (datetime.now(timezone.utc) - saved_time).total_seconds()
return age > self._ttl
except (ValueError, TypeError):
return False
def _key(self, plan_id: str, phase_id: str) -> str:
return f"{self._prefix}:{plan_id}:{phase_id}"
def _index_key(self, plan_id: str) -> str:
"""Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。"""
return f"{self._prefix}:index:{plan_id}"
def _plan_key(self, plan_id: str) -> str:
"""完整 plan JSON 的存储键。"""
return f"{self._prefix}:plan:{plan_id}"
async def save_plan(self, plan: Any) -> None:
"""保存完整 TeamPlan用于 resume 重建)。
Args:
plan: TeamPlan 对象(需要有 to_dict() 方法)
"""
plan_id = plan.id
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
# 内存降级(带 TTL 时间戳)
self._memory_plans[plan_id] = (plan_dict, time.time())
# 尝试写入 Redis
if self._redis is not None:
try:
await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
except Exception as e:
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
"""加载完整 plan JSON。"""
# 优先 Redis
if self._redis is not None:
try:
raw = await self._redis.get(self._plan_key(plan_id))
if raw:
return json.loads(raw)
except Exception as e:
logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
# 内存降级(检查 TTL 过期)
entry = self._memory_plans.get(plan_id)
if entry is None:
return None
plan_dict, saved_at = entry
if time.time() - saved_at > self._ttl:
# 已过期,清除
del self._memory_plans[plan_id]
return None
return plan_dict
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
"""保存阶段 checkpoint。
Args:
plan_id: 计划 ID
phase: PlanPhase 对象(需要有 id, name, status, result 属性)
plan_status: 计划当前状态
"""
phase_id = getattr(phase, "id", str(phase))
phase_name = getattr(phase, "name", "")
phase_status = getattr(phase, "status", "")
if hasattr(phase_status, "value"):
phase_status = phase_status.value
# 序列化 phase.result可能是 offloaded dict with _ref_key
phase_result = getattr(phase, "result", None)
if phase_result is not None and not isinstance(phase_result, dict):
phase_result = {"content": str(phase_result)}
data = CheckpointData(
plan_id=plan_id,
phase_id=phase_id,
phase_name=phase_name,
phase_status=str(phase_status),
phase_result=phase_result,
plan_status=plan_status,
)
# P1 #6: 内存降级用 dict keyed by phase_id覆盖重复 checkpoint
self._memory.setdefault(plan_id, {})[phase_id] = data
# 尝试写入 Redis
if self._redis is not None:
try:
score = time.time()
pipe = self._redis.pipeline()
pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl)
pipe.zadd(self._index_key(plan_id), {phase_id: score})
await pipe.execute()
except Exception as e:
logger.warning(
f"PipelineCheckpoint.save Redis failed for plan {plan_id}, "
f"phase {phase_id}: {e} — using memory fallback"
)
async def load(self, plan_id: str) -> CheckpointData | None:
"""加载最后完成的阶段 checkpoint。
Returns:
最后一个 COMPLETED 阶段的 CheckpointData或 None。
"""
checkpoints = await self.list_checkpoints(plan_id)
if not checkpoints:
return None
# 返回最后一个 COMPLETED 阶段
completed = [c for c in checkpoints if c.phase_status == "completed"]
if not completed:
return None
return completed[-1]
async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]:
"""列出某 plan 的所有 checkpoint按保存时间排序"""
# 优先从 Redis 读取
if self._redis is not None:
try:
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
if not phase_ids:
# Redis 无数据,检查内存(过滤过期)
return [
c for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
]
# 批量 GETpipeline 避免 N+1 往返)
pipe = self._redis.pipeline()
for phase_id in phase_ids:
pipe.get(self._key(plan_id, phase_id))
raws = await pipe.execute()
results: list[CheckpointData] = []
for raw in raws:
if raw:
data = json.loads(raw)
results.append(CheckpointData.from_dict(data))
return results
except Exception as e:
logger.warning(
f"PipelineCheckpoint.list_checkpoints Redis failed for "
f"plan {plan_id}: {e} — using memory fallback"
)
# 内存降级(过滤过期 checkpoint
return [
c for c in self._memory.get(plan_id, {}).values()
if not self._is_expired(c.saved_at)
]
async def clear(self, plan_id: str) -> None:
"""清除某 plan 的所有 checkpoint。"""
# 清除内存
self._memory.pop(plan_id, None)
self._memory_plans.pop(plan_id, None)
# 清除 Redis
if self._redis is not None:
try:
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
pipe = self._redis.pipeline()
for phase_id in phase_ids:
pipe.delete(self._key(plan_id, phase_id))
pipe.delete(self._index_key(plan_id))
pipe.delete(self._plan_key(plan_id))
await pipe.execute()
except Exception as e:
logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")