fischer-agentkit/src/agentkit/server/auth/middleware.py

233 lines
9.3 KiB
Python

"""Authentication middleware — JWT + API Key dual-track.
This middleware runs *before* the legacy :class:`APIKeyAuthMiddleware` so that
JWT-authenticated requests (browser sessions) and API-Key-authenticated
requests (programmatic clients) can coexist during the U2 → U5 migration.
Authentication order per request:
1. **Whitelist** — paths in :data:`AuthMiddleware.WHITELIST_PATHS` pass through.
2. **JWT** — ``Authorization: Bearer <token>`` header. On success, the decoded
payload is stored on ``request.state.current_user``.
3. **API Key** — ``X-API-Key`` header. Compared in constant time against the
global ``api_key`` and any ``client_keys`` (loaded from ``clients.yaml``).
4. **Dev mode** — when no JWT secret, no global API key, and no client keys
are configured, all requests pass through (with a one-time warning).
5. Otherwise → ``401 Unauthorized``.
"""
from __future__ import annotations
import hmac
import logging
from typing import Any
import jwt
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from agentkit.server.auth.jwt_utils import verify_token
logger = logging.getLogger(__name__)
class AuthMiddleware(BaseHTTPMiddleware):
"""Dual-track JWT + API Key authentication middleware.
Args:
app: ASGI app to wrap.
jwt_secret: HS256 signing secret. Empty string disables JWT auth.
api_key: Global API key (e.g. from ``server.api_key`` in agentkit.yaml).
client_keys: Mapping of ``client_name -> api_key`` (from clients.yaml).
"""
WHITELIST_PATHS = (
"/api/v1/health",
"/api/v1/auth/login",
"/api/v1/auth/refresh",
"/api/v1/auth/logout",
"/api/v1/auth/whoami", # Route does its own auth (access OR refresh)
"/docs",
"/openapi.json",
"/redoc",
)
def __init__(
self,
app,
jwt_secret: str = "",
api_key: str | None = None,
client_keys: dict[str, str] | None = None,
) -> None:
super().__init__(app)
self._jwt_secret = jwt_secret or ""
self._api_key = api_key
self._client_keys: dict[str, str] = dict(client_keys) if client_keys else {}
self._dev_mode_warned = False
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _is_whitelisted(self, path: str) -> bool:
"""Return True if ``path`` matches a whitelisted route.
Uses exact match for auth routes (so ``/auth/logout`` does NOT
whitelist ``/auth/logout-others``) and prefix match for docs.
"""
for prefix in self.WHITELIST_PATHS:
if path == prefix:
return True
# Prefix match only for documentation paths (trailing slash
# or sub-path is fine). Auth paths require exact match to
# avoid accidentally whitelisting sibling routes like
# /auth/logout-others under /auth/logout.
if prefix in ("/docs", "/openapi.json", "/redoc") and path.startswith(prefix):
return True
return False
def _is_dev_mode(self) -> bool:
"""Dev mode = no JWT secret, no global API key, no client keys."""
return not self._jwt_secret and self._api_key is None and not self._client_keys
def _verify_jwt(self, token: str) -> dict[str, Any] | None:
"""Verify a JWT bearer token. Returns payload or None.
V2 tokens carry a ``sid`` claim. The middleware does NOT
consult the session cache here — it only checks the JWT
signature + ``type`` claim. The sid-aware validation runs
lazily in the :func:`require_authenticated` dependency when
the route opts in (see :mod:`.dependencies`). This keeps the
middleware hot path cheap and lets the per-route decision
about cache TTL be made by the route author.
"""
if not self._jwt_secret:
return None
try:
payload = verify_token(token, self._jwt_secret)
except jwt.InvalidTokenError:
return None
# Only accept access tokens for request auth (not refresh tokens)
if payload.get("type") != "access":
return None
return payload
def _verify_api_key(self, provided_key: str) -> tuple[bool, str | None]:
"""Verify an API key in constant time against all known keys.
Returns ``(matched, client_name)``. ``client_name`` is the key from
``client_keys`` that matched, or ``None`` for the global API key.
"""
if not provided_key:
return False, None
candidates: list[tuple[str, str | None]] = []
if self._api_key:
candidates.append((self._api_key, None))
for name, key in self._client_keys.items():
candidates.append((key, name))
if not candidates:
return False, None
provided_bytes = provided_key.encode("utf-8")
# Compare against ALL candidates (no short-circuit) to maintain
# constant-time behavior at the multi-key level.
matched_name: str | None = None
found = False
for candidate, name in candidates:
if hmac.compare_digest(provided_bytes, candidate.encode("utf-8")):
found = True
matched_name = name
return found, matched_name
# ------------------------------------------------------------------
# Dispatch
# ------------------------------------------------------------------
async def dispatch(self, request: Request, call_next):
path = request.url.path
# 0. CORS preflight — OPTIONS requests must never be authenticated.
# The browser sends them without credentials; if we return 401
# here, CORSMiddleware never sees the request and the browser
# blocks the actual request with "Load failed".
if request.method == "OPTIONS":
return await call_next(request)
# 0b. If an outer middleware already set current_user (e.g. a test
# dev-admin injector), defer to it instead of re-authenticating.
if getattr(request.state, "current_user", None) is not None:
return await call_next(request)
# 1. Whitelist
if self._is_whitelisted(path):
return await call_next(request)
# 2. JWT (Authorization: Bearer <token>) — also accept ?token=<jwt>
# query parameter for WebSocket clients where setting headers is
# not possible. The query param is only honored for /ws paths to
# limit its exposure in access logs.
auth_header = request.headers.get("Authorization", "")
if auth_header.lower().startswith("bearer "):
token = auth_header[7:].strip()
payload = self._verify_jwt(token)
if payload is not None:
request.state.current_user = {
"user_id": payload.get("sub"),
"username": payload.get("username"),
"role": payload.get("role"),
"sid": payload.get("sid"),
}
return await call_next(request)
# Fall through to API key check, then 401
elif (
path.startswith("/api/v1/ws")
or path.startswith("/ws")
or path.startswith("/api/v1/chat/ws")
or path.startswith("/api/v1/portal/ws")
):
token = request.query_params.get("token")
if token:
payload = self._verify_jwt(token)
if payload is not None:
request.state.current_user = {
"user_id": payload.get("sub"),
"username": payload.get("username"),
"role": payload.get("role"),
}
return await call_next(request)
# 3. API Key (X-API-Key header)
api_key = request.headers.get("X-API-Key")
if api_key:
matched, client_name = self._verify_api_key(api_key)
if matched:
# Set current_user so RBAC (require_permission) works.
# API key clients get 'operator' role by default — enough
# for chat/KB/workflow but not user management.
request.state.current_user = {
"user_id": None,
"username": client_name or "api_client",
"role": "operator",
}
return await call_next(request)
# 4. Dev mode (no auth configured)
if self._is_dev_mode():
if not self._dev_mode_warned:
logger.warning(
"AuthMiddleware running in dev mode (no JWT secret, "
"no API key, no client keys). All requests pass through. "
"Set AGENTKIT_JWT_SECRET or server.api_key for production."
)
self._dev_mode_warned = True
return await call_next(request)
# 5. Unauthorized
return JSONResponse(
status_code=401,
content={
"error": "Unauthorized",
"message": "Valid JWT (Authorization: Bearer) or API key (X-API-Key) required",
},
)