"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket.""" from __future__ import annotations import asyncio import json import logging from typing import Any from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request from pydantic import BaseModel from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine from agentkit.session.manager import SessionManager from agentkit.session.models import MessageRole, SessionStatus logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["chat"]) # ── Request/Response schemas ────────────────────────────────────────── class CreateSessionRequest(BaseModel): agent_name: str metadata: dict[str, Any] | None = None class SendMessageRequest(BaseModel): content: str role: str = "user" class SessionResponse(BaseModel): session_id: str agent_name: str status: str metadata: dict[str, Any] created_at: str updated_at: str class MessageResponse(BaseModel): message_id: str session_id: str role: str content: str tool_call_id: str | None = None agent_name: str | None = None created_at: str # ── Chat WebSocket connection manager ───────────────────────────────── class ChatConnectionManager: """Track active WebSocket connections per session_id.""" def __init__(self) -> None: # session_id -> list of (websocket, pending_replies) self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {} def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None: self._connections.setdefault(session_id, []).append((ws, pending)) def remove(self, session_id: str, ws: WebSocket) -> None: conns = self._connections.get(session_id) if conns is None: return self._connections[session_id] = [(w, p) for w, p in conns if w is not ws] if not self._connections[session_id]: del self._connections[session_id] def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]: return self._connections.get(session_id, []) async def send_json(self, session_id: str, message: dict) -> None: """Send a JSON message to all connections for a session.""" conns = self._connections.get(session_id, []) stale: list[WebSocket] = [] for ws, _ in conns: try: await ws.send_json(message) except Exception: stale.append(ws) for ws in stale: self.remove(session_id, ws) chat_manager = ChatConnectionManager() # ── Helper ──────────────────────────────────────────────────────────── def _get_session_manager(request: Request) -> SessionManager: return request.app.state.session_manager def _session_to_response(session) -> SessionResponse: return SessionResponse( session_id=session.session_id, agent_name=session.agent_name, status=session.status.value, metadata=session.metadata, created_at=session.created_at.isoformat(), updated_at=session.updated_at.isoformat(), ) def _message_to_response(msg) -> MessageResponse: return MessageResponse( message_id=msg.message_id, session_id=msg.session_id, role=msg.role.value, content=msg.content, tool_call_id=msg.tool_call_id, agent_name=msg.agent_name, created_at=msg.created_at.isoformat(), ) # ── REST endpoints ──────────────────────────────────────────────────── @router.get("/sessions", response_model=list[SessionResponse]) async def list_sessions(req: Request): """List all chat sessions.""" sm = _get_session_manager(req) sessions = await sm.list_sessions() return [_session_to_response(s) for s in sessions] @router.post("/sessions", response_model=SessionResponse) async def create_session(request: CreateSessionRequest, req: Request): """Create a new chat session bound to an Agent.""" sm = _get_session_manager(req) session = await sm.create_session( agent_name=request.agent_name, metadata=request.metadata, ) return _session_to_response(session) @router.get("/sessions/{session_id}", response_model=SessionResponse) async def get_session(session_id: str, req: Request): """Get session information.""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") return _session_to_response(session) @router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse]) async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0): """Get conversation history for a session.""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") messages = await sm.get_messages(session_id, limit=limit, offset=offset) return [_message_to_response(m) for m in messages] @router.post("/sessions/{session_id}/messages", response_model=MessageResponse) async def send_message(session_id: str, request: SendMessageRequest, req: Request): """Send a message to the Agent (synchronous mode — waits for full reply).""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") if session.status == SessionStatus.CLOSED: raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") # Append user message user_msg = await sm.append_message( session_id=session_id, role=MessageRole.USER, content=request.content, ) # Get full conversation history for the Agent chat_messages = await sm.get_chat_messages(session_id) # Resolve the Agent pool = req.app.state.agent_pool agent = pool.get_agent(session.agent_name) if agent is None: raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found") # Execute the Agent try: react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) tools = agent._tool_registry.list_tools() if agent._tool_registry else [] system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) result = await react_engine.execute( messages=chat_messages, tools=tools, model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"), agent_name=agent.name, system_prompt=system_prompt, ) # Append assistant reply assistant_msg = await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=result.output if hasattr(result, "output") else str(result), agent_name=agent.name, ) return _message_to_response(assistant_msg) except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/sessions/{session_id}") async def close_session(session_id: str, req: Request): """Close a chat session.""" sm = _get_session_manager(req) session = await sm.close_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") return {"status": "closed", "session_id": session_id} # ── WebSocket endpoint ──────────────────────────────────────────────── @router.websocket("/ws/{session_id}") async def chat_websocket(websocket: WebSocket, session_id: str) -> None: """WebSocket endpoint for real-time chat with streaming. Client → Server messages: {"type": "message", "content": "..."} — Send a user message {"type": "cancel"} — Cancel current execution {"type": "ping"} — Heartbeat Server → Client messages: {"type": "connected", "session_id": "..."} — Connection confirmed {"type": "token", "content": "..."} — LLM token streaming {"type": "step", "data": {...}} — ReAct step event {"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user {"type": "final_answer", "content": "..."} — Agent's final reply {"type": "error", "data": {"message": "..."}} — Error occurred {"type": "pong"} — Heartbeat response """ # Authentication configured_api_key: str | None = None if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config: configured_api_key = websocket.app.state.server_config.api_key if configured_api_key is None and hasattr(websocket.app.state, "api_key"): configured_api_key = websocket.app.state.api_key if configured_api_key: provided = websocket.query_params.get("api_key") if provided != configured_api_key: await websocket.accept() await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}}) await websocket.close(code=4001, reason="Invalid api_key") return await websocket.accept() # Validate session sm: SessionManager = websocket.app.state.session_manager session = await sm.get_session(session_id) if session is None: await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}}) await websocket.close(code=1000, reason="Session not found") return if session.status == SessionStatus.CLOSED: await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}}) await websocket.close(code=1000, reason="Session closed") return # Track pending replies for AskHumanTool pending_replies: dict[str, asyncio.Future] = {} chat_manager.add(session_id, websocket, pending_replies) cancellation_token = CancellationToken() try: await websocket.send_json({"type": "connected", "session_id": session_id}) # Listen for client messages while True: try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0) except asyncio.TimeoutError: await websocket.send_json({"type": "pong"}) continue try: msg = json.loads(raw) except json.JSONDecodeError: continue msg_type = msg.get("type") if msg_type == "message": content = msg.get("content", "") # Create a fresh CancellationToken for each message message_token = CancellationToken() await _handle_chat_message( websocket, session_id, content, sm, message_token, pending_replies ) elif msg_type == "reply": # Reply to AskHumanTool request_id = msg.get("request_id") reply_content = msg.get("content", "") if request_id and request_id in pending_replies: pending_replies[request_id].set_result(reply_content) elif msg_type == "cancel": cancellation_token.cancel() await websocket.send_json({"type": "result", "data": {"status": "cancelled"}}) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: logger.debug(f"Chat WebSocket disconnected for session {session_id}") except Exception as e: logger.error(f"Chat WebSocket error for session {session_id}: {e}") try: await websocket.send_json({"type": "error", "data": {"message": str(e)}}) except Exception: pass finally: # Clean up pending futures for fut in pending_replies.values(): if not fut.done(): fut.cancel() chat_manager.remove(session_id, websocket) async def _handle_chat_message( websocket: WebSocket, session_id: str, content: str, sm: SessionManager, cancellation_token: CancellationToken, pending_replies: dict[str, asyncio.Future], ) -> None: """Handle a user message: append to session, execute Agent, stream events. When skills are registered, attempts to route the user's message to a matching skill via IntentRouter. If a skill is matched, the skill's prompt, tools, and execution_mode are used instead of the default agent's. """ from agentkit.chat.skill_routing import resolve_skill_routing # Resolve Agent first (needed for default tools/prompt) pool = websocket.app.state.agent_pool session = await sm.get_session(session_id) if session is None: await websocket.send_json({"type": "error", "data": {"message": "Session lost"}}) return agent = pool.get_agent(session.agent_name) if agent is None: await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}) return # Default execution parameters from agent default_tools = agent._tool_registry.list_tools() if agent._tool_registry else [] default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None) default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default") # Resolve skill routing using shared module skill_registry = getattr(websocket.app.state, "skill_registry", None) intent_router = getattr(websocket.app.state, "intent_router", None) routing = await resolve_skill_routing( content=content, skill_registry=skill_registry, intent_router=intent_router, default_tools=default_tools, default_system_prompt=default_system_prompt, default_model=default_model, default_agent_name=agent.name, agent_tool_registry=agent._tool_registry if agent._tool_registry else None, session_id=session_id, ) # Notify frontend about skill match if routing.matched: await websocket.send_json({ "type": "skill_match", "data": { "skill": routing.skill_name, "method": routing.match_method, "confidence": routing.match_confidence, }, }) # Append user message (use clean_content if @skill: prefix was stripped) await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content) # Get full conversation history chat_messages = await sm.get_chat_messages(session_id) # Execute Agent with streaming react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}") try: final_content = "" async for event in react_engine.execute_stream( messages=chat_messages, tools=routing.tools, model=routing.model, agent_name=routing.agent_name, system_prompt=routing.system_prompt, cancellation_token=cancellation_token, ): if event.event_type == "final_answer": final_content = event.data.get("output", "") await websocket.send_json({ "type": "final_answer", "content": final_content, }) elif event.event_type == "token": await websocket.send_json({ "type": "token", "content": event.data.get("content", ""), }) else: await websocket.send_json({ "type": "step", "data": { "event_type": event.event_type, "step": event.step, "data": event.data, }, }) # Append assistant reply to session if final_content: await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=final_content, agent_name=agent.name, ) except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") # Show meaningful error to user, but avoid leaking full stack traces error_msg = str(e) # Truncate very long error messages if len(error_msg) > 200: error_msg = error_msg[:200] + "..." await websocket.send_json({"type": "error", "data": {"message": error_msg}})