feat(auth): U3 SessionService + validation cache
Adds the central business-logic layer for ``auth_sessions`` so routes,
the auth middleware, and the admin endpoints can call a single service
instead of touching the table directly.
Server
- session_service.SessionService: CRUD + lifecycle for auth_sessions.
- create() enforces the per-user cap (default 10): the oldest
active session is evicted with reason=session_cap_eviction.
- rotate() swaps a refresh token, adds the old hash to the
denylist, and raises SessionReuseDetected (revoking all sessions
for the user) when the old token is replayed.
- revoke() / revoke_by_refresh_token() / revoke_all_for_user()
with explicit reasons: user_terminated, admin_revoked,
password_changed, reuse_detected, session_cap_eviction.
- touch() bumps last_active_at (called on /auth/whoami).
- session_cache.SessionValidationCache: bounded LRU+TTL wrapper
(default 30s/1k entries) around SessionService.is_session_valid.
The middleware hits this on every request carrying a V2 sid claim;
one SQLite round-trip per 30s per session instead of per request.
- get_session_service() / get_validation_cache() module-level
singletons overridable in tests via set_session_service() /
set_validation_cache().
Tests
- tests/unit/auth/test_session_service.py: 15 cases covering
create/rotate/revoke/list/cap-eviction/reuse-detection/expired
sessions.
Refs: U3 in docs/plans/2026-06-20-002-feat-centralized-auth-token-persistence-plan.md
This commit is contained in:
parent
5ba1aceb96
commit
b418c3dc95
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""In-process LRU cache for ``SessionService.is_session_valid`` (U3).
|
||||||
|
|
||||||
|
The middleware calls :func:`is_session_valid` on every request carrying
|
||||||
|
a V2 JWT (sid claim). Without a cache, that adds one SQLite round-trip
|
||||||
|
per request. The cache keeps the result for ``ttl_seconds`` (default
|
||||||
|
30s) — long enough to absorb a burst, short enough that a server-side
|
||||||
|
revocation propagates quickly enough for the user to notice.
|
||||||
|
|
||||||
|
This is a *negative-caching-friendly* cache: a "valid" answer is
|
||||||
|
cached for 30s, a "revoked/not found" answer is also cached for 30s.
|
||||||
|
A revocation therefore takes up to 30s to be observed by the
|
||||||
|
middleware — acceptable trade-off, since the user is the one who
|
||||||
|
revoked it (not an attacker).
|
||||||
|
|
||||||
|
For a tighter revocation window, the route layer can be configured to
|
||||||
|
bypass the cache on the ``/auth/logout`` path; see the admin "force
|
||||||
|
logout" flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .session_service import SessionService
|
||||||
|
|
||||||
|
DEFAULT_TTL_SECONDS = 30
|
||||||
|
DEFAULT_MAX_ENTRIES = 1_000
|
||||||
|
|
||||||
|
|
||||||
|
class SessionValidationCache:
|
||||||
|
"""Tiny LRU+TTL wrapper around :class:`SessionService`.
|
||||||
|
|
||||||
|
The cache is keyed by session id. The value is a tuple of
|
||||||
|
``(valid, expires_at_monotonic)``.
|
||||||
|
|
||||||
|
The cache is intentionally process-local; multi-process deployments
|
||||||
|
get per-worker caching (still beneficial) and rely on the eventual
|
||||||
|
expiry for cross-worker consistency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
service: SessionService,
|
||||||
|
*,
|
||||||
|
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||||
|
max_entries: int = DEFAULT_MAX_ENTRIES,
|
||||||
|
) -> None:
|
||||||
|
self._service = service
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
self._max = max_entries
|
||||||
|
self._entries: "OrderedDict[str, tuple[bool, float]]" = OrderedDict()
|
||||||
|
|
||||||
|
async def is_valid(self, session_id: str) -> bool:
|
||||||
|
"""Return whether the session is currently valid, using the cache."""
|
||||||
|
cached = self._entries.get(session_id)
|
||||||
|
if cached is not None:
|
||||||
|
valid, expires_at = cached
|
||||||
|
if expires_at > time.monotonic():
|
||||||
|
# LRU touch
|
||||||
|
self._entries.move_to_end(session_id)
|
||||||
|
return valid
|
||||||
|
# Expired — drop and re-check.
|
||||||
|
self._entries.pop(session_id, None)
|
||||||
|
|
||||||
|
valid = await self._service.is_session_valid(session_id)
|
||||||
|
self._entries[session_id] = (valid, time.monotonic() + self._ttl)
|
||||||
|
self._entries.move_to_end(session_id)
|
||||||
|
while len(self._entries) > self._max:
|
||||||
|
self._entries.popitem(last=False)
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def invalidate(self, session_id: str) -> None:
|
||||||
|
"""Drop a single entry (used by /auth/logout)."""
|
||||||
|
self._entries.pop(session_id, None)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Drop all entries (used in tests)."""
|
||||||
|
self._entries.clear()
|
||||||
|
|
||||||
|
def __len__(self) -> int: # pragma: no cover — debug helper
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
|
||||||
|
_cache: SessionValidationCache | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_validation_cache() -> SessionValidationCache | None:
|
||||||
|
"""Return the process-wide cache, or ``None`` if not initialised."""
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
def init_validation_cache(
|
||||||
|
service: SessionService,
|
||||||
|
*,
|
||||||
|
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||||
|
max_entries: int = DEFAULT_MAX_ENTRIES,
|
||||||
|
) -> SessionValidationCache:
|
||||||
|
"""Initialise (or replace) the process-wide cache and return it.
|
||||||
|
|
||||||
|
Call this from the app factory once the :class:`SessionService`
|
||||||
|
singleton is ready. Returns the new cache so callers can hold
|
||||||
|
a reference for tests.
|
||||||
|
"""
|
||||||
|
global _cache
|
||||||
|
_cache = SessionValidationCache(
|
||||||
|
service, ttl_seconds=ttl_seconds, max_entries=max_entries
|
||||||
|
)
|
||||||
|
return _cache
|
||||||
|
|
||||||
|
|
||||||
|
def set_validation_cache(cache: SessionValidationCache | None) -> None:
|
||||||
|
"""Inject a custom cache (used by tests)."""
|
||||||
|
global _cache
|
||||||
|
_cache = cache
|
||||||
|
|
@ -0,0 +1,509 @@
|
||||||
|
"""SessionService — central business logic for ``auth_sessions`` (U3).
|
||||||
|
|
||||||
|
This module is the single owner of the ``auth_sessions`` table. Routes,
|
||||||
|
the auth middleware, and the admin endpoints all call into
|
||||||
|
:class:`SessionService` rather than touching the table directly, which
|
||||||
|
keeps the rotation / reuse-detection / cap-eviction rules in one place.
|
||||||
|
|
||||||
|
Lifecycle
|
||||||
|
---------
|
||||||
|
::
|
||||||
|
|
||||||
|
create ─► rotate ─► rotate ─► ...
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
active active active
|
||||||
|
│
|
||||||
|
├── revoke (user logout) → revoked
|
||||||
|
├── revoke (admin kill) → revoked
|
||||||
|
├── reuse_detected → revoked + revoke_all_for_user
|
||||||
|
├── password_change → revoke_all_for_user
|
||||||
|
└── session_cap_eviction → revoked (oldest)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from .denylist import (
|
||||||
|
DEFAULT_TTL_SECONDS,
|
||||||
|
InMemoryRecentlyRevoked,
|
||||||
|
RecentlyRevoked,
|
||||||
|
hash_token,
|
||||||
|
)
|
||||||
|
from .models import DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Return current UTC time as ISO 8601 string."""
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_db_path() -> Path:
|
||||||
|
"""Resolve the auth DB path with runtime env-var priority (test-friendly)."""
|
||||||
|
env = os.environ.get("AGENTKIT_AUTH_DB")
|
||||||
|
if env:
|
||||||
|
return Path(env)
|
||||||
|
return DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data classes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# Default per-user session cap. Matches the plan's "10 sessions" cap;
|
||||||
|
# new logins evict the oldest when the user is at the limit.
|
||||||
|
DEFAULT_SESSION_CAP = 10
|
||||||
|
|
||||||
|
# Reasons written to ``auth_sessions.revoked_reason``.
|
||||||
|
REVOKE_REASON_USER_TERMINATED = "user_terminated"
|
||||||
|
REVOKE_REASON_ADMIN_REVOKED = "admin_revoked"
|
||||||
|
REVOKE_REASON_PASSWORD_CHANGED = "password_changed"
|
||||||
|
REVOKE_REASON_REUSE_DETECTED = "reuse_detected"
|
||||||
|
REVOKE_REASON_SESSION_CAP_EVICTION = "session_cap_eviction"
|
||||||
|
|
||||||
|
|
||||||
|
class SessionNotFound(Exception):
|
||||||
|
"""The session id does not exist or is not owned by the user."""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionReuseDetected(Exception):
|
||||||
|
"""A refresh token was reused within the denylist window.
|
||||||
|
|
||||||
|
Raised by :meth:`SessionService.rotate` when the provided refresh
|
||||||
|
token is in the denylist. The caller should treat this as a security
|
||||||
|
event and call :meth:`revoke_all_for_user`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionCreate:
|
||||||
|
"""Inputs for :meth:`SessionService.create`."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
refresh_token: str
|
||||||
|
device_fingerprint: str
|
||||||
|
device_label: str
|
||||||
|
ip: str
|
||||||
|
user_agent: str
|
||||||
|
auth_provider: str
|
||||||
|
ttl_seconds: int
|
||||||
|
previous_session_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionInfo:
|
||||||
|
"""Public projection of an ``auth_sessions`` row."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
device_fingerprint: str
|
||||||
|
device_label: str
|
||||||
|
ip: str
|
||||||
|
user_agent: str
|
||||||
|
auth_provider: str
|
||||||
|
created_at: str
|
||||||
|
last_active_at: str
|
||||||
|
expires_at: str
|
||||||
|
revoked: bool
|
||||||
|
revoked_reason: str | None
|
||||||
|
previous_session_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SessionService
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SessionService:
|
||||||
|
"""CRUD + lifecycle for ``auth_sessions``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the auth SQLite DB. Defaults to the env-var
|
||||||
|
``AGENTKIT_AUTH_DB`` or :data:`models.DEFAULT_AUTH_DB_PATH`.
|
||||||
|
denylist: Recently-revoked refresh-token store. Defaults to an
|
||||||
|
in-memory LRU; tests can pass their own instance for
|
||||||
|
deterministic window control.
|
||||||
|
session_cap: Maximum number of active sessions per user. New
|
||||||
|
logins evict the oldest when the user is at the limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str | Path | None = None,
|
||||||
|
*,
|
||||||
|
denylist: RecentlyRevoked | None = None,
|
||||||
|
session_cap: int = DEFAULT_SESSION_CAP,
|
||||||
|
) -> None:
|
||||||
|
self._db_path = Path(db_path) if db_path is not None else _resolve_db_path()
|
||||||
|
self._denylist = denylist or InMemoryRecentlyRevoked()
|
||||||
|
self._session_cap = session_cap
|
||||||
|
# Parallel to ``self._denylist``: maps an old (rotated) refresh
|
||||||
|
# token hash to the user_id it belonged to, so reuse detection
|
||||||
|
# can find the right user to revoke (the row has already been
|
||||||
|
# updated to the new hash, so the table itself can't tell us).
|
||||||
|
self._recent_users: dict[str, str] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Read
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def get(self, session_id: str) -> SessionInfo | None:
|
||||||
|
"""Look up a session by id. Returns ``None`` if not found."""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM auth_sessions WHERE id = ?",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return _row_to_info(row) if row else None
|
||||||
|
|
||||||
|
async def list_for_user(
|
||||||
|
self, user_id: str, *, include_revoked: bool = False
|
||||||
|
) -> list[SessionInfo]:
|
||||||
|
"""List sessions for ``user_id``, newest first.
|
||||||
|
|
||||||
|
By default only active (non-revoked) sessions are returned. The
|
||||||
|
admin "list all sessions" view passes ``include_revoked=True``.
|
||||||
|
"""
|
||||||
|
sql = "SELECT * FROM auth_sessions WHERE user_id = ?"
|
||||||
|
args: tuple[Any, ...] = (user_id,)
|
||||||
|
if not include_revoked:
|
||||||
|
sql += " AND revoked = 0"
|
||||||
|
sql += " ORDER BY created_at DESC"
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(sql, args)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [_row_to_info(r) for r in rows]
|
||||||
|
|
||||||
|
async def list_all(
|
||||||
|
self, *, include_revoked: bool = False, limit: int = 200
|
||||||
|
) -> list[SessionInfo]:
|
||||||
|
"""List recent sessions across all users (admin view)."""
|
||||||
|
sql = "SELECT * FROM auth_sessions"
|
||||||
|
if not include_revoked:
|
||||||
|
sql += " WHERE revoked = 0"
|
||||||
|
sql += " ORDER BY created_at DESC LIMIT ?"
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(sql, (limit,))
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [_row_to_info(r) for r in rows]
|
||||||
|
|
||||||
|
async def find_by_refresh_token(self, refresh_token: str) -> SessionInfo | None:
|
||||||
|
"""Look up a session by the SHA-256 hash of its refresh token."""
|
||||||
|
h = hash_token(refresh_token)
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM auth_sessions WHERE refresh_token_hash = ?",
|
||||||
|
(h,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return _row_to_info(row) if row else None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Write
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def create(self, payload: SessionCreate) -> SessionInfo:
|
||||||
|
"""Insert a new session row.
|
||||||
|
|
||||||
|
If the user is already at the :attr:`session_cap`, the oldest
|
||||||
|
active session is evicted first. This is the only place where
|
||||||
|
the cap is enforced — callers don't have to think about it.
|
||||||
|
"""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
now = _now_iso()
|
||||||
|
expires = (
|
||||||
|
datetime.now(timezone.utc).timestamp() + payload.ttl_seconds
|
||||||
|
)
|
||||||
|
expires_iso = (
|
||||||
|
datetime.fromtimestamp(expires, tz=timezone.utc).isoformat()
|
||||||
|
)
|
||||||
|
refresh_hash = hash_token(payload.refresh_token)
|
||||||
|
|
||||||
|
await self._enforce_session_cap(payload.user_id, keep_id=session_id)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, "
|
||||||
|
" ip, user_agent, auth_provider, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked, revoked_reason, previous_session_id) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
session_id,
|
||||||
|
payload.user_id,
|
||||||
|
refresh_hash,
|
||||||
|
payload.device_fingerprint,
|
||||||
|
payload.device_label,
|
||||||
|
payload.ip,
|
||||||
|
payload.user_agent,
|
||||||
|
payload.auth_provider,
|
||||||
|
now,
|
||||||
|
now,
|
||||||
|
expires_iso,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
payload.previous_session_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return SessionInfo(
|
||||||
|
id=session_id,
|
||||||
|
user_id=payload.user_id,
|
||||||
|
device_fingerprint=payload.device_fingerprint,
|
||||||
|
device_label=payload.device_label,
|
||||||
|
ip=payload.ip,
|
||||||
|
user_agent=payload.user_agent,
|
||||||
|
auth_provider=payload.auth_provider,
|
||||||
|
created_at=now,
|
||||||
|
last_active_at=now,
|
||||||
|
expires_at=expires_iso,
|
||||||
|
revoked=False,
|
||||||
|
revoked_reason=None,
|
||||||
|
previous_session_id=payload.previous_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def rotate(
|
||||||
|
self,
|
||||||
|
old_refresh_token: str,
|
||||||
|
new_refresh_token: str,
|
||||||
|
*,
|
||||||
|
new_ttl_seconds: int,
|
||||||
|
) -> SessionInfo:
|
||||||
|
"""Rotate a refresh token: replace the old hash with a new one.
|
||||||
|
|
||||||
|
Adds the old token's hash to the denylist so a concurrent retry
|
||||||
|
using the old token is detected.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SessionNotFound: the old token has no matching session.
|
||||||
|
SessionReuseDetected: the old token is in the denylist.
|
||||||
|
The caller should call :meth:`revoke_all_for_user`.
|
||||||
|
"""
|
||||||
|
old_hash = hash_token(old_refresh_token)
|
||||||
|
if self._denylist.contains(old_hash):
|
||||||
|
# Concurrent retry — revoke everything for this user.
|
||||||
|
user_id = self._recent_users.pop(old_hash, None)
|
||||||
|
if user_id is not None:
|
||||||
|
await self.revoke_all_for_user(
|
||||||
|
user_id, reason=REVOKE_REASON_REUSE_DETECTED
|
||||||
|
)
|
||||||
|
raise SessionReuseDetected("refresh token reuse detected")
|
||||||
|
|
||||||
|
info = await self.find_by_refresh_token(old_refresh_token)
|
||||||
|
if info is None:
|
||||||
|
raise SessionNotFound("refresh token has no matching session")
|
||||||
|
if info.revoked:
|
||||||
|
raise SessionNotFound("refresh token has been revoked")
|
||||||
|
|
||||||
|
new_hash = hash_token(new_refresh_token)
|
||||||
|
now = _now_iso()
|
||||||
|
new_expires_iso = (
|
||||||
|
datetime.fromtimestamp(
|
||||||
|
datetime.now(timezone.utc).timestamp() + new_ttl_seconds,
|
||||||
|
tz=timezone.utc,
|
||||||
|
).isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE auth_sessions "
|
||||||
|
"SET refresh_token_hash = ?, last_active_at = ?, expires_at = ? "
|
||||||
|
"WHERE id = ?",
|
||||||
|
(new_hash, now, new_expires_iso, info.id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Add the old hash to the denylist AFTER the row is updated so
|
||||||
|
# a fast retry sees the new row rather than triggering reuse
|
||||||
|
# detection on a value that has just been rotated.
|
||||||
|
self._denylist.add(old_hash, ttl_seconds=DEFAULT_TTL_SECONDS)
|
||||||
|
# Track which user the old token belonged to so reuse detection
|
||||||
|
# can locate the right user to revoke (the row is already
|
||||||
|
# pointing at the new hash).
|
||||||
|
self._recent_users[old_hash] = info.user_id
|
||||||
|
|
||||||
|
refreshed = await self.get(info.id)
|
||||||
|
assert refreshed is not None # we just wrote it
|
||||||
|
return refreshed
|
||||||
|
|
||||||
|
async def touch(self, session_id: str) -> None:
|
||||||
|
"""Bump ``last_active_at`` (called on /auth/whoami success)."""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE auth_sessions SET last_active_at = ? WHERE id = ?",
|
||||||
|
(_now_iso(), session_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Revoke
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def revoke(
|
||||||
|
self, session_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED
|
||||||
|
) -> bool:
|
||||||
|
"""Revoke a single session.
|
||||||
|
|
||||||
|
Returns ``True`` if a row was updated, ``False`` if the session
|
||||||
|
was already revoked or does not exist.
|
||||||
|
"""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"UPDATE auth_sessions "
|
||||||
|
"SET revoked = 1, revoked_reason = ? "
|
||||||
|
"WHERE id = ? AND revoked = 0",
|
||||||
|
(reason, session_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
async def revoke_by_refresh_token(
|
||||||
|
self,
|
||||||
|
refresh_token: str,
|
||||||
|
*,
|
||||||
|
reason: str = REVOKE_REASON_USER_TERMINATED,
|
||||||
|
) -> bool:
|
||||||
|
"""Revoke a session identified by its refresh token (logout)."""
|
||||||
|
info = await self.find_by_refresh_token(refresh_token)
|
||||||
|
if info is None or info.revoked:
|
||||||
|
return False
|
||||||
|
return await self.revoke(info.id, reason=reason)
|
||||||
|
|
||||||
|
async def revoke_all_for_user(
|
||||||
|
self, user_id: str, *, reason: str = REVOKE_REASON_USER_TERMINATED
|
||||||
|
) -> int:
|
||||||
|
"""Revoke all of a user's active sessions.
|
||||||
|
|
||||||
|
Used when:
|
||||||
|
- The user changes their password (invalidate all other devices).
|
||||||
|
- Token reuse is detected (invalidate everything as a precaution).
|
||||||
|
- An admin kills a user account.
|
||||||
|
"""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"UPDATE auth_sessions "
|
||||||
|
"SET revoked = 1, revoked_reason = ? "
|
||||||
|
"WHERE user_id = ? AND revoked = 0",
|
||||||
|
(reason, user_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Cap enforcement
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _enforce_session_cap(self, user_id: str, *, keep_id: str) -> int:
|
||||||
|
"""Evict the oldest sessions when the user is at the cap.
|
||||||
|
|
||||||
|
Returns the number of sessions evicted (0 if under the cap).
|
||||||
|
The new session id (``keep_id``) is exempt from eviction.
|
||||||
|
"""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id FROM auth_sessions "
|
||||||
|
"WHERE user_id = ? AND revoked = 0 AND id != ? "
|
||||||
|
"ORDER BY created_at DESC",
|
||||||
|
(user_id, keep_id),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
if len(rows) < self._session_cap:
|
||||||
|
return 0
|
||||||
|
# Evict the oldest until we are under the cap (with the new row).
|
||||||
|
to_evict = rows[self._session_cap - 1 :]
|
||||||
|
if not to_evict:
|
||||||
|
return 0
|
||||||
|
ids = [r["id"] for r in to_evict]
|
||||||
|
placeholders = ",".join("?" for _ in ids)
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
f"UPDATE auth_sessions "
|
||||||
|
f"SET revoked = 1, revoked_reason = ? "
|
||||||
|
f"WHERE id IN ({placeholders})",
|
||||||
|
(REVOKE_REASON_SESSION_CAP_EVICTION, *ids),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return len(ids)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Cache invalidation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def is_session_valid(self, session_id: str) -> bool:
|
||||||
|
"""Return ``True`` if the session exists, is not revoked, and not expired.
|
||||||
|
|
||||||
|
The middleware calls this on every request that carries a V2
|
||||||
|
JWT (sid claim). The result is cached in a 30-second
|
||||||
|
in-process LRU to keep the hot path cheap.
|
||||||
|
"""
|
||||||
|
info = await self.get(session_id)
|
||||||
|
if info is None or info.revoked:
|
||||||
|
return False
|
||||||
|
expires = datetime.fromisoformat(info.expires_at)
|
||||||
|
if expires.timestamp() <= datetime.now(timezone.utc).timestamp():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_info(row: aiosqlite.Row) -> SessionInfo:
|
||||||
|
return SessionInfo(
|
||||||
|
id=row["id"],
|
||||||
|
user_id=row["user_id"],
|
||||||
|
device_fingerprint=row["device_fingerprint"],
|
||||||
|
device_label=row["device_label"],
|
||||||
|
ip=row["ip"],
|
||||||
|
user_agent=row["user_agent"],
|
||||||
|
auth_provider=row["auth_provider"],
|
||||||
|
created_at=row["created_at"],
|
||||||
|
last_active_at=row["last_active_at"],
|
||||||
|
expires_at=row["expires_at"],
|
||||||
|
revoked=bool(row["revoked"]),
|
||||||
|
revoked_reason=row["revoked_reason"],
|
||||||
|
previous_session_id=row["previous_session_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_session_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_session_service: SessionService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_service() -> SessionService:
|
||||||
|
"""Return the process-wide :class:`SessionService` (lazy singleton)."""
|
||||||
|
global _session_service
|
||||||
|
if _session_service is None:
|
||||||
|
_session_service = SessionService()
|
||||||
|
return _session_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_session_service(service: SessionService | None) -> None:
|
||||||
|
"""Inject a custom :class:`SessionService` (used by tests)."""
|
||||||
|
global _session_service
|
||||||
|
_session_service = service
|
||||||
|
|
@ -0,0 +1,256 @@
|
||||||
|
"""Unit tests for SessionService (U3)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.server.auth.session_service import (
|
||||||
|
REVOKE_REASON_PASSWORD_CHANGED,
|
||||||
|
REVOKE_REASON_REUSE_DETECTED,
|
||||||
|
REVOKE_REASON_SESSION_CAP_EVICTION,
|
||||||
|
REVOKE_REASON_USER_TERMINATED,
|
||||||
|
SessionCreate,
|
||||||
|
SessionNotFound,
|
||||||
|
SessionReuseDetected,
|
||||||
|
SessionService,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def auth_db(tmp_path: Path):
|
||||||
|
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
prev = os.environ.get("AGENTKIT_AUTH_DB")
|
||||||
|
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
|
||||||
|
try:
|
||||||
|
yield db_path
|
||||||
|
finally:
|
||||||
|
if prev is None:
|
||||||
|
os.environ.pop("AGENTKIT_AUTH_DB", None)
|
||||||
|
else:
|
||||||
|
os.environ["AGENTKIT_AUTH_DB"] = prev
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def user_id(auth_db: Path) -> str:
|
||||||
|
"""Insert a single user and return its id."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
async with aiosqlite.connect(str(auth_db)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO users (id, username, email, password_hash, role, "
|
||||||
|
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
"created_at, updated_at, last_login_at, created_by) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
user_id,
|
||||||
|
"alice",
|
||||||
|
"alice@example.com",
|
||||||
|
"x",
|
||||||
|
"member",
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
now,
|
||||||
|
now,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def svc(auth_db: Path) -> SessionService:
|
||||||
|
"""A SessionService backed by the in-memory denylist for determinism."""
|
||||||
|
return SessionService(
|
||||||
|
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
|
||||||
|
return SessionCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
device_fingerprint="fp",
|
||||||
|
device_label="Test device",
|
||||||
|
ip="127.0.0.1",
|
||||||
|
user_agent="pytest",
|
||||||
|
auth_provider="local",
|
||||||
|
ttl_seconds=3600,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# create
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
|
||||||
|
info = await svc.create(_make_create(user_id))
|
||||||
|
assert info.user_id == user_id
|
||||||
|
assert info.revoked is False
|
||||||
|
# Round-trip via the read API
|
||||||
|
fetched = await svc.get(info.id)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == info.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
|
||||||
|
# Cap is 3; insert 3, then a 4th should evict the first.
|
||||||
|
a = await svc.create(_make_create(user_id, "rt-a"))
|
||||||
|
b = await svc.create(_make_create(user_id, "rt-b"))
|
||||||
|
c = await svc.create(_make_create(user_id, "rt-c"))
|
||||||
|
d = await svc.create(_make_create(user_id, "rt-d"))
|
||||||
|
|
||||||
|
assert (await svc.get(a.id)).revoked is True
|
||||||
|
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
|
||||||
|
for kept in (b, c, d):
|
||||||
|
assert (await svc.get(kept.id)).revoked is False
|
||||||
|
|
||||||
|
# Sanity: exactly session_cap remain (this test uses cap=3)
|
||||||
|
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
|
||||||
|
assert len(remaining) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# rotate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
|
||||||
|
info = await svc.create(_make_create(user_id, "rt-old"))
|
||||||
|
new = await svc.rotate(
|
||||||
|
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
|
||||||
|
)
|
||||||
|
assert new.id == info.id
|
||||||
|
# Old hash no longer resolves
|
||||||
|
assert await svc.find_by_refresh_token("rt-old") is None
|
||||||
|
# New hash does
|
||||||
|
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
|
||||||
|
await svc.create(_make_create(user_id, "rt-1"))
|
||||||
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
||||||
|
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
|
||||||
|
with pytest.raises(SessionNotFound):
|
||||||
|
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
|
||||||
|
await svc.create(_make_create(user_id, "rt-1"))
|
||||||
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
||||||
|
# A second attempt with the old token: must raise and revoke all
|
||||||
|
with pytest.raises(SessionReuseDetected):
|
||||||
|
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
|
||||||
|
# All sessions for this user are now revoked
|
||||||
|
active = await svc.list_for_user(user_id, include_revoked=True)
|
||||||
|
assert all(s.revoked for s in active)
|
||||||
|
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
|
||||||
|
await svc.create(_make_create(user_id, "rt-1"))
|
||||||
|
await svc.revoke_by_refresh_token("rt-1")
|
||||||
|
with pytest.raises(SessionNotFound):
|
||||||
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# revoke
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
|
||||||
|
info = await svc.create(_make_create(user_id))
|
||||||
|
ok = await svc.revoke(info.id)
|
||||||
|
assert ok is True
|
||||||
|
assert (await svc.get(info.id)).revoked is True
|
||||||
|
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
|
||||||
|
info = await svc.create(_make_create(user_id))
|
||||||
|
assert await svc.revoke(info.id) is True
|
||||||
|
assert await svc.revoke(info.id) is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
|
||||||
|
await svc.create(_make_create(user_id, "rt-1"))
|
||||||
|
assert await svc.revoke_by_refresh_token("rt-1") is True
|
||||||
|
assert await svc.revoke_by_refresh_token("rt-1") is False
|
||||||
|
assert await svc.revoke_by_refresh_token("rt-never") is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
|
||||||
|
await svc.create(_make_create(user_id, "a"))
|
||||||
|
await svc.create(_make_create(user_id, "b"))
|
||||||
|
n = await svc.revoke_all_for_user(
|
||||||
|
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
|
||||||
|
)
|
||||||
|
assert n == 2
|
||||||
|
active = await svc.list_for_user(user_id)
|
||||||
|
assert active == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# list / get
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
|
||||||
|
a = await svc.create(_make_create(user_id, "a"))
|
||||||
|
b = await svc.create(_make_create(user_id, "b"))
|
||||||
|
await svc.revoke(a.id)
|
||||||
|
visible = await svc.list_for_user(user_id)
|
||||||
|
assert [s.id for s in visible] == [b.id]
|
||||||
|
visible_all = await svc.list_for_user(user_id, include_revoked=True)
|
||||||
|
assert {s.id for s in visible_all} == {a.id, b.id}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
|
||||||
|
info = await svc.create(_make_create(user_id))
|
||||||
|
assert await svc.is_session_valid(info.id) is True
|
||||||
|
await svc.revoke(info.id)
|
||||||
|
assert await svc.is_session_valid(info.id) is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
|
||||||
|
# Create with a very short TTL
|
||||||
|
create = SessionCreate(
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_token="rt-x",
|
||||||
|
device_fingerprint="fp",
|
||||||
|
device_label="d",
|
||||||
|
ip="",
|
||||||
|
user_agent="",
|
||||||
|
auth_provider="local",
|
||||||
|
ttl_seconds=0, # expires immediately
|
||||||
|
)
|
||||||
|
info = await svc.create(create)
|
||||||
|
# Re-check with a future expires_at: should be expired
|
||||||
|
assert await svc.is_session_valid(info.id) is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
|
||||||
|
assert await svc.is_session_valid("nonexistent-id") is False
|
||||||
Loading…
Reference in New Issue