261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7)
|
||
|
||
在 TeamOrchestrator 阶段完成后保存 checkpoint 到 Redis(或内存降级),
|
||
崩溃后可通过 resume(plan_id) 从最后完成阶段恢复。
|
||
|
||
复用 PipelineStateRedis 的 _safe_redis_call 模式:
|
||
Redis 失败时降级到内存 dict,不阻断执行。
|
||
|
||
键命名:agentkit:pipeline:checkpoint:{plan_id}:{phase_id}
|
||
TTL:7 天(与 PipelineStateRedis._TTL_SECONDS 一致)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from dataclasses import asdict, dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
||
|
||
|
||
@dataclass
|
||
class CheckpointData:
|
||
"""单个阶段的 checkpoint 数据。"""
|
||
|
||
plan_id: str
|
||
phase_id: str
|
||
phase_name: str
|
||
phase_status: str
|
||
phase_result: dict[str, Any] | None = None
|
||
plan_status: str = ""
|
||
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||
|
||
def to_dict(self) -> dict[str, Any]:
|
||
return asdict(self)
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
|
||
return cls(
|
||
plan_id=data.get("plan_id", ""),
|
||
phase_id=data.get("phase_id", ""),
|
||
phase_name=data.get("phase_name", ""),
|
||
phase_status=data.get("phase_status", ""),
|
||
phase_result=data.get("phase_result"),
|
||
plan_status=data.get("plan_status", ""),
|
||
saved_at=data.get("saved_at", ""),
|
||
)
|
||
|
||
|
||
class PipelineCheckpoint:
|
||
"""阶段级检查点存储 — Redis 优先,内存降级。
|
||
|
||
Usage::
|
||
|
||
checkpoint = PipelineCheckpoint(redis_client=redis)
|
||
await checkpoint.save(plan.id, phase, plan.status.value)
|
||
last = await checkpoint.load(plan.id)
|
||
if last:
|
||
# resume from last completed phase
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
redis_client: Any = None,
|
||
prefix: str = _KEY_PREFIX,
|
||
ttl_seconds: int = _TTL_SECONDS,
|
||
) -> None:
|
||
self._redis = redis_client
|
||
self._prefix = prefix
|
||
self._ttl = ttl_seconds
|
||
# 内存降级存储:plan_id → {phase_id → CheckpointData}
|
||
# P1 #6: 改用 dict keyed by phase_id,避免重复 append
|
||
self._memory: dict[str, dict[str, CheckpointData]] = {}
|
||
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
||
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
|
||
|
||
def _is_expired(self, saved_at: str) -> bool:
|
||
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
||
if not saved_at:
|
||
return False
|
||
try:
|
||
saved_time = datetime.fromisoformat(saved_at)
|
||
age = (datetime.now(timezone.utc) - saved_time).total_seconds()
|
||
return age > self._ttl
|
||
except (ValueError, TypeError):
|
||
return False
|
||
|
||
def _key(self, plan_id: str, phase_id: str) -> str:
|
||
return f"{self._prefix}:{plan_id}:{phase_id}"
|
||
|
||
def _index_key(self, plan_id: str) -> str:
|
||
"""Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。"""
|
||
return f"{self._prefix}:index:{plan_id}"
|
||
|
||
def _plan_key(self, plan_id: str) -> str:
|
||
"""完整 plan JSON 的存储键。"""
|
||
return f"{self._prefix}:plan:{plan_id}"
|
||
|
||
async def save_plan(self, plan: Any) -> None:
|
||
"""保存完整 TeamPlan(用于 resume 重建)。
|
||
|
||
Args:
|
||
plan: TeamPlan 对象(需要有 to_dict() 方法)
|
||
"""
|
||
plan_id = plan.id
|
||
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
|
||
|
||
# 内存降级(带 TTL 时间戳)
|
||
self._memory_plans[plan_id] = (plan_dict, time.time())
|
||
|
||
# 尝试写入 Redis
|
||
if self._redis is not None:
|
||
try:
|
||
await self._redis.set(self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl)
|
||
except Exception as e:
|
||
logger.warning(f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}")
|
||
|
||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
||
"""加载完整 plan JSON。"""
|
||
# 优先 Redis
|
||
if self._redis is not None:
|
||
try:
|
||
raw = await self._redis.get(self._plan_key(plan_id))
|
||
if raw:
|
||
return json.loads(raw)
|
||
except Exception as e:
|
||
logger.warning(f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}")
|
||
# 内存降级(检查 TTL 过期)
|
||
entry = self._memory_plans.get(plan_id)
|
||
if entry is None:
|
||
return None
|
||
plan_dict, saved_at = entry
|
||
if time.time() - saved_at > self._ttl:
|
||
# 已过期,清除
|
||
del self._memory_plans[plan_id]
|
||
return None
|
||
return plan_dict
|
||
|
||
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
||
"""保存阶段 checkpoint。
|
||
|
||
Args:
|
||
plan_id: 计划 ID
|
||
phase: PlanPhase 对象(需要有 id, name, status, result 属性)
|
||
plan_status: 计划当前状态
|
||
"""
|
||
phase_id = getattr(phase, "id", str(phase))
|
||
phase_name = getattr(phase, "name", "")
|
||
phase_status = getattr(phase, "status", "")
|
||
if hasattr(phase_status, "value"):
|
||
phase_status = phase_status.value
|
||
|
||
# 序列化 phase.result(可能是 offloaded dict with _ref_key)
|
||
phase_result = getattr(phase, "result", None)
|
||
if phase_result is not None and not isinstance(phase_result, dict):
|
||
phase_result = {"content": str(phase_result)}
|
||
|
||
data = CheckpointData(
|
||
plan_id=plan_id,
|
||
phase_id=phase_id,
|
||
phase_name=phase_name,
|
||
phase_status=str(phase_status),
|
||
phase_result=phase_result,
|
||
plan_status=plan_status,
|
||
)
|
||
|
||
# P1 #6: 内存降级用 dict keyed by phase_id,覆盖重复 checkpoint
|
||
self._memory.setdefault(plan_id, {})[phase_id] = data
|
||
|
||
# 尝试写入 Redis
|
||
if self._redis is not None:
|
||
try:
|
||
score = time.time()
|
||
pipe = self._redis.pipeline()
|
||
pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl)
|
||
pipe.zadd(self._index_key(plan_id), {phase_id: score})
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"PipelineCheckpoint.save Redis failed for plan {plan_id}, "
|
||
f"phase {phase_id}: {e} — using memory fallback"
|
||
)
|
||
|
||
async def load(self, plan_id: str) -> CheckpointData | None:
|
||
"""加载最后完成的阶段 checkpoint。
|
||
|
||
Returns:
|
||
最后一个 COMPLETED 阶段的 CheckpointData,或 None。
|
||
"""
|
||
checkpoints = await self.list_checkpoints(plan_id)
|
||
if not checkpoints:
|
||
return None
|
||
|
||
# 返回最后一个 COMPLETED 阶段
|
||
completed = [c for c in checkpoints if c.phase_status == "completed"]
|
||
if not completed:
|
||
return None
|
||
return completed[-1]
|
||
|
||
async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]:
|
||
"""列出某 plan 的所有 checkpoint(按保存时间排序)。"""
|
||
# 优先从 Redis 读取
|
||
if self._redis is not None:
|
||
try:
|
||
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||
if not phase_ids:
|
||
# Redis 无数据,检查内存(过滤过期)
|
||
return [
|
||
c for c in self._memory.get(plan_id, {}).values()
|
||
if not self._is_expired(c.saved_at)
|
||
]
|
||
|
||
# 批量 GET(pipeline 避免 N+1 往返)
|
||
pipe = self._redis.pipeline()
|
||
for phase_id in phase_ids:
|
||
pipe.get(self._key(plan_id, phase_id))
|
||
raws = await pipe.execute()
|
||
|
||
results: list[CheckpointData] = []
|
||
for raw in raws:
|
||
if raw:
|
||
data = json.loads(raw)
|
||
results.append(CheckpointData.from_dict(data))
|
||
return results
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"PipelineCheckpoint.list_checkpoints Redis failed for "
|
||
f"plan {plan_id}: {e} — using memory fallback"
|
||
)
|
||
|
||
# 内存降级(过滤过期 checkpoint)
|
||
return [
|
||
c for c in self._memory.get(plan_id, {}).values()
|
||
if not self._is_expired(c.saved_at)
|
||
]
|
||
|
||
async def clear(self, plan_id: str) -> None:
|
||
"""清除某 plan 的所有 checkpoint。"""
|
||
# 清除内存
|
||
self._memory.pop(plan_id, None)
|
||
self._memory_plans.pop(plan_id, None)
|
||
|
||
# 清除 Redis
|
||
if self._redis is not None:
|
||
try:
|
||
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||
pipe = self._redis.pipeline()
|
||
for phase_id in phase_ids:
|
||
pipe.delete(self._key(plan_id, phase_id))
|
||
pipe.delete(self._index_key(plan_id))
|
||
pipe.delete(self._plan_key(plan_id))
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}")
|