feat: hub-and-spoke experts, tiered tool injection, unified event model (U3/U7/U10)
This commit is contained in:
parent
200174c5c7
commit
bbedfff597
|
|
@ -3,6 +3,7 @@
|
||||||
from agentkit.core.base import BaseAgent
|
from agentkit.core.base import BaseAgent
|
||||||
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
|
from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor
|
||||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
|
||||||
from agentkit.core.exceptions import (
|
from agentkit.core.exceptions import (
|
||||||
AgentAlreadyRegisteredError,
|
AgentAlreadyRegisteredError,
|
||||||
AgentFrameworkError,
|
AgentFrameworkError,
|
||||||
|
|
@ -29,12 +30,16 @@ from agentkit.core.protocol import (
|
||||||
AgentCapability,
|
AgentCapability,
|
||||||
AgentStatus,
|
AgentStatus,
|
||||||
CancellationToken,
|
CancellationToken,
|
||||||
|
Event,
|
||||||
EvolutionEvent,
|
EvolutionEvent,
|
||||||
HandoffMessage,
|
HandoffMessage,
|
||||||
|
SessionEventType,
|
||||||
|
TaskEventType,
|
||||||
TaskMessage,
|
TaskMessage,
|
||||||
TaskProgress,
|
TaskProgress,
|
||||||
TaskResult,
|
TaskResult,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
|
TurnEventType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional: HeadroomCompressor — only available when headroom-ai is installed
|
# Optional: HeadroomCompressor — only available when headroom-ai is installed
|
||||||
|
|
@ -80,4 +85,12 @@ __all__ = [
|
||||||
"TaskProgress",
|
"TaskProgress",
|
||||||
"TaskResult",
|
"TaskResult",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
|
# SQ/EQ 统一事件模型
|
||||||
|
"Event",
|
||||||
|
"SessionEventType",
|
||||||
|
"TaskEventType",
|
||||||
|
"TurnEventType",
|
||||||
|
"Submission",
|
||||||
|
"SubmissionQueue",
|
||||||
|
"EventQueue",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,244 @@
|
||||||
|
"""统一事件模型 - SQ/EQ 双队列实现
|
||||||
|
|
||||||
|
参考 Codex 的 Session/Task/Turn 三级模型,提供:
|
||||||
|
- SubmissionQueue (SQ): 接收用户输入,返回 task_id
|
||||||
|
- EventQueue (EQ): 推送 Agent 事件,支持多订阅者广播和事件缓冲回放
|
||||||
|
|
||||||
|
集成点(Phase 4 再做实际集成):
|
||||||
|
- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端
|
||||||
|
- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
from agentkit.core.protocol import Event
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流)
|
||||||
|
_CLOSED_SENTINEL: Event = Event(
|
||||||
|
event_type="__closed__",
|
||||||
|
task_id="",
|
||||||
|
session_id="",
|
||||||
|
data={},
|
||||||
|
timestamp="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Submission:
|
||||||
|
"""用户提交的任务
|
||||||
|
|
||||||
|
由 SubmissionQueue.submit() 创建,消费者通过 SubmissionQueue.drain() 获取。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
task_id: 唯一任务 ID
|
||||||
|
session_id: 会话 ID
|
||||||
|
content: 用户输入内容
|
||||||
|
created_at: 创建时间
|
||||||
|
cancelled: 是否已取消
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
session_id: str
|
||||||
|
content: str
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
cancelled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SubmissionQueue:
|
||||||
|
"""提交队列 (SQ) - 接收用户输入
|
||||||
|
|
||||||
|
用户通过 submit() 提交输入,消费者通过 drain() 获取提交。
|
||||||
|
支持通过 task_id 取消提交。
|
||||||
|
|
||||||
|
内部使用 asyncio.Queue 实现,队列大小上限 1024(与 HandoffTransport 一致)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_MAX_QUEUE_SIZE: int = 1024
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||||||
|
self._submissions: dict[str, Submission] = {}
|
||||||
|
self._cancelled_tasks: set[str] = set()
|
||||||
|
self._closed: bool = False
|
||||||
|
|
||||||
|
async def submit(self, content: str, session_id: str) -> str:
|
||||||
|
"""提交用户输入,返回 task_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 用户输入内容
|
||||||
|
session_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新生成的 task_id(UUID4)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 队列已关闭
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise RuntimeError("SubmissionQueue is closed")
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
submission = Submission(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
self._submissions[task_id] = submission
|
||||||
|
await self._queue.put(submission)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
async def drain(self) -> AsyncIterator[Submission]:
|
||||||
|
"""消费提交队列(异步生成器)
|
||||||
|
|
||||||
|
已取消的提交会被跳过。当队列关闭时,生成器结束。
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
submission = await self._queue.get()
|
||||||
|
if submission.task_id in self._cancelled_tasks:
|
||||||
|
continue
|
||||||
|
yield submission
|
||||||
|
|
||||||
|
async def cancel(self, task_id: str) -> bool:
|
||||||
|
"""取消任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 要取消的任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功取消(任务存在且未取消过)
|
||||||
|
"""
|
||||||
|
if task_id not in self._submissions:
|
||||||
|
return False
|
||||||
|
if task_id in self._cancelled_tasks:
|
||||||
|
return False
|
||||||
|
self._cancelled_tasks.add(task_id)
|
||||||
|
self._submissions[task_id].cancelled = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""返回队列是否已关闭"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭队列,不再接受新提交。
|
||||||
|
|
||||||
|
已在队列中的提交不受影响,消费者仍可通过 drain() 获取。
|
||||||
|
"""
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
|
||||||
|
class EventQueue:
|
||||||
|
"""事件队列 (EQ) - 推送 Agent 事件
|
||||||
|
|
||||||
|
支持多订阅者广播模式:每条事件会投递到所有活跃订阅者。
|
||||||
|
新订阅者会收到最近 N 条缓冲事件的回放(默认 100 条)。
|
||||||
|
|
||||||
|
集成点:
|
||||||
|
- Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端
|
||||||
|
- CLI 可通过 subscribe() 订阅事件流并打印
|
||||||
|
"""
|
||||||
|
|
||||||
|
_MAX_QUEUE_SIZE: int = 1024
|
||||||
|
_DEFAULT_BUFFER_SIZE: int = 100
|
||||||
|
|
||||||
|
def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None:
|
||||||
|
self._subscribers: list[asyncio.Queue[Event]] = []
|
||||||
|
self._buffer: deque[Event] = deque(maxlen=buffer_size)
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
self._closed: bool = False
|
||||||
|
|
||||||
|
async def emit(self, event: Event) -> None:
|
||||||
|
"""推送事件给所有订阅者
|
||||||
|
|
||||||
|
事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。
|
||||||
|
如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: 要推送的事件
|
||||||
|
"""
|
||||||
|
self._buffer.append(event)
|
||||||
|
for queue in self._subscribers:
|
||||||
|
try:
|
||||||
|
queue.put_nowait(event)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("EventQueue subscriber queue full, dropping event")
|
||||||
|
|
||||||
|
async def subscribe(self) -> AsyncIterator[Event]:
|
||||||
|
"""订阅事件流(异步生成器)
|
||||||
|
|
||||||
|
订阅时会先回放缓冲区中的事件,然后持续接收新事件。
|
||||||
|
每个订阅者获得独立的队列,实现广播语义。
|
||||||
|
|
||||||
|
当队列关闭时,生成器结束。
|
||||||
|
|
||||||
|
注意:回放和加入订阅者列表在同一同步段内完成(无 await),
|
||||||
|
保证不会遗漏或重复事件。
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
|
||||||
|
queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE)
|
||||||
|
|
||||||
|
# 回放缓冲事件(同步操作,无 await,保证原子性)
|
||||||
|
for event in list(self._buffer):
|
||||||
|
try:
|
||||||
|
queue.put_nowait(event)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning("EventQueue replay buffer full, skipping remaining")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 加入订阅者列表(在回放之后,确保不会收到重复事件)
|
||||||
|
self._subscribers.append(queue)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
event = await queue.get()
|
||||||
|
if event is _CLOSED_SENTINEL:
|
||||||
|
break
|
||||||
|
yield event
|
||||||
|
finally:
|
||||||
|
# 清理:移除当前订阅者的队列
|
||||||
|
if queue in self._subscribers:
|
||||||
|
self._subscribers.remove(queue)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subscriber_count(self) -> int:
|
||||||
|
"""返回当前订阅者数量"""
|
||||||
|
return len(self._subscribers)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def buffer_size(self) -> int:
|
||||||
|
"""返回缓冲区大小上限"""
|
||||||
|
return self._buffer_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
"""返回队列是否已关闭"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭队列,向所有订阅者发送哨兵并清理状态。
|
||||||
|
|
||||||
|
使用哨兵模式(参考 HandoffTransport),确保阻塞在 subscribe() 上的
|
||||||
|
订阅者能够优雅退出。
|
||||||
|
"""
|
||||||
|
self._closed = True
|
||||||
|
# 向所有活跃订阅者队列放入哨兵,使其能够优雅退出
|
||||||
|
for queue in self._subscribers:
|
||||||
|
try:
|
||||||
|
queue.put_nowait(_CLOSED_SENTINEL)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
pass
|
||||||
|
self._subscribers.clear()
|
||||||
|
self._buffer.clear()
|
||||||
|
|
@ -10,6 +10,7 @@ from agentkit.core.exceptions import TaskCancelledError
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
class TaskStatus(str, Enum):
|
||||||
"""任务状态枚举"""
|
"""任务状态枚举"""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
|
@ -21,6 +22,7 @@ class TaskStatus(str, Enum):
|
||||||
|
|
||||||
class AgentStatus(str, Enum):
|
class AgentStatus(str, Enum):
|
||||||
"""Agent 状态枚举"""
|
"""Agent 状态枚举"""
|
||||||
|
|
||||||
ONLINE = "online"
|
ONLINE = "online"
|
||||||
OFFLINE = "offline"
|
OFFLINE = "offline"
|
||||||
BUSY = "busy"
|
BUSY = "busy"
|
||||||
|
|
@ -29,6 +31,7 @@ class AgentStatus(str, Enum):
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgentCapability:
|
class AgentCapability:
|
||||||
"""Agent 能力声明"""
|
"""Agent 能力声明"""
|
||||||
|
|
||||||
agent_name: str
|
agent_name: str
|
||||||
agent_type: str
|
agent_type: str
|
||||||
version: str
|
version: str
|
||||||
|
|
@ -70,6 +73,7 @@ class AgentCapability:
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskMessage:
|
class TaskMessage:
|
||||||
"""任务消息 - 从调度器发往 Agent"""
|
"""任务消息 - 从调度器发往 Agent"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
agent_name: str
|
agent_name: str
|
||||||
task_type: str
|
task_type: str
|
||||||
|
|
@ -114,6 +118,7 @@ class TaskMessage:
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskResult:
|
class TaskResult:
|
||||||
"""任务结果 - 从 Agent 返回"""
|
"""任务结果 - 从 Agent 返回"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
agent_name: str
|
agent_name: str
|
||||||
status: str
|
status: str
|
||||||
|
|
@ -163,6 +168,7 @@ class TaskResult:
|
||||||
@dataclass
|
@dataclass
|
||||||
class TaskProgress:
|
class TaskProgress:
|
||||||
"""进度上报 - Agent 执行过程中上报"""
|
"""进度上报 - Agent 执行过程中上报"""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
agent_name: str
|
agent_name: str
|
||||||
progress: float
|
progress: float
|
||||||
|
|
@ -195,6 +201,7 @@ class TaskProgress:
|
||||||
@dataclass
|
@dataclass
|
||||||
class HandoffMessage:
|
class HandoffMessage:
|
||||||
"""任务转交消息 - Agent 间 Handoff"""
|
"""任务转交消息 - Agent 间 Handoff"""
|
||||||
|
|
||||||
source_agent: str
|
source_agent: str
|
||||||
target_agent: str
|
target_agent: str
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|
@ -233,6 +240,7 @@ class HandoffMessage:
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvolutionEvent:
|
class EvolutionEvent:
|
||||||
"""进化事件 - 记录 Agent 的自我进化变更"""
|
"""进化事件 - 记录 Agent 的自我进化变更"""
|
||||||
|
|
||||||
agent_name: str
|
agent_name: str
|
||||||
change_type: str # prompt / strategy / pipeline
|
change_type: str # prompt / strategy / pipeline
|
||||||
before: dict[str, Any]
|
before: dict[str, Any]
|
||||||
|
|
@ -277,3 +285,102 @@ class CancellationToken:
|
||||||
"""检查是否已取消,若已取消则抛出 TaskCancelledError"""
|
"""检查是否已取消,若已取消则抛出 TaskCancelledError"""
|
||||||
if self._cancelled:
|
if self._cancelled:
|
||||||
raise TaskCancelledError(task_id="")
|
raise TaskCancelledError(task_id="")
|
||||||
|
|
||||||
|
|
||||||
|
# ── SQ/EQ 统一事件模型 ──────────────────────────────────────────
|
||||||
|
# 参考 Codex 的 Session/Task/Turn 三级模型,统一 CLI 和 WebSocket 事件流。
|
||||||
|
|
||||||
|
|
||||||
|
class SessionEventType:
|
||||||
|
"""Session 级别事件类型
|
||||||
|
|
||||||
|
对应一次完整的会话生命周期(如用户打开 CLI、关闭 CLI)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
SESSION_STARTED = "session.started"
|
||||||
|
SESSION_ENDED = "session.ended"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskEventType:
|
||||||
|
"""Task 级别事件类型
|
||||||
|
|
||||||
|
对应一次任务提交(用户输入),从创建到终态(完成/失败)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
TASK_CREATED = "task.created"
|
||||||
|
TASK_STARTED = "task.started"
|
||||||
|
TASK_COMPLETED = "task.completed"
|
||||||
|
TASK_FAILED = "task.failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TurnEventType:
|
||||||
|
"""Turn 级别事件类型
|
||||||
|
|
||||||
|
对应一个 Task 内的多轮对话/推理步骤,包括思考、工具调用、Token 流等。
|
||||||
|
"""
|
||||||
|
|
||||||
|
TURN_STARTED = "turn.started"
|
||||||
|
THINKING = "turn.thinking"
|
||||||
|
TOOL_CALL = "turn.tool_call"
|
||||||
|
TOOL_RESULT = "turn.tool_result"
|
||||||
|
TOKEN = "turn.token"
|
||||||
|
STEP = "turn.step"
|
||||||
|
FINAL_ANSWER = "turn.final_answer"
|
||||||
|
TURN_COMPLETED = "turn.completed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Event:
|
||||||
|
"""统一事件模型 - SQ/EQ 双队列事件
|
||||||
|
|
||||||
|
所有 CLI 和 WebSocket 事件统一使用此结构,便于前端和 CLI 统一处理。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
event_type: 事件类型(见 SessionEventType / TaskEventType / TurnEventType)
|
||||||
|
task_id: 关联的任务 ID
|
||||||
|
session_id: 关联的会话 ID
|
||||||
|
data: 事件负载数据
|
||||||
|
timestamp: ISO 8601 格式时间戳
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_type: str
|
||||||
|
task_id: str
|
||||||
|
session_id: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
timestamp: str # ISO 8601 format
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"event_type": self.event_type,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"data": self.data,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "Event":
|
||||||
|
return cls(
|
||||||
|
event_type=data["event_type"],
|
||||||
|
task_id=data["task_id"],
|
||||||
|
session_id=data["session_id"],
|
||||||
|
data=data.get("data", {}),
|
||||||
|
timestamp=data["timestamp"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
event_type: str,
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
) -> "Event":
|
||||||
|
"""创建事件,自动生成 ISO 8601 格式时间戳"""
|
||||||
|
return cls(
|
||||||
|
event_type=event_type,
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
data=data or {},
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
"""Expert 系统 - 专家团队模式的配置、模板、注册与协作计划"""
|
"""Expert 系统 - 专家团队 hub-and-spoke 模式
|
||||||
|
|
||||||
|
U3 简化:从去中心化协作改为 Lead Expert + 并行 Task 模式。
|
||||||
|
"""
|
||||||
|
|
||||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||||
from agentkit.experts.expert import Expert
|
from agentkit.experts.expert import Expert
|
||||||
from agentkit.experts.orchestrator import TeamOrchestrator
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
||||||
from agentkit.experts.plan import (
|
from agentkit.experts.plan import (
|
||||||
CollaborationPlan,
|
|
||||||
MergeStrategy,
|
MergeStrategy,
|
||||||
ParallelType,
|
|
||||||
PhaseStatus,
|
|
||||||
PlanPhase,
|
|
||||||
PlanStatus,
|
PlanStatus,
|
||||||
|
SubTask,
|
||||||
|
SubTaskStatus,
|
||||||
|
TeamPlan,
|
||||||
)
|
)
|
||||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||||
from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult
|
from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult
|
||||||
from agentkit.experts.team import ExpertTeam, TeamStatus
|
from agentkit.experts.team import ExpertTeam, TeamStatus
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CollaborationPlan",
|
|
||||||
"Expert",
|
"Expert",
|
||||||
"ExpertConfig",
|
"ExpertConfig",
|
||||||
"ExpertTeam",
|
"ExpertTeam",
|
||||||
|
|
@ -25,10 +26,10 @@ __all__ = [
|
||||||
"ExpertTemplate",
|
"ExpertTemplate",
|
||||||
"ExpertTemplateRegistry",
|
"ExpertTemplateRegistry",
|
||||||
"MergeStrategy",
|
"MergeStrategy",
|
||||||
"ParallelType",
|
|
||||||
"PhaseStatus",
|
|
||||||
"PlanPhase",
|
|
||||||
"PlanStatus",
|
"PlanStatus",
|
||||||
|
"SubTask",
|
||||||
|
"SubTaskStatus",
|
||||||
"TeamOrchestrator",
|
"TeamOrchestrator",
|
||||||
|
"TeamPlan",
|
||||||
"TeamStatus",
|
"TeamStatus",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.core.config_driven import AgentConfig
|
from agentkit.core.config_driven import AgentConfig
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,9 @@ class Expert:
|
||||||
# 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息
|
# 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息
|
||||||
if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template:
|
if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template:
|
||||||
sections = agent._prompt_template._sections
|
sections = agent._prompt_template._sections
|
||||||
sections.context = f"{team_context}\n\n{sections.context}" if sections.context else team_context
|
sections.context = (
|
||||||
|
f"{team_context}\n\n{sections.context}" if sections.context else team_context
|
||||||
|
)
|
||||||
|
|
||||||
return expert
|
return expert
|
||||||
|
|
||||||
|
|
@ -116,9 +118,7 @@ class Expert:
|
||||||
"reason": reason,
|
"reason": reason,
|
||||||
"type": "assist_request",
|
"type": "assist_request",
|
||||||
}
|
}
|
||||||
await self._handoff_transport.send(
|
await self._handoff_transport.send(f"expert:{target_expert}:handoff", handoff_msg)
|
||||||
f"expert:{target_expert}:handoff", handoff_msg
|
|
||||||
)
|
|
||||||
|
|
||||||
async def propose_plan_modification(
|
async def propose_plan_modification(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,281 +1,172 @@
|
||||||
"""CollaborationPlan 数据模型 - Expert Team 协作蓝图
|
"""Expert Team 计划数据模型 - hub-and-spoke 模式
|
||||||
|
|
||||||
定义 Expert Team 的结构化协作计划,包括阶段、角色分配、依赖关系、
|
简化为 Lead Expert + 并行 Task 模式:
|
||||||
并行类型、合并策略和里程碑。
|
- Lead Expert 接收任务并分解为子任务
|
||||||
|
- 子任务并行执行(Task 深度=1,不能再 spawn Task)
|
||||||
|
- Lead Expert 汇总结果(BEST 策略)
|
||||||
|
|
||||||
|
移除了阶段依赖图、VOTE/FUSION 合并策略、跨阶段状态共享。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class ParallelType(str, enum.Enum):
|
|
||||||
"""并行执行类型"""
|
|
||||||
|
|
||||||
SERIAL = "serial"
|
|
||||||
SUBTASK_PARALLEL = "subtask_parallel"
|
|
||||||
COMPETITIVE_PARALLEL = "competitive_parallel"
|
|
||||||
|
|
||||||
|
|
||||||
class MergeStrategy(str, enum.Enum):
|
class MergeStrategy(str, enum.Enum):
|
||||||
"""合并策略 - 仅用于 COMPETITIVE_PARALLEL 阶段"""
|
"""合并策略 - Lead Expert 用于选择最佳结果
|
||||||
|
|
||||||
|
hub-and-spoke 模式下仅保留 BEST 策略:
|
||||||
|
Lead Expert 从所有子任务结果中选择或综合出最佳结果。
|
||||||
|
"""
|
||||||
|
|
||||||
BEST = "best"
|
BEST = "best"
|
||||||
VOTE = "vote"
|
|
||||||
FUSION = "fusion"
|
|
||||||
|
|
||||||
|
|
||||||
class PhaseStatus(str, enum.Enum):
|
|
||||||
"""阶段状态"""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
|
||||||
IN_PROGRESS = "in_progress"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
class PlanStatus(str, enum.Enum):
|
class PlanStatus(str, enum.Enum):
|
||||||
"""计划状态"""
|
"""计划状态"""
|
||||||
|
|
||||||
DRAFT = "draft"
|
DRAFT = "draft"
|
||||||
CONFIRMED = "confirmed"
|
|
||||||
EXECUTING = "executing"
|
EXECUTING = "executing"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
FALLBACK = "fallback"
|
FALLBACK = "fallback"
|
||||||
|
|
||||||
|
|
||||||
# DFS 着色常量
|
class SubTaskStatus(str, enum.Enum):
|
||||||
_WHITE = 0 # 未访问
|
"""子任务状态"""
|
||||||
_GRAY = 1 # 正在访问(当前路径上)
|
|
||||||
_BLACK = 2 # 已完成访问
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PlanPhase:
|
class SubTask:
|
||||||
"""协作计划中的单个阶段
|
"""Lead Expert 分解出的子任务
|
||||||
|
|
||||||
|
在 hub-and-spoke 模式中,Lead Expert 将原始任务分解为多个子任务,
|
||||||
|
每个子任务由一个 Expert 并行执行。子任务之间无依赖关系、无通信。
|
||||||
|
|
||||||
|
约束:
|
||||||
|
- Task 深度=1(子任务不能再 spawn 子任务)
|
||||||
|
- 子任务之间无通信
|
||||||
|
- Lead Expert 持有所有状态
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
id: 阶段标识符
|
id: 子任务标识符
|
||||||
name: 阶段显示名称
|
description: 子任务描述
|
||||||
assigned_expert: 分配到此阶段的 Expert 名称
|
assigned_expert: 分配的 Expert 名称
|
||||||
task_description: 此阶段完成的任务描述
|
|
||||||
depends_on: 依赖的阶段 ID 列表
|
|
||||||
parallel_type: 执行类型
|
|
||||||
merge_strategy: 合并策略,仅 COMPETITIVE_PARALLEL 需要
|
|
||||||
milestone: 里程碑检查点描述
|
|
||||||
status: 当前状态
|
status: 当前状态
|
||||||
result: 阶段输出结果
|
result: 子任务输出结果
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
name: str
|
description: str = ""
|
||||||
assigned_expert: str
|
assigned_expert: str = ""
|
||||||
task_description: str
|
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||||
depends_on: list[str] = field(default_factory=list)
|
result: dict[str, Any] | None = None
|
||||||
parallel_type: ParallelType = ParallelType.SERIAL
|
|
||||||
merge_strategy: MergeStrategy | None = None
|
|
||||||
milestone: str = ""
|
|
||||||
status: PhaseStatus = PhaseStatus.PENDING
|
|
||||||
result: dict | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""序列化为字典"""
|
"""序列化为字典"""
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"name": self.name,
|
"description": self.description,
|
||||||
"assigned_expert": self.assigned_expert,
|
"assigned_expert": self.assigned_expert,
|
||||||
"task_description": self.task_description,
|
|
||||||
"depends_on": self.depends_on,
|
|
||||||
"parallel_type": self.parallel_type.value,
|
|
||||||
"merge_strategy": self.merge_strategy.value if self.merge_strategy is not None else None,
|
|
||||||
"milestone": self.milestone,
|
|
||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"result": self.result,
|
"result": self.result,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> PlanPhase:
|
def from_dict(cls, data: dict[str, Any]) -> SubTask:
|
||||||
"""从字典创建 PlanPhase"""
|
"""从字典创建 SubTask"""
|
||||||
merge_strategy = None
|
|
||||||
if data.get("merge_strategy") is not None:
|
|
||||||
merge_strategy = MergeStrategy(data["merge_strategy"])
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data.get("id", str(uuid.uuid4())),
|
||||||
name=data["name"],
|
description=data.get("description", ""),
|
||||||
assigned_expert=data["assigned_expert"],
|
assigned_expert=data.get("assigned_expert", ""),
|
||||||
task_description=data["task_description"],
|
status=SubTaskStatus(data.get("status", SubTaskStatus.PENDING.value)),
|
||||||
depends_on=data.get("depends_on", []),
|
|
||||||
parallel_type=ParallelType(data.get("parallel_type", ParallelType.SERIAL.value)),
|
|
||||||
merge_strategy=merge_strategy,
|
|
||||||
milestone=data.get("milestone", ""),
|
|
||||||
status=PhaseStatus(data.get("status", PhaseStatus.PENDING.value)),
|
|
||||||
result=data.get("result"),
|
result=data.get("result"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CollaborationPlan:
|
class TeamPlan:
|
||||||
"""Expert Team 协作计划
|
"""Expert Team hub-and-spoke 执行计划
|
||||||
|
|
||||||
定义 Expert Team 的结构化协作蓝图,包括阶段编排、共享变量、
|
Lead Expert 持有此计划,包含分解的子任务列表。
|
||||||
状态管理和依赖关系。
|
与旧版 CollaborationPlan 不同,此计划无阶段依赖图,
|
||||||
|
所有子任务并行执行,由 Lead Expert 统一汇总。
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
id: 计划标识符
|
id: 计划标识符
|
||||||
task: 原始任务描述
|
task: 原始任务描述
|
||||||
phases: 有序阶段列表
|
subtasks: 子任务列表(并行执行,无依赖关系)
|
||||||
variables: 共享变量
|
|
||||||
status: 计划状态
|
status: 计划状态
|
||||||
lead_expert: 主导 Expert 名称
|
lead_expert: 主导 Expert 名称
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
task: str
|
task: str = ""
|
||||||
phases: list[PlanPhase] = field(default_factory=list)
|
subtasks: list[SubTask] = field(default_factory=list)
|
||||||
variables: dict = field(default_factory=dict)
|
|
||||||
status: PlanStatus = PlanStatus.DRAFT
|
status: PlanStatus = PlanStatus.DRAFT
|
||||||
lead_expert: str = ""
|
lead_expert: str = ""
|
||||||
_phase_index: dict[str, PlanPhase] = field(default_factory=dict, init=False, repr=False)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
"""Build the phase index after initialization."""
|
|
||||||
self._rebuild_index()
|
|
||||||
|
|
||||||
def _rebuild_index(self) -> None:
|
|
||||||
"""Rebuild the phase index from the phases list."""
|
|
||||||
self._phase_index = {phase.id: phase for phase in self.phases}
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""序列化为字典"""
|
"""序列化为字典"""
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"task": self.task,
|
"task": self.task,
|
||||||
"phases": [phase.to_dict() for phase in self.phases],
|
"subtasks": [st.to_dict() for st in self.subtasks],
|
||||||
"variables": self.variables,
|
|
||||||
"status": self.status.value,
|
"status": self.status.value,
|
||||||
"lead_expert": self.lead_expert,
|
"lead_expert": self.lead_expert,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict[str, Any]) -> CollaborationPlan:
|
def from_dict(cls, data: dict[str, Any]) -> TeamPlan:
|
||||||
"""从字典创建 CollaborationPlan"""
|
"""从字典创建 TeamPlan"""
|
||||||
phases = [PlanPhase.from_dict(p) for p in data.get("phases", [])]
|
subtasks = [SubTask.from_dict(st) for st in data.get("subtasks", [])]
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data.get("id", str(uuid.uuid4())),
|
||||||
task=data["task"],
|
task=data.get("task", ""),
|
||||||
phases=phases,
|
subtasks=subtasks,
|
||||||
variables=data.get("variables", {}),
|
|
||||||
status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)),
|
status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)),
|
||||||
lead_expert=data.get("lead_expert", ""),
|
lead_expert=data.get("lead_expert", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self) -> list[str]:
|
def get_subtask(self, subtask_id: str) -> SubTask | None:
|
||||||
"""验证计划,返回错误消息列表(空列表表示有效)
|
"""根据 ID 获取子任务,不存在则返回 None"""
|
||||||
|
for st in self.subtasks:
|
||||||
|
if st.id == subtask_id:
|
||||||
|
return st
|
||||||
|
return None
|
||||||
|
|
||||||
检查项:
|
def update_subtask_status(
|
||||||
- 无重复阶段 ID
|
self, subtask_id: str, status: SubTaskStatus, result: dict[str, Any] | None = None
|
||||||
- 所有 depends_on 引用存在
|
|
||||||
- 无循环依赖(DFS 着色检测)
|
|
||||||
- COMPETITIVE_PARALLEL 阶段必须有 merge_strategy
|
|
||||||
"""
|
|
||||||
errors: list[str] = []
|
|
||||||
|
|
||||||
# 构建阶段 ID 集合
|
|
||||||
phase_ids = {phase.id for phase in self.phases}
|
|
||||||
|
|
||||||
# 检查重复阶段 ID
|
|
||||||
seen_ids: set[str] = set()
|
|
||||||
for phase in self.phases:
|
|
||||||
if phase.id in seen_ids:
|
|
||||||
errors.append(f"重复的阶段 ID: {phase.id}")
|
|
||||||
seen_ids.add(phase.id)
|
|
||||||
|
|
||||||
# 检查 depends_on 引用是否存在
|
|
||||||
for phase in self.phases:
|
|
||||||
for dep_id in phase.depends_on:
|
|
||||||
if dep_id not in phase_ids:
|
|
||||||
errors.append(
|
|
||||||
f"阶段 '{phase.id}' 依赖了不存在的阶段 ID: {dep_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查循环依赖(迭代 DFS 着色 — 避免递归栈溢出)
|
|
||||||
color: dict[str, int] = {phase.id: _WHITE for phase in self.phases}
|
|
||||||
dep_map: dict[str, list[str]] = {
|
|
||||||
phase.id: phase.depends_on for phase in self.phases
|
|
||||||
}
|
|
||||||
|
|
||||||
for phase in self.phases:
|
|
||||||
if color[phase.id] != _WHITE:
|
|
||||||
continue
|
|
||||||
# Iterative DFS using an explicit stack
|
|
||||||
stack: list[tuple[str, bool]] = [(phase.id, False)]
|
|
||||||
while stack:
|
|
||||||
node, is_backtrack = stack.pop()
|
|
||||||
if is_backtrack:
|
|
||||||
color[node] = _BLACK
|
|
||||||
continue
|
|
||||||
if color[node] == _GRAY:
|
|
||||||
# Already on current path — cycle detected
|
|
||||||
errors.append("检测到循环依赖")
|
|
||||||
break
|
|
||||||
if color[node] == _BLACK:
|
|
||||||
continue
|
|
||||||
color[node] = _GRAY
|
|
||||||
# Push backtrack marker
|
|
||||||
stack.append((node, True))
|
|
||||||
for neighbor in dep_map.get(node, []):
|
|
||||||
if neighbor not in color:
|
|
||||||
continue
|
|
||||||
if color[neighbor] == _GRAY:
|
|
||||||
errors.append("检测到循环依赖")
|
|
||||||
break
|
|
||||||
if color[neighbor] == _WHITE:
|
|
||||||
stack.append((neighbor, False))
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
break # Inner break propagates to outer
|
|
||||||
if errors:
|
|
||||||
break # Only report cycle once
|
|
||||||
|
|
||||||
# 检查 COMPETITIVE_PARALLEL 必须有 merge_strategy
|
|
||||||
for phase in self.phases:
|
|
||||||
if (
|
|
||||||
phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
|
|
||||||
and phase.merge_strategy is None
|
|
||||||
):
|
|
||||||
errors.append(
|
|
||||||
f"阶段 '{phase.id}' 为 COMPETITIVE_PARALLEL 但未设置 merge_strategy"
|
|
||||||
)
|
|
||||||
|
|
||||||
return errors
|
|
||||||
|
|
||||||
def get_ready_phases(self) -> list[PlanPhase]:
|
|
||||||
"""获取所有依赖已完成且状态为 PENDING 的阶段"""
|
|
||||||
completed_ids = {
|
|
||||||
phase.id for phase in self.phases if phase.status == PhaseStatus.COMPLETED
|
|
||||||
}
|
|
||||||
ready: list[PlanPhase] = []
|
|
||||||
for phase in self.phases:
|
|
||||||
if phase.status != PhaseStatus.PENDING:
|
|
||||||
continue
|
|
||||||
if all(dep_id in completed_ids for dep_id in phase.depends_on):
|
|
||||||
ready.append(phase)
|
|
||||||
return ready
|
|
||||||
|
|
||||||
def get_phase(self, phase_id: str) -> PlanPhase | None:
|
|
||||||
"""根据 ID 获取阶段,不存在则返回 None (O(1) lookup)"""
|
|
||||||
return self._phase_index.get(phase_id)
|
|
||||||
|
|
||||||
def update_phase_status(
|
|
||||||
self, phase_id: str, status: PhaseStatus, result: dict | None = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""更新阶段状态和可选的结果"""
|
"""更新子任务状态和可选的结果"""
|
||||||
phase = self.get_phase(phase_id)
|
st = self.get_subtask(subtask_id)
|
||||||
if phase is not None:
|
if st is not None:
|
||||||
phase.status = status
|
st.status = status
|
||||||
if result is not None:
|
if result is not None:
|
||||||
phase.result = result
|
st.result = result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completed_subtasks(self) -> list[SubTask]:
|
||||||
|
"""已完成的子任务列表"""
|
||||||
|
return [st for st in self.subtasks if st.status == SubTaskStatus.COMPLETED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def failed_subtasks(self) -> list[SubTask]:
|
||||||
|
"""失败的子任务列表"""
|
||||||
|
return [st for st in self.subtasks if st.status == SubTaskStatus.FAILED]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_done(self) -> bool:
|
||||||
|
"""所有子任务是否都已完成(成功或失败)"""
|
||||||
|
return all(
|
||||||
|
st.status in (SubTaskStatus.COMPLETED, SubTaskStatus.FAILED) for st in self.subtasks
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from agentkit.core.exceptions import ConfigValidationError
|
from agentkit.core.exceptions import ConfigValidationError
|
||||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
from agentkit.experts.config import ExpertTemplate
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -62,10 +61,7 @@ class ExpertTemplateRegistry:
|
||||||
query_lower = query.lower()
|
query_lower = query.lower()
|
||||||
results: list[ExpertTemplate] = []
|
results: list[ExpertTemplate] = []
|
||||||
for template in self._templates.values():
|
for template in self._templates.values():
|
||||||
if (
|
if query_lower in template.name.lower() or query_lower in template.description.lower():
|
||||||
query_lower in template.name.lower()
|
|
||||||
or query_lower in template.description.lower()
|
|
||||||
):
|
|
||||||
results.append(template)
|
results.append(template)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
@ -134,7 +130,5 @@ class ExpertTemplateRegistry:
|
||||||
template = self.load_from_yaml(filepath)
|
template = self.load_from_yaml(filepath)
|
||||||
loaded.append(template)
|
loaded.append(template)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(f"Failed to load ExpertTemplate from '{filepath}': {e}")
|
||||||
f"Failed to load ExpertTemplate from '{filepath}': {e}"
|
|
||||||
)
|
|
||||||
return loaded
|
return loaded
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
"""ExpertTeam - 专家团队容器
|
"""ExpertTeam - 专家团队容器(hub-and-spoke 模式)
|
||||||
|
|
||||||
管理 Expert 生命周期、共享上下文、协作计划和团队状态,
|
管理 Expert 生命周期、团队状态和事件广播,
|
||||||
是 Expert Team 协作模式的中央协调点。
|
是 Expert Team hub-and-spoke 协作模式的中央协调点。
|
||||||
|
|
||||||
|
简化说明(U3):
|
||||||
|
- 移除 CollaborationPlan 依赖(Lead Expert 自主分解任务)
|
||||||
|
- 移除跨阶段状态共享(Lead Expert 持有所有状态)
|
||||||
|
- 保留 handoff_transport 用于事件广播(不再用于 Agent 间通信)
|
||||||
|
- 保留 workspace 用于输出保存(不再用于跨阶段状态共享)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -14,7 +20,6 @@ import uuid
|
||||||
|
|
||||||
from .config import ExpertConfig
|
from .config import ExpertConfig
|
||||||
from .expert import Expert
|
from .expert import Expert
|
||||||
from .plan import CollaborationPlan, PlanStatus
|
|
||||||
from .registry import ExpertTemplateRegistry
|
from .registry import ExpertTemplateRegistry
|
||||||
from ..core.handoff_transport import InProcessHandoffTransport
|
from ..core.handoff_transport import InProcessHandoffTransport
|
||||||
from ..core.shared_workspace import SharedWorkspace
|
from ..core.shared_workspace import SharedWorkspace
|
||||||
|
|
@ -27,7 +32,6 @@ class TeamStatus(str, enum.Enum):
|
||||||
"""ExpertTeam lifecycle states."""
|
"""ExpertTeam lifecycle states."""
|
||||||
|
|
||||||
FORMING = "forming"
|
FORMING = "forming"
|
||||||
PLANNING = "planning"
|
|
||||||
EXECUTING = "executing"
|
EXECUTING = "executing"
|
||||||
SYNTHESIZING = "synthesizing"
|
SYNTHESIZING = "synthesizing"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
|
@ -35,7 +39,14 @@ class TeamStatus(str, enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class ExpertTeam:
|
class ExpertTeam:
|
||||||
"""Container managing a team of Experts working together on a task."""
|
"""Container managing a team of Experts in hub-and-spoke mode.
|
||||||
|
|
||||||
|
In hub-and-spoke mode:
|
||||||
|
- Lead Expert (hub) receives the task and decomposes it
|
||||||
|
- Member Experts (spokes) execute subtasks in parallel
|
||||||
|
- Lead Expert synthesizes the final result
|
||||||
|
- No inter-agent communication (Lead Expert holds all state)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -51,7 +62,6 @@ class ExpertTeam:
|
||||||
self._handoff_transport = InProcessHandoffTransport()
|
self._handoff_transport = InProcessHandoffTransport()
|
||||||
self._experts: dict[str, Expert] = {}
|
self._experts: dict[str, Expert] = {}
|
||||||
self._lead_expert_name: str | None = None
|
self._lead_expert_name: str | None = None
|
||||||
self._plan: CollaborationPlan | None = None
|
|
||||||
self._status = TeamStatus.FORMING
|
self._status = TeamStatus.FORMING
|
||||||
self._team_channel = f"team:{self.team_id}"
|
self._team_channel = f"team:{self.team_id}"
|
||||||
self._orchestrator_task: asyncio.Task | None = None
|
self._orchestrator_task: asyncio.Task | None = None
|
||||||
|
|
@ -66,10 +76,6 @@ class ExpertTeam:
|
||||||
return self._experts.get(self._lead_expert_name)
|
return self._experts.get(self._lead_expert_name)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def plan(self) -> CollaborationPlan | None:
|
|
||||||
return self._plan
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def experts(self) -> list[Expert]:
|
def experts(self) -> list[Expert]:
|
||||||
return list(self._experts.values())
|
return list(self._experts.values())
|
||||||
|
|
@ -80,12 +86,21 @@ class ExpertTeam:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workspace(self) -> SharedWorkspace:
|
def workspace(self) -> SharedWorkspace:
|
||||||
"""Public read access to the team's shared workspace."""
|
"""Public read access to the team's shared workspace.
|
||||||
|
|
||||||
|
In hub-and-spoke mode, workspace is used for output preservation only,
|
||||||
|
not for cross-phase state sharing.
|
||||||
|
"""
|
||||||
return self._workspace
|
return self._workspace
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handoff_transport(self):
|
def handoff_transport(self):
|
||||||
"""Public read access to the team's handoff transport."""
|
"""Public read access to the team's handoff transport.
|
||||||
|
|
||||||
|
In hub-and-spoke mode, handoff_transport is used for event broadcasting
|
||||||
|
(team_formed, expert_step, expert_result, team_synthesis, team_dissolved),
|
||||||
|
not for inter-agent communication.
|
||||||
|
"""
|
||||||
return self._handoff_transport
|
return self._handoff_transport
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -106,7 +121,13 @@ class ExpertTeam:
|
||||||
lead_config: ExpertConfig,
|
lead_config: ExpertConfig,
|
||||||
member_configs: list[ExpertConfig] | None = None,
|
member_configs: list[ExpertConfig] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a team with a Lead Expert and optional members."""
|
"""Create a team with a Lead Expert and optional members.
|
||||||
|
|
||||||
|
In hub-and-spoke mode, the Lead Expert acts as the hub:
|
||||||
|
- Receives the task and decomposes it into subtasks
|
||||||
|
- Dispatches subtasks to member experts (spokes)
|
||||||
|
- Synthesizes the final result
|
||||||
|
"""
|
||||||
if not self._pool:
|
if not self._pool:
|
||||||
raise RuntimeError("AgentPool not configured")
|
raise RuntimeError("AgentPool not configured")
|
||||||
|
|
||||||
|
|
@ -128,7 +149,7 @@ class ExpertTeam:
|
||||||
for config in member_configs:
|
for config in member_configs:
|
||||||
await self._add_expert_internal(config, team_context)
|
await self._add_expert_internal(config, team_context)
|
||||||
|
|
||||||
self._status = TeamStatus.PLANNING
|
self._status = TeamStatus.EXECUTING
|
||||||
|
|
||||||
async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert:
|
async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert:
|
||||||
"""Add an Expert to the team dynamically.
|
"""Add an Expert to the team dynamically.
|
||||||
|
|
@ -155,9 +176,7 @@ class ExpertTeam:
|
||||||
)
|
)
|
||||||
return await self._add_expert_internal(config, team_context)
|
return await self._add_expert_internal(config, team_context)
|
||||||
|
|
||||||
async def _add_expert_internal(
|
async def _add_expert_internal(self, config: ExpertConfig, team_context: str) -> Expert:
|
||||||
self, config: ExpertConfig, team_context: str
|
|
||||||
) -> Expert:
|
|
||||||
"""Internal method to add an Expert."""
|
"""Internal method to add an Expert."""
|
||||||
if not self._pool:
|
if not self._pool:
|
||||||
raise RuntimeError("AgentPool not configured")
|
raise RuntimeError("AgentPool not configured")
|
||||||
|
|
@ -204,13 +223,6 @@ class ExpertTeam:
|
||||||
await expert.destroy(self._pool)
|
await expert.destroy(self._pool)
|
||||||
del self._experts[name]
|
del self._experts[name]
|
||||||
|
|
||||||
# Update plan: reassign phases that referenced the removed expert
|
|
||||||
if self._plan:
|
|
||||||
new_lead_name = self._lead_expert_name
|
|
||||||
for phase in self._plan.phases:
|
|
||||||
if phase.assigned_expert == name:
|
|
||||||
phase.assigned_expert = new_lead_name or ""
|
|
||||||
|
|
||||||
# Broadcast expert left
|
# Broadcast expert left
|
||||||
await self._handoff_transport.send(
|
await self._handoff_transport.send(
|
||||||
self._team_channel,
|
self._team_channel,
|
||||||
|
|
@ -220,24 +232,6 @@ class ExpertTeam:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_plan(self, plan: CollaborationPlan) -> list[str]:
|
|
||||||
"""Update the collaboration plan. Only Lead Expert or user should call this.
|
|
||||||
|
|
||||||
Returns list of affected expert names on success, or list of validation
|
|
||||||
error strings on failure (empty list with no errors = success).
|
|
||||||
"""
|
|
||||||
errors = plan.validate()
|
|
||||||
if errors:
|
|
||||||
return errors # Return validation errors instead of silently swallowing
|
|
||||||
|
|
||||||
self._plan = plan
|
|
||||||
if plan.status == PlanStatus.CONFIRMED:
|
|
||||||
self._status = TeamStatus.EXECUTING
|
|
||||||
|
|
||||||
# Determine affected experts
|
|
||||||
affected = [p.assigned_expert for p in plan.phases]
|
|
||||||
return affected
|
|
||||||
|
|
||||||
async def broadcast_user_message(self, content: str) -> None:
|
async def broadcast_user_message(self, content: str) -> None:
|
||||||
"""Broadcast a user intervention message to all active Experts."""
|
"""Broadcast a user intervention message to all active Experts."""
|
||||||
message = {
|
message = {
|
||||||
|
|
@ -248,7 +242,11 @@ class ExpertTeam:
|
||||||
await self._handoff_transport.send(self._team_channel, message)
|
await self._handoff_transport.send(self._team_channel, message)
|
||||||
|
|
||||||
async def get_shared_context(self) -> dict:
|
async def get_shared_context(self) -> dict:
|
||||||
"""Get the team's shared context from SharedWorkspace."""
|
"""Get the team's shared context from SharedWorkspace.
|
||||||
|
|
||||||
|
In hub-and-spoke mode, this returns preserved outputs only,
|
||||||
|
not cross-phase state.
|
||||||
|
"""
|
||||||
context = {}
|
context = {}
|
||||||
keys = await self._workspace.list_keys()
|
keys = await self._workspace.list_keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
@ -258,23 +256,6 @@ class ExpertTeam:
|
||||||
context[key] = data
|
context[key] = data
|
||||||
return context
|
return context
|
||||||
|
|
||||||
async def generate_plan(self, task: str) -> CollaborationPlan:
|
|
||||||
"""Generate a CollaborationPlan for the task.
|
|
||||||
|
|
||||||
Uses hybrid mode: core roles from template registry, auxiliary roles dynamically generated.
|
|
||||||
This method creates a plan structure — the actual LLM-based task decomposition
|
|
||||||
will be handled by TeamOrchestrator.
|
|
||||||
"""
|
|
||||||
plan_id = str(uuid.uuid4())
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id=plan_id,
|
|
||||||
task=task,
|
|
||||||
phases=[],
|
|
||||||
lead_expert=self._lead_expert_name or "",
|
|
||||||
)
|
|
||||||
self._plan = plan
|
|
||||||
return plan
|
|
||||||
|
|
||||||
async def dissolve(self) -> None:
|
async def dissolve(self) -> None:
|
||||||
"""Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace."""
|
"""Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace."""
|
||||||
# Cancel ongoing orchestrator task if any
|
# Cancel ongoing orchestrator task if any
|
||||||
|
|
@ -302,20 +283,31 @@ class ExpertTeam:
|
||||||
lead_config: ExpertConfig | None,
|
lead_config: ExpertConfig | None,
|
||||||
member_configs: list[ExpertConfig],
|
member_configs: list[ExpertConfig],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build team context string for injection into Expert system prompts."""
|
"""Build team context string for injection into Expert system prompts.
|
||||||
lines = ["You are part of an Expert Team."]
|
|
||||||
|
In hub-and-spoke mode, the context emphasizes the Lead Expert's role
|
||||||
|
as the hub and member experts' role as spokes.
|
||||||
|
"""
|
||||||
|
lines = ["You are part of an Expert Team (hub-and-spoke mode)."]
|
||||||
|
|
||||||
if lead_config:
|
if lead_config:
|
||||||
lines.append(f"Lead Expert: {lead_config.name} ({lead_config.persona})")
|
lines.append(
|
||||||
|
f"Lead Expert (hub): {lead_config.name} ({lead_config.persona}) — "
|
||||||
|
f"decomposes tasks, dispatches subtasks, and synthesizes results."
|
||||||
|
)
|
||||||
|
|
||||||
for config in member_configs:
|
for config in member_configs:
|
||||||
if lead_config and config.name == lead_config.name:
|
if lead_config and config.name == lead_config.name:
|
||||||
continue
|
continue
|
||||||
lines.append(
|
lines.append(
|
||||||
f"Team Member: {config.name} ({config.persona}), Skills: {', '.join(config.bound_skills)}"
|
f"Team Member (spoke): {config.name} ({config.persona}), "
|
||||||
|
f"Skills: {', '.join(config.bound_skills)} — "
|
||||||
|
f"executes assigned subtasks independently."
|
||||||
)
|
)
|
||||||
|
|
||||||
lines.append(
|
lines.append(
|
||||||
"You can collaborate with other team members via send_message() and request_assist()."
|
"In hub-and-spoke mode: Lead Expert holds all state, "
|
||||||
|
"subtasks are independent (depth=1, no further spawning), "
|
||||||
|
"no inter-agent communication."
|
||||||
)
|
)
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||||||
from agentkit.tools.ask_human import AskHumanTool
|
from agentkit.tools.ask_human import AskHumanTool
|
||||||
from agentkit.tools.memory_tool import MemoryTool
|
from agentkit.tools.memory_tool import MemoryTool
|
||||||
from agentkit.tools.web_search import WebSearchTool
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
|
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||||
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -40,6 +42,9 @@ __all__ = [
|
||||||
"MemoryTool",
|
"MemoryTool",
|
||||||
"ShellTool",
|
"ShellTool",
|
||||||
"WebSearchTool",
|
"WebSearchTool",
|
||||||
|
"RunTestsTool",
|
||||||
|
"ToolSearchTool",
|
||||||
|
"ToolSearchIndex",
|
||||||
"HeadroomRetrieveTool",
|
"HeadroomRetrieveTool",
|
||||||
"TerminalSession",
|
"TerminalSession",
|
||||||
"TerminalSessionManager",
|
"TerminalSessionManager",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""Tool search index using BM25 algorithm.
|
||||||
|
|
||||||
|
Provides lexical tool discovery for tiered tool description injection:
|
||||||
|
core tools are fully injected into the LLM prompt, while extended tools
|
||||||
|
are searchable on-demand via the ``tool_search`` tool.
|
||||||
|
|
||||||
|
Pure Python implementation — no external dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
__all__ = ["ToolSearchIndex"]
|
||||||
|
|
||||||
|
|
||||||
|
_TOKEN_RE = re.compile(r"[a-zA-Z0-9_]+")
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize(text: str) -> list[str]:
|
||||||
|
"""Tokenize text into lowercase alphanumeric tokens.
|
||||||
|
|
||||||
|
Splits on non-alphanumeric characters (keeping underscores so that
|
||||||
|
snake_case tool names like ``read_file`` stay intact) and lowercases.
|
||||||
|
Empty tokens are filtered out.
|
||||||
|
"""
|
||||||
|
return [t for t in _TOKEN_RE.findall(text.lower()) if t]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSearchIndex:
|
||||||
|
"""BM25-based search index over :class:`Tool` descriptions.
|
||||||
|
|
||||||
|
Each tool is indexed by its ``name`` + ``description`` + parameter
|
||||||
|
metadata (from ``input_schema``) + ``tags``. Supports keyword search
|
||||||
|
returning the most relevant tools ranked by BM25 score.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("read file", top_k=5)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DEFAULT_K1: float = 1.5
|
||||||
|
_DEFAULT_B: float = 0.75
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tools: list[Tool],
|
||||||
|
k1: float = _DEFAULT_K1,
|
||||||
|
b: float = _DEFAULT_B,
|
||||||
|
):
|
||||||
|
if k1 < 0:
|
||||||
|
raise ValueError(f"k1 must be >= 0, got {k1}")
|
||||||
|
if not (0.0 <= b <= 1.0):
|
||||||
|
raise ValueError(f"b must be in [0, 1], got {b}")
|
||||||
|
|
||||||
|
self._tools: list[Tool] = list(tools)
|
||||||
|
self._k1 = k1
|
||||||
|
self._b = b
|
||||||
|
|
||||||
|
# Build document corpus from tool metadata
|
||||||
|
self._documents: list[str] = [self._tool_to_text(t) for t in self._tools]
|
||||||
|
self._tokenized_docs: list[list[str]] = [_tokenize(doc) for doc in self._documents]
|
||||||
|
self._doc_lengths: list[int] = [len(tokens) for tokens in self._tokenized_docs]
|
||||||
|
self._avgdl: float = (
|
||||||
|
sum(self._doc_lengths) / len(self._doc_lengths) if self._doc_lengths else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Term frequencies per document
|
||||||
|
self._term_freqs: list[Counter[str]] = [Counter(tokens) for tokens in self._tokenized_docs]
|
||||||
|
|
||||||
|
# Document frequency per term (number of docs containing the term)
|
||||||
|
self._doc_freq: Counter[str] = Counter()
|
||||||
|
for tokens in self._tokenized_docs:
|
||||||
|
for term in set(tokens):
|
||||||
|
self._doc_freq[term] += 1
|
||||||
|
|
||||||
|
# Precompute IDF for every term in the corpus
|
||||||
|
self._idf: dict[str, float] = self._compute_idf()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Index construction
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_to_text(tool: Tool) -> str:
|
||||||
|
"""Convert a tool's searchable metadata into a single text document."""
|
||||||
|
parts: list[str] = [str(tool.name), str(tool.description)]
|
||||||
|
|
||||||
|
schema: dict[str, Any] | None = tool.input_schema
|
||||||
|
if schema:
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
for pname, pinfo in props.items():
|
||||||
|
parts.append(str(pname))
|
||||||
|
if isinstance(pinfo, dict):
|
||||||
|
pdesc = pinfo.get("description", "")
|
||||||
|
if pdesc:
|
||||||
|
parts.append(str(pdesc))
|
||||||
|
|
||||||
|
if tool.tags:
|
||||||
|
parts.extend(str(t) for t in tool.tags)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
def _compute_idf(self) -> dict[str, float]:
|
||||||
|
"""Compute IDF for each term using the BM25 formula.
|
||||||
|
|
||||||
|
``idf(t) = ln((N - n + 0.5) / (n + 0.5) + 1)``
|
||||||
|
|
||||||
|
where ``N`` is the total number of documents and ``n`` is the
|
||||||
|
number of documents containing term ``t``. The ``+1`` inside the
|
||||||
|
logarithm guarantees a non-negative IDF.
|
||||||
|
"""
|
||||||
|
n_docs = len(self._tools)
|
||||||
|
if n_docs == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
idf: dict[str, float] = {}
|
||||||
|
for term, df in self._doc_freq.items():
|
||||||
|
idf[term] = math.log((n_docs - df + 0.5) / (df + 0.5) + 1.0)
|
||||||
|
return idf
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Scoring
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _bm25_score(self, query_tokens: list[str], doc_idx: int) -> float:
|
||||||
|
"""Compute the BM25 score for a single document against the query."""
|
||||||
|
if not query_tokens or doc_idx >= len(self._tools):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
doc_len = self._doc_lengths[doc_idx]
|
||||||
|
term_freq = self._term_freqs[doc_idx]
|
||||||
|
norm = (
|
||||||
|
(1 - self._b + self._b * (doc_len / self._avgdl)) if self._avgdl > 0 else (1 - self._b)
|
||||||
|
)
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
# Each unique query term contributes once
|
||||||
|
for term in set(query_tokens):
|
||||||
|
tf = term_freq.get(term, 0)
|
||||||
|
if tf == 0:
|
||||||
|
continue
|
||||||
|
idf = self._idf.get(term, 0.0)
|
||||||
|
if idf <= 0:
|
||||||
|
continue
|
||||||
|
denom = tf + self._k1 * norm
|
||||||
|
score += idf * (tf * (self._k1 + 1)) / denom
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def search(self, query: str, top_k: int = 5) -> list[Tool]:
|
||||||
|
"""Search tools by keyword query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search keywords (natural language or tool-name fragments).
|
||||||
|
top_k: Maximum number of tools to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tools ranked by relevance (most relevant first). Only tools
|
||||||
|
with a positive BM25 score are returned.
|
||||||
|
"""
|
||||||
|
if top_k <= 0 or not self._tools:
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_tokens = _tokenize(query)
|
||||||
|
if not query_tokens:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scored: list[tuple[int, float]] = []
|
||||||
|
for i in range(len(self._tools)):
|
||||||
|
score = self._bm25_score(query_tokens, i)
|
||||||
|
if score > 0:
|
||||||
|
scored.append((i, score))
|
||||||
|
|
||||||
|
scored.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return [self._tools[i] for i, _ in scored[:top_k]]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._tools)
|
||||||
|
|
@ -0,0 +1,666 @@
|
||||||
|
"""Tests for EventQueue — SQ/EQ 双队列实现
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
- SQ 正确接收用户输入并返回 task_id
|
||||||
|
- EQ 正确推送事件给订阅者
|
||||||
|
- 多订阅者同时接收事件(广播)
|
||||||
|
- 事件缓冲对新订阅者的回放
|
||||||
|
- SQ 取消任务
|
||||||
|
- 事件类型正确分类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue
|
||||||
|
from agentkit.core.protocol import (
|
||||||
|
Event,
|
||||||
|
SessionEventType,
|
||||||
|
TaskEventType,
|
||||||
|
TurnEventType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── SubmissionQueue Tests ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubmissionQueue:
|
||||||
|
"""SubmissionQueue 单元测试"""
|
||||||
|
|
||||||
|
async def test_submit_returns_task_id(self):
|
||||||
|
"""测试 submit 返回有效的 task_id"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id = await sq.submit("hello", "session-1")
|
||||||
|
|
||||||
|
assert isinstance(task_id, str)
|
||||||
|
assert len(task_id) > 0
|
||||||
|
|
||||||
|
async def test_submit_returns_unique_task_ids(self):
|
||||||
|
"""测试每次 submit 返回不同的 task_id"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id_1 = await sq.submit("hello", "session-1")
|
||||||
|
task_id_2 = await sq.submit("world", "session-1")
|
||||||
|
|
||||||
|
assert task_id_1 != task_id_2
|
||||||
|
|
||||||
|
async def test_submit_stores_submission(self):
|
||||||
|
"""测试 submit 正确存储提交内容"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id = await sq.submit("hello world", "session-1")
|
||||||
|
|
||||||
|
assert task_id in sq._submissions
|
||||||
|
submission = sq._submissions[task_id]
|
||||||
|
assert submission.content == "hello world"
|
||||||
|
assert submission.session_id == "session-1"
|
||||||
|
assert submission.task_id == task_id
|
||||||
|
assert submission.cancelled is False
|
||||||
|
|
||||||
|
async def test_drain_receives_submissions_in_order(self):
|
||||||
|
"""测试 drain 按提交顺序接收提交"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
await sq.submit("first", "session-1")
|
||||||
|
await sq.submit("second", "session-1")
|
||||||
|
|
||||||
|
received: list[str] = []
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
async for submission in sq.drain():
|
||||||
|
received.append(submission.content)
|
||||||
|
if len(received) >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
consumer_task = asyncio.create_task(consumer())
|
||||||
|
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert received == ["first", "second"]
|
||||||
|
|
||||||
|
async def test_drain_preserves_submission_fields(self):
|
||||||
|
"""测试 drain 返回的 Submission 字段完整"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
await sq.submit("hello", "session-1")
|
||||||
|
|
||||||
|
received: list[Submission] = []
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
async for submission in sq.drain():
|
||||||
|
received.append(submission)
|
||||||
|
break
|
||||||
|
|
||||||
|
consumer_task = asyncio.create_task(consumer())
|
||||||
|
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
sub = received[0]
|
||||||
|
assert sub.content == "hello"
|
||||||
|
assert sub.session_id == "session-1"
|
||||||
|
assert isinstance(sub.task_id, str)
|
||||||
|
assert isinstance(sub.created_at, datetime)
|
||||||
|
|
||||||
|
async def test_cancel_task_succeeds(self):
|
||||||
|
"""测试取消已存在的任务"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id = await sq.submit("hello", "session-1")
|
||||||
|
|
||||||
|
result = await sq.cancel(task_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert task_id in sq._cancelled_tasks
|
||||||
|
assert sq._submissions[task_id].cancelled is True
|
||||||
|
|
||||||
|
async def test_cancel_nonexistent_task_returns_false(self):
|
||||||
|
"""测试取消不存在的任务返回 False"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
result = await sq.cancel("nonexistent-task-id")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
async def test_cancel_already_cancelled_task_returns_false(self):
|
||||||
|
"""测试重复取消返回 False"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id = await sq.submit("hello", "session-1")
|
||||||
|
|
||||||
|
first_cancel = await sq.cancel(task_id)
|
||||||
|
second_cancel = await sq.cancel(task_id)
|
||||||
|
|
||||||
|
assert first_cancel is True
|
||||||
|
assert second_cancel is False
|
||||||
|
|
||||||
|
async def test_drain_skips_cancelled_submissions(self):
|
||||||
|
"""测试 drain 跳过已取消的提交"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
task_id_1 = await sq.submit("first", "session-1")
|
||||||
|
await sq.submit("second", "session-1")
|
||||||
|
|
||||||
|
# 取消第一个提交
|
||||||
|
await sq.cancel(task_id_1)
|
||||||
|
|
||||||
|
received: list[str] = []
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
async for submission in sq.drain():
|
||||||
|
received.append(submission.content)
|
||||||
|
if len(received) >= 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
consumer_task = asyncio.create_task(consumer())
|
||||||
|
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||||
|
|
||||||
|
# 应只收到第二个提交
|
||||||
|
assert received == ["second"]
|
||||||
|
|
||||||
|
async def test_close_prevents_new_submissions(self):
|
||||||
|
"""测试关闭后不再接受新提交"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
sq.close()
|
||||||
|
|
||||||
|
assert sq.is_closed is True
|
||||||
|
|
||||||
|
try:
|
||||||
|
await sq.submit("hello", "session-1")
|
||||||
|
raise AssertionError("Should have raised RuntimeError")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_close_does_not_affect_existing_submissions(self):
|
||||||
|
"""测试关闭后已提交的内容仍可消费"""
|
||||||
|
sq = SubmissionQueue()
|
||||||
|
await sq.submit("before-close", "session-1")
|
||||||
|
sq.close()
|
||||||
|
|
||||||
|
received: list[str] = []
|
||||||
|
|
||||||
|
async def consumer():
|
||||||
|
async for submission in sq.drain():
|
||||||
|
received.append(submission.content)
|
||||||
|
break
|
||||||
|
|
||||||
|
consumer_task = asyncio.create_task(consumer())
|
||||||
|
await asyncio.wait_for(consumer_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert received == ["before-close"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── EventQueue Tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventQueue:
|
||||||
|
"""EventQueue 单元测试"""
|
||||||
|
|
||||||
|
async def test_emit_and_subscribe_single_event(self):
|
||||||
|
"""测试 EQ 正确推送事件给订阅者"""
|
||||||
|
eq = EventQueue()
|
||||||
|
event = Event.create(
|
||||||
|
event_type=TurnEventType.TOKEN,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"text": "hello"},
|
||||||
|
)
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05) # 给订阅者启动时间
|
||||||
|
|
||||||
|
await eq.emit(event)
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
assert received[0].event_type == TurnEventType.TOKEN
|
||||||
|
assert received[0].task_id == "task-1"
|
||||||
|
assert received[0].session_id == "session-1"
|
||||||
|
assert received[0].data == {"text": "hello"}
|
||||||
|
|
||||||
|
async def test_emit_preserves_event_fields(self):
|
||||||
|
"""测试 emit 不修改事件字段"""
|
||||||
|
eq = EventQueue()
|
||||||
|
original = Event.create(
|
||||||
|
event_type=TaskEventType.TASK_STARTED,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"agent": "react"},
|
||||||
|
)
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
await eq.emit(original)
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert received[0] is original or received[0].to_dict() == original.to_dict()
|
||||||
|
|
||||||
|
async def test_broadcast_to_multiple_subscribers(self):
|
||||||
|
"""测试多订阅者同时接收事件(广播)"""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received_a: list[Event] = []
|
||||||
|
received_b: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber_a():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received_a.append(evt)
|
||||||
|
if len(received_a) >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def subscriber_b():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received_b.append(evt)
|
||||||
|
if len(received_b) >= 2:
|
||||||
|
break
|
||||||
|
|
||||||
|
task_a = asyncio.create_task(subscriber_a())
|
||||||
|
task_b = asyncio.create_task(subscriber_b())
|
||||||
|
await asyncio.sleep(0.05) # 给订阅者启动时间
|
||||||
|
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||||
|
|
||||||
|
await asyncio.wait_for(task_a, timeout=1.0)
|
||||||
|
await asyncio.wait_for(task_b, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received_a) == 2
|
||||||
|
assert len(received_b) == 2
|
||||||
|
assert received_a[0].data == {"seq": 1}
|
||||||
|
assert received_a[1].data == {"seq": 2}
|
||||||
|
assert received_b[0].data == {"seq": 1}
|
||||||
|
assert received_b[1].data == {"seq": 2}
|
||||||
|
|
||||||
|
async def test_buffer_replay_for_new_subscriber(self):
|
||||||
|
"""测试事件缓冲对新订阅者的回放"""
|
||||||
|
eq = EventQueue(buffer_size=100)
|
||||||
|
|
||||||
|
# 先发送几条事件(无订阅者)
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
|
||||||
|
|
||||||
|
# 新订阅者应收到缓冲回放
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 3
|
||||||
|
assert received[0].data == {"seq": 1}
|
||||||
|
assert received[1].data == {"seq": 2}
|
||||||
|
assert received[2].data == {"seq": 3}
|
||||||
|
|
||||||
|
async def test_buffer_replay_then_live_events(self):
|
||||||
|
"""测试新订阅者先收到回放,再收到新事件"""
|
||||||
|
eq = EventQueue(buffer_size=100)
|
||||||
|
|
||||||
|
# 先发送 2 条事件(进入缓冲)
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2}))
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 4:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲)
|
||||||
|
|
||||||
|
# 再发送 2 条新事件
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3}))
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4}))
|
||||||
|
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 4
|
||||||
|
assert [r.data["seq"] for r in received] == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
async def test_buffer_size_limit_keeps_latest(self):
|
||||||
|
"""测试缓冲区大小限制,只保留最新 N 条"""
|
||||||
|
eq = EventQueue(buffer_size=3)
|
||||||
|
|
||||||
|
# 发送 5 条事件,缓冲区只保留最后 3 条
|
||||||
|
for i in range(5):
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i}))
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
if len(received) >= 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 3
|
||||||
|
# 应该是最后 3 条(seq: 2, 3, 4)
|
||||||
|
assert [r.data["seq"] for r in received] == [2, 3, 4]
|
||||||
|
|
||||||
|
async def test_default_buffer_size_is_100(self):
|
||||||
|
"""测试默认缓冲区大小为 100"""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
assert eq.buffer_size == 100
|
||||||
|
|
||||||
|
async def test_close_unblocks_subscribers(self):
|
||||||
|
"""测试 close 解除订阅者阻塞"""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for _ in eq.subscribe():
|
||||||
|
pass # 消费事件直到队列关闭
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
eq.close()
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert sub_task.done()
|
||||||
|
assert eq.is_closed is True
|
||||||
|
|
||||||
|
async def test_subscribe_after_close_returns_immediately(self):
|
||||||
|
"""测试关闭后订阅立即返回(不阻塞)"""
|
||||||
|
eq = EventQueue()
|
||||||
|
eq.close()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert sub_task.done()
|
||||||
|
assert len(received) == 0
|
||||||
|
|
||||||
|
async def test_subscriber_count_tracks_subscriptions(self):
|
||||||
|
"""测试订阅者计数正确跟踪订阅"""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
assert eq.subscriber_count == 0
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for _ in eq.subscribe():
|
||||||
|
pass
|
||||||
|
|
||||||
|
task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
assert eq.subscriber_count == 1
|
||||||
|
|
||||||
|
eq.close()
|
||||||
|
await asyncio.wait_for(task, timeout=1.0)
|
||||||
|
|
||||||
|
assert eq.subscriber_count == 0
|
||||||
|
|
||||||
|
async def test_subscriber_removed_on_explicit_close(self):
|
||||||
|
"""测试显式关闭订阅生成器后从列表移除
|
||||||
|
|
||||||
|
注意:async for 的 break 不会立即触发生成器的 finally,
|
||||||
|
需要显式调用 aclose() 才能保证清理。
|
||||||
|
"""
|
||||||
|
eq = EventQueue()
|
||||||
|
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
gen = eq.subscribe()
|
||||||
|
try:
|
||||||
|
async for evt in gen:
|
||||||
|
received.append(evt)
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
assert eq.subscriber_count == 1
|
||||||
|
|
||||||
|
# 触发一次 emit 让订阅者能收到事件并 break
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||||
|
await asyncio.wait_for(task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
assert eq.subscriber_count == 0
|
||||||
|
|
||||||
|
async def test_emit_to_no_subscribers_still_buffers(self):
|
||||||
|
"""测试无订阅者时 emit 仍写入缓冲区"""
|
||||||
|
eq = EventQueue(buffer_size=100)
|
||||||
|
|
||||||
|
await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1}))
|
||||||
|
|
||||||
|
# 缓冲区应有 1 条
|
||||||
|
assert len(eq._buffer) == 1
|
||||||
|
|
||||||
|
# 新订阅者应收到回放
|
||||||
|
received: list[Event] = []
|
||||||
|
|
||||||
|
async def subscriber():
|
||||||
|
async for evt in eq.subscribe():
|
||||||
|
received.append(evt)
|
||||||
|
break
|
||||||
|
|
||||||
|
sub_task = asyncio.create_task(subscriber())
|
||||||
|
await asyncio.wait_for(sub_task, timeout=1.0)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
assert received[0].data == {"seq": 1}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Event Type Tests ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventTypes:
|
||||||
|
"""事件类型分类测试"""
|
||||||
|
|
||||||
|
def test_session_event_types(self):
|
||||||
|
"""测试 Session 级别事件类型"""
|
||||||
|
assert SessionEventType.SESSION_STARTED == "session.started"
|
||||||
|
assert SessionEventType.SESSION_ENDED == "session.ended"
|
||||||
|
|
||||||
|
def test_task_event_types(self):
|
||||||
|
"""测试 Task 级别事件类型"""
|
||||||
|
assert TaskEventType.TASK_CREATED == "task.created"
|
||||||
|
assert TaskEventType.TASK_STARTED == "task.started"
|
||||||
|
assert TaskEventType.TASK_COMPLETED == "task.completed"
|
||||||
|
assert TaskEventType.TASK_FAILED == "task.failed"
|
||||||
|
|
||||||
|
def test_turn_event_types(self):
|
||||||
|
"""测试 Turn 级别事件类型"""
|
||||||
|
assert TurnEventType.TURN_STARTED == "turn.started"
|
||||||
|
assert TurnEventType.THINKING == "turn.thinking"
|
||||||
|
assert TurnEventType.TOOL_CALL == "turn.tool_call"
|
||||||
|
assert TurnEventType.TOOL_RESULT == "turn.tool_result"
|
||||||
|
assert TurnEventType.TOKEN == "turn.token"
|
||||||
|
assert TurnEventType.STEP == "turn.step"
|
||||||
|
assert TurnEventType.FINAL_ANSWER == "turn.final_answer"
|
||||||
|
assert TurnEventType.TURN_COMPLETED == "turn.completed"
|
||||||
|
|
||||||
|
def test_event_type_prefixes(self):
|
||||||
|
"""测试事件类型按前缀正确分类"""
|
||||||
|
session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED]
|
||||||
|
task_events = [
|
||||||
|
TaskEventType.TASK_CREATED,
|
||||||
|
TaskEventType.TASK_STARTED,
|
||||||
|
TaskEventType.TASK_COMPLETED,
|
||||||
|
TaskEventType.TASK_FAILED,
|
||||||
|
]
|
||||||
|
turn_events = [
|
||||||
|
TurnEventType.TURN_STARTED,
|
||||||
|
TurnEventType.THINKING,
|
||||||
|
TurnEventType.TOOL_CALL,
|
||||||
|
TurnEventType.TOOL_RESULT,
|
||||||
|
TurnEventType.TOKEN,
|
||||||
|
TurnEventType.STEP,
|
||||||
|
TurnEventType.FINAL_ANSWER,
|
||||||
|
TurnEventType.TURN_COMPLETED,
|
||||||
|
]
|
||||||
|
|
||||||
|
for evt in session_events:
|
||||||
|
assert evt.startswith("session."), f"{evt} should start with 'session.'"
|
||||||
|
for evt in task_events:
|
||||||
|
assert evt.startswith("task."), f"{evt} should start with 'task.'"
|
||||||
|
for evt in turn_events:
|
||||||
|
assert evt.startswith("turn."), f"{evt} should start with 'turn.'"
|
||||||
|
|
||||||
|
def test_event_types_are_distinct(self):
|
||||||
|
"""测试所有事件类型互不相同"""
|
||||||
|
all_types = [
|
||||||
|
SessionEventType.SESSION_STARTED,
|
||||||
|
SessionEventType.SESSION_ENDED,
|
||||||
|
TaskEventType.TASK_CREATED,
|
||||||
|
TaskEventType.TASK_STARTED,
|
||||||
|
TaskEventType.TASK_COMPLETED,
|
||||||
|
TaskEventType.TASK_FAILED,
|
||||||
|
TurnEventType.TURN_STARTED,
|
||||||
|
TurnEventType.THINKING,
|
||||||
|
TurnEventType.TOOL_CALL,
|
||||||
|
TurnEventType.TOOL_RESULT,
|
||||||
|
TurnEventType.TOKEN,
|
||||||
|
TurnEventType.STEP,
|
||||||
|
TurnEventType.FINAL_ANSWER,
|
||||||
|
TurnEventType.TURN_COMPLETED,
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(all_types) == len(set(all_types)), "Event types should be distinct"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Event Dataclass Tests ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventDataclass:
|
||||||
|
"""Event 数据结构测试"""
|
||||||
|
|
||||||
|
def test_event_creation(self):
|
||||||
|
"""测试 Event 创建"""
|
||||||
|
event = Event(
|
||||||
|
event_type=TurnEventType.TOKEN,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"text": "hello"},
|
||||||
|
timestamp="2025-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == TurnEventType.TOKEN
|
||||||
|
assert event.task_id == "task-1"
|
||||||
|
assert event.session_id == "session-1"
|
||||||
|
assert event.data == {"text": "hello"}
|
||||||
|
assert event.timestamp == "2025-01-01T00:00:00+00:00"
|
||||||
|
|
||||||
|
def test_event_create_factory_generates_timestamp(self):
|
||||||
|
"""测试 Event.create 工厂方法自动生成时间戳"""
|
||||||
|
event = Event.create(
|
||||||
|
event_type=TaskEventType.TASK_STARTED,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"agent": "react"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == TaskEventType.TASK_STARTED
|
||||||
|
assert event.task_id == "task-1"
|
||||||
|
assert event.session_id == "session-1"
|
||||||
|
assert event.data == {"agent": "react"}
|
||||||
|
assert len(event.timestamp) > 0
|
||||||
|
# 时间戳应为 ISO 8601 格式(fromisoformat 不抛异常即正确)
|
||||||
|
datetime.fromisoformat(event.timestamp)
|
||||||
|
|
||||||
|
def test_event_create_with_default_data(self):
|
||||||
|
"""测试 Event.create 不传 data 时默认为空 dict"""
|
||||||
|
event = Event.create(
|
||||||
|
event_type=SessionEventType.SESSION_STARTED,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.data == {}
|
||||||
|
|
||||||
|
def test_event_to_dict(self):
|
||||||
|
"""测试 Event.to_dict"""
|
||||||
|
event = Event(
|
||||||
|
event_type=TurnEventType.TOKEN,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"text": "hello"},
|
||||||
|
timestamp="2025-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
d = event.to_dict()
|
||||||
|
|
||||||
|
assert d == {
|
||||||
|
"event_type": TurnEventType.TOKEN,
|
||||||
|
"task_id": "task-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"data": {"text": "hello"},
|
||||||
|
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_event_from_dict(self):
|
||||||
|
"""测试 Event.from_dict"""
|
||||||
|
data = {
|
||||||
|
"event_type": TurnEventType.TOKEN,
|
||||||
|
"task_id": "task-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"data": {"text": "hello"},
|
||||||
|
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
event = Event.from_dict(data)
|
||||||
|
|
||||||
|
assert event.event_type == TurnEventType.TOKEN
|
||||||
|
assert event.task_id == "task-1"
|
||||||
|
assert event.session_id == "session-1"
|
||||||
|
assert event.data == {"text": "hello"}
|
||||||
|
assert event.timestamp == "2025-01-01T00:00:00+00:00"
|
||||||
|
|
||||||
|
def test_event_to_dict_from_dict_roundtrip(self):
|
||||||
|
"""测试 to_dict -> from_dict 往返保持一致"""
|
||||||
|
original = Event.create(
|
||||||
|
event_type=SessionEventType.SESSION_STARTED,
|
||||||
|
task_id="task-1",
|
||||||
|
session_id="session-1",
|
||||||
|
data={"user": "alice"},
|
||||||
|
)
|
||||||
|
|
||||||
|
d = original.to_dict()
|
||||||
|
restored = Event.from_dict(d)
|
||||||
|
|
||||||
|
assert restored.event_type == original.event_type
|
||||||
|
assert restored.task_id == original.task_id
|
||||||
|
assert restored.session_id == original.session_id
|
||||||
|
assert restored.data == original.data
|
||||||
|
assert restored.timestamp == original.timestamp
|
||||||
|
|
||||||
|
def test_event_from_dict_with_missing_data_defaults_to_empty(self):
|
||||||
|
"""测试 from_dict 缺少 data 字段时默认为空 dict"""
|
||||||
|
data = {
|
||||||
|
"event_type": TurnEventType.TOKEN,
|
||||||
|
"task_id": "task-1",
|
||||||
|
"session_id": "session-1",
|
||||||
|
"timestamp": "2025-01-01T00:00:00+00:00",
|
||||||
|
}
|
||||||
|
|
||||||
|
event = Event.from_dict(data)
|
||||||
|
|
||||||
|
assert event.data == {}
|
||||||
|
|
@ -1,328 +1,288 @@
|
||||||
"""CollaborationPlan 数据模型单元测试"""
|
"""TeamPlan / SubTask 数据模型单元测试 (hub-and-spoke 模式)"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from agentkit.experts.plan import (
|
from agentkit.experts.plan import (
|
||||||
CollaborationPlan,
|
|
||||||
MergeStrategy,
|
MergeStrategy,
|
||||||
ParallelType,
|
|
||||||
PhaseStatus,
|
|
||||||
PlanPhase,
|
|
||||||
PlanStatus,
|
PlanStatus,
|
||||||
|
SubTask,
|
||||||
|
SubTaskStatus,
|
||||||
|
TeamPlan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── 辅助函数 ──────────────────────────────────────────────
|
# ── 辅助函数 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _make_phase(
|
def _make_subtask(
|
||||||
id: str = "phase_1",
|
id: str = "subtask_1",
|
||||||
name: str = "分析阶段",
|
description: str = "分析数据",
|
||||||
assigned_expert: str = "analyst",
|
assigned_expert: str = "analyst",
|
||||||
task_description: str = "分析需求",
|
status: SubTaskStatus = SubTaskStatus.PENDING,
|
||||||
depends_on: list[str] | None = None,
|
|
||||||
parallel_type: ParallelType = ParallelType.SERIAL,
|
|
||||||
merge_strategy: MergeStrategy | None = None,
|
|
||||||
milestone: str = "",
|
|
||||||
status: PhaseStatus = PhaseStatus.PENDING,
|
|
||||||
result: dict | None = None,
|
result: dict | None = None,
|
||||||
) -> PlanPhase:
|
) -> SubTask:
|
||||||
"""创建测试用 PlanPhase 实例"""
|
"""创建测试用 SubTask 实例"""
|
||||||
return PlanPhase(
|
return SubTask(
|
||||||
id=id,
|
id=id,
|
||||||
name=name,
|
description=description,
|
||||||
assigned_expert=assigned_expert,
|
assigned_expert=assigned_expert,
|
||||||
task_description=task_description,
|
|
||||||
depends_on=depends_on or [],
|
|
||||||
parallel_type=parallel_type,
|
|
||||||
merge_strategy=merge_strategy,
|
|
||||||
milestone=milestone,
|
|
||||||
status=status,
|
status=status,
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_valid_plan() -> CollaborationPlan:
|
def _make_valid_plan() -> TeamPlan:
|
||||||
"""创建一个有效的协作计划"""
|
"""创建一个有效的 hub-and-spoke 执行计划"""
|
||||||
phases = [
|
subtasks = [
|
||||||
_make_phase(id="p1", name="需求分析", assigned_expert="analyst", task_description="分析需求"),
|
_make_subtask(id="s1", description="分析需求", assigned_expert="analyst"),
|
||||||
_make_phase(
|
_make_subtask(id="s2", description="设计架构", assigned_expert="architect"),
|
||||||
id="p2",
|
_make_subtask(id="s3", description="编写代码", assigned_expert="coder"),
|
||||||
name="架构设计",
|
|
||||||
assigned_expert="architect",
|
|
||||||
task_description="设计架构",
|
|
||||||
depends_on=["p1"],
|
|
||||||
),
|
|
||||||
_make_phase(
|
|
||||||
id="p3",
|
|
||||||
name="代码实现",
|
|
||||||
assigned_expert="coder",
|
|
||||||
task_description="编写代码",
|
|
||||||
depends_on=["p2"],
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
return CollaborationPlan(
|
return TeamPlan(
|
||||||
id="plan_001",
|
id="plan_001",
|
||||||
task="实现用户登录功能",
|
task="实现用户登录功能",
|
||||||
phases=phases,
|
subtasks=subtasks,
|
||||||
variables={"project": "fischer"},
|
|
||||||
status=PlanStatus.DRAFT,
|
status=PlanStatus.DRAFT,
|
||||||
lead_expert="architect",
|
lead_expert="architect",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── PlanPhase 测试 ────────────────────────────────────────
|
# ── MergeStrategy 测试 ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestPlanPhase:
|
class TestMergeStrategy:
|
||||||
"""PlanPhase 数据模型测试"""
|
"""MergeStrategy 枚举测试"""
|
||||||
|
|
||||||
|
def test_only_best_strategy_exists(self):
|
||||||
|
"""hub-and-spoke 模式仅保留 BEST 策略"""
|
||||||
|
assert MergeStrategy.BEST == "best"
|
||||||
|
# 确保只有 BEST 一个值
|
||||||
|
assert len(list(MergeStrategy)) == 1
|
||||||
|
|
||||||
|
def test_no_vote_or_fusion(self):
|
||||||
|
"""VOTE 和 FUSION 已被移除"""
|
||||||
|
assert not hasattr(MergeStrategy, "VOTE")
|
||||||
|
assert not hasattr(MergeStrategy, "FUSION")
|
||||||
|
|
||||||
|
|
||||||
|
# ── PlanStatus 测试 ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlanStatus:
|
||||||
|
"""PlanStatus 枚举测试"""
|
||||||
|
|
||||||
|
def test_statuses_exist(self):
|
||||||
|
"""必要的计划状态都存在"""
|
||||||
|
assert PlanStatus.DRAFT == "draft"
|
||||||
|
assert PlanStatus.EXECUTING == "executing"
|
||||||
|
assert PlanStatus.COMPLETED == "completed"
|
||||||
|
assert PlanStatus.FAILED == "failed"
|
||||||
|
assert PlanStatus.FALLBACK == "fallback"
|
||||||
|
|
||||||
|
def test_no_confirmed_status(self):
|
||||||
|
"""CONFIRMED 状态已被移除(hub-and-spoke 无需确认阶段)"""
|
||||||
|
assert not hasattr(PlanStatus, "CONFIRMED")
|
||||||
|
|
||||||
|
|
||||||
|
# ── SubTaskStatus 测试 ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubTaskStatus:
|
||||||
|
"""SubTaskStatus 枚举测试"""
|
||||||
|
|
||||||
|
def test_statuses_exist(self):
|
||||||
|
"""子任务状态都存在"""
|
||||||
|
assert SubTaskStatus.PENDING == "pending"
|
||||||
|
assert SubTaskStatus.RUNNING == "running"
|
||||||
|
assert SubTaskStatus.COMPLETED == "completed"
|
||||||
|
assert SubTaskStatus.FAILED == "failed"
|
||||||
|
|
||||||
|
|
||||||
|
# ── SubTask 测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubTask:
|
||||||
|
"""SubTask 数据模型测试"""
|
||||||
|
|
||||||
def test_creation_with_all_fields(self):
|
def test_creation_with_all_fields(self):
|
||||||
"""创建 PlanPhase 并设置所有字段"""
|
"""创建 SubTask 并设置所有字段"""
|
||||||
phase = PlanPhase(
|
subtask = SubTask(
|
||||||
id="phase_a",
|
id="subtask_a",
|
||||||
name="竞品分析",
|
description="竞品分析",
|
||||||
assigned_expert="analyst",
|
assigned_expert="analyst",
|
||||||
task_description="分析竞品功能",
|
status=SubTaskStatus.RUNNING,
|
||||||
depends_on=["phase_0"],
|
|
||||||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
|
||||||
merge_strategy=MergeStrategy.BEST,
|
|
||||||
milestone="竞品报告完成",
|
|
||||||
status=PhaseStatus.IN_PROGRESS,
|
|
||||||
result={"report": "竞品分析报告"},
|
result={"report": "竞品分析报告"},
|
||||||
)
|
)
|
||||||
assert phase.id == "phase_a"
|
assert subtask.id == "subtask_a"
|
||||||
assert phase.name == "竞品分析"
|
assert subtask.description == "竞品分析"
|
||||||
assert phase.assigned_expert == "analyst"
|
assert subtask.assigned_expert == "analyst"
|
||||||
assert phase.task_description == "分析竞品功能"
|
assert subtask.status == SubTaskStatus.RUNNING
|
||||||
assert phase.depends_on == ["phase_0"]
|
assert subtask.result == {"report": "竞品分析报告"}
|
||||||
assert phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL
|
|
||||||
assert phase.merge_strategy == MergeStrategy.BEST
|
def test_default_values(self):
|
||||||
assert phase.milestone == "竞品报告完成"
|
"""默认值:自动生成 id,PENDING 状态"""
|
||||||
assert phase.status == PhaseStatus.IN_PROGRESS
|
subtask = SubTask(description="测试任务")
|
||||||
assert phase.result == {"report": "竞品分析报告"}
|
assert subtask.id is not None
|
||||||
|
assert len(subtask.id) > 0
|
||||||
|
assert subtask.description == "测试任务"
|
||||||
|
assert subtask.assigned_expert == ""
|
||||||
|
assert subtask.status == SubTaskStatus.PENDING
|
||||||
|
assert subtask.result is None
|
||||||
|
|
||||||
def test_to_dict_from_dict_roundtrip(self):
|
def test_to_dict_from_dict_roundtrip(self):
|
||||||
"""to_dict / from_dict 往返序列化"""
|
"""to_dict / from_dict 往返序列化"""
|
||||||
phase = PlanPhase(
|
subtask = SubTask(
|
||||||
id="roundtrip_phase",
|
id="roundtrip_subtask",
|
||||||
name="往返测试",
|
description="往返测试",
|
||||||
assigned_expert="tester",
|
assigned_expert="tester",
|
||||||
task_description="测试序列化",
|
status=SubTaskStatus.COMPLETED,
|
||||||
depends_on=["dep_a", "dep_b"],
|
|
||||||
parallel_type=ParallelType.SUBTASK_PARALLEL,
|
|
||||||
merge_strategy=MergeStrategy.VOTE,
|
|
||||||
milestone="序列化验证",
|
|
||||||
status=PhaseStatus.COMPLETED,
|
|
||||||
result={"key": "value"},
|
result={"key": "value"},
|
||||||
)
|
)
|
||||||
d = phase.to_dict()
|
d = subtask.to_dict()
|
||||||
restored = PlanPhase.from_dict(d)
|
restored = SubTask.from_dict(d)
|
||||||
|
|
||||||
assert restored.id == phase.id
|
assert restored.id == subtask.id
|
||||||
assert restored.name == phase.name
|
assert restored.description == subtask.description
|
||||||
assert restored.assigned_expert == phase.assigned_expert
|
assert restored.assigned_expert == subtask.assigned_expert
|
||||||
assert restored.task_description == phase.task_description
|
assert restored.status == subtask.status
|
||||||
assert restored.depends_on == phase.depends_on
|
assert restored.result == subtask.result
|
||||||
assert restored.parallel_type == phase.parallel_type
|
|
||||||
assert restored.merge_strategy == phase.merge_strategy
|
|
||||||
assert restored.milestone == phase.milestone
|
|
||||||
assert restored.status == phase.status
|
|
||||||
assert restored.result == phase.result
|
|
||||||
|
|
||||||
def test_to_dict_from_dict_with_none_merge_strategy(self):
|
def test_to_dict_structure(self):
|
||||||
"""merge_strategy 为 None 时的序列化往返"""
|
"""to_dict 返回正确的字典结构"""
|
||||||
phase = PlanPhase(
|
subtask = _make_subtask(
|
||||||
id="no_merge",
|
id="struct_test",
|
||||||
name="无合并",
|
description="结构测试",
|
||||||
assigned_expert="dev",
|
assigned_expert="dev",
|
||||||
task_description="串行任务",
|
status=SubTaskStatus.RUNNING,
|
||||||
parallel_type=ParallelType.SERIAL,
|
result={"output": "data"},
|
||||||
)
|
)
|
||||||
d = phase.to_dict()
|
d = subtask.to_dict()
|
||||||
assert d["merge_strategy"] is None
|
assert d["id"] == "struct_test"
|
||||||
restored = PlanPhase.from_dict(d)
|
assert d["description"] == "结构测试"
|
||||||
assert restored.merge_strategy is None
|
assert d["assigned_expert"] == "dev"
|
||||||
|
assert d["status"] == "running"
|
||||||
|
assert d["result"] == {"output": "data"}
|
||||||
|
|
||||||
|
|
||||||
# ── CollaborationPlan 测试 ────────────────────────────────
|
# ── TeamPlan 测试 ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestCollaborationPlan:
|
class TestTeamPlan:
|
||||||
"""CollaborationPlan 数据模型测试"""
|
"""TeamPlan 数据模型测试"""
|
||||||
|
|
||||||
def test_creation(self):
|
def test_creation(self):
|
||||||
"""创建 CollaborationPlan"""
|
"""创建 TeamPlan"""
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
assert plan.id == "plan_001"
|
assert plan.id == "plan_001"
|
||||||
assert plan.task == "实现用户登录功能"
|
assert plan.task == "实现用户登录功能"
|
||||||
assert len(plan.phases) == 3
|
assert len(plan.subtasks) == 3
|
||||||
assert plan.variables == {"project": "fischer"}
|
|
||||||
assert plan.status == PlanStatus.DRAFT
|
assert plan.status == PlanStatus.DRAFT
|
||||||
assert plan.lead_expert == "architect"
|
assert plan.lead_expert == "architect"
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""默认值:自动生成 id,空子任务列表"""
|
||||||
|
plan = TeamPlan(task="测试任务")
|
||||||
|
assert plan.id is not None
|
||||||
|
assert plan.task == "测试任务"
|
||||||
|
assert plan.subtasks == []
|
||||||
|
assert plan.status == PlanStatus.DRAFT
|
||||||
|
assert plan.lead_expert == ""
|
||||||
|
|
||||||
def test_to_dict_from_dict_roundtrip(self):
|
def test_to_dict_from_dict_roundtrip(self):
|
||||||
"""to_dict / from_dict 往返序列化"""
|
"""to_dict / from_dict 往返序列化"""
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
d = plan.to_dict()
|
d = plan.to_dict()
|
||||||
restored = CollaborationPlan.from_dict(d)
|
restored = TeamPlan.from_dict(d)
|
||||||
|
|
||||||
assert restored.id == plan.id
|
assert restored.id == plan.id
|
||||||
assert restored.task == plan.task
|
assert restored.task == plan.task
|
||||||
assert len(restored.phases) == len(plan.phases)
|
assert len(restored.subtasks) == len(plan.subtasks)
|
||||||
assert restored.variables == plan.variables
|
|
||||||
assert restored.status == plan.status
|
assert restored.status == plan.status
|
||||||
assert restored.lead_expert == plan.lead_expert
|
assert restored.lead_expert == plan.lead_expert
|
||||||
|
|
||||||
for original, restored_phase in zip(plan.phases, restored.phases):
|
for original, restored_st in zip(plan.subtasks, restored.subtasks):
|
||||||
assert restored_phase.id == original.id
|
assert restored_st.id == original.id
|
||||||
assert restored_phase.name == original.name
|
assert restored_st.description == original.description
|
||||||
assert restored_phase.assigned_expert == original.assigned_expert
|
assert restored_st.assigned_expert == original.assigned_expert
|
||||||
assert restored_phase.depends_on == original.depends_on
|
assert restored_st.status == original.status
|
||||||
assert restored_phase.parallel_type == original.parallel_type
|
|
||||||
assert restored_phase.merge_strategy == original.merge_strategy
|
|
||||||
|
|
||||||
def test_validate_valid_plan(self):
|
def test_get_subtask_by_id(self):
|
||||||
"""验证有效计划无错误"""
|
"""get_subtask 根据 ID 获取子任务"""
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
errors = plan.validate()
|
st = plan.get_subtask("s2")
|
||||||
assert errors == []
|
assert st is not None
|
||||||
|
assert st.id == "s2"
|
||||||
|
assert st.description == "设计架构"
|
||||||
|
|
||||||
def test_validate_detects_duplicate_phase_ids(self):
|
def test_get_subtask_with_nonexistent_id_returns_none(self):
|
||||||
"""验证检测到重复阶段 ID"""
|
"""get_subtask 对不存在的 ID 返回 None"""
|
||||||
phases = [
|
|
||||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
|
||||||
_make_phase(id="p1", name="阶段2", assigned_expert="b", task_description="t2"),
|
|
||||||
]
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="dup_plan", task="重复ID测试", phases=phases, lead_expert="a"
|
|
||||||
)
|
|
||||||
errors = plan.validate()
|
|
||||||
assert any("重复的阶段 ID" in e for e in errors)
|
|
||||||
|
|
||||||
def test_validate_detects_missing_depends_on_references(self):
|
|
||||||
"""验证检测到不存在的 depends_on 引用"""
|
|
||||||
phases = [
|
|
||||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
|
||||||
_make_phase(
|
|
||||||
id="p2",
|
|
||||||
name="阶段2",
|
|
||||||
assigned_expert="b",
|
|
||||||
task_description="t2",
|
|
||||||
depends_on=["p1", "nonexistent"],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="missing_dep_plan", task="缺失依赖测试", phases=phases, lead_expert="a"
|
|
||||||
)
|
|
||||||
errors = plan.validate()
|
|
||||||
assert any("不存在的阶段 ID" in e for e in errors)
|
|
||||||
|
|
||||||
def test_validate_detects_circular_dependencies(self):
|
|
||||||
"""验证检测到循环依赖"""
|
|
||||||
phases = [
|
|
||||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1", depends_on=["p3"]),
|
|
||||||
_make_phase(id="p2", name="阶段2", assigned_expert="b", task_description="t2", depends_on=["p1"]),
|
|
||||||
_make_phase(id="p3", name="阶段3", assigned_expert="c", task_description="t3", depends_on=["p2"]),
|
|
||||||
]
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="cycle_plan", task="循环依赖测试", phases=phases, lead_expert="a"
|
|
||||||
)
|
|
||||||
errors = plan.validate()
|
|
||||||
assert any("循环依赖" in e for e in errors)
|
|
||||||
|
|
||||||
def test_validate_detects_competitive_parallel_without_merge_strategy(self):
|
|
||||||
"""验证检测到 COMPETITIVE_PARALLEL 缺少 merge_strategy"""
|
|
||||||
phases = [
|
|
||||||
_make_phase(
|
|
||||||
id="p1",
|
|
||||||
name="竞争阶段",
|
|
||||||
assigned_expert="a",
|
|
||||||
task_description="竞争任务",
|
|
||||||
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
|
|
||||||
merge_strategy=None,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="no_merge_plan", task="缺少合并策略测试", phases=phases, lead_expert="a"
|
|
||||||
)
|
|
||||||
errors = plan.validate()
|
|
||||||
assert any("COMPETITIVE_PARALLEL" in e and "merge_strategy" in e for e in errors)
|
|
||||||
|
|
||||||
def test_get_ready_phases_returns_phases_with_completed_dependencies(self):
|
|
||||||
"""get_ready_phases 返回依赖已完成的阶段"""
|
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
# 初始状态:p1 无依赖,应该就绪
|
assert plan.get_subtask("nonexistent") is None
|
||||||
ready = plan.get_ready_phases()
|
|
||||||
assert len(ready) == 1
|
|
||||||
assert ready[0].id == "p1"
|
|
||||||
|
|
||||||
# 完成 p1 后,p2 应该就绪
|
def test_update_subtask_status(self):
|
||||||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"analysis": "done"})
|
"""update_subtask_status 更新子任务状态和结果"""
|
||||||
ready = plan.get_ready_phases()
|
|
||||||
assert len(ready) == 1
|
|
||||||
assert ready[0].id == "p2"
|
|
||||||
|
|
||||||
# 完成 p2 后,p3 应该就绪
|
|
||||||
plan.update_phase_status("p2", PhaseStatus.COMPLETED, {"design": "done"})
|
|
||||||
ready = plan.get_ready_phases()
|
|
||||||
assert len(ready) == 1
|
|
||||||
assert ready[0].id == "p3"
|
|
||||||
|
|
||||||
def test_get_ready_phases_returns_empty_when_dependencies_not_met(self):
|
|
||||||
"""get_ready_phases 在依赖未满足时返回空列表"""
|
|
||||||
phases = [
|
|
||||||
_make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"),
|
|
||||||
_make_phase(
|
|
||||||
id="p2",
|
|
||||||
name="阶段2",
|
|
||||||
assigned_expert="b",
|
|
||||||
task_description="t2",
|
|
||||||
depends_on=["p1"],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="dep_plan", task="依赖未满足测试", phases=phases, lead_expert="a"
|
|
||||||
)
|
|
||||||
# p2 依赖 p1,p1 未完成,所以 p2 不就绪
|
|
||||||
# 但 p1 无依赖,所以 p1 就绪
|
|
||||||
ready = plan.get_ready_phases()
|
|
||||||
assert len(ready) == 1
|
|
||||||
assert ready[0].id == "p1"
|
|
||||||
|
|
||||||
# 将 p1 设为 IN_PROGRESS(未 COMPLETED),p2 仍不就绪
|
|
||||||
plan.update_phase_status("p1", PhaseStatus.IN_PROGRESS)
|
|
||||||
ready = plan.get_ready_phases()
|
|
||||||
assert len(ready) == 0
|
|
||||||
|
|
||||||
def test_update_phase_status(self):
|
|
||||||
"""update_phase_status 更新阶段状态和结果"""
|
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"output": "分析完成"})
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "分析完成"})
|
||||||
phase = plan.get_phase("p1")
|
st = plan.get_subtask("s1")
|
||||||
assert phase is not None
|
assert st is not None
|
||||||
assert phase.status == PhaseStatus.COMPLETED
|
assert st.status == SubTaskStatus.COMPLETED
|
||||||
assert phase.result == {"output": "分析完成"}
|
assert st.result == {"output": "分析完成"}
|
||||||
|
|
||||||
# 不传 result 时不应覆盖已有 result
|
def test_update_subtask_status_without_result(self):
|
||||||
plan.update_phase_status("p2", PhaseStatus.IN_PROGRESS)
|
"""update_subtask_status 不传 result 时不覆盖已有 result"""
|
||||||
phase2 = plan.get_phase("p2")
|
|
||||||
assert phase2 is not None
|
|
||||||
assert phase2.status == PhaseStatus.IN_PROGRESS
|
|
||||||
assert phase2.result is None
|
|
||||||
|
|
||||||
def test_get_phase_by_id(self):
|
|
||||||
"""get_phase 根据 ID 获取阶段"""
|
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
phase = plan.get_phase("p2")
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "done"})
|
||||||
assert phase is not None
|
plan.update_subtask_status("s1", SubTaskStatus.RUNNING)
|
||||||
assert phase.id == "p2"
|
st = plan.get_subtask("s1")
|
||||||
assert phase.name == "架构设计"
|
assert st is not None
|
||||||
|
assert st.status == SubTaskStatus.RUNNING
|
||||||
|
assert st.result == {"output": "done"}
|
||||||
|
|
||||||
def test_get_phase_with_nonexistent_id_returns_none(self):
|
def test_completed_subtasks_property(self):
|
||||||
"""get_phase 对不存在的 ID 返回 None"""
|
"""completed_subtasks 返回已完成的子任务"""
|
||||||
plan = _make_valid_plan()
|
plan = _make_valid_plan()
|
||||||
phase = plan.get_phase("nonexistent")
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||||
assert phase is None
|
plan.update_subtask_status("s2", SubTaskStatus.COMPLETED)
|
||||||
|
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
|
||||||
|
|
||||||
|
completed = plan.completed_subtasks
|
||||||
|
assert len(completed) == 2
|
||||||
|
assert {st.id for st in completed} == {"s1", "s2"}
|
||||||
|
|
||||||
|
def test_failed_subtasks_property(self):
|
||||||
|
"""failed_subtasks 返回失败的子任务"""
|
||||||
|
plan = _make_valid_plan()
|
||||||
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||||
|
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
|
||||||
|
plan.update_subtask_status("s3", SubTaskStatus.FAILED)
|
||||||
|
|
||||||
|
failed = plan.failed_subtasks
|
||||||
|
assert len(failed) == 2
|
||||||
|
assert {st.id for st in failed} == {"s2", "s3"}
|
||||||
|
|
||||||
|
def test_all_done_property_when_all_completed(self):
|
||||||
|
"""all_done 当所有子任务完成时返回 True"""
|
||||||
|
plan = _make_valid_plan()
|
||||||
|
for st in plan.subtasks:
|
||||||
|
plan.update_subtask_status(st.id, SubTaskStatus.COMPLETED)
|
||||||
|
assert plan.all_done is True
|
||||||
|
|
||||||
|
def test_all_done_property_when_some_failed(self):
|
||||||
|
"""all_done 当所有子任务完成或失败时返回 True"""
|
||||||
|
plan = _make_valid_plan()
|
||||||
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||||
|
plan.update_subtask_status("s2", SubTaskStatus.FAILED)
|
||||||
|
plan.update_subtask_status("s3", SubTaskStatus.COMPLETED)
|
||||||
|
assert plan.all_done is True
|
||||||
|
|
||||||
|
def test_all_done_property_when_pending(self):
|
||||||
|
"""all_done 当有子任务未完成时返回 False"""
|
||||||
|
plan = _make_valid_plan()
|
||||||
|
plan.update_subtask_status("s1", SubTaskStatus.COMPLETED)
|
||||||
|
# s2 and s3 still pending
|
||||||
|
assert plan.all_done is False
|
||||||
|
|
||||||
|
def test_all_done_property_empty_plan(self):
|
||||||
|
"""all_done 当没有子任务时返回 True(vacuous truth)"""
|
||||||
|
plan = TeamPlan(task="空计划")
|
||||||
|
assert plan.all_done is True
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,6 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||||
from agentkit.experts.router import (
|
from agentkit.experts.router import (
|
||||||
|
|
@ -113,7 +111,6 @@ class TestExpertTeamRoutingResult:
|
||||||
assert result.specified_experts == []
|
assert result.specified_experts == []
|
||||||
assert result.task_content == ""
|
assert result.task_content == ""
|
||||||
assert result.auto_compose is False
|
assert result.auto_compose is False
|
||||||
assert result.complexity == 0.0
|
|
||||||
assert result.match_method == ""
|
assert result.match_method == ""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -159,47 +156,14 @@ class TestExpertTeamRouterResolve:
|
||||||
assert result.specified_experts == ["analyst"]
|
assert result.specified_experts == ["analyst"]
|
||||||
assert result.task_content == "请分析这份报告"
|
assert result.task_content == "请分析这份报告"
|
||||||
|
|
||||||
def test_high_complexity_triggers_team_suggestion(self):
|
def test_no_team_prefix_no_team_mode(self):
|
||||||
"""高复杂度 (>=0.7) 触发团队模式建议"""
|
"""无 @team 前缀不触发团队模式"""
|
||||||
router = ExpertTeamRouter()
|
|
||||||
result = router.resolve("分析这个复杂系统", complexity=0.8)
|
|
||||||
assert result.matched is True
|
|
||||||
assert result.team_mode is True
|
|
||||||
assert result.auto_compose is True
|
|
||||||
assert result.match_method == "complexity_suggestion"
|
|
||||||
assert result.complexity == 0.8
|
|
||||||
|
|
||||||
def test_high_complexity_exact_threshold(self):
|
|
||||||
"""复杂度恰好等于阈值 0.7 也触发团队模式"""
|
|
||||||
router = ExpertTeamRouter()
|
|
||||||
result = router.resolve("任务内容", complexity=0.7)
|
|
||||||
assert result.matched is True
|
|
||||||
assert result.team_mode is True
|
|
||||||
assert result.match_method == "complexity_suggestion"
|
|
||||||
|
|
||||||
def test_low_complexity_no_team_mode(self):
|
|
||||||
"""低复杂度 (<0.7) 不触发团队模式"""
|
|
||||||
router = ExpertTeamRouter()
|
|
||||||
result = router.resolve("简单问题", complexity=0.3)
|
|
||||||
assert result.matched is False
|
|
||||||
assert result.team_mode is False
|
|
||||||
assert result.task_content == "简单问题"
|
|
||||||
assert result.complexity == 0.3
|
|
||||||
|
|
||||||
def test_no_team_prefix_no_complexity(self):
|
|
||||||
"""无 @team 前缀且无复杂度时不触发团队模式"""
|
|
||||||
router = ExpertTeamRouter()
|
router = ExpertTeamRouter()
|
||||||
result = router.resolve("普通问题")
|
result = router.resolve("普通问题")
|
||||||
assert result.matched is False
|
assert result.matched is False
|
||||||
assert result.team_mode is False
|
assert result.team_mode is False
|
||||||
|
assert result.task_content == "普通问题"
|
||||||
def test_team_prefix_takes_priority_over_complexity(self):
|
assert result.match_method == ""
|
||||||
"""@team 前缀优先于复杂度判断"""
|
|
||||||
router = ExpertTeamRouter()
|
|
||||||
result = router.resolve("@team:analyst 任务", complexity=0.1)
|
|
||||||
assert result.matched is True
|
|
||||||
assert result.match_method == "explicit_team"
|
|
||||||
assert result.specified_experts == ["analyst"]
|
|
||||||
|
|
||||||
def test_nonexistent_expert_still_included(self):
|
def test_nonexistent_expert_still_included(self):
|
||||||
"""指定不存在的专家名仍包含在列表中"""
|
"""指定不存在的专家名仍包含在列表中"""
|
||||||
|
|
@ -214,6 +178,35 @@ class TestExpertTeamRouterResolve:
|
||||||
assert result.task_content == "@team"
|
assert result.task_content == "@team"
|
||||||
assert result.auto_compose is True
|
assert result.auto_compose is True
|
||||||
|
|
||||||
|
def test_resolve_strips_leading_whitespace(self):
|
||||||
|
"""resolve() 对前导空白做 strip()"""
|
||||||
|
router = ExpertTeamRouter()
|
||||||
|
result = router.resolve(" @team:analyst 任务内容")
|
||||||
|
assert result.matched is True
|
||||||
|
assert result.team_mode is True
|
||||||
|
assert result.specified_experts == ["analyst"]
|
||||||
|
assert result.task_content == "任务内容"
|
||||||
|
|
||||||
|
def test_invalid_expert_names_rejected(self):
|
||||||
|
"""无效专家名被过滤掉"""
|
||||||
|
router = ExpertTeamRouter()
|
||||||
|
# 包含特殊字符的名称应被过滤
|
||||||
|
result = router.resolve("@team:analyst,inva!id,bad/name 任务")
|
||||||
|
# 仅保留合法名称
|
||||||
|
assert "analyst" in result.specified_experts
|
||||||
|
assert "inva!id" not in result.specified_experts
|
||||||
|
assert "bad/name" not in result.specified_experts
|
||||||
|
|
||||||
|
def test_max_experts_limit(self):
|
||||||
|
"""指定超过 MAX_EXPERTS 数量的专家时截断"""
|
||||||
|
router = ExpertTeamRouter()
|
||||||
|
# 构造 15 个专家名
|
||||||
|
names = ",".join(f"expert{i}" for i in range(15))
|
||||||
|
result = router.resolve(f"@team:{names} 任务")
|
||||||
|
from agentkit.experts.router import MAX_EXPERTS
|
||||||
|
|
||||||
|
assert len(result.specified_experts) == MAX_EXPERTS
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeamRouter.resolve_expert_configs 测试 ───────────
|
# ── ExpertTeamRouter.resolve_expert_configs 测试 ───────────
|
||||||
|
|
||||||
|
|
@ -275,6 +268,25 @@ class TestExpertTeamRouterResolveExpertConfigs:
|
||||||
configs = router.resolve_expert_configs(["analyst"])
|
configs = router.resolve_expert_configs(["analyst"])
|
||||||
assert configs[0].bound_skills == ["data_query"]
|
assert configs[0].bound_skills == ["data_query"]
|
||||||
|
|
||||||
|
def test_resolve_first_expert_is_lead(self):
|
||||||
|
"""第一个专家被指定为 lead"""
|
||||||
|
registry = _make_registry_with_templates()
|
||||||
|
router = ExpertTeamRouter(template_registry=registry)
|
||||||
|
|
||||||
|
configs = router.resolve_expert_configs(["analyst", "strategist", "reviewer"])
|
||||||
|
assert configs[0].is_lead is True
|
||||||
|
assert configs[1].is_lead is False
|
||||||
|
assert configs[2].is_lead is False
|
||||||
|
|
||||||
|
def test_resolve_skips_invalid_names(self):
|
||||||
|
"""跳过无效的专家名"""
|
||||||
|
router = ExpertTeamRouter()
|
||||||
|
# 包含无效字符的名称应被跳过
|
||||||
|
configs = router.resolve_expert_configs(["valid_name", "inva!id", "also_valid"])
|
||||||
|
assert len(configs) == 2
|
||||||
|
assert configs[0].name == "valid_name"
|
||||||
|
assert configs[1].name == "also_valid"
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeamRouter 构造测试 ─────────────────────────────
|
# ── ExpertTeamRouter 构造测试 ─────────────────────────────
|
||||||
|
|
||||||
|
|
@ -293,7 +305,40 @@ class TestExpertTeamRouterInit:
|
||||||
router = ExpertTeamRouter(template_registry=registry)
|
router = ExpertTeamRouter(template_registry=registry)
|
||||||
assert router._registry is registry
|
assert router._registry is registry
|
||||||
|
|
||||||
def test_complexity_threshold(self):
|
|
||||||
"""复杂度阈值默认为 0.7"""
|
# ── ExpertTeamRouter.can_handle 测试 ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpertTeamRouterCanHandle:
|
||||||
|
"""ExpertTeamRouter.can_handle 方法测试"""
|
||||||
|
|
||||||
|
def test_can_handle_with_matching_template_name(self):
|
||||||
|
"""模板名出现在内容中时返回 True"""
|
||||||
|
registry = _make_registry_with_templates()
|
||||||
|
router = ExpertTeamRouter(template_registry=registry)
|
||||||
|
|
||||||
|
assert router.can_handle("请 analyst 帮我分析") is True
|
||||||
|
|
||||||
|
def test_can_handle_with_matching_description(self):
|
||||||
|
"""模板描述词出现在内容中时返回 True"""
|
||||||
|
registry = _make_registry_with_templates()
|
||||||
|
router = ExpertTeamRouter(template_registry=registry)
|
||||||
|
|
||||||
|
# 描述为 "analyst 模板" 等,"模板" 长度 > 2 应匹配
|
||||||
|
assert router.can_handle("我需要一个模板来参考") is True
|
||||||
|
|
||||||
|
def test_can_handle_no_templates(self):
|
||||||
|
"""无注册模板时返回 False"""
|
||||||
router = ExpertTeamRouter()
|
router = ExpertTeamRouter()
|
||||||
assert router.COMPLEXITY_THRESHOLD == 0.7
|
# 默认注册中心可能为空
|
||||||
|
# 强制清空以测试
|
||||||
|
router._registry._templates.clear()
|
||||||
|
assert router.can_handle("任何内容") is False
|
||||||
|
|
||||||
|
def test_can_handle_with_templates_no_match(self):
|
||||||
|
"""有模板但内容不匹配时仍返回 True(auto-compose 可组建团队)"""
|
||||||
|
registry = _make_registry_with_templates()
|
||||||
|
router = ExpertTeamRouter(template_registry=registry)
|
||||||
|
|
||||||
|
# 内容与任何模板名/描述都不匹配,但有模板存在 → auto-compose 可用
|
||||||
|
assert router.can_handle("完全无关的内容 xyz123") is True
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""ExpertTeam 容器单元测试"""
|
"""ExpertTeam 容器单元测试 (hub-and-spoke 模式)"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
@ -11,11 +11,6 @@ from agentkit.core.handoff_transport import InProcessHandoffTransport
|
||||||
from agentkit.core.shared_workspace import SharedWorkspace
|
from agentkit.core.shared_workspace import SharedWorkspace
|
||||||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||||||
from agentkit.experts.expert import Expert
|
from agentkit.experts.expert import Expert
|
||||||
from agentkit.experts.plan import (
|
|
||||||
CollaborationPlan,
|
|
||||||
PlanPhase,
|
|
||||||
PlanStatus,
|
|
||||||
)
|
|
||||||
from agentkit.experts.registry import ExpertTemplateRegistry
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
||||||
from agentkit.experts.team import ExpertTeam, TeamStatus
|
from agentkit.experts.team import ExpertTeam, TeamStatus
|
||||||
|
|
||||||
|
|
@ -84,27 +79,6 @@ def _make_mock_expert(
|
||||||
return expert
|
return expert
|
||||||
|
|
||||||
|
|
||||||
def _make_valid_plan(
|
|
||||||
plan_id: str = "plan_1",
|
|
||||||
task: str = "测试任务",
|
|
||||||
lead_expert: str = "lead",
|
|
||||||
) -> CollaborationPlan:
|
|
||||||
"""创建有效的 CollaborationPlan"""
|
|
||||||
return CollaborationPlan(
|
|
||||||
id=plan_id,
|
|
||||||
task=task,
|
|
||||||
phases=[
|
|
||||||
PlanPhase(
|
|
||||||
id="phase_1",
|
|
||||||
name="阶段1",
|
|
||||||
assigned_expert=lead_expert,
|
|
||||||
task_description="执行任务",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
lead_expert=lead_expert,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam 创建测试 ───────────────────────────────────
|
# ── ExpertTeam 创建测试 ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -119,7 +93,6 @@ class TestExpertTeamCreation:
|
||||||
assert len(team.team_id) > 0
|
assert len(team.team_id) > 0
|
||||||
assert team.status == TeamStatus.FORMING
|
assert team.status == TeamStatus.FORMING
|
||||||
assert team.lead_expert is None
|
assert team.lead_expert is None
|
||||||
assert team.plan is None
|
|
||||||
assert team.experts == []
|
assert team.experts == []
|
||||||
assert team.active_experts == []
|
assert team.active_experts == []
|
||||||
|
|
||||||
|
|
@ -173,7 +146,7 @@ class TestExpertTeamCreateTeam:
|
||||||
|
|
||||||
assert team._lead_expert_name == "lead"
|
assert team._lead_expert_name == "lead"
|
||||||
assert team.lead_expert is mock_expert
|
assert team.lead_expert is mock_expert
|
||||||
assert team.status == TeamStatus.PLANNING
|
assert team.status == TeamStatus.EXECUTING
|
||||||
assert mock_expert.team_id == team.team_id
|
assert mock_expert.team_id == team.team_id
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -195,7 +168,7 @@ class TestExpertTeamCreateTeam:
|
||||||
|
|
||||||
assert len(team.experts) == 2
|
assert len(team.experts) == 2
|
||||||
assert team._lead_expert_name == "lead"
|
assert team._lead_expert_name == "lead"
|
||||||
assert team.status == TeamStatus.PLANNING
|
assert team.status == TeamStatus.EXECUTING
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_team_without_pool_raises(self):
|
async def test_create_team_without_pool_raises(self):
|
||||||
|
|
@ -411,64 +384,6 @@ class TestExpertTeamRemoveExpert:
|
||||||
assert last_left[0][1]["expert_name"] == "member1"
|
assert last_left[0][1]["expert_name"] == "member1"
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam.update_plan 测试 ────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpertTeamUpdatePlan:
|
|
||||||
"""ExpertTeam.update_plan 协作计划更新测试"""
|
|
||||||
|
|
||||||
def test_update_plan_with_valid_plan(self):
|
|
||||||
"""有效计划更新成功,返回受影响的 Expert 名称"""
|
|
||||||
team = ExpertTeam()
|
|
||||||
plan = _make_valid_plan(lead_expert="lead")
|
|
||||||
|
|
||||||
affected = team.update_plan(plan)
|
|
||||||
|
|
||||||
assert team.plan is plan
|
|
||||||
assert "lead" in affected
|
|
||||||
|
|
||||||
def test_update_plan_confirmed_sets_executing(self):
|
|
||||||
"""CONFIRMED 状态的计划将团队状态设为 EXECUTING"""
|
|
||||||
team = ExpertTeam()
|
|
||||||
plan = _make_valid_plan(lead_expert="lead")
|
|
||||||
plan.status = PlanStatus.CONFIRMED
|
|
||||||
|
|
||||||
team.update_plan(plan)
|
|
||||||
|
|
||||||
assert team.status == TeamStatus.EXECUTING
|
|
||||||
|
|
||||||
def test_update_plan_with_invalid_plan_no_update(self):
|
|
||||||
"""无效计划(validate 返回错误)不更新,返回验证错误列表"""
|
|
||||||
team = ExpertTeam()
|
|
||||||
# 创建有循环依赖的无效计划
|
|
||||||
plan = CollaborationPlan(
|
|
||||||
id="bad_plan",
|
|
||||||
task="无效任务",
|
|
||||||
phases=[
|
|
||||||
PlanPhase(
|
|
||||||
id="p1",
|
|
||||||
name="阶段1",
|
|
||||||
assigned_expert="a",
|
|
||||||
task_description="t1",
|
|
||||||
depends_on=["p2"],
|
|
||||||
),
|
|
||||||
PlanPhase(
|
|
||||||
id="p2",
|
|
||||||
name="阶段2",
|
|
||||||
assigned_expert="b",
|
|
||||||
task_description="t2",
|
|
||||||
depends_on=["p1"],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
lead_expert="lead",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = team.update_plan(plan)
|
|
||||||
|
|
||||||
assert len(result) > 0 # 返回验证错误而非空列表
|
|
||||||
assert team.plan is None # 未更新
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam.broadcast_user_message 测试 ─────────────────
|
# ── ExpertTeam.broadcast_user_message 测试 ─────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -535,42 +450,6 @@ class TestExpertTeamGetSharedContext:
|
||||||
assert context == {}
|
assert context == {}
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam.generate_plan 测试 ──────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpertTeamGeneratePlan:
|
|
||||||
"""ExpertTeam.generate_plan 计划生成测试"""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_plan(self):
|
|
||||||
"""生成空的 CollaborationPlan"""
|
|
||||||
pool = _make_mock_pool()
|
|
||||||
team = ExpertTeam(pool=pool)
|
|
||||||
|
|
||||||
lead_config = _make_expert_config(name="lead", is_lead=True)
|
|
||||||
with patch.object(Expert, "create", new_callable=AsyncMock) as mock_create:
|
|
||||||
lead_expert = _make_mock_expert(name="lead", is_lead=True)
|
|
||||||
mock_create.return_value = lead_expert
|
|
||||||
await team.create_team(lead_config)
|
|
||||||
|
|
||||||
plan = await team.generate_plan("分析数据")
|
|
||||||
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.task == "分析数据"
|
|
||||||
assert plan.lead_expert == "lead"
|
|
||||||
assert plan.phases == []
|
|
||||||
assert team.plan is plan
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_generate_plan_without_lead(self):
|
|
||||||
"""没有 Lead Expert 时生成计划,lead_expert 为空字符串"""
|
|
||||||
team = ExpertTeam()
|
|
||||||
|
|
||||||
plan = await team.generate_plan("测试任务")
|
|
||||||
|
|
||||||
assert plan.lead_expert == ""
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam.dissolve 测试 ───────────────────────────────
|
# ── ExpertTeam.dissolve 测试 ───────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -667,11 +546,6 @@ class TestExpertTeamDissolvedOperations:
|
||||||
# 解散后状态为 DISSOLVED
|
# 解散后状态为 DISSOLVED
|
||||||
assert team.status == TeamStatus.DISSOLVED
|
assert team.status == TeamStatus.DISSOLVED
|
||||||
|
|
||||||
# 再次 create_team 时,由于 experts 已清空,
|
|
||||||
# 但 pool 仍然存在,理论上可以重新创建
|
|
||||||
# 但这里验证状态是 DISSOLVED
|
|
||||||
assert team.status == TeamStatus.DISSOLVED
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExpertTeam.lead_expert 属性测试 ────────────────────────
|
# ── ExpertTeam.lead_expert 属性测试 ────────────────────────
|
||||||
|
|
||||||
|
|
@ -742,10 +616,11 @@ class TestExpertTeamBuildContext:
|
||||||
|
|
||||||
context = team._build_team_context(lead_config, [member_config])
|
context = team._build_team_context(lead_config, [member_config])
|
||||||
|
|
||||||
assert "You are part of an Expert Team." in context
|
assert "hub-and-spoke mode" in context
|
||||||
assert "Lead Expert: lead (领导者)" in context
|
assert "Lead Expert (hub): lead (领导者)" in context
|
||||||
assert "Team Member: analyst (分析师), Skills: data_query" in context
|
assert "Team Member (spoke): analyst (分析师)" in context
|
||||||
assert "send_message() and request_assist()" in context
|
assert "data_query" in context
|
||||||
|
assert "depth=1" in context
|
||||||
|
|
||||||
def test_build_team_context_no_lead(self):
|
def test_build_team_context_no_lead(self):
|
||||||
"""没有 Lead Expert 时构建上下文"""
|
"""没有 Lead Expert 时构建上下文"""
|
||||||
|
|
@ -754,8 +629,9 @@ class TestExpertTeamBuildContext:
|
||||||
|
|
||||||
context = team._build_team_context(None, [member_config])
|
context = team._build_team_context(None, [member_config])
|
||||||
|
|
||||||
assert "Lead Expert" not in context
|
# 不应出现具体的 Lead Expert (hub): name 行
|
||||||
assert "Team Member: analyst" in context
|
assert "Lead Expert (hub):" not in context
|
||||||
|
assert "Team Member (spoke): analyst" in context
|
||||||
|
|
||||||
def test_build_team_context_skips_lead_in_members(self):
|
def test_build_team_context_skips_lead_in_members(self):
|
||||||
"""成员列表中包含 Lead 时跳过"""
|
"""成员列表中包含 Lead 时跳过"""
|
||||||
|
|
@ -765,4 +641,4 @@ class TestExpertTeamBuildContext:
|
||||||
context = team._build_team_context(lead_config, [lead_config])
|
context = team._build_team_context(lead_config, [lead_config])
|
||||||
|
|
||||||
# Lead 不应出现在 Team Member 行
|
# Lead 不应出现在 Team Member 行
|
||||||
assert "Team Member: lead" not in context
|
assert "Team Member (spoke): lead" not in context
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,472 @@
|
||||||
|
"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
- ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配
|
||||||
|
- ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述
|
||||||
|
- ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.builtin import ToolSearchTool
|
||||||
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test Helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTool(Tool):
|
||||||
|
"""用于测试的 Fake Tool"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema,
|
||||||
|
tags=tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tools() -> list[Tool]:
|
||||||
|
"""创建一组测试工具"""
|
||||||
|
return [
|
||||||
|
FakeTool(
|
||||||
|
name="read_file",
|
||||||
|
description="Read the contents of a file from the filesystem.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string", "description": "file path to read"},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
},
|
||||||
|
tags=["io", "file"],
|
||||||
|
),
|
||||||
|
FakeTool(
|
||||||
|
name="write_file",
|
||||||
|
description="Write content to a file on the filesystem.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string", "description": "file path to write"},
|
||||||
|
"content": {"type": "string", "description": "content to write"},
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
},
|
||||||
|
tags=["io", "file"],
|
||||||
|
),
|
||||||
|
FakeTool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web for information using a search engine.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "search query"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
tags=["web", "search"],
|
||||||
|
),
|
||||||
|
FakeTool(
|
||||||
|
name="run_tests",
|
||||||
|
description="Run project tests to verify code changes.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"commands": {"type": "array", "description": "test commands"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tags=["testing", "verification"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── ToolSearchIndex Tests ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolSearchIndex:
|
||||||
|
"""ToolSearchIndex BM25 搜索测试"""
|
||||||
|
|
||||||
|
def test_empty_tools(self):
|
||||||
|
"""空工具列表构建索引不报错"""
|
||||||
|
index = ToolSearchIndex([])
|
||||||
|
assert len(index) == 0
|
||||||
|
assert index.search("anything") == []
|
||||||
|
|
||||||
|
def test_single_tool_match(self):
|
||||||
|
"""单工具索引,匹配查询返回该工具"""
|
||||||
|
tools = _make_tools()[:1]
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("read file")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "read_file"
|
||||||
|
|
||||||
|
def test_relevance_ranking(self):
|
||||||
|
"""多工具索引,相关工具排在前面"""
|
||||||
|
tools = _make_tools()
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("web search")
|
||||||
|
assert len(results) > 0
|
||||||
|
# web_search 应该排在最前
|
||||||
|
assert results[0].name == "web_search"
|
||||||
|
|
||||||
|
def test_top_k_limit(self):
|
||||||
|
"""top_k 限制返回数量"""
|
||||||
|
tools = _make_tools()
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("file", top_k=2)
|
||||||
|
assert len(results) <= 2
|
||||||
|
|
||||||
|
def test_no_match_returns_empty(self):
|
||||||
|
"""无匹配时返回空列表"""
|
||||||
|
tools = _make_tools()
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("xyzzy_nonexistent")
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_empty_query_returns_empty(self):
|
||||||
|
"""空查询返回空列表"""
|
||||||
|
tools = _make_tools()
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
assert index.search("") == []
|
||||||
|
assert index.search(" ") == []
|
||||||
|
|
||||||
|
def test_top_k_zero_returns_empty(self):
|
||||||
|
"""top_k=0 返回空列表"""
|
||||||
|
tools = _make_tools()
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
assert index.search("file", top_k=0) == []
|
||||||
|
|
||||||
|
def test_snake_case_tokenization(self):
|
||||||
|
"""snake_case 工具名被正确分词"""
|
||||||
|
tool = FakeTool(
|
||||||
|
name="read_file",
|
||||||
|
description="Read file contents.",
|
||||||
|
)
|
||||||
|
index = ToolSearchIndex([tool])
|
||||||
|
# 搜索 "read" 应该匹配
|
||||||
|
results = index.search("read")
|
||||||
|
assert len(results) == 1
|
||||||
|
# 搜索 "file" 也应该匹配
|
||||||
|
results = index.search("file")
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_search_includes_parameter_descriptions(self):
|
||||||
|
"""搜索能匹配参数描述中的关键词"""
|
||||||
|
tool = FakeTool(
|
||||||
|
name="custom_tool",
|
||||||
|
description="A custom tool.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"database_url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "PostgreSQL connection string",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
index = ToolSearchIndex([tool])
|
||||||
|
results = index.search("postgresql database")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "custom_tool"
|
||||||
|
|
||||||
|
def test_search_includes_tags(self):
|
||||||
|
"""搜索能匹配标签中的关键词"""
|
||||||
|
tool = FakeTool(
|
||||||
|
name="data_tool",
|
||||||
|
description="Process data.",
|
||||||
|
tags=["etl", "pipeline"],
|
||||||
|
)
|
||||||
|
index = ToolSearchIndex([tool])
|
||||||
|
results = index.search("pipeline etl")
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].name == "data_tool"
|
||||||
|
|
||||||
|
def test_invalid_k1_raises(self):
|
||||||
|
"""k1 < 0 抛出 ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="k1"):
|
||||||
|
ToolSearchIndex(_make_tools(), k1=-1.0)
|
||||||
|
|
||||||
|
def test_invalid_b_raises(self):
|
||||||
|
"""b 不在 [0,1] 范围抛出 ValueError"""
|
||||||
|
with pytest.raises(ValueError, match="b"):
|
||||||
|
ToolSearchIndex(_make_tools(), b=1.5)
|
||||||
|
with pytest.raises(ValueError, match="b"):
|
||||||
|
ToolSearchIndex(_make_tools(), b=-0.1)
|
||||||
|
|
||||||
|
def test_multiple_results_sorted_by_score(self):
|
||||||
|
"""多个匹配结果按分数降序排列"""
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="search_web", description="Search the web."),
|
||||||
|
FakeTool(name="search_files", description="Search files on disk."),
|
||||||
|
FakeTool(name="unrelated", description="Do something unrelated."),
|
||||||
|
]
|
||||||
|
index = ToolSearchIndex(tools)
|
||||||
|
results = index.search("search")
|
||||||
|
# 两个包含 "search" 的工具应该返回,unrelated 不返回
|
||||||
|
assert len(results) == 2
|
||||||
|
names = [r.name for r in results]
|
||||||
|
assert "unrelated" not in names
|
||||||
|
|
||||||
|
|
||||||
|
# ── ToolSearchTool Tests ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolSearchTool:
|
||||||
|
"""ToolSearchTool 工具测试"""
|
||||||
|
|
||||||
|
def test_tool_name_and_schema(self):
|
||||||
|
"""工具名称和 schema 正确"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
tool = ToolSearchTool(search_index=index)
|
||||||
|
assert tool.name == "tool_search"
|
||||||
|
assert "query" in tool.input_schema["properties"]
|
||||||
|
assert "query" in tool.input_schema["required"]
|
||||||
|
|
||||||
|
async def test_execute_returns_results(self):
|
||||||
|
"""执行搜索返回匹配工具的完整描述"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
tool = ToolSearchTool(search_index=index)
|
||||||
|
result = await tool.execute(query="web search")
|
||||||
|
|
||||||
|
assert result["count"] > 0
|
||||||
|
assert result["query"] == "web search"
|
||||||
|
first = result["results"][0]
|
||||||
|
assert "name" in first
|
||||||
|
assert "description" in first
|
||||||
|
assert "parameters" in first
|
||||||
|
assert first["name"] == "web_search"
|
||||||
|
|
||||||
|
async def test_execute_empty_query_returns_error(self):
|
||||||
|
"""空查询返回错误"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
tool = ToolSearchTool(search_index=index)
|
||||||
|
result = await tool.execute(query="")
|
||||||
|
assert "error" in result
|
||||||
|
assert result["results"] == []
|
||||||
|
|
||||||
|
async def test_execute_no_match(self):
|
||||||
|
"""无匹配返回空结果和提示消息"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
tool = ToolSearchTool(search_index=index)
|
||||||
|
result = await tool.execute(query="zzz_nonexistent")
|
||||||
|
assert result["count"] == 0
|
||||||
|
assert result["results"] == []
|
||||||
|
assert "message" in result
|
||||||
|
|
||||||
|
async def test_execute_respects_top_k(self):
|
||||||
|
"""top_k 限制返回数量"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
tool = ToolSearchTool(search_index=index, top_k=1)
|
||||||
|
result = await tool.execute(query="file")
|
||||||
|
assert result["count"] <= 1
|
||||||
|
|
||||||
|
def test_invalid_top_k_raises(self):
|
||||||
|
"""top_k < 1 抛出 ValueError"""
|
||||||
|
index = ToolSearchIndex(_make_tools())
|
||||||
|
with pytest.raises(ValueError, match="top_k"):
|
||||||
|
ToolSearchTool(search_index=index, top_k=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ── ReActEngine Tiered Injection Tests ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestReActTieredInjection:
|
||||||
|
"""ReActEngine 工具描述分层注入测试"""
|
||||||
|
|
||||||
|
def _make_engine(self, **kwargs: Any):
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
gateway = MagicMock()
|
||||||
|
return ReActEngine(llm_gateway=gateway, **kwargs)
|
||||||
|
|
||||||
|
def test_core_tools_get_full_description(self):
|
||||||
|
"""Core 工具注入完整描述(含参数)"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(
|
||||||
|
name="read_file",
|
||||||
|
description="Read a file.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string", "description": "file path"},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
# 核心工具区域存在
|
||||||
|
assert "核心工具" in prompt
|
||||||
|
# 参数描述被注入
|
||||||
|
assert "path" in prompt
|
||||||
|
assert "file path" in prompt
|
||||||
|
|
||||||
|
def test_extended_tools_get_one_line_only(self):
|
||||||
|
"""Extended 工具只注入名称+一行描述(无参数)"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(
|
||||||
|
name="custom_extended",
|
||||||
|
description="A custom extended tool for testing.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"secret_param": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "SECRET_PARAM_DESC",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
assert "扩展工具" in prompt
|
||||||
|
assert "custom_extended" in prompt
|
||||||
|
# 参数描述不应出现在扩展工具区域
|
||||||
|
assert "SECRET_PARAM_DESC" not in prompt
|
||||||
|
|
||||||
|
def test_mixed_core_and_extended(self):
|
||||||
|
"""混合 core + extended 工具,两者分区显示"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
FakeTool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the web.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"q": {"type": "string", "description": "query string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
assert "核心工具" in prompt
|
||||||
|
assert "扩展工具" in prompt
|
||||||
|
# read_file 在核心区,web_search 在扩展区
|
||||||
|
assert "read_file" in prompt
|
||||||
|
assert "web_search" in prompt
|
||||||
|
|
||||||
|
def test_maybe_add_tool_search_adds_for_extended(self):
|
||||||
|
"""有扩展工具时自动添加 tool_search"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."), # core
|
||||||
|
FakeTool(name="web_search", description="Search the web."), # extended
|
||||||
|
]
|
||||||
|
result = engine._maybe_add_tool_search(tools)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert any(t.name == "tool_search" for t in result)
|
||||||
|
|
||||||
|
def test_maybe_add_tool_search_skips_when_only_core(self):
|
||||||
|
"""只有 core 工具时不添加 tool_search"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
FakeTool(name="write_file", description="Write a file."),
|
||||||
|
]
|
||||||
|
result = engine._maybe_add_tool_search(tools)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert not any(t.name == "tool_search" for t in result)
|
||||||
|
|
||||||
|
def test_maybe_add_tool_search_skips_when_disabled(self):
|
||||||
|
"""enable_tool_search=False 时不添加 tool_search"""
|
||||||
|
engine = self._make_engine(enable_tool_search=False)
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
FakeTool(name="web_search", description="Search the web."),
|
||||||
|
]
|
||||||
|
result = engine._maybe_add_tool_search(tools)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert not any(t.name == "tool_search" for t in result)
|
||||||
|
|
||||||
|
def test_maybe_add_tool_search_skips_when_already_present(self):
|
||||||
|
"""tool_search 已存在时不重复添加"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
index = ToolSearchIndex([])
|
||||||
|
existing_search = ToolSearchTool(search_index=index)
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
FakeTool(name="web_search", description="Search the web."),
|
||||||
|
existing_search,
|
||||||
|
]
|
||||||
|
result = engine._maybe_add_tool_search(tools)
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
def test_custom_core_tool_names(self):
|
||||||
|
"""自定义 core_tool_names 覆盖默认值"""
|
||||||
|
engine = self._make_engine(core_tool_names=["my_core_tool"])
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="my_core_tool", description="My core tool."),
|
||||||
|
FakeTool(
|
||||||
|
name="read_file",
|
||||||
|
description="Read a file.",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"p": {"type": "string", "description": "PARAM_DESC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
# my_core_tool 在核心区
|
||||||
|
assert "核心工具" in prompt
|
||||||
|
# read_file 现在是扩展工具,参数描述不应出现
|
||||||
|
assert "PARAM_DESC" not in prompt
|
||||||
|
|
||||||
|
def test_tool_search_is_core_tool(self):
|
||||||
|
"""tool_search 被视为 core 工具(全量描述注入)"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
index = ToolSearchIndex([])
|
||||||
|
search_tool = ToolSearchTool(search_index=index)
|
||||||
|
tools = [search_tool]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
# tool_search 应该在核心工具区
|
||||||
|
assert "核心工具" in prompt
|
||||||
|
assert "tool_search" in prompt
|
||||||
|
# 其参数 query 应该被注入
|
||||||
|
assert "query" in prompt
|
||||||
|
|
||||||
|
def test_search_hint_added_when_tool_search_present(self):
|
||||||
|
"""tool_search 存在且有扩展工具时添加搜索提示"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
FakeTool(name="web_search", description="Search the web."),
|
||||||
|
]
|
||||||
|
# 模拟 _maybe_add_tool_search 添加 tool_search
|
||||||
|
tools_with_search = engine._maybe_add_tool_search(tools)
|
||||||
|
prompt = engine._build_tool_use_prompt(tools_with_search)
|
||||||
|
assert "tool_search" in prompt
|
||||||
|
assert "扩展工具" in prompt
|
||||||
|
# 搜索提示存在
|
||||||
|
assert "tool_search(query" in prompt
|
||||||
|
|
||||||
|
def test_no_search_hint_when_no_extended_tools(self):
|
||||||
|
"""无扩展工具时不添加搜索提示"""
|
||||||
|
engine = self._make_engine()
|
||||||
|
tools = [
|
||||||
|
FakeTool(name="read_file", description="Read a file."),
|
||||||
|
]
|
||||||
|
prompt = engine._build_tool_use_prompt(tools)
|
||||||
|
assert "扩展工具" not in prompt
|
||||||
Loading…
Reference in New Issue