feat: hub-and-spoke experts, tiered tool injection, unified event model (U3/U7/U10)

This commit is contained in:
chiguyong 2026-06-17 10:46:16 +08:00
parent 200174c5c7
commit bbedfff597
18 changed files with 2953 additions and 1869 deletions

View File

@ -3,6 +3,7 @@
from agentkit.core.base import BaseAgent from agentkit.core.base import BaseAgent
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
from agentkit.core.exceptions import ( from agentkit.core.exceptions import (
AgentAlreadyRegisteredError, AgentAlreadyRegisteredError,
AgentFrameworkError, AgentFrameworkError,
@ -29,12 +30,16 @@ from agentkit.core.protocol import (
AgentCapability, AgentCapability,
AgentStatus, AgentStatus,
CancellationToken, CancellationToken,
Event,
EvolutionEvent, EvolutionEvent,
HandoffMessage, HandoffMessage,
SessionEventType,
TaskEventType,
TaskMessage, TaskMessage,
TaskProgress, TaskProgress,
TaskResult, TaskResult,
TaskStatus, TaskStatus,
TurnEventType,
) )
# Optional: HeadroomCompressor — only available when headroom-ai is installed # Optional: HeadroomCompressor — only available when headroom-ai is installed
@ -80,4 +85,12 @@ __all__ = [
"TaskProgress", "TaskProgress",
"TaskResult", "TaskResult",
"TaskStatus", "TaskStatus",
# SQ/EQ 统一事件模型
"Event",
"SessionEventType",
"TaskEventType",
"TurnEventType",
"Submission",
"SubmissionQueue",
"EventQueue",
] ]

View File

