From bbedfff5972376f8eae9d65fdc8545305a83e5ea Mon Sep 17 00:00:00 2001 From: chiguyong Date: Wed, 17 Jun 2026 10:46:16 +0800 Subject: [PATCH] feat: hub-and-spoke experts, tiered tool injection, unified event model (U3/U7/U10) --- src/agentkit/core/__init__.py | 13 + src/agentkit/core/event_queue.py | 244 +++++ src/agentkit/core/protocol.py | 107 ++ src/agentkit/experts/__init__.py | 19 +- src/agentkit/experts/config.py | 2 +- src/agentkit/experts/expert.py | 8 +- src/agentkit/experts/orchestrator.py | 1004 +++++++----------- src/agentkit/experts/plan.py | 293 ++--- src/agentkit/experts/registry.py | 12 +- src/agentkit/experts/team.py | 124 +-- src/agentkit/tools/__init__.py | 5 + src/agentkit/tools/search.py | 189 ++++ tests/unit/core/test_event_queue.py | 666 ++++++++++++ tests/unit/experts/test_plan.py | 460 ++++---- tests/unit/experts/test_router.py | 131 ++- tests/unit/experts/test_team.py | 148 +-- tests/unit/experts/test_team_orchestrator.py | 925 +++++++--------- tests/unit/tools/test_tool_search.py | 472 ++++++++ 18 files changed, 2953 insertions(+), 1869 deletions(-) create mode 100644 src/agentkit/core/event_queue.py create mode 100644 src/agentkit/tools/search.py create mode 100644 tests/unit/core/test_event_queue.py create mode 100644 tests/unit/tools/test_tool_search.py diff --git a/src/agentkit/core/__init__.py b/src/agentkit/core/__init__.py index ea1ffb9..3cde926 100644 --- a/src/agentkit/core/__init__.py +++ b/src/agentkit/core/__init__.py @@ -3,6 +3,7 @@ from agentkit.core.base import BaseAgent from agentkit.core.compressor import CompressionStrategy, ContextCompressor, create_compressor from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent +from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue from agentkit.core.exceptions import ( AgentAlreadyRegisteredError, AgentFrameworkError, @@ -29,12 +30,16 @@ from agentkit.core.protocol import ( AgentCapability, AgentStatus, CancellationToken, + Event, EvolutionEvent, HandoffMessage, + SessionEventType, + TaskEventType, TaskMessage, TaskProgress, TaskResult, TaskStatus, + TurnEventType, ) # Optional: HeadroomCompressor — only available when headroom-ai is installed @@ -80,4 +85,12 @@ __all__ = [ "TaskProgress", "TaskResult", "TaskStatus", + # SQ/EQ 统一事件模型 + "Event", + "SessionEventType", + "TaskEventType", + "TurnEventType", + "Submission", + "SubmissionQueue", + "EventQueue", ] diff --git a/src/agentkit/core/event_queue.py b/src/agentkit/core/event_queue.py new file mode 100644 index 0000000..887e208 --- /dev/null +++ b/src/agentkit/core/event_queue.py @@ -0,0 +1,244 @@ +"""统一事件模型 - SQ/EQ 双队列实现 + +参考 Codex 的 Session/Task/Turn 三级模型,提供: +- SubmissionQueue (SQ): 接收用户输入,返回 task_id +- EventQueue (EQ): 推送 Agent 事件,支持多订阅者广播和事件缓冲回放 + +集成点(Phase 4 再做实际集成): +- Portal WebSocket 可通过 EventQueue.subscribe() 订阅事件流并推送给前端 +- CLI 可通过 EventQueue.subscribe() 订阅事件流并打印 +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import AsyncIterator + +from agentkit.core.protocol import Event + +logger = logging.getLogger(__name__) + + +# 哨兵对象:用于通知订阅者队列已关闭(基于身份比较,不参与事件流) +_CLOSED_SENTINEL: Event = Event( + event_type="__closed__", + task_id="", + session_id="", + data={}, + timestamp="", +) + + +@dataclass +class Submission: + """用户提交的任务 + + 由 SubmissionQueue.submit() 创建,消费者通过 SubmissionQueue.drain() 获取。 + + Attributes: + task_id: 唯一任务 ID + session_id: 会话 ID + content: 用户输入内容 + created_at: 创建时间 + cancelled: 是否已取消 + """ + + task_id: str + session_id: str + content: str + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + cancelled: bool = False + + +class SubmissionQueue: + """提交队列 (SQ) - 接收用户输入 + + 用户通过 submit() 提交输入,消费者通过 drain() 获取提交。 + 支持通过 task_id 取消提交。 + + 内部使用 asyncio.Queue 实现,队列大小上限 1024(与 HandoffTransport 一致)。 + """ + + _MAX_QUEUE_SIZE: int = 1024 + + def __init__(self) -> None: + self._queue: asyncio.Queue[Submission] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE) + self._submissions: dict[str, Submission] = {} + self._cancelled_tasks: set[str] = set() + self._closed: bool = False + + async def submit(self, content: str, session_id: str) -> str: + """提交用户输入,返回 task_id + + Args: + content: 用户输入内容 + session_id: 会话 ID + + Returns: + 新生成的 task_id(UUID4) + + Raises: + RuntimeError: 队列已关闭 + """ + if self._closed: + raise RuntimeError("SubmissionQueue is closed") + task_id = str(uuid.uuid4()) + submission = Submission( + task_id=task_id, + session_id=session_id, + content=content, + ) + self._submissions[task_id] = submission + await self._queue.put(submission) + return task_id + + async def drain(self) -> AsyncIterator[Submission]: + """消费提交队列(异步生成器) + + 已取消的提交会被跳过。当队列关闭时,生成器结束。 + """ + while True: + submission = await self._queue.get() + if submission.task_id in self._cancelled_tasks: + continue + yield submission + + async def cancel(self, task_id: str) -> bool: + """取消任务 + + Args: + task_id: 要取消的任务 ID + + Returns: + 是否成功取消(任务存在且未取消过) + """ + if task_id not in self._submissions: + return False + if task_id in self._cancelled_tasks: + return False + self._cancelled_tasks.add(task_id) + self._submissions[task_id].cancelled = True + return True + + @property + def is_closed(self) -> bool: + """返回队列是否已关闭""" + return self._closed + + def close(self) -> None: + """关闭队列,不再接受新提交。 + + 已在队列中的提交不受影响,消费者仍可通过 drain() 获取。 + """ + self._closed = True + + +class EventQueue: + """事件队列 (EQ) - 推送 Agent 事件 + + 支持多订阅者广播模式:每条事件会投递到所有活跃订阅者。 + 新订阅者会收到最近 N 条缓冲事件的回放(默认 100 条)。 + + 集成点: + - Portal WebSocket 可通过 subscribe() 订阅事件流并推送给前端 + - CLI 可通过 subscribe() 订阅事件流并打印 + """ + + _MAX_QUEUE_SIZE: int = 1024 + _DEFAULT_BUFFER_SIZE: int = 100 + + def __init__(self, buffer_size: int = _DEFAULT_BUFFER_SIZE) -> None: + self._subscribers: list[asyncio.Queue[Event]] = [] + self._buffer: deque[Event] = deque(maxlen=buffer_size) + self._buffer_size = buffer_size + self._closed: bool = False + + async def emit(self, event: Event) -> None: + """推送事件给所有订阅者 + + 事件会同时写入缓冲区(供未来订阅者回放)和所有活跃订阅者队列。 + 如果某订阅者队列已满,该事件对该订阅者被丢弃(不影响其他订阅者)。 + + Args: + event: 要推送的事件 + """ + self._buffer.append(event) + for queue in self._subscribers: + try: + queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning("EventQueue subscriber queue full, dropping event") + + async def subscribe(self) -> AsyncIterator[Event]: + """订阅事件流(异步生成器) + + 订阅时会先回放缓冲区中的事件,然后持续接收新事件。 + 每个订阅者获得独立的队列,实现广播语义。 + + 当队列关闭时,生成器结束。 + + 注意:回放和加入订阅者列表在同一同步段内完成(无 await), + 保证不会遗漏或重复事件。 + """ + if self._closed: + return + + queue: asyncio.Queue[Event] = asyncio.Queue(maxsize=self._MAX_QUEUE_SIZE) + + # 回放缓冲事件(同步操作,无 await,保证原子性) + for event in list(self._buffer): + try: + queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning("EventQueue replay buffer full, skipping remaining") + break + + # 加入订阅者列表(在回放之后,确保不会收到重复事件) + self._subscribers.append(queue) + + try: + while True: + event = await queue.get() + if event is _CLOSED_SENTINEL: + break + yield event + finally: + # 清理:移除当前订阅者的队列 + if queue in self._subscribers: + self._subscribers.remove(queue) + + @property + def subscriber_count(self) -> int: + """返回当前订阅者数量""" + return len(self._subscribers) + + @property + def buffer_size(self) -> int: + """返回缓冲区大小上限""" + return self._buffer_size + + @property + def is_closed(self) -> bool: + """返回队列是否已关闭""" + return self._closed + + def close(self) -> None: + """关闭队列,向所有订阅者发送哨兵并清理状态。 + + 使用哨兵模式(参考 HandoffTransport),确保阻塞在 subscribe() 上的 + 订阅者能够优雅退出。 + """ + self._closed = True + # 向所有活跃订阅者队列放入哨兵,使其能够优雅退出 + for queue in self._subscribers: + try: + queue.put_nowait(_CLOSED_SENTINEL) + except asyncio.QueueFull: + pass + self._subscribers.clear() + self._buffer.clear() diff --git a/src/agentkit/core/protocol.py b/src/agentkit/core/protocol.py index 6b5286c..99ac0aa 100644 --- a/src/agentkit/core/protocol.py +++ b/src/agentkit/core/protocol.py @@ -10,6 +10,7 @@ from agentkit.core.exceptions import TaskCancelledError class TaskStatus(str, Enum): """任务状态枚举""" + PENDING = "pending" RUNNING = "running" COMPLETED = "completed" @@ -21,6 +22,7 @@ class TaskStatus(str, Enum): class AgentStatus(str, Enum): """Agent 状态枚举""" + ONLINE = "online" OFFLINE = "offline" BUSY = "busy" @@ -29,6 +31,7 @@ class AgentStatus(str, Enum): @dataclass class AgentCapability: """Agent 能力声明""" + agent_name: str agent_type: str version: str @@ -70,6 +73,7 @@ class AgentCapability: @dataclass class TaskMessage: """任务消息 - 从调度器发往 Agent""" + task_id: str agent_name: str task_type: str @@ -114,6 +118,7 @@ class TaskMessage: @dataclass class TaskResult: """任务结果 - 从 Agent 返回""" + task_id: str agent_name: str status: str @@ -163,6 +168,7 @@ class TaskResult: @dataclass class TaskProgress: """进度上报 - Agent 执行过程中上报""" + task_id: str agent_name: str progress: float @@ -195,6 +201,7 @@ class TaskProgress: @dataclass class HandoffMessage: """任务转交消息 - Agent 间 Handoff""" + source_agent: str target_agent: str task_id: str @@ -233,6 +240,7 @@ class HandoffMessage: @dataclass class EvolutionEvent: """进化事件 - 记录 Agent 的自我进化变更""" + agent_name: str change_type: str # prompt / strategy / pipeline before: dict[str, Any] @@ -277,3 +285,102 @@ class CancellationToken: """检查是否已取消,若已取消则抛出 TaskCancelledError""" if self._cancelled: raise TaskCancelledError(task_id="") + + +# ── SQ/EQ 统一事件模型 ────────────────────────────────────────── +# 参考 Codex 的 Session/Task/Turn 三级模型,统一 CLI 和 WebSocket 事件流。 + + +class SessionEventType: + """Session 级别事件类型 + + 对应一次完整的会话生命周期(如用户打开 CLI、关闭 CLI)。 + """ + + SESSION_STARTED = "session.started" + SESSION_ENDED = "session.ended" + + +class TaskEventType: + """Task 级别事件类型 + + 对应一次任务提交(用户输入),从创建到终态(完成/失败)。 + """ + + TASK_CREATED = "task.created" + TASK_STARTED = "task.started" + TASK_COMPLETED = "task.completed" + TASK_FAILED = "task.failed" + + +class TurnEventType: + """Turn 级别事件类型 + + 对应一个 Task 内的多轮对话/推理步骤,包括思考、工具调用、Token 流等。 + """ + + TURN_STARTED = "turn.started" + THINKING = "turn.thinking" + TOOL_CALL = "turn.tool_call" + TOOL_RESULT = "turn.tool_result" + TOKEN = "turn.token" + STEP = "turn.step" + FINAL_ANSWER = "turn.final_answer" + TURN_COMPLETED = "turn.completed" + + +@dataclass +class Event: + """统一事件模型 - SQ/EQ 双队列事件 + + 所有 CLI 和 WebSocket 事件统一使用此结构,便于前端和 CLI 统一处理。 + + Attributes: + event_type: 事件类型(见 SessionEventType / TaskEventType / TurnEventType) + task_id: 关联的任务 ID + session_id: 关联的会话 ID + data: 事件负载数据 + timestamp: ISO 8601 格式时间戳 + """ + + event_type: str + task_id: str + session_id: str + data: dict[str, Any] + timestamp: str # ISO 8601 format + + def to_dict(self) -> dict: + return { + "event_type": self.event_type, + "task_id": self.task_id, + "session_id": self.session_id, + "data": self.data, + "timestamp": self.timestamp, + } + + @classmethod + def from_dict(cls, data: dict) -> "Event": + return cls( + event_type=data["event_type"], + task_id=data["task_id"], + session_id=data["session_id"], + data=data.get("data", {}), + timestamp=data["timestamp"], + ) + + @classmethod + def create( + cls, + event_type: str, + task_id: str, + session_id: str, + data: dict[str, Any] | None = None, + ) -> "Event": + """创建事件,自动生成 ISO 8601 格式时间戳""" + return cls( + event_type=event_type, + task_id=task_id, + session_id=session_id, + data=data or {}, + timestamp=datetime.now(timezone.utc).isoformat(), + ) diff --git a/src/agentkit/experts/__init__.py b/src/agentkit/experts/__init__.py index 6be2c13..2303d0a 100644 --- a/src/agentkit/experts/__init__.py +++ b/src/agentkit/experts/__init__.py @@ -1,22 +1,23 @@ -"""Expert 系统 - 专家团队模式的配置、模板、注册与协作计划""" +"""Expert 系统 - 专家团队 hub-and-spoke 模式 + +U3 简化:从去中心化协作改为 Lead Expert + 并行 Task 模式。 +""" from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.expert import Expert from agentkit.experts.orchestrator import TeamOrchestrator from agentkit.experts.plan import ( - CollaborationPlan, MergeStrategy, - ParallelType, - PhaseStatus, - PlanPhase, PlanStatus, + SubTask, + SubTaskStatus, + TeamPlan, ) from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.router import ExpertTeamRouter, ExpertTeamRoutingResult from agentkit.experts.team import ExpertTeam, TeamStatus __all__ = [ - "CollaborationPlan", "Expert", "ExpertConfig", "ExpertTeam", @@ -25,10 +26,10 @@ __all__ = [ "ExpertTemplate", "ExpertTemplateRegistry", "MergeStrategy", - "ParallelType", - "PhaseStatus", - "PlanPhase", "PlanStatus", + "SubTask", + "SubTaskStatus", "TeamOrchestrator", + "TeamPlan", "TeamStatus", ] diff --git a/src/agentkit/experts/config.py b/src/agentkit/experts/config.py index 001c3f6..79b8615 100644 --- a/src/agentkit/experts/config.py +++ b/src/agentkit/experts/config.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any from agentkit.core.config_driven import AgentConfig diff --git a/src/agentkit/experts/expert.py b/src/agentkit/experts/expert.py index 76fea46..0355b15 100644 --- a/src/agentkit/experts/expert.py +++ b/src/agentkit/experts/expert.py @@ -65,7 +65,9 @@ class Expert: # 如果提供了团队上下文,修改 Agent 的 prompt 以注入团队角色信息 if team_context and hasattr(agent, "_prompt_template") and agent._prompt_template: sections = agent._prompt_template._sections - sections.context = f"{team_context}\n\n{sections.context}" if sections.context else team_context + sections.context = ( + f"{team_context}\n\n{sections.context}" if sections.context else team_context + ) return expert @@ -116,9 +118,7 @@ class Expert: "reason": reason, "type": "assist_request", } - await self._handoff_transport.send( - f"expert:{target_expert}:handoff", handoff_msg - ) + await self._handoff_transport.send(f"expert:{target_expert}:handoff", handoff_msg) async def propose_plan_modification( self, diff --git a/src/agentkit/experts/orchestrator.py b/src/agentkit/experts/orchestrator.py index 14fe1f4..fb09cea 100644 --- a/src/agentkit/experts/orchestrator.py +++ b/src/agentkit/experts/orchestrator.py @@ -1,495 +1,200 @@ -"""TeamOrchestrator - 专家团队协作计划执行引擎 +"""TeamOrchestrator - hub-and-spoke 专家团队执行引擎 -驱动 CollaborationPlan 在 ExpertTeam 中的执行,负责: -- 阶段执行(串行、子任务并行、竞争并行) -- 结果合并(BEST / VOTE / FUSION) -- 里程碑检查点 -- 重试 + 回退到单 Agent 模式 -- 事件广播 +驱动 ExpertTeam 在 hub-and-spoke 模式下执行任务: + +1. Lead Expert 接收任务,自主分解为子任务 +2. 并行 spawn Task(每个 Task 是独立 Agent 执行实例,深度=1) +3. 等待所有 Task 完成 +4. Lead Expert 汇总结果(BEST 策略) +5. 返回最终结果 + +约束: +- Task 深度=1(Task 不能再 spawn Task) +- Task 之间无通信 +- Lead Expert 持有所有状态 + +设计依据: +- Claude Code: Task 工具深度=1,子 Agent 不能再生子 Agent +- Codex: spawn_agent 层级式,结果返回父 Agent +- 去中心化协作的通信复杂度 O(N²),hub-and-spoke 为 O(N) """ from __future__ import annotations import asyncio +import json import logging +import re from datetime import datetime, timezone from typing import Any from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus -from agentkit.core.shared_workspace import SharedWorkspace from .expert import Expert -from .plan import ( - CollaborationPlan, - MergeStrategy, - ParallelType, - PhaseStatus, - PlanPhase, - PlanStatus, -) +from .plan import PlanStatus, SubTask, SubTaskStatus, TeamPlan from .team import ExpertTeam, TeamStatus logger = logging.getLogger(__name__) class TeamOrchestrator: - """Orchestration engine that drives CollaborationPlan execution within an ExpertTeam.""" + """Hub-and-spoke orchestration engine. - MAX_RETRIES = 1 # Retry once on failure before fallback - MAX_INTERACTION_ROUNDS = 20 # Prevent infinite collaboration loops - MAX_REPLANS = 2 # Maximum replanning attempts before fallback + Lead Expert acts as the hub: it decomposes the task, dispatches subtasks + to member experts (spokes) in parallel, and synthesizes the final result. + """ - def __init__( - self, - team: ExpertTeam, - workspace: "SharedWorkspace | None" = None, - max_replans: int = 2, - ) -> None: + MAX_SUBTASKS = 10 # Maximum subtasks Lead Expert can decompose + MAX_RETRIES = 1 # Retry once on subtask failure before marking failed + + def __init__(self, team: ExpertTeam) -> None: self._team = team - self._workspace = workspace or team._workspace - self._interaction_count = 0 - self._max_replans = max_replans - async def execute_plan(self, plan: CollaborationPlan) -> dict[str, Any]: - """Execute a CollaborationPlan within the team. + async def execute(self, task: str) -> dict[str, Any]: + """Execute a task in hub-and-spoke mode. + + Flow: + 1. Emit team_formed event + 2. Lead Expert decomposes task into subtasks + 3. Spawn parallel subtasks (each independent Agent execution) + 4. Wait for all subtasks to complete + 5. Lead Expert synthesizes results (BEST strategy) + 6. Emit team_synthesis and team_dissolved events Returns a dict with: - "status": "completed" | "failed" | "fallback" - "result": final synthesized result - - "phase_results": dict of phase_id -> result + - "subtask_results": dict of subtask_id -> result + - "plan": TeamPlan instance """ - # Validate plan first - errors = plan.validate() - if errors: - logger.error(f"Plan validation failed: {errors}") - return { - "status": "failed", - "result": None, - "phase_results": {}, - "errors": errors, - } + lead = self._team.lead_expert + if not lead or not lead.is_active: + active = self._team.active_experts + if not active: + return { + "status": "failed", + "result": None, + "subtask_results": {}, + "error": "No active expert available", + } + lead = active[0] + logger.warning(f"Lead expert not available, falling back to '{lead.config.name}'") - plan.status = PlanStatus.EXECUTING + plan = TeamPlan( + task=task, + lead_expert=lead.config.name, + status=PlanStatus.EXECUTING, + ) self._team.set_status(TeamStatus.EXECUTING) - self._interaction_count = 0 # Reset for each plan execution - phase_results: dict[str, dict[str, Any]] = {} - retry_counts: dict[str, int] = {} # Per-phase retry tracking - replan_count = 0 + # 1. Emit team_formed event + await self._broadcast_event( + "team_formed", + { + "team_id": self._team.team_id, + "status": self._team.status.value, + "lead_expert": lead.config.name, + "experts": [e.config.name for e in self._team.active_experts], + }, + ) try: - while True: - ready_phases = plan.get_ready_phases() + # 2. Lead Expert decomposes task into subtasks + subtasks = await self._decompose_task(lead, task) + if not subtasks: + # If decomposition fails, treat the whole task as a single subtask + logger.warning("Task decomposition returned no subtasks, executing as single task") + subtasks = [SubTask(description=task, assigned_expert=lead.config.name)] - if not ready_phases: - # Check if all phases are done - all_done = all( - p.status in (PhaseStatus.COMPLETED, PhaseStatus.FAILED) - for p in plan.phases - ) - if all_done: - break + plan.subtasks = subtasks[: self.MAX_SUBTASKS] - # Check for stuck state (some phases pending but none ready) - pending = [ - p for p in plan.phases if p.status == PhaseStatus.PENDING - ] - if pending: - # Cascade: mark pending phases with failed deps as FAILED - failed_ids = { - p.id for p in plan.phases if p.status == PhaseStatus.FAILED - } - for p in pending: - if any(dep in failed_ids for dep in p.dependencies): - plan.update_phase_status(p.id, PhaseStatus.FAILED) - phase_results[p.id] = { - "error": f"Dependency failed, cannot execute phase '{p.name}'" - } - logger.warning( - f"Phase {p.id} marked FAILED due to failed dependency" - ) - - # Re-check after cascade - still_pending = [ - p for p in plan.phases if p.status == PhaseStatus.PENDING - ] - if not still_pending: - break - - # If still stuck, trigger fallback - logger.warning( - f"Stuck: {len(still_pending)} pending phases with unresolvable deps" - ) - return await self._fallback_to_single_agent( - plan, phase_results - ) - - break - - # Group ready phases by parallel type - serial_phases = [ - p for p in ready_phases if p.parallel_type == ParallelType.SERIAL - ] - parallel_phases = [ - p - for p in ready_phases - if p.parallel_type == ParallelType.SUBTASK_PARALLEL - ] - competitive_phases = [ - p - for p in ready_phases - if p.parallel_type == ParallelType.COMPETITIVE_PARALLEL - ] - - # Execute serial phases - for phase in serial_phases: - result = await self._execute_phase(phase, plan, phase_results) - if result is None: - # Phase failed — retry per-phase - phase_retries = retry_counts.get(phase.id, 0) - if phase_retries < self.MAX_RETRIES: - retry_counts[phase.id] = phase_retries + 1 - logger.info( - f"Retrying phase {phase.id} (attempt {phase_retries + 1})" - ) - # Reset phase status for retry - plan.update_phase_status(phase.id, PhaseStatus.PENDING) - result = await self._execute_phase(phase, plan, phase_results) - - if result is None: - # Still failed after retry — try replanning before fallback - if replan_count < self._max_replans: - replan_count += 1 - logger.info( - f"Phase {phase.id} failed after retry, " - f"attempting replan ({replan_count}/{self._max_replans})" - ) - await self._broadcast_event( - "replanning", - { - "phase_id": phase.id, - "replan_count": replan_count, - "reason": "phase_failed", - }, - ) - # Reset phase status for replan - plan.update_phase_status(phase.id, PhaseStatus.PENDING) - result = await self._execute_phase(phase, plan, phase_results) - - if result is None: - # Still failed after replan — fallback to single agent - logger.warning( - f"Phase {phase.id} failed after replan, falling back to single agent" - ) - return await self._fallback_to_single_agent( - plan, phase_results - ) - - phase_results[phase.id] = result - - # Execute subtask-level parallel phases - if parallel_phases: - results = await asyncio.gather( - *[ - self._execute_phase(p, plan, phase_results) - for p in parallel_phases - ], - return_exceptions=True, - ) - - all_parallel_failed = True - for phase, result in zip(parallel_phases, results): - if isinstance(result, Exception): - logger.error( - f"Parallel phase {phase.id} failed: {result}" - ) - plan.update_phase_status(phase.id, PhaseStatus.FAILED) - phase_results[phase.id] = {"error": str(result)} - else: - all_parallel_failed = False - phase_results[phase.id] = result - - # If all parallel phases failed, trigger fallback - if all_parallel_failed: - logger.warning("All parallel phases failed, falling back to single agent") - return await self._fallback_to_single_agent( - plan, phase_results - ) - - # Execute competitive parallel phases - for phase in competitive_phases: - result = await self._execute_competitive_phase( - phase, plan, phase_results - ) - if "error" in result: - # Competitive phase completely failed - logger.warning( - f"Competitive phase {phase.id} failed: {result.get('error')}" - ) - return await self._fallback_to_single_agent( - plan, phase_results - ) - phase_results[phase.id] = result - - self._interaction_count += 1 - if self._interaction_count >= self.MAX_INTERACTION_ROUNDS: - logger.warning("Max interaction rounds reached") - break - - # Synthesize final result - plan.status = PlanStatus.COMPLETED - self._team.set_status(TeamStatus.SYNTHESIZING) - - final_result = await self._synthesize_results(plan, phase_results) - - self._team.set_status(TeamStatus.COMPLETED) - return { - "status": "completed", - "result": final_result, - "phase_results": phase_results, - } - - except Exception as e: - logger.error(f"Plan execution failed: {e}") - plan.status = PlanStatus.FAILED - return { - "status": "failed", - "result": None, - "phase_results": phase_results, - "error": str(e), - } - - async def _execute_phase( - self, - phase: PlanPhase, - plan: CollaborationPlan, - phase_results: dict[str, dict[str, Any]], - ) -> dict[str, Any] | None: - """Execute a single phase. Returns result dict or None on failure.""" - plan.update_phase_status(phase.id, PhaseStatus.IN_PROGRESS) - - try: - # Broadcast phase start (inside try so transient broadcast failures don't kill the plan) + # 3. Emit plan_update with subtask list await self._broadcast_event( - "phase_started", + "plan_update", { - "phase_id": phase.id, - "phase_name": phase.name, - "assigned_expert": phase.assigned_expert, + "plan_id": plan.id, + "subtasks": [st.to_dict() for st in plan.subtasks], }, ) - # Get the assigned expert - expert = self._team.get_expert(phase.assigned_expert) - if not expert or not expert.is_active: - # Fallback to lead expert or first active expert - expert = self._team.lead_expert - if not expert or not expert.is_active: - active = self._team.active_experts - if not active: - raise RuntimeError( - f"Expert '{phase.assigned_expert}' not available and no active fallback" - ) - expert = active[0] - logger.warning( - f"Expert '{phase.assigned_expert}' not available, " - f"falling back to '{expert.config.name}'" - ) - - # Build TaskMessage for real execution - input_data: dict[str, Any] = { - "phase_name": phase.name, - "phase_description": phase.task_description or phase.name, - "team_id": self._team.team_id, - } - # Inject dependency results from previous phases - if phase.depends_on: - dep_results: dict[str, dict[str, Any]] = {} - for dep_id in phase.depends_on: - # Try workspace first, then fall back to in-memory phase_results - if self._workspace: - ws_data = await self._workspace.read( - f"team:{self._team.team_id}:phase:{dep_id}:result" - ) - if ws_data: - dep_results[dep_id] = ws_data.get("value", {}) - continue - if dep_id in phase_results: - dep_results[dep_id] = phase_results[dep_id] - if dep_results: - input_data["dependency_results"] = dep_results - - task_msg = TaskMessage( - task_id=phase.id, - agent_name=expert.config.name, - task_type="team_phase", - priority=0, - input_data=input_data, - callback_url=None, - created_at=datetime.now(timezone.utc), + # 4. Spawn parallel subtasks + subtask_results: dict[str, dict[str, Any]] = {} + results = await asyncio.gather( + *[self._execute_subtask(st) for st in plan.subtasks], + return_exceptions=True, ) - # Execute the task via the expert's agent - task_result: TaskResult = await expert.agent.execute(task_msg) - - if task_result.status != TaskStatus.COMPLETED.value: - raise RuntimeError( - f"Agent execution failed: {task_result.error_message or 'unknown error'}" - ) - - result = task_result.output_data or {"content": ""} - - # Write result to workspace for cross-phase state sharing - if self._workspace: - try: - await self._workspace.write( - f"team:{self._team.team_id}:phase:{phase.id}:result", - result, - agent_id=expert.config.name, + for subtask, result in zip(plan.subtasks, results): + if isinstance(result, Exception): + logger.error(f"Subtask {subtask.id} failed: {result}") + plan.update_subtask_status( + subtask.id, SubTaskStatus.FAILED, {"error": str(result)} ) - except Exception as e: - logger.warning(f"Workspace write failed for phase {phase.id}: {e}") + subtask_results[subtask.id] = {"error": str(result)} + else: + subtask_results[subtask.id] = result - # Check milestone - if phase.milestone: - milestone_passed = await self._check_milestone(phase, result) - if not milestone_passed: - plan.update_phase_status(phase.id, PhaseStatus.FAILED) - try: - await self._broadcast_event( - "milestone_failed", - {"phase_id": phase.id, "milestone": phase.milestone}, - ) - except Exception: - pass - return None + # 5. Check if all subtasks failed + completed = plan.completed_subtasks + if not completed: + logger.warning("All subtasks failed, falling back to single agent") + return await self._fallback_to_single_agent(task, plan, subtask_results) - plan.update_phase_status(phase.id, PhaseStatus.COMPLETED, result) + # 6. Lead Expert synthesizes results (BEST strategy) + self._team.set_status(TeamStatus.SYNTHESIZING) + plan.status = PlanStatus.COMPLETED - try: - await self._broadcast_event( - "phase_completed", - {"phase_id": phase.id, "phase_name": phase.name}, - ) - except Exception: - pass + final_result = await self._synthesize_results(lead, task, [st for st in completed]) - return result + self._team.set_status(TeamStatus.COMPLETED) + + # 7. Emit team_synthesis event + await self._broadcast_event( + "team_synthesis", + { + "content": final_result.get("content", ""), + "subtasks_completed": len(completed), + "subtasks_total": len(plan.subtasks), + }, + ) + + return { + "status": "completed", + "result": final_result, + "subtask_results": subtask_results, + "plan": plan, + } except Exception as e: - logger.error(f"Phase {phase.id} execution failed: {e}") - plan.update_phase_status(phase.id, PhaseStatus.FAILED) - try: - await self._broadcast_event( - "phase_failed", {"phase_id": phase.id, "error": str(e)} - ) - except Exception: - pass - return None + logger.error(f"Hub-and-spoke execution failed: {e}") + plan.status = PlanStatus.FAILED + return await self._fallback_to_single_agent(task, plan, subtask_results) - async def _execute_competitive_phase( - self, - phase: PlanPhase, - plan: CollaborationPlan, - phase_results: dict[str, dict[str, Any]], - ) -> dict[str, Any]: - """Execute a competitive parallel phase with merge strategy.""" - plan.update_phase_status(phase.id, PhaseStatus.IN_PROGRESS) + async def _decompose_task(self, lead: Expert, task: str) -> list[SubTask]: + """Lead Expert decomposes task into subtasks using LLM. - # For competitive parallel, we need multiple experts working on the same task - # In practice, the plan should specify which experts compete - # For now, we use all active experts as competitors - competitors = self._team.active_experts - - # Run all competitors in parallel - results = await asyncio.gather( - *[self._run_competitor(expert, phase) for expert in competitors], - return_exceptions=True, - ) - - # Filter out exceptions - valid_results = [r for r in results if not isinstance(r, Exception)] - - if not valid_results: - plan.update_phase_status(phase.id, PhaseStatus.FAILED) - return {"error": "All competitors failed"} - - # Apply merge strategy - merged = await self._merge_results(phase, valid_results) - - plan.update_phase_status(phase.id, PhaseStatus.COMPLETED, merged) - return merged - - async def _run_competitor( - self, expert: Expert, phase: PlanPhase - ) -> dict[str, Any]: - """Run a single competitor for a competitive phase.""" - # Build TaskMessage for real execution - task_msg = TaskMessage( - task_id=f"{phase.id}_{expert.config.name}", - agent_name=expert.config.name, - task_type="team_competitive", - priority=0, - input_data={ - "phase_name": phase.name, - "phase_description": phase.task_description or phase.name, - "team_id": self._team.team_id, - }, - callback_url=None, - created_at=datetime.now(timezone.utc), - ) - - task_result: TaskResult = await expert.agent.execute(task_msg) - - if task_result.status != TaskStatus.COMPLETED.value: - raise RuntimeError( - f"Competitor {expert.config.name} failed: {task_result.error_message or 'unknown'}" - ) - - return { - "expert": expert.config.name, - "output": task_result.output_data or {}, - "status": task_result.status, - } - - def _get_llm_gateway(self) -> Any: - """Get LLM gateway from the lead expert's agent.""" - lead = self._team.lead_expert - if lead and hasattr(lead, "agent") and hasattr(lead.agent, "_llm_gateway"): - return lead.agent._llm_gateway - # Fallback: try first active expert - for expert in self._team.active_experts: - if hasattr(expert, "agent") and hasattr(expert.agent, "_llm_gateway"): - return expert.agent._llm_gateway - return None - - @staticmethod - def _build_result_summaries(results: list[dict[str, Any]], max_len: int = 500) -> list[str]: - """Build text summaries from competitor results for LLM evaluation.""" - summaries = [] - for i, r in enumerate(results): - output = r.get("output", {}) - content = output.get("content", str(output)) if isinstance(output, dict) else str(output) - summaries.append(f"Result {i + 1} (by {r.get('expert', 'unknown')}):\n{content[:max_len]}") - return summaries - - async def _llm_pick_best( - self, task: str, results: list[dict[str, Any]] - ) -> dict[str, Any]: - """Use LLM to evaluate and pick the best result.""" - gateway = self._get_llm_gateway() + Returns a list of SubTask instances. If LLM decomposition fails, + returns a single subtask with the original task. + """ + gateway = self._get_llm_gateway(lead) if not gateway: - return results[0] + logger.warning("No LLM gateway available, treating task as single subtask") + return [SubTask(description=task, assigned_expert=lead.config.name)] - # Build evaluation prompt - result_summaries = self._build_result_summaries(results) + member_names = [ + e.config.name for e in self._team.active_experts if e.config.name != lead.config.name + ] + available_experts = member_names if member_names else [lead.config.name] prompt = ( + f"You are the Lead Expert in a team. Decompose the following task into " + f"at most {self.MAX_SUBTASKS} independent subtasks that can be executed in parallel.\n\n" f"Task: {task}\n\n" - f"Below are {len(results)} candidate results. Pick the BEST one based on " - f"completeness, accuracy, and relevance to the task.\n\n" - + "\n---\n".join(result_summaries) - + "\n\nReply with ONLY the number of the best result (e.g., '1' or '2')." + f"Available experts: {', '.join(available_experts)}\n\n" + f"Return a JSON array of objects, each with:\n" + f'- "description": clear subtask description\n' + f'- "assigned_expert": name of the expert to assign (must be one of: {", ".join(available_experts)})\n\n' + f"Return ONLY the JSON array, no other text." ) try: @@ -497,88 +202,199 @@ class TeamOrchestrator: messages=[{"role": "user", "content": prompt}], model="default", ) - choice = response.content.strip() - # Parse the number from the response - for ch in choice: - if ch.isdigit(): - idx = int(ch) - 1 - if 0 <= idx < len(results): - return results[idx] + subtasks = self._parse_subtasks(response.content, available_experts, lead.config.name) + if subtasks: + return subtasks + logger.warning("LLM decomposition returned no valid subtasks") except Exception as e: - logger.warning(f"LLM best-pick failed, falling back to first result: {e}") + logger.warning(f"LLM task decomposition failed: {e}") - return results[0] + return [SubTask(description=task, assigned_expert=lead.config.name)] - async def _llm_vote( - self, task: str, results: list[dict[str, Any]] - ) -> dict[str, Any]: - """Use LLM voting to select the best result.""" - gateway = self._get_llm_gateway() - if not gateway: - return results[0] + @staticmethod + def _parse_subtasks( + content: str, available_experts: list[str], lead_name: str + ) -> list[SubTask]: + """Parse LLM response into SubTask list. - scores: dict[int, float] = {} - result_summaries = self._build_result_summaries(results) + Extracts JSON array from the response content and creates SubTask instances. + Validates assigned_expert against available_experts list. + """ + # Try to extract JSON array from the response + json_match = re.search(r"\[.*\]", content, re.DOTALL) + if not json_match: + return [] - # Each expert votes by ranking results (excluding their own) - for voter_idx, r in enumerate(results): - # Build summaries excluding the voter's own result to avoid self-voting bias - other_indices = [i for i in range(len(results)) if i != voter_idx] - other_summaries = [result_summaries[i] for i in other_indices] + try: + items = json.loads(json_match.group(0)) + except json.JSONDecodeError: + return [] - prompt = ( - f"Task: {task}\n\n" - f"Below are {len(other_summaries)} candidate results. Rank them from best to worst.\n\n" - + "\n---\n".join(other_summaries) - + "\n\nReply with ONLY a comma-separated list of result numbers, best first (e.g., '2,1,3')." + if not isinstance(items, list): + return [] + + subtasks: list[SubTask] = [] + for item in items: + if not isinstance(item, dict): + continue + description = item.get("description", "").strip() + if not description: + continue + assigned = item.get("assigned_expert", "").strip() + # Validate assigned expert; fall back to lead if invalid + if assigned not in available_experts: + assigned = lead_name + subtasks.append(SubTask(description=description, assigned_expert=assigned)) + return subtasks + + async def _execute_subtask(self, subtask: SubTask) -> dict[str, Any]: + """Execute a single subtask using the assigned expert. + + Each subtask is an independent Agent execution (Task depth=1). + Subtasks cannot spawn further subtasks. + """ + # Resolve the assigned expert + expert = self._team.get_expert(subtask.assigned_expert) + if not expert or not expert.is_active: + # Fallback to lead expert or first active expert + expert = self._team.lead_expert + if not expert or not expert.is_active: + active = self._team.active_experts + if not active: + raise RuntimeError( + f"Expert '{subtask.assigned_expert}' not available and no active fallback" + ) + expert = active[0] + logger.warning( + f"Expert '{subtask.assigned_expert}' not available, " + f"falling back to '{expert.config.name}'" ) + + # Update subtask status + subtask.status = SubTaskStatus.RUNNING + subtask.assigned_expert = expert.config.name + + # Emit expert_step event + await self._broadcast_event( + "expert_step", + { + "expert_id": expert.config.name, + "expert_name": expert.config.name, + "expert_color": expert.config.color, + "content": subtask.description, + "step": subtask.id, + }, + ) + + # Build TaskMessage for execution + task_msg = TaskMessage( + task_id=subtask.id, + agent_name=expert.config.name, + task_type="team_subtask", + priority=0, + input_data={ + "task": subtask.description, + "team_id": self._team.team_id, + "is_subtask": True, # Marker: depth=1, cannot spawn further subtasks + }, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + # Execute with retry + last_error: str | None = None + for attempt in range(self.MAX_RETRIES + 1): try: - response = await gateway.chat( - messages=[{"role": "user", "content": prompt}], - model="default", + task_result: TaskResult = await expert.agent.execute(task_msg) + + if task_result.status != TaskStatus.COMPLETED.value: + last_error = task_result.error_message or "unknown error" + if attempt < self.MAX_RETRIES: + logger.info(f"Retrying subtask {subtask.id} (attempt {attempt + 1})") + continue + raise RuntimeError(f"Agent execution failed: {last_error}") + + result = task_result.output_data or {"content": ""} + + subtask.status = SubTaskStatus.COMPLETED + subtask.result = result + + # Emit expert_result event + await self._broadcast_event( + "expert_result", + { + "expert_id": expert.config.name, + "expert_name": expert.config.name, + "expert_color": expert.config.color, + "content": result.get("content", str(result)), + "subtask_id": subtask.id, + }, ) - # Parse ranking: map back to original indices - for rank_pos, ch in enumerate(response.content.strip().split(",")): - ch = ch.strip() - if ch.isdigit(): - local_idx = int(ch) - 1 - if 0 <= local_idx < len(other_indices): - original_idx = other_indices[local_idx] - scores[original_idx] = scores.get(original_idx, 0) + ( - len(other_indices) - rank_pos - ) + + return result + except Exception as e: - logger.warning(f"Voter {voter_idx} vote failed: {e}") - # On failure, distribute 1 point evenly across other results - for oi in other_indices: - scores[oi] = scores.get(oi, 0) + 1 + last_error = str(e) + if attempt < self.MAX_RETRIES: + logger.info(f"Retrying subtask {subtask.id} (attempt {attempt + 1})") + continue + raise - if not scores: - return results[0] + # Should not reach here, but just in case + subtask.status = SubTaskStatus.FAILED + raise RuntimeError(f"Subtask {subtask.id} failed: {last_error}") - best_idx = max(scores, key=scores.get) # type: ignore[arg-type] - return results[best_idx] - - async def _llm_fuse( - self, task: str, results: list[dict[str, Any]] + async def _synthesize_results( + self, lead: Expert, task: str, completed_subtasks: list[SubTask] ) -> dict[str, Any]: - """Use LLM to fuse multiple results into one.""" - gateway = self._get_llm_gateway() - if not gateway: - # Fallback: concatenate all results - combined = "\n\n".join( - str(r.get("output", {})) for r in results - ) - return {"content": combined, "fused_from": len(results)} + """Lead Expert synthesizes results using BEST strategy. - result_summaries = self._build_result_summaries(results, max_len=800) + The Lead Expert evaluates all completed subtask results and produces + a final synthesized result. Uses LLM when available, otherwise + concatenates results. + """ + results = [st.result or {} for st in completed_subtasks] + if not results: + return {"content": ""} + + # If only one result, return it directly + if len(results) == 1: + content = results[0].get("content", str(results[0])) + return { + "content": content, + "strategy": "best", + "subtasks_completed": 1, + } + + gateway = self._get_llm_gateway(lead) + if not gateway: + # Without LLM, concatenate all results + combined = "\n\n".join( + r.get("content", str(r)) if isinstance(r, dict) else str(r) for r in results + ) + return { + "content": combined, + "strategy": "best", + "subtasks_completed": len(results), + } + + # Build result summaries for LLM evaluation + summaries = [] + for i, st in enumerate(completed_subtasks): + r = st.result or {} + content = r.get("content", str(r)) if isinstance(r, dict) else str(r) + summaries.append( + f"Subtask {i + 1} (by {st.assigned_expert}, task: {st.description[:100]}):\n" + f"{content[:500]}" + ) prompt = ( - f"Task: {task}\n\n" - f"Below are {len(results)} results from different experts working on the same task. " - f"Fuse them into a single comprehensive result that combines the best elements.\n\n" - + "\n---\n".join(result_summaries) - + "\n\nProvide the fused result directly." + f"Original task: {task}\n\n" + f"Below are {len(results)} subtask results from your team members. " + f"Synthesize them into a single comprehensive final result that " + f"best addresses the original task.\n\n" + + "\n---\n".join(summaries) + + "\n\nProvide the synthesized result directly." ) try: @@ -588,156 +404,49 @@ class TeamOrchestrator: ) return { "content": response.content.strip(), - "fused_from": len(results), - "strategy": "fusion", - } - except Exception as e: - logger.warning(f"LLM fusion failed, falling back to concatenation: {e}") - combined = "\n\n".join( - str(r.get("output", {})) for r in results - ) - return {"content": combined, "fused_from": len(results)} - - async def _merge_results( - self, phase: PlanPhase, results: list[dict[str, Any]] - ) -> dict[str, Any]: - """Merge competitive parallel results based on merge strategy.""" - if not results: - return {} - - strategy = phase.merge_strategy or MergeStrategy.BEST - task_desc = phase.task_description or phase.name - - if strategy == MergeStrategy.BEST: - selected = await self._llm_pick_best(task_desc, results) - return { - "merged": True, "strategy": "best", - "selected": selected, - "all_results": results, + "subtasks_completed": len(results), } - - elif strategy == MergeStrategy.VOTE: - selected = await self._llm_vote(task_desc, results) - return { - "merged": True, - "strategy": "vote", - "selected": selected, - "all_results": results, - } - - elif strategy == MergeStrategy.FUSION: - fused = await self._llm_fuse(task_desc, results) - return { - "merged": True, - "strategy": "fusion", - "fused_from": len(results), - "selected": fused, - "all_results": results, - } - - return results[0] - - async def _check_milestone( - self, phase: PlanPhase, result: dict[str, Any] - ) -> bool: - """Check if a phase result passes its milestone checkpoint. - - Uses LLM evaluation when available, falls back to basic content check. - """ - milestone = phase.milestone - if not milestone: - return True - - # Basic check: result must have non-empty content - output = result.get("output", result) if isinstance(result, dict) else result - content = output.get("content", str(output)) if isinstance(output, dict) else str(output) - if not content or content.strip() == "": - return False - - # LLM-based milestone evaluation - gateway = self._get_llm_gateway() - if not gateway: - # Without LLM, do basic keyword matching - milestone_lower = milestone.lower() - content_lower = content.lower() - # Check if milestone keywords appear in content - keywords = [w for w in milestone_lower.split() if len(w) > 2] - if keywords and not any(kw in content_lower for kw in keywords): - return False - return True - - prompt = ( - f"Task: {phase.task_description or phase.name}\n" - f"Milestone requirement: {milestone}\n" - f"Result:\n{content[:500]}\n\n" - f"Does this result meet the milestone requirement? " - f"Reply with ONLY 'yes' or 'no'." - ) - - try: - response = await gateway.chat( - messages=[{"role": "user", "content": prompt}], - model="default", - ) - answer = response.content.strip().lower() - return answer.startswith("yes") except Exception as e: - logger.warning(f"Milestone LLM check failed for phase {phase.id}: {e}") - # On LLM failure, pass the milestone (conservative — don't block on infra issues) - return True - - async def _synthesize_results( - self, plan: CollaborationPlan, phase_results: dict[str, dict[str, Any]] - ) -> dict[str, Any]: - """Synthesize final results from all phase outputs.""" - # Collect completed phase results in order - completed: list[dict[str, Any]] = [] - for phase in plan.phases: - if phase.status == PhaseStatus.COMPLETED and phase.id in phase_results: - completed.append( - { - "phase": phase.name, - "expert": phase.assigned_expert, - "result": phase_results[phase.id], - } - ) - - return { - "task": plan.task, - "phases_completed": len(completed), - "phases_total": len(plan.phases), - "results": completed, - } + logger.warning(f"LLM synthesis failed, falling back to concatenation: {e}") + combined = "\n\n".join( + r.get("content", str(r)) if isinstance(r, dict) else str(r) for r in results + ) + return { + "content": combined, + "strategy": "best", + "subtasks_completed": len(results), + } async def _fallback_to_single_agent( - self, plan: CollaborationPlan, phase_results: dict[str, dict[str, Any]] + self, + task: str, + plan: TeamPlan, + subtask_results: dict[str, dict[str, Any]], ) -> dict[str, Any]: - """Fallback to single agent mode when team execution fails. + """Fallback to single agent mode when hub-and-spoke execution fails. Uses the lead expert (or first active expert) to complete the original task. """ plan.status = PlanStatus.FALLBACK logger.warning("Falling back to single agent mode") - # Try to use the lead expert, or fall back to any active expert expert = self._team.lead_expert if not expert or not expert.is_active: active = self._team.active_experts expert = active[0] if active else None - fallback_result = None + fallback_result: dict[str, Any] | None = None if expert: try: - # Execute the original task with a single expert via real agent task_msg = TaskMessage( task_id=f"fallback_{plan.id}", agent_name=expert.config.name, task_type="fallback", priority=0, input_data={ - "task": plan.task, - "phase_results": phase_results, + "task": task, + "subtask_results": subtask_results, "team_id": self._team.team_id, }, callback_url=None, @@ -756,14 +465,39 @@ class TeamOrchestrator: return { "status": "fallback", "result": fallback_result, - "phase_results": phase_results, + "subtask_results": subtask_results, + "plan": plan, } - async def _broadcast_event( - self, event_type: str, data: dict[str, Any] - ) -> None: - """Broadcast an orchestration event to the team channel.""" + def _get_llm_gateway(self, expert: Expert | None = None) -> Any: + """Get LLM gateway from the given expert or the lead expert's agent. + + Falls back to other active experts if the primary target has no gateway. + """ + target = expert or self._team.lead_expert + if target and hasattr(target, "agent") and hasattr(target.agent, "_llm_gateway"): + gateway = target.agent._llm_gateway + if gateway is not None: + return gateway + # Fallback: try first active expert with a gateway + for exp in self._team.active_experts: + if hasattr(exp, "agent") and hasattr(exp.agent, "_llm_gateway"): + gateway = exp.agent._llm_gateway + if gateway is not None: + return gateway + return None + + async def _broadcast_event(self, event_type: str, data: dict[str, Any]) -> None: + """Broadcast an orchestration event to the team channel. + + Events are emitted via handoff_transport for WebSocket relay. + Supported event types: team_formed, expert_step, expert_result, + plan_update, team_synthesis, team_dissolved. + """ if self._team.handoff_transport: - await self._team.handoff_transport.send( - self._team.team_channel, {"type": event_type, **data} - ) + try: + await self._team.handoff_transport.send( + self._team.team_channel, {"type": event_type, **data} + ) + except Exception as e: + logger.warning(f"Failed to broadcast event '{event_type}': {e}") diff --git a/src/agentkit/experts/plan.py b/src/agentkit/experts/plan.py index bd401c5..4c1a440 100644 --- a/src/agentkit/experts/plan.py +++ b/src/agentkit/experts/plan.py @@ -1,281 +1,172 @@ -"""CollaborationPlan 数据模型 - Expert Team 协作蓝图 +"""Expert Team 计划数据模型 - hub-and-spoke 模式 -定义 Expert Team 的结构化协作计划,包括阶段、角色分配、依赖关系、 -并行类型、合并策略和里程碑。 +简化为 Lead Expert + 并行 Task 模式: +- Lead Expert 接收任务并分解为子任务 +- 子任务并行执行(Task 深度=1,不能再 spawn Task) +- Lead Expert 汇总结果(BEST 策略) + +移除了阶段依赖图、VOTE/FUSION 合并策略、跨阶段状态共享。 """ from __future__ import annotations import enum +import uuid from dataclasses import dataclass, field from typing import Any -class ParallelType(str, enum.Enum): - """并行执行类型""" - - SERIAL = "serial" - SUBTASK_PARALLEL = "subtask_parallel" - COMPETITIVE_PARALLEL = "competitive_parallel" - - class MergeStrategy(str, enum.Enum): - """合并策略 - 仅用于 COMPETITIVE_PARALLEL 阶段""" + """合并策略 - Lead Expert 用于选择最佳结果 + + hub-and-spoke 模式下仅保留 BEST 策略: + Lead Expert 从所有子任务结果中选择或综合出最佳结果。 + """ BEST = "best" - VOTE = "vote" - FUSION = "fusion" - - -class PhaseStatus(str, enum.Enum): - """阶段状态""" - - PENDING = "pending" - IN_PROGRESS = "in_progress" - COMPLETED = "completed" - FAILED = "failed" class PlanStatus(str, enum.Enum): """计划状态""" DRAFT = "draft" - CONFIRMED = "confirmed" EXECUTING = "executing" COMPLETED = "completed" FAILED = "failed" FALLBACK = "fallback" -# DFS 着色常量 -_WHITE = 0 # 未访问 -_GRAY = 1 # 正在访问(当前路径上) -_BLACK = 2 # 已完成访问 +class SubTaskStatus(str, enum.Enum): + """子任务状态""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" @dataclass -class PlanPhase: - """协作计划中的单个阶段 +class SubTask: + """Lead Expert 分解出的子任务 + + 在 hub-and-spoke 模式中,Lead Expert 将原始任务分解为多个子任务, + 每个子任务由一个 Expert 并行执行。子任务之间无依赖关系、无通信。 + + 约束: + - Task 深度=1(子任务不能再 spawn 子任务) + - 子任务之间无通信 + - Lead Expert 持有所有状态 Attributes: - id: 阶段标识符 - name: 阶段显示名称 - assigned_expert: 分配到此阶段的 Expert 名称 - task_description: 此阶段完成的任务描述 - depends_on: 依赖的阶段 ID 列表 - parallel_type: 执行类型 - merge_strategy: 合并策略,仅 COMPETITIVE_PARALLEL 需要 - milestone: 里程碑检查点描述 + id: 子任务标识符 + description: 子任务描述 + assigned_expert: 分配的 Expert 名称 status: 当前状态 - result: 阶段输出结果 + result: 子任务输出结果 """ - id: str - name: str - assigned_expert: str - task_description: str - depends_on: list[str] = field(default_factory=list) - parallel_type: ParallelType = ParallelType.SERIAL - merge_strategy: MergeStrategy | None = None - milestone: str = "" - status: PhaseStatus = PhaseStatus.PENDING - result: dict | None = None + id: str = field(default_factory=lambda: str(uuid.uuid4())) + description: str = "" + assigned_expert: str = "" + status: SubTaskStatus = SubTaskStatus.PENDING + result: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: """序列化为字典""" return { "id": self.id, - "name": self.name, + "description": self.description, "assigned_expert": self.assigned_expert, - "task_description": self.task_description, - "depends_on": self.depends_on, - "parallel_type": self.parallel_type.value, - "merge_strategy": self.merge_strategy.value if self.merge_strategy is not None else None, - "milestone": self.milestone, "status": self.status.value, "result": self.result, } @classmethod - def from_dict(cls, data: dict[str, Any]) -> PlanPhase: - """从字典创建 PlanPhase""" - merge_strategy = None - if data.get("merge_strategy") is not None: - merge_strategy = MergeStrategy(data["merge_strategy"]) - + def from_dict(cls, data: dict[str, Any]) -> SubTask: + """从字典创建 SubTask""" return cls( - id=data["id"], - name=data["name"], - assigned_expert=data["assigned_expert"], - task_description=data["task_description"], - depends_on=data.get("depends_on", []), - parallel_type=ParallelType(data.get("parallel_type", ParallelType.SERIAL.value)), - merge_strategy=merge_strategy, - milestone=data.get("milestone", ""), - status=PhaseStatus(data.get("status", PhaseStatus.PENDING.value)), + id=data.get("id", str(uuid.uuid4())), + description=data.get("description", ""), + assigned_expert=data.get("assigned_expert", ""), + status=SubTaskStatus(data.get("status", SubTaskStatus.PENDING.value)), result=data.get("result"), ) @dataclass -class CollaborationPlan: - """Expert Team 协作计划 +class TeamPlan: + """Expert Team hub-and-spoke 执行计划 - 定义 Expert Team 的结构化协作蓝图,包括阶段编排、共享变量、 - 状态管理和依赖关系。 + Lead Expert 持有此计划,包含分解的子任务列表。 + 与旧版 CollaborationPlan 不同,此计划无阶段依赖图, + 所有子任务并行执行,由 Lead Expert 统一汇总。 Attributes: id: 计划标识符 task: 原始任务描述 - phases: 有序阶段列表 - variables: 共享变量 + subtasks: 子任务列表(并行执行,无依赖关系) status: 计划状态 lead_expert: 主导 Expert 名称 """ - id: str - task: str - phases: list[PlanPhase] = field(default_factory=list) - variables: dict = field(default_factory=dict) + id: str = field(default_factory=lambda: str(uuid.uuid4())) + task: str = "" + subtasks: list[SubTask] = field(default_factory=list) status: PlanStatus = PlanStatus.DRAFT lead_expert: str = "" - _phase_index: dict[str, PlanPhase] = field(default_factory=dict, init=False, repr=False) - - def __post_init__(self) -> None: - """Build the phase index after initialization.""" - self._rebuild_index() - - def _rebuild_index(self) -> None: - """Rebuild the phase index from the phases list.""" - self._phase_index = {phase.id: phase for phase in self.phases} def to_dict(self) -> dict[str, Any]: """序列化为字典""" return { "id": self.id, "task": self.task, - "phases": [phase.to_dict() for phase in self.phases], - "variables": self.variables, + "subtasks": [st.to_dict() for st in self.subtasks], "status": self.status.value, "lead_expert": self.lead_expert, } @classmethod - def from_dict(cls, data: dict[str, Any]) -> CollaborationPlan: - """从字典创建 CollaborationPlan""" - phases = [PlanPhase.from_dict(p) for p in data.get("phases", [])] + def from_dict(cls, data: dict[str, Any]) -> TeamPlan: + """从字典创建 TeamPlan""" + subtasks = [SubTask.from_dict(st) for st in data.get("subtasks", [])] return cls( - id=data["id"], - task=data["task"], - phases=phases, - variables=data.get("variables", {}), + id=data.get("id", str(uuid.uuid4())), + task=data.get("task", ""), + subtasks=subtasks, status=PlanStatus(data.get("status", PlanStatus.DRAFT.value)), lead_expert=data.get("lead_expert", ""), ) - def validate(self) -> list[str]: - """验证计划,返回错误消息列表(空列表表示有效) + def get_subtask(self, subtask_id: str) -> SubTask | None: + """根据 ID 获取子任务,不存在则返回 None""" + for st in self.subtasks: + if st.id == subtask_id: + return st + return None - 检查项: - - 无重复阶段 ID - - 所有 depends_on 引用存在 - - 无循环依赖(DFS 着色检测) - - COMPETITIVE_PARALLEL 阶段必须有 merge_strategy - """ - errors: list[str] = [] - - # 构建阶段 ID 集合 - phase_ids = {phase.id for phase in self.phases} - - # 检查重复阶段 ID - seen_ids: set[str] = set() - for phase in self.phases: - if phase.id in seen_ids: - errors.append(f"重复的阶段 ID: {phase.id}") - seen_ids.add(phase.id) - - # 检查 depends_on 引用是否存在 - for phase in self.phases: - for dep_id in phase.depends_on: - if dep_id not in phase_ids: - errors.append( - f"阶段 '{phase.id}' 依赖了不存在的阶段 ID: {dep_id}" - ) - - # 检查循环依赖(迭代 DFS 着色 — 避免递归栈溢出) - color: dict[str, int] = {phase.id: _WHITE for phase in self.phases} - dep_map: dict[str, list[str]] = { - phase.id: phase.depends_on for phase in self.phases - } - - for phase in self.phases: - if color[phase.id] != _WHITE: - continue - # Iterative DFS using an explicit stack - stack: list[tuple[str, bool]] = [(phase.id, False)] - while stack: - node, is_backtrack = stack.pop() - if is_backtrack: - color[node] = _BLACK - continue - if color[node] == _GRAY: - # Already on current path — cycle detected - errors.append("检测到循环依赖") - break - if color[node] == _BLACK: - continue - color[node] = _GRAY - # Push backtrack marker - stack.append((node, True)) - for neighbor in dep_map.get(node, []): - if neighbor not in color: - continue - if color[neighbor] == _GRAY: - errors.append("检测到循环依赖") - break - if color[neighbor] == _WHITE: - stack.append((neighbor, False)) - else: - continue - break # Inner break propagates to outer - if errors: - break # Only report cycle once - - # 检查 COMPETITIVE_PARALLEL 必须有 merge_strategy - for phase in self.phases: - if ( - phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL - and phase.merge_strategy is None - ): - errors.append( - f"阶段 '{phase.id}' 为 COMPETITIVE_PARALLEL 但未设置 merge_strategy" - ) - - return errors - - def get_ready_phases(self) -> list[PlanPhase]: - """获取所有依赖已完成且状态为 PENDING 的阶段""" - completed_ids = { - phase.id for phase in self.phases if phase.status == PhaseStatus.COMPLETED - } - ready: list[PlanPhase] = [] - for phase in self.phases: - if phase.status != PhaseStatus.PENDING: - continue - if all(dep_id in completed_ids for dep_id in phase.depends_on): - ready.append(phase) - return ready - - def get_phase(self, phase_id: str) -> PlanPhase | None: - """根据 ID 获取阶段,不存在则返回 None (O(1) lookup)""" - return self._phase_index.get(phase_id) - - def update_phase_status( - self, phase_id: str, status: PhaseStatus, result: dict | None = None + def update_subtask_status( + self, subtask_id: str, status: SubTaskStatus, result: dict[str, Any] | None = None ) -> None: - """更新阶段状态和可选的结果""" - phase = self.get_phase(phase_id) - if phase is not None: - phase.status = status + """更新子任务状态和可选的结果""" + st = self.get_subtask(subtask_id) + if st is not None: + st.status = status if result is not None: - phase.result = result + st.result = result + + @property + def completed_subtasks(self) -> list[SubTask]: + """已完成的子任务列表""" + return [st for st in self.subtasks if st.status == SubTaskStatus.COMPLETED] + + @property + def failed_subtasks(self) -> list[SubTask]: + """失败的子任务列表""" + return [st for st in self.subtasks if st.status == SubTaskStatus.FAILED] + + @property + def all_done(self) -> bool: + """所有子任务是否都已完成(成功或失败)""" + return all( + st.status in (SubTaskStatus.COMPLETED, SubTaskStatus.FAILED) for st in self.subtasks + ) diff --git a/src/agentkit/experts/registry.py b/src/agentkit/experts/registry.py index 6ebf4be..d4a6be9 100644 --- a/src/agentkit/experts/registry.py +++ b/src/agentkit/experts/registry.py @@ -4,12 +4,11 @@ from __future__ import annotations import logging import os -from typing import Any import yaml from agentkit.core.exceptions import ConfigValidationError -from agentkit.experts.config import ExpertConfig, ExpertTemplate +from agentkit.experts.config import ExpertTemplate logger = logging.getLogger(__name__) @@ -62,10 +61,7 @@ class ExpertTemplateRegistry: query_lower = query.lower() results: list[ExpertTemplate] = [] for template in self._templates.values(): - if ( - query_lower in template.name.lower() - or query_lower in template.description.lower() - ): + if query_lower in template.name.lower() or query_lower in template.description.lower(): results.append(template) return results @@ -134,7 +130,5 @@ class ExpertTemplateRegistry: template = self.load_from_yaml(filepath) loaded.append(template) except Exception as e: - logger.warning( - f"Failed to load ExpertTemplate from '{filepath}': {e}" - ) + logger.warning(f"Failed to load ExpertTemplate from '{filepath}': {e}") return loaded diff --git a/src/agentkit/experts/team.py b/src/agentkit/experts/team.py index cc1295b..63cab9d 100644 --- a/src/agentkit/experts/team.py +++ b/src/agentkit/experts/team.py @@ -1,7 +1,13 @@ -"""ExpertTeam - 专家团队容器 +"""ExpertTeam - 专家团队容器(hub-and-spoke 模式) -管理 Expert 生命周期、共享上下文、协作计划和团队状态, -是 Expert Team 协作模式的中央协调点。 +管理 Expert 生命周期、团队状态和事件广播, +是 Expert Team hub-and-spoke 协作模式的中央协调点。 + +简化说明(U3): +- 移除 CollaborationPlan 依赖(Lead Expert 自主分解任务) +- 移除跨阶段状态共享(Lead Expert 持有所有状态) +- 保留 handoff_transport 用于事件广播(不再用于 Agent 间通信) +- 保留 workspace 用于输出保存(不再用于跨阶段状态共享) """ from __future__ import annotations @@ -14,7 +20,6 @@ import uuid from .config import ExpertConfig from .expert import Expert -from .plan import CollaborationPlan, PlanStatus from .registry import ExpertTemplateRegistry from ..core.handoff_transport import InProcessHandoffTransport from ..core.shared_workspace import SharedWorkspace @@ -27,7 +32,6 @@ class TeamStatus(str, enum.Enum): """ExpertTeam lifecycle states.""" FORMING = "forming" - PLANNING = "planning" EXECUTING = "executing" SYNTHESIZING = "synthesizing" COMPLETED = "completed" @@ -35,7 +39,14 @@ class TeamStatus(str, enum.Enum): class ExpertTeam: - """Container managing a team of Experts working together on a task.""" + """Container managing a team of Experts in hub-and-spoke mode. + + In hub-and-spoke mode: + - Lead Expert (hub) receives the task and decomposes it + - Member Experts (spokes) execute subtasks in parallel + - Lead Expert synthesizes the final result + - No inter-agent communication (Lead Expert holds all state) + """ def __init__( self, @@ -51,7 +62,6 @@ class ExpertTeam: self._handoff_transport = InProcessHandoffTransport() self._experts: dict[str, Expert] = {} self._lead_expert_name: str | None = None - self._plan: CollaborationPlan | None = None self._status = TeamStatus.FORMING self._team_channel = f"team:{self.team_id}" self._orchestrator_task: asyncio.Task | None = None @@ -66,10 +76,6 @@ class ExpertTeam: return self._experts.get(self._lead_expert_name) return None - @property - def plan(self) -> CollaborationPlan | None: - return self._plan - @property def experts(self) -> list[Expert]: return list(self._experts.values()) @@ -80,12 +86,21 @@ class ExpertTeam: @property def workspace(self) -> SharedWorkspace: - """Public read access to the team's shared workspace.""" + """Public read access to the team's shared workspace. + + In hub-and-spoke mode, workspace is used for output preservation only, + not for cross-phase state sharing. + """ return self._workspace @property def handoff_transport(self): - """Public read access to the team's handoff transport.""" + """Public read access to the team's handoff transport. + + In hub-and-spoke mode, handoff_transport is used for event broadcasting + (team_formed, expert_step, expert_result, team_synthesis, team_dissolved), + not for inter-agent communication. + """ return self._handoff_transport @property @@ -106,7 +121,13 @@ class ExpertTeam: lead_config: ExpertConfig, member_configs: list[ExpertConfig] | None = None, ) -> None: - """Create a team with a Lead Expert and optional members.""" + """Create a team with a Lead Expert and optional members. + + In hub-and-spoke mode, the Lead Expert acts as the hub: + - Receives the task and decomposes it into subtasks + - Dispatches subtasks to member experts (spokes) + - Synthesizes the final result + """ if not self._pool: raise RuntimeError("AgentPool not configured") @@ -128,7 +149,7 @@ class ExpertTeam: for config in member_configs: await self._add_expert_internal(config, team_context) - self._status = TeamStatus.PLANNING + self._status = TeamStatus.EXECUTING async def add_expert(self, config_or_template: ExpertConfig | str) -> Expert: """Add an Expert to the team dynamically. @@ -155,9 +176,7 @@ class ExpertTeam: ) return await self._add_expert_internal(config, team_context) - async def _add_expert_internal( - self, config: ExpertConfig, team_context: str - ) -> Expert: + async def _add_expert_internal(self, config: ExpertConfig, team_context: str) -> Expert: """Internal method to add an Expert.""" if not self._pool: raise RuntimeError("AgentPool not configured") @@ -204,13 +223,6 @@ class ExpertTeam: await expert.destroy(self._pool) del self._experts[name] - # Update plan: reassign phases that referenced the removed expert - if self._plan: - new_lead_name = self._lead_expert_name - for phase in self._plan.phases: - if phase.assigned_expert == name: - phase.assigned_expert = new_lead_name or "" - # Broadcast expert left await self._handoff_transport.send( self._team_channel, @@ -220,24 +232,6 @@ class ExpertTeam: }, ) - def update_plan(self, plan: CollaborationPlan) -> list[str]: - """Update the collaboration plan. Only Lead Expert or user should call this. - - Returns list of affected expert names on success, or list of validation - error strings on failure (empty list with no errors = success). - """ - errors = plan.validate() - if errors: - return errors # Return validation errors instead of silently swallowing - - self._plan = plan - if plan.status == PlanStatus.CONFIRMED: - self._status = TeamStatus.EXECUTING - - # Determine affected experts - affected = [p.assigned_expert for p in plan.phases] - return affected - async def broadcast_user_message(self, content: str) -> None: """Broadcast a user intervention message to all active Experts.""" message = { @@ -248,7 +242,11 @@ class ExpertTeam: await self._handoff_transport.send(self._team_channel, message) async def get_shared_context(self) -> dict: - """Get the team's shared context from SharedWorkspace.""" + """Get the team's shared context from SharedWorkspace. + + In hub-and-spoke mode, this returns preserved outputs only, + not cross-phase state. + """ context = {} keys = await self._workspace.list_keys() for key in keys: @@ -258,23 +256,6 @@ class ExpertTeam: context[key] = data return context - async def generate_plan(self, task: str) -> CollaborationPlan: - """Generate a CollaborationPlan for the task. - - Uses hybrid mode: core roles from template registry, auxiliary roles dynamically generated. - This method creates a plan structure — the actual LLM-based task decomposition - will be handled by TeamOrchestrator. - """ - plan_id = str(uuid.uuid4()) - plan = CollaborationPlan( - id=plan_id, - task=task, - phases=[], - lead_expert=self._lead_expert_name or "", - ) - self._plan = plan - return plan - async def dissolve(self) -> None: """Dissolve the team. Temporary Experts are recycled, outputs preserved in SharedWorkspace.""" # Cancel ongoing orchestrator task if any @@ -302,20 +283,31 @@ class ExpertTeam: lead_config: ExpertConfig | None, member_configs: list[ExpertConfig], ) -> str: - """Build team context string for injection into Expert system prompts.""" - lines = ["You are part of an Expert Team."] + """Build team context string for injection into Expert system prompts. + + In hub-and-spoke mode, the context emphasizes the Lead Expert's role + as the hub and member experts' role as spokes. + """ + lines = ["You are part of an Expert Team (hub-and-spoke mode)."] if lead_config: - lines.append(f"Lead Expert: {lead_config.name} ({lead_config.persona})") + lines.append( + f"Lead Expert (hub): {lead_config.name} ({lead_config.persona}) — " + f"decomposes tasks, dispatches subtasks, and synthesizes results." + ) for config in member_configs: if lead_config and config.name == lead_config.name: continue lines.append( - f"Team Member: {config.name} ({config.persona}), Skills: {', '.join(config.bound_skills)}" + f"Team Member (spoke): {config.name} ({config.persona}), " + f"Skills: {', '.join(config.bound_skills)} — " + f"executes assigned subtasks independently." ) lines.append( - "You can collaborate with other team members via send_message() and request_assist()." + "In hub-and-spoke mode: Lead Expert holds all state, " + "subtasks are independent (depth=1, no further spawning), " + "no inter-agent communication." ) return "\n".join(lines) diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 4d4ddb9..315b63e 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -16,6 +16,8 @@ from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType from agentkit.tools.ask_human import AskHumanTool from agentkit.tools.memory_tool import MemoryTool from agentkit.tools.web_search import WebSearchTool +from agentkit.tools.builtin import RunTestsTool, ToolSearchTool +from agentkit.tools.search import ToolSearchIndex # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -40,6 +42,9 @@ __all__ = [ "MemoryTool", "ShellTool", "WebSearchTool", + "RunTestsTool", + "ToolSearchTool", + "ToolSearchIndex", "HeadroomRetrieveTool", "TerminalSession", "TerminalSessionManager", diff --git a/src/agentkit/tools/search.py b/src/agentkit/tools/search.py new file mode 100644 index 0000000..30c4140 --- /dev/null +++ b/src/agentkit/tools/search.py @@ -0,0 +1,189 @@ +"""Tool search index using BM25 algorithm. + +Provides lexical tool discovery for tiered tool description injection: +core tools are fully injected into the LLM prompt, while extended tools +are searchable on-demand via the ``tool_search`` tool. + +Pure Python implementation — no external dependencies. +""" + +from __future__ import annotations + +import math +import re +from collections import Counter +from typing import Any + +from agentkit.tools.base import Tool + +__all__ = ["ToolSearchIndex"] + + +_TOKEN_RE = re.compile(r"[a-zA-Z0-9_]+") + + +def _tokenize(text: str) -> list[str]: + """Tokenize text into lowercase alphanumeric tokens. + + Splits on non-alphanumeric characters (keeping underscores so that + snake_case tool names like ``read_file`` stay intact) and lowercases. + Empty tokens are filtered out. + """ + return [t for t in _TOKEN_RE.findall(text.lower()) if t] + + +class ToolSearchIndex: + """BM25-based search index over :class:`Tool` descriptions. + + Each tool is indexed by its ``name`` + ``description`` + parameter + metadata (from ``input_schema``) + ``tags``. Supports keyword search + returning the most relevant tools ranked by BM25 score. + + Usage:: + + index = ToolSearchIndex(tools) + results = index.search("read file", top_k=5) + """ + + _DEFAULT_K1: float = 1.5 + _DEFAULT_B: float = 0.75 + + def __init__( + self, + tools: list[Tool], + k1: float = _DEFAULT_K1, + b: float = _DEFAULT_B, + ): + if k1 < 0: + raise ValueError(f"k1 must be >= 0, got {k1}") + if not (0.0 <= b <= 1.0): + raise ValueError(f"b must be in [0, 1], got {b}") + + self._tools: list[Tool] = list(tools) + self._k1 = k1 + self._b = b + + # Build document corpus from tool metadata + self._documents: list[str] = [self._tool_to_text(t) for t in self._tools] + self._tokenized_docs: list[list[str]] = [_tokenize(doc) for doc in self._documents] + self._doc_lengths: list[int] = [len(tokens) for tokens in self._tokenized_docs] + self._avgdl: float = ( + sum(self._doc_lengths) / len(self._doc_lengths) if self._doc_lengths else 0.0 + ) + + # Term frequencies per document + self._term_freqs: list[Counter[str]] = [Counter(tokens) for tokens in self._tokenized_docs] + + # Document frequency per term (number of docs containing the term) + self._doc_freq: Counter[str] = Counter() + for tokens in self._tokenized_docs: + for term in set(tokens): + self._doc_freq[term] += 1 + + # Precompute IDF for every term in the corpus + self._idf: dict[str, float] = self._compute_idf() + + # ------------------------------------------------------------------ + # Index construction + # ------------------------------------------------------------------ + + @staticmethod + def _tool_to_text(tool: Tool) -> str: + """Convert a tool's searchable metadata into a single text document.""" + parts: list[str] = [str(tool.name), str(tool.description)] + + schema: dict[str, Any] | None = tool.input_schema + if schema: + props = schema.get("properties", {}) + for pname, pinfo in props.items(): + parts.append(str(pname)) + if isinstance(pinfo, dict): + pdesc = pinfo.get("description", "") + if pdesc: + parts.append(str(pdesc)) + + if tool.tags: + parts.extend(str(t) for t in tool.tags) + + return " ".join(parts) + + def _compute_idf(self) -> dict[str, float]: + """Compute IDF for each term using the BM25 formula. + + ``idf(t) = ln((N - n + 0.5) / (n + 0.5) + 1)`` + + where ``N`` is the total number of documents and ``n`` is the + number of documents containing term ``t``. The ``+1`` inside the + logarithm guarantees a non-negative IDF. + """ + n_docs = len(self._tools) + if n_docs == 0: + return {} + + idf: dict[str, float] = {} + for term, df in self._doc_freq.items(): + idf[term] = math.log((n_docs - df + 0.5) / (df + 0.5) + 1.0) + return idf + + # ------------------------------------------------------------------ + # Scoring + # ------------------------------------------------------------------ + + def _bm25_score(self, query_tokens: list[str], doc_idx: int) -> float: + """Compute the BM25 score for a single document against the query.""" + if not query_tokens or doc_idx >= len(self._tools): + return 0.0 + + doc_len = self._doc_lengths[doc_idx] + term_freq = self._term_freqs[doc_idx] + norm = ( + (1 - self._b + self._b * (doc_len / self._avgdl)) if self._avgdl > 0 else (1 - self._b) + ) + score = 0.0 + + # Each unique query term contributes once + for term in set(query_tokens): + tf = term_freq.get(term, 0) + if tf == 0: + continue + idf = self._idf.get(term, 0.0) + if idf <= 0: + continue + denom = tf + self._k1 * norm + score += idf * (tf * (self._k1 + 1)) / denom + + return score + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def search(self, query: str, top_k: int = 5) -> list[Tool]: + """Search tools by keyword query. + + Args: + query: Search keywords (natural language or tool-name fragments). + top_k: Maximum number of tools to return. + + Returns: + Tools ranked by relevance (most relevant first). Only tools + with a positive BM25 score are returned. + """ + if top_k <= 0 or not self._tools: + return [] + + query_tokens = _tokenize(query) + if not query_tokens: + return [] + + scored: list[tuple[int, float]] = [] + for i in range(len(self._tools)): + score = self._bm25_score(query_tokens, i) + if score > 0: + scored.append((i, score)) + + scored.sort(key=lambda x: x[1], reverse=True) + return [self._tools[i] for i, _ in scored[:top_k]] + + def __len__(self) -> int: + return len(self._tools) diff --git a/tests/unit/core/test_event_queue.py b/tests/unit/core/test_event_queue.py new file mode 100644 index 0000000..c630942 --- /dev/null +++ b/tests/unit/core/test_event_queue.py @@ -0,0 +1,666 @@ +"""Tests for EventQueue — SQ/EQ 双队列实现 + +测试场景: +- SQ 正确接收用户输入并返回 task_id +- EQ 正确推送事件给订阅者 +- 多订阅者同时接收事件(广播) +- 事件缓冲对新订阅者的回放 +- SQ 取消任务 +- 事件类型正确分类 +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime + +from agentkit.core.event_queue import EventQueue, Submission, SubmissionQueue +from agentkit.core.protocol import ( + Event, + SessionEventType, + TaskEventType, + TurnEventType, +) + + +# ── SubmissionQueue Tests ─────────────────────────────────────── + + +class TestSubmissionQueue: + """SubmissionQueue 单元测试""" + + async def test_submit_returns_task_id(self): + """测试 submit 返回有效的 task_id""" + sq = SubmissionQueue() + task_id = await sq.submit("hello", "session-1") + + assert isinstance(task_id, str) + assert len(task_id) > 0 + + async def test_submit_returns_unique_task_ids(self): + """测试每次 submit 返回不同的 task_id""" + sq = SubmissionQueue() + task_id_1 = await sq.submit("hello", "session-1") + task_id_2 = await sq.submit("world", "session-1") + + assert task_id_1 != task_id_2 + + async def test_submit_stores_submission(self): + """测试 submit 正确存储提交内容""" + sq = SubmissionQueue() + task_id = await sq.submit("hello world", "session-1") + + assert task_id in sq._submissions + submission = sq._submissions[task_id] + assert submission.content == "hello world" + assert submission.session_id == "session-1" + assert submission.task_id == task_id + assert submission.cancelled is False + + async def test_drain_receives_submissions_in_order(self): + """测试 drain 按提交顺序接收提交""" + sq = SubmissionQueue() + await sq.submit("first", "session-1") + await sq.submit("second", "session-1") + + received: list[str] = [] + + async def consumer(): + async for submission in sq.drain(): + received.append(submission.content) + if len(received) >= 2: + break + + consumer_task = asyncio.create_task(consumer()) + await asyncio.wait_for(consumer_task, timeout=1.0) + + assert received == ["first", "second"] + + async def test_drain_preserves_submission_fields(self): + """测试 drain 返回的 Submission 字段完整""" + sq = SubmissionQueue() + await sq.submit("hello", "session-1") + + received: list[Submission] = [] + + async def consumer(): + async for submission in sq.drain(): + received.append(submission) + break + + consumer_task = asyncio.create_task(consumer()) + await asyncio.wait_for(consumer_task, timeout=1.0) + + assert len(received) == 1 + sub = received[0] + assert sub.content == "hello" + assert sub.session_id == "session-1" + assert isinstance(sub.task_id, str) + assert isinstance(sub.created_at, datetime) + + async def test_cancel_task_succeeds(self): + """测试取消已存在的任务""" + sq = SubmissionQueue() + task_id = await sq.submit("hello", "session-1") + + result = await sq.cancel(task_id) + + assert result is True + assert task_id in sq._cancelled_tasks + assert sq._submissions[task_id].cancelled is True + + async def test_cancel_nonexistent_task_returns_false(self): + """测试取消不存在的任务返回 False""" + sq = SubmissionQueue() + result = await sq.cancel("nonexistent-task-id") + + assert result is False + + async def test_cancel_already_cancelled_task_returns_false(self): + """测试重复取消返回 False""" + sq = SubmissionQueue() + task_id = await sq.submit("hello", "session-1") + + first_cancel = await sq.cancel(task_id) + second_cancel = await sq.cancel(task_id) + + assert first_cancel is True + assert second_cancel is False + + async def test_drain_skips_cancelled_submissions(self): + """测试 drain 跳过已取消的提交""" + sq = SubmissionQueue() + task_id_1 = await sq.submit("first", "session-1") + await sq.submit("second", "session-1") + + # 取消第一个提交 + await sq.cancel(task_id_1) + + received: list[str] = [] + + async def consumer(): + async for submission in sq.drain(): + received.append(submission.content) + if len(received) >= 1: + break + + consumer_task = asyncio.create_task(consumer()) + await asyncio.wait_for(consumer_task, timeout=1.0) + + # 应只收到第二个提交 + assert received == ["second"] + + async def test_close_prevents_new_submissions(self): + """测试关闭后不再接受新提交""" + sq = SubmissionQueue() + sq.close() + + assert sq.is_closed is True + + try: + await sq.submit("hello", "session-1") + raise AssertionError("Should have raised RuntimeError") + except RuntimeError: + pass + + async def test_close_does_not_affect_existing_submissions(self): + """测试关闭后已提交的内容仍可消费""" + sq = SubmissionQueue() + await sq.submit("before-close", "session-1") + sq.close() + + received: list[str] = [] + + async def consumer(): + async for submission in sq.drain(): + received.append(submission.content) + break + + consumer_task = asyncio.create_task(consumer()) + await asyncio.wait_for(consumer_task, timeout=1.0) + + assert received == ["before-close"] + + +# ── EventQueue Tests ──────────────────────────────────────────── + + +class TestEventQueue: + """EventQueue 单元测试""" + + async def test_emit_and_subscribe_single_event(self): + """测试 EQ 正确推送事件给订阅者""" + eq = EventQueue() + event = Event.create( + event_type=TurnEventType.TOKEN, + task_id="task-1", + session_id="session-1", + data={"text": "hello"}, + ) + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) # 给订阅者启动时间 + + await eq.emit(event) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received) == 1 + assert received[0].event_type == TurnEventType.TOKEN + assert received[0].task_id == "task-1" + assert received[0].session_id == "session-1" + assert received[0].data == {"text": "hello"} + + async def test_emit_preserves_event_fields(self): + """测试 emit 不修改事件字段""" + eq = EventQueue() + original = Event.create( + event_type=TaskEventType.TASK_STARTED, + task_id="task-1", + session_id="session-1", + data={"agent": "react"}, + ) + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + await eq.emit(original) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert received[0] is original or received[0].to_dict() == original.to_dict() + + async def test_broadcast_to_multiple_subscribers(self): + """测试多订阅者同时接收事件(广播)""" + eq = EventQueue() + + received_a: list[Event] = [] + received_b: list[Event] = [] + + async def subscriber_a(): + async for evt in eq.subscribe(): + received_a.append(evt) + if len(received_a) >= 2: + break + + async def subscriber_b(): + async for evt in eq.subscribe(): + received_b.append(evt) + if len(received_b) >= 2: + break + + task_a = asyncio.create_task(subscriber_a()) + task_b = asyncio.create_task(subscriber_b()) + await asyncio.sleep(0.05) # 给订阅者启动时间 + + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) + + await asyncio.wait_for(task_a, timeout=1.0) + await asyncio.wait_for(task_b, timeout=1.0) + + assert len(received_a) == 2 + assert len(received_b) == 2 + assert received_a[0].data == {"seq": 1} + assert received_a[1].data == {"seq": 2} + assert received_b[0].data == {"seq": 1} + assert received_b[1].data == {"seq": 2} + + async def test_buffer_replay_for_new_subscriber(self): + """测试事件缓冲对新订阅者的回放""" + eq = EventQueue(buffer_size=100) + + # 先发送几条事件(无订阅者) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3})) + + # 新订阅者应收到缓冲回放 + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + if len(received) >= 3: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received) == 3 + assert received[0].data == {"seq": 1} + assert received[1].data == {"seq": 2} + assert received[2].data == {"seq": 3} + + async def test_buffer_replay_then_live_events(self): + """测试新订阅者先收到回放,再收到新事件""" + eq = EventQueue(buffer_size=100) + + # 先发送 2 条事件(进入缓冲) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 2})) + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + if len(received) >= 4: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) # 给订阅者启动时间(回放缓冲) + + # 再发送 2 条新事件 + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 3})) + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 4})) + + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received) == 4 + assert [r.data["seq"] for r in received] == [1, 2, 3, 4] + + async def test_buffer_size_limit_keeps_latest(self): + """测试缓冲区大小限制,只保留最新 N 条""" + eq = EventQueue(buffer_size=3) + + # 发送 5 条事件,缓冲区只保留最后 3 条 + for i in range(5): + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": i})) + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + if len(received) >= 3: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received) == 3 + # 应该是最后 3 条(seq: 2, 3, 4) + assert [r.data["seq"] for r in received] == [2, 3, 4] + + async def test_default_buffer_size_is_100(self): + """测试默认缓冲区大小为 100""" + eq = EventQueue() + + assert eq.buffer_size == 100 + + async def test_close_unblocks_subscribers(self): + """测试 close 解除订阅者阻塞""" + eq = EventQueue() + + async def subscriber(): + async for _ in eq.subscribe(): + pass # 消费事件直到队列关闭 + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + eq.close() + await asyncio.wait_for(sub_task, timeout=1.0) + + assert sub_task.done() + assert eq.is_closed is True + + async def test_subscribe_after_close_returns_immediately(self): + """测试关闭后订阅立即返回(不阻塞)""" + eq = EventQueue() + eq.close() + + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + + sub_task = asyncio.create_task(subscriber()) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert sub_task.done() + assert len(received) == 0 + + async def test_subscriber_count_tracks_subscriptions(self): + """测试订阅者计数正确跟踪订阅""" + eq = EventQueue() + + assert eq.subscriber_count == 0 + + async def subscriber(): + async for _ in eq.subscribe(): + pass + + task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + + assert eq.subscriber_count == 1 + + eq.close() + await asyncio.wait_for(task, timeout=1.0) + + assert eq.subscriber_count == 0 + + async def test_subscriber_removed_on_explicit_close(self): + """测试显式关闭订阅生成器后从列表移除 + + 注意:async for 的 break 不会立即触发生成器的 finally, + 需要显式调用 aclose() 才能保证清理。 + """ + eq = EventQueue() + + received: list[Event] = [] + + async def subscriber(): + gen = eq.subscribe() + try: + async for evt in gen: + received.append(evt) + break + finally: + await gen.aclose() + + task = asyncio.create_task(subscriber()) + await asyncio.sleep(0.05) + assert eq.subscriber_count == 1 + + # 触发一次 emit 让订阅者能收到事件并 break + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) + await asyncio.wait_for(task, timeout=1.0) + + assert len(received) == 1 + assert eq.subscriber_count == 0 + + async def test_emit_to_no_subscribers_still_buffers(self): + """测试无订阅者时 emit 仍写入缓冲区""" + eq = EventQueue(buffer_size=100) + + await eq.emit(Event.create(TurnEventType.TOKEN, "task-1", "session-1", {"seq": 1})) + + # 缓冲区应有 1 条 + assert len(eq._buffer) == 1 + + # 新订阅者应收到回放 + received: list[Event] = [] + + async def subscriber(): + async for evt in eq.subscribe(): + received.append(evt) + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.wait_for(sub_task, timeout=1.0) + + assert len(received) == 1 + assert received[0].data == {"seq": 1} + + +# ── Event Type Tests ──────────────────────────────────────────── + + +class TestEventTypes: + """事件类型分类测试""" + + def test_session_event_types(self): + """测试 Session 级别事件类型""" + assert SessionEventType.SESSION_STARTED == "session.started" + assert SessionEventType.SESSION_ENDED == "session.ended" + + def test_task_event_types(self): + """测试 Task 级别事件类型""" + assert TaskEventType.TASK_CREATED == "task.created" + assert TaskEventType.TASK_STARTED == "task.started" + assert TaskEventType.TASK_COMPLETED == "task.completed" + assert TaskEventType.TASK_FAILED == "task.failed" + + def test_turn_event_types(self): + """测试 Turn 级别事件类型""" + assert TurnEventType.TURN_STARTED == "turn.started" + assert TurnEventType.THINKING == "turn.thinking" + assert TurnEventType.TOOL_CALL == "turn.tool_call" + assert TurnEventType.TOOL_RESULT == "turn.tool_result" + assert TurnEventType.TOKEN == "turn.token" + assert TurnEventType.STEP == "turn.step" + assert TurnEventType.FINAL_ANSWER == "turn.final_answer" + assert TurnEventType.TURN_COMPLETED == "turn.completed" + + def test_event_type_prefixes(self): + """测试事件类型按前缀正确分类""" + session_events = [SessionEventType.SESSION_STARTED, SessionEventType.SESSION_ENDED] + task_events = [ + TaskEventType.TASK_CREATED, + TaskEventType.TASK_STARTED, + TaskEventType.TASK_COMPLETED, + TaskEventType.TASK_FAILED, + ] + turn_events = [ + TurnEventType.TURN_STARTED, + TurnEventType.THINKING, + TurnEventType.TOOL_CALL, + TurnEventType.TOOL_RESULT, + TurnEventType.TOKEN, + TurnEventType.STEP, + TurnEventType.FINAL_ANSWER, + TurnEventType.TURN_COMPLETED, + ] + + for evt in session_events: + assert evt.startswith("session."), f"{evt} should start with 'session.'" + for evt in task_events: + assert evt.startswith("task."), f"{evt} should start with 'task.'" + for evt in turn_events: + assert evt.startswith("turn."), f"{evt} should start with 'turn.'" + + def test_event_types_are_distinct(self): + """测试所有事件类型互不相同""" + all_types = [ + SessionEventType.SESSION_STARTED, + SessionEventType.SESSION_ENDED, + TaskEventType.TASK_CREATED, + TaskEventType.TASK_STARTED, + TaskEventType.TASK_COMPLETED, + TaskEventType.TASK_FAILED, + TurnEventType.TURN_STARTED, + TurnEventType.THINKING, + TurnEventType.TOOL_CALL, + TurnEventType.TOOL_RESULT, + TurnEventType.TOKEN, + TurnEventType.STEP, + TurnEventType.FINAL_ANSWER, + TurnEventType.TURN_COMPLETED, + ] + + assert len(all_types) == len(set(all_types)), "Event types should be distinct" + + +# ── Event Dataclass Tests ─────────────────────────────────────── + + +class TestEventDataclass: + """Event 数据结构测试""" + + def test_event_creation(self): + """测试 Event 创建""" + event = Event( + event_type=TurnEventType.TOKEN, + task_id="task-1", + session_id="session-1", + data={"text": "hello"}, + timestamp="2025-01-01T00:00:00+00:00", + ) + + assert event.event_type == TurnEventType.TOKEN + assert event.task_id == "task-1" + assert event.session_id == "session-1" + assert event.data == {"text": "hello"} + assert event.timestamp == "2025-01-01T00:00:00+00:00" + + def test_event_create_factory_generates_timestamp(self): + """测试 Event.create 工厂方法自动生成时间戳""" + event = Event.create( + event_type=TaskEventType.TASK_STARTED, + task_id="task-1", + session_id="session-1", + data={"agent": "react"}, + ) + + assert event.event_type == TaskEventType.TASK_STARTED + assert event.task_id == "task-1" + assert event.session_id == "session-1" + assert event.data == {"agent": "react"} + assert len(event.timestamp) > 0 + # 时间戳应为 ISO 8601 格式(fromisoformat 不抛异常即正确) + datetime.fromisoformat(event.timestamp) + + def test_event_create_with_default_data(self): + """测试 Event.create 不传 data 时默认为空 dict""" + event = Event.create( + event_type=SessionEventType.SESSION_STARTED, + task_id="task-1", + session_id="session-1", + ) + + assert event.data == {} + + def test_event_to_dict(self): + """测试 Event.to_dict""" + event = Event( + event_type=TurnEventType.TOKEN, + task_id="task-1", + session_id="session-1", + data={"text": "hello"}, + timestamp="2025-01-01T00:00:00+00:00", + ) + + d = event.to_dict() + + assert d == { + "event_type": TurnEventType.TOKEN, + "task_id": "task-1", + "session_id": "session-1", + "data": {"text": "hello"}, + "timestamp": "2025-01-01T00:00:00+00:00", + } + + def test_event_from_dict(self): + """测试 Event.from_dict""" + data = { + "event_type": TurnEventType.TOKEN, + "task_id": "task-1", + "session_id": "session-1", + "data": {"text": "hello"}, + "timestamp": "2025-01-01T00:00:00+00:00", + } + + event = Event.from_dict(data) + + assert event.event_type == TurnEventType.TOKEN + assert event.task_id == "task-1" + assert event.session_id == "session-1" + assert event.data == {"text": "hello"} + assert event.timestamp == "2025-01-01T00:00:00+00:00" + + def test_event_to_dict_from_dict_roundtrip(self): + """测试 to_dict -> from_dict 往返保持一致""" + original = Event.create( + event_type=SessionEventType.SESSION_STARTED, + task_id="task-1", + session_id="session-1", + data={"user": "alice"}, + ) + + d = original.to_dict() + restored = Event.from_dict(d) + + assert restored.event_type == original.event_type + assert restored.task_id == original.task_id + assert restored.session_id == original.session_id + assert restored.data == original.data + assert restored.timestamp == original.timestamp + + def test_event_from_dict_with_missing_data_defaults_to_empty(self): + """测试 from_dict 缺少 data 字段时默认为空 dict""" + data = { + "event_type": TurnEventType.TOKEN, + "task_id": "task-1", + "session_id": "session-1", + "timestamp": "2025-01-01T00:00:00+00:00", + } + + event = Event.from_dict(data) + + assert event.data == {} diff --git a/tests/unit/experts/test_plan.py b/tests/unit/experts/test_plan.py index a4d9d30..f75bccb 100644 --- a/tests/unit/experts/test_plan.py +++ b/tests/unit/experts/test_plan.py @@ -1,328 +1,288 @@ -"""CollaborationPlan 数据模型单元测试""" +"""TeamPlan / SubTask 数据模型单元测试 (hub-and-spoke 模式)""" from __future__ import annotations -import pytest - from agentkit.experts.plan import ( - CollaborationPlan, MergeStrategy, - ParallelType, - PhaseStatus, - PlanPhase, PlanStatus, + SubTask, + SubTaskStatus, + TeamPlan, ) # ── 辅助函数 ────────────────────────────────────────────── -def _make_phase( - id: str = "phase_1", - name: str = "分析阶段", +def _make_subtask( + id: str = "subtask_1", + description: str = "分析数据", assigned_expert: str = "analyst", - task_description: str = "分析需求", - depends_on: list[str] | None = None, - parallel_type: ParallelType = ParallelType.SERIAL, - merge_strategy: MergeStrategy | None = None, - milestone: str = "", - status: PhaseStatus = PhaseStatus.PENDING, + status: SubTaskStatus = SubTaskStatus.PENDING, result: dict | None = None, -) -> PlanPhase: - """创建测试用 PlanPhase 实例""" - return PlanPhase( +) -> SubTask: + """创建测试用 SubTask 实例""" + return SubTask( id=id, - name=name, + description=description, assigned_expert=assigned_expert, - task_description=task_description, - depends_on=depends_on or [], - parallel_type=parallel_type, - merge_strategy=merge_strategy, - milestone=milestone, status=status, result=result, ) -def _make_valid_plan() -> CollaborationPlan: - """创建一个有效的协作计划""" - phases = [ - _make_phase(id="p1", name="需求分析", assigned_expert="analyst", task_description="分析需求"), - _make_phase( - id="p2", - name="架构设计", - assigned_expert="architect", - task_description="设计架构", - depends_on=["p1"], - ), - _make_phase( - id="p3", - name="代码实现", - assigned_expert="coder", - task_description="编写代码", - depends_on=["p2"], - ), +def _make_valid_plan() -> TeamPlan: + """创建一个有效的 hub-and-spoke 执行计划""" + subtasks = [ + _make_subtask(id="s1", description="分析需求", assigned_expert="analyst"), + _make_subtask(id="s2", description="设计架构", assigned_expert="architect"), + _make_subtask(id="s3", description="编写代码", assigned_expert="coder"), ] - return CollaborationPlan( + return TeamPlan( id="plan_001", task="实现用户登录功能", - phases=phases, - variables={"project": "fischer"}, + subtasks=subtasks, status=PlanStatus.DRAFT, lead_expert="architect", ) -# ── PlanPhase 测试 ──────────────────────────────────────── +# ── MergeStrategy 测试 ──────────────────────────────────── -class TestPlanPhase: - """PlanPhase 数据模型测试""" +class TestMergeStrategy: + """MergeStrategy 枚举测试""" + + def test_only_best_strategy_exists(self): + """hub-and-spoke 模式仅保留 BEST 策略""" + assert MergeStrategy.BEST == "best" + # 确保只有 BEST 一个值 + assert len(list(MergeStrategy)) == 1 + + def test_no_vote_or_fusion(self): + """VOTE 和 FUSION 已被移除""" + assert not hasattr(MergeStrategy, "VOTE") + assert not hasattr(MergeStrategy, "FUSION") + + +# ── PlanStatus 测试 ─────────────────────────────────────── + + +class TestPlanStatus: + """PlanStatus 枚举测试""" + + def test_statuses_exist(self): + """必要的计划状态都存在""" + assert PlanStatus.DRAFT == "draft" + assert PlanStatus.EXECUTING == "executing" + assert PlanStatus.COMPLETED == "completed" + assert PlanStatus.FAILED == "failed" + assert PlanStatus.FALLBACK == "fallback" + + def test_no_confirmed_status(self): + """CONFIRMED 状态已被移除(hub-and-spoke 无需确认阶段)""" + assert not hasattr(PlanStatus, "CONFIRMED") + + +# ── SubTaskStatus 测试 ──────────────────────────────────── + + +class TestSubTaskStatus: + """SubTaskStatus 枚举测试""" + + def test_statuses_exist(self): + """子任务状态都存在""" + assert SubTaskStatus.PENDING == "pending" + assert SubTaskStatus.RUNNING == "running" + assert SubTaskStatus.COMPLETED == "completed" + assert SubTaskStatus.FAILED == "failed" + + +# ── SubTask 测试 ────────────────────────────────────────── + + +class TestSubTask: + """SubTask 数据模型测试""" def test_creation_with_all_fields(self): - """创建 PlanPhase 并设置所有字段""" - phase = PlanPhase( - id="phase_a", - name="竞品分析", + """创建 SubTask 并设置所有字段""" + subtask = SubTask( + id="subtask_a", + description="竞品分析", assigned_expert="analyst", - task_description="分析竞品功能", - depends_on=["phase_0"], - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=MergeStrategy.BEST, - milestone="竞品报告完成", - status=PhaseStatus.IN_PROGRESS, + status=SubTaskStatus.RUNNING, result={"report": "竞品分析报告"}, ) - assert phase.id == "phase_a" - assert phase.name == "竞品分析" - assert phase.assigned_expert == "analyst" - assert phase.task_description == "分析竞品功能" - assert phase.depends_on == ["phase_0"] - assert phase.parallel_type == ParallelType.COMPETITIVE_PARALLEL - assert phase.merge_strategy == MergeStrategy.BEST - assert phase.milestone == "竞品报告完成" - assert phase.status == PhaseStatus.IN_PROGRESS - assert phase.result == {"report": "竞品分析报告"} + assert subtask.id == "subtask_a" + assert subtask.description == "竞品分析" + assert subtask.assigned_expert == "analyst" + assert subtask.status == SubTaskStatus.RUNNING + assert subtask.result == {"report": "竞品分析报告"} + + def test_default_values(self): + """默认值:自动生成 id,PENDING 状态""" + subtask = SubTask(description="测试任务") + assert subtask.id is not None + assert len(subtask.id) > 0 + assert subtask.description == "测试任务" + assert subtask.assigned_expert == "" + assert subtask.status == SubTaskStatus.PENDING + assert subtask.result is None def test_to_dict_from_dict_roundtrip(self): """to_dict / from_dict 往返序列化""" - phase = PlanPhase( - id="roundtrip_phase", - name="往返测试", + subtask = SubTask( + id="roundtrip_subtask", + description="往返测试", assigned_expert="tester", - task_description="测试序列化", - depends_on=["dep_a", "dep_b"], - parallel_type=ParallelType.SUBTASK_PARALLEL, - merge_strategy=MergeStrategy.VOTE, - milestone="序列化验证", - status=PhaseStatus.COMPLETED, + status=SubTaskStatus.COMPLETED, result={"key": "value"}, ) - d = phase.to_dict() - restored = PlanPhase.from_dict(d) + d = subtask.to_dict() + restored = SubTask.from_dict(d) - assert restored.id == phase.id - assert restored.name == phase.name - assert restored.assigned_expert == phase.assigned_expert - assert restored.task_description == phase.task_description - assert restored.depends_on == phase.depends_on - assert restored.parallel_type == phase.parallel_type - assert restored.merge_strategy == phase.merge_strategy - assert restored.milestone == phase.milestone - assert restored.status == phase.status - assert restored.result == phase.result + assert restored.id == subtask.id + assert restored.description == subtask.description + assert restored.assigned_expert == subtask.assigned_expert + assert restored.status == subtask.status + assert restored.result == subtask.result - def test_to_dict_from_dict_with_none_merge_strategy(self): - """merge_strategy 为 None 时的序列化往返""" - phase = PlanPhase( - id="no_merge", - name="无合并", + def test_to_dict_structure(self): + """to_dict 返回正确的字典结构""" + subtask = _make_subtask( + id="struct_test", + description="结构测试", assigned_expert="dev", - task_description="串行任务", - parallel_type=ParallelType.SERIAL, + status=SubTaskStatus.RUNNING, + result={"output": "data"}, ) - d = phase.to_dict() - assert d["merge_strategy"] is None - restored = PlanPhase.from_dict(d) - assert restored.merge_strategy is None + d = subtask.to_dict() + assert d["id"] == "struct_test" + assert d["description"] == "结构测试" + assert d["assigned_expert"] == "dev" + assert d["status"] == "running" + assert d["result"] == {"output": "data"} -# ── CollaborationPlan 测试 ──────────────────────────────── +# ── TeamPlan 测试 ───────────────────────────────────────── -class TestCollaborationPlan: - """CollaborationPlan 数据模型测试""" +class TestTeamPlan: + """TeamPlan 数据模型测试""" def test_creation(self): - """创建 CollaborationPlan""" + """创建 TeamPlan""" plan = _make_valid_plan() assert plan.id == "plan_001" assert plan.task == "实现用户登录功能" - assert len(plan.phases) == 3 - assert plan.variables == {"project": "fischer"} + assert len(plan.subtasks) == 3 assert plan.status == PlanStatus.DRAFT assert plan.lead_expert == "architect" + def test_default_values(self): + """默认值:自动生成 id,空子任务列表""" + plan = TeamPlan(task="测试任务") + assert plan.id is not None + assert plan.task == "测试任务" + assert plan.subtasks == [] + assert plan.status == PlanStatus.DRAFT + assert plan.lead_expert == "" + def test_to_dict_from_dict_roundtrip(self): """to_dict / from_dict 往返序列化""" plan = _make_valid_plan() d = plan.to_dict() - restored = CollaborationPlan.from_dict(d) + restored = TeamPlan.from_dict(d) assert restored.id == plan.id assert restored.task == plan.task - assert len(restored.phases) == len(plan.phases) - assert restored.variables == plan.variables + assert len(restored.subtasks) == len(plan.subtasks) assert restored.status == plan.status assert restored.lead_expert == plan.lead_expert - for original, restored_phase in zip(plan.phases, restored.phases): - assert restored_phase.id == original.id - assert restored_phase.name == original.name - assert restored_phase.assigned_expert == original.assigned_expert - assert restored_phase.depends_on == original.depends_on - assert restored_phase.parallel_type == original.parallel_type - assert restored_phase.merge_strategy == original.merge_strategy + for original, restored_st in zip(plan.subtasks, restored.subtasks): + assert restored_st.id == original.id + assert restored_st.description == original.description + assert restored_st.assigned_expert == original.assigned_expert + assert restored_st.status == original.status - def test_validate_valid_plan(self): - """验证有效计划无错误""" + def test_get_subtask_by_id(self): + """get_subtask 根据 ID 获取子任务""" plan = _make_valid_plan() - errors = plan.validate() - assert errors == [] + st = plan.get_subtask("s2") + assert st is not None + assert st.id == "s2" + assert st.description == "设计架构" - def test_validate_detects_duplicate_phase_ids(self): - """验证检测到重复阶段 ID""" - phases = [ - _make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"), - _make_phase(id="p1", name="阶段2", assigned_expert="b", task_description="t2"), - ] - plan = CollaborationPlan( - id="dup_plan", task="重复ID测试", phases=phases, lead_expert="a" - ) - errors = plan.validate() - assert any("重复的阶段 ID" in e for e in errors) - - def test_validate_detects_missing_depends_on_references(self): - """验证检测到不存在的 depends_on 引用""" - phases = [ - _make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"), - _make_phase( - id="p2", - name="阶段2", - assigned_expert="b", - task_description="t2", - depends_on=["p1", "nonexistent"], - ), - ] - plan = CollaborationPlan( - id="missing_dep_plan", task="缺失依赖测试", phases=phases, lead_expert="a" - ) - errors = plan.validate() - assert any("不存在的阶段 ID" in e for e in errors) - - def test_validate_detects_circular_dependencies(self): - """验证检测到循环依赖""" - phases = [ - _make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1", depends_on=["p3"]), - _make_phase(id="p2", name="阶段2", assigned_expert="b", task_description="t2", depends_on=["p1"]), - _make_phase(id="p3", name="阶段3", assigned_expert="c", task_description="t3", depends_on=["p2"]), - ] - plan = CollaborationPlan( - id="cycle_plan", task="循环依赖测试", phases=phases, lead_expert="a" - ) - errors = plan.validate() - assert any("循环依赖" in e for e in errors) - - def test_validate_detects_competitive_parallel_without_merge_strategy(self): - """验证检测到 COMPETITIVE_PARALLEL 缺少 merge_strategy""" - phases = [ - _make_phase( - id="p1", - name="竞争阶段", - assigned_expert="a", - task_description="竞争任务", - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=None, - ), - ] - plan = CollaborationPlan( - id="no_merge_plan", task="缺少合并策略测试", phases=phases, lead_expert="a" - ) - errors = plan.validate() - assert any("COMPETITIVE_PARALLEL" in e and "merge_strategy" in e for e in errors) - - def test_get_ready_phases_returns_phases_with_completed_dependencies(self): - """get_ready_phases 返回依赖已完成的阶段""" + def test_get_subtask_with_nonexistent_id_returns_none(self): + """get_subtask 对不存在的 ID 返回 None""" plan = _make_valid_plan() - # 初始状态:p1 无依赖,应该就绪 - ready = plan.get_ready_phases() - assert len(ready) == 1 - assert ready[0].id == "p1" + assert plan.get_subtask("nonexistent") is None - # 完成 p1 后,p2 应该就绪 - plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"analysis": "done"}) - ready = plan.get_ready_phases() - assert len(ready) == 1 - assert ready[0].id == "p2" - - # 完成 p2 后,p3 应该就绪 - plan.update_phase_status("p2", PhaseStatus.COMPLETED, {"design": "done"}) - ready = plan.get_ready_phases() - assert len(ready) == 1 - assert ready[0].id == "p3" - - def test_get_ready_phases_returns_empty_when_dependencies_not_met(self): - """get_ready_phases 在依赖未满足时返回空列表""" - phases = [ - _make_phase(id="p1", name="阶段1", assigned_expert="a", task_description="t1"), - _make_phase( - id="p2", - name="阶段2", - assigned_expert="b", - task_description="t2", - depends_on=["p1"], - ), - ] - plan = CollaborationPlan( - id="dep_plan", task="依赖未满足测试", phases=phases, lead_expert="a" - ) - # p2 依赖 p1,p1 未完成,所以 p2 不就绪 - # 但 p1 无依赖,所以 p1 就绪 - ready = plan.get_ready_phases() - assert len(ready) == 1 - assert ready[0].id == "p1" - - # 将 p1 设为 IN_PROGRESS(未 COMPLETED),p2 仍不就绪 - plan.update_phase_status("p1", PhaseStatus.IN_PROGRESS) - ready = plan.get_ready_phases() - assert len(ready) == 0 - - def test_update_phase_status(self): - """update_phase_status 更新阶段状态和结果""" + def test_update_subtask_status(self): + """update_subtask_status 更新子任务状态和结果""" plan = _make_valid_plan() - plan.update_phase_status("p1", PhaseStatus.COMPLETED, {"output": "分析完成"}) - phase = plan.get_phase("p1") - assert phase is not None - assert phase.status == PhaseStatus.COMPLETED - assert phase.result == {"output": "分析完成"} + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "分析完成"}) + st = plan.get_subtask("s1") + assert st is not None + assert st.status == SubTaskStatus.COMPLETED + assert st.result == {"output": "分析完成"} - # 不传 result 时不应覆盖已有 result - plan.update_phase_status("p2", PhaseStatus.IN_PROGRESS) - phase2 = plan.get_phase("p2") - assert phase2 is not None - assert phase2.status == PhaseStatus.IN_PROGRESS - assert phase2.result is None - - def test_get_phase_by_id(self): - """get_phase 根据 ID 获取阶段""" + def test_update_subtask_status_without_result(self): + """update_subtask_status 不传 result 时不覆盖已有 result""" plan = _make_valid_plan() - phase = plan.get_phase("p2") - assert phase is not None - assert phase.id == "p2" - assert phase.name == "架构设计" + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED, {"output": "done"}) + plan.update_subtask_status("s1", SubTaskStatus.RUNNING) + st = plan.get_subtask("s1") + assert st is not None + assert st.status == SubTaskStatus.RUNNING + assert st.result == {"output": "done"} - def test_get_phase_with_nonexistent_id_returns_none(self): - """get_phase 对不存在的 ID 返回 None""" + def test_completed_subtasks_property(self): + """completed_subtasks 返回已完成的子任务""" plan = _make_valid_plan() - phase = plan.get_phase("nonexistent") - assert phase is None + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED) + plan.update_subtask_status("s2", SubTaskStatus.COMPLETED) + plan.update_subtask_status("s3", SubTaskStatus.FAILED) + + completed = plan.completed_subtasks + assert len(completed) == 2 + assert {st.id for st in completed} == {"s1", "s2"} + + def test_failed_subtasks_property(self): + """failed_subtasks 返回失败的子任务""" + plan = _make_valid_plan() + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED) + plan.update_subtask_status("s2", SubTaskStatus.FAILED) + plan.update_subtask_status("s3", SubTaskStatus.FAILED) + + failed = plan.failed_subtasks + assert len(failed) == 2 + assert {st.id for st in failed} == {"s2", "s3"} + + def test_all_done_property_when_all_completed(self): + """all_done 当所有子任务完成时返回 True""" + plan = _make_valid_plan() + for st in plan.subtasks: + plan.update_subtask_status(st.id, SubTaskStatus.COMPLETED) + assert plan.all_done is True + + def test_all_done_property_when_some_failed(self): + """all_done 当所有子任务完成或失败时返回 True""" + plan = _make_valid_plan() + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED) + plan.update_subtask_status("s2", SubTaskStatus.FAILED) + plan.update_subtask_status("s3", SubTaskStatus.COMPLETED) + assert plan.all_done is True + + def test_all_done_property_when_pending(self): + """all_done 当有子任务未完成时返回 False""" + plan = _make_valid_plan() + plan.update_subtask_status("s1", SubTaskStatus.COMPLETED) + # s2 and s3 still pending + assert plan.all_done is False + + def test_all_done_property_empty_plan(self): + """all_done 当没有子任务时返回 True(vacuous truth)""" + plan = TeamPlan(task="空计划") + assert plan.all_done is True diff --git a/tests/unit/experts/test_router.py b/tests/unit/experts/test_router.py index 5e44ca2..2078cbc 100644 --- a/tests/unit/experts/test_router.py +++ b/tests/unit/experts/test_router.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.router import ( @@ -113,7 +111,6 @@ class TestExpertTeamRoutingResult: assert result.specified_experts == [] assert result.task_content == "" assert result.auto_compose is False - assert result.complexity == 0.0 assert result.match_method == "" @@ -159,47 +156,14 @@ class TestExpertTeamRouterResolve: assert result.specified_experts == ["analyst"] assert result.task_content == "请分析这份报告" - def test_high_complexity_triggers_team_suggestion(self): - """高复杂度 (>=0.7) 触发团队模式建议""" - router = ExpertTeamRouter() - result = router.resolve("分析这个复杂系统", complexity=0.8) - assert result.matched is True - assert result.team_mode is True - assert result.auto_compose is True - assert result.match_method == "complexity_suggestion" - assert result.complexity == 0.8 - - def test_high_complexity_exact_threshold(self): - """复杂度恰好等于阈值 0.7 也触发团队模式""" - router = ExpertTeamRouter() - result = router.resolve("任务内容", complexity=0.7) - assert result.matched is True - assert result.team_mode is True - assert result.match_method == "complexity_suggestion" - - def test_low_complexity_no_team_mode(self): - """低复杂度 (<0.7) 不触发团队模式""" - router = ExpertTeamRouter() - result = router.resolve("简单问题", complexity=0.3) - assert result.matched is False - assert result.team_mode is False - assert result.task_content == "简单问题" - assert result.complexity == 0.3 - - def test_no_team_prefix_no_complexity(self): - """无 @team 前缀且无复杂度时不触发团队模式""" + def test_no_team_prefix_no_team_mode(self): + """无 @team 前缀不触发团队模式""" router = ExpertTeamRouter() result = router.resolve("普通问题") assert result.matched is False assert result.team_mode is False - - def test_team_prefix_takes_priority_over_complexity(self): - """@team 前缀优先于复杂度判断""" - router = ExpertTeamRouter() - result = router.resolve("@team:analyst 任务", complexity=0.1) - assert result.matched is True - assert result.match_method == "explicit_team" - assert result.specified_experts == ["analyst"] + assert result.task_content == "普通问题" + assert result.match_method == "" def test_nonexistent_expert_still_included(self): """指定不存在的专家名仍包含在列表中""" @@ -214,6 +178,35 @@ class TestExpertTeamRouterResolve: assert result.task_content == "@team" assert result.auto_compose is True + def test_resolve_strips_leading_whitespace(self): + """resolve() 对前导空白做 strip()""" + router = ExpertTeamRouter() + result = router.resolve(" @team:analyst 任务内容") + assert result.matched is True + assert result.team_mode is True + assert result.specified_experts == ["analyst"] + assert result.task_content == "任务内容" + + def test_invalid_expert_names_rejected(self): + """无效专家名被过滤掉""" + router = ExpertTeamRouter() + # 包含特殊字符的名称应被过滤 + result = router.resolve("@team:analyst,inva!id,bad/name 任务") + # 仅保留合法名称 + assert "analyst" in result.specified_experts + assert "inva!id" not in result.specified_experts + assert "bad/name" not in result.specified_experts + + def test_max_experts_limit(self): + """指定超过 MAX_EXPERTS 数量的专家时截断""" + router = ExpertTeamRouter() + # 构造 15 个专家名 + names = ",".join(f"expert{i}" for i in range(15)) + result = router.resolve(f"@team:{names} 任务") + from agentkit.experts.router import MAX_EXPERTS + + assert len(result.specified_experts) == MAX_EXPERTS + # ── ExpertTeamRouter.resolve_expert_configs 测试 ─────────── @@ -275,6 +268,25 @@ class TestExpertTeamRouterResolveExpertConfigs: configs = router.resolve_expert_configs(["analyst"]) assert configs[0].bound_skills == ["data_query"] + def test_resolve_first_expert_is_lead(self): + """第一个专家被指定为 lead""" + registry = _make_registry_with_templates() + router = ExpertTeamRouter(template_registry=registry) + + configs = router.resolve_expert_configs(["analyst", "strategist", "reviewer"]) + assert configs[0].is_lead is True + assert configs[1].is_lead is False + assert configs[2].is_lead is False + + def test_resolve_skips_invalid_names(self): + """跳过无效的专家名""" + router = ExpertTeamRouter() + # 包含无效字符的名称应被跳过 + configs = router.resolve_expert_configs(["valid_name", "inva!id", "also_valid"]) + assert len(configs) == 2 + assert configs[0].name == "valid_name" + assert configs[1].name == "also_valid" + # ── ExpertTeamRouter 构造测试 ───────────────────────────── @@ -293,7 +305,40 @@ class TestExpertTeamRouterInit: router = ExpertTeamRouter(template_registry=registry) assert router._registry is registry - def test_complexity_threshold(self): - """复杂度阈值默认为 0.7""" + +# ── ExpertTeamRouter.can_handle 测试 ────────────────────── + + +class TestExpertTeamRouterCanHandle: + """ExpertTeamRouter.can_handle 方法测试""" + + def test_can_handle_with_matching_template_name(self): + """模板名出现在内容中时返回 True""" + registry = _make_registry_with_templates() + router = ExpertTeamRouter(template_registry=registry) + + assert router.can_handle("请 analyst 帮我分析") is True + + def test_can_handle_with_matching_description(self): + """模板描述词出现在内容中时返回 True""" + registry = _make_registry_with_templates() + router = ExpertTeamRouter(template_registry=registry) + + # 描述为 "analyst 模板" 等,"模板" 长度 > 2 应匹配 + assert router.can_handle("我需要一个模板来参考") is True + + def test_can_handle_no_templates(self): + """无注册模板时返回 False""" router = ExpertTeamRouter() - assert router.COMPLEXITY_THRESHOLD == 0.7 + # 默认注册中心可能为空 + # 强制清空以测试 + router._registry._templates.clear() + assert router.can_handle("任何内容") is False + + def test_can_handle_with_templates_no_match(self): + """有模板但内容不匹配时仍返回 True(auto-compose 可组建团队)""" + registry = _make_registry_with_templates() + router = ExpertTeamRouter(template_registry=registry) + + # 内容与任何模板名/描述都不匹配,但有模板存在 → auto-compose 可用 + assert router.can_handle("完全无关的内容 xyz123") is True diff --git a/tests/unit/experts/test_team.py b/tests/unit/experts/test_team.py index df1dbb1..05220ee 100644 --- a/tests/unit/experts/test_team.py +++ b/tests/unit/experts/test_team.py @@ -1,4 +1,4 @@ -"""ExpertTeam 容器单元测试""" +"""ExpertTeam 容器单元测试 (hub-and-spoke 模式)""" from __future__ import annotations @@ -11,11 +11,6 @@ from agentkit.core.handoff_transport import InProcessHandoffTransport from agentkit.core.shared_workspace import SharedWorkspace from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.expert import Expert -from agentkit.experts.plan import ( - CollaborationPlan, - PlanPhase, - PlanStatus, -) from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.team import ExpertTeam, TeamStatus @@ -84,27 +79,6 @@ def _make_mock_expert( return expert -def _make_valid_plan( - plan_id: str = "plan_1", - task: str = "测试任务", - lead_expert: str = "lead", -) -> CollaborationPlan: - """创建有效的 CollaborationPlan""" - return CollaborationPlan( - id=plan_id, - task=task, - phases=[ - PlanPhase( - id="phase_1", - name="阶段1", - assigned_expert=lead_expert, - task_description="执行任务", - ) - ], - lead_expert=lead_expert, - ) - - # ── ExpertTeam 创建测试 ─────────────────────────────────── @@ -119,7 +93,6 @@ class TestExpertTeamCreation: assert len(team.team_id) > 0 assert team.status == TeamStatus.FORMING assert team.lead_expert is None - assert team.plan is None assert team.experts == [] assert team.active_experts == [] @@ -173,7 +146,7 @@ class TestExpertTeamCreateTeam: assert team._lead_expert_name == "lead" assert team.lead_expert is mock_expert - assert team.status == TeamStatus.PLANNING + assert team.status == TeamStatus.EXECUTING assert mock_expert.team_id == team.team_id @pytest.mark.asyncio @@ -195,7 +168,7 @@ class TestExpertTeamCreateTeam: assert len(team.experts) == 2 assert team._lead_expert_name == "lead" - assert team.status == TeamStatus.PLANNING + assert team.status == TeamStatus.EXECUTING @pytest.mark.asyncio async def test_create_team_without_pool_raises(self): @@ -411,64 +384,6 @@ class TestExpertTeamRemoveExpert: assert last_left[0][1]["expert_name"] == "member1" -# ── ExpertTeam.update_plan 测试 ──────────────────────────── - - -class TestExpertTeamUpdatePlan: - """ExpertTeam.update_plan 协作计划更新测试""" - - def test_update_plan_with_valid_plan(self): - """有效计划更新成功,返回受影响的 Expert 名称""" - team = ExpertTeam() - plan = _make_valid_plan(lead_expert="lead") - - affected = team.update_plan(plan) - - assert team.plan is plan - assert "lead" in affected - - def test_update_plan_confirmed_sets_executing(self): - """CONFIRMED 状态的计划将团队状态设为 EXECUTING""" - team = ExpertTeam() - plan = _make_valid_plan(lead_expert="lead") - plan.status = PlanStatus.CONFIRMED - - team.update_plan(plan) - - assert team.status == TeamStatus.EXECUTING - - def test_update_plan_with_invalid_plan_no_update(self): - """无效计划(validate 返回错误)不更新,返回验证错误列表""" - team = ExpertTeam() - # 创建有循环依赖的无效计划 - plan = CollaborationPlan( - id="bad_plan", - task="无效任务", - phases=[ - PlanPhase( - id="p1", - name="阶段1", - assigned_expert="a", - task_description="t1", - depends_on=["p2"], - ), - PlanPhase( - id="p2", - name="阶段2", - assigned_expert="b", - task_description="t2", - depends_on=["p1"], - ), - ], - lead_expert="lead", - ) - - result = team.update_plan(plan) - - assert len(result) > 0 # 返回验证错误而非空列表 - assert team.plan is None # 未更新 - - # ── ExpertTeam.broadcast_user_message 测试 ───────────────── @@ -535,42 +450,6 @@ class TestExpertTeamGetSharedContext: assert context == {} -# ── ExpertTeam.generate_plan 测试 ────────────────────────── - - -class TestExpertTeamGeneratePlan: - """ExpertTeam.generate_plan 计划生成测试""" - - @pytest.mark.asyncio - async def test_generate_plan(self): - """生成空的 CollaborationPlan""" - pool = _make_mock_pool() - team = ExpertTeam(pool=pool) - - lead_config = _make_expert_config(name="lead", is_lead=True) - with patch.object(Expert, "create", new_callable=AsyncMock) as mock_create: - lead_expert = _make_mock_expert(name="lead", is_lead=True) - mock_create.return_value = lead_expert - await team.create_team(lead_config) - - plan = await team.generate_plan("分析数据") - - assert plan is not None - assert plan.task == "分析数据" - assert plan.lead_expert == "lead" - assert plan.phases == [] - assert team.plan is plan - - @pytest.mark.asyncio - async def test_generate_plan_without_lead(self): - """没有 Lead Expert 时生成计划,lead_expert 为空字符串""" - team = ExpertTeam() - - plan = await team.generate_plan("测试任务") - - assert plan.lead_expert == "" - - # ── ExpertTeam.dissolve 测试 ─────────────────────────────── @@ -667,11 +546,6 @@ class TestExpertTeamDissolvedOperations: # 解散后状态为 DISSOLVED assert team.status == TeamStatus.DISSOLVED - # 再次 create_team 时,由于 experts 已清空, - # 但 pool 仍然存在,理论上可以重新创建 - # 但这里验证状态是 DISSOLVED - assert team.status == TeamStatus.DISSOLVED - # ── ExpertTeam.lead_expert 属性测试 ──────────────────────── @@ -742,10 +616,11 @@ class TestExpertTeamBuildContext: context = team._build_team_context(lead_config, [member_config]) - assert "You are part of an Expert Team." in context - assert "Lead Expert: lead (领导者)" in context - assert "Team Member: analyst (分析师), Skills: data_query" in context - assert "send_message() and request_assist()" in context + assert "hub-and-spoke mode" in context + assert "Lead Expert (hub): lead (领导者)" in context + assert "Team Member (spoke): analyst (分析师)" in context + assert "data_query" in context + assert "depth=1" in context def test_build_team_context_no_lead(self): """没有 Lead Expert 时构建上下文""" @@ -754,8 +629,9 @@ class TestExpertTeamBuildContext: context = team._build_team_context(None, [member_config]) - assert "Lead Expert" not in context - assert "Team Member: analyst" in context + # 不应出现具体的 Lead Expert (hub): name 行 + assert "Lead Expert (hub):" not in context + assert "Team Member (spoke): analyst" in context def test_build_team_context_skips_lead_in_members(self): """成员列表中包含 Lead 时跳过""" @@ -765,4 +641,4 @@ class TestExpertTeamBuildContext: context = team._build_team_context(lead_config, [lead_config]) # Lead 不应出现在 Team Member 行 - assert "Team Member: lead" not in context + assert "Team Member (spoke): lead" not in context diff --git a/tests/unit/experts/test_team_orchestrator.py b/tests/unit/experts/test_team_orchestrator.py index 09c69f1..897fe2c 100644 --- a/tests/unit/experts/test_team_orchestrator.py +++ b/tests/unit/experts/test_team_orchestrator.py @@ -1,4 +1,4 @@ -"""TeamOrchestrator 单元测试""" +"""TeamOrchestrator 单元测试 (hub-and-spoke 模式)""" from __future__ import annotations @@ -11,14 +11,7 @@ from agentkit.core.protocol import TaskResult, TaskStatus from agentkit.experts.config import ExpertConfig from agentkit.experts.expert import Expert from agentkit.experts.orchestrator import TeamOrchestrator -from agentkit.experts.plan import ( - CollaborationPlan, - MergeStrategy, - ParallelType, - PhaseStatus, - PlanPhase, - PlanStatus, -) +from agentkit.experts.plan import PlanStatus, SubTask, SubTaskStatus from agentkit.experts.team import ExpertTeam, TeamStatus @@ -71,6 +64,8 @@ def _make_mock_expert( started_at=None, completed_at=None, )) + # No LLM gateway by default (tests single-subtask path) + mock_agent._llm_gateway = None expert.agent = mock_agent return expert @@ -97,516 +92,404 @@ def _make_team_with_experts( return team -def _make_serial_plan( - plan_id: str = "plan_1", - task: str = "测试任务", - lead_expert: str = "lead", - num_phases: int = 1, -) -> CollaborationPlan: - """创建串行阶段的 CollaborationPlan""" - phases = [] - for i in range(num_phases): - deps = [f"phase_{i}"] if i > 0 else [] - phases.append( - PlanPhase( - id=f"phase_{i + 1}", - name=f"阶段{i + 1}", - assigned_expert=lead_expert, - task_description=f"执行任务{i + 1}", - depends_on=deps, - parallel_type=ParallelType.SERIAL, - ) - ) - return CollaborationPlan( - id=plan_id, - task=task, - phases=phases, - lead_expert=lead_expert, - ) +def _make_mock_llm_gateway(subtask_descriptions: list[str] | None = None) -> MagicMock: + """创建 mock LLM gateway. + + If subtask_descriptions is provided, the gateway returns a JSON array + of subtasks for decomposition. Otherwise returns a simple response. + """ + gateway = AsyncMock() + if subtask_descriptions: + import json + subtasks_json = json.dumps([ + {"description": desc, "assigned_expert": "member1"} + for desc in subtask_descriptions + ]) + response = MagicMock() + response.content = subtasks_json + gateway.chat = AsyncMock(return_value=response) + else: + response = MagicMock() + response.content = "Synthesized result" + gateway.chat = AsyncMock(return_value=response) + return gateway -def _make_parallel_plan( - plan_id: str = "plan_parallel", - task: str = "并行测试任务", - parallel_type: ParallelType = ParallelType.SUBTASK_PARALLEL, - merge_strategy: MergeStrategy | None = None, -) -> CollaborationPlan: - """创建并行阶段的 CollaborationPlan""" - phases = [ - PlanPhase( - id="phase_1", - name="并行阶段1", - assigned_expert="member1", - task_description="并行任务1", - parallel_type=parallel_type, - merge_strategy=merge_strategy, - ), - PlanPhase( - id="phase_2", - name="并行阶段2", - assigned_expert="member2", - task_description="并行任务2", - parallel_type=parallel_type, - merge_strategy=merge_strategy, - ), - ] - return CollaborationPlan( - id=plan_id, - task=task, - phases=phases, - lead_expert="lead", - ) +# ── Hub-and-spoke 执行测试 ──────────────────────────────── -# ── 串行阶段执行测试 ────────────────────────────────────── - - -class TestSerialPhaseExecution: - """串行阶段执行测试""" +class TestHubAndSpokeExecution: + """Hub-and-spoke 模式执行测试""" @pytest.mark.asyncio - async def test_single_serial_phase_completes(self): - """单个串行阶段执行完成""" + async def test_execute_single_subtask_completes(self): + """无 LLM 时,任务作为单个子任务执行完成""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - result = await orchestrator.execute_plan(plan) + result = await orchestrator.execute("测试任务") assert result["status"] == "completed" - assert "phase_1" in result["phase_results"] - assert plan.phases[0].status == PhaseStatus.COMPLETED - - @pytest.mark.asyncio - async def test_multiple_serial_phases_in_order(self): - """多个串行阶段按依赖顺序执行""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=3) - - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "completed" - assert len(result["phase_results"]) == 3 - # All phases should be completed - for phase in plan.phases: - assert phase.status == PhaseStatus.COMPLETED - - @pytest.mark.asyncio - async def test_serial_phase_sets_plan_and_team_status(self): - """执行计划时设置 plan 和 team 状态""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan() - - await orchestrator.execute_plan(plan) - - assert plan.status == PlanStatus.COMPLETED + assert "result" in result + assert "subtask_results" in result + assert "plan" in result assert team.status == TeamStatus.COMPLETED - -# ── 子任务并行阶段执行测试 ──────────────────────────────── - - -class TestSubtaskParallelExecution: - """子任务并行阶段执行测试""" - @pytest.mark.asyncio - async def test_subtask_parallel_phases_execute(self): - """子任务并行阶段并行执行""" + async def test_execute_sets_team_status(self): + """执行时设置 team 状态为 EXECUTING → SYNTHESIZING → COMPLETED""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_parallel_plan(parallel_type=ParallelType.SUBTASK_PARALLEL) - result = await orchestrator.execute_plan(plan) + await orchestrator.execute("测试任务") - assert result["status"] == "completed" - assert "phase_1" in result["phase_results"] - assert "phase_2" in result["phase_results"] + assert team.status == TeamStatus.COMPLETED @pytest.mark.asyncio - async def test_subtask_parallel_phase_failure_recorded(self): - """子任务并行阶段失败时记录错误""" + async def test_execute_emits_team_formed_event(self): + """执行时广播 team_formed 事件""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_parallel_plan(parallel_type=ParallelType.SUBTASK_PARALLEL) - # Mock _execute_phase to raise for one phase - original_execute = orchestrator._execute_phase - call_count = 0 + await orchestrator.execute("测试任务") - async def mock_execute_phase(phase, p, pr): - nonlocal call_count - call_count += 1 - if phase.id == "phase_1": - raise RuntimeError("Simulated failure") - return await original_execute(phase, p, pr) - - with patch.object( - orchestrator, "_execute_phase", side_effect=mock_execute_phase - ): - result = await orchestrator.execute_plan(plan) - - # The exception should be caught by asyncio.gather(return_exceptions=True) - assert "phase_1" in result["phase_results"] - assert "error" in result["phase_results"]["phase_1"] - - -# ── 竞争并行阶段测试 ────────────────────────────────────── - - -class TestCompetitiveParallelExecution: - """竞争并行阶段执行测试""" + calls = team._handoff_transport.send.call_args_list + event_types = [c[0][1]["type"] for c in calls] + assert "team_formed" in event_types @pytest.mark.asyncio - async def test_competitive_parallel_best_strategy(self): - """竞争并行阶段使用 BEST 合并策略""" + async def test_execute_emits_expert_step_and_result_events(self): + """执行时广播 expert_step 和 expert_result 事件""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_parallel_plan( - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=MergeStrategy.BEST, + + await orchestrator.execute("测试任务") + + calls = team._handoff_transport.send.call_args_list + event_types = [c[0][1]["type"] for c in calls] + assert "expert_step" in event_types + assert "expert_result" in event_types + + @pytest.mark.asyncio + async def test_execute_emits_team_synthesis_event(self): + """执行完成时广播 team_synthesis 事件""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + await orchestrator.execute("测试任务") + + calls = team._handoff_transport.send.call_args_list + event_types = [c[0][1]["type"] for c in calls] + assert "team_synthesis" in event_types + + @pytest.mark.asyncio + async def test_execute_emits_plan_update_event(self): + """执行时广播 plan_update 事件(包含子任务列表)""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + await orchestrator.execute("测试任务") + + calls = team._handoff_transport.send.call_args_list + plan_updates = [c for c in calls if c[0][1].get("type") == "plan_update"] + assert len(plan_updates) >= 1 + assert "subtasks" in plan_updates[0][0][1] + + +# ── LLM 任务分解测试 ────────────────────────────────────── + + +class TestTaskDecomposition: + """LLM 任务分解测试""" + + @pytest.mark.asyncio + async def test_llm_decomposes_task_into_subtasks(self): + """LLM 将任务分解为多个子任务""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + # Set up LLM gateway on lead expert for decomposition + gateway = _make_mock_llm_gateway( + subtask_descriptions=["分析数据", "生成报告", "审核结果"] ) + team._experts["lead"].agent._llm_gateway = gateway - result = await orchestrator.execute_plan(plan) + result = await orchestrator.execute("分析并报告数据") assert result["status"] == "completed" - # Competitive phases are merged into one result per phase - for phase_id in ["phase_1", "phase_2"]: - assert phase_id in result["phase_results"] - phase_result = result["phase_results"][phase_id] - assert phase_result.get("merged") is True - assert phase_result.get("strategy") == "best" + plan = result["plan"] + assert len(plan.subtasks) == 3 + # Each subtask should have been executed + assert len(result["subtask_results"]) == 3 @pytest.mark.asyncio - async def test_competitive_parallel_vote_strategy(self): - """竞争并行阶段使用 VOTE 合并策略""" + async def test_decomposition_fallback_to_single_subtask(self): + """LLM 不可用时回退到单个子任务""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_parallel_plan( - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=MergeStrategy.VOTE, + + # No LLM gateway — should fall back to single subtask + result = await orchestrator.execute("测试任务") + + assert result["status"] == "completed" + plan = result["plan"] + assert len(plan.subtasks) == 1 + + @pytest.mark.asyncio + async def test_parse_subtasks_valid_json(self): + """_parse_subtasks 正确解析 JSON 数组""" + import json + content = json.dumps([ + {"description": "任务1", "assigned_expert": "member1"}, + {"description": "任务2", "assigned_expert": "member2"}, + ]) + subtasks = TeamOrchestrator._parse_subtasks( + content, ["member1", "member2"], "lead" ) - - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "completed" - for phase_id in ["phase_1", "phase_2"]: - phase_result = result["phase_results"][phase_id] - assert phase_result.get("merged") is True - assert phase_result.get("strategy") == "vote" + assert len(subtasks) == 2 + assert subtasks[0].description == "任务1" + assert subtasks[0].assigned_expert == "member1" + assert subtasks[1].description == "任务2" + assert subtasks[1].assigned_expert == "member2" @pytest.mark.asyncio - async def test_competitive_parallel_fusion_strategy(self): - """竞争并行阶段使用 FUSION 合并策略""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = _make_parallel_plan( - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=MergeStrategy.FUSION, + async def test_parse_subtasks_invalid_expert_falls_back_to_lead(self): + """_parse_subtasks 对无效专家名回退到 lead""" + import json + content = json.dumps([ + {"description": "任务1", "assigned_expert": "nonexistent"}, + ]) + subtasks = TeamOrchestrator._parse_subtasks( + content, ["member1"], "lead" ) - - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "completed" - for phase_id in ["phase_1", "phase_2"]: - phase_result = result["phase_results"][phase_id] - assert phase_result.get("merged") is True - assert phase_result.get("strategy") == "fusion" - assert phase_result.get("fused_from") == 3 # 3 active experts - - -# ── 里程碑检查点测试 ────────────────────────────────────── - - -class TestMilestoneCheckpoint: - """里程碑检查点测试""" + assert len(subtasks) == 1 + assert subtasks[0].assigned_expert == "lead" @pytest.mark.asyncio - async def test_milestone_pass(self): - """里程碑检查通过""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = CollaborationPlan( - id="plan_milestone", - task="里程碑测试", - phases=[ - PlanPhase( - id="phase_1", - name="带里程碑阶段", - assigned_expert="lead", - task_description="执行带里程碑的任务", - milestone="输出质量达标", - ) - ], - lead_expert="lead", + async def test_parse_subtasks_invalid_json_returns_empty(self): + """_parse_subtasks 对无效 JSON 返回空列表""" + subtasks = TeamOrchestrator._parse_subtasks( + "not json at all", ["member1"], "lead" ) - - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "completed" - assert plan.phases[0].status == PhaseStatus.COMPLETED + assert subtasks == [] @pytest.mark.asyncio - async def test_milestone_fail_phase_failed(self): - """里程碑检查失败 → 阶段状态为 FAILED""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = CollaborationPlan( - id="plan_milestone_fail", - task="里程碑失败测试", - phases=[ - PlanPhase( - id="phase_1", - name="带里程碑阶段", - assigned_expert="lead", - task_description="执行带里程碑的任务", - milestone="输出质量达标", - ) - ], - lead_expert="lead", + async def test_parse_subtasks_empty_description_skipped(self): + """_parse_subtasks 跳过空描述的子任务""" + import json + content = json.dumps([ + {"description": "", "assigned_expert": "member1"}, + {"description": "有效任务", "assigned_expert": "member1"}, + ]) + subtasks = TeamOrchestrator._parse_subtasks( + content, ["member1"], "lead" ) - - # Mock _check_milestone to return False - with patch.object( - orchestrator, "_check_milestone", return_value=False - ): - result = await orchestrator.execute_plan(plan) - - assert plan.phases[0].status == PhaseStatus.FAILED - # Phase failed → retry → still failed → fallback - assert result["status"] == "fallback" + assert len(subtasks) == 1 + assert subtasks[0].description == "有效任务" -# ── 重试与回退测试 ──────────────────────────────────────── +# ── 子任务执行测试 ──────────────────────────────────────── -class TestRetryAndFallback: - """重试与回退测试""" +class TestSubtaskExecution: + """子任务执行测试""" @pytest.mark.asyncio - async def test_phase_failure_triggers_retry(self): - """阶段失败触发重试""" + async def test_subtask_execution_calls_agent_execute(self): + """子任务执行调用 agent.execute()""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - # Mock _execute_phase: first call returns None, second call succeeds - call_count = 0 + await orchestrator.execute("测试任务") - async def mock_execute_phase(phase, p, pr): - nonlocal call_count - call_count += 1 - if call_count == 1: - # First call fails - p.update_phase_status(phase.id, PhaseStatus.FAILED) - return None - # Retry succeeds — simulate a successful phase execution - p.update_phase_status(phase.id, PhaseStatus.COMPLETED, {"output": "retry ok"}) - return {"output": "retry ok"} + # Lead expert's agent should have been called + team._experts["lead"].agent.execute.assert_awaited() - with patch.object( - orchestrator, "_execute_phase", side_effect=mock_execute_phase - ): - result = await orchestrator.execute_plan(plan) + @pytest.mark.asyncio + async def test_subtask_marks_completed(self): + """子任务执行后状态标记为 COMPLETED""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + result = await orchestrator.execute("测试任务") + + plan = result["plan"] + for st in plan.subtasks: + assert st.status == SubTaskStatus.COMPLETED + assert st.result is not None + + @pytest.mark.asyncio + async def test_subtask_with_invalid_expert_falls_back_to_lead(self): + """子任务分配的专家不可用时回退到 lead""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + # Set up LLM to assign to a nonexistent expert + import json + gateway = _make_mock_llm_gateway(subtask_descriptions=["任务1"]) + gateway.chat = AsyncMock(return_value=MagicMock( + content=json.dumps([ + {"description": "任务1", "assigned_expert": "nonexistent"} + ]) + )) + team._experts["lead"].agent._llm_gateway = gateway + + result = await orchestrator.execute("测试任务") - # After retry, the phase should succeed - assert call_count == 2 assert result["status"] == "completed" - - @pytest.mark.asyncio - async def test_retry_failure_triggers_fallback(self): - """重试仍然失败 → 回退到单 Agent 模式""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - - # Mock _execute_phase to always return None (failure) - async def mock_execute_phase(phase, p, pr): - plan.update_phase_status(phase.id, PhaseStatus.FAILED) - return None - - with patch.object( - orchestrator, "_execute_phase", side_effect=mock_execute_phase - ): - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "fallback" - assert plan.status == PlanStatus.FALLBACK - - @pytest.mark.asyncio - async def test_replan_before_fallback_on_failure(self): - """重试失败后尝试 replan,replan 成功则不 fallback""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team, max_replans=1) - plan = _make_serial_plan(num_phases=1) - - call_count = 0 - - async def mock_execute_phase(phase, p, pr): - nonlocal call_count - call_count += 1 - if call_count <= 2: - # First call + retry both fail - p.update_phase_status(phase.id, PhaseStatus.FAILED) - return None - # Replan attempt succeeds - p.update_phase_status(phase.id, PhaseStatus.COMPLETED, {"output": "replan ok"}) - return {"output": "replan ok"} - - with patch.object( - orchestrator, "_execute_phase", side_effect=mock_execute_phase - ): - result = await orchestrator.execute_plan(plan) - - # After retry fails → replan → succeeds, should complete - assert call_count == 3 # 1 initial + 1 retry + 1 replan - assert result["status"] == "completed" - - @pytest.mark.asyncio - async def test_replan_exhausted_then_fallback(self): - """replan 次数用尽后 fallback""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team, max_replans=2) - plan = _make_serial_plan(num_phases=1) - - async def mock_execute_phase(phase, p, pr): - # Always fail - p.update_phase_status(phase.id, PhaseStatus.FAILED) - return None - - with patch.object( - orchestrator, "_execute_phase", side_effect=mock_execute_phase - ): - result = await orchestrator.execute_plan(plan) - - # Exhausted retries + replans → fallback - assert result["status"] == "fallback" - - -# ── 最大交互轮次测试 ────────────────────────────────────── - - -class TestMaxInteractionRounds: - """最大交互轮次限制测试""" - - @pytest.mark.asyncio - async def test_max_interaction_rounds_limit(self): - """超过最大交互轮次时停止执行""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - orchestrator.MAX_INTERACTION_ROUNDS = 1 - - # Create a plan with many phases that would take many rounds - plan = _make_serial_plan(num_phases=5) - - await orchestrator.execute_plan(plan) - - # Should stop after 1 round, not completing all phases - # Only the first phase should complete (1 interaction round) - assert orchestrator._interaction_count >= 1 - - -# ── 无效计划测试 ────────────────────────────────────────── - - -class TestInvalidPlan: - """无效计划测试""" - - @pytest.mark.asyncio - async def test_invalid_plan_returns_failed_status(self): - """无效计划返回 failed 状态""" - team = _make_team_with_experts() - orchestrator = TeamOrchestrator(team) - - # Create invalid plan with circular dependency - plan = CollaborationPlan( - id="invalid_plan", - task="无效任务", - phases=[ - PlanPhase( - id="p1", - name="阶段1", - assigned_expert="lead", - task_description="t1", - depends_on=["p2"], - ), - PlanPhase( - id="p2", - name="阶段2", - assigned_expert="lead", - task_description="t2", - depends_on=["p1"], - ), - ], - lead_expert="lead", - ) - - result = await orchestrator.execute_plan(plan) - - assert result["status"] == "failed" - assert "errors" in result - assert len(result["errors"]) > 0 + # The subtask should have been reassigned to lead + plan = result["plan"] + assert plan.subtasks[0].assigned_expert == "lead" # ── 结果综合测试 ────────────────────────────────────────── -class TestSynthesizeResults: +class TestResultSynthesis: """结果综合测试""" @pytest.mark.asyncio - async def test_synthesize_results(self): - """综合所有阶段结果""" + async def test_synthesize_single_result(self): + """单个子任务结果直接返回""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=2) - result = await orchestrator.execute_plan(plan) + result = await orchestrator.execute("测试任务") assert result["status"] == "completed" final = result["result"] - assert final["task"] == "测试任务" - assert final["phases_completed"] == 2 - assert final["phases_total"] == 2 - assert len(final["results"]) == 2 + assert "content" in final + assert final["strategy"] == "best" + assert final["subtasks_completed"] == 1 @pytest.mark.asyncio - async def test_synthesize_results_only_completed_phases(self): - """只综合已完成阶段的结果""" + async def test_synthesize_multiple_results_with_llm(self): + """多个子任务结果通过 LLM 综合""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = CollaborationPlan( - id="plan_partial", - task="部分完成测试", - phases=[ - PlanPhase( - id="phase_1", - name="完成阶段", - assigned_expert="lead", - task_description="任务1", - ), - PlanPhase( - id="phase_2", - name="依赖阶段", - assigned_expert="member1", - task_description="任务2", - depends_on=["phase_1"], - ), - ], - lead_expert="lead", + # Set up LLM for both decomposition and synthesis + import json + gateway = AsyncMock() + + # First call: decomposition + decomp_response = MagicMock() + decomp_response.content = json.dumps([ + {"description": "子任务1", "assigned_expert": "member1"}, + {"description": "子任务2", "assigned_expert": "member2"}, + ]) + + # Second call: synthesis + synth_response = MagicMock() + synth_response.content = "综合结果" + + gateway.chat = AsyncMock(side_effect=[decomp_response, synth_response]) + team._experts["lead"].agent._llm_gateway = gateway + + result = await orchestrator.execute("复杂任务") + + assert result["status"] == "completed" + final = result["result"] + assert final["content"] == "综合结果" + assert final["strategy"] == "best" + assert final["subtasks_completed"] == 2 + + @pytest.mark.asyncio + async def test_synthesize_without_llm_concatenates(self): + """无 LLM 时拼接所有结果""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + # Set up LLM for decomposition only (no synthesis LLM) + import json + gateway = AsyncMock() + decomp_response = MagicMock() + decomp_response.content = json.dumps([ + {"description": "子任务1", "assigned_expert": "member1"}, + {"description": "子任务2", "assigned_expert": "member2"}, + ]) + # Synthesis call raises to force concatenation fallback + gateway.chat = AsyncMock( + side_effect=[decomp_response, RuntimeError("LLM unavailable")] ) + team._experts["lead"].agent._llm_gateway = gateway - # Manually set phase_1 as completed, phase_2 as pending - plan.update_phase_status("phase_1", PhaseStatus.COMPLETED, {"output": "done"}) + result = await orchestrator.execute("复杂任务") - # Synthesize directly - phase_results = {"phase_1": {"output": "done"}} - result = await orchestrator._synthesize_results(plan, phase_results) + assert result["status"] == "completed" + final = result["result"] + assert "content" in final + # Should contain both results concatenated + assert "Result from member1" in final["content"] + assert "Result from member2" in final["content"] - assert result["phases_completed"] == 1 - assert result["phases_total"] == 2 + +# ── 回退测试 ────────────────────────────────────────────── + + +class TestFallback: + """回退到单 Agent 模式测试""" + + @pytest.mark.asyncio + async def test_all_subtasks_fail_triggers_fallback(self): + """所有子任务失败时触发回退""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + # Make agent.execute raise for all subtasks + for expert in team._experts.values(): + expert.agent.execute = AsyncMock(side_effect=RuntimeError("Execution failed")) + + result = await orchestrator.execute("测试任务") + + assert result["status"] == "fallback" + assert result["plan"].status == PlanStatus.FALLBACK + + @pytest.mark.asyncio + async def test_fallback_uses_lead_expert(self): + """回退使用 lead expert 执行原始任务""" + team = _make_team_with_experts() + orchestrator = TeamOrchestrator(team) + + # Make agent.execute raise for subtasks but succeed for fallback + call_count = 0 + + async def mock_execute(task_msg): + nonlocal call_count + call_count += 1 + if task_msg.task_type == "team_subtask": + raise RuntimeError("Subtask failed") + # Fallback succeeds + return TaskResult( + task_id=task_msg.task_id, + agent_name="lead", + status=TaskStatus.COMPLETED.value, + output_data={"content": "Fallback result"}, + error_message=None, + started_at=None, + completed_at=None, + ) + + team._experts["lead"].agent.execute = AsyncMock(side_effect=mock_execute) + + result = await orchestrator.execute("测试任务") + + assert result["status"] == "fallback" + assert result["result"]["content"] == "Fallback result" + + @pytest.mark.asyncio + async def test_no_active_experts_returns_failed(self): + """没有活跃专家时返回 failed""" + team = _make_team_with_experts() + # Mark all experts as inactive + for expert in team._experts.values(): + expert.is_active = False + orchestrator = TeamOrchestrator(team) + + result = await orchestrator.execute("测试任务") + + assert result["status"] == "failed" + assert "error" in result # ── 事件广播测试 ────────────────────────────────────────── @@ -623,7 +506,7 @@ class TestBroadcastEvent: await orchestrator._broadcast_event("test_event", {"key": "value"}) - team._handoff_transport.send.assert_awaited_once() + team._handoff_transport.send.assert_awaited() call_args = team._handoff_transport.send.call_args assert call_args[0][0] == team._team_channel message = call_args[0][1] @@ -641,105 +524,117 @@ class TestBroadcastEvent: await orchestrator._broadcast_event("test_event", {"key": "value"}) @pytest.mark.asyncio - async def test_phase_execution_broadcasts_events(self): - """阶段执行时广播 phase_started 和 phase_completed 事件""" + async def test_broadcast_event_handles_transport_error(self): + """handoff_transport 发送失败时不影响执行""" + team = _make_team_with_experts() + team._handoff_transport.send = AsyncMock(side_effect=RuntimeError("Transport error")) + orchestrator = TeamOrchestrator(team) + + # Should not raise + await orchestrator._broadcast_event("test_event", {"key": "value"}) + + +# ── LLM Gateway 测试 ────────────────────────────────────── + + +class TestLLMGateway: + """LLM Gateway 获取测试""" + + def test_get_llm_gateway_from_lead(self): + """从 lead expert 获取 LLM gateway""" + team = _make_team_with_experts() + gateway = MagicMock() + team._experts["lead"].agent._llm_gateway = gateway + orchestrator = TeamOrchestrator(team) + + result = orchestrator._get_llm_gateway() + assert result is gateway + + def test_get_llm_gateway_no_gateway(self): + """没有 LLM gateway 时返回 None""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - await orchestrator.execute_plan(plan) + result = orchestrator._get_llm_gateway() + assert result is None - calls = team._handoff_transport.send.call_args_list - event_types = [c[0][1]["type"] for c in calls] - assert "phase_started" in event_types - assert "phase_completed" in event_types + def test_get_llm_gateway_fallback_to_active_expert(self): + """lead 没有 gateway 时从其他活跃专家获取""" + team = _make_team_with_experts() + gateway = MagicMock() + # Lead has no gateway, but member1 does + team._experts["member1"].agent._llm_gateway = gateway + orchestrator = TeamOrchestrator(team) + + result = orchestrator._get_llm_gateway() + assert result is gateway -# ── 竞争并行全部失败测试 ────────────────────────────────── +# ── 并行执行测试 ────────────────────────────────────────── -class TestCompetitiveAllFail: - """竞争并行全部失败测试""" +class TestParallelExecution: + """并行子任务执行测试""" @pytest.mark.asyncio - async def test_all_competitors_fail(self): - """所有竞争者都失败时触发 fallback""" + async def test_multiple_subtasks_execute_in_parallel(self): + """多个子任务并行执行""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - # Mock _run_competitor to always raise - async def mock_run_competitor(expert, phase): - raise RuntimeError("Competitor failed") + # Set up LLM for decomposition + import json + gateway = AsyncMock() + decomp_response = MagicMock() + decomp_response.content = json.dumps([ + {"description": "子任务1", "assigned_expert": "member1"}, + {"description": "子任务2", "assigned_expert": "member2"}, + {"description": "子任务3", "assigned_expert": "lead"}, + ]) + synth_response = MagicMock() + synth_response.content = "综合结果" + gateway.chat = AsyncMock(side_effect=[decomp_response, synth_response]) + team._experts["lead"].agent._llm_gateway = gateway - with patch.object( - orchestrator, "_run_competitor", side_effect=mock_run_competitor - ): - plan = _make_parallel_plan( - parallel_type=ParallelType.COMPETITIVE_PARALLEL, - merge_strategy=MergeStrategy.BEST, - ) - result = await orchestrator.execute_plan(plan) + result = await orchestrator.execute("并行任务") - # All competitors failed → triggers fallback - assert result["status"] == "fallback" - - -# ── Expert 不可用测试 ──────────────────────────────────── - - -class TestExpertUnavailable: - """Expert 不可用测试""" - - @pytest.mark.asyncio - async def test_inactive_expert_falls_back_to_active(self): - """分配的 Expert 不活跃时自动降级到其他可用 Expert""" - team = _make_team_with_experts() - # Mark the lead expert as inactive - team._experts["lead"].is_active = False - orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - - result = await orchestrator.execute_plan(plan) - - # Phase should complete via fallback expert (member1) assert result["status"] == "completed" + assert len(result["subtask_results"]) == 3 + # All subtasks should be completed + plan = result["plan"] + for st in plan.subtasks: + assert st.status == SubTaskStatus.COMPLETED @pytest.mark.asyncio - async def test_nonexistent_expert_falls_back_to_lead(self): - """分配的 Expert 不存在时自动降级到 lead expert""" + async def test_partial_failure_still_completes(self): + """部分子任务失败时仍能完成(只要有成功的)""" team = _make_team_with_experts() orchestrator = TeamOrchestrator(team) - plan = CollaborationPlan( - id="plan_no_expert", - task="无专家测试", - phases=[ - PlanPhase( - id="phase_1", - name="无专家阶段", - assigned_expert="nonexistent_expert", - task_description="执行任务", - ) - ], - lead_expert="lead", + # Set up LLM for decomposition + import json + gateway = AsyncMock() + decomp_response = MagicMock() + decomp_response.content = json.dumps([ + {"description": "子任务1", "assigned_expert": "member1"}, + {"description": "子任务2", "assigned_expert": "member2"}, + ]) + synth_response = MagicMock() + synth_response.content = "综合结果" + gateway.chat = AsyncMock(side_effect=[decomp_response, synth_response]) + team._experts["lead"].agent._llm_gateway = gateway + + # Make member1's agent fail + team._experts["member1"].agent.execute = AsyncMock( + side_effect=RuntimeError("member1 failed") ) - result = await orchestrator.execute_plan(plan) + result = await orchestrator.execute("部分失败任务") - # Phase should complete via fallback to lead expert assert result["status"] == "completed" - - @pytest.mark.asyncio - async def test_all_experts_unavailable_causes_failure(self): - """所有 Expert 都不可用时阶段失败""" - team = _make_team_with_experts() - # Mark all experts as inactive - for expert in team._experts.values(): - expert.is_active = False - orchestrator = TeamOrchestrator(team) - plan = _make_serial_plan(num_phases=1) - - result = await orchestrator.execute_plan(plan) - - # No expert available → phase fails → fallback - assert result["status"] == "fallback" + # member2's subtask should have succeeded + plan = result["plan"] + completed = plan.completed_subtasks + failed = plan.failed_subtasks + assert len(completed) >= 1 + assert len(failed) >= 1 diff --git a/tests/unit/tools/test_tool_search.py b/tests/unit/tools/test_tool_search.py new file mode 100644 index 0000000..43b65ec --- /dev/null +++ b/tests/unit/tools/test_tool_search.py @@ -0,0 +1,472 @@ +"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试 + +测试场景: +- ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配 +- ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述 +- ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表 +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from agentkit.tools.base import Tool +from agentkit.tools.builtin import ToolSearchTool +from agentkit.tools.search import ToolSearchIndex + + +# ── Test Helpers ────────────────────────────────────────── + + +class FakeTool(Tool): + """用于测试的 Fake Tool""" + + def __init__( + self, + name: str, + description: str, + input_schema: dict[str, Any] | None = None, + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema, + tags=tags or [], + ) + + async def execute(self, **kwargs) -> dict: + return {"status": "ok"} + + +def _make_tools() -> list[Tool]: + """创建一组测试工具""" + return [ + FakeTool( + name="read_file", + description="Read the contents of a file from the filesystem.", + input_schema={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "file path to read"}, + }, + "required": ["path"], + }, + tags=["io", "file"], + ), + FakeTool( + name="write_file", + description="Write content to a file on the filesystem.", + input_schema={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "file path to write"}, + "content": {"type": "string", "description": "content to write"}, + }, + "required": ["path", "content"], + }, + tags=["io", "file"], + ), + FakeTool( + name="web_search", + description="Search the web for information using a search engine.", + input_schema={ + "type": "object", + "properties": { + "query": {"type": "string", "description": "search query"}, + }, + "required": ["query"], + }, + tags=["web", "search"], + ), + FakeTool( + name="run_tests", + description="Run project tests to verify code changes.", + input_schema={ + "type": "object", + "properties": { + "commands": {"type": "array", "description": "test commands"}, + }, + }, + tags=["testing", "verification"], + ), + ] + + +# ── ToolSearchIndex Tests ───────────────────────────────── + + +class TestToolSearchIndex: + """ToolSearchIndex BM25 搜索测试""" + + def test_empty_tools(self): + """空工具列表构建索引不报错""" + index = ToolSearchIndex([]) + assert len(index) == 0 + assert index.search("anything") == [] + + def test_single_tool_match(self): + """单工具索引,匹配查询返回该工具""" + tools = _make_tools()[:1] + index = ToolSearchIndex(tools) + results = index.search("read file") + assert len(results) == 1 + assert results[0].name == "read_file" + + def test_relevance_ranking(self): + """多工具索引,相关工具排在前面""" + tools = _make_tools() + index = ToolSearchIndex(tools) + results = index.search("web search") + assert len(results) > 0 + # web_search 应该排在最前 + assert results[0].name == "web_search" + + def test_top_k_limit(self): + """top_k 限制返回数量""" + tools = _make_tools() + index = ToolSearchIndex(tools) + results = index.search("file", top_k=2) + assert len(results) <= 2 + + def test_no_match_returns_empty(self): + """无匹配时返回空列表""" + tools = _make_tools() + index = ToolSearchIndex(tools) + results = index.search("xyzzy_nonexistent") + assert results == [] + + def test_empty_query_returns_empty(self): + """空查询返回空列表""" + tools = _make_tools() + index = ToolSearchIndex(tools) + assert index.search("") == [] + assert index.search(" ") == [] + + def test_top_k_zero_returns_empty(self): + """top_k=0 返回空列表""" + tools = _make_tools() + index = ToolSearchIndex(tools) + assert index.search("file", top_k=0) == [] + + def test_snake_case_tokenization(self): + """snake_case 工具名被正确分词""" + tool = FakeTool( + name="read_file", + description="Read file contents.", + ) + index = ToolSearchIndex([tool]) + # 搜索 "read" 应该匹配 + results = index.search("read") + assert len(results) == 1 + # 搜索 "file" 也应该匹配 + results = index.search("file") + assert len(results) == 1 + + def test_search_includes_parameter_descriptions(self): + """搜索能匹配参数描述中的关键词""" + tool = FakeTool( + name="custom_tool", + description="A custom tool.", + input_schema={ + "type": "object", + "properties": { + "database_url": { + "type": "string", + "description": "PostgreSQL connection string", + }, + }, + }, + ) + index = ToolSearchIndex([tool]) + results = index.search("postgresql database") + assert len(results) == 1 + assert results[0].name == "custom_tool" + + def test_search_includes_tags(self): + """搜索能匹配标签中的关键词""" + tool = FakeTool( + name="data_tool", + description="Process data.", + tags=["etl", "pipeline"], + ) + index = ToolSearchIndex([tool]) + results = index.search("pipeline etl") + assert len(results) == 1 + assert results[0].name == "data_tool" + + def test_invalid_k1_raises(self): + """k1 < 0 抛出 ValueError""" + with pytest.raises(ValueError, match="k1"): + ToolSearchIndex(_make_tools(), k1=-1.0) + + def test_invalid_b_raises(self): + """b 不在 [0,1] 范围抛出 ValueError""" + with pytest.raises(ValueError, match="b"): + ToolSearchIndex(_make_tools(), b=1.5) + with pytest.raises(ValueError, match="b"): + ToolSearchIndex(_make_tools(), b=-0.1) + + def test_multiple_results_sorted_by_score(self): + """多个匹配结果按分数降序排列""" + tools = [ + FakeTool(name="search_web", description="Search the web."), + FakeTool(name="search_files", description="Search files on disk."), + FakeTool(name="unrelated", description="Do something unrelated."), + ] + index = ToolSearchIndex(tools) + results = index.search("search") + # 两个包含 "search" 的工具应该返回,unrelated 不返回 + assert len(results) == 2 + names = [r.name for r in results] + assert "unrelated" not in names + + +# ── ToolSearchTool Tests ────────────────────────────────── + + +class TestToolSearchTool: + """ToolSearchTool 工具测试""" + + def test_tool_name_and_schema(self): + """工具名称和 schema 正确""" + index = ToolSearchIndex(_make_tools()) + tool = ToolSearchTool(search_index=index) + assert tool.name == "tool_search" + assert "query" in tool.input_schema["properties"] + assert "query" in tool.input_schema["required"] + + async def test_execute_returns_results(self): + """执行搜索返回匹配工具的完整描述""" + index = ToolSearchIndex(_make_tools()) + tool = ToolSearchTool(search_index=index) + result = await tool.execute(query="web search") + + assert result["count"] > 0 + assert result["query"] == "web search" + first = result["results"][0] + assert "name" in first + assert "description" in first + assert "parameters" in first + assert first["name"] == "web_search" + + async def test_execute_empty_query_returns_error(self): + """空查询返回错误""" + index = ToolSearchIndex(_make_tools()) + tool = ToolSearchTool(search_index=index) + result = await tool.execute(query="") + assert "error" in result + assert result["results"] == [] + + async def test_execute_no_match(self): + """无匹配返回空结果和提示消息""" + index = ToolSearchIndex(_make_tools()) + tool = ToolSearchTool(search_index=index) + result = await tool.execute(query="zzz_nonexistent") + assert result["count"] == 0 + assert result["results"] == [] + assert "message" in result + + async def test_execute_respects_top_k(self): + """top_k 限制返回数量""" + index = ToolSearchIndex(_make_tools()) + tool = ToolSearchTool(search_index=index, top_k=1) + result = await tool.execute(query="file") + assert result["count"] <= 1 + + def test_invalid_top_k_raises(self): + """top_k < 1 抛出 ValueError""" + index = ToolSearchIndex(_make_tools()) + with pytest.raises(ValueError, match="top_k"): + ToolSearchTool(search_index=index, top_k=0) + + +# ── ReActEngine Tiered Injection Tests ──────────────────── + + +class TestReActTieredInjection: + """ReActEngine 工具描述分层注入测试""" + + def _make_engine(self, **kwargs: Any): + from agentkit.core.react import ReActEngine + + gateway = MagicMock() + return ReActEngine(llm_gateway=gateway, **kwargs) + + def test_core_tools_get_full_description(self): + """Core 工具注入完整描述(含参数)""" + engine = self._make_engine() + tools = [ + FakeTool( + name="read_file", + description="Read a file.", + input_schema={ + "type": "object", + "properties": { + "path": {"type": "string", "description": "file path"}, + }, + "required": ["path"], + }, + ), + ] + prompt = engine._build_tool_use_prompt(tools) + # 核心工具区域存在 + assert "核心工具" in prompt + # 参数描述被注入 + assert "path" in prompt + assert "file path" in prompt + + def test_extended_tools_get_one_line_only(self): + """Extended 工具只注入名称+一行描述(无参数)""" + engine = self._make_engine() + tools = [ + FakeTool( + name="custom_extended", + description="A custom extended tool for testing.", + input_schema={ + "type": "object", + "properties": { + "secret_param": { + "type": "string", + "description": "SECRET_PARAM_DESC", + }, + }, + }, + ), + ] + prompt = engine._build_tool_use_prompt(tools) + assert "扩展工具" in prompt + assert "custom_extended" in prompt + # 参数描述不应出现在扩展工具区域 + assert "SECRET_PARAM_DESC" not in prompt + + def test_mixed_core_and_extended(self): + """混合 core + extended 工具,两者分区显示""" + engine = self._make_engine() + tools = [ + FakeTool(name="read_file", description="Read a file."), + FakeTool( + name="web_search", + description="Search the web.", + input_schema={ + "type": "object", + "properties": { + "q": {"type": "string", "description": "query string"}, + }, + }, + ), + ] + prompt = engine._build_tool_use_prompt(tools) + assert "核心工具" in prompt + assert "扩展工具" in prompt + # read_file 在核心区,web_search 在扩展区 + assert "read_file" in prompt + assert "web_search" in prompt + + def test_maybe_add_tool_search_adds_for_extended(self): + """有扩展工具时自动添加 tool_search""" + engine = self._make_engine() + tools = [ + FakeTool(name="read_file", description="Read a file."), # core + FakeTool(name="web_search", description="Search the web."), # extended + ] + result = engine._maybe_add_tool_search(tools) + assert len(result) == 3 + assert any(t.name == "tool_search" for t in result) + + def test_maybe_add_tool_search_skips_when_only_core(self): + """只有 core 工具时不添加 tool_search""" + engine = self._make_engine() + tools = [ + FakeTool(name="read_file", description="Read a file."), + FakeTool(name="write_file", description="Write a file."), + ] + result = engine._maybe_add_tool_search(tools) + assert len(result) == 2 + assert not any(t.name == "tool_search" for t in result) + + def test_maybe_add_tool_search_skips_when_disabled(self): + """enable_tool_search=False 时不添加 tool_search""" + engine = self._make_engine(enable_tool_search=False) + tools = [ + FakeTool(name="read_file", description="Read a file."), + FakeTool(name="web_search", description="Search the web."), + ] + result = engine._maybe_add_tool_search(tools) + assert len(result) == 2 + assert not any(t.name == "tool_search" for t in result) + + def test_maybe_add_tool_search_skips_when_already_present(self): + """tool_search 已存在时不重复添加""" + engine = self._make_engine() + index = ToolSearchIndex([]) + existing_search = ToolSearchTool(search_index=index) + tools = [ + FakeTool(name="read_file", description="Read a file."), + FakeTool(name="web_search", description="Search the web."), + existing_search, + ] + result = engine._maybe_add_tool_search(tools) + assert len(result) == 3 + + def test_custom_core_tool_names(self): + """自定义 core_tool_names 覆盖默认值""" + engine = self._make_engine(core_tool_names=["my_core_tool"]) + tools = [ + FakeTool(name="my_core_tool", description="My core tool."), + FakeTool( + name="read_file", + description="Read a file.", + input_schema={ + "type": "object", + "properties": { + "p": {"type": "string", "description": "PARAM_DESC"}, + }, + }, + ), + ] + prompt = engine._build_tool_use_prompt(tools) + # my_core_tool 在核心区 + assert "核心工具" in prompt + # read_file 现在是扩展工具,参数描述不应出现 + assert "PARAM_DESC" not in prompt + + def test_tool_search_is_core_tool(self): + """tool_search 被视为 core 工具(全量描述注入)""" + engine = self._make_engine() + index = ToolSearchIndex([]) + search_tool = ToolSearchTool(search_index=index) + tools = [search_tool] + prompt = engine._build_tool_use_prompt(tools) + # tool_search 应该在核心工具区 + assert "核心工具" in prompt + assert "tool_search" in prompt + # 其参数 query 应该被注入 + assert "query" in prompt + + def test_search_hint_added_when_tool_search_present(self): + """tool_search 存在且有扩展工具时添加搜索提示""" + engine = self._make_engine() + tools = [ + FakeTool(name="read_file", description="Read a file."), + FakeTool(name="web_search", description="Search the web."), + ] + # 模拟 _maybe_add_tool_search 添加 tool_search + tools_with_search = engine._maybe_add_tool_search(tools) + prompt = engine._build_tool_use_prompt(tools_with_search) + assert "tool_search" in prompt + assert "扩展工具" in prompt + # 搜索提示存在 + assert "tool_search(query" in prompt + + def test_no_search_hint_when_no_extended_tools(self): + """无扩展工具时不添加搜索提示""" + engine = self._make_engine() + tools = [ + FakeTool(name="read_file", description="Read a file."), + ] + prompt = engine._build_tool_use_prompt(tools) + assert "扩展工具" not in prompt