682 lines
24 KiB
Python
682 lines
24 KiB
Python
"""PipelineCheckpoint 单元测试 (U7)
|
||
|
||
测试覆盖:
|
||
- save/load 阶段 checkpoint(内存模式 + Redis mock)
|
||
- save_plan/load_plan 完整 plan 序列化
|
||
- list_checkpoints 按保存时间排序
|
||
- clear 清除所有数据
|
||
- Redis 异常降级到内存
|
||
- resume 从 checkpoint 恢复执行
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||
from agentkit.experts.plan import PhaseStatus, PlanPhase, PlanStatus, TeamPlan
|
||
from agentkit.orchestrator.checkpoint import CheckpointData, PipelineCheckpoint
|
||
from tests.unit.experts.test_team_orchestrator import (
|
||
_make_mock_llm_gateway,
|
||
_make_team_with_experts,
|
||
)
|
||
|
||
|
||
# ── 辅助函数 ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_phase(
|
||
name: str = "test_phase",
|
||
phase_id: str = "phase_1",
|
||
status: PhaseStatus = PhaseStatus.COMPLETED,
|
||
result: dict | None = None,
|
||
) -> PlanPhase:
|
||
"""创建测试用 PlanPhase"""
|
||
return PlanPhase(
|
||
id=phase_id,
|
||
name=name,
|
||
assigned_expert="test_expert",
|
||
task_description="test task",
|
||
status=status,
|
||
result=result or {"content": "test output"},
|
||
)
|
||
|
||
|
||
def _make_plan(
|
||
plan_id: str = "plan_1",
|
||
task: str = "test task",
|
||
phases: list[PlanPhase] | None = None,
|
||
) -> TeamPlan:
|
||
"""创建测试用 TeamPlan"""
|
||
plan = TeamPlan(id=plan_id, task=task, lead_expert="lead")
|
||
plan.status = PlanStatus.EXECUTING
|
||
if phases:
|
||
plan.phases = phases
|
||
return plan
|
||
|
||
|
||
def _make_mock_redis() -> AsyncMock:
|
||
"""创建 mock Redis client,模拟 aioredis 行为。"""
|
||
redis = AsyncMock()
|
||
# 内部存储
|
||
store: dict[str, str] = {}
|
||
zsets: dict[str, dict[str, float]] = {}
|
||
|
||
async def _set(key, value, ex=None): # noqa: ARG001
|
||
store[key] = value
|
||
|
||
async def _get(key):
|
||
return store.get(key)
|
||
|
||
async def _zadd(key, mapping):
|
||
zsets.setdefault(key, {}).update(mapping)
|
||
|
||
async def _zrange(key, start, stop):
|
||
members = zsets.get(key, {})
|
||
sorted_members = sorted(members.keys(), key=lambda m: members[m])
|
||
if start == 0 and stop == -1:
|
||
return sorted_members
|
||
return sorted_members[start : stop + 1] if stop >= 0 else sorted_members[start:]
|
||
|
||
async def _delete(*keys):
|
||
count = 0
|
||
for k in keys:
|
||
if k in store:
|
||
del store[k]
|
||
count += 1
|
||
if k in zsets:
|
||
del zsets[k]
|
||
count += 1
|
||
return count
|
||
|
||
def _pipeline():
|
||
commands: list[tuple] = []
|
||
|
||
class _Pipe:
|
||
def set(self, key, value, ex=None):
|
||
commands.append(("set", key, value, ex))
|
||
|
||
def zadd(self, key, mapping):
|
||
commands.append(("zadd", key, mapping))
|
||
|
||
def delete(self, *keys):
|
||
commands.append(("delete", keys))
|
||
|
||
async def execute(self):
|
||
for cmd in commands:
|
||
if cmd[0] == "set":
|
||
await _set(cmd[1], cmd[2], cmd[3])
|
||
elif cmd[0] == "zadd":
|
||
await _zadd(cmd[1], cmd[2])
|
||
elif cmd[0] == "delete":
|
||
await _delete(*cmd[1])
|
||
|
||
return _Pipe()
|
||
|
||
redis.set = AsyncMock(side_effect=_set)
|
||
redis.get = AsyncMock(side_effect=_get)
|
||
redis.zadd = AsyncMock(side_effect=_zadd)
|
||
redis.zrange = AsyncMock(side_effect=_zrange)
|
||
redis.delete = AsyncMock(side_effect=_delete)
|
||
redis.pipeline = MagicMock(side_effect=_pipeline)
|
||
return redis
|
||
|
||
|
||
# ── PipelineCheckpoint 基础测试 ──────────────────────────
|
||
|
||
|
||
class TestCheckpointSaveLoad:
|
||
"""save / load / list_checkpoints 基础测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_save_then_load_returns_last_completed(self):
|
||
"""save 后 load 返回最后一个 COMPLETED 阶段"""
|
||
cp = PipelineCheckpoint() # 内存模式
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
|
||
await cp.save("plan_1", phase, "executing")
|
||
loaded = await cp.load("plan_1")
|
||
|
||
assert loaded is not None
|
||
assert loaded.plan_id == "plan_1"
|
||
assert loaded.phase_id == "p1"
|
||
assert loaded.phase_status == "completed"
|
||
assert loaded.plan_status == "executing"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_returns_last_completed_of_multiple(self):
|
||
"""3 个阶段完成后,load 返回第 3 个(最后一个 COMPLETED)"""
|
||
cp = PipelineCheckpoint()
|
||
for i in range(3):
|
||
phase = _make_phase(
|
||
name=f"phase_{i}",
|
||
phase_id=f"p{i}",
|
||
status=PhaseStatus.COMPLETED,
|
||
result={"content": f"output_{i}"},
|
||
)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == "p2"
|
||
assert loaded.phase_result == {"content": "output_2"}
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_nonexistent_plan_returns_none(self):
|
||
"""load 不存在的 plan_id → None"""
|
||
cp = PipelineCheckpoint()
|
||
loaded = await cp.load("nonexistent")
|
||
assert loaded is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_with_no_completed_returns_none(self):
|
||
"""只有 FAILED 阶段时 load 返回 None"""
|
||
cp = PipelineCheckpoint()
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.FAILED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_checkpoints_returns_all(self):
|
||
"""list_checkpoints 返回所有 checkpoint"""
|
||
cp = PipelineCheckpoint()
|
||
for i in range(3):
|
||
phase = _make_phase(phase_id=f"p{i}", status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert len(checkpoints) == 3
|
||
assert all(c.plan_id == "plan_1" for c in checkpoints)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_save_serializes_phase_status_enum(self):
|
||
"""save 正确序列化 PhaseStatus enum 为字符串"""
|
||
cp = PipelineCheckpoint()
|
||
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert checkpoints[0].phase_status == "completed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_save_handles_non_dict_result(self):
|
||
"""save 处理非 dict 的 phase.result"""
|
||
cp = PipelineCheckpoint()
|
||
phase = _make_phase()
|
||
phase.result = "plain string result"
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert checkpoints[0].phase_result == {"content": "plain string result"}
|
||
|
||
|
||
# ── save_plan / load_plan 测试 ───────────────────────────
|
||
|
||
|
||
class TestCheckpointSavePlan:
|
||
"""save_plan / load_plan 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_save_plan_then_load_plan_roundtrip(self):
|
||
"""save_plan 后 load_plan 返回 plan dict"""
|
||
cp = PipelineCheckpoint()
|
||
plan = _make_plan(
|
||
plan_id="plan_42",
|
||
task="build feature",
|
||
phases=[
|
||
_make_phase(name="phase1", phase_id="p1"),
|
||
_make_phase(name="phase2", phase_id="p2"),
|
||
],
|
||
)
|
||
|
||
await cp.save_plan(plan)
|
||
loaded = await cp.load_plan("plan_42")
|
||
|
||
assert loaded is not None
|
||
assert loaded["id"] == "plan_42"
|
||
assert loaded["task"] == "build feature"
|
||
assert len(loaded["phases"]) == 2
|
||
assert loaded["phases"][0]["name"] == "phase1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_load_plan_nonexistent_returns_none(self):
|
||
"""load_plan 不存在的 plan_id → None"""
|
||
cp = PipelineCheckpoint()
|
||
loaded = await cp.load_plan("nonexistent")
|
||
assert loaded is None
|
||
|
||
|
||
# ── TTL 过期测试 ─────────────────────────────────────────
|
||
|
||
|
||
class TestCheckpointTTL:
|
||
"""内存模式 TTL 过期测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_memory_checkpoint_expired_returns_none(self):
|
||
"""内存模式下 checkpoint 过期后 load 返回 None"""
|
||
# 使用极短 TTL(1 秒)便于测试
|
||
cp = PipelineCheckpoint(ttl_seconds=1)
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
# 立即 load 应正常返回
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == "p1"
|
||
|
||
# 手动将 saved_at 改为过期时间
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||
cp._memory["plan_1"][0].saved_at = expired_time
|
||
|
||
# 过期后 load 应返回 None
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_memory_plan_expired_returns_none(self):
|
||
"""内存模式下 plan 过期后 load_plan 返回 None"""
|
||
cp = PipelineCheckpoint(ttl_seconds=1)
|
||
plan = _make_plan(plan_id="plan_ttl")
|
||
await cp.save_plan(plan)
|
||
|
||
# 立即 load_plan 应正常返回
|
||
loaded = await cp.load_plan("plan_ttl")
|
||
assert loaded is not None
|
||
assert loaded["id"] == "plan_ttl"
|
||
|
||
# 手动将保存时间改为过期
|
||
import time as _time
|
||
|
||
cp._memory_plans["plan_ttl"] = (cp._memory_plans["plan_ttl"][0], _time.time() - 10)
|
||
|
||
# 过期后 load_plan 应返回 None
|
||
loaded = await cp.load_plan("plan_ttl")
|
||
assert loaded is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_memory_checkpoint_not_expired_returns_data(self):
|
||
"""内存模式下 checkpoint 未过期时正常返回"""
|
||
cp = PipelineCheckpoint(ttl_seconds=3600) # 1 小时 TTL
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == "p1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_memory_list_checkpoints_filters_expired(self):
|
||
"""内存模式下 list_checkpoints 过滤过期 checkpoint"""
|
||
cp = PipelineCheckpoint(ttl_seconds=1)
|
||
# 保存 2 个 checkpoint
|
||
phase1 = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
phase2 = _make_phase(phase_id="p2", status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase1, "executing")
|
||
await cp.save("plan_1", phase2, "executing")
|
||
|
||
# 立即 list 应返回 2 个
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert len(checkpoints) == 2
|
||
|
||
# 将第一个 checkpoint 标记为过期
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||
cp._memory["plan_1"][0].saved_at = expired_time
|
||
|
||
# list 应过滤掉过期的,只返回 1 个
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert len(checkpoints) == 1
|
||
assert checkpoints[0].phase_id == "p2"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_empty_falls_back_to_memory_with_ttl_filter(self):
|
||
"""Redis 无数据时降级到内存,并应用 TTL 过滤"""
|
||
redis = _make_mock_redis()
|
||
cp = PipelineCheckpoint(redis_client=redis, ttl_seconds=1)
|
||
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
# save 会同时写入 Redis 和内存降级
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
# 立即 load 应正常返回(Redis 有数据)
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
|
||
# 模拟 Redis 数据丢失:zrange 返回空,get 返回 None
|
||
redis.zrange = AsyncMock(return_value=[])
|
||
redis.get = AsyncMock(return_value=None)
|
||
|
||
# 内存降级 + TTL 未过期 → 应返回数据
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == "p1"
|
||
|
||
# 标记为过期
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
expired_time = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||
cp._memory["plan_1"][0].saved_at = expired_time
|
||
|
||
# 内存降级 + TTL 过期 → 应返回 None
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is None
|
||
|
||
|
||
# ── clear 测试 ───────────────────────────────────────────
|
||
|
||
|
||
class TestCheckpointClear:
|
||
"""clear 测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_clear_removes_all_data(self):
|
||
"""clear 清除所有 checkpoint 和 plan 数据"""
|
||
cp = PipelineCheckpoint()
|
||
plan = _make_plan()
|
||
phase = _make_phase()
|
||
await cp.save_plan(plan)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
await cp.clear("plan_1")
|
||
|
||
assert await cp.load("plan_1") is None
|
||
assert await cp.list_checkpoints("plan_1") == []
|
||
assert await cp.load_plan("plan_1") is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_clear_nonexistent_does_not_raise(self):
|
||
"""clear 不存在的 plan_id 不抛异常"""
|
||
cp = PipelineCheckpoint()
|
||
await cp.clear("nonexistent") # should not raise
|
||
|
||
|
||
# ── Redis 模式测试 ───────────────────────────────────────
|
||
|
||
|
||
class TestCheckpointRedis:
|
||
"""Redis 模式测试(使用 mock Redis)"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_save_then_load(self):
|
||
"""Redis 模式下 save/load 正常工作"""
|
||
redis = _make_mock_redis()
|
||
cp = PipelineCheckpoint(redis_client=redis)
|
||
|
||
phase = _make_phase(phase_id="p1", status=PhaseStatus.COMPLETED)
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == "p1"
|
||
assert loaded.phase_status == "completed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_save_plan_then_load_plan(self):
|
||
"""Redis 模式下 save_plan/load_plan 正常工作"""
|
||
redis = _make_mock_redis()
|
||
cp = PipelineCheckpoint(redis_client=redis)
|
||
|
||
plan = _make_plan(plan_id="plan_99")
|
||
await cp.save_plan(plan)
|
||
loaded = await cp.load_plan("plan_99")
|
||
|
||
assert loaded is not None
|
||
assert loaded["id"] == "plan_99"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_clear_removes_all(self):
|
||
"""Redis 模式下 clear 清除所有数据"""
|
||
redis = _make_mock_redis()
|
||
cp = PipelineCheckpoint(redis_client=redis)
|
||
|
||
phase = _make_phase()
|
||
plan = _make_plan()
|
||
await cp.save("plan_1", phase, "executing")
|
||
await cp.save_plan(plan)
|
||
|
||
await cp.clear("plan_1")
|
||
|
||
assert await cp.load("plan_1") is None
|
||
assert await cp.load_plan("plan_1") is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_failure_falls_back_to_memory(self):
|
||
"""Redis 异常时降级到内存,save/load 仍工作"""
|
||
redis = _make_mock_redis()
|
||
# 让 pipeline 抛异常(save 写 Redis 失败)
|
||
redis.pipeline = MagicMock(side_effect=Exception("Redis connection lost"))
|
||
|
||
cp = PipelineCheckpoint(redis_client=redis)
|
||
|
||
phase = _make_phase(status=PhaseStatus.COMPLETED)
|
||
# save 不应抛异常,降级到内存
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
# Redis get/zrange 也失败时,load 从内存降级读取
|
||
redis.get = AsyncMock(side_effect=Exception("Redis down"))
|
||
redis.zrange = AsyncMock(side_effect=Exception("Redis down"))
|
||
|
||
loaded = await cp.load("plan_1")
|
||
assert loaded is not None
|
||
assert loaded.phase_id == phase.id
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_redis_save_exception_does_not_block(self):
|
||
"""save 时 Redis 异常不阻断执行"""
|
||
redis = _make_mock_redis()
|
||
# 让 pipeline 抛异常,但 zrange/get 仍正常工作
|
||
redis.pipeline = MagicMock(side_effect=Exception("Redis write error"))
|
||
|
||
cp = PipelineCheckpoint(redis_client=redis)
|
||
|
||
phase = _make_phase()
|
||
# 不应抛异常
|
||
await cp.save("plan_1", phase, "executing")
|
||
|
||
# 内存降级中应有数据(Redis zrange 返回空 → 降级到内存)
|
||
checkpoints = await cp.list_checkpoints("plan_1")
|
||
assert len(checkpoints) == 1
|
||
|
||
|
||
# ── CheckpointData 数据类测试 ─────────────────────────────
|
||
|
||
|
||
class TestCheckpointData:
|
||
"""CheckpointData 序列化/反序列化测试"""
|
||
|
||
def test_to_dict_contains_all_fields(self):
|
||
"""to_dict 包含所有字段"""
|
||
data = CheckpointData(
|
||
plan_id="p1",
|
||
phase_id="ph1",
|
||
phase_name="Phase 1",
|
||
phase_status="completed",
|
||
phase_result={"content": "output"},
|
||
plan_status="executing",
|
||
)
|
||
d = data.to_dict()
|
||
assert d["plan_id"] == "p1"
|
||
assert d["phase_id"] == "ph1"
|
||
assert d["phase_name"] == "Phase 1"
|
||
assert d["phase_status"] == "completed"
|
||
assert d["phase_result"] == {"content": "output"}
|
||
assert d["plan_status"] == "executing"
|
||
assert "saved_at" in d
|
||
|
||
def test_from_dict_roundtrip(self):
|
||
"""from_dict 反序列化正确"""
|
||
original = CheckpointData(
|
||
plan_id="p1",
|
||
phase_id="ph1",
|
||
phase_name="Phase 1",
|
||
phase_status="completed",
|
||
phase_result={"content": "output"},
|
||
plan_status="executing",
|
||
)
|
||
d = original.to_dict()
|
||
restored = CheckpointData.from_dict(d)
|
||
|
||
assert restored.plan_id == original.plan_id
|
||
assert restored.phase_id == original.phase_id
|
||
assert restored.phase_name == original.phase_name
|
||
assert restored.phase_status == original.phase_status
|
||
assert restored.phase_result == original.phase_result
|
||
assert restored.plan_status == original.plan_status
|
||
|
||
def test_from_dict_with_missing_fields_uses_defaults(self):
|
||
"""from_dict 缺失字段时使用默认值"""
|
||
data = CheckpointData.from_dict({"plan_id": "p1", "phase_id": "ph1"})
|
||
assert data.plan_id == "p1"
|
||
assert data.phase_id == "ph1"
|
||
assert data.phase_name == ""
|
||
assert data.phase_status == ""
|
||
assert data.phase_result is None
|
||
assert data.plan_status == ""
|
||
|
||
|
||
# ── TeamOrchestrator.resume 集成测试 ─────────────────────
|
||
|
||
|
||
class TestOrchestratorResume:
|
||
"""TeamOrchestrator.resume 集成测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resume_without_checkpoint_returns_failed(self):
|
||
"""无 checkpoint manager 时 resume 返回 failed"""
|
||
team = _make_team_with_experts()
|
||
orchestrator = TeamOrchestrator(team=team) # no checkpoint
|
||
|
||
result = await orchestrator.resume("plan_1")
|
||
|
||
assert result["status"] == "failed"
|
||
assert "No checkpoint manager" in result["error"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resume_nonexistent_plan_returns_failed(self):
|
||
"""resume 不存在的 plan_id 返回 failed"""
|
||
team = _make_team_with_experts()
|
||
cp = PipelineCheckpoint()
|
||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||
|
||
result = await orchestrator.resume("nonexistent")
|
||
|
||
assert result["status"] == "failed"
|
||
assert "No checkpoint" in result["error"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resume_skips_completed_phases(self):
|
||
"""resume 跳过已完成阶段,只执行未完成阶段"""
|
||
# 创建一个有 2 个阶段的 plan,阶段 1 已完成,阶段 2 依赖阶段 1
|
||
phase1 = PlanPhase(
|
||
id="p1",
|
||
name="phase1",
|
||
assigned_expert="lead",
|
||
task_description="task 1",
|
||
status=PhaseStatus.COMPLETED,
|
||
result={"content": "phase1 output"},
|
||
)
|
||
phase2 = PlanPhase(
|
||
id="p2",
|
||
name="phase2",
|
||
assigned_expert="member1",
|
||
task_description="task 2",
|
||
depends_on=["p1"],
|
||
status=PhaseStatus.PENDING,
|
||
)
|
||
plan = _make_plan(
|
||
plan_id="plan_resume",
|
||
task="test resume",
|
||
phases=[phase1, phase2],
|
||
)
|
||
|
||
# 保存 plan + checkpoint for phase1
|
||
cp = PipelineCheckpoint()
|
||
await cp.save_plan(plan)
|
||
await cp.save("plan_resume", phase1, "executing")
|
||
|
||
# 创建 team + orchestrator
|
||
team = _make_team_with_experts(expert_names=["lead", "member1"])
|
||
# 设置 mock LLM gateway 用于 synthesis
|
||
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||
team._experts["lead"].agent._llm_gateway = gateway
|
||
|
||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||
|
||
result = await orchestrator.resume("plan_resume")
|
||
|
||
# 验证结果
|
||
assert result["status"] == "completed"
|
||
# phase1 的结果应从 checkpoint 恢复
|
||
assert "p1" in result["phase_results"]
|
||
# phase2 应被执行
|
||
assert "p2" in result["phase_results"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resume_all_phases_completed_skips_execution(self):
|
||
"""resume 时所有阶段都已完成 → 直接 synthesis"""
|
||
phase1 = PlanPhase(
|
||
id="p1",
|
||
name="phase1",
|
||
assigned_expert="lead",
|
||
task_description="task 1",
|
||
status=PhaseStatus.COMPLETED,
|
||
result={"content": "phase1 output"},
|
||
)
|
||
plan = _make_plan(
|
||
plan_id="plan_all_done",
|
||
task="test resume all done",
|
||
phases=[phase1],
|
||
)
|
||
|
||
cp = PipelineCheckpoint()
|
||
await cp.save_plan(plan)
|
||
await cp.save("plan_all_done", phase1, "executing")
|
||
|
||
team = _make_team_with_experts()
|
||
gateway = _make_mock_llm_gateway(synthesis_content="综合结果")
|
||
team._experts["lead"].agent._llm_gateway = gateway
|
||
|
||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||
result = await orchestrator.resume("plan_all_done")
|
||
|
||
assert result["status"] == "completed"
|
||
assert "p1" in result["phase_results"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resume_no_active_expert_returns_failed(self):
|
||
"""resume 时无活跃专家返回 failed"""
|
||
phase1 = PlanPhase(
|
||
id="p1",
|
||
name="phase1",
|
||
assigned_expert="lead",
|
||
task_description="task 1",
|
||
status=PhaseStatus.PENDING,
|
||
)
|
||
plan = _make_plan(
|
||
plan_id="plan_no_expert",
|
||
task="test",
|
||
phases=[phase1],
|
||
)
|
||
|
||
cp = PipelineCheckpoint()
|
||
await cp.save_plan(plan)
|
||
|
||
# 创建无活跃专家的 team
|
||
team = _make_team_with_experts()
|
||
for expert in team._experts.values():
|
||
expert.is_active = False
|
||
|
||
orchestrator = TeamOrchestrator(team=team, checkpoint=cp)
|
||
result = await orchestrator.resume("plan_no_expert")
|
||
|
||
assert result["status"] == "failed"
|
||
assert "No active expert" in result["error"]
|