"""Tests for Orchestrator adaptive task decomposition (U5).""" import pytest from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock from agentkit.core.orchestrator import ( Orchestrator, OrchestratorConfig, OrchestrationResult, SubTaskStatus, ) from agentkit.core.protocol import TaskMessage, TaskStatus # ── Test Helpers ────────────────────────────────────────── def _make_task(**overrides) -> TaskMessage: defaults = { "task_id": "task-001", "agent_name": "test_agent", "task_type": "analyze", "priority": 0, "input_data": {"query": "test"}, "callback_url": None, "created_at": datetime.now(timezone.utc), "timeout_seconds": 60, } defaults.update(overrides) return TaskMessage(**defaults) def _make_mock_pool(agents: dict[str, AsyncMock] | None = None): """Create a mock AgentPool.""" pool = MagicMock() if agents: pool.get_agent = lambda name: agents.get(name) pool.list_agents = lambda: [ {"name": name, "agent_type": "worker", "description": f"Agent {name}"} for name in agents ] else: # Default: single agent that succeeds mock_agent = AsyncMock() mock_result = MagicMock() mock_result.output_data = {"result": "done"} mock_agent.execute = AsyncMock(return_value=mock_result) pool.get_agent = lambda name: mock_agent pool.list_agents = lambda: [ {"name": "test_agent", "agent_type": "worker", "description": "Test agent"} ] return pool # ── OrchestratorConfig Tests ────────────────────────────── class TestOrchestratorConfig: def test_default_values(self): config = OrchestratorConfig() assert config.adaptive is False assert config.max_iterations == 3 assert config.quality_threshold == 0.7 def test_custom_values(self): config = OrchestratorConfig( adaptive=True, max_iterations=5, quality_threshold=0.9, ) assert config.adaptive is True assert config.max_iterations == 5 assert config.quality_threshold == 0.9 # ── Adaptive Execution Tests ───────────────────────────── class TestOrchestratorAdaptive: @pytest.mark.asyncio async def test_adaptive_false_behaves_like_execute(self): """When adaptive=False, execute_adaptive should behave like execute.""" pool = _make_mock_pool() config = OrchestratorConfig(adaptive=False) orch = Orchestrator(agent_pool=pool, config=config) task = _make_task() result = await orch.execute_adaptive(task) assert result.status == TaskStatus.COMPLETED @pytest.mark.asyncio async def test_rule_based_evaluate_all_completed(self): """All completed subtasks should score 1.0.""" pool = _make_mock_pool() config = OrchestratorConfig(adaptive=True) orch = Orchestrator(agent_pool=pool, config=config) result = OrchestrationResult( plan_id="p1", parent_task_id="t1", subtask_results={ "st1": {"status": "completed", "output": "ok"}, "st2": {"status": "completed", "output": "ok"}, }, aggregated_result={}, status=TaskStatus.COMPLETED, total_duration_ms=100, ) quality = orch._rule_based_evaluate(result) assert quality["score"] == 1.0 @pytest.mark.asyncio async def test_rule_based_evaluate_partial(self): """Partial completion should score proportionally.""" pool = _make_mock_pool() config = OrchestratorConfig(adaptive=True) orch = Orchestrator(agent_pool=pool, config=config) result = OrchestrationResult( plan_id="p1", parent_task_id="t1", subtask_results={ "st1": {"status": "completed", "output": "ok"}, "st2": {"status": "failed", "error": "bad"}, }, aggregated_result={}, status=TaskStatus.COMPLETED, total_duration_ms=100, ) quality = orch._rule_based_evaluate(result) assert quality["score"] == 0.5 assert "st2" in quality["feedback"] @pytest.mark.asyncio async def test_rule_based_evaluate_empty(self): """No subtasks should score 0.0.""" pool = _make_mock_pool() config = OrchestratorConfig(adaptive=True) orch = Orchestrator(agent_pool=pool, config=config) result = OrchestrationResult( plan_id="p1", parent_task_id="t1", subtask_results={}, aggregated_result={}, status=TaskStatus.FAILED, total_duration_ms=100, ) quality = orch._rule_based_evaluate(result) assert quality["score"] == 0.0 @pytest.mark.asyncio async def test_adaptive_first_round_pass(self): """If first round quality passes, return directly.""" pool = _make_mock_pool() config = OrchestratorConfig( adaptive=True, quality_threshold=0.5, ) orch = Orchestrator(agent_pool=pool, config=config) task = _make_task() result = await orch.execute_adaptive(task) # All subtasks complete in mock, so quality = 1.0 >= 0.5 assert result.status == TaskStatus.COMPLETED @pytest.mark.asyncio async def test_orchestration_result_metadata(self): """OrchestrationResult should have metadata field.""" result = OrchestrationResult( plan_id="p1", parent_task_id="t1", subtask_results={}, aggregated_result={}, status=TaskStatus.COMPLETED, total_duration_ms=100, ) assert result.metadata == {} @pytest.mark.asyncio async def test_reexecute_failed_preserves_completed(self): """_reexecute_failed should keep completed subtask results.""" pool = _make_mock_pool() config = OrchestratorConfig(adaptive=True) orch = Orchestrator(agent_pool=pool, config=config) task = _make_task() previous = OrchestrationResult( plan_id="p1", parent_task_id="task-001", subtask_results={ "st1": {"status": "completed", "output": "ok"}, "st2": {"status": "failed", "error": "bad"}, }, aggregated_result={}, status=TaskStatus.COMPLETED, total_duration_ms=100, ) quality = {"score": 0.5, "feedback": "Fix st2"} result = await orch._reexecute_failed(task, previous, quality) # st1 should be preserved assert "st1" in result.subtask_results assert result.subtask_results["st1"]["status"] == "completed" @pytest.mark.asyncio async def test_max_iterations_respected(self): """Adaptive loop should not exceed max_iterations.""" # Create a pool where agent always fails mock_agent = AsyncMock() mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails")) pool = MagicMock() pool.get_agent = lambda name: mock_agent pool.list_agents = lambda: [ {"name": "test_agent", "agent_type": "worker", "description": "Test"} ] config = OrchestratorConfig( adaptive=True, max_iterations=2, quality_threshold=0.9, ) orch = Orchestrator(agent_pool=pool, config=config) task = _make_task() result = await orch.execute_adaptive(task) # Should have attempted iterations assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)