From 7869caddc77ad3ebf0be370a7d5ad95d2940f3e9 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Tue, 30 Jun 2026 00:16:39 +0800 Subject: [PATCH] feat(U4): G6 PLAN_EXEC wiring at chat WebSocket path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - PLAN_EXEC branch builds PhasePolicy from ServerConfig.plan_exec - Empty config → default_policy(); enabled=False → falls back to REACT - Bad config → error event sent, returns early (no engine constructed) - ReActEngine created with phase_policy; AdvancePhaseTool registered - phase_changed events emitted on phase transitions (PLAN_EXEC only) - REST send_message with execution_mode=plan_exec → HTTP 501 (KTD4) - REWOO/REFLEXION/TEAM_COLLAB still fall back to REACT (no regression) - 9 unit tests covering REST 501, characterization, happy path, edge cases, error path, phase_changed events --- src/agentkit/server/routes/chat.py | 93 ++++- tests/unit/test_chat_plan_exec_ws.py | 531 +++++++++++++++++++++++++++ 2 files changed, 616 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_chat_plan_exec_ws.py diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 6737740..b1fe4ad 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -25,11 +25,13 @@ from fastapi.responses import FileResponse from pydantic import BaseModel from agentkit.chat.skill_routing import ExecutionMode +from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine from agentkit.server._fallback_chain import execute_with_fallback_chain from agentkit.session.manager import SessionManager from agentkit.session.models import MessageRole, SessionStatus +from agentkit.tools.advance_phase import AdvancePhaseTool logger = logging.getLogger(__name__) @@ -47,6 +49,8 @@ class CreateSessionRequest(BaseModel): class SendMessageRequest(BaseModel): content: str role: str = "user" + # Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only). + execution_mode: str | None = None class SessionResponse(BaseModel): @@ -583,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques if session.status == SessionStatus.CLOSED: raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") + # KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501. + if request.execution_mode == "plan_exec": + raise HTTPException( + status_code=501, + detail="PLAN_EXEC via REST not yet supported; use WebSocket", + ) + # Append user message await sm.append_message( session_id=session_id, @@ -1079,21 +1090,68 @@ async def _handle_chat_message( await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) return - # Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB - # currently fall back to REACT with a warning. - if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT): + # U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only). + # KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and + # fallback chain are mutually exclusive. PLAN_EXEC uses its own engine. + phase_policy: PhasePolicy | None = None + if routing.execution_mode == ExecutionMode.PLAN_EXEC: + server_config = getattr(websocket.app.state, "server_config", None) + plan_exec_cfg = getattr(server_config, "plan_exec", None) or {} + + if plan_exec_cfg.get("enabled", True) is False: + # Explicit opt-out → fall back to REACT. + logger.info( + "PLAN_EXEC disabled by config (plan_exec.enabled=False), " + "falling back to REACT for session %s", + session_id, + ) + else: + try: + phase_policy = policy_from_config(plan_exec_cfg) + if phase_policy is None: + # Empty config (no `plan_exec:` section) → use KTD5 defaults. + phase_policy = default_policy() + except Exception as e: + logger.error( + "PLAN_EXEC phase policy construction failed for session %s: %s", + session_id, + e, + ) + await websocket.send_json( + {"type": "error", "data": {"message": f"phase policy error: {e}"}} + ) + return + + # Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB + # still fall back to REACT with a warning. PLAN_EXEC is handled above. + if routing.execution_mode not in ( + ExecutionMode.REACT, + ExecutionMode.SKILL_REACT, + ExecutionMode.PLAN_EXEC, + ): logger.warning( f"Execution mode {routing.execution_mode.value} not yet supported " f"in chat WebSocket, falling back to REACT" ) # Execute Agent with streaming - # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) - react_engine = getattr(agent, "_react_engine", None) - if react_engine is None: - react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization). + # PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the + # agent's _react_engine — it has no policy). + if phase_policy is not None: + react_engine = ReActEngine( + llm_gateway=websocket.app.state.llm_gateway, + phase_policy=phase_policy, + ) + # Register AdvancePhaseTool bound to this engine (LLM's escape hatch). + advance_phase_tool = AdvancePhaseTool(engine=react_engine) + routing.tools = list(routing.tools) + [advance_phase_tool] else: - react_engine.reset() + react_engine = getattr(agent, "_react_engine", None) + if react_engine is None: + react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + else: + react_engine.reset() # Create confirmation handler that sends request to frontend and waits for reply # Use the same dict object — do NOT use `or {}` because an empty dict is falsy @@ -1149,6 +1207,9 @@ async def _handle_chat_message( try: final_content = "" token_buffer: list[str] = [] + # Track phase transitions for phase_changed events (PLAN_EXEC only). + # For non-PLAN_EXEC modes, current_phase is always None → no events. + prev_phase = react_engine.current_phase async for event in react_engine.execute_stream( messages=chat_messages, tools=routing.tools, @@ -1226,6 +1287,22 @@ async def _handle_chat_message( } ) + # U4/G6: emit phase_changed event when the phase state machine + # transitions (PLAN_EXEC only). For non-PLAN_EXEC modes, + # current_phase is always None → this branch never fires. + curr_phase = react_engine.current_phase + if curr_phase != prev_phase: + await websocket.send_json( + { + "type": "phase_changed", + "data": { + "phase": curr_phase.value if curr_phase else None, + "previous": prev_phase.value if prev_phase else None, + }, + } + ) + prev_phase = curr_phase + # Append assistant reply to session if final_content: await sm.append_message( diff --git a/tests/unit/test_chat_plan_exec_ws.py b/tests/unit/test_chat_plan_exec_ws.py new file mode 100644 index 0000000..84fe358 --- /dev/null +++ b/tests/unit/test_chat_plan_exec_ws.py @@ -0,0 +1,531 @@ +"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4). + +Per plan U4 Execution note: characterization-first — verify that existing +REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning +(no regression). Then add PLAN_EXEC wiring tests. + +KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi.testclient import TestClient + +from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult +from agentkit.core.phase import PhaseState +from agentkit.tools.advance_phase import AdvancePhaseTool + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def app_with_chat(): + """Create a FastAPI app with Chat routes and mocked dependencies.""" + from fastapi import FastAPI + + from agentkit.server.routes.chat import router + + app = FastAPI() + app.include_router(router, prefix="/api/v1") + + from agentkit.session.manager import SessionManager + from agentkit.session.store import InMemorySessionStore + + app.state.session_manager = SessionManager(store=InMemorySessionStore()) + app.state.llm_gateway = MagicMock() + app.state.agent_pool = MagicMock() + app.state.server_config = MagicMock() + app.state.server_config.api_key = None + app.state.server_config.plan_exec = {} + return app + + +@pytest.fixture +def client(app_with_chat): + return TestClient(app_with_chat) + + +def _make_routing( + execution_mode: ExecutionMode = ExecutionMode.REACT, + tools: list | None = None, +) -> SkillRoutingResult: + """Build a minimal SkillRoutingResult for testing.""" + return SkillRoutingResult( + execution_mode=execution_mode, + tools=tools or [], + clean_content="test message", + model="default", + agent_name="test-agent", + system_prompt=None, + skill_name=None, + ) + + +def _make_websocket_mock(app) -> MagicMock: + """Build a mock WebSocket with app.state and async send_json.""" + ws = MagicMock() + ws.app = app + ws.send_json = AsyncMock() + return ws + + +def _make_agent_mock() -> MagicMock: + """Build a mock Agent with _tool_registry and _react_engine.""" + agent = MagicMock() + agent.name = "test-agent" + agent._tool_registry = MagicMock() + agent._tool_registry.list_tools.return_value = [] + agent._system_prompt = None + # _react_engine is None to force the code path that creates a new engine + agent._react_engine = None + agent.get_model.return_value = "default" + return agent + + +def _make_session_manager_mock() -> MagicMock: + """Build a mock SessionManager with async methods.""" + sm = MagicMock() + # get_session returns a mock session with agent_name="test-agent" + session = MagicMock() + session.agent_name = "test-agent" + session.status = "active" + sm.get_session = AsyncMock(return_value=session) + sm.get_chat_messages = AsyncMock(return_value=[]) + sm.append_message = AsyncMock() + return sm + + +def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None: + """Wire up app.state so _handle_chat_message finds the right routing.""" + app.state.agent_pool.get_agent.return_value = agent + app.state.request_preprocessor = MagicMock() + app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing) + + +# --------------------------------------------------------------------------- +# REST — PLAN_EXEC raises 501 (KTD4) +# --------------------------------------------------------------------------- + + +class TestRestPlanExec501: + def test_rest_plan_exec_returns_501(self, client): + """REST send_message with execution_mode=plan_exec → 501.""" + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello", "execution_mode": "plan_exec"}, + ) + assert msg_resp.status_code == 501 + assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"] + + def test_rest_react_mode_still_works(self, client): + """REST send_message without execution_mode doesn't 501.""" + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + # No execution_mode field → should NOT trigger 501. + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello"}, + ) + assert msg_resp.status_code != 501 + + +# --------------------------------------------------------------------------- +# Characterization — REWOO still falls back to REACT (no regression) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat): + """Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT).""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REWOO) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_engine_kwargs: dict = {} + + class _StubEngine: + def __init__(self, **kwargs): + captured_engine_kwargs.update(kwargs) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = None + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield # async generator marker + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + # REWOO should NOT build a phase_policy + assert captured_engine_kwargs.get("phase_policy") is None + + +# --------------------------------------------------------------------------- +# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool( + app_with_chat, +): + """PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered.""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}]) + ws = _make_websocket_mock(app_with_chat) + + captured_engine: list = [] + captured_tools: list = [] + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = ( + kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None + ) + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + captured_tools.extend(kwargs.get("tools", [])) + captured_engine.append(self) + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + assert len(captured_engine) == 1 + engine = captured_engine[0] + assert engine._phase_policy is not None + assert engine._current_phase == PhaseState.PLANNING + # AdvancePhaseTool was registered in the tools list + assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_empty_config_uses_default_policy(app_with_chat): + """Edge: plan_exec config absent (empty dict) → default_policy() used.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_policy: list = [] + + class _StubEngine: + def __init__(self, **kwargs): + captured_policy.append(kwargs.get("phase_policy")) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = ( + kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None + ) + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + assert len(captured_policy) == 1 + assert captured_policy[0] is not None + # Default policy: PLANNING allows search but not write_file + assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING] + assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING] + + +@pytest.mark.asyncio +async def test_plan_exec_disabled_falls_back_to_react(app_with_chat): + """Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy).""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {"enabled": False} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + captured_engine_kwargs: dict = {} + + class _StubEngine: + def __init__(self, **kwargs): + captured_engine_kwargs.update(kwargs) + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = None + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + return + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + # enabled=False → no phase_policy (falls back to REACT) + assert captured_engine_kwargs.get("phase_policy") is None + + +# --------------------------------------------------------------------------- +# Error path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat): + """Error: phase policy construction fails → error event sent, returns early.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + ws = _make_websocket_mock(app_with_chat) + + engine_constructor_called = [] + + class _StubEngine: + def __init__(self, **kwargs): + engine_constructor_called.append(kwargs) + + async def execute_stream(self, **kwargs): + yield + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="test", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + error_messages = [m for m in sent_messages if m.get("type") == "error"] + assert len(error_messages) == 1 + assert "phase policy error" in error_messages[0]["data"]["message"] + # Engine constructor was NOT called (returned early) + assert len(engine_constructor_called) == 0 + + +# --------------------------------------------------------------------------- +# phase_changed event emission +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_phase_changed_event_emitted_on_transition(app_with_chat): + """phase_changed event sent when current_phase changes during execute_stream.""" + from agentkit.server.routes import chat as chat_module + + app_with_chat.state.server_config.plan_exec = {} + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = kwargs.get("phase_policy") + self._current_phase = PhaseState.PLANNING + + @property + def current_phase(self): + return self._current_phase + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + yield ReActEvent( + event_type="tool_call", + step=1, + data={"tool": "search", "output": "ok"}, + ) + # Simulate phase transition (as if AdvancePhaseTool was called) + self._current_phase = PhaseState.BUILDING + yield ReActEvent( + event_type="final_answer", + step=2, + data={"output": "done"}, + ) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="go", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] + assert len(phase_events) == 1 + assert phase_events[0]["data"]["phase"] == "building" + assert phase_events[0]["data"]["previous"] == "planning" + + +@pytest.mark.asyncio +async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat): + """Characterization: REACT mode → no phase_changed events.""" + from agentkit.server.routes import chat as chat_module + + agent = _make_agent_mock() + routing = _make_routing(execution_mode=ExecutionMode.REACT) + _setup_routing(app_with_chat, routing, agent) + + sm = _make_session_manager_mock() + sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}]) + ws = _make_websocket_mock(app_with_chat) + + class _StubEngine: + def __init__(self, **kwargs): + self._phase_policy = None + self._current_phase = None + + @property + def current_phase(self): + return None + + def reset(self): + pass + + async def execute_stream(self, **kwargs): + from agentkit.core.react import ReActEvent + + yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"}) + + with pytest.MonkeyPatch().context() as mp: + mp.setattr(chat_module, "ReActEngine", _StubEngine) + + await chat_module._handle_chat_message( + websocket=ws, + session_id="test-session", + content="hi", + sm=sm, + cancellation_token=MagicMock(), + pending_replies={}, + pending_confirmations=None, + ) + + sent_messages = [call.args[0] for call in ws.send_json.call_args_list] + phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] + assert len(phase_events) == 0