feat(core): add Orchestrator adaptive task decomposition (U5)
- execute_adaptive(): iterative execute→evaluate→re-decompose loop - OrchestratorConfig: adaptive, max_iterations, quality_threshold - _evaluate_quality(): LLM-based or rule-based quality scoring (0-1) - _reexecute_failed(): preserves completed subtask results, retries failed ones with improvement feedback injected into input_data - OrchestrationResult.metadata field for tracking iteration history - 10 new tests, all passing
This commit is contained in:
parent
7054ac02b6
commit
88d8298871
|
|
@ -71,6 +71,16 @@ class OrchestrationResult:
|
||||||
aggregated_result: dict[str, Any]
|
aggregated_result: dict[str, Any]
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
total_duration_ms: float
|
total_duration_ms: float
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrchestratorConfig:
|
||||||
|
"""Orchestrator 配置"""
|
||||||
|
|
||||||
|
adaptive: bool = False
|
||||||
|
max_iterations: int = 3
|
||||||
|
quality_threshold: float = 0.7
|
||||||
|
|
||||||
|
|
||||||
class Orchestrator:
|
class Orchestrator:
|
||||||
|
|
@ -95,6 +105,7 @@ class Orchestrator:
|
||||||
llm_gateway: Any = None,
|
llm_gateway: Any = None,
|
||||||
max_parallel: int = 5,
|
max_parallel: int = 5,
|
||||||
subtask_timeout: float = 300.0,
|
subtask_timeout: float = 300.0,
|
||||||
|
config: OrchestratorConfig | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -103,12 +114,14 @@ class Orchestrator:
|
||||||
llm_gateway: LLM Gateway,用于任务分解
|
llm_gateway: LLM Gateway,用于任务分解
|
||||||
max_parallel: 最大并行子任务数
|
max_parallel: 最大并行子任务数
|
||||||
subtask_timeout: 子任务超时时间(秒)
|
subtask_timeout: 子任务超时时间(秒)
|
||||||
|
config: Orchestrator 配置,包含自适应参数
|
||||||
"""
|
"""
|
||||||
self._agent_pool = agent_pool
|
self._agent_pool = agent_pool
|
||||||
self._workspace = workspace or SharedWorkspace()
|
self._workspace = workspace or SharedWorkspace()
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
self._max_parallel = max_parallel
|
self._max_parallel = max_parallel
|
||||||
self._subtask_timeout = subtask_timeout
|
self._subtask_timeout = subtask_timeout
|
||||||
|
self._config = config or OrchestratorConfig()
|
||||||
|
|
||||||
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
||||||
"""执行编排任务
|
"""执行编排任务
|
||||||
|
|
@ -404,3 +417,258 @@ class Orchestrator:
|
||||||
aggregated["partial_success"] = True
|
aggregated["partial_success"] = True
|
||||||
|
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
|
async def execute_adaptive(
|
||||||
|
self, task: TaskMessage,
|
||||||
|
) -> OrchestrationResult:
|
||||||
|
"""自适应编排:执行→评估→再分解循环。
|
||||||
|
|
||||||
|
与 execute() 不同,此方法在第一轮执行后评估子任务结果质量,
|
||||||
|
如果评估不通过且未达 max_iterations,则基于评估反馈重新分解
|
||||||
|
未达标的子任务,保留已完成的子任务结果,然后执行新分解的子任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 原始任务消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrchestrationResult: 编排结果,metadata 中包含迭代历史
|
||||||
|
"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
start_time = _time.monotonic()
|
||||||
|
iteration_history: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# First execution
|
||||||
|
result = await self.execute(task)
|
||||||
|
|
||||||
|
# If adaptive not enabled or already succeeded, return directly
|
||||||
|
if not self._config.adaptive or result.status == TaskStatus.COMPLETED:
|
||||||
|
# Check quality even on success
|
||||||
|
if self._config.adaptive and self._llm_gateway:
|
||||||
|
quality = await self._evaluate_quality(task, result)
|
||||||
|
if quality["score"] >= self._config.quality_threshold:
|
||||||
|
result.metadata["quality_score"] = quality["score"]
|
||||||
|
return result
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Adaptive loop
|
||||||
|
current_result = result
|
||||||
|
for iteration in range(1, self._config.max_iterations + 1):
|
||||||
|
# Evaluate quality
|
||||||
|
quality = await self._evaluate_quality(task, current_result)
|
||||||
|
iteration_history.append({
|
||||||
|
"iteration": iteration,
|
||||||
|
"quality_score": quality["score"],
|
||||||
|
"feedback": quality.get("feedback", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
if quality["score"] >= self._config.quality_threshold:
|
||||||
|
logger.info(
|
||||||
|
f"Adaptive iteration {iteration}: quality "
|
||||||
|
f"{quality['score']:.2f} >= {self._config.quality_threshold}"
|
||||||
|
)
|
||||||
|
current_result.metadata["quality_score"] = quality["score"]
|
||||||
|
current_result.metadata["iterations"] = iteration_history
|
||||||
|
return current_result
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Adaptive iteration {iteration}: quality "
|
||||||
|
f"{quality['score']:.2f} < {self._config.quality_threshold}, "
|
||||||
|
f"re-decomposing failed subtasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-decompose failed subtasks
|
||||||
|
new_result = await self._reexecute_failed(
|
||||||
|
task, current_result, quality,
|
||||||
|
)
|
||||||
|
current_result = new_result
|
||||||
|
|
||||||
|
# Exhausted iterations
|
||||||
|
current_result.metadata["iterations"] = iteration_history
|
||||||
|
return current_result
|
||||||
|
|
||||||
|
async def _evaluate_quality(
|
||||||
|
self,
|
||||||
|
task: TaskMessage,
|
||||||
|
result: OrchestrationResult,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""评估子任务结果质量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with "score" (0-1) and optional "feedback" string.
|
||||||
|
"""
|
||||||
|
# Rule-based evaluation when no LLM
|
||||||
|
if self._llm_gateway is None:
|
||||||
|
return self._rule_based_evaluate(result)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._llm_evaluate(task, result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM evaluation failed, falling back to rule-based: {e}")
|
||||||
|
return self._rule_based_evaluate(result)
|
||||||
|
|
||||||
|
def _rule_based_evaluate(
|
||||||
|
self, result: OrchestrationResult,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""基于规则的质量评估:根据完成率打分。"""
|
||||||
|
total = len(result.subtask_results)
|
||||||
|
if total == 0:
|
||||||
|
return {"score": 0.0, "feedback": "No subtasks executed"}
|
||||||
|
|
||||||
|
completed = sum(
|
||||||
|
1 for r in result.subtask_results.values()
|
||||||
|
if r.get("status") == "completed"
|
||||||
|
)
|
||||||
|
score = completed / total
|
||||||
|
feedback = ""
|
||||||
|
if score < 1.0:
|
||||||
|
failed = [
|
||||||
|
tid for tid, r in result.subtask_results.items()
|
||||||
|
if r.get("status") != "completed"
|
||||||
|
]
|
||||||
|
feedback = f"Failed subtasks: {failed}"
|
||||||
|
return {"score": score, "feedback": feedback}
|
||||||
|
|
||||||
|
async def _llm_evaluate(
|
||||||
|
self,
|
||||||
|
task: TaskMessage,
|
||||||
|
result: OrchestrationResult,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""使用 LLM 评估子任务结果质量。"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
subtask_summary = []
|
||||||
|
for tid, r in result.subtask_results.items():
|
||||||
|
subtask_summary.append({
|
||||||
|
"task_id": tid,
|
||||||
|
"status": r.get("status", "unknown"),
|
||||||
|
"output_preview": str(r.get("output", ""))[:200],
|
||||||
|
})
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Evaluate the quality of the following orchestration result.\n\n"
|
||||||
|
f"Original task: {task.input_data}\n"
|
||||||
|
f"Subtask results:\n{json.dumps(subtask_summary, ensure_ascii=False)}\n\n"
|
||||||
|
f'Respond ONLY with JSON: {{"score": 0.0-1.0, "feedback": "..."}}\n'
|
||||||
|
f"Score 1.0 = perfect, 0.0 = completely failed."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
model="default",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text = response.content.strip()
|
||||||
|
if text.startswith("```"):
|
||||||
|
lines = text.split("\n")
|
||||||
|
text = "\n".join(lines[1:-1])
|
||||||
|
data = json.loads(text)
|
||||||
|
return {
|
||||||
|
"score": float(data.get("score", 0.0)),
|
||||||
|
"feedback": data.get("feedback", ""),
|
||||||
|
}
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to parse LLM evaluation: {e}")
|
||||||
|
return self._rule_based_evaluate(result)
|
||||||
|
|
||||||
|
async def _reexecute_failed(
|
||||||
|
self,
|
||||||
|
task: TaskMessage,
|
||||||
|
previous_result: OrchestrationResult,
|
||||||
|
quality: dict[str, Any],
|
||||||
|
) -> OrchestrationResult:
|
||||||
|
"""重新执行失败的子任务,保留已完成的结果。"""
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
start_time = _time.monotonic()
|
||||||
|
|
||||||
|
# Identify failed subtasks
|
||||||
|
failed_task_ids = [
|
||||||
|
tid for tid, r in previous_result.subtask_results.items()
|
||||||
|
if r.get("status") != "completed"
|
||||||
|
]
|
||||||
|
|
||||||
|
if not failed_task_ids:
|
||||||
|
return previous_result
|
||||||
|
|
||||||
|
# Create new subtasks for failed ones, incorporating feedback
|
||||||
|
new_subtasks = []
|
||||||
|
for tid in failed_task_ids:
|
||||||
|
old_result = previous_result.subtask_results[tid]
|
||||||
|
new_subtasks.append(SubTask(
|
||||||
|
task_id=f"retry-{tid}",
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
assigned_agent=task.agent_name,
|
||||||
|
task_type=task.task_type,
|
||||||
|
input_data={
|
||||||
|
**task.input_data,
|
||||||
|
"previous_error": old_result.get("error", ""),
|
||||||
|
"improvement_feedback": quality.get("feedback", ""),
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
# Build a mini-plan for the retry subtasks
|
||||||
|
plan = OrchestrationPlan(
|
||||||
|
plan_id=f"retry-{previous_result.plan_id}",
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtasks=new_subtasks,
|
||||||
|
parallel_groups=[[st.task_id for st in new_subtasks]],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute retry subtasks
|
||||||
|
retry_results = await self._execute_plan(plan, task)
|
||||||
|
|
||||||
|
# Merge: keep completed results, replace failed with retry results
|
||||||
|
merged_results = {}
|
||||||
|
for tid, r in previous_result.subtask_results.items():
|
||||||
|
if r.get("status") == "completed":
|
||||||
|
merged_results[tid] = r
|
||||||
|
|
||||||
|
for tid, r in retry_results.items():
|
||||||
|
# Map retry task IDs back to original
|
||||||
|
original_tid = tid.replace("retry-", "", 1)
|
||||||
|
merged_results[original_tid] = r
|
||||||
|
|
||||||
|
# Re-aggregate
|
||||||
|
all_subtasks = []
|
||||||
|
for tid, r in merged_results.items():
|
||||||
|
all_subtasks.append(SubTask(
|
||||||
|
task_id=tid,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
assigned_agent=task.agent_name,
|
||||||
|
task_type=task.task_type,
|
||||||
|
input_data=task.input_data,
|
||||||
|
status=SubTaskStatus.COMPLETED if r.get("status") == "completed" else SubTaskStatus.FAILED,
|
||||||
|
result=r.get("output"),
|
||||||
|
))
|
||||||
|
|
||||||
|
retry_plan = OrchestrationPlan(
|
||||||
|
plan_id=plan.plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtasks=all_subtasks,
|
||||||
|
parallel_groups=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
aggregated = await self._aggregate_results(retry_plan, merged_results, task)
|
||||||
|
|
||||||
|
failed_count = sum(
|
||||||
|
1 for r in merged_results.values() if r.get("status") != "completed"
|
||||||
|
)
|
||||||
|
if failed_count == len(merged_results):
|
||||||
|
status = TaskStatus.FAILED
|
||||||
|
elif failed_count > 0:
|
||||||
|
status = TaskStatus.COMPLETED
|
||||||
|
else:
|
||||||
|
status = TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
duration_ms = (_time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
|
return OrchestrationResult(
|
||||||
|
plan_id=plan.plan_id,
|
||||||
|
parent_task_id=task.task_id,
|
||||||
|
subtask_results=merged_results,
|
||||||
|
aggregated_result=aggregated,
|
||||||
|
status=status,
|
||||||
|
total_duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
"""Tests for Orchestrator adaptive task decomposition (U5)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.core.orchestrator import (
|
||||||
|
Orchestrator,
|
||||||
|
OrchestratorConfig,
|
||||||
|
OrchestrationResult,
|
||||||
|
SubTaskStatus,
|
||||||
|
)
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test Helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_task(**overrides) -> TaskMessage:
|
||||||
|
defaults = {
|
||||||
|
"task_id": "task-001",
|
||||||
|
"agent_name": "test_agent",
|
||||||
|
"task_type": "analyze",
|
||||||
|
"priority": 0,
|
||||||
|
"input_data": {"query": "test"},
|
||||||
|
"callback_url": None,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"timeout_seconds": 60,
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return TaskMessage(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
|
||||||
|
"""Create a mock AgentPool."""
|
||||||
|
pool = MagicMock()
|
||||||
|
|
||||||
|
if agents:
|
||||||
|
pool.get_agent = lambda name: agents.get(name)
|
||||||
|
pool.list_agents = lambda: [
|
||||||
|
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
|
||||||
|
for name in agents
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Default: single agent that succeeds
|
||||||
|
mock_agent = AsyncMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.output_data = {"result": "done"}
|
||||||
|
mock_agent.execute = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
pool.get_agent = lambda name: mock_agent
|
||||||
|
pool.list_agents = lambda: [
|
||||||
|
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
|
||||||
|
]
|
||||||
|
|
||||||
|
return pool
|
||||||
|
|
||||||
|
|
||||||
|
# ── OrchestratorConfig Tests ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestratorConfig:
|
||||||
|
def test_default_values(self):
|
||||||
|
config = OrchestratorConfig()
|
||||||
|
assert config.adaptive is False
|
||||||
|
assert config.max_iterations == 3
|
||||||
|
assert config.quality_threshold == 0.7
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
adaptive=True,
|
||||||
|
max_iterations=5,
|
||||||
|
quality_threshold=0.9,
|
||||||
|
)
|
||||||
|
assert config.adaptive is True
|
||||||
|
assert config.max_iterations == 5
|
||||||
|
assert config.quality_threshold == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
# ── Adaptive Execution Tests ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestratorAdaptive:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adaptive_false_behaves_like_execute(self):
|
||||||
|
"""When adaptive=False, execute_adaptive should behave like execute."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(adaptive=False)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await orch.execute_adaptive(task)
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rule_based_evaluate_all_completed(self):
|
||||||
|
"""All completed subtasks should score 1.0."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(adaptive=True)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
result = OrchestrationResult(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="t1",
|
||||||
|
subtask_results={
|
||||||
|
"st1": {"status": "completed", "output": "ok"},
|
||||||
|
"st2": {"status": "completed", "output": "ok"},
|
||||||
|
},
|
||||||
|
aggregated_result={},
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
total_duration_ms=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
quality = orch._rule_based_evaluate(result)
|
||||||
|
assert quality["score"] == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rule_based_evaluate_partial(self):
|
||||||
|
"""Partial completion should score proportionally."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(adaptive=True)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
result = OrchestrationResult(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="t1",
|
||||||
|
subtask_results={
|
||||||
|
"st1": {"status": "completed", "output": "ok"},
|
||||||
|
"st2": {"status": "failed", "error": "bad"},
|
||||||
|
},
|
||||||
|
aggregated_result={},
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
total_duration_ms=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
quality = orch._rule_based_evaluate(result)
|
||||||
|
assert quality["score"] == 0.5
|
||||||
|
assert "st2" in quality["feedback"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rule_based_evaluate_empty(self):
|
||||||
|
"""No subtasks should score 0.0."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(adaptive=True)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
result = OrchestrationResult(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="t1",
|
||||||
|
subtask_results={},
|
||||||
|
aggregated_result={},
|
||||||
|
status=TaskStatus.FAILED,
|
||||||
|
total_duration_ms=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
quality = orch._rule_based_evaluate(result)
|
||||||
|
assert quality["score"] == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adaptive_first_round_pass(self):
|
||||||
|
"""If first round quality passes, return directly."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
adaptive=True,
|
||||||
|
quality_threshold=0.5,
|
||||||
|
)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await orch.execute_adaptive(task)
|
||||||
|
# All subtasks complete in mock, so quality = 1.0 >= 0.5
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_orchestration_result_metadata(self):
|
||||||
|
"""OrchestrationResult should have metadata field."""
|
||||||
|
result = OrchestrationResult(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="t1",
|
||||||
|
subtask_results={},
|
||||||
|
aggregated_result={},
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
total_duration_ms=100,
|
||||||
|
)
|
||||||
|
assert result.metadata == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reexecute_failed_preserves_completed(self):
|
||||||
|
"""_reexecute_failed should keep completed subtask results."""
|
||||||
|
pool = _make_mock_pool()
|
||||||
|
config = OrchestratorConfig(adaptive=True)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
previous = OrchestrationResult(
|
||||||
|
plan_id="p1",
|
||||||
|
parent_task_id="task-001",
|
||||||
|
subtask_results={
|
||||||
|
"st1": {"status": "completed", "output": "ok"},
|
||||||
|
"st2": {"status": "failed", "error": "bad"},
|
||||||
|
},
|
||||||
|
aggregated_result={},
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
total_duration_ms=100,
|
||||||
|
)
|
||||||
|
quality = {"score": 0.5, "feedback": "Fix st2"}
|
||||||
|
|
||||||
|
result = await orch._reexecute_failed(task, previous, quality)
|
||||||
|
# st1 should be preserved
|
||||||
|
assert "st1" in result.subtask_results
|
||||||
|
assert result.subtask_results["st1"]["status"] == "completed"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_iterations_respected(self):
|
||||||
|
"""Adaptive loop should not exceed max_iterations."""
|
||||||
|
# Create a pool where agent always fails
|
||||||
|
mock_agent = AsyncMock()
|
||||||
|
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
|
||||||
|
|
||||||
|
pool = MagicMock()
|
||||||
|
pool.get_agent = lambda name: mock_agent
|
||||||
|
pool.list_agents = lambda: [
|
||||||
|
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
|
||||||
|
]
|
||||||
|
|
||||||
|
config = OrchestratorConfig(
|
||||||
|
adaptive=True,
|
||||||
|
max_iterations=2,
|
||||||
|
quality_threshold=0.9,
|
||||||
|
)
|
||||||
|
orch = Orchestrator(agent_pool=pool, config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await orch.execute_adaptive(task)
|
||||||
|
# Should have attempted iterations
|
||||||
|
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)
|
||||||
Loading…
Reference in New Issue