fischer-agentkit/src/agentkit/server/auth/jwt_utils.py

298 lines
10 KiB
Python

"""JWT issuance and verification utilities.
Tokens are HS256-signed with a secret read from the ``AGENTKIT_JWT_SECRET``
environment variable. In dev mode (no secret configured) a random secret is
generated in-process and a warning is logged — this secret is *not* persisted
and will invalidate all tokens on restart, so it must never be used in
production.
Access tokens are short-lived (15 min) and carry ``type="access"``.
Refresh tokens are long-lived (7 days default, 30 days when the user opts
into "remember me") and carry ``type="refresh"``.
V2 claims (Centralized Auth & Token Persistence, U2):
- ``sid`` — the session id (UUID) referencing a row in ``auth_sessions``.
The middleware uses this to validate that the session has not
been revoked server-side, closing the kicked-out window.
- ``jti`` — a per-token unique id, used by the denylist for rotation.
Only the access token carries a ``jti``; the refresh token
uses its hashed value (``sha256(refresh_token)``) as its
rotation handle, so re-generating a jti on every refresh
would be wasteful.
Backwards-compat:
- Tokens issued before U2 land (no ``sid``) are still accepted by
:func:`verify_token`. Validation falls through to the legacy
``user_sessions`` table via :func:`agentkit.server.auth.dependencies`
(see U10 back-compat shim).
"""
from __future__ import annotations
import logging
import os
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
import jwt
logger = logging.getLogger(__name__)
# Token lifetimes
ACCESS_TOKEN_TTL = timedelta(minutes=15)
REFRESH_TOKEN_TTL = timedelta(days=7)
REFRESH_TOKEN_TTL_REMEMBER_ME = timedelta(days=30)
# JWT algorithm
JWT_ALGORITHM = "HS256"
@dataclass
class TokenPair:
"""A signed access + refresh JWT pair with their expiry timestamps."""
access_token: str
refresh_token: str
access_expires_at: datetime
refresh_expires_at: datetime
def get_jwt_secret() -> str | None:
"""Return the configured JWT secret, or ``None`` if not configured.
Reads from the ``AGENTKIT_JWT_SECRET`` env var. Returns ``None`` when
the env var is not set — callers (middleware, routes) decide how to
handle the unset case. Use :func:`get_or_create_jwt_secret` when a
non-empty secret is required (e.g. for signing tokens in dev mode).
"""
secret = os.environ.get("AGENTKIT_JWT_SECRET")
if secret:
return secret
logger.warning(
"AGENTKIT_JWT_SECRET is not set. JWT auth is disabled in the "
"middleware; token-signing routes will use an ephemeral secret "
"that is invalidated on restart. Set AGENTKIT_JWT_SECRET for "
"production use."
)
return None
def get_or_create_jwt_secret() -> str:
"""Return the configured JWT secret, or generate an ephemeral one.
Use this when a non-empty secret is required (e.g. signing tokens).
For middleware configuration, prefer :func:`get_jwt_secret` so that
the absence of a secret disables JWT auth (dev mode).
"""
secret = os.environ.get("AGENTKIT_JWT_SECRET")
if secret:
return secret
ephemeral = secrets.token_urlsafe(48)
logger.warning(
"AGENTKIT_JWT_SECRET is not set — generated an ephemeral dev secret. "
"All JWTs will be invalidated on process restart. "
"Set AGENTKIT_JWT_SECRET in the environment for production use."
)
return ephemeral
def _refresh_ttl_for(remember_me: bool) -> timedelta:
"""Return the refresh-token TTL for a given login.
- ``remember_me=False`` (default) → 7 days
- ``remember_me=True`` → 30 days
"""
return REFRESH_TOKEN_TTL_REMEMBER_ME if remember_me else REFRESH_TOKEN_TTL
def create_token_pair(
user_id: str,
username: str,
role: str,
secret: str,
*,
session_id: str | None = None,
remember_me: bool = False,
now: datetime | None = None,
legacy_mode: bool = False,
) -> TokenPair:
"""Create a signed access + refresh JWT pair.
Args:
user_id: Subject (user id) — stored as ``sub``.
username: Username claim.
role: Role claim (e.g. ``member``, ``admin``).
secret: HS256 signing secret.
session_id: Server-side session id (UUID) referencing a row in
``auth_sessions``. When provided, the tokens carry a
``sid`` claim and the access token also carries a
``jti`` claim. Pass ``None`` for the V1 back-compat path
(U10): the resulting tokens are accepted by ``verify_token``
but lack session validation.
remember_me: When ``True``, the refresh token is valid for 30
days instead of the default 7.
now: Override the issued-at time (for testing). Defaults to UTC now.
legacy_mode: When ``True``, the resulting tokens are
intentionally issued without a ``sid`` claim, regardless of
whether ``session_id`` was passed. This is the U10
back-compat flag used by the login route for clients with
``X-Client-Version`` below the rollout cutoff. Default
``False`` — production tokens always carry ``sid``.
Returns:
A :class:`TokenPair` with both signed tokens and their expiry times.
"""
if not secret:
raise ValueError("JWT secret must not be empty")
# U10: legacy_mode forces the no-sid path even if the caller has
# already created a session row. This avoids handing a fresh sid
# to a client that doesn't know how to use it (the
# ``/auth/whoami`` route will fall back to its legacy branch).
effective_session_id: str | None = None if legacy_mode else session_id
issued_at = now or datetime.now(timezone.utc)
refresh_ttl = _refresh_ttl_for(remember_me)
access_exp = issued_at + ACCESS_TOKEN_TTL
refresh_exp = issued_at + refresh_ttl
# jti: per-token unique id (access only). The refresh token uses the
# sha256 of its plaintext value as the rotation handle in the
# denylist; giving it a jti too would be redundant and would bloat
# the JWT.
jti = str(uuid.uuid4()) if effective_session_id else None
access_payload: dict[str, object] = {
"sub": user_id,
"username": username,
"role": role,
"type": "access",
"iat": int(issued_at.timestamp()),
"exp": int(access_exp.timestamp()),
}
refresh_payload: dict[str, object] = {
"sub": user_id,
"username": username,
"role": role,
"type": "refresh",
"iat": int(issued_at.timestamp()),
"exp": int(refresh_exp.timestamp()),
# Persist the remember_me flag so /auth/refresh can inherit
# the original TTL (30d vs 7d) without the client re-sending
# it. Without this claim, every refresh would reset to the
# default 7-day TTL, defeating the "记住我 30 天" checkbox.
"rmb": remember_me,
}
if effective_session_id:
access_payload["sid"] = effective_session_id
access_payload["jti"] = jti
refresh_payload["sid"] = effective_session_id
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
refresh_token = jwt.encode(refresh_payload, secret, algorithm=JWT_ALGORITHM)
# PyJWT >= 2 returns str; older versions returned bytes. Normalize.
if isinstance(access_token, bytes):
access_token = access_token.decode("utf-8")
if isinstance(refresh_token, bytes):
refresh_token = refresh_token.decode("utf-8")
return TokenPair(
access_token=access_token,
refresh_token=refresh_token,
access_expires_at=access_exp,
refresh_expires_at=refresh_exp,
)
def create_access_token(
user_id: str,
username: str,
role: str,
secret: str,
*,
session_id: str | None = None,
now: datetime | None = None,
) -> str:
"""Create a single signed access JWT (no refresh token).
Used by ``/auth/whoami`` cold-start to issue a fresh access token
without creating a new refresh token (the client already has one).
This avoids the token-amplification risk of ``create_token_pair``
which would silently discard the new refresh token.
Args:
user_id: Subject (user id) — stored as ``sub``.
username: Username claim.
role: Role claim.
secret: HS256 signing secret.
session_id: Server-side session id. When provided, the token
carries ``sid`` and ``jti`` claims.
now: Override the issued-at time (for testing).
Returns:
The signed access token string.
"""
if not secret:
raise ValueError("JWT secret must not be empty")
issued_at = now or datetime.now(timezone.utc)
access_exp = issued_at + ACCESS_TOKEN_TTL
jti = str(uuid.uuid4()) if session_id else None
access_payload: dict[str, object] = {
"sub": user_id,
"username": username,
"role": role,
"type": "access",
"iat": int(issued_at.timestamp()),
"exp": int(access_exp.timestamp()),
}
if session_id:
access_payload["sid"] = session_id
access_payload["jti"] = jti
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
if isinstance(access_token, bytes):
access_token = access_token.decode("utf-8")
return access_token
def verify_token(
token: str,
secret: str,
*,
expected_type: str | None = None,
) -> dict[str, object]:
"""Verify a JWT and return its payload.
Args:
token: The JWT string to verify.
secret: The HS256 signing secret.
expected_type: When set, the ``type`` claim must match (e.g.
``"access"`` or ``"refresh"``). When ``None``, both
``access`` and ``refresh`` tokens are accepted (used by
the ``/auth/whoami`` cold-start path).
Returns:
The decoded payload as a dict. V2 tokens include ``sid`` and
``jti`` (access only); V1 tokens lack these fields and the
caller (U10 back-compat path) must handle the absence.
Raises:
jwt.InvalidTokenError: If the token is malformed, expired, has an
invalid signature, or has the wrong ``type`` claim.
"""
if not secret:
raise jwt.InvalidTokenError("JWT secret must not be empty")
payload = jwt.decode(token, secret, algorithms=[JWT_ALGORITHM])
token_type = payload.get("type")
if expected_type is not None and token_type != expected_type:
raise jwt.InvalidTokenError(
f"token type mismatch: expected {expected_type!r}, got {token_type!r}"
)
return payload