fischer-agentkit/tests/unit/test_chat_plan_exec_ws.py

532 lines
18 KiB
Python

"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
Per plan U4 Execution note: characterization-first — verify that existing
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
(no regression). Then add PLAN_EXEC wiring tests.
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi.testclient import TestClient
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
from agentkit.core.phase import PhaseState
from agentkit.tools.advance_phase import AdvancePhaseTool
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def app_with_chat():
"""Create a FastAPI app with Chat routes and mocked dependencies."""
from fastapi import FastAPI
from agentkit.server.routes.chat import router
app = FastAPI()
app.include_router(router, prefix="/api/v1")
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
app.state.session_manager = SessionManager(store=InMemorySessionStore())
app.state.llm_gateway = MagicMock()
app.state.agent_pool = MagicMock()
app.state.server_config = MagicMock()
app.state.server_config.api_key = None
app.state.server_config.plan_exec = {}
return app
@pytest.fixture
def client(app_with_chat):
return TestClient(app_with_chat)
def _make_routing(
execution_mode: ExecutionMode = ExecutionMode.REACT,
tools: list | None = None,
) -> SkillRoutingResult:
"""Build a minimal SkillRoutingResult for testing."""
return SkillRoutingResult(
execution_mode=execution_mode,
tools=tools or [],
clean_content="test message",
model="default",
agent_name="test-agent",
system_prompt=None,
skill_name=None,
)
def _make_websocket_mock(app) -> MagicMock:
"""Build a mock WebSocket with app.state and async send_json."""
ws = MagicMock()
ws.app = app
ws.send_json = AsyncMock()
return ws
def _make_agent_mock() -> MagicMock:
"""Build a mock Agent with _tool_registry and _react_engine."""
agent = MagicMock()
agent.name = "test-agent"
agent._tool_registry = MagicMock()
agent._tool_registry.list_tools.return_value = []
agent._system_prompt = None
# _react_engine is None to force the code path that creates a new engine
agent._react_engine = None
agent.get_model.return_value = "default"
return agent
def _make_session_manager_mock() -> MagicMock:
"""Build a mock SessionManager with async methods."""
sm = MagicMock()
# get_session returns a mock session with agent_name="test-agent"
session = MagicMock()
session.agent_name = "test-agent"
session.status = "active"
sm.get_session = AsyncMock(return_value=session)
sm.get_chat_messages = AsyncMock(return_value=[])
sm.append_message = AsyncMock()
return sm
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
"""Wire up app.state so _handle_chat_message finds the right routing."""
app.state.agent_pool.get_agent.return_value = agent
app.state.request_preprocessor = MagicMock()
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
# ---------------------------------------------------------------------------
# REST — PLAN_EXEC raises 501 (KTD4)
# ---------------------------------------------------------------------------
class TestRestPlanExec501:
def test_rest_plan_exec_returns_501(self, client):
"""REST send_message with execution_mode=plan_exec → 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello", "execution_mode": "plan_exec"},
)
assert msg_resp.status_code == 501
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
def test_rest_react_mode_still_works(self, client):
"""REST send_message without execution_mode doesn't 501."""
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
# No execution_mode field → should NOT trigger 501.
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello"},
)
assert msg_resp.status_code != 501
# ---------------------------------------------------------------------------
# Characterization — REWOO still falls back to REACT (no regression)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield # async generator marker
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# REWOO should NOT build a phase_policy
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
app_with_chat,
):
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
ws = _make_websocket_mock(app_with_chat)
captured_engine: list = []
captured_tools: list = []
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
captured_tools.extend(kwargs.get("tools", []))
captured_engine.append(self)
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_engine) == 1
engine = captured_engine[0]
assert engine._phase_policy is not None
assert engine._current_phase == PhaseState.PLANNING
# AdvancePhaseTool was registered in the tools list
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_policy: list = []
class _StubEngine:
def __init__(self, **kwargs):
captured_policy.append(kwargs.get("phase_policy"))
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = (
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
)
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
assert len(captured_policy) == 1
assert captured_policy[0] is not None
# Default policy: PLANNING allows search but not write_file
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
@pytest.mark.asyncio
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"enabled": False}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
captured_engine_kwargs: dict = {}
class _StubEngine:
def __init__(self, **kwargs):
captured_engine_kwargs.update(kwargs)
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = None
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
return
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
# enabled=False → no phase_policy (falls back to REACT)
assert captured_engine_kwargs.get("phase_policy") is None
# ---------------------------------------------------------------------------
# Error path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
"""Error: phase policy construction fails → error event sent, returns early."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
ws = _make_websocket_mock(app_with_chat)
engine_constructor_called = []
class _StubEngine:
def __init__(self, **kwargs):
engine_constructor_called.append(kwargs)
async def execute_stream(self, **kwargs):
yield
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="test",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
error_messages = [m for m in sent_messages if m.get("type") == "error"]
assert len(error_messages) == 1
assert "phase policy error" in error_messages[0]["data"]["message"]
# Engine constructor was NOT called (returned early)
assert len(engine_constructor_called) == 0
# ---------------------------------------------------------------------------
# phase_changed event emission
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
"""phase_changed event sent when current_phase changes during execute_stream."""
from agentkit.server.routes import chat as chat_module
app_with_chat.state.server_config.plan_exec = {}
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = kwargs.get("phase_policy")
self._current_phase = PhaseState.PLANNING
@property
def current_phase(self):
return self._current_phase
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(
event_type="tool_call",
step=1,
data={"tool": "search", "output": "ok"},
)
# Simulate phase transition (as if AdvancePhaseTool was called)
self._current_phase = PhaseState.BUILDING
yield ReActEvent(
event_type="final_answer",
step=2,
data={"output": "done"},
)
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="go",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 1
assert phase_events[0]["data"]["phase"] == "building"
assert phase_events[0]["data"]["previous"] == "planning"
@pytest.mark.asyncio
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
"""Characterization: REACT mode → no phase_changed events."""
from agentkit.server.routes import chat as chat_module
agent = _make_agent_mock()
routing = _make_routing(execution_mode=ExecutionMode.REACT)
_setup_routing(app_with_chat, routing, agent)
sm = _make_session_manager_mock()
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
ws = _make_websocket_mock(app_with_chat)
class _StubEngine:
def __init__(self, **kwargs):
self._phase_policy = None
self._current_phase = None
@property
def current_phase(self):
return None
def reset(self):
pass
async def execute_stream(self, **kwargs):
from agentkit.core.react import ReActEvent
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
with pytest.MonkeyPatch().context() as mp:
mp.setattr(chat_module, "ReActEngine", _StubEngine)
await chat_module._handle_chat_message(
websocket=ws,
session_id="test-session",
content="hi",
sm=sm,
cancellation_token=MagicMock(),
pending_replies={},
pending_confirmations=None,
)
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
assert len(phase_events) == 0