feat(core): U4 multi-agent Orchestrator with SharedWorkspace

- Orchestrator: Orchestrator-Worker pattern with LLM-driven task decomposition
- SharedWorkspace: Redis-backed shared state with versioning and distributed locks
- SubTask dependency graph, parallel group building, result aggregation
- 16 tests passing
This commit is contained in:
chiguyong 2026-06-06 22:25:12 +08:00
parent 364fe6bd6d
commit 23934602c0
3 changed files with 901 additions and 0 deletions

View File

@ -0,0 +1,406 @@
"""Orchestrator - 多 Agent 协作编排器
实现 Orchestrator-Worker 模式中央编排器协调多 Agent 并行/串行执行
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
from agentkit.core.shared_workspace import SharedWorkspace
logger = logging.getLogger(__name__)
class AgentRole(str, Enum):
"""Agent 角色枚举"""
ORCHESTRATOR = "orchestrator"
WORKER = "worker"
REVIEWER = "reviewer"
class SubTaskStatus(str, Enum):
"""子任务状态"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class SubTask:
"""子任务定义"""
task_id: str
parent_task_id: str
assigned_agent: str
task_type: str
input_data: dict[str, Any]
status: SubTaskStatus = SubTaskStatus.PENDING
result: dict[str, Any] | None = None
error: str | None = None
depends_on: list[str] = field(default_factory=list)
@dataclass
class OrchestrationPlan:
"""编排计划"""
plan_id: str
parent_task_id: str
subtasks: list[SubTask]
parallel_groups: list[list[str]] # 每组内的子任务可并行执行
@dataclass
class OrchestrationResult:
"""编排结果"""
plan_id: str
parent_task_id: str
subtask_results: dict[str, dict[str, Any]]
aggregated_result: dict[str, Any]
status: TaskStatus
total_duration_ms: float
class Orchestrator:
"""多 Agent 协作编排器
Orchestrator-Worker 模式
1. 接收复杂任务
2. LLM 驱动分解为子任务
3. 基于 Skill 能力匹配子任务到 Worker Agent
4. 并行/串行执行子任务
5. 汇总结果生成最终输出
使用方式
orchestrator = Orchestrator(agent_pool=pool, workspace=workspace)
result = await orchestrator.execute(task_message)
"""
def __init__(
self,
agent_pool: Any,
workspace: SharedWorkspace | None = None,
llm_gateway: Any = None,
max_parallel: int = 5,
subtask_timeout: float = 300.0,
):
"""
Args:
agent_pool: AgentPool 实例
workspace: 共享工作空间
llm_gateway: LLM Gateway用于任务分解
max_parallel: 最大并行子任务数
subtask_timeout: 子任务超时时间
"""
self._agent_pool = agent_pool
self._workspace = workspace or SharedWorkspace()
self._llm_gateway = llm_gateway
self._max_parallel = max_parallel
self._subtask_timeout = subtask_timeout
async def execute(self, task: TaskMessage) -> OrchestrationResult:
"""执行编排任务
Args:
task: 原始任务消息
Returns:
OrchestrationResult: 编排结果
"""
import time
start_time = time.monotonic()
# 1. Decompose task into subtasks
plan = await self._decompose_task(task)
if not plan.subtasks:
return OrchestrationResult(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtask_results={},
aggregated_result={"error": "Failed to decompose task"},
status=TaskStatus.FAILED,
total_duration_ms=0,
)
# 2. Store plan in workspace
await self._workspace.write(
f"plan:{plan.plan_id}",
{"task_id": task.task_id, "subtask_count": len(plan.subtasks)},
agent_id="orchestrator",
)
# 3. Execute subtasks
subtask_results = await self._execute_plan(plan, task)
# 4. Aggregate results
aggregated = await self._aggregate_results(plan, subtask_results, task)
# 5. Determine overall status
failed_count = sum(
1 for r in subtask_results.values() if r.get("status") == "failed"
)
if failed_count == len(plan.subtasks):
status = TaskStatus.FAILED
elif failed_count > 0:
status = TaskStatus.COMPLETED # Partial success
else:
status = TaskStatus.COMPLETED
duration_ms = (time.monotonic() - start_time) * 1000
return OrchestrationResult(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtask_results=subtask_results,
aggregated_result=aggregated,
status=status,
total_duration_ms=duration_ms,
)
async def _decompose_task(self, task: TaskMessage) -> OrchestrationPlan:
"""将复杂任务分解为子任务"""
plan_id = str(uuid.uuid4())[:8]
# If LLM gateway available, use it for decomposition
if self._llm_gateway:
try:
subtasks = await self._llm_decompose(task)
if subtasks:
parallel_groups = self._build_parallel_groups(subtasks)
return OrchestrationPlan(
plan_id=plan_id,
parent_task_id=task.task_id,
subtasks=subtasks,
parallel_groups=parallel_groups,
)
except Exception as e:
logger.warning(f"LLM decomposition failed, falling back to simple: {e}")
# Fallback: single subtask = original task
subtask = SubTask(
task_id=f"{plan_id}-0",
parent_task_id=task.task_id,
assigned_agent=task.agent_name,
task_type=task.task_type,
input_data=task.input_data,
)
return OrchestrationPlan(
plan_id=plan_id,
parent_task_id=task.task_id,
subtasks=[subtask],
parallel_groups=[[subtask.task_id]],
)
async def _llm_decompose(self, task: TaskMessage) -> list[SubTask]:
"""使用 LLM 分解任务"""
# Get available agents and their capabilities
agents_info = self._agent_pool.list_agents()
agent_descriptions = "\n".join(
f"- {a['name']} ({a['agent_type']}): {a.get('description', 'No description')}"
for a in agents_info
)
prompt = (
f"Decompose the following task into subtasks that can be assigned to available agents.\n\n"
f"Task: {task.input_data}\n"
f"Task Type: {task.task_type}\n\n"
f"Available Agents:\n{agent_descriptions}\n\n"
'Respond ONLY with a JSON array: [{"agent_name": "...", "task_type": "...", '
'"input_data": {...}, "depends_on": []}]\n'
"The depends_on field lists task indices (0-based) that must complete first.\n"
"Do not include any other text."
)
import json
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
try:
subtask_defs = json.loads(response.content)
if not isinstance(subtask_defs, list):
return []
subtasks = []
for i, defn in enumerate(subtask_defs):
depends_on = [
f"task-{i}" for i in defn.get("depends_on", [])
]
subtasks.append(SubTask(
task_id=f"task-{i}",
parent_task_id=task.task_id,
assigned_agent=defn.get("agent_name", task.agent_name),
task_type=defn.get("task_type", task.task_type),
input_data=defn.get("input_data", {}),
depends_on=depends_on,
))
return subtasks
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM decomposition: {e}")
return []
def _build_parallel_groups(self, subtasks: list[SubTask]) -> list[list[str]]:
"""构建并行执行组
基于依赖关系拓扑排序无依赖的子任务分到同一组并行执行
"""
# Build dependency graph
task_map = {st.task_id: st for st in subtasks}
completed: set[str] = set()
groups: list[list[str]] = []
remaining = set(st.task_id for st in subtasks)
while remaining:
# Find tasks with all dependencies satisfied
ready = []
for tid in remaining:
task = task_map[tid]
if all(dep in completed for dep in task.depends_on):
ready.append(tid)
if not ready:
# Circular dependency — put remaining in one group
groups.append(list(remaining))
break
# Limit group size
group = ready[:self._max_parallel]
groups.append(group)
for tid in group:
completed.add(tid)
remaining.discard(tid)
return groups
async def _execute_plan(
self, plan: OrchestrationPlan, original_task: TaskMessage
) -> dict[str, dict[str, Any]]:
"""执行编排计划"""
subtask_results: dict[str, dict[str, Any]] = {}
task_map = {st.task_id: st for st in plan.subtasks}
for group in plan.parallel_groups:
# Execute group in parallel
tasks = []
for task_id in group:
subtask = task_map[task_id]
# Inject results from dependencies
enriched_input = self._inject_dependency_results(
subtask, subtask_results
)
tasks.append(self._execute_subtask(subtask, enriched_input, original_task))
results = await asyncio.gather(*tasks, return_exceptions=True)
for task_id, result in zip(group, results):
if isinstance(result, Exception):
subtask_results[task_id] = {
"status": "failed",
"error": str(result),
}
else:
subtask_results[task_id] = result
return subtask_results
async def _execute_subtask(
self,
subtask: SubTask,
input_data: dict[str, Any],
original_task: TaskMessage,
) -> dict[str, Any]:
"""执行单个子任务"""
agent = self._agent_pool.get_agent(subtask.assigned_agent)
if agent is None:
return {"status": "failed", "error": f"Agent '{subtask.assigned_agent}' not found"}
sub_task_msg = TaskMessage(
task_id=subtask.task_id,
agent_name=subtask.assigned_agent,
task_type=subtask.task_type,
priority=original_task.priority,
input_data=input_data,
callback_url=None,
created_at=original_task.created_at,
timeout_seconds=int(self._subtask_timeout),
)
try:
result = await asyncio.wait_for(
agent.execute(sub_task_msg),
timeout=self._subtask_timeout,
)
return {
"status": "completed",
"output": result.output_data if hasattr(result, "output_data") else result,
}
except asyncio.TimeoutError:
return {"status": "failed", "error": "Subtask timed out"}
except Exception as e:
return {"status": "failed", "error": str(e)}
def _inject_dependency_results(
self,
subtask: SubTask,
subtask_results: dict[str, dict[str, Any]],
) -> dict[str, Any]:
"""将依赖子任务的结果注入到当前子任务的输入中"""
enriched = dict(subtask.input_data)
if subtask.depends_on:
dep_results = {}
for dep_id in subtask.depends_on:
if dep_id in subtask_results:
dep_results[dep_id] = subtask_results[dep_id]
if dep_results:
enriched["dependency_results"] = dep_results
return enriched
async def _aggregate_results(
self,
plan: OrchestrationPlan,
subtask_results: dict[str, dict[str, Any]],
original_task: TaskMessage,
) -> dict[str, Any]:
"""汇总子任务结果"""
# Simple aggregation: collect all outputs
outputs = {}
errors = []
for subtask in plan.subtasks:
result = subtask_results.get(subtask.task_id, {})
if result.get("status") == "completed":
outputs[subtask.task_id] = result.get("output", {})
else:
errors.append({
"task_id": subtask.task_id,
"error": result.get("error", "Unknown error"),
})
aggregated = {
"outputs": outputs,
"task_id": original_task.task_id,
}
if errors:
aggregated["errors"] = errors
aggregated["partial_success"] = True
return aggregated

