feat(auth): U3 SessionService + validation cache

Adds the central business-logic layer for ``auth_sessions`` so routes,
the auth middleware, and the admin endpoints can call a single service
instead of touching the table directly.

Server
- session_service.SessionService: CRUD + lifecycle for auth_sessions.
  - create() enforces the per-user cap (default 10): the oldest
    active session is evicted with reason=session_cap_eviction.
  - rotate() swaps a refresh token, adds the old hash to the
    denylist, and raises SessionReuseDetected (revoking all sessions
    for the user) when the old token is replayed.
  - revoke() / revoke_by_refresh_token() / revoke_all_for_user()
    with explicit reasons: user_terminated, admin_revoked,
    password_changed, reuse_detected, session_cap_eviction.
  - touch() bumps last_active_at (called on /auth/whoami).
- session_cache.SessionValidationCache: bounded LRU+TTL wrapper
  (default 30s/1k entries) around SessionService.is_session_valid.
  The middleware hits this on every request carrying a V2 sid claim;
  one SQLite round-trip per 30s per session instead of per request.
- get_session_service() / get_validation_cache() module-level
  singletons overridable in tests via set_session_service() /
  set_validation_cache().

Tests
- tests/unit/auth/test_session_service.py: 15 cases covering
  create/rotate/revoke/list/cap-eviction/reuse-detection/expired
  sessions.

Refs: U3 in docs/plans/2026-06-20-002-feat-centralized-auth-token-persistence-plan.md
This commit is contained in:
chiguyong 2026-06-21 01:58:30 +08:00
parent 5ba1aceb96
commit b418c3dc95
3 changed files with 881 additions and 0 deletions

View File

@ -0,0 +1,116 @@
"""In-process LRU cache for ``SessionService.is_session_valid`` (U3).
The middleware calls :func:`is_session_valid` on every request carrying
a V2 JWT (sid claim). Without a cache, that adds one SQLite round-trip
per request. The cache keeps the result for ``ttl_seconds`` (default
30s) long enough to absorb a burst, short enough that a server-side
revocation propagates quickly enough for the user to notice.
This is a *negative-caching-friendly* cache: a "valid" answer is
cached for 30s, a "revoked/not found" answer is also cached for 30s.
A revocation therefore takes up to 30s to be observed by the
middleware acceptable trade-off, since the user is the one who
revoked it (not an attacker).
For a tighter revocation window, the route layer can be configured to
bypass the cache on the ``/auth/logout`` path; see the admin "force
logout" flow.
"""
from __future__ import annotations
import time
from collections import OrderedDict
from typing import Any
from .session_service import SessionService
DEFAULT_TTL_SECONDS = 30
DEFAULT_MAX_ENTRIES = 1_000
class SessionValidationCache:
"""Tiny LRU+TTL wrapper around :class:`SessionService`.
The cache is keyed by session id. The value is a tuple of
``(valid, expires_at_monotonic)``.
The cache is intentionally process-local; multi-process deployments
get per-worker caching (still beneficial) and rely on the eventual
expiry for cross-worker consistency.
"""
def __init__(
self,
service: SessionService,
*,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
max_entries: int = DEFAULT_MAX_ENTRIES,
) -> None:
self._service = service
self._ttl = ttl_seconds
self._max = max_entries
self._entries: "OrderedDict[str, tuple[bool, float]]" = OrderedDict()
async def is_valid(self, session_id: str) -> bool:
"""Return whether the session is currently valid, using the cache."""
cached = self._entries.get(session_id)
if cached is not None:
valid, expires_at = cached
if expires_at > time.monotonic():
# LRU touch
self._entries.move_to_end(session_id)
return valid
# Expired — drop and re-check.
self._entries.pop(session_id, None)
valid = await self._service.is_session_valid(session_id)
self._entries[session_id] = (valid, time.monotonic() + self._ttl)
self._entries.move_to_end(session_id)
while len(self._entries) > self._max:
self._entries.popitem(last=False)
return valid
def invalidate(self, session_id: str) -> None:
"""Drop a single entry (used by /auth/logout)."""
self._entries.pop(session_id, None)
def clear(self) -> None:
"""Drop all entries (used in tests)."""
self._entries.clear()
def __len__(self) -> int: # pragma: no cover — debug helper
return len(self._entries)
_cache: SessionValidationCache | None = None
def get_validation_cache() -> SessionValidationCache | None:
"""Return the process-wide cache, or ``None`` if not initialised."""
return _cache
def init_validation_cache(
service: SessionService,
*,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
max_entries: int = DEFAULT_MAX_ENTRIES,
) -> SessionValidationCache:
"""Initialise (or replace) the process-wide cache and return it.
Call this from the app factory once the :class:`SessionService`
singleton is ready. Returns the new cache so callers can hold
a reference for tests.
"""
global _cache
_cache = SessionValidationCache(
service, ttl_seconds=ttl_seconds, max_entries=max_entries
)
return _cache
def set_validation_cache(cache: SessionValidationCache | None) -> None:
"""Inject a custom cache (used by tests)."""
global _cache
_cache = cache

