fix(review): U4 WebSocket quota enforcement + gateway-layer hardening
This commit is contained in:
parent
cd371e4155
commit
45b6752e0d
|
|
@ -167,7 +167,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||
}
|
||||
return await call_next(request)
|
||||
# Fall through to API key check, then 401
|
||||
elif path.startswith("/api/v1/ws") or path.startswith("/ws"):
|
||||
elif (
|
||||
path.startswith("/api/v1/ws")
|
||||
or path.startswith("/ws")
|
||||
or path.startswith("/api/v1/chat/ws")
|
||||
):
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
payload = self._verify_jwt(token)
|
||||
|
|
|
|||
|
|
@ -609,6 +609,58 @@ async def close_session(session_id: str, req: Request):
|
|||
# ── WebSocket endpoint ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _resolve_ws_dept_context(
|
||||
websocket: WebSocket,
|
||||
) -> tuple[str | None, list[str], Path | None]:
|
||||
"""Resolve user_id, department_ids, and db_path for quota enforcement.
|
||||
|
||||
Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware`
|
||||
when JWT auth via ``?token=`` query param succeeds). If no user context
|
||||
is available (API key auth), returns ``(None, [], None)`` — quota is
|
||||
not enforced for API key clients.
|
||||
|
||||
Returns ``(user_id, department_ids, db_path)``:
|
||||
- ``user_id``: the authenticated user's id, or ``None``.
|
||||
- ``department_ids``: the user's active department ids (empty for
|
||||
admins and API key clients).
|
||||
- ``db_path``: the auth DB path (needed for quota checks), or ``None``
|
||||
if no user context is available.
|
||||
"""
|
||||
current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None)
|
||||
if current_user is None:
|
||||
return None, [], None
|
||||
|
||||
user_id = current_user.get("user_id")
|
||||
role = current_user.get("role")
|
||||
|
||||
# Admins and API-key clients (user_id=None) — no quota enforcement.
|
||||
if role == "admin" or not user_id:
|
||||
return user_id, [], None
|
||||
|
||||
# Regular user: look up their active department ids.
|
||||
from agentkit.server.admin.context import _fetch_user_department_ids
|
||||
|
||||
db_path = getattr(websocket.app.state, "auth_db_path", None)
|
||||
if db_path is None:
|
||||
db_path_resolved = None
|
||||
else:
|
||||
db_path_resolved = Path(db_path)
|
||||
try:
|
||||
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to fetch department ids for WebSocket user %s — fail-closed",
|
||||
user_id,
|
||||
)
|
||||
# Fail-closed: return empty list with db_path=None so gateway
|
||||
# skips quota enforcement but we log the failure. The user
|
||||
# gets a degraded experience rather than a security bypass.
|
||||
return user_id, [], None
|
||||
return user_id, department_ids, db_path_resolved
|
||||
|
||||
return user_id, [], None
|
||||
|
||||
|
||||
@router.websocket("/ws/{session_id}")
|
||||
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||
"""WebSocket endpoint for real-time chat with streaming.
|
||||
|
|
@ -887,12 +939,17 @@ async def _handle_chat_message(
|
|||
if routing.system_prompt:
|
||||
direct_messages.append({"role": "system", "content": routing.system_prompt})
|
||||
direct_messages.extend(chat_messages)
|
||||
# Resolve department context for quota enforcement (U4/KTD-3).
|
||||
ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket)
|
||||
try:
|
||||
response = await websocket.app.state.llm_gateway.chat(
|
||||
messages=direct_messages,
|
||||
model=routing.model or "default",
|
||||
agent_name=agent.name,
|
||||
task_type="chat",
|
||||
user_id=ws_user_id,
|
||||
department_ids=ws_dept_ids if ws_dept_ids else None,
|
||||
db_path=ws_db_path,
|
||||
)
|
||||
final_content = response.content or ""
|
||||
if not final_content or not final_content.strip():
|
||||
|
|
@ -913,8 +970,35 @@ async def _handle_chat_message(
|
|||
agent_name=agent.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
||||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
||||
# Check if this is a QuotaExceededError (U4: WebSocket quota).
|
||||
from agentkit.llm.gateway import QuotaExceededError
|
||||
|
||||
if isinstance(e, QuotaExceededError):
|
||||
logger.warning(
|
||||
"WebSocket DIRECT_CHAT quota exceeded for session %s: %s",
|
||||
session_id,
|
||||
e,
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"data": {
|
||||
"message": "quota exceeded",
|
||||
"quota_info": {
|
||||
"department_id": e.department_id,
|
||||
"quota_type": e.quota_type,
|
||||
"period": e.period,
|
||||
"limit": e.limit,
|
||||
"current": e.current,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
||||
await websocket.send_json(
|
||||
{"type": "error", "data": {"message": str(e)[:200]}}
|
||||
)
|
||||
return
|
||||
|
||||
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
||||
|
|
|
|||
Loading…
Reference in New Issue