298 lines
10 KiB
Python
298 lines
10 KiB
Python
"""JWT issuance and verification utilities.
|
|
|
|
Tokens are HS256-signed with a secret read from the ``AGENTKIT_JWT_SECRET``
|
|
environment variable. In dev mode (no secret configured) a random secret is
|
|
generated in-process and a warning is logged — this secret is *not* persisted
|
|
and will invalidate all tokens on restart, so it must never be used in
|
|
production.
|
|
|
|
Access tokens are short-lived (15 min) and carry ``type="access"``.
|
|
Refresh tokens are long-lived (7 days default, 30 days when the user opts
|
|
into "remember me") and carry ``type="refresh"``.
|
|
|
|
V2 claims (Centralized Auth & Token Persistence, U2):
|
|
- ``sid`` — the session id (UUID) referencing a row in ``auth_sessions``.
|
|
The middleware uses this to validate that the session has not
|
|
been revoked server-side, closing the kicked-out window.
|
|
- ``jti`` — a per-token unique id, used by the denylist for rotation.
|
|
Only the access token carries a ``jti``; the refresh token
|
|
uses its hashed value (``sha256(refresh_token)``) as its
|
|
rotation handle, so re-generating a jti on every refresh
|
|
would be wasteful.
|
|
|
|
Backwards-compat:
|
|
- Tokens issued before U2 land (no ``sid``) are still accepted by
|
|
:func:`verify_token`. Validation falls through to the legacy
|
|
``user_sessions`` table via :func:`agentkit.server.auth.dependencies`
|
|
(see U10 back-compat shim).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import secrets
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import jwt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Token lifetimes
|
|
ACCESS_TOKEN_TTL = timedelta(minutes=15)
|
|
REFRESH_TOKEN_TTL = timedelta(days=7)
|
|
REFRESH_TOKEN_TTL_REMEMBER_ME = timedelta(days=30)
|
|
|
|
# JWT algorithm
|
|
JWT_ALGORITHM = "HS256"
|
|
|
|
|
|
@dataclass
|
|
class TokenPair:
|
|
"""A signed access + refresh JWT pair with their expiry timestamps."""
|
|
|
|
access_token: str
|
|
refresh_token: str
|
|
access_expires_at: datetime
|
|
refresh_expires_at: datetime
|
|
|
|
|
|
def get_jwt_secret() -> str | None:
|
|
"""Return the configured JWT secret, or ``None`` if not configured.
|
|
|
|
Reads from the ``AGENTKIT_JWT_SECRET`` env var. Returns ``None`` when
|
|
the env var is not set — callers (middleware, routes) decide how to
|
|
handle the unset case. Use :func:`get_or_create_jwt_secret` when a
|
|
non-empty secret is required (e.g. for signing tokens in dev mode).
|
|
"""
|
|
secret = os.environ.get("AGENTKIT_JWT_SECRET")
|
|
if secret:
|
|
return secret
|
|
logger.warning(
|
|
"AGENTKIT_JWT_SECRET is not set. JWT auth is disabled in the "
|
|
"middleware; token-signing routes will use an ephemeral secret "
|
|
"that is invalidated on restart. Set AGENTKIT_JWT_SECRET for "
|
|
"production use."
|
|
)
|
|
return None
|
|
|
|
|
|
def get_or_create_jwt_secret() -> str:
|
|
"""Return the configured JWT secret, or generate an ephemeral one.
|
|
|
|
Use this when a non-empty secret is required (e.g. signing tokens).
|
|
For middleware configuration, prefer :func:`get_jwt_secret` so that
|
|
the absence of a secret disables JWT auth (dev mode).
|
|
"""
|
|
secret = os.environ.get("AGENTKIT_JWT_SECRET")
|
|
if secret:
|
|
return secret
|
|
ephemeral = secrets.token_urlsafe(48)
|
|
logger.warning(
|
|
"AGENTKIT_JWT_SECRET is not set — generated an ephemeral dev secret. "
|
|
"All JWTs will be invalidated on process restart. "
|
|
"Set AGENTKIT_JWT_SECRET in the environment for production use."
|
|
)
|
|
return ephemeral
|
|
|
|
|
|
def _refresh_ttl_for(remember_me: bool) -> timedelta:
|
|
"""Return the refresh-token TTL for a given login.
|
|
|
|
- ``remember_me=False`` (default) → 7 days
|
|
- ``remember_me=True`` → 30 days
|
|
"""
|
|
return REFRESH_TOKEN_TTL_REMEMBER_ME if remember_me else REFRESH_TOKEN_TTL
|
|
|
|
|
|
def create_token_pair(
|
|
user_id: str,
|
|
username: str,
|
|
role: str,
|
|
secret: str,
|
|
*,
|
|
session_id: str | None = None,
|
|
remember_me: bool = False,
|
|
now: datetime | None = None,
|
|
legacy_mode: bool = False,
|
|
) -> TokenPair:
|
|
"""Create a signed access + refresh JWT pair.
|
|
|
|
Args:
|
|
user_id: Subject (user id) — stored as ``sub``.
|
|
username: Username claim.
|
|
role: Role claim (e.g. ``member``, ``admin``).
|
|
secret: HS256 signing secret.
|
|
session_id: Server-side session id (UUID) referencing a row in
|
|
``auth_sessions``. When provided, the tokens carry a
|
|
``sid`` claim and the access token also carries a
|
|
``jti`` claim. Pass ``None`` for the V1 back-compat path
|
|
(U10): the resulting tokens are accepted by ``verify_token``
|
|
but lack session validation.
|
|
remember_me: When ``True``, the refresh token is valid for 30
|
|
days instead of the default 7.
|
|
now: Override the issued-at time (for testing). Defaults to UTC now.
|
|
legacy_mode: When ``True``, the resulting tokens are
|
|
intentionally issued without a ``sid`` claim, regardless of
|
|
whether ``session_id`` was passed. This is the U10
|
|
back-compat flag used by the login route for clients with
|
|
``X-Client-Version`` below the rollout cutoff. Default
|
|
``False`` — production tokens always carry ``sid``.
|
|
|
|
Returns:
|
|
A :class:`TokenPair` with both signed tokens and their expiry times.
|
|
"""
|
|
if not secret:
|
|
raise ValueError("JWT secret must not be empty")
|
|
|
|
# U10: legacy_mode forces the no-sid path even if the caller has
|
|
# already created a session row. This avoids handing a fresh sid
|
|
# to a client that doesn't know how to use it (the
|
|
# ``/auth/whoami`` route will fall back to its legacy branch).
|
|
effective_session_id: str | None = None if legacy_mode else session_id
|
|
|
|
issued_at = now or datetime.now(timezone.utc)
|
|
refresh_ttl = _refresh_ttl_for(remember_me)
|
|
access_exp = issued_at + ACCESS_TOKEN_TTL
|
|
refresh_exp = issued_at + refresh_ttl
|
|
|
|
# jti: per-token unique id (access only). The refresh token uses the
|
|
# sha256 of its plaintext value as the rotation handle in the
|
|
# denylist; giving it a jti too would be redundant and would bloat
|
|
# the JWT.
|
|
jti = str(uuid.uuid4()) if effective_session_id else None
|
|
|
|
access_payload: dict[str, object] = {
|
|
"sub": user_id,
|
|
"username": username,
|
|
"role": role,
|
|
"type": "access",
|
|
"iat": int(issued_at.timestamp()),
|
|
"exp": int(access_exp.timestamp()),
|
|
}
|
|
refresh_payload: dict[str, object] = {
|
|
"sub": user_id,
|
|
"username": username,
|
|
"role": role,
|
|
"type": "refresh",
|
|
"iat": int(issued_at.timestamp()),
|
|
"exp": int(refresh_exp.timestamp()),
|
|
# Persist the remember_me flag so /auth/refresh can inherit
|
|
# the original TTL (30d vs 7d) without the client re-sending
|
|
# it. Without this claim, every refresh would reset to the
|
|
# default 7-day TTL, defeating the "记住我 30 天" checkbox.
|
|
"rmb": remember_me,
|
|
}
|
|
if effective_session_id:
|
|
access_payload["sid"] = effective_session_id
|
|
access_payload["jti"] = jti
|
|
refresh_payload["sid"] = effective_session_id
|
|
|
|
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
|
|
refresh_token = jwt.encode(refresh_payload, secret, algorithm=JWT_ALGORITHM)
|
|
|
|
# PyJWT >= 2 returns str; older versions returned bytes. Normalize.
|
|
if isinstance(access_token, bytes):
|
|
access_token = access_token.decode("utf-8")
|
|
if isinstance(refresh_token, bytes):
|
|
refresh_token = refresh_token.decode("utf-8")
|
|
|
|
return TokenPair(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
access_expires_at=access_exp,
|
|
refresh_expires_at=refresh_exp,
|
|
)
|
|
|
|
|
|
def create_access_token(
|
|
user_id: str,
|
|
username: str,
|
|
role: str,
|
|
secret: str,
|
|
*,
|
|
session_id: str | None = None,
|
|
now: datetime | None = None,
|
|
) -> str:
|
|
"""Create a single signed access JWT (no refresh token).
|
|
|
|
Used by ``/auth/whoami`` cold-start to issue a fresh access token
|
|
without creating a new refresh token (the client already has one).
|
|
This avoids the token-amplification risk of ``create_token_pair``
|
|
which would silently discard the new refresh token.
|
|
|
|
Args:
|
|
user_id: Subject (user id) — stored as ``sub``.
|
|
username: Username claim.
|
|
role: Role claim.
|
|
secret: HS256 signing secret.
|
|
session_id: Server-side session id. When provided, the token
|
|
carries ``sid`` and ``jti`` claims.
|
|
now: Override the issued-at time (for testing).
|
|
|
|
Returns:
|
|
The signed access token string.
|
|
"""
|
|
if not secret:
|
|
raise ValueError("JWT secret must not be empty")
|
|
|
|
issued_at = now or datetime.now(timezone.utc)
|
|
access_exp = issued_at + ACCESS_TOKEN_TTL
|
|
jti = str(uuid.uuid4()) if session_id else None
|
|
|
|
access_payload: dict[str, object] = {
|
|
"sub": user_id,
|
|
"username": username,
|
|
"role": role,
|
|
"type": "access",
|
|
"iat": int(issued_at.timestamp()),
|
|
"exp": int(access_exp.timestamp()),
|
|
}
|
|
if session_id:
|
|
access_payload["sid"] = session_id
|
|
access_payload["jti"] = jti
|
|
|
|
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
|
|
if isinstance(access_token, bytes):
|
|
access_token = access_token.decode("utf-8")
|
|
return access_token
|
|
|
|
|
|
def verify_token(
|
|
token: str,
|
|
secret: str,
|
|
*,
|
|
expected_type: str | None = None,
|
|
) -> dict[str, object]:
|
|
"""Verify a JWT and return its payload.
|
|
|
|
Args:
|
|
token: The JWT string to verify.
|
|
secret: The HS256 signing secret.
|
|
expected_type: When set, the ``type`` claim must match (e.g.
|
|
``"access"`` or ``"refresh"``). When ``None``, both
|
|
``access`` and ``refresh`` tokens are accepted (used by
|
|
the ``/auth/whoami`` cold-start path).
|
|
|
|
Returns:
|
|
The decoded payload as a dict. V2 tokens include ``sid`` and
|
|
``jti`` (access only); V1 tokens lack these fields and the
|
|
caller (U10 back-compat path) must handle the absence.
|
|
|
|
Raises:
|
|
jwt.InvalidTokenError: If the token is malformed, expired, has an
|
|
invalid signature, or has the wrong ``type`` claim.
|
|
"""
|
|
if not secret:
|
|
raise jwt.InvalidTokenError("JWT secret must not be empty")
|
|
|
|
payload = jwt.decode(token, secret, algorithms=[JWT_ALGORITHM])
|
|
token_type = payload.get("type")
|
|
if expected_type is not None and token_type != expected_type:
|
|
raise jwt.InvalidTokenError(
|
|
f"token type mismatch: expected {expected_type!r}, got {token_type!r}"
|
|
)
|
|
return payload
|