fix(review): U4 WebSocket quota enforcement + gateway-layer hardening

This commit is contained in:
chiguyong 2026-06-22 16:37:23 +08:00
parent cd371e4155
commit 45b6752e0d
2 changed files with 91 additions and 3 deletions

View File

@ -167,7 +167,11 @@ class AuthMiddleware(BaseHTTPMiddleware):
} }
return await call_next(request) return await call_next(request)
# Fall through to API key check, then 401 # Fall through to API key check, then 401
elif path.startswith("/api/v1/ws") or path.startswith("/ws"): elif (
path.startswith("/api/v1/ws")
or path.startswith("/ws")
or path.startswith("/api/v1/chat/ws")
):
token = request.query_params.get("token") token = request.query_params.get("token")
if token: if token:
payload = self._verify_jwt(token) payload = self._verify_jwt(token)

View File

@ -609,6 +609,58 @@ async def close_session(session_id: str, req: Request):
# ── WebSocket endpoint ──────────────────────────────────────────────── # ── WebSocket endpoint ────────────────────────────────────────────────
async def _resolve_ws_dept_context(
websocket: WebSocket,
) -> tuple[str | None, list[str], Path | None]:
"""Resolve user_id, department_ids, and db_path for quota enforcement.
Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware`
when JWT auth via ``?token=`` query param succeeds). If no user context
is available (API key auth), returns ``(None, [], None)`` quota is
not enforced for API key clients.
Returns ``(user_id, department_ids, db_path)``:
- ``user_id``: the authenticated user's id, or ``None``.
- ``department_ids``: the user's active department ids (empty for
admins and API key clients).
- ``db_path``: the auth DB path (needed for quota checks), or ``None``
if no user context is available.
"""
current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None)
if current_user is None:
return None, [], None
user_id = current_user.get("user_id")
role = current_user.get("role")
# Admins and API-key clients (user_id=None) — no quota enforcement.
if role == "admin" or not user_id:
return user_id, [], None
# Regular user: look up their active department ids.
from agentkit.server.admin.context import _fetch_user_department_ids
db_path = getattr(websocket.app.state, "auth_db_path", None)
if db_path is None:
db_path_resolved = None
else:
db_path_resolved = Path(db_path)
try:
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
except Exception:
logger.exception(
"Failed to fetch department ids for WebSocket user %s — fail-closed",
user_id,
)
# Fail-closed: return empty list with db_path=None so gateway
# skips quota enforcement but we log the failure. The user
# gets a degraded experience rather than a security bypass.
return user_id, [], None
return user_id, department_ids, db_path_resolved
return user_id, [], None
@router.websocket("/ws/{session_id}") @router.websocket("/ws/{session_id}")
async def chat_websocket(websocket: WebSocket, session_id: str) -> None: async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket endpoint for real-time chat with streaming. """WebSocket endpoint for real-time chat with streaming.
@ -887,12 +939,17 @@ async def _handle_chat_message(
if routing.system_prompt: if routing.system_prompt:
direct_messages.append({"role": "system", "content": routing.system_prompt}) direct_messages.append({"role": "system", "content": routing.system_prompt})
direct_messages.extend(chat_messages) direct_messages.extend(chat_messages)
# Resolve department context for quota enforcement (U4/KTD-3).
ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket)
try: try:
response = await websocket.app.state.llm_gateway.chat( response = await websocket.app.state.llm_gateway.chat(
messages=direct_messages, messages=direct_messages,
model=routing.model or "default", model=routing.model or "default",
agent_name=agent.name, agent_name=agent.name,
task_type="chat", task_type="chat",
user_id=ws_user_id,
department_ids=ws_dept_ids if ws_dept_ids else None,
db_path=ws_db_path,
) )
final_content = response.content or "" final_content = response.content or ""
if not final_content or not final_content.strip(): if not final_content or not final_content.strip():
@ -913,8 +970,35 @@ async def _handle_chat_message(
agent_name=agent.name, agent_name=agent.name,
) )
except Exception as e: except Exception as e:
# Check if this is a QuotaExceededError (U4: WebSocket quota).
from agentkit.llm.gateway import QuotaExceededError
if isinstance(e, QuotaExceededError):
logger.warning(
"WebSocket DIRECT_CHAT quota exceeded for session %s: %s",
session_id,
e,
)
await websocket.send_json(
{
"type": "error",
"data": {
"message": "quota exceeded",
"quota_info": {
"department_id": e.department_id,
"quota_type": e.quota_type,
"period": e.period,
"limit": e.limit,
"current": e.current,
},
},
}
)
else:
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}") logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) await websocket.send_json(
{"type": "error", "data": {"message": str(e)[:200]}}
)
return return
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB # Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB