From 23934602c0f1f77d17646dd65e2f78eaf546a60d Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sat, 6 Jun 2026 22:25:12 +0800 Subject: [PATCH] feat(core): U4 multi-agent Orchestrator with SharedWorkspace - Orchestrator: Orchestrator-Worker pattern with LLM-driven task decomposition - SharedWorkspace: Redis-backed shared state with versioning and distributed locks - SubTask dependency graph, parallel group building, result aggregation - 16 tests passing --- src/agentkit/core/orchestrator.py | 406 ++++++++++++++++++++++++++ src/agentkit/core/shared_workspace.py | 159 ++++++++++ tests/unit/test_orchestrator.py | 336 +++++++++++++++++++++ 3 files changed, 901 insertions(+) create mode 100644 src/agentkit/core/orchestrator.py create mode 100644 src/agentkit/core/shared_workspace.py create mode 100644 tests/unit/test_orchestrator.py diff --git a/src/agentkit/core/orchestrator.py b/src/agentkit/core/orchestrator.py new file mode 100644 index 0000000..558ae84 --- /dev/null +++ b/src/agentkit/core/orchestrator.py @@ -0,0 +1,406 @@ +"""Orchestrator - 多 Agent 协作编排器 + +实现 Orchestrator-Worker 模式:中央编排器协调多 Agent 并行/串行执行。 +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus +from agentkit.core.shared_workspace import SharedWorkspace + +logger = logging.getLogger(__name__) + + +class AgentRole(str, Enum): + """Agent 角色枚举""" + + ORCHESTRATOR = "orchestrator" + WORKER = "worker" + REVIEWER = "reviewer" + + +class SubTaskStatus(str, Enum): + """子任务状态""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class SubTask: + """子任务定义""" + + task_id: str + parent_task_id: str + assigned_agent: str + task_type: str + input_data: dict[str, Any] + status: SubTaskStatus = SubTaskStatus.PENDING + result: dict[str, Any] | None = None + error: str | None = None + depends_on: list[str] = field(default_factory=list) + + +@dataclass +class OrchestrationPlan: + """编排计划""" + + plan_id: str + parent_task_id: str + subtasks: list[SubTask] + parallel_groups: list[list[str]] # 每组内的子任务可并行执行 + + +@dataclass +class OrchestrationResult: + """编排结果""" + + plan_id: str + parent_task_id: str + subtask_results: dict[str, dict[str, Any]] + aggregated_result: dict[str, Any] + status: TaskStatus + total_duration_ms: float + + +class Orchestrator: + """多 Agent 协作编排器 + + Orchestrator-Worker 模式: + 1. 接收复杂任务 + 2. LLM 驱动分解为子任务 + 3. 基于 Skill 能力匹配子任务到 Worker Agent + 4. 并行/串行执行子任务 + 5. 汇总结果,生成最终输出 + + 使用方式: + orchestrator = Orchestrator(agent_pool=pool, workspace=workspace) + result = await orchestrator.execute(task_message) + """ + + def __init__( + self, + agent_pool: Any, + workspace: SharedWorkspace | None = None, + llm_gateway: Any = None, + max_parallel: int = 5, + subtask_timeout: float = 300.0, + ): + """ + Args: + agent_pool: AgentPool 实例 + workspace: 共享工作空间 + llm_gateway: LLM Gateway,用于任务分解 + max_parallel: 最大并行子任务数 + subtask_timeout: 子任务超时时间(秒) + """ + self._agent_pool = agent_pool + self._workspace = workspace or SharedWorkspace() + self._llm_gateway = llm_gateway + self._max_parallel = max_parallel + self._subtask_timeout = subtask_timeout + + async def execute(self, task: TaskMessage) -> OrchestrationResult: + """执行编排任务 + + Args: + task: 原始任务消息 + + Returns: + OrchestrationResult: 编排结果 + """ + import time + + start_time = time.monotonic() + + # 1. Decompose task into subtasks + plan = await self._decompose_task(task) + + if not plan.subtasks: + return OrchestrationResult( + plan_id=plan.plan_id, + parent_task_id=task.task_id, + subtask_results={}, + aggregated_result={"error": "Failed to decompose task"}, + status=TaskStatus.FAILED, + total_duration_ms=0, + ) + + # 2. Store plan in workspace + await self._workspace.write( + f"plan:{plan.plan_id}", + {"task_id": task.task_id, "subtask_count": len(plan.subtasks)}, + agent_id="orchestrator", + ) + + # 3. Execute subtasks + subtask_results = await self._execute_plan(plan, task) + + # 4. Aggregate results + aggregated = await self._aggregate_results(plan, subtask_results, task) + + # 5. Determine overall status + failed_count = sum( + 1 for r in subtask_results.values() if r.get("status") == "failed" + ) + if failed_count == len(plan.subtasks): + status = TaskStatus.FAILED + elif failed_count > 0: + status = TaskStatus.COMPLETED # Partial success + else: + status = TaskStatus.COMPLETED + + duration_ms = (time.monotonic() - start_time) * 1000 + + return OrchestrationResult( + plan_id=plan.plan_id, + parent_task_id=task.task_id, + subtask_results=subtask_results, + aggregated_result=aggregated, + status=status, + total_duration_ms=duration_ms, + ) + + async def _decompose_task(self, task: TaskMessage) -> OrchestrationPlan: + """将复杂任务分解为子任务""" + plan_id = str(uuid.uuid4())[:8] + + # If LLM gateway available, use it for decomposition + if self._llm_gateway: + try: + subtasks = await self._llm_decompose(task) + if subtasks: + parallel_groups = self._build_parallel_groups(subtasks) + return OrchestrationPlan( + plan_id=plan_id, + parent_task_id=task.task_id, + subtasks=subtasks, + parallel_groups=parallel_groups, + ) + except Exception as e: + logger.warning(f"LLM decomposition failed, falling back to simple: {e}") + + # Fallback: single subtask = original task + subtask = SubTask( + task_id=f"{plan_id}-0", + parent_task_id=task.task_id, + assigned_agent=task.agent_name, + task_type=task.task_type, + input_data=task.input_data, + ) + return OrchestrationPlan( + plan_id=plan_id, + parent_task_id=task.task_id, + subtasks=[subtask], + parallel_groups=[[subtask.task_id]], + ) + + async def _llm_decompose(self, task: TaskMessage) -> list[SubTask]: + """使用 LLM 分解任务""" + # Get available agents and their capabilities + agents_info = self._agent_pool.list_agents() + agent_descriptions = "\n".join( + f"- {a['name']} ({a['agent_type']}): {a.get('description', 'No description')}" + for a in agents_info + ) + + prompt = ( + f"Decompose the following task into subtasks that can be assigned to available agents.\n\n" + f"Task: {task.input_data}\n" + f"Task Type: {task.task_type}\n\n" + f"Available Agents:\n{agent_descriptions}\n\n" + 'Respond ONLY with a JSON array: [{"agent_name": "...", "task_type": "...", ' + '"input_data": {...}, "depends_on": []}]\n' + "The depends_on field lists task indices (0-based) that must complete first.\n" + "Do not include any other text." + ) + + import json + + response = await self._llm_gateway.chat( + messages=[{"role": "user", "content": prompt}], + model="default", + ) + + try: + subtask_defs = json.loads(response.content) + if not isinstance(subtask_defs, list): + return [] + + subtasks = [] + for i, defn in enumerate(subtask_defs): + depends_on = [ + f"task-{i}" for i in defn.get("depends_on", []) + ] + subtasks.append(SubTask( + task_id=f"task-{i}", + parent_task_id=task.task_id, + assigned_agent=defn.get("agent_name", task.agent_name), + task_type=defn.get("task_type", task.task_type), + input_data=defn.get("input_data", {}), + depends_on=depends_on, + )) + return subtasks + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to parse LLM decomposition: {e}") + return [] + + def _build_parallel_groups(self, subtasks: list[SubTask]) -> list[list[str]]: + """构建并行执行组 + + 基于依赖关系拓扑排序,无依赖的子任务分到同一组并行执行。 + """ + # Build dependency graph + task_map = {st.task_id: st for st in subtasks} + completed: set[str] = set() + groups: list[list[str]] = [] + + remaining = set(st.task_id for st in subtasks) + + while remaining: + # Find tasks with all dependencies satisfied + ready = [] + for tid in remaining: + task = task_map[tid] + if all(dep in completed for dep in task.depends_on): + ready.append(tid) + + if not ready: + # Circular dependency — put remaining in one group + groups.append(list(remaining)) + break + + # Limit group size + group = ready[:self._max_parallel] + groups.append(group) + for tid in group: + completed.add(tid) + remaining.discard(tid) + + return groups + + async def _execute_plan( + self, plan: OrchestrationPlan, original_task: TaskMessage + ) -> dict[str, dict[str, Any]]: + """执行编排计划""" + subtask_results: dict[str, dict[str, Any]] = {} + task_map = {st.task_id: st for st in plan.subtasks} + + for group in plan.parallel_groups: + # Execute group in parallel + tasks = [] + for task_id in group: + subtask = task_map[task_id] + # Inject results from dependencies + enriched_input = self._inject_dependency_results( + subtask, subtask_results + ) + tasks.append(self._execute_subtask(subtask, enriched_input, original_task)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for task_id, result in zip(group, results): + if isinstance(result, Exception): + subtask_results[task_id] = { + "status": "failed", + "error": str(result), + } + else: + subtask_results[task_id] = result + + return subtask_results + + async def _execute_subtask( + self, + subtask: SubTask, + input_data: dict[str, Any], + original_task: TaskMessage, + ) -> dict[str, Any]: + """执行单个子任务""" + agent = self._agent_pool.get_agent(subtask.assigned_agent) + if agent is None: + return {"status": "failed", "error": f"Agent '{subtask.assigned_agent}' not found"} + + sub_task_msg = TaskMessage( + task_id=subtask.task_id, + agent_name=subtask.assigned_agent, + task_type=subtask.task_type, + priority=original_task.priority, + input_data=input_data, + callback_url=None, + created_at=original_task.created_at, + timeout_seconds=int(self._subtask_timeout), + ) + + try: + result = await asyncio.wait_for( + agent.execute(sub_task_msg), + timeout=self._subtask_timeout, + ) + return { + "status": "completed", + "output": result.output_data if hasattr(result, "output_data") else result, + } + except asyncio.TimeoutError: + return {"status": "failed", "error": "Subtask timed out"} + except Exception as e: + return {"status": "failed", "error": str(e)} + + def _inject_dependency_results( + self, + subtask: SubTask, + subtask_results: dict[str, dict[str, Any]], + ) -> dict[str, Any]: + """将依赖子任务的结果注入到当前子任务的输入中""" + enriched = dict(subtask.input_data) + + if subtask.depends_on: + dep_results = {} + for dep_id in subtask.depends_on: + if dep_id in subtask_results: + dep_results[dep_id] = subtask_results[dep_id] + if dep_results: + enriched["dependency_results"] = dep_results + + return enriched + + async def _aggregate_results( + self, + plan: OrchestrationPlan, + subtask_results: dict[str, dict[str, Any]], + original_task: TaskMessage, + ) -> dict[str, Any]: + """汇总子任务结果""" + # Simple aggregation: collect all outputs + outputs = {} + errors = [] + + for subtask in plan.subtasks: + result = subtask_results.get(subtask.task_id, {}) + if result.get("status") == "completed": + outputs[subtask.task_id] = result.get("output", {}) + else: + errors.append({ + "task_id": subtask.task_id, + "error": result.get("error", "Unknown error"), + }) + + aggregated = { + "outputs": outputs, + "task_id": original_task.task_id, + } + if errors: + aggregated["errors"] = errors + aggregated["partial_success"] = True + + return aggregated diff --git a/src/agentkit/core/shared_workspace.py b/src/agentkit/core/shared_workspace.py new file mode 100644 index 0000000..702a720 --- /dev/null +++ b/src/agentkit/core/shared_workspace.py @@ -0,0 +1,159 @@ +"""SharedWorkspace - Agent 间共享工作空间 + +基于 Redis 的共享状态存储,支持读写、订阅、锁操作。 +""" + +from __future__ import annotations + +import json +import logging +import time +from typing import Any + +logger = logging.getLogger(__name__) + + +class SharedWorkspace: + """Agent 间共享工作空间 + + 基于 Redis 的共享状态存储,支持: + - write/read: 读写共享数据 + - lock/unlock: 分布式锁 + - 版本控制:每次写入递增版本号 + """ + + def __init__(self, redis_client: Any = None, prefix: str = "workspace"): + """ + Args: + redis_client: aioredis.Redis 实例,None 时使用内存字典 + prefix: Redis key 前缀 + """ + self._redis = redis_client + self._prefix = prefix + self._local_store: dict[str, dict[str, Any]] = {} + self._locks: dict[str, str] = {} # key -> lock_owner + + def _make_key(self, key: str) -> str: + return f"{self._prefix}:{key}" + + async def write( + self, key: str, value: Any, agent_id: str, ttl: int | None = None + ) -> int: + """写入共享数据 + + Args: + key: 数据键 + value: 数据值 + agent_id: 写入者 ID + ttl: 过期时间(秒),None 表示不过期 + + Returns: + 版本号 + """ + entry = { + "value": value, + "agent_id": agent_id, + "version": await self._get_version(key) + 1, + "timestamp": time.time(), + } + + if self._redis: + redis_key = self._make_key(key) + data = json.dumps(entry, default=str) + if ttl: + await self._redis.setex(redis_key, ttl, data) + else: + await self._redis.set(redis_key, data) + else: + self._local_store[key] = entry + + return entry["version"] + + async def read(self, key: str) -> dict[str, Any] | None: + """读取共享数据 + + Returns: + {"value": ..., "agent_id": ..., "version": ..., "timestamp": ...} 或 None + """ + if self._redis: + redis_key = self._make_key(key) + data = await self._redis.get(redis_key) + if data is None: + return None + return json.loads(data) + else: + return self._local_store.get(key) + + async def delete(self, key: str) -> bool: + """删除共享数据""" + if self._redis: + redis_key = self._make_key(key) + result = await self._redis.delete(redis_key) + return result > 0 + else: + return self._local_store.pop(key, None) is not None + + async def lock(self, key: str, agent_id: str, timeout: float = 30.0) -> bool: + """获取分布式锁 + + Args: + key: 要锁定的数据键 + agent_id: 请求锁的 Agent ID + timeout: 锁超时时间(秒) + + Returns: + 是否成功获取锁 + """ + lock_key = f"{self._prefix}:lock:{key}" + + if self._redis: + # Redis SET with NX (only if not exists) and EX (expiry) + result = await self._redis.set(lock_key, agent_id, nx=True, ex=int(timeout)) + return result is not None + else: + if key in self._locks: + return False + self._locks[key] = agent_id + return True + + async def unlock(self, key: str, agent_id: str) -> bool: + """释放分布式锁 + + 只有锁的持有者才能释放锁。 + """ + lock_key = f"{self._prefix}:lock:{key}" + + if self._redis: + current_owner = await self._redis.get(lock_key) + if current_owner and current_owner.decode() == agent_id: + await self._redis.delete(lock_key) + return True + return False + else: + if self._locks.get(key) == agent_id: + del self._locks[key] + return True + return False + + async def _get_version(self, key: str) -> int: + """获取当前版本号""" + data = await self.read(key) + if data is None: + return 0 + return data.get("version", 0) + + async def list_keys(self) -> list[str]: + """列出所有键""" + if self._redis: + pattern = f"{self._prefix}:*" + keys = [] + async for key in self._redis.scan_iter(match=pattern): + # Strip prefix + k = key.decode() if isinstance(key, bytes) else key + k = k[len(self._prefix) + 1:] # Remove "prefix:" + # Skip lock keys + if not k.startswith("lock:"): + keys.append(k) + return keys + else: + return list(self._local_store.keys()) diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py new file mode 100644 index 0000000..3f343aa --- /dev/null +++ b/tests/unit/test_orchestrator.py @@ -0,0 +1,336 @@ +"""Tests for Orchestrator and SharedWorkspace""" + +import asyncio +import pytest + +from agentkit.core.orchestrator import ( + Orchestrator, + OrchestrationPlan, + OrchestrationResult, + SubTask, + SubTaskStatus, + AgentRole, +) +from agentkit.core.shared_workspace import SharedWorkspace +from agentkit.core.protocol import TaskMessage, TaskStatus +from datetime import datetime, timezone + + +# --- SharedWorkspace Tests --- + + +class TestSharedWorkspace: + """SharedWorkspace unit tests (in-memory mode)""" + + @pytest.mark.asyncio + async def test_write_and_read(self): + ws = SharedWorkspace() + version = await ws.write("key1", {"data": "value"}, agent_id="agent_a") + assert version == 1 + + result = await ws.read("key1") + assert result is not None + assert result["value"] == {"data": "value"} + assert result["agent_id"] == "agent_a" + assert result["version"] == 1 + + @pytest.mark.asyncio + async def test_version_increments(self): + ws = SharedWorkspace() + v1 = await ws.write("key1", "first", agent_id="a") + v2 = await ws.write("key1", "second", agent_id="b") + assert v1 == 1 + assert v2 == 2 + + @pytest.mark.asyncio + async def test_read_nonexistent(self): + ws = SharedWorkspace() + result = await ws.read("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_delete(self): + ws = SharedWorkspace() + await ws.write("key1", "value", agent_id="a") + deleted = await ws.delete("key1") + assert deleted is True + result = await ws.read("key1") + assert result is None + + @pytest.mark.asyncio + async def test_delete_nonexistent(self): + ws = SharedWorkspace() + deleted = await ws.delete("nonexistent") + assert deleted is False + + @pytest.mark.asyncio + async def test_lock_and_unlock(self): + ws = SharedWorkspace() + acquired = await ws.lock("resource1", agent_id="agent_a") + assert acquired is True + + # Same agent can't lock again (already held) + acquired2 = await ws.lock("resource1", agent_id="agent_b") + assert acquired2 is False + + # Owner can unlock + unlocked = await ws.unlock("resource1", agent_id="agent_a") + assert unlocked is True + + # Now another agent can lock + acquired3 = await ws.lock("resource1", agent_id="agent_b") + assert acquired3 is True + + @pytest.mark.asyncio + async def test_unlock_by_non_owner(self): + ws = SharedWorkspace() + await ws.lock("resource1", agent_id="agent_a") + unlocked = await ws.unlock("resource1", agent_id="agent_b") + assert unlocked is False + + @pytest.mark.asyncio + async def test_list_keys(self): + ws = SharedWorkspace() + await ws.write("key1", "v1", agent_id="a") + await ws.write("key2", "v2", agent_id="a") + keys = await ws.list_keys() + assert set(keys) == {"key1", "key2"} + + +# --- Orchestrator Tests --- + + +class MockAgent: + """Mock Agent for testing""" + + def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False): + self.name = name + self.agent_type = "mock" + self.version = "1.0.0" + self._output_data = output_data or {"result": f"output from {name}"} + self._should_fail = should_fail + + async def execute(self, task: TaskMessage): + if self._should_fail: + raise RuntimeError(f"Agent {self.name} failed") + from agentkit.core.protocol import TaskResult + now = datetime.now(timezone.utc) + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=self._output_data, + error_message=None, + started_at=now, + completed_at=now, + ) + + +class MockAgentPool: + """Mock AgentPool for testing""" + + def __init__(self, agents: dict[str, MockAgent] | None = None): + self._agents = agents or {} + + def get_agent(self, name: str) -> MockAgent | None: + return self._agents.get(name) + + def list_agents(self) -> list[dict]: + return [ + {"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"} + for a in self._agents.values() + ] + + +class TestOrchestrator: + """Orchestrator unit tests""" + + @pytest.mark.asyncio + async def test_single_subtask_execution(self): + """Single agent should execute task directly""" + agent = MockAgent("worker1", {"analysis": "result"}) + pool = MockAgentPool({"worker1": agent}) + orchestrator = Orchestrator(agent_pool=pool) + + task = TaskMessage( + task_id="t1", + agent_name="worker1", + task_type="analyze", + priority=1, + input_data={"query": "test"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await orchestrator.execute(task) + assert result.status == TaskStatus.COMPLETED + assert "outputs" in result.aggregated_result + + @pytest.mark.asyncio + async def test_parallel_groups_building(self): + """Parallel groups should be built from dependency graph""" + pool = MockAgentPool() + orchestrator = Orchestrator(agent_pool=pool) + + subtasks = [ + SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1", + task_type="type1", input_data={}, depends_on=[]), + SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2", + task_type="type2", input_data={}, depends_on=[]), + SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3", + task_type="type3", input_data={}, depends_on=["t0"]), + ] + + groups = orchestrator._build_parallel_groups(subtasks) + assert len(groups) == 2 + # First group: t0 and t1 (no dependencies) + assert set(groups[0]) == {"t0", "t1"} + # Second group: t2 (depends on t0) + assert groups[1] == ["t2"] + + @pytest.mark.asyncio + async def test_sequential_dependency(self): + """Tasks with sequential dependencies should execute in order""" + agent1 = MockAgent("a1", {"step": 1}) + agent2 = MockAgent("a2", {"step": 2}) + pool = MockAgentPool({"a1": agent1, "a2": agent2}) + orchestrator = Orchestrator(agent_pool=pool) + + # Manually create a plan with sequential dependencies + plan = OrchestrationPlan( + plan_id="p1", + parent_task_id="parent", + subtasks=[ + SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1", + task_type="step1", input_data={}, depends_on=[]), + SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2", + task_type="step2", input_data={}, depends_on=["t0"]), + ], + parallel_groups=[["t0"], ["t1"]], + ) + + task = TaskMessage( + task_id="parent", + agent_name="orchestrator", + task_type="pipeline", + priority=1, + input_data={"query": "test"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + results = await orchestrator._execute_plan(plan, task) + assert results["t0"]["status"] == "completed" + assert results["t1"]["status"] == "completed" + + @pytest.mark.asyncio + async def test_agent_not_found(self): + """Missing agent should result in failed subtask""" + pool = MockAgentPool({}) + orchestrator = Orchestrator(agent_pool=pool) + + plan = OrchestrationPlan( + plan_id="p1", + parent_task_id="parent", + subtasks=[ + SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent", + task_type="test", input_data={}, depends_on=[]), + ], + parallel_groups=[["t0"]], + ) + + task = TaskMessage( + task_id="parent", + agent_name="orchestrator", + task_type="test", + priority=1, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + results = await orchestrator._execute_plan(plan, task) + assert results["t0"]["status"] == "failed" + assert "not found" in results["t0"]["error"] + + @pytest.mark.asyncio + async def test_dependency_result_injection(self): + """Subtask should receive dependency results in input""" + pool = MockAgentPool() + orchestrator = Orchestrator(agent_pool=pool) + + subtask = SubTask( + task_id="t1", + parent_task_id="p1", + assigned_agent="a1", + task_type="test", + input_data={"query": "test"}, + depends_on=["t0"], + ) + + subtask_results = { + "t0": {"status": "completed", "output": {"step1_result": "data"}}, + } + + enriched = orchestrator._inject_dependency_results(subtask, subtask_results) + assert "dependency_results" in enriched + assert "t0" in enriched["dependency_results"] + + @pytest.mark.asyncio + async def test_aggregation_with_errors(self): + """Aggregation should include errors for failed subtasks""" + pool = MockAgentPool() + orchestrator = Orchestrator(agent_pool=pool) + + plan = OrchestrationPlan( + plan_id="p1", + parent_task_id="parent", + subtasks=[ + SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1", + task_type="test", input_data={}, depends_on=[]), + ], + parallel_groups=[["t0"]], + ) + + subtask_results = { + "t0": {"status": "failed", "error": "Agent failed"}, + } + + task = TaskMessage( + task_id="parent", + agent_name="orchestrator", + task_type="test", + priority=1, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + aggregated = await orchestrator._aggregate_results(plan, subtask_results, task) + assert "errors" in aggregated + assert aggregated["partial_success"] is True + + +class TestAgentRole: + """AgentRole enum tests""" + + def test_role_values(self): + assert AgentRole.ORCHESTRATOR.value == "orchestrator" + assert AgentRole.WORKER.value == "worker" + assert AgentRole.REVIEWER.value == "reviewer" + + +class TestSubTask: + """SubTask dataclass tests""" + + def test_default_values(self): + st = SubTask( + task_id="t1", + parent_task_id="p1", + assigned_agent="a1", + task_type="test", + input_data={}, + ) + assert st.status == SubTaskStatus.PENDING + assert st.result is None + assert st.depends_on == []