337 lines
11 KiB
Python
337 lines
11 KiB
Python
"""Tests for Orchestrator and SharedWorkspace"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
|
|
from agentkit.core.orchestrator import (
|
|
Orchestrator,
|
|
OrchestrationPlan,
|
|
OrchestrationResult,
|
|
SubTask,
|
|
SubTaskStatus,
|
|
AgentRole,
|
|
)
|
|
from agentkit.core.shared_workspace import SharedWorkspace
|
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
# --- SharedWorkspace Tests ---
|
|
|
|
|
|
class TestSharedWorkspace:
|
|
"""SharedWorkspace unit tests (in-memory mode)"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_write_and_read(self):
|
|
ws = SharedWorkspace()
|
|
version = await ws.write("key1", {"data": "value"}, agent_id="agent_a")
|
|
assert version == 1
|
|
|
|
result = await ws.read("key1")
|
|
assert result is not None
|
|
assert result["value"] == {"data": "value"}
|
|
assert result["agent_id"] == "agent_a"
|
|
assert result["version"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_version_increments(self):
|
|
ws = SharedWorkspace()
|
|
v1 = await ws.write("key1", "first", agent_id="a")
|
|
v2 = await ws.write("key1", "second", agent_id="b")
|
|
assert v1 == 1
|
|
assert v2 == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_read_nonexistent(self):
|
|
ws = SharedWorkspace()
|
|
result = await ws.read("nonexistent")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete(self):
|
|
ws = SharedWorkspace()
|
|
await ws.write("key1", "value", agent_id="a")
|
|
deleted = await ws.delete("key1")
|
|
assert deleted is True
|
|
result = await ws.read("key1")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_nonexistent(self):
|
|
ws = SharedWorkspace()
|
|
deleted = await ws.delete("nonexistent")
|
|
assert deleted is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lock_and_unlock(self):
|
|
ws = SharedWorkspace()
|
|
acquired = await ws.lock("resource1", agent_id="agent_a")
|
|
assert acquired is True
|
|
|
|
# Same agent can't lock again (already held)
|
|
acquired2 = await ws.lock("resource1", agent_id="agent_b")
|
|
assert acquired2 is False
|
|
|
|
# Owner can unlock
|
|
unlocked = await ws.unlock("resource1", agent_id="agent_a")
|
|
assert unlocked is True
|
|
|
|
# Now another agent can lock
|
|
acquired3 = await ws.lock("resource1", agent_id="agent_b")
|
|
assert acquired3 is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unlock_by_non_owner(self):
|
|
ws = SharedWorkspace()
|
|
await ws.lock("resource1", agent_id="agent_a")
|
|
unlocked = await ws.unlock("resource1", agent_id="agent_b")
|
|
assert unlocked is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys(self):
|
|
ws = SharedWorkspace()
|
|
await ws.write("key1", "v1", agent_id="a")
|
|
await ws.write("key2", "v2", agent_id="a")
|
|
keys = await ws.list_keys()
|
|
assert set(keys) == {"key1", "key2"}
|
|
|
|
|
|
# --- Orchestrator Tests ---
|
|
|
|
|
|
class MockAgent:
|
|
"""Mock Agent for testing"""
|
|
|
|
def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False):
|
|
self.name = name
|
|
self.agent_type = "mock"
|
|
self.version = "1.0.0"
|
|
self._output_data = output_data or {"result": f"output from {name}"}
|
|
self._should_fail = should_fail
|
|
|
|
async def execute(self, task: TaskMessage):
|
|
if self._should_fail:
|
|
raise RuntimeError(f"Agent {self.name} failed")
|
|
from agentkit.core.protocol import TaskResult
|
|
now = datetime.now(timezone.utc)
|
|
return TaskResult(
|
|
task_id=task.task_id,
|
|
agent_name=self.name,
|
|
status=TaskStatus.COMPLETED,
|
|
output_data=self._output_data,
|
|
error_message=None,
|
|
started_at=now,
|
|
completed_at=now,
|
|
)
|
|
|
|
|
|
class MockAgentPool:
|
|
"""Mock AgentPool for testing"""
|
|
|
|
def __init__(self, agents: dict[str, MockAgent] | None = None):
|
|
self._agents = agents or {}
|
|
|
|
def get_agent(self, name: str) -> MockAgent | None:
|
|
return self._agents.get(name)
|
|
|
|
def list_agents(self) -> list[dict]:
|
|
return [
|
|
{"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"}
|
|
for a in self._agents.values()
|
|
]
|
|
|
|
|
|
class TestOrchestrator:
|
|
"""Orchestrator unit tests"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_single_subtask_execution(self):
|
|
"""Single agent should execute task directly"""
|
|
agent = MockAgent("worker1", {"analysis": "result"})
|
|
pool = MockAgentPool({"worker1": agent})
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
task = TaskMessage(
|
|
task_id="t1",
|
|
agent_name="worker1",
|
|
task_type="analyze",
|
|
priority=1,
|
|
input_data={"query": "test"},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
result = await orchestrator.execute(task)
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert "outputs" in result.aggregated_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parallel_groups_building(self):
|
|
"""Parallel groups should be built from dependency graph"""
|
|
pool = MockAgentPool()
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
subtasks = [
|
|
SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1",
|
|
task_type="type1", input_data={}, depends_on=[]),
|
|
SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2",
|
|
task_type="type2", input_data={}, depends_on=[]),
|
|
SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3",
|
|
task_type="type3", input_data={}, depends_on=["t0"]),
|
|
]
|
|
|
|
groups = orchestrator._build_parallel_groups(subtasks)
|
|
assert len(groups) == 2
|
|
# First group: t0 and t1 (no dependencies)
|
|
assert set(groups[0]) == {"t0", "t1"}
|
|
# Second group: t2 (depends on t0)
|
|
assert groups[1] == ["t2"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sequential_dependency(self):
|
|
"""Tasks with sequential dependencies should execute in order"""
|
|
agent1 = MockAgent("a1", {"step": 1})
|
|
agent2 = MockAgent("a2", {"step": 2})
|
|
pool = MockAgentPool({"a1": agent1, "a2": agent2})
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
# Manually create a plan with sequential dependencies
|
|
plan = OrchestrationPlan(
|
|
plan_id="p1",
|
|
parent_task_id="parent",
|
|
subtasks=[
|
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
|
task_type="step1", input_data={}, depends_on=[]),
|
|
SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2",
|
|
task_type="step2", input_data={}, depends_on=["t0"]),
|
|
],
|
|
parallel_groups=[["t0"], ["t1"]],
|
|
)
|
|
|
|
task = TaskMessage(
|
|
task_id="parent",
|
|
agent_name="orchestrator",
|
|
task_type="pipeline",
|
|
priority=1,
|
|
input_data={"query": "test"},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
results = await orchestrator._execute_plan(plan, task)
|
|
assert results["t0"]["status"] == "completed"
|
|
assert results["t1"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_not_found(self):
|
|
"""Missing agent should result in failed subtask"""
|
|
pool = MockAgentPool({})
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
plan = OrchestrationPlan(
|
|
plan_id="p1",
|
|
parent_task_id="parent",
|
|
subtasks=[
|
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent",
|
|
task_type="test", input_data={}, depends_on=[]),
|
|
],
|
|
parallel_groups=[["t0"]],
|
|
)
|
|
|
|
task = TaskMessage(
|
|
task_id="parent",
|
|
agent_name="orchestrator",
|
|
task_type="test",
|
|
priority=1,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
results = await orchestrator._execute_plan(plan, task)
|
|
assert results["t0"]["status"] == "failed"
|
|
assert "not found" in results["t0"]["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_dependency_result_injection(self):
|
|
"""Subtask should receive dependency results in input"""
|
|
pool = MockAgentPool()
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
subtask = SubTask(
|
|
task_id="t1",
|
|
parent_task_id="p1",
|
|
assigned_agent="a1",
|
|
task_type="test",
|
|
input_data={"query": "test"},
|
|
depends_on=["t0"],
|
|
)
|
|
|
|
subtask_results = {
|
|
"t0": {"status": "completed", "output": {"step1_result": "data"}},
|
|
}
|
|
|
|
enriched = orchestrator._inject_dependency_results(subtask, subtask_results)
|
|
assert "dependency_results" in enriched
|
|
assert "t0" in enriched["dependency_results"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aggregation_with_errors(self):
|
|
"""Aggregation should include errors for failed subtasks"""
|
|
pool = MockAgentPool()
|
|
orchestrator = Orchestrator(agent_pool=pool)
|
|
|
|
plan = OrchestrationPlan(
|
|
plan_id="p1",
|
|
parent_task_id="parent",
|
|
subtasks=[
|
|
SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1",
|
|
task_type="test", input_data={}, depends_on=[]),
|
|
],
|
|
parallel_groups=[["t0"]],
|
|
)
|
|
|
|
subtask_results = {
|
|
"t0": {"status": "failed", "error": "Agent failed"},
|
|
}
|
|
|
|
task = TaskMessage(
|
|
task_id="parent",
|
|
agent_name="orchestrator",
|
|
task_type="test",
|
|
priority=1,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
aggregated = await orchestrator._aggregate_results(plan, subtask_results, task)
|
|
assert "errors" in aggregated
|
|
assert aggregated["partial_success"] is True
|
|
|
|
|
|
class TestAgentRole:
|
|
"""AgentRole enum tests"""
|
|
|
|
def test_role_values(self):
|
|
assert AgentRole.ORCHESTRATOR.value == "orchestrator"
|
|
assert AgentRole.WORKER.value == "worker"
|
|
assert AgentRole.REVIEWER.value == "reviewer"
|
|
|
|
|
|
class TestSubTask:
|
|
"""SubTask dataclass tests"""
|
|
|
|
def test_default_values(self):
|
|
st = SubTask(
|
|
task_id="t1",
|
|
parent_task_id="p1",
|
|
assigned_agent="a1",
|
|
task_type="test",
|
|
input_data={},
|
|
)
|
|
assert st.status == SubTaskStatus.PENDING
|
|
assert st.result is None
|
|
assert st.depends_on == []
|