feat: hub-and-spoke experts, tiered tool injection, unified event model (U3/U7/U10)
This commit is contained in:
parent
200174c5c7
commit
bbedfff597
|
|
@ -3,6 +3,7 @@
|
|||
from agentkit.core.base import BaseAgent
|
||||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
|
||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
|
||||
from agentkit.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentFrameworkError,
|
||||
|
|
@ -29,12 +30,16 @@ from agentkit.core.protocol import (
|
|||
AgentCapability,
|
||||
AgentStatus,
|
||||
CancellationToken,
|
||||
Event,
|
||||
EvolutionEvent,
|
||||
HandoffMessage,
|
||||
SessionEventType,
|
||||
TaskEventType,
|
||||
TaskMessage,
|
||||
TaskProgress,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
TurnEventType,
|
||||
)
|
||||
|
||||
# Optional: HeadroomCompressor — only available when headroom-ai is installed
|
||||
|
|
@ -80,4 +85,12 @@ __all__ = [
|
|||
"TaskProgress",
|
||||
"TaskResult",
|
||||
"TaskStatus",
|
||||
# SQ/EQ 统一事件模型
|
||||
"Event",
|
||||
"SessionEventType",
|
||||
"TaskEventType",
|
||||
"TurnEventType",
|
||||
"Submission",
|
||||
"SubmissionQueue",
|
||||
"EventQueue",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,244 @@
|
|||
"""统一事件模型 - SQ/EQ 双队列实现
|
||||
|
||||
参考 Codex 的 Session/Task/Turn 三级模型,提供:
|
||||
- SubmissionQueue (SQ): 接收用户输入,返回 task_id
|
||||
- EventQueue (EQ): 推送 Agent 事件,支持多订阅者广播和事件缓冲回放
|
||||
|
||||
集成点(Phase 4 再做实际集成):
|
||||
- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端
|
||||
- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator
|
||||
|
||||
from agentkit.core.protocol import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流)
|
||||
_CLOSED_SENTINEL: Event = Event(
|
||||
event_type="__closed__",
|
||||
task_id="",
|
||||
session_id="",
|
||||
data={},
|
||||
timestamp="",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Submission:
|
||||
"""用户提交的任务
|
||||
|
||||
由 SubmissionQueue.submit() 创建,消费者通过 SubmissionQueue.drain() 获取。
|
||||
|
||||
Attributes:
|
||||
task_id: 唯一任务 ID
|
||||
session_id: 会话 ID
|
||||
content: 用户输入内容
|
||||
created_at: 创建时间
|
||||
cancelled: 是否已取消
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
content: str
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
class SubmissionQueue:
|
||||
"""提交队列 (SQ) - 接收用户输入
|
||||
|
||||
用户通过 submit() 提交输入,消费者通过 drain() 获取提交。
|
||||
支持通过 task_id 取消提交。
|
||||
|
||||
内部使用 asyncio.Queue 实现,队列大小上限 1024(与 HandoffTransport 一致)。
|
||||
"""
|
||||
|
||||
_MAX_QUEUE_SIZE: int = 1024
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||||
self._submissions: dict[str, Submission] = {}
|
||||
self._cancelled_tasks: set[str] = set()
|
||||
self._closed: bool = False
|
||||
|
||||
async def submit(self, content: str, session_id: str) -> str:
|
||||
"""提交用户输入,返回 task_id
|
||||
|
||||
Args:
|
||||
content: 用户输入内容
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
新生成的 task_id(UUID4)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 队列已关闭
|
||||
"""
|
||||
if self._closed:
|
||||
raise RuntimeError("SubmissionQueue is closed")
|
||||
task_id = str(uuid.uuid4())
|
||||
submission = Submission(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
)
|
||||
self._submissions[task_id] = submission
|
||||
await self._queue.put(submission)
|
||||
return task_id
|
||||
|
||||
async def drain(self) -> AsyncIterator[Submission]:
|
||||
"""消费提交队列(异步生成器)
|
||||
|
||||
已取消的提交会被跳过。当队列关闭时,生成器结束。
|
||||
"""
|
||||
while True:
|
||||
submission = await self._queue.get()
|
||||
if submission.task_id in self._cancelled_tasks:
|
||||
continue
|
||||
yield submission
|
||||
|
||||
async def cancel(self, task_id: str) -> bool:
|
||||
"""取消任务
|
||||
|
||||
Args:
|
||||
task_id: 要取消的任务 ID
|
||||
|
||||
Returns:
|
||||
是否成功取消(任务存在且未取消过)
|
||||
"""
|
||||
if task_id not in self._submissions:
|
||||
return False
|
||||
if task_id in self._cancelled_tasks:
|
||||
return False
|
||||
self._cancelled_tasks.add(task_id)
|
||||
self._submissions[task_id].cancelled = True
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""返回队列是否已关闭"""
|
||||
return self._closed
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭队列,不再接受新提交。
|
||||
|
||||
已在队列中的提交不受影响,消费者仍可通过 drain() 获取。
|
||||
"""
|
||||
self._closed = True
|
||||
|
||||
|
||||
class EventQueue:
|
||||
"""事件队列 (EQ) - 推送 Agent 事件
|
||||
|
||||
支持多订阅者广播模式:每条事件会投递到所有活跃订阅者。
|
||||
新订阅者会收到最近 N 条缓冲事件的回放(默认 100 条)。
|
||||
|
||||
集成点:
|
||||
- Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端
|
||||
- CLI 可通过 subscribe() 订阅事件流并打印
|
||||
"""
|
||||
|
||||
_MAX_QUEUE_SIZE: int = 1024
|
||||
_DEFAULT_BUFFER_SIZE: int = 100
|
||||
|
||||
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
|
||||
self._subscribers: list[asyncio.Queue[Event]] = []
|
||||
self._buffer: deque[Event] = deque(maxlen=buffer_size)
|
||||
self._buffer_size = buffer_size
|
||||
self._closed: bool = False
|
||||
|
||||
async def emit(self, event: Event) -> None:
|
||||
"""推送事件给所有订阅者
|
||||
|
||||
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
|
||||
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
|
||||
|
||||
Args:
|
||||
event: 要推送的事件
|
||||
"""
|
||||
self._buffer.append(event)
|
||||
for queue in self._subscribers:
|
||||
try:
|
||||
queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("EventQueue subscriber queue full, dropping event")
|
||||
|
||||
async def subscribe(self) -> AsyncIterator[Event]:
|
||||
"""订阅事件流(异步生成器)
|
||||
|
||||
订阅时会先回放缓冲区中的事件,然后持续接收新事件。
|
||||
每个订阅者获得独立的队列,实现广播语义。
|
||||
|
||||
当队列关闭时,生成器结束。
|
||||
|
||||
注意:回放和加入订阅者列表在同一同步段内完成(无 await),
|
||||
保证不会遗漏或重复事件。
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||||
|
||||
# 回放缓冲事件(同步操作,无 await,保证原子性)
|
||||
for event in list(self._buffer):
|
||||
try:
|
||||
queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("EventQueue replay buffer full, skipping remaining")
|
||||
break
|
||||
|
||||
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
|
||||
self._subscribers.append(queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is _CLOSED_SENTINEL:
|
||||
break
|
||||
yield event
|
||||
finally:
|
||||
# 清理:移除当前订阅者的队列
|
||||
if queue in self._subscribers:
|
||||
self._subscribers.remove(queue)
|
||||
|
||||
@property
|
||||
def subscriber_count(self) -> int:
|
||||
"""返回当前订阅者数量"""
|
||||
return len(self._subscribers)
|
||||
|
||||
@property
|
||||
def buffer_size(self) -> int:
|
||||
"""返回缓冲区大小上限"""
|
||||
return self._buffer_size
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
"""返回队列是否已关闭"""
|
||||
return self._closed
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭队列,向所有订阅者发送哨兵并清理状态。
|
||||
|
||||
使用哨兵模式(参考 HandoffTransport),确保阻塞在 subscribe() 上的
|
||||
订阅者能够优雅退出。
|
||||
"""
|
||||
self._closed = True
|
||||
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
|
||||
for queue in self._subscribers:
|
||||
try:
|
||||
queue.put_nowait(_CLOSED_SENTINEL)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
self._subscribers.clear()
|
||||
self._buffer.clear()
|
||||
|
|
@ -10,6 +10,7 @@ from agentkit.core.exceptions import TaskCancelledError
|
|||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
|
|
@ -21,6 +22,7 @@ class TaskStatus(str, Enum):
|
|||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent 状态枚举"""
|
||||
|
||||
ONLINE = "online"
|
||||
OFFLINE = "offline"
|
||||
BUSY = "busy"
|
||||
|
|
@ -29,6 +31,7 @@ class AgentStatus(str, Enum):
|
|||
@dataclass
|
||||
class AgentCapability:
|
||||
"""Agent 能力声明"""
|
||||
|
||||
agent_name: str
|
||||
agent_type: str
|
||||
version: str
|
||||
|
|
@ -70,6 +73,7 @@ class AgentCapability:
|
|||
@dataclass
|
||||
class TaskMessage:
|
||||
"""任务消息 - 从调度器发往 Agent"""
|
||||
|
||||
task_id: str
|
||||
agent_name: str
|
||||
task_type: str
|
||||
|
|
@ -114,6 +118,7 @@ class TaskMessage:
|
|||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果 - 从 Agent 返回"""
|
||||
|
||||
task_id: str
|
||||
agent_name: str
|
||||
status: str
|
||||
|
|
@ -163,6 +168,7 @@ class TaskResult:
|
|||
@dataclass
|
||||
class TaskProgress:
|
||||
"""进度上报 - Agent 执行过程中上报"""
|
||||
|
||||
task_id: str
|
||||
agent_name: str
|
||||
progress: float
|
||||
|
|
@ -195,6 +201,7 @@ class TaskProgress:
|
|||
@dataclass
|
||||
class HandoffMessage:
|
||||
"""任务转交消息 - Agent 间 Handoff"""
|
||||
|
||||
source_agent: str
|
||||
target_agent: str
|
||||
task_id: str
|
||||
|
|
@ -233,6 +240,7 @@ class HandoffMessage:
|
|||
@dataclass
|
||||
class EvolutionEvent:
|
||||
"""进化事件 - 记录 Agent 的自我进化变更"""
|
||||
|
||||
agent_name: str
|
||||
change_type: str # prompt / strategy / pipeline
|
||||
before: dict[str, Any]
|
||||
|
|
@ -277,3 +285,102 @@ class CancellationToken:
|
|||
"""检查是否已取消,若已取消则抛出 TaskCancelledError"""
|
||||
if self._cancelled:
|
||||
raise TaskCancelledError(task_id="")
|
||||
|
||||
|
||||
# ── SQ/EQ 统一事件模型 ──────────────────────────────────────────
|
||||
# 参考 Codex 的 Session/Task/Turn 三级模型,统一 CLI 和 WebSocket 事件流。
|
||||
|
||||
|
||||
class SessionEventType:
|
||||
"""Session 级别事件类型
|
||||
|
||||
对应一次完整的会话生命周期(如用户打开 CLI、关闭 CLI)。
|
||||
"""
|
||||
|
||||
SESSION_STARTED = "session.started"
|
||||
SESSION_ENDED = "session.ended"
|
||||
|
||||
|
||||
class TaskEventType:
|
||||
"""Task 级别事件类型
|
||||
|
||||
对应一次任务提交(用户输入),从创建到终态(完成/失败)。
|
||||
"""
|
||||
|
||||
TASK_CREATED = "task.created"
|
||||
TASK_STARTED = "task.started"
|
||||
TASK_COMPLETED = "task.completed"
|
||||
TASK_FAILED = "task.failed"
|
||||
|
||||
|
||||
class TurnEventType:
|
||||
"""Turn 级别事件类型
|
||||
|
||||
对应一个 Task 内的多轮对话/推理步骤,包括思考、工具调用、Token 流等。
|
||||
"""
|
||||
|
||||
TURN_STARTED = "turn.started"
|
||||
THINKING = "turn.thinking"
|
||||
TOOL_CALL = "turn.tool_call"
|
||||
TOOL_RESULT = "turn.tool_result"
|
||||
TOKEN = "turn.token"
|
||||
STEP = "turn.step"
|
||||
FINAL_ANSWER = "turn.final_answer"
|
||||
TURN_COMPLETED = "turn.completed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
"""统一事件模型 - SQ/EQ 双队列事件
|
||||
|
||||
所有 CLI 和 WebSocket 事件统一使用此结构,便于前端和 CLI 统一处理。
|
||||
|
||||
Attributes:
|
||||
event_type: 事件类型(见 SessionEventType / TaskEventType / TurnEventType)
|
||||
task_id: 关联的任务 ID
|
||||
session_id: 关联的会话 ID
|
||||
data: 事件负载数据
|
||||
timestamp: ISO 8601 格式时间戳
|
||||
"""
|
||||
|
||||
event_type: str
|
||||
task_id: str
|
||||
session_id: str
|
||||
data: dict[str, Any]
|
||||
timestamp: str # ISO 8601 format
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"event_type": self.event_type,
|
||||
"task_id": self.task_id,
|
||||
"session_id": self.session_id,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Event":
|
||||
return cls(
|
||||
event_type=data["event_type"],
|
||||
task_id=data["task_id"],
|
||||
session_id=data["session_id"],
|
||||
data=data.get("data", {}),
|
||||
timestamp=data["timestamp"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
event_type: str,
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> "Event":
|
||||
"""创建事件,自动生成 ISO 8601 格式时间戳"""
|
||||
return cls(
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
data=data or {},
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
"""Expert 系统 - 专家团队模式的配置、模板、注册与协作计划"""
|
||||
"""Expert 系统 - 专家团队 hub-and-spoke 模式
|
||||
|
||||
U3 简化:从去中心化协作改为 Lead Expert + 并行 Task 模式。
|
||||
"""
|
||||
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
from agentkit.experts.expert import Expert
|
||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||
from agentkit.experts.plan import (
|
||||
CollaborationPlan,
|
||||
MergeStrategy,
|
||||
ParallelType,
|
||||
PhaseStatus,
|
||||
PlanPhase,
|
||||
PlanStatus,
|
||||
SubTask,
|
||||
SubTaskStatus,
|
||||
TeamPlan,
|
||||
)
|
||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||
from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult
|
||||
from agentkit.experts.team import ExpertTeam, TeamStatus
|
||||
|
||||
__all__ = [
|
||||
"CollaborationPlan",
|
||||
"Expert",
|
||||
"ExpertConfig",
|
||||
"ExpertTeam",
|
||||
|
|
@ -25,10 +26,10 @@ __all__ = [
|
|||
"ExpertTemplate",
|
||||
"ExpertTemplateRegistry",
|
||||
"MergeStrategy",
|
||||
"ParallelType",
|
||||
"PhaseStatus",
|
||||
"PlanPhase",
|
||||
"PlanStatus",
|
||||
"SubTask",
|
||||
"SubTaskStatus",
|
||||
"TeamOrchestrator",
|
||||
"TeamPlan",
|
||||
"TeamStatus",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.config_driven import AgentConfig
|
||||
|
|
|
|||
|
|
@ -65,7 +65,9 @@ class Expert:
|
|||
# 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息
|
||||
if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template:
|
||||
sections = agent._prompt_template._sections
|
||||
sections.context = f"{team_context}\n\n{sections.context}" if sections.context else team_context
|
||||
sections.context = (
|
||||
f"{team_context}\n\n{sections.context}" if sections.context else team_context
|
||||
)
|
||||
|
||||
return expert
|
||||
|
||||
|
|
@ -116,9 +118,7 @@ class Expert:
|
|||
"reason": reason,
|
||||
"type": "assist_request",
|
||||
}
|
||||
await self._handoff_transport.send(
|
||||
f"expert:{target_expert}:handoff", handoff_msg
|
||||
)
|
||||
await self._handoff_transport.send(f"expert:{target_expert}:handoff", handoff_msg)
|
||||
|
||||
async def propose_plan_modification(
|
||||
self,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,281 +1,172 @@
|
|||
"""CollaborationPlan 数据模型 - Expert Team 协作蓝图
|
||||
"""Expert Team 计划数据模型 - hub-and-spoke 模式
|
||||
|
||||
定义 Expert Team 的结构化协作计划,包括阶段、角色分配、依赖关系、
|
||||
并行类型、合并策略和里程碑。
|
||||
简化为 Lead Expert + 并行 Task 模式:
|
||||
- Lead Expert 接收任务并分解为子任务
|
||||
- 子任务并行执行(Task 深度=1,不能再 spawn Task)
|
||||
- Lead Expert 汇总结果(BEST 策略)
|
||||
|
||||
移除了阶段依赖图、VOTE/FUSION 合并策略、跨阶段状态共享。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParallelType(str, enum.Enum):
|
||||
"""并行执行类型"""
|
||||
|
||||
SERIAL = "serial"
|
||||
SUBTASK_PARALLEL = "subtask_parallel"
|
||||
COMPETITIVE_PARALLEL = "competitive_parallel"
|
||||
|
||||
|
||||
class MergeStrategy(str, enum.Enum):
|
||||
"""合并策略 - 仅用于 COMPETITIVE_PARALLEL 阶段"""
|
||||
"""合并策略 - Lead Expert 用于选择最佳结果
|
||||
|
||||
hub-and-spoke 模式下仅保留 BEST 策略:
|
||||
Lead Expert 从所有子任务结果中选择或综合出最佳结果。
|
||||
"""
|
||||
|
||||
BEST = "best"
|
||||
VOTE = "vote"
|
||||
FUSION = "fusion"
|
||||
|
||||
|
||||
class PhaseStatus(str, enum.Enum):
|
||||
"""阶段状态"""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlanStatus(str, enum.Enum):
|
||||
"""计划状态"""
|
||||
|
||||
DRAFT = "draft"
|
||||
CONFIRMED = "confirmed"
|
||||
EXECUTING = "executing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
FALLBACK = "fallback"
|
||||
|
||||
|
||||
# DFS 着色常量
|
||||
_WHITE = 0 # 未访问
|
||||
_GRAY = 1 # 正在访问(当前路径上)
|
||||
_BLACK = 2 # 已完成访问
|
||||
class SubTaskStatus(str, enum.Enum):
|
||||
"""子任务状态"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanPhase:
|
||||
"""协作计划中的单个阶段
|
||||
class SubTask:
|
||||
"""Lead Expert 分解出的子任务
|
||||
|
||||
在 hub-and-spoke 模式中,Lead Expert 将原始任务分解为多个子任务,
|
||||
每个子任务由一个 Expert 并行执行。子任务之间无依赖关系、无通信。
|
||||
|
||||
约束:
|
||||
- Task 深度=1(子任务不能再 spawn 子任务)
|
||||
- 子任务之间无通信
|
||||
- Lead Expert 持有所有状态
|
||||
|
||||
Attributes:
|
||||
id: 阶段标识符
|
||||
name: 阶段显示名称
|
||||
assigned_expert: 分配到此阶段的 Expert 名称
|
||||
task_description: 此阶段完成的任务描述
|
||||
depends_on: 依赖的阶段 ID 列表
|
||||
parallel_type: 执行类型
|
||||
merge_strategy: 合并策略,仅 COMPETITIVE_PARALLEL 需要
|
||||
milestone: 里程碑检查点描述
|
||||
id: 子任务标识符
|
||||
description: 子任务描述
|
||||
assigned_expert: 分配的 Expert 名称
|
||||
status: 当前状态
|
||||
result: 阶段输出结果
|
||||
result: 子任务输出结果
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
assigned_expert: str
|
||||
task_description: str
|
||||
depends_on: list[str] = field(default_factory=list)
|
||||
parallel_type: ParallelType = ParallelType.SERIAL
|
||||
merge_strategy: MergeStrategy | None = None
|
||||
milestone: str = ""
|
||||
status: PhaseStatus = PhaseStatus.PENDING
|
||||
result: dict | None = None
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
description: str = ""
|
||||
assigned_expert: str = ""
|
||||
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||
result: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""序列化为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"assigned_expert": self.assigned_expert,
|
||||
"task_description": self.task_description,
|
||||
"depends_on": self.depends_on,
|
||||
"parallel_type": self.parallel_type.value,
|
||||
"merge_strategy": self.merge_strategy.value if self.merge_strategy is not None else None,
|
||||
"milestone": self.milestone,
|
||||
"status": self.status.value,
|
||||
"result": self.result,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
||||
"""从字典创建 PlanPhase"""
|
||||
merge_strategy = None
|
||||
if data.get("merge_strategy") is not None:
|
||||
merge_strategy = MergeStrategy(data["merge_strategy"])
|
||||
|
||||
def from_dict(cls, data: dict[str, Any]) -> SubTask:
|
||||
"""从字典创建 SubTask"""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data["name"],
|
||||
assigned_expert=data["assigned_expert"],
|
||||
task_description=data["task_description"],
|
||||
depends_on=data.get("depends_on", []),
|
||||
parallel_type=ParallelType(data.get("parallel_type", ParallelType.SERIAL.value)),
|
||||
merge_strategy=merge_strategy,
|
||||
milestone=data.get("milestone", ""),
|
||||
status=PhaseStatus(data.get("status", PhaseStatus.PENDING.value)),
|
||||
id=data.get("id", str(uuid.uuid4())),
|
||||
description=data.get("description", ""),
|
||||
assigned_expert=data.get("assigned_expert", ""),
|
||||
status=SubTaskStatus(data.get("status", SubTaskStatus.PENDING.value)),
|
||||
result=data.get("result"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CollaborationPlan:
|
||||
"""Expert Team 协作计划
|
||||
class TeamPlan:
|
||||
"""Expert Team hub-and-spoke 执行计划
|
||||
|
||||
定义 Expert Team 的结构化协作蓝图,包括阶段编排、共享变量、
|
||||
状态管理和依赖关系。
|
||||
Lead Expert 持有此计划,包含分解的子任务列表。
|
||||
与旧版 CollaborationPlan 不同,此计划无阶段依赖图,
|
||||
所有子任务并行执行,由 Lead Expert 统一汇总。
|
||||
|
||||
Attributes:
|
||||
id: 计划标识符
|
||||
task: 原始任务描述
|
||||
phases: 有序阶段列表
|
||||
variables: 共享变量
|
||||
subtasks: 子任务列表(并行执行,无依赖关系)
|
||||
status: 计划状态
|
||||
lead_expert: 主导 Expert 名称
|
||||
"""
|
||||
|
||||
id: str
|
||||
task: str
|
||||
phases: list[PlanPhase] = field(default_factory=list)
|
||||
variables: dict = field(default_factory=dict)
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
task: str = ""
|
||||
subtasks: list[SubTask] = field(default_factory=list)
|
||||
status: PlanStatus = PlanStatus.DRAFT
|
||||
lead_expert: str = ""
|
||||
_phase_index: dict[str, PlanPhase] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Build the phase index after initialization."""
|
||||
self._rebuild_index()
|
||||
|
||||
def _rebuild_index(self) -> None:
|
||||
"""Rebuild the phase index from the phases list."""
|
||||
self._phase_index = {phase.id: phase for phase in self.phases}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""序列化为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"task": self.task,
|
||||
"phases": [phase.to_dict() for phase in self.phases],
|
||||
"variables": self.variables,
|
||||
"subtasks": [st.to_dict() for st in self.subtasks],
|
||||
"status": self.status.value,
|
||||
"lead_expert": self.lead_expert,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> CollaborationPlan:
|
||||
"""从字典创建 CollaborationPlan"""
|
||||
phases = [PlanPhase.from_dict(p) for p in data.get("phases", [])]
|
||||
def from_dict(cls, data: dict[str, Any]) -> TeamPlan:
|
||||
"""从字典创建 TeamPlan"""
|
||||
subtasks = [SubTask.from_dict(st) for st in data.get("subtasks", [])]
|
||||
return cls(
|
||||
id=data["id"],
|
||||
task=data["task"],
|
||||
phases=phases,
|
||||
variables=data.get("variables", {}),
|
||||
id=data.get("id", str(uuid.uuid4())),
|
||||
task=data.get("task", ""),
|
||||
subtasks=subtasks,
|
||||
status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)),
|
||||
lead_expert=data.get("lead_expert", ""),
|
||||
)
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""验证计划,返回错误消息列表(空列表表示有效)
|
||||
def get_subtask(self, subtask_id: str) -> SubTask | None:
|
||||
"""根据 ID 获取子任务,不存在则返回 None"""
|
||||
for st in self.subtasks:
|
||||
if st.id == subtask_id:
|
||||
return st
|
||||
return None
|
||||
|
||||
检查项:
|
||||
- 无重复阶段 ID
|
||||
- 所有 depends_on 引用存在
|
||||
- 无循环依赖(DFS 着色检测)
|
||||
- COMPETITIVE_PARALLEL 阶段必须有 merge_strategy
|
||||
"""
|
||||
errors: list[str] = []
|
||||
|
||||
# 构建阶段 ID 集合
|
||||
phase_ids = {phase.id for phase in self.phases}
|
||||
|
||||
# 检查重复阶段 ID
|
||||
seen_ids: set[str] = set()
|
||||
for phase in self.phases:
|
||||
if phase.id in seen_ids:
|
||||
errors.append(f"重复的阶段 ID: {phase.id}")
|
||||
seen_ids.add(phase.id)
|
||||
|
||||
# 检查 depends_on 引用是否存在
|
||||
for phase in self.phases:
|
||||
for dep_id in phase.depends_on:
|
||||
if dep_id not in phase_ids:
|
||||
errors.append(
|
||||
f"阶段 '{phase.id}' 依赖了不存在的阶段 ID: {dep_id}"
|
||||
)
|
||||
|
||||
# 检查循环依赖(迭代 DFS 着色 — 避免递归栈溢出)
|
||||
color: dict[str, int] = {phase.id: _WHITE for phase in self.phases}
|
||||
dep_map: dict[str, list[str]] = {
|
||||
phase.id: phase.depends_on for phase in self.phases
|
||||
}
|
||||
|
||||
for phase in self.phases:
|
||||
if color[phase.id] != _WHITE:
|
||||
continue
|
||||
# Iterative DFS using an explicit stack
|
||||
stack: list[tuple[str, bool]] = [(phase.id, False)]
|
||||
while stack:
|
||||
node, is_backtrack = stack.pop()
|
||||
if is_backtrack:
|
||||
color[node] = _BLACK
|
||||
continue
|
||||
if color[node] == _GRAY:
|
||||
# Already on current path — cycle detected
|
||||
errors.append("检测到循环依赖")
|
||||
break
|
||||
if color[node] == _BLACK:
|
||||
continue
|
||||
color[node] = _GRAY
|
||||
# Push backtrack marker
|
||||
stack.append((node, True))
|
||||
for neighbor in dep_map.get(node, []):
|
||||
if neighbor not in color:
|
||||
continue
|
||||
if color[neighbor] == _GRAY:
|
||||
errors.append("检测到循环依赖")
|
||||
break
|
||||
if color[neighbor] == _WHITE:
|
||||
stack.append((neighbor, False))
|
||||
else:
|
||||
continue
|
||||
break # Inner break propagates to outer
|
||||
if errors:
|
||||
break # Only report cycle once
|
||||
|
||||
# 检查 COMPETITIVE_PARALLEL 必须有 merge_strategy
|
||||
for phase in self.phases:
|
||||
if (
|
||||
phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
|
||||
and phase.merge_strategy is None
|
||||
):
|
||||
errors.append(
|
||||
f"阶段 '{phase.id}' 为 COMPETITIVE_PARALLEL 但未设置 merge_strategy"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
def get_ready_phases(self) -> list[PlanPhase]:
|
||||
"""获取所有依赖已完成且状态为 PENDING 的阶段"""
|
||||
completed_ids = {
|
||||
phase.id for phase in self.phases if phase.status == PhaseStatus.COMPLETED
|
||||
}
|
||||
ready: list[PlanPhase] = []
|
||||
for phase in self.phases:
|
||||
if phase.status != PhaseStatus.PENDING:
|
||||
continue
|
||||
if all(dep_id in completed_ids for dep_id in phase.depends_on):
|
||||
ready.append(phase)
|
||||
return ready
|
||||
|
||||
def get_phase(self, phase_id: str) -> PlanPhase | None:
|
||||
"""根据 ID 获取阶段,不存在则返回 None (O(1) lookup)"""
|
||||
return self._phase_index.get(phase_id)
|
||||
|
||||
def update_phase_status(
|
||||
self, phase_id: str, status: PhaseStatus, result: dict | None = None
|
||||
def update_subtask_status(
|
||||
self, subtask_id: str, status: SubTaskStatus, result: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""更新阶段状态和可选的结果"""
|
||||
phase = self.get_phase(phase_id)
|
||||
if phase is not None:
|
||||
phase.status = status
|
||||
"""更新子任务状态和可选的结果"""
|
||||
st = self.get_subtask(subtask_id)
|
||||
if st is not None:
|
||||
st.status = status
|
||||
if result is not None:
|
||||
phase.result = result
|
||||
st.result = result
|
||||
|
||||
@property
|
||||
def completed_subtasks(self) -> list[SubTask]:
|
||||
"""已完成的子任务列表"""
|
||||
return [st for st in self.subtasks if st.status == SubTaskStatus.COMPLETED]
|
||||
|
||||
@property
|
||||
def failed_subtasks(self) -> list[SubTask]:
|
||||
"""失败的子任务列表"""
|
||||
return [st for st in self.subtasks if st.status == SubTaskStatus.FAILED]
|
||||
|
||||
@property
|
||||
def all_done(self) -> bool:
|
||||
"""所有子任务是否都已完成(成功或失败)"""
|
||||
return all(
|
||||
st.status in (SubTaskStatus.COMPLETED, SubTaskStatus.FAILED) for st in self.subtasks
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from agentkit.core.exceptions import ConfigValidationError
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
from agentkit.experts.config import ExpertTemplate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -62,10 +61,7 @@ class ExpertTemplateRegistry:
|
|||
query_lower = query.lower()
|
||||
results: list[ExpertTemplate] = []
|
||||
for template in self._templates.values():
|
||||
if (
|
||||
query_lower in template.name.lower()
|
||||
or query_lower in template.description.lower()
|
||||
):
|
||||
if query_lower in template.name.lower() or query_lower in template.description.lower():
|
||||
results.append(template)
|
||||
return results
|
||||
|
||||
|
|
@ -134,7 +130,5 @@ class ExpertTemplateRegistry:
|
|||
template = self.load_from_yaml(filepath)
|
||||
loaded.append(template)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load ExpertTemplate from '{filepath}': {e}"
|
||||
)
|
||||
logger.warning(f"Failed to load ExpertTemplate from '{filepath}': {e}")
|
||||
return loaded
|
||||
|
|
|
|||
|
|
@ -1,7 +1,13 @@
|
|||
"""ExpertTeam - 专家团队容器
|
||||
"""ExpertTeam - 专家团队容器(hub-and-spoke 模式)
|
||||
|
||||
管理 Expert 生命周期、共享上下文、协作计划和团队状态,
|
||||
是 Expert Team 协作模式的中央协调点。
|
||||
管理 Expert 生命周期、团队状态和事件广播,
|
||||
是 Expert Team hub-and-spoke 协作模式的中央协调点。
|
||||
|
||||
简化说明(U3):
|
||||
- 移除 CollaborationPlan 依赖(Lead Expert 自主分解任务)
|
||||
- 移除跨阶段状态共享(Lead Expert 持有所有状态)
|
||||
- 保留 handoff_transport 用于事件广播(不再用于 Agent 间通信)
|
||||
- 保留 workspace 用于输出保存(不再用于跨阶段状态共享)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -14,7 +20,6 @@ import uuid
|
|||
|
||||
from .config import ExpertConfig
|
||||
from .expert import Expert
|
||||
from .plan import CollaborationPlan, PlanStatus
|
||||
from .registry import ExpertTemplateRegistry
|
||||
from ..core.handoff_transport import InProcessHandoffTransport
|
||||
from ..core.shared_workspace import SharedWorkspace
|
||||
|
|
@ -27,7 +32,6 @@ class TeamStatus(str, enum.Enum):
|
|||
"""ExpertTeam lifecycle states."""
|
||||
|
||||
FORMING = "forming"
|
||||
PLANNING = "planning"
|
||||
EXECUTING = "executing"
|
||||
SYNTHESIZING = "synthesizing"
|
||||
COMPLETED = "completed"
|
||||
|
|
@ -35,7 +39,14 @@ class TeamStatus(str, enum.Enum):
|
|||
|
||||
|
||||
class ExpertTeam:
|
||||
"""Container managing a team of Experts working together on a task."""
|
||||
"""Container managing a team of Experts in hub-and-spoke mode.
|
||||
|
||||
In hub-and-spoke mode:
|
||||
- Lead Expert (hub) receives the task and decomposes it
|
||||
- Member Experts (spokes) execute subtasks in parallel
|
||||
- Lead Expert synthesizes the final result
|
||||
- No inter-agent communication (Lead Expert holds all state)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -51,7 +62,6 @@ class ExpertTeam:
|
|||
self._handoff_transport = InProcessHandoffTransport()
|
||||
self._experts: dict[str, Expert] = {}
|
||||
self._lead_expert_name: str | None = None
|
||||
self._plan: CollaborationPlan | None = None
|
||||
self._status = TeamStatus.FORMING
|
||||
self._team_channel = f"team:{self.team_id}"
|
||||
self._orchestrator_task: asyncio.Task | None = None
|
||||
|
|
@ -66,10 +76,6 @@ class ExpertTeam:
|
|||
return self._experts.get(self._lead_expert_name)
|
||||
return None
|
||||
|
||||
@property
|
||||
def plan(self) -> CollaborationPlan | None:
|
||||
return self._plan
|
||||
|
||||
@property
|
||||
def experts(self) -> list[Expert]:
|
||||
return list(self._experts.values())
|
||||
|
|
@ -80,12 +86,21 @@ class ExpertTeam:
|
|||
|
||||
@property
|
||||
def workspace(self) -> SharedWorkspace:
|
||||
"""Public read access to the team's shared workspace."""
|
||||
"""Public read access to the team's shared workspace.
|
||||
|
||||
In hub-and-spoke mode, workspace is used for output preservation only,
|
||||
not for cross-phase state sharing.
|
||||
"""
|
||||
return self._workspace
|
||||
|
||||
@property
|
||||
def handoff_transport(self):
|
||||
"""Public read access to the team's handoff transport."""
|
||||
"""Public read access to the team's handoff transport.
|
||||
|
||||
In hub-and-spoke mode, handoff_transport is used for event broadcasting
|
||||
(team_formed, expert_step, expert_result, team_synthesis, team_dissolved),
|
||||
not for inter-agent communication.
|
||||
"""
|
||||
return self._handoff_transport
|
||||
|
||||
@property
|
||||
|
|
@ -106,7 +121,13 @@ class ExpertTeam:
|
|||
lead_config: ExpertConfig,
|
||||
member_configs: list[ExpertConfig] | None = None,
|
||||
) -> None:
|
||||
"""Create a team with a Lead Expert and optional members."""
|
||||
"""Create a team with a Lead Expert and optional members.
|
||||
|
||||
In hub-and-spoke mode, the Lead Expert acts as the hub:
|
||||
- Receives the task and decomposes it into subtasks
|
||||
- Dispatches subtasks to member experts (spokes)
|
||||
- Synthesizes the final result
|
||||
"""
|
||||
if not self._pool:
|
||||
raise RuntimeError("AgentPool not configured")
|
||||
|
||||
|
|
@ -128,7 +149,7 @@ class ExpertTeam:
|
|||
for config in member_configs:
|
||||
await self._add_expert_internal(config, team_context)
|
||||
|
||||
self._status = TeamStatus.PLANNING
|
||||
self._status = TeamStatus.EXECUTING
|
||||
|
||||
async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert:
|
||||
"""Add an Expert to the team dynamically.
|
||||
|
|
@ -155,9 +176,7 @@ class ExpertTeam:
|
|||
)
|
||||
return await self._add_expert_internal(config, team_context)
|
||||
|
||||
async def _add_expert_internal(
|
||||
self, config: ExpertConfig, team_context: str
|
||||
) -> Expert:
|
||||
async def _add_expert_internal(self, config: ExpertConfig, team_context: str) -> Expert:
|
||||
"""Internal method to add an Expert."""
|
||||
if not self._pool:
|
||||
raise RuntimeError("AgentPool not configured")
|
||||
|
|
@ -204,13 +223,6 @@ class ExpertTeam:
|
|||
await expert.destroy(self._pool)
|
||||
del self._experts[name]
|
||||
|
||||
# Update plan: reassign phases that referenced the removed expert
|
||||
if self._plan:
|
||||
new_lead_name = self._lead_expert_name
|
||||
for phase in self._plan.phases:
|
||||
if phase.assigned_expert == name:
|
||||
phase.assigned_expert = new_lead_name or ""
|
||||
|
||||
# Broadcast expert left
|
||||
await self._handoff_transport.send(
|
||||
self._team_channel,
|
||||
|
|
@ -220,24 +232,6 @@ class ExpertTeam:
|
|||
},
|
||||
)
|
||||
|
||||
def update_plan(self, plan: CollaborationPlan) -> list[str]:
|
||||
"""Update the collaboration plan. Only Lead Expert or user should call this.
|
||||
|
||||
Returns list of affected expert names on success, or list of validation
|
||||
error strings on failure (empty list with no errors = success).
|
||||
"""
|
||||
errors = plan.validate()
|
||||
if errors:
|
||||
return errors # Return validation errors instead of silently swallowing
|
||||
|
||||
self._plan = plan
|
||||
if plan.status == PlanStatus.CONFIRMED:
|
||||
self._status = TeamStatus.EXECUTING
|
||||
|
||||
# Determine affected experts
|
||||
affected = [p.assigned_expert for p in plan.phases]
|
||||
return affected
|
||||
|
||||
async def broadcast_user_message(self, content: str) -> None:
|
||||
"""Broadcast a user intervention message to all active Experts."""
|
||||
message = {
|
||||
|
|
@ -248,7 +242,11 @@ class ExpertTeam:
|
|||
await self._handoff_transport.send(self._team_channel, message)
|
||||
|
||||
async def get_shared_context(self) -> dict:
|
||||
"""Get the team's shared context from SharedWorkspace."""
|
||||
"""Get the team's shared context from SharedWorkspace.
|
||||
|
||||
In hub-and-spoke mode, this returns preserved outputs only,
|
||||
not cross-phase state.
|
||||
"""
|
||||
context = {}
|
||||
keys = await self._workspace.list_keys()
|
||||
for key in keys:
|
||||
|
|
@ -258,23 +256,6 @@ class ExpertTeam:
|
|||
context[key] = data
|
||||
return context
|
||||
|
||||
async def generate_plan(self, task: str) -> CollaborationPlan:
|
||||
"""Generate a CollaborationPlan for the task.
|
||||
|
||||
Uses hybrid mode: core roles from template registry, auxiliary roles dynamically generated.
|
||||
This method creates a plan structure — the actual LLM-based task decomposition
|
||||
will be handled by TeamOrchestrator.
|
||||
"""
|
||||
plan_id = str(uuid.uuid4())
|
||||
plan = CollaborationPlan(
|
||||
id=plan_id,
|
||||
task=task,
|
||||
phases=[],
|
||||
lead_expert=self._lead_expert_name or "",
|
||||
)
|
||||
self._plan = plan
|
||||
return plan
|
||||
|
||||
async def dissolve(self) -> None:
|
||||
"""Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace."""
|
||||
# Cancel ongoing orchestrator task if any
|
||||
|
|
@ -302,20 +283,31 @@ class ExpertTeam:
|
|||
lead_config: ExpertConfig | None,
|
||||
member_configs: list[ExpertConfig],
|
||||
) -> str:
|
||||
"""Build team context string for injection into Expert system prompts."""
|
||||
lines = ["You are part of an Expert Team."]
|
||||
"""Build team context string for injection into Expert system prompts.
|
||||
|
||||
In hub-and-spoke mode, the context emphasizes the Lead Expert's role
|
||||
as the hub and member experts' role as spokes.
|
||||
"""
|
||||
lines = ["You are part of an Expert Team (hub-and-spoke mode)."]
|
||||
|
||||
if lead_config:
|
||||
lines.append(f"Lead Expert: {lead_config.name} ({lead_config.persona})")
|
||||
lines.append(
|
||||
f"Lead Expert (hub): {lead_config.name} ({lead_config.persona}) — "
|
||||
f"decomposes tasks, dispatches subtasks, and synthesizes results."
|
||||
)
|
||||
|
||||
for config in member_configs:
|
||||
if lead_config and config.name == lead_config.name:
|
||||
continue
|
||||
lines.append(
|
||||
f"Team Member: {config.name} ({config.persona}), Skills: {', '.join(config.bound_skills)}"
|
||||
f"Team Member (spoke): {config.name} ({config.persona}), "
|
||||
f"Skills: {', '.join(config.bound_skills)} — "
|
||||
f"executes assigned subtasks independently."
|
||||
)
|
||||
|
||||
lines.append(
|
||||
"You can collaborate with other team members via send_message() and request_assist()."
|
||||
"In hub-and-spoke mode: Lead Expert holds all state, "
|
||||
"subtasks are independent (depth=1, no further spawning), "
|
||||
"no inter-agent communication."
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
|||
from agentkit.tools.ask_human import AskHumanTool
|
||||
from agentkit.tools.memory_tool import MemoryTool
|
||||
from agentkit.tools.web_search import WebSearchTool
|
||||
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
|
|
@ -40,6 +42,9 @@ __all__ = [
|
|||
"MemoryTool",
|
||||
"ShellTool",
|
||||
"WebSearchTool",
|
||||
"RunTestsTool",
|
||||
"ToolSearchTool",
|
||||
"ToolSearchIndex",
|
||||
"HeadroomRetrieveTool",
|
||||
"TerminalSession",
|
||||
"TerminalSessionManager",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,189 @@
|
|||
"""Tool search index using BM25 algorithm.
|
||||
|
||||
Provides lexical tool discovery for tiered tool description injection:
|
||||
core tools are fully injected into the LLM prompt, while extended tools
|
||||
are searchable on-demand via the ``tool_search`` tool.
|
||||
|
||||
Pure Python implementation — no external dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
__all__ = ["ToolSearchIndex"]
|
||||
|
||||
|
||||
_TOKEN_RE = re.compile(r"[a-zA-Z0-9_]+")
|
||||
|
||||
|
||||
def _tokenize(text: str) -> list[str]:
|
||||
"""Tokenize text into lowercase alphanumeric tokens.
|
||||
|
||||
Splits on non-alphanumeric characters (keeping underscores so that
|
||||
snake_case tool names like ``read_file`` stay intact) and lowercases.
|
||||
Empty tokens are filtered out.
|
||||
"""
|
||||
return [t for t in _TOKEN_RE.findall(text.lower()) if t]
|
||||
|
||||
|
||||
class ToolSearchIndex:
|
||||
"""BM25-based search index over :class:`Tool` descriptions.
|
||||
|
||||
Each tool is indexed by its ``name`` + ``description`` + parameter
|
||||
metadata (from ``input_schema``) + ``tags``. Supports keyword search
|
||||
returning the most relevant tools ranked by BM25 score.
|
||||
|
||||
Usage::
|
||||
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("read file", top_k=5)
|
||||
"""
|
||||
|
||||
_DEFAULT_K1: float = 1.5
|
||||
_DEFAULT_B: float = 0.75
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: list[Tool],
|
||||
k1: float = _DEFAULT_K1,
|
||||
b: float = _DEFAULT_B,
|
||||
):
|
||||
if k1 < 0:
|
||||
raise ValueError(f"k1 must be >= 0, got {k1}")
|
||||
if not (0.0 <= b <= 1.0):
|
||||
raise ValueError(f"b must be in [0, 1], got {b}")
|
||||
|
||||
self._tools: list[Tool] = list(tools)
|
||||
self._k1 = k1
|
||||
self._b = b
|
||||
|
||||
# Build document corpus from tool metadata
|
||||
self._documents: list[str] = [self._tool_to_text(t) for t in self._tools]
|
||||
self._tokenized_docs: list[list[str]] = [_tokenize(doc) for doc in self._documents]
|
||||
self._doc_lengths: list[int] = [len(tokens) for tokens in self._tokenized_docs]
|
||||
self._avgdl: float = (
|
||||
sum(self._doc_lengths) / len(self._doc_lengths) if self._doc_lengths else 0.0
|
||||
)
|
||||
|
||||
# Term frequencies per document
|
||||
self._term_freqs: list[Counter[str]] = [Counter(tokens) for tokens in self._tokenized_docs]
|
||||
|
||||
# Document frequency per term (number of docs containing the term)
|
||||
self._doc_freq: Counter[str] = Counter()
|
||||
for tokens in self._tokenized_docs:
|
||||
for term in set(tokens):
|
||||
self._doc_freq[term] += 1
|
||||
|
||||
# Precompute IDF for every term in the corpus
|
||||
self._idf: dict[str, float] = self._compute_idf()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Index construction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _tool_to_text(tool: Tool) -> str:
|
||||
"""Convert a tool's searchable metadata into a single text document."""
|
||||
parts: list[str] = [str(tool.name), str(tool.description)]
|
||||
|
||||
schema: dict[str, Any] | None = tool.input_schema
|
||||
if schema:
|
||||
props = schema.get("properties", {})
|
||||
for pname, pinfo in props.items():
|
||||
parts.append(str(pname))
|
||||
if isinstance(pinfo, dict):
|
||||
pdesc = pinfo.get("description", "")
|
||||
if pdesc:
|
||||
parts.append(str(pdesc))
|
||||
|
||||
if tool.tags:
|
||||
parts.extend(str(t) for t in tool.tags)
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def _compute_idf(self) -> dict[str, float]:
|
||||
"""Compute IDF for each term using the BM25 formula.
|
||||
|
||||
``idf(t) = ln((N - n + 0.5) / (n + 0.5) + 1)``
|
||||
|
||||
where ``N`` is the total number of documents and ``n`` is the
|
||||
number of documents containing term ``t``. The ``+1`` inside the
|
||||
logarithm guarantees a non-negative IDF.
|
||||
"""
|
||||
n_docs = len(self._tools)
|
||||
if n_docs == 0:
|
||||
return {}
|
||||
|
||||
idf: dict[str, float] = {}
|
||||
for term, df in self._doc_freq.items():
|
||||
idf[term] = math.log((n_docs - df + 0.5) / (df + 0.5) + 1.0)
|
||||
return idf
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scoring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _bm25_score(self, query_tokens: list[str], doc_idx: int) -> float:
|
||||
"""Compute the BM25 score for a single document against the query."""
|
||||
if not query_tokens or doc_idx >= len(self._tools):
|
||||
return 0.0
|
||||
|
||||
doc_len = self._doc_lengths[doc_idx]
|
||||
term_freq = self._term_freqs[doc_idx]
|
||||
norm = (
|
||||
(1 - self._b + self._b * (doc_len / self._avgdl)) if self._avgdl > 0 else (1 - self._b)
|
||||
)
|
||||
score = 0.0
|
||||
|
||||
# Each unique query term contributes once
|
||||
for term in set(query_tokens):
|
||||
tf = term_freq.get(term, 0)
|
||||
if tf == 0:
|
||||
continue
|
||||
idf = self._idf.get(term, 0.0)
|
||||
if idf <= 0:
|
||||
continue
|
||||
denom = tf + self._k1 * norm
|
||||
score += idf * (tf * (self._k1 + 1)) / denom
|
||||
|
||||
return score
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[Tool]:
|
||||
"""Search tools by keyword query.
|
||||
|
||||
Args:
|
||||
query: Search keywords (natural language or tool-name fragments).
|
||||
top_k: Maximum number of tools to return.
|
||||
|
||||
Returns:
|
||||
Tools ranked by relevance (most relevant first). Only tools
|
||||
with a positive BM25 score are returned.
|
||||
"""
|
||||
if top_k <= 0 or not self._tools:
|
||||
return []
|
||||
|
||||
query_tokens = _tokenize(query)
|
||||
if not query_tokens:
|
||||
return []
|
||||
|
||||
scored: list[tuple[int, float]] = []
|
||||
for i in range(len(self._tools)):
|
||||
score = self._bm25_score(query_tokens, i)
|
||||
if score > 0:
|
||||
scored.append((i, score))
|
||||
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return [self._tools[i] for i, _ in scored[:top_k]]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
|
@ -0,0 +1,666 @@
|
|||
"""Tests for EventQueue — SQ/EQ 双队列实现
|
||||
|
||||
测试场景:
|
||||
- SQ 正确接收用户输入并返回 task_id
|
||||
- EQ 正确推送事件给订阅者
|
||||
- 多订阅者同时接收事件(广播)
|
||||
- 事件缓冲对新订阅者的回放
|
||||
- SQ 取消任务
|
||||
- 事件类型正确分类
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
|
||||
from agentkit.core.protocol import (
|
||||
Event,
|
||||
SessionEventType,
|
||||
TaskEventType,
|
||||
TurnEventType,
|
||||
)
|
||||
|
||||
|
||||
# ── SubmissionQueue Tests ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestSubmissionQueue:
|
||||
"""SubmissionQueue 单元测试"""
|
||||
|
||||
async def test_submit_returns_task_id(self):
|
||||
"""测试 submit 返回有效的 task_id"""
|
||||
sq = SubmissionQueue()
|
||||
task_id = await sq.submit("hello", "session-1")
|
||||
|
||||
assert isinstance(task_id, str)
|
||||
assert len(task_id) > 0
|
||||
|
||||
async def test_submit_returns_unique_task_ids(self):
|
||||
"""测试每次 submit 返回不同的 task_id"""
|
||||
sq = SubmissionQueue()
|
||||
task_id_1 = await sq.submit("hello", "session-1")
|
||||
task_id_2 = await sq.submit("world", "session-1")
|
||||
|
||||
assert task_id_1 != task_id_2
|
||||
|
||||
async def test_submit_stores_submission(self):
|
||||
"""测试 submit 正确存储提交内容"""
|
||||
sq = SubmissionQueue()
|
||||
task_id = await sq.submit("hello world", "session-1")
|
||||
|
||||
assert task_id in sq._submissions
|
||||
submission = sq._submissions[task_id]
|
||||
assert submission.content == "hello world"
|
||||
assert submission.session_id == "session-1"
|
||||
assert submission.task_id == task_id
|
||||
assert submission.cancelled is False
|
||||
|
||||
async def test_drain_receives_submissions_in_order(self):
|
||||
"""测试 drain 按提交顺序接收提交"""
|
||||
sq = SubmissionQueue()
|
||||
await sq.submit("first", "session-1")
|
||||
await sq.submit("second", "session-1")
|
||||
|
||||
received: list[str] = []
|
||||
|
||||
async def consumer():
|
||||
async for submission in sq.drain():
|
||||
received.append(submission.content)
|
||||
if len(received) >= 2:
|
||||
break
|
||||
|
||||
consumer_task = asyncio.create_task(consumer())
|
||||
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||
|
||||
assert received == ["first", "second"]
|
||||
|
||||
async def test_drain_preserves_submission_fields(self):
|
||||
"""测试 drain 返回的 Submission 字段完整"""
|
||||
sq = SubmissionQueue()
|
||||
await sq.submit("hello", "session-1")
|
||||
|
||||
received: list[Submission] = []
|
||||
|
||||
async def consumer():
|
||||
async for submission in sq.drain():
|
||||
received.append(submission)
|
||||
break
|
||||
|
||||
consumer_task = asyncio.create_task(consumer())
|
||||
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 1
|
||||
sub = received[0]
|
||||
assert sub.content == "hello"
|
||||
assert sub.session_id == "session-1"
|
||||
assert isinstance(sub.task_id, str)
|
||||
assert isinstance(sub.created_at, datetime)
|
||||
|
||||
async def test_cancel_task_succeeds(self):
|
||||
"""测试取消已存在的任务"""
|
||||
sq = SubmissionQueue()
|
||||
task_id = await sq.submit("hello", "session-1")
|
||||
|
||||
result = await sq.cancel(task_id)
|
||||
|
||||
assert result is True
|
||||
assert task_id in sq._cancelled_tasks
|
||||
assert sq._submissions[task_id].cancelled is True
|
||||
|
||||
async def test_cancel_nonexistent_task_returns_false(self):
|
||||
"""测试取消不存在的任务返回 False"""
|
||||
sq = SubmissionQueue()
|
||||
result = await sq.cancel("nonexistent-task-id")
|
||||
|
||||
assert result is False
|
||||
|
||||
async def test_cancel_already_cancelled_task_returns_false(self):
|
||||
"""测试重复取消返回 False"""
|
||||
sq = SubmissionQueue()
|
||||
task_id = await sq.submit("hello", "session-1")
|
||||
|
||||
first_cancel = await sq.cancel(task_id)
|
||||
second_cancel = await sq.cancel(task_id)
|
||||
|
||||
assert first_cancel is True
|
||||
assert second_cancel is False
|
||||
|
||||
async def test_drain_skips_cancelled_submissions(self):
|
||||
"""测试 drain 跳过已取消的提交"""
|
||||
sq = SubmissionQueue()
|
||||
task_id_1 = await sq.submit("first", "session-1")
|
||||
await sq.submit("second", "session-1")
|
||||
|
||||
# 取消第一个提交
|
||||
await sq.cancel(task_id_1)
|
||||
|
||||
received: list[str] = []
|
||||
|
||||
async def consumer():
|
||||
async for submission in sq.drain():
|
||||
received.append(submission.content)
|
||||
if len(received) >= 1:
|
||||
break
|
||||
|
||||
consumer_task = asyncio.create_task(consumer())
|
||||
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||
|
||||
# 应只收到第二个提交
|
||||
assert received == ["second"]
|
||||
|
||||
async def test_close_prevents_new_submissions(self):
|
||||
"""测试关闭后不再接受新提交"""
|
||||
sq = SubmissionQueue()
|
||||
sq.close()
|
||||
|
||||
assert sq.is_closed is True
|
||||
|
||||
try:
|
||||
await sq.submit("hello", "session-1")
|
||||
raise AssertionError("Should have raised RuntimeError")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
async def test_close_does_not_affect_existing_submissions(self):
|
||||
"""测试关闭后已提交的内容仍可消费"""
|
||||
sq = SubmissionQueue()
|
||||
await sq.submit("before-close", "session-1")
|
||||
sq.close()
|
||||
|
||||
received: list[str] = []
|
||||
|
||||
async def consumer():
|
||||
async for submission in sq.drain():
|
||||
received.append(submission.content)
|
||||
break
|
||||
|
||||
consumer_task = asyncio.create_task(consumer())
|
||||
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||
|
||||
assert received == ["before-close"]
|
||||
|
||||
|
||||
# ── EventQueue Tests ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEventQueue:
|
||||
"""EventQueue 单元测试"""
|
||||
|
||||
async def test_emit_and_subscribe_single_event(self):
|
||||
"""测试 EQ 正确推送事件给订阅者"""
|
||||
eq = EventQueue()
|
||||
event = Event.create(
|
||||
event_type=TurnEventType.TOKEN,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"text": "hello"},
|
||||
)
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05) # 给订阅者启动时间
|
||||
|
||||
await eq.emit(event)
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].event_type == TurnEventType.TOKEN
|
||||
assert received[0].task_id == "task-1"
|
||||
assert received[0].session_id == "session-1"
|
||||
assert received[0].data == {"text": "hello"}
|
||||
|
||||
async def test_emit_preserves_event_fields(self):
|
||||
"""测试 emit 不修改事件字段"""
|
||||
eq = EventQueue()
|
||||
original = Event.create(
|
||||
event_type=TaskEventType.TASK_STARTED,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"agent": "react"},
|
||||
)
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
await eq.emit(original)
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert received[0] is original or received[0].to_dict() == original.to_dict()
|
||||
|
||||
async def test_broadcast_to_multiple_subscribers(self):
|
||||
"""测试多订阅者同时接收事件(广播)"""
|
||||
eq = EventQueue()
|
||||
|
||||
received_a: list[Event] = []
|
||||
received_b: list[Event] = []
|
||||
|
||||
async def subscriber_a():
|
||||
async for evt in eq.subscribe():
|
||||
received_a.append(evt)
|
||||
if len(received_a) >= 2:
|
||||
break
|
||||
|
||||
async def subscriber_b():
|
||||
async for evt in eq.subscribe():
|
||||
received_b.append(evt)
|
||||
if len(received_b) >= 2:
|
||||
break
|
||||
|
||||
task_a = asyncio.create_task(subscriber_a())
|
||||
task_b = asyncio.create_task(subscriber_b())
|
||||
await asyncio.sleep(0.05) # 给订阅者启动时间
|
||||
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||
|
||||
await asyncio.wait_for(task_a, timeout=1.0)
|
||||
await asyncio.wait_for(task_b, timeout=1.0)
|
||||
|
||||
assert len(received_a) == 2
|
||||
assert len(received_b) == 2
|
||||
assert received_a[0].data == {"seq": 1}
|
||||
assert received_a[1].data == {"seq": 2}
|
||||
assert received_b[0].data == {"seq": 1}
|
||||
assert received_b[1].data == {"seq": 2}
|
||||
|
||||
async def test_buffer_replay_for_new_subscriber(self):
|
||||
"""测试事件缓冲对新订阅者的回放"""
|
||||
eq = EventQueue(buffer_size=100)
|
||||
|
||||
# 先发送几条事件(无订阅者)
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
|
||||
|
||||
# 新订阅者应收到缓冲回放
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
if len(received) >= 3:
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 3
|
||||
assert received[0].data == {"seq": 1}
|
||||
assert received[1].data == {"seq": 2}
|
||||
assert received[2].data == {"seq": 3}
|
||||
|
||||
async def test_buffer_replay_then_live_events(self):
|
||||
"""测试新订阅者先收到回放,再收到新事件"""
|
||||
eq = EventQueue(buffer_size=100)
|
||||
|
||||
# 先发送 2 条事件(进入缓冲)
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
if len(received) >= 4:
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲)
|
||||
|
||||
# 再发送 2 条新事件
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4}))
|
||||
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 4
|
||||
assert [r.data["seq"] for r in received] == [1, 2, 3, 4]
|
||||
|
||||
async def test_buffer_size_limit_keeps_latest(self):
|
||||
"""测试缓冲区大小限制,只保留最新 N 条"""
|
||||
eq = EventQueue(buffer_size=3)
|
||||
|
||||
# 发送 5 条事件,缓冲区只保留最后 3 条
|
||||
for i in range(5):
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i}))
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
if len(received) >= 3:
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 3
|
||||
# 应该是最后 3 条(seq: 2, 3, 4)
|
||||
assert [r.data["seq"] for r in received] == [2, 3, 4]
|
||||
|
||||
async def test_default_buffer_size_is_100(self):
|
||||
"""测试默认缓冲区大小为 100"""
|
||||
eq = EventQueue()
|
||||
|
||||
assert eq.buffer_size == 100
|
||||
|
||||
async def test_close_unblocks_subscribers(self):
|
||||
"""测试 close 解除订阅者阻塞"""
|
||||
eq = EventQueue()
|
||||
|
||||
async def subscriber():
|
||||
async for _ in eq.subscribe():
|
||||
pass # 消费事件直到队列关闭
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
eq.close()
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert sub_task.done()
|
||||
assert eq.is_closed is True
|
||||
|
||||
async def test_subscribe_after_close_returns_immediately(self):
|
||||
"""测试关闭后订阅立即返回(不阻塞)"""
|
||||
eq = EventQueue()
|
||||
eq.close()
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert sub_task.done()
|
||||
assert len(received) == 0
|
||||
|
||||
async def test_subscriber_count_tracks_subscriptions(self):
|
||||
"""测试订阅者计数正确跟踪订阅"""
|
||||
eq = EventQueue()
|
||||
|
||||
assert eq.subscriber_count == 0
|
||||
|
||||
async def subscriber():
|
||||
async for _ in eq.subscribe():
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
assert eq.subscriber_count == 1
|
||||
|
||||
eq.close()
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
|
||||
assert eq.subscriber_count == 0
|
||||
|
||||
async def test_subscriber_removed_on_explicit_close(self):
|
||||
"""测试显式关闭订阅生成器后从列表移除
|
||||
|
||||
注意:async for 的 break 不会立即触发生成器的 finally,
|
||||
需要显式调用 aclose() 才能保证清理。
|
||||
"""
|
||||
eq = EventQueue()
|
||||
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
gen = eq.subscribe()
|
||||
try:
|
||||
async for evt in gen:
|
||||
received.append(evt)
|
||||
break
|
||||
finally:
|
||||
await gen.aclose()
|
||||
|
||||
task = asyncio.create_task(subscriber())
|
||||
await asyncio.sleep(0.05)
|
||||
assert eq.subscriber_count == 1
|
||||
|
||||
# 触发一次 emit 让订阅者能收到事件并 break
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
|
||||
assert len(received) == 1
|
||||
assert eq.subscriber_count == 0
|
||||
|
||||
async def test_emit_to_no_subscribers_still_buffers(self):
|
||||
"""测试无订阅者时 emit 仍写入缓冲区"""
|
||||
eq = EventQueue(buffer_size=100)
|
||||
|
||||
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||
|
||||
# 缓冲区应有 1 条
|
||||
assert len(eq._buffer) == 1
|
||||
|
||||
# 新订阅者应收到回放
|
||||
received: list[Event] = []
|
||||
|
||||
async def subscriber():
|
||||
async for evt in eq.subscribe():
|
||||
received.append(evt)
|
||||
break
|
||||
|
||||
sub_task = asyncio.create_task(subscriber())
|
||||
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].data == {"seq": 1}
|
||||
|
||||
|
||||
# ── Event Type Tests ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEventTypes:
|
||||
"""事件类型分类测试"""
|
||||
|
||||
def test_session_event_types(self):
|
||||
"""测试 Session 级别事件类型"""
|
||||
assert SessionEventType.SESSION_STARTED == "session.started"
|
||||
assert SessionEventType.SESSION_ENDED == "session.ended"
|
||||
|
||||
def test_task_event_types(self):
|
||||
"""测试 Task 级别事件类型"""
|
||||
assert TaskEventType.TASK_CREATED == "task.created"
|
||||
assert TaskEventType.TASK_STARTED == "task.started"
|
||||
assert TaskEventType.TASK_COMPLETED == "task.completed"
|
||||
assert TaskEventType.TASK_FAILED == "task.failed"
|
||||
|
||||
def test_turn_event_types(self):
|
||||
"""测试 Turn 级别事件类型"""
|
||||
assert TurnEventType.TURN_STARTED == "turn.started"
|
||||
assert TurnEventType.THINKING == "turn.thinking"
|
||||
assert TurnEventType.TOOL_CALL == "turn.tool_call"
|
||||
assert TurnEventType.TOOL_RESULT == "turn.tool_result"
|
||||
assert TurnEventType.TOKEN == "turn.token"
|
||||
assert TurnEventType.STEP == "turn.step"
|
||||
assert TurnEventType.FINAL_ANSWER == "turn.final_answer"
|
||||
assert TurnEventType.TURN_COMPLETED == "turn.completed"
|
||||
|
||||
def test_event_type_prefixes(self):
|
||||
"""测试事件类型按前缀正确分类"""
|
||||
session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED]
|
||||
task_events = [
|
||||
TaskEventType.TASK_CREATED,
|
||||
TaskEventType.TASK_STARTED,
|
||||
TaskEventType.TASK_COMPLETED,
|
||||
TaskEventType.TASK_FAILED,
|
||||
]
|
||||
turn_events = [
|
||||
TurnEventType.TURN_STARTED,
|
||||
TurnEventType.THINKING,
|
||||
TurnEventType.TOOL_CALL,
|
||||
TurnEventType.TOOL_RESULT,
|
||||
TurnEventType.TOKEN,
|
||||
TurnEventType.STEP,
|
||||
TurnEventType.FINAL_ANSWER,
|
||||
TurnEventType.TURN_COMPLETED,
|
||||
]
|
||||
|
||||
for evt in session_events:
|
||||
assert evt.startswith("session."), f"{evt} should start with 'session.'"
|
||||
for evt in task_events:
|
||||
assert evt.startswith("task."), f"{evt} should start with 'task.'"
|
||||
for evt in turn_events:
|
||||
assert evt.startswith("turn."), f"{evt} should start with 'turn.'"
|
||||
|
||||
def test_event_types_are_distinct(self):
|
||||
"""测试所有事件类型互不相同"""
|
||||
all_types = [
|
||||
SessionEventType.SESSION_STARTED,
|
||||
SessionEventType.SESSION_ENDED,
|
||||
TaskEventType.TASK_CREATED,
|
||||
TaskEventType.TASK_STARTED,
|
||||
TaskEventType.TASK_COMPLETED,
|
||||
TaskEventType.TASK_FAILED,
|
||||
TurnEventType.TURN_STARTED,
|
||||
TurnEventType.THINKING,
|
||||
TurnEventType.TOOL_CALL,
|
||||
TurnEventType.TOOL_RESULT,
|
||||
TurnEventType.TOKEN,
|
||||
TurnEventType.STEP,
|
||||
TurnEventType.FINAL_ANSWER,
|
||||
TurnEventType.TURN_COMPLETED,
|
||||
]
|
||||
|
||||
assert len(all_types) == len(set(all_types)), "Event types should be distinct"
|
||||
|
||||
|
||||
# ── Event Dataclass Tests ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestEventDataclass:
|
||||
"""Event 数据结构测试"""
|
||||
|
||||
def test_event_creation(self):
|
||||
"""测试 Event 创建"""
|
||||
event = Event(
|
||||
event_type=TurnEventType.TOKEN,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"text": "hello"},
|
||||
timestamp="2025-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
assert event.event_type == TurnEventType.TOKEN
|
||||
assert event.task_id == "task-1"
|
||||
assert event.session_id == "session-1"
|
||||
assert event.data == {"text": "hello"}
|
||||
assert event.timestamp == "2025-01-01T00:00:00+00:00"
|
||||
|
||||
def test_event_create_factory_generates_timestamp(self):
|
||||
"""测试 Event.create 工厂方法自动生成时间戳"""
|
||||
event = Event.create(
|
||||
event_type=TaskEventType.TASK_STARTED,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"agent": "react"},
|
||||
)
|
||||
|
||||
assert event.event_type == TaskEventType.TASK_STARTED
|
||||
assert event.task_id == "task-1"
|
||||
assert event.session_id == "session-1"
|
||||
assert event.data == {"agent": "react"}
|
||||
assert len(event.timestamp) > 0
|
||||
# 时间戳应为 ISO 8601 格式(fromisoformat 不抛异常即正确)
|
||||
datetime.fromisoformat(event.timestamp)
|
||||
|
||||
def test_event_create_with_default_data(self):
|
||||
"""测试 Event.create 不传 data 时默认为空 dict"""
|
||||
event = Event.create(
|
||||
event_type=SessionEventType.SESSION_STARTED,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
)
|
||||
|
||||
assert event.data == {}
|
||||
|
||||
def test_event_to_dict(self):
|
||||
"""测试 Event.to_dict"""
|
||||
event = Event(
|
||||
event_type=TurnEventType.TOKEN,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"text": "hello"},
|
||||
timestamp="2025-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
d = event.to_dict()
|
||||
|
||||
assert d == {
|
||||
"event_type": TurnEventType.TOKEN,
|
||||
"task_id": "task-1",
|
||||
"session_id": "session-1",
|
||||
"data": {"text": "hello"},
|
||||
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
|
||||
def test_event_from_dict(self):
|
||||
"""测试 Event.from_dict"""
|
||||
data = {
|
||||
"event_type": TurnEventType.TOKEN,
|
||||
"task_id": "task-1",
|
||||
"session_id": "session-1",
|
||||
"data": {"text": "hello"},
|
||||
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
|
||||
event = Event.from_dict(data)
|
||||
|
||||
assert event.event_type == TurnEventType.TOKEN
|
||||
assert event.task_id == "task-1"
|
||||
assert event.session_id == "session-1"
|
||||
assert event.data == {"text": "hello"}
|
||||
assert event.timestamp == "2025-01-01T00:00:00+00:00"
|
||||
|
||||
def test_event_to_dict_from_dict_roundtrip(self):
|
||||
"""测试 to_dict -> from_dict 往返保持一致"""
|
||||
original = Event.create(
|
||||
event_type=SessionEventType.SESSION_STARTED,
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
data={"user": "alice"},
|
||||
)
|
||||
|
||||
d = original.to_dict()
|
||||
restored = Event.from_dict(d)
|
||||
|
||||
assert restored.event_type == original.event_type
|
||||
assert restored.task_id == original.task_id
|
||||
assert restored.session_id == original.session_id
|
||||
assert restored.data == original.data
|
||||
assert restored.timestamp == original.timestamp
|
||||
|
||||
def test_event_from_dict_with_missing_data_defaults_to_empty(self):
|
||||
"""测试 from_dict 缺少 data 字段时默认为空 dict"""
|
||||
data = {
|
||||
"event_type": TurnEventType.TOKEN,
|
||||
"task_id": "task-1",
|
||||
"session_id": "session-1",
|
||||
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||
}
|
||||
|
||||
event = Event.from_dict(data)
|
||||
|
||||
assert event.data == {}
|
||||
|
|
@ -1,328 +1,288 @@
|
|||
"""CollaborationPlan 数据模型单元测试"""
|
||||
"""TeamPlan / SubTask 数据模型单元测试 (hub-and-spoke 模式)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.experts.plan import (
|
||||
CollaborationPlan,
|
||||
MergeStrategy,
|
||||
ParallelType,
|
||||
PhaseStatus,
|
||||
PlanPhase,
|
||||
PlanStatus,
|
||||
SubTask,
|
||||
SubTaskStatus,
|
||||
TeamPlan,
|
||||
)
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_phase(
|
||||
id: str = "phase_1",
|
||||
name: str = "分析阶段",
|
||||
def _make_subtask(
|
||||
id: str = "subtask_1",
|
||||
description: str = "分析数据",
|
||||
assigned_expert: str = "analyst",
|
||||
task_description: str = "分析需求",
|
||||
depends_on: list[str] | None = None,
|
||||
parallel_type: ParallelType = ParallelType.SERIAL,
|
||||
merge_strategy: MergeStrategy | None = None,
|
||||
milestone: str = "",
|
||||
status: PhaseStatus = PhaseStatus.PENDING,
|
||||
status: SubTaskStatus = SubTaskStatus.PENDING,
|
||||
result: dict | None = None,
|
||||
) -> PlanPhase:
|
||||
"""创建测试用 PlanPhase 实例"""
|
||||
return PlanPhase(
|
||||
) -> SubTask:
|
||||
"""创建测试用 SubTask 实例"""
|
||||
return SubTask(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
assigned_expert=assigned_expert,
|
||||
task_description=task_description,
|
||||
depends_on=depends_on or [],
|
||||
parallel_type=parallel_type,
|
||||
merge_strategy=merge_strategy,
|
||||
milestone=milestone,
|
||||
status=status,
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
def _make_valid_plan() -> CollaborationPlan:
|
||||
"""创建一个有效的协作计划"""
|
||||
phases = [
|
||||
_make_phase(id="p1", name="需求分析", assigned_expert="analyst", task_description="分析需求"),
|
||||
_make_phase(
|
||||
id="p2",
|
||||
name="架构设计",
|
||||
assigned_expert="architect",
|
||||
task_description="设计架构",
|
||||
depends_on=["p1"],
|
||||
),
|
||||
_make_phase(
|
||||
id="p3",
|
||||
name="代码实现",
|
||||
assigned_expert="coder",
|
||||
task_description="编写代码",
|
||||
depends_on=["p2"],
|
||||
),
|
||||
def _make_valid_plan() -> TeamPlan:
|
||||
"""创建一个有效的 hub-and-spoke 执行计划"""
|
||||
subtasks = [
|
||||
_make_subtask(id="s1", description="分析需求", assigned_expert="analyst"),
|
||||
_make_subtask(id="s2", description="设计架构", assigned_expert="architect"),
|
||||
_make_subtask(id="s3", description="编写代码", assigned_expert="coder"),
|
||||
]
|
||||
return CollaborationPlan(
|
||||
return TeamPlan(
|
||||
id="plan_001",
|
||||
task="实现用户登录功能",
|
||||
phases=phases,
|
||||
variables={"project": "fischer"},
|
||||
subtasks=subtasks,
|
||||
status=PlanStatus.DRAFT,
|
||||
lead_expert="architect",
|
||||
)
|
||||
|
||||
|
||||
# ── PlanPhase 测试 ────────────────────────────────────────
|
||||
# ── MergeStrategy 测试 ────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanPhase:
|
||||
"""PlanPhase 数据模型测试"""
|
||||
class TestMergeStrategy:
|
||||
"""MergeStrategy 枚举测试"""
|
||||
|
||||
def test_only_best_strategy_exists(self):
|
||||
"""hub-and-spoke 模式仅保留 BEST 策略"""
|
||||
assert MergeStrategy.BEST == "best"
|
||||
# 确保只有 BEST 一个值
|
||||
assert len(list(MergeStrategy)) == 1
|
||||
|
||||
def test_no_vote_or_fusion(self):
|
||||
"""VOTE 和 FUSION 已被移除"""
|
||||
assert not hasattr(MergeStrategy, "VOTE")
|
||||
assert not hasattr(MergeStrategy, "FUSION")
|
||||
|
||||
|
||||
# ── PlanStatus 测试 ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanStatus:
|
||||
"""PlanStatus 枚举测试"""
|
||||
|
||||
def test_statuses_exist(self):
|
||||
"""必要的计划状态都存在"""
|
||||
assert PlanStatus.DRAFT == "draft"
|
||||
assert PlanStatus.EXECUTING == "executing"
|
||||
assert PlanStatus.COMPLETED == "completed"
|
||||
assert PlanStatus.FAILED == "failed"
|
||||
assert PlanStatus.FALLBACK == "fallback"
|
||||
|
||||
def test_no_confirmed_status(self):
|
||||
"""CONFIRMED 状态已被移除(hub-and-spoke 无需确认阶段)"""
|
||||
assert not hasattr(PlanStatus, "CONFIRMED")
|
||||
|
||||
|
||||
# ── SubTaskStatus 测试 ────────────────────────────────────
|
||||
|
||||
|
||||
class TestSubTaskStatus:
|
||||
"""SubTaskStatus 枚举测试"""
|
||||
|
||||
def test_statuses_exist(self):
|
||||
"""子任务状态都存在"""
|
||||
assert SubTaskStatus.PENDING == "pending"
|
||||
assert SubTaskStatus.RUNNING == "running"
|
||||
assert SubTaskStatus.COMPLETED == "completed"
|
||||
assert SubTaskStatus.FAILED == "failed"
|
||||
|
||||
|
||||
# ── SubTask 测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSubTask:
|
||||
"""SubTask 数据模型测试"""
|
||||
|
||||
def test_creation_with_all_fields(self):
|
||||
"""创建 PlanPhase 并设置所有字段"""
|
||||
phase = PlanPhase(
|
||||
id="phase_a",
|
||||
name="竞品分析",
|
||||
"""创建 SubTask 并设置所有字段"""
|
||||
subtask = SubTask(
|
||||
id="subtask_a",
|
||||
description="竞品分析",
|
||||
assigned_expert="analyst",
|
||||
task_description="分析竞品功能",
|
||||
depends_on=["phase_0"],
|
||||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
||||
merge_strategy=MergeStrategy.BEST,
|
||||
milestone="竞品报告完成",
|
||||
status=PhaseStatus.IN_PROGRESS,
|
||||
status=SubTaskStatus.RUNNING,
|
||||
result={"report": "竞品分析报告"},
|
||||
)
|
||||
assert phase.id == "phase_a"
|
||||
assert phase.name == "竞品分析"
|
||||
assert phase.assigned_expert == "analyst"
|
||||
assert phase.task_description == "分析竞品功能"
|
||||
assert phase.depends_on == ["phase_0"]
|
||||
assert phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
|
||||
assert phase.merge_strategy == MergeStrategy.BEST
|
||||
assert phase.milestone == "竞品报告完成"
|
||||
assert phase.status == PhaseStatus.IN_PROGRESS
|
||||
assert phase.result == {"report": "竞品分析报告"}
|
||||
assert subtask.id == "subtask_a"
|
||||
assert subtask.description == "竞品分析"
|
||||
assert subtask.assigned_expert == "analyst"
|
||||
assert subtask.status == SubTaskStatus.RUNNING
|
||||
assert subtask.result == {"report": "竞品分析报告"}
|
||||
|
||||
def test_default_values(self):
|
||||
"""默认值:自动生成 id,PENDING 状态"""
|
||||
subtask = SubTask(description="测试任务")
|
||||
assert subtask.id is not None
|
||||
assert len(subtask.id) > 0
|
||||
assert subtask.description == "测试任务"
|
||||
assert subtask.assigned_expert == ""
|
||||
assert subtask.status == SubTaskStatus.PENDING
|
||||
assert subtask.result is None
|
||||
|
||||
def test_to_dict_from_dict_roundtrip(self):
|
||||
"""to_dict / from_dict 往返序列化"""
|
||||
phase = PlanPhase(
|
||||
id="roundtrip_phase",
|
||||
name="往返测试",
|
||||
subtask = SubTask(
|
||||
id="roundtrip_subtask",
|
||||
description="往返测试",
|
||||
assigned_expert="tester",
|
||||
task_description="测试序列化",
|
||||
depends_on=["dep_a", "dep_b"],
|
||||
parallel_type=ParallelType.SUBTASK_PARALLEL,
|
||||
merge_strategy=MergeStrategy.VOTE,
|
||||
milestone="序列化验证",
|
||||
status=PhaseStatus.COMPLETED,
|
||||
status=SubTaskStatus.COMPLETED,
|
||||
result={"key": "value"},
|
||||
)
|
||||
d = phase.to_dict()
|
||||
restored = PlanPhase.from_dict(d)
|
||||
d = subtask.to_dict()
|
||||
restored = SubTask.from_dict(d)
|
||||
|
||||
assert restored.id == phase.id
|
||||
assert restored.name == phase.name
|
||||
assert restored.assigned_expert == phase.assigned_expert
|
||||
assert restored.task_description == phase.task_description
|
||||
assert restored.depends_on == phase.depends_on
|
||||
assert restored.parallel_type == phase.parallel_type
|
||||
assert restored.merge_strategy == phase.merge_strategy
|
||||
assert restored.milestone == phase.milestone
|
||||
assert restored.status == phase.status
|
||||
assert restored.result == phase.result
|
||||
assert restored.id == subtask.id
|
||||
assert restored.description == subtask.description
|
||||
assert restored.assigned_expert == subtask.assigned_expert
|
||||
assert restored.status == subtask.status
|
||||
assert restored.result == subtask.result
|
||||
|
||||
def test_to_dict_from_dict_with_none_merge_strategy(self):
|
||||
"""merge_strategy 为 None 时的序列化往返"""
|
||||
phase = PlanPhase(
|
||||
id="no_merge",
|
||||
name="无合并",
|
||||
def test_to_dict_structure(self):
|
||||
"""to_dict 返回正确的字典结构"""
|
||||
subtask = _make_subtask(
|
||||
id="struct_test",
|
||||
description="结构测试",
|
||||
assigned_expert="dev",
|
||||
task_description="串行任务",
|
||||
parallel_type=ParallelType.SERIAL,
|
||||
status=SubTaskStatus.RUNNING,
|
||||
result={"output": "data"},
|
||||
)
|
||||
d = phase.to_dict()
|
||||
assert d["merge_strategy"] is None
|
||||
restored = PlanPhase.from_dict(d)
|
||||
assert restored.merge_strategy is None
|
||||
d = subtask.to_dict()
|
||||
assert d["id"] == "struct_test"
|
||||
assert d["description"] == "结构测试"
|
||||
assert d["assigned_expert"] == "dev"
|
||||
assert d["status"] == "running"
|
||||
assert d["result"] == {"output": "data"}
|
||||
|
||||
|
||||
# ── CollaborationPlan 测试 ────────────────────────────────
|
||||
# ── TeamPlan 测试 ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCollaborationPlan:
|
||||
"""CollaborationPlan 数据模型测试"""
|
||||
class TestTeamPlan:
|
||||
"""TeamPlan 数据模型测试"""
|
||||
|
||||
def test_creation(self):
|
||||
"""创建 CollaborationPlan"""
|
||||
"""创建 TeamPlan"""
|
||||
plan = _make_valid_plan()
|
||||
assert plan.id == "plan_001"
|
||||
assert plan.task == "实现用户登录功能"
|
||||
assert len(plan.phases) == 3
|
||||
assert plan.variables == {"project": "fischer"}
|
||||
assert len(plan.subtasks) == 3
|
||||
assert plan.status == PlanStatus.DRAFT
|
||||
assert plan.lead_expert == "architect"
|
||||
|
||||
def test_default_values(self):
|
||||
"""默认值:自动生成 id,空子任务列表"""
|
||||
plan = TeamPlan(task="测试任务")
|
||||
assert plan.id is not None
|
||||
assert plan.task == "测试任务"
|
||||
assert plan.subtasks == []
|
||||
assert plan.status == PlanStatus.DRAFT
|
||||
assert plan.lead_expert == ""
|
||||
|
||||
def test_to_dict_from_dict_roundtrip(self):
|
||||
"""to_dict / from_dict 往返序列化"""
|
||||
plan = _make_valid_plan()
|
||||
d = plan.to_dict()
|
||||
restored = CollaborationPlan.from_dict(d)
|
||||
restored = TeamPlan.from_dict(d)
|
||||
|
||||
assert restored.id == plan.id
|
||||
assert restored.task == plan.task
|
||||
assert len(restored.phases) == len(plan.phases)
|
||||
assert restored.variables == plan.variables
|
||||
assert len(restored.subtasks) == len(plan.subtasks)
|
||||
assert restored.status == plan.status
|
||||
assert restored.lead_expert == plan.lead_expert
|
||||
|
||||
for original, restored_phase in zip(plan.phases, restored.phases):
|
||||
assert restored_phase.id == original.id
|
||||
assert restored_phase.name == original.name
|
||||
assert restored_phase.assigned_expert == original.assigned_expert
|
||||
assert restored_phase.depends_on == original.depends_on
|
||||
assert restored_phase.parallel_type == original.parallel_type
|
||||
assert restored_phase.merge_strategy == original.merge_strategy
|
||||
for original, restored_st in zip(plan.subtasks, restored.subtasks):
|
||||
assert restored_st.id == original.id
|
||||
assert restored_st.description == original.description
|
||||
assert restored_st.assigned_expert == original.assigned_expert
|
||||
assert restored_st.status == original.status
|
||||
|
||||
def test_validate_valid_plan(self):
|
||||
"""验证有效计划无错误"""
|
||||
def test_get_subtask_by_id(self):
|
||||
"""get_subtask 根据 ID 获取子任务"""
|
||||
plan = _make_valid_plan()
|
||||
errors = plan.validate()
|
||||
assert errors == []
|
||||
st = plan.get_subtask("s2")
|
||||
assert st is not None
|
||||
assert st.id == "s2"
|
||||
assert st.description == "设计架构"
|
||||
|
||||
def test_validate_detects_duplicate_phase_ids(self):
|
||||
"""验证检测到重复阶段 ID"""
|
||||
phases = [
|
||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||||
_make_phase(id="p1", name="阶段2", assigned_expert="b", task_description="t2"),
|
||||
]
|
||||
plan = CollaborationPlan(
|
||||
id="dup_plan", task="重复ID测试", phases=phases, lead_expert="a"
|
||||
)
|
||||
errors = plan.validate()
|
||||
assert any("重复的阶段 ID" in e for e in errors)
|
||||
|
||||
def test_validate_detects_missing_depends_on_references(self):
|
||||
"""验证检测到不存在的 depends_on 引用"""
|
||||
phases = [
|
||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||||
_make_phase(
|
||||
id="p2",
|
||||
name="阶段2",
|
||||
assigned_expert="b",
|
||||
task_description="t2",
|
||||
depends_on=["p1", "nonexistent"],
|
||||
),
|
||||
]
|
||||
plan = CollaborationPlan(
|
||||
id="missing_dep_plan", task="缺失依赖测试", phases=phases, lead_expert="a"
|
||||
)
|
||||
errors = plan.validate()
|
||||
assert any("不存在的阶段 ID" in e for e in errors)
|
||||
|
||||
def test_validate_detects_circular_dependencies(self):
|
||||
"""验证检测到循环依赖"""
|
||||
phases = [
|
||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1", depends_on=["p3"]),
|
||||
_make_phase(id="p2", name="阶段2", assigned_expert="b", task_description="t2", depends_on=["p1"]),
|
||||
_make_phase(id="p3", name="阶段3", assigned_expert="c", task_description="t3", depends_on=["p2"]),
|
||||
]
|
||||
plan = CollaborationPlan(
|
||||
id="cycle_plan", task="循环依赖测试", phases=phases, lead_expert="a"
|
||||
)
|
||||
errors = plan.validate()
|
||||
assert any("循环依赖" in e for e in errors)
|
||||
|
||||
def test_validate_detects_competitive_parallel_without_merge_strategy(self):
|
||||
"""验证检测到 COMPETITIVE_PARALLEL 缺少 merge_strategy"""
|
||||
phases = [
|
||||
_make_phase(
|
||||
id="p1",
|
||||
name="竞争阶段",
|
||||
assigned_expert="a",
|
||||
task_description="竞争任务",
|
||||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
||||
merge_strategy=None,
|
||||
),
|
||||
]
|
||||
plan = CollaborationPlan(
|
||||
id="no_merge_plan", task="缺少合并策略测试", phases=phases, lead_expert="a"
|
||||
)
|
||||
errors = plan.validate()
|
||||
assert any("COMPETITIVE_PARALLEL" in e and "merge_strategy" in e for e in errors)
|
||||
|
||||
def test_get_ready_phases_returns_phases_with_completed_dependencies(self):
|
||||
"""get_ready_phases 返回依赖已完成的阶段"""
|
||||
def test_get_subtask_with_nonexistent_id_returns_none(self):
|
||||
"""get_subtask 对不存在的 ID 返回 None"""
|
||||
plan = _make_valid_plan()
|
||||
# 初始状态:p1 无依赖,应该就绪
|
||||
ready = plan.get_ready_phases()
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "p1"
|
||||
assert plan.get_subtask("nonexistent") is None
|
||||
|
||||
# 完成 p1 后,p2 应该就绪
|
||||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"analysis": "done"})
|
||||
ready = plan.get_ready_phases()
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "p2"
|
||||
|
||||
# 完成 p2 后,p3 应该就绪
|
||||
plan.update_phase_status("p2", PhaseStatus.COMPLETED, {"design": "done"})
|
||||
ready = plan.get_ready_phases()
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "p3"
|
||||
|
||||
def test_get_ready_phases_returns_empty_when_dependencies_not_met(self):
|
||||
"""get_ready_phases 在依赖未满足时返回空列表"""
|
||||
phases = [
|
||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
||||
_make_phase(
|
||||
id="p2",
|
||||
name="阶段2",
|
||||
assigned_expert="b",
|
||||
task_description="t2",
|
||||
depends_on=["p1"],
|
||||
),
|
||||
]
|
||||
plan = CollaborationPlan(
|
||||
id="dep_plan", task="依赖未满足测试", phases=phases, lead_expert="a"
|
||||
)
|
||||
# p2 依赖 p1,p1 未完成,所以 p2 不就绪
|
||||
# 但 p1 无依赖,所以 p1 就绪
|
||||
ready = plan.get_ready_phases()
|
||||
assert len(ready) == 1
|
||||
assert ready[0].id == "p1"
|
||||
|
||||
# 将 p1 设为 IN_PROGRESS(未 COMPLETED),p2 仍不就绪
|
||||
plan.update_phase_status("p1", PhaseStatus.IN_PROGRESS)
|
||||
ready = plan.get_ready_phases()
|
||||
assert len(ready) == 0
|
||||
|
||||
def test_update_phase_status(self):
|
||||
"""update_phase_status 更新阶段状态和结果"""
|
||||
def test_update_subtask_status(self):
|
||||
"""update_subtask_status 更新子任务状态和结果"""
|
||||
plan = _make_valid_plan()
|
||||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"output": "分析完成"})
|
||||
phase = plan.get_phase("p1")
|
||||
assert phase is not None
|
||||
assert phase.status == PhaseStatus.COMPLETED
|
||||
assert phase.result == {"output": "分析完成"}
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "分析完成"})
|
||||
st = plan.get_subtask("s1")
|
||||
assert st is not None
|
||||
assert st.status == SubTaskStatus.COMPLETED
|
||||
assert st.result == {"output": "分析完成"}
|
||||
|
||||
# 不传 result 时不应覆盖已有 result
|
||||
plan.update_phase_status("p2", PhaseStatus.IN_PROGRESS)
|
||||
phase2 = plan.get_phase("p2")
|
||||
assert phase2 is not None
|
||||
assert phase2.status == PhaseStatus.IN_PROGRESS
|
||||
assert phase2.result is None
|
||||
|
||||
def test_get_phase_by_id(self):
|
||||
"""get_phase 根据 ID 获取阶段"""
|
||||
def test_update_subtask_status_without_result(self):
|
||||
"""update_subtask_status 不传 result 时不覆盖已有 result"""
|
||||
plan = _make_valid_plan()
|
||||
phase = plan.get_phase("p2")
|
||||
assert phase is not None
|
||||
assert phase.id == "p2"
|
||||
assert phase.name == "架构设计"
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "done"})
|
||||
plan.update_subtask_status("s1", SubTaskStatus.RUNNING)
|
||||
st = plan.get_subtask("s1")
|
||||
assert st is not None
|
||||
assert st.status == SubTaskStatus.RUNNING
|
||||
assert st.result == {"output": "done"}
|
||||
|
||||
def test_get_phase_with_nonexistent_id_returns_none(self):
|
||||
"""get_phase 对不存在的 ID 返回 None"""
|
||||
def test_completed_subtasks_property(self):
|
||||
"""completed_subtasks 返回已完成的子任务"""
|
||||
plan = _make_valid_plan()
|
||||
phase = plan.get_phase("nonexistent")
|
||||
assert phase is None
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||
plan.update_subtask_status("s2", SubTaskStatus.COMPLETED)
|
||||
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
|
||||
|
||||
completed = plan.completed_subtasks
|
||||
assert len(completed) == 2
|
||||
assert {st.id for st in completed} == {"s1", "s2"}
|
||||
|
||||
def test_failed_subtasks_property(self):
|
||||
"""failed_subtasks 返回失败的子任务"""
|
||||
plan = _make_valid_plan()
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
|
||||
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
|
||||
|
||||
failed = plan.failed_subtasks
|
||||
assert len(failed) == 2
|
||||
assert {st.id for st in failed} == {"s2", "s3"}
|
||||
|
||||
def test_all_done_property_when_all_completed(self):
|
||||
"""all_done 当所有子任务完成时返回 True"""
|
||||
plan = _make_valid_plan()
|
||||
for st in plan.subtasks:
|
||||
plan.update_subtask_status(st.id, SubTaskStatus.COMPLETED)
|
||||
assert plan.all_done is True
|
||||
|
||||
def test_all_done_property_when_some_failed(self):
|
||||
"""all_done 当所有子任务完成或失败时返回 True"""
|
||||
plan = _make_valid_plan()
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
|
||||
plan.update_subtask_status("s3", SubTaskStatus.COMPLETED)
|
||||
assert plan.all_done is True
|
||||
|
||||
def test_all_done_property_when_pending(self):
|
||||
"""all_done 当有子任务未完成时返回 False"""
|
||||
plan = _make_valid_plan()
|
||||
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||
# s2 and s3 still pending
|
||||
assert plan.all_done is False
|
||||
|
||||
def test_all_done_property_empty_plan(self):
|
||||
"""all_done 当没有子任务时返回 True(vacuous truth)"""
|
||||
plan = TeamPlan(task="空计划")
|
||||
assert plan.all_done is True
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||
from agentkit.experts.router import (
|
||||
|
|
@ -113,7 +111,6 @@ class TestExpertTeamRoutingResult:
|
|||
assert result.specified_experts == []
|
||||
assert result.task_content == ""
|
||||
assert result.auto_compose is False
|
||||
assert result.complexity == 0.0
|
||||
assert result.match_method == ""
|
||||
|
||||
|
||||
|
|
@ -159,47 +156,14 @@ class TestExpertTeamRouterResolve:
|
|||
assert result.specified_experts == ["analyst"]
|
||||
assert result.task_content == "请分析这份报告"
|
||||
|
||||
def test_high_complexity_triggers_team_suggestion(self):
|
||||
"""高复杂度 (>=0.7) 触发团队模式建议"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve("分析这个复杂系统", complexity=0.8)
|
||||
assert result.matched is True
|
||||
assert result.team_mode is True
|
||||
assert result.auto_compose is True
|
||||
assert result.match_method == "complexity_suggestion"
|
||||
assert result.complexity == 0.8
|
||||
|
||||
def test_high_complexity_exact_threshold(self):
|
||||
"""复杂度恰好等于阈值 0.7 也触发团队模式"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve("任务内容", complexity=0.7)
|
||||
assert result.matched is True
|
||||
assert result.team_mode is True
|
||||
assert result.match_method == "complexity_suggestion"
|
||||
|
||||
def test_low_complexity_no_team_mode(self):
|
||||
"""低复杂度 (<0.7) 不触发团队模式"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve("简单问题", complexity=0.3)
|
||||
assert result.matched is False
|
||||
assert result.team_mode is False
|
||||
assert result.task_content == "简单问题"
|
||||
assert result.complexity == 0.3
|
||||
|
||||
def test_no_team_prefix_no_complexity(self):
|
||||
"""无 @team 前缀且无复杂度时不触发团队模式"""
|
||||
def test_no_team_prefix_no_team_mode(self):
|
||||
"""无 @team 前缀不触发团队模式"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve("普通问题")
|
||||
assert result.matched is False
|
||||
assert result.team_mode is False
|
||||
|
||||
def test_team_prefix_takes_priority_over_complexity(self):
|
||||
"""@team 前缀优先于复杂度判断"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve("@team:analyst 任务", complexity=0.1)
|
||||
assert result.matched is True
|
||||
assert result.match_method == "explicit_team"
|
||||
assert result.specified_experts == ["analyst"]
|
||||
assert result.task_content == "普通问题"
|
||||
assert result.match_method == ""
|
||||
|
||||
def test_nonexistent_expert_still_included(self):
|
||||
"""指定不存在的专家名仍包含在列表中"""
|
||||
|
|
@ -214,6 +178,35 @@ class TestExpertTeamRouterResolve:
|
|||
assert result.task_content == "@team"
|
||||
assert result.auto_compose is True
|
||||
|
||||
def test_resolve_strips_leading_whitespace(self):
|
||||
"""resolve() 对前导空白做 strip()"""
|
||||
router = ExpertTeamRouter()
|
||||
result = router.resolve(" @team:analyst 任务内容")
|
||||
assert result.matched is True
|
||||
assert result.team_mode is True
|
||||
assert result.specified_experts == ["analyst"]
|
||||
assert result.task_content == "任务内容"
|
||||
|
||||
def test_invalid_expert_names_rejected(self):
|
||||
"""无效专家名被过滤掉"""
|
||||
router = ExpertTeamRouter()
|
||||
# 包含特殊字符的名称应被过滤
|
||||
result = router.resolve("@team:analyst,inva!id,bad/name 任务")
|
||||
# 仅保留合法名称
|
||||
assert "analyst" in result.specified_experts
|
||||
assert "inva!id" not in result.specified_experts
|
||||
assert "bad/name" not in result.specified_experts
|
||||
|
||||
def test_max_experts_limit(self):
|
||||
"""指定超过 MAX_EXPERTS 数量的专家时截断"""
|
||||
router = ExpertTeamRouter()
|
||||
# 构造 15 个专家名
|
||||
names = ",".join(f"expert{i}" for i in range(15))
|
||||
result = router.resolve(f"@team:{names} 任务")
|
||||
from agentkit.experts.router import MAX_EXPERTS
|
||||
|
||||
assert len(result.specified_experts) == MAX_EXPERTS
|
||||
|
||||
|
||||
# ── ExpertTeamRouter.resolve_expert_configs 测试 ───────────
|
||||
|
||||
|
|
@ -275,6 +268,25 @@ class TestExpertTeamRouterResolveExpertConfigs:
|
|||
configs = router.resolve_expert_configs(["analyst"])
|
||||
assert configs[0].bound_skills == ["data_query"]
|
||||
|
||||
def test_resolve_first_expert_is_lead(self):
|
||||
"""第一个专家被指定为 lead"""
|
||||
registry = _make_registry_with_templates()
|
||||
router = ExpertTeamRouter(template_registry=registry)
|
||||
|
||||
configs = router.resolve_expert_configs(["analyst", "strategist", "reviewer"])
|
||||
assert configs[0].is_lead is True
|
||||
assert configs[1].is_lead is False
|
||||
assert configs[2].is_lead is False
|
||||
|
||||
def test_resolve_skips_invalid_names(self):
|
||||
"""跳过无效的专家名"""
|
||||
router = ExpertTeamRouter()
|
||||
# 包含无效字符的名称应被跳过
|
||||
configs = router.resolve_expert_configs(["valid_name", "inva!id", "also_valid"])
|
||||
assert len(configs) == 2
|
||||
assert configs[0].name == "valid_name"
|
||||
assert configs[1].name == "also_valid"
|
||||
|
||||
|
||||
# ── ExpertTeamRouter 构造测试 ─────────────────────────────
|
||||
|
||||
|
|
@ -293,7 +305,40 @@ class TestExpertTeamRouterInit:
|
|||
router = ExpertTeamRouter(template_registry=registry)
|
||||
assert router._registry is registry
|
||||
|
||||
def test_complexity_threshold(self):
|
||||
"""复杂度阈值默认为 0.7"""
|
||||
|
||||
# ── ExpertTeamRouter.can_handle 测试 ──────────────────────
|
||||
|
||||
|
||||
class TestExpertTeamRouterCanHandle:
|
||||
"""ExpertTeamRouter.can_handle 方法测试"""
|
||||
|
||||
def test_can_handle_with_matching_template_name(self):
|
||||
"""模板名出现在内容中时返回 True"""
|
||||
registry = _make_registry_with_templates()
|
||||
router = ExpertTeamRouter(template_registry=registry)
|
||||
|
||||
assert router.can_handle("请 analyst 帮我分析") is True
|
||||
|
||||
def test_can_handle_with_matching_description(self):
|
||||
"""模板描述词出现在内容中时返回 True"""
|
||||
registry = _make_registry_with_templates()
|
||||
router = ExpertTeamRouter(template_registry=registry)
|
||||
|
||||
# 描述为 "analyst 模板" 等,"模板" 长度 > 2 应匹配
|
||||
assert router.can_handle("我需要一个模板来参考") is True
|
||||
|
||||
def test_can_handle_no_templates(self):
|
||||
"""无注册模板时返回 False"""
|
||||
router = ExpertTeamRouter()
|
||||
assert router.COMPLEXITY_THRESHOLD == 0.7
|
||||
# 默认注册中心可能为空
|
||||
# 强制清空以测试
|
||||
router._registry._templates.clear()
|
||||
assert router.can_handle("任何内容") is False
|
||||
|
||||
def test_can_handle_with_templates_no_match(self):
|
||||
"""有模板但内容不匹配时仍返回 True(auto-compose 可组建团队)"""
|
||||
registry = _make_registry_with_templates()
|
||||
router = ExpertTeamRouter(template_registry=registry)
|
||||
|
||||
# 内容与任何模板名/描述都不匹配,但有模板存在 → auto-compose 可用
|
||||
assert router.can_handle("完全无关的内容 xyz123") is True
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""ExpertTeam 容器单元测试"""
|
||||
"""ExpertTeam 容器单元测试 (hub-and-spoke 模式)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -11,11 +11,6 @@ from agentkit.core.handoff_transport import InProcessHandoffTransport
|
|||
from agentkit.core.shared_workspace import SharedWorkspace
|
||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||
from agentkit.experts.expert import Expert
|
||||
from agentkit.experts.plan import (
|
||||
CollaborationPlan,
|
||||
PlanPhase,
|
||||
PlanStatus,
|
||||
)
|
||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||
from agentkit.experts.team import ExpertTeam, TeamStatus
|
||||
|
||||
|
|
@ -84,27 +79,6 @@ def _make_mock_expert(
|
|||
return expert
|
||||
|
||||
|
||||
def _make_valid_plan(
|
||||
plan_id: str = "plan_1",
|
||||
task: str = "测试任务",
|
||||
lead_expert: str = "lead",
|
||||
) -> CollaborationPlan:
|
||||
"""创建有效的 CollaborationPlan"""
|
||||
return CollaborationPlan(
|
||||
id=plan_id,
|
||||
task=task,
|
||||
phases=[
|
||||
PlanPhase(
|
||||
id="phase_1",
|
||||
name="阶段1",
|
||||
assigned_expert=lead_expert,
|
||||
task_description="执行任务",
|
||||
)
|
||||
],
|
||||
lead_expert=lead_expert,
|
||||
)
|
||||
|
||||
|
||||
# ── ExpertTeam 创建测试 ───────────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -119,7 +93,6 @@ class TestExpertTeamCreation:
|
|||
assert len(team.team_id) > 0
|
||||
assert team.status == TeamStatus.FORMING
|
||||
assert team.lead_expert is None
|
||||
assert team.plan is None
|
||||
assert team.experts == []
|
||||
assert team.active_experts == []
|
||||
|
||||
|
|
@ -173,7 +146,7 @@ class TestExpertTeamCreateTeam:
|
|||
|
||||
assert team._lead_expert_name == "lead"
|
||||
assert team.lead_expert is mock_expert
|
||||
assert team.status == TeamStatus.PLANNING
|
||||
assert team.status == TeamStatus.EXECUTING
|
||||
assert mock_expert.team_id == team.team_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -195,7 +168,7 @@ class TestExpertTeamCreateTeam:
|
|||
|
||||
assert len(team.experts) == 2
|
||||
assert team._lead_expert_name == "lead"
|
||||
assert team.status == TeamStatus.PLANNING
|
||||
assert team.status == TeamStatus.EXECUTING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_without_pool_raises(self):
|
||||
|
|
@ -411,64 +384,6 @@ class TestExpertTeamRemoveExpert:
|
|||
assert last_left[0][1]["expert_name"] == "member1"
|
||||
|
||||
|
||||
# ── ExpertTeam.update_plan 测试 ────────────────────────────
|
||||
|
||||
|
||||
class TestExpertTeamUpdatePlan:
|
||||
"""ExpertTeam.update_plan 协作计划更新测试"""
|
||||
|
||||
def test_update_plan_with_valid_plan(self):
|
||||
"""有效计划更新成功,返回受影响的 Expert 名称"""
|
||||
team = ExpertTeam()
|
||||
plan = _make_valid_plan(lead_expert="lead")
|
||||
|
||||
affected = team.update_plan(plan)
|
||||
|
||||
assert team.plan is plan
|
||||
assert "lead" in affected
|
||||
|
||||
def test_update_plan_confirmed_sets_executing(self):
|
||||
"""CONFIRMED 状态的计划将团队状态设为 EXECUTING"""
|
||||
team = ExpertTeam()
|
||||
plan = _make_valid_plan(lead_expert="lead")
|
||||
plan.status = PlanStatus.CONFIRMED
|
||||
|
||||
team.update_plan(plan)
|
||||
|
||||
assert team.status == TeamStatus.EXECUTING
|
||||
|
||||
def test_update_plan_with_invalid_plan_no_update(self):
|
||||
"""无效计划(validate 返回错误)不更新,返回验证错误列表"""
|
||||
team = ExpertTeam()
|
||||
# 创建有循环依赖的无效计划
|
||||
plan = CollaborationPlan(
|
||||
id="bad_plan",
|
||||
task="无效任务",
|
||||
phases=[
|
||||
PlanPhase(
|
||||
id="p1",
|
||||
name="阶段1",
|
||||
assigned_expert="a",
|
||||
task_description="t1",
|
||||
depends_on=["p2"],
|
||||
),
|
||||
PlanPhase(
|
||||
id="p2",
|
||||
name="阶段2",
|
||||
assigned_expert="b",
|
||||
task_description="t2",
|
||||
depends_on=["p1"],
|
||||
),
|
||||
],
|
||||
lead_expert="lead",
|
||||
)
|
||||
|
||||
result = team.update_plan(plan)
|
||||
|
||||
assert len(result) > 0 # 返回验证错误而非空列表
|
||||
assert team.plan is None # 未更新
|
||||
|
||||
|
||||
# ── ExpertTeam.broadcast_user_message 测试 ─────────────────
|
||||
|
||||
|
||||
|
|
@ -535,42 +450,6 @@ class TestExpertTeamGetSharedContext:
|
|||
assert context == {}
|
||||
|
||||
|
||||
# ── ExpertTeam.generate_plan 测试 ──────────────────────────
|
||||
|
||||
|
||||
class TestExpertTeamGeneratePlan:
|
||||
"""ExpertTeam.generate_plan 计划生成测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_plan(self):
|
||||
"""生成空的 CollaborationPlan"""
|
||||
pool = _make_mock_pool()
|
||||
team = ExpertTeam(pool=pool)
|
||||
|
||||
lead_config = _make_expert_config(name="lead", is_lead=True)
|
||||
with patch.object(Expert, "create", new_callable=AsyncMock) as mock_create:
|
||||
lead_expert = _make_mock_expert(name="lead", is_lead=True)
|
||||
mock_create.return_value = lead_expert
|
||||
await team.create_team(lead_config)
|
||||
|
||||
plan = await team.generate_plan("分析数据")
|
||||
|
||||
assert plan is not None
|
||||
assert plan.task == "分析数据"
|
||||
assert plan.lead_expert == "lead"
|
||||
assert plan.phases == []
|
||||
assert team.plan is plan
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_plan_without_lead(self):
|
||||
"""没有 Lead Expert 时生成计划,lead_expert 为空字符串"""
|
||||
team = ExpertTeam()
|
||||
|
||||
plan = await team.generate_plan("测试任务")
|
||||
|
||||
assert plan.lead_expert == ""
|
||||
|
||||
|
||||
# ── ExpertTeam.dissolve 测试 ───────────────────────────────
|
||||
|
||||
|
||||
|
|
@ -667,11 +546,6 @@ class TestExpertTeamDissolvedOperations:
|
|||
# 解散后状态为 DISSOLVED
|
||||
assert team.status == TeamStatus.DISSOLVED
|
||||
|
||||
# 再次 create_team 时,由于 experts 已清空,
|
||||
# 但 pool 仍然存在,理论上可以重新创建
|
||||
# 但这里验证状态是 DISSOLVED
|
||||
assert team.status == TeamStatus.DISSOLVED
|
||||
|
||||
|
||||
# ── ExpertTeam.lead_expert 属性测试 ────────────────────────
|
||||
|
||||
|
|
@ -742,10 +616,11 @@ class TestExpertTeamBuildContext:
|
|||
|
||||
context = team._build_team_context(lead_config, [member_config])
|
||||
|
||||
assert "You are part of an Expert Team." in context
|
||||
assert "Lead Expert: lead (领导者)" in context
|
||||
assert "Team Member: analyst (分析师), Skills: data_query" in context
|
||||
assert "send_message() and request_assist()" in context
|
||||
assert "hub-and-spoke mode" in context
|
||||
assert "Lead Expert (hub): lead (领导者)" in context
|
||||
assert "Team Member (spoke): analyst (分析师)" in context
|
||||
assert "data_query" in context
|
||||
assert "depth=1" in context
|
||||
|
||||
def test_build_team_context_no_lead(self):
|
||||
"""没有 Lead Expert 时构建上下文"""
|
||||
|
|
@ -754,8 +629,9 @@ class TestExpertTeamBuildContext:
|
|||
|
||||
context = team._build_team_context(None, [member_config])
|
||||
|
||||
assert "Lead Expert" not in context
|
||||
assert "Team Member: analyst" in context
|
||||
# 不应出现具体的 Lead Expert (hub): name 行
|
||||
assert "Lead Expert (hub):" not in context
|
||||
assert "Team Member (spoke): analyst" in context
|
||||
|
||||
def test_build_team_context_skips_lead_in_members(self):
|
||||
"""成员列表中包含 Lead 时跳过"""
|
||||
|
|
@ -765,4 +641,4 @@ class TestExpertTeamBuildContext:
|
|||
context = team._build_team_context(lead_config, [lead_config])
|
||||
|
||||
# Lead 不应出现在 Team Member 行
|
||||
assert "Team Member: lead" not in context
|
||||
assert "Team Member (spoke): lead" not in context
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,472 @@
|
|||
"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试
|
||||
|
||||
测试场景:
|
||||
- ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配
|
||||
- ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述
|
||||
- ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.builtin import ToolSearchTool
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
|
||||
|
||||
# ── Test Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
class FakeTool(Tool):
|
||||
"""用于测试的 Fake Tool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def _make_tools() -> list[Tool]:
|
||||
"""创建一组测试工具"""
|
||||
return [
|
||||
FakeTool(
|
||||
name="read_file",
|
||||
description="Read the contents of a file from the filesystem.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "file path to read"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
tags=["io", "file"],
|
||||
),
|
||||
FakeTool(
|
||||
name="write_file",
|
||||
description="Write content to a file on the filesystem.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "file path to write"},
|
||||
"content": {"type": "string", "description": "content to write"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
tags=["io", "file"],
|
||||
),
|
||||
FakeTool(
|
||||
name="web_search",
|
||||
description="Search the web for information using a search engine.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
tags=["web", "search"],
|
||||
),
|
||||
FakeTool(
|
||||
name="run_tests",
|
||||
description="Run project tests to verify code changes.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"commands": {"type": "array", "description": "test commands"},
|
||||
},
|
||||
},
|
||||
tags=["testing", "verification"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── ToolSearchIndex Tests ─────────────────────────────────
|
||||
|
||||
|
||||
class TestToolSearchIndex:
|
||||
"""ToolSearchIndex BM25 搜索测试"""
|
||||
|
||||
def test_empty_tools(self):
|
||||
"""空工具列表构建索引不报错"""
|
||||
index = ToolSearchIndex([])
|
||||
assert len(index) == 0
|
||||
assert index.search("anything") == []
|
||||
|
||||
def test_single_tool_match(self):
|
||||
"""单工具索引,匹配查询返回该工具"""
|
||||
tools = _make_tools()[:1]
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("read file")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "read_file"
|
||||
|
||||
def test_relevance_ranking(self):
|
||||
"""多工具索引,相关工具排在前面"""
|
||||
tools = _make_tools()
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("web search")
|
||||
assert len(results) > 0
|
||||
# web_search 应该排在最前
|
||||
assert results[0].name == "web_search"
|
||||
|
||||
def test_top_k_limit(self):
|
||||
"""top_k 限制返回数量"""
|
||||
tools = _make_tools()
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("file", top_k=2)
|
||||
assert len(results) <= 2
|
||||
|
||||
def test_no_match_returns_empty(self):
|
||||
"""无匹配时返回空列表"""
|
||||
tools = _make_tools()
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("xyzzy_nonexistent")
|
||||
assert results == []
|
||||
|
||||
def test_empty_query_returns_empty(self):
|
||||
"""空查询返回空列表"""
|
||||
tools = _make_tools()
|
||||
index = ToolSearchIndex(tools)
|
||||
assert index.search("") == []
|
||||
assert index.search(" ") == []
|
||||
|
||||
def test_top_k_zero_returns_empty(self):
|
||||
"""top_k=0 返回空列表"""
|
||||
tools = _make_tools()
|
||||
index = ToolSearchIndex(tools)
|
||||
assert index.search("file", top_k=0) == []
|
||||
|
||||
def test_snake_case_tokenization(self):
|
||||
"""snake_case 工具名被正确分词"""
|
||||
tool = FakeTool(
|
||||
name="read_file",
|
||||
description="Read file contents.",
|
||||
)
|
||||
index = ToolSearchIndex([tool])
|
||||
# 搜索 "read" 应该匹配
|
||||
results = index.search("read")
|
||||
assert len(results) == 1
|
||||
# 搜索 "file" 也应该匹配
|
||||
results = index.search("file")
|
||||
assert len(results) == 1
|
||||
|
||||
def test_search_includes_parameter_descriptions(self):
|
||||
"""搜索能匹配参数描述中的关键词"""
|
||||
tool = FakeTool(
|
||||
name="custom_tool",
|
||||
description="A custom tool.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"database_url": {
|
||||
"type": "string",
|
||||
"description": "PostgreSQL connection string",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
index = ToolSearchIndex([tool])
|
||||
results = index.search("postgresql database")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "custom_tool"
|
||||
|
||||
def test_search_includes_tags(self):
|
||||
"""搜索能匹配标签中的关键词"""
|
||||
tool = FakeTool(
|
||||
name="data_tool",
|
||||
description="Process data.",
|
||||
tags=["etl", "pipeline"],
|
||||
)
|
||||
index = ToolSearchIndex([tool])
|
||||
results = index.search("pipeline etl")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "data_tool"
|
||||
|
||||
def test_invalid_k1_raises(self):
|
||||
"""k1 < 0 抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="k1"):
|
||||
ToolSearchIndex(_make_tools(), k1=-1.0)
|
||||
|
||||
def test_invalid_b_raises(self):
|
||||
"""b 不在 [0,1] 范围抛出 ValueError"""
|
||||
with pytest.raises(ValueError, match="b"):
|
||||
ToolSearchIndex(_make_tools(), b=1.5)
|
||||
with pytest.raises(ValueError, match="b"):
|
||||
ToolSearchIndex(_make_tools(), b=-0.1)
|
||||
|
||||
def test_multiple_results_sorted_by_score(self):
|
||||
"""多个匹配结果按分数降序排列"""
|
||||
tools = [
|
||||
FakeTool(name="search_web", description="Search the web."),
|
||||
FakeTool(name="search_files", description="Search files on disk."),
|
||||
FakeTool(name="unrelated", description="Do something unrelated."),
|
||||
]
|
||||
index = ToolSearchIndex(tools)
|
||||
results = index.search("search")
|
||||
# 两个包含 "search" 的工具应该返回,unrelated 不返回
|
||||
assert len(results) == 2
|
||||
names = [r.name for r in results]
|
||||
assert "unrelated" not in names
|
||||
|
||||
|
||||
# ── ToolSearchTool Tests ──────────────────────────────────
|
||||
|
||||
|
||||
class TestToolSearchTool:
|
||||
"""ToolSearchTool 工具测试"""
|
||||
|
||||
def test_tool_name_and_schema(self):
|
||||
"""工具名称和 schema 正确"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
tool = ToolSearchTool(search_index=index)
|
||||
assert tool.name == "tool_search"
|
||||
assert "query" in tool.input_schema["properties"]
|
||||
assert "query" in tool.input_schema["required"]
|
||||
|
||||
async def test_execute_returns_results(self):
|
||||
"""执行搜索返回匹配工具的完整描述"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
tool = ToolSearchTool(search_index=index)
|
||||
result = await tool.execute(query="web search")
|
||||
|
||||
assert result["count"] > 0
|
||||
assert result["query"] == "web search"
|
||||
first = result["results"][0]
|
||||
assert "name" in first
|
||||
assert "description" in first
|
||||
assert "parameters" in first
|
||||
assert first["name"] == "web_search"
|
||||
|
||||
async def test_execute_empty_query_returns_error(self):
|
||||
"""空查询返回错误"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
tool = ToolSearchTool(search_index=index)
|
||||
result = await tool.execute(query="")
|
||||
assert "error" in result
|
||||
assert result["results"] == []
|
||||
|
||||
async def test_execute_no_match(self):
|
||||
"""无匹配返回空结果和提示消息"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
tool = ToolSearchTool(search_index=index)
|
||||
result = await tool.execute(query="zzz_nonexistent")
|
||||
assert result["count"] == 0
|
||||
assert result["results"] == []
|
||||
assert "message" in result
|
||||
|
||||
async def test_execute_respects_top_k(self):
|
||||
"""top_k 限制返回数量"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
tool = ToolSearchTool(search_index=index, top_k=1)
|
||||
result = await tool.execute(query="file")
|
||||
assert result["count"] <= 1
|
||||
|
||||
def test_invalid_top_k_raises(self):
|
||||
"""top_k < 1 抛出 ValueError"""
|
||||
index = ToolSearchIndex(_make_tools())
|
||||
with pytest.raises(ValueError, match="top_k"):
|
||||
ToolSearchTool(search_index=index, top_k=0)
|
||||
|
||||
|
||||
# ── ReActEngine Tiered Injection Tests ────────────────────
|
||||
|
||||
|
||||
class TestReActTieredInjection:
|
||||
"""ReActEngine 工具描述分层注入测试"""
|
||||
|
||||
def _make_engine(self, **kwargs: Any):
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = MagicMock()
|
||||
return ReActEngine(llm_gateway=gateway, **kwargs)
|
||||
|
||||
def test_core_tools_get_full_description(self):
|
||||
"""Core 工具注入完整描述(含参数)"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(
|
||||
name="read_file",
|
||||
description="Read a file.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "file path"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
),
|
||||
]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
# 核心工具区域存在
|
||||
assert "核心工具" in prompt
|
||||
# 参数描述被注入
|
||||
assert "path" in prompt
|
||||
assert "file path" in prompt
|
||||
|
||||
def test_extended_tools_get_one_line_only(self):
|
||||
"""Extended 工具只注入名称+一行描述(无参数)"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(
|
||||
name="custom_extended",
|
||||
description="A custom extended tool for testing.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"secret_param": {
|
||||
"type": "string",
|
||||
"description": "SECRET_PARAM_DESC",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
assert "扩展工具" in prompt
|
||||
assert "custom_extended" in prompt
|
||||
# 参数描述不应出现在扩展工具区域
|
||||
assert "SECRET_PARAM_DESC" not in prompt
|
||||
|
||||
def test_mixed_core_and_extended(self):
|
||||
"""混合 core + extended 工具,两者分区显示"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
FakeTool(
|
||||
name="web_search",
|
||||
description="Search the web.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"q": {"type": "string", "description": "query string"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
assert "核心工具" in prompt
|
||||
assert "扩展工具" in prompt
|
||||
# read_file 在核心区,web_search 在扩展区
|
||||
assert "read_file" in prompt
|
||||
assert "web_search" in prompt
|
||||
|
||||
def test_maybe_add_tool_search_adds_for_extended(self):
|
||||
"""有扩展工具时自动添加 tool_search"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."), # core
|
||||
FakeTool(name="web_search", description="Search the web."), # extended
|
||||
]
|
||||
result = engine._maybe_add_tool_search(tools)
|
||||
assert len(result) == 3
|
||||
assert any(t.name == "tool_search" for t in result)
|
||||
|
||||
def test_maybe_add_tool_search_skips_when_only_core(self):
|
||||
"""只有 core 工具时不添加 tool_search"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
FakeTool(name="write_file", description="Write a file."),
|
||||
]
|
||||
result = engine._maybe_add_tool_search(tools)
|
||||
assert len(result) == 2
|
||||
assert not any(t.name == "tool_search" for t in result)
|
||||
|
||||
def test_maybe_add_tool_search_skips_when_disabled(self):
|
||||
"""enable_tool_search=False 时不添加 tool_search"""
|
||||
engine = self._make_engine(enable_tool_search=False)
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
FakeTool(name="web_search", description="Search the web."),
|
||||
]
|
||||
result = engine._maybe_add_tool_search(tools)
|
||||
assert len(result) == 2
|
||||
assert not any(t.name == "tool_search" for t in result)
|
||||
|
||||
def test_maybe_add_tool_search_skips_when_already_present(self):
|
||||
"""tool_search 已存在时不重复添加"""
|
||||
engine = self._make_engine()
|
||||
index = ToolSearchIndex([])
|
||||
existing_search = ToolSearchTool(search_index=index)
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
FakeTool(name="web_search", description="Search the web."),
|
||||
existing_search,
|
||||
]
|
||||
result = engine._maybe_add_tool_search(tools)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_custom_core_tool_names(self):
|
||||
"""自定义 core_tool_names 覆盖默认值"""
|
||||
engine = self._make_engine(core_tool_names=["my_core_tool"])
|
||||
tools = [
|
||||
FakeTool(name="my_core_tool", description="My core tool."),
|
||||
FakeTool(
|
||||
name="read_file",
|
||||
description="Read a file.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"p": {"type": "string", "description": "PARAM_DESC"},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
# my_core_tool 在核心区
|
||||
assert "核心工具" in prompt
|
||||
# read_file 现在是扩展工具,参数描述不应出现
|
||||
assert "PARAM_DESC" not in prompt
|
||||
|
||||
def test_tool_search_is_core_tool(self):
|
||||
"""tool_search 被视为 core 工具(全量描述注入)"""
|
||||
engine = self._make_engine()
|
||||
index = ToolSearchIndex([])
|
||||
search_tool = ToolSearchTool(search_index=index)
|
||||
tools = [search_tool]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
# tool_search 应该在核心工具区
|
||||
assert "核心工具" in prompt
|
||||
assert "tool_search" in prompt
|
||||
# 其参数 query 应该被注入
|
||||
assert "query" in prompt
|
||||
|
||||
def test_search_hint_added_when_tool_search_present(self):
|
||||
"""tool_search 存在且有扩展工具时添加搜索提示"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
FakeTool(name="web_search", description="Search the web."),
|
||||
]
|
||||
# 模拟 _maybe_add_tool_search 添加 tool_search
|
||||
tools_with_search = engine._maybe_add_tool_search(tools)
|
||||
prompt = engine._build_tool_use_prompt(tools_with_search)
|
||||
assert "tool_search" in prompt
|
||||
assert "扩展工具" in prompt
|
||||
# 搜索提示存在
|
||||
assert "tool_search(query" in prompt
|
||||
|
||||
def test_no_search_hint_when_no_extended_tools(self):
|
||||
"""无扩展工具时不添加搜索提示"""
|
||||
engine = self._make_engine()
|
||||
tools = [
|
||||
FakeTool(name="read_file", description="Read a file."),
|
||||
]
|
||||
prompt = engine._build_tool_use_prompt(tools)
|
||||
assert "扩展工具" not in prompt
|
||||
Loading…
Reference in New Issue