fix(checkpoint): add TTL expiration for memory fallback mode
内存降级模式之前没有 TTL 过期机制,长期运行进程会导致内存泄漏。 现在 list_checkpoints 和 load_plan 在内存模式下会过滤/清除过期数据。 - list_checkpoints: 内存降级分支过滤过期 checkpoint - load_plan: 内存降级分支检查 TTL 过期,过期则清除并返回 None - 新增 _is_expired 方法检查 saved_at 是否超过 TTL - _memory_plans 类型改为 tuple(plan_dict, timestamp) 以支持 TTL - 新增 5 个 TTL 过期测试覆盖内存模式和 Redis 降级场景
This commit is contained in:
parent
fa152e24ac
commit
0847c0e086
|
|
@ -76,8 +76,19 @@ class PipelineCheckpoint:
|
||||||
self._ttl = ttl_seconds
|
self._ttl = ttl_seconds
|
||||||
# 内存降级存储:plan_id → list of CheckpointData
|
# 内存降级存储:plan_id → list of CheckpointData
|
||||||
self._memory: dict[str, list[CheckpointData]] = {}
|
self._memory: dict[str, list[CheckpointData]] = {}
|
||||||
# 内存降级存储:plan_id → plan dict (for resume)
|
# 内存降级存储:plan_id → (plan_dict, saved_timestamp)
|
||||||
self._memory_plans: dict[str, dict[str, Any]] = {}
|
self._memory_plans: dict[str, tuple[dict[str, Any], float]] = {}
|
||||||
|
|
||||||
|
def _is_expired(self, saved_at: str) -> bool:
|
||||||
|
"""检查 checkpoint 是否已过期(内存模式 TTL)。"""
|
||||||
|
if not saved_at:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
saved_time = datetime.fromisoformat(saved_at)
|
||||||
|
age = (datetime.now(timezone.utc) - saved_time).total_seconds()
|
||||||
|
return age > self._ttl
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
def _key(self, plan_id: str, phase_id: str) -> str:
|
def _key(self, plan_id: str, phase_id: str) -> str:
|
||||||
return f"{self._prefix}:{plan_id}:{phase_id}"
|
return f"{self._prefix}:{plan_id}:{phase_id}"
|
||||||
|
|
@ -99,8 +110,8 @@ class PipelineCheckpoint:
|
||||||
plan_id = plan.id
|
plan_id = plan.id
|
||||||
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
|
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
|
||||||
|
|
||||||
# 内存降级
|
# 内存降级(带 TTL 时间戳)
|
||||||
self._memory_plans[plan_id] = plan_dict
|
self._memory_plans[plan_id] = (plan_dict, time.time())
|
||||||
|
|
||||||
# 尝试写入 Redis
|
# 尝试写入 Redis
|
||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
|
|
@ -125,8 +136,16 @@ class PipelineCheckpoint:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
||||||
)
|
)
|
||||||
# 内存降级
|
# 内存降级(检查 TTL 过期)
|
||||||
return self._memory_plans.get(plan_id)
|
entry = self._memory_plans.get(plan_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
plan_dict, saved_at = entry
|
||||||
|
if time.time() - saved_at > self._ttl:
|
||||||
|
# 已过期,清除
|
||||||
|
del self._memory_plans[plan_id]
|
||||||
|
return None
|
||||||
|
return plan_dict
|
||||||
|
|
||||||
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
||||||
"""保存阶段 checkpoint。
|
"""保存阶段 checkpoint。
|
||||||
|
|
@ -196,8 +215,12 @@ class PipelineCheckpoint:
|
||||||
try:
|
try:
|
||||||
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||||||
if not phase_ids:
|
if not phase_ids:
|
||||||
# Redis 无数据,检查内存
|
# Redis 无数据,检查内存(过滤过期)
|
||||||
return list(self._memory.get(plan_id, []))
|
return [
|
||||||
|
c
|
||||||
|
for c in self._memory.get(plan_id, [])
|
||||||
|
if not self._is_expired(c.saved_at)
|
||||||
|
]
|
||||||
|
|
||||||
results: list[CheckpointData] = []
|
results: list[CheckpointData] = []
|
||||||
for phase_id in phase_ids:
|
for phase_id in phase_ids:
|
||||||
|
|
@ -212,8 +235,8 @@ class PipelineCheckpoint:
|
||||||
f"plan {plan_id}: {e} — using memory fallback"
|
f"plan {plan_id}: {e} — using memory fallback"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 内存降级
|
# 内存降级(过滤过期 checkpoint)
|
||||||
return list(self._memory.get(plan_id, []))
|
return [c for c in self._memory.get(plan_id, []) if not self._is_expired(c.saved_at)]
|
||||||
|
|
||||||
async def clear(self, plan_id: str) -> None:
|
async def clear(self, plan_id: str) -> None:
|
||||||
"""清除某 plan 的所有 checkpoint。"""
|
"""清除某 plan 的所有 checkpoint。"""
|
||||||
|
|
|
||||||
|
|
@ -250,6 +250,126 @@ class TestCheckpointSavePlan:
|
||||||
assert loaded is None
|
assert loaded is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── TTL 过期测试 ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointTTL:
|
||||||
|
"""内存模式 TTL 过期测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_checkpoint_expired_returns_none(self):
|
||||||
|
"""内存模式下 checkpoint 过期后 load 返回 None"""
|
||||||
|
# 使用极短 TTL(1 秒)便于测试
|
||||||
|
cp = PipelineCheckpoint(ttl_seconds=1)
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
# 立即 load 应正常返回
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == "p1"
|
||||||
|
|
||||||
|
# 手动将 saved_at 改为过期时间
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||||
|
cp._memory["plan_1"][0].saved_at = expired_time
|
||||||
|
|
||||||
|
# 过期后 load 应返回 None
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_plan_expired_returns_none(self):
|
||||||
|
"""内存模式下 plan 过期后 load_plan 返回 None"""
|
||||||
|
cp = PipelineCheckpoint(ttl_seconds=1)
|
||||||
|
plan = _make_plan(plan_id="plan_ttl")
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
|
||||||
|
# 立即 load_plan 应正常返回
|
||||||
|
loaded = await cp.load_plan("plan_ttl")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded["id"] == "plan_ttl"
|
||||||
|
|
||||||
|
# 手动将保存时间改为过期
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10)
|
||||||
|
|
||||||
|
# 过期后 load_plan 应返回 None
|
||||||
|
loaded = await cp.load_plan("plan_ttl")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_checkpoint_not_expired_returns_data(self):
|
||||||
|
"""内存模式下 checkpoint 未过期时正常返回"""
|
||||||
|
cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_list_checkpoints_filters_expired(self):
|
||||||
|
"""内存模式下 list_checkpoints 过滤过期 checkpoint"""
|
||||||
|
cp = PipelineCheckpoint(ttl_seconds=1)
|
||||||
|
# 保存 2 个 checkpoint
|
||||||
|
phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase1, "executing")
|
||||||
|
await cp.save("plan_1", phase2, "executing")
|
||||||
|
|
||||||
|
# 立即 list 应返回 2 个
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
|
||||||
|
# 将第一个 checkpoint 标记为过期
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||||
|
cp._memory["plan_1"][0].saved_at = expired_time
|
||||||
|
|
||||||
|
# list 应过滤掉过期的,只返回 1 个
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].phase_id == "p2"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self):
|
||||||
|
"""Redis 无数据时降级到内存,并应用 TTL 过滤"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1)
|
||||||
|
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
# save 会同时写入 Redis 和内存降级
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
# 立即 load 应正常返回(Redis 有数据)
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
|
||||||
|
# 模拟 Redis 数据丢失:zrange 返回空,get 返回 None
|
||||||
|
redis.zrange = AsyncMock(return_value=[])
|
||||||
|
redis.get = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
# 内存降级 + TTL 未过期 → 应返回数据
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == "p1"
|
||||||
|
|
||||||
|
# 标记为过期
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||||
|
cp._memory["plan_1"][0].saved_at = expired_time
|
||||||
|
|
||||||
|
# 内存降级 + TTL 过期 → 应返回 None
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
|
||||||
# ── clear 测试 ───────────────────────────────────────────
|
# ── clear 测试 ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue