feat(core): add Orchestrator adaptive task decomposition (U5)
- execute_adaptive(): iterative execute→evaluate→re-decompose loop - OrchestratorConfig: adaptive, max_iterations, quality_threshold - _evaluate_quality(): LLM-based or rule-based quality scoring (0-1) - _reexecute_failed(): preserves completed subtask results, retries failed ones with improvement feedback injected into input_data - OrchestrationResult.metadata field for tracking iteration history - 10 new tests, all passing
This commit is contained in:
parent
7054ac02b6
commit
88d8298871
|
|
@ -71,6 +71,16 @@ class OrchestrationResult:
|
|||
aggregated_result: dict[str, Any]
|
||||
status: TaskStatus
|
||||
total_duration_ms: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
"""Orchestrator 配置"""
|
||||
|
||||
adaptive: bool = False
|
||||
max_iterations: int = 3
|
||||
quality_threshold: float = 0.7
|
||||
|
||||
|
||||
class Orchestrator:
|
||||
|
|
@ -95,6 +105,7 @@ class Orchestrator:
|
|||
llm_gateway: Any = None,
|
||||
max_parallel: int = 5,
|
||||
subtask_timeout: float = 300.0,
|
||||
config: OrchestratorConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -103,12 +114,14 @@ class Orchestrator:
|
|||
llm_gateway: LLM Gateway,用于任务分解
|
||||
max_parallel: 最大并行子任务数
|
||||
subtask_timeout: 子任务超时时间(秒)
|
||||
config: Orchestrator 配置,包含自适应参数
|
||||
"""
|
||||
self._agent_pool = agent_pool
|
||||
self._workspace = workspace or SharedWorkspace()
|
||||
self._llm_gateway = llm_gateway
|
||||
self._max_parallel = max_parallel
|
||||
self._subtask_timeout = subtask_timeout
|
||||
self._config = config or OrchestratorConfig()
|
||||
|
||||
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
||||
"""执行编排任务
|
||||
|
|
@ -404,3 +417,258 @@ class Orchestrator:
|
|||
aggregated["partial_success"] = True
|
||||
|
||||
return aggregated
|
||||
|
||||
async def execute_adaptive(
|
||||
self, task: TaskMessage,
|
||||
) -> OrchestrationResult:
|
||||
"""自适应编排:执行→评估→再分解循环。
|
||||
|
||||
与 execute() 不同,此方法在第一轮执行后评估子任务结果质量,
|
||||
如果评估不通过且未达 max_iterations,则基于评估反馈重新分解
|
||||
未达标的子任务,保留已完成的子任务结果,然后执行新分解的子任务。
|
||||
|
||||
Args:
|
||||
task: 原始任务消息
|
||||
|
||||
Returns:
|
||||
OrchestrationResult: 编排结果,metadata 中包含迭代历史
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
start_time = _time.monotonic()
|
||||
iteration_history: list[dict[str, Any]] = []
|
||||
|
||||
# First execution
|
||||
result = await self.execute(task)
|
||||
|
||||
# If adaptive not enabled or already succeeded, return directly
|
||||
if not self._config.adaptive or result.status == TaskStatus.COMPLETED:
|
||||
# Check quality even on success
|
||||
if self._config.adaptive and self._llm_gateway:
|
||||
quality = await self._evaluate_quality(task, result)
|
||||
if quality["score"] >= self._config.quality_threshold:
|
||||
result.metadata["quality_score"] = quality["score"]
|
||||
return result
|
||||
return result
|
||||
|
||||
# Adaptive loop
|
||||
current_result = result
|
||||
for iteration in range(1, self._config.max_iterations + 1):
|
||||
# Evaluate quality
|
||||
quality = await self._evaluate_quality(task, current_result)
|
||||
iteration_history.append({
|
||||
"iteration": iteration,
|
||||
"quality_score": quality["score"],
|
||||
"feedback": quality.get("feedback", ""),
|
||||
})
|
||||
|
||||
if quality["score"] >= self._config.quality_threshold:
|
||||
logger.info(
|
||||
f"Adaptive iteration {iteration}: quality "
|
||||
f"{quality['score']:.2f} >= {self._config.quality_threshold}"
|
||||
)
|
||||
current_result.metadata["quality_score"] = quality["score"]
|
||||
current_result.metadata["iterations"] = iteration_history
|
||||
return current_result
|
||||
|
||||
logger.info(
|
||||
f"Adaptive iteration {iteration}: quality "
|
||||
f"{quality['score']:.2f} < {self._config.quality_threshold}, "
|
||||
f"re-decomposing failed subtasks"
|
||||
)
|
||||
|
||||
# Re-decompose failed subtasks
|
||||
new_result = await self._reexecute_failed(
|
||||
task, current_result, quality,
|
||||
)
|
||||
current_result = new_result
|
||||
|
||||
# Exhausted iterations
|
||||
current_result.metadata["iterations"] = iteration_history
|
||||
return current_result
|
||||
|
||||
async def _evaluate_quality(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: OrchestrationResult,
|
||||
) -> dict[str, Any]:
|
||||
"""评估子任务结果质量。
|
||||
|
||||
Returns:
|
||||
Dict with "score" (0-1) and optional "feedback" string.
|
||||
"""
|
||||
# Rule-based evaluation when no LLM
|
||||
if self._llm_gateway is None:
|
||||
return self._rule_based_evaluate(result)
|
||||
|
||||
try:
|
||||
return await self._llm_evaluate(task, result)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM evaluation failed, falling back to rule-based: {e}")
|
||||
return self._rule_based_evaluate(result)
|
||||
|
||||
def _rule_based_evaluate(
|
||||
self, result: OrchestrationResult,
|
||||
) -> dict[str, Any]:
|
||||
"""基于规则的质量评估:根据完成率打分。"""
|
||||
total = len(result.subtask_results)
|
||||
if total == 0:
|
||||
return {"score": 0.0, "feedback": "No subtasks executed"}
|
||||
|
||||
completed = sum(
|
||||
1 for r in result.subtask_results.values()
|
||||
if r.get("status") == "completed"
|
||||
)
|
||||
score = completed / total
|
||||
feedback = ""
|
||||
if score < 1.0:
|
||||
failed = [
|
||||
tid for tid, r in result.subtask_results.items()
|
||||
if r.get("status") != "completed"
|
||||
]
|
||||
feedback = f"Failed subtasks: {failed}"
|
||||
return {"score": score, "feedback": feedback}
|
||||
|
||||
async def _llm_evaluate(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
result: OrchestrationResult,
|
||||
) -> dict[str, Any]:
|
||||
"""使用 LLM 评估子任务结果质量。"""
|
||||
import json
|
||||
|
||||
subtask_summary = []
|
||||
for tid, r in result.subtask_results.items():
|
||||
subtask_summary.append({
|
||||
"task_id": tid,
|
||||
"status": r.get("status", "unknown"),
|
||||
"output_preview": str(r.get("output", ""))[:200],
|
||||
})
|
||||
|
||||
prompt = (
|
||||
f"Evaluate the quality of the following orchestration result.\n\n"
|
||||
f"Original task: {task.input_data}\n"
|
||||
f"Subtask results:\n{json.dumps(subtask_summary, ensure_ascii=False)}\n\n"
|
||||
f'Respond ONLY with JSON: {{"score": 0.0-1.0, "feedback": "..."}}\n'
|
||||
f"Score 1.0 = perfect, 0.0 = completely failed."
|
||||
)
|
||||
|
||||
response = await self._llm_gateway.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model="default",
|
||||
)
|
||||
|
||||
try:
|
||||
text = response.content.strip()
|
||||
if text.startswith("```"):
|
||||
lines = text.split("\n")
|
||||
text = "\n".join(lines[1:-1])
|
||||
data = json.loads(text)
|
||||
return {
|
||||
"score": float(data.get("score", 0.0)),
|
||||
"feedback": data.get("feedback", ""),
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse LLM evaluation: {e}")
|
||||
return self._rule_based_evaluate(result)
|
||||
|
||||
async def _reexecute_failed(
|
||||
self,
|
||||
task: TaskMessage,
|
||||
previous_result: OrchestrationResult,
|
||||
quality: dict[str, Any],
|
||||
) -> OrchestrationResult:
|
||||
"""重新执行失败的子任务,保留已完成的结果。"""
|
||||
import time as _time
|
||||
|
||||
start_time = _time.monotonic()
|
||||
|
||||
# Identify failed subtasks
|
||||
failed_task_ids = [
|
||||
tid for tid, r in previous_result.subtask_results.items()
|
||||
if r.get("status") != "completed"
|
||||
]
|
||||
|
||||
if not failed_task_ids:
|
||||
return previous_result
|
||||
|
||||
# Create new subtasks for failed ones, incorporating feedback
|
||||
new_subtasks = []
|
||||
for tid in failed_task_ids:
|
||||
old_result = previous_result.subtask_results[tid]
|
||||
new_subtasks.append(SubTask(
|
||||
task_id=f"retry-{tid}",
|
||||
parent_task_id=task.task_id,
|
||||
assigned_agent=task.agent_name,
|
||||
task_type=task.task_type,
|
||||
input_data={
|
||||
**task.input_data,
|
||||
"previous_error": old_result.get("error", ""),
|
||||
"improvement_feedback": quality.get("feedback", ""),
|
||||
},
|
||||
))
|
||||
|
||||
# Build a mini-plan for the retry subtasks
|
||||
plan = OrchestrationPlan(
|
||||
plan_id=f"retry-{previous_result.plan_id}",
|
||||
parent_task_id=task.task_id,
|
||||
subtasks=new_subtasks,
|
||||
parallel_groups=[[st.task_id for st in new_subtasks]],
|
||||
)
|
||||
|
||||
# Execute retry subtasks
|
||||
retry_results = await self._execute_plan(plan, task)
|
||||
|
||||
# Merge: keep completed results, replace failed with retry results
|
||||
merged_results = {}
|
||||
for tid, r in previous_result.subtask_results.items():
|
||||
if r.get("status") == "completed":
|
||||
merged_results[tid] = r
|
||||
|
||||
for tid, r in retry_results.items():
|
||||
# Map retry task IDs back to original
|
||||
original_tid = tid.replace("retry-", "", 1)
|
||||
merged_results[original_tid] = r
|
||||
|
||||
# Re-aggregate
|
||||
all_subtasks = []
|
||||
for tid, r in merged_results.items():
|
||||
all_subtasks.append(SubTask(
|
||||
task_id=tid,
|
||||
parent_task_id=task.task_id,
|
||||
assigned_agent=task.agent_name,
|
||||
task_type=task.task_type,
|
||||
input_data=task.input_data,
|
||||
status=SubTaskStatus.COMPLETED if r.get("status") == "completed" else SubTaskStatus.FAILED,
|
||||
result=r.get("output"),
|
||||
))
|
||||
|
||||
retry_plan = OrchestrationPlan(
|
||||
plan_id=plan.plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtasks=all_subtasks,
|
||||
parallel_groups=[],
|
||||
)
|
||||
|
||||
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
|
||||
|
||||
failed_count = sum(
|
||||
1 for r in merged_results.values() if r.get("status") != "completed"
|
||||
)
|
||||
if failed_count == len(merged_results):
|
||||
status = TaskStatus.FAILED
|
||||
elif failed_count > 0:
|
||||
status = TaskStatus.COMPLETED
|
||||
else:
|
||||
status = TaskStatus.COMPLETED
|
||||
|
||||
duration_ms = (_time.monotonic() - start_time) * 1000
|
||||
|
||||
return OrchestrationResult(
|
||||
plan_id=plan.plan_id,
|
||||
parent_task_id=task.task_id,
|
||||
subtask_results=merged_results,
|
||||
aggregated_result=aggregated,
|
||||
status=status,
|
||||
total_duration_ms=duration_ms,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,236 @@
|
|||
"""Tests for Orchestrator adaptive task decomposition (U5)."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.core.orchestrator import (
|
||||
Orchestrator,
|
||||
OrchestratorConfig,
|
||||
OrchestrationResult,
|
||||
SubTaskStatus,
|
||||
)
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
|
||||
|
||||
# ── Test Helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_task(**overrides) -> TaskMessage:
|
||||
defaults = {
|
||||
"task_id": "task-001",
|
||||
"agent_name": "test_agent",
|
||||
"task_type": "analyze",
|
||||
"priority": 0,
|
||||
"input_data": {"query": "test"},
|
||||
"callback_url": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"timeout_seconds": 60,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return TaskMessage(**defaults)
|
||||
|
||||
|
||||
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
|
||||
"""Create a mock AgentPool."""
|
||||
pool = MagicMock()
|
||||
|
||||
if agents:
|
||||
pool.get_agent = lambda name: agents.get(name)
|
||||
pool.list_agents = lambda: [
|
||||
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
|
||||
for name in agents
|
||||
]
|
||||
else:
|
||||
# Default: single agent that succeeds
|
||||
mock_agent = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.output_data = {"result": "done"}
|
||||
mock_agent.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
pool.get_agent = lambda name: mock_agent
|
||||
pool.list_agents = lambda: [
|
||||
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
|
||||
]
|
||||
|
||||
return pool
|
||||
|
||||
|
||||
# ── OrchestratorConfig Tests ──────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestratorConfig:
|
||||
def test_default_values(self):
|
||||
config = OrchestratorConfig()
|
||||
assert config.adaptive is False
|
||||
assert config.max_iterations == 3
|
||||
assert config.quality_threshold == 0.7
|
||||
|
||||
def test_custom_values(self):
|
||||
config = OrchestratorConfig(
|
||||
adaptive=True,
|
||||
max_iterations=5,
|
||||
quality_threshold=0.9,
|
||||
)
|
||||
assert config.adaptive is True
|
||||
assert config.max_iterations == 5
|
||||
assert config.quality_threshold == 0.9
|
||||
|
||||
|
||||
# ── Adaptive Execution Tests ─────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestratorAdaptive:
|
||||
@pytest.mark.asyncio
|
||||
async def test_adaptive_false_behaves_like_execute(self):
|
||||
"""When adaptive=False, execute_adaptive should behave like execute."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=False)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = await orch.execute_adaptive(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule_based_evaluate_all_completed(self):
|
||||
"""All completed subtasks should score 1.0."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=True)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
result = OrchestrationResult(
|
||||
plan_id="p1",
|
||||
parent_task_id="t1",
|
||||
subtask_results={
|
||||
"st1": {"status": "completed", "output": "ok"},
|
||||
"st2": {"status": "completed", "output": "ok"},
|
||||
},
|
||||
aggregated_result={},
|
||||
status=TaskStatus.COMPLETED,
|
||||
total_duration_ms=100,
|
||||
)
|
||||
|
||||
quality = orch._rule_based_evaluate(result)
|
||||
assert quality["score"] == 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule_based_evaluate_partial(self):
|
||||
"""Partial completion should score proportionally."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=True)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
result = OrchestrationResult(
|
||||
plan_id="p1",
|
||||
parent_task_id="t1",
|
||||
subtask_results={
|
||||
"st1": {"status": "completed", "output": "ok"},
|
||||
"st2": {"status": "failed", "error": "bad"},
|
||||
},
|
||||
aggregated_result={},
|
||||
status=TaskStatus.COMPLETED,
|
||||
total_duration_ms=100,
|
||||
)
|
||||
|
||||
quality = orch._rule_based_evaluate(result)
|
||||
assert quality["score"] == 0.5
|
||||
assert "st2" in quality["feedback"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rule_based_evaluate_empty(self):
|
||||
"""No subtasks should score 0.0."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=True)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
result = OrchestrationResult(
|
||||
plan_id="p1",
|
||||
parent_task_id="t1",
|
||||
subtask_results={},
|
||||
aggregated_result={},
|
||||
status=TaskStatus.FAILED,
|
||||
total_duration_ms=100,
|
||||
)
|
||||
|
||||
quality = orch._rule_based_evaluate(result)
|
||||
assert quality["score"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adaptive_first_round_pass(self):
|
||||
"""If first round quality passes, return directly."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(
|
||||
adaptive=True,
|
||||
quality_threshold=0.5,
|
||||
)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = await orch.execute_adaptive(task)
|
||||
# All subtasks complete in mock, so quality = 1.0 >= 0.5
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestration_result_metadata(self):
|
||||
"""OrchestrationResult should have metadata field."""
|
||||
result = OrchestrationResult(
|
||||
plan_id="p1",
|
||||
parent_task_id="t1",
|
||||
subtask_results={},
|
||||
aggregated_result={},
|
||||
status=TaskStatus.COMPLETED,
|
||||
total_duration_ms=100,
|
||||
)
|
||||
assert result.metadata == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reexecute_failed_preserves_completed(self):
|
||||
"""_reexecute_failed should keep completed subtask results."""
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=True)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
task = _make_task()
|
||||
previous = OrchestrationResult(
|
||||
plan_id="p1",
|
||||
parent_task_id="task-001",
|
||||
subtask_results={
|
||||
"st1": {"status": "completed", "output": "ok"},
|
||||
"st2": {"status": "failed", "error": "bad"},
|
||||
},
|
||||
aggregated_result={},
|
||||
status=TaskStatus.COMPLETED,
|
||||
total_duration_ms=100,
|
||||
)
|
||||
quality = {"score": 0.5, "feedback": "Fix st2"}
|
||||
|
||||
result = await orch._reexecute_failed(task, previous, quality)
|
||||
# st1 should be preserved
|
||||
assert "st1" in result.subtask_results
|
||||
assert result.subtask_results["st1"]["status"] == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_respected(self):
|
||||
"""Adaptive loop should not exceed max_iterations."""
|
||||
# Create a pool where agent always fails
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
|
||||
|
||||
pool = MagicMock()
|
||||
pool.get_agent = lambda name: mock_agent
|
||||
pool.list_agents = lambda: [
|
||||
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
|
||||
]
|
||||
|
||||
config = OrchestratorConfig(
|
||||
adaptive=True,
|
||||
max_iterations=2,
|
||||
quality_threshold=0.9,
|
||||
)
|
||||
orch = Orchestrator(agent_pool=pool, config=config)
|
||||
|
||||
task = _make_task()
|
||||
result = await orch.execute_adaptive(task)
|
||||
# Should have attempted iterations
|
||||
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)
|
||||
Loading…
Reference in New Issue