feat(orchestrator): add pipeline checkpoint and crash recovery (U7)

Add PipelineCheckpoint for stage-level crash recovery with Redis-first
+ memory fallback. TeamOrchestrator saves checkpoints after each phase
finalizes and supports resume(plan_id) to continue from the last
completed phase. New POST /api/v1/tasks/{id}/resume endpoint recreates
the team from saved plan and calls resume.
This commit is contained in:
chiguyong 2026-06-24 21:04:18 +08:00
parent 3dfda904d7
commit dfd188b1a4
5 changed files with 1006 additions and 2 deletions

View File

@ -73,7 +73,12 @@ class TeamOrchestrator:
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰 DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"}) STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
def __init__(self, team: ExpertTeam, max_concurrent_phases: int | None = None) -> None: def __init__(
self,
team: ExpertTeam,
max_concurrent_phases: int | None = None,
checkpoint: Any = None,
) -> None:
self._team = team self._team = team
# Track temporary agent names created for context isolation (KTD3) # Track temporary agent names created for context isolation (KTD3)
# Maps phase_id -> temp_agent_name for cleanup # Maps phase_id -> temp_agent_name for cleanup
@ -86,6 +91,8 @@ class TeamOrchestrator:
# U2: 并发限制 — 同层并行阶段加 Semaphore避免 LLM 限流洪峰 # U2: 并发限制 — 同层并行阶段加 Semaphore避免 LLM 限流洪峰
limit = max_concurrent_phases or self.DEFAULT_MAX_CONCURRENT_PHASES limit = max_concurrent_phases or self.DEFAULT_MAX_CONCURRENT_PHASES
self._phase_semaphore = asyncio.Semaphore(limit) self._phase_semaphore = asyncio.Semaphore(limit)
# U7: Pipeline checkpoint for crash recovery
self._checkpoint = checkpoint
async def execute(self, task: str) -> dict[str, Any]: async def execute(self, task: str) -> dict[str, Any]:
"""Execute a task in pipeline mode. """Execute a task in pipeline mode.
@ -173,10 +180,31 @@ class TeamOrchestrator:
}, },
) )
# U7: Save plan for potential resume (before execution starts)
if self._checkpoint is not None:
try:
await self._checkpoint.save_plan(plan)
except Exception as e:
logger.warning(f"Checkpoint save_plan failed: {e}")
# 4. Set EXECUTING status, execute phases # 4. Set EXECUTING status, execute phases
self._team.set_status(TeamStatus.EXECUTING) self._team.set_status(TeamStatus.EXECUTING)
phase_results: dict[str, dict[str, Any]] = {} phase_results: dict[str, dict[str, Any]] = {}
return await self._run_pipeline(lead, plan, phase_results, task)
async def _run_pipeline(
self,
lead: Expert,
plan: TeamPlan,
phase_results: dict[str, dict[str, Any]],
task: str,
) -> dict[str, Any]:
"""Execute the pipeline loop: run pending phases, synthesize, return result.
Shared by execute() and resume(). phase_results may be pre-populated
by resume() with completed phase outputs.
"""
try: try:
# Execute layers sequentially, phases within layer in parallel. # Execute layers sequentially, phases within layer in parallel.
# U3: while-loop re-computes topological_sort each iteration so # U3: while-loop re-computes topological_sort each iteration so
@ -234,6 +262,13 @@ class TeamOrchestrator:
else: else:
phase_results[ph.id] = result phase_results[ph.id] = result
# U7: Save checkpoint after phase finalizes (success or failure)
if self._checkpoint is not None:
try:
await self._checkpoint.save(plan.id, ph, plan.status.value)
except Exception as e:
logger.warning(f"Checkpoint save failed for phase {ph.id}: {e}")
# U3: Divergence detection — check completed phases for conflicts # U3: Divergence detection — check completed phases for conflicts
# and dynamically insert DEBATE phases if needed # and dynamically insert DEBATE phases if needed
if self._debate_count < self.MAX_DEBATES: if self._debate_count < self.MAX_DEBATES:
@ -290,6 +325,82 @@ class TeamOrchestrator:
await self._broadcast_event("team_dissolved", {"team_id": self._team.team_id}) await self._broadcast_event("team_dissolved", {"team_id": self._team.team_id})
return await self._fallback_to_single_agent(task, plan, phase_results) return await self._fallback_to_single_agent(task, plan, phase_results)
async def resume(self, plan_id: str) -> dict[str, Any]:
"""Resume a crashed pipeline from the last completed phase checkpoint.
Flow:
1. Load plan + checkpoints from PipelineCheckpoint
2. Reconstruct TeamPlan, mark completed phases as COMPLETED
3. Pre-populate phase_results with checkpoint data
4. Call _run_pipeline to continue from next pending phase
Returns same dict shape as execute(). If no checkpoint found, returns
a failed result.
"""
if self._checkpoint is None:
return {
"status": "failed",
"result": None,
"phase_results": {},
"error": "No checkpoint manager configured",
}
# 1. Load plan
plan_dict = await self._checkpoint.load_plan(plan_id)
if plan_dict is None:
return {
"status": "failed",
"result": None,
"phase_results": {},
"error": f"No checkpoint found for plan '{plan_id}'",
}
# 2. Reconstruct TeamPlan
plan = TeamPlan.from_dict(plan_dict)
task = plan.task
# 3. Load checkpoints, mark completed phases
checkpoints = await self._checkpoint.list_checkpoints(plan_id)
phase_results: dict[str, dict[str, Any]] = {}
completed_phase_ids: set[str] = set()
for cp in checkpoints:
if cp.phase_status == "completed":
completed_phase_ids.add(cp.phase_id)
# Restore phase result from checkpoint
if cp.phase_result:
phase_results[cp.phase_id] = cp.phase_result
# Apply checkpoint state to plan phases
for ph in plan.phases:
if ph.id in completed_phase_ids:
ph.status = PhaseStatus.COMPLETED
if ph.id in phase_results and phase_results[ph.id]:
ph.result = phase_results[ph.id]
# PENDING phases remain PENDING — will be executed by _run_pipeline
logger.info(
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
f"{len(plan.phases) - len(completed_phase_ids)} pending"
)
# 4. Get lead expert
lead = self._team.lead_expert
if not lead or not lead.is_active:
active = self._team.active_experts
if not active:
return {
"status": "failed",
"result": None,
"phase_results": phase_results,
"error": "No active expert available",
}
lead = active[0]
# 5. Resume execution
self._team.set_status(TeamStatus.EXECUTING)
return await self._run_pipeline(lead, plan, phase_results, task)
async def _decompose_task(self, lead: Expert, task: str) -> list[PlanPhase]: async def _decompose_task(self, lead: Expert, task: str) -> list[PlanPhase]:
"""Lead Expert decomposes task into phases using LLM. """Lead Expert decomposes task into phases using LLM.

