fischer-agentkit/tests/unit/test_pipeline_checkpoint.py

682 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PipelineCheckpoint 单元测试 (U7)
测试覆盖:
- save/load 阶段 checkpoint内存模式 + Redis mock
- save_plan/load_plan 完整 plan 序列化
- list_checkpoints 按保存时间排序
- clear 清除所有数据
- Redis 异常降级到内存
- resume 从 checkpoint 恢复执行
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan
from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint
from tests.unit.experts.test_team_orchestrator import (
_make_mock_llm_gateway,
_make_team_with_experts,
)
# ── 辅助函数 ──────────────────────────────────────────────
def _make_phase(
name: str = "test_phase",
phase_id: str = "phase_1",
status: PhaseStatus = PhaseStatus.COMPLETED,
result: dict | None = None,
) -> PlanPhase:
"""创建测试用 PlanPhase"""
return PlanPhase(
id=phase_id,
name=name,
assigned_expert="test_expert",
task_description="test task",
status=status,
result=result or {"content": "test output"},
)
def _make_plan(
plan_id: str = "plan_1",
task: str = "test task",
phases: list[PlanPhase] | None = None,
) -> TeamPlan:
"""创建测试用 TeamPlan"""
plan = TeamPlan(id=plan_id, task=task, lead_expert="lead")
plan.status = PlanStatus.EXECUTING
if phases:
plan.phases = phases
return plan
def _make_mock_redis() -> AsyncMock:
"""创建 mock Redis client模拟 aioredis 行为。"""
redis = AsyncMock()
# 内部存储
store: dict[str, str] = {}
zsets: dict[str, dict[str, float]] = {}
async def _set(key, value, ex=None): # noqa: ARG001
store[key] = value
async def _get(key):
return store.get(key)
async def _zadd(key, mapping):
zsets.setdefault(key, {}).update(mapping)
async def _zrange(key, start, stop):
members = zsets.get(key, {})
sorted_members = sorted(members.keys(), key=lambda m: members[m])
if start == 0 and stop == -1:
return sorted_members
return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:]
async def _delete(*keys):
count = 0
for k in keys:
if k in store:
del store[k]
count += 1
if k in zsets:
del zsets[k]
count += 1
return count
def _pipeline():
commands: list[tuple] = []
class _Pipe:
def set(self, key, value, ex=None):
commands.append(("set", key, value, ex))
def zadd(self, key, mapping):
commands.append(("zadd", key, mapping))
def delete(self, *keys):
commands.append(("delete", keys))
async def execute(self):
for cmd in commands:
if cmd[0] == "set":
await _set(cmd[1], cmd[2], cmd[3])
elif cmd[0] == "zadd":
await _zadd(cmd[1], cmd[2])
elif cmd[0] == "delete":
await _delete(*cmd[1])
return _Pipe()
redis.set = AsyncMock(side_effect=_set)
redis.get = AsyncMock(side_effect=_get)
redis.zadd = AsyncMock(side_effect=_zadd)
redis.zrange = AsyncMock(side_effect=_zrange)
redis.delete = AsyncMock(side_effect=_delete)
redis.pipeline = MagicMock(side_effect=_pipeline)
return redis
# ── PipelineCheckpoint 基础测试 ──────────────────────────
class TestCheckpointSaveLoad:
"""save / load / list_checkpoints 基础测试"""
@pytest.mark.asyncio
async def test_save_then_load_returns_last_completed(self):
"""save 后 load 返回最后一个 COMPLETED 阶段"""
cp = PipelineCheckpoint() # 内存模式
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.plan_id == "plan_1"
assert loaded.phase_id == "p1"
assert loaded.phase_status == "completed"
assert loaded.plan_status == "executing"
@pytest.mark.asyncio
async def test_load_returns_last_completed_of_multiple(self):
"""3 个阶段完成后load 返回第 3 个(最后一个 COMPLETED"""
cp = PipelineCheckpoint()
for i in range(3):
phase = _make_phase(
name=f"phase_{i}",
phase_id=f"p{i}",
status=PhaseStatus.COMPLETED,
result={"content": f"output_{i}"},
)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p2"
assert loaded.phase_result == {"content": "output_2"}
@pytest.mark.asyncio
async def test_load_nonexistent_plan_returns_none(self):
"""load 不存在的 plan_id → None"""
cp = PipelineCheckpoint()
loaded = await cp.load("nonexistent")
assert loaded is None
@pytest.mark.asyncio
async def test_load_with_no_completed_returns_none(self):
"""只有 FAILED 阶段时 load 返回 None"""
cp = PipelineCheckpoint()
phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is None
@pytest.mark.asyncio
async def test_list_checkpoints_returns_all(self):
"""list_checkpoints 返回所有 checkpoint"""
cp = PipelineCheckpoint()
for i in range(3):
phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 3
assert all(c.plan_id == "plan_1" for c in checkpoints)
@pytest.mark.asyncio
async def test_save_serializes_phase_status_enum(self):
"""save 正确序列化 PhaseStatus enum 为字符串"""
cp = PipelineCheckpoint()
phase = _make_phase(status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert checkpoints[0].phase_status == "completed"
@pytest.mark.asyncio
async def test_save_handles_non_dict_result(self):
"""save 处理非 dict 的 phase.result"""
cp = PipelineCheckpoint()
phase = _make_phase()
phase.result = "plain string result"
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert checkpoints[0].phase_result == {"content": "plain string result"}
# ── save_plan / load_plan 测试 ───────────────────────────
class TestCheckpointSavePlan:
"""save_plan / load_plan 测试"""
@pytest.mark.asyncio
async def test_save_plan_then_load_plan_roundtrip(self):
"""save_plan 后 load_plan 返回 plan dict"""
cp = PipelineCheckpoint()
plan = _make_plan(
plan_id="plan_42",
task="build feature",
phases=[
_make_phase(name="phase1", phase_id="p1"),
_make_phase(name="phase2", phase_id="p2"),
],
)
await cp.save_plan(plan)
loaded = await cp.load_plan("plan_42")
assert loaded is not None
assert loaded["id"] == "plan_42"
assert loaded["task"] == "build feature"
assert len(loaded["phases"]) == 2
assert loaded["phases"][0]["name"] == "phase1"
@pytest.mark.asyncio
async def test_load_plan_nonexistent_returns_none(self):
"""load_plan 不存在的 plan_id → None"""
cp = PipelineCheckpoint()
loaded = await cp.load_plan("nonexistent")
assert loaded is None
# ── TTL 过期测试 ─────────────────────────────────────────
class TestCheckpointTTL:
"""内存模式 TTL 过期测试"""
@pytest.mark.asyncio
async def test_memory_checkpoint_expired_returns_none(self):
"""内存模式下 checkpoint 过期后 load 返回 None"""
# 使用极短 TTL1 秒)便于测试
cp = PipelineCheckpoint(ttl_seconds=1)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
# 立即 load 应正常返回
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
# 手动将 saved_at 改为过期时间
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"]["p1"].saved_at = expired_time
# 过期后 load 应返回 None
loaded = await cp.load("plan_1")
assert loaded is None
@pytest.mark.asyncio
async def test_memory_plan_expired_returns_none(self):
"""内存模式下 plan 过期后 load_plan 返回 None"""
cp = PipelineCheckpoint(ttl_seconds=1)
plan = _make_plan(plan_id="plan_ttl")
await cp.save_plan(plan)
# 立即 load_plan 应正常返回
loaded = await cp.load_plan("plan_ttl")
assert loaded is not None
assert loaded["id"] == "plan_ttl"
# 手动将保存时间改为过期
import time as _time
cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10)
# 过期后 load_plan 应返回 None
loaded = await cp.load_plan("plan_ttl")
assert loaded is None
@pytest.mark.asyncio
async def test_memory_checkpoint_not_expired_returns_data(self):
"""内存模式下 checkpoint 未过期时正常返回"""
cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
@pytest.mark.asyncio
async def test_memory_list_checkpoints_filters_expired(self):
"""内存模式下 list_checkpoints 过滤过期 checkpoint"""
cp = PipelineCheckpoint(ttl_seconds=1)
# 保存 2 个 checkpoint
phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase1, "executing")
await cp.save("plan_1", phase2, "executing")
# 立即 list 应返回 2 个
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 2
# 将第一个 checkpoint 标记为过期
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"]["p1"].saved_at = expired_time
# list 应过滤掉过期的,只返回 1 个
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 1
assert checkpoints[0].phase_id == "p2"
@pytest.mark.asyncio
async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self):
"""Redis 无数据时降级到内存,并应用 TTL 过滤"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
# save 会同时写入 Redis 和内存降级
await cp.save("plan_1", phase, "executing")
# 立即 load 应正常返回Redis 有数据)
loaded = await cp.load("plan_1")
assert loaded is not None
# 模拟 Redis 数据丢失zrange 返回空get 返回 None
redis.zrange = AsyncMock(return_value=[])
redis.get = AsyncMock(return_value=None)
# 内存降级 + TTL 未过期 → 应返回数据
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
# 标记为过期
from datetime import datetime, timedelta, timezone
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
cp._memory["plan_1"]["p1"].saved_at = expired_time
# 内存降级 + TTL 过期 → 应返回 None
loaded = await cp.load("plan_1")
assert loaded is None
# ── clear 测试 ───────────────────────────────────────────
class TestCheckpointClear:
"""clear 测试"""
@pytest.mark.asyncio
async def test_clear_removes_all_data(self):
"""clear 清除所有 checkpoint 和 plan 数据"""
cp = PipelineCheckpoint()
plan = _make_plan()
phase = _make_phase()
await cp.save_plan(plan)
await cp.save("plan_1", phase, "executing")
await cp.clear("plan_1")
assert await cp.load("plan_1") is None
assert await cp.list_checkpoints("plan_1") == []
assert await cp.load_plan("plan_1") is None
@pytest.mark.asyncio
async def test_clear_nonexistent_does_not_raise(self):
"""clear 不存在的 plan_id 不抛异常"""
cp = PipelineCheckpoint()
await cp.clear("nonexistent") # should not raise
# ── Redis 模式测试 ───────────────────────────────────────
class TestCheckpointRedis:
"""Redis 模式测试(使用 mock Redis"""
@pytest.mark.asyncio
async def test_redis_save_then_load(self):
"""Redis 模式下 save/load 正常工作"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
assert loaded.phase_status == "completed"
@pytest.mark.asyncio
async def test_redis_save_plan_then_load_plan(self):
"""Redis 模式下 save_plan/load_plan 正常工作"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
plan = _make_plan(plan_id="plan_99")
await cp.save_plan(plan)
loaded = await cp.load_plan("plan_99")
assert loaded is not None
assert loaded["id"] == "plan_99"
@pytest.mark.asyncio
async def test_redis_clear_removes_all(self):
"""Redis 模式下 clear 清除所有数据"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase()
plan = _make_plan()
await cp.save("plan_1", phase, "executing")
await cp.save_plan(plan)
await cp.clear("plan_1")
assert await cp.load("plan_1") is None
assert await cp.load_plan("plan_1") is None
@pytest.mark.asyncio
async def test_redis_failure_falls_back_to_memory(self):
"""Redis 异常时降级到内存save/load 仍工作"""
redis = _make_mock_redis()
# 让 pipeline 抛异常save 写 Redis 失败)
redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost"))
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase(status=PhaseStatus.COMPLETED)
# save 不应抛异常,降级到内存
await cp.save("plan_1", phase, "executing")
# Redis get/zrange 也失败时load 从内存降级读取
redis.get = AsyncMock(side_effect=Exception("Redis down"))
redis.zrange = AsyncMock(side_effect=Exception("Redis down"))
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == phase.id
@pytest.mark.asyncio
async def test_redis_save_exception_does_not_block(self):
"""save 时 Redis 异常不阻断执行"""
redis = _make_mock_redis()
# 让 pipeline 抛异常,但 zrange/get 仍正常工作
redis.pipeline = MagicMock(side_effect=Exception("Redis write error"))
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase()
# 不应抛异常
await cp.save("plan_1", phase, "executing")
# 内存降级中应有数据Redis zrange 返回空 → 降级到内存)
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 1
# ── CheckpointData 数据类测试 ─────────────────────────────
class TestCheckpointData:
"""CheckpointData 序列化/反序列化测试"""
def test_to_dict_contains_all_fields(self):
"""to_dict 包含所有字段"""
data = CheckpointData(
plan_id="p1",
phase_id="ph1",
phase_name="Phase 1",
phase_status="completed",
phase_result={"content": "output"},
plan_status="executing",
)
d = data.to_dict()
assert d["plan_id"] == "p1"
assert d["phase_id"] == "ph1"
assert d["phase_name"] == "Phase 1"
assert d["phase_status"] == "completed"
assert d["phase_result"] == {"content": "output"}
assert d["plan_status"] == "executing"
assert "saved_at" in d
def test_from_dict_roundtrip(self):
"""from_dict 反序列化正确"""
original = CheckpointData(
plan_id="p1",
phase_id="ph1",
phase_name="Phase 1",
phase_status="completed",
phase_result={"content": "output"},
plan_status="executing",
)
d = original.to_dict()
restored = CheckpointData.from_dict(d)
assert restored.plan_id == original.plan_id
assert restored.phase_id == original.phase_id
assert restored.phase_name == original.phase_name
assert restored.phase_status == original.phase_status
assert restored.phase_result == original.phase_result
assert restored.plan_status == original.plan_status
def test_from_dict_with_missing_fields_uses_defaults(self):
"""from_dict 缺失字段时使用默认值"""
data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"})
assert data.plan_id == "p1"
assert data.phase_id == "ph1"
assert data.phase_name == ""
assert data.phase_status == ""
assert data.phase_result is None
assert data.plan_status == ""
# ── TeamOrchestrator.resume 集成测试 ─────────────────────
class TestOrchestratorResume:
"""TeamOrchestrator.resume 集成测试"""
@pytest.mark.asyncio
async def test_resume_without_checkpoint_returns_failed(self):
"""无 checkpoint manager 时 resume 返回 failed"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team=team) # no checkpoint
result = await orchestrator.resume("plan_1")
assert result["status"] == "failed"
assert "No checkpoint manager" in result["error"]
@pytest.mark.asyncio
async def test_resume_nonexistent_plan_returns_failed(self):
"""resume 不存在的 plan_id 返回 failed"""
team = _make_team_with_experts()
cp = PipelineCheckpoint()
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("nonexistent")
assert result["status"] == "failed"
assert "No checkpoint" in result["error"]
@pytest.mark.asyncio
async def test_resume_skips_completed_phases(self):
"""resume 跳过已完成阶段,只执行未完成阶段"""
# 创建一个有 2 个阶段的 plan阶段 1 已完成,阶段 2 依赖阶段 1
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.COMPLETED,
result={"content": "phase1 output"},
)
phase2 = PlanPhase(
id="p2",
name="phase2",
assigned_expert="member1",
task_description="task 2",
depends_on=["p1"],
status=PhaseStatus.PENDING,
)
plan = _make_plan(
plan_id="plan_resume",
task="test resume",
phases=[phase1, phase2],
)
# 保存 plan + checkpoint for phase1
cp = PipelineCheckpoint()
await cp.save_plan(plan)
await cp.save("plan_resume", phase1, "executing")
# 创建 team + orchestrator
team = _make_team_with_experts(expert_names=["lead", "member1"])
# 设置 mock LLM gateway 用于 synthesis
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
team._experts["lead"].agent._llm_gateway = gateway
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_resume")
# 验证结果
assert result["status"] == "completed"
# phase1 的结果应从 checkpoint 恢复
assert "p1" in result["phase_results"]
# phase2 应被执行
assert "p2" in result["phase_results"]
@pytest.mark.asyncio
async def test_resume_all_phases_completed_skips_execution(self):
"""resume 时所有阶段都已完成 → 直接 synthesis"""
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.COMPLETED,
result={"content": "phase1 output"},
)
plan = _make_plan(
plan_id="plan_all_done",
task="test resume all done",
phases=[phase1],
)
cp = PipelineCheckpoint()
await cp.save_plan(plan)
await cp.save("plan_all_done", phase1, "executing")
team = _make_team_with_experts()
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
team._experts["lead"].agent._llm_gateway = gateway
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_all_done")
assert result["status"] == "completed"
assert "p1" in result["phase_results"]
@pytest.mark.asyncio
async def test_resume_no_active_expert_returns_failed(self):
"""resume 时无活跃专家返回 failed"""
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.PENDING,
)
plan = _make_plan(
plan_id="plan_no_expert",
task="test",
phases=[phase1],
)
cp = PipelineCheckpoint()
await cp.save_plan(plan)
# 创建无活跃专家的 team
team = _make_team_with_experts()
for expert in team._experts.values():
expert.is_active = False
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_no_expert")
assert result["status"] == "failed"
assert "No active expert" in result["error"]