feat(U4): G6 PLAN_EXEC wiring at chat WebSocket path
- PLAN_EXEC branch builds PhasePolicy from ServerConfig.plan_exec - Empty config → default_policy(); enabled=False → falls back to REACT - Bad config → error event sent, returns early (no engine constructed) - ReActEngine created with phase_policy; AdvancePhaseTool registered - phase_changed events emitted on phase transitions (PLAN_EXEC only) - REST send_message with execution_mode=plan_exec → HTTP 501 (KTD4) - REWOO/REFLEXION/TEAM_COLLAB still fall back to REACT (no regression) - 9 unit tests covering REST 501, characterization, happy path, edge cases, error path, phase_changed events
This commit is contained in:
parent
6efd5957f6
commit
da4eef1349
|
|
@ -25,11 +25,13 @@ from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentkit.chat.skill_routing import ExecutionMode
|
from agentkit.chat.skill_routing import ExecutionMode
|
||||||
|
from agentkit.core.phase import PhasePolicy, default_policy, policy_from_config
|
||||||
from agentkit.core.protocol import CancellationToken
|
from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.core.react import ReActEngine
|
from agentkit.core.react import ReActEngine
|
||||||
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
from agentkit.server._fallback_chain import execute_with_fallback_chain
|
||||||
from agentkit.session.manager import SessionManager
|
from agentkit.session.manager import SessionManager
|
||||||
from agentkit.session.models import MessageRole, SessionStatus
|
from agentkit.session.models import MessageRole, SessionStatus
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -47,6 +49,8 @@ class CreateSessionRequest(BaseModel):
|
||||||
class SendMessageRequest(BaseModel):
|
class SendMessageRequest(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
role: str = "user"
|
role: str = "user"
|
||||||
|
# Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only).
|
||||||
|
execution_mode: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
class SessionResponse(BaseModel):
|
||||||
|
|
@ -583,6 +587,13 @@ async def send_message(session_id: str, request: SendMessageRequest, req: Reques
|
||||||
if session.status == SessionStatus.CLOSED:
|
if session.status == SessionStatus.CLOSED:
|
||||||
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
||||||
|
|
||||||
|
# KTD4: PLAN_EXEC is wired only at the WebSocket path. REST raises 501.
|
||||||
|
if request.execution_mode == "plan_exec":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=501,
|
||||||
|
detail="PLAN_EXEC via REST not yet supported; use WebSocket",
|
||||||
|
)
|
||||||
|
|
||||||
# Append user message
|
# Append user message
|
||||||
await sm.append_message(
|
await sm.append_message(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
|
@ -1079,16 +1090,63 @@ async def _handle_chat_message(
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
# U4/G6: PLAN_EXEC — build PhasePolicy from server config (KTD4: WebSocket only).
|
||||||
# currently fall back to REACT with a warning.
|
# KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and
|
||||||
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
# fallback chain are mutually exclusive. PLAN_EXEC uses its own engine.
|
||||||
|
phase_policy: PhasePolicy | None = None
|
||||||
|
if routing.execution_mode == ExecutionMode.PLAN_EXEC:
|
||||||
|
server_config = getattr(websocket.app.state, "server_config", None)
|
||||||
|
plan_exec_cfg = getattr(server_config, "plan_exec", None) or {}
|
||||||
|
|
||||||
|
if plan_exec_cfg.get("enabled", True) is False:
|
||||||
|
# Explicit opt-out → fall back to REACT.
|
||||||
|
logger.info(
|
||||||
|
"PLAN_EXEC disabled by config (plan_exec.enabled=False), "
|
||||||
|
"falling back to REACT for session %s",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
phase_policy = policy_from_config(plan_exec_cfg)
|
||||||
|
if phase_policy is None:
|
||||||
|
# Empty config (no `plan_exec:` section) → use KTD5 defaults.
|
||||||
|
phase_policy = default_policy()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"PLAN_EXEC phase policy construction failed for session %s: %s",
|
||||||
|
session_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "error", "data": {"message": f"phase policy error: {e}"}}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB
|
||||||
|
# still fall back to REACT with a warning. PLAN_EXEC is handled above.
|
||||||
|
if routing.execution_mode not in (
|
||||||
|
ExecutionMode.REACT,
|
||||||
|
ExecutionMode.SKILL_REACT,
|
||||||
|
ExecutionMode.PLAN_EXEC,
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Execution mode {routing.execution_mode.value} not yet supported "
|
f"Execution mode {routing.execution_mode.value} not yet supported "
|
||||||
f"in chat WebSocket, falling back to REACT"
|
f"in chat WebSocket, falling back to REACT"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute Agent with streaming
|
# Execute Agent with streaming
|
||||||
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization).
|
||||||
|
# PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the
|
||||||
|
# agent's _react_engine — it has no policy).
|
||||||
|
if phase_policy is not None:
|
||||||
|
react_engine = ReActEngine(
|
||||||
|
llm_gateway=websocket.app.state.llm_gateway,
|
||||||
|
phase_policy=phase_policy,
|
||||||
|
)
|
||||||
|
# Register AdvancePhaseTool bound to this engine (LLM's escape hatch).
|
||||||
|
advance_phase_tool = AdvancePhaseTool(engine=react_engine)
|
||||||
|
routing.tools = list(routing.tools) + [advance_phase_tool]
|
||||||
|
else:
|
||||||
react_engine = getattr(agent, "_react_engine", None)
|
react_engine = getattr(agent, "_react_engine", None)
|
||||||
if react_engine is None:
|
if react_engine is None:
|
||||||
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
||||||
|
|
@ -1149,6 +1207,9 @@ async def _handle_chat_message(
|
||||||
try:
|
try:
|
||||||
final_content = ""
|
final_content = ""
|
||||||
token_buffer: list[str] = []
|
token_buffer: list[str] = []
|
||||||
|
# Track phase transitions for phase_changed events (PLAN_EXEC only).
|
||||||
|
# For non-PLAN_EXEC modes, current_phase is always None → no events.
|
||||||
|
prev_phase = react_engine.current_phase
|
||||||
async for event in react_engine.execute_stream(
|
async for event in react_engine.execute_stream(
|
||||||
messages=chat_messages,
|
messages=chat_messages,
|
||||||
tools=routing.tools,
|
tools=routing.tools,
|
||||||
|
|
@ -1226,6 +1287,22 @@ async def _handle_chat_message(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# U4/G6: emit phase_changed event when the phase state machine
|
||||||
|
# transitions (PLAN_EXEC only). For non-PLAN_EXEC modes,
|
||||||
|
# current_phase is always None → this branch never fires.
|
||||||
|
curr_phase = react_engine.current_phase
|
||||||
|
if curr_phase != prev_phase:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "phase_changed",
|
||||||
|
"data": {
|
||||||
|
"phase": curr_phase.value if curr_phase else None,
|
||||||
|
"previous": prev_phase.value if prev_phase else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prev_phase = curr_phase
|
||||||
|
|
||||||
# Append assistant reply to session
|
# Append assistant reply to session
|
||||||
if final_content:
|
if final_content:
|
||||||
await sm.append_message(
|
await sm.append_message(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,531 @@
|
||||||
|
"""Unit tests for PLAN_EXEC wiring at chat.py WebSocket path (G6, U4).
|
||||||
|
|
||||||
|
Per plan U4 Execution note: characterization-first — verify that existing
|
||||||
|
REWOO/REFLEXION/TEAM_COLLAB modes still fall back to REACT with the warning
|
||||||
|
(no regression). Then add PLAN_EXEC wiring tests.
|
||||||
|
|
||||||
|
KTD4: PLAN_EXEC is wired only at the WebSocket path; REST raises HTTP 501.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||||||
|
from agentkit.core.phase import PhaseState
|
||||||
|
from agentkit.tools.advance_phase import AdvancePhaseTool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_with_chat():
|
||||||
|
"""Create a FastAPI app with Chat routes and mocked dependencies."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from agentkit.server.routes.chat import router
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
|
||||||
|
from agentkit.session.manager import SessionManager
|
||||||
|
from agentkit.session.store import InMemorySessionStore
|
||||||
|
|
||||||
|
app.state.session_manager = SessionManager(store=InMemorySessionStore())
|
||||||
|
app.state.llm_gateway = MagicMock()
|
||||||
|
app.state.agent_pool = MagicMock()
|
||||||
|
app.state.server_config = MagicMock()
|
||||||
|
app.state.server_config.api_key = None
|
||||||
|
app.state.server_config.plan_exec = {}
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app_with_chat):
|
||||||
|
return TestClient(app_with_chat)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_routing(
|
||||||
|
execution_mode: ExecutionMode = ExecutionMode.REACT,
|
||||||
|
tools: list | None = None,
|
||||||
|
) -> SkillRoutingResult:
|
||||||
|
"""Build a minimal SkillRoutingResult for testing."""
|
||||||
|
return SkillRoutingResult(
|
||||||
|
execution_mode=execution_mode,
|
||||||
|
tools=tools or [],
|
||||||
|
clean_content="test message",
|
||||||
|
model="default",
|
||||||
|
agent_name="test-agent",
|
||||||
|
system_prompt=None,
|
||||||
|
skill_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_websocket_mock(app) -> MagicMock:
|
||||||
|
"""Build a mock WebSocket with app.state and async send_json."""
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.app = app
|
||||||
|
ws.send_json = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _make_agent_mock() -> MagicMock:
|
||||||
|
"""Build a mock Agent with _tool_registry and _react_engine."""
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.name = "test-agent"
|
||||||
|
agent._tool_registry = MagicMock()
|
||||||
|
agent._tool_registry.list_tools.return_value = []
|
||||||
|
agent._system_prompt = None
|
||||||
|
# _react_engine is None to force the code path that creates a new engine
|
||||||
|
agent._react_engine = None
|
||||||
|
agent.get_model.return_value = "default"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session_manager_mock() -> MagicMock:
|
||||||
|
"""Build a mock SessionManager with async methods."""
|
||||||
|
sm = MagicMock()
|
||||||
|
# get_session returns a mock session with agent_name="test-agent"
|
||||||
|
session = MagicMock()
|
||||||
|
session.agent_name = "test-agent"
|
||||||
|
session.status = "active"
|
||||||
|
sm.get_session = AsyncMock(return_value=session)
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[])
|
||||||
|
sm.append_message = AsyncMock()
|
||||||
|
return sm
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_routing(app, routing: SkillRoutingResult, agent: MagicMock) -> None:
|
||||||
|
"""Wire up app.state so _handle_chat_message finds the right routing."""
|
||||||
|
app.state.agent_pool.get_agent.return_value = agent
|
||||||
|
app.state.request_preprocessor = MagicMock()
|
||||||
|
app.state.request_preprocessor.preprocess = AsyncMock(return_value=routing)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# REST — PLAN_EXEC raises 501 (KTD4)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRestPlanExec501:
|
||||||
|
def test_rest_plan_exec_returns_501(self, client):
|
||||||
|
"""REST send_message with execution_mode=plan_exec → 501."""
|
||||||
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = client.post(
|
||||||
|
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||||
|
json={"content": "Hello", "execution_mode": "plan_exec"},
|
||||||
|
)
|
||||||
|
assert msg_resp.status_code == 501
|
||||||
|
assert "PLAN_EXEC via REST not yet supported" in msg_resp.json()["detail"]
|
||||||
|
|
||||||
|
def test_rest_react_mode_still_works(self, client):
|
||||||
|
"""REST send_message without execution_mode doesn't 501."""
|
||||||
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
||||||
|
session_id = create_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# No execution_mode field → should NOT trigger 501.
|
||||||
|
msg_resp = client.post(
|
||||||
|
f"/api/v1/chat/sessions/{session_id}/messages",
|
||||||
|
json={"content": "Hello"},
|
||||||
|
)
|
||||||
|
assert msg_resp.status_code != 501
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Characterization — REWOO still falls back to REACT (no regression)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rewoo_still_falls_back_to_react_without_phase_policy(app_with_chat):
|
||||||
|
"""Characterization: REWOO via WebSocket → no phase_policy (falls back to REACT)."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.REWOO)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine_kwargs: dict = {}
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_engine_kwargs.update(kwargs)
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield # async generator marker
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# REWOO should NOT build a phase_policy
|
||||||
|
assert captured_engine_kwargs.get("phase_policy") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Happy path — PLAN_EXEC builds phase policy + registers AdvancePhaseTool
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_builds_phase_policy_and_registers_advance_phase_tool(
|
||||||
|
app_with_chat,
|
||||||
|
):
|
||||||
|
"""PLAN_EXEC via WebSocket → engine with phase_policy, AdvancePhaseTool registered."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "test"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine: list = []
|
||||||
|
captured_tools: list = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = (
|
||||||
|
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
captured_tools.extend(kwargs.get("tools", []))
|
||||||
|
captured_engine.append(self)
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_engine) == 1
|
||||||
|
engine = captured_engine[0]
|
||||||
|
assert engine._phase_policy is not None
|
||||||
|
assert engine._current_phase == PhaseState.PLANNING
|
||||||
|
# AdvancePhaseTool was registered in the tools list
|
||||||
|
assert any(isinstance(t, AdvancePhaseTool) for t in captured_tools)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_empty_config_uses_default_policy(app_with_chat):
|
||||||
|
"""Edge: plan_exec config absent (empty dict) → default_policy() used."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_policy: list = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_policy.append(kwargs.get("phase_policy"))
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = (
|
||||||
|
kwargs.get("phase_policy").start_phase if kwargs.get("phase_policy") else None
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(captured_policy) == 1
|
||||||
|
assert captured_policy[0] is not None
|
||||||
|
# Default policy: PLANNING allows search but not write_file
|
||||||
|
assert "search" in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||||
|
assert "write_file" not in captured_policy[0].whitelist[PhaseState.PLANNING]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_disabled_falls_back_to_react(app_with_chat):
|
||||||
|
"""Edge: plan_exec.enabled=False → falls back to REACT (no phase_policy)."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {"enabled": False}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
captured_engine_kwargs: dict = {}
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_engine_kwargs.update(kwargs)
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
return
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# enabled=False → no phase_policy (falls back to REACT)
|
||||||
|
assert captured_engine_kwargs.get("phase_policy") is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_exec_bad_config_sends_error_and_returns(app_with_chat):
|
||||||
|
"""Error: phase policy construction fails → error event sent, returns early."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {"start_phase": "invalid_phase_name"}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
engine_constructor_called = []
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
engine_constructor_called.append(kwargs)
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
yield
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="test",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
error_messages = [m for m in sent_messages if m.get("type") == "error"]
|
||||||
|
assert len(error_messages) == 1
|
||||||
|
assert "phase policy error" in error_messages[0]["data"]["message"]
|
||||||
|
# Engine constructor was NOT called (returned early)
|
||||||
|
assert len(engine_constructor_called) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# phase_changed event emission
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_phase_changed_event_emitted_on_transition(app_with_chat):
|
||||||
|
"""phase_changed event sent when current_phase changes during execute_stream."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
app_with_chat.state.server_config.plan_exec = {}
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.PLAN_EXEC)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "go"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = kwargs.get("phase_policy")
|
||||||
|
self._current_phase = PhaseState.PLANNING
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return self._current_phase
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="tool_call",
|
||||||
|
step=1,
|
||||||
|
data={"tool": "search", "output": "ok"},
|
||||||
|
)
|
||||||
|
# Simulate phase transition (as if AdvancePhaseTool was called)
|
||||||
|
self._current_phase = PhaseState.BUILDING
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="final_answer",
|
||||||
|
step=2,
|
||||||
|
data={"output": "done"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="go",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||||
|
assert len(phase_events) == 1
|
||||||
|
assert phase_events[0]["data"]["phase"] == "building"
|
||||||
|
assert phase_events[0]["data"]["previous"] == "planning"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_phase_changed_event_when_not_plan_exec(app_with_chat):
|
||||||
|
"""Characterization: REACT mode → no phase_changed events."""
|
||||||
|
from agentkit.server.routes import chat as chat_module
|
||||||
|
|
||||||
|
agent = _make_agent_mock()
|
||||||
|
routing = _make_routing(execution_mode=ExecutionMode.REACT)
|
||||||
|
_setup_routing(app_with_chat, routing, agent)
|
||||||
|
|
||||||
|
sm = _make_session_manager_mock()
|
||||||
|
sm.get_chat_messages = AsyncMock(return_value=[{"role": "user", "content": "hi"}])
|
||||||
|
ws = _make_websocket_mock(app_with_chat)
|
||||||
|
|
||||||
|
class _StubEngine:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self._phase_policy = None
|
||||||
|
self._current_phase = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_phase(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_stream(self, **kwargs):
|
||||||
|
from agentkit.core.react import ReActEvent
|
||||||
|
|
||||||
|
yield ReActEvent(event_type="final_answer", step=1, data={"output": "hi"})
|
||||||
|
|
||||||
|
with pytest.MonkeyPatch().context() as mp:
|
||||||
|
mp.setattr(chat_module, "ReActEngine", _StubEngine)
|
||||||
|
|
||||||
|
await chat_module._handle_chat_message(
|
||||||
|
websocket=ws,
|
||||||
|
session_id="test-session",
|
||||||
|
content="hi",
|
||||||
|
sm=sm,
|
||||||
|
cancellation_token=MagicMock(),
|
||||||
|
pending_replies={},
|
||||||
|
pending_confirmations=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_messages = [call.args[0] for call in ws.send_json.call_args_list]
|
||||||
|
phase_events = [m for m in sent_messages if m.get("type") == "phase_changed"]
|
||||||
|
assert len(phase_events) == 0
|
||||||
Loading…
Reference in New Issue