feat(auth): U2 JWT sid/jti claims + refresh-token denylist

Adds V2 JWT claim schema that closes the kicked-out window and enables
refresh-token rotation with reuse detection.

Server
- jwt_utils.create_token_pair now takes ``session_id`` and ``remember_me``
  kwargs.  When ``session_id`` is provided, both tokens carry a ``sid``
  claim and the access token also carries a ``jti`` claim; the refresh
  token's jti is intentionally absent (rotation uses the token hash).
- New ``REFRESH_TOKEN_TTL_REMEMBER_ME = 30d`` (default 7d) selected by
  the ``remember_me`` flag.
- ``verify_token`` now supports an optional ``expected_type`` filter
  (e.g. ``"access"`` / ``"refresh"``); when omitted, both types pass
  (used by /auth/whoami's cold-start path).
- New ``auth.denylist`` module: ``InMemoryRecentlyRevoked`` (default for
  the Tauri sidecar / dev mode) and ``RedisRecentlyRevoked`` (multi-
  process server).  Bounded LRU with auto-expiry via ``time.monotonic()``.

Backwards-compat
- Tokens issued before U2 (no ``sid``) are still accepted by
  ``verify_token``; validation falls through to the legacy
  ``user_sessions`` table via the U10 shim (next commit).

Tests
- tests/unit/auth/test_jwt_utils.py: 12 cases — V1/V2 claim presence,
  default + remember-me TTL, expected_type filter, expiry, wrong secret.
- tests/unit/auth/test_denylist.py: 6 cases — add/contains, TTL expiry,
  LRU eviction, re-add refresh, clear, hash stability.

Refs: U2 in docs/plans/2026-06-20-002-feat-centralized-auth-token-persistence-plan.md
This commit is contained in:
chiguyong 2026-06-21 01:53:13 +08:00
parent e39bf56248
commit 5ba1aceb96
4 changed files with 428 additions and 9 deletions

View File

@ -0,0 +1,147 @@
"""Recently-revoked token denylist for refresh-token rotation (U2 / U3).
When a refresh token is rotated, its hash is added to this denylist for a
short window (default 30 seconds) so a concurrent retry using the old
token can be detected. If the old token is reused within that window,
the :class:`SessionService` treats it as **token reuse** revokes
all sessions for the user.
This is the industry-standard refresh-token rotation pattern (Auth0,
Okta, AWS) that closes the window where an attacker who captured the
old token can still use it after the legitimate user has refreshed.
Two backends are provided:
- :class:`InMemoryRecentlyRevoked` single-process, default for Tauri
sidecar / dev mode. Bounded by ``max_entries``; oldest entries are
evicted on overflow.
- :class:`RedisRecentlyRevoked` multi-process via Redis ``SET`` +
``EXPIRE``; suitable for server deployments with multiple uvicorn
workers.
The default in-process backend is the one used by the sidecar. Tests
that want full control over the window can pass a custom
:class:`InMemoryRecentlyRevoked` instance to :class:`SessionService`.
"""
from __future__ import annotations
import hashlib
import logging
import time
from collections import OrderedDict
from typing import Protocol
logger = logging.getLogger(__name__)
DEFAULT_TTL_SECONDS = 30
DEFAULT_MAX_ENTRIES = 10_000
def hash_token(token: str) -> str:
"""Return the canonical SHA-256 hex digest used as the denylist key."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
class RecentlyRevoked(Protocol):
"""Storage interface for recently-revoked refresh-token hashes."""
def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
"""Add ``token_hash`` to the denylist for ``ttl_seconds``."""
...
def contains(self, token_hash: str) -> bool:
"""Return ``True`` if the hash is still in the denylist window."""
...
def clear(self) -> None:
"""Empty the denylist (used by tests)."""
...
class InMemoryRecentlyRevoked:
"""Bounded in-memory LRU denylist for the Tauri sidecar / dev mode.
Entries auto-expire after ``ttl_seconds``. When the cache exceeds
``max_entries`` the oldest entries are evicted. All operations are
O(1).
"""
def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES) -> None:
self._max = max_entries
self._entries: "OrderedDict[str, float]" = OrderedDict()
def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
expires_at = time.monotonic() + ttl_seconds
# Move-to-end semantics: re-adding an existing key resets the TTL
# and refreshes its position (LRU).
self._entries[token_hash] = expires_at
self._entries.move_to_end(token_hash)
while len(self._entries) > self._max:
self._entries.popitem(last=False)
def contains(self, token_hash: str) -> bool:
expires_at = self._entries.get(token_hash)
if expires_at is None:
return False
if expires_at <= time.monotonic():
# Expired — drop it and report False.
self._entries.pop(token_hash, None)
return False
return True
def clear(self) -> None:
self._entries.clear()
def __len__(self) -> int: # pragma: no cover — debug helper
return len(self._entries)
class RedisRecentlyRevoked:
"""Redis-backed denylist for multi-process server deployments.
Keys are stored with ``SET key value EX <ttl>`` so they auto-expire.
The ``value`` is irrelevant; only the key's existence matters.
Falls back to a no-op (returns ``False`` from ``contains``) when the
``redis`` package is unavailable or the connection fails this
matches the policy "never block login on infrastructure that's
transiently down".
"""
KEY_PREFIX = "auth:revoked:"
def __init__(self, client: object | None = None) -> None:
self._client = client
def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None:
if self._client is None:
return
try:
self._client.set(self.KEY_PREFIX + token_hash, "1", ex=ttl_seconds) # type: ignore[attr-defined]
except Exception as exc: # noqa: BLE001
logger.warning("RedisRecentlyRevoked.add failed: %s", exc)
def contains(self, token_hash: str) -> bool:
if self._client is None:
return False
try:
return bool(self._client.exists(self.KEY_PREFIX + token_hash)) # type: ignore[attr-defined]
except Exception as exc: # noqa: BLE001
logger.warning("RedisRecentlyRevoked.contains failed: %s", exc)
return False
def clear(self) -> None:
# No global flush; entries auto-expire. This method exists for
# test symmetry with the in-memory backend.
return
def build_default() -> RecentlyRevoked:
"""Build the default denylist for the current process.
Returns the in-memory implementation. Callers (e.g.
:class:`SessionService`) can pass a custom instance for testing
or to use the Redis backend in production.
"""
return InMemoryRecentlyRevoked()

View File

@ -7,7 +7,24 @@ and will invalidate all tokens on restart, so it must never be used in
production.
Access tokens are short-lived (15 min) and carry ``type="access"``.
Refresh tokens are long-lived (7 days) and carry ``type="refresh"``.
Refresh tokens are long-lived (7 days default, 30 days when the user opts
into "remember me") and carry ``type="refresh"``.
V2 claims (Centralized Auth & Token Persistence, U2):
- ``sid`` the session id (UUID) referencing a row in ``auth_sessions``.
The middleware uses this to validate that the session has not
been revoked server-side, closing the kicked-out window.
- ``jti`` a per-token unique id, used by the denylist for rotation.
Only the access token carries a ``jti``; the refresh token
uses its hashed value (``sha256(refresh_token)``) as its
rotation handle, so re-generating a jti on every refresh
would be wasteful.
Backwards-compat:
- Tokens issued before U2 land (no ``sid``) are still accepted by
:func:`verify_token`. Validation falls through to the legacy
``user_sessions`` table via :func:`agentkit.server.auth.dependencies`
(see U10 back-compat shim).
"""
from __future__ import annotations
@ -15,6 +32,7 @@ from __future__ import annotations
import logging
import os
import secrets
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
@ -26,6 +44,7 @@ logger = logging.getLogger(__name__)
# Token lifetimes
ACCESS_TOKEN_TTL = timedelta(minutes=15)
REFRESH_TOKEN_TTL = timedelta(days=7)
REFRESH_TOKEN_TTL_REMEMBER_ME = timedelta(days=30)
# JWT algorithm
JWT_ALGORITHM = "HS256"
@ -80,12 +99,23 @@ def get_or_create_jwt_secret() -> str:
return ephemeral
def _refresh_ttl_for(remember_me: bool) -> timedelta:
"""Return the refresh-token TTL for a given login.
- ``remember_me=False`` (default) 7 days
- ``remember_me=True`` 30 days
"""
return REFRESH_TOKEN_TTL_REMEMBER_ME if remember_me else REFRESH_TOKEN_TTL
def create_token_pair(
user_id: str,
username: str,
role: str,
secret: str,
*,
session_id: str | None = None,
remember_me: bool = False,
now: datetime | None = None,
) -> TokenPair:
"""Create a signed access + refresh JWT pair.
@ -95,6 +125,14 @@ def create_token_pair(
username: Username claim.
role: Role claim (e.g. ``member``, ``admin``).
secret: HS256 signing secret.
session_id: Server-side session id (UUID) referencing a row in
``auth_sessions``. When provided, the tokens carry a
``sid`` claim and the access token also carries a
``jti`` claim. Pass ``None`` for the V1 back-compat path
(U10): the resulting tokens are accepted by ``verify_token``
but lack session validation.
remember_me: When ``True``, the refresh token is valid for 30
days instead of the default 7.
now: Override the issued-at time (for testing). Defaults to UTC now.
Returns:
@ -104,8 +142,15 @@ def create_token_pair(
raise ValueError("JWT secret must not be empty")
issued_at = now or datetime.now(timezone.utc)
refresh_ttl = _refresh_ttl_for(remember_me)
access_exp = issued_at + ACCESS_TOKEN_TTL
refresh_exp = issued_at + REFRESH_TOKEN_TTL
refresh_exp = issued_at + refresh_ttl
# jti: per-token unique id (access only). The refresh token uses the
# sha256 of its plaintext value as the rotation handle in the
# denylist; giving it a jti too would be redundant and would bloat
# the JWT.
jti = str(uuid.uuid4()) if session_id else None
access_payload: dict[str, Any] = {
"sub": user_id,
@ -123,6 +168,10 @@ def create_token_pair(
"iat": int(issued_at.timestamp()),
"exp": int(refresh_exp.timestamp()),
}
if session_id:
access_payload["sid"] = session_id
access_payload["jti"] = jti
refresh_payload["sid"] = session_id
access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM)
refresh_token = jwt.encode(refresh_payload, secret, algorithm=JWT_ALGORITHM)
@ -141,23 +190,38 @@ def create_token_pair(
)
def verify_token(token: str, secret: str) -> dict[str, Any]:
def verify_token(
token: str,
secret: str,
*,
expected_type: str | None = None,
) -> dict[str, Any]:
"""Verify a JWT and return its payload.
Args:
token: The JWT string to verify.
secret: The HS256 signing secret.
expected_type: When set, the ``type`` claim must match (e.g.
``"access"`` or ``"refresh"``). When ``None``, both
``access`` and ``refresh`` tokens are accepted (used by
the ``/auth/whoami`` cold-start path).
Returns:
The decoded payload as a dict (contains ``sub``, ``username``,
``role``, ``type``, ``iat``, ``exp``).
The decoded payload as a dict. V2 tokens include ``sid`` and
``jti`` (access only); V1 tokens lack these fields and the
caller (U10 back-compat path) must handle the absence.
Raises:
jwt.InvalidTokenError: If the token is malformed, expired, or has
an invalid signature. Subclasses include ``ExpiredSignatureError``
and ``DecodeError``.
jwt.InvalidTokenError: If the token is malformed, expired, has an
invalid signature, or has the wrong ``type`` claim.
"""
if not secret:
raise jwt.InvalidTokenError("JWT secret must not be empty")
return jwt.decode(token, secret, algorithms=[JWT_ALGORITHM])
payload = jwt.decode(token, secret, algorithms=[JWT_ALGORITHM])
token_type = payload.get("type")
if expected_type is not None and token_type != expected_type:
raise jwt.InvalidTokenError(
f"token type mismatch: expected {expected_type!r}, got {token_type!r}"
)
return payload

View File

@ -0,0 +1,64 @@
"""Unit tests for the recently-revoked denylist (U2 / U3)."""
from __future__ import annotations
import time
from agentkit.server.auth.denylist import (
InMemoryRecentlyRevoked,
hash_token,
)
def test_add_then_contains_within_ttl():
cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked()
h = hash_token("token-a")
cache.add(h, ttl_seconds=10)
assert cache.contains(h) is True
def test_contains_returns_false_after_ttl_expires():
cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked()
h = hash_token("token-b")
cache.add(h, ttl_seconds=0)
# TTL=0 → expires immediately (or very soon)
time.sleep(0.05)
assert cache.contains(h) is False
def test_evicts_oldest_when_overflowed():
cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked(max_entries=2)
cache.add("a", ttl_seconds=10)
cache.add("b", ttl_seconds=10)
cache.add("c", ttl_seconds=10) # pushes 'a' out (LRU eviction)
assert cache.contains("a") is False
assert cache.contains("b") is True
assert cache.contains("c") is True
def test_re_add_refreshes_ttl_and_lru_position():
cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked(max_entries=2)
cache.add("a", ttl_seconds=10)
cache.add("b", ttl_seconds=10)
cache.add("a", ttl_seconds=10) # re-add: should now be MRU
cache.add("c", ttl_seconds=10) # evicts oldest, which is 'b' now
assert cache.contains("a") is True
assert cache.contains("b") is False
assert cache.contains("c") is True
def test_clear_drops_all_entries():
cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked()
cache.add("a", ttl_seconds=10)
cache.add("b", ttl_seconds=10)
cache.clear()
assert cache.contains("a") is False
assert cache.contains("b") is False
def test_hash_token_is_stable_and_sha256_hex():
h1 = hash_token("xyz")
h2 = hash_token("xyz")
assert h1 == h2
assert len(h1) == 64
assert all(c in "0123456789abcdef" for c in h1)

View File

@ -0,0 +1,144 @@
"""Unit tests for jwt_utils (V2 sid/jti claims)."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
import jwt
import pytest
from agentkit.server.auth.jwt_utils import (
ACCESS_TOKEN_TTL,
JWT_ALGORITHM,
REFRESH_TOKEN_TTL,
REFRESH_TOKEN_TTL_REMEMBER_ME,
create_token_pair,
verify_token,
)
SECRET = "test-secret-with-at-least-32-bytes-1234"
def test_create_token_pair_has_required_v1_claims():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
for token in (pair.access_token, pair.refresh_token):
payload = jwt.decode(token, SECRET, algorithms=[JWT_ALGORITHM])
assert payload["sub"] == "u-1"
assert payload["username"] == "alice"
assert payload["role"] == "member"
assert payload["iat"] and payload["exp"]
def test_v1_pair_has_no_sid_or_jti():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM])
refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM])
assert "sid" not in access
assert "jti" not in access
assert "sid" not in refresh
def test_v2_pair_includes_sid_on_both_tokens_and_jti_on_access():
pair = create_token_pair(
user_id="u-1",
username="alice",
role="member",
secret=SECRET,
session_id="sess-abc",
)
access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM])
refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM])
assert access["sid"] == "sess-abc"
assert refresh["sid"] == "sess-abc"
assert access["jti"] and isinstance(access["jti"], str)
# Refresh intentionally has no jti — rotation uses the token hash.
assert "jti" not in refresh
def test_refresh_token_type_is_refresh():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM])
assert refresh["type"] == "refresh"
def test_access_token_type_is_access():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM])
assert access["type"] == "access"
def test_default_refresh_ttl_is_7_days():
now = datetime(2026, 6, 20, 0, 0, 0, tzinfo=timezone.utc)
pair = create_token_pair(
user_id="u-1",
username="alice",
role="member",
secret=SECRET,
now=now,
)
assert (pair.access_expires_at - now) == ACCESS_TOKEN_TTL
assert (pair.refresh_expires_at - now) == REFRESH_TOKEN_TTL
def test_remember_me_extends_refresh_ttl_to_30_days():
now = datetime(2026, 6, 20, 0, 0, 0, tzinfo=timezone.utc)
pair = create_token_pair(
user_id="u-1",
username="alice",
role="member",
secret=SECRET,
now=now,
remember_me=True,
)
assert (pair.refresh_expires_at - now) == REFRESH_TOKEN_TTL_REMEMBER_ME
def test_verify_token_accepts_access_and_refresh_by_default():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
a = verify_token(pair.access_token, SECRET)
r = verify_token(pair.refresh_token, SECRET)
assert a["type"] == "access"
assert r["type"] == "refresh"
def test_verify_token_expected_type_filters_other_type():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
with pytest.raises(jwt.InvalidTokenError):
verify_token(pair.access_token, SECRET, expected_type="refresh")
with pytest.raises(jwt.InvalidTokenError):
verify_token(pair.refresh_token, SECRET, expected_type="access")
def test_verify_token_rejects_expired():
# Use a "now" far in the past so the access token is definitely expired.
past = datetime.now(timezone.utc) - timedelta(hours=1)
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET, now=past
)
with pytest.raises(jwt.ExpiredSignatureError):
verify_token(pair.access_token, SECRET)
def test_verify_token_rejects_wrong_secret():
pair = create_token_pair(
user_id="u-1", username="alice", role="member", secret=SECRET
)
with pytest.raises(jwt.InvalidTokenError):
verify_token(pair.access_token, "other-secret")
def test_create_token_pair_rejects_empty_secret():
with pytest.raises(ValueError):
create_token_pair(user_id="u-1", username="alice", role="member", secret="")