feat(orchestrator): add pipeline checkpoint and crash recovery (U7)
Add PipelineCheckpoint for stage-level crash recovery with Redis-first
+ memory fallback. TeamOrchestrator saves checkpoints after each phase
finalizes and supports resume(plan_id) to continue from the last
completed phase. New POST /api/v1/tasks/{id}/resume endpoint recreates
the team from saved plan and calls resume.
This commit is contained in:
parent
3dfda904d7
commit
dfd188b1a4
|
|
@ -73,7 +73,12 @@ class TeamOrchestrator:
|
|||
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
||||
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
||||
|
||||
def __init__(self, team: ExpertTeam, max_concurrent_phases: int | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
team: ExpertTeam,
|
||||
max_concurrent_phases: int | None = None,
|
||||
checkpoint: Any = None,
|
||||
) -> None:
|
||||
self._team = team
|
||||
# Track temporary agent names created for context isolation (KTD3)
|
||||
# Maps phase_id -> temp_agent_name for cleanup
|
||||
|
|
@ -86,6 +91,8 @@ class TeamOrchestrator:
|
|||
# U2: 并发限制 — 同层并行阶段加 Semaphore,避免 LLM 限流洪峰
|
||||
limit = max_concurrent_phases or self.DEFAULT_MAX_CONCURRENT_PHASES
|
||||
self._phase_semaphore = asyncio.Semaphore(limit)
|
||||
# U7: Pipeline checkpoint for crash recovery
|
||||
self._checkpoint = checkpoint
|
||||
|
||||
async def execute(self, task: str) -> dict[str, Any]:
|
||||
"""Execute a task in pipeline mode.
|
||||
|
|
@ -173,10 +180,31 @@ class TeamOrchestrator:
|
|||
},
|
||||
)
|
||||
|
||||
# U7: Save plan for potential resume (before execution starts)
|
||||
if self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save_plan(plan)
|
||||
except Exception as e:
|
||||
logger.warning(f"Checkpoint save_plan failed: {e}")
|
||||
|
||||
# 4. Set EXECUTING status, execute phases
|
||||
self._team.set_status(TeamStatus.EXECUTING)
|
||||
phase_results: dict[str, dict[str, Any]] = {}
|
||||
|
||||
return await self._run_pipeline(lead, plan, phase_results, task)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
lead: Expert,
|
||||
plan: TeamPlan,
|
||||
phase_results: dict[str, dict[str, Any]],
|
||||
task: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the pipeline loop: run pending phases, synthesize, return result.
|
||||
|
||||
Shared by execute() and resume(). phase_results may be pre-populated
|
||||
by resume() with completed phase outputs.
|
||||
"""
|
||||
try:
|
||||
# Execute layers sequentially, phases within layer in parallel.
|
||||
# U3: while-loop re-computes topological_sort each iteration so
|
||||
|
|
@ -234,6 +262,13 @@ class TeamOrchestrator:
|
|||
else:
|
||||
phase_results[ph.id] = result
|
||||
|
||||
# U7: Save checkpoint after phase finalizes (success or failure)
|
||||
if self._checkpoint is not None:
|
||||
try:
|
||||
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Checkpoint save failed for phase {ph.id}: {e}")
|
||||
|
||||
# U3: Divergence detection — check completed phases for conflicts
|
||||
# and dynamically insert DEBATE phases if needed
|
||||
if self._debate_count < self.MAX_DEBATES:
|
||||
|
|
@ -290,6 +325,82 @@ class TeamOrchestrator:
|
|||
await self._broadcast_event("team_dissolved", {"team_id": self._team.team_id})
|
||||
return await self._fallback_to_single_agent(task, plan, phase_results)
|
||||
|
||||
async def resume(self, plan_id: str) -> dict[str, Any]:
|
||||
"""Resume a crashed pipeline from the last completed phase checkpoint.
|
||||
|
||||
Flow:
|
||||
1. Load plan + checkpoints from PipelineCheckpoint
|
||||
2. Reconstruct TeamPlan, mark completed phases as COMPLETED
|
||||
3. Pre-populate phase_results with checkpoint data
|
||||
4. Call _run_pipeline to continue from next pending phase
|
||||
|
||||
Returns same dict shape as execute(). If no checkpoint found, returns
|
||||
a failed result.
|
||||
"""
|
||||
if self._checkpoint is None:
|
||||
return {
|
||||
"status": "failed",
|
||||
"result": None,
|
||||
"phase_results": {},
|
||||
"error": "No checkpoint manager configured",
|
||||
}
|
||||
|
||||
# 1. Load plan
|
||||
plan_dict = await self._checkpoint.load_plan(plan_id)
|
||||
if plan_dict is None:
|
||||
return {
|
||||
"status": "failed",
|
||||
"result": None,
|
||||
"phase_results": {},
|
||||
"error": f"No checkpoint found for plan '{plan_id}'",
|
||||
}
|
||||
|
||||
# 2. Reconstruct TeamPlan
|
||||
plan = TeamPlan.from_dict(plan_dict)
|
||||
task = plan.task
|
||||
|
||||
# 3. Load checkpoints, mark completed phases
|
||||
checkpoints = await self._checkpoint.list_checkpoints(plan_id)
|
||||
phase_results: dict[str, dict[str, Any]] = {}
|
||||
completed_phase_ids: set[str] = set()
|
||||
|
||||
for cp in checkpoints:
|
||||
if cp.phase_status == "completed":
|
||||
completed_phase_ids.add(cp.phase_id)
|
||||
# Restore phase result from checkpoint
|
||||
if cp.phase_result:
|
||||
phase_results[cp.phase_id] = cp.phase_result
|
||||
|
||||
# Apply checkpoint state to plan phases
|
||||
for ph in plan.phases:
|
||||
if ph.id in completed_phase_ids:
|
||||
ph.status = PhaseStatus.COMPLETED
|
||||
if ph.id in phase_results and phase_results[ph.id]:
|
||||
ph.result = phase_results[ph.id]
|
||||
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
||||
|
||||
logger.info(
|
||||
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
||||
f"{len(plan.phases) - len(completed_phase_ids)} pending"
|
||||
)
|
||||
|
||||
# 4. Get lead expert
|
||||
lead = self._team.lead_expert
|
||||
if not lead or not lead.is_active:
|
||||
active = self._team.active_experts
|
||||
if not active:
|
||||
return {
|
||||
"status": "failed",
|
||||
"result": None,
|
||||
"phase_results": phase_results,
|
||||
"error": "No active expert available",
|
||||
}
|
||||
lead = active[0]
|
||||
|
||||
# 5. Resume execution
|
||||
self._team.set_status(TeamStatus.EXECUTING)
|
||||
return await self._run_pipeline(lead, plan, phase_results, task)
|
||||
|
||||
async def _decompose_task(self, lead: Expert, task: str) -> list[PlanPhase]:
|
||||
"""Lead Expert decomposes task into phases using LLM.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,237 @@
|
|||
"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7)
|
||||
|
||||
在 TeamOrchestrator 阶段完成后保存 checkpoint 到 Redis(或内存降级),
|
||||
崩溃后可通过 resume(plan_id) 从最后完成阶段恢复。
|
||||
|
||||
复用 PipelineStateRedis 的 _safe_redis_call 模式:
|
||||
Redis 失败时降级到内存 dict,不阻断执行。
|
||||
|
||||
键命名:agentkit:pipeline:checkpoint:{plan_id}:{phase_id}
|
||||
TTL:7 天(与 PipelineStateRedis._TTL_SECONDS 一致)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||||
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointData:
|
||||
"""单个阶段的 checkpoint 数据。"""
|
||||
|
||||
plan_id: str
|
||||
phase_id: str
|
||||
phase_name: str
|
||||
phase_status: str
|
||||
phase_result: dict[str, Any] | None = None
|
||||
plan_status: str = ""
|
||||
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
|
||||
return cls(
|
||||
plan_id=data.get("plan_id", ""),
|
||||
phase_id=data.get("phase_id", ""),
|
||||
phase_name=data.get("phase_name", ""),
|
||||
phase_status=data.get("phase_status", ""),
|
||||
phase_result=data.get("phase_result"),
|
||||
plan_status=data.get("plan_status", ""),
|
||||
saved_at=data.get("saved_at", ""),
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpoint:
|
||||
"""阶段级检查点存储 — Redis 优先,内存降级。
|
||||
|
||||
Usage::
|
||||
|
||||
checkpoint = PipelineCheckpoint(redis_client=redis)
|
||||
await checkpoint.save(plan.id, phase, plan.status.value)
|
||||
last = await checkpoint.load(plan.id)
|
||||
if last:
|
||||
# resume from last completed phase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any = None,
|
||||
prefix: str = _KEY_PREFIX,
|
||||
ttl_seconds: int = _TTL_SECONDS,
|
||||
) -> None:
|
||||
self._redis = redis_client
|
||||
self._prefix = prefix
|
||||
self._ttl = ttl_seconds
|
||||
# 内存降级存储:plan_id → list of CheckpointData
|
||||
self._memory: dict[str, list[CheckpointData]] = {}
|
||||
# 内存降级存储:plan_id → plan dict (for resume)
|
||||
self._memory_plans: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _key(self, plan_id: str, phase_id: str) -> str:
|
||||
return f"{self._prefix}:{plan_id}:{phase_id}"
|
||||
|
||||
def _index_key(self, plan_id: str) -> str:
|
||||
"""Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。"""
|
||||
return f"{self._prefix}:index:{plan_id}"
|
||||
|
||||
def _plan_key(self, plan_id: str) -> str:
|
||||
"""完整 plan JSON 的存储键。"""
|
||||
return f"{self._prefix}:plan:{plan_id}"
|
||||
|
||||
async def save_plan(self, plan: Any) -> None:
|
||||
"""保存完整 TeamPlan(用于 resume 重建)。
|
||||
|
||||
Args:
|
||||
plan: TeamPlan 对象(需要有 to_dict() 方法)
|
||||
"""
|
||||
plan_id = plan.id
|
||||
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
|
||||
|
||||
# 内存降级
|
||||
self._memory_plans[plan_id] = plan_dict
|
||||
|
||||
# 尝试写入 Redis
|
||||
if self._redis is not None:
|
||||
try:
|
||||
await self._redis.set(
|
||||
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
|
||||
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
||||
"""加载完整 plan JSON。"""
|
||||
# 优先 Redis
|
||||
if self._redis is not None:
|
||||
try:
|
||||
raw = await self._redis.get(self._plan_key(plan_id))
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
# 内存降级
|
||||
return self._memory_plans.get(plan_id)
|
||||
|
||||
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
||||
"""保存阶段 checkpoint。
|
||||
|
||||
Args:
|
||||
plan_id: 计划 ID
|
||||
phase: PlanPhase 对象(需要有 id, name, status, result 属性)
|
||||
plan_status: 计划当前状态
|
||||
"""
|
||||
phase_id = getattr(phase, "id", str(phase))
|
||||
phase_name = getattr(phase, "name", "")
|
||||
phase_status = getattr(phase, "status", "")
|
||||
if hasattr(phase_status, "value"):
|
||||
phase_status = phase_status.value
|
||||
|
||||
# 序列化 phase.result(可能是 offloaded dict with _ref_key)
|
||||
phase_result = getattr(phase, "result", None)
|
||||
if phase_result is not None and not isinstance(phase_result, dict):
|
||||
phase_result = {"content": str(phase_result)}
|
||||
|
||||
data = CheckpointData(
|
||||
plan_id=plan_id,
|
||||
phase_id=phase_id,
|
||||
phase_name=phase_name,
|
||||
phase_status=str(phase_status),
|
||||
phase_result=phase_result,
|
||||
plan_status=plan_status,
|
||||
)
|
||||
|
||||
# 总是写入内存降级(保证一致性)
|
||||
self._memory.setdefault(plan_id, []).append(data)
|
||||
|
||||
# 尝试写入 Redis
|
||||
if self._redis is not None:
|
||||
try:
|
||||
score = time.time()
|
||||
pipe = self._redis.pipeline()
|
||||
pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl)
|
||||
pipe.zadd(self._index_key(plan_id), {phase_id: score})
|
||||
await pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.save Redis failed for plan {plan_id}, "
|
||||
f"phase {phase_id}: {e} — using memory fallback"
|
||||
)
|
||||
|
||||
async def load(self, plan_id: str) -> CheckpointData | None:
|
||||
"""加载最后完成的阶段 checkpoint。
|
||||
|
||||
Returns:
|
||||
最后一个 COMPLETED 阶段的 CheckpointData,或 None。
|
||||
"""
|
||||
checkpoints = await self.list_checkpoints(plan_id)
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# 返回最后一个 COMPLETED 阶段
|
||||
completed = [c for c in checkpoints if c.phase_status == "completed"]
|
||||
if not completed:
|
||||
return None
|
||||
return completed[-1]
|
||||
|
||||
async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]:
|
||||
"""列出某 plan 的所有 checkpoint(按保存时间排序)。"""
|
||||
# 优先从 Redis 读取
|
||||
if self._redis is not None:
|
||||
try:
|
||||
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||||
if not phase_ids:
|
||||
# Redis 无数据,检查内存
|
||||
return list(self._memory.get(plan_id, []))
|
||||
|
||||
results: list[CheckpointData] = []
|
||||
for phase_id in phase_ids:
|
||||
raw = await self._redis.get(self._key(plan_id, phase_id))
|
||||
if raw:
|
||||
data = json.loads(raw)
|
||||
results.append(CheckpointData.from_dict(data))
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.list_checkpoints Redis failed for "
|
||||
f"plan {plan_id}: {e} — using memory fallback"
|
||||
)
|
||||
|
||||
# 内存降级
|
||||
return list(self._memory.get(plan_id, []))
|
||||
|
||||
async def clear(self, plan_id: str) -> None:
|
||||
"""清除某 plan 的所有 checkpoint。"""
|
||||
# 清除内存
|
||||
self._memory.pop(plan_id, None)
|
||||
self._memory_plans.pop(plan_id, None)
|
||||
|
||||
# 清除 Redis
|
||||
if self._redis is not None:
|
||||
try:
|
||||
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||||
pipe = self._redis.pipeline()
|
||||
for phase_id in phase_ids:
|
||||
pipe.delete(self._key(plan_id, phase_id))
|
||||
pipe.delete(self._index_key(plan_id))
|
||||
pipe.delete(self._plan_key(plan_id))
|
||||
await pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
|
||||
)
|
||||
|
|
@ -430,7 +430,13 @@ async def _execute_team_collab(
|
|||
)
|
||||
|
||||
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
||||
orchestrator = TeamOrchestrator(team=team)
|
||||
# U7: Create checkpoint manager for crash recovery
|
||||
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
|
||||
|
||||
checkpoint = PipelineCheckpoint(
|
||||
redis_client=getattr(app_state, "working_redis_client", None)
|
||||
)
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
|
||||
# U4: Register active team so WS messages during execution route as interventions
|
||||
_register_active_team(session_id, team)
|
||||
result = await orchestrator.execute(routing_result.task_content)
|
||||
|
|
|
|||
|
|
@ -209,6 +209,95 @@ async def cancel_task(task_id: str, req: Request):
|
|||
return {"task_id": task_id, "status": "cancelled"}
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/resume")
|
||||
async def resume_task(task_id: str, req: Request):
|
||||
"""Resume a crashed pipeline from the last completed phase checkpoint.
|
||||
|
||||
Reconstructs the team from the saved plan's expert names, creates a new
|
||||
TeamOrchestrator with the checkpoint manager, and calls resume().
|
||||
"""
|
||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||
from agentkit.experts.router import ExpertTeamRouter
|
||||
from agentkit.experts.team import ExpertTeam
|
||||
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
|
||||
|
||||
app_state = req.app.state
|
||||
|
||||
# 1. Create checkpoint manager
|
||||
checkpoint = PipelineCheckpoint(
|
||||
redis_client=getattr(app_state, "working_redis_client", None)
|
||||
)
|
||||
|
||||
# 2. Load plan to get expert names
|
||||
plan_dict = await checkpoint.load_plan(task_id)
|
||||
if plan_dict is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No checkpoint found for task '{task_id}'",
|
||||
)
|
||||
|
||||
# 3. Extract unique expert names from plan
|
||||
expert_names: list[str] = []
|
||||
lead_name = plan_dict.get("lead_expert", "")
|
||||
if lead_name:
|
||||
expert_names.append(lead_name)
|
||||
for ph in plan_dict.get("phases", []):
|
||||
name = ph.get("assigned_expert", "")
|
||||
if name and name not in expert_names:
|
||||
expert_names.append(name)
|
||||
|
||||
if not expert_names:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot resume: no experts found in saved plan",
|
||||
)
|
||||
|
||||
# 4. Resolve expert configs via ExpertTeamRouter
|
||||
template_registry = getattr(app_state, "expert_template_registry", None)
|
||||
if template_registry is None:
|
||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||
|
||||
template_registry = ExpertTemplateRegistry()
|
||||
|
||||
team_router = ExpertTeamRouter(template_registry=template_registry)
|
||||
expert_configs = team_router.resolve_expert_configs(expert_names)
|
||||
if not expert_configs:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot resume: failed to resolve expert configs",
|
||||
)
|
||||
|
||||
lead_config = expert_configs[0]
|
||||
member_configs = expert_configs[1:] if len(expert_configs) > 1 else []
|
||||
|
||||
# 5. Create team + orchestrator
|
||||
team = ExpertTeam(
|
||||
pool=app_state.agent_pool,
|
||||
template_registry=template_registry,
|
||||
redis_client=getattr(app_state, "working_redis_client", None),
|
||||
)
|
||||
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
||||
|
||||
try:
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
|
||||
result = await orchestrator.resume(task_id)
|
||||
finally:
|
||||
try:
|
||||
await team.dissolve()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": result.get("status", "unknown"),
|
||||
"result": result.get("result"),
|
||||
"phase_results": {
|
||||
pid: pr if isinstance(pr, dict) else {"content": str(pr)}
|
||||
for pid, pr in (result.get("phase_results") or {}).items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/tasks/stream")
|
||||
async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||
"""Submit a task and stream ReAct events via SSE"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,561 @@
|
|||
"""PipelineCheckpoint 单元测试 (U7)
|
||||
|
||||
测试覆盖:
|
||||
- save/load 阶段 checkpoint(内存模式 + Redis mock)
|
||||
- save_plan/load_plan 完整 plan 序列化
|
||||
- list_checkpoints 按保存时间排序
|
||||
- clear 清除所有数据
|
||||
- Redis 异常降级到内存
|
||||
- resume 从 checkpoint 恢复执行
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||
from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan
|
||||
from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint
|
||||
from tests.unit.experts.test_team_orchestrator import (
|
||||
_make_mock_llm_gateway,
|
||||
_make_team_with_experts,
|
||||
)
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_phase(
|
||||
name: str = "test_phase",
|
||||
phase_id: str = "phase_1",
|
||||
status: PhaseStatus = PhaseStatus.COMPLETED,
|
||||
result: dict | None = None,
|
||||
) -> PlanPhase:
|
||||
"""创建测试用 PlanPhase"""
|
||||
return PlanPhase(
|
||||
id=phase_id,
|
||||
name=name,
|
||||
assigned_expert="test_expert",
|
||||
task_description="test task",
|
||||
status=status,
|
||||
result=result or {"content": "test output"},
|
||||
)
|
||||
|
||||
|
||||
def _make_plan(
|
||||
plan_id: str = "plan_1",
|
||||
task: str = "test task",
|
||||
phases: list[PlanPhase] | None = None,
|
||||
) -> TeamPlan:
|
||||
"""创建测试用 TeamPlan"""
|
||||
plan = TeamPlan(id=plan_id, task=task, lead_expert="lead")
|
||||
plan.status = PlanStatus.EXECUTING
|
||||
if phases:
|
||||
plan.phases = phases
|
||||
return plan
|
||||
|
||||
|
||||
def _make_mock_redis() -> AsyncMock:
|
||||
"""创建 mock Redis client,模拟 aioredis 行为。"""
|
||||
redis = AsyncMock()
|
||||
# 内部存储
|
||||
store: dict[str, str] = {}
|
||||
zsets: dict[str, dict[str, float]] = {}
|
||||
|
||||
async def _set(key, value, ex=None): # noqa: ARG001
|
||||
store[key] = value
|
||||
|
||||
async def _get(key):
|
||||
return store.get(key)
|
||||
|
||||
async def _zadd(key, mapping):
|
||||
zsets.setdefault(key, {}).update(mapping)
|
||||
|
||||
async def _zrange(key, start, stop):
|
||||
members = zsets.get(key, {})
|
||||
sorted_members = sorted(members.keys(), key=lambda m: members[m])
|
||||
if start == 0 and stop == -1:
|
||||
return sorted_members
|
||||
return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:]
|
||||
|
||||
async def _delete(*keys):
|
||||
count = 0
|
||||
for k in keys:
|
||||
if k in store:
|
||||
del store[k]
|
||||
count += 1
|
||||
if k in zsets:
|
||||
del zsets[k]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _pipeline():
|
||||
commands: list[tuple] = []
|
||||
|
||||
class _Pipe:
|
||||
def set(self, key, value, ex=None):
|
||||
commands.append(("set", key, value, ex))
|
||||
|
||||
def zadd(self, key, mapping):
|
||||
commands.append(("zadd", key, mapping))
|
||||
|
||||
def delete(self, *keys):
|
||||
commands.append(("delete", keys))
|
||||
|
||||
async def execute(self):
|
||||
for cmd in commands:
|
||||
if cmd[0] == "set":
|
||||
await _set(cmd[1], cmd[2], cmd[3])
|
||||
elif cmd[0] == "zadd":
|
||||
await _zadd(cmd[1], cmd[2])
|
||||
elif cmd[0] == "delete":
|
||||
await _delete(*cmd[1])
|
||||
|
||||
return _Pipe()
|
||||
|
||||
redis.set = AsyncMock(side_effect=_set)
|
||||
redis.get = AsyncMock(side_effect=_get)
|
||||
redis.zadd = AsyncMock(side_effect=_zadd)
|
||||
redis.zrange = AsyncMock(side_effect=_zrange)
|
||||
redis.delete = AsyncMock(side_effect=_delete)
|
||||
redis.pipeline = MagicMock(side_effect=_pipeline)
|
||||
return redis
|
||||
|
||||
|
||||
# ── PipelineCheckpoint 基础测试 ──────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointSaveLoad:
|
||||
"""save / load / list_checkpoints 基础测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_then_load_returns_last_completed(self):
|
||||
"""save 后 load 返回最后一个 COMPLETED 阶段"""
|
||||
cp = PipelineCheckpoint() # 内存模式
|
||||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
loaded = await cp.load("plan_1")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.plan_id == "plan_1"
|
||||
assert loaded.phase_id == "p1"
|
||||
assert loaded.phase_status == "completed"
|
||||
assert loaded.plan_status == "executing"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_returns_last_completed_of_multiple(self):
|
||||
"""3 个阶段完成后,load 返回第 3 个(最后一个 COMPLETED)"""
|
||||
cp = PipelineCheckpoint()
|
||||
for i in range(3):
|
||||
phase = _make_phase(
|
||||
name=f"phase_{i}",
|
||||
phase_id=f"p{i}",
|
||||
status=PhaseStatus.COMPLETED,
|
||||
result={"content": f"output_{i}"},
|
||||
)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
loaded = await cp.load("plan_1")
|
||||
assert loaded is not None
|
||||
assert loaded.phase_id == "p2"
|
||||
assert loaded.phase_result == {"content": "output_2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_nonexistent_plan_returns_none(self):
|
||||
"""load 不存在的 plan_id → None"""
|
||||
cp = PipelineCheckpoint()
|
||||
loaded = await cp.load("nonexistent")
|
||||
assert loaded is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_with_no_completed_returns_none(self):
|
||||
"""只有 FAILED 阶段时 load 返回 None"""
|
||||
cp = PipelineCheckpoint()
|
||||
phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
loaded = await cp.load("plan_1")
|
||||
assert loaded is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints_returns_all(self):
|
||||
"""list_checkpoints 返回所有 checkpoint"""
|
||||
cp = PipelineCheckpoint()
|
||||
for i in range(3):
|
||||
phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
checkpoints = await cp.list_checkpoints("plan_1")
|
||||
assert len(checkpoints) == 3
|
||||
assert all(c.plan_id == "plan_1" for c in checkpoints)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_serializes_phase_status_enum(self):
|
||||
"""save 正确序列化 PhaseStatus enum 为字符串"""
|
||||
cp = PipelineCheckpoint()
|
||||
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
checkpoints = await cp.list_checkpoints("plan_1")
|
||||
assert checkpoints[0].phase_status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_handles_non_dict_result(self):
|
||||
"""save 处理非 dict 的 phase.result"""
|
||||
cp = PipelineCheckpoint()
|
||||
phase = _make_phase()
|
||||
phase.result = "plain string result"
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
checkpoints = await cp.list_checkpoints("plan_1")
|
||||
assert checkpoints[0].phase_result == {"content": "plain string result"}
|
||||
|
||||
|
||||
# ── save_plan / load_plan 测试 ───────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointSavePlan:
|
||||
"""save_plan / load_plan 测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_plan_then_load_plan_roundtrip(self):
|
||||
"""save_plan 后 load_plan 返回 plan dict"""
|
||||
cp = PipelineCheckpoint()
|
||||
plan = _make_plan(
|
||||
plan_id="plan_42",
|
||||
task="build feature",
|
||||
phases=[
|
||||
_make_phase(name="phase1", phase_id="p1"),
|
||||
_make_phase(name="phase2", phase_id="p2"),
|
||||
],
|
||||
)
|
||||
|
||||
await cp.save_plan(plan)
|
||||
loaded = await cp.load_plan("plan_42")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded["id"] == "plan_42"
|
||||
assert loaded["task"] == "build feature"
|
||||
assert len(loaded["phases"]) == 2
|
||||
assert loaded["phases"][0]["name"] == "phase1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_plan_nonexistent_returns_none(self):
|
||||
"""load_plan 不存在的 plan_id → None"""
|
||||
cp = PipelineCheckpoint()
|
||||
loaded = await cp.load_plan("nonexistent")
|
||||
assert loaded is None
|
||||
|
||||
|
||||
# ── clear 测试 ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointClear:
|
||||
"""clear 测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_removes_all_data(self):
|
||||
"""clear 清除所有 checkpoint 和 plan 数据"""
|
||||
cp = PipelineCheckpoint()
|
||||
plan = _make_plan()
|
||||
phase = _make_phase()
|
||||
await cp.save_plan(plan)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
await cp.clear("plan_1")
|
||||
|
||||
assert await cp.load("plan_1") is None
|
||||
assert await cp.list_checkpoints("plan_1") == []
|
||||
assert await cp.load_plan("plan_1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_nonexistent_does_not_raise(self):
|
||||
"""clear 不存在的 plan_id 不抛异常"""
|
||||
cp = PipelineCheckpoint()
|
||||
await cp.clear("nonexistent") # should not raise
|
||||
|
||||
|
||||
# ── Redis 模式测试 ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointRedis:
|
||||
"""Redis 模式测试(使用 mock Redis)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_save_then_load(self):
|
||||
"""Redis 模式下 save/load 正常工作"""
|
||||
redis = _make_mock_redis()
|
||||
cp = PipelineCheckpoint(redis_client=redis)
|
||||
|
||||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
loaded = await cp.load("plan_1")
|
||||
assert loaded is not None
|
||||
assert loaded.phase_id == "p1"
|
||||
assert loaded.phase_status == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_save_plan_then_load_plan(self):
|
||||
"""Redis 模式下 save_plan/load_plan 正常工作"""
|
||||
redis = _make_mock_redis()
|
||||
cp = PipelineCheckpoint(redis_client=redis)
|
||||
|
||||
plan = _make_plan(plan_id="plan_99")
|
||||
await cp.save_plan(plan)
|
||||
loaded = await cp.load_plan("plan_99")
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded["id"] == "plan_99"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_clear_removes_all(self):
|
||||
"""Redis 模式下 clear 清除所有数据"""
|
||||
redis = _make_mock_redis()
|
||||
cp = PipelineCheckpoint(redis_client=redis)
|
||||
|
||||
phase = _make_phase()
|
||||
plan = _make_plan()
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
await cp.save_plan(plan)
|
||||
|
||||
await cp.clear("plan_1")
|
||||
|
||||
assert await cp.load("plan_1") is None
|
||||
assert await cp.load_plan("plan_1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_failure_falls_back_to_memory(self):
|
||||
"""Redis 异常时降级到内存,save/load 仍工作"""
|
||||
redis = _make_mock_redis()
|
||||
# 让 pipeline 抛异常(save 写 Redis 失败)
|
||||
redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost"))
|
||||
|
||||
cp = PipelineCheckpoint(redis_client=redis)
|
||||
|
||||
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||||
# save 不应抛异常,降级到内存
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
# Redis get/zrange 也失败时,load 从内存降级读取
|
||||
redis.get = AsyncMock(side_effect=Exception("Redis down"))
|
||||
redis.zrange = AsyncMock(side_effect=Exception("Redis down"))
|
||||
|
||||
loaded = await cp.load("plan_1")
|
||||
assert loaded is not None
|
||||
assert loaded.phase_id == phase.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_save_exception_does_not_block(self):
|
||||
"""save 时 Redis 异常不阻断执行"""
|
||||
redis = _make_mock_redis()
|
||||
# 让 pipeline 抛异常,但 zrange/get 仍正常工作
|
||||
redis.pipeline = MagicMock(side_effect=Exception("Redis write error"))
|
||||
|
||||
cp = PipelineCheckpoint(redis_client=redis)
|
||||
|
||||
phase = _make_phase()
|
||||
# 不应抛异常
|
||||
await cp.save("plan_1", phase, "executing")
|
||||
|
||||
# 内存降级中应有数据(Redis zrange 返回空 → 降级到内存)
|
||||
checkpoints = await cp.list_checkpoints("plan_1")
|
||||
assert len(checkpoints) == 1
|
||||
|
||||
|
||||
# ── CheckpointData 数据类测试 ─────────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointData:
|
||||
"""CheckpointData 序列化/反序列化测试"""
|
||||
|
||||
def test_to_dict_contains_all_fields(self):
|
||||
"""to_dict 包含所有字段"""
|
||||
data = CheckpointData(
|
||||
plan_id="p1",
|
||||
phase_id="ph1",
|
||||
phase_name="Phase 1",
|
||||
phase_status="completed",
|
||||
phase_result={"content": "output"},
|
||||
plan_status="executing",
|
||||
)
|
||||
d = data.to_dict()
|
||||
assert d["plan_id"] == "p1"
|
||||
assert d["phase_id"] == "ph1"
|
||||
assert d["phase_name"] == "Phase 1"
|
||||
assert d["phase_status"] == "completed"
|
||||
assert d["phase_result"] == {"content": "output"}
|
||||
assert d["plan_status"] == "executing"
|
||||
assert "saved_at" in d
|
||||
|
||||
def test_from_dict_roundtrip(self):
|
||||
"""from_dict 反序列化正确"""
|
||||
original = CheckpointData(
|
||||
plan_id="p1",
|
||||
phase_id="ph1",
|
||||
phase_name="Phase 1",
|
||||
phase_status="completed",
|
||||
phase_result={"content": "output"},
|
||||
plan_status="executing",
|
||||
)
|
||||
d = original.to_dict()
|
||||
restored = CheckpointData.from_dict(d)
|
||||
|
||||
assert restored.plan_id == original.plan_id
|
||||
assert restored.phase_id == original.phase_id
|
||||
assert restored.phase_name == original.phase_name
|
||||
assert restored.phase_status == original.phase_status
|
||||
assert restored.phase_result == original.phase_result
|
||||
assert restored.plan_status == original.plan_status
|
||||
|
||||
def test_from_dict_with_missing_fields_uses_defaults(self):
|
||||
"""from_dict 缺失字段时使用默认值"""
|
||||
data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"})
|
||||
assert data.plan_id == "p1"
|
||||
assert data.phase_id == "ph1"
|
||||
assert data.phase_name == ""
|
||||
assert data.phase_status == ""
|
||||
assert data.phase_result is None
|
||||
assert data.plan_status == ""
|
||||
|
||||
|
||||
# ── TeamOrchestrator.resume 集成测试 ─────────────────────
|
||||
|
||||
|
||||
class TestOrchestratorResume:
|
||||
"""TeamOrchestrator.resume 集成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_without_checkpoint_returns_failed(self):
|
||||
"""无 checkpoint manager 时 resume 返回 failed"""
|
||||
team = _make_team_with_experts()
|
||||
orchestrator = TeamOrchestrator(team=team) # no checkpoint
|
||||
|
||||
result = await orchestrator.resume("plan_1")
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "No checkpoint manager" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_nonexistent_plan_returns_failed(self):
|
||||
"""resume 不存在的 plan_id 返回 failed"""
|
||||
team = _make_team_with_experts()
|
||||
cp = PipelineCheckpoint()
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||
|
||||
result = await orchestrator.resume("nonexistent")
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "No checkpoint" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_skips_completed_phases(self):
|
||||
"""resume 跳过已完成阶段,只执行未完成阶段"""
|
||||
# 创建一个有 2 个阶段的 plan,阶段 1 已完成,阶段 2 依赖阶段 1
|
||||
phase1 = PlanPhase(
|
||||
id="p1",
|
||||
name="phase1",
|
||||
assigned_expert="lead",
|
||||
task_description="task 1",
|
||||
status=PhaseStatus.COMPLETED,
|
||||
result={"content": "phase1 output"},
|
||||
)
|
||||
phase2 = PlanPhase(
|
||||
id="p2",
|
||||
name="phase2",
|
||||
assigned_expert="member1",
|
||||
task_description="task 2",
|
||||
depends_on=["p1"],
|
||||
status=PhaseStatus.PENDING,
|
||||
)
|
||||
plan = _make_plan(
|
||||
plan_id="plan_resume",
|
||||
task="test resume",
|
||||
phases=[phase1, phase2],
|
||||
)
|
||||
|
||||
# 保存 plan + checkpoint for phase1
|
||||
cp = PipelineCheckpoint()
|
||||
await cp.save_plan(plan)
|
||||
await cp.save("plan_resume", phase1, "executing")
|
||||
|
||||
# 创建 team + orchestrator
|
||||
team = _make_team_with_experts(expert_names=["lead", "member1"])
|
||||
# 设置 mock LLM gateway 用于 synthesis
|
||||
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||||
team._experts["lead"].agent._llm_gateway = gateway
|
||||
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||
|
||||
result = await orchestrator.resume("plan_resume")
|
||||
|
||||
# 验证结果
|
||||
assert result["status"] == "completed"
|
||||
# phase1 的结果应从 checkpoint 恢复
|
||||
assert "p1" in result["phase_results"]
|
||||
# phase2 应被执行
|
||||
assert "p2" in result["phase_results"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_all_phases_completed_skips_execution(self):
|
||||
"""resume 时所有阶段都已完成 → 直接 synthesis"""
|
||||
phase1 = PlanPhase(
|
||||
id="p1",
|
||||
name="phase1",
|
||||
assigned_expert="lead",
|
||||
task_description="task 1",
|
||||
status=PhaseStatus.COMPLETED,
|
||||
result={"content": "phase1 output"},
|
||||
)
|
||||
plan = _make_plan(
|
||||
plan_id="plan_all_done",
|
||||
task="test resume all done",
|
||||
phases=[phase1],
|
||||
)
|
||||
|
||||
cp = PipelineCheckpoint()
|
||||
await cp.save_plan(plan)
|
||||
await cp.save("plan_all_done", phase1, "executing")
|
||||
|
||||
team = _make_team_with_experts()
|
||||
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||||
team._experts["lead"].agent._llm_gateway = gateway
|
||||
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||
result = await orchestrator.resume("plan_all_done")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
assert "p1" in result["phase_results"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_no_active_expert_returns_failed(self):
|
||||
"""resume 时无活跃专家返回 failed"""
|
||||
phase1 = PlanPhase(
|
||||
id="p1",
|
||||
name="phase1",
|
||||
assigned_expert="lead",
|
||||
task_description="task 1",
|
||||
status=PhaseStatus.PENDING,
|
||||
)
|
||||
plan = _make_plan(
|
||||
plan_id="plan_no_expert",
|
||||
task="test",
|
||||
phases=[phase1],
|
||||
)
|
||||
|
||||
cp = PipelineCheckpoint()
|
||||
await cp.save_plan(plan)
|
||||
|
||||
# 创建无活跃专家的 team
|
||||
team = _make_team_with_experts()
|
||||
for expert in team._experts.values():
|
||||
expert.is_active = False
|
||||
|
||||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||
result = await orchestrator.resume("plan_no_expert")
|
||||
|
||||
assert result["status"] == "failed"
|
||||
assert "No active expert" in result["error"]
|
||||
Loading…
Reference in New Issue