"""PipelineCheckpoint 单元测试 (U7) 测试覆盖: - save/load 阶段 checkpoint(内存模式 + Redis mock) - save_plan/load_plan 完整 plan 序列化 - list_checkpoints 按保存时间排序 - clear 清除所有数据 - Redis 异常降级到内存 - resume 从 checkpoint 恢复执行 """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from agentkit.experts.orchestrator import TeamOrchestrator from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint from tests.unit.experts.test_team_orchestrator import ( _make_mock_llm_gateway, _make_team_with_experts, ) # ── 辅助函数 ────────────────────────────────────────────── def _make_phase( name: str = "test_phase", phase_id: str = "phase_1", status: PhaseStatus = PhaseStatus.COMPLETED, result: dict | None = None, ) -> PlanPhase: """创建测试用 PlanPhase""" return PlanPhase( id=phase_id, name=name, assigned_expert="test_expert", task_description="test task", status=status, result=result or {"content": "test output"}, ) def _make_plan( plan_id: str = "plan_1", task: str = "test task", phases: list[PlanPhase] | None = None, ) -> TeamPlan: """创建测试用 TeamPlan""" plan = TeamPlan(id=plan_id, task=task, lead_expert="lead") plan.status = PlanStatus.EXECUTING if phases: plan.phases = phases return plan def _make_mock_redis() -> AsyncMock: """创建 mock Redis client,模拟 aioredis 行为。""" redis = AsyncMock() # 内部存储 store: dict[str, str] = {} zsets: dict[str, dict[str, float]] = {} async def _set(key, value, ex=None): # noqa: ARG001 store[key] = value async def _get(key): return store.get(key) async def _zadd(key, mapping): zsets.setdefault(key, {}).update(mapping) async def _zrange(key, start, stop): members = zsets.get(key, {}) sorted_members = sorted(members.keys(), key=lambda m: members[m]) if start == 0 and stop == -1: return sorted_members return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:] async def _delete(*keys): count = 0 for k in keys: if k in store: del store[k] count += 1 if k in zsets: del zsets[k] count += 1 return count def _pipeline(): commands: list[tuple] = [] class _Pipe: def set(self, key, value, ex=None): commands.append(("set", key, value, ex)) def zadd(self, key, mapping): commands.append(("zadd", key, mapping)) def delete(self, *keys): commands.append(("delete", keys)) async def execute(self): for cmd in commands: if cmd[0] == "set": await _set(cmd[1], cmd[2], cmd[3]) elif cmd[0] == "zadd": await _zadd(cmd[1], cmd[2]) elif cmd[0] == "delete": await _delete(*cmd[1]) return _Pipe() redis.set = AsyncMock(side_effect=_set) redis.get = AsyncMock(side_effect=_get) redis.zadd = AsyncMock(side_effect=_zadd) redis.zrange = AsyncMock(side_effect=_zrange) redis.delete = AsyncMock(side_effect=_delete) redis.pipeline = MagicMock(side_effect=_pipeline) return redis # ── PipelineCheckpoint 基础测试 ────────────────────────── class TestCheckpointSaveLoad: """save / load / list_checkpoints 基础测试""" @pytest.mark.asyncio async def test_save_then_load_returns_last_completed(self): """save 后 load 返回最后一个 COMPLETED 阶段""" cp = PipelineCheckpoint() # 内存模式 phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") loaded = await cp.load("plan_1") assert loaded is not None assert loaded.plan_id == "plan_1" assert loaded.phase_id == "p1" assert loaded.phase_status == "completed" assert loaded.plan_status == "executing" @pytest.mark.asyncio async def test_load_returns_last_completed_of_multiple(self): """3 个阶段完成后,load 返回第 3 个(最后一个 COMPLETED)""" cp = PipelineCheckpoint() for i in range(3): phase = _make_phase( name=f"phase_{i}", phase_id=f"p{i}", status=PhaseStatus.COMPLETED, result={"content": f"output_{i}"}, ) await cp.save("plan_1", phase, "executing") loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == "p2" assert loaded.phase_result == {"content": "output_2"} @pytest.mark.asyncio async def test_load_nonexistent_plan_returns_none(self): """load 不存在的 plan_id → None""" cp = PipelineCheckpoint() loaded = await cp.load("nonexistent") assert loaded is None @pytest.mark.asyncio async def test_load_with_no_completed_returns_none(self): """只有 FAILED 阶段时 load 返回 None""" cp = PipelineCheckpoint() phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED) await cp.save("plan_1", phase, "executing") loaded = await cp.load("plan_1") assert loaded is None @pytest.mark.asyncio async def test_list_checkpoints_returns_all(self): """list_checkpoints 返回所有 checkpoint""" cp = PipelineCheckpoint() for i in range(3): phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") checkpoints = await cp.list_checkpoints("plan_1") assert len(checkpoints) == 3 assert all(c.plan_id == "plan_1" for c in checkpoints) @pytest.mark.asyncio async def test_save_serializes_phase_status_enum(self): """save 正确序列化 PhaseStatus enum 为字符串""" cp = PipelineCheckpoint() phase = _make_phase(status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") checkpoints = await cp.list_checkpoints("plan_1") assert checkpoints[0].phase_status == "completed" @pytest.mark.asyncio async def test_save_handles_non_dict_result(self): """save 处理非 dict 的 phase.result""" cp = PipelineCheckpoint() phase = _make_phase() phase.result = "plain string result" await cp.save("plan_1", phase, "executing") checkpoints = await cp.list_checkpoints("plan_1") assert checkpoints[0].phase_result == {"content": "plain string result"} # ── save_plan / load_plan 测试 ─────────────────────────── class TestCheckpointSavePlan: """save_plan / load_plan 测试""" @pytest.mark.asyncio async def test_save_plan_then_load_plan_roundtrip(self): """save_plan 后 load_plan 返回 plan dict""" cp = PipelineCheckpoint() plan = _make_plan( plan_id="plan_42", task="build feature", phases=[ _make_phase(name="phase1", phase_id="p1"), _make_phase(name="phase2", phase_id="p2"), ], ) await cp.save_plan(plan) loaded = await cp.load_plan("plan_42") assert loaded is not None assert loaded["id"] == "plan_42" assert loaded["task"] == "build feature" assert len(loaded["phases"]) == 2 assert loaded["phases"][0]["name"] == "phase1" @pytest.mark.asyncio async def test_load_plan_nonexistent_returns_none(self): """load_plan 不存在的 plan_id → None""" cp = PipelineCheckpoint() loaded = await cp.load_plan("nonexistent") assert loaded is None # ── TTL 过期测试 ───────────────────────────────────────── class TestCheckpointTTL: """内存模式 TTL 过期测试""" @pytest.mark.asyncio async def test_memory_checkpoint_expired_returns_none(self): """内存模式下 checkpoint 过期后 load 返回 None""" # 使用极短 TTL(1 秒)便于测试 cp = PipelineCheckpoint(ttl_seconds=1) phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") # 立即 load 应正常返回 loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == "p1" # 手动将 saved_at 改为过期时间 from datetime import datetime, timedelta, timezone expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() cp._memory["plan_1"]["p1"].saved_at = expired_time # 过期后 load 应返回 None loaded = await cp.load("plan_1") assert loaded is None @pytest.mark.asyncio async def test_memory_plan_expired_returns_none(self): """内存模式下 plan 过期后 load_plan 返回 None""" cp = PipelineCheckpoint(ttl_seconds=1) plan = _make_plan(plan_id="plan_ttl") await cp.save_plan(plan) # 立即 load_plan 应正常返回 loaded = await cp.load_plan("plan_ttl") assert loaded is not None assert loaded["id"] == "plan_ttl" # 手动将保存时间改为过期 import time as _time cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10) # 过期后 load_plan 应返回 None loaded = await cp.load_plan("plan_ttl") assert loaded is None @pytest.mark.asyncio async def test_memory_checkpoint_not_expired_returns_data(self): """内存模式下 checkpoint 未过期时正常返回""" cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == "p1" @pytest.mark.asyncio async def test_memory_list_checkpoints_filters_expired(self): """内存模式下 list_checkpoints 过滤过期 checkpoint""" cp = PipelineCheckpoint(ttl_seconds=1) # 保存 2 个 checkpoint phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase1, "executing") await cp.save("plan_1", phase2, "executing") # 立即 list 应返回 2 个 checkpoints = await cp.list_checkpoints("plan_1") assert len(checkpoints) == 2 # 将第一个 checkpoint 标记为过期 from datetime import datetime, timedelta, timezone expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() cp._memory["plan_1"]["p1"].saved_at = expired_time # list 应过滤掉过期的,只返回 1 个 checkpoints = await cp.list_checkpoints("plan_1") assert len(checkpoints) == 1 assert checkpoints[0].phase_id == "p2" @pytest.mark.asyncio async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self): """Redis 无数据时降级到内存,并应用 TTL 过滤""" redis = _make_mock_redis() cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1) phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) # save 会同时写入 Redis 和内存降级 await cp.save("plan_1", phase, "executing") # 立即 load 应正常返回(Redis 有数据) loaded = await cp.load("plan_1") assert loaded is not None # 模拟 Redis 数据丢失:zrange 返回空,get 返回 None redis.zrange = AsyncMock(return_value=[]) redis.get = AsyncMock(return_value=None) # 内存降级 + TTL 未过期 → 应返回数据 loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == "p1" # 标记为过期 from datetime import datetime, timedelta, timezone expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat() cp._memory["plan_1"]["p1"].saved_at = expired_time # 内存降级 + TTL 过期 → 应返回 None loaded = await cp.load("plan_1") assert loaded is None # ── clear 测试 ─────────────────────────────────────────── class TestCheckpointClear: """clear 测试""" @pytest.mark.asyncio async def test_clear_removes_all_data(self): """clear 清除所有 checkpoint 和 plan 数据""" cp = PipelineCheckpoint() plan = _make_plan() phase = _make_phase() await cp.save_plan(plan) await cp.save("plan_1", phase, "executing") await cp.clear("plan_1") assert await cp.load("plan_1") is None assert await cp.list_checkpoints("plan_1") == [] assert await cp.load_plan("plan_1") is None @pytest.mark.asyncio async def test_clear_nonexistent_does_not_raise(self): """clear 不存在的 plan_id 不抛异常""" cp = PipelineCheckpoint() await cp.clear("nonexistent") # should not raise # ── Redis 模式测试 ─────────────────────────────────────── class TestCheckpointRedis: """Redis 模式测试(使用 mock Redis)""" @pytest.mark.asyncio async def test_redis_save_then_load(self): """Redis 模式下 save/load 正常工作""" redis = _make_mock_redis() cp = PipelineCheckpoint(redis_client=redis) phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED) await cp.save("plan_1", phase, "executing") loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == "p1" assert loaded.phase_status == "completed" @pytest.mark.asyncio async def test_redis_save_plan_then_load_plan(self): """Redis 模式下 save_plan/load_plan 正常工作""" redis = _make_mock_redis() cp = PipelineCheckpoint(redis_client=redis) plan = _make_plan(plan_id="plan_99") await cp.save_plan(plan) loaded = await cp.load_plan("plan_99") assert loaded is not None assert loaded["id"] == "plan_99" @pytest.mark.asyncio async def test_redis_clear_removes_all(self): """Redis 模式下 clear 清除所有数据""" redis = _make_mock_redis() cp = PipelineCheckpoint(redis_client=redis) phase = _make_phase() plan = _make_plan() await cp.save("plan_1", phase, "executing") await cp.save_plan(plan) await cp.clear("plan_1") assert await cp.load("plan_1") is None assert await cp.load_plan("plan_1") is None @pytest.mark.asyncio async def test_redis_failure_falls_back_to_memory(self): """Redis 异常时降级到内存,save/load 仍工作""" redis = _make_mock_redis() # 让 pipeline 抛异常(save 写 Redis 失败) redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost")) cp = PipelineCheckpoint(redis_client=redis) phase = _make_phase(status=PhaseStatus.COMPLETED) # save 不应抛异常,降级到内存 await cp.save("plan_1", phase, "executing") # Redis get/zrange 也失败时,load 从内存降级读取 redis.get = AsyncMock(side_effect=Exception("Redis down")) redis.zrange = AsyncMock(side_effect=Exception("Redis down")) loaded = await cp.load("plan_1") assert loaded is not None assert loaded.phase_id == phase.id @pytest.mark.asyncio async def test_redis_save_exception_does_not_block(self): """save 时 Redis 异常不阻断执行""" redis = _make_mock_redis() # 让 pipeline 抛异常,但 zrange/get 仍正常工作 redis.pipeline = MagicMock(side_effect=Exception("Redis write error")) cp = PipelineCheckpoint(redis_client=redis) phase = _make_phase() # 不应抛异常 await cp.save("plan_1", phase, "executing") # 内存降级中应有数据(Redis zrange 返回空 → 降级到内存) checkpoints = await cp.list_checkpoints("plan_1") assert len(checkpoints) == 1 # ── CheckpointData 数据类测试 ───────────────────────────── class TestCheckpointData: """CheckpointData 序列化/反序列化测试""" def test_to_dict_contains_all_fields(self): """to_dict 包含所有字段""" data = CheckpointData( plan_id="p1", phase_id="ph1", phase_name="Phase 1", phase_status="completed", phase_result={"content": "output"}, plan_status="executing", ) d = data.to_dict() assert d["plan_id"] == "p1" assert d["phase_id"] == "ph1" assert d["phase_name"] == "Phase 1" assert d["phase_status"] == "completed" assert d["phase_result"] == {"content": "output"} assert d["plan_status"] == "executing" assert "saved_at" in d def test_from_dict_roundtrip(self): """from_dict 反序列化正确""" original = CheckpointData( plan_id="p1", phase_id="ph1", phase_name="Phase 1", phase_status="completed", phase_result={"content": "output"}, plan_status="executing", ) d = original.to_dict() restored = CheckpointData.from_dict(d) assert restored.plan_id == original.plan_id assert restored.phase_id == original.phase_id assert restored.phase_name == original.phase_name assert restored.phase_status == original.phase_status assert restored.phase_result == original.phase_result assert restored.plan_status == original.plan_status def test_from_dict_with_missing_fields_uses_defaults(self): """from_dict 缺失字段时使用默认值""" data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"}) assert data.plan_id == "p1" assert data.phase_id == "ph1" assert data.phase_name == "" assert data.phase_status == "" assert data.phase_result is None assert data.plan_status == "" # ── TeamOrchestrator.resume 集成测试 ───────────────────── class TestOrchestratorResume: """TeamOrchestrator.resume 集成测试""" @pytest.mark.asyncio async def test_resume_without_checkpoint_returns_failed(self): """无 checkpoint manager 时 resume 返回 failed""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team=team) # no checkpoint result = await orchestrator.resume("plan_1") assert result["status"] == "failed" assert "No checkpoint manager" in result["error"] @pytest.mark.asyncio async def test_resume_nonexistent_plan_returns_failed(self): """resume 不存在的 plan_id 返回 failed""" team = _make_team_with_experts() cp = PipelineCheckpoint() orchestrator = TeamOrchestrator(team=team, checkpoint=cp) result = await orchestrator.resume("nonexistent") assert result["status"] == "failed" assert "No checkpoint" in result["error"] @pytest.mark.asyncio async def test_resume_skips_completed_phases(self): """resume 跳过已完成阶段,只执行未完成阶段""" # 创建一个有 2 个阶段的 plan,阶段 1 已完成,阶段 2 依赖阶段 1 phase1 = PlanPhase( id="p1", name="phase1", assigned_expert="lead", task_description="task 1", status=PhaseStatus.COMPLETED, result={"content": "phase1 output"}, ) phase2 = PlanPhase( id="p2", name="phase2", assigned_expert="member1", task_description="task 2", depends_on=["p1"], status=PhaseStatus.PENDING, ) plan = _make_plan( plan_id="plan_resume", task="test resume", phases=[phase1, phase2], ) # 保存 plan + checkpoint for phase1 cp = PipelineCheckpoint() await cp.save_plan(plan) await cp.save("plan_resume", phase1, "executing") # 创建 team + orchestrator team = _make_team_with_experts(expert_names=["lead", "member1"]) # 设置 mock LLM gateway 用于 synthesis gateway = _make_mock_llm_gateway(synthesis_content="综合结果") team._experts["lead"].agent._llm_gateway = gateway orchestrator = TeamOrchestrator(team=team, checkpoint=cp) result = await orchestrator.resume("plan_resume") # 验证结果 assert result["status"] == "completed" # phase1 的结果应从 checkpoint 恢复 assert "p1" in result["phase_results"] # phase2 应被执行 assert "p2" in result["phase_results"] @pytest.mark.asyncio async def test_resume_all_phases_completed_skips_execution(self): """resume 时所有阶段都已完成 → 直接 synthesis""" phase1 = PlanPhase( id="p1", name="phase1", assigned_expert="lead", task_description="task 1", status=PhaseStatus.COMPLETED, result={"content": "phase1 output"}, ) plan = _make_plan( plan_id="plan_all_done", task="test resume all done", phases=[phase1], ) cp = PipelineCheckpoint() await cp.save_plan(plan) await cp.save("plan_all_done", phase1, "executing") team = _make_team_with_experts() gateway = _make_mock_llm_gateway(synthesis_content="综合结果") team._experts["lead"].agent._llm_gateway = gateway orchestrator = TeamOrchestrator(team=team, checkpoint=cp) result = await orchestrator.resume("plan_all_done") assert result["status"] == "completed" assert "p1" in result["phase_results"] @pytest.mark.asyncio async def test_resume_no_active_expert_returns_failed(self): """resume 时无活跃专家返回 failed""" phase1 = PlanPhase( id="p1", name="phase1", assigned_expert="lead", task_description="task 1", status=PhaseStatus.PENDING, ) plan = _make_plan( plan_id="plan_no_expert", task="test", phases=[phase1], ) cp = PipelineCheckpoint() await cp.save_plan(plan) # 创建无活跃专家的 team team = _make_team_with_experts() for expert in team._experts.values(): expert.is_active = False orchestrator = TeamOrchestrator(team=team, checkpoint=cp) result = await orchestrator.resume("plan_no_expert") assert result["status"] == "failed" assert "No active expert" in result["error"]