233 lines
9.3 KiB
Python
233 lines
9.3 KiB
Python
"""Authentication middleware — JWT + API Key dual-track.
|
|
|
|
This middleware runs *before* the legacy :class:`APIKeyAuthMiddleware` so that
|
|
JWT-authenticated requests (browser sessions) and API-Key-authenticated
|
|
requests (programmatic clients) can coexist during the U2 → U5 migration.
|
|
|
|
Authentication order per request:
|
|
|
|
1. **Whitelist** — paths in :data:`AuthMiddleware.WHITELIST_PATHS` pass through.
|
|
2. **JWT** — ``Authorization: Bearer <token>`` header. On success, the decoded
|
|
payload is stored on ``request.state.current_user``.
|
|
3. **API Key** — ``X-API-Key`` header. Compared in constant time against the
|
|
global ``api_key`` and any ``client_keys`` (loaded from ``clients.yaml``).
|
|
4. **Dev mode** — when no JWT secret, no global API key, and no client keys
|
|
are configured, all requests pass through (with a one-time warning).
|
|
5. Otherwise → ``401 Unauthorized``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hmac
|
|
import logging
|
|
from typing import Any
|
|
|
|
import jwt
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
|
|
from agentkit.server.auth.jwt_utils import verify_token
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""Dual-track JWT + API Key authentication middleware.
|
|
|
|
Args:
|
|
app: ASGI app to wrap.
|
|
jwt_secret: HS256 signing secret. Empty string disables JWT auth.
|
|
api_key: Global API key (e.g. from ``server.api_key`` in agentkit.yaml).
|
|
client_keys: Mapping of ``client_name -> api_key`` (from clients.yaml).
|
|
"""
|
|
|
|
WHITELIST_PATHS = (
|
|
"/api/v1/health",
|
|
"/api/v1/auth/login",
|
|
"/api/v1/auth/refresh",
|
|
"/api/v1/auth/logout",
|
|
"/api/v1/auth/whoami", # Route does its own auth (access OR refresh)
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/redoc",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
app,
|
|
jwt_secret: str = "",
|
|
api_key: str | None = None,
|
|
client_keys: dict[str, str] | None = None,
|
|
) -> None:
|
|
super().__init__(app)
|
|
self._jwt_secret = jwt_secret or ""
|
|
self._api_key = api_key
|
|
self._client_keys: dict[str, str] = dict(client_keys) if client_keys else {}
|
|
self._dev_mode_warned = False
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _is_whitelisted(self, path: str) -> bool:
|
|
"""Return True if ``path`` matches a whitelisted route.
|
|
|
|
Uses exact match for auth routes (so ``/auth/logout`` does NOT
|
|
whitelist ``/auth/logout-others``) and prefix match for docs.
|
|
"""
|
|
for prefix in self.WHITELIST_PATHS:
|
|
if path == prefix:
|
|
return True
|
|
# Prefix match only for documentation paths (trailing slash
|
|
# or sub-path is fine). Auth paths require exact match to
|
|
# avoid accidentally whitelisting sibling routes like
|
|
# /auth/logout-others under /auth/logout.
|
|
if prefix in ("/docs", "/openapi.json", "/redoc") and path.startswith(prefix):
|
|
return True
|
|
return False
|
|
|
|
def _is_dev_mode(self) -> bool:
|
|
"""Dev mode = no JWT secret, no global API key, no client keys."""
|
|
return not self._jwt_secret and self._api_key is None and not self._client_keys
|
|
|
|
def _verify_jwt(self, token: str) -> dict[str, Any] | None:
|
|
"""Verify a JWT bearer token. Returns payload or None.
|
|
|
|
V2 tokens carry a ``sid`` claim. The middleware does NOT
|
|
consult the session cache here — it only checks the JWT
|
|
signature + ``type`` claim. The sid-aware validation runs
|
|
lazily in the :func:`require_authenticated` dependency when
|
|
the route opts in (see :mod:`.dependencies`). This keeps the
|
|
middleware hot path cheap and lets the per-route decision
|
|
about cache TTL be made by the route author.
|
|
"""
|
|
if not self._jwt_secret:
|
|
return None
|
|
try:
|
|
payload = verify_token(token, self._jwt_secret)
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
# Only accept access tokens for request auth (not refresh tokens)
|
|
if payload.get("type") != "access":
|
|
return None
|
|
return payload
|
|
|
|
def _verify_api_key(self, provided_key: str) -> tuple[bool, str | None]:
|
|
"""Verify an API key in constant time against all known keys.
|
|
|
|
Returns ``(matched, client_name)``. ``client_name`` is the key from
|
|
``client_keys`` that matched, or ``None`` for the global API key.
|
|
"""
|
|
if not provided_key:
|
|
return False, None
|
|
candidates: list[tuple[str, str | None]] = []
|
|
if self._api_key:
|
|
candidates.append((self._api_key, None))
|
|
for name, key in self._client_keys.items():
|
|
candidates.append((key, name))
|
|
if not candidates:
|
|
return False, None
|
|
provided_bytes = provided_key.encode("utf-8")
|
|
# Compare against ALL candidates (no short-circuit) to maintain
|
|
# constant-time behavior at the multi-key level.
|
|
matched_name: str | None = None
|
|
found = False
|
|
for candidate, name in candidates:
|
|
if hmac.compare_digest(provided_bytes, candidate.encode("utf-8")):
|
|
found = True
|
|
matched_name = name
|
|
return found, matched_name
|
|
|
|
# ------------------------------------------------------------------
|
|
# Dispatch
|
|
# ------------------------------------------------------------------
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
path = request.url.path
|
|
|
|
# 0. CORS preflight — OPTIONS requests must never be authenticated.
|
|
# The browser sends them without credentials; if we return 401
|
|
# here, CORSMiddleware never sees the request and the browser
|
|
# blocks the actual request with "Load failed".
|
|
if request.method == "OPTIONS":
|
|
return await call_next(request)
|
|
|
|
# 0b. If an outer middleware already set current_user (e.g. a test
|
|
# dev-admin injector), defer to it instead of re-authenticating.
|
|
if getattr(request.state, "current_user", None) is not None:
|
|
return await call_next(request)
|
|
|
|
# 1. Whitelist
|
|
if self._is_whitelisted(path):
|
|
return await call_next(request)
|
|
|
|
# 2. JWT (Authorization: Bearer <token>) — also accept ?token=<jwt>
|
|
# query parameter for WebSocket clients where setting headers is
|
|
# not possible. The query param is only honored for /ws paths to
|
|
# limit its exposure in access logs.
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if auth_header.lower().startswith("bearer "):
|
|
token = auth_header[7:].strip()
|
|
payload = self._verify_jwt(token)
|
|
if payload is not None:
|
|
request.state.current_user = {
|
|
"user_id": payload.get("sub"),
|
|
"username": payload.get("username"),
|
|
"role": payload.get("role"),
|
|
"sid": payload.get("sid"),
|
|
}
|
|
return await call_next(request)
|
|
# Fall through to API key check, then 401
|
|
elif (
|
|
path.startswith("/api/v1/ws")
|
|
or path.startswith("/ws")
|
|
or path.startswith("/api/v1/chat/ws")
|
|
or path.startswith("/api/v1/portal/ws")
|
|
):
|
|
token = request.query_params.get("token")
|
|
if token:
|
|
payload = self._verify_jwt(token)
|
|
if payload is not None:
|
|
request.state.current_user = {
|
|
"user_id": payload.get("sub"),
|
|
"username": payload.get("username"),
|
|
"role": payload.get("role"),
|
|
}
|
|
return await call_next(request)
|
|
|
|
# 3. API Key (X-API-Key header)
|
|
api_key = request.headers.get("X-API-Key")
|
|
if api_key:
|
|
matched, client_name = self._verify_api_key(api_key)
|
|
if matched:
|
|
# Set current_user so RBAC (require_permission) works.
|
|
# API key clients get 'operator' role by default — enough
|
|
# for chat/KB/workflow but not user management.
|
|
request.state.current_user = {
|
|
"user_id": None,
|
|
"username": client_name or "api_client",
|
|
"role": "operator",
|
|
}
|
|
return await call_next(request)
|
|
|
|
# 4. Dev mode (no auth configured)
|
|
if self._is_dev_mode():
|
|
if not self._dev_mode_warned:
|
|
logger.warning(
|
|
"AuthMiddleware running in dev mode (no JWT secret, "
|
|
"no API key, no client keys). All requests pass through. "
|
|
"Set AGENTKIT_JWT_SECRET or server.api_key for production."
|
|
)
|
|
self._dev_mode_warned = True
|
|
return await call_next(request)
|
|
|
|
# 5. Unauthorized
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={
|
|
"error": "Unauthorized",
|
|
"message": "Valid JWT (Authorization: Bearer) or API key (X-API-Key) required",
|
|
},
|
|
)
|