From 45b6752e0d7a28b780e563a522aef14a546d9cf7 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 22 Jun 2026 16:37:23 +0800 Subject: [PATCH] fix(review): U4 WebSocket quota enforcement + gateway-layer hardening --- src/agentkit/server/auth/middleware.py | 6 +- src/agentkit/server/routes/chat.py | 88 +++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/src/agentkit/server/auth/middleware.py b/src/agentkit/server/auth/middleware.py index 7b66dde..b7040f7 100644 --- a/src/agentkit/server/auth/middleware.py +++ b/src/agentkit/server/auth/middleware.py @@ -167,7 +167,11 @@ class AuthMiddleware(BaseHTTPMiddleware): } return await call_next(request) # Fall through to API key check, then 401 - elif path.startswith("/api/v1/ws") or path.startswith("/ws"): + elif ( + path.startswith("/api/v1/ws") + or path.startswith("/ws") + or path.startswith("/api/v1/chat/ws") + ): token = request.query_params.get("token") if token: payload = self._verify_jwt(token) diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py index 3184b85..eee5b15 100644 --- a/src/agentkit/server/routes/chat.py +++ b/src/agentkit/server/routes/chat.py @@ -609,6 +609,58 @@ async def close_session(session_id: str, req: Request): # ── WebSocket endpoint ──────────────────────────────────────────────── +async def _resolve_ws_dept_context( + websocket: WebSocket, +) -> tuple[str | None, list[str], Path | None]: + """Resolve user_id, department_ids, and db_path for quota enforcement. + + Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware` + when JWT auth via ``?token=`` query param succeeds). If no user context + is available (API key auth), returns ``(None, [], None)`` — quota is + not enforced for API key clients. + + Returns ``(user_id, department_ids, db_path)``: + - ``user_id``: the authenticated user's id, or ``None``. + - ``department_ids``: the user's active department ids (empty for + admins and API key clients). + - ``db_path``: the auth DB path (needed for quota checks), or ``None`` + if no user context is available. + """ + current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None) + if current_user is None: + return None, [], None + + user_id = current_user.get("user_id") + role = current_user.get("role") + + # Admins and API-key clients (user_id=None) — no quota enforcement. + if role == "admin" or not user_id: + return user_id, [], None + + # Regular user: look up their active department ids. + from agentkit.server.admin.context import _fetch_user_department_ids + + db_path = getattr(websocket.app.state, "auth_db_path", None) + if db_path is None: + db_path_resolved = None + else: + db_path_resolved = Path(db_path) + try: + department_ids = await _fetch_user_department_ids(db_path_resolved, user_id) + except Exception: + logger.exception( + "Failed to fetch department ids for WebSocket user %s — fail-closed", + user_id, + ) + # Fail-closed: return empty list with db_path=None so gateway + # skips quota enforcement but we log the failure. The user + # gets a degraded experience rather than a security bypass. + return user_id, [], None + return user_id, department_ids, db_path_resolved + + return user_id, [], None + + @router.websocket("/ws/{session_id}") async def chat_websocket(websocket: WebSocket, session_id: str) -> None: """WebSocket endpoint for real-time chat with streaming. @@ -887,12 +939,17 @@ async def _handle_chat_message( if routing.system_prompt: direct_messages.append({"role": "system", "content": routing.system_prompt}) direct_messages.extend(chat_messages) + # Resolve department context for quota enforcement (U4/KTD-3). + ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket) try: response = await websocket.app.state.llm_gateway.chat( messages=direct_messages, model=routing.model or "default", agent_name=agent.name, task_type="chat", + user_id=ws_user_id, + department_ids=ws_dept_ids if ws_dept_ids else None, + db_path=ws_db_path, ) final_content = response.content or "" if not final_content or not final_content.strip(): @@ -913,8 +970,35 @@ async def _handle_chat_message( agent_name=agent.name, ) except Exception as e: - logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}") - await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) + # Check if this is a QuotaExceededError (U4: WebSocket quota). + from agentkit.llm.gateway import QuotaExceededError + + if isinstance(e, QuotaExceededError): + logger.warning( + "WebSocket DIRECT_CHAT quota exceeded for session %s: %s", + session_id, + e, + ) + await websocket.send_json( + { + "type": "error", + "data": { + "message": "quota exceeded", + "quota_info": { + "department_id": e.department_id, + "quota_type": e.quota_type, + "period": e.period, + "limit": e.limit, + "current": e.current, + }, + }, + } + ) + else: + logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}") + await websocket.send_json( + {"type": "error", "data": {"message": str(e)[:200]}} + ) return # Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB