119 lines
3.9 KiB
Python
119 lines
3.9 KiB
Python
"""Tests for Orchestrator + MessageBus integration (U7)."""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.bus.message import AgentMessage
|
|
from agentkit.bus.memory_bus import InMemoryMessageBus
|
|
from agentkit.core.orchestrator import Orchestrator, OrchestratorConfig
|
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
|
|
|
|
|
def _make_task(**overrides) -> TaskMessage:
|
|
defaults = {
|
|
"task_id": "task-001",
|
|
"agent_name": "test_agent",
|
|
"task_type": "analyze",
|
|
"priority": 0,
|
|
"input_data": {"query": "test"},
|
|
"callback_url": None,
|
|
"created_at": datetime.now(timezone.utc),
|
|
"timeout_seconds": 60,
|
|
}
|
|
defaults.update(overrides)
|
|
return TaskMessage(**defaults)
|
|
|
|
|
|
def _make_mock_pool():
|
|
"""Create a mock AgentPool with a working agent."""
|
|
mock_agent = AsyncMock()
|
|
mock_result = MagicMock()
|
|
mock_result.output_data = {"result": "done"}
|
|
mock_agent.execute = AsyncMock(return_value=mock_result)
|
|
|
|
pool = MagicMock()
|
|
pool.get_agent = lambda name: mock_agent
|
|
pool.list_agents = lambda: [
|
|
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
|
|
]
|
|
return pool
|
|
|
|
|
|
class TestOrchestratorWithMessageBus:
|
|
@pytest.mark.asyncio
|
|
async def test_worker_publishes_progress(self):
|
|
"""Worker should publish progress via MessageBus after execution."""
|
|
bus = InMemoryMessageBus()
|
|
pool = _make_mock_pool()
|
|
orch = Orchestrator(agent_pool=pool, message_bus=bus)
|
|
|
|
# Subscribe orchestrator to receive progress
|
|
progress_messages: list[AgentMessage] = []
|
|
await bus.subscribe("orchestrator", lambda msg: progress_messages.append(msg))
|
|
|
|
task = _make_task()
|
|
result = await orch.execute(task)
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
# Give consumer task time to process
|
|
await asyncio.sleep(0.2)
|
|
assert len(progress_messages) >= 1
|
|
assert progress_messages[0].topic == "task.progress"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_message_bus_works_normally(self):
|
|
"""Without MessageBus, Orchestrator should work normally."""
|
|
pool = _make_mock_pool()
|
|
orch = Orchestrator(agent_pool=pool)
|
|
|
|
task = _make_task()
|
|
result = await orch.execute(task)
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_bus_injected_via_config(self):
|
|
"""MessageBus should be injectable via constructor."""
|
|
bus = InMemoryMessageBus()
|
|
pool = _make_mock_pool()
|
|
config = OrchestratorConfig(adaptive=True)
|
|
orch = Orchestrator(
|
|
agent_pool=pool,
|
|
message_bus=bus,
|
|
config=config,
|
|
)
|
|
assert orch._message_bus is bus
|
|
|
|
|
|
class TestAgentPoolWithMessageBus:
|
|
@pytest.mark.asyncio
|
|
async def test_agent_registered_to_bus_on_create(self):
|
|
"""Agent should be registered to MessageBus when created."""
|
|
from agentkit.core.agent_pool import AgentPool
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.skills.registry import SkillRegistry
|
|
|
|
bus = InMemoryMessageBus()
|
|
pool = AgentPool(
|
|
llm_gateway=MagicMock(spec=LLMGateway),
|
|
skill_registry=SkillRegistry(),
|
|
message_bus=bus,
|
|
)
|
|
|
|
# Verify bus has the message_bus reference
|
|
assert pool._message_bus is bus
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pool_without_bus_works(self):
|
|
"""AgentPool without MessageBus should work normally."""
|
|
from agentkit.core.agent_pool import AgentPool
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.skills.registry import SkillRegistry
|
|
|
|
pool = AgentPool(
|
|
llm_gateway=MagicMock(spec=LLMGateway),
|
|
skill_registry=SkillRegistry(),
|
|
)
|
|
assert pool._message_bus is None
|