@ -0,0 +1,244 @@
"""统一事件模型 - SQ/EQ 双队列实现
参考 Codex Session/Task/Turn 三级模型提供
- SubmissionQueue (SQ): 接收用户输入返回 task_id
- EventQueue (EQ): 推送 Agent 事件支持多订阅者广播和事件缓冲回放
集成点Phase 4 再做实际集成
- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端
- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import AsyncIterator
from agentkit.core.protocol import Event
logger = logging.getLogger(__name__)
# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流)
_CLOSED_SENTINEL: Event = Event(
event_type="__closed__",
task_id="",
session_id="",
data={},
timestamp="",
)
@dataclass
class Submission:
"""用户提交的任务
SubmissionQueue.submit() 创建消费者通过 SubmissionQueue.drain() 获取
Attributes:
task_id: 唯一任务 ID
session_id: 会话 ID
content: 用户输入内容
created_at: 创建时间
cancelled: 是否已取消
"""
task_id: str
session_id: str
content: str
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
cancelled: bool = False
class SubmissionQueue:
"""提交队列 (SQ) - 接收用户输入
用户通过 submit() 提交输入消费者通过 drain() 获取提交
支持通过 task_id 取消提交
内部使用 asyncio.Queue 实现队列大小上限 1024 HandoffTransport 一致
"""
_MAX_QUEUE_SIZE: int = 1024
def __init__(self) -> None:
self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
self._submissions: dict[str, Submission] = {}
self._cancelled_tasks: set[str] = set()
self._closed: bool = False
async def submit(self, content: str, session_id: str) -> str:
"""提交用户输入,返回 task_id
Args:
content: 用户输入内容
session_id: 会话 ID
Returns:
新生成的 task_idUUID4
Raises:
RuntimeError: 队列已关闭
"""
if self._closed:
raise RuntimeError("SubmissionQueue is closed")
task_id = str(uuid.uuid4())
submission = Submission(
task_id=task_id,
session_id=session_id,
content=content,
)
self._submissions[task_id] = submission
await self._queue.put(submission)
return task_id
async def drain(self) -> AsyncIterator[Submission]:
"""消费提交队列(异步生成器)
已取消的提交会被跳过当队列关闭时生成器结束
"""
while True:
submission = await self._queue.get()
if submission.task_id in self._cancelled_tasks:
continue
yield submission
async def cancel(self, task_id: str) -> bool:
"""取消任务
Args:
task_id: 要取消的任务 ID
Returns:
是否成功取消任务存在且未取消过
"""
if task_id not in self._submissions:
return False
if task_id in self._cancelled_tasks:
return False
self._cancelled_tasks.add(task_id)
self._submissions[task_id].cancelled = True
return True
@property
def is_closed(self) -> bool:
"""返回队列是否已关闭"""
return self._closed
def close(self) -> None:
"""关闭队列,不再接受新提交。
已在队列中的提交不受影响消费者仍可通过 drain() 获取
"""
self._closed = True
class EventQueue:
"""事件队列 (EQ) - 推送 Agent 事件
支持多订阅者广播模式每条事件会投递到所有活跃订阅者
新订阅者会收到最近 N 条缓冲事件的回放默认 100
集成点
- Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端
- CLI 可通过 subscribe() 订阅事件流并打印
"""
_MAX_QUEUE_SIZE: int = 1024
_DEFAULT_BUFFER_SIZE: int = 100
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
self._subscribers: list[asyncio.Queue[Event]] = []
self._buffer: deque[Event] = deque(maxlen=buffer_size)
self._buffer_size = buffer_size
self._closed: bool = False
async def emit(self, event: Event) -> None:
"""推送事件给所有订阅者
事件会同时写入缓冲区供未来订阅者回放和所有活跃订阅者队列
如果某订阅者队列已满该事件对该订阅者被丢弃不影响其他订阅者
Args:
event: 要推送的事件
"""
self._buffer.append(event)
for queue in self._subscribers:
try:
queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("EventQueue subscriber queue full, dropping event")
async def subscribe(self) -> AsyncIterator[Event]:
"""订阅事件流(异步生成器)
订阅时会先回放缓冲区中的事件然后持续接收新事件
每个订阅者获得独立的队列实现广播语义
当队列关闭时生成器结束
注意回放和加入订阅者列表在同一同步段内完成 await
保证不会遗漏或重复事件
"""
if self._closed:
return
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
# 回放缓冲事件(同步操作,无 await保证原子性
for event in list(self._buffer):
try:
queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("EventQueue replay buffer full, skipping remaining")
break
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
self._subscribers.append(queue)
try:
while True:
event = await queue.get()
if event is _CLOSED_SENTINEL:
break
yield event
finally:
# 清理:移除当前订阅者的队列
if queue in self._subscribers:
self._subscribers.remove(queue)
@property
def subscriber_count(self) -> int:
"""返回当前订阅者数量"""
return len(self._subscribers)
@property
def buffer_size(self) -> int:
"""返回缓冲区大小上限"""
return self._buffer_size
@property
def is_closed(self) -> bool:
"""返回队列是否已关闭"""
return self._closed
def close(self) -> None:
"""关闭队列,向所有订阅者发送哨兵并清理状态。
使用哨兵模式参考 HandoffTransport确保阻塞在 subscribe() 上的
订阅者能够优雅退出
"""
self._closed = True
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
for queue in self._subscribers:
try:
queue.put_nowait(_CLOSED_SENTINEL)
except asyncio.QueueFull:
pass
self._subscribers.clear()
self._buffer.clear()

View File

@ -10,6 +10,7 @@ from agentkit.core.exceptions import TaskCancelledError
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
"""任务状态枚举""" """任务状态枚举"""
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
COMPLETED = "completed" COMPLETED = "completed"
@ -21,6 +22,7 @@ class TaskStatus(str, Enum):
class AgentStatus(str, Enum): class AgentStatus(str, Enum):
"""Agent 状态枚举""" """Agent 状态枚举"""
ONLINE = "online" ONLINE = "online"
OFFLINE = "offline" OFFLINE = "offline"
BUSY = "busy" BUSY = "busy"
@ -29,6 +31,7 @@ class AgentStatus(str, Enum):
@dataclass @dataclass
class AgentCapability: class AgentCapability:
"""Agent 能力声明""" """Agent 能力声明"""
agent_name: str agent_name: str
agent_type: str agent_type: str
version: str version: str
@ -70,6 +73,7 @@ class AgentCapability:
@dataclass @dataclass
class TaskMessage: class TaskMessage:
"""任务消息 - 从调度器发往 Agent""" """任务消息 - 从调度器发往 Agent"""
task_id: str task_id: str
agent_name: str agent_name: str
task_type: str task_type: str
@ -114,6 +118,7 @@ class TaskMessage:
@dataclass @dataclass
class TaskResult: class TaskResult:
"""任务结果 - 从 Agent 返回""" """任务结果 - 从 Agent 返回"""
task_id: str task_id: str
agent_name: str agent_name: str
status: str status: str
@ -163,6 +168,7 @@ class TaskResult:
@dataclass @dataclass
class TaskProgress: class TaskProgress:
"""进度上报 - Agent 执行过程中上报""" """进度上报 - Agent 执行过程中上报"""
task_id: str task_id: str
agent_name: str agent_name: str
progress: float progress: float
@ -195,6 +201,7 @@ class TaskProgress:
@dataclass @dataclass
class HandoffMessage: class HandoffMessage:
"""任务转交消息 - Agent 间 Handoff""" """任务转交消息 - Agent 间 Handoff"""
source_agent: str source_agent: str
target_agent: str target_agent: str
task_id: str task_id: str
@ -233,6 +240,7 @@ class HandoffMessage:
@dataclass @dataclass
class EvolutionEvent: class EvolutionEvent:
"""进化事件 - 记录 Agent 的自我进化变更""" """进化事件 - 记录 Agent 的自我进化变更"""
agent_name: str agent_name: str
change_type: str # prompt / strategy / pipeline change_type: str # prompt / strategy / pipeline
before: dict[str, Any] before: dict[str, Any]
@ -277,3 +285,102 @@ class CancellationToken:
"""检查是否已取消,若已取消则抛出 TaskCancelledError""" """检查是否已取消,若已取消则抛出 TaskCancelledError"""
if self._cancelled: if self._cancelled:
raise TaskCancelledError(task_id="") raise TaskCancelledError(task_id="")
# ── SQ/EQ 统一事件模型 ──────────────────────────────────────────
# 参考 Codex 的 Session/Task/Turn 三级模型,统一 CLI 和 WebSocket 事件流。
class SessionEventType:
"""Session 级别事件类型
对应一次完整的会话生命周期如用户打开 CLI关闭 CLI
"""
SESSION_STARTED = "session.started"
SESSION_ENDED = "session.ended"
class TaskEventType:
"""Task 级别事件类型
对应一次任务提交用户输入从创建到终态完成/失败
"""
TASK_CREATED = "task.created"
TASK_STARTED = "task.started"
TASK_COMPLETED = "task.completed"
TASK_FAILED = "task.failed"
class TurnEventType:
"""Turn 级别事件类型
对应一个 Task 内的多轮对话/推理步骤包括思考工具调用Token 流等
"""
TURN_STARTED = "turn.started"
THINKING = "turn.thinking"
TOOL_CALL = "turn.tool_call"
TOOL_RESULT = "turn.tool_result"
TOKEN = "turn.token"
STEP = "turn.step"
FINAL_ANSWER = "turn.final_answer"
TURN_COMPLETED = "turn.completed"
@dataclass
class Event:
"""统一事件模型 - SQ/EQ 双队列事件
所有 CLI WebSocket 事件统一使用此结构便于前端和 CLI 统一处理
Attributes:
event_type: 事件类型 SessionEventType / TaskEventType / TurnEventType
task_id: 关联的任务 ID
session_id: 关联的会话 ID
data: 事件负载数据
timestamp: ISO 8601 格式时间戳
"""
event_type: str
task_id: str
session_id: str
data: dict[str, Any]
timestamp: str # ISO 8601 format
def to_dict(self) -> dict:
return {
"event_type": self.event_type,
"task_id": self.task_id,
"session_id": self.session_id,
"data": self.data,
"timestamp": self.timestamp,
}
@classmethod
def from_dict(cls, data: dict) -> "Event":
return cls(
event_type=data["event_type"],
task_id=data["task_id"],
session_id=data["session_id"],
data=data.get("data", {}),
timestamp=data["timestamp"],
)
@classmethod
def create(
cls,
event_type: str,
task_id: str,
session_id: str,
data: dict[str, Any] | None = None,
) -> "Event":
"""创建事件,自动生成 ISO 8601 格式时间戳"""
return cls(
event_type=event_type,
task_id=task_id,
session_id=session_id,
data=data or {},
timestamp=datetime.now(timezone.utc).isoformat(),
)

View File

@ -1,22 +1,23 @@
"""Expert 系统 - 专家团队模式的配置、模板、注册与协作计划""" """Expert 系统 - 专家团队 hub-and-spoke 模式
U3 简化从去中心化协作改为 Lead Expert + 并行 Task 模式
"""
from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.config import ExpertConfig, ExpertTemplate
from agentkit.experts.expert import Expert from agentkit.experts.expert import Expert
from agentkit.experts.orchestrator import TeamOrchestrator from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import ( from agentkit.experts.plan import (
CollaborationPlan,
MergeStrategy, MergeStrategy,
ParallelType,
PhaseStatus,
PlanPhase,
PlanStatus, PlanStatus,
SubTask,
SubTaskStatus,
TeamPlan,
) )
from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.registry import ExpertTemplateRegistry
from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult
from agentkit.experts.team import ExpertTeam, TeamStatus from agentkit.experts.team import ExpertTeam, TeamStatus
__all__ = [ __all__ = [
"CollaborationPlan",
"Expert", "Expert",
"ExpertConfig", "ExpertConfig",
"ExpertTeam", "ExpertTeam",
@ -25,10 +26,10 @@ __all__ = [
"ExpertTemplate", "ExpertTemplate",
"ExpertTemplateRegistry", "ExpertTemplateRegistry",
"MergeStrategy", "MergeStrategy",
"ParallelType",
"PhaseStatus",
"PlanPhase",
"PlanStatus", "PlanStatus",
"SubTask",
"SubTaskStatus",
"TeamOrchestrator", "TeamOrchestrator",
"TeamPlan",
"TeamStatus", "TeamStatus",
] ]

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any from typing import Any
from agentkit.core.config_driven import AgentConfig from agentkit.core.config_driven import AgentConfig

View File

@ -65,7 +65,9 @@ class Expert:
# 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息 # 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息
if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template: if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template:
sections = agent._prompt_template._sections sections = agent._prompt_template._sections
sections.context = f"{team_context}\n\n{sections.context}" if sections.context else team_context sections.context = (
f"{team_context}\n\n{sections.context}" if sections.context else team_context
)
return expert return expert
@ -116,9 +118,7 @@ class Expert:
"reason": reason, "reason": reason,
"type": "assist_request", "type": "assist_request",
} }
await self._handoff_transport.send( await self._handoff_transport.send(f"expert:{target_expert}:handoff", handoff_msg)
f"expert:{target_expert}:handoff", handoff_msg
)
async def propose_plan_modification( async def propose_plan_modification(
self, self,

File diff suppressed because it is too large Load Diff

View File

@ -1,281 +1,172 @@
"""CollaborationPlan 数据模型 - Expert Team 协作蓝图 """Expert Team 计划数据模型 - hub-and-spoke 模式
定义 Expert Team 的结构化协作计划包括阶段角色分配依赖关系 简化为 Lead Expert + 并行 Task 模式
并行类型合并策略和里程碑 - Lead Expert 接收任务并分解为子任务
- 子任务并行执行Task 深度=1不能再 spawn Task
- Lead Expert 汇总结果BEST 策略
移除了阶段依赖图VOTE/FUSION 合并策略跨阶段状态共享
""" """
from __future__ import annotations from __future__ import annotations
import enum import enum
import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
class ParallelType(str, enum.Enum):
"""并行执行类型"""
SERIAL = "serial"
SUBTASK_PARALLEL = "subtask_parallel"
COMPETITIVE_PARALLEL = "competitive_parallel"
class MergeStrategy(str, enum.Enum): class MergeStrategy(str, enum.Enum):
"""合并策略 - 仅用于 COMPETITIVE_PARALLEL 阶段""" """合并策略 - Lead Expert 用于选择最佳结果
hub-and-spoke 模式下仅保留 BEST 策略
Lead Expert 从所有子任务结果中选择或综合出最佳结果
"""
BEST = "best" BEST = "best"
VOTE = "vote"
FUSION = "fusion"
class PhaseStatus(str, enum.Enum):
"""阶段状态"""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
class PlanStatus(str, enum.Enum): class PlanStatus(str, enum.Enum):
"""计划状态""" """计划状态"""
DRAFT = "draft" DRAFT = "draft"
CONFIRMED = "confirmed"
EXECUTING = "executing" EXECUTING = "executing"
COMPLETED = "completed" COMPLETED = "completed"
FAILED = "failed" FAILED = "failed"
FALLBACK = "fallback" FALLBACK = "fallback"
# DFS 着色常量 class SubTaskStatus(str, enum.Enum):
_WHITE = 0 # 未访问 """子任务状态"""
_GRAY = 1 # 正在访问(当前路径上)
_BLACK = 2 # 已完成访问 PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass @dataclass
class PlanPhase: class SubTask:
"""协作计划中的单个阶段 """Lead Expert 分解出的子任务
hub-and-spoke 模式中Lead Expert 将原始任务分解为多个子任务
每个子任务由一个 Expert 并行执行子任务之间无依赖关系无通信
约束
- Task 深度=1子任务不能再 spawn 子任务
- 子任务之间无通信
- Lead Expert 持有所有状态
Attributes: Attributes:
id: 阶段标识符 id: 子任务标识符
name: 阶段显示名称 description: 子任务描述
assigned_expert: 分配到此阶段的 Expert 名称 assigned_expert: 分配的 Expert 名称
task_description: 此阶段完成的任务描述
depends_on: 依赖的阶段 ID 列表
parallel_type: 执行类型
merge_strategy: 合并策略 COMPETITIVE_PARALLEL 需要
milestone: 里程碑检查点描述
status: 当前状态 status: 当前状态
result: 阶段输出结果 result: 子任务输出结果
""" """
id: str id: str = field(default_factory=lambda: str(uuid.uuid4()))
name: str description: str = ""
assigned_expert: str assigned_expert: str = ""
task_description: str status: SubTaskStatus = SubTaskStatus.PENDING
depends_on: list[str] = field(default_factory=list) result: dict[str, Any] | None = None
parallel_type: ParallelType = ParallelType.SERIAL
merge_strategy: MergeStrategy | None = None
milestone: str = ""
status: PhaseStatus = PhaseStatus.PENDING
result: dict | None = None
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""序列化为字典""" """序列化为字典"""
return { return {
"id": self.id, "id": self.id,
"name": self.name, "description": self.description,
"assigned_expert": self.assigned_expert, "assigned_expert": self.assigned_expert,
"task_description": self.task_description,
"depends_on": self.depends_on,
"parallel_type": self.parallel_type.value,
"merge_strategy": self.merge_strategy.value if self.merge_strategy is not None else None,
"milestone": self.milestone,
"status": self.status.value, "status": self.status.value,
"result": self.result, "result": self.result,
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> PlanPhase: def from_dict(cls, data: dict[str, Any]) -> SubTask:
"""从字典创建 PlanPhase""" """从字典创建 SubTask"""
merge_strategy = None
if data.get("merge_strategy") is not None:
merge_strategy = MergeStrategy(data["merge_strategy"])
return cls( return cls(
id=data["id"], id=data.get("id", str(uuid.uuid4())),
name=data["name"], description=data.get("description", ""),
assigned_expert=data["assigned_expert"], assigned_expert=data.get("assigned_expert", ""),
task_description=data["task_description"], status=SubTaskStatus(data.get("status", SubTaskStatus.PENDING.value)),
depends_on=data.get("depends_on", []),
parallel_type=ParallelType(data.get("parallel_type", ParallelType.SERIAL.value)),
merge_strategy=merge_strategy,
milestone=data.get("milestone", ""),
status=PhaseStatus(data.get("status", PhaseStatus.PENDING.value)),
result=data.get("result"), result=data.get("result"),
) )
@dataclass @dataclass
class CollaborationPlan: class TeamPlan:
"""Expert Team 协作计划 """Expert Team hub-and-spoke 执行计划
定义 Expert Team 的结构化协作蓝图包括阶段编排共享变量 Lead Expert 持有此计划包含分解的子任务列表
状态管理和依赖关系 与旧版 CollaborationPlan 不同此计划无阶段依赖图
所有子任务并行执行 Lead Expert 统一汇总
Attributes: Attributes:
id: 计划标识符 id: 计划标识符
task: 原始任务描述 task: 原始任务描述
phases: 有序阶段列表 subtasks: 子任务列表并行执行无依赖关系
variables: 共享变量
status: 计划状态 status: 计划状态
lead_expert: 主导 Expert 名称 lead_expert: 主导 Expert 名称
""" """
id: str id: str = field(default_factory=lambda: str(uuid.uuid4()))
task: str task: str = ""
phases: list[PlanPhase] = field(default_factory=list) subtasks: list[SubTask] = field(default_factory=list)
variables: dict = field(default_factory=dict)
status: PlanStatus = PlanStatus.DRAFT status: PlanStatus = PlanStatus.DRAFT
lead_expert: str = "" lead_expert: str = ""
_phase_index: dict[str, PlanPhase] = field(default_factory=dict, init=False, repr=False)
def __post_init__(self) -> None:
"""Build the phase index after initialization."""
self._rebuild_index()
def _rebuild_index(self) -> None:
"""Rebuild the phase index from the phases list."""
self._phase_index = {phase.id: phase for phase in self.phases}
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""序列化为字典""" """序列化为字典"""
return { return {
"id": self.id, "id": self.id,
"task": self.task, "task": self.task,
"phases": [phase.to_dict() for phase in self.phases], "subtasks": [st.to_dict() for st in self.subtasks],
"variables": self.variables,
"status": self.status.value, "status": self.status.value,
"lead_expert": self.lead_expert, "lead_expert": self.lead_expert,
} }
@classmethod @classmethod
def from_dict(cls, data: dict[str, Any]) -> CollaborationPlan: def from_dict(cls, data: dict[str, Any]) -> TeamPlan:
"""从字典创建 CollaborationPlan""" """从字典创建 TeamPlan"""
phases = [PlanPhase.from_dict(p) for p in data.get("phases", [])] subtasks = [SubTask.from_dict(st) for st in data.get("subtasks", [])]
return cls( return cls(
id=data["id"], id=data.get("id", str(uuid.uuid4())),
task=data["task"], task=data.get("task", ""),
phases=phases, subtasks=subtasks,
variables=data.get("variables", {}),
status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)), status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)),
lead_expert=data.get("lead_expert", ""), lead_expert=data.get("lead_expert", ""),
) )
def validate(self) -> list[str]: def get_subtask(self, subtask_id: str) -> SubTask | None:
"""验证计划,返回错误消息列表(空列表表示有效) """根据 ID 获取子任务,不存在则返回 None"""
for st in self.subtasks:
if st.id == subtask_id:
return st
return None
检查项 def update_subtask_status(
- 无重复阶段 ID self, subtask_id: str, status: SubTaskStatus, result: dict[str, Any] | None = None
- 所有 depends_on 引用存在
- 无循环依赖DFS 着色检测
- COMPETITIVE_PARALLEL 阶段必须有 merge_strategy
"""
errors: list[str] = []
# 构建阶段 ID 集合
phase_ids = {phase.id for phase in self.phases}
# 检查重复阶段 ID
seen_ids: set[str] = set()
for phase in self.phases:
if phase.id in seen_ids:
errors.append(f"重复的阶段 ID: {phase.id}")
seen_ids.add(phase.id)
# 检查 depends_on 引用是否存在
for phase in self.phases:
for dep_id in phase.depends_on:
if dep_id not in phase_ids:
errors.append(
f"阶段 '{phase.id}' 依赖了不存在的阶段 ID: {dep_id}"
)
# 检查循环依赖(迭代 DFS 着色 — 避免递归栈溢出)
color: dict[str, int] = {phase.id: _WHITE for phase in self.phases}
dep_map: dict[str, list[str]] = {
phase.id: phase.depends_on for phase in self.phases
}
for phase in self.phases:
if color[phase.id] != _WHITE:
continue
# Iterative DFS using an explicit stack
stack: list[tuple[str, bool]] = [(phase.id, False)]
while stack:
node, is_backtrack = stack.pop()
if is_backtrack:
color[node] = _BLACK
continue
if color[node] == _GRAY:
# Already on current path — cycle detected
errors.append("检测到循环依赖")
break
if color[node] == _BLACK:
continue
color[node] = _GRAY
# Push backtrack marker
stack.append((node, True))
for neighbor in dep_map.get(node, []):
if neighbor not in color:
continue
if color[neighbor] == _GRAY:
errors.append("检测到循环依赖")
break
if color[neighbor] == _WHITE:
stack.append((neighbor, False))
else:
continue
break # Inner break propagates to outer
if errors:
break # Only report cycle once
# 检查 COMPETITIVE_PARALLEL 必须有 merge_strategy
for phase in self.phases:
if (
phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
and phase.merge_strategy is None
):
errors.append(
f"阶段 '{phase.id}' 为 COMPETITIVE_PARALLEL 但未设置 merge_strategy"
)
return errors
def get_ready_phases(self) -> list[PlanPhase]:
"""获取所有依赖已完成且状态为 PENDING 的阶段"""
completed_ids = {
phase.id for phase in self.phases if phase.status == PhaseStatus.COMPLETED
}
ready: list[PlanPhase] = []
for phase in self.phases:
if phase.status != PhaseStatus.PENDING:
continue
if all(dep_id in completed_ids for dep_id in phase.depends_on):
ready.append(phase)
return ready
def get_phase(self, phase_id: str) -> PlanPhase | None:
"""根据 ID 获取阶段,不存在则返回 None (O(1) lookup)"""
return self._phase_index.get(phase_id)
def update_phase_status(
self, phase_id: str, status: PhaseStatus, result: dict | None = None
) -> None: ) -> None:
"""更新阶段状态和可选的结果""" """更新子任务状态和可选的结果"""
phase = self.get_phase(phase_id) st = self.get_subtask(subtask_id)
if phase is not None: if st is not None:
phase.status = status st.status = status
if result is not None: if result is not None:
phase.result = result st.result = result
@property
def completed_subtasks(self) -> list[SubTask]:
"""已完成的子任务列表"""
return [st for st in self.subtasks if st.status == SubTaskStatus.COMPLETED]
@property
def failed_subtasks(self) -> list[SubTask]:
"""失败的子任务列表"""
return [st for st in self.subtasks if st.status == SubTaskStatus.FAILED]
@property
def all_done(self) -> bool:
"""所有子任务是否都已完成(成功或失败)"""
return all(
st.status in (SubTaskStatus.COMPLETED, SubTaskStatus.FAILED) for st in self.subtasks
)

View File

@ -4,12 +4,11 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any
import yaml import yaml
from agentkit.core.exceptions import ConfigValidationError from agentkit.core.exceptions import ConfigValidationError
from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.config import ExpertTemplate
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,10 +61,7 @@ class ExpertTemplateRegistry:
query_lower = query.lower() query_lower = query.lower()
results: list[ExpertTemplate] = [] results: list[ExpertTemplate] = []
for template in self._templates.values(): for template in self._templates.values():
if ( if query_lower in template.name.lower() or query_lower in template.description.lower():
query_lower in template.name.lower()
or query_lower in template.description.lower()
):
results.append(template) results.append(template)
return results return results
@ -134,7 +130,5 @@ class ExpertTemplateRegistry:
template = self.load_from_yaml(filepath) template = self.load_from_yaml(filepath)
loaded.append(template) loaded.append(template)
except Exception as e: except Exception as e:
logger.warning( logger.warning(f"Failed to load ExpertTemplate from '{filepath}': {e}")
f"Failed to load ExpertTemplate from '{filepath}': {e}"
)
return loaded return loaded

View File

@ -1,7 +1,13 @@
"""ExpertTeam - 专家团队容器 """ExpertTeam - 专家团队容器hub-and-spoke 模式)
管理 Expert 生命周期共享上下文协作计划和团队状态 管理 Expert 生命周期团队状态和事件广播
Expert Team 协作模式的中央协调点 Expert Team hub-and-spoke 协作模式的中央协调点
简化说明U3
- 移除 CollaborationPlan 依赖Lead Expert 自主分解任务
- 移除跨阶段状态共享Lead Expert 持有所有状态
- 保留 handoff_transport 用于事件广播不再用于 Agent 间通信
- 保留 workspace 用于输出保存不再用于跨阶段状态共享
""" """
from __future__ import annotations from __future__ import annotations
@ -14,7 +20,6 @@ import uuid
from .config import ExpertConfig from .config import ExpertConfig
from .expert import Expert from .expert import Expert
from .plan import CollaborationPlan, PlanStatus
from .registry import ExpertTemplateRegistry from .registry import ExpertTemplateRegistry
from ..core.handoff_transport import InProcessHandoffTransport from ..core.handoff_transport import InProcessHandoffTransport
from ..core.shared_workspace import SharedWorkspace from ..core.shared_workspace import SharedWorkspace
@ -27,7 +32,6 @@ class TeamStatus(str, enum.Enum):
"""ExpertTeam lifecycle states.""" """ExpertTeam lifecycle states."""
FORMING = "forming" FORMING = "forming"
PLANNING = "planning"
EXECUTING = "executing" EXECUTING = "executing"
SYNTHESIZING = "synthesizing" SYNTHESIZING = "synthesizing"
COMPLETED = "completed" COMPLETED = "completed"
@ -35,7 +39,14 @@ class TeamStatus(str, enum.Enum):
class ExpertTeam: class ExpertTeam:
"""Container managing a team of Experts working together on a task.""" """Container managing a team of Experts in hub-and-spoke mode.
In hub-and-spoke mode:
- Lead Expert (hub) receives the task and decomposes it
- Member Experts (spokes) execute subtasks in parallel
- Lead Expert synthesizes the final result
- No inter-agent communication (Lead Expert holds all state)
"""
def __init__( def __init__(
self, self,
@ -51,7 +62,6 @@ class ExpertTeam:
self._handoff_transport = InProcessHandoffTransport() self._handoff_transport = InProcessHandoffTransport()
self._experts: dict[str, Expert] = {} self._experts: dict[str, Expert] = {}
self._lead_expert_name: str | None = None self._lead_expert_name: str | None = None
self._plan: CollaborationPlan | None = None
self._status = TeamStatus.FORMING self._status = TeamStatus.FORMING
self._team_channel = f"team:{self.team_id}" self._team_channel = f"team:{self.team_id}"
self._orchestrator_task: asyncio.Task | None = None self._orchestrator_task: asyncio.Task | None = None
@ -66,10 +76,6 @@ class ExpertTeam:
return self._experts.get(self._lead_expert_name) return self._experts.get(self._lead_expert_name)
return None return None
@property
def plan(self) -> CollaborationPlan | None:
return self._plan
@property @property
def experts(self) -> list[Expert]: def experts(self) -> list[Expert]:
return list(self._experts.values()) return list(self._experts.values())
@ -80,12 +86,21 @@ class ExpertTeam:
@property @property
def workspace(self) -> SharedWorkspace: def workspace(self) -> SharedWorkspace:
"""Public read access to the team's shared workspace.""" """Public read access to the team's shared workspace.
In hub-and-spoke mode, workspace is used for output preservation only,
not for cross-phase state sharing.
"""
return self._workspace return self._workspace
@property @property
def handoff_transport(self): def handoff_transport(self):
"""Public read access to the team's handoff transport.""" """Public read access to the team's handoff transport.
In hub-and-spoke mode, handoff_transport is used for event broadcasting
(team_formed, expert_step, expert_result, team_synthesis, team_dissolved),
not for inter-agent communication.
"""
return self._handoff_transport return self._handoff_transport
@property @property
@ -106,7 +121,13 @@ class ExpertTeam:
lead_config: ExpertConfig, lead_config: ExpertConfig,
member_configs: list[ExpertConfig] | None = None, member_configs: list[ExpertConfig] | None = None,
) -> None: ) -> None:
"""Create a team with a Lead Expert and optional members.""" """Create a team with a Lead Expert and optional members.
In hub-and-spoke mode, the Lead Expert acts as the hub:
- Receives the task and decomposes it into subtasks
- Dispatches subtasks to member experts (spokes)
- Synthesizes the final result
"""
if not self._pool: if not self._pool:
raise RuntimeError("AgentPool not configured") raise RuntimeError("AgentPool not configured")
@ -128,7 +149,7 @@ class ExpertTeam:
for config in member_configs: for config in member_configs:
await self._add_expert_internal(config, team_context) await self._add_expert_internal(config, team_context)
self._status = TeamStatus.PLANNING self._status = TeamStatus.EXECUTING
async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert: async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert:
"""Add an Expert to the team dynamically. """Add an Expert to the team dynamically.
@ -155,9 +176,7 @@ class ExpertTeam:
) )
return await self._add_expert_internal(config, team_context) return await self._add_expert_internal(config, team_context)
async def _add_expert_internal( async def _add_expert_internal(self, config: ExpertConfig, team_context: str) -> Expert:
self, config: ExpertConfig, team_context: str
) -> Expert:
"""Internal method to add an Expert.""" """Internal method to add an Expert."""
if not self._pool: if not self._pool:
raise RuntimeError("AgentPool not configured") raise RuntimeError("AgentPool not configured")
@ -204,13 +223,6 @@ class ExpertTeam:
await expert.destroy(self._pool) await expert.destroy(self._pool)
del self._experts[name] del self._experts[name]
# Update plan: reassign phases that referenced the removed expert
if self._plan:
new_lead_name = self._lead_expert_name
for phase in self._plan.phases:
if phase.assigned_expert == name:
phase.assigned_expert = new_lead_name or ""
# Broadcast expert left # Broadcast expert left
await self._handoff_transport.send( await self._handoff_transport.send(
self._team_channel, self._team_channel,
@ -220,24 +232,6 @@ class ExpertTeam:
}, },
) )
def update_plan(self, plan: CollaborationPlan) -> list[str]:
"""Update the collaboration plan. Only Lead Expert or user should call this.
Returns list of affected expert names on success, or list of validation
error strings on failure (empty list with no errors = success).
"""
errors = plan.validate()
if errors:
return errors # Return validation errors instead of silently swallowing
self._plan = plan
if plan.status == PlanStatus.CONFIRMED:
self._status = TeamStatus.EXECUTING
# Determine affected experts
affected = [p.assigned_expert for p in plan.phases]
return affected
async def broadcast_user_message(self, content: str) -> None: async def broadcast_user_message(self, content: str) -> None:
"""Broadcast a user intervention message to all active Experts.""" """Broadcast a user intervention message to all active Experts."""
message = { message = {
@ -248,7 +242,11 @@ class ExpertTeam:
await self._handoff_transport.send(self._team_channel, message) await self._handoff_transport.send(self._team_channel, message)
async def get_shared_context(self) -> dict: async def get_shared_context(self) -> dict:
"""Get the team's shared context from SharedWorkspace.""" """Get the team's shared context from SharedWorkspace.
In hub-and-spoke mode, this returns preserved outputs only,
not cross-phase state.
"""
context = {} context = {}
keys = await self._workspace.list_keys() keys = await self._workspace.list_keys()
for key in keys: for key in keys:
@ -258,23 +256,6 @@ class ExpertTeam:
context[key] = data context[key] = data
return context return context
async def generate_plan(self, task: str) -> CollaborationPlan:
"""Generate a CollaborationPlan for the task.
Uses hybrid mode: core roles from template registry, auxiliary roles dynamically generated.
This method creates a plan structure the actual LLM-based task decomposition
will be handled by TeamOrchestrator.
"""
plan_id = str(uuid.uuid4())
plan = CollaborationPlan(
id=plan_id,
task=task,
phases=[],
lead_expert=self._lead_expert_name or "",
)
self._plan = plan
return plan
async def dissolve(self) -> None: async def dissolve(self) -> None:
"""Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace.""" """Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace."""
# Cancel ongoing orchestrator task if any # Cancel ongoing orchestrator task if any
@ -302,20 +283,31 @@ class ExpertTeam:
lead_config: ExpertConfig | None, lead_config: ExpertConfig | None,
member_configs: list[ExpertConfig], member_configs: list[ExpertConfig],
) -> str: ) -> str:
"""Build team context string for injection into Expert system prompts.""" """Build team context string for injection into Expert system prompts.
lines = ["You are part of an Expert Team."]
In hub-and-spoke mode, the context emphasizes the Lead Expert's role
as the hub and member experts' role as spokes.
"""
lines = ["You are part of an Expert Team (hub-and-spoke mode)."]
if lead_config: if lead_config:
lines.append(f"Lead Expert: {lead_config.name} ({lead_config.persona})") lines.append(
f"Lead Expert (hub): {lead_config.name} ({lead_config.persona}) — "
f"decomposes tasks, dispatches subtasks, and synthesizes results."
)
for config in member_configs: for config in member_configs:
if lead_config and config.name == lead_config.name: if lead_config and config.name == lead_config.name:
continue continue
lines.append( lines.append(
f"Team Member: {config.name} ({config.persona}), Skills: {', '.join(config.bound_skills)}" f"Team Member (spoke): {config.name} ({config.persona}), "
f"Skills: {', '.join(config.bound_skills)}"
f"executes assigned subtasks independently."
) )
lines.append( lines.append(
"You can collaborate with other team members via send_message() and request_assist()." "In hub-and-spoke mode: Lead Expert holds all state, "
"subtasks are independent (depth=1, no further spawning), "
"no inter-agent communication."
) )
return "\n".join(lines) return "\n".join(lines)

View File

@ -16,6 +16,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
from agentkit.tools.ask_human import AskHumanTool from agentkit.tools.ask_human import AskHumanTool
from agentkit.tools.memory_tool import MemoryTool from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.web_search import WebSearchTool from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -40,6 +42,9 @@ __all__ = [
"MemoryTool", "MemoryTool",
"ShellTool", "ShellTool",
"WebSearchTool", "WebSearchTool",
"RunTestsTool",
"ToolSearchTool",
"ToolSearchIndex",
"HeadroomRetrieveTool", "HeadroomRetrieveTool",
"TerminalSession", "TerminalSession",
"TerminalSessionManager", "TerminalSessionManager",

View File

@ -0,0 +1,189 @@
"""Tool search index using BM25 algorithm.
Provides lexical tool discovery for tiered tool description injection:
core tools are fully injected into the LLM prompt, while extended tools
are searchable on-demand via the ``tool_search`` tool.
Pure Python implementation no external dependencies.
"""
from __future__ import annotations
import math
import re
from collections import Counter
from typing import Any
from agentkit.tools.base import Tool
__all__ = ["ToolSearchIndex"]
_TOKEN_RE = re.compile(r"[a-zA-Z0-9_]+")
def _tokenize(text: str) -> list[str]:
"""Tokenize text into lowercase alphanumeric tokens.
Splits on non-alphanumeric characters (keeping underscores so that
snake_case tool names like ``read_file`` stay intact) and lowercases.
Empty tokens are filtered out.
"""
return [t for t in _TOKEN_RE.findall(text.lower()) if t]
class ToolSearchIndex:
"""BM25-based search index over :class:`Tool` descriptions.
Each tool is indexed by its ``name`` + ``description`` + parameter
metadata (from ``input_schema``) + ``tags``. Supports keyword search
returning the most relevant tools ranked by BM25 score.
Usage::
index = ToolSearchIndex(tools)
results = index.search("read file", top_k=5)
"""
_DEFAULT_K1: float = 1.5
_DEFAULT_B: float = 0.75
def __init__(
self,
tools: list[Tool],
k1: float = _DEFAULT_K1,
b: float = _DEFAULT_B,
):
if k1 < 0:
raise ValueError(f"k1 must be >= 0, got {k1}")
if not (0.0 <= b <= 1.0):
raise ValueError(f"b must be in [0, 1], got {b}")
self._tools: list[Tool] = list(tools)
self._k1 = k1
self._b = b
# Build document corpus from tool metadata
self._documents: list[str] = [self._tool_to_text(t) for t in self._tools]
self._tokenized_docs: list[list[str]] = [_tokenize(doc) for doc in self._documents]
self._doc_lengths: list[int] = [len(tokens) for tokens in self._tokenized_docs]
self._avgdl: float = (
sum(self._doc_lengths) / len(self._doc_lengths) if self._doc_lengths else 0.0
)
# Term frequencies per document
self._term_freqs: list[Counter[str]] = [Counter(tokens) for tokens in self._tokenized_docs]
# Document frequency per term (number of docs containing the term)
self._doc_freq: Counter[str] = Counter()
for tokens in self._tokenized_docs:
for term in set(tokens):
self._doc_freq[term] += 1
# Precompute IDF for every term in the corpus
self._idf: dict[str, float] = self._compute_idf()
# ------------------------------------------------------------------
# Index construction
# ------------------------------------------------------------------
@staticmethod
def _tool_to_text(tool: Tool) -> str:
"""Convert a tool's searchable metadata into a single text document."""
parts: list[str] = [str(tool.name), str(tool.description)]
schema: dict[str, Any] | None = tool.input_schema
if schema:
props = schema.get("properties", {})
for pname, pinfo in props.items():
parts.append(str(pname))
if isinstance(pinfo, dict):
pdesc = pinfo.get("description", "")
if pdesc:
parts.append(str(pdesc))
if tool.tags:
parts.extend(str(t) for t in tool.tags)
return " ".join(parts)
def _compute_idf(self) -> dict[str, float]:
"""Compute IDF for each term using the BM25 formula.
``idf(t) = ln((N - n + 0.5) / (n + 0.5) + 1)``
where ``N`` is the total number of documents and ``n`` is the
number of documents containing term ``t``. The ``+1`` inside the
logarithm guarantees a non-negative IDF.
"""
n_docs = len(self._tools)
if n_docs == 0:
return {}
idf: dict[str, float] = {}
for term, df in self._doc_freq.items():
idf[term] = math.log((n_docs - df + 0.5) / (df + 0.5) + 1.0)
return idf
# ------------------------------------------------------------------
# Scoring
# ------------------------------------------------------------------
def _bm25_score(self, query_tokens: list[str], doc_idx: int) -> float:
"""Compute the BM25 score for a single document against the query."""
if not query_tokens or doc_idx >= len(self._tools):
return 0.0
doc_len = self._doc_lengths[doc_idx]
term_freq = self._term_freqs[doc_idx]
norm = (
(1 - self._b + self._b * (doc_len / self._avgdl)) if self._avgdl > 0 else (1 - self._b)
)
score = 0.0
# Each unique query term contributes once
for term in set(query_tokens):
tf = term_freq.get(term, 0)
if tf == 0:
continue
idf = self._idf.get(term, 0.0)
if idf <= 0:
continue
denom = tf + self._k1 * norm
score += idf * (tf * (self._k1 + 1)) / denom
return score
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def search(self, query: str, top_k: int = 5) -> list[Tool]:
"""Search tools by keyword query.
Args:
query: Search keywords (natural language or tool-name fragments).
top_k: Maximum number of tools to return.
Returns:
Tools ranked by relevance (most relevant first). Only tools
with a positive BM25 score are returned.
"""
if top_k <= 0 or not self._tools:
return []
query_tokens = _tokenize(query)
if not query_tokens:
return []
scored: list[tuple[int, float]] = []
for i in range(len(self._tools)):
score = self._bm25_score(query_tokens, i)
if score > 0:
scored.append((i, score))
scored.sort(key=lambda x: x[1], reverse=True)
return [self._tools[i] for i, _ in scored[:top_k]]
def __len__(self) -> int:
return len(self._tools)

View File

@ -0,0 +1,666 @@
"""Tests for EventQueue — SQ/EQ 双队列实现
测试场景
- SQ 正确接收用户输入并返回 task_id
- EQ 正确推送事件给订阅者
- 多订阅者同时接收事件广播
- 事件缓冲对新订阅者的回放
- SQ 取消任务
- 事件类型正确分类
"""
from __future__ import annotations
import asyncio
from datetime import datetime
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
from agentkit.core.protocol import (
Event,
SessionEventType,
TaskEventType,
TurnEventType,
)
# ── SubmissionQueue Tests ───────────────────────────────────────
class TestSubmissionQueue:
"""SubmissionQueue 单元测试"""
async def test_submit_returns_task_id(self):
"""测试 submit 返回有效的 task_id"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
assert isinstance(task_id, str)
assert len(task_id) > 0
async def test_submit_returns_unique_task_ids(self):
"""测试每次 submit 返回不同的 task_id"""
sq = SubmissionQueue()
task_id_1 = await sq.submit("hello", "session-1")
task_id_2 = await sq.submit("world", "session-1")
assert task_id_1 != task_id_2
async def test_submit_stores_submission(self):
"""测试 submit 正确存储提交内容"""
sq = SubmissionQueue()
task_id = await sq.submit("hello world", "session-1")
assert task_id in sq._submissions
submission = sq._submissions[task_id]
assert submission.content == "hello world"
assert submission.session_id == "session-1"
assert submission.task_id == task_id
assert submission.cancelled is False
async def test_drain_receives_submissions_in_order(self):
"""测试 drain 按提交顺序接收提交"""
sq = SubmissionQueue()
await sq.submit("first", "session-1")
await sq.submit("second", "session-1")
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
if len(received) >= 2:
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert received == ["first", "second"]
async def test_drain_preserves_submission_fields(self):
"""测试 drain 返回的 Submission 字段完整"""
sq = SubmissionQueue()
await sq.submit("hello", "session-1")
received: list[Submission] = []
async def consumer():
async for submission in sq.drain():
received.append(submission)
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert len(received) == 1
sub = received[0]
assert sub.content == "hello"
assert sub.session_id == "session-1"
assert isinstance(sub.task_id, str)
assert isinstance(sub.created_at, datetime)
async def test_cancel_task_succeeds(self):
"""测试取消已存在的任务"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
result = await sq.cancel(task_id)
assert result is True
assert task_id in sq._cancelled_tasks
assert sq._submissions[task_id].cancelled is True
async def test_cancel_nonexistent_task_returns_false(self):
"""测试取消不存在的任务返回 False"""
sq = SubmissionQueue()
result = await sq.cancel("nonexistent-task-id")
assert result is False
async def test_cancel_already_cancelled_task_returns_false(self):
"""测试重复取消返回 False"""
sq = SubmissionQueue()
task_id = await sq.submit("hello", "session-1")
first_cancel = await sq.cancel(task_id)
second_cancel = await sq.cancel(task_id)
assert first_cancel is True
assert second_cancel is False
async def test_drain_skips_cancelled_submissions(self):
"""测试 drain 跳过已取消的提交"""
sq = SubmissionQueue()
task_id_1 = await sq.submit("first", "session-1")
await sq.submit("second", "session-1")
# 取消第一个提交
await sq.cancel(task_id_1)
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
if len(received) >= 1:
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
# 应只收到第二个提交
assert received == ["second"]
async def test_close_prevents_new_submissions(self):
"""测试关闭后不再接受新提交"""
sq = SubmissionQueue()
sq.close()
assert sq.is_closed is True
try:
await sq.submit("hello", "session-1")
raise AssertionError("Should have raised RuntimeError")
except RuntimeError:
pass
async def test_close_does_not_affect_existing_submissions(self):
"""测试关闭后已提交的内容仍可消费"""
sq = SubmissionQueue()
await sq.submit("before-close", "session-1")
sq.close()
received: list[str] = []
async def consumer():
async for submission in sq.drain():
received.append(submission.content)
break
consumer_task = asyncio.create_task(consumer())
await asyncio.wait_for(consumer_task, timeout=1.0)
assert received == ["before-close"]
# ── EventQueue Tests ────────────────────────────────────────────
class TestEventQueue:
"""EventQueue 单元测试"""
async def test_emit_and_subscribe_single_event(self):
"""测试 EQ 正确推送事件给订阅者"""
eq = EventQueue()
event = Event.create(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
)
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05) # 给订阅者启动时间
await eq.emit(event)
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 1
assert received[0].event_type == TurnEventType.TOKEN
assert received[0].task_id == "task-1"
assert received[0].session_id == "session-1"
assert received[0].data == {"text": "hello"}
async def test_emit_preserves_event_fields(self):
"""测试 emit 不修改事件字段"""
eq = EventQueue()
original = Event.create(
event_type=TaskEventType.TASK_STARTED,
task_id="task-1",
session_id="session-1",
data={"agent": "react"},
)
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
await eq.emit(original)
await asyncio.wait_for(sub_task, timeout=1.0)
assert received[0] is original or received[0].to_dict() == original.to_dict()
async def test_broadcast_to_multiple_subscribers(self):
"""测试多订阅者同时接收事件(广播)"""
eq = EventQueue()
received_a: list[Event] = []
received_b: list[Event] = []
async def subscriber_a():
async for evt in eq.subscribe():
received_a.append(evt)
if len(received_a) >= 2:
break
async def subscriber_b():
async for evt in eq.subscribe():
received_b.append(evt)
if len(received_b) >= 2:
break
task_a = asyncio.create_task(subscriber_a())
task_b = asyncio.create_task(subscriber_b())
await asyncio.sleep(0.05) # 给订阅者启动时间
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
await asyncio.wait_for(task_a, timeout=1.0)
await asyncio.wait_for(task_b, timeout=1.0)
assert len(received_a) == 2
assert len(received_b) == 2
assert received_a[0].data == {"seq": 1}
assert received_a[1].data == {"seq": 2}
assert received_b[0].data == {"seq": 1}
assert received_b[1].data == {"seq": 2}
async def test_buffer_replay_for_new_subscriber(self):
"""测试事件缓冲对新订阅者的回放"""
eq = EventQueue(buffer_size=100)
# 先发送几条事件(无订阅者)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
# 新订阅者应收到缓冲回放
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 3:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 3
assert received[0].data == {"seq": 1}
assert received[1].data == {"seq": 2}
assert received[2].data == {"seq": 3}
async def test_buffer_replay_then_live_events(self):
"""测试新订阅者先收到回放,再收到新事件"""
eq = EventQueue(buffer_size=100)
# 先发送 2 条事件(进入缓冲)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 4:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲)
# 再发送 2 条新事件
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4}))
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 4
assert [r.data["seq"] for r in received] == [1, 2, 3, 4]
async def test_buffer_size_limit_keeps_latest(self):
"""测试缓冲区大小限制,只保留最新 N 条"""
eq = EventQueue(buffer_size=3)
# 发送 5 条事件,缓冲区只保留最后 3 条
for i in range(5):
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i}))
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
if len(received) >= 3:
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 3
# 应该是最后 3 条seq: 2, 3, 4
assert [r.data["seq"] for r in received] == [2, 3, 4]
async def test_default_buffer_size_is_100(self):
"""测试默认缓冲区大小为 100"""
eq = EventQueue()
assert eq.buffer_size == 100
async def test_close_unblocks_subscribers(self):
"""测试 close 解除订阅者阻塞"""
eq = EventQueue()
async def subscriber():
async for _ in eq.subscribe():
pass # 消费事件直到队列关闭
sub_task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
eq.close()
await asyncio.wait_for(sub_task, timeout=1.0)
assert sub_task.done()
assert eq.is_closed is True
async def test_subscribe_after_close_returns_immediately(self):
"""测试关闭后订阅立即返回(不阻塞)"""
eq = EventQueue()
eq.close()
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert sub_task.done()
assert len(received) == 0
async def test_subscriber_count_tracks_subscriptions(self):
"""测试订阅者计数正确跟踪订阅"""
eq = EventQueue()
assert eq.subscriber_count == 0
async def subscriber():
async for _ in eq.subscribe():
pass
task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
assert eq.subscriber_count == 1
eq.close()
await asyncio.wait_for(task, timeout=1.0)
assert eq.subscriber_count == 0
async def test_subscriber_removed_on_explicit_close(self):
"""测试显式关闭订阅生成器后从列表移除
注意async for break 不会立即触发生成器的 finally
需要显式调用 aclose() 才能保证清理
"""
eq = EventQueue()
received: list[Event] = []
async def subscriber():
gen = eq.subscribe()
try:
async for evt in gen:
received.append(evt)
break
finally:
await gen.aclose()
task = asyncio.create_task(subscriber())
await asyncio.sleep(0.05)
assert eq.subscriber_count == 1
# 触发一次 emit 让订阅者能收到事件并 break
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
await asyncio.wait_for(task, timeout=1.0)
assert len(received) == 1
assert eq.subscriber_count == 0
async def test_emit_to_no_subscribers_still_buffers(self):
"""测试无订阅者时 emit 仍写入缓冲区"""
eq = EventQueue(buffer_size=100)
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
# 缓冲区应有 1 条
assert len(eq._buffer) == 1
# 新订阅者应收到回放
received: list[Event] = []
async def subscriber():
async for evt in eq.subscribe():
received.append(evt)
break
sub_task = asyncio.create_task(subscriber())
await asyncio.wait_for(sub_task, timeout=1.0)
assert len(received) == 1
assert received[0].data == {"seq": 1}
# ── Event Type Tests ────────────────────────────────────────────
class TestEventTypes:
"""事件类型分类测试"""
def test_session_event_types(self):
"""测试 Session 级别事件类型"""
assert SessionEventType.SESSION_STARTED == "session.started"
assert SessionEventType.SESSION_ENDED == "session.ended"
def test_task_event_types(self):
"""测试 Task 级别事件类型"""
assert TaskEventType.TASK_CREATED == "task.created"
assert TaskEventType.TASK_STARTED == "task.started"
assert TaskEventType.TASK_COMPLETED == "task.completed"
assert TaskEventType.TASK_FAILED == "task.failed"
def test_turn_event_types(self):
"""测试 Turn 级别事件类型"""
assert TurnEventType.TURN_STARTED == "turn.started"
assert TurnEventType.THINKING == "turn.thinking"
assert TurnEventType.TOOL_CALL == "turn.tool_call"
assert TurnEventType.TOOL_RESULT == "turn.tool_result"
assert TurnEventType.TOKEN == "turn.token"
assert TurnEventType.STEP == "turn.step"
assert TurnEventType.FINAL_ANSWER == "turn.final_answer"
assert TurnEventType.TURN_COMPLETED == "turn.completed"
def test_event_type_prefixes(self):
"""测试事件类型按前缀正确分类"""
session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED]
task_events = [
TaskEventType.TASK_CREATED,
TaskEventType.TASK_STARTED,
TaskEventType.TASK_COMPLETED,
TaskEventType.TASK_FAILED,
]
turn_events = [
TurnEventType.TURN_STARTED,
TurnEventType.THINKING,
TurnEventType.TOOL_CALL,
TurnEventType.TOOL_RESULT,
TurnEventType.TOKEN,
TurnEventType.STEP,
TurnEventType.FINAL_ANSWER,
TurnEventType.TURN_COMPLETED,
]
for evt in session_events:
assert evt.startswith("session."), f"{evt} should start with 'session.'"
for evt in task_events:
assert evt.startswith("task."), f"{evt} should start with 'task.'"
for evt in turn_events:
assert evt.startswith("turn."), f"{evt} should start with 'turn.'"
def test_event_types_are_distinct(self):
"""测试所有事件类型互不相同"""
all_types = [
SessionEventType.SESSION_STARTED,
SessionEventType.SESSION_ENDED,
TaskEventType.TASK_CREATED,
TaskEventType.TASK_STARTED,
TaskEventType.TASK_COMPLETED,
TaskEventType.TASK_FAILED,
TurnEventType.TURN_STARTED,
TurnEventType.THINKING,
TurnEventType.TOOL_CALL,
TurnEventType.TOOL_RESULT,
TurnEventType.TOKEN,
TurnEventType.STEP,
TurnEventType.FINAL_ANSWER,
TurnEventType.TURN_COMPLETED,
]
assert len(all_types) == len(set(all_types)), "Event types should be distinct"
# ── Event Dataclass Tests ───────────────────────────────────────
class TestEventDataclass:
"""Event 数据结构测试"""
def test_event_creation(self):
"""测试 Event 创建"""
event = Event(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
timestamp="2025-01-01T00:00:00+00:00",
)
assert event.event_type == TurnEventType.TOKEN
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"text": "hello"}
assert event.timestamp == "2025-01-01T00:00:00+00:00"
def test_event_create_factory_generates_timestamp(self):
"""测试 Event.create 工厂方法自动生成时间戳"""
event = Event.create(
event_type=TaskEventType.TASK_STARTED,
task_id="task-1",
session_id="session-1",
data={"agent": "react"},
)
assert event.event_type == TaskEventType.TASK_STARTED
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"agent": "react"}
assert len(event.timestamp) > 0
# 时间戳应为 ISO 8601 格式fromisoformat 不抛异常即正确)
datetime.fromisoformat(event.timestamp)
def test_event_create_with_default_data(self):
"""测试 Event.create 不传 data 时默认为空 dict"""
event = Event.create(
event_type=SessionEventType.SESSION_STARTED,
task_id="task-1",
session_id="session-1",
)
assert event.data == {}
def test_event_to_dict(self):
"""测试 Event.to_dict"""
event = Event(
event_type=TurnEventType.TOKEN,
task_id="task-1",
session_id="session-1",
data={"text": "hello"},
timestamp="2025-01-01T00:00:00+00:00",
)
d = event.to_dict()
assert d == {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"data": {"text": "hello"},
"timestamp": "2025-01-01T00:00:00+00:00",
}
def test_event_from_dict(self):
"""测试 Event.from_dict"""
data = {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"data": {"text": "hello"},
"timestamp": "2025-01-01T00:00:00+00:00",
}
event = Event.from_dict(data)
assert event.event_type == TurnEventType.TOKEN
assert event.task_id == "task-1"
assert event.session_id == "session-1"
assert event.data == {"text": "hello"}
assert event.timestamp == "2025-01-01T00:00:00+00:00"
def test_event_to_dict_from_dict_roundtrip(self):
"""测试 to_dict -> from_dict 往返保持一致"""
original = Event.create(
event_type=SessionEventType.SESSION_STARTED,
task_id="task-1",
session_id="session-1",
data={"user": "alice"},
)
d = original.to_dict()
restored = Event.from_dict(d)
assert restored.event_type == original.event_type
assert restored.task_id == original.task_id
assert restored.session_id == original.session_id
assert restored.data == original.data
assert restored.timestamp == original.timestamp
def test_event_from_dict_with_missing_data_defaults_to_empty(self):
"""测试 from_dict 缺少 data 字段时默认为空 dict"""
data = {
"event_type": TurnEventType.TOKEN,
"task_id": "task-1",
"session_id": "session-1",
"timestamp": "2025-01-01T00:00:00+00:00",
}
event = Event.from_dict(data)
assert event.data == {}

