diff --git a/src/agentkit/orchestrator/checkpoint.py b/src/agentkit/orchestrator/checkpoint.py index 9814d16..4c115bf 100644 --- a/src/agentkit/orchestrator/checkpoint.py +++ b/src/agentkit/orchestrator/checkpoint.py @@ -76,8 +76,19 @@ class PipelineCheckpoint: self._ttl = ttl_seconds # 内存降级存储:plan_id → list of CheckpointData self._memory: dict[str, list[CheckpointData]] = {} - # 内存降级存储:plan_id → plan dict (for resume) - self._memory_plans: dict[str, dict[str, Any]] = {} + # 内存降级存储:plan_id → (plan_dict, saved_timestamp) + self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {} + + def _is_expired(self, saved_at: str) -> bool: + """检查 checkpoint 是否已过期(内存模式 TTL)。""" + if not saved_at: + return False + try: + saved_time = datetime.fromisoformat(saved_at) + age = (datetime.now(timezone.utc) - saved_time).total_seconds() + return age > self._ttl + except (ValueError, TypeError): + return False def _key(self, plan_id: str, phase_id: str) -> str: return f"{self._prefix}:{plan_id}:{phase_id}" @@ -99,8 +110,8 @@ class PipelineCheckpoint: plan_id = plan.id plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id} - # 内存降级 - self._memory_plans[plan_id] = plan_dict + # 内存降级(带 TTL 时间戳) + self._memory_plans[plan_id] = (plan_dict, time.time()) # 尝试写入 Redis if self._redis is not None: @@ -125,8 +136,16 @@ class PipelineCheckpoint: logger.warning( f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}" ) - # 内存降级 - return self._memory_plans.get(plan_id) + # 内存降级(检查 TTL 过期) + entry = self._memory_plans.get(plan_id) + if entry is None: + return None + plan_dict, saved_at = entry + if time.time() - saved_at > self._ttl: + # 已过期,清除 + del self._memory_plans[plan_id] + return None + return plan_dict async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: """保存阶段 checkpoint。 @@ -196,8 +215,12 @@ class PipelineCheckpoint: try: phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1) if not phase_ids: - # Redis 无数据,检查内存 - return list(self._memory.get(plan_id, [])) + # Redis 无数据,检查内存(过滤过期) + return [ + c + for c in self._memory.get(plan_id, []) + if not self._is_expired(c.saved_at) + ] results: list[CheckpointData] = [] for phase_id in phase_ids: @@ -212,8 +235,8 @@ class PipelineCheckpoint: f"plan {plan_id}: {e} — using memory fallback" ) - # 内存降级 - return list(self._memory.get(plan_id, [])) + # 内存降级(过滤过期 checkpoint) + return [c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)] async def clear(self, plan_id: str) -> None: """清除某 plan 的所有 checkpoint。""" diff --git a/tests/unit/test_pipeline_checkpoint.py b/tests/unit/test_pipeline_checkpoint.py index 3ed8182..68cf75b 100644 --- a/tests/unit/test_pipeline_checkpoint.py +++ b/tests/unit/test_pipeline_checkpoint.py @@ -250,6 +250,126 @@ class TestCheckpointSavePlan: assert loaded is None +# ── TTL 过期测试 ───────────────────────────────────────── + + +class TestCheckpointTTL: + """内存模式 TTL 过期测试""" + + @pytest.mark.asyncio + async def test_memory_checkpoint_expired_returns_none(self): + """内存模式下 checkpoint 过期后 load 返回 None""" + # 使用极短 TTL(1 秒)便于测试 + cp = PipelineCheckpoint(ttl_seconds=1) + phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) + await cp.save("plan_1", phase, "executing") + + # 立即 load 应正常返回 + loaded = await cp.load("plan_1") + assert loaded is not None + assert loaded.phase_id == "p1" + + # 手动将 saved_at 改为过期时间 + from datetime import datetime, timedelta, timezone + + expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + cp._memory["plan_1"][0].saved_at = expired_time + + # 过期后 load 应返回 None + loaded = await cp.load("plan_1") + assert loaded is None + + @pytest.mark.asyncio + async def test_memory_plan_expired_returns_none(self): + """内存模式下 plan 过期后 load_plan 返回 None""" + cp = PipelineCheckpoint(ttl_seconds=1) + plan = _make_plan(plan_id="plan_ttl") + await cp.save_plan(plan) + + # 立即 load_plan 应正常返回 + loaded = await cp.load_plan("plan_ttl") + assert loaded is not None + assert loaded["id"] == "plan_ttl" + + # 手动将保存时间改为过期 + import time as _time + + cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10) + + # 过期后 load_plan 应返回 None + loaded = await cp.load_plan("plan_ttl") + assert loaded is None + + @pytest.mark.asyncio + async def test_memory_checkpoint_not_expired_returns_data(self): + """内存模式下 checkpoint 未过期时正常返回""" + cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL + phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) + await cp.save("plan_1", phase, "executing") + + loaded = await cp.load("plan_1") + assert loaded is not None + assert loaded.phase_id == "p1" + + @pytest.mark.asyncio + async def test_memory_list_checkpoints_filters_expired(self): + """内存模式下 list_checkpoints 过滤过期 checkpoint""" + cp = PipelineCheckpoint(ttl_seconds=1) + # 保存 2 个 checkpoint + phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) + phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED) + await cp.save("plan_1", phase1, "executing") + await cp.save("plan_1", phase2, "executing") + + # 立即 list 应返回 2 个 + checkpoints = await cp.list_checkpoints("plan_1") + assert len(checkpoints) == 2 + + # 将第一个 checkpoint 标记为过期 + from datetime import datetime, timedelta, timezone + + expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + cp._memory["plan_1"][0].saved_at = expired_time + + # list 应过滤掉过期的,只返回 1 个 + checkpoints = await cp.list_checkpoints("plan_1") + assert len(checkpoints) == 1 + assert checkpoints[0].phase_id == "p2" + + @pytest.mark.asyncio + async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self): + """Redis 无数据时降级到内存,并应用 TTL 过滤""" + redis = _make_mock_redis() + cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1) + + phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) + # save 会同时写入 Redis 和内存降级 + await cp.save("plan_1", phase, "executing") + + # 立即 load 应正常返回(Redis 有数据) + loaded = await cp.load("plan_1") + assert loaded is not None + + # 模拟 Redis 数据丢失:zrange 返回空,get 返回 None + redis.zrange = AsyncMock(return_value=[]) + redis.get = AsyncMock(return_value=None) + + # 内存降级 + TTL 未过期 → 应返回数据 + loaded = await cp.load("plan_1") + assert loaded is not None + assert loaded.phase_id == "p1" + + # 标记为过期 + from datetime import datetime, timedelta, timezone + + expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() + cp._memory["plan_1"][0].saved_at = expired_time + + # 内存降级 + TTL 过期 → 应返回 None + loaded = await cp.load("plan_1") + assert loaded is None + + # ── clear 测试 ───────────────────────────────────────────