feat(core): U4 multi-agent Orchestrator with SharedWorkspace
- Orchestrator: Orchestrator-Worker pattern with LLM-driven task decomposition - SharedWorkspace: Redis-backed shared state with versioning and distributed locks - SubTask dependency graph, parallel group building, result aggregation - 16 tests passing
This commit is contained in:
parent
364fe6bd6d
commit
23934602c0
|
|
@ -0,0 +1,406 @@
|
|||
"""Orchestrator - 多 Agent 协作编排器
|
||||
|
||||
实现 Orchestrator-Worker 模式:中央编排器协调多 Agent 并行/串行执行。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||
from agentkit.core.shared_workspace import SharedWorkspace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRole(str, Enum):
|
||||
"""Agent 角色枚举"""
|
||||
|
||||
ORCHESTRATOR = "orchestrator"
|
||||
WORKER = "worker"
|
||||
REVIEWER = "reviewer"
|
||||
|
||||
|
||||
class SubTaskStatus(str, Enum):
|
||||
"""子任务状态"""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubTask:
|
||||
"""子任务定义"""
|
||||
|
||||
task_id: str
|
||||
parent_task_id: str
|
||||
assigned_agent: str
|
||||
task_type: str
|
||||
input_data: dict[str, Any]
|
||||
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||
result: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
depends_on: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationPlan:
|
||||
"""编排计划"""
|
||||
|
||||
plan_id: str
|
||||
parent_task_id: str
|
||||
subtasks: list[SubTask]
|
||||
parallel_groups: list[list[str]] # 每组内的子任务可并行执行
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestrationResult:
|
||||
"""编排结果"""
|
||||
|
||||
plan_id: str
|
||||
parent_task_id: str
|
||||
subtask_results: dict[str, dict[str, Any]]
|
||||
aggregated_result: dict[str, Any]
|
||||
status: TaskStatus
|
||||
total_duration_ms: float
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
"""多 Agent 协作编排器
|
||||
|
||||
Orchestrator-Worker 模式:
|
||||
1. 接收复杂任务
|
||||
2. LLM 驱动分解为子任务
|
||||
3. 基于 Skill 能力匹配子任务到 Worker Agent
|
||||
4. 并行/串行执行子任务
|
||||
5. 汇总结果,生成最终输出
|
||||
|
||||
使用方式:
|
||||
orchestrator = Orchestrator(agent_pool=pool, workspace=workspace)
|
||||
result = await orchestrator.execute(task_message)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_pool: Any,
|
||||
workspace: SharedWorkspace | None = None,
|
||||
llm_gateway: Any = None,
|
||||
max_parallel: int = 5,
|
||||
subtask_timeout: float = 300.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
agent_pool: AgentPool 实例
|
||||
workspace: 共享工作空间
|
||||
llm_gateway: LLM Gateway,用于任务分解
|
||||
max_parallel: 最大并行子任务数
|
||||
subtask_timeout: 子任务超时时间(秒)
|
||||
"""
|
||||
self._agent_pool = agent_pool
|
||||
self._workspace = workspace or SharedWorkspace()
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_parallel = max_parallel
|
||||
self._subtask_timeout = subtask_timeout
|
||||
|
||||
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
||||
"""执行编排任务
|
||||
|
||||
Args:
|
||||
task: 原始任务消息
|
||||
|
||||
Returns:
|
||||
OrchestrationResult: 编排结果
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
# 1. Decompose task into subtasks
|
||||
plan = await self._decompose_task(task)
|
||||
|
||||
if not plan.subtasks:
|
||||
return OrchestrationResult(
|
||||
plan_id=plan.plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtask_results={},
|
||||
aggregated_result={"error": "Failed to decompose task"},
|
||||
status=TaskStatus.FAILED,
|
||||
total_duration_ms=0,
|
||||
)
|
||||
|
||||
# 2. Store plan in workspace
|
||||
await self._workspace.write(
|
||||
f"plan:{plan.plan_id}",
|
||||
{"task_id": task.task_id, "subtask_count": len(plan.subtasks)},
|
||||
agent_id="orchestrator",
|
||||
)
|
||||
|
||||
# 3. Execute subtasks
|
||||
subtask_results = await self._execute_plan(plan, task)
|
||||
|
||||
# 4. Aggregate results
|
||||
aggregated = await self._aggregate_results(plan, subtask_results, task)
|
||||
|
||||
# 5. Determine overall status
|
||||
failed_count = sum(
|
||||
1 for r in subtask_results.values() if r.get("status") == "failed"
|
||||
)
|
||||
if failed_count == len(plan.subtasks):
|
||||
status = TaskStatus.FAILED
|
||||
elif failed_count > 0:
|
||||
status = TaskStatus.COMPLETED # Partial success
|
||||
else:
|
||||
status = TaskStatus.COMPLETED
|
||||
|
||||
duration_ms = (time.monotonic() - start_time) * 1000
|
||||
|
||||
return OrchestrationResult(
|
||||
plan_id=plan.plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtask_results=subtask_results,
|
||||
aggregated_result=aggregated,
|
||||
status=status,
|
||||
total_duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
async def _decompose_task(self, task: TaskMessage) -> OrchestrationPlan:
|
||||
"""将复杂任务分解为子任务"""
|
||||
plan_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# If LLM gateway available, use it for decomposition
|
||||
if self._llm_gateway:
|
||||
try:
|
||||
subtasks = await self._llm_decompose(task)
|
||||
if subtasks:
|
||||
parallel_groups = self._build_parallel_groups(subtasks)
|
||||
return OrchestrationPlan(
|
||||
plan_id=plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtasks=subtasks,
|
||||
parallel_groups=parallel_groups,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM decomposition failed, falling back to simple: {e}")
|
||||
|
||||
# Fallback: single subtask = original task
|
||||
subtask = SubTask(
|
||||
task_id=f"{plan_id}-0",
|
||||
parent_task_id=task.task_id,
|
||||
assigned_agent=task.agent_name,
|
||||
task_type=task.task_type,
|
||||
input_data=task.input_data,
|
||||
)
|
||||
return OrchestrationPlan(
|
||||
plan_id=plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtasks=[subtask],
|
||||
parallel_groups=[[subtask.task_id]],
|
||||
)
|
||||
|
||||
async def _llm_decompose(self, task: TaskMessage) -> list[SubTask]:
|
||||
"""使用 LLM 分解任务"""
|
||||
# Get available agents and their capabilities
|
||||
agents_info = self._agent_pool.list_agents()
|
||||
agent_descriptions = "\n".join(
|
||||
f"- {a['name']} ({a['agent_type']}): {a.get('description', 'No description')}"
|
||||
for a in agents_info
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"Decompose the following task into subtasks that can be assigned to available agents.\n\n"
|
||||
f"Task: {task.input_data}\n"
|
||||
f"Task Type: {task.task_type}\n\n"
|
||||
f"Available Agents:\n{agent_descriptions}\n\n"
|
||||
'Respond ONLY with a JSON array: [{"agent_name": "...", "task_type": "...", '
|
||||
'"input_data": {...}, "depends_on": []}]\n'
|
||||
"The depends_on field lists task indices (0-based) that must complete first.\n"
|
||||
"Do not include any other text."
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model="default",
|
||||
)
|
||||
|
||||
try:
|
||||
subtask_defs = json.loads(response.content)
|
||||
if not isinstance(subtask_defs, list):
|
||||
return []
|
||||
|
||||
subtasks = []
|
||||
for i, defn in enumerate(subtask_defs):
|
||||
depends_on = [
|
||||
f"task-{i}" for i in defn.get("depends_on", [])
|
||||
]
|
||||
subtasks.append(SubTask(
|
||||
task_id=f"task-{i}",
|
||||
parent_task_id=task.task_id,
|
||||
assigned_agent=defn.get("agent_name", task.agent_name),
|
||||
task_type=defn.get("task_type", task.task_type),
|
||||
input_data=defn.get("input_data", {}),
|
||||
depends_on=depends_on,
|
||||
))
|
||||
return subtasks
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to parse LLM decomposition: {e}")
|
||||
return []
|
||||
|
||||
def _build_parallel_groups(self, subtasks: list[SubTask]) -> list[list[str]]:
|
||||
"""构建并行执行组
|
||||
|
||||
基于依赖关系拓扑排序,无依赖的子任务分到同一组并行执行。
|
||||
"""
|
||||
# Build dependency graph
|
||||
task_map = {st.task_id: st for st in subtasks}
|
||||
completed: set[str] = set()
|
||||
groups: list[list[str]] = []
|
||||
|
||||
remaining = set(st.task_id for st in subtasks)
|
||||
|
||||
while remaining:
|
||||
# Find tasks with all dependencies satisfied
|
||||
ready = []
|
||||
for tid in remaining:
|
||||
task = task_map[tid]
|
||||
if all(dep in completed for dep in task.depends_on):
|
||||
ready.append(tid)
|
||||
|
||||
if not ready:
|
||||
# Circular dependency — put remaining in one group
|
||||
groups.append(list(remaining))
|
||||
break
|
||||
|
||||
# Limit group size
|
||||
group = ready[:self._max_parallel]
|
||||
groups.append(group)
|
||||
for tid in group:
|
||||
completed.add(tid)
|
||||
remaining.discard(tid)
|
||||
|
||||
return groups
|
||||
|
||||
async def _execute_plan(
|
||||
self, plan: OrchestrationPlan, original_task: TaskMessage
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""执行编排计划"""
|
||||
subtask_results: dict[str, dict[str, Any]] = {}
|
||||
task_map = {st.task_id: st for st in plan.subtasks}
|
||||
|
||||
for group in plan.parallel_groups:
|
||||
# Execute group in parallel
|
||||
tasks = []
|
||||
for task_id in group:
|
||||
subtask = task_map[task_id]
|
||||
# Inject results from dependencies
|
||||
enriched_input = self._inject_dependency_results(
|
||||
subtask, subtask_results
|
||||
)
|
||||
tasks.append(self._execute_subtask(subtask, enriched_input, original_task))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for task_id, result in zip(group, results):
|
||||
if isinstance(result, Exception):
|
||||
subtask_results[task_id] = {
|
||||
"status": "failed",
|
||||
"error": str(result),
|
||||
}
|
||||
else:
|
||||
subtask_results[task_id] = result
|
||||
|
||||
return subtask_results
|
||||
|
||||
async def _execute_subtask(
|
||||
self,
|
||||
subtask: SubTask,
|
||||
input_data: dict[str, Any],
|
||||
original_task: TaskMessage,
|
||||
) -> dict[str, Any]:
|
||||
"""执行单个子任务"""
|
||||
agent = self._agent_pool.get_agent(subtask.assigned_agent)
|
||||
if agent is None:
|
||||
return {"status": "failed", "error": f"Agent '{subtask.assigned_agent}' not found"}
|
||||
|
||||
sub_task_msg = TaskMessage(
|
||||
task_id=subtask.task_id,
|
||||
agent_name=subtask.assigned_agent,
|
||||
task_type=subtask.task_type,
|
||||
priority=original_task.priority,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=original_task.created_at,
|
||||
timeout_seconds=int(self._subtask_timeout),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
agent.execute(sub_task_msg),
|
||||
timeout=self._subtask_timeout,
|
||||
)
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": result.output_data if hasattr(result, "output_data") else result,
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {"status": "failed", "error": "Subtask timed out"}
|
||||
except Exception as e:
|
||||
return {"status": "failed", "error": str(e)}
|
||||
|
||||
def _inject_dependency_results(
|
||||
self,
|
||||
subtask: SubTask,
|
||||
subtask_results: dict[str, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""将依赖子任务的结果注入到当前子任务的输入中"""
|
||||
enriched = dict(subtask.input_data)
|
||||
|
||||
if subtask.depends_on:
|
||||
dep_results = {}
|
||||
for dep_id in subtask.depends_on:
|
||||
if dep_id in subtask_results:
|
||||
dep_results[dep_id] = subtask_results[dep_id]
|
||||
if dep_results:
|
||||
enriched["dependency_results"] = dep_results
|
||||
|
||||
return enriched
|
||||
|
||||
async def _aggregate_results(
|
||||
self,
|
||||
plan: OrchestrationPlan,
|
||||
subtask_results: dict[str, dict[str, Any]],
|
||||
original_task: TaskMessage,
|
||||
) -> dict[str, Any]:
|
||||
"""汇总子任务结果"""
|
||||
# Simple aggregation: collect all outputs
|
||||
outputs = {}
|
||||
errors = []
|
||||
|
||||
for subtask in plan.subtasks:
|
||||
result = subtask_results.get(subtask.task_id, {})
|
||||
if result.get("status") == "completed":
|
||||
outputs[subtask.task_id] = result.get("output", {})
|
||||
else:
|
||||
errors.append({
|
||||
"task_id": subtask.task_id,
|
||||
"error": result.get("error", "Unknown error"),
|
||||
})
|
||||
|
||||
aggregated = {
|
||||
"outputs": outputs,
|
||||
"task_id": original_task.task_id,
|
||||
}
|
||||
if errors:
|
||||
aggregated["errors"] = errors
|
||||
aggregated["partial_success"] = True
|
||||
|
||||
return aggregated
|
||||
|
|
@ -0,0 +1,159 @@
|
|||
"""SharedWorkspace - Agent 间共享工作空间
|
||||
|
||||
基于 Redis 的共享状态存储,支持读写、订阅、锁操作。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SharedWorkspace:
|
||||
"""Agent 间共享工作空间
|
||||
|
||||
基于 Redis 的共享状态存储,支持:
|
||||
- write/read: 读写共享数据
|
||||
- lock/unlock: 分布式锁
|
||||
- 版本控制:每次写入递增版本号
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Any = None, prefix: str = "workspace"):
|
||||
"""
|
||||
Args:
|
||||
redis_client: aioredis.Redis 实例,None 时使用内存字典
|
||||
prefix: Redis key 前缀
|
||||
"""
|
||||
self._redis = redis_client
|
||||
self._prefix = prefix
|
||||
self._local_store: dict[str, dict[str, Any]] = {}
|
||||
self._locks: dict[str, str] = {} # key -> lock_owner
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
return f"{self._prefix}:{key}"
|
||||
|
||||
async def write(
|
||||
self, key: str, value: Any, agent_id: str, ttl: int | None = None
|
||||
) -> int:
|
||||
"""写入共享数据
|
||||
|
||||
Args:
|
||||
key: 数据键
|
||||
value: 数据值
|
||||
agent_id: 写入者 ID
|
||||
ttl: 过期时间(秒),None 表示不过期
|
||||
|
||||
Returns:
|
||||
版本号
|
||||
"""
|
||||
entry = {
|
||||
"value": value,
|
||||
"agent_id": agent_id,
|
||||
"version": await self._get_version(key) + 1,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
if self._redis:
|
||||
redis_key = self._make_key(key)
|
||||
data = json.dumps(entry, default=str)
|
||||
if ttl:
|
||||
await self._redis.setex(redis_key, ttl, data)
|
||||
else:
|
||||
await self._redis.set(redis_key, data)
|
||||
else:
|
||||
self._local_store[key] = entry
|
||||
|
||||
return entry["version"]
|
||||
|
||||
async def read(self, key: str) -> dict[str, Any] | None:
|
||||
"""读取共享数据
|
||||
|
||||
Returns:
|
||||
{"value": ..., "agent_id": ..., "version": ..., "timestamp": ...} 或 None
|
||||
"""
|
||||
if self._redis:
|
||||
redis_key = self._make_key(key)
|
||||
data = await self._redis.get(redis_key)
|
||||
if data is None:
|
||||
return None
|
||||
return json.loads(data)
|
||||
else:
|
||||
return self._local_store.get(key)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""删除共享数据"""
|
||||
if self._redis:
|
||||
redis_key = self._make_key(key)
|
||||
result = await self._redis.delete(redis_key)
|
||||
return result > 0
|
||||
else:
|
||||
return self._local_store.pop(key, None) is not None
|
||||
|
||||
async def lock(self, key: str, agent_id: str, timeout: float = 30.0) -> bool:
|
||||
"""获取分布式锁
|
||||
|
||||
Args:
|
||||
key: 要锁定的数据键
|
||||
agent_id: 请求锁的 Agent ID
|
||||
timeout: 锁超时时间(秒)
|
||||
|
||||
Returns:
|
||||
是否成功获取锁
|
||||
"""
|
||||
lock_key = f"{self._prefix}:lock:{key}"
|
||||
|
||||
if self._redis:
|
||||
# Redis SET with NX (only if not exists) and EX (expiry)
|
||||
result = await self._redis.set(lock_key, agent_id, nx=True, ex=int(timeout))
|
||||
return result is not None
|
||||
else:
|
||||
if key in self._locks:
|
||||
return False
|
||||
self._locks[key] = agent_id
|
||||
return True
|
||||
|
||||
async def unlock(self, key: str, agent_id: str) -> bool:
|
||||
"""释放分布式锁
|
||||
|
||||
只有锁的持有者才能释放锁。
|
||||
"""
|
||||
lock_key = f"{self._prefix}:lock:{key}"
|
||||
|
||||
if self._redis:
|
||||
current_owner = await self._redis.get(lock_key)
|
||||
if current_owner and current_owner.decode() == agent_id:
|
||||
await self._redis.delete(lock_key)
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
if self._locks.get(key) == agent_id:
|
||||
del self._locks[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_version(self, key: str) -> int:
|
||||
"""获取当前版本号"""
|
||||
data = await self.read(key)
|
||||
if data is None:
|
||||
return 0
|
||||
return data.get("version", 0)
|
||||
|
||||
async def list_keys(self) -> list[str]:
|
||||
"""列出所有键"""
|
||||
if self._redis:
|
||||
pattern = f"{self._prefix}:*"
|
||||
keys = []
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
# Strip prefix
|
||||
k = key.decode() if isinstance(key, bytes) else key
|
||||
k = k[len(self._prefix) + 1:] # Remove "prefix:"
|
||||
# Skip lock keys
|
||||
if not k.startswith("lock:"):
|
||||
keys.append(k)
|
||||
return keys
|
||||
else:
|
||||
return list(self._local_store.keys())
|
||||
|
|
@ -0,0 +1,336 @@
|
|||
"""Tests for Orchestrator and SharedWorkspace"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from agentkit.core.orchestrator import (
|
||||
Orchestrator,
|
||||
OrchestrationPlan,
|
||||
OrchestrationResult,
|
||||
SubTask,
|
||||
SubTaskStatus,
|
||||
AgentRole,
|
||||
)
|
||||
from agentkit.core.shared_workspace import SharedWorkspace
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
# --- SharedWorkspace Tests ---
|
||||
|
||||
|
||||
class TestSharedWorkspace:
|
||||
"""SharedWorkspace unit tests (in-memory mode)"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_and_read(self):
|
||||
ws = SharedWorkspace()
|
||||
version = await ws.write("key1", {"data": "value"}, agent_id="agent_a")
|
||||
assert version == 1
|
||||
|
||||
result = await ws.read("key1")
|
||||
assert result is not None
|
||||
assert result["value"] == {"data": "value"}
|
||||
assert result["agent_id"] == "agent_a"
|
||||
assert result["version"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_version_increments(self):
|
||||
ws = SharedWorkspace()
|
||||
v1 = await ws.write("key1", "first", agent_id="a")
|
||||
v2 = await ws.write("key1", "second", agent_id="b")
|
||||
assert v1 == 1
|
||||
assert v2 == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_nonexistent(self):
|
||||
ws = SharedWorkspace()
|
||||
result = await ws.read("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(self):
|
||||
ws = SharedWorkspace()
|
||||
await ws.write("key1", "value", agent_id="a")
|
||||
deleted = await ws.delete("key1")
|
||||
assert deleted is True
|
||||
result = await ws.read("key1")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent(self):
|
||||
ws = SharedWorkspace()
|
||||
deleted = await ws.delete("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_and_unlock(self):
|
||||
ws = SharedWorkspace()
|
||||
acquired = await ws.lock("resource1", agent_id="agent_a")
|
||||
assert acquired is True
|
||||
|
||||
# Same agent can't lock again (already held)
|
||||
acquired2 = await ws.lock("resource1", agent_id="agent_b")
|
||||
assert acquired2 is False
|
||||
|
||||
# Owner can unlock
|
||||
unlocked = await ws.unlock("resource1", agent_id="agent_a")
|
||||
assert unlocked is True
|
||||
|
||||
# Now another agent can lock
|
||||
acquired3 = await ws.lock("resource1", agent_id="agent_b")
|
||||
assert acquired3 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_by_non_owner(self):
|
||||
ws = SharedWorkspace()
|
||||
await ws.lock("resource1", agent_id="agent_a")
|
||||
unlocked = await ws.unlock("resource1", agent_id="agent_b")
|
||||
assert unlocked is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys(self):
|
||||
ws = SharedWorkspace()
|
||||
await ws.write("key1", "v1", agent_id="a")
|
||||
await ws.write("key2", "v2", agent_id="a")
|
||||
keys = await ws.list_keys()
|
||||
assert set(keys) == {"key1", "key2"}
|
||||
|
||||
|
||||
# --- Orchestrator Tests ---
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""Mock Agent for testing"""
|
||||
|
||||
def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False):
|
||||
self.name = name
|
||||
self.agent_type = "mock"
|
||||
self.version = "1.0.0"
|
||||
self._output_data = output_data or {"result": f"output from {name}"}
|
||||
self._should_fail = should_fail
|
||||
|
||||
async def execute(self, task: TaskMessage):
|
||||
if self._should_fail:
|
||||
raise RuntimeError(f"Agent {self.name} failed")
|
||||
from agentkit.core.protocol import TaskResult
|
||||
now = datetime.now(timezone.utc)
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=self._output_data,
|
||||
error_message=None,
|
||||
started_at=now,
|
||||
completed_at=now,
|
||||
)
|
||||
|
||||
|
||||
class MockAgentPool:
|
||||
"""Mock AgentPool for testing"""
|
||||
|
||||
def __init__(self, agents: dict[str, MockAgent] | None = None):
|
||||
self._agents = agents or {}
|
||||
|
||||
def get_agent(self, name: str) -> MockAgent | None:
|
||||
return self._agents.get(name)
|
||||
|
||||
def list_agents(self) -> list[dict]:
|
||||
return [
|
||||
{"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"}
|
||||
for a in self._agents.values()
|
||||
]
|
||||
|
||||
|
||||
class TestOrchestrator:
|
||||
"""Orchestrator unit tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_subtask_execution(self):
|
||||
"""Single agent should execute task directly"""
|
||||
agent = MockAgent("worker1", {"analysis": "result"})
|
||||
pool = MockAgentPool({"worker1": agent})
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
task = TaskMessage(
|
||||
task_id="t1",
|
||||
agent_name="worker1",
|
||||
task_type="analyze",
|
||||
priority=1,
|
||||
input_data={"query": "test"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
result = await orchestrator.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
assert "outputs" in result.aggregated_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_groups_building(self):
|
||||
"""Parallel groups should be built from dependency graph"""
|
||||
pool = MockAgentPool()
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
subtasks = [
|
||||
SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1",
|
||||
task_type="type1", input_data={}, depends_on=[]),
|
||||
SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2",
|
||||
task_type="type2", input_data={}, depends_on=[]),
|
||||
SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3",
|
||||
task_type="type3", input_data={}, depends_on=["t0"]),
|
||||
]
|
||||
|
||||
groups = orchestrator._build_parallel_groups(subtasks)
|
||||
assert len(groups) == 2
|
||||
# First group: t0 and t1 (no dependencies)
|
||||
assert set(groups[0]) == {"t0", "t1"}
|
||||
# Second group: t2 (depends on t0)
|
||||
assert groups[1] == ["t2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_dependency(self):
|
||||
"""Tasks with sequential dependencies should execute in order"""
|
||||
agent1 = MockAgent("a1", {"step": 1})
|
||||
agent2 = MockAgent("a2", {"step": 2})
|
||||
pool = MockAgentPool({"a1": agent1, "a2": agent2})
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
# Manually create a plan with sequential dependencies
|
||||
plan = OrchestrationPlan(
|
||||
plan_id="p1",
|
||||
parent_task_id="parent",
|
||||
subtasks=[
|
||||
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
||||
task_type="step1", input_data={}, depends_on=[]),
|
||||
SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2",
|
||||
task_type="step2", input_data={}, depends_on=["t0"]),
|
||||
],
|
||||
parallel_groups=[["t0"], ["t1"]],
|
||||
)
|
||||
|
||||
task = TaskMessage(
|
||||
task_id="parent",
|
||||
agent_name="orchestrator",
|
||||
task_type="pipeline",
|
||||
priority=1,
|
||||
input_data={"query": "test"},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
results = await orchestrator._execute_plan(plan, task)
|
||||
assert results["t0"]["status"] == "completed"
|
||||
assert results["t1"]["status"] == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_not_found(self):
|
||||
"""Missing agent should result in failed subtask"""
|
||||
pool = MockAgentPool({})
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
plan = OrchestrationPlan(
|
||||
plan_id="p1",
|
||||
parent_task_id="parent",
|
||||
subtasks=[
|
||||
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent",
|
||||
task_type="test", input_data={}, depends_on=[]),
|
||||
],
|
||||
parallel_groups=[["t0"]],
|
||||
)
|
||||
|
||||
task = TaskMessage(
|
||||
task_id="parent",
|
||||
agent_name="orchestrator",
|
||||
task_type="test",
|
||||
priority=1,
|
||||
input_data={},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
results = await orchestrator._execute_plan(plan, task)
|
||||
assert results["t0"]["status"] == "failed"
|
||||
assert "not found" in results["t0"]["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dependency_result_injection(self):
|
||||
"""Subtask should receive dependency results in input"""
|
||||
pool = MockAgentPool()
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
subtask = SubTask(
|
||||
task_id="t1",
|
||||
parent_task_id="p1",
|
||||
assigned_agent="a1",
|
||||
task_type="test",
|
||||
input_data={"query": "test"},
|
||||
depends_on=["t0"],
|
||||
)
|
||||
|
||||
subtask_results = {
|
||||
"t0": {"status": "completed", "output": {"step1_result": "data"}},
|
||||
}
|
||||
|
||||
enriched = orchestrator._inject_dependency_results(subtask, subtask_results)
|
||||
assert "dependency_results" in enriched
|
||||
assert "t0" in enriched["dependency_results"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aggregation_with_errors(self):
|
||||
"""Aggregation should include errors for failed subtasks"""
|
||||
pool = MockAgentPool()
|
||||
orchestrator = Orchestrator(agent_pool=pool)
|
||||
|
||||
plan = OrchestrationPlan(
|
||||
plan_id="p1",
|
||||
parent_task_id="parent",
|
||||
subtasks=[
|
||||
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
||||
task_type="test", input_data={}, depends_on=[]),
|
||||
],
|
||||
parallel_groups=[["t0"]],
|
||||
)
|
||||
|
||||
subtask_results = {
|
||||
"t0": {"status": "failed", "error": "Agent failed"},
|
||||
}
|
||||
|
||||
task = TaskMessage(
|
||||
task_id="parent",
|
||||
agent_name="orchestrator",
|
||||
task_type="test",
|
||||
priority=1,
|
||||
input_data={},
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
aggregated = await orchestrator._aggregate_results(plan, subtask_results, task)
|
||||
assert "errors" in aggregated
|
||||
assert aggregated["partial_success"] is True
|
||||
|
||||
|
||||
class TestAgentRole:
|
||||
"""AgentRole enum tests"""
|
||||
|
||||
def test_role_values(self):
|
||||
assert AgentRole.ORCHESTRATOR.value == "orchestrator"
|
||||
assert AgentRole.WORKER.value == "worker"
|
||||
assert AgentRole.REVIEWER.value == "reviewer"
|
||||
|
||||
|
||||
class TestSubTask:
|
||||
"""SubTask dataclass tests"""
|
||||
|
||||
def test_default_values(self):
|
||||
st = SubTask(
|
||||
task_id="t1",
|
||||
parent_task_id="p1",
|
||||
assigned_agent="a1",
|
||||
task_type="test",
|
||||
input_data={},
|
||||
)
|
||||
assert st.status == SubTaskStatus.PENDING
|
||||
assert st.result is None
|
||||
assert st.depends_on == []
|
||||
Loading…
Reference in New Issue