"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7) 在 TeamOrchestrator 阶段完成后保存 checkpoint 到 Redis(或内存降级), 崩溃后可通过 resume(plan_id) 从最后完成阶段恢复。 复用 PipelineStateRedis 的 _safe_redis_call 模式: Redis 失败时降级到内存 dict,不阻断执行。 键命名:agentkit:pipeline:checkpoint:{plan_id}:{phase_id} TTL:7 天(与 PipelineStateRedis._TTL_SECONDS 一致) """ from __future__ import annotations import json import logging import time from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from typing import Any logger = logging.getLogger(__name__) _TTL_SECONDS = 7 * 24 * 3600 # 7 days _KEY_PREFIX = "agentkit:pipeline:checkpoint" @dataclass class CheckpointData: """单个阶段的 checkpoint 数据。""" plan_id: str phase_id: str phase_name: str phase_status: str phase_result: dict[str, Any] | None = None plan_status: str = "" saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) def to_dict(self) -> dict[str, Any]: return asdict(self) @classmethod def from_dict(cls, data: dict[str, Any]) -> CheckpointData: return cls( plan_id=data.get("plan_id", ""), phase_id=data.get("phase_id", ""), phase_name=data.get("phase_name", ""), phase_status=data.get("phase_status", ""), phase_result=data.get("phase_result"), plan_status=data.get("plan_status", ""), saved_at=data.get("saved_at", ""), ) class PipelineCheckpoint: """阶段级检查点存储 — Redis 优先,内存降级。 Usage:: checkpoint = PipelineCheckpoint(redis_client=redis) await checkpoint.save(plan.id, phase, plan.status.value) last = await checkpoint.load(plan.id) if last: # resume from last completed phase """ def __init__( self, redis_client: Any = None, prefix: str = _KEY_PREFIX, ttl_seconds: int = _TTL_SECONDS, ) -> None: self._redis = redis_client self._prefix = prefix self._ttl = ttl_seconds # 内存降级存储:plan_id → {phase_id → CheckpointData} # P1 #6: 改用 dict keyed by phase_id,避免重复 append self._memory: dict[str, dict[str, CheckpointData]] = {} # 内存降级存储:plan_id → (plan_dict, saved_timestamp) self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {} def _is_expired(self, saved_at: str) -> bool: """检查 checkpoint 是否已过期(内存模式 TTL)。""" if not saved_at: return False try: saved_time = datetime.fromisoformat(saved_at) age = (datetime.now(timezone.utc) - saved_time).total_seconds() return age > self._ttl except (ValueError, TypeError): return False def _key(self, plan_id: str, phase_id: str) -> str: return f"{self._prefix}:{plan_id}:{phase_id}" def _index_key(self, plan_id: str) -> str: """Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。""" return f"{self._prefix}:index:{plan_id}" def _plan_key(self, plan_id: str) -> str: """完整 plan JSON 的存储键。""" return f"{self._prefix}:plan:{plan_id}" async def save_plan(self, plan: Any) -> None: """保存完整 TeamPlan(用于 resume 重建)。 Args: plan: TeamPlan 对象(需要有 to_dict() 方法) """ plan_id = plan.id plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id} # 内存降级(带 TTL 时间戳) self._memory_plans[plan_id] = (plan_dict, time.time()) # 尝试写入 Redis if self._redis is not None: try: await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl) except Exception as e: logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}") async def load_plan(self, plan_id: str) -> dict[str, Any] | None: """加载完整 plan JSON。""" # 优先 Redis if self._redis is not None: try: raw = await self._redis.get(self._plan_key(plan_id)) if raw: return json.loads(raw) except Exception as e: logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}") # 内存降级(检查 TTL 过期) entry = self._memory_plans.get(plan_id) if entry is None: return None plan_dict, saved_at = entry if time.time() - saved_at > self._ttl: # 已过期,清除 del self._memory_plans[plan_id] return None return plan_dict async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: """保存阶段 checkpoint。 Args: plan_id: 计划 ID phase: PlanPhase 对象(需要有 id, name, status, result 属性) plan_status: 计划当前状态 """ phase_id = getattr(phase, "id", str(phase)) phase_name = getattr(phase, "name", "") phase_status = getattr(phase, "status", "") if hasattr(phase_status, "value"): phase_status = phase_status.value # 序列化 phase.result(可能是 offloaded dict with _ref_key) phase_result = getattr(phase, "result", None) if phase_result is not None and not isinstance(phase_result, dict): phase_result = {"content": str(phase_result)} data = CheckpointData( plan_id=plan_id, phase_id=phase_id, phase_name=phase_name, phase_status=str(phase_status), phase_result=phase_result, plan_status=plan_status, ) # P1 #6: 内存降级用 dict keyed by phase_id,覆盖重复 checkpoint self._memory.setdefault(plan_id, {})[phase_id] = data # 尝试写入 Redis if self._redis is not None: try: score = time.time() pipe = self._redis.pipeline() pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl) pipe.zadd(self._index_key(plan_id), {phase_id: score}) await pipe.execute() except Exception as e: logger.warning( f"PipelineCheckpoint.save Redis failed for plan {plan_id}, " f"phase {phase_id}: {e} — using memory fallback" ) async def load(self, plan_id: str) -> CheckpointData | None: """加载最后完成的阶段 checkpoint。 Returns: 最后一个 COMPLETED 阶段的 CheckpointData,或 None。 """ checkpoints = await self.list_checkpoints(plan_id) if not checkpoints: return None # 返回最后一个 COMPLETED 阶段 completed = [c for c in checkpoints if c.phase_status == "completed"] if not completed: return None return completed[-1] async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]: """列出某 plan 的所有 checkpoint(按保存时间排序)。""" # 优先从 Redis 读取 if self._redis is not None: try: phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1) if not phase_ids: # Redis 无数据,检查内存(过滤过期) return [ c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] # 批量 GET(pipeline 避免 N+1 往返) pipe = self._redis.pipeline() for phase_id in phase_ids: pipe.get(self._key(plan_id, phase_id)) raws = await pipe.execute() results: list[CheckpointData] = [] for raw in raws: if raw: data = json.loads(raw) results.append(CheckpointData.from_dict(data)) return results except Exception as e: logger.warning( f"PipelineCheckpoint.list_checkpoints Redis failed for " f"plan {plan_id}: {e} — using memory fallback" ) # 内存降级(过滤过期 checkpoint) return [ c for c in self._memory.get(plan_id, {}).values() if not self._is_expired(c.saved_at) ] async def clear(self, plan_id: str) -> None: """清除某 plan 的所有 checkpoint。""" # 清除内存 self._memory.pop(plan_id, None) self._memory_plans.pop(plan_id, None) # 清除 Redis if self._redis is not None: try: phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1) pipe = self._redis.pipeline() for phase_id in phase_ids: pipe.delete(self._key(plan_id, phase_id)) pipe.delete(self._index_key(plan_id)) pipe.delete(self._plan_key(plan_id)) await pipe.execute() except Exception as e: logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")