View File

@ -0,0 +1,159 @@
"""SharedWorkspace - Agent 间共享工作空间
基于 Redis 的共享状态存储支持读写订阅锁操作
"""
from __future__ import annotations
import json
import logging
import time
from typing import Any
logger = logging.getLogger(__name__)
class SharedWorkspace:
"""Agent 间共享工作空间
基于 Redis 的共享状态存储支持
- write/read: 读写共享数据
- lock/unlock: 分布式锁
- 版本控制每次写入递增版本号
"""
def __init__(self, redis_client: Any = None, prefix: str = "workspace"):
"""
Args:
redis_client: aioredis.Redis 实例None 时使用内存字典
prefix: Redis key 前缀
"""
self._redis = redis_client
self._prefix = prefix
self._local_store: dict[str, dict[str, Any]] = {}
self._locks: dict[str, str] = {} # key -> lock_owner
def _make_key(self, key: str) -> str:
return f"{self._prefix}:{key}"
async def write(
self, key: str, value: Any, agent_id: str, ttl: int | None = None
) -> int:
"""写入共享数据
Args:
key: 数据键
value: 数据值
agent_id: 写入者 ID
ttl: 过期时间None 表示不过期
Returns:
版本号
"""
entry = {
"value": value,
"agent_id": agent_id,
"version": await self._get_version(key) + 1,
"timestamp": time.time(),
}
if self._redis:
redis_key = self._make_key(key)
data = json.dumps(entry, default=str)
if ttl:
await self._redis.setex(redis_key, ttl, data)
else:
await self._redis.set(redis_key, data)
else:
self._local_store[key] = entry
return entry["version"]
async def read(self, key: str) -> dict[str, Any] | None:
"""读取共享数据
Returns:
{"value": ..., "agent_id": ..., "version": ..., "timestamp": ...} None
"""
if self._redis:
redis_key = self._make_key(key)
data = await self._redis.get(redis_key)
if data is None:
return None
return json.loads(data)
else:
return self._local_store.get(key)
async def delete(self, key: str) -> bool:
"""删除共享数据"""
if self._redis:
redis_key = self._make_key(key)
result = await self._redis.delete(redis_key)
return result > 0
else:
return self._local_store.pop(key, None) is not None
async def lock(self, key: str, agent_id: str, timeout: float = 30.0) -> bool:
"""获取分布式锁
Args:
key: 要锁定的数据键
agent_id: 请求锁的 Agent ID
timeout: 锁超时时间
Returns:
是否成功获取锁
"""
lock_key = f"{self._prefix}:lock:{key}"
if self._redis:
# Redis SET with NX (only if not exists) and EX (expiry)
result = await self._redis.set(lock_key, agent_id, nx=True, ex=int(timeout))
return result is not None
else:
if key in self._locks:
return False
self._locks[key] = agent_id
return True
async def unlock(self, key: str, agent_id: str) -> bool:
"""释放分布式锁
只有锁的持有者才能释放锁
"""
lock_key = f"{self._prefix}:lock:{key}"
if self._redis:
current_owner = await self._redis.get(lock_key)
if current_owner and current_owner.decode() == agent_id:
await self._redis.delete(lock_key)
return True
return False
else:
if self._locks.get(key) == agent_id:
del self._locks[key]
return True
return False
async def _get_version(self, key: str) -> int:
"""获取当前版本号"""
data = await self.read(key)
if data is None:
return 0
return data.get("version", 0)
async def list_keys(self) -> list[str]:
"""列出所有键"""
if self._redis:
pattern = f"{self._prefix}:*"
keys = []
async for key in self._redis.scan_iter(match=pattern):
# Strip prefix
k = key.decode() if isinstance(key, bytes) else key
k = k[len(self._prefix) + 1:] # Remove "prefix:"
# Skip lock keys
if not k.startswith("lock:"):
keys.append(k)
return keys
else:
return list(self._local_store.keys())

