"""JWT issuance and verification utilities. Tokens are HS256-signed with a secret read from the ``AGENTKIT_JWT_SECRET`` environment variable. In dev mode (no secret configured) a random secret is generated in-process and a warning is logged — this secret is *not* persisted and will invalidate all tokens on restart, so it must never be used in production. Access tokens are short-lived (15 min) and carry ``type="access"``. Refresh tokens are long-lived (7 days default, 30 days when the user opts into "remember me") and carry ``type="refresh"``. V2 claims (Centralized Auth & Token Persistence, U2): - ``sid`` — the session id (UUID) referencing a row in ``auth_sessions``. The middleware uses this to validate that the session has not been revoked server-side, closing the kicked-out window. - ``jti`` — a per-token unique id, used by the denylist for rotation. Only the access token carries a ``jti``; the refresh token uses its hashed value (``sha256(refresh_token)``) as its rotation handle, so re-generating a jti on every refresh would be wasteful. Backwards-compat: - Tokens issued before U2 land (no ``sid``) are still accepted by :func:`verify_token`. Validation falls through to the legacy ``user_sessions`` table via :func:`agentkit.server.auth.dependencies` (see U10 back-compat shim). """ from __future__ import annotations import logging import os import secrets import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone import jwt logger = logging.getLogger(__name__) # Token lifetimes ACCESS_TOKEN_TTL = timedelta(minutes=15) REFRESH_TOKEN_TTL = timedelta(days=7) REFRESH_TOKEN_TTL_REMEMBER_ME = timedelta(days=30) # JWT algorithm JWT_ALGORITHM = "HS256" @dataclass class TokenPair: """A signed access + refresh JWT pair with their expiry timestamps.""" access_token: str refresh_token: str access_expires_at: datetime refresh_expires_at: datetime def get_jwt_secret() -> str | None: """Return the configured JWT secret, or ``None`` if not configured. Reads from the ``AGENTKIT_JWT_SECRET`` env var. Returns ``None`` when the env var is not set — callers (middleware, routes) decide how to handle the unset case. Use :func:`get_or_create_jwt_secret` when a non-empty secret is required (e.g. for signing tokens in dev mode). """ secret = os.environ.get("AGENTKIT_JWT_SECRET") if secret: return secret logger.warning( "AGENTKIT_JWT_SECRET is not set. JWT auth is disabled in the " "middleware; token-signing routes will use an ephemeral secret " "that is invalidated on restart. Set AGENTKIT_JWT_SECRET for " "production use." ) return None def get_or_create_jwt_secret() -> str: """Return the configured JWT secret, or generate an ephemeral one. Use this when a non-empty secret is required (e.g. signing tokens). For middleware configuration, prefer :func:`get_jwt_secret` so that the absence of a secret disables JWT auth (dev mode). """ secret = os.environ.get("AGENTKIT_JWT_SECRET") if secret: return secret ephemeral = secrets.token_urlsafe(48) logger.warning( "AGENTKIT_JWT_SECRET is not set — generated an ephemeral dev secret. " "All JWTs will be invalidated on process restart. " "Set AGENTKIT_JWT_SECRET in the environment for production use." ) return ephemeral def _refresh_ttl_for(remember_me: bool) -> timedelta: """Return the refresh-token TTL for a given login. - ``remember_me=False`` (default) → 7 days - ``remember_me=True`` → 30 days """ return REFRESH_TOKEN_TTL_REMEMBER_ME if remember_me else REFRESH_TOKEN_TTL def create_token_pair( user_id: str, username: str, role: str, secret: str, *, session_id: str | None = None, remember_me: bool = False, now: datetime | None = None, legacy_mode: bool = False, ) -> TokenPair: """Create a signed access + refresh JWT pair. Args: user_id: Subject (user id) — stored as ``sub``. username: Username claim. role: Role claim (e.g. ``member``, ``admin``). secret: HS256 signing secret. session_id: Server-side session id (UUID) referencing a row in ``auth_sessions``. When provided, the tokens carry a ``sid`` claim and the access token also carries a ``jti`` claim. Pass ``None`` for the V1 back-compat path (U10): the resulting tokens are accepted by ``verify_token`` but lack session validation. remember_me: When ``True``, the refresh token is valid for 30 days instead of the default 7. now: Override the issued-at time (for testing). Defaults to UTC now. legacy_mode: When ``True``, the resulting tokens are intentionally issued without a ``sid`` claim, regardless of whether ``session_id`` was passed. This is the U10 back-compat flag used by the login route for clients with ``X-Client-Version`` below the rollout cutoff. Default ``False`` — production tokens always carry ``sid``. Returns: A :class:`TokenPair` with both signed tokens and their expiry times. """ if not secret: raise ValueError("JWT secret must not be empty") # U10: legacy_mode forces the no-sid path even if the caller has # already created a session row. This avoids handing a fresh sid # to a client that doesn't know how to use it (the # ``/auth/whoami`` route will fall back to its legacy branch). effective_session_id: str | None = None if legacy_mode else session_id issued_at = now or datetime.now(timezone.utc) refresh_ttl = _refresh_ttl_for(remember_me) access_exp = issued_at + ACCESS_TOKEN_TTL refresh_exp = issued_at + refresh_ttl # jti: per-token unique id (access only). The refresh token uses the # sha256 of its plaintext value as the rotation handle in the # denylist; giving it a jti too would be redundant and would bloat # the JWT. jti = str(uuid.uuid4()) if effective_session_id else None access_payload: dict[str, object] = { "sub": user_id, "username": username, "role": role, "type": "access", "iat": int(issued_at.timestamp()), "exp": int(access_exp.timestamp()), } refresh_payload: dict[str, object] = { "sub": user_id, "username": username, "role": role, "type": "refresh", "iat": int(issued_at.timestamp()), "exp": int(refresh_exp.timestamp()), # Persist the remember_me flag so /auth/refresh can inherit # the original TTL (30d vs 7d) without the client re-sending # it. Without this claim, every refresh would reset to the # default 7-day TTL, defeating the "记住我 30 天" checkbox. "rmb": remember_me, } if effective_session_id: access_payload["sid"] = effective_session_id access_payload["jti"] = jti refresh_payload["sid"] = effective_session_id access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM) refresh_token = jwt.encode(refresh_payload, secret, algorithm=JWT_ALGORITHM) # PyJWT >= 2 returns str; older versions returned bytes. Normalize. if isinstance(access_token, bytes): access_token = access_token.decode("utf-8") if isinstance(refresh_token, bytes): refresh_token = refresh_token.decode("utf-8") return TokenPair( access_token=access_token, refresh_token=refresh_token, access_expires_at=access_exp, refresh_expires_at=refresh_exp, ) def create_access_token( user_id: str, username: str, role: str, secret: str, *, session_id: str | None = None, now: datetime | None = None, ) -> str: """Create a single signed access JWT (no refresh token). Used by ``/auth/whoami`` cold-start to issue a fresh access token without creating a new refresh token (the client already has one). This avoids the token-amplification risk of ``create_token_pair`` which would silently discard the new refresh token. Args: user_id: Subject (user id) — stored as ``sub``. username: Username claim. role: Role claim. secret: HS256 signing secret. session_id: Server-side session id. When provided, the token carries ``sid`` and ``jti`` claims. now: Override the issued-at time (for testing). Returns: The signed access token string. """ if not secret: raise ValueError("JWT secret must not be empty") issued_at = now or datetime.now(timezone.utc) access_exp = issued_at + ACCESS_TOKEN_TTL jti = str(uuid.uuid4()) if session_id else None access_payload: dict[str, object] = { "sub": user_id, "username": username, "role": role, "type": "access", "iat": int(issued_at.timestamp()), "exp": int(access_exp.timestamp()), } if session_id: access_payload["sid"] = session_id access_payload["jti"] = jti access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM) if isinstance(access_token, bytes): access_token = access_token.decode("utf-8") return access_token def verify_token( token: str, secret: str, *, expected_type: str | None = None, ) -> dict[str, object]: """Verify a JWT and return its payload. Args: token: The JWT string to verify. secret: The HS256 signing secret. expected_type: When set, the ``type`` claim must match (e.g. ``"access"`` or ``"refresh"``). When ``None``, both ``access`` and ``refresh`` tokens are accepted (used by the ``/auth/whoami`` cold-start path). Returns: The decoded payload as a dict. V2 tokens include ``sid`` and ``jti`` (access only); V1 tokens lack these fields and the caller (U10 back-compat path) must handle the absence. Raises: jwt.InvalidTokenError: If the token is malformed, expired, has an invalid signature, or has the wrong ``type`` claim. """ if not secret: raise jwt.InvalidTokenError("JWT secret must not be empty") payload = jwt.decode(token, secret, algorithms=[JWT_ALGORITHM]) token_type = payload.get("type") if expected_type is not None and token_type != expected_type: raise jwt.InvalidTokenError( f"token type mismatch: expected {expected_type!r}, got {token_type!r}" ) return payload