237 lines
7.9 KiB
Python
237 lines
7.9 KiB
Python
"""Tests for Orchestrator adaptive task decomposition (U5)."""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.core.orchestrator import (
|
|
Orchestrator,
|
|
OrchestratorConfig,
|
|
OrchestrationResult,
|
|
SubTaskStatus,
|
|
)
|
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
|
|
|
|
|
# ── Test Helpers ──────────────────────────────────────────
|
|
|
|
|
|
def _make_task(**overrides) -> TaskMessage:
|
|
defaults = {
|
|
"task_id": "task-001",
|
|
"agent_name": "test_agent",
|
|
"task_type": "analyze",
|
|
"priority": 0,
|
|
"input_data": {"query": "test"},
|
|
"callback_url": None,
|
|
"created_at": datetime.now(timezone.utc),
|
|
"timeout_seconds": 60,
|
|
}
|
|
defaults.update(overrides)
|
|
return TaskMessage(**defaults)
|
|
|
|
|
|
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
|
|
"""Create a mock AgentPool."""
|
|
pool = MagicMock()
|
|
|
|
if agents:
|
|
pool.get_agent = lambda name: agents.get(name)
|
|
pool.list_agents = lambda: [
|
|
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
|
|
for name in agents
|
|
]
|
|
else:
|
|
# Default: single agent that succeeds
|
|
mock_agent = AsyncMock()
|
|
mock_result = MagicMock()
|
|
mock_result.output_data = {"result": "done"}
|
|
mock_agent.execute = AsyncMock(return_value=mock_result)
|
|
|
|
pool.get_agent = lambda name: mock_agent
|
|
pool.list_agents = lambda: [
|
|
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
|
|
]
|
|
|
|
return pool
|
|
|
|
|
|
# ── OrchestratorConfig Tests ──────────────────────────────
|
|
|
|
|
|
class TestOrchestratorConfig:
|
|
def test_default_values(self):
|
|
config = OrchestratorConfig()
|
|
assert config.adaptive is False
|
|
assert config.max_iterations == 3
|
|
assert config.quality_threshold == 0.7
|
|
|
|
def test_custom_values(self):
|
|
config = OrchestratorConfig(
|
|
adaptive=True,
|
|
max_iterations=5,
|
|
quality_threshold=0.9,
|
|
)
|
|
assert config.adaptive is True
|
|
assert config.max_iterations == 5
|
|
assert config.quality_threshold == 0.9
|
|
|
|
|
|
# ── Adaptive Execution Tests ─────────────────────────────
|
|
|
|
|
|
class TestOrchestratorAdaptive:
|
|
@pytest.mark.asyncio
|
|
async def test_adaptive_false_behaves_like_execute(self):
|
|
"""When adaptive=False, execute_adaptive should behave like execute."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=False)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
task = _make_task()
|
|
result = await orch.execute_adaptive(task)
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rule_based_evaluate_all_completed(self):
|
|
"""All completed subtasks should score 1.0."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=True)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
result = OrchestrationResult(
|
|
plan_id="p1",
|
|
parent_task_id="t1",
|
|
subtask_results={
|
|
"st1": {"status": "completed", "output": "ok"},
|
|
"st2": {"status": "completed", "output": "ok"},
|
|
},
|
|
aggregated_result={},
|
|
status=TaskStatus.COMPLETED,
|
|
total_duration_ms=100,
|
|
)
|
|
|
|
quality = orch._rule_based_evaluate(result)
|
|
assert quality["score"] == 1.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rule_based_evaluate_partial(self):
|
|
"""Partial completion should score proportionally."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=True)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
result = OrchestrationResult(
|
|
plan_id="p1",
|
|
parent_task_id="t1",
|
|
subtask_results={
|
|
"st1": {"status": "completed", "output": "ok"},
|
|
"st2": {"status": "failed", "error": "bad"},
|
|
},
|
|
aggregated_result={},
|
|
status=TaskStatus.COMPLETED,
|
|
total_duration_ms=100,
|
|
)
|
|
|
|
quality = orch._rule_based_evaluate(result)
|
|
assert quality["score"] == 0.5
|
|
assert "st2" in quality["feedback"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rule_based_evaluate_empty(self):
|
|
"""No subtasks should score 0.0."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=True)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
result = OrchestrationResult(
|
|
plan_id="p1",
|
|
parent_task_id="t1",
|
|
subtask_results={},
|
|
aggregated_result={},
|
|
status=TaskStatus.FAILED,
|
|
total_duration_ms=100,
|
|
)
|
|
|
|
quality = orch._rule_based_evaluate(result)
|
|
assert quality["score"] == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adaptive_first_round_pass(self):
|
|
"""If first round quality passes, return directly."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(
|
|
adaptive=True,
|
|
quality_threshold=0.5,
|
|
)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
task = _make_task()
|
|
result = await orch.execute_adaptive(task)
|
|
# All subtasks complete in mock, so quality = 1.0 >= 0.5
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_orchestration_result_metadata(self):
|
|
"""OrchestrationResult should have metadata field."""
|
|
result = OrchestrationResult(
|
|
plan_id="p1",
|
|
parent_task_id="t1",
|
|
subtask_results={},
|
|
aggregated_result={},
|
|
status=TaskStatus.COMPLETED,
|
|
total_duration_ms=100,
|
|
)
|
|
assert result.metadata == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reexecute_failed_preserves_completed(self):
|
|
"""_reexecute_failed should keep completed subtask results."""
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=True)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
task = _make_task()
|
|
previous = OrchestrationResult(
|
|
plan_id="p1",
|
|
parent_task_id="task-001",
|
|
subtask_results={
|
|
"st1": {"status": "completed", "output": "ok"},
|
|
"st2": {"status": "failed", "error": "bad"},
|
|
},
|
|
aggregated_result={},
|
|
status=TaskStatus.COMPLETED,
|
|
total_duration_ms=100,
|
|
)
|
|
quality = {"score": 0.5, "feedback": "Fix st2"}
|
|
|
|
result = await orch._reexecute_failed(task, previous, quality)
|
|
# st1 should be preserved
|
|
assert "st1" in result.subtask_results
|
|
assert result.subtask_results["st1"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_iterations_respected(self):
|
|
"""Adaptive loop should not exceed max_iterations."""
|
|
# Create a pool where agent always fails
|
|
mock_agent = AsyncMock()
|
|
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
|
|
|
|
pool = MagicMock()
|
|
pool.get_agent = lambda name: mock_agent
|
|
pool.list_agents = lambda: [
|
|
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
|
|
]
|
|
|
|
config = OrchestratorConfig(
|
|
adaptive=True,
|
|
max_iterations=2,
|
|
quality_threshold=0.9,
|
|
)
|
|
orch = Orchestrator(agent_pool=pool, config=config)
|
|
|
|
task = _make_task()
|
|
result = await orch.execute_adaptive(task)
|
|
# Should have attempted iterations
|
|
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)
|