fischer-agentkit/src/agentkit/server/routes/chat.py

685 lines
28 KiB
Python

"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
from __future__ import annotations
import asyncio
import hmac
import json
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
from pydantic import BaseModel
from agentkit.chat.skill_routing import ExecutionMode
from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Request/Response schemas ──────────────────────────────────────────
class CreateSessionRequest(BaseModel):
agent_name: str
metadata: dict[str, Any] | None = None
class SendMessageRequest(BaseModel):
content: str
role: str = "user"
class SessionResponse(BaseModel):
session_id: str
agent_name: str
status: str
metadata: dict[str, Any]
created_at: str
updated_at: str
class MessageResponse(BaseModel):
message_id: str
session_id: str
role: str
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: str
# ── Chat WebSocket connection manager ─────────────────────────────────
class ChatConnectionManager:
"""Track active WebSocket connections per session_id."""
def __init__(self) -> None:
# session_id -> list of (websocket, pending_replies)
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
self._connections.setdefault(session_id, []).append((ws, pending))
def remove(self, session_id: str, ws: WebSocket) -> None:
conns = self._connections.get(session_id)
if conns is None:
return
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
if not self._connections[session_id]:
del self._connections[session_id]
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
return self._connections.get(session_id, [])
async def send_json(self, session_id: str, message: dict) -> None:
"""Send a JSON message to all connections for a session."""
conns = self._connections.get(session_id, [])
stale: list[WebSocket] = []
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
stale.append(ws)
for ws in stale:
self.remove(session_id, ws)
chat_manager = ChatConnectionManager()
# ── Helper ────────────────────────────────────────────────────────────
_VALID_TEAM_EVENT_TYPES = frozenset({
"team_formed", "expert_step", "expert_result",
"plan_update", "team_synthesis", "team_dissolved",
"plan_step", "phase_started", "phase_completed", "phase_failed",
"replanning",
})
async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None:
"""Emit a team-related WebSocket event.
Supported event types:
team_formed — Team assembled with expert list and plan.
data: {team_id, status, experts, plan_phases, lead_expert}
expert_step — An expert is executing a step.
data: {expert_id, expert_name, expert_color, content, step}
expert_result — An expert produced a result.
data: {expert_id, expert_name, expert_color, content}
plan_update — Team plan phases updated.
data: {plan_phases}
team_synthesis — Final synthesis from the lead expert.
data: {content}
team_dissolved — Team dissolved after completion.
data: {team_id}
"""
if event_type not in _VALID_TEAM_EVENT_TYPES:
logger.warning(f"emit_team_event: invalid event_type '{event_type}'")
return
await websocket.send_json({
"type": event_type,
"data": data,
})
def _get_session_manager(request: Request) -> SessionManager:
return request.app.state.session_manager
def _session_to_response(session) -> SessionResponse:
return SessionResponse(
session_id=session.session_id,
agent_name=session.agent_name,
status=session.status.value,
metadata=session.metadata,
created_at=session.created_at.isoformat(),
updated_at=session.updated_at.isoformat(),
)
def _message_to_response(msg) -> MessageResponse:
return MessageResponse(
message_id=msg.message_id,
session_id=msg.session_id,
role=msg.role.value,
content=msg.content,
tool_call_id=msg.tool_call_id,
agent_name=msg.agent_name,
created_at=msg.created_at.isoformat(),
)
# ── REST endpoints ────────────────────────────────────────────────────
@router.get("/sessions", response_model=list[SessionResponse])
async def list_sessions(req: Request):
"""List all chat sessions."""
sm = _get_session_manager(req)
sessions = await sm.list_sessions()
return [_session_to_response(s) for s in sessions]
@router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new chat session bound to an Agent."""
sm = _get_session_manager(req)
session = await sm.create_session(
agent_name=request.agent_name,
metadata=request.metadata,
)
return _session_to_response(session)
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str, req: Request):
"""Get session information."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return _session_to_response(session)
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0):
"""Get conversation history for a session."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
return [_message_to_response(m) for m in messages]
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# Append user message
await sm.append_message(
session_id=session_id,
role=MessageRole.USER,
content=request.content,
)
# Get full conversation history for the Agent
chat_messages = await sm.get_chat_messages(session_id)
# Resolve the Agent
pool = req.app.state.agent_pool
agent = pool.get_agent(session.agent_name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
# Execute the Agent
try:
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
else:
react_engine.reset()
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
result = await react_engine.execute(
messages=chat_messages,
tools=tools,
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"),
agent_name=agent.name,
system_prompt=system_prompt,
)
# Append assistant reply
assistant_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result),
agent_name=agent.name,
)
return _message_to_response(assistant_msg)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/sessions/{session_id}")
async def close_session(session_id: str, req: Request):
"""Close a chat session."""
sm = _get_session_manager(req)
session = await sm.close_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return {"status": "closed", "session_id": session_id}
# ── WebSocket endpoint ────────────────────────────────────────────────
@router.websocket("/ws/{session_id}")
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket endpoint for real-time chat with streaming.
Client → Server messages:
{"type": "message", "content": "..."} — Send a user message
{"type": "cancel"} — Cancel current execution
{"type": "ping"} — Heartbeat
Server → Client messages:
{"type": "connected", "session_id": "..."} — Connection confirmed
{"type": "token", "content": "..."} — LLM token streaming
{"type": "step", "data": {...}} — ReAct step event
{"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user
{"type": "final_answer", "content": "..."} — Agent's final reply
{"type": "error", "data": {"message": "..."}} — Error occurred
{"type": "pong"} — Heartbeat response
Expert Team events (Server → Client):
{"type": "team_formed", "data": {...}} — Team assembled
{"type": "expert_step", "data": {...}} — Expert executing step
{"type": "expert_result", "data": {...}} — Expert produced result
{"type": "plan_update", "data": {...}} — Plan phases updated
{"type": "team_synthesis", "data": {...}} — Final synthesis
{"type": "team_dissolved", "data": {...}} — Team dissolved
"""
# Authentication
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if not provided or not hmac.compare_digest(provided, configured_api_key):
await websocket.accept()
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
await websocket.close(code=4001, reason="Invalid api_key")
return
await websocket.accept()
# Validate session
sm: SessionManager = websocket.app.state.session_manager
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
await websocket.close(code=1000, reason="Session not found")
return
if session.status == SessionStatus.CLOSED:
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
await websocket.close(code=1000, reason="Session closed")
return
# Track pending replies for AskHumanTool and confirmations
pending_replies: dict[str, asyncio.Future] = {}
pending_confirmations: dict[str, asyncio.Future] = {}
chat_manager.add(session_id, websocket, pending_replies)
cancellation_token = CancellationToken()
# Per-session concurrency guard to prevent unlimited task creation (DoS mitigation)
_MAX_CONCURRENT_TASKS = 4
active_tasks: set[asyncio.Task] = set()
try:
await websocket.send_json({"type": "connected", "session_id": session_id})
# Listen for client messages
while True:
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
except asyncio.TimeoutError:
await websocket.send_json({"type": "pong"})
continue
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = msg.get("type")
if msg_type == "message":
content = msg.get("content", "")
model = msg.get("model") # Optional model override from frontend
# Create a fresh CancellationToken for each message
message_token = CancellationToken()
# Guard against unlimited concurrent tasks
# Clean up completed tasks first
active_tasks.difference_update(t for t in active_tasks if t.done())
if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
await websocket.send_json({
"type": "error",
"data": {"message": "Too many concurrent requests. Please wait for the current task to complete."},
})
continue
# Run in background task so the WebSocket receive loop stays free
# to process confirmation_reply / reply messages while the agent
# is waiting for user confirmation (otherwise deadlock).
task = asyncio.create_task(
_handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations,
model_override=model,
)
)
active_tasks.add(task)
task.add_done_callback(active_tasks.discard)
elif msg_type == "reply":
# Reply to AskHumanTool
request_id = msg.get("request_id")
reply_content = msg.get("content", "")
if request_id and request_id in pending_replies:
pending_replies[request_id].set_result(reply_content)
elif msg_type == "confirmation_reply":
# Reply to confirmation request
confirmation_id = msg.get("confirmation_id")
approved = msg.get("approved", False)
logger.info(f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}")
if confirmation_id and confirmation_id in pending_confirmations:
pending_confirmations[confirmation_id].set_result(approved)
logger.info(f"Confirmation {confirmation_id} set_result({approved})")
else:
logger.warning(f"Confirmation {confirmation_id!r} not found in pending_confirmations")
elif msg_type == "cancel":
cancellation_token.cancel()
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
pass
finally:
# Clean up pending futures
for fut in pending_replies.values():
if not fut.done():
fut.cancel()
for fut in pending_confirmations.values():
if not fut.done():
fut.cancel()
chat_manager.remove(session_id, websocket)
async def _handle_chat_message(
websocket: WebSocket,
session_id: str,
content: str,
sm: SessionManager,
cancellation_token: CancellationToken,
pending_replies: dict[str, asyncio.Future],
pending_confirmations: dict[str, asyncio.Future] | None = None,
model_override: str | None = None,
) -> None:
"""Handle a user message: append to session, execute Agent, stream events.
Uses SimpleRouter for minimal routing: @skill prefix + greeting regex + REACT.
"""
from agentkit.chat.simple_router import SimpleRouter
# Resolve Agent first (needed for default tools/prompt)
pool = websocket.app.state.agent_pool
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
return
agent = pool.get_agent(session.agent_name)
if agent is None:
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
return
# Default execution parameters from agent
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
# Resolve skill routing using SimpleRouter
skill_registry = getattr(websocket.app.state, "skill_registry", None)
simple_router: SimpleRouter = websocket.app.state.simple_router
routing = await simple_router.route(
content=content,
skill_registry=skill_registry,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=agent.name,
)
# Debug: log tools that will be passed to ReActEngine
tool_names = [t.name for t in routing.tools]
logger.info(f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}")
# Apply model override from frontend selector
if model_override:
routing.model = model_override
# Notify frontend about skill match
if routing.matched:
await websocket.send_json({
"type": "skill_match",
"data": {
"skill": routing.skill_name,
"method": routing.match_method,
"confidence": routing.match_confidence,
},
})
# Append user message (use clean_content if @skill: prefix was stripped)
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content)
# Get full conversation history
chat_messages = await sm.get_chat_messages(session_id)
# Handle DIRECT_CHAT: direct LLM call, no ReAct loop
if routing.execution_mode == ExecutionMode.DIRECT_CHAT:
direct_messages = []
if routing.system_prompt:
direct_messages.append({"role": "system", "content": routing.system_prompt})
direct_messages.extend(chat_messages)
try:
response = await websocket.app.state.llm_gateway.chat(
messages=direct_messages,
model=routing.model or "default",
agent_name=agent.name,
task_type="chat",
)
final_content = response.content or ""
if not final_content or not final_content.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
final_content = EMPTY_LLM_RESPONSE
await websocket.send_json({
"type": "final_answer",
"content": final_content,
"is_final": True,
})
await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=final_content,
agent_name=agent.name,
)
except Exception as e:
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
return
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
# currently fall back to REACT with a warning.
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
logger.warning(
f"Execution mode {routing.execution_mode.value} not yet supported "
f"in chat WebSocket, falling back to REACT"
)
# Execute Agent with streaming
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
react_engine = getattr(agent, "_react_engine", None)
if react_engine is None:
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
else:
react_engine.reset()
# Create confirmation handler that sends request to frontend and waits for reply
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
# and would create a new dict, breaking the shared state with the WS loop.
_pending_confirmations = pending_confirmations if pending_confirmations is not None else {}
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
# Send confirmation request to frontend
await websocket.send_json({
"type": "confirmation_request",
"data": {
"confirmation_id": confirmation_id,
"command": command,
"reason": reason,
},
})
# Create a Future and wait for the user's reply
loop = asyncio.get_running_loop()
future: asyncio.Future[bool] = loop.create_future()
_pending_confirmations[confirmation_id] = future
logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply")
try:
# Wait up to 5 minutes for user confirmation
result = await asyncio.wait_for(future, timeout=300.0)
logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
# Immediately notify frontend of the result so the card updates
# without waiting for the tool to re-execute
await websocket.send_json({
"type": "confirmation_result",
"data": {"confirmation_id": confirmation_id, "approved": result},
})
return result
except asyncio.TimeoutError:
logger.warning(f"Confirmation request {confirmation_id} timed out")
return False
except asyncio.CancelledError:
logger.warning(f"Confirmation request {confirmation_id} cancelled")
return False
finally:
_pending_confirmations.pop(confirmation_id, None)
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
try:
final_content = ""
token_buffer: list[str] = []
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=routing.tools,
model=routing.model,
agent_name=routing.agent_name,
system_prompt=routing.system_prompt,
cancellation_token=cancellation_token,
confirmation_handler=_confirmation_handler,
):
if event.event_type == "final_answer":
# Flush any buffered tokens as a single write
if token_buffer:
await websocket.send_json({"type": "token", "content": "".join(token_buffer)})
token_buffer.clear()
# Then send final answer
final_content = event.data.get("output", "")
if not final_content or not final_content.strip():
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
final_content = EMPTY_LLM_RESPONSE
await websocket.send_json({
"type": "final_answer",
"content": final_content,
"is_final": True,
})
elif event.event_type == "token":
# Buffer tokens instead of sending immediately
token_buffer.append(event.data.get("content", ""))
elif event.event_type == "thinking":
# If we have buffered tokens, convert them to a thinking event
if token_buffer:
buffered_text = "".join(token_buffer)
token_buffer.clear()
await websocket.send_json({"type": "thinking", "content": buffered_text})
# Also send the thinking event content
thinking_msg = event.data.get("message", "")
if thinking_msg:
await websocket.send_json({"type": "thinking", "content": thinking_msg})
elif event.event_type == "tool_call":
# Convert buffered tokens to thinking (they were "thinking" text before tool call)
if token_buffer:
buffered_text = "".join(token_buffer)
token_buffer.clear()
await websocket.send_json({"type": "thinking", "content": buffered_text})
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
elif event.event_type == "confirmation_request":
pass
elif event.event_type == "confirmation_result":
await websocket.send_json({
"type": "confirmation_result",
"data": event.data,
})
else:
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
# Append assistant reply to session
if final_content:
await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=final_content,
agent_name=agent.name,
)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
# Show meaningful error to user, but avoid leaking full stack traces
error_msg = str(e)
# Truncate very long error messages
if len(error_msg) > 200:
error_msg = error_msg[:200] + "..."
await websocket.send_json({"type": "error", "data": {"message": error_msg}})