View File

@ -0,0 +1,336 @@
"""Tests for Orchestrator and SharedWorkspace"""
import asyncio
import pytest
from agentkit.core.orchestrator import (
Orchestrator,
OrchestrationPlan,
OrchestrationResult,
SubTask,
SubTaskStatus,
AgentRole,
)
from agentkit.core.shared_workspace import SharedWorkspace
from agentkit.core.protocol import TaskMessage, TaskStatus
from datetime import datetime, timezone
# --- SharedWorkspace Tests ---
class TestSharedWorkspace:
"""SharedWorkspace unit tests (in-memory mode)"""
@pytest.mark.asyncio
async def test_write_and_read(self):
ws = SharedWorkspace()
version = await ws.write("key1", {"data": "value"}, agent_id="agent_a")
assert version == 1
result = await ws.read("key1")
assert result is not None
assert result["value"] == {"data": "value"}
assert result["agent_id"] == "agent_a"
assert result["version"] == 1
@pytest.mark.asyncio
async def test_version_increments(self):
ws = SharedWorkspace()
v1 = await ws.write("key1", "first", agent_id="a")
v2 = await ws.write("key1", "second", agent_id="b")
assert v1 == 1
assert v2 == 2
@pytest.mark.asyncio
async def test_read_nonexistent(self):
ws = SharedWorkspace()
result = await ws.read("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_delete(self):
ws = SharedWorkspace()
await ws.write("key1", "value", agent_id="a")
deleted = await ws.delete("key1")
assert deleted is True
result = await ws.read("key1")
assert result is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self):
ws = SharedWorkspace()
deleted = await ws.delete("nonexistent")
assert deleted is False
@pytest.mark.asyncio
async def test_lock_and_unlock(self):
ws = SharedWorkspace()
acquired = await ws.lock("resource1", agent_id="agent_a")
assert acquired is True
# Same agent can't lock again (already held)
acquired2 = await ws.lock("resource1", agent_id="agent_b")
assert acquired2 is False
# Owner can unlock
unlocked = await ws.unlock("resource1", agent_id="agent_a")
assert unlocked is True
# Now another agent can lock
acquired3 = await ws.lock("resource1", agent_id="agent_b")
assert acquired3 is True
@pytest.mark.asyncio
async def test_unlock_by_non_owner(self):
ws = SharedWorkspace()
await ws.lock("resource1", agent_id="agent_a")
unlocked = await ws.unlock("resource1", agent_id="agent_b")
assert unlocked is False
@pytest.mark.asyncio
async def test_list_keys(self):
ws = SharedWorkspace()
await ws.write("key1", "v1", agent_id="a")
await ws.write("key2", "v2", agent_id="a")
keys = await ws.list_keys()
assert set(keys) == {"key1", "key2"}
# --- Orchestrator Tests ---
class MockAgent:
"""Mock Agent for testing"""
def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False):
self.name = name
self.agent_type = "mock"
self.version = "1.0.0"
self._output_data = output_data or {"result": f"output from {name}"}
self._should_fail = should_fail
async def execute(self, task: TaskMessage):
if self._should_fail:
raise RuntimeError(f"Agent {self.name} failed")
from agentkit.core.protocol import TaskResult
now = datetime.now(timezone.utc)
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=self._output_data,
error_message=None,
started_at=now,
completed_at=now,
)
class MockAgentPool:
"""Mock AgentPool for testing"""
def __init__(self, agents: dict[str, MockAgent] | None = None):
self._agents = agents or {}
def get_agent(self, name: str) -> MockAgent | None:
return self._agents.get(name)
def list_agents(self) -> list[dict]:
return [
{"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"}
for a in self._agents.values()
]
class TestOrchestrator:
"""Orchestrator unit tests"""
@pytest.mark.asyncio
async def test_single_subtask_execution(self):
"""Single agent should execute task directly"""
agent = MockAgent("worker1", {"analysis": "result"})
pool = MockAgentPool({"worker1": agent})
orchestrator = Orchestrator(agent_pool=pool)
task = TaskMessage(
task_id="t1",
agent_name="worker1",
task_type="analyze",
priority=1,
input_data={"query": "test"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await orchestrator.execute(task)
assert result.status == TaskStatus.COMPLETED
assert "outputs" in result.aggregated_result
@pytest.mark.asyncio
async def test_parallel_groups_building(self):
"""Parallel groups should be built from dependency graph"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
subtasks = [
SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1",
task_type="type1", input_data={}, depends_on=[]),
SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2",
task_type="type2", input_data={}, depends_on=[]),
SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3",
task_type="type3", input_data={}, depends_on=["t0"]),
]
groups = orchestrator._build_parallel_groups(subtasks)
assert len(groups) == 2
# First group: t0 and t1 (no dependencies)
assert set(groups[0]) == {"t0", "t1"}
# Second group: t2 (depends on t0)
assert groups[1] == ["t2"]
@pytest.mark.asyncio
async def test_sequential_dependency(self):
"""Tasks with sequential dependencies should execute in order"""
agent1 = MockAgent("a1", {"step": 1})
agent2 = MockAgent("a2", {"step": 2})
pool = MockAgentPool({"a1": agent1, "a2": agent2})
orchestrator = Orchestrator(agent_pool=pool)
# Manually create a plan with sequential dependencies
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
task_type="step1", input_data={}, depends_on=[]),
SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2",
task_type="step2", input_data={}, depends_on=["t0"]),
],
parallel_groups=[["t0"], ["t1"]],
)
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="pipeline",
priority=1,
input_data={"query": "test"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
results = await orchestrator._execute_plan(plan, task)
assert results["t0"]["status"] == "completed"
assert results["t1"]["status"] == "completed"
@pytest.mark.asyncio
async def test_agent_not_found(self):
"""Missing agent should result in failed subtask"""
pool = MockAgentPool({})
orchestrator = Orchestrator(agent_pool=pool)
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent",
task_type="test", input_data={}, depends_on=[]),
],
parallel_groups=[["t0"]],
)
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="test",
priority=1,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
results = await orchestrator._execute_plan(plan, task)
assert results["t0"]["status"] == "failed"
assert "not found" in results["t0"]["error"]
@pytest.mark.asyncio
async def test_dependency_result_injection(self):
"""Subtask should receive dependency results in input"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
subtask = SubTask(
task_id="t1",
parent_task_id="p1",
assigned_agent="a1",
task_type="test",
input_data={"query": "test"},
depends_on=["t0"],
)
subtask_results = {
"t0": {"status": "completed", "output": {"step1_result": "data"}},
}
enriched = orchestrator._inject_dependency_results(subtask, subtask_results)
assert "dependency_results" in enriched
assert "t0" in enriched["dependency_results"]
@pytest.mark.asyncio
async def test_aggregation_with_errors(self):
"""Aggregation should include errors for failed subtasks"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
task_type="test", input_data={}, depends_on=[]),
],
parallel_groups=[["t0"]],
)
subtask_results = {
"t0": {"status": "failed", "error": "Agent failed"},
}
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="test",
priority=1,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
aggregated = await orchestrator._aggregate_results(plan, subtask_results, task)
assert "errors" in aggregated
assert aggregated["partial_success"] is True
class TestAgentRole:
"""AgentRole enum tests"""
def test_role_values(self):
assert AgentRole.ORCHESTRATOR.value == "orchestrator"
assert AgentRole.WORKER.value == "worker"
assert AgentRole.REVIEWER.value == "reviewer"
class TestSubTask:
"""SubTask dataclass tests"""
def test_default_values(self):
st = SubTask(
task_id="t1",
parent_task_id="p1",
assigned_agent="a1",
task_type="test",
input_data={},
)
assert st.status == SubTaskStatus.PENDING
assert st.result is None
assert st.depends_on == []