feat(core): integrate MessageBus into Orchestrator and AgentPool (U7)
- Orchestrator accepts optional message_bus parameter; workers publish task.progress messages via MessageBus after each subtask execution - AgentPool accepts optional message_bus; auto-registers agents on create and auto-unregisters on remove - app.py initializes MessageBus from config and injects into AgentPool - ServerConfig adds bus configuration field - 5 new tests, all passing
This commit is contained in:
parent
13d6e74099
commit
45283d31e8
|
|
@ -1,7 +1,7 @@
|
|||
"""AgentPool - 运行时 Agent 实例池"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent
|
||||
from agentkit.core.protocol import AgentStatus
|
||||
|
|
@ -24,12 +24,14 @@ class AgentPool:
|
|||
skill_registry: SkillRegistry,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
compressor: "CompressionStrategy | None" = None,
|
||||
message_bus: Any = None,
|
||||
):
|
||||
self._agents: dict[str, ConfigDrivenAgent] = {}
|
||||
self._llm_gateway = llm_gateway
|
||||
self._skill_registry = skill_registry
|
||||
self._tool_registry = tool_registry or ToolRegistry()
|
||||
self._compressor = compressor
|
||||
self._message_bus = message_bus
|
||||
|
||||
async def create_agent(self, config) -> ConfigDrivenAgent:
|
||||
"""Create and start an Agent instance
|
||||
|
|
@ -53,6 +55,19 @@ class AgentPool:
|
|||
await agent.start()
|
||||
self._agents[config.name] = agent
|
||||
logger.info(f"Agent '{config.name}' created and started in pool")
|
||||
|
||||
# Register agent to MessageBus if available
|
||||
if self._message_bus is not None:
|
||||
try:
|
||||
async def _handle_bus_message(msg):
|
||||
"""Handle incoming bus messages for this agent."""
|
||||
logger.debug(f"Agent '{config.name}' received bus message: {msg.topic}")
|
||||
|
||||
await self._message_bus.subscribe(config.name, _handle_bus_message)
|
||||
logger.info(f"Agent '{config.name}' registered to MessageBus")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to register agent '{config.name}' to MessageBus: {e}")
|
||||
|
||||
return agent
|
||||
|
||||
async def remove_agent(self, name: str) -> None:
|
||||
|
|
@ -60,6 +75,15 @@ class AgentPool:
|
|||
agent = self._agents.pop(name, None)
|
||||
if agent:
|
||||
await agent.stop()
|
||||
|
||||
# Unregister from MessageBus if available
|
||||
if self._message_bus is not None:
|
||||
try:
|
||||
await self._message_bus.unsubscribe(name)
|
||||
logger.info(f"Agent '{name}' unregistered from MessageBus")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to unregister agent '{name}' from MessageBus: {e}")
|
||||
|
||||
logger.info(f"Agent '{name}' stopped and removed from pool")
|
||||
|
||||
def get_agent(self, name: str) -> ConfigDrivenAgent | None:
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ class Orchestrator:
|
|||
max_parallel: int = 5,
|
||||
subtask_timeout: float = 300.0,
|
||||
config: OrchestratorConfig | None = None,
|
||||
message_bus: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
|
@ -115,6 +116,7 @@ class Orchestrator:
|
|||
max_parallel: 最大并行子任务数
|
||||
subtask_timeout: 子任务超时时间(秒)
|
||||
config: Orchestrator 配置,包含自适应参数
|
||||
message_bus: MessageBus 实例,用于 Agent 间通信
|
||||
"""
|
||||
self._agent_pool = agent_pool
|
||||
self._workspace = workspace or SharedWorkspace()
|
||||
|
|
@ -122,6 +124,7 @@ class Orchestrator:
|
|||
self._max_parallel = max_parallel
|
||||
self._subtask_timeout = subtask_timeout
|
||||
self._config = config or OrchestratorConfig()
|
||||
self._message_bus = message_bus
|
||||
|
||||
async def execute(self, task: TaskMessage) -> OrchestrationResult:
|
||||
"""执行编排任务
|
||||
|
|
@ -360,14 +363,64 @@ class Orchestrator:
|
|||
agent.execute(sub_task_msg),
|
||||
timeout=self._subtask_timeout,
|
||||
)
|
||||
return {
|
||||
output = {
|
||||
"status": "completed",
|
||||
"output": result.output_data if hasattr(result, "output_data") else result,
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {"status": "failed", "error": "Subtask timed out"}
|
||||
|
||||
# Publish progress via MessageBus if available
|
||||
if self._message_bus is not None:
|
||||
try:
|
||||
from agentkit.bus.message import AgentMessage
|
||||
await self._message_bus.publish(AgentMessage(
|
||||
sender=subtask.assigned_agent,
|
||||
recipient="orchestrator",
|
||||
topic="task.progress",
|
||||
payload={
|
||||
"task_id": subtask.task_id,
|
||||
"status": "completed",
|
||||
},
|
||||
))
|
||||
except Exception as e:
|
||||
return {"status": "failed", "error": str(e)}
|
||||
logger.warning(f"Failed to publish progress via MessageBus: {e}")
|
||||
|
||||
return output
|
||||
except asyncio.TimeoutError:
|
||||
error_result = {"status": "failed", "error": "Subtask timed out"}
|
||||
if self._message_bus is not None:
|
||||
try:
|
||||
from agentkit.bus.message import AgentMessage
|
||||
await self._message_bus.publish(AgentMessage(
|
||||
sender=subtask.assigned_agent,
|
||||
recipient="orchestrator",
|
||||
topic="task.progress",
|
||||
payload={
|
||||
"task_id": subtask.task_id,
|
||||
"status": "failed",
|
||||
"error": "Subtask timed out",
|
||||
},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish progress via MessageBus: {e}")
|
||||
return error_result
|
||||
except Exception as e:
|
||||
error_result = {"status": "failed", "error": str(e)}
|
||||
if self._message_bus is not None:
|
||||
try:
|
||||
from agentkit.bus.message import AgentMessage
|
||||
await self._message_bus.publish(AgentMessage(
|
||||
sender=subtask.assigned_agent,
|
||||
recipient="orchestrator",
|
||||
topic="task.progress",
|
||||
payload={
|
||||
"task_id": subtask.task_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
},
|
||||
))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish progress via MessageBus: {e}")
|
||||
return error_result
|
||||
|
||||
def _inject_dependency_results(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -271,11 +271,23 @@ def create_app(
|
|||
logger.info("HeadroomRetrieveTool registered (CCR retrieval enabled)")
|
||||
except ImportError:
|
||||
pass
|
||||
# Initialize MessageBus for inter-agent communication
|
||||
from agentkit.bus.redis_bus import create_message_bus
|
||||
bus_config = {}
|
||||
if server_config and hasattr(server_config, "bus") and server_config.bus:
|
||||
bus_config = server_config.bus
|
||||
message_bus = create_message_bus(
|
||||
backend=bus_config.get("backend", "memory"),
|
||||
redis_url=bus_config.get("redis_url", "redis://localhost:6379/0"),
|
||||
)
|
||||
app.state.message_bus = message_bus
|
||||
|
||||
app.state.agent_pool = AgentPool(
|
||||
llm_gateway=app.state.llm_gateway,
|
||||
skill_registry=app.state.skill_registry,
|
||||
tool_registry=app.state.tool_registry,
|
||||
compressor=compressor,
|
||||
message_bus=message_bus,
|
||||
)
|
||||
app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway)
|
||||
app.state.quality_gate = QualityGate()
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class ServerConfig:
|
|||
telemetry: dict[str, Any] | None = None,
|
||||
compression: dict[str, Any] | None = None,
|
||||
session: dict[str, Any] | None = None,
|
||||
bus: dict[str, Any] | None = None,
|
||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||
):
|
||||
self.host = host
|
||||
|
|
@ -126,6 +127,7 @@ class ServerConfig:
|
|||
self.telemetry = telemetry or {}
|
||||
self.compression = compression or {}
|
||||
self.session = session or {}
|
||||
self.bus = bus or {}
|
||||
self.on_change = on_change
|
||||
|
||||
# Config watching state
|
||||
|
|
@ -195,6 +197,7 @@ class ServerConfig:
|
|||
telemetry=telemetry_data,
|
||||
compression=compression_data,
|
||||
session=session_data,
|
||||
bus=server.get("bus"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -0,0 +1,118 @@
|
|||
"""Tests for Orchestrator + MessageBus integration (U7)."""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.bus.message import AgentMessage
|
||||
from agentkit.bus.memory_bus import InMemoryMessageBus
|
||||
from agentkit.core.orchestrator import Orchestrator, OrchestratorConfig
|
||||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||
|
||||
|
||||
def _make_task(**overrides) -> TaskMessage:
|
||||
defaults = {
|
||||
"task_id": "task-001",
|
||||
"agent_name": "test_agent",
|
||||
"task_type": "analyze",
|
||||
"priority": 0,
|
||||
"input_data": {"query": "test"},
|
||||
"callback_url": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"timeout_seconds": 60,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return TaskMessage(**defaults)
|
||||
|
||||
|
||||
def _make_mock_pool():
|
||||
"""Create a mock AgentPool with a working agent."""
|
||||
mock_agent = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.output_data = {"result": "done"}
|
||||
mock_agent.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
pool = MagicMock()
|
||||
pool.get_agent = lambda name: mock_agent
|
||||
pool.list_agents = lambda: [
|
||||
{"name": "test_agent", "agent_type": "worker", "description": "Test agent"}
|
||||
]
|
||||
return pool
|
||||
|
||||
|
||||
class TestOrchestratorWithMessageBus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_publishes_progress(self):
|
||||
"""Worker should publish progress via MessageBus after execution."""
|
||||
bus = InMemoryMessageBus()
|
||||
pool = _make_mock_pool()
|
||||
orch = Orchestrator(agent_pool=pool, message_bus=bus)
|
||||
|
||||
# Subscribe orchestrator to receive progress
|
||||
progress_messages: list[AgentMessage] = []
|
||||
await bus.subscribe("orchestrator", lambda msg: progress_messages.append(msg))
|
||||
|
||||
task = _make_task()
|
||||
result = await orch.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
# Give consumer task time to process
|
||||
await asyncio.sleep(0.2)
|
||||
assert len(progress_messages) >= 1
|
||||
assert progress_messages[0].topic == "task.progress"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_bus_works_normally(self):
|
||||
"""Without MessageBus, Orchestrator should work normally."""
|
||||
pool = _make_mock_pool()
|
||||
orch = Orchestrator(agent_pool=pool)
|
||||
|
||||
task = _make_task()
|
||||
result = await orch.execute(task)
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_bus_injected_via_config(self):
|
||||
"""MessageBus should be injectable via constructor."""
|
||||
bus = InMemoryMessageBus()
|
||||
pool = _make_mock_pool()
|
||||
config = OrchestratorConfig(adaptive=True)
|
||||
orch = Orchestrator(
|
||||
agent_pool=pool,
|
||||
message_bus=bus,
|
||||
config=config,
|
||||
)
|
||||
assert orch._message_bus is bus
|
||||
|
||||
|
||||
class TestAgentPoolWithMessageBus:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_registered_to_bus_on_create(self):
|
||||
"""Agent should be registered to MessageBus when created."""
|
||||
from agentkit.core.agent_pool import AgentPool
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
bus = InMemoryMessageBus()
|
||||
pool = AgentPool(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
skill_registry=SkillRegistry(),
|
||||
message_bus=bus,
|
||||
)
|
||||
|
||||
# Verify bus has the message_bus reference
|
||||
assert pool._message_bus is bus
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_without_bus_works(self):
|
||||
"""AgentPool without MessageBus should work normally."""
|
||||
from agentkit.core.agent_pool import AgentPool
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.skills.registry import SkillRegistry
|
||||
|
||||
pool = AgentPool(
|
||||
llm_gateway=MagicMock(spec=LLMGateway),
|
||||
skill_registry=SkillRegistry(),
|
||||
)
|
||||
assert pool._message_bus is None
|
||||
Loading…
Reference in New Issue