"""WebSocket route for bidirectional real-time task communication.""" import asyncio import hmac import json import logging from typing import Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine logger = logging.getLogger(__name__) router = APIRouter(tags=["websocket"]) # WebSocket close codes WS_CODE_UNAUTHENTICATED = 4001 WS_CODE_SERVER_ERROR = 1011 class ConnectionManager: """Track active WebSocket connections per task_id for fan-out.""" def __init__(self) -> None: # task_id -> list of (websocket, cancellation_token) self._connections: dict[str, list[tuple[WebSocket, CancellationToken]]] = {} def add(self, task_id: str, ws: WebSocket, token: CancellationToken) -> None: self._connections.setdefault(task_id, []).append((ws, token)) def remove(self, task_id: str, ws: WebSocket) -> None: conns = self._connections.get(task_id) if conns is None: return self._connections[task_id] = [(w, t) for w, t in conns if w is not ws] if not self._connections[task_id]: del self._connections[task_id] def get_tokens(self, task_id: str) -> list[CancellationToken]: return [t for _, t in self._connections.get(task_id, [])] async def broadcast(self, task_id: str, message: dict[str, Any]) -> None: conns = self._connections.get(task_id, []) stale: list[WebSocket] = [] for ws, _ in conns: try: await ws.send_json(message) except (ConnectionError, RuntimeError, asyncio.TimeoutError): stale.append(ws) for ws in stale: self.remove(task_id, ws) def has_connections(self, task_id: str) -> bool: return bool(self._connections.get(task_id)) manager = ConnectionManager() def _authenticate(websocket: WebSocket, api_key: str | None) -> bool: """Check api_key query param against the configured key. Returns True if the connection should be allowed. """ # No API key configured → dev mode, allow all if not api_key: return True provided = websocket.query_params.get("api_key") return hmac.compare_digest(provided or "", api_key) @router.websocket("/ws/tasks/{task_id}") async def task_websocket(websocket: WebSocket, task_id: str) -> None: """WebSocket endpoint for real-time task execution and monitoring. Client → Server messages: {"type": "cancel"} — Cancel the running task {"type": "ping"} — Heartbeat Server → Client messages: {"type": "connected", "task_id": "..."} — Connection confirmed {"type": "step", "data": {...}} — ReAct step event {"type": "result", "data": {...}} — Final task result {"type": "error", "data": {"message": "..."}} — Error occurred {"type": "pong"} — Heartbeat response """ # Authentication — must accept before sending/closing configured_api_key: str | None = None if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config: configured_api_key = websocket.app.state.server_config.api_key # Fallback: check app.state.api_key (set by create_app when api_key param is used) if configured_api_key is None and hasattr(websocket.app.state, "api_key"): configured_api_key = websocket.app.state.api_key if not _authenticate(websocket, configured_api_key): await websocket.accept() await websocket.send_json( { "type": "error", "data": {"message": "Invalid or missing api_key"}, } ) await websocket.close(code=WS_CODE_UNAUTHENTICATED, reason="Invalid or missing api_key") return await websocket.accept() cancellation_token = CancellationToken() manager.add(task_id, websocket, cancellation_token) try: # Send connected confirmation await websocket.send_json({"type": "connected", "task_id": task_id}) # Resolve agent and start execution in background agent = _resolve_agent(websocket, task_id) if agent is None: await websocket.send_json( { "type": "error", "data": {"message": f"No agent available for task {task_id}"}, } ) return # Run the ReAct loop and client listener concurrently exec_task = asyncio.create_task( _run_react_and_stream(websocket, task_id, agent, cancellation_token) ) listener_task = asyncio.create_task( _listen_client_messages(websocket, task_id, cancellation_token, exec_task) ) done, pending = await asyncio.wait( [exec_task, listener_task], return_when=asyncio.FIRST_COMPLETED, ) for t in pending: t.cancel() try: await t except asyncio.CancelledError: pass # Propagate exec errors if exec_task in done and exec_task.exception(): err = exec_task.exception() logger.error(f"WebSocket exec error for task {task_id}: {err}") except WebSocketDisconnect: logger.debug(f"WebSocket disconnected for task {task_id}") except asyncio.CancelledError: raise except Exception as e: logger.error(f"WebSocket error for task {task_id}: {e}") try: await websocket.send_json( { "type": "error", "data": {"message": str(e)}, } ) except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: manager.remove(task_id, websocket) def _resolve_agent(websocket: WebSocket, task_id: str): """Try to find an agent from the pool for the given task.""" pool = websocket.app.state.agent_pool agents = list(pool._agents.values()) if hasattr(pool, "_agents") else [] if not agents: return None # Try to find agent by task_id mapping if available if hasattr(pool, "get_agent_for_task"): agent = pool.get_agent_for_task(task_id) if agent: return agent return agents[0] async def _run_react_and_stream( websocket: WebSocket, task_id: str, agent, cancellation_token: CancellationToken, ) -> None: """Execute ReAct loop and stream events to the WebSocket client.""" react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) messages = [{"role": "user", "content": str(task_id)}] tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else [] try: async for event in react_engine.execute_stream( messages=messages, tools=tools, model=agent.get_model() if hasattr(agent, "get_model") else (agent._llm_model if hasattr(agent, "_llm_model") else "default"), agent_name=agent.name, system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None, cancellation_token=cancellation_token, ): if event.event_type == "final_answer": await websocket.send_json( { "type": "result", "data": { "output": event.data.get("output", ""), "total_steps": event.data.get("total_steps", 0), "total_tokens": event.data.get("total_tokens", 0), }, } ) else: await websocket.send_json( { "type": "step", "data": { "event_type": event.event_type, "step": event.step, "data": event.data, "timestamp": event.timestamp, }, } ) # Also broadcast to other subscribers await manager.broadcast( task_id, { "type": "step", "data": { "event_type": event.event_type, "step": event.step, "data": event.data, "timestamp": event.timestamp, }, }, ) except asyncio.CancelledError: raise except Exception as e: await websocket.send_json( { "type": "error", "data": {"message": str(e)}, } ) async def _listen_client_messages( websocket: WebSocket, task_id: str, cancellation_token: CancellationToken, _exec_task: asyncio.Task, ) -> None: """Listen for client messages (cancel, ping) with heartbeat timeout.""" try: while True: try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=60.0) except asyncio.TimeoutError: # No message in 60s → close connection await websocket.close(code=1000, reason="Heartbeat timeout") return try: msg = json.loads(raw) except json.JSONDecodeError: continue msg_type = msg.get("type") if msg_type == "cancel": cancellation_token.cancel() # Also cancel any asyncio task via runner runner = websocket.app.state.runner await runner.cancel(task_id) # Cancel all tokens for this task (fan-out) for token in manager.get_tokens(task_id): token.cancel() await websocket.send_json( { "type": "result", "data": {"status": "cancelled", "task_id": task_id}, } ) return elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: pass except asyncio.CancelledError: pass