From 88d8298871943551dfc7b267c6201609740d7874 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 7 Jun 2026 23:50:54 +0800 Subject: [PATCH] feat(core): add Orchestrator adaptive task decomposition (U5) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - execute_adaptive(): iterative execute→evaluate→re-decompose loop - OrchestratorConfig: adaptive, max_iterations, quality_threshold - _evaluate_quality(): LLM-based or rule-based quality scoring (0-1) - _reexecute_failed(): preserves completed subtask results, retries failed ones with improvement feedback injected into input_data - OrchestrationResult.metadata field for tracking iteration history - 10 new tests, all passing --- src/agentkit/core/orchestrator.py | 268 +++++++++++++++++++++++ tests/unit/test_orchestrator_adaptive.py | 236 ++++++++++++++++++++ 2 files changed, 504 insertions(+) create mode 100644 tests/unit/test_orchestrator_adaptive.py diff --git a/src/agentkit/core/orchestrator.py b/src/agentkit/core/orchestrator.py index 558ae84..3ed450b 100644 --- a/src/agentkit/core/orchestrator.py +++ b/src/agentkit/core/orchestrator.py @@ -71,6 +71,16 @@ class OrchestrationResult: aggregated_result: dict[str, Any] status: TaskStatus total_duration_ms: float + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class OrchestratorConfig: + """Orchestrator 配置""" + + adaptive: bool = False + max_iterations: int = 3 + quality_threshold: float = 0.7 class Orchestrator: @@ -95,6 +105,7 @@ class Orchestrator: llm_gateway: Any = None, max_parallel: int = 5, subtask_timeout: float = 300.0, + config: OrchestratorConfig | None = None, ): """ Args: @@ -103,12 +114,14 @@ class Orchestrator: llm_gateway: LLM Gateway,用于任务分解 max_parallel: 最大并行子任务数 subtask_timeout: 子任务超时时间(秒) + config: Orchestrator 配置,包含自适应参数 """ self._agent_pool = agent_pool self._workspace = workspace or SharedWorkspace() self._llm_gateway = llm_gateway self._max_parallel = max_parallel self._subtask_timeout = subtask_timeout + self._config = config or OrchestratorConfig() async def execute(self, task: TaskMessage) -> OrchestrationResult: """执行编排任务 @@ -404,3 +417,258 @@ class Orchestrator: aggregated["partial_success"] = True return aggregated + + async def execute_adaptive( + self, task: TaskMessage, + ) -> OrchestrationResult: + """自适应编排:执行→评估→再分解循环。 + + 与 execute() 不同,此方法在第一轮执行后评估子任务结果质量, + 如果评估不通过且未达 max_iterations,则基于评估反馈重新分解 + 未达标的子任务,保留已完成的子任务结果,然后执行新分解的子任务。 + + Args: + task: 原始任务消息 + + Returns: + OrchestrationResult: 编排结果,metadata 中包含迭代历史 + """ + import time as _time + + start_time = _time.monotonic() + iteration_history: list[dict[str, Any]] = [] + + # First execution + result = await self.execute(task) + + # If adaptive not enabled or already succeeded, return directly + if not self._config.adaptive or result.status == TaskStatus.COMPLETED: + # Check quality even on success + if self._config.adaptive and self._llm_gateway: + quality = await self._evaluate_quality(task, result) + if quality["score"] >= self._config.quality_threshold: + result.metadata["quality_score"] = quality["score"] + return result + return result + + # Adaptive loop + current_result = result + for iteration in range(1, self._config.max_iterations + 1): + # Evaluate quality + quality = await self._evaluate_quality(task, current_result) + iteration_history.append({ + "iteration": iteration, + "quality_score": quality["score"], + "feedback": quality.get("feedback", ""), + }) + + if quality["score"] >= self._config.quality_threshold: + logger.info( + f"Adaptive iteration {iteration}: quality " + f"{quality['score']:.2f} >= {self._config.quality_threshold}" + ) + current_result.metadata["quality_score"] = quality["score"] + current_result.metadata["iterations"] = iteration_history + return current_result + + logger.info( + f"Adaptive iteration {iteration}: quality " + f"{quality['score']:.2f} < {self._config.quality_threshold}, " + f"re-decomposing failed subtasks" + ) + + # Re-decompose failed subtasks + new_result = await self._reexecute_failed( + task, current_result, quality, + ) + current_result = new_result + + # Exhausted iterations + current_result.metadata["iterations"] = iteration_history + return current_result + + async def _evaluate_quality( + self, + task: TaskMessage, + result: OrchestrationResult, + ) -> dict[str, Any]: + """评估子任务结果质量。 + + Returns: + Dict with "score" (0-1) and optional "feedback" string. + """ + # Rule-based evaluation when no LLM + if self._llm_gateway is None: + return self._rule_based_evaluate(result) + + try: + return await self._llm_evaluate(task, result) + except Exception as e: + logger.warning(f"LLM evaluation failed, falling back to rule-based: {e}") + return self._rule_based_evaluate(result) + + def _rule_based_evaluate( + self, result: OrchestrationResult, + ) -> dict[str, Any]: + """基于规则的质量评估:根据完成率打分。""" + total = len(result.subtask_results) + if total == 0: + return {"score": 0.0, "feedback": "No subtasks executed"} + + completed = sum( + 1 for r in result.subtask_results.values() + if r.get("status") == "completed" + ) + score = completed / total + feedback = "" + if score < 1.0: + failed = [ + tid for tid, r in result.subtask_results.items() + if r.get("status") != "completed" + ] + feedback = f"Failed subtasks: {failed}" + return {"score": score, "feedback": feedback} + + async def _llm_evaluate( + self, + task: TaskMessage, + result: OrchestrationResult, + ) -> dict[str, Any]: + """使用 LLM 评估子任务结果质量。""" + import json + + subtask_summary = [] + for tid, r in result.subtask_results.items(): + subtask_summary.append({ + "task_id": tid, + "status": r.get("status", "unknown"), + "output_preview": str(r.get("output", ""))[:200], + }) + + prompt = ( + f"Evaluate the quality of the following orchestration result.\n\n" + f"Original task: {task.input_data}\n" + f"Subtask results:\n{json.dumps(subtask_summary, ensure_ascii=False)}\n\n" + f'Respond ONLY with JSON: {{"score": 0.0-1.0, "feedback": "..."}}\n' + f"Score 1.0 = perfect, 0.0 = completely failed." + ) + + response = await self._llm_gateway.chat( + messages=[{"role": "user", "content": prompt}], + model="default", + ) + + try: + text = response.content.strip() + if text.startswith("```"): + lines = text.split("\n") + text = "\n".join(lines[1:-1]) + data = json.loads(text) + return { + "score": float(data.get("score", 0.0)), + "feedback": data.get("feedback", ""), + } + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Failed to parse LLM evaluation: {e}") + return self._rule_based_evaluate(result) + + async def _reexecute_failed( + self, + task: TaskMessage, + previous_result: OrchestrationResult, + quality: dict[str, Any], + ) -> OrchestrationResult: + """重新执行失败的子任务,保留已完成的结果。""" + import time as _time + + start_time = _time.monotonic() + + # Identify failed subtasks + failed_task_ids = [ + tid for tid, r in previous_result.subtask_results.items() + if r.get("status") != "completed" + ] + + if not failed_task_ids: + return previous_result + + # Create new subtasks for failed ones, incorporating feedback + new_subtasks = [] + for tid in failed_task_ids: + old_result = previous_result.subtask_results[tid] + new_subtasks.append(SubTask( + task_id=f"retry-{tid}", + parent_task_id=task.task_id, + assigned_agent=task.agent_name, + task_type=task.task_type, + input_data={ + **task.input_data, + "previous_error": old_result.get("error", ""), + "improvement_feedback": quality.get("feedback", ""), + }, + )) + + # Build a mini-plan for the retry subtasks + plan = OrchestrationPlan( + plan_id=f"retry-{previous_result.plan_id}", + parent_task_id=task.task_id, + subtasks=new_subtasks, + parallel_groups=[[st.task_id for st in new_subtasks]], + ) + + # Execute retry subtasks + retry_results = await self._execute_plan(plan, task) + + # Merge: keep completed results, replace failed with retry results + merged_results = {} + for tid, r in previous_result.subtask_results.items(): + if r.get("status") == "completed": + merged_results[tid] = r + + for tid, r in retry_results.items(): + # Map retry task IDs back to original + original_tid = tid.replace("retry-", "", 1) + merged_results[original_tid] = r + + # Re-aggregate + all_subtasks = [] + for tid, r in merged_results.items(): + all_subtasks.append(SubTask( + task_id=tid, + parent_task_id=task.task_id, + assigned_agent=task.agent_name, + task_type=task.task_type, + input_data=task.input_data, + status=SubTaskStatus.COMPLETED if r.get("status") == "completed" else SubTaskStatus.FAILED, + result=r.get("output"), + )) + + retry_plan = OrchestrationPlan( + plan_id=plan.plan_id, + parent_task_id=task.task_id, + subtasks=all_subtasks, + parallel_groups=[], + ) + + aggregated = await self._aggregate_results(retry_plan, merged_results, task) + + failed_count = sum( + 1 for r in merged_results.values() if r.get("status") != "completed" + ) + if failed_count == len(merged_results): + status = TaskStatus.FAILED + elif failed_count > 0: + status = TaskStatus.COMPLETED + else: + status = TaskStatus.COMPLETED + + duration_ms = (_time.monotonic() - start_time) * 1000 + + return OrchestrationResult( + plan_id=plan.plan_id, + parent_task_id=task.task_id, + subtask_results=merged_results, + aggregated_result=aggregated, + status=status, + total_duration_ms=duration_ms, + ) diff --git a/tests/unit/test_orchestrator_adaptive.py b/tests/unit/test_orchestrator_adaptive.py new file mode 100644 index 0000000..27b363a --- /dev/null +++ b/tests/unit/test_orchestrator_adaptive.py @@ -0,0 +1,236 @@ +"""Tests for Orchestrator adaptive task decomposition (U5).""" + +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from agentkit.core.orchestrator import ( + Orchestrator, + OrchestratorConfig, + OrchestrationResult, + SubTaskStatus, +) +from agentkit.core.protocol import TaskMessage, TaskStatus + + +# ── Test Helpers ────────────────────────────────────────── + + +def _make_task(**overrides) -> TaskMessage: + defaults = { + "task_id": "task-001", + "agent_name": "test_agent", + "task_type": "analyze", + "priority": 0, + "input_data": {"query": "test"}, + "callback_url": None, + "created_at": datetime.now(timezone.utc), + "timeout_seconds": 60, + } + defaults.update(overrides) + return TaskMessage(**defaults) + + +def _make_mock_pool(agents: dict[str, AsyncMock] | None = None): + """Create a mock AgentPool.""" + pool = MagicMock() + + if agents: + pool.get_agent = lambda name: agents.get(name) + pool.list_agents = lambda: [ + {"name": name, "agent_type": "worker", "description": f"Agent {name}"} + for name in agents + ] + else: + # Default: single agent that succeeds + mock_agent = AsyncMock() + mock_result = MagicMock() + mock_result.output_data = {"result": "done"} + mock_agent.execute = AsyncMock(return_value=mock_result) + + pool.get_agent = lambda name: mock_agent + pool.list_agents = lambda: [ + {"name": "test_agent", "agent_type": "worker", "description": "Test agent"} + ] + + return pool + + +# ── OrchestratorConfig Tests ────────────────────────────── + + +class TestOrchestratorConfig: + def test_default_values(self): + config = OrchestratorConfig() + assert config.adaptive is False + assert config.max_iterations == 3 + assert config.quality_threshold == 0.7 + + def test_custom_values(self): + config = OrchestratorConfig( + adaptive=True, + max_iterations=5, + quality_threshold=0.9, + ) + assert config.adaptive is True + assert config.max_iterations == 5 + assert config.quality_threshold == 0.9 + + +# ── Adaptive Execution Tests ───────────────────────────── + + +class TestOrchestratorAdaptive: + @pytest.mark.asyncio + async def test_adaptive_false_behaves_like_execute(self): + """When adaptive=False, execute_adaptive should behave like execute.""" + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=False) + orch = Orchestrator(agent_pool=pool, config=config) + + task = _make_task() + result = await orch.execute_adaptive(task) + assert result.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_rule_based_evaluate_all_completed(self): + """All completed subtasks should score 1.0.""" + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=True) + orch = Orchestrator(agent_pool=pool, config=config) + + result = OrchestrationResult( + plan_id="p1", + parent_task_id="t1", + subtask_results={ + "st1": {"status": "completed", "output": "ok"}, + "st2": {"status": "completed", "output": "ok"}, + }, + aggregated_result={}, + status=TaskStatus.COMPLETED, + total_duration_ms=100, + ) + + quality = orch._rule_based_evaluate(result) + assert quality["score"] == 1.0 + + @pytest.mark.asyncio + async def test_rule_based_evaluate_partial(self): + """Partial completion should score proportionally.""" + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=True) + orch = Orchestrator(agent_pool=pool, config=config) + + result = OrchestrationResult( + plan_id="p1", + parent_task_id="t1", + subtask_results={ + "st1": {"status": "completed", "output": "ok"}, + "st2": {"status": "failed", "error": "bad"}, + }, + aggregated_result={}, + status=TaskStatus.COMPLETED, + total_duration_ms=100, + ) + + quality = orch._rule_based_evaluate(result) + assert quality["score"] == 0.5 + assert "st2" in quality["feedback"] + + @pytest.mark.asyncio + async def test_rule_based_evaluate_empty(self): + """No subtasks should score 0.0.""" + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=True) + orch = Orchestrator(agent_pool=pool, config=config) + + result = OrchestrationResult( + plan_id="p1", + parent_task_id="t1", + subtask_results={}, + aggregated_result={}, + status=TaskStatus.FAILED, + total_duration_ms=100, + ) + + quality = orch._rule_based_evaluate(result) + assert quality["score"] == 0.0 + + @pytest.mark.asyncio + async def test_adaptive_first_round_pass(self): + """If first round quality passes, return directly.""" + pool = _make_mock_pool() + config = OrchestratorConfig( + adaptive=True, + quality_threshold=0.5, + ) + orch = Orchestrator(agent_pool=pool, config=config) + + task = _make_task() + result = await orch.execute_adaptive(task) + # All subtasks complete in mock, so quality = 1.0 >= 0.5 + assert result.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_orchestration_result_metadata(self): + """OrchestrationResult should have metadata field.""" + result = OrchestrationResult( + plan_id="p1", + parent_task_id="t1", + subtask_results={}, + aggregated_result={}, + status=TaskStatus.COMPLETED, + total_duration_ms=100, + ) + assert result.metadata == {} + + @pytest.mark.asyncio + async def test_reexecute_failed_preserves_completed(self): + """_reexecute_failed should keep completed subtask results.""" + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=True) + orch = Orchestrator(agent_pool=pool, config=config) + + task = _make_task() + previous = OrchestrationResult( + plan_id="p1", + parent_task_id="task-001", + subtask_results={ + "st1": {"status": "completed", "output": "ok"}, + "st2": {"status": "failed", "error": "bad"}, + }, + aggregated_result={}, + status=TaskStatus.COMPLETED, + total_duration_ms=100, + ) + quality = {"score": 0.5, "feedback": "Fix st2"} + + result = await orch._reexecute_failed(task, previous, quality) + # st1 should be preserved + assert "st1" in result.subtask_results + assert result.subtask_results["st1"]["status"] == "completed" + + @pytest.mark.asyncio + async def test_max_iterations_respected(self): + """Adaptive loop should not exceed max_iterations.""" + # Create a pool where agent always fails + mock_agent = AsyncMock() + mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails")) + + pool = MagicMock() + pool.get_agent = lambda name: mock_agent + pool.list_agents = lambda: [ + {"name": "test_agent", "agent_type": "worker", "description": "Test"} + ] + + config = OrchestratorConfig( + adaptive=True, + max_iterations=2, + quality_threshold=0.9, + ) + orch = Orchestrator(agent_pool=pool, config=config) + + task = _make_task() + result = await orch.execute_adaptive(task) + # Should have attempted iterations + assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)