fischer-agentkit/tests/unit/test_orchestrator_adaptive.py

237 lines
7.9 KiB
Python

"""Tests for Orchestrator adaptive task decomposition (U5)."""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock
from agentkit.core.orchestrator import (
Orchestrator,
OrchestratorConfig,
OrchestrationResult,
SubTaskStatus,
)
from agentkit.core.protocol import TaskMessage, TaskStatus
# ── Test Helpers ──────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = {
"task_id": "task-001",
"agent_name": "test_agent",
"task_type": "analyze",
"priority": 0,
"input_data": {"query": "test"},
"callback_url": None,
"created_at": datetime.now(timezone.utc),
"timeout_seconds": 60,
}
defaults.update(overrides)
return TaskMessage(**defaults)
def _make_mock_pool(agents: dict[str, AsyncMock] | None = None):
"""Create a mock AgentPool."""
pool = MagicMock()
if agents:
pool.get_agent = lambda name: agents.get(name)
pool.list_agents = lambda: [
{"name": name, "agent_type": "worker", "description": f"Agent {name}"}
for name in agents
]
else:
# Default: single agent that succeeds
mock_agent = AsyncMock()
mock_result = MagicMock()
mock_result.output_data = {"result": "done"}
mock_agent.execute = AsyncMock(return_value=mock_result)
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
]
return pool
# ── OrchestratorConfig Tests ──────────────────────────────
class TestOrchestratorConfig:
def test_default_values(self):
config = OrchestratorConfig()
assert config.adaptive is False
assert config.max_iterations == 3
assert config.quality_threshold == 0.7
def test_custom_values(self):
config = OrchestratorConfig(
adaptive=True,
max_iterations=5,
quality_threshold=0.9,
)
assert config.adaptive is True
assert config.max_iterations == 5
assert config.quality_threshold == 0.9
# ── Adaptive Execution Tests ─────────────────────────────
class TestOrchestratorAdaptive:
@pytest.mark.asyncio
async def test_adaptive_false_behaves_like_execute(self):
"""When adaptive=False, execute_adaptive should behave like execute."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=False)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_rule_based_evaluate_all_completed(self):
"""All completed subtasks should score 1.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "completed", "output": "ok"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 1.0
@pytest.mark.asyncio
async def test_rule_based_evaluate_partial(self):
"""Partial completion should score proportionally."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.5
assert "st2" in quality["feedback"]
@pytest.mark.asyncio
async def test_rule_based_evaluate_empty(self):
"""No subtasks should score 0.0."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.FAILED,
total_duration_ms=100,
)
quality = orch._rule_based_evaluate(result)
assert quality["score"] == 0.0
@pytest.mark.asyncio
async def test_adaptive_first_round_pass(self):
"""If first round quality passes, return directly."""
pool = _make_mock_pool()
config = OrchestratorConfig(
adaptive=True,
quality_threshold=0.5,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# All subtasks complete in mock, so quality = 1.0 >= 0.5
assert result.status == TaskStatus.COMPLETED
@pytest.mark.asyncio
async def test_orchestration_result_metadata(self):
"""OrchestrationResult should have metadata field."""
result = OrchestrationResult(
plan_id="p1",
parent_task_id="t1",
subtask_results={},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
assert result.metadata == {}
@pytest.mark.asyncio
async def test_reexecute_failed_preserves_completed(self):
"""_reexecute_failed should keep completed subtask results."""
pool = _make_mock_pool()
config = OrchestratorConfig(adaptive=True)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
previous = OrchestrationResult(
plan_id="p1",
parent_task_id="task-001",
subtask_results={
"st1": {"status": "completed", "output": "ok"},
"st2": {"status": "failed", "error": "bad"},
},
aggregated_result={},
status=TaskStatus.COMPLETED,
total_duration_ms=100,
)
quality = {"score": 0.5, "feedback": "Fix st2"}
result = await orch._reexecute_failed(task, previous, quality)
# st1 should be preserved
assert "st1" in result.subtask_results
assert result.subtask_results["st1"]["status"] == "completed"
@pytest.mark.asyncio
async def test_max_iterations_respected(self):
"""Adaptive loop should not exceed max_iterations."""
# Create a pool where agent always fails
mock_agent = AsyncMock()
mock_agent.execute = AsyncMock(side_effect=RuntimeError("always fails"))
pool = MagicMock()
pool.get_agent = lambda name: mock_agent
pool.list_agents = lambda: [
{"name": "test_agent", "agent_type": "worker", "description": "Test"}
]
config = OrchestratorConfig(
adaptive=True,
max_iterations=2,
quality_threshold=0.9,
)
orch = Orchestrator(agent_pool=pool, config=config)
task = _make_task()
result = await orch.execute_adaptive(task)
# Should have attempted iterations
assert result.status in (TaskStatus.FAILED, TaskStatus.COMPLETED)