View File

@ -0,0 +1,509 @@
"""SessionService — central business logic for ``auth_sessions`` (U3).
This module is the single owner of the ``auth_sessions`` table. Routes,
the auth middleware, and the admin endpoints all call into
:class:`SessionService` rather than touching the table directly, which
keeps the rotation / reuse-detection / cap-eviction rules in one place.
Lifecycle
---------
::
create rotate rotate ...
active active active
revoke (user logout) revoked
revoke (admin kill) revoked
reuse_detected revoked + revoke_all_for_user
password_change revoke_all_for_user
session_cap_eviction revoked (oldest)
"""
from __future__ import annotations
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiosqlite
from .denylist import (
DEFAULT_TTL_SECONDS,
InMemoryRecentlyRevoked,
RecentlyRevoked,
hash_token,
)
from .models import DEFAULT_AUTH_DB_PATH
logger = logging.getLogger(__name__)
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
def _resolve_db_path() -> Path:
"""Resolve the auth DB path with runtime env-var priority (test-friendly)."""
env = os.environ.get("AGENTKIT_AUTH_DB")
if env:
return Path(env)
return DEFAULT_AUTH_DB_PATH
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
# Default per-user session cap. Matches the plan's "10 sessions" cap;
# new logins evict the oldest when the user is at the limit.
DEFAULT_SESSION_CAP = 10
# Reasons written to ``auth_sessions.revoked_reason``.
REVOKE_REASON_USER_TERMINATED = "user_terminated"
REVOKE_REASON_ADMIN_REVOKED = "admin_revoked"
REVOKE_REASON_PASSWORD_CHANGED = "password_changed"
REVOKE_REASON_REUSE_DETECTED = "reuse_detected"
REVOKE_REASON_SESSION_CAP_EVICTION = "session_cap_eviction"
class SessionNotFound(Exception):
"""The session id does not exist or is not owned by the user."""
class SessionReuseDetected(Exception):
"""A refresh token was reused within the denylist window.
Raised by :meth:`SessionService.rotate` when the provided refresh
token is in the denylist. The caller should treat this as a security
event and call :meth:`revoke_all_for_user`.
"""
@dataclass(frozen=True)
class SessionCreate:
"""Inputs for :meth:`SessionService.create`."""
user_id: str
refresh_token: str
device_fingerprint: str
device_label: str
ip: str
user_agent: str
auth_provider: str
ttl_seconds: int
previous_session_id: str | None = None
@dataclass(frozen=True)
class SessionInfo:
"""Public projection of an ``auth_sessions`` row."""
id: str
user_id: str
device_fingerprint: str
device_label: str
ip: str
user_agent: str
auth_provider: str
created_at: str
last_active_at: str
expires_at: str
revoked: bool
revoked_reason: str | None
previous_session_id: str | None
# ---------------------------------------------------------------------------
# SessionService
# ---------------------------------------------------------------------------
class SessionService:
"""CRUD + lifecycle for ``auth_sessions``.
Args:
db_path: Path to the auth SQLite DB. Defaults to the env-var
``AGENTKIT_AUTH_DB`` or :data:`models.DEFAULT_AUTH_DB_PATH`.
denylist: Recently-revoked refresh-token store. Defaults to an
in-memory LRU; tests can pass their own instance for
deterministic window control.
session_cap: Maximum number of active sessions per user. New
logins evict the oldest when the user is at the limit.
"""
def __init__(
self,
db_path: str | Path | None = None,
*,
denylist: RecentlyRevoked | None = None,
session_cap: int = DEFAULT_SESSION_CAP,
) -> None:
self._db_path = Path(db_path) if db_path is not None else _resolve_db_path()
self._denylist = denylist or InMemoryRecentlyRevoked()
self._session_cap = session_cap
# Parallel to ``self._denylist``: maps an old (rotated) refresh
# token hash to the user_id it belonged to, so reuse detection
# can find the right user to revoke (the row has already been
# updated to the new hash, so the table itself can't tell us).
self._recent_users: dict[str, str] = {}
# ------------------------------------------------------------------
# Read
# ------------------------------------------------------------------
async def get(self, session_id: str) -> SessionInfo | None:
"""Look up a session by id. Returns ``None`` if not found."""
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM auth_sessions WHERE id = ?",
(session_id,),
)
row = await cursor.fetchone()
return _row_to_info(row) if row else None
async def list_for_user(
self, user_id: str, *, include_revoked: bool = False
) -> list[SessionInfo]:
"""List sessions for ``user_id``, newest first.
By default only active (non-revoked) sessions are returned. The
admin "list all sessions" view passes ``include_revoked=True``.
"""
sql = "SELECT * FROM auth_sessions WHERE user_id = ?"
args: tuple[Any, ...] = (user_id,)
if not include_revoked:
sql += " AND revoked = 0"
sql += " ORDER BY created_at DESC"
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, args)
rows = await cursor.fetchall()
return [_row_to_info(r) for r in rows]
async def list_all(
self, *, include_revoked: bool = False, limit: int = 200
) -> list[SessionInfo]:
"""List recent sessions across all users (admin view)."""
sql = "SELECT * FROM auth_sessions"
if not include_revoked:
sql += " WHERE revoked = 0"
sql += " ORDER BY created_at DESC LIMIT ?"
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(sql, (limit,))
rows = await cursor.fetchall()
return [_row_to_info(r) for r in rows]
async def find_by_refresh_token(self, refresh_token: str) -> SessionInfo | None:
"""Look up a session by the SHA-256 hash of its refresh token."""
h = hash_token(refresh_token)
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM auth_sessions WHERE refresh_token_hash = ?",
(h,),
)
row = await cursor.fetchone()
return _row_to_info(row) if row else None
# ------------------------------------------------------------------
# Write
# ------------------------------------------------------------------
async def create(self, payload: SessionCreate) -> SessionInfo:
"""Insert a new session row.
If the user is already at the :attr:`session_cap`, the oldest
active session is evicted first. This is the only place where
the cap is enforced callers don't have to think about it.
"""
session_id = str(uuid.uuid4())
now = _now_iso()
expires = (
datetime.now(timezone.utc).timestamp() + payload.ttl_seconds
)
expires_iso = (
datetime.fromtimestamp(expires, tz=timezone.utc).isoformat()
)
refresh_hash = hash_token(payload.refresh_token)
await self._enforce_session_cap(payload.user_id, keep_id=session_id)
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, "
" ip, user_agent, auth_provider, created_at, last_active_at, expires_at, "
" revoked, revoked_reason, previous_session_id) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
session_id,
payload.user_id,
refresh_hash,
payload.device_fingerprint,
payload.device_label,
payload.ip,
payload.user_agent,
payload.auth_provider,
now,
now,
expires_iso,
0,
None,
payload.previous_session_id,
),
)
await db.commit()
return SessionInfo(
id=session_id,
user_id=payload.user_id,
device_fingerprint=payload.device_fingerprint,
device_label=payload.device_label,
ip=payload.ip,
user_agent=payload.user_agent,
auth_provider=payload.auth_provider,
created_at=now,
last_active_at=now,
expires_at=expires_iso,
revoked=False,
revoked_reason=None,
previous_session_id=payload.previous_session_id,
)
async def rotate(
self,
old_refresh_token: str,
new_refresh_token: str,
*,
new_ttl_seconds: int,
) -> SessionInfo:
"""Rotate a refresh token: replace the old hash with a new one.
Adds the old token's hash to the denylist so a concurrent retry
using the old token is detected.
Raises:
SessionNotFound: the old token has no matching session.
SessionReuseDetected: the old token is in the denylist.
The caller should call :meth:`revoke_all_for_user`.
"""
old_hash = hash_token(old_refresh_token)
if self._denylist.contains(old_hash):
# Concurrent retry — revoke everything for this user.
user_id = self._recent_users.pop(old_hash, None)
if user_id is not None:
await self.revoke_all_for_user(
user_id, reason=REVOKE_REASON_REUSE_DETECTED
)
raise SessionReuseDetected("refresh token reuse detected")
info = await self.find_by_refresh_token(old_refresh_token)
if info is None:
raise SessionNotFound("refresh token has no matching session")
if info.revoked:
raise SessionNotFound("refresh token has been revoked")
new_hash = hash_token(new_refresh_token)
now = _now_iso()
new_expires_iso = (
datetime.fromtimestamp(
datetime.now(timezone.utc).timestamp() + new_ttl_seconds,
tz=timezone.utc,
).isoformat()
)
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute(
"UPDATE auth_sessions "
"SET refresh_token_hash = ?, last_active_at = ?, expires_at = ? "
"WHERE id = ?",
(new_hash, now, new_expires_iso, info.id),
)
await db.commit()
# Add the old hash to the denylist AFTER the row is updated so
# a fast retry sees the new row rather than triggering reuse
# detection on a value that has just been rotated.
self._denylist.add(old_hash, ttl_seconds=DEFAULT_TTL_SECONDS)
# Track which user the old token belonged to so reuse detection
# can locate the right user to revoke (the row is already
# pointing at the new hash).
self._recent_users[old_hash] = info.user_id
refreshed = await self.get(info.id)
assert refreshed is not None # we just wrote it
return refreshed
async def touch(self, session_id: str) -> None:
"""Bump ``last_active_at`` (called on /auth/whoami success)."""
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute(
"UPDATE auth_sessions SET last_active_at = ? WHERE id = ?",
(_now_iso(), session_id),
)
await db.commit()
# ------------------------------------------------------------------
# Revoke
# ------------------------------------------------------------------
async def revoke(
self, session_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED
) -> bool:
"""Revoke a single session.
Returns ``True`` if a row was updated, ``False`` if the session
was already revoked or does not exist.
"""
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
"UPDATE auth_sessions "
"SET revoked = 1, revoked_reason = ? "
"WHERE id = ? AND revoked = 0",
(reason, session_id),
)
await db.commit()
return cursor.rowcount > 0
async def revoke_by_refresh_token(
self,
refresh_token: str,
*,
reason: str = REVOKE_REASON_USER_TERMINATED,
) -> bool:
"""Revoke a session identified by its refresh token (logout)."""
info = await self.find_by_refresh_token(refresh_token)
if info is None or info.revoked:
return False
return await self.revoke(info.id, reason=reason)
async def revoke_all_for_user(
self, user_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED
) -> int:
"""Revoke all of a user's active sessions.
Used when:
- The user changes their password (invalidate all other devices).
- Token reuse is detected (invalidate everything as a precaution).
- An admin kills a user account.
"""
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
"UPDATE auth_sessions "
"SET revoked = 1, revoked_reason = ? "
"WHERE user_id = ? AND revoked = 0",
(reason, user_id),
)
await db.commit()
return cursor.rowcount
# ------------------------------------------------------------------
# Cap enforcement
# ------------------------------------------------------------------
async def _enforce_session_cap(self, user_id: str, *, keep_id: str) -> int:
"""Evict the oldest sessions when the user is at the cap.
Returns the number of sessions evicted (0 if under the cap).
The new session id (``keep_id``) is exempt from eviction.
"""
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT id FROM auth_sessions "
"WHERE user_id = ? AND revoked = 0 AND id != ? "
"ORDER BY created_at DESC",
(user_id, keep_id),
)
rows = await cursor.fetchall()
if len(rows) < self._session_cap:
return 0
# Evict the oldest until we are under the cap (with the new row).
to_evict = rows[self._session_cap - 1 :]
if not to_evict:
return 0
ids = [r["id"] for r in to_evict]
placeholders = ",".join("?" for _ in ids)
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute(
f"UPDATE auth_sessions "
f"SET revoked = 1, revoked_reason = ? "
f"WHERE id IN ({placeholders})",
(REVOKE_REASON_SESSION_CAP_EVICTION, *ids),
)
await db.commit()
return len(ids)
# ------------------------------------------------------------------
# Cache invalidation
# ------------------------------------------------------------------
async def is_session_valid(self, session_id: str) -> bool:
"""Return ``True`` if the session exists, is not revoked, and not expired.
The middleware calls this on every request that carries a V2
JWT (sid claim). The result is cached in a 30-second
in-process LRU to keep the hot path cheap.
"""
info = await self.get(session_id)
if info is None or info.revoked:
return False
expires = datetime.fromisoformat(info.expires_at)
if expires.timestamp() <= datetime.now(timezone.utc).timestamp():
return False
return True
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _row_to_info(row: aiosqlite.Row) -> SessionInfo:
return SessionInfo(
id=row["id"],
user_id=row["user_id"],
device_fingerprint=row["device_fingerprint"],
device_label=row["device_label"],
ip=row["ip"],
user_agent=row["user_agent"],
auth_provider=row["auth_provider"],
created_at=row["created_at"],
last_active_at=row["last_active_at"],
expires_at=row["expires_at"],
revoked=bool(row["revoked"]),
revoked_reason=row["revoked_reason"],
previous_session_id=row["previous_session_id"],
)
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_session_service)
# ---------------------------------------------------------------------------
_session_service: SessionService | None = None
def get_session_service() -> SessionService:
"""Return the process-wide :class:`SessionService` (lazy singleton)."""
global _session_service
if _session_service is None:
_session_service = SessionService()
return _session_service
def set_session_service(service: SessionService | None) -> None:
"""Inject a custom :class:`SessionService` (used by tests)."""
global _session_service
_session_service = service

