feat(orchestrator): add pipeline checkpoint and crash recovery (U7)
Add PipelineCheckpoint for stage-level crash recovery with Redis-first
+ memory fallback. TeamOrchestrator saves checkpoints after each phase
finalizes and supports resume(plan_id) to continue from the last
completed phase. New POST /api/v1/tasks/{id}/resume endpoint recreates
the team from saved plan and calls resume.
This commit is contained in:
parent
3dfda904d7
commit
dfd188b1a4
|
|
@ -73,7 +73,12 @@ class TeamOrchestrator:
|
||||||
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
DEFAULT_MAX_CONCURRENT_PHASES = 3 # 同层最大并发阶段数,避免 LLM 限流洪峰
|
||||||
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
STOP_COMMANDS = frozenset({"/stop", "停止", "stop", "结束"})
|
||||||
|
|
||||||
def __init__(self, team: ExpertTeam, max_concurrent_phases: int | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
team: ExpertTeam,
|
||||||
|
max_concurrent_phases: int | None = None,
|
||||||
|
checkpoint: Any = None,
|
||||||
|
) -> None:
|
||||||
self._team = team
|
self._team = team
|
||||||
# Track temporary agent names created for context isolation (KTD3)
|
# Track temporary agent names created for context isolation (KTD3)
|
||||||
# Maps phase_id -> temp_agent_name for cleanup
|
# Maps phase_id -> temp_agent_name for cleanup
|
||||||
|
|
@ -86,6 +91,8 @@ class TeamOrchestrator:
|
||||||
# U2: 并发限制 — 同层并行阶段加 Semaphore,避免 LLM 限流洪峰
|
# U2: 并发限制 — 同层并行阶段加 Semaphore,避免 LLM 限流洪峰
|
||||||
limit = max_concurrent_phases or self.DEFAULT_MAX_CONCURRENT_PHASES
|
limit = max_concurrent_phases or self.DEFAULT_MAX_CONCURRENT_PHASES
|
||||||
self._phase_semaphore = asyncio.Semaphore(limit)
|
self._phase_semaphore = asyncio.Semaphore(limit)
|
||||||
|
# U7: Pipeline checkpoint for crash recovery
|
||||||
|
self._checkpoint = checkpoint
|
||||||
|
|
||||||
async def execute(self, task: str) -> dict[str, Any]:
|
async def execute(self, task: str) -> dict[str, Any]:
|
||||||
"""Execute a task in pipeline mode.
|
"""Execute a task in pipeline mode.
|
||||||
|
|
@ -173,10 +180,31 @@ class TeamOrchestrator:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# U7: Save plan for potential resume (before execution starts)
|
||||||
|
if self._checkpoint is not None:
|
||||||
|
try:
|
||||||
|
await self._checkpoint.save_plan(plan)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Checkpoint save_plan failed: {e}")
|
||||||
|
|
||||||
# 4. Set EXECUTING status, execute phases
|
# 4. Set EXECUTING status, execute phases
|
||||||
self._team.set_status(TeamStatus.EXECUTING)
|
self._team.set_status(TeamStatus.EXECUTING)
|
||||||
phase_results: dict[str, dict[str, Any]] = {}
|
phase_results: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
return await self._run_pipeline(lead, plan, phase_results, task)
|
||||||
|
|
||||||
|
async def _run_pipeline(
|
||||||
|
self,
|
||||||
|
lead: Expert,
|
||||||
|
plan: TeamPlan,
|
||||||
|
phase_results: dict[str, dict[str, Any]],
|
||||||
|
task: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Execute the pipeline loop: run pending phases, synthesize, return result.
|
||||||
|
|
||||||
|
Shared by execute() and resume(). phase_results may be pre-populated
|
||||||
|
by resume() with completed phase outputs.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Execute layers sequentially, phases within layer in parallel.
|
# Execute layers sequentially, phases within layer in parallel.
|
||||||
# U3: while-loop re-computes topological_sort each iteration so
|
# U3: while-loop re-computes topological_sort each iteration so
|
||||||
|
|
@ -234,6 +262,13 @@ class TeamOrchestrator:
|
||||||
else:
|
else:
|
||||||
phase_results[ph.id] = result
|
phase_results[ph.id] = result
|
||||||
|
|
||||||
|
# U7: Save checkpoint after phase finalizes (success or failure)
|
||||||
|
if self._checkpoint is not None:
|
||||||
|
try:
|
||||||
|
await self._checkpoint.save(plan.id, ph, plan.status.value)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Checkpoint save failed for phase {ph.id}: {e}")
|
||||||
|
|
||||||
# U3: Divergence detection — check completed phases for conflicts
|
# U3: Divergence detection — check completed phases for conflicts
|
||||||
# and dynamically insert DEBATE phases if needed
|
# and dynamically insert DEBATE phases if needed
|
||||||
if self._debate_count < self.MAX_DEBATES:
|
if self._debate_count < self.MAX_DEBATES:
|
||||||
|
|
@ -290,6 +325,82 @@ class TeamOrchestrator:
|
||||||
await self._broadcast_event("team_dissolved", {"team_id": self._team.team_id})
|
await self._broadcast_event("team_dissolved", {"team_id": self._team.team_id})
|
||||||
return await self._fallback_to_single_agent(task, plan, phase_results)
|
return await self._fallback_to_single_agent(task, plan, phase_results)
|
||||||
|
|
||||||
|
async def resume(self, plan_id: str) -> dict[str, Any]:
|
||||||
|
"""Resume a crashed pipeline from the last completed phase checkpoint.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Load plan + checkpoints from PipelineCheckpoint
|
||||||
|
2. Reconstruct TeamPlan, mark completed phases as COMPLETED
|
||||||
|
3. Pre-populate phase_results with checkpoint data
|
||||||
|
4. Call _run_pipeline to continue from next pending phase
|
||||||
|
|
||||||
|
Returns same dict shape as execute(). If no checkpoint found, returns
|
||||||
|
a failed result.
|
||||||
|
"""
|
||||||
|
if self._checkpoint is None:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"result": None,
|
||||||
|
"phase_results": {},
|
||||||
|
"error": "No checkpoint manager configured",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Load plan
|
||||||
|
plan_dict = await self._checkpoint.load_plan(plan_id)
|
||||||
|
if plan_dict is None:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"result": None,
|
||||||
|
"phase_results": {},
|
||||||
|
"error": f"No checkpoint found for plan '{plan_id}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 2. Reconstruct TeamPlan
|
||||||
|
plan = TeamPlan.from_dict(plan_dict)
|
||||||
|
task = plan.task
|
||||||
|
|
||||||
|
# 3. Load checkpoints, mark completed phases
|
||||||
|
checkpoints = await self._checkpoint.list_checkpoints(plan_id)
|
||||||
|
phase_results: dict[str, dict[str, Any]] = {}
|
||||||
|
completed_phase_ids: set[str] = set()
|
||||||
|
|
||||||
|
for cp in checkpoints:
|
||||||
|
if cp.phase_status == "completed":
|
||||||
|
completed_phase_ids.add(cp.phase_id)
|
||||||
|
# Restore phase result from checkpoint
|
||||||
|
if cp.phase_result:
|
||||||
|
phase_results[cp.phase_id] = cp.phase_result
|
||||||
|
|
||||||
|
# Apply checkpoint state to plan phases
|
||||||
|
for ph in plan.phases:
|
||||||
|
if ph.id in completed_phase_ids:
|
||||||
|
ph.status = PhaseStatus.COMPLETED
|
||||||
|
if ph.id in phase_results and phase_results[ph.id]:
|
||||||
|
ph.result = phase_results[ph.id]
|
||||||
|
# PENDING phases remain PENDING — will be executed by _run_pipeline
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Resuming plan {plan_id}: {len(completed_phase_ids)} completed, "
|
||||||
|
f"{len(plan.phases) - len(completed_phase_ids)} pending"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Get lead expert
|
||||||
|
lead = self._team.lead_expert
|
||||||
|
if not lead or not lead.is_active:
|
||||||
|
active = self._team.active_experts
|
||||||
|
if not active:
|
||||||
|
return {
|
||||||
|
"status": "failed",
|
||||||
|
"result": None,
|
||||||
|
"phase_results": phase_results,
|
||||||
|
"error": "No active expert available",
|
||||||
|
}
|
||||||
|
lead = active[0]
|
||||||
|
|
||||||
|
# 5. Resume execution
|
||||||
|
self._team.set_status(TeamStatus.EXECUTING)
|
||||||
|
return await self._run_pipeline(lead, plan, phase_results, task)
|
||||||
|
|
||||||
async def _decompose_task(self, lead: Expert, task: str) -> list[PlanPhase]:
|
async def _decompose_task(self, lead: Expert, task: str) -> list[PlanPhase]:
|
||||||
"""Lead Expert decomposes task into phases using LLM.
|
"""Lead Expert decomposes task into phases using LLM.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,237 @@
|
||||||
|
"""PipelineCheckpoint — 阶段级检查点与断点续跑 (U7)
|
||||||
|
|
||||||
|
在 TeamOrchestrator 阶段完成后保存 checkpoint 到 Redis(或内存降级),
|
||||||
|
崩溃后可通过 resume(plan_id) 从最后完成阶段恢复。
|
||||||
|
|
||||||
|
复用 PipelineStateRedis 的 _safe_redis_call 模式:
|
||||||
|
Redis 失败时降级到内存 dict,不阻断执行。
|
||||||
|
|
||||||
|
键命名:agentkit:pipeline:checkpoint:{plan_id}:{phase_id}
|
||||||
|
TTL:7 天(与 PipelineStateRedis._TTL_SECONDS 一致)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||||||
|
_KEY_PREFIX = "agentkit:pipeline:checkpoint"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CheckpointData:
|
||||||
|
"""单个阶段的 checkpoint 数据。"""
|
||||||
|
|
||||||
|
plan_id: str
|
||||||
|
phase_id: str
|
||||||
|
phase_name: str
|
||||||
|
phase_status: str
|
||||||
|
phase_result: dict[str, Any] | None = None
|
||||||
|
plan_status: str = ""
|
||||||
|
saved_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> CheckpointData:
|
||||||
|
return cls(
|
||||||
|
plan_id=data.get("plan_id", ""),
|
||||||
|
phase_id=data.get("phase_id", ""),
|
||||||
|
phase_name=data.get("phase_name", ""),
|
||||||
|
phase_status=data.get("phase_status", ""),
|
||||||
|
phase_result=data.get("phase_result"),
|
||||||
|
plan_status=data.get("plan_status", ""),
|
||||||
|
saved_at=data.get("saved_at", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineCheckpoint:
|
||||||
|
"""阶段级检查点存储 — Redis 优先,内存降级。
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
checkpoint = PipelineCheckpoint(redis_client=redis)
|
||||||
|
await checkpoint.save(plan.id, phase, plan.status.value)
|
||||||
|
last = await checkpoint.load(plan.id)
|
||||||
|
if last:
|
||||||
|
# resume from last completed phase
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client: Any = None,
|
||||||
|
prefix: str = _KEY_PREFIX,
|
||||||
|
ttl_seconds: int = _TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
self._redis = redis_client
|
||||||
|
self._prefix = prefix
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
# 内存降级存储:plan_id → list of CheckpointData
|
||||||
|
self._memory: dict[str, list[CheckpointData]] = {}
|
||||||
|
# 内存降级存储:plan_id → plan dict (for resume)
|
||||||
|
self._memory_plans: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def _key(self, plan_id: str, phase_id: str) -> str:
|
||||||
|
return f"{self._prefix}:{plan_id}:{phase_id}"
|
||||||
|
|
||||||
|
def _index_key(self, plan_id: str) -> str:
|
||||||
|
"""Redis Sorted Set 索引键,用于列出某 plan 的所有 checkpoint。"""
|
||||||
|
return f"{self._prefix}:index:{plan_id}"
|
||||||
|
|
||||||
|
def _plan_key(self, plan_id: str) -> str:
|
||||||
|
"""完整 plan JSON 的存储键。"""
|
||||||
|
return f"{self._prefix}:plan:{plan_id}"
|
||||||
|
|
||||||
|
async def save_plan(self, plan: Any) -> None:
|
||||||
|
"""保存完整 TeamPlan(用于 resume 重建)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: TeamPlan 对象(需要有 to_dict() 方法)
|
||||||
|
"""
|
||||||
|
plan_id = plan.id
|
||||||
|
plan_dict = plan.to_dict() if hasattr(plan, "to_dict") else {"id": plan_id}
|
||||||
|
|
||||||
|
# 内存降级
|
||||||
|
self._memory_plans[plan_id] = plan_dict
|
||||||
|
|
||||||
|
# 尝试写入 Redis
|
||||||
|
if self._redis is not None:
|
||||||
|
try:
|
||||||
|
await self._redis.set(
|
||||||
|
self._plan_key(plan_id), json.dumps(plan_dict), ex=self._ttl
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PipelineCheckpoint.save_plan Redis failed for plan {plan_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_plan(self, plan_id: str) -> dict[str, Any] | None:
|
||||||
|
"""加载完整 plan JSON。"""
|
||||||
|
# 优先 Redis
|
||||||
|
if self._redis is not None:
|
||||||
|
try:
|
||||||
|
raw = await self._redis.get(self._plan_key(plan_id))
|
||||||
|
if raw:
|
||||||
|
return json.loads(raw)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PipelineCheckpoint.load_plan Redis failed for plan {plan_id}: {e}"
|
||||||
|
)
|
||||||
|
# 内存降级
|
||||||
|
return self._memory_plans.get(plan_id)
|
||||||
|
|
||||||
|
async def save(self, plan_id: str, phase: Any, plan_status: str) -> None:
|
||||||
|
"""保存阶段 checkpoint。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan_id: 计划 ID
|
||||||
|
phase: PlanPhase 对象(需要有 id, name, status, result 属性)
|
||||||
|
plan_status: 计划当前状态
|
||||||
|
"""
|
||||||
|
phase_id = getattr(phase, "id", str(phase))
|
||||||
|
phase_name = getattr(phase, "name", "")
|
||||||
|
phase_status = getattr(phase, "status", "")
|
||||||
|
if hasattr(phase_status, "value"):
|
||||||
|
phase_status = phase_status.value
|
||||||
|
|
||||||
|
# 序列化 phase.result(可能是 offloaded dict with _ref_key)
|
||||||
|
phase_result = getattr(phase, "result", None)
|
||||||
|
if phase_result is not None and not isinstance(phase_result, dict):
|
||||||
|
phase_result = {"content": str(phase_result)}
|
||||||
|
|
||||||
|
data = CheckpointData(
|
||||||
|
plan_id=plan_id,
|
||||||
|
phase_id=phase_id,
|
||||||
|
phase_name=phase_name,
|
||||||
|
phase_status=str(phase_status),
|
||||||
|
phase_result=phase_result,
|
||||||
|
plan_status=plan_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 总是写入内存降级(保证一致性)
|
||||||
|
self._memory.setdefault(plan_id, []).append(data)
|
||||||
|
|
||||||
|
# 尝试写入 Redis
|
||||||
|
if self._redis is not None:
|
||||||
|
try:
|
||||||
|
score = time.time()
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
pipe.set(self._key(plan_id, phase_id), json.dumps(data.to_dict()), ex=self._ttl)
|
||||||
|
pipe.zadd(self._index_key(plan_id), {phase_id: score})
|
||||||
|
await pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PipelineCheckpoint.save Redis failed for plan {plan_id}, "
|
||||||
|
f"phase {phase_id}: {e} — using memory fallback"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load(self, plan_id: str) -> CheckpointData | None:
|
||||||
|
"""加载最后完成的阶段 checkpoint。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最后一个 COMPLETED 阶段的 CheckpointData,或 None。
|
||||||
|
"""
|
||||||
|
checkpoints = await self.list_checkpoints(plan_id)
|
||||||
|
if not checkpoints:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 返回最后一个 COMPLETED 阶段
|
||||||
|
completed = [c for c in checkpoints if c.phase_status == "completed"]
|
||||||
|
if not completed:
|
||||||
|
return None
|
||||||
|
return completed[-1]
|
||||||
|
|
||||||
|
async def list_checkpoints(self, plan_id: str) -> list[CheckpointData]:
|
||||||
|
"""列出某 plan 的所有 checkpoint(按保存时间排序)。"""
|
||||||
|
# 优先从 Redis 读取
|
||||||
|
if self._redis is not None:
|
||||||
|
try:
|
||||||
|
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||||||
|
if not phase_ids:
|
||||||
|
# Redis 无数据,检查内存
|
||||||
|
return list(self._memory.get(plan_id, []))
|
||||||
|
|
||||||
|
results: list[CheckpointData] = []
|
||||||
|
for phase_id in phase_ids:
|
||||||
|
raw = await self._redis.get(self._key(plan_id, phase_id))
|
||||||
|
if raw:
|
||||||
|
data = json.loads(raw)
|
||||||
|
results.append(CheckpointData.from_dict(data))
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PipelineCheckpoint.list_checkpoints Redis failed for "
|
||||||
|
f"plan {plan_id}: {e} — using memory fallback"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 内存降级
|
||||||
|
return list(self._memory.get(plan_id, []))
|
||||||
|
|
||||||
|
async def clear(self, plan_id: str) -> None:
|
||||||
|
"""清除某 plan 的所有 checkpoint。"""
|
||||||
|
# 清除内存
|
||||||
|
self._memory.pop(plan_id, None)
|
||||||
|
self._memory_plans.pop(plan_id, None)
|
||||||
|
|
||||||
|
# 清除 Redis
|
||||||
|
if self._redis is not None:
|
||||||
|
try:
|
||||||
|
phase_ids = await self._redis.zrange(self._index_key(plan_id), 0, -1)
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
for phase_id in phase_ids:
|
||||||
|
pipe.delete(self._key(plan_id, phase_id))
|
||||||
|
pipe.delete(self._index_key(plan_id))
|
||||||
|
pipe.delete(self._plan_key(plan_id))
|
||||||
|
await pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"PipelineCheckpoint.clear Redis failed for plan {plan_id}: {e}"
|
||||||
|
)
|
||||||
|
|
@ -430,7 +430,13 @@ async def _execute_team_collab(
|
||||||
)
|
)
|
||||||
|
|
||||||
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
||||||
orchestrator = TeamOrchestrator(team=team)
|
# U7: Create checkpoint manager for crash recovery
|
||||||
|
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
|
||||||
|
|
||||||
|
checkpoint = PipelineCheckpoint(
|
||||||
|
redis_client=getattr(app_state, "working_redis_client", None)
|
||||||
|
)
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
|
||||||
# U4: Register active team so WS messages during execution route as interventions
|
# U4: Register active team so WS messages during execution route as interventions
|
||||||
_register_active_team(session_id, team)
|
_register_active_team(session_id, team)
|
||||||
result = await orchestrator.execute(routing_result.task_content)
|
result = await orchestrator.execute(routing_result.task_content)
|
||||||
|
|
|
||||||
|
|
@ -209,6 +209,95 @@ async def cancel_task(task_id: str, req: Request):
|
||||||
return {"task_id": task_id, "status": "cancelled"}
|
return {"task_id": task_id, "status": "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tasks/{task_id}/resume")
|
||||||
|
async def resume_task(task_id: str, req: Request):
|
||||||
|
"""Resume a crashed pipeline from the last completed phase checkpoint.
|
||||||
|
|
||||||
|
Reconstructs the team from the saved plan's expert names, creates a new
|
||||||
|
TeamOrchestrator with the checkpoint manager, and calls resume().
|
||||||
|
"""
|
||||||
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||||
|
from agentkit.experts.router import ExpertTeamRouter
|
||||||
|
from agentkit.experts.team import ExpertTeam
|
||||||
|
from agentkit.orchestrator.checkpoint import PipelineCheckpoint
|
||||||
|
|
||||||
|
app_state = req.app.state
|
||||||
|
|
||||||
|
# 1. Create checkpoint manager
|
||||||
|
checkpoint = PipelineCheckpoint(
|
||||||
|
redis_client=getattr(app_state, "working_redis_client", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Load plan to get expert names
|
||||||
|
plan_dict = await checkpoint.load_plan(task_id)
|
||||||
|
if plan_dict is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"No checkpoint found for task '{task_id}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Extract unique expert names from plan
|
||||||
|
expert_names: list[str] = []
|
||||||
|
lead_name = plan_dict.get("lead_expert", "")
|
||||||
|
if lead_name:
|
||||||
|
expert_names.append(lead_name)
|
||||||
|
for ph in plan_dict.get("phases", []):
|
||||||
|
name = ph.get("assigned_expert", "")
|
||||||
|
if name and name not in expert_names:
|
||||||
|
expert_names.append(name)
|
||||||
|
|
||||||
|
if not expert_names:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Cannot resume: no experts found in saved plan",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Resolve expert configs via ExpertTeamRouter
|
||||||
|
template_registry = getattr(app_state, "expert_template_registry", None)
|
||||||
|
if template_registry is None:
|
||||||
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||||
|
|
||||||
|
template_registry = ExpertTemplateRegistry()
|
||||||
|
|
||||||
|
team_router = ExpertTeamRouter(template_registry=template_registry)
|
||||||
|
expert_configs = team_router.resolve_expert_configs(expert_names)
|
||||||
|
if not expert_configs:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Cannot resume: failed to resolve expert configs",
|
||||||
|
)
|
||||||
|
|
||||||
|
lead_config = expert_configs[0]
|
||||||
|
member_configs = expert_configs[1:] if len(expert_configs) > 1 else []
|
||||||
|
|
||||||
|
# 5. Create team + orchestrator
|
||||||
|
team = ExpertTeam(
|
||||||
|
pool=app_state.agent_pool,
|
||||||
|
template_registry=template_registry,
|
||||||
|
redis_client=getattr(app_state, "working_redis_client", None),
|
||||||
|
)
|
||||||
|
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint)
|
||||||
|
result = await orchestrator.resume(task_id)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await team.dissolve()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_id": task_id,
|
||||||
|
"status": result.get("status", "unknown"),
|
||||||
|
"result": result.get("result"),
|
||||||
|
"phase_results": {
|
||||||
|
pid: pr if isinstance(pr, dict) else {"content": str(pr)}
|
||||||
|
for pid, pr in (result.get("phase_results") or {}).items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tasks/stream")
|
@router.post("/tasks/stream")
|
||||||
async def stream_task(request: SubmitTaskRequest, req: Request):
|
async def stream_task(request: SubmitTaskRequest, req: Request):
|
||||||
"""Submit a task and stream ReAct events via SSE"""
|
"""Submit a task and stream ReAct events via SSE"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,561 @@
|
||||||
|
"""PipelineCheckpoint 单元测试 (U7)
|
||||||
|
|
||||||
|
测试覆盖:
|
||||||
|
- save/load 阶段 checkpoint(内存模式 + Redis mock)
|
||||||
|
- save_plan/load_plan 完整 plan 序列化
|
||||||
|
- list_checkpoints 按保存时间排序
|
||||||
|
- clear 清除所有数据
|
||||||
|
- Redis 异常降级到内存
|
||||||
|
- resume 从 checkpoint 恢复执行
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||||
|
from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan
|
||||||
|
from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint
|
||||||
|
from tests.unit.experts.test_team_orchestrator import (
|
||||||
|
_make_mock_llm_gateway,
|
||||||
|
_make_team_with_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 辅助函数 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_phase(
|
||||||
|
name: str = "test_phase",
|
||||||
|
phase_id: str = "phase_1",
|
||||||
|
status: PhaseStatus = PhaseStatus.COMPLETED,
|
||||||
|
result: dict | None = None,
|
||||||
|
) -> PlanPhase:
|
||||||
|
"""创建测试用 PlanPhase"""
|
||||||
|
return PlanPhase(
|
||||||
|
id=phase_id,
|
||||||
|
name=name,
|
||||||
|
assigned_expert="test_expert",
|
||||||
|
task_description="test task",
|
||||||
|
status=status,
|
||||||
|
result=result or {"content": "test output"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_plan(
|
||||||
|
plan_id: str = "plan_1",
|
||||||
|
task: str = "test task",
|
||||||
|
phases: list[PlanPhase] | None = None,
|
||||||
|
) -> TeamPlan:
|
||||||
|
"""创建测试用 TeamPlan"""
|
||||||
|
plan = TeamPlan(id=plan_id, task=task, lead_expert="lead")
|
||||||
|
plan.status = PlanStatus.EXECUTING
|
||||||
|
if phases:
|
||||||
|
plan.phases = phases
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_redis() -> AsyncMock:
|
||||||
|
"""创建 mock Redis client,模拟 aioredis 行为。"""
|
||||||
|
redis = AsyncMock()
|
||||||
|
# 内部存储
|
||||||
|
store: dict[str, str] = {}
|
||||||
|
zsets: dict[str, dict[str, float]] = {}
|
||||||
|
|
||||||
|
async def _set(key, value, ex=None): # noqa: ARG001
|
||||||
|
store[key] = value
|
||||||
|
|
||||||
|
async def _get(key):
|
||||||
|
return store.get(key)
|
||||||
|
|
||||||
|
async def _zadd(key, mapping):
|
||||||
|
zsets.setdefault(key, {}).update(mapping)
|
||||||
|
|
||||||
|
async def _zrange(key, start, stop):
|
||||||
|
members = zsets.get(key, {})
|
||||||
|
sorted_members = sorted(members.keys(), key=lambda m: members[m])
|
||||||
|
if start == 0 and stop == -1:
|
||||||
|
return sorted_members
|
||||||
|
return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:]
|
||||||
|
|
||||||
|
async def _delete(*keys):
|
||||||
|
count = 0
|
||||||
|
for k in keys:
|
||||||
|
if k in store:
|
||||||
|
del store[k]
|
||||||
|
count += 1
|
||||||
|
if k in zsets:
|
||||||
|
del zsets[k]
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _pipeline():
|
||||||
|
commands: list[tuple] = []
|
||||||
|
|
||||||
|
class _Pipe:
|
||||||
|
def set(self, key, value, ex=None):
|
||||||
|
commands.append(("set", key, value, ex))
|
||||||
|
|
||||||
|
def zadd(self, key, mapping):
|
||||||
|
commands.append(("zadd", key, mapping))
|
||||||
|
|
||||||
|
def delete(self, *keys):
|
||||||
|
commands.append(("delete", keys))
|
||||||
|
|
||||||
|
async def execute(self):
|
||||||
|
for cmd in commands:
|
||||||
|
if cmd[0] == "set":
|
||||||
|
await _set(cmd[1], cmd[2], cmd[3])
|
||||||
|
elif cmd[0] == "zadd":
|
||||||
|
await _zadd(cmd[1], cmd[2])
|
||||||
|
elif cmd[0] == "delete":
|
||||||
|
await _delete(*cmd[1])
|
||||||
|
|
||||||
|
return _Pipe()
|
||||||
|
|
||||||
|
redis.set = AsyncMock(side_effect=_set)
|
||||||
|
redis.get = AsyncMock(side_effect=_get)
|
||||||
|
redis.zadd = AsyncMock(side_effect=_zadd)
|
||||||
|
redis.zrange = AsyncMock(side_effect=_zrange)
|
||||||
|
redis.delete = AsyncMock(side_effect=_delete)
|
||||||
|
redis.pipeline = MagicMock(side_effect=_pipeline)
|
||||||
|
return redis
|
||||||
|
|
||||||
|
|
||||||
|
# ── PipelineCheckpoint 基础测试 ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointSaveLoad:
|
||||||
|
"""save / load / list_checkpoints 基础测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_then_load_returns_last_completed(self):
|
||||||
|
"""save 后 load 返回最后一个 COMPLETED 阶段"""
|
||||||
|
cp = PipelineCheckpoint() # 内存模式
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.plan_id == "plan_1"
|
||||||
|
assert loaded.phase_id == "p1"
|
||||||
|
assert loaded.phase_status == "completed"
|
||||||
|
assert loaded.plan_status == "executing"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_returns_last_completed_of_multiple(self):
|
||||||
|
"""3 个阶段完成后,load 返回第 3 个(最后一个 COMPLETED)"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
for i in range(3):
|
||||||
|
phase = _make_phase(
|
||||||
|
name=f"phase_{i}",
|
||||||
|
phase_id=f"p{i}",
|
||||||
|
status=PhaseStatus.COMPLETED,
|
||||||
|
result={"content": f"output_{i}"},
|
||||||
|
)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == "p2"
|
||||||
|
assert loaded.phase_result == {"content": "output_2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_nonexistent_plan_returns_none(self):
|
||||||
|
"""load 不存在的 plan_id → None"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
loaded = await cp.load("nonexistent")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_with_no_completed_returns_none(self):
|
||||||
|
"""只有 FAILED 阶段时 load 返回 None"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_returns_all(self):
|
||||||
|
"""list_checkpoints 返回所有 checkpoint"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
for i in range(3):
|
||||||
|
phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert len(checkpoints) == 3
|
||||||
|
assert all(c.plan_id == "plan_1" for c in checkpoints)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_serializes_phase_status_enum(self):
|
||||||
|
"""save 正确序列化 PhaseStatus enum 为字符串"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert checkpoints[0].phase_status == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_handles_non_dict_result(self):
|
||||||
|
"""save 处理非 dict 的 phase.result"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
phase = _make_phase()
|
||||||
|
phase.result = "plain string result"
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert checkpoints[0].phase_result == {"content": "plain string result"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── save_plan / load_plan 测试 ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointSavePlan:
|
||||||
|
"""save_plan / load_plan 测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_plan_then_load_plan_roundtrip(self):
|
||||||
|
"""save_plan 后 load_plan 返回 plan dict"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
plan = _make_plan(
|
||||||
|
plan_id="plan_42",
|
||||||
|
task="build feature",
|
||||||
|
phases=[
|
||||||
|
_make_phase(name="phase1", phase_id="p1"),
|
||||||
|
_make_phase(name="phase2", phase_id="p2"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
loaded = await cp.load_plan("plan_42")
|
||||||
|
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded["id"] == "plan_42"
|
||||||
|
assert loaded["task"] == "build feature"
|
||||||
|
assert len(loaded["phases"]) == 2
|
||||||
|
assert loaded["phases"][0]["name"] == "phase1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_plan_nonexistent_returns_none(self):
|
||||||
|
"""load_plan 不存在的 plan_id → None"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
loaded = await cp.load_plan("nonexistent")
|
||||||
|
assert loaded is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── clear 测试 ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointClear:
|
||||||
|
"""clear 测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_removes_all_data(self):
|
||||||
|
"""clear 清除所有 checkpoint 和 plan 数据"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
plan = _make_plan()
|
||||||
|
phase = _make_phase()
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
await cp.clear("plan_1")
|
||||||
|
|
||||||
|
assert await cp.load("plan_1") is None
|
||||||
|
assert await cp.list_checkpoints("plan_1") == []
|
||||||
|
assert await cp.load_plan("plan_1") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_nonexistent_does_not_raise(self):
|
||||||
|
"""clear 不存在的 plan_id 不抛异常"""
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
await cp.clear("nonexistent") # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ── Redis 模式测试 ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointRedis:
|
||||||
|
"""Redis 模式测试(使用 mock Redis)"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_save_then_load(self):
|
||||||
|
"""Redis 模式下 save/load 正常工作"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis)
|
||||||
|
|
||||||
|
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == "p1"
|
||||||
|
assert loaded.phase_status == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_save_plan_then_load_plan(self):
|
||||||
|
"""Redis 模式下 save_plan/load_plan 正常工作"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis)
|
||||||
|
|
||||||
|
plan = _make_plan(plan_id="plan_99")
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
loaded = await cp.load_plan("plan_99")
|
||||||
|
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded["id"] == "plan_99"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_clear_removes_all(self):
|
||||||
|
"""Redis 模式下 clear 清除所有数据"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis)
|
||||||
|
|
||||||
|
phase = _make_phase()
|
||||||
|
plan = _make_plan()
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
|
||||||
|
await cp.clear("plan_1")
|
||||||
|
|
||||||
|
assert await cp.load("plan_1") is None
|
||||||
|
assert await cp.load_plan("plan_1") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_failure_falls_back_to_memory(self):
|
||||||
|
"""Redis 异常时降级到内存,save/load 仍工作"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
# 让 pipeline 抛异常(save 写 Redis 失败)
|
||||||
|
redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost"))
|
||||||
|
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis)
|
||||||
|
|
||||||
|
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||||||
|
# save 不应抛异常,降级到内存
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
# Redis get/zrange 也失败时,load 从内存降级读取
|
||||||
|
redis.get = AsyncMock(side_effect=Exception("Redis down"))
|
||||||
|
redis.zrange = AsyncMock(side_effect=Exception("Redis down"))
|
||||||
|
|
||||||
|
loaded = await cp.load("plan_1")
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.phase_id == phase.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_save_exception_does_not_block(self):
|
||||||
|
"""save 时 Redis 异常不阻断执行"""
|
||||||
|
redis = _make_mock_redis()
|
||||||
|
# 让 pipeline 抛异常,但 zrange/get 仍正常工作
|
||||||
|
redis.pipeline = MagicMock(side_effect=Exception("Redis write error"))
|
||||||
|
|
||||||
|
cp = PipelineCheckpoint(redis_client=redis)
|
||||||
|
|
||||||
|
phase = _make_phase()
|
||||||
|
# 不应抛异常
|
||||||
|
await cp.save("plan_1", phase, "executing")
|
||||||
|
|
||||||
|
# 内存降级中应有数据(Redis zrange 返回空 → 降级到内存)
|
||||||
|
checkpoints = await cp.list_checkpoints("plan_1")
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── CheckpointData 数据类测试 ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointData:
|
||||||
|
"""CheckpointData 序列化/反序列化测试"""
|
||||||
|
|
||||||
|
def test_to_dict_contains_all_fields(self):
|
||||||
|
"""to_dict 包含所有字段"""
|
||||||
|
data = CheckpointData(
|
||||||
|
plan_id="p1",
|
||||||
|
phase_id="ph1",
|
||||||
|
phase_name="Phase 1",
|
||||||
|
phase_status="completed",
|
||||||
|
phase_result={"content": "output"},
|
||||||
|
plan_status="executing",
|
||||||
|
)
|
||||||
|
d = data.to_dict()
|
||||||
|
assert d["plan_id"] == "p1"
|
||||||
|
assert d["phase_id"] == "ph1"
|
||||||
|
assert d["phase_name"] == "Phase 1"
|
||||||
|
assert d["phase_status"] == "completed"
|
||||||
|
assert d["phase_result"] == {"content": "output"}
|
||||||
|
assert d["plan_status"] == "executing"
|
||||||
|
assert "saved_at" in d
|
||||||
|
|
||||||
|
def test_from_dict_roundtrip(self):
|
||||||
|
"""from_dict 反序列化正确"""
|
||||||
|
original = CheckpointData(
|
||||||
|
plan_id="p1",
|
||||||
|
phase_id="ph1",
|
||||||
|
phase_name="Phase 1",
|
||||||
|
phase_status="completed",
|
||||||
|
phase_result={"content": "output"},
|
||||||
|
plan_status="executing",
|
||||||
|
)
|
||||||
|
d = original.to_dict()
|
||||||
|
restored = CheckpointData.from_dict(d)
|
||||||
|
|
||||||
|
assert restored.plan_id == original.plan_id
|
||||||
|
assert restored.phase_id == original.phase_id
|
||||||
|
assert restored.phase_name == original.phase_name
|
||||||
|
assert restored.phase_status == original.phase_status
|
||||||
|
assert restored.phase_result == original.phase_result
|
||||||
|
assert restored.plan_status == original.plan_status
|
||||||
|
|
||||||
|
def test_from_dict_with_missing_fields_uses_defaults(self):
|
||||||
|
"""from_dict 缺失字段时使用默认值"""
|
||||||
|
data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"})
|
||||||
|
assert data.plan_id == "p1"
|
||||||
|
assert data.phase_id == "ph1"
|
||||||
|
assert data.phase_name == ""
|
||||||
|
assert data.phase_status == ""
|
||||||
|
assert data.phase_result is None
|
||||||
|
assert data.plan_status == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── TeamOrchestrator.resume 集成测试 ─────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestratorResume:
|
||||||
|
"""TeamOrchestrator.resume 集成测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_without_checkpoint_returns_failed(self):
|
||||||
|
"""无 checkpoint manager 时 resume 返回 failed"""
|
||||||
|
team = _make_team_with_experts()
|
||||||
|
orchestrator = TeamOrchestrator(team=team) # no checkpoint
|
||||||
|
|
||||||
|
result = await orchestrator.resume("plan_1")
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "No checkpoint manager" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_nonexistent_plan_returns_failed(self):
|
||||||
|
"""resume 不存在的 plan_id 返回 failed"""
|
||||||
|
team = _make_team_with_experts()
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||||
|
|
||||||
|
result = await orchestrator.resume("nonexistent")
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "No checkpoint" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_skips_completed_phases(self):
|
||||||
|
"""resume 跳过已完成阶段,只执行未完成阶段"""
|
||||||
|
# 创建一个有 2 个阶段的 plan,阶段 1 已完成,阶段 2 依赖阶段 1
|
||||||
|
phase1 = PlanPhase(
|
||||||
|
id="p1",
|
||||||
|
name="phase1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
task_description="task 1",
|
||||||
|
status=PhaseStatus.COMPLETED,
|
||||||
|
result={"content": "phase1 output"},
|
||||||
|
)
|
||||||
|
phase2 = PlanPhase(
|
||||||
|
id="p2",
|
||||||
|
name="phase2",
|
||||||
|
assigned_expert="member1",
|
||||||
|
task_description="task 2",
|
||||||
|
depends_on=["p1"],
|
||||||
|
status=PhaseStatus.PENDING,
|
||||||
|
)
|
||||||
|
plan = _make_plan(
|
||||||
|
plan_id="plan_resume",
|
||||||
|
task="test resume",
|
||||||
|
phases=[phase1, phase2],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存 plan + checkpoint for phase1
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
await cp.save("plan_resume", phase1, "executing")
|
||||||
|
|
||||||
|
# 创建 team + orchestrator
|
||||||
|
team = _make_team_with_experts(expert_names=["lead", "member1"])
|
||||||
|
# 设置 mock LLM gateway 用于 synthesis
|
||||||
|
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||||||
|
team._experts["lead"].agent._llm_gateway = gateway
|
||||||
|
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||||
|
|
||||||
|
result = await orchestrator.resume("plan_resume")
|
||||||
|
|
||||||
|
# 验证结果
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
# phase1 的结果应从 checkpoint 恢复
|
||||||
|
assert "p1" in result["phase_results"]
|
||||||
|
# phase2 应被执行
|
||||||
|
assert "p2" in result["phase_results"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_all_phases_completed_skips_execution(self):
|
||||||
|
"""resume 时所有阶段都已完成 → 直接 synthesis"""
|
||||||
|
phase1 = PlanPhase(
|
||||||
|
id="p1",
|
||||||
|
name="phase1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
task_description="task 1",
|
||||||
|
status=PhaseStatus.COMPLETED,
|
||||||
|
result={"content": "phase1 output"},
|
||||||
|
)
|
||||||
|
plan = _make_plan(
|
||||||
|
plan_id="plan_all_done",
|
||||||
|
task="test resume all done",
|
||||||
|
phases=[phase1],
|
||||||
|
)
|
||||||
|
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
await cp.save("plan_all_done", phase1, "executing")
|
||||||
|
|
||||||
|
team = _make_team_with_experts()
|
||||||
|
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||||||
|
team._experts["lead"].agent._llm_gateway = gateway
|
||||||
|
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||||
|
result = await orchestrator.resume("plan_all_done")
|
||||||
|
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
assert "p1" in result["phase_results"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resume_no_active_expert_returns_failed(self):
|
||||||
|
"""resume 时无活跃专家返回 failed"""
|
||||||
|
phase1 = PlanPhase(
|
||||||
|
id="p1",
|
||||||
|
name="phase1",
|
||||||
|
assigned_expert="lead",
|
||||||
|
task_description="task 1",
|
||||||
|
status=PhaseStatus.PENDING,
|
||||||
|
)
|
||||||
|
plan = _make_plan(
|
||||||
|
plan_id="plan_no_expert",
|
||||||
|
task="test",
|
||||||
|
phases=[phase1],
|
||||||
|
)
|
||||||
|
|
||||||
|
cp = PipelineCheckpoint()
|
||||||
|
await cp.save_plan(plan)
|
||||||
|
|
||||||
|
# 创建无活跃专家的 team
|
||||||
|
team = _make_team_with_experts()
|
||||||
|
for expert in team._experts.values():
|
||||||
|
expert.is_active = False
|
||||||
|
|
||||||
|
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||||||
|
result = await orchestrator.resume("plan_no_expert")
|
||||||
|
|
||||||
|
assert result["status"] == "failed"
|
||||||
|
assert "No active expert" in result["error"]
|
||||||
Loading…
Reference in New Issue