feat(core): add Orchestrator adaptive task decomposition (U5)

- execute_adaptive(): iterative execute→evaluate→re-decompose loop
- OrchestratorConfig: adaptive, max_iterations, quality_threshold
- _evaluate_quality(): LLM-based or rule-based quality scoring (0-1)
- _reexecute_failed(): preserves completed subtask results, retries
  failed ones with improvement feedback injected into input_data
- OrchestrationResult.metadata field for tracking iteration history
- 10 new tests, all passing
This commit is contained in:
chiguyong 2026-06-07 23:50:54 +08:00
parent 7054ac02b6
commit 88d8298871
2 changed files with 504 additions and 0 deletions

View File

@ -71,6 +71,16 @@ class OrchestrationResult:
aggregated_result: dict[str, Any] aggregated_result: dict[str, Any]
status: TaskStatus status: TaskStatus
total_duration_ms: float total_duration_ms: float
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class OrchestratorConfig:
"""Orchestrator 配置"""
adaptive: bool = False
max_iterations: int = 3
quality_threshold: float = 0.7
class Orchestrator: class Orchestrator:
@ -95,6 +105,7 @@ class Orchestrator:
llm_gateway: Any = None, llm_gateway: Any = None,
max_parallel: int = 5, max_parallel: int = 5,
subtask_timeout: float = 300.0, subtask_timeout: float = 300.0,
config: OrchestratorConfig | None = None,
): ):
""" """
Args: Args:
@ -103,12 +114,14 @@ class Orchestrator:
llm_gateway: LLM Gateway用于任务分解 llm_gateway: LLM Gateway用于任务分解
max_parallel: 最大并行子任务数 max_parallel: 最大并行子任务数
subtask_timeout: 子任务超时时间 subtask_timeout: 子任务超时时间
config: Orchestrator 配置包含自适应参数
""" """
self._agent_pool = agent_pool self._agent_pool = agent_pool
self._workspace = workspace or SharedWorkspace() self._workspace = workspace or SharedWorkspace()
self._llm_gateway = llm_gateway self._llm_gateway = llm_gateway
self._max_parallel = max_parallel self._max_parallel = max_parallel
self._subtask_timeout = subtask_timeout self._subtask_timeout = subtask_timeout
self._config = config or OrchestratorConfig()
async def execute(self, task: TaskMessage) -> OrchestrationResult: async def execute(self, task: TaskMessage) -> OrchestrationResult:
"""执行编排任务 """执行编排任务
@ -404,3 +417,258 @@ class Orchestrator:
aggregated["partial_success"] = True aggregated["partial_success"] = True
return aggregated return aggregated
async def execute_adaptive(
self, task: TaskMessage,
) -> OrchestrationResult:
"""自适应编排:执行→评估→再分解循环。
execute() 不同此方法在第一轮执行后评估子任务结果质量
如果评估不通过且未达 max_iterations则基于评估反馈重新分解
未达标的子任务保留已完成的子任务结果然后执行新分解的子任务
Args:
task: 原始任务消息
Returns:
OrchestrationResult: 编排结果metadata 中包含迭代历史
"""
import time as _time
start_time = _time.monotonic()
iteration_history: list[dict[str, Any]] = []
# First execution
result = await self.execute(task)
# If adaptive not enabled or already succeeded, return directly
if not self._config.adaptive or result.status == TaskStatus.COMPLETED:
# Check quality even on success
if self._config.adaptive and self._llm_gateway:
quality = await self._evaluate_quality(task, result)
if quality["score"] >= self._config.quality_threshold:
result.metadata["quality_score"] = quality["score"]
return result
return result
# Adaptive loop
current_result = result
for iteration in range(1, self._config.max_iterations + 1):
# Evaluate quality
quality = await self._evaluate_quality(task, current_result)
iteration_history.append({
"iteration": iteration,
"quality_score": quality["score"],
"feedback": quality.get("feedback", ""),
})
if quality["score"] >= self._config.quality_threshold:
logger.info(
f"Adaptive iteration {iteration}: quality "
f"{quality['score']:.2f} >= {self._config.quality_threshold}"
)
current_result.metadata["quality_score"] = quality["score"]
current_result.metadata["iterations"] = iteration_history
return current_result
logger.info(
f"Adaptive iteration {iteration}: quality "
f"{quality['score']:.2f} < {self._config.quality_threshold}, "
f"re-decomposing failed subtasks"
)
# Re-decompose failed subtasks
new_result = await self._reexecute_failed(
task, current_result, quality,
)
current_result = new_result
# Exhausted iterations
current_result.metadata["iterations"] = iteration_history
return current_result
async def _evaluate_quality(
self,
task: TaskMessage,
result: OrchestrationResult,
) -> dict[str, Any]:
"""评估子任务结果质量。
Returns:
Dict with "score" (0-1) and optional "feedback" string.
"""
# Rule-based evaluation when no LLM
if self._llm_gateway is None:
return self._rule_based_evaluate(result)
try:
return await self._llm_evaluate(task, result)
except Exception as e:
logger.warning(f"LLM evaluation failed, falling back to rule-based: {e}")
return self._rule_based_evaluate(result)
def _rule_based_evaluate(
self, result: OrchestrationResult,
) -> dict[str, Any]:
"""基于规则的质量评估:根据完成率打分。"""
total = len(result.subtask_results)
if total == 0:
return {"score": 0.0, "feedback": "No subtasks executed"}
completed = sum(
1 for r in result.subtask_results.values()
if r.get("status") == "completed"
)
score = completed / total
feedback = ""
if score < 1.0:
failed = [
tid for tid, r in result.subtask_results.items()
if r.get("status") != "completed"
]
feedback = f"Failed subtasks: {failed}"
return {"score": score, "feedback": feedback}
async def _llm_evaluate(
self,
task: TaskMessage,
result: OrchestrationResult,
) -> dict[str, Any]:
"""使用 LLM 评估子任务结果质量。"""
import json
subtask_summary = []
for tid, r in result.subtask_results.items():
subtask_summary.append({
"task_id": tid,
"status": r.get("status", "unknown"),
"output_preview": str(r.get("output", ""))[:200],
})
prompt = (
f"Evaluate the quality of the following orchestration result.\n\n"
f"Original task: {task.input_data}\n"
f"Subtask results:\n{json.dumps(subtask_summary, ensure_ascii=False)}\n\n"
f'Respond ONLY with JSON: {{"score": 0.0-1.0, "feedback": "..."}}\n'
f"Score 1.0 = perfect, 0.0 = completely failed."
)
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
try:
text = response.content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
return {
"score": float(data.get("score", 0.0)),
"feedback": data.get("feedback", ""),
}
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse LLM evaluation: {e}")
return self._rule_based_evaluate(result)
async def _reexecute_failed(
self,
task: TaskMessage,
previous_result: OrchestrationResult,
quality: dict[str, Any],
) -> OrchestrationResult:
"""重新执行失败的子任务,保留已完成的结果。"""
import time as _time
start_time = _time.monotonic()
# Identify failed subtasks
failed_task_ids = [
tid for tid, r in previous_result.subtask_results.items()
if r.get("status") != "completed"
]
if not failed_task_ids:
return previous_result
# Create new subtasks for failed ones, incorporating feedback
new_subtasks = []
for tid in failed_task_ids:
old_result = previous_result.subtask_results[tid]
new_subtasks.append(SubTask(
task_id=f"retry-{tid}",
parent_task_id=task.task_id,
assigned_agent=task.agent_name,
task_type=task.task_type,
input_data={
**task.input_data,
"previous_error": old_result.get("error", ""),
"improvement_feedback": quality.get("feedback", ""),
},
))
# Build a mini-plan for the retry subtasks
plan = OrchestrationPlan(
plan_id=f"retry-{previous_result.plan_id}",
parent_task_id=task.task_id,
subtasks=new_subtasks,
parallel_groups=[[st.task_id for st in new_subtasks]],
)
# Execute retry subtasks
retry_results = await self._execute_plan(plan, task)
# Merge: keep completed results, replace failed with retry results
merged_results = {}
for tid, r in previous_result.subtask_results.items():
if r.get("status") == "completed":
merged_results[tid] = r
for tid, r in retry_results.items():
# Map retry task IDs back to original
original_tid = tid.replace("retry-", "", 1)
merged_results[original_tid] = r
# Re-aggregate
all_subtasks = []
for tid, r in merged_results.items():
all_subtasks.append(SubTask(
task_id=tid,
parent_task_id=task.task_id,
assigned_agent=task.agent_name,
task_type=task.task_type,
input_data=task.input_data,
status=SubTaskStatus.COMPLETED if r.get("status") == "completed" else SubTaskStatus.FAILED,
result=r.get("output"),
))
retry_plan = OrchestrationPlan(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtasks=all_subtasks,
parallel_groups=[],
)
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
failed_count = sum(
1 for r in merged_results.values() if r.get("status") != "completed"
)
if failed_count == len(merged_results):
status = TaskStatus.FAILED
elif failed_count > 0:
status = TaskStatus.COMPLETED
else:
status = TaskStatus.COMPLETED
duration_ms = (_time.monotonic() - start_time) * 1000
return OrchestrationResult(
plan_id=plan.plan_id,
parent_task_id=task.task_id,
subtask_results=merged_results,
aggregated_result=aggregated,
status=status,
total_duration_ms=duration_ms,
)

