diff --git a/src/agentkit/server/auth/denylist.py b/src/agentkit/server/auth/denylist.py new file mode 100644 index 0000000..06a14d8 --- /dev/null +++ b/src/agentkit/server/auth/denylist.py @@ -0,0 +1,147 @@ +"""Recently-revoked token denylist for refresh-token rotation (U2 / U3). + +When a refresh token is rotated, its hash is added to this denylist for a +short window (default 30 seconds) so a concurrent retry using the old +token can be detected. If the old token is reused within that window, +the :class:`SessionService` treats it as **token reuse** → revokes +all sessions for the user. + +This is the industry-standard refresh-token rotation pattern (Auth0, +Okta, AWS) that closes the window where an attacker who captured the +old token can still use it after the legitimate user has refreshed. + +Two backends are provided: + +- :class:`InMemoryRecentlyRevoked` — single-process, default for Tauri + sidecar / dev mode. Bounded by ``max_entries``; oldest entries are + evicted on overflow. +- :class:`RedisRecentlyRevoked` — multi-process via Redis ``SET`` + + ``EXPIRE``; suitable for server deployments with multiple uvicorn + workers. + +The default in-process backend is the one used by the sidecar. Tests +that want full control over the window can pass a custom +:class:`InMemoryRecentlyRevoked` instance to :class:`SessionService`. +""" + +from __future__ import annotations + +import hashlib +import logging +import time +from collections import OrderedDict +from typing import Protocol + +logger = logging.getLogger(__name__) + +DEFAULT_TTL_SECONDS = 30 +DEFAULT_MAX_ENTRIES = 10_000 + + +def hash_token(token: str) -> str: + """Return the canonical SHA-256 hex digest used as the denylist key.""" + return hashlib.sha256(token.encode("utf-8")).hexdigest() + + +class RecentlyRevoked(Protocol): + """Storage interface for recently-revoked refresh-token hashes.""" + + def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None: + """Add ``token_hash`` to the denylist for ``ttl_seconds``.""" + ... + + def contains(self, token_hash: str) -> bool: + """Return ``True`` if the hash is still in the denylist window.""" + ... + + def clear(self) -> None: + """Empty the denylist (used by tests).""" + ... + + +class InMemoryRecentlyRevoked: + """Bounded in-memory LRU denylist for the Tauri sidecar / dev mode. + + Entries auto-expire after ``ttl_seconds``. When the cache exceeds + ``max_entries`` the oldest entries are evicted. All operations are + O(1). + """ + + def __init__(self, max_entries: int = DEFAULT_MAX_ENTRIES) -> None: + self._max = max_entries + self._entries: "OrderedDict[str, float]" = OrderedDict() + + def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None: + expires_at = time.monotonic() + ttl_seconds + # Move-to-end semantics: re-adding an existing key resets the TTL + # and refreshes its position (LRU). + self._entries[token_hash] = expires_at + self._entries.move_to_end(token_hash) + while len(self._entries) > self._max: + self._entries.popitem(last=False) + + def contains(self, token_hash: str) -> bool: + expires_at = self._entries.get(token_hash) + if expires_at is None: + return False + if expires_at <= time.monotonic(): + # Expired — drop it and report False. + self._entries.pop(token_hash, None) + return False + return True + + def clear(self) -> None: + self._entries.clear() + + def __len__(self) -> int: # pragma: no cover — debug helper + return len(self._entries) + + +class RedisRecentlyRevoked: + """Redis-backed denylist for multi-process server deployments. + + Keys are stored with ``SET key value EX `` so they auto-expire. + The ``value`` is irrelevant; only the key's existence matters. + + Falls back to a no-op (returns ``False`` from ``contains``) when the + ``redis`` package is unavailable or the connection fails — this + matches the policy "never block login on infrastructure that's + transiently down". + """ + + KEY_PREFIX = "auth:revoked:" + + def __init__(self, client: object | None = None) -> None: + self._client = client + + def add(self, token_hash: str, ttl_seconds: int = DEFAULT_TTL_SECONDS) -> None: + if self._client is None: + return + try: + self._client.set(self.KEY_PREFIX + token_hash, "1", ex=ttl_seconds) # type: ignore[attr-defined] + except Exception as exc: # noqa: BLE001 + logger.warning("RedisRecentlyRevoked.add failed: %s", exc) + + def contains(self, token_hash: str) -> bool: + if self._client is None: + return False + try: + return bool(self._client.exists(self.KEY_PREFIX + token_hash)) # type: ignore[attr-defined] + except Exception as exc: # noqa: BLE001 + logger.warning("RedisRecentlyRevoked.contains failed: %s", exc) + return False + + def clear(self) -> None: + # No global flush; entries auto-expire. This method exists for + # test symmetry with the in-memory backend. + return + + +def build_default() -> RecentlyRevoked: + """Build the default denylist for the current process. + + Returns the in-memory implementation. Callers (e.g. + :class:`SessionService`) can pass a custom instance for testing + or to use the Redis backend in production. + """ + return InMemoryRecentlyRevoked() diff --git a/src/agentkit/server/auth/jwt_utils.py b/src/agentkit/server/auth/jwt_utils.py index 3a4f1ed..b4865d8 100644 --- a/src/agentkit/server/auth/jwt_utils.py +++ b/src/agentkit/server/auth/jwt_utils.py @@ -7,7 +7,24 @@ and will invalidate all tokens on restart, so it must never be used in production. Access tokens are short-lived (15 min) and carry ``type="access"``. -Refresh tokens are long-lived (7 days) and carry ``type="refresh"``. +Refresh tokens are long-lived (7 days default, 30 days when the user opts +into "remember me") and carry ``type="refresh"``. + +V2 claims (Centralized Auth & Token Persistence, U2): +- ``sid`` — the session id (UUID) referencing a row in ``auth_sessions``. + The middleware uses this to validate that the session has not + been revoked server-side, closing the kicked-out window. +- ``jti`` — a per-token unique id, used by the denylist for rotation. + Only the access token carries a ``jti``; the refresh token + uses its hashed value (``sha256(refresh_token)``) as its + rotation handle, so re-generating a jti on every refresh + would be wasteful. + +Backwards-compat: +- Tokens issued before U2 land (no ``sid``) are still accepted by + :func:`verify_token`. Validation falls through to the legacy + ``user_sessions`` table via :func:`agentkit.server.auth.dependencies` + (see U10 back-compat shim). """ from __future__ import annotations @@ -15,6 +32,7 @@ from __future__ import annotations import logging import os import secrets +import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any @@ -26,6 +44,7 @@ logger = logging.getLogger(__name__) # Token lifetimes ACCESS_TOKEN_TTL = timedelta(minutes=15) REFRESH_TOKEN_TTL = timedelta(days=7) +REFRESH_TOKEN_TTL_REMEMBER_ME = timedelta(days=30) # JWT algorithm JWT_ALGORITHM = "HS256" @@ -80,12 +99,23 @@ def get_or_create_jwt_secret() -> str: return ephemeral +def _refresh_ttl_for(remember_me: bool) -> timedelta: + """Return the refresh-token TTL for a given login. + + - ``remember_me=False`` (default) → 7 days + - ``remember_me=True`` → 30 days + """ + return REFRESH_TOKEN_TTL_REMEMBER_ME if remember_me else REFRESH_TOKEN_TTL + + def create_token_pair( user_id: str, username: str, role: str, secret: str, *, + session_id: str | None = None, + remember_me: bool = False, now: datetime | None = None, ) -> TokenPair: """Create a signed access + refresh JWT pair. @@ -95,6 +125,14 @@ def create_token_pair( username: Username claim. role: Role claim (e.g. ``member``, ``admin``). secret: HS256 signing secret. + session_id: Server-side session id (UUID) referencing a row in + ``auth_sessions``. When provided, the tokens carry a + ``sid`` claim and the access token also carries a + ``jti`` claim. Pass ``None`` for the V1 back-compat path + (U10): the resulting tokens are accepted by ``verify_token`` + but lack session validation. + remember_me: When ``True``, the refresh token is valid for 30 + days instead of the default 7. now: Override the issued-at time (for testing). Defaults to UTC now. Returns: @@ -104,8 +142,15 @@ def create_token_pair( raise ValueError("JWT secret must not be empty") issued_at = now or datetime.now(timezone.utc) + refresh_ttl = _refresh_ttl_for(remember_me) access_exp = issued_at + ACCESS_TOKEN_TTL - refresh_exp = issued_at + REFRESH_TOKEN_TTL + refresh_exp = issued_at + refresh_ttl + + # jti: per-token unique id (access only). The refresh token uses the + # sha256 of its plaintext value as the rotation handle in the + # denylist; giving it a jti too would be redundant and would bloat + # the JWT. + jti = str(uuid.uuid4()) if session_id else None access_payload: dict[str, Any] = { "sub": user_id, @@ -123,6 +168,10 @@ def create_token_pair( "iat": int(issued_at.timestamp()), "exp": int(refresh_exp.timestamp()), } + if session_id: + access_payload["sid"] = session_id + access_payload["jti"] = jti + refresh_payload["sid"] = session_id access_token = jwt.encode(access_payload, secret, algorithm=JWT_ALGORITHM) refresh_token = jwt.encode(refresh_payload, secret, algorithm=JWT_ALGORITHM) @@ -141,23 +190,38 @@ def create_token_pair( ) -def verify_token(token: str, secret: str) -> dict[str, Any]: +def verify_token( + token: str, + secret: str, + *, + expected_type: str | None = None, +) -> dict[str, Any]: """Verify a JWT and return its payload. Args: token: The JWT string to verify. secret: The HS256 signing secret. + expected_type: When set, the ``type`` claim must match (e.g. + ``"access"`` or ``"refresh"``). When ``None``, both + ``access`` and ``refresh`` tokens are accepted (used by + the ``/auth/whoami`` cold-start path). Returns: - The decoded payload as a dict (contains ``sub``, ``username``, - ``role``, ``type``, ``iat``, ``exp``). + The decoded payload as a dict. V2 tokens include ``sid`` and + ``jti`` (access only); V1 tokens lack these fields and the + caller (U10 back-compat path) must handle the absence. Raises: - jwt.InvalidTokenError: If the token is malformed, expired, or has - an invalid signature. Subclasses include ``ExpiredSignatureError`` - and ``DecodeError``. + jwt.InvalidTokenError: If the token is malformed, expired, has an + invalid signature, or has the wrong ``type`` claim. """ if not secret: raise jwt.InvalidTokenError("JWT secret must not be empty") - return jwt.decode(token, secret, algorithms=[JWT_ALGORITHM]) + payload = jwt.decode(token, secret, algorithms=[JWT_ALGORITHM]) + token_type = payload.get("type") + if expected_type is not None and token_type != expected_type: + raise jwt.InvalidTokenError( + f"token type mismatch: expected {expected_type!r}, got {token_type!r}" + ) + return payload diff --git a/tests/unit/auth/test_denylist.py b/tests/unit/auth/test_denylist.py new file mode 100644 index 0000000..a586d1f --- /dev/null +++ b/tests/unit/auth/test_denylist.py @@ -0,0 +1,64 @@ +"""Unit tests for the recently-revoked denylist (U2 / U3).""" + +from __future__ import annotations + +import time + +from agentkit.server.auth.denylist import ( + InMemoryRecentlyRevoked, + hash_token, +) + + +def test_add_then_contains_within_ttl(): + cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked() + h = hash_token("token-a") + cache.add(h, ttl_seconds=10) + assert cache.contains(h) is True + + +def test_contains_returns_false_after_ttl_expires(): + cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked() + h = hash_token("token-b") + cache.add(h, ttl_seconds=0) + # TTL=0 → expires immediately (or very soon) + time.sleep(0.05) + assert cache.contains(h) is False + + +def test_evicts_oldest_when_overflowed(): + cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked(max_entries=2) + cache.add("a", ttl_seconds=10) + cache.add("b", ttl_seconds=10) + cache.add("c", ttl_seconds=10) # pushes 'a' out (LRU eviction) + assert cache.contains("a") is False + assert cache.contains("b") is True + assert cache.contains("c") is True + + +def test_re_add_refreshes_ttl_and_lru_position(): + cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked(max_entries=2) + cache.add("a", ttl_seconds=10) + cache.add("b", ttl_seconds=10) + cache.add("a", ttl_seconds=10) # re-add: should now be MRU + cache.add("c", ttl_seconds=10) # evicts oldest, which is 'b' now + assert cache.contains("a") is True + assert cache.contains("b") is False + assert cache.contains("c") is True + + +def test_clear_drops_all_entries(): + cache: InMemoryRecentlyRevoked = InMemoryRecentlyRevoked() + cache.add("a", ttl_seconds=10) + cache.add("b", ttl_seconds=10) + cache.clear() + assert cache.contains("a") is False + assert cache.contains("b") is False + + +def test_hash_token_is_stable_and_sha256_hex(): + h1 = hash_token("xyz") + h2 = hash_token("xyz") + assert h1 == h2 + assert len(h1) == 64 + assert all(c in "0123456789abcdef" for c in h1) diff --git a/tests/unit/auth/test_jwt_utils.py b/tests/unit/auth/test_jwt_utils.py new file mode 100644 index 0000000..ee04aa7 --- /dev/null +++ b/tests/unit/auth/test_jwt_utils.py @@ -0,0 +1,144 @@ +"""Unit tests for jwt_utils (V2 sid/jti claims).""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import jwt +import pytest + +from agentkit.server.auth.jwt_utils import ( + ACCESS_TOKEN_TTL, + JWT_ALGORITHM, + REFRESH_TOKEN_TTL, + REFRESH_TOKEN_TTL_REMEMBER_ME, + create_token_pair, + verify_token, +) + +SECRET = "test-secret-with-at-least-32-bytes-1234" + + +def test_create_token_pair_has_required_v1_claims(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + for token in (pair.access_token, pair.refresh_token): + payload = jwt.decode(token, SECRET, algorithms=[JWT_ALGORITHM]) + assert payload["sub"] == "u-1" + assert payload["username"] == "alice" + assert payload["role"] == "member" + assert payload["iat"] and payload["exp"] + + +def test_v1_pair_has_no_sid_or_jti(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM]) + refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM]) + assert "sid" not in access + assert "jti" not in access + assert "sid" not in refresh + + +def test_v2_pair_includes_sid_on_both_tokens_and_jti_on_access(): + pair = create_token_pair( + user_id="u-1", + username="alice", + role="member", + secret=SECRET, + session_id="sess-abc", + ) + access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM]) + refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM]) + assert access["sid"] == "sess-abc" + assert refresh["sid"] == "sess-abc" + assert access["jti"] and isinstance(access["jti"], str) + # Refresh intentionally has no jti — rotation uses the token hash. + assert "jti" not in refresh + + +def test_refresh_token_type_is_refresh(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + refresh = jwt.decode(pair.refresh_token, SECRET, algorithms=[JWT_ALGORITHM]) + assert refresh["type"] == "refresh" + + +def test_access_token_type_is_access(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + access = jwt.decode(pair.access_token, SECRET, algorithms=[JWT_ALGORITHM]) + assert access["type"] == "access" + + +def test_default_refresh_ttl_is_7_days(): + now = datetime(2026, 6, 20, 0, 0, 0, tzinfo=timezone.utc) + pair = create_token_pair( + user_id="u-1", + username="alice", + role="member", + secret=SECRET, + now=now, + ) + assert (pair.access_expires_at - now) == ACCESS_TOKEN_TTL + assert (pair.refresh_expires_at - now) == REFRESH_TOKEN_TTL + + +def test_remember_me_extends_refresh_ttl_to_30_days(): + now = datetime(2026, 6, 20, 0, 0, 0, tzinfo=timezone.utc) + pair = create_token_pair( + user_id="u-1", + username="alice", + role="member", + secret=SECRET, + now=now, + remember_me=True, + ) + assert (pair.refresh_expires_at - now) == REFRESH_TOKEN_TTL_REMEMBER_ME + + +def test_verify_token_accepts_access_and_refresh_by_default(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + a = verify_token(pair.access_token, SECRET) + r = verify_token(pair.refresh_token, SECRET) + assert a["type"] == "access" + assert r["type"] == "refresh" + + +def test_verify_token_expected_type_filters_other_type(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + with pytest.raises(jwt.InvalidTokenError): + verify_token(pair.access_token, SECRET, expected_type="refresh") + with pytest.raises(jwt.InvalidTokenError): + verify_token(pair.refresh_token, SECRET, expected_type="access") + + +def test_verify_token_rejects_expired(): + # Use a "now" far in the past so the access token is definitely expired. + past = datetime.now(timezone.utc) - timedelta(hours=1) + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET, now=past + ) + with pytest.raises(jwt.ExpiredSignatureError): + verify_token(pair.access_token, SECRET) + + +def test_verify_token_rejects_wrong_secret(): + pair = create_token_pair( + user_id="u-1", username="alice", role="member", secret=SECRET + ) + with pytest.raises(jwt.InvalidTokenError): + verify_token(pair.access_token, "other-secret") + + +def test_create_token_pair_rejects_empty_secret(): + with pytest.raises(ValueError): + create_token_pair(user_id="u-1", username="alice", role="member", secret="")