305 lines
10 KiB
Python
305 lines
10 KiB
Python
"""WebSocket route for bidirectional real-time task communication."""
|
|
|
|
import asyncio
|
|
import hmac
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
from agentkit.core.protocol import CancellationToken
|
|
from agentkit.core.react import ReActEngine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["websocket"])
|
|
|
|
# WebSocket close codes
|
|
WS_CODE_UNAUTHENTICATED = 4001
|
|
WS_CODE_SERVER_ERROR = 1011
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Track active WebSocket connections per task_id for fan-out."""
|
|
|
|
def __init__(self) -> None:
|
|
# task_id -> list of (websocket, cancellation_token)
|
|
self._connections: dict[str, list[tuple[WebSocket, CancellationToken]]] = {}
|
|
|
|
def add(self, task_id: str, ws: WebSocket, token: CancellationToken) -> None:
|
|
self._connections.setdefault(task_id, []).append((ws, token))
|
|
|
|
def remove(self, task_id: str, ws: WebSocket) -> None:
|
|
conns = self._connections.get(task_id)
|
|
if conns is None:
|
|
return
|
|
self._connections[task_id] = [(w, t) for w, t in conns if w is not ws]
|
|
if not self._connections[task_id]:
|
|
del self._connections[task_id]
|
|
|
|
def get_tokens(self, task_id: str) -> list[CancellationToken]:
|
|
return [t for _, t in self._connections.get(task_id, [])]
|
|
|
|
async def broadcast(self, task_id: str, message: dict[str, Any]) -> None:
|
|
conns = self._connections.get(task_id, [])
|
|
stale: list[WebSocket] = []
|
|
for ws, _ in conns:
|
|
try:
|
|
await ws.send_json(message)
|
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
|
stale.append(ws)
|
|
for ws in stale:
|
|
self.remove(task_id, ws)
|
|
|
|
def has_connections(self, task_id: str) -> bool:
|
|
return bool(self._connections.get(task_id))
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
def _authenticate(websocket: WebSocket, api_key: str | None) -> bool:
|
|
"""Check api_key query param against the configured key.
|
|
|
|
Returns True if the connection should be allowed.
|
|
"""
|
|
# No API key configured → dev mode, allow all
|
|
if not api_key:
|
|
return True
|
|
|
|
provided = websocket.query_params.get("api_key")
|
|
return hmac.compare_digest(provided or "", api_key)
|
|
|
|
|
|
@router.websocket("/ws/tasks/{task_id}")
|
|
async def task_websocket(websocket: WebSocket, task_id: str) -> None:
|
|
"""WebSocket endpoint for real-time task execution and monitoring.
|
|
|
|
Client → Server messages:
|
|
{"type": "cancel"} — Cancel the running task
|
|
{"type": "ping"} — Heartbeat
|
|
|
|
Server → Client messages:
|
|
{"type": "connected", "task_id": "..."} — Connection confirmed
|
|
{"type": "step", "data": {...}} — ReAct step event
|
|
{"type": "result", "data": {...}} — Final task result
|
|
{"type": "error", "data": {"message": "..."}} — Error occurred
|
|
{"type": "pong"} — Heartbeat response
|
|
"""
|
|
# Authentication — must accept before sending/closing
|
|
configured_api_key: str | None = None
|
|
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
|
|
configured_api_key = websocket.app.state.server_config.api_key
|
|
# Fallback: check app.state.api_key (set by create_app when api_key param is used)
|
|
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
|
|
configured_api_key = websocket.app.state.api_key
|
|
|
|
if not _authenticate(websocket, configured_api_key):
|
|
await websocket.accept()
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": "Invalid or missing api_key"},
|
|
}
|
|
)
|
|
await websocket.close(code=WS_CODE_UNAUTHENTICATED, reason="Invalid or missing api_key")
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
cancellation_token = CancellationToken()
|
|
manager.add(task_id, websocket, cancellation_token)
|
|
|
|
try:
|
|
# Send connected confirmation
|
|
await websocket.send_json({"type": "connected", "task_id": task_id})
|
|
|
|
# Resolve agent and start execution in background
|
|
agent = _resolve_agent(websocket, task_id)
|
|
if agent is None:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": f"No agent available for task {task_id}"},
|
|
}
|
|
)
|
|
return
|
|
|
|
# Run the ReAct loop and client listener concurrently
|
|
exec_task = asyncio.create_task(
|
|
_run_react_and_stream(websocket, task_id, agent, cancellation_token)
|
|
)
|
|
listener_task = asyncio.create_task(
|
|
_listen_client_messages(websocket, task_id, cancellation_token, exec_task)
|
|
)
|
|
|
|
done, pending = await asyncio.wait(
|
|
[exec_task, listener_task],
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
)
|
|
|
|
for t in pending:
|
|
t.cancel()
|
|
try:
|
|
await t
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Propagate exec errors
|
|
if exec_task in done and exec_task.exception():
|
|
err = exec_task.exception()
|
|
logger.error(f"WebSocket exec error for task {task_id}: {err}")
|
|
|
|
except WebSocketDisconnect:
|
|
logger.debug(f"WebSocket disconnected for task {task_id}")
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error for task {task_id}: {e}")
|
|
try:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": str(e)},
|
|
}
|
|
)
|
|
except (ConnectionError, RuntimeError, asyncio.TimeoutError):
|
|
pass
|
|
finally:
|
|
manager.remove(task_id, websocket)
|
|
|
|
|
|
def _resolve_agent(websocket: WebSocket, task_id: str):
|
|
"""Try to find an agent from the pool for the given task."""
|
|
pool = websocket.app.state.agent_pool
|
|
agents = list(pool._agents.values()) if hasattr(pool, "_agents") else []
|
|
if not agents:
|
|
return None
|
|
# Try to find agent by task_id mapping if available
|
|
if hasattr(pool, "get_agent_for_task"):
|
|
agent = pool.get_agent_for_task(task_id)
|
|
if agent:
|
|
return agent
|
|
return agents[0]
|
|
|
|
|
|
async def _run_react_and_stream(
|
|
websocket: WebSocket,
|
|
task_id: str,
|
|
agent,
|
|
cancellation_token: CancellationToken,
|
|
) -> None:
|
|
"""Execute ReAct loop and stream events to the WebSocket client."""
|
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
|
|
|
messages = [{"role": "user", "content": str(task_id)}]
|
|
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
|
|
|
|
try:
|
|
async for event in react_engine.execute_stream(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=agent.get_model()
|
|
if hasattr(agent, "get_model")
|
|
else (agent._llm_model if hasattr(agent, "_llm_model") else "default"),
|
|
agent_name=agent.name,
|
|
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
|
|
cancellation_token=cancellation_token,
|
|
):
|
|
if event.event_type == "final_answer":
|
|
await websocket.send_json(
|
|
{
|
|
"type": "result",
|
|
"data": {
|
|
"output": event.data.get("output", ""),
|
|
"total_steps": event.data.get("total_steps", 0),
|
|
"total_tokens": event.data.get("total_tokens", 0),
|
|
},
|
|
}
|
|
)
|
|
else:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "step",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"step": event.step,
|
|
"data": event.data,
|
|
"timestamp": event.timestamp,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Also broadcast to other subscribers
|
|
await manager.broadcast(
|
|
task_id,
|
|
{
|
|
"type": "step",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"step": event.step,
|
|
"data": event.data,
|
|
"timestamp": event.timestamp,
|
|
},
|
|
},
|
|
)
|
|
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": str(e)},
|
|
}
|
|
)
|
|
|
|
|
|
async def _listen_client_messages(
|
|
websocket: WebSocket,
|
|
task_id: str,
|
|
cancellation_token: CancellationToken,
|
|
_exec_task: asyncio.Task,
|
|
) -> None:
|
|
"""Listen for client messages (cancel, ping) with heartbeat timeout."""
|
|
try:
|
|
while True:
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=60.0)
|
|
except asyncio.TimeoutError:
|
|
# No message in 60s → close connection
|
|
await websocket.close(code=1000, reason="Heartbeat timeout")
|
|
return
|
|
|
|
try:
|
|
msg = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
msg_type = msg.get("type")
|
|
|
|
if msg_type == "cancel":
|
|
cancellation_token.cancel()
|
|
# Also cancel any asyncio task via runner
|
|
runner = websocket.app.state.runner
|
|
await runner.cancel(task_id)
|
|
# Cancel all tokens for this task (fan-out)
|
|
for token in manager.get_tokens(task_id):
|
|
token.cancel()
|
|
await websocket.send_json(
|
|
{
|
|
"type": "result",
|
|
"data": {"status": "cancelled", "task_id": task_id},
|
|
}
|
|
)
|
|
return
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
pass
|