View File

@ -0,0 +1,236 @@
"""Tests for Orchestrator adaptive task decomposition (U5)."""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.orchestrator import (
Orchestrator,
OrchestratorConfig,
OrchestrationResult,
SubTaskStatus,
)
from agentkit.core.protocol import TaskMessage, TaskStatus
# ── Test Helpers ──────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = {
"task_id": "task-001",
"agent_name": "test_agent",
"task_type": "analyze",
"priority": 0,
"input_data": {"query": "test"},
"callback_url": None,
"created_at": datetime.now(timezone.utc),
"timeout_seconds": 60,
}
defaults.update(overrides)
return TaskMessage(**defaults)
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
"""Create a mock AgentPool."""
pool = MagicMock()
if agents:
pool.get_agent = lambda name: agents.get(name)
pool.list_agents = lambda: [
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
for name in agents
]
else:
# Default: single agent that succeeds
mock_agent = AsyncMock()
mock_result = MagicMock()
mock_result.output_data = {"result": "done"}
mock_agent.execute = AsyncMock(return_value=mock_result)
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
]
return pool
# ── OrchestratorConfig Tests ──────────────────────────────
class TestOrchestratorConfig:
def test_default_values(self):
config = OrchestratorConfig()
assert config.adaptive is False
assert config.max_iterations == 3
assert config.quality_threshold == 0.7
def test_custom_values(self):
config = OrchestratorConfig(
adaptive=True,
max_iterations=5,
quality_threshold=0.9,
)
assert config.adaptive is True
assert config.max_iterations == 5
assert config.quality_threshold == 0.9
# ── Adaptive Execution Tests ─────────────────────────────
class TestOrchestratorAdaptive:
@pytest.mark.asyncio
async def test_adaptive_false_behaves_like_execute(self):
"""When adaptive=False, execute_adaptive should behave like execute."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=False)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_rule_based_evaluate_all_completed(self):
"""All completed subtasks should score 1.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "completed", "output": "ok"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 1.0
@pytest.mark.asyncio
async def test_rule_based_evaluate_partial(self):
"""Partial completion should score proportionally."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.5
assert "st2" in quality["feedback"]
@pytest.mark.asyncio
async def test_rule_based_evaluate_empty(self):
"""No subtasks should score 0.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.FAILED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.0
@pytest.mark.asyncio
async def test_adaptive_first_round_pass(self):
"""If first round quality passes, return directly."""
pool = _make_mock_pool()
config = OrchestratorConfig(
adaptive=True,
quality_threshold=0.5,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# All subtasks complete in mock, so quality = 1.0 >= 0.5
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_orchestration_result_metadata(self):
"""OrchestrationResult should have metadata field."""
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
assert result.metadata == {}
@pytest.mark.asyncio
async def test_reexecute_failed_preserves_completed(self):
"""_reexecute_failed should keep completed subtask results."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
previous = OrchestrationResult(
plan_id="p1",
parent_task_id="task-001",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = {"score": 0.5, "feedback": "Fix st2"}
result = await orch._reexecute_failed(task, previous, quality)
# st1 should be preserved
assert "st1" in result.subtask_results
assert result.subtask_results["st1"]["status"] == "completed"
@pytest.mark.asyncio
async def test_max_iterations_respected(self):
"""Adaptive loop should not exceed max_iterations."""
# Create a pool where agent always fails
mock_agent = AsyncMock()
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
pool = MagicMock()
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
]
config = OrchestratorConfig(
adaptive=True,
max_iterations=2,
quality_threshold=0.9,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# Should have attempted iterations
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)