feat(core): U4 multi-agent Orchestrator with SharedWorkspace
- Orchestrator: Orchestrator-Worker pattern with LLM-driven task decomposition - SharedWorkspace: Redis-backed shared state with versioning and distributed locks - SubTask dependency graph, parallel group building, result aggregation - 16 tests passing
This commit is contained in:
parent
364fe6bd6d
commit
23934602c0
|
|
@ -0,0 +1,406 @@
|
||||||
|
"""Orchestrator - 多 Agent 协作编排器
|
||||||
|
|
||||||
|
实现 Orchestrator-Worker 模式:中央编排器协调多 Agent 并行/串行执行。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||||||
|
from agentkit.core.shared_workspace import SharedWorkspace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRole(str, Enum):
|
||||||
|
"""Agent 角色枚举"""
|
||||||
|
|
||||||
|
ORCHESTRATOR = "orchestrator"
|
||||||
|
WORKER = "worker"
|
||||||
|
REVIEWER = "reviewer"
|
||||||
|
|
||||||
|
|
||||||
|
class SubTaskStatus(str, Enum):
|
||||||
|
"""子任务状态"""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SubTask:
|
||||||
|
"""子任务定义"""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
parent_task_id: str
|
||||||
|
assigned_agent: str
|
||||||
|
task_type: str
|
||||||
|
input_data: dict[str, Any]
|
||||||
|
status: SubTaskStatus = SubTaskStatus.PENDING
|
||||||
|
result: dict[str, Any] | None = None
|
||||||
|
error: str | None = None
|
||||||
|
depends_on: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrchestrationPlan:
|
||||||
|
"""编排计划"""
|
||||||
|
|
||||||
|
plan_id: str
|
||||||
|
parent_task_id: str
|
||||||
|
subtasks: list[SubTask]
|
||||||
|
parallel_groups: list[list[str]] # 每组内的子任务可并行执行
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrchestrationResult:
|
||||||
|
"""编排结果"""
|
||||||
|
|
||||||
|
plan_id: str
|
||||||
|
parent_task_id: str
|
||||||
|
subtask_results: dict[str, dict[str, Any]]
|
||||||
|
aggregated_result: dict[str, Any]
|
||||||
|
status: TaskStatus
|
||||||
|
total_duration_ms: float
|
||||||
|
|
||||||
|
|
||||||
|
class Orchestrator:
|
||||||
|
"""多 Agent 协作编排器
|
||||||
|
|
||||||
|
Orchestrator-Worker 模式:
|
||||||
|
1. 接收复杂任务
|
||||||
|
2. LLM 驱动分解为子任务
|
||||||
|
3. 基于 Skill 能力匹配子任务到 Worker Agent
|
||||||
|
4. 并行/串行执行子任务
|
||||||
|
5. 汇总结果,生成最终输出
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool, workspace=workspace)
|
||||||
|
result = await orchestrator.execute(task_message)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_pool: Any,
|
||||||
|
workspace: SharedWorkspace | None = None,
|
||||||
|
llm_gateway: Any = None,
|
||||||
|
max_parallel: int = 5,
|
||||||
|
subtask_timeout: float = 300.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
agent_pool: AgentPool 实例
|
||||||
|
workspace: 共享工作空间
|
||||||
|
llm_gateway: LLM Gateway,用于任务分解
|
||||||
|
max_parallel: 最大并行子任务数
|
||||||
|
subtask_timeout: 子任务超时时间(秒)
|
||||||
|
"""
|
||||||
|
self._agent_pool = agent_pool
|
||||||
|
self._workspace = workspace or SharedWorkspace()
|
||||||
|
self._llm_gateway = llm_gateway
|
||||||
|
self._max_parallel = max_parallel
|
||||||
|
self._subtask_timeout = subtask_timeout
|
||||||
|
|
||||||
|
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
||||||
|
"""执行编排任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 原始任务消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrchestrationResult: 编排结果
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
# 1. Decompose task into subtasks
|
||||||
|
plan = await self._decompose_task(task)
|
||||||
|
|
||||||
|
if not plan.subtasks:
|
||||||
|
return OrchestrationResult(
|
||||||
|
plan_id=plan.plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtask_results={},
|
||||||
|
aggregated_result={"error": "Failed to decompose task"},
|
||||||
|
status=TaskStatus.FAILED,
|
||||||
|
total_duration_ms=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Store plan in workspace
|
||||||
|
await self._workspace.write(
|
||||||
|
f"plan:{plan.plan_id}",
|
||||||
|
{"task_id": task.task_id, "subtask_count": len(plan.subtasks)},
|
||||||
|
agent_id="orchestrator",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Execute subtasks
|
||||||
|
subtask_results = await self._execute_plan(plan, task)
|
||||||
|
|
||||||
|
# 4. Aggregate results
|
||||||
|
aggregated = await self._aggregate_results(plan, subtask_results, task)
|
||||||
|
|
||||||
|
# 5. Determine overall status
|
||||||
|
failed_count = sum(
|
||||||
|
1 for r in subtask_results.values() if r.get("status") == "failed"
|
||||||
|
)
|
||||||
|
if failed_count == len(plan.subtasks):
|
||||||
|
status = TaskStatus.FAILED
|
||||||
|
elif failed_count > 0:
|
||||||
|
status = TaskStatus.COMPLETED # Partial success
|
||||||
|
else:
|
||||||
|
status = TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
duration_ms = (time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
|
return OrchestrationResult(
|
||||||
|
plan_id=plan.plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtask_results=subtask_results,
|
||||||
|
aggregated_result=aggregated,
|
||||||
|
status=status,
|
||||||
|
total_duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _decompose_task(self, task: TaskMessage) -> OrchestrationPlan:
|
||||||
|
"""将复杂任务分解为子任务"""
|
||||||
|
plan_id = str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
# If LLM gateway available, use it for decomposition
|
||||||
|
if self._llm_gateway:
|
||||||
|
try:
|
||||||
|
subtasks = await self._llm_decompose(task)
|
||||||
|
if subtasks:
|
||||||
|
parallel_groups = self._build_parallel_groups(subtasks)
|
||||||
|
return OrchestrationPlan(
|
||||||
|
plan_id=plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtasks=subtasks,
|
||||||
|
parallel_groups=parallel_groups,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM decomposition failed, falling back to simple: {e}")
|
||||||
|
|
||||||
|
# Fallback: single subtask = original task
|
||||||
|
subtask = SubTask(
|
||||||
|
task_id=f"{plan_id}-0",
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
assigned_agent=task.agent_name,
|
||||||
|
task_type=task.task_type,
|
||||||
|
input_data=task.input_data,
|
||||||
|
)
|
||||||
|
return OrchestrationPlan(
|
||||||
|
plan_id=plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtasks=[subtask],
|
||||||
|
parallel_groups=[[subtask.task_id]],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _llm_decompose(self, task: TaskMessage) -> list[SubTask]:
|
||||||
|
"""使用 LLM 分解任务"""
|
||||||
|
# Get available agents and their capabilities
|
||||||
|
agents_info = self._agent_pool.list_agents()
|
||||||
|
agent_descriptions = "\n".join(
|
||||||
|
f"- {a['name']} ({a['agent_type']}): {a.get('description', 'No description')}"
|
||||||
|
for a in agents_info
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Decompose the following task into subtasks that can be assigned to available agents.\n\n"
|
||||||
|
f"Task: {task.input_data}\n"
|
||||||
|
f"Task Type: {task.task_type}\n\n"
|
||||||
|
f"Available Agents:\n{agent_descriptions}\n\n"
|
||||||
|
'Respond ONLY with a JSON array: [{"agent_name": "...", "task_type": "...", '
|
||||||
|
'"input_data": {...}, "depends_on": []}]\n'
|
||||||
|
"The depends_on field lists task indices (0-based) that must complete first.\n"
|
||||||
|
"Do not include any other text."
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
model="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
subtask_defs = json.loads(response.content)
|
||||||
|
if not isinstance(subtask_defs, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
subtasks = []
|
||||||
|
for i, defn in enumerate(subtask_defs):
|
||||||
|
depends_on = [
|
||||||
|
f"task-{i}" for i in defn.get("depends_on", [])
|
||||||
|
]
|
||||||
|
subtasks.append(SubTask(
|
||||||
|
task_id=f"task-{i}",
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
assigned_agent=defn.get("agent_name", task.agent_name),
|
||||||
|
task_type=defn.get("task_type", task.task_type),
|
||||||
|
input_data=defn.get("input_data", {}),
|
||||||
|
depends_on=depends_on,
|
||||||
|
))
|
||||||
|
return subtasks
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.warning(f"Failed to parse LLM decomposition: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _build_parallel_groups(self, subtasks: list[SubTask]) -> list[list[str]]:
|
||||||
|
"""构建并行执行组
|
||||||
|
|
||||||
|
基于依赖关系拓扑排序,无依赖的子任务分到同一组并行执行。
|
||||||
|
"""
|
||||||
|
# Build dependency graph
|
||||||
|
task_map = {st.task_id: st for st in subtasks}
|
||||||
|
completed: set[str] = set()
|
||||||
|
groups: list[list[str]] = []
|
||||||
|
|
||||||
|
remaining = set(st.task_id for st in subtasks)
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
# Find tasks with all dependencies satisfied
|
||||||
|
ready = []
|
||||||
|
for tid in remaining:
|
||||||
|
task = task_map[tid]
|
||||||
|
if all(dep in completed for dep in task.depends_on):
|
||||||
|
ready.append(tid)
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
# Circular dependency — put remaining in one group
|
||||||
|
groups.append(list(remaining))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Limit group size
|
||||||
|
group = ready[:self._max_parallel]
|
||||||
|
groups.append(group)
|
||||||
|
for tid in group:
|
||||||
|
completed.add(tid)
|
||||||
|
remaining.discard(tid)
|
||||||
|
|
||||||
|
return groups
|
||||||
|
|
||||||
|
async def _execute_plan(
|
||||||
|
self, plan: OrchestrationPlan, original_task: TaskMessage
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""执行编排计划"""
|
||||||
|
subtask_results: dict[str, dict[str, Any]] = {}
|
||||||
|
task_map = {st.task_id: st for st in plan.subtasks}
|
||||||
|
|
||||||
|
for group in plan.parallel_groups:
|
||||||
|
# Execute group in parallel
|
||||||
|
tasks = []
|
||||||
|
for task_id in group:
|
||||||
|
subtask = task_map[task_id]
|
||||||
|
# Inject results from dependencies
|
||||||
|
enriched_input = self._inject_dependency_results(
|
||||||
|
subtask, subtask_results
|
||||||
|
)
|
||||||
|
tasks.append(self._execute_subtask(subtask, enriched_input, original_task))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for task_id, result in zip(group, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
subtask_results[task_id] = {
|
||||||
|
"status": "failed",
|
||||||
|
"error": str(result),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
subtask_results[task_id] = result
|
||||||
|
|
||||||
|
return subtask_results
|
||||||
|
|
||||||
|
async def _execute_subtask(
|
||||||
|
self,
|
||||||
|
subtask: SubTask,
|
||||||
|
input_data: dict[str, Any],
|
||||||
|
original_task: TaskMessage,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""执行单个子任务"""
|
||||||
|
agent = self._agent_pool.get_agent(subtask.assigned_agent)
|
||||||
|
if agent is None:
|
||||||
|
return {"status": "failed", "error": f"Agent '{subtask.assigned_agent}' not found"}
|
||||||
|
|
||||||
|
sub_task_msg = TaskMessage(
|
||||||
|
task_id=subtask.task_id,
|
||||||
|
agent_name=subtask.assigned_agent,
|
||||||
|
task_type=subtask.task_type,
|
||||||
|
priority=original_task.priority,
|
||||||
|
input_data=input_data,
|
||||||
|
callback_url=None,
|
||||||
|
created_at=original_task.created_at,
|
||||||
|
timeout_seconds=int(self._subtask_timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
agent.execute(sub_task_msg),
|
||||||
|
timeout=self._subtask_timeout,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"output": result.output_data if hasattr(result, "output_data") else result,
|
||||||
|
}
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return {"status": "failed", "error": "Subtask timed out"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "failed", "error": str(e)}
|
||||||
|
|
||||||
|
def _inject_dependency_results(
|
||||||
|
self,
|
||||||
|
subtask: SubTask,
|
||||||
|
subtask_results: dict[str, dict[str, Any]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""将依赖子任务的结果注入到当前子任务的输入中"""
|
||||||
|
enriched = dict(subtask.input_data)
|
||||||
|
|
||||||
|
if subtask.depends_on:
|
||||||
|
dep_results = {}
|
||||||
|
for dep_id in subtask.depends_on:
|
||||||
|
if dep_id in subtask_results:
|
||||||
|
dep_results[dep_id] = subtask_results[dep_id]
|
||||||
|
if dep_results:
|
||||||
|
enriched["dependency_results"] = dep_results
|
||||||
|
|
||||||
|
return enriched
|
||||||
|
|
||||||
|
async def _aggregate_results(
|
||||||
|
self,
|
||||||
|
plan: OrchestrationPlan,
|
||||||
|
subtask_results: dict[str, dict[str, Any]],
|
||||||
|
original_task: TaskMessage,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""汇总子任务结果"""
|
||||||
|
# Simple aggregation: collect all outputs
|
||||||
|
outputs = {}
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for subtask in plan.subtasks:
|
||||||
|
result = subtask_results.get(subtask.task_id, {})
|
||||||
|
if result.get("status") == "completed":
|
||||||
|
outputs[subtask.task_id] = result.get("output", {})
|
||||||
|
else:
|
||||||
|
errors.append({
|
||||||
|
"task_id": subtask.task_id,
|
||||||
|
"error": result.get("error", "Unknown error"),
|
||||||
|
})
|
||||||
|
|
||||||
|
aggregated = {
|
||||||
|
"outputs": outputs,
|
||||||
|
"task_id": original_task.task_id,
|
||||||
|
}
|
||||||
|
if errors:
|
||||||
|
aggregated["errors"] = errors
|
||||||
|
aggregated["partial_success"] = True
|
||||||
|
|
||||||
|
return aggregated
|
||||||
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""SharedWorkspace - Agent 间共享工作空间
|
||||||
|
|
||||||
|
基于 Redis 的共享状态存储,支持读写、订阅、锁操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SharedWorkspace:
|
||||||
|
"""Agent 间共享工作空间
|
||||||
|
|
||||||
|
基于 Redis 的共享状态存储,支持:
|
||||||
|
- write/read: 读写共享数据
|
||||||
|
- lock/unlock: 分布式锁
|
||||||
|
- 版本控制:每次写入递增版本号
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, redis_client: Any = None, prefix: str = "workspace"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
redis_client: aioredis.Redis 实例,None 时使用内存字典
|
||||||
|
prefix: Redis key 前缀
|
||||||
|
"""
|
||||||
|
self._redis = redis_client
|
||||||
|
self._prefix = prefix
|
||||||
|
self._local_store: dict[str, dict[str, Any]] = {}
|
||||||
|
self._locks: dict[str, str] = {} # key -> lock_owner
|
||||||
|
|
||||||
|
def _make_key(self, key: str) -> str:
|
||||||
|
return f"{self._prefix}:{key}"
|
||||||
|
|
||||||
|
async def write(
|
||||||
|
self, key: str, value: Any, agent_id: str, ttl: int | None = None
|
||||||
|
) -> int:
|
||||||
|
"""写入共享数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 数据键
|
||||||
|
value: 数据值
|
||||||
|
agent_id: 写入者 ID
|
||||||
|
ttl: 过期时间(秒),None 表示不过期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
版本号
|
||||||
|
"""
|
||||||
|
entry = {
|
||||||
|
"value": value,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"version": await self._get_version(key) + 1,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._redis:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
data = json.dumps(entry, default=str)
|
||||||
|
if ttl:
|
||||||
|
await self._redis.setex(redis_key, ttl, data)
|
||||||
|
else:
|
||||||
|
await self._redis.set(redis_key, data)
|
||||||
|
else:
|
||||||
|
self._local_store[key] = entry
|
||||||
|
|
||||||
|
return entry["version"]
|
||||||
|
|
||||||
|
async def read(self, key: str) -> dict[str, Any] | None:
|
||||||
|
"""读取共享数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"value": ..., "agent_id": ..., "version": ..., "timestamp": ...} 或 None
|
||||||
|
"""
|
||||||
|
if self._redis:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
data = await self._redis.get(redis_key)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
return json.loads(data)
|
||||||
|
else:
|
||||||
|
return self._local_store.get(key)
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""删除共享数据"""
|
||||||
|
if self._redis:
|
||||||
|
redis_key = self._make_key(key)
|
||||||
|
result = await self._redis.delete(redis_key)
|
||||||
|
return result > 0
|
||||||
|
else:
|
||||||
|
return self._local_store.pop(key, None) is not None
|
||||||
|
|
||||||
|
async def lock(self, key: str, agent_id: str, timeout: float = 30.0) -> bool:
|
||||||
|
"""获取分布式锁
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 要锁定的数据键
|
||||||
|
agent_id: 请求锁的 Agent ID
|
||||||
|
timeout: 锁超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功获取锁
|
||||||
|
"""
|
||||||
|
lock_key = f"{self._prefix}:lock:{key}"
|
||||||
|
|
||||||
|
if self._redis:
|
||||||
|
# Redis SET with NX (only if not exists) and EX (expiry)
|
||||||
|
result = await self._redis.set(lock_key, agent_id, nx=True, ex=int(timeout))
|
||||||
|
return result is not None
|
||||||
|
else:
|
||||||
|
if key in self._locks:
|
||||||
|
return False
|
||||||
|
self._locks[key] = agent_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def unlock(self, key: str, agent_id: str) -> bool:
|
||||||
|
"""释放分布式锁
|
||||||
|
|
||||||
|
只有锁的持有者才能释放锁。
|
||||||
|
"""
|
||||||
|
lock_key = f"{self._prefix}:lock:{key}"
|
||||||
|
|
||||||
|
if self._redis:
|
||||||
|
current_owner = await self._redis.get(lock_key)
|
||||||
|
if current_owner and current_owner.decode() == agent_id:
|
||||||
|
await self._redis.delete(lock_key)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if self._locks.get(key) == agent_id:
|
||||||
|
del self._locks[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _get_version(self, key: str) -> int:
|
||||||
|
"""获取当前版本号"""
|
||||||
|
data = await self.read(key)
|
||||||
|
if data is None:
|
||||||
|
return 0
|
||||||
|
return data.get("version", 0)
|
||||||
|
|
||||||
|
async def list_keys(self) -> list[str]:
|
||||||
|
"""列出所有键"""
|
||||||
|
if self._redis:
|
||||||
|
pattern = f"{self._prefix}:*"
|
||||||
|
keys = []
|
||||||
|
async for key in self._redis.scan_iter(match=pattern):
|
||||||
|
# Strip prefix
|
||||||
|
k = key.decode() if isinstance(key, bytes) else key
|
||||||
|
k = k[len(self._prefix) + 1:] # Remove "prefix:"
|
||||||
|
# Skip lock keys
|
||||||
|
if not k.startswith("lock:"):
|
||||||
|
keys.append(k)
|
||||||
|
return keys
|
||||||
|
else:
|
||||||
|
return list(self._local_store.keys())
|
||||||
|
|
@ -0,0 +1,336 @@
|
||||||
|
"""Tests for Orchestrator and SharedWorkspace"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.orchestrator import (
|
||||||
|
Orchestrator,
|
||||||
|
OrchestrationPlan,
|
||||||
|
OrchestrationResult,
|
||||||
|
SubTask,
|
||||||
|
SubTaskStatus,
|
||||||
|
AgentRole,
|
||||||
|
)
|
||||||
|
from agentkit.core.shared_workspace import SharedWorkspace
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
|
||||||
|
# --- SharedWorkspace Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestSharedWorkspace:
|
||||||
|
"""SharedWorkspace unit tests (in-memory mode)"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_write_and_read(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
version = await ws.write("key1", {"data": "value"}, agent_id="agent_a")
|
||||||
|
assert version == 1
|
||||||
|
|
||||||
|
result = await ws.read("key1")
|
||||||
|
assert result is not None
|
||||||
|
assert result["value"] == {"data": "value"}
|
||||||
|
assert result["agent_id"] == "agent_a"
|
||||||
|
assert result["version"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_version_increments(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
v1 = await ws.write("key1", "first", agent_id="a")
|
||||||
|
v2 = await ws.write("key1", "second", agent_id="b")
|
||||||
|
assert v1 == 1
|
||||||
|
assert v2 == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_nonexistent(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
result = await ws.read("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
await ws.write("key1", "value", agent_id="a")
|
||||||
|
deleted = await ws.delete("key1")
|
||||||
|
assert deleted is True
|
||||||
|
result = await ws.read("key1")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_nonexistent(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
deleted = await ws.delete("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_and_unlock(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
acquired = await ws.lock("resource1", agent_id="agent_a")
|
||||||
|
assert acquired is True
|
||||||
|
|
||||||
|
# Same agent can't lock again (already held)
|
||||||
|
acquired2 = await ws.lock("resource1", agent_id="agent_b")
|
||||||
|
assert acquired2 is False
|
||||||
|
|
||||||
|
# Owner can unlock
|
||||||
|
unlocked = await ws.unlock("resource1", agent_id="agent_a")
|
||||||
|
assert unlocked is True
|
||||||
|
|
||||||
|
# Now another agent can lock
|
||||||
|
acquired3 = await ws.lock("resource1", agent_id="agent_b")
|
||||||
|
assert acquired3 is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unlock_by_non_owner(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
await ws.lock("resource1", agent_id="agent_a")
|
||||||
|
unlocked = await ws.unlock("resource1", agent_id="agent_b")
|
||||||
|
assert unlocked is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys(self):
|
||||||
|
ws = SharedWorkspace()
|
||||||
|
await ws.write("key1", "v1", agent_id="a")
|
||||||
|
await ws.write("key2", "v2", agent_id="a")
|
||||||
|
keys = await ws.list_keys()
|
||||||
|
assert set(keys) == {"key1", "key2"}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Orchestrator Tests ---
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgent:
|
||||||
|
"""Mock Agent for testing"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False):
|
||||||
|
self.name = name
|
||||||
|
self.agent_type = "mock"
|
||||||
|
self.version = "1.0.0"
|
||||||
|
self._output_data = output_data or {"result": f"output from {name}"}
|
||||||
|
self._should_fail = should_fail
|
||||||
|
|
||||||
|
async def execute(self, task: TaskMessage):
|
||||||
|
if self._should_fail:
|
||||||
|
raise RuntimeError(f"Agent {self.name} failed")
|
||||||
|
from agentkit.core.protocol import TaskResult
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data=self._output_data,
|
||||||
|
error_message=None,
|
||||||
|
started_at=now,
|
||||||
|
completed_at=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgentPool:
|
||||||
|
"""Mock AgentPool for testing"""
|
||||||
|
|
||||||
|
def __init__(self, agents: dict[str, MockAgent] | None = None):
|
||||||
|
self._agents = agents or {}
|
||||||
|
|
||||||
|
def get_agent(self, name: str) -> MockAgent | None:
|
||||||
|
return self._agents.get(name)
|
||||||
|
|
||||||
|
def list_agents(self) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"}
|
||||||
|
for a in self._agents.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestrator:
|
||||||
|
"""Orchestrator unit tests"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_subtask_execution(self):
|
||||||
|
"""Single agent should execute task directly"""
|
||||||
|
agent = MockAgent("worker1", {"analysis": "result"})
|
||||||
|
pool = MockAgentPool({"worker1": agent})
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="t1",
|
||||||
|
agent_name="worker1",
|
||||||
|
task_type="analyze",
|
||||||
|
priority=1,
|
||||||
|
input_data={"query": "test"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await orchestrator.execute(task)
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
assert "outputs" in result.aggregated_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parallel_groups_building(self):
|
||||||
|
"""Parallel groups should be built from dependency graph"""
|
||||||
|
pool = MockAgentPool()
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
subtasks = [
|
||||||
|
SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1",
|
||||||
|
task_type="type1", input_data={}, depends_on=[]),
|
||||||
|
SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2",
|
||||||
|
task_type="type2", input_data={}, depends_on=[]),
|
||||||
|
SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3",
|
||||||
|
task_type="type3", input_data={}, depends_on=["t0"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
groups = orchestrator._build_parallel_groups(subtasks)
|
||||||
|
assert len(groups) == 2
|
||||||
|
# First group: t0 and t1 (no dependencies)
|
||||||
|
assert set(groups[0]) == {"t0", "t1"}
|
||||||
|
# Second group: t2 (depends on t0)
|
||||||
|
assert groups[1] == ["t2"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sequential_dependency(self):
|
||||||
|
"""Tasks with sequential dependencies should execute in order"""
|
||||||
|
agent1 = MockAgent("a1", {"step": 1})
|
||||||
|
agent2 = MockAgent("a2", {"step": 2})
|
||||||
|
pool = MockAgentPool({"a1": agent1, "a2": agent2})
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
# Manually create a plan with sequential dependencies
|
||||||
|
plan = OrchestrationPlan(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="parent",
|
||||||
|
subtasks=[
|
||||||
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
||||||
|
task_type="step1", input_data={}, depends_on=[]),
|
||||||
|
SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2",
|
||||||
|
task_type="step2", input_data={}, depends_on=["t0"]),
|
||||||
|
],
|
||||||
|
parallel_groups=[["t0"], ["t1"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="parent",
|
||||||
|
agent_name="orchestrator",
|
||||||
|
task_type="pipeline",
|
||||||
|
priority=1,
|
||||||
|
input_data={"query": "test"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await orchestrator._execute_plan(plan, task)
|
||||||
|
assert results["t0"]["status"] == "completed"
|
||||||
|
assert results["t1"]["status"] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_not_found(self):
|
||||||
|
"""Missing agent should result in failed subtask"""
|
||||||
|
pool = MockAgentPool({})
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
plan = OrchestrationPlan(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="parent",
|
||||||
|
subtasks=[
|
||||||
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent",
|
||||||
|
task_type="test", input_data={}, depends_on=[]),
|
||||||
|
],
|
||||||
|
parallel_groups=[["t0"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="parent",
|
||||||
|
agent_name="orchestrator",
|
||||||
|
task_type="test",
|
||||||
|
priority=1,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await orchestrator._execute_plan(plan, task)
|
||||||
|
assert results["t0"]["status"] == "failed"
|
||||||
|
assert "not found" in results["t0"]["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dependency_result_injection(self):
|
||||||
|
"""Subtask should receive dependency results in input"""
|
||||||
|
pool = MockAgentPool()
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
subtask = SubTask(
|
||||||
|
task_id="t1",
|
||||||
|
parent_task_id="p1",
|
||||||
|
assigned_agent="a1",
|
||||||
|
task_type="test",
|
||||||
|
input_data={"query": "test"},
|
||||||
|
depends_on=["t0"],
|
||||||
|
)
|
||||||
|
|
||||||
|
subtask_results = {
|
||||||
|
"t0": {"status": "completed", "output": {"step1_result": "data"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
enriched = orchestrator._inject_dependency_results(subtask, subtask_results)
|
||||||
|
assert "dependency_results" in enriched
|
||||||
|
assert "t0" in enriched["dependency_results"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aggregation_with_errors(self):
|
||||||
|
"""Aggregation should include errors for failed subtasks"""
|
||||||
|
pool = MockAgentPool()
|
||||||
|
orchestrator = Orchestrator(agent_pool=pool)
|
||||||
|
|
||||||
|
plan = OrchestrationPlan(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="parent",
|
||||||
|
subtasks=[
|
||||||
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
||||||
|
task_type="test", input_data={}, depends_on=[]),
|
||||||
|
],
|
||||||
|
parallel_groups=[["t0"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
subtask_results = {
|
||||||
|
"t0": {"status": "failed", "error": "Agent failed"},
|
||||||
|
}
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="parent",
|
||||||
|
agent_name="orchestrator",
|
||||||
|
task_type="test",
|
||||||
|
priority=1,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = await orchestrator._aggregate_results(plan, subtask_results, task)
|
||||||
|
assert "errors" in aggregated
|
||||||
|
assert aggregated["partial_success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentRole:
|
||||||
|
"""AgentRole enum tests"""
|
||||||
|
|
||||||
|
def test_role_values(self):
|
||||||
|
assert AgentRole.ORCHESTRATOR.value == "orchestrator"
|
||||||
|
assert AgentRole.WORKER.value == "worker"
|
||||||
|
assert AgentRole.REVIEWER.value == "reviewer"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubTask:
|
||||||
|
"""SubTask dataclass tests"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
st = SubTask(
|
||||||
|
task_id="t1",
|
||||||
|
parent_task_id="p1",
|
||||||
|
assigned_agent="a1",
|
||||||
|
task_type="test",
|
||||||
|
input_data={},
|
||||||
|
)
|
||||||
|
assert st.status == SubTaskStatus.PENDING
|
||||||
|
assert st.result is None
|
||||||
|
assert st.depends_on == []
|
||||||
Loading…
Reference in New Issue