"""Tests for Orchestrator and SharedWorkspace""" import asyncio import pytest from agentkit.core.orchestrator import ( Orchestrator, OrchestrationPlan, OrchestrationResult, SubTask, SubTaskStatus, AgentRole, ) from agentkit.core.shared_workspace import SharedWorkspace from agentkit.core.protocol import TaskMessage, TaskStatus from datetime import datetime, timezone # --- SharedWorkspace Tests --- class TestSharedWorkspace: """SharedWorkspace unit tests (in-memory mode)""" @pytest.mark.asyncio async def test_write_and_read(self): ws = SharedWorkspace() version = await ws.write("key1", {"data": "value"}, agent_id="agent_a") assert version == 1 result = await ws.read("key1") assert result is not None assert result["value"] == {"data": "value"} assert result["agent_id"] == "agent_a" assert result["version"] == 1 @pytest.mark.asyncio async def test_version_increments(self): ws = SharedWorkspace() v1 = await ws.write("key1", "first", agent_id="a") v2 = await ws.write("key1", "second", agent_id="b") assert v1 == 1 assert v2 == 2 @pytest.mark.asyncio async def test_read_nonexistent(self): ws = SharedWorkspace() result = await ws.read("nonexistent") assert result is None @pytest.mark.asyncio async def test_delete(self): ws = SharedWorkspace() await ws.write("key1", "value", agent_id="a") deleted = await ws.delete("key1") assert deleted is True result = await ws.read("key1") assert result is None @pytest.mark.asyncio async def test_delete_nonexistent(self): ws = SharedWorkspace() deleted = await ws.delete("nonexistent") assert deleted is False @pytest.mark.asyncio async def test_lock_and_unlock(self): ws = SharedWorkspace() acquired = await ws.lock("resource1", agent_id="agent_a") assert acquired is True # Same agent can't lock again (already held) acquired2 = await ws.lock("resource1", agent_id="agent_b") assert acquired2 is False # Owner can unlock unlocked = await ws.unlock("resource1", agent_id="agent_a") assert unlocked is True # Now another agent can lock acquired3 = await ws.lock("resource1", agent_id="agent_b") assert acquired3 is True @pytest.mark.asyncio async def test_unlock_by_non_owner(self): ws = SharedWorkspace() await ws.lock("resource1", agent_id="agent_a") unlocked = await ws.unlock("resource1", agent_id="agent_b") assert unlocked is False @pytest.mark.asyncio async def test_list_keys(self): ws = SharedWorkspace() await ws.write("key1", "v1", agent_id="a") await ws.write("key2", "v2", agent_id="a") keys = await ws.list_keys() assert set(keys) == {"key1", "key2"} # --- Orchestrator Tests --- class MockAgent: """Mock Agent for testing""" def __init__(self, name: str, output_data: dict | None = None, should_fail: bool = False): self.name = name self.agent_type = "mock" self.version = "1.0.0" self._output_data = output_data or {"result": f"output from {name}"} self._should_fail = should_fail async def execute(self, task: TaskMessage): if self._should_fail: raise RuntimeError(f"Agent {self.name} failed") from agentkit.core.protocol import TaskResult now = datetime.now(timezone.utc) return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.COMPLETED, output_data=self._output_data, error_message=None, started_at=now, completed_at=now, ) class MockAgentPool: """Mock AgentPool for testing""" def __init__(self, agents: dict[str, MockAgent] | None = None): self._agents = agents or {} def get_agent(self, name: str) -> MockAgent | None: return self._agents.get(name) def list_agents(self) -> list[dict]: return [ {"name": a.name, "agent_type": a.agent_type, "description": f"Mock agent {a.name}"} for a in self._agents.values() ] class TestOrchestrator: """Orchestrator unit tests""" @pytest.mark.asyncio async def test_single_subtask_execution(self): """Single agent should execute task directly""" agent = MockAgent("worker1", {"analysis": "result"}) pool = MockAgentPool({"worker1": agent}) orchestrator = Orchestrator(agent_pool=pool) task = TaskMessage( task_id="t1", agent_name="worker1", task_type="analyze", priority=1, input_data={"query": "test"}, callback_url=None, created_at=datetime.now(timezone.utc), ) result = await orchestrator.execute(task) assert result.status == TaskStatus.COMPLETED assert "outputs" in result.aggregated_result @pytest.mark.asyncio async def test_parallel_groups_building(self): """Parallel groups should be built from dependency graph""" pool = MockAgentPool() orchestrator = Orchestrator(agent_pool=pool) subtasks = [ SubTask(task_id="t0", parent_task_id="p1", assigned_agent="a1", task_type="type1", input_data={}, depends_on=[]), SubTask(task_id="t1", parent_task_id="p1", assigned_agent="a2", task_type="type2", input_data={}, depends_on=[]), SubTask(task_id="t2", parent_task_id="p1", assigned_agent="a3", task_type="type3", input_data={}, depends_on=["t0"]), ] groups = orchestrator._build_parallel_groups(subtasks) assert len(groups) == 2 # First group: t0 and t1 (no dependencies) assert set(groups[0]) == {"t0", "t1"} # Second group: t2 (depends on t0) assert groups[1] == ["t2"] @pytest.mark.asyncio async def test_sequential_dependency(self): """Tasks with sequential dependencies should execute in order""" agent1 = MockAgent("a1", {"step": 1}) agent2 = MockAgent("a2", {"step": 2}) pool = MockAgentPool({"a1": agent1, "a2": agent2}) orchestrator = Orchestrator(agent_pool=pool) # Manually create a plan with sequential dependencies plan = OrchestrationPlan( plan_id="p1", parent_task_id="parent", subtasks=[ SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1", task_type="step1", input_data={}, depends_on=[]), SubTask(task_id="t1", parent_task_id="parent", assigned_agent="a2", task_type="step2", input_data={}, depends_on=["t0"]), ], parallel_groups=[["t0"], ["t1"]], ) task = TaskMessage( task_id="parent", agent_name="orchestrator", task_type="pipeline", priority=1, input_data={"query": "test"}, callback_url=None, created_at=datetime.now(timezone.utc), ) results = await orchestrator._execute_plan(plan, task) assert results["t0"]["status"] == "completed" assert results["t1"]["status"] == "completed" @pytest.mark.asyncio async def test_agent_not_found(self): """Missing agent should result in failed subtask""" pool = MockAgentPool({}) orchestrator = Orchestrator(agent_pool=pool) plan = OrchestrationPlan( plan_id="p1", parent_task_id="parent", subtasks=[ SubTask(task_id="t0", parent_task_id="parent", assigned_agent="missing_agent", task_type="test", input_data={}, depends_on=[]), ], parallel_groups=[["t0"]], ) task = TaskMessage( task_id="parent", agent_name="orchestrator", task_type="test", priority=1, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), ) results = await orchestrator._execute_plan(plan, task) assert results["t0"]["status"] == "failed" assert "not found" in results["t0"]["error"] @pytest.mark.asyncio async def test_dependency_result_injection(self): """Subtask should receive dependency results in input""" pool = MockAgentPool() orchestrator = Orchestrator(agent_pool=pool) subtask = SubTask( task_id="t1", parent_task_id="p1", assigned_agent="a1", task_type="test", input_data={"query": "test"}, depends_on=["t0"], ) subtask_results = { "t0": {"status": "completed", "output": {"step1_result": "data"}}, } enriched = orchestrator._inject_dependency_results(subtask, subtask_results) assert "dependency_results" in enriched assert "t0" in enriched["dependency_results"] @pytest.mark.asyncio async def test_aggregation_with_errors(self): """Aggregation should include errors for failed subtasks""" pool = MockAgentPool() orchestrator = Orchestrator(agent_pool=pool) plan = OrchestrationPlan( plan_id="p1", parent_task_id="parent", subtasks=[ SubTask(task_id="t0", parent_task_id="parent", assigned_agent="a1", task_type="test", input_data={}, depends_on=[]), ], parallel_groups=[["t0"]], ) subtask_results = { "t0": {"status": "failed", "error": "Agent failed"}, } task = TaskMessage( task_id="parent", agent_name="orchestrator", task_type="test", priority=1, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), ) aggregated = await orchestrator._aggregate_results(plan, subtask_results, task) assert "errors" in aggregated assert aggregated["partial_success"] is True class TestAgentRole: """AgentRole enum tests""" def test_role_values(self): assert AgentRole.ORCHESTRATOR.value == "orchestrator" assert AgentRole.WORKER.value == "worker" assert AgentRole.REVIEWER.value == "reviewer" class TestSubTask: """SubTask dataclass tests""" def test_default_values(self): st = SubTask( task_id="t1", parent_task_id="p1", assigned_agent="a1", task_type="test", input_data={}, ) assert st.status == SubTaskStatus.PENDING assert st.result is None assert st.depends_on == []