fischer-agentkit/tests/unit/test_orchestrator.py

337 lines
11 KiB
Python

"""Tests for Orchestrator and SharedWorkspace"""
import asyncio
import pytest
from agentkit.core.orchestrator import (
Orchestrator,
OrchestrationPlan,
OrchestrationResult,
SubTask,
SubTaskStatus,
AgentRole,
)
from agentkit.core.shared_workspace import SharedWorkspace
from agentkit.core.protocol import TaskMessage, TaskStatus
from datetime import datetime, timezone
# --- SharedWorkspace Tests ---
class TestSharedWorkspace:
"""SharedWorkspace unit tests (in-memory mode)"""
@pytest.mark.asyncio
async def test_write_and_read(self):
ws = SharedWorkspace()
version = await ws.write("key1", {"data": "value"}, agent_id="agent_a")
assert version == 1
result = await ws.read("key1")
assert result is not None
assert result["value"] == {"data": "value"}
assert result["agent_id"] == "agent_a"
assert result["version"] == 1
@pytest.mark.asyncio
async def test_version_increments(self):
ws = SharedWorkspace()
v1 = await ws.write("key1", "first", agent_id="a")
v2 = await ws.write("key1", "second", agent_id="b")
assert v1 == 1
assert v2 == 2
@pytest.mark.asyncio
async def test_read_nonexistent(self):
ws = SharedWorkspace()
result = await ws.read("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_delete(self):
ws = SharedWorkspace()
await ws.write("key1", "value", agent_id="a")
deleted = await ws.delete("key1")
assert deleted is True
result = await ws.read("key1")
assert result is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self):
ws = SharedWorkspace()
deleted = await ws.delete("nonexistent")
assert deleted is False
@pytest.mark.asyncio
async def test_lock_and_unlock(self):
ws = SharedWorkspace()
acquired = await ws.lock("resource1", agent_id="agent_a")
assert acquired is True
# Same agent can't lock again (already held)
acquired2 = await ws.lock("resource1", agent_id="agent_b")
assert acquired2 is False
# Owner can unlock
unlocked = await ws.unlock("resource1", agent_id="agent_a")
assert unlocked is True
# Now another agent can lock
acquired3 = await ws.lock("resource1", agent_id="agent_b")
assert acquired3 is True
@pytest.mark.asyncio
async def test_unlock_by_non_owner(self):
ws = SharedWorkspace()
await ws.lock("resource1", agent_id="agent_a")
unlocked = await ws.unlock("resource1", agent_id="agent_b")
assert unlocked is False
@pytest.mark.asyncio
async def test_list_keys(self):
ws = SharedWorkspace()
await ws.write("key1", "v1", agent_id="a")
await ws.write("key2", "v2", agent_id="a")
keys = await ws.list_keys()
assert set(keys) == {"key1", "key2"}
# --- Orchestrator Tests ---
class MockAgent:
"""Mock Agent for testing"""
def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False):
self.name = name
self.agent_type = "mock"
self.version = "1.0.0"
self._output_data = output_data or {"result": f"output from {name}"}
self._should_fail = should_fail
async def execute(self, task: TaskMessage):
if self._should_fail:
raise RuntimeError(f"Agent {self.name} failed")
from agentkit.core.protocol import TaskResult
now = datetime.now(timezone.utc)
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=self._output_data,
error_message=None,
started_at=now,
completed_at=now,
)
class MockAgentPool:
"""Mock AgentPool for testing"""
def __init__(self, agents: dict[str, MockAgent] | None = None):
self._agents = agents or {}
def get_agent(self, name: str) -> MockAgent | None:
return self._agents.get(name)
def list_agents(self) -> list[dict]:
return [
{"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"}
for a in self._agents.values()
]
class TestOrchestrator:
"""Orchestrator unit tests"""
@pytest.mark.asyncio
async def test_single_subtask_execution(self):
"""Single agent should execute task directly"""
agent = MockAgent("worker1", {"analysis": "result"})
pool = MockAgentPool({"worker1": agent})
orchestrator = Orchestrator(agent_pool=pool)
task = TaskMessage(
task_id="t1",
agent_name="worker1",
task_type="analyze",
priority=1,
input_data={"query": "test"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await orchestrator.execute(task)
assert result.status == TaskStatus.COMPLETED
assert "outputs" in result.aggregated_result
@pytest.mark.asyncio
async def test_parallel_groups_building(self):
"""Parallel groups should be built from dependency graph"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
subtasks = [
SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1",
task_type="type1", input_data={}, depends_on=[]),
SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2",
task_type="type2", input_data={}, depends_on=[]),
SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3",
task_type="type3", input_data={}, depends_on=["t0"]),
]
groups = orchestrator._build_parallel_groups(subtasks)
assert len(groups) == 2
# First group: t0 and t1 (no dependencies)
assert set(groups[0]) == {"t0", "t1"}
# Second group: t2 (depends on t0)
assert groups[1] == ["t2"]
@pytest.mark.asyncio
async def test_sequential_dependency(self):
"""Tasks with sequential dependencies should execute in order"""
agent1 = MockAgent("a1", {"step": 1})
agent2 = MockAgent("a2", {"step": 2})
pool = MockAgentPool({"a1": agent1, "a2": agent2})
orchestrator = Orchestrator(agent_pool=pool)
# Manually create a plan with sequential dependencies
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
task_type="step1", input_data={}, depends_on=[]),
SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2",
task_type="step2", input_data={}, depends_on=["t0"]),
],
parallel_groups=[["t0"], ["t1"]],
)
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="pipeline",
priority=1,
input_data={"query": "test"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
results = await orchestrator._execute_plan(plan, task)
assert results["t0"]["status"] == "completed"
assert results["t1"]["status"] == "completed"
@pytest.mark.asyncio
async def test_agent_not_found(self):
"""Missing agent should result in failed subtask"""
pool = MockAgentPool({})
orchestrator = Orchestrator(agent_pool=pool)
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent",
task_type="test", input_data={}, depends_on=[]),
],
parallel_groups=[["t0"]],
)
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="test",
priority=1,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
results = await orchestrator._execute_plan(plan, task)
assert results["t0"]["status"] == "failed"
assert "not found" in results["t0"]["error"]
@pytest.mark.asyncio
async def test_dependency_result_injection(self):
"""Subtask should receive dependency results in input"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
subtask = SubTask(
task_id="t1",
parent_task_id="p1",
assigned_agent="a1",
task_type="test",
input_data={"query": "test"},
depends_on=["t0"],
)
subtask_results = {
"t0": {"status": "completed", "output": {"step1_result": "data"}},
}
enriched = orchestrator._inject_dependency_results(subtask, subtask_results)
assert "dependency_results" in enriched
assert "t0" in enriched["dependency_results"]
@pytest.mark.asyncio
async def test_aggregation_with_errors(self):
"""Aggregation should include errors for failed subtasks"""
pool = MockAgentPool()
orchestrator = Orchestrator(agent_pool=pool)
plan = OrchestrationPlan(
plan_id="p1",
parent_task_id="parent",
subtasks=[
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
task_type="test", input_data={}, depends_on=[]),
],
parallel_groups=[["t0"]],
)
subtask_results = {
"t0": {"status": "failed", "error": "Agent failed"},
}
task = TaskMessage(
task_id="parent",
agent_name="orchestrator",
task_type="test",
priority=1,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
aggregated = await orchestrator._aggregate_results(plan, subtask_results, task)
assert "errors" in aggregated
assert aggregated["partial_success"] is True
class TestAgentRole:
"""AgentRole enum tests"""
def test_role_values(self):
assert AgentRole.ORCHESTRATOR.value == "orchestrator"
assert AgentRole.WORKER.value == "worker"
assert AgentRole.REVIEWER.value == "reviewer"
class TestSubTask:
"""SubTask dataclass tests"""
def test_default_values(self):
st = SubTask(
task_id="t1",
parent_task_id="p1",
assigned_agent="a1",
task_type="test",
input_data={},
)
assert st.status == SubTaskStatus.PENDING
assert st.result is None
assert st.depends_on == []