View File

@ -0,0 +1,237 @@
"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7)
TeamOrchestrator 阶段完成后保存 checkpoint Redis或内存降级
崩溃后可通过 resume(plan_id) 从最后完成阶段恢复
复用 PipelineStateRedis _safe_redis_call 模式
Redis 失败时降级到内存 dict不阻断执行
键命名agentkit:pipeline:checkpoint:{plan_id}:{phase_id}
TTL7 PipelineStateRedis._TTL_SECONDS 一致
"""
from __future__ import annotations
import json
import logging
import time
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from typing import Any
logger = logging.getLogger(__name__)
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
@dataclass
class CheckpointData:
"""单个阶段的 checkpoint 数据。"""
plan_id: str
phase_id: str
phase_name: str
phase_status: str
phase_result: dict[str, Any] | None = None
plan_status: str = ""
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
return cls(
plan_id=data.get("plan_id", ""),
phase_id=data.get("phase_id", ""),
phase_name=data.get("phase_name", ""),
phase_status=data.get("phase_status", ""),
phase_result=data.get("phase_result"),
plan_status=data.get("plan_status", ""),
saved_at=data.get("saved_at", ""),
)
class PipelineCheckpoint:
"""阶段级检查点存储 — Redis 优先,内存降级。
Usage::
checkpoint = PipelineCheckpoint(redis_client=redis)
await checkpoint.save(plan.id, phase, plan.status.value)
last = await checkpoint.load(plan.id)
if last:
# resume from last completed phase
"""
def __init__(
self,
redis_client: Any = None,
prefix: str = _KEY_PREFIX,
ttl_seconds: int = _TTL_SECONDS,
) -> None:
self._redis = redis_client
self._prefix = prefix
self._ttl = ttl_seconds
# 内存降级存储plan_id → list of CheckpointData
self._memory: dict[str, list[CheckpointData]] = {}
# 内存降级存储plan_id → plan dict (for resume)
self._memory_plans: dict[str, dict[str, Any]] = {}
def _key(self, plan_id: str, phase_id: str) -> str:
return f"{self._prefix}:{plan_id}:{phase_id}"
def _index_key(self, plan_id: str) -> str:
"""Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。"""
return f"{self._prefix}:index:{plan_id}"
def _plan_key(self, plan_id: str) -> str:
"""完整 plan JSON 的存储键。"""
return f"{self._prefix}:plan:{plan_id}"
async def save_plan(self, plan: Any) -> None:
"""保存完整 TeamPlan用于 resume 重建)。
Args:
plan: TeamPlan 对象需要有 to_dict() 方法
"""
plan_id = plan.id
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
# 内存降级
self._memory_plans[plan_id] = plan_dict
# 尝试写入 Redis
if self._redis is not None:
try:
await self._redis.set(
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
)
except Exception as e:
logger.warning(
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
)
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
"""加载完整 plan JSON。"""
# 优先 Redis
if self._redis is not None:
try:
raw = await self._redis.get(self._plan_key(plan_id))
if raw:
return json.loads(raw)
except Exception as e:
logger.warning(
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
)
# 内存降级
return self._memory_plans.get(plan_id)
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
"""保存阶段 checkpoint。
Args:
plan_id: 计划 ID
phase: PlanPhase 对象需要有 id, name, status, result 属性
plan_status: 计划当前状态
"""
phase_id = getattr(phase, "id", str(phase))
phase_name = getattr(phase, "name", "")
phase_status = getattr(phase, "status", "")
if hasattr(phase_status, "value"):
phase_status = phase_status.value
# 序列化 phase.result可能是 offloaded dict with _ref_key
phase_result = getattr(phase, "result", None)
if phase_result is not None and not isinstance(phase_result, dict):
phase_result = {"content": str(phase_result)}
data = CheckpointData(
plan_id=plan_id,
phase_id=phase_id,
phase_name=phase_name,
phase_status=str(phase_status),
phase_result=phase_result,
plan_status=plan_status,
)
# 总是写入内存降级(保证一致性)
self._memory.setdefault(plan_id, []).append(data)
# 尝试写入 Redis
if self._redis is not None:
try:
score = time.time()
pipe = self._redis.pipeline()
pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl)
pipe.zadd(self._index_key(plan_id), {phase_id: score})
await pipe.execute()
except Exception as e:
logger.warning(
f"PipelineCheckpoint.save Redis failed for plan {plan_id}, "
f"phase {phase_id}: {e} — using memory fallback"
)
async def load(self, plan_id: str) -> CheckpointData | None:
"""加载最后完成的阶段 checkpoint。
Returns:
最后一个 COMPLETED 阶段的 CheckpointData None
"""
checkpoints = await self.list_checkpoints(plan_id)
if not checkpoints:
return None
# 返回最后一个 COMPLETED 阶段
completed = [c for c in checkpoints if c.phase_status == "completed"]
if not completed:
return None
return completed[-1]
async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]:
"""列出某 plan 的所有 checkpoint按保存时间排序"""
# 优先从 Redis 读取
if self._redis is not None:
try:
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
if not phase_ids:
# Redis 无数据,检查内存
return list(self._memory.get(plan_id, []))
results: list[CheckpointData] = []
for phase_id in phase_ids:
raw = await self._redis.get(self._key(plan_id, phase_id))
if raw:
data = json.loads(raw)
results.append(CheckpointData.from_dict(data))
return results
except Exception as e:
logger.warning(
f"PipelineCheckpoint.list_checkpoints Redis failed for "
f"plan {plan_id}: {e} — using memory fallback"
)
# 内存降级
return list(self._memory.get(plan_id, []))
async def clear(self, plan_id: str) -> None:
"""清除某 plan 的所有 checkpoint。"""
# 清除内存
self._memory.pop(plan_id, None)
self._memory_plans.pop(plan_id, None)
# 清除 Redis
if self._redis is not None:
try:
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
pipe = self._redis.pipeline()
for phase_id in phase_ids:
pipe.delete(self._key(plan_id, phase_id))
pipe.delete(self._index_key(plan_id))
pipe.delete(self._plan_key(plan_id))
await pipe.execute()
except Exception as e:
logger.warning(
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
)

View File

@ -430,7 +430,13 @@ async def _execute_team_collab(
) )
await team.create_team(lead_config=lead_config, member_configs=member_configs) await team.create_team(lead_config=lead_config, member_configs=member_configs)
orchestrator = TeamOrchestrator(team=team) # U7: Create checkpoint manager for crash recovery
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
checkpoint = PipelineCheckpoint(
redis_client=getattr(app_state, "working_redis_client", None)
)
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
# U4: Register active team so WS messages during execution route as interventions # U4: Register active team so WS messages during execution route as interventions
_register_active_team(session_id, team) _register_active_team(session_id, team)
result = await orchestrator.execute(routing_result.task_content) result = await orchestrator.execute(routing_result.task_content)

View File

@ -209,6 +209,95 @@ async def cancel_task(task_id: str, req: Request):
return {"task_id": task_id, "status": "cancelled"} return {"task_id": task_id, "status": "cancelled"}
@router.post("/tasks/{task_id}/resume")
async def resume_task(task_id: str, req: Request):
"""Resume a crashed pipeline from the last completed phase checkpoint.
Reconstructs the team from the saved plan's expert names, creates a new
TeamOrchestrator with the checkpoint manager, and calls resume().
"""
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.router import ExpertTeamRouter
from agentkit.experts.team import ExpertTeam
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
app_state = req.app.state
# 1. Create checkpoint manager
checkpoint = PipelineCheckpoint(
redis_client=getattr(app_state, "working_redis_client", None)
)
# 2. Load plan to get expert names
plan_dict = await checkpoint.load_plan(task_id)
if plan_dict is None:
raise HTTPException(
status_code=404,
detail=f"No checkpoint found for task '{task_id}'",
)
# 3. Extract unique expert names from plan
expert_names: list[str] = []
lead_name = plan_dict.get("lead_expert", "")
if lead_name:
expert_names.append(lead_name)
for ph in plan_dict.get("phases", []):
name = ph.get("assigned_expert", "")
if name and name not in expert_names:
expert_names.append(name)
if not expert_names:
raise HTTPException(
status_code=400,
detail="Cannot resume: no experts found in saved plan",
)
# 4. Resolve expert configs via ExpertTeamRouter
template_registry = getattr(app_state, "expert_template_registry", None)
if template_registry is None:
from agentkit.experts.registry import ExpertTemplateRegistry
template_registry = ExpertTemplateRegistry()
team_router = ExpertTeamRouter(template_registry=template_registry)
expert_configs = team_router.resolve_expert_configs(expert_names)
if not expert_configs:
raise HTTPException(
status_code=400,
detail="Cannot resume: failed to resolve expert configs",
)
lead_config = expert_configs[0]
member_configs = expert_configs[1:] if len(expert_configs) > 1 else []
# 5. Create team + orchestrator
team = ExpertTeam(
pool=app_state.agent_pool,
template_registry=template_registry,
redis_client=getattr(app_state, "working_redis_client", None),
)
await team.create_team(lead_config=lead_config, member_configs=member_configs)
try:
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
result = await orchestrator.resume(task_id)
finally:
try:
await team.dissolve()
except Exception:
pass
return {
"task_id": task_id,
"status": result.get("status", "unknown"),
"result": result.get("result"),
"phase_results": {
pid: pr if isinstance(pr, dict) else {"content": str(pr)}
for pid, pr in (result.get("phase_results") or {}).items()
},
}
@router.post("/tasks/stream") @router.post("/tasks/stream")
async def stream_task(request: SubmitTaskRequest, req: Request): async def stream_task(request: SubmitTaskRequest, req: Request):
"""Submit a task and stream ReAct events via SSE""" """Submit a task and stream ReAct events via SSE"""

View File

@ -0,0 +1,561 @@
"""PipelineCheckpoint 单元测试 (U7)
测试覆盖
- save/load 阶段 checkpoint内存模式 + Redis mock
- save_plan/load_plan 完整 plan 序列化
- list_checkpoints 按保存时间排序
- clear 清除所有数据
- Redis 异常降级到内存
- resume checkpoint 恢复执行
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan
from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint
from tests.unit.experts.test_team_orchestrator import (
_make_mock_llm_gateway,
_make_team_with_experts,
)
# ── 辅助函数 ──────────────────────────────────────────────
def _make_phase(
name: str = "test_phase",
phase_id: str = "phase_1",
status: PhaseStatus = PhaseStatus.COMPLETED,
result: dict | None = None,
) -> PlanPhase:
"""创建测试用 PlanPhase"""
return PlanPhase(
id=phase_id,
name=name,
assigned_expert="test_expert",
task_description="test task",
status=status,
result=result or {"content": "test output"},
)
def _make_plan(
plan_id: str = "plan_1",
task: str = "test task",
phases: list[PlanPhase] | None = None,
) -> TeamPlan:
"""创建测试用 TeamPlan"""
plan = TeamPlan(id=plan_id, task=task, lead_expert="lead")
plan.status = PlanStatus.EXECUTING
if phases:
plan.phases = phases
return plan
def _make_mock_redis() -> AsyncMock:
"""创建 mock Redis client模拟 aioredis 行为。"""
redis = AsyncMock()
# 内部存储
store: dict[str, str] = {}
zsets: dict[str, dict[str, float]] = {}
async def _set(key, value, ex=None): # noqa: ARG001
store[key] = value
async def _get(key):
return store.get(key)
async def _zadd(key, mapping):
zsets.setdefault(key, {}).update(mapping)
async def _zrange(key, start, stop):
members = zsets.get(key, {})
sorted_members = sorted(members.keys(), key=lambda m: members[m])
if start == 0 and stop == -1:
return sorted_members
return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:]
async def _delete(*keys):
count = 0
for k in keys:
if k in store:
del store[k]
count += 1
if k in zsets:
del zsets[k]
count += 1
return count
def _pipeline():
commands: list[tuple] = []
class _Pipe:
def set(self, key, value, ex=None):
commands.append(("set", key, value, ex))
def zadd(self, key, mapping):
commands.append(("zadd", key, mapping))
def delete(self, *keys):
commands.append(("delete", keys))
async def execute(self):
for cmd in commands:
if cmd[0] == "set":
await _set(cmd[1], cmd[2], cmd[3])
elif cmd[0] == "zadd":
await _zadd(cmd[1], cmd[2])
elif cmd[0] == "delete":
await _delete(*cmd[1])
return _Pipe()
redis.set = AsyncMock(side_effect=_set)
redis.get = AsyncMock(side_effect=_get)
redis.zadd = AsyncMock(side_effect=_zadd)
redis.zrange = AsyncMock(side_effect=_zrange)
redis.delete = AsyncMock(side_effect=_delete)
redis.pipeline = MagicMock(side_effect=_pipeline)
return redis
# ── PipelineCheckpoint 基础测试 ──────────────────────────
class TestCheckpointSaveLoad:
"""save / load / list_checkpoints 基础测试"""
@pytest.mark.asyncio
async def test_save_then_load_returns_last_completed(self):
"""save 后 load 返回最后一个 COMPLETED 阶段"""
cp = PipelineCheckpoint() # 内存模式
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.plan_id == "plan_1"
assert loaded.phase_id == "p1"
assert loaded.phase_status == "completed"
assert loaded.plan_status == "executing"
@pytest.mark.asyncio
async def test_load_returns_last_completed_of_multiple(self):
"""3 个阶段完成后load 返回第 3 个(最后一个 COMPLETED"""
cp = PipelineCheckpoint()
for i in range(3):
phase = _make_phase(
name=f"phase_{i}",
phase_id=f"p{i}",
status=PhaseStatus.COMPLETED,
result={"content": f"output_{i}"},
)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p2"
assert loaded.phase_result == {"content": "output_2"}
@pytest.mark.asyncio
async def test_load_nonexistent_plan_returns_none(self):
"""load 不存在的 plan_id → None"""
cp = PipelineCheckpoint()
loaded = await cp.load("nonexistent")
assert loaded is None
@pytest.mark.asyncio
async def test_load_with_no_completed_returns_none(self):
"""只有 FAILED 阶段时 load 返回 None"""
cp = PipelineCheckpoint()
phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is None
@pytest.mark.asyncio
async def test_list_checkpoints_returns_all(self):
"""list_checkpoints 返回所有 checkpoint"""
cp = PipelineCheckpoint()
for i in range(3):
phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 3
assert all(c.plan_id == "plan_1" for c in checkpoints)
@pytest.mark.asyncio
async def test_save_serializes_phase_status_enum(self):
"""save 正确序列化 PhaseStatus enum 为字符串"""
cp = PipelineCheckpoint()
phase = _make_phase(status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert checkpoints[0].phase_status == "completed"
@pytest.mark.asyncio
async def test_save_handles_non_dict_result(self):
"""save 处理非 dict 的 phase.result"""
cp = PipelineCheckpoint()
phase = _make_phase()
phase.result = "plain string result"
await cp.save("plan_1", phase, "executing")
checkpoints = await cp.list_checkpoints("plan_1")
assert checkpoints[0].phase_result == {"content": "plain string result"}
# ── save_plan / load_plan 测试 ───────────────────────────
class TestCheckpointSavePlan:
"""save_plan / load_plan 测试"""
@pytest.mark.asyncio
async def test_save_plan_then_load_plan_roundtrip(self):
"""save_plan 后 load_plan 返回 plan dict"""
cp = PipelineCheckpoint()
plan = _make_plan(
plan_id="plan_42",
task="build feature",
phases=[
_make_phase(name="phase1", phase_id="p1"),
_make_phase(name="phase2", phase_id="p2"),
],
)
await cp.save_plan(plan)
loaded = await cp.load_plan("plan_42")
assert loaded is not None
assert loaded["id"] == "plan_42"
assert loaded["task"] == "build feature"
assert len(loaded["phases"]) == 2
assert loaded["phases"][0]["name"] == "phase1"
@pytest.mark.asyncio
async def test_load_plan_nonexistent_returns_none(self):
"""load_plan 不存在的 plan_id → None"""
cp = PipelineCheckpoint()
loaded = await cp.load_plan("nonexistent")
assert loaded is None
# ── clear 测试 ───────────────────────────────────────────
class TestCheckpointClear:
"""clear 测试"""
@pytest.mark.asyncio
async def test_clear_removes_all_data(self):
"""clear 清除所有 checkpoint 和 plan 数据"""
cp = PipelineCheckpoint()
plan = _make_plan()
phase = _make_phase()
await cp.save_plan(plan)
await cp.save("plan_1", phase, "executing")
await cp.clear("plan_1")
assert await cp.load("plan_1") is None
assert await cp.list_checkpoints("plan_1") == []
assert await cp.load_plan("plan_1") is None
@pytest.mark.asyncio
async def test_clear_nonexistent_does_not_raise(self):
"""clear 不存在的 plan_id 不抛异常"""
cp = PipelineCheckpoint()
await cp.clear("nonexistent") # should not raise
# ── Redis 模式测试 ───────────────────────────────────────
class TestCheckpointRedis:
"""Redis 模式测试(使用 mock Redis"""
@pytest.mark.asyncio
async def test_redis_save_then_load(self):
"""Redis 模式下 save/load 正常工作"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
await cp.save("plan_1", phase, "executing")
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == "p1"
assert loaded.phase_status == "completed"
@pytest.mark.asyncio
async def test_redis_save_plan_then_load_plan(self):
"""Redis 模式下 save_plan/load_plan 正常工作"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
plan = _make_plan(plan_id="plan_99")
await cp.save_plan(plan)
loaded = await cp.load_plan("plan_99")
assert loaded is not None
assert loaded["id"] == "plan_99"
@pytest.mark.asyncio
async def test_redis_clear_removes_all(self):
"""Redis 模式下 clear 清除所有数据"""
redis = _make_mock_redis()
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase()
plan = _make_plan()
await cp.save("plan_1", phase, "executing")
await cp.save_plan(plan)
await cp.clear("plan_1")
assert await cp.load("plan_1") is None
assert await cp.load_plan("plan_1") is None
@pytest.mark.asyncio
async def test_redis_failure_falls_back_to_memory(self):
"""Redis 异常时降级到内存save/load 仍工作"""
redis = _make_mock_redis()
# 让 pipeline 抛异常save 写 Redis 失败)
redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost"))
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase(status=PhaseStatus.COMPLETED)
# save 不应抛异常,降级到内存
await cp.save("plan_1", phase, "executing")
# Redis get/zrange 也失败时load 从内存降级读取
redis.get = AsyncMock(side_effect=Exception("Redis down"))
redis.zrange = AsyncMock(side_effect=Exception("Redis down"))
loaded = await cp.load("plan_1")
assert loaded is not None
assert loaded.phase_id == phase.id
@pytest.mark.asyncio
async def test_redis_save_exception_does_not_block(self):
"""save 时 Redis 异常不阻断执行"""
redis = _make_mock_redis()
# 让 pipeline 抛异常,但 zrange/get 仍正常工作
redis.pipeline = MagicMock(side_effect=Exception("Redis write error"))
cp = PipelineCheckpoint(redis_client=redis)
phase = _make_phase()
# 不应抛异常
await cp.save("plan_1", phase, "executing")
# 内存降级中应有数据Redis zrange 返回空 → 降级到内存)
checkpoints = await cp.list_checkpoints("plan_1")
assert len(checkpoints) == 1
# ── CheckpointData 数据类测试 ─────────────────────────────
class TestCheckpointData:
"""CheckpointData 序列化/反序列化测试"""
def test_to_dict_contains_all_fields(self):
"""to_dict 包含所有字段"""
data = CheckpointData(
plan_id="p1",
phase_id="ph1",
phase_name="Phase 1",
phase_status="completed",
phase_result={"content": "output"},
plan_status="executing",
)
d = data.to_dict()
assert d["plan_id"] == "p1"
assert d["phase_id"] == "ph1"
assert d["phase_name"] == "Phase 1"
assert d["phase_status"] == "completed"
assert d["phase_result"] == {"content": "output"}
assert d["plan_status"] == "executing"
assert "saved_at" in d
def test_from_dict_roundtrip(self):
"""from_dict 反序列化正确"""
original = CheckpointData(
plan_id="p1",
phase_id="ph1",
phase_name="Phase 1",
phase_status="completed",
phase_result={"content": "output"},
plan_status="executing",
)
d = original.to_dict()
restored = CheckpointData.from_dict(d)
assert restored.plan_id == original.plan_id
assert restored.phase_id == original.phase_id
assert restored.phase_name == original.phase_name
assert restored.phase_status == original.phase_status
assert restored.phase_result == original.phase_result
assert restored.plan_status == original.plan_status
def test_from_dict_with_missing_fields_uses_defaults(self):
"""from_dict 缺失字段时使用默认值"""
data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"})
assert data.plan_id == "p1"
assert data.phase_id == "ph1"
assert data.phase_name == ""
assert data.phase_status == ""
assert data.phase_result is None
assert data.plan_status == ""
# ── TeamOrchestrator.resume 集成测试 ─────────────────────
class TestOrchestratorResume:
"""TeamOrchestrator.resume 集成测试"""
@pytest.mark.asyncio
async def test_resume_without_checkpoint_returns_failed(self):
"""无 checkpoint manager 时 resume 返回 failed"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team=team) # no checkpoint
result = await orchestrator.resume("plan_1")
assert result["status"] == "failed"
assert "No checkpoint manager" in result["error"]
@pytest.mark.asyncio
async def test_resume_nonexistent_plan_returns_failed(self):
"""resume 不存在的 plan_id 返回 failed"""
team = _make_team_with_experts()
cp = PipelineCheckpoint()
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("nonexistent")
assert result["status"] == "failed"
assert "No checkpoint" in result["error"]
@pytest.mark.asyncio
async def test_resume_skips_completed_phases(self):
"""resume 跳过已完成阶段,只执行未完成阶段"""
# 创建一个有 2 个阶段的 plan阶段 1 已完成,阶段 2 依赖阶段 1
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.COMPLETED,
result={"content": "phase1 output"},
)
phase2 = PlanPhase(
id="p2",
name="phase2",
assigned_expert="member1",
task_description="task 2",
depends_on=["p1"],
status=PhaseStatus.PENDING,
)
plan = _make_plan(
plan_id="plan_resume",
task="test resume",
phases=[phase1, phase2],
)
# 保存 plan + checkpoint for phase1
cp = PipelineCheckpoint()
await cp.save_plan(plan)
await cp.save("plan_resume", phase1, "executing")
# 创建 team + orchestrator
team = _make_team_with_experts(expert_names=["lead", "member1"])
# 设置 mock LLM gateway 用于 synthesis
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
team._experts["lead"].agent._llm_gateway = gateway
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_resume")
# 验证结果
assert result["status"] == "completed"
# phase1 的结果应从 checkpoint 恢复
assert "p1" in result["phase_results"]
# phase2 应被执行
assert "p2" in result["phase_results"]
@pytest.mark.asyncio
async def test_resume_all_phases_completed_skips_execution(self):
"""resume 时所有阶段都已完成 → 直接 synthesis"""
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.COMPLETED,
result={"content": "phase1 output"},
)
plan = _make_plan(
plan_id="plan_all_done",
task="test resume all done",
phases=[phase1],
)
cp = PipelineCheckpoint()
await cp.save_plan(plan)
await cp.save("plan_all_done", phase1, "executing")
team = _make_team_with_experts()
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
team._experts["lead"].agent._llm_gateway = gateway
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_all_done")
assert result["status"] == "completed"
assert "p1" in result["phase_results"]
@pytest.mark.asyncio
async def test_resume_no_active_expert_returns_failed(self):
"""resume 时无活跃专家返回 failed"""
phase1 = PlanPhase(
id="p1",
name="phase1",
assigned_expert="lead",
task_description="task 1",
status=PhaseStatus.PENDING,
)
plan = _make_plan(
plan_id="plan_no_expert",
task="test",
phases=[phase1],
)
cp = PipelineCheckpoint()
await cp.save_plan(plan)
# 创建无活跃专家的 team
team = _make_team_with_experts()
for expert in team._experts.values():
expert.is_active = False
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
result = await orchestrator.resume("plan_no_expert")
assert result["status"] == "failed"
assert "No active expert" in result["error"]