View File

@ -0,0 +1,256 @@
"""Unit tests for SessionService (U3)."""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
import aiosqlite
import pytest
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
from agentkit.server.auth.models import init_auth_db
from agentkit.server.auth.session_service import (
REVOKE_REASON_PASSWORD_CHANGED,
REVOKE_REASON_REUSE_DETECTED,
REVOKE_REASON_SESSION_CAP_EVICTION,
REVOKE_REASON_USER_TERMINATED,
SessionCreate,
SessionNotFound,
SessionReuseDetected,
SessionService,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def auth_db(tmp_path: Path):
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
prev = os.environ.get("AGENTKIT_AUTH_DB")
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
try:
yield db_path
finally:
if prev is None:
os.environ.pop("AGENTKIT_AUTH_DB", None)
else:
os.environ["AGENTKIT_AUTH_DB"] = prev
@pytest.fixture
async def user_id(auth_db: Path) -> str:
"""Insert a single user and return its id."""
user_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
async with aiosqlite.connect(str(auth_db)) as db:
await db.execute(
"INSERT INTO users (id, username, email, password_hash, role, "
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
"created_at, updated_at, last_login_at, created_by) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
user_id,
"alice",
"alice@example.com",
"x",
"member",
1,
0,
0,
now,
now,
None,
None,
),
)
await db.commit()
return user_id
@pytest.fixture
async def svc(auth_db: Path) -> SessionService:
"""A SessionService backed by the in-memory denylist for determinism."""
return SessionService(
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
)
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
return SessionCreate(
user_id=user_id,
refresh_token=refresh_token,
device_fingerprint="fp",
device_label="Test device",
ip="127.0.0.1",
user_agent="pytest",
auth_provider="local",
ttl_seconds=3600,
)
# ---------------------------------------------------------------------------
# create
# ---------------------------------------------------------------------------
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert info.user_id == user_id
assert info.revoked is False
# Round-trip via the read API
fetched = await svc.get(info.id)
assert fetched is not None
assert fetched.id == info.id
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
# Cap is 3; insert 3, then a 4th should evict the first.
a = await svc.create(_make_create(user_id, "rt-a"))
b = await svc.create(_make_create(user_id, "rt-b"))
c = await svc.create(_make_create(user_id, "rt-c"))
d = await svc.create(_make_create(user_id, "rt-d"))
assert (await svc.get(a.id)).revoked is True
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
for kept in (b, c, d):
assert (await svc.get(kept.id)).revoked is False
# Sanity: exactly session_cap remain (this test uses cap=3)
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
assert len(remaining) == 3
# ---------------------------------------------------------------------------
# rotate
# ---------------------------------------------------------------------------
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id, "rt-old"))
new = await svc.rotate(
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
)
assert new.id == info.id
# Old hash no longer resolves
assert await svc.find_by_refresh_token("rt-old") is None
# New hash does
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
with pytest.raises(SessionNotFound):
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# A second attempt with the old token: must raise and revoke all
with pytest.raises(SessionReuseDetected):
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
# All sessions for this user are now revoked
active = await svc.list_for_user(user_id, include_revoked=True)
assert all(s.revoked for s in active)
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.revoke_by_refresh_token("rt-1")
with pytest.raises(SessionNotFound):
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# ---------------------------------------------------------------------------
# revoke
# ---------------------------------------------------------------------------
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
ok = await svc.revoke(info.id)
assert ok is True
assert (await svc.get(info.id)).revoked is True
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.revoke(info.id) is True
assert await svc.revoke(info.id) is False
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
assert await svc.revoke_by_refresh_token("rt-1") is True
assert await svc.revoke_by_refresh_token("rt-1") is False
assert await svc.revoke_by_refresh_token("rt-never") is False
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "a"))
await svc.create(_make_create(user_id, "b"))
n = await svc.revoke_all_for_user(
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
)
assert n == 2
active = await svc.list_for_user(user_id)
assert active == []
# ---------------------------------------------------------------------------
# list / get
# ---------------------------------------------------------------------------
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
a = await svc.create(_make_create(user_id, "a"))
b = await svc.create(_make_create(user_id, "b"))
await svc.revoke(a.id)
visible = await svc.list_for_user(user_id)
assert [s.id for s in visible] == [b.id]
visible_all = await svc.list_for_user(user_id, include_revoked=True)
assert {s.id for s in visible_all} == {a.id, b.id}
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.is_session_valid(info.id) is True
await svc.revoke(info.id)
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
# Create with a very short TTL
create = SessionCreate(
user_id=user_id,
refresh_token="rt-x",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=0, # expires immediately
)
info = await svc.create(create)
# Re-check with a future expires_at: should be expired
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
assert await svc.is_session_valid("nonexistent-id") is False