fix(review): U4 WebSocket quota enforcement + gateway-layer hardening
This commit is contained in:
parent
cd371e4155
commit
45b6752e0d
|
|
@ -167,7 +167,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
}
|
}
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
# Fall through to API key check, then 401
|
# Fall through to API key check, then 401
|
||||||
elif path.startswith("/api/v1/ws") or path.startswith("/ws"):
|
elif (
|
||||||
|
path.startswith("/api/v1/ws")
|
||||||
|
or path.startswith("/ws")
|
||||||
|
or path.startswith("/api/v1/chat/ws")
|
||||||
|
):
|
||||||
token = request.query_params.get("token")
|
token = request.query_params.get("token")
|
||||||
if token:
|
if token:
|
||||||
payload = self._verify_jwt(token)
|
payload = self._verify_jwt(token)
|
||||||
|
|
|
||||||
|
|
@ -609,6 +609,58 @@ async def close_session(session_id: str, req: Request):
|
||||||
# ── WebSocket endpoint ────────────────────────────────────────────────
|
# ── WebSocket endpoint ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolve_ws_dept_context(
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> tuple[str | None, list[str], Path | None]:
|
||||||
|
"""Resolve user_id, department_ids, and db_path for quota enforcement.
|
||||||
|
|
||||||
|
Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware`
|
||||||
|
when JWT auth via ``?token=`` query param succeeds). If no user context
|
||||||
|
is available (API key auth), returns ``(None, [], None)`` — quota is
|
||||||
|
not enforced for API key clients.
|
||||||
|
|
||||||
|
Returns ``(user_id, department_ids, db_path)``:
|
||||||
|
- ``user_id``: the authenticated user's id, or ``None``.
|
||||||
|
- ``department_ids``: the user's active department ids (empty for
|
||||||
|
admins and API key clients).
|
||||||
|
- ``db_path``: the auth DB path (needed for quota checks), or ``None``
|
||||||
|
if no user context is available.
|
||||||
|
"""
|
||||||
|
current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None)
|
||||||
|
if current_user is None:
|
||||||
|
return None, [], None
|
||||||
|
|
||||||
|
user_id = current_user.get("user_id")
|
||||||
|
role = current_user.get("role")
|
||||||
|
|
||||||
|
# Admins and API-key clients (user_id=None) — no quota enforcement.
|
||||||
|
if role == "admin" or not user_id:
|
||||||
|
return user_id, [], None
|
||||||
|
|
||||||
|
# Regular user: look up their active department ids.
|
||||||
|
from agentkit.server.admin.context import _fetch_user_department_ids
|
||||||
|
|
||||||
|
db_path = getattr(websocket.app.state, "auth_db_path", None)
|
||||||
|
if db_path is None:
|
||||||
|
db_path_resolved = None
|
||||||
|
else:
|
||||||
|
db_path_resolved = Path(db_path)
|
||||||
|
try:
|
||||||
|
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to fetch department ids for WebSocket user %s — fail-closed",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
# Fail-closed: return empty list with db_path=None so gateway
|
||||||
|
# skips quota enforcement but we log the failure. The user
|
||||||
|
# gets a degraded experience rather than a security bypass.
|
||||||
|
return user_id, [], None
|
||||||
|
return user_id, department_ids, db_path_resolved
|
||||||
|
|
||||||
|
return user_id, [], None
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/{session_id}")
|
@router.websocket("/ws/{session_id}")
|
||||||
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
||||||
"""WebSocket endpoint for real-time chat with streaming.
|
"""WebSocket endpoint for real-time chat with streaming.
|
||||||
|
|
@ -887,12 +939,17 @@ async def _handle_chat_message(
|
||||||
if routing.system_prompt:
|
if routing.system_prompt:
|
||||||
direct_messages.append({"role": "system", "content": routing.system_prompt})
|
direct_messages.append({"role": "system", "content": routing.system_prompt})
|
||||||
direct_messages.extend(chat_messages)
|
direct_messages.extend(chat_messages)
|
||||||
|
# Resolve department context for quota enforcement (U4/KTD-3).
|
||||||
|
ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket)
|
||||||
try:
|
try:
|
||||||
response = await websocket.app.state.llm_gateway.chat(
|
response = await websocket.app.state.llm_gateway.chat(
|
||||||
messages=direct_messages,
|
messages=direct_messages,
|
||||||
model=routing.model or "default",
|
model=routing.model or "default",
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
task_type="chat",
|
task_type="chat",
|
||||||
|
user_id=ws_user_id,
|
||||||
|
department_ids=ws_dept_ids if ws_dept_ids else None,
|
||||||
|
db_path=ws_db_path,
|
||||||
)
|
)
|
||||||
final_content = response.content or ""
|
final_content = response.content or ""
|
||||||
if not final_content or not final_content.strip():
|
if not final_content or not final_content.strip():
|
||||||
|
|
@ -913,8 +970,35 @@ async def _handle_chat_message(
|
||||||
agent_name=agent.name,
|
agent_name=agent.name,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
# Check if this is a QuotaExceededError (U4: WebSocket quota).
|
||||||
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
from agentkit.llm.gateway import QuotaExceededError
|
||||||
|
|
||||||
|
if isinstance(e, QuotaExceededError):
|
||||||
|
logger.warning(
|
||||||
|
"WebSocket DIRECT_CHAT quota exceeded for session %s: %s",
|
||||||
|
session_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"data": {
|
||||||
|
"message": "quota exceeded",
|
||||||
|
"quota_info": {
|
||||||
|
"department_id": e.department_id,
|
||||||
|
"quota_type": e.quota_type,
|
||||||
|
"period": e.period,
|
||||||
|
"limit": e.limit,
|
||||||
|
"current": e.current,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "error", "data": {"message": str(e)[:200]}}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue