"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4). Per plan U4 Execution note: characterization-first — verify that existing REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning (no regression). Then add PLAN_EXEC wiring tests. KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501. """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult from agentkit.core.phase import PhaseState from agentkit.tools.advance_phase import AdvancePhaseTool # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def app_with_chat(): """Create a FastAPI app with Chat routes and mocked dependencies.""" from fastapi import FastAPI from agentkit.server.routes.chat import router app = FastAPI() app.include_router(router, prefix="/api/v1") from agentkit.session.manager import SessionManager from agentkit.session.store import InMemorySessionStore app.state.session_manager = SessionManager(store=InMemorySessionStore()) app.state.llm_gateway = MagicMock() app.state.agent_pool = MagicMock() app.state.server_config = MagicMock() app.state.server_config.api_key = None app.state.server_config.plan_exec = {} return app @pytest.fixture def client(app_with_chat): return TestClient(app_with_chat) def _make_routing( execution_mode: ExecutionMode = ExecutionMode.REACT, tools: list | None = None, ) -> SkillRoutingResult: """Build a minimal SkillRoutingResult for testing.""" return SkillRoutingResult( execution_mode=execution_mode, tools=tools or [], clean_content="test message", model="default", agent_name="test-agent", system_prompt=None, skill_name=None, ) def _make_websocket_mock(app) -> MagicMock: """Build a mock WebSocket with app.state and async send_json.""" ws = MagicMock() ws.app = app ws.send_json = AsyncMock() return ws def _make_agent_mock() -> MagicMock: """Build a mock Agent with _tool_registry and _react_engine.""" agent = MagicMock() agent.name = "test-agent" agent._tool_registry = MagicMock() agent._tool_registry.list_tools.return_value = [] agent._system_prompt = None # _react_engine is None to force the code path that creates a new engine agent._react_engine = None agent.get_model.return_value = "default" return agent def _make_session_manager_mock() -> MagicMock: """Build a mock SessionManager with async methods.""" sm = MagicMock() # get_session returns a mock session with agent_name="test-agent" session = MagicMock() session.agent_name = "test-agent" session.status = "active" sm.get_session = AsyncMock(return_value=session) sm.get_chat_messages = AsyncMock(return_value=[]) sm.append_message = AsyncMock() return sm def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None: """Wire up app.state so _handle_chat_message finds the right routing.""" app.state.agent_pool.get_agent.return_value = agent app.state.request_preprocessor = MagicMock() app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing) # --------------------------------------------------------------------------- # REST — PLAN_EXEC raises 501 (KTD4) # --------------------------------------------------------------------------- class TestRestPlanExec501: def test_rest_plan_exec_returns_501(self, client): """REST send_message with execution_mode=plan_exec → 501.""" create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) session_id = create_resp.json()["session_id"] msg_resp = client.post( f"/api/v1/chat/sessions/{session_id}/messages", json={"content": "Hello", "execution_mode": "plan_exec"}, ) assert msg_resp.status_code == 501 assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"] def test_rest_react_mode_still_works(self, client): """REST send_message without execution_mode doesn't 501.""" create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) session_id = create_resp.json()["session_id"] # No execution_mode field → should NOT trigger 501. msg_resp = client.post( f"/api/v1/chat/sessions/{session_id}/messages", json={"content": "Hello"}, ) assert msg_resp.status_code != 501 # --------------------------------------------------------------------------- # Characterization — REWOO still falls back to REACT (no regression) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat): """Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT).""" from agentkit.server.routes import chat as chat_module agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.REWOO) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() ws = _make_websocket_mock(app_with_chat) captured_engine_kwargs: dict = {} class _StubEngine: def __init__(self, **kwargs): captured_engine_kwargs.update(kwargs) self._phase_policy = kwargs.get("phase_policy") self._current_phase = None @property def current_phase(self): return self._current_phase def reset(self): pass async def execute_stream(self, **kwargs): return yield # async generator marker with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="test", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) # REWOO should NOT build a phase_policy assert captured_engine_kwargs.get("phase_policy") is None # --------------------------------------------------------------------------- # Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool( app_with_chat, ): """PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered.""" from agentkit.server.routes import chat as chat_module agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}]) ws = _make_websocket_mock(app_with_chat) captured_engine: list = [] captured_tools: list = [] class _StubEngine: def __init__(self, **kwargs): self._phase_policy = kwargs.get("phase_policy") self._current_phase = ( kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None ) @property def current_phase(self): return self._current_phase def reset(self): pass async def execute_stream(self, **kwargs): captured_tools.extend(kwargs.get("tools", [])) captured_engine.append(self) return yield with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="test", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) assert len(captured_engine) == 1 engine = captured_engine[0] assert engine._phase_policy is not None assert engine._current_phase == PhaseState.PLANNING # AdvancePhaseTool was registered in the tools list assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools) # --------------------------------------------------------------------------- # Edge cases # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_plan_exec_empty_config_uses_default_policy(app_with_chat): """Edge: plan_exec config absent (empty dict) → default_policy() used.""" from agentkit.server.routes import chat as chat_module app_with_chat.state.server_config.plan_exec = {} agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() ws = _make_websocket_mock(app_with_chat) captured_policy: list = [] class _StubEngine: def __init__(self, **kwargs): captured_policy.append(kwargs.get("phase_policy")) self._phase_policy = kwargs.get("phase_policy") self._current_phase = ( kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None ) @property def current_phase(self): return self._current_phase def reset(self): pass async def execute_stream(self, **kwargs): return yield with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="test", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) assert len(captured_policy) == 1 assert captured_policy[0] is not None # Default policy: PLANNING allows search but not write_file assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING] assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING] @pytest.mark.asyncio async def test_plan_exec_disabled_falls_back_to_react(app_with_chat): """Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy).""" from agentkit.server.routes import chat as chat_module app_with_chat.state.server_config.plan_exec = {"enabled": False} agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() ws = _make_websocket_mock(app_with_chat) captured_engine_kwargs: dict = {} class _StubEngine: def __init__(self, **kwargs): captured_engine_kwargs.update(kwargs) self._phase_policy = kwargs.get("phase_policy") self._current_phase = None @property def current_phase(self): return self._current_phase def reset(self): pass async def execute_stream(self, **kwargs): return yield with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="test", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) # enabled=False → no phase_policy (falls back to REACT) assert captured_engine_kwargs.get("phase_policy") is None # --------------------------------------------------------------------------- # Error path # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat): """Error: phase policy construction fails → error event sent, returns early.""" from agentkit.server.routes import chat as chat_module app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"} agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() ws = _make_websocket_mock(app_with_chat) engine_constructor_called = [] class _StubEngine: def __init__(self, **kwargs): engine_constructor_called.append(kwargs) async def execute_stream(self, **kwargs): yield with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="test", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) sent_messages = [call.args[0] for call in ws.send_json.call_args_list] error_messages = [m for m in sent_messages if m.get("type") == "error"] assert len(error_messages) == 1 assert "phase policy error" in error_messages[0]["data"]["message"] # Engine constructor was NOT called (returned early) assert len(engine_constructor_called) == 0 # --------------------------------------------------------------------------- # phase_changed event emission # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_phase_changed_event_emitted_on_transition(app_with_chat): """phase_changed event sent when current_phase changes during execute_stream.""" from agentkit.server.routes import chat as chat_module app_with_chat.state.server_config.plan_exec = {} agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}]) ws = _make_websocket_mock(app_with_chat) class _StubEngine: def __init__(self, **kwargs): self._phase_policy = kwargs.get("phase_policy") self._current_phase = PhaseState.PLANNING @property def current_phase(self): return self._current_phase def reset(self): pass async def execute_stream(self, **kwargs): from agentkit.core.react import ReActEvent yield ReActEvent( event_type="tool_call", step=1, data={"tool": "search", "output": "ok"}, ) # Simulate phase transition (as if AdvancePhaseTool was called) self._current_phase = PhaseState.BUILDING yield ReActEvent( event_type="final_answer", step=2, data={"output": "done"}, ) with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="go", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) sent_messages = [call.args[0] for call in ws.send_json.call_args_list] phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] assert len(phase_events) == 1 assert phase_events[0]["data"]["phase"] == "building" assert phase_events[0]["data"]["previous"] == "planning" @pytest.mark.asyncio async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat): """Characterization: REACT mode → no phase_changed events.""" from agentkit.server.routes import chat as chat_module agent = _make_agent_mock() routing = _make_routing(execution_mode=ExecutionMode.REACT) _setup_routing(app_with_chat, routing, agent) sm = _make_session_manager_mock() sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}]) ws = _make_websocket_mock(app_with_chat) class _StubEngine: def __init__(self, **kwargs): self._phase_policy = None self._current_phase = None @property def current_phase(self): return None def reset(self): pass async def execute_stream(self, **kwargs): from agentkit.core.react import ReActEvent yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"}) with pytest.MonkeyPatch().context() as mp: mp.setattr(chat_module, "ReActEngine", _StubEngine) await chat_module._handle_chat_message( websocket=ws, session_id="test-session", content="hi", sm=sm, cancellation_token=MagicMock(), pending_replies={}, pending_confirmations=None, ) sent_messages = [call.args[0] for call in ws.send_json.call_args_list] phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"] assert len(phase_events) == 0