From b418c3dc95e41bff64b1a1c94eb7b9b3bc52b702 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 21 Jun 2026 01:58:30 +0800 Subject: [PATCH] feat(auth): U3 SessionService + validation cache Adds the central business-logic layer for ``auth_sessions`` so routes, the auth middleware, and the admin endpoints can call a single service instead of touching the table directly. Server - session_service.SessionService: CRUD + lifecycle for auth_sessions. - create() enforces the per-user cap (default 10): the oldest active session is evicted with reason=session_cap_eviction. - rotate() swaps a refresh token, adds the old hash to the denylist, and raises SessionReuseDetected (revoking all sessions for the user) when the old token is replayed. - revoke() / revoke_by_refresh_token() / revoke_all_for_user() with explicit reasons: user_terminated, admin_revoked, password_changed, reuse_detected, session_cap_eviction. - touch() bumps last_active_at (called on /auth/whoami). - session_cache.SessionValidationCache: bounded LRU+TTL wrapper (default 30s/1k entries) around SessionService.is_session_valid. The middleware hits this on every request carrying a V2 sid claim; one SQLite round-trip per 30s per session instead of per request. - get_session_service() / get_validation_cache() module-level singletons overridable in tests via set_session_service() / set_validation_cache(). Tests - tests/unit/auth/test_session_service.py: 15 cases covering create/rotate/revoke/list/cap-eviction/reuse-detection/expired sessions. Refs: U3 in docs/plans/2026-06-20-002-feat-centralized-auth-token-persistence-plan.md --- src/agentkit/server/auth/session_cache.py | 116 +++++ src/agentkit/server/auth/session_service.py | 509 ++++++++++++++++++++ tests/unit/auth/test_session_service.py | 256 ++++++++++ 3 files changed, 881 insertions(+) create mode 100644 src/agentkit/server/auth/session_cache.py create mode 100644 src/agentkit/server/auth/session_service.py create mode 100644 tests/unit/auth/test_session_service.py diff --git a/src/agentkit/server/auth/session_cache.py b/src/agentkit/server/auth/session_cache.py new file mode 100644 index 0000000..a69bfe7 --- /dev/null +++ b/src/agentkit/server/auth/session_cache.py @@ -0,0 +1,116 @@ +"""In-process LRU cache for ``SessionService.is_session_valid`` (U3). + +The middleware calls :func:`is_session_valid` on every request carrying +a V2 JWT (sid claim). Without a cache, that adds one SQLite round-trip +per request. The cache keeps the result for ``ttl_seconds`` (default +30s) — long enough to absorb a burst, short enough that a server-side +revocation propagates quickly enough for the user to notice. + +This is a *negative-caching-friendly* cache: a "valid" answer is +cached for 30s, a "revoked/not found" answer is also cached for 30s. +A revocation therefore takes up to 30s to be observed by the +middleware — acceptable trade-off, since the user is the one who +revoked it (not an attacker). + +For a tighter revocation window, the route layer can be configured to +bypass the cache on the ``/auth/logout`` path; see the admin "force +logout" flow. +""" + +from __future__ import annotations + +import time +from collections import OrderedDict +from typing import Any + +from .session_service import SessionService + +DEFAULT_TTL_SECONDS = 30 +DEFAULT_MAX_ENTRIES = 1_000 + + +class SessionValidationCache: + """Tiny LRU+TTL wrapper around :class:`SessionService`. + + The cache is keyed by session id. The value is a tuple of + ``(valid, expires_at_monotonic)``. + + The cache is intentionally process-local; multi-process deployments + get per-worker caching (still beneficial) and rely on the eventual + expiry for cross-worker consistency. + """ + + def __init__( + self, + service: SessionService, + *, + ttl_seconds: int = DEFAULT_TTL_SECONDS, + max_entries: int = DEFAULT_MAX_ENTRIES, + ) -> None: + self._service = service + self._ttl = ttl_seconds + self._max = max_entries + self._entries: "OrderedDict[str, tuple[bool, float]]" = OrderedDict() + + async def is_valid(self, session_id: str) -> bool: + """Return whether the session is currently valid, using the cache.""" + cached = self._entries.get(session_id) + if cached is not None: + valid, expires_at = cached + if expires_at > time.monotonic(): + # LRU touch + self._entries.move_to_end(session_id) + return valid + # Expired — drop and re-check. + self._entries.pop(session_id, None) + + valid = await self._service.is_session_valid(session_id) + self._entries[session_id] = (valid, time.monotonic() + self._ttl) + self._entries.move_to_end(session_id) + while len(self._entries) > self._max: + self._entries.popitem(last=False) + return valid + + def invalidate(self, session_id: str) -> None: + """Drop a single entry (used by /auth/logout).""" + self._entries.pop(session_id, None) + + def clear(self) -> None: + """Drop all entries (used in tests).""" + self._entries.clear() + + def __len__(self) -> int: # pragma: no cover — debug helper + return len(self._entries) + + +_cache: SessionValidationCache | None = None + + +def get_validation_cache() -> SessionValidationCache | None: + """Return the process-wide cache, or ``None`` if not initialised.""" + return _cache + + +def init_validation_cache( + service: SessionService, + *, + ttl_seconds: int = DEFAULT_TTL_SECONDS, + max_entries: int = DEFAULT_MAX_ENTRIES, +) -> SessionValidationCache: + """Initialise (or replace) the process-wide cache and return it. + + Call this from the app factory once the :class:`SessionService` + singleton is ready. Returns the new cache so callers can hold + a reference for tests. + """ + global _cache + _cache = SessionValidationCache( + service, ttl_seconds=ttl_seconds, max_entries=max_entries + ) + return _cache + + +def set_validation_cache(cache: SessionValidationCache | None) -> None: + """Inject a custom cache (used by tests).""" + global _cache + _cache = cache diff --git a/src/agentkit/server/auth/session_service.py b/src/agentkit/server/auth/session_service.py new file mode 100644 index 0000000..c76ad3e --- /dev/null +++ b/src/agentkit/server/auth/session_service.py @@ -0,0 +1,509 @@ +"""SessionService — central business logic for ``auth_sessions`` (U3). + +This module is the single owner of the ``auth_sessions`` table. Routes, +the auth middleware, and the admin endpoints all call into +:class:`SessionService` rather than touching the table directly, which +keeps the rotation / reuse-detection / cap-eviction rules in one place. + +Lifecycle +--------- +:: + + create ─► rotate ─► rotate ─► ... + │ │ │ + ▼ ▼ ▼ + active active active + │ + ├── revoke (user logout) → revoked + ├── revoke (admin kill) → revoked + ├── reuse_detected → revoked + revoke_all_for_user + ├── password_change → revoke_all_for_user + └── session_cap_eviction → revoked (oldest) +""" + +from __future__ import annotations + +import logging +import os +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import aiosqlite + +from .denylist import ( + DEFAULT_TTL_SECONDS, + InMemoryRecentlyRevoked, + RecentlyRevoked, + hash_token, +) +from .models import DEFAULT_AUTH_DB_PATH + +logger = logging.getLogger(__name__) + + +def _now_iso() -> str: + """Return current UTC time as ISO 8601 string.""" + return datetime.now(timezone.utc).isoformat() + + +def _resolve_db_path() -> Path: + """Resolve the auth DB path with runtime env-var priority (test-friendly).""" + env = os.environ.get("AGENTKIT_AUTH_DB") + if env: + return Path(env) + return DEFAULT_AUTH_DB_PATH + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +# Default per-user session cap. Matches the plan's "10 sessions" cap; +# new logins evict the oldest when the user is at the limit. +DEFAULT_SESSION_CAP = 10 + +# Reasons written to ``auth_sessions.revoked_reason``. +REVOKE_REASON_USER_TERMINATED = "user_terminated" +REVOKE_REASON_ADMIN_REVOKED = "admin_revoked" +REVOKE_REASON_PASSWORD_CHANGED = "password_changed" +REVOKE_REASON_REUSE_DETECTED = "reuse_detected" +REVOKE_REASON_SESSION_CAP_EVICTION = "session_cap_eviction" + + +class SessionNotFound(Exception): + """The session id does not exist or is not owned by the user.""" + + +class SessionReuseDetected(Exception): + """A refresh token was reused within the denylist window. + + Raised by :meth:`SessionService.rotate` when the provided refresh + token is in the denylist. The caller should treat this as a security + event and call :meth:`revoke_all_for_user`. + """ + + +@dataclass(frozen=True) +class SessionCreate: + """Inputs for :meth:`SessionService.create`.""" + + user_id: str + refresh_token: str + device_fingerprint: str + device_label: str + ip: str + user_agent: str + auth_provider: str + ttl_seconds: int + previous_session_id: str | None = None + + +@dataclass(frozen=True) +class SessionInfo: + """Public projection of an ``auth_sessions`` row.""" + + id: str + user_id: str + device_fingerprint: str + device_label: str + ip: str + user_agent: str + auth_provider: str + created_at: str + last_active_at: str + expires_at: str + revoked: bool + revoked_reason: str | None + previous_session_id: str | None + + +# --------------------------------------------------------------------------- +# SessionService +# --------------------------------------------------------------------------- + + +class SessionService: + """CRUD + lifecycle for ``auth_sessions``. + + Args: + db_path: Path to the auth SQLite DB. Defaults to the env-var + ``AGENTKIT_AUTH_DB`` or :data:`models.DEFAULT_AUTH_DB_PATH`. + denylist: Recently-revoked refresh-token store. Defaults to an + in-memory LRU; tests can pass their own instance for + deterministic window control. + session_cap: Maximum number of active sessions per user. New + logins evict the oldest when the user is at the limit. + """ + + def __init__( + self, + db_path: str | Path | None = None, + *, + denylist: RecentlyRevoked | None = None, + session_cap: int = DEFAULT_SESSION_CAP, + ) -> None: + self._db_path = Path(db_path) if db_path is not None else _resolve_db_path() + self._denylist = denylist or InMemoryRecentlyRevoked() + self._session_cap = session_cap + # Parallel to ``self._denylist``: maps an old (rotated) refresh + # token hash to the user_id it belonged to, so reuse detection + # can find the right user to revoke (the row has already been + # updated to the new hash, so the table itself can't tell us). + self._recent_users: dict[str, str] = {} + + # ------------------------------------------------------------------ + # Read + # ------------------------------------------------------------------ + + async def get(self, session_id: str) -> SessionInfo | None: + """Look up a session by id. Returns ``None`` if not found.""" + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT * FROM auth_sessions WHERE id = ?", + (session_id,), + ) + row = await cursor.fetchone() + return _row_to_info(row) if row else None + + async def list_for_user( + self, user_id: str, *, include_revoked: bool = False + ) -> list[SessionInfo]: + """List sessions for ``user_id``, newest first. + + By default only active (non-revoked) sessions are returned. The + admin "list all sessions" view passes ``include_revoked=True``. + """ + sql = "SELECT * FROM auth_sessions WHERE user_id = ?" + args: tuple[Any, ...] = (user_id,) + if not include_revoked: + sql += " AND revoked = 0" + sql += " ORDER BY created_at DESC" + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(sql, args) + rows = await cursor.fetchall() + return [_row_to_info(r) for r in rows] + + async def list_all( + self, *, include_revoked: bool = False, limit: int = 200 + ) -> list[SessionInfo]: + """List recent sessions across all users (admin view).""" + sql = "SELECT * FROM auth_sessions" + if not include_revoked: + sql += " WHERE revoked = 0" + sql += " ORDER BY created_at DESC LIMIT ?" + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute(sql, (limit,)) + rows = await cursor.fetchall() + return [_row_to_info(r) for r in rows] + + async def find_by_refresh_token(self, refresh_token: str) -> SessionInfo | None: + """Look up a session by the SHA-256 hash of its refresh token.""" + h = hash_token(refresh_token) + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT * FROM auth_sessions WHERE refresh_token_hash = ?", + (h,), + ) + row = await cursor.fetchone() + return _row_to_info(row) if row else None + + # ------------------------------------------------------------------ + # Write + # ------------------------------------------------------------------ + + async def create(self, payload: SessionCreate) -> SessionInfo: + """Insert a new session row. + + If the user is already at the :attr:`session_cap`, the oldest + active session is evicted first. This is the only place where + the cap is enforced — callers don't have to think about it. + """ + session_id = str(uuid.uuid4()) + now = _now_iso() + expires = ( + datetime.now(timezone.utc).timestamp() + payload.ttl_seconds + ) + expires_iso = ( + datetime.fromtimestamp(expires, tz=timezone.utc).isoformat() + ) + refresh_hash = hash_token(payload.refresh_token) + + await self._enforce_session_cap(payload.user_id, keep_id=session_id) + + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, device_fingerprint, device_label, " + " ip, user_agent, auth_provider, created_at, last_active_at, expires_at, " + " revoked, revoked_reason, previous_session_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + session_id, + payload.user_id, + refresh_hash, + payload.device_fingerprint, + payload.device_label, + payload.ip, + payload.user_agent, + payload.auth_provider, + now, + now, + expires_iso, + 0, + None, + payload.previous_session_id, + ), + ) + await db.commit() + + return SessionInfo( + id=session_id, + user_id=payload.user_id, + device_fingerprint=payload.device_fingerprint, + device_label=payload.device_label, + ip=payload.ip, + user_agent=payload.user_agent, + auth_provider=payload.auth_provider, + created_at=now, + last_active_at=now, + expires_at=expires_iso, + revoked=False, + revoked_reason=None, + previous_session_id=payload.previous_session_id, + ) + + async def rotate( + self, + old_refresh_token: str, + new_refresh_token: str, + *, + new_ttl_seconds: int, + ) -> SessionInfo: + """Rotate a refresh token: replace the old hash with a new one. + + Adds the old token's hash to the denylist so a concurrent retry + using the old token is detected. + + Raises: + SessionNotFound: the old token has no matching session. + SessionReuseDetected: the old token is in the denylist. + The caller should call :meth:`revoke_all_for_user`. + """ + old_hash = hash_token(old_refresh_token) + if self._denylist.contains(old_hash): + # Concurrent retry — revoke everything for this user. + user_id = self._recent_users.pop(old_hash, None) + if user_id is not None: + await self.revoke_all_for_user( + user_id, reason=REVOKE_REASON_REUSE_DETECTED + ) + raise SessionReuseDetected("refresh token reuse detected") + + info = await self.find_by_refresh_token(old_refresh_token) + if info is None: + raise SessionNotFound("refresh token has no matching session") + if info.revoked: + raise SessionNotFound("refresh token has been revoked") + + new_hash = hash_token(new_refresh_token) + now = _now_iso() + new_expires_iso = ( + datetime.fromtimestamp( + datetime.now(timezone.utc).timestamp() + new_ttl_seconds, + tz=timezone.utc, + ).isoformat() + ) + + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute( + "UPDATE auth_sessions " + "SET refresh_token_hash = ?, last_active_at = ?, expires_at = ? " + "WHERE id = ?", + (new_hash, now, new_expires_iso, info.id), + ) + await db.commit() + + # Add the old hash to the denylist AFTER the row is updated so + # a fast retry sees the new row rather than triggering reuse + # detection on a value that has just been rotated. + self._denylist.add(old_hash, ttl_seconds=DEFAULT_TTL_SECONDS) + # Track which user the old token belonged to so reuse detection + # can locate the right user to revoke (the row is already + # pointing at the new hash). + self._recent_users[old_hash] = info.user_id + + refreshed = await self.get(info.id) + assert refreshed is not None # we just wrote it + return refreshed + + async def touch(self, session_id: str) -> None: + """Bump ``last_active_at`` (called on /auth/whoami success).""" + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute( + "UPDATE auth_sessions SET last_active_at = ? WHERE id = ?", + (_now_iso(), session_id), + ) + await db.commit() + + # ------------------------------------------------------------------ + # Revoke + # ------------------------------------------------------------------ + + async def revoke( + self, session_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED + ) -> bool: + """Revoke a single session. + + Returns ``True`` if a row was updated, ``False`` if the session + was already revoked or does not exist. + """ + async with aiosqlite.connect(str(self._db_path)) as db: + cursor = await db.execute( + "UPDATE auth_sessions " + "SET revoked = 1, revoked_reason = ? " + "WHERE id = ? AND revoked = 0", + (reason, session_id), + ) + await db.commit() + return cursor.rowcount > 0 + + async def revoke_by_refresh_token( + self, + refresh_token: str, + *, + reason: str = REVOKE_REASON_USER_TERMINATED, + ) -> bool: + """Revoke a session identified by its refresh token (logout).""" + info = await self.find_by_refresh_token(refresh_token) + if info is None or info.revoked: + return False + return await self.revoke(info.id, reason=reason) + + async def revoke_all_for_user( + self, user_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED + ) -> int: + """Revoke all of a user's active sessions. + + Used when: + - The user changes their password (invalidate all other devices). + - Token reuse is detected (invalidate everything as a precaution). + - An admin kills a user account. + """ + async with aiosqlite.connect(str(self._db_path)) as db: + cursor = await db.execute( + "UPDATE auth_sessions " + "SET revoked = 1, revoked_reason = ? " + "WHERE user_id = ? AND revoked = 0", + (reason, user_id), + ) + await db.commit() + return cursor.rowcount + + # ------------------------------------------------------------------ + # Cap enforcement + # ------------------------------------------------------------------ + + async def _enforce_session_cap(self, user_id: str, *, keep_id: str) -> int: + """Evict the oldest sessions when the user is at the cap. + + Returns the number of sessions evicted (0 if under the cap). + The new session id (``keep_id``) is exempt from eviction. + """ + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT id FROM auth_sessions " + "WHERE user_id = ? AND revoked = 0 AND id != ? " + "ORDER BY created_at DESC", + (user_id, keep_id), + ) + rows = await cursor.fetchall() + if len(rows) < self._session_cap: + return 0 + # Evict the oldest until we are under the cap (with the new row). + to_evict = rows[self._session_cap - 1 :] + if not to_evict: + return 0 + ids = [r["id"] for r in to_evict] + placeholders = ",".join("?" for _ in ids) + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute( + f"UPDATE auth_sessions " + f"SET revoked = 1, revoked_reason = ? " + f"WHERE id IN ({placeholders})", + (REVOKE_REASON_SESSION_CAP_EVICTION, *ids), + ) + await db.commit() + return len(ids) + + # ------------------------------------------------------------------ + # Cache invalidation + # ------------------------------------------------------------------ + + async def is_session_valid(self, session_id: str) -> bool: + """Return ``True`` if the session exists, is not revoked, and not expired. + + The middleware calls this on every request that carries a V2 + JWT (sid claim). The result is cached in a 30-second + in-process LRU to keep the hot path cheap. + """ + info = await self.get(session_id) + if info is None or info.revoked: + return False + expires = datetime.fromisoformat(info.expires_at) + if expires.timestamp() <= datetime.now(timezone.utc).timestamp(): + return False + return True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _row_to_info(row: aiosqlite.Row) -> SessionInfo: + return SessionInfo( + id=row["id"], + user_id=row["user_id"], + device_fingerprint=row["device_fingerprint"], + device_label=row["device_label"], + ip=row["ip"], + user_agent=row["user_agent"], + auth_provider=row["auth_provider"], + created_at=row["created_at"], + last_active_at=row["last_active_at"], + expires_at=row["expires_at"], + revoked=bool(row["revoked"]), + revoked_reason=row["revoked_reason"], + previous_session_id=row["previous_session_id"], + ) + + +# --------------------------------------------------------------------------- +# Module-level singleton (overridable in tests via set_session_service) +# --------------------------------------------------------------------------- + + +_session_service: SessionService | None = None + + +def get_session_service() -> SessionService: + """Return the process-wide :class:`SessionService` (lazy singleton).""" + global _session_service + if _session_service is None: + _session_service = SessionService() + return _session_service + + +def set_session_service(service: SessionService | None) -> None: + """Inject a custom :class:`SessionService` (used by tests).""" + global _session_service + _session_service = service diff --git a/tests/unit/auth/test_session_service.py b/tests/unit/auth/test_session_service.py new file mode 100644 index 0000000..de679c7 --- /dev/null +++ b/tests/unit/auth/test_session_service.py @@ -0,0 +1,256 @@ +"""Unit tests for SessionService (U3).""" + +from __future__ import annotations + +import asyncio +import os +import uuid +from datetime import datetime, timezone +from pathlib import Path + +import aiosqlite +import pytest + +from agentkit.server.auth.denylist import InMemoryRecentlyRevoked +from agentkit.server.auth.models import init_auth_db +from agentkit.server.auth.session_service import ( + REVOKE_REASON_PASSWORD_CHANGED, + REVOKE_REASON_REUSE_DETECTED, + REVOKE_REASON_SESSION_CAP_EVICTION, + REVOKE_REASON_USER_TERMINATED, + SessionCreate, + SessionNotFound, + SessionReuseDetected, + SessionService, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def auth_db(tmp_path: Path): + """Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration.""" + db_path = tmp_path / "auth.db" + await init_auth_db(db_path) + prev = os.environ.get("AGENTKIT_AUTH_DB") + os.environ["AGENTKIT_AUTH_DB"] = str(db_path) + try: + yield db_path + finally: + if prev is None: + os.environ.pop("AGENTKIT_AUTH_DB", None) + else: + os.environ["AGENTKIT_AUTH_DB"] = prev + + +@pytest.fixture +async def user_id(auth_db: Path) -> str: + """Insert a single user and return its id.""" + user_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + async with aiosqlite.connect(str(auth_db)) as db: + await db.execute( + "INSERT INTO users (id, username, email, password_hash, role, " + "is_active, is_terminal_authorized, is_server_terminal_authorized, " + "created_at, updated_at, last_login_at, created_by) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + user_id, + "alice", + "alice@example.com", + "x", + "member", + 1, + 0, + 0, + now, + now, + None, + None, + ), + ) + await db.commit() + return user_id + + +@pytest.fixture +async def svc(auth_db: Path) -> SessionService: + """A SessionService backed by the in-memory denylist for determinism.""" + return SessionService( + db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3 + ) + + +def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate: + return SessionCreate( + user_id=user_id, + refresh_token=refresh_token, + device_fingerprint="fp", + device_label="Test device", + ip="127.0.0.1", + user_agent="pytest", + auth_provider="local", + ttl_seconds=3600, + ) + + +# --------------------------------------------------------------------------- +# create +# --------------------------------------------------------------------------- + + +async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str): + info = await svc.create(_make_create(user_id)) + assert info.user_id == user_id + assert info.revoked is False + # Round-trip via the read API + fetched = await svc.get(info.id) + assert fetched is not None + assert fetched.id == info.id + + +async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str): + # Cap is 3; insert 3, then a 4th should evict the first. + a = await svc.create(_make_create(user_id, "rt-a")) + b = await svc.create(_make_create(user_id, "rt-b")) + c = await svc.create(_make_create(user_id, "rt-c")) + d = await svc.create(_make_create(user_id, "rt-d")) + + assert (await svc.get(a.id)).revoked is True + assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION + for kept in (b, c, d): + assert (await svc.get(kept.id)).revoked is False + + # Sanity: exactly session_cap remain (this test uses cap=3) + remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked] + assert len(remaining) == 3 + + +# --------------------------------------------------------------------------- +# rotate +# --------------------------------------------------------------------------- + + +async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str): + info = await svc.create(_make_create(user_id, "rt-old")) + new = await svc.rotate( + old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600 + ) + assert new.id == info.id + # Old hash no longer resolves + assert await svc.find_by_refresh_token("rt-old") is None + # New hash does + assert (await svc.find_by_refresh_token("rt-new")).id == info.id + + +async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str): + await svc.create(_make_create(user_id, "rt-1")) + await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600) + assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest()) + + +async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str): + with pytest.raises(SessionNotFound): + await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600) + + +async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str): + await svc.create(_make_create(user_id, "rt-1")) + await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600) + # A second attempt with the old token: must raise and revoke all + with pytest.raises(SessionReuseDetected): + await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600) + # All sessions for this user are now revoked + active = await svc.list_for_user(user_id, include_revoked=True) + assert all(s.revoked for s in active) + assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active) + + +async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str): + await svc.create(_make_create(user_id, "rt-1")) + await svc.revoke_by_refresh_token("rt-1") + with pytest.raises(SessionNotFound): + await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600) + + +# --------------------------------------------------------------------------- +# revoke +# --------------------------------------------------------------------------- + + +async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str): + info = await svc.create(_make_create(user_id)) + ok = await svc.revoke(info.id) + assert ok is True + assert (await svc.get(info.id)).revoked is True + assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED + + +async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str): + info = await svc.create(_make_create(user_id)) + assert await svc.revoke(info.id) is True + assert await svc.revoke(info.id) is False + + +async def test_revoke_by_refresh_token(svc: SessionService, user_id: str): + await svc.create(_make_create(user_id, "rt-1")) + assert await svc.revoke_by_refresh_token("rt-1") is True + assert await svc.revoke_by_refresh_token("rt-1") is False + assert await svc.revoke_by_refresh_token("rt-never") is False + + +async def test_revoke_all_for_user(svc: SessionService, user_id: str): + await svc.create(_make_create(user_id, "a")) + await svc.create(_make_create(user_id, "b")) + n = await svc.revoke_all_for_user( + user_id, reason=REVOKE_REASON_PASSWORD_CHANGED + ) + assert n == 2 + active = await svc.list_for_user(user_id) + assert active == [] + + +# --------------------------------------------------------------------------- +# list / get +# --------------------------------------------------------------------------- + + +async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str): + a = await svc.create(_make_create(user_id, "a")) + b = await svc.create(_make_create(user_id, "b")) + await svc.revoke(a.id) + visible = await svc.list_for_user(user_id) + assert [s.id for s in visible] == [b.id] + visible_all = await svc.list_for_user(user_id, include_revoked=True) + assert {s.id for s in visible_all} == {a.id, b.id} + + +async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str): + info = await svc.create(_make_create(user_id)) + assert await svc.is_session_valid(info.id) is True + await svc.revoke(info.id) + assert await svc.is_session_valid(info.id) is False + + +async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str): + # Create with a very short TTL + create = SessionCreate( + user_id=user_id, + refresh_token="rt-x", + device_fingerprint="fp", + device_label="d", + ip="", + user_agent="", + auth_provider="local", + ttl_seconds=0, # expires immediately + ) + info = await svc.create(create) + # Re-check with a future expires_at: should be expired + assert await svc.is_session_valid(info.id) is False + + +async def test_is_session_valid_returns_false_for_unknown(svc: SessionService): + assert await svc.is_session_valid("nonexistent-id") is False