View File

@ -1,328 +1,288 @@
"""CollaborationPlan 数据模型单元测试""" """TeamPlan / SubTask 数据模型单元测试 (hub-and-spoke 模式)"""
from __future__ import annotations from __future__ import annotations
import pytest
from agentkit.experts.plan import ( from agentkit.experts.plan import (
CollaborationPlan,
MergeStrategy, MergeStrategy,
ParallelType,
PhaseStatus,
PlanPhase,
PlanStatus, PlanStatus,
SubTask,
SubTaskStatus,
TeamPlan,
) )
# ── 辅助函数 ────────────────────────────────────────────── # ── 辅助函数 ──────────────────────────────────────────────
def _make_phase( def _make_subtask(
id: str = "phase_1", id: str = "subtask_1",
name: str = "分析阶段", description: str = "分析数据",
assigned_expert: str = "analyst", assigned_expert: str = "analyst",
task_description: str = "分析需求", status: SubTaskStatus = SubTaskStatus.PENDING,
depends_on: list[str] | None = None,
parallel_type: ParallelType = ParallelType.SERIAL,
merge_strategy: MergeStrategy | None = None,
milestone: str = "",
status: PhaseStatus = PhaseStatus.PENDING,
result: dict | None = None, result: dict | None = None,
) -> PlanPhase: ) -> SubTask:
"""创建测试用 PlanPhase 实例""" """创建测试用 SubTask 实例"""
return PlanPhase( return SubTask(
id=id, id=id,
name=name, description=description,
assigned_expert=assigned_expert, assigned_expert=assigned_expert,
task_description=task_description,
depends_on=depends_on or [],
parallel_type=parallel_type,
merge_strategy=merge_strategy,
milestone=milestone,
status=status, status=status,
result=result, result=result,
) )
def _make_valid_plan() -> CollaborationPlan: def _make_valid_plan() -> TeamPlan:
"""创建一个有效的协作计划""" """创建一个有效的 hub-and-spoke 执行计划"""
phases = [ subtasks = [
_make_phase(id="p1", name="需求分析", assigned_expert="analyst", task_description="分析需求"), _make_subtask(id="s1", description="分析需求", assigned_expert="analyst"),
_make_phase( _make_subtask(id="s2", description="设计架构", assigned_expert="architect"),
id="p2", _make_subtask(id="s3", description="编写代码", assigned_expert="coder"),
name="架构设计",
assigned_expert="architect",
task_description="设计架构",
depends_on=["p1"],
),
_make_phase(
id="p3",
name="代码实现",
assigned_expert="coder",
task_description="编写代码",
depends_on=["p2"],
),
] ]
return CollaborationPlan( return TeamPlan(
id="plan_001", id="plan_001",
task="实现用户登录功能", task="实现用户登录功能",
phases=phases, subtasks=subtasks,
variables={"project": "fischer"},
status=PlanStatus.DRAFT, status=PlanStatus.DRAFT,
lead_expert="architect", lead_expert="architect",
) )
# ── PlanPhase 测试 ──────────────────────────────────────── # ── MergeStrategy 测试 ────────────────────────────────────
class TestPlanPhase: class TestMergeStrategy:
"""PlanPhase 数据模型测试""" """MergeStrategy 枚举测试"""
def test_only_best_strategy_exists(self):
"""hub-and-spoke 模式仅保留 BEST 策略"""
assert MergeStrategy.BEST == "best"
# 确保只有 BEST 一个值
assert len(list(MergeStrategy)) == 1
def test_no_vote_or_fusion(self):
"""VOTE 和 FUSION 已被移除"""
assert not hasattr(MergeStrategy, "VOTE")
assert not hasattr(MergeStrategy, "FUSION")
# ── PlanStatus 测试 ───────────────────────────────────────
class TestPlanStatus:
"""PlanStatus 枚举测试"""
def test_statuses_exist(self):
"""必要的计划状态都存在"""
assert PlanStatus.DRAFT == "draft"
assert PlanStatus.EXECUTING == "executing"
assert PlanStatus.COMPLETED == "completed"
assert PlanStatus.FAILED == "failed"
assert PlanStatus.FALLBACK == "fallback"
def test_no_confirmed_status(self):
"""CONFIRMED 状态已被移除hub-and-spoke 无需确认阶段)"""
assert not hasattr(PlanStatus, "CONFIRMED")
# ── SubTaskStatus 测试 ────────────────────────────────────
class TestSubTaskStatus:
"""SubTaskStatus 枚举测试"""
def test_statuses_exist(self):
"""子任务状态都存在"""
assert SubTaskStatus.PENDING == "pending"
assert SubTaskStatus.RUNNING == "running"
assert SubTaskStatus.COMPLETED == "completed"
assert SubTaskStatus.FAILED == "failed"
# ── SubTask 测试 ──────────────────────────────────────────
class TestSubTask:
"""SubTask 数据模型测试"""
def test_creation_with_all_fields(self): def test_creation_with_all_fields(self):
"""创建 PlanPhase 并设置所有字段""" """创建 SubTask 并设置所有字段"""
phase = PlanPhase( subtask = SubTask(
id="phase_a", id="subtask_a",
name="竞品分析", description="竞品分析",
assigned_expert="analyst", assigned_expert="analyst",
task_description="分析竞品功能", status=SubTaskStatus.RUNNING,
depends_on=["phase_0"],
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=MergeStrategy.BEST,
milestone="竞品报告完成",
status=PhaseStatus.IN_PROGRESS,
result={"report": "竞品分析报告"}, result={"report": "竞品分析报告"},
) )
assert phase.id == "phase_a" assert subtask.id == "subtask_a"
assert phase.name == "竞品分析" assert subtask.description == "竞品分析"
assert phase.assigned_expert == "analyst" assert subtask.assigned_expert == "analyst"
assert phase.task_description == "分析竞品功能" assert subtask.status == SubTaskStatus.RUNNING
assert phase.depends_on == ["phase_0"] assert subtask.result == {"report": "竞品分析报告"}
assert phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
assert phase.merge_strategy == MergeStrategy.BEST def test_default_values(self):
assert phase.milestone == "竞品报告完成" """默认值:自动生成 idPENDING 状态"""
assert phase.status == PhaseStatus.IN_PROGRESS subtask = SubTask(description="测试任务")
assert phase.result == {"report": "竞品分析报告"} assert subtask.id is not None
assert len(subtask.id) > 0
assert subtask.description == "测试任务"
assert subtask.assigned_expert == ""
assert subtask.status == SubTaskStatus.PENDING
assert subtask.result is None
def test_to_dict_from_dict_roundtrip(self): def test_to_dict_from_dict_roundtrip(self):
"""to_dict / from_dict 往返序列化""" """to_dict / from_dict 往返序列化"""
phase = PlanPhase( subtask = SubTask(
id="roundtrip_phase", id="roundtrip_subtask",
name="往返测试", description="往返测试",
assigned_expert="tester", assigned_expert="tester",
task_description="测试序列化", status=SubTaskStatus.COMPLETED,
depends_on=["dep_a", "dep_b"],
parallel_type=ParallelType.SUBTASK_PARALLEL,
merge_strategy=MergeStrategy.VOTE,
milestone="序列化验证",
status=PhaseStatus.COMPLETED,
result={"key": "value"}, result={"key": "value"},
) )
d = phase.to_dict() d = subtask.to_dict()
restored = PlanPhase.from_dict(d) restored = SubTask.from_dict(d)
assert restored.id == phase.id assert restored.id == subtask.id
assert restored.name == phase.name assert restored.description == subtask.description
assert restored.assigned_expert == phase.assigned_expert assert restored.assigned_expert == subtask.assigned_expert
assert restored.task_description == phase.task_description assert restored.status == subtask.status
assert restored.depends_on == phase.depends_on assert restored.result == subtask.result
assert restored.parallel_type == phase.parallel_type
assert restored.merge_strategy == phase.merge_strategy
assert restored.milestone == phase.milestone
assert restored.status == phase.status
assert restored.result == phase.result
def test_to_dict_from_dict_with_none_merge_strategy(self): def test_to_dict_structure(self):
"""merge_strategy 为 None 时的序列化往返""" """to_dict 返回正确的字典结构"""
phase = PlanPhase( subtask = _make_subtask(
id="no_merge", id="struct_test",
name="无合并", description="结构测试",
assigned_expert="dev", assigned_expert="dev",
task_description="串行任务", status=SubTaskStatus.RUNNING,
parallel_type=ParallelType.SERIAL, result={"output": "data"},
) )
d = phase.to_dict() d = subtask.to_dict()
assert d["merge_strategy"] is None assert d["id"] == "struct_test"
restored = PlanPhase.from_dict(d) assert d["description"] == "结构测试"
assert restored.merge_strategy is None assert d["assigned_expert"] == "dev"
assert d["status"] == "running"
assert d["result"] == {"output": "data"}
# ── CollaborationPlan 测试 ──────────────────────────────── # ── TeamPlan 测试 ─────────────────────────────────────────
class TestCollaborationPlan: class TestTeamPlan:
"""CollaborationPlan 数据模型测试""" """TeamPlan 数据模型测试"""
def test_creation(self): def test_creation(self):
"""创建 CollaborationPlan""" """创建 TeamPlan"""
plan = _make_valid_plan() plan = _make_valid_plan()
assert plan.id == "plan_001" assert plan.id == "plan_001"
assert plan.task == "实现用户登录功能" assert plan.task == "实现用户登录功能"
assert len(plan.phases) == 3 assert len(plan.subtasks) == 3
assert plan.variables == {"project": "fischer"}
assert plan.status == PlanStatus.DRAFT assert plan.status == PlanStatus.DRAFT
assert plan.lead_expert == "architect" assert plan.lead_expert == "architect"
def test_default_values(self):
"""默认值:自动生成 id空子任务列表"""
plan = TeamPlan(task="测试任务")
assert plan.id is not None
assert plan.task == "测试任务"
assert plan.subtasks == []
assert plan.status == PlanStatus.DRAFT
assert plan.lead_expert == ""
def test_to_dict_from_dict_roundtrip(self): def test_to_dict_from_dict_roundtrip(self):
"""to_dict / from_dict 往返序列化""" """to_dict / from_dict 往返序列化"""
plan = _make_valid_plan() plan = _make_valid_plan()
d = plan.to_dict() d = plan.to_dict()
restored = CollaborationPlan.from_dict(d) restored = TeamPlan.from_dict(d)
assert restored.id == plan.id assert restored.id == plan.id
assert restored.task == plan.task assert restored.task == plan.task
assert len(restored.phases) == len(plan.phases) assert len(restored.subtasks) == len(plan.subtasks)
assert restored.variables == plan.variables
assert restored.status == plan.status assert restored.status == plan.status
assert restored.lead_expert == plan.lead_expert assert restored.lead_expert == plan.lead_expert
for original, restored_phase in zip(plan.phases, restored.phases): for original, restored_st in zip(plan.subtasks, restored.subtasks):
assert restored_phase.id == original.id assert restored_st.id == original.id
assert restored_phase.name == original.name assert restored_st.description == original.description
assert restored_phase.assigned_expert == original.assigned_expert assert restored_st.assigned_expert == original.assigned_expert
assert restored_phase.depends_on == original.depends_on assert restored_st.status == original.status
assert restored_phase.parallel_type == original.parallel_type
assert restored_phase.merge_strategy == original.merge_strategy
def test_validate_valid_plan(self): def test_get_subtask_by_id(self):
"""验证有效计划无错误""" """get_subtask 根据 ID 获取子任务"""
plan = _make_valid_plan() plan = _make_valid_plan()
errors = plan.validate() st = plan.get_subtask("s2")
assert errors == [] assert st is not None
assert st.id == "s2"
assert st.description == "设计架构"
def test_validate_detects_duplicate_phase_ids(self): def test_get_subtask_with_nonexistent_id_returns_none(self):
"""验证检测到重复阶段 ID""" """get_subtask 对不存在的 ID 返回 None"""
phases = [
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
_make_phase(id="p1", name="阶段2", assigned_expert="b", task_description="t2"),
]
plan = CollaborationPlan(
id="dup_plan", task="重复ID测试", phases=phases, lead_expert="a"
)
errors = plan.validate()
assert any("重复的阶段 ID" in e for e in errors)
def test_validate_detects_missing_depends_on_references(self):
"""验证检测到不存在的 depends_on 引用"""
phases = [
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
_make_phase(
id="p2",
name="阶段2",
assigned_expert="b",
task_description="t2",
depends_on=["p1", "nonexistent"],
),
]
plan = CollaborationPlan(
id="missing_dep_plan", task="缺失依赖测试", phases=phases, lead_expert="a"
)
errors = plan.validate()
assert any("不存在的阶段 ID" in e for e in errors)
def test_validate_detects_circular_dependencies(self):
"""验证检测到循环依赖"""
phases = [
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1", depends_on=["p3"]),
_make_phase(id="p2", name="阶段2", assigned_expert="b", task_description="t2", depends_on=["p1"]),
_make_phase(id="p3", name="阶段3", assigned_expert="c", task_description="t3", depends_on=["p2"]),
]
plan = CollaborationPlan(
id="cycle_plan", task="循环依赖测试", phases=phases, lead_expert="a"
)
errors = plan.validate()
assert any("循环依赖" in e for e in errors)
def test_validate_detects_competitive_parallel_without_merge_strategy(self):
"""验证检测到 COMPETITIVE_PARALLEL 缺少 merge_strategy"""
phases = [
_make_phase(
id="p1",
name="竞争阶段",
assigned_expert="a",
task_description="竞争任务",
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=None,
),
]
plan = CollaborationPlan(
id="no_merge_plan", task="缺少合并策略测试", phases=phases, lead_expert="a"
)
errors = plan.validate()
assert any("COMPETITIVE_PARALLEL" in e and "merge_strategy" in e for e in errors)
def test_get_ready_phases_returns_phases_with_completed_dependencies(self):
"""get_ready_phases 返回依赖已完成的阶段"""
plan = _make_valid_plan() plan = _make_valid_plan()
# 初始状态p1 无依赖,应该就绪 assert plan.get_subtask("nonexistent") is None
ready = plan.get_ready_phases()
assert len(ready) == 1
assert ready[0].id == "p1"
# 完成 p1 后p2 应该就绪 def test_update_subtask_status(self):
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"analysis": "done"}) """update_subtask_status 更新子任务状态和结果"""
ready = plan.get_ready_phases()
assert len(ready) == 1
assert ready[0].id == "p2"
# 完成 p2 后p3 应该就绪
plan.update_phase_status("p2", PhaseStatus.COMPLETED, {"design": "done"})
ready = plan.get_ready_phases()
assert len(ready) == 1
assert ready[0].id == "p3"
def test_get_ready_phases_returns_empty_when_dependencies_not_met(self):
"""get_ready_phases 在依赖未满足时返回空列表"""
phases = [
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
_make_phase(
id="p2",
name="阶段2",
assigned_expert="b",
task_description="t2",
depends_on=["p1"],
),
]
plan = CollaborationPlan(
id="dep_plan", task="依赖未满足测试", phases=phases, lead_expert="a"
)
# p2 依赖 p1p1 未完成,所以 p2 不就绪
# 但 p1 无依赖,所以 p1 就绪
ready = plan.get_ready_phases()
assert len(ready) == 1
assert ready[0].id == "p1"
# 将 p1 设为 IN_PROGRESS未 COMPLETEDp2 仍不就绪
plan.update_phase_status("p1", PhaseStatus.IN_PROGRESS)
ready = plan.get_ready_phases()
assert len(ready) == 0
def test_update_phase_status(self):
"""update_phase_status 更新阶段状态和结果"""
plan = _make_valid_plan() plan = _make_valid_plan()
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"output": "分析完成"}) plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "分析完成"})
phase = plan.get_phase("p1") st = plan.get_subtask("s1")
assert phase is not None assert st is not None
assert phase.status == PhaseStatus.COMPLETED assert st.status == SubTaskStatus.COMPLETED
assert phase.result == {"output": "分析完成"} assert st.result == {"output": "分析完成"}
# 不传 result 时不应覆盖已有 result def test_update_subtask_status_without_result(self):
plan.update_phase_status("p2", PhaseStatus.IN_PROGRESS) """update_subtask_status 不传 result 时不覆盖已有 result"""
phase2 = plan.get_phase("p2")
assert phase2 is not None
assert phase2.status == PhaseStatus.IN_PROGRESS
assert phase2.result is None
def test_get_phase_by_id(self):
"""get_phase 根据 ID 获取阶段"""
plan = _make_valid_plan() plan = _make_valid_plan()
phase = plan.get_phase("p2") plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "done"})
assert phase is not None plan.update_subtask_status("s1", SubTaskStatus.RUNNING)
assert phase.id == "p2" st = plan.get_subtask("s1")
assert phase.name == "架构设计" assert st is not None
assert st.status == SubTaskStatus.RUNNING
assert st.result == {"output": "done"}
def test_get_phase_with_nonexistent_id_returns_none(self): def test_completed_subtasks_property(self):
"""get_phase 对不存在的 ID 返回 None""" """completed_subtasks 返回已完成的子任务"""
plan = _make_valid_plan() plan = _make_valid_plan()
phase = plan.get_phase("nonexistent") plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
assert phase is None plan.update_subtask_status("s2", SubTaskStatus.COMPLETED)
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
completed = plan.completed_subtasks
assert len(completed) == 2
assert {st.id for st in completed} == {"s1", "s2"}
def test_failed_subtasks_property(self):
"""failed_subtasks 返回失败的子任务"""
plan = _make_valid_plan()
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
failed = plan.failed_subtasks
assert len(failed) == 2
assert {st.id for st in failed} == {"s2", "s3"}
def test_all_done_property_when_all_completed(self):
"""all_done 当所有子任务完成时返回 True"""
plan = _make_valid_plan()
for st in plan.subtasks:
plan.update_subtask_status(st.id, SubTaskStatus.COMPLETED)
assert plan.all_done is True
def test_all_done_property_when_some_failed(self):
"""all_done 当所有子任务完成或失败时返回 True"""
plan = _make_valid_plan()
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
plan.update_subtask_status("s3", SubTaskStatus.COMPLETED)
assert plan.all_done is True
def test_all_done_property_when_pending(self):
"""all_done 当有子任务未完成时返回 False"""
plan = _make_valid_plan()
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
# s2 and s3 still pending
assert plan.all_done is False
def test_all_done_property_empty_plan(self):
"""all_done 当没有子任务时返回 Truevacuous truth"""
plan = TeamPlan(task="空计划")
assert plan.all_done is True

