329 lines
12 KiB
Python
329 lines
12 KiB
Python
"""CollaborationPlan 数据模型单元测试"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import pytest
|
||
|
||
from agentkit.experts.plan import (
|
||
CollaborationPlan,
|
||
MergeStrategy,
|
||
ParallelType,
|
||
PhaseStatus,
|
||
PlanPhase,
|
||
PlanStatus,
|
||
)
|
||
|
||
|
||
# ── 辅助函数 ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_phase(
|
||
id: str = "phase_1",
|
||
name: str = "分析阶段",
|
||
assigned_expert: str = "analyst",
|
||
task_description: str = "分析需求",
|
||
depends_on: list[str] | None = None,
|
||
parallel_type: ParallelType = ParallelType.SERIAL,
|
||
merge_strategy: MergeStrategy | None = None,
|
||
milestone: str = "",
|
||
status: PhaseStatus = PhaseStatus.PENDING,
|
||
result: dict | None = None,
|
||
) -> PlanPhase:
|
||
"""创建测试用 PlanPhase 实例"""
|
||
return PlanPhase(
|
||
id=id,
|
||
name=name,
|
||
assigned_expert=assigned_expert,
|
||
task_description=task_description,
|
||
depends_on=depends_on or [],
|
||
parallel_type=parallel_type,
|
||
merge_strategy=merge_strategy,
|
||
milestone=milestone,
|
||
status=status,
|
||
result=result,
|
||
)
|
||
|
||
|
||
def _make_valid_plan() -> CollaborationPlan:
|
||
"""创建一个有效的协作计划"""
|
||
phases = [
|
||
_make_phase(id="p1", name="需求分析", assigned_expert="analyst", task_description="分析需求"),
|
||
_make_phase(
|
||
id="p2",
|
||
name="架构设计",
|
||
assigned_expert="architect",
|
||
task_description="设计架构",
|
||
depends_on=["p1"],
|
||
),
|
||
_make_phase(
|
||
id="p3",
|
||
name="代码实现",
|
||
assigned_expert="coder",
|
||
task_description="编写代码",
|
||
depends_on=["p2"],
|
||
),
|
||
]
|
||
return CollaborationPlan(
|
||
id="plan_001",
|
||
task="实现用户登录功能",
|
||
phases=phases,
|
||
variables={"project": "fischer"},
|
||
status=PlanStatus.DRAFT,
|
||
lead_expert="architect",
|
||
)
|
||
|
||
|
||
# ── PlanPhase 测试 ────────────────────────────────────────
|
||
|
||
|
||
class TestPlanPhase:
|
||
"""PlanPhase 数据模型测试"""
|
||
|
||
def test_creation_with_all_fields(self):
|
||
"""创建 PlanPhase 并设置所有字段"""
|
||
phase = PlanPhase(
|
||
id="phase_a",
|
||
name="竞品分析",
|
||
assigned_expert="analyst",
|
||
task_description="分析竞品功能",
|
||
depends_on=["phase_0"],
|
||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
||
merge_strategy=MergeStrategy.BEST,
|
||
milestone="竞品报告完成",
|
||
status=PhaseStatus.IN_PROGRESS,
|
||
result={"report": "竞品分析报告"},
|
||
)
|
||
assert phase.id == "phase_a"
|
||
assert phase.name == "竞品分析"
|
||
assert phase.assigned_expert == "analyst"
|
||
assert phase.task_description == "分析竞品功能"
|
||
assert phase.depends_on == ["phase_0"]
|
||
assert phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
|
||
assert phase.merge_strategy == MergeStrategy.BEST
|
||
assert phase.milestone == "竞品报告完成"
|
||
assert phase.status == PhaseStatus.IN_PROGRESS
|
||
assert phase.result == {"report": "竞品分析报告"}
|
||
|
||
def test_to_dict_from_dict_roundtrip(self):
|
||
"""to_dict / from_dict 往返序列化"""
|
||
phase = PlanPhase(
|
||
id="roundtrip_phase",
|
||
name="往返测试",
|
||
assigned_expert="tester",
|
||
task_description="测试序列化",
|
||
depends_on=["dep_a", "dep_b"],
|
||
parallel_type=ParallelType.SUBTASK_PARALLEL,
|
||
merge_strategy=MergeStrategy.VOTE,
|
||
milestone="序列化验证",
|
||
status=PhaseStatus.COMPLETED,
|
||
result={"key": "value"},
|
||
)
|
||
d = phase.to_dict()
|
||
restored = PlanPhase.from_dict(d)
|
||
|
||
assert restored.id == phase.id
|
||
assert restored.name == phase.name
|
||
assert restored.assigned_expert == phase.assigned_expert
|
||
assert restored.task_description == phase.task_description
|
||
assert restored.depends_on == phase.depends_on
|
||
assert restored.parallel_type == phase.parallel_type
|
||
assert restored.merge_strategy == phase.merge_strategy
|
||
assert restored.milestone == phase.milestone
|
||
assert restored.status == phase.status
|
||
assert restored.result == phase.result
|
||
|
||
def test_to_dict_from_dict_with_none_merge_strategy(self):
|
||
"""merge_strategy 为 None 时的序列化往返"""
|
||
phase = PlanPhase(
|
||
id="no_merge",
|
||
name="无合并",
|
||
assigned_expert="dev",
|
||
task_description="串行任务",
|
||
parallel_type=ParallelType.SERIAL,
|
||
)
|
||
d = phase.to_dict()
|
||
assert d["merge_strategy"] is None
|
||
restored = PlanPhase.from_dict(d)
|
||
assert restored.merge_strategy is None
|
||
|
||
|
||
# ── CollaborationPlan 测试 ────────────────────────────────
|
||
|
||
|
||
class TestCollaborationPlan:
|
||
"""CollaborationPlan 数据模型测试"""
|
||
|
||
def test_creation(self):
|
||
"""创建 CollaborationPlan"""
|
||
plan = _make_valid_plan()
|
||
assert plan.id == "plan_001"
|
||
assert plan.task == "实现用户登录功能"
|
||
assert len(plan.phases) == 3
|
||
assert plan.variables == {"project": "fischer"}
|
||
assert plan.status == PlanStatus.DRAFT
|
||
assert plan.lead_expert == "architect"
|
||
|
||
def test_to_dict_from_dict_roundtrip(self):
|
||
"""to_dict / from_dict 往返序列化"""
|
||
plan = _make_valid_plan()
|
||
d = plan.to_dict()
|
||
restored = CollaborationPlan.from_dict(d)
|
||
|
||
assert restored.id == plan.id
|
||
assert restored.task == plan.task
|
||
assert len(restored.phases) == len(plan.phases)
|
||
assert restored.variables == plan.variables
|
||
assert restored.status == plan.status
|
||
assert restored.lead_expert == plan.lead_expert
|
||
|
||
for original, restored_phase in zip(plan.phases, restored.phases):
|
||
assert restored_phase.id == original.id
|
||
assert restored_phase.name == original.name
|
||
assert restored_phase.assigned_expert == original.assigned_expert
|
||
assert restored_phase.depends_on == original.depends_on
|
||
assert restored_phase.parallel_type == original.parallel_type
|
||
assert restored_phase.merge_strategy == original.merge_strategy
|
||
|
||
def test_validate_valid_plan(self):
|
||
"""验证有效计划无错误"""
|
||
plan = _make_valid_plan()
|
||
errors = plan.validate()
|
||
assert errors == []
|
||
|
||
def test_validate_detects_duplicate_phase_ids(self):
|
||
"""验证检测到重复阶段 ID"""
|
||
phases = [
|
||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||
_make_phase(id="p1", name="阶段2", assigned_expert="b", task_description="t2"),
|
||
]
|
||
plan = CollaborationPlan(
|
||
id="dup_plan", task="重复ID测试", phases=phases, lead_expert="a"
|
||
)
|
||
errors = plan.validate()
|
||
assert any("重复的阶段 ID" in e for e in errors)
|
||
|
||
def test_validate_detects_missing_depends_on_references(self):
|
||
"""验证检测到不存在的 depends_on 引用"""
|
||
phases = [
|
||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||
_make_phase(
|
||
id="p2",
|
||
name="阶段2",
|
||
assigned_expert="b",
|
||
task_description="t2",
|
||
depends_on=["p1", "nonexistent"],
|
||
),
|
||
]
|
||
plan = CollaborationPlan(
|
||
id="missing_dep_plan", task="缺失依赖测试", phases=phases, lead_expert="a"
|
||
)
|
||
errors = plan.validate()
|
||
assert any("不存在的阶段 ID" in e for e in errors)
|
||
|
||
def test_validate_detects_circular_dependencies(self):
|
||
"""验证检测到循环依赖"""
|
||
phases = [
|
||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1", depends_on=["p3"]),
|
||
_make_phase(id="p2", name="阶段2", assigned_expert="b", task_description="t2", depends_on=["p1"]),
|
||
_make_phase(id="p3", name="阶段3", assigned_expert="c", task_description="t3", depends_on=["p2"]),
|
||
]
|
||
plan = CollaborationPlan(
|
||
id="cycle_plan", task="循环依赖测试", phases=phases, lead_expert="a"
|
||
)
|
||
errors = plan.validate()
|
||
assert any("循环依赖" in e for e in errors)
|
||
|
||
def test_validate_detects_competitive_parallel_without_merge_strategy(self):
|
||
"""验证检测到 COMPETITIVE_PARALLEL 缺少 merge_strategy"""
|
||
phases = [
|
||
_make_phase(
|
||
id="p1",
|
||
name="竞争阶段",
|
||
assigned_expert="a",
|
||
task_description="竞争任务",
|
||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
||
merge_strategy=None,
|
||
),
|
||
]
|
||
plan = CollaborationPlan(
|
||
id="no_merge_plan", task="缺少合并策略测试", phases=phases, lead_expert="a"
|
||
)
|
||
errors = plan.validate()
|
||
assert any("COMPETITIVE_PARALLEL" in e and "merge_strategy" in e for e in errors)
|
||
|
||
def test_get_ready_phases_returns_phases_with_completed_dependencies(self):
|
||
"""get_ready_phases 返回依赖已完成的阶段"""
|
||
plan = _make_valid_plan()
|
||
# 初始状态:p1 无依赖,应该就绪
|
||
ready = plan.get_ready_phases()
|
||
assert len(ready) == 1
|
||
assert ready[0].id == "p1"
|
||
|
||
# 完成 p1 后,p2 应该就绪
|
||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"analysis": "done"})
|
||
ready = plan.get_ready_phases()
|
||
assert len(ready) == 1
|
||
assert ready[0].id == "p2"
|
||
|
||
# 完成 p2 后,p3 应该就绪
|
||
plan.update_phase_status("p2", PhaseStatus.COMPLETED, {"design": "done"})
|
||
ready = plan.get_ready_phases()
|
||
assert len(ready) == 1
|
||
assert ready[0].id == "p3"
|
||
|
||
def test_get_ready_phases_returns_empty_when_dependencies_not_met(self):
|
||
"""get_ready_phases 在依赖未满足时返回空列表"""
|
||
phases = [
|
||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||
_make_phase(
|
||
id="p2",
|
||
name="阶段2",
|
||
assigned_expert="b",
|
||
task_description="t2",
|
||
depends_on=["p1"],
|
||
),
|
||
]
|
||
plan = CollaborationPlan(
|
||
id="dep_plan", task="依赖未满足测试", phases=phases, lead_expert="a"
|
||
)
|
||
# p2 依赖 p1,p1 未完成,所以 p2 不就绪
|
||
# 但 p1 无依赖,所以 p1 就绪
|
||
ready = plan.get_ready_phases()
|
||
assert len(ready) == 1
|
||
assert ready[0].id == "p1"
|
||
|
||
# 将 p1 设为 IN_PROGRESS(未 COMPLETED),p2 仍不就绪
|
||
plan.update_phase_status("p1", PhaseStatus.IN_PROGRESS)
|
||
ready = plan.get_ready_phases()
|
||
assert len(ready) == 0
|
||
|
||
def test_update_phase_status(self):
|
||
"""update_phase_status 更新阶段状态和结果"""
|
||
plan = _make_valid_plan()
|
||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"output": "分析完成"})
|
||
phase = plan.get_phase("p1")
|
||
assert phase is not None
|
||
assert phase.status == PhaseStatus.COMPLETED
|
||
assert phase.result == {"output": "分析完成"}
|
||
|
||
# 不传 result 时不应覆盖已有 result
|
||
plan.update_phase_status("p2", PhaseStatus.IN_PROGRESS)
|
||
phase2 = plan.get_phase("p2")
|
||
assert phase2 is not None
|
||
assert phase2.status == PhaseStatus.IN_PROGRESS
|
||
assert phase2.result is None
|
||
|
||
def test_get_phase_by_id(self):
|
||
"""get_phase 根据 ID 获取阶段"""
|
||
plan = _make_valid_plan()
|
||
phase = plan.get_phase("p2")
|
||
assert phase is not None
|
||
assert phase.id == "p2"
|
||
assert phase.name == "架构设计"
|
||
|
||
def test_get_phase_with_nonexistent_id_returns_none(self):
|
||
"""get_phase 对不存在的 ID 返回 None"""
|
||
plan = _make_valid_plan()
|
||
phase = plan.get_phase("nonexistent")
|
||
assert phase is None
|