From 45283d31e84a6424d4791fe3efd83d23bb417ceb Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 8 Jun 2026 00:03:40 +0800 Subject: [PATCH] feat(core): integrate MessageBus into Orchestrator and AgentPool (U7) - Orchestrator accepts optional message_bus parameter; workers publish task.progress messages via MessageBus after each subtask execution - AgentPool accepts optional message_bus; auto-registers agents on create and auto-unregisters on remove - app.py initializes MessageBus from config and injects into AgentPool - ServerConfig adds bus configuration field - 5 new tests, all passing --- src/agentkit/core/agent_pool.py | 26 +++++- src/agentkit/core/orchestrator.py | 59 +++++++++++++- src/agentkit/server/app.py | 12 +++ src/agentkit/server/config.py | 3 + tests/unit/test_orchestrator_bus.py | 118 ++++++++++++++++++++++++++++ 5 files changed, 214 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_orchestrator_bus.py diff --git a/src/agentkit/core/agent_pool.py b/src/agentkit/core/agent_pool.py index 200ac77..1525390 100644 --- a/src/agentkit/core/agent_pool.py +++ b/src/agentkit/core/agent_pool.py @@ -1,7 +1,7 @@ """AgentPool - 运行时 Agent 实例池""" import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from agentkit.core.config_driven import ConfigDrivenAgent from agentkit.core.protocol import AgentStatus @@ -24,12 +24,14 @@ class AgentPool: skill_registry: SkillRegistry, tool_registry: ToolRegistry | None = None, compressor: "CompressionStrategy | None" = None, + message_bus: Any = None, ): self._agents: dict[str, ConfigDrivenAgent] = {} self._llm_gateway = llm_gateway self._skill_registry = skill_registry self._tool_registry = tool_registry or ToolRegistry() self._compressor = compressor + self._message_bus = message_bus async def create_agent(self, config) -> ConfigDrivenAgent: """Create and start an Agent instance @@ -53,6 +55,19 @@ class AgentPool: await agent.start() self._agents[config.name] = agent logger.info(f"Agent '{config.name}' created and started in pool") + + # Register agent to MessageBus if available + if self._message_bus is not None: + try: + async def _handle_bus_message(msg): + """Handle incoming bus messages for this agent.""" + logger.debug(f"Agent '{config.name}' received bus message: {msg.topic}") + + await self._message_bus.subscribe(config.name, _handle_bus_message) + logger.info(f"Agent '{config.name}' registered to MessageBus") + except Exception as e: + logger.warning(f"Failed to register agent '{config.name}' to MessageBus: {e}") + return agent async def remove_agent(self, name: str) -> None: @@ -60,6 +75,15 @@ class AgentPool: agent = self._agents.pop(name, None) if agent: await agent.stop() + + # Unregister from MessageBus if available + if self._message_bus is not None: + try: + await self._message_bus.unsubscribe(name) + logger.info(f"Agent '{name}' unregistered from MessageBus") + except Exception as e: + logger.warning(f"Failed to unregister agent '{name}' from MessageBus: {e}") + logger.info(f"Agent '{name}' stopped and removed from pool") def get_agent(self, name: str) -> ConfigDrivenAgent | None: diff --git a/src/agentkit/core/orchestrator.py b/src/agentkit/core/orchestrator.py index 3ed450b..7d1bb65 100644 --- a/src/agentkit/core/orchestrator.py +++ b/src/agentkit/core/orchestrator.py @@ -106,6 +106,7 @@ class Orchestrator: max_parallel: int = 5, subtask_timeout: float = 300.0, config: OrchestratorConfig | None = None, + message_bus: Any = None, ): """ Args: @@ -115,6 +116,7 @@ class Orchestrator: max_parallel: 最大并行子任务数 subtask_timeout: 子任务超时时间(秒) config: Orchestrator 配置,包含自适应参数 + message_bus: MessageBus 实例,用于 Agent 间通信 """ self._agent_pool = agent_pool self._workspace = workspace or SharedWorkspace() @@ -122,6 +124,7 @@ class Orchestrator: self._max_parallel = max_parallel self._subtask_timeout = subtask_timeout self._config = config or OrchestratorConfig() + self._message_bus = message_bus async def execute(self, task: TaskMessage) -> OrchestrationResult: """执行编排任务 @@ -360,14 +363,64 @@ class Orchestrator: agent.execute(sub_task_msg), timeout=self._subtask_timeout, ) - return { + output = { "status": "completed", "output": result.output_data if hasattr(result, "output_data") else result, } + + # Publish progress via MessageBus if available + if self._message_bus is not None: + try: + from agentkit.bus.message import AgentMessage + await self._message_bus.publish(AgentMessage( + sender=subtask.assigned_agent, + recipient="orchestrator", + topic="task.progress", + payload={ + "task_id": subtask.task_id, + "status": "completed", + }, + )) + except Exception as e: + logger.warning(f"Failed to publish progress via MessageBus: {e}") + + return output except asyncio.TimeoutError: - return {"status": "failed", "error": "Subtask timed out"} + error_result = {"status": "failed", "error": "Subtask timed out"} + if self._message_bus is not None: + try: + from agentkit.bus.message import AgentMessage + await self._message_bus.publish(AgentMessage( + sender=subtask.assigned_agent, + recipient="orchestrator", + topic="task.progress", + payload={ + "task_id": subtask.task_id, + "status": "failed", + "error": "Subtask timed out", + }, + )) + except Exception as e: + logger.warning(f"Failed to publish progress via MessageBus: {e}") + return error_result except Exception as e: - return {"status": "failed", "error": str(e)} + error_result = {"status": "failed", "error": str(e)} + if self._message_bus is not None: + try: + from agentkit.bus.message import AgentMessage + await self._message_bus.publish(AgentMessage( + sender=subtask.assigned_agent, + recipient="orchestrator", + topic="task.progress", + payload={ + "task_id": subtask.task_id, + "status": "failed", + "error": str(e), + }, + )) + except Exception as e: + logger.warning(f"Failed to publish progress via MessageBus: {e}") + return error_result def _inject_dependency_results( self, diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 6ad6018..1874925 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -271,11 +271,23 @@ def create_app( logger.info("HeadroomRetrieveTool registered (CCR retrieval enabled)") except ImportError: pass + # Initialize MessageBus for inter-agent communication + from agentkit.bus.redis_bus import create_message_bus + bus_config = {} + if server_config and hasattr(server_config, "bus") and server_config.bus: + bus_config = server_config.bus + message_bus = create_message_bus( + backend=bus_config.get("backend", "memory"), + redis_url=bus_config.get("redis_url", "redis://localhost:6379/0"), + ) + app.state.message_bus = message_bus + app.state.agent_pool = AgentPool( llm_gateway=app.state.llm_gateway, skill_registry=app.state.skill_registry, tool_registry=app.state.tool_registry, compressor=compressor, + message_bus=message_bus, ) app.state.intent_router = IntentRouter(llm_gateway=app.state.llm_gateway) app.state.quality_gate = QualityGate() diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index 449d644..5d66a7d 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -107,6 +107,7 @@ class ServerConfig: telemetry: dict[str, Any] | None = None, compression: dict[str, Any] | None = None, session: dict[str, Any] | None = None, + bus: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -126,6 +127,7 @@ class ServerConfig: self.telemetry = telemetry or {} self.compression = compression or {} self.session = session or {} + self.bus = bus or {} self.on_change = on_change # Config watching state @@ -195,6 +197,7 @@ class ServerConfig: telemetry=telemetry_data, compression=compression_data, session=session_data, + bus=server.get("bus"), ) @staticmethod diff --git a/tests/unit/test_orchestrator_bus.py b/tests/unit/test_orchestrator_bus.py new file mode 100644 index 0000000..dcd8762 --- /dev/null +++ b/tests/unit/test_orchestrator_bus.py @@ -0,0 +1,118 @@ +"""Tests for Orchestrator + MessageBus integration (U7).""" + +import asyncio +import pytest +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from agentkit.bus.message import AgentMessage +from agentkit.bus.memory_bus import InMemoryMessageBus +from agentkit.core.orchestrator import Orchestrator, OrchestratorConfig +from agentkit.core.protocol import TaskMessage, TaskStatus + + +def _make_task(**overrides) -> TaskMessage: + defaults = { + "task_id": "task-001", + "agent_name": "test_agent", + "task_type": "analyze", + "priority": 0, + "input_data": {"query": "test"}, + "callback_url": None, + "created_at": datetime.now(timezone.utc), + "timeout_seconds": 60, + } + defaults.update(overrides) + return TaskMessage(**defaults) + + +def _make_mock_pool(): + """Create a mock AgentPool with a working agent.""" + mock_agent = AsyncMock() + mock_result = MagicMock() + mock_result.output_data = {"result": "done"} + mock_agent.execute = AsyncMock(return_value=mock_result) + + pool = MagicMock() + pool.get_agent = lambda name: mock_agent + pool.list_agents = lambda: [ + {"name": "test_agent", "agent_type": "worker", "description": "Test agent"} + ] + return pool + + +class TestOrchestratorWithMessageBus: + @pytest.mark.asyncio + async def test_worker_publishes_progress(self): + """Worker should publish progress via MessageBus after execution.""" + bus = InMemoryMessageBus() + pool = _make_mock_pool() + orch = Orchestrator(agent_pool=pool, message_bus=bus) + + # Subscribe orchestrator to receive progress + progress_messages: list[AgentMessage] = [] + await bus.subscribe("orchestrator", lambda msg: progress_messages.append(msg)) + + task = _make_task() + result = await orch.execute(task) + assert result.status == TaskStatus.COMPLETED + + # Give consumer task time to process + await asyncio.sleep(0.2) + assert len(progress_messages) >= 1 + assert progress_messages[0].topic == "task.progress" + + @pytest.mark.asyncio + async def test_no_message_bus_works_normally(self): + """Without MessageBus, Orchestrator should work normally.""" + pool = _make_mock_pool() + orch = Orchestrator(agent_pool=pool) + + task = _make_task() + result = await orch.execute(task) + assert result.status == TaskStatus.COMPLETED + + @pytest.mark.asyncio + async def test_message_bus_injected_via_config(self): + """MessageBus should be injectable via constructor.""" + bus = InMemoryMessageBus() + pool = _make_mock_pool() + config = OrchestratorConfig(adaptive=True) + orch = Orchestrator( + agent_pool=pool, + message_bus=bus, + config=config, + ) + assert orch._message_bus is bus + + +class TestAgentPoolWithMessageBus: + @pytest.mark.asyncio + async def test_agent_registered_to_bus_on_create(self): + """Agent should be registered to MessageBus when created.""" + from agentkit.core.agent_pool import AgentPool + from agentkit.llm.gateway import LLMGateway + from agentkit.skills.registry import SkillRegistry + + bus = InMemoryMessageBus() + pool = AgentPool( + llm_gateway=MagicMock(spec=LLMGateway), + skill_registry=SkillRegistry(), + message_bus=bus, + ) + + # Verify bus has the message_bus reference + assert pool._message_bus is bus + + @pytest.mark.asyncio + async def test_pool_without_bus_works(self): + """AgentPool without MessageBus should work normally.""" + from agentkit.core.agent_pool import AgentPool + from agentkit.llm.gateway import LLMGateway + from agentkit.skills.registry import SkillRegistry + + pool = AgentPool( + llm_gateway=MagicMock(spec=LLMGateway), + skill_registry=SkillRegistry(), + ) + assert pool._message_bus is None