View File

@ -2,8 +2,6 @@
from __future__ import annotations from __future__ import annotations
import pytest
from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.config import ExpertConfig, ExpertTemplate
from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.registry import ExpertTemplateRegistry
from agentkit.experts.router import ( from agentkit.experts.router import (
@ -113,7 +111,6 @@ class TestExpertTeamRoutingResult:
assert result.specified_experts == [] assert result.specified_experts == []
assert result.task_content == "" assert result.task_content == ""
assert result.auto_compose is False assert result.auto_compose is False
assert result.complexity == 0.0
assert result.match_method == "" assert result.match_method == ""
@ -159,47 +156,14 @@ class TestExpertTeamRouterResolve:
assert result.specified_experts == ["analyst"] assert result.specified_experts == ["analyst"]
assert result.task_content == "请分析这份报告" assert result.task_content == "请分析这份报告"
def test_high_complexity_triggers_team_suggestion(self): def test_no_team_prefix_no_team_mode(self):
"""高复杂度 (>=0.7) 触发团队模式建议""" """无 @team 前缀不触发团队模式"""
router = ExpertTeamRouter()
result = router.resolve("分析这个复杂系统", complexity=0.8)
assert result.matched is True
assert result.team_mode is True
assert result.auto_compose is True
assert result.match_method == "complexity_suggestion"
assert result.complexity == 0.8
def test_high_complexity_exact_threshold(self):
"""复杂度恰好等于阈值 0.7 也触发团队模式"""
router = ExpertTeamRouter()
result = router.resolve("任务内容", complexity=0.7)
assert result.matched is True
assert result.team_mode is True
assert result.match_method == "complexity_suggestion"
def test_low_complexity_no_team_mode(self):
"""低复杂度 (<0.7) 不触发团队模式"""
router = ExpertTeamRouter()
result = router.resolve("简单问题", complexity=0.3)
assert result.matched is False
assert result.team_mode is False
assert result.task_content == "简单问题"
assert result.complexity == 0.3
def test_no_team_prefix_no_complexity(self):
"""无 @team 前缀且无复杂度时不触发团队模式"""
router = ExpertTeamRouter() router = ExpertTeamRouter()
result = router.resolve("普通问题") result = router.resolve("普通问题")
assert result.matched is False assert result.matched is False
assert result.team_mode is False assert result.team_mode is False
assert result.task_content == "普通问题"
def test_team_prefix_takes_priority_over_complexity(self): assert result.match_method == ""
"""@team 前缀优先于复杂度判断"""
router = ExpertTeamRouter()
result = router.resolve("@team:analyst 任务", complexity=0.1)
assert result.matched is True
assert result.match_method == "explicit_team"
assert result.specified_experts == ["analyst"]
def test_nonexistent_expert_still_included(self): def test_nonexistent_expert_still_included(self):
"""指定不存在的专家名仍包含在列表中""" """指定不存在的专家名仍包含在列表中"""
@ -214,6 +178,35 @@ class TestExpertTeamRouterResolve:
assert result.task_content == "@team" assert result.task_content == "@team"
assert result.auto_compose is True assert result.auto_compose is True
def test_resolve_strips_leading_whitespace(self):
"""resolve() 对前导空白做 strip()"""
router = ExpertTeamRouter()
result = router.resolve(" @team:analyst 任务内容")
assert result.matched is True
assert result.team_mode is True
assert result.specified_experts == ["analyst"]
assert result.task_content == "任务内容"
def test_invalid_expert_names_rejected(self):
"""无效专家名被过滤掉"""
router = ExpertTeamRouter()
# 包含特殊字符的名称应被过滤
result = router.resolve("@team:analyst,inva!id,bad/name 任务")
# 仅保留合法名称
assert "analyst" in result.specified_experts
assert "inva!id" not in result.specified_experts
assert "bad/name" not in result.specified_experts
def test_max_experts_limit(self):
"""指定超过 MAX_EXPERTS 数量的专家时截断"""
router = ExpertTeamRouter()
# 构造 15 个专家名
names = ",".join(f"expert{i}" for i in range(15))
result = router.resolve(f"@team:{names} 任务")
from agentkit.experts.router import MAX_EXPERTS
assert len(result.specified_experts) == MAX_EXPERTS
# ── ExpertTeamRouter.resolve_expert_configs 测试 ─────────── # ── ExpertTeamRouter.resolve_expert_configs 测试 ───────────
@ -275,6 +268,25 @@ class TestExpertTeamRouterResolveExpertConfigs:
configs = router.resolve_expert_configs(["analyst"]) configs = router.resolve_expert_configs(["analyst"])
assert configs[0].bound_skills == ["data_query"] assert configs[0].bound_skills == ["data_query"]
def test_resolve_first_expert_is_lead(self):
"""第一个专家被指定为 lead"""
registry = _make_registry_with_templates()
router = ExpertTeamRouter(template_registry=registry)
configs = router.resolve_expert_configs(["analyst", "strategist", "reviewer"])
assert configs[0].is_lead is True
assert configs[1].is_lead is False
assert configs[2].is_lead is False
def test_resolve_skips_invalid_names(self):
"""跳过无效的专家名"""
router = ExpertTeamRouter()
# 包含无效字符的名称应被跳过
configs = router.resolve_expert_configs(["valid_name", "inva!id", "also_valid"])
assert len(configs) == 2
assert configs[0].name == "valid_name"
assert configs[1].name == "also_valid"
# ── ExpertTeamRouter 构造测试 ───────────────────────────── # ── ExpertTeamRouter 构造测试 ─────────────────────────────
@ -293,7 +305,40 @@ class TestExpertTeamRouterInit:
router = ExpertTeamRouter(template_registry=registry) router = ExpertTeamRouter(template_registry=registry)
assert router._registry is registry assert router._registry is registry
def test_complexity_threshold(self):
"""复杂度阈值默认为 0.7""" # ── ExpertTeamRouter.can_handle 测试 ──────────────────────
class TestExpertTeamRouterCanHandle:
"""ExpertTeamRouter.can_handle 方法测试"""
def test_can_handle_with_matching_template_name(self):
"""模板名出现在内容中时返回 True"""
registry = _make_registry_with_templates()
router = ExpertTeamRouter(template_registry=registry)
assert router.can_handle("请 analyst 帮我分析") is True
def test_can_handle_with_matching_description(self):
"""模板描述词出现在内容中时返回 True"""
registry = _make_registry_with_templates()
router = ExpertTeamRouter(template_registry=registry)
# 描述为 "analyst 模板" 等,"模板" 长度 > 2 应匹配
assert router.can_handle("我需要一个模板来参考") is True
def test_can_handle_no_templates(self):
"""无注册模板时返回 False"""
router = ExpertTeamRouter() router = ExpertTeamRouter()
assert router.COMPLEXITY_THRESHOLD == 0.7 # 默认注册中心可能为空
# 强制清空以测试
router._registry._templates.clear()
assert router.can_handle("任何内容") is False
def test_can_handle_with_templates_no_match(self):
"""有模板但内容不匹配时仍返回 Trueauto-compose 可组建团队)"""
registry = _make_registry_with_templates()
router = ExpertTeamRouter(template_registry=registry)
# 内容与任何模板名/描述都不匹配,但有模板存在 → auto-compose 可用
assert router.can_handle("完全无关的内容 xyz123") is True

View File

@ -1,4 +1,4 @@
"""ExpertTeam 容器单元测试""" """ExpertTeam 容器单元测试 (hub-and-spoke 模式)"""
from __future__ import annotations from __future__ import annotations
@ -11,11 +11,6 @@ from agentkit.core.handoff_transport import InProcessHandoffTransport
from agentkit.core.shared_workspace import SharedWorkspace from agentkit.core.shared_workspace import SharedWorkspace
from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.config import ExpertConfig, ExpertTemplate
from agentkit.experts.expert import Expert from agentkit.experts.expert import Expert
from agentkit.experts.plan import (
CollaborationPlan,
PlanPhase,
PlanStatus,
)
from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.registry import ExpertTemplateRegistry
from agentkit.experts.team import ExpertTeam, TeamStatus from agentkit.experts.team import ExpertTeam, TeamStatus
@ -84,27 +79,6 @@ def _make_mock_expert(
return expert return expert
def _make_valid_plan(
plan_id: str = "plan_1",
task: str = "测试任务",
lead_expert: str = "lead",
) -> CollaborationPlan:
"""创建有效的 CollaborationPlan"""
return CollaborationPlan(
id=plan_id,
task=task,
phases=[
PlanPhase(
id="phase_1",
name="阶段1",
assigned_expert=lead_expert,
task_description="执行任务",
)
],
lead_expert=lead_expert,
)
# ── ExpertTeam 创建测试 ─────────────────────────────────── # ── ExpertTeam 创建测试 ───────────────────────────────────
@ -119,7 +93,6 @@ class TestExpertTeamCreation:
assert len(team.team_id) > 0 assert len(team.team_id) > 0
assert team.status == TeamStatus.FORMING assert team.status == TeamStatus.FORMING
assert team.lead_expert is None assert team.lead_expert is None
assert team.plan is None
assert team.experts == [] assert team.experts == []
assert team.active_experts == [] assert team.active_experts == []
@ -173,7 +146,7 @@ class TestExpertTeamCreateTeam:
assert team._lead_expert_name == "lead" assert team._lead_expert_name == "lead"
assert team.lead_expert is mock_expert assert team.lead_expert is mock_expert
assert team.status == TeamStatus.PLANNING assert team.status == TeamStatus.EXECUTING
assert mock_expert.team_id == team.team_id assert mock_expert.team_id == team.team_id
@pytest.mark.asyncio @pytest.mark.asyncio
@ -195,7 +168,7 @@ class TestExpertTeamCreateTeam:
assert len(team.experts) == 2 assert len(team.experts) == 2
assert team._lead_expert_name == "lead" assert team._lead_expert_name == "lead"
assert team.status == TeamStatus.PLANNING assert team.status == TeamStatus.EXECUTING
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_team_without_pool_raises(self): async def test_create_team_without_pool_raises(self):
@ -411,64 +384,6 @@ class TestExpertTeamRemoveExpert:
assert last_left[0][1]["expert_name"] == "member1" assert last_left[0][1]["expert_name"] == "member1"
# ── ExpertTeam.update_plan 测试 ────────────────────────────
class TestExpertTeamUpdatePlan:
"""ExpertTeam.update_plan 协作计划更新测试"""
def test_update_plan_with_valid_plan(self):
"""有效计划更新成功,返回受影响的 Expert 名称"""
team = ExpertTeam()
plan = _make_valid_plan(lead_expert="lead")
affected = team.update_plan(plan)
assert team.plan is plan
assert "lead" in affected
def test_update_plan_confirmed_sets_executing(self):
"""CONFIRMED 状态的计划将团队状态设为 EXECUTING"""
team = ExpertTeam()
plan = _make_valid_plan(lead_expert="lead")
plan.status = PlanStatus.CONFIRMED
team.update_plan(plan)
assert team.status == TeamStatus.EXECUTING
def test_update_plan_with_invalid_plan_no_update(self):
"""无效计划validate 返回错误)不更新,返回验证错误列表"""
team = ExpertTeam()
# 创建有循环依赖的无效计划
plan = CollaborationPlan(
id="bad_plan",
task="无效任务",
phases=[
PlanPhase(
id="p1",
name="阶段1",
assigned_expert="a",
task_description="t1",
depends_on=["p2"],
),
PlanPhase(
id="p2",
name="阶段2",
assigned_expert="b",
task_description="t2",
depends_on=["p1"],
),
],
lead_expert="lead",
)
result = team.update_plan(plan)
assert len(result) > 0 # 返回验证错误而非空列表
assert team.plan is None # 未更新
# ── ExpertTeam.broadcast_user_message 测试 ───────────────── # ── ExpertTeam.broadcast_user_message 测试 ─────────────────
@ -535,42 +450,6 @@ class TestExpertTeamGetSharedContext:
assert context == {} assert context == {}
# ── ExpertTeam.generate_plan 测试 ──────────────────────────
class TestExpertTeamGeneratePlan:
"""ExpertTeam.generate_plan 计划生成测试"""
@pytest.mark.asyncio
async def test_generate_plan(self):
"""生成空的 CollaborationPlan"""
pool = _make_mock_pool()
team = ExpertTeam(pool=pool)
lead_config = _make_expert_config(name="lead", is_lead=True)
with patch.object(Expert, "create", new_callable=AsyncMock) as mock_create:
lead_expert = _make_mock_expert(name="lead", is_lead=True)
mock_create.return_value = lead_expert
await team.create_team(lead_config)
plan = await team.generate_plan("分析数据")
assert plan is not None
assert plan.task == "分析数据"
assert plan.lead_expert == "lead"
assert plan.phases == []
assert team.plan is plan
@pytest.mark.asyncio
async def test_generate_plan_without_lead(self):
"""没有 Lead Expert 时生成计划lead_expert 为空字符串"""
team = ExpertTeam()
plan = await team.generate_plan("测试任务")
assert plan.lead_expert == ""
# ── ExpertTeam.dissolve 测试 ─────────────────────────────── # ── ExpertTeam.dissolve 测试 ───────────────────────────────
@ -667,11 +546,6 @@ class TestExpertTeamDissolvedOperations:
# 解散后状态为 DISSOLVED # 解散后状态为 DISSOLVED
assert team.status == TeamStatus.DISSOLVED assert team.status == TeamStatus.DISSOLVED
# 再次 create_team 时,由于 experts 已清空,
# 但 pool 仍然存在,理论上可以重新创建
# 但这里验证状态是 DISSOLVED
assert team.status == TeamStatus.DISSOLVED
# ── ExpertTeam.lead_expert 属性测试 ──────────────────────── # ── ExpertTeam.lead_expert 属性测试 ────────────────────────
@ -742,10 +616,11 @@ class TestExpertTeamBuildContext:
context = team._build_team_context(lead_config, [member_config]) context = team._build_team_context(lead_config, [member_config])
assert "You are part of an Expert Team." in context assert "hub-and-spoke mode" in context
assert "Lead Expert: lead (领导者)" in context assert "Lead Expert (hub): lead (领导者)" in context
assert "Team Member: analyst (分析师), Skills: data_query" in context assert "Team Member (spoke): analyst (分析师)" in context
assert "send_message() and request_assist()" in context assert "data_query" in context
assert "depth=1" in context
def test_build_team_context_no_lead(self): def test_build_team_context_no_lead(self):
"""没有 Lead Expert 时构建上下文""" """没有 Lead Expert 时构建上下文"""
@ -754,8 +629,9 @@ class TestExpertTeamBuildContext:
context = team._build_team_context(None, [member_config]) context = team._build_team_context(None, [member_config])
assert "Lead Expert" not in context # 不应出现具体的 Lead Expert (hub): name 行
assert "Team Member: analyst" in context assert "Lead Expert (hub):" not in context
assert "Team Member (spoke): analyst" in context
def test_build_team_context_skips_lead_in_members(self): def test_build_team_context_skips_lead_in_members(self):
"""成员列表中包含 Lead 时跳过""" """成员列表中包含 Lead 时跳过"""
@ -765,4 +641,4 @@ class TestExpertTeamBuildContext:
context = team._build_team_context(lead_config, [lead_config]) context = team._build_team_context(lead_config, [lead_config])
# Lead 不应出现在 Team Member 行 # Lead 不应出现在 Team Member 行
assert "Team Member: lead" not in context assert "Team Member (spoke): lead" not in context

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,472 @@
"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试
测试场景
- ToolSearchIndex: 空索引单工具多工具相关性排序top_k 限制无匹配
- ToolSearchTool: 正常查询空查询无匹配结果包含完整描述
- ReActEngine 分层注入: core/extended 分离tool_search 自动添加禁用配置自定义 core 列表
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from agentkit.tools.base import Tool
from agentkit.tools.builtin import ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
# ── Test Helpers ──────────────────────────────────────────
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str,
description: str,
input_schema: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema,
tags=tags or [],
)
async def execute(self, **kwargs) -> dict:
return {"status": "ok"}
def _make_tools() -> list[Tool]:
"""创建一组测试工具"""
return [
FakeTool(
name="read_file",
description="Read the contents of a file from the filesystem.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path to read"},
},
"required": ["path"],
},
tags=["io", "file"],
),
FakeTool(
name="write_file",
description="Write content to a file on the filesystem.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path to write"},
"content": {"type": "string", "description": "content to write"},
},
"required": ["path", "content"],
},
tags=["io", "file"],
),
FakeTool(
name="web_search",
description="Search the web for information using a search engine.",
input_schema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "search query"},
},
"required": ["query"],
},
tags=["web", "search"],
),
FakeTool(
name="run_tests",
description="Run project tests to verify code changes.",
input_schema={
"type": "object",
"properties": {
"commands": {"type": "array", "description": "test commands"},
},
},
tags=["testing", "verification"],
),
]
# ── ToolSearchIndex Tests ─────────────────────────────────
class TestToolSearchIndex:
"""ToolSearchIndex BM25 搜索测试"""
def test_empty_tools(self):
"""空工具列表构建索引不报错"""
index = ToolSearchIndex([])
assert len(index) == 0
assert index.search("anything") == []
def test_single_tool_match(self):
"""单工具索引,匹配查询返回该工具"""
tools = _make_tools()[:1]
index = ToolSearchIndex(tools)
results = index.search("read file")
assert len(results) == 1
assert results[0].name == "read_file"
def test_relevance_ranking(self):
"""多工具索引,相关工具排在前面"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("web search")
assert len(results) > 0
# web_search 应该排在最前
assert results[0].name == "web_search"
def test_top_k_limit(self):
"""top_k 限制返回数量"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("file", top_k=2)
assert len(results) <= 2
def test_no_match_returns_empty(self):
"""无匹配时返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("xyzzy_nonexistent")
assert results == []
def test_empty_query_returns_empty(self):
"""空查询返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
assert index.search("") == []
assert index.search(" ") == []
def test_top_k_zero_returns_empty(self):
"""top_k=0 返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
assert index.search("file", top_k=0) == []
def test_snake_case_tokenization(self):
"""snake_case 工具名被正确分词"""
tool = FakeTool(
name="read_file",
description="Read file contents.",
)
index = ToolSearchIndex([tool])
# 搜索 "read" 应该匹配
results = index.search("read")
assert len(results) == 1
# 搜索 "file" 也应该匹配
results = index.search("file")
assert len(results) == 1
def test_search_includes_parameter_descriptions(self):
"""搜索能匹配参数描述中的关键词"""
tool = FakeTool(
name="custom_tool",
description="A custom tool.",
input_schema={
"type": "object",
"properties": {
"database_url": {
"type": "string",
"description": "PostgreSQL connection string",
},
},
},
)
index = ToolSearchIndex([tool])
results = index.search("postgresql database")
assert len(results) == 1
assert results[0].name == "custom_tool"
def test_search_includes_tags(self):
"""搜索能匹配标签中的关键词"""
tool = FakeTool(
name="data_tool",
description="Process data.",
tags=["etl", "pipeline"],
)
index = ToolSearchIndex([tool])
results = index.search("pipeline etl")
assert len(results) == 1
assert results[0].name == "data_tool"
def test_invalid_k1_raises(self):
"""k1 < 0 抛出 ValueError"""
with pytest.raises(ValueError, match="k1"):
ToolSearchIndex(_make_tools(), k1=-1.0)
def test_invalid_b_raises(self):
"""b 不在 [0,1] 范围抛出 ValueError"""
with pytest.raises(ValueError, match="b"):
ToolSearchIndex(_make_tools(), b=1.5)
with pytest.raises(ValueError, match="b"):
ToolSearchIndex(_make_tools(), b=-0.1)
def test_multiple_results_sorted_by_score(self):
"""多个匹配结果按分数降序排列"""
tools = [
FakeTool(name="search_web", description="Search the web."),
FakeTool(name="search_files", description="Search files on disk."),
FakeTool(name="unrelated", description="Do something unrelated."),
]
index = ToolSearchIndex(tools)
results = index.search("search")
# 两个包含 "search" 的工具应该返回unrelated 不返回
assert len(results) == 2
names = [r.name for r in results]
assert "unrelated" not in names
# ── ToolSearchTool Tests ──────────────────────────────────
class TestToolSearchTool:
"""ToolSearchTool 工具测试"""
def test_tool_name_and_schema(self):
"""工具名称和 schema 正确"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
assert tool.name == "tool_search"
assert "query" in tool.input_schema["properties"]
assert "query" in tool.input_schema["required"]
async def test_execute_returns_results(self):
"""执行搜索返回匹配工具的完整描述"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="web search")
assert result["count"] > 0
assert result["query"] == "web search"
first = result["results"][0]
assert "name" in first
assert "description" in first
assert "parameters" in first
assert first["name"] == "web_search"
async def test_execute_empty_query_returns_error(self):
"""空查询返回错误"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="")
assert "error" in result
assert result["results"] == []
async def test_execute_no_match(self):
"""无匹配返回空结果和提示消息"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="zzz_nonexistent")
assert result["count"] == 0
assert result["results"] == []
assert "message" in result
async def test_execute_respects_top_k(self):
"""top_k 限制返回数量"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index, top_k=1)
result = await tool.execute(query="file")
assert result["count"] <= 1
def test_invalid_top_k_raises(self):
"""top_k < 1 抛出 ValueError"""
index = ToolSearchIndex(_make_tools())
with pytest.raises(ValueError, match="top_k"):
ToolSearchTool(search_index=index, top_k=0)
# ── ReActEngine Tiered Injection Tests ────────────────────
class TestReActTieredInjection:
"""ReActEngine 工具描述分层注入测试"""
def _make_engine(self, **kwargs: Any):
from agentkit.core.react import ReActEngine
gateway = MagicMock()
return ReActEngine(llm_gateway=gateway, **kwargs)
def test_core_tools_get_full_description(self):
"""Core 工具注入完整描述(含参数)"""
engine = self._make_engine()
tools = [
FakeTool(
name="read_file",
description="Read a file.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path"},
},
"required": ["path"],
},
),
]
prompt = engine._build_tool_use_prompt(tools)
# 核心工具区域存在
assert "核心工具" in prompt
# 参数描述被注入
assert "path" in prompt
assert "file path" in prompt
def test_extended_tools_get_one_line_only(self):
"""Extended 工具只注入名称+一行描述(无参数)"""
engine = self._make_engine()
tools = [
FakeTool(
name="custom_extended",
description="A custom extended tool for testing.",
input_schema={
"type": "object",
"properties": {
"secret_param": {
"type": "string",
"description": "SECRET_PARAM_DESC",
},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
assert "扩展工具" in prompt
assert "custom_extended" in prompt
# 参数描述不应出现在扩展工具区域
assert "SECRET_PARAM_DESC" not in prompt
def test_mixed_core_and_extended(self):
"""混合 core + extended 工具,两者分区显示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(
name="web_search",
description="Search the web.",
input_schema={
"type": "object",
"properties": {
"q": {"type": "string", "description": "query string"},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
assert "核心工具" in prompt
assert "扩展工具" in prompt
# read_file 在核心区web_search 在扩展区
assert "read_file" in prompt
assert "web_search" in prompt
def test_maybe_add_tool_search_adds_for_extended(self):
"""有扩展工具时自动添加 tool_search"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."), # core
FakeTool(name="web_search", description="Search the web."), # extended
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 3
assert any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_only_core(self):
"""只有 core 工具时不添加 tool_search"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="write_file", description="Write a file."),
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 2
assert not any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_disabled(self):
"""enable_tool_search=False 时不添加 tool_search"""
engine = self._make_engine(enable_tool_search=False)
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 2
assert not any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_already_present(self):
"""tool_search 已存在时不重复添加"""
engine = self._make_engine()
index = ToolSearchIndex([])
existing_search = ToolSearchTool(search_index=index)
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
existing_search,
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 3
def test_custom_core_tool_names(self):
"""自定义 core_tool_names 覆盖默认值"""
engine = self._make_engine(core_tool_names=["my_core_tool"])
tools = [
FakeTool(name="my_core_tool", description="My core tool."),
FakeTool(
name="read_file",
description="Read a file.",
input_schema={
"type": "object",
"properties": {
"p": {"type": "string", "description": "PARAM_DESC"},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
# my_core_tool 在核心区
assert "核心工具" in prompt
# read_file 现在是扩展工具,参数描述不应出现
assert "PARAM_DESC" not in prompt
def test_tool_search_is_core_tool(self):
"""tool_search 被视为 core 工具(全量描述注入)"""
engine = self._make_engine()
index = ToolSearchIndex([])
search_tool = ToolSearchTool(search_index=index)
tools = [search_tool]
prompt = engine._build_tool_use_prompt(tools)
# tool_search 应该在核心工具区
assert "核心工具" in prompt
assert "tool_search" in prompt
# 其参数 query 应该被注入
assert "query" in prompt
def test_search_hint_added_when_tool_search_present(self):
"""tool_search 存在且有扩展工具时添加搜索提示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
]
# 模拟 _maybe_add_tool_search 添加 tool_search
tools_with_search = engine._maybe_add_tool_search(tools)
prompt = engine._build_tool_use_prompt(tools_with_search)
assert "tool_search" in prompt
assert "扩展工具" in prompt
# 搜索提示存在
assert "tool_search(query" in prompt
def test_no_search_hint_when_no_extended_tools(self):
"""无扩展工具时不添加搜索提示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
]
prompt = engine._build_tool_use_prompt(tools)
assert "扩展工具" not in prompt