fix(checkpoint): add TTL expiration for memory fallback mode

内存降级模式之前没有 TTL 过期机制,长期运行进程会导致内存泄漏。
现在 list_checkpoints 和 load_plan 在内存模式下会过滤/清除过期数据。

- list_checkpoints: 内存降级分支过滤过期 checkpoint
- load_plan: 内存降级分支检查 TTL 过期,过期则清除并返回 None
- 新增 _is_expired 方法检查 saved_at 是否超过 TTL
- _memory_plans 类型改为 tuple(plan_dict, timestamp) 以支持 TTL
- 新增 5 个 TTL 过期测试覆盖内存模式和 Redis 降级场景
This commit is contained in:
chiguyong 2026-06-24 22:04:55 +08:00
parent fa152e24ac
commit 0847c0e086
2 changed files with 153 additions and 10 deletions

View File

@ -76,8 +76,19 @@ class PipelineCheckpoint:
self._ttl = ttl_seconds self._ttl = ttl_seconds
# 内存降级存储plan_id → list of CheckpointData # 内存降级存储plan_id → list of CheckpointData
self._memory: dict[str, list[CheckpointData]] = {} self._memory: dict[str, list[CheckpointData]] = {}
# 内存降级存储plan_id → plan dict (for resume) # 内存降级存储plan_id → (plan_dict, saved_timestamp)
self._memory_plans: dict[str, dict[str, Any]] = {} self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
def _is_expired(self, saved_at: str) -> bool:
"""检查 checkpoint 是否已过期(内存模式 TTL"""
if not saved_at:
return False
try:
saved_time = datetime.fromisoformat(saved_at)
age = (datetime.now(timezone.utc) - saved_time).total_seconds()
return age > self._ttl
except (ValueError, TypeError):
return False
def _key(self, plan_id: str, phase_id: str) -> str: def _key(self, plan_id: str, phase_id: str) -> str:
return f"{self._prefix}:{plan_id}:{phase_id}" return f"{self._prefix}:{plan_id}:{phase_id}"
@ -99,8 +110,8 @@ class PipelineCheckpoint:
plan_id = plan.id plan_id = plan.id
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id} plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
# 内存降级 # 内存降级(带 TTL 时间戳)
self._memory_plans[plan_id] = plan_dict self._memory_plans[plan_id] = (plan_dict, time.time())
# 尝试写入 Redis # 尝试写入 Redis
if self._redis is not None: if self._redis is not None:
@ -125,8 +136,16 @@ class PipelineCheckpoint:
logger.warning( logger.warning(
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}" f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
) )
# 内存降级 # 内存降级(检查 TTL 过期)
return self._memory_plans.get(plan_id) entry = self._memory_plans.get(plan_id)
if entry is None:
return None
plan_dict, saved_at = entry
if time.time() - saved_at > self._ttl:
# 已过期,清除
del self._memory_plans[plan_id]
return None
return plan_dict
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None: async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
"""保存阶段 checkpoint。 """保存阶段 checkpoint。
@ -196,8 +215,12 @@ class PipelineCheckpoint:
try: try:
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1) phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
if not phase_ids: if not phase_ids:
# Redis 无数据,检查内存 # Redis 无数据,检查内存(过滤过期)
return list(self._memory.get(plan_id, [])) return [
c
for c in self._memory.get(plan_id, [])
if not self._is_expired(c.saved_at)
]
results: list[CheckpointData] = [] results: list[CheckpointData] = []
for phase_id in phase_ids: for phase_id in phase_ids:
@ -212,8 +235,8 @@ class PipelineCheckpoint:
f"plan {plan_id}: {e} — using memory fallback" f"plan {plan_id}: {e} — using memory fallback"
) )
# 内存降级 # 内存降级(过滤过期 checkpoint
return list(self._memory.get(plan_id, [])) return [c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)]
async def clear(self, plan_id: str) -> None: async def clear(self, plan_id: str) -> None:
"""清除某 plan 的所有 checkpoint。""" """清除某 plan 的所有 checkpoint。"""

View File

@ -250,6 +250,126 @@ class TestCheckpointSavePlan:
assert loaded is None assert loaded is None
# ── TTL 过期测试 ─────────────────────────────────────────
class TestCheckpointTTL:
"""内存模式 TTL 过期测试"""
@pytest.mark.asyncio
async def test_memory_checkpoint_expired_returns_none(self):
"""内存模式下 checkpoint 过期后 load 返回 None"""
# 使用极短 TTL1 秒)便于测试
cp = PipelineCheckpoint(ttl_seconds=1)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
# 立即 load 应正常返回
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
# 手动将 saved_at 改为过期时间
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"][0].saved_at = expired_time
# 过期后 load 应返回 None
loaded = await cp.load("plan_1")
assert loaded is None
@pytest.mark.asyncio
async def test_memory_plan_expired_returns_none(self):
"""内存模式下 plan 过期后 load_plan 返回 None"""
cp = PipelineCheckpoint(ttl_seconds=1)
plan = _make_plan(plan_id="plan_ttl")
await cp.save_plan(plan)
# 立即 load_plan 应正常返回
loaded = await cp.load_plan("plan_ttl")
assert loaded is not None
assert loaded["id"] == "plan_ttl"
# 手动将保存时间改为过期
import time as _time
cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10)
# 过期后 load_plan 应返回 None
loaded = await cp.load_plan("plan_ttl")
assert loaded is None
@pytest.mark.asyncio
async def test_memory_checkpoint_not_expired_returns_data(self):
"""内存模式下 checkpoint 未过期时正常返回"""
cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
@pytest.mark.asyncio
async def test_memory_list_checkpoints_filters_expired(self):
"""内存模式下 list_checkpoints 过滤过期 checkpoint"""
cp = PipelineCheckpoint(ttl_seconds=1)
# 保存 2 个 checkpoint
phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase1, "executing")
await cp.save("plan_1", phase2, "executing")
# 立即 list 应返回 2 个
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 2
# 将第一个 checkpoint 标记为过期
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"][0].saved_at = expired_time
# list 应过滤掉过期的,只返回 1 个
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 1
assert checkpoints[0].phase_id == "p2"
@pytest.mark.asyncio
async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self):
"""Redis 无数据时降级到内存,并应用 TTL 过滤"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
# save 会同时写入 Redis 和内存降级
await cp.save("plan_1", phase, "executing")
# 立即 load 应正常返回Redis 有数据)
loaded = await cp.load("plan_1")
assert loaded is not None
# 模拟 Redis 数据丢失zrange 返回空get 返回 None
redis.zrange = AsyncMock(return_value=[])
redis.get = AsyncMock(return_value=None)
# 内存降级 + TTL 未过期 → 应返回数据
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
# 标记为过期
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"][0].saved_at = expired_time
# 内存降级 + TTL 过期 → 应返回 None
loaded = await cp.load("plan_1")
assert loaded is None
# ── clear 测试 ─────────────────────────────────────────── # ── clear 测试 ───────────────────────────────────────────