diff --git a/src/agentkit/server/auth/models.py b/src/agentkit/server/auth/models.py index 4c1d6bb..316e1ec 100644 --- a/src/agentkit/server/auth/models.py +++ b/src/agentkit/server/auth/models.py @@ -5,10 +5,23 @@ stored as ``String(36)`` so the same schema works on both SQLite and PostgreSQL without dialect-specific types. Use :func:`init_auth_db` to create the tables on startup. + +Schema versioning +----------------- +The :data:`_SCHEMA_VERSION` constant tracks the current auth DB schema. The +:func:`_backfill_user_sessions` one-time migration is gated on this version +to ensure idempotency. After a successful backfill the version is stored +in the ``auth_meta`` table so subsequent restarts are no-ops. + +V2 additions (2026-06-20, Centralized Auth & Token Persistence): +- New ``auth_sessions`` table with full device/IP/audit metadata and an + ``auth_provider`` column for future IdP integration traceability. +- New ``auth_meta`` table for storing schema version + migration state. """ from __future__ import annotations +import json import logging import os from collections.abc import Mapping @@ -88,10 +101,16 @@ class UserApiKeyModel(Base): class UserSessionModel(Base): - """Refresh-token session record. + """Refresh-token session record (V1, deprecated — see :class:`AuthSessionModel`). Stores the SHA-256 hash of the refresh token (never the plaintext). ``revoked_at`` is set on logout / forced revocation. + + .. deprecated:: + Kept for one minor version (per U10 back-compat shim) so legacy + clients holding JWTs without ``sid`` claim can still validate. + New code should use :class:`AuthSessionModel` (table ``auth_sessions``) + which carries full device/IP/audit metadata. """ __tablename__ = "user_sessions" @@ -107,6 +126,50 @@ class UserSessionModel(Base): revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True) +class AuthSessionModel(Base): + """Server-side session record (V2, the primary session table going forward). + + Each row corresponds to a single refresh-token issuance. The full JWT + session id (``sid`` claim) is the row's ``id`` (UUID string). + + V2 fields (vs V1 ``user_sessions``): + - ``device_fingerprint`` / ``device_label``: surfaces "which device is + this session" in the admin UI. + - ``ip`` / ``user_agent``: audit trail. + - ``last_active_at``: updated on every successful refresh, used to + display "last seen" in the sessions list. + - ``revoked`` (0/1) + ``revoked_reason``: explicit revoked-state machine + with machine-readable reasons (``user_terminated``, ``password_changed``, + ``admin_revoked``, ``reuse_detected``, ``session_cap_eviction``). + - ``previous_session_id``: back-pointer to the previous session id, + written on refresh rotation, for audit trail. + - ``auth_provider``: the AuthProvider that issued this session + (``local`` / ``oidc-stub`` / future ``oidc-keycloak`` / ``saml`` / ``ldap``). + Enables admin "list sessions by provider" queries and audit traceability. + """ + + __tablename__ = "auth_sessions" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True) + refresh_token_hash: Mapped[str] = mapped_column( + String(64), unique=True, nullable=False, index=True + ) + device_fingerprint: Mapped[str] = mapped_column(String(128), nullable=False, default="unknown") + device_label: Mapped[str] = mapped_column(String(256), nullable=False, default="Unknown device") + ip: Mapped[str] = mapped_column(String(64), nullable=False, default="") + user_agent: Mapped[str] = mapped_column(String(512), nullable=False, default="") + auth_provider: Mapped[str] = mapped_column( + String(32), nullable=False, default="local", index=True + ) + created_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso) + last_active_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso) + expires_at: Mapped[str] = mapped_column(String(64), nullable=False) + revoked: Mapped[bool] = mapped_column(default=False, nullable=False, index=True) + revoked_reason: Mapped[str | None] = mapped_column(String(64), nullable=True) + previous_session_id: Mapped[str | None] = mapped_column(String(36), nullable=True) + + class TerminalWhitelistUserModel(Base): """Per-user terminal command whitelist. @@ -244,6 +307,44 @@ CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id); CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash ON user_sessions(refresh_token_hash); +-- V2: auth_sessions replaces user_sessions as the primary session table. +-- Stores device/IP/audit metadata and auth_provider for IdP traceability. +-- Per-row indexes are sized for the most common access patterns: +-- * (user_id, revoked, expires_at) — cap-count, list-active, refresh-validate +-- * expires_at — cleanup sweeps +-- * refresh_token_hash — uniqueness + fast lookup on /auth/refresh +-- * auth_provider — admin "list sessions by provider" queries +CREATE TABLE IF NOT EXISTS auth_sessions ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + refresh_token_hash TEXT NOT NULL UNIQUE, + device_fingerprint TEXT NOT NULL DEFAULT 'unknown', + device_label TEXT NOT NULL DEFAULT 'Unknown device', + ip TEXT NOT NULL DEFAULT '', + user_agent TEXT NOT NULL DEFAULT '', + auth_provider TEXT NOT NULL DEFAULT 'local', + created_at TEXT NOT NULL, + last_active_at TEXT NOT NULL, + expires_at TEXT NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + revoked_reason TEXT, + previous_session_id TEXT +); +CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id_active + ON auth_sessions(user_id, revoked, expires_at); +CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at + ON auth_sessions(expires_at); +CREATE INDEX IF NOT EXISTS idx_auth_sessions_auth_provider + ON auth_sessions(auth_provider); + +-- V2: auth_meta stores schema version + migration completion markers. +-- Used by init_auth_db to gate one-time migrations (idempotency). +CREATE TABLE IF NOT EXISTS auth_meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + updated_at TEXT NOT NULL +); + CREATE TABLE IF NOT EXISTS terminal_whitelist_user ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, @@ -307,12 +408,137 @@ CREATE INDEX IF NOT EXISTS idx_terminal_approvals_status """ +# --------------------------------------------------------------------------- +# Schema versioning + one-time migrations +# --------------------------------------------------------------------------- + +# Current auth DB schema version. Bump this when adding new tables/columns +# that require data backfill or migration. The :func:`init_auth_db` function +# uses this together with the ``auth_meta.schema_version`` row to decide +# which migrations to run. +_SCHEMA_VERSION = 2 + +_META_SCHEMA_VERSION_KEY = "schema_version" + + +async def _get_meta_value(db: aiosqlite.Connection, key: str) -> str | None: + """Read a key from the ``auth_meta`` table. Returns ``None`` if missing.""" + cursor = await db.execute("SELECT value FROM auth_meta WHERE key = ?", (key,)) + row = await cursor.fetchone() + return row["value"] if row else None + + +async def _set_meta_value(db: aiosqlite.Connection, key: str, value: str) -> None: + """Upsert a key in the ``auth_meta`` table.""" + await db.execute( + "INSERT INTO auth_meta (key, value, updated_at) VALUES (?, ?, ?) " + "ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at", + (key, value, _now_iso()), + ) + + +async def _backfill_user_sessions(db: aiosqlite.Connection) -> int: + """One-time backfill from ``user_sessions`` (V1) to ``auth_sessions`` (V2). + + Runs only when ``auth_sessions`` is empty AND ``user_sessions`` has rows. + Idempotent: subsequent restarts are no-ops because we mark the backfill + as completed in ``auth_meta``. + + For each non-revoked V1 session, copies: + - id (reused — see note below) + - user_id + - refresh_token_hash + - device_fingerprint / device_label / ip / user_agent from the legacy + ``device_info`` JSON blob (best-effort) + - created_at / expires_at + - last_active_at defaults to created_at + - revoked=0 (already filtered) + - revoked_reason=None + - auth_provider='local' (default; backfilled rows are pre-IdP) + + The original ``id`` is preserved so that legacy clients holding the + old refresh_token_hash still match a row in the new table — this is + what the back-compat path in U10 (``get_current_user`` for legacy + JWTs) relies on. + + Returns: + Number of rows backfilled (0 if already done or nothing to backfill). + """ + # Idempotency: check the marker + if await _get_meta_value(db, "backfill_user_sessions_v1_to_v2") == "done": + return 0 + + cursor = await db.execute("SELECT COUNT(*) FROM auth_sessions") + (count,) = await cursor.fetchone() + if count > 0: + # auth_sessions already has data — this is a fresh V2 install, not an + # upgrade. Mark the backfill done so we never re-check. + await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done") + await db.commit() + return 0 + + cursor = await db.execute( + "SELECT id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at " + "FROM user_sessions WHERE revoked_at IS NULL" + ) + rows = await cursor.fetchall() + backfilled = 0 + for row in rows: + try: + device_info = json.loads(row["device_info"]) if row["device_info"] else {} + except (json.JSONDecodeError, TypeError): + device_info = {} + + await db.execute( + "INSERT OR IGNORE INTO auth_sessions " + "(id, user_id, refresh_token_hash, device_fingerprint, device_label, " + " ip, user_agent, auth_provider, created_at, last_active_at, expires_at, " + " revoked, revoked_reason, previous_session_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + row["id"], # reuse legacy id for back-compat with old clients + row["user_id"], + row["refresh_token_hash"], + device_info.get("fingerprint", "unknown"), + device_info.get("label", "Unknown device"), + device_info.get("ip", ""), + device_info.get("user_agent", ""), + "local", # backfilled rows are pre-IdP by definition + row["created_at"], + row["created_at"], # last_active_at defaults to created_at + row["expires_at"], + 0, # not revoked (already filtered) + None, + None, + ), + ) + backfilled += 1 + + if backfilled: + logger.info( + f"Backfilled {backfilled} user_sessions rows to auth_sessions " + f"(schema v{_SCHEMA_VERSION})" + ) + + # Mark the backfill as completed regardless of how many rows were moved. + # (idempotency: even a 0-row backfill is "done".) + await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done") + await db.commit() + return backfilled + + async def init_auth_db(db_path: str | Path | None = None) -> Path: """Create auth tables if they do not exist. Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight, zero-config bootstrap that mirrors :class:`SqliteConversationStore`. + On startup, this function: + 1. Creates all tables and indexes from :data:`_SCHEMA_SQL` (idempotent). + 2. Records the current :data:`_SCHEMA_VERSION` in ``auth_meta``. + 3. Runs any pending one-time migrations (currently: V1 → V2 backfill + from ``user_sessions`` to ``auth_sessions``). + Args: db_path: Path to the SQLite file. Defaults to :data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project @@ -325,8 +551,19 @@ async def init_auth_db(db_path: str | Path | None = None) -> Path: path.parent.mkdir(parents=True, exist_ok=True) async with aiosqlite.connect(str(path)) as db: + db.row_factory = aiosqlite.Row await db.execute("PRAGMA journal_mode=WAL") await db.executescript(_SCHEMA_SQL) + + # Record the current schema version (idempotent upsert). + current = await _get_meta_value(db, _META_SCHEMA_VERSION_KEY) + if current != str(_SCHEMA_VERSION): + await _set_meta_value(db, _META_SCHEMA_VERSION_KEY, str(_SCHEMA_VERSION)) + logger.info(f"Auth DB schema version set to {_SCHEMA_VERSION}") + + # Run pending migrations (each is internally idempotent). + await _backfill_user_sessions(db) + await db.commit() logger.info(f"Auth DB initialized at {path}") @@ -353,3 +590,28 @@ def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any "last_login_at": row["last_login_at"], "created_by": row["created_by"], } + + +def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]: + """Convert an ``auth_sessions`` row into a JSON-safe dict. + + The ``revoked`` field is normalized to a Python ``bool`` (the DB stores + 0/1). The full set of audit fields is included so the admin UI and + API responses can surface device/IP/last-active information without + a separate lookup. + """ + return { + "id": row["id"], + "user_id": row["user_id"], + "device_fingerprint": row["device_fingerprint"], + "device_label": row["device_label"], + "ip": row["ip"], + "user_agent": row["user_agent"], + "auth_provider": row["auth_provider"], + "created_at": row["created_at"], + "last_active_at": row["last_active_at"], + "expires_at": row["expires_at"], + "revoked": bool(row["revoked"]), + "revoked_reason": row["revoked_reason"], + "previous_session_id": row["previous_session_id"], + } diff --git a/src/agentkit/server/auth/providers/__init__.py b/src/agentkit/server/auth/providers/__init__.py new file mode 100644 index 0000000..5c9a90b --- /dev/null +++ b/src/agentkit/server/auth/providers/__init__.py @@ -0,0 +1,115 @@ +"""AuthProvider implementations and dependency-injection factory. + +Public surface: + +- :class:`AuthProvider` (re-exported from :mod:`.base`) — the protocol + every backend must satisfy. +- :class:`LocalAuthProvider` — SQLite + bcrypt default. +- :class:`StubOIDCProvider` — interface placeholder for the future + OIDC integration. +- :class:`User` — provider-agnostic user value object. +- :func:`get_auth_provider` — DI factory used by routes via + ``Depends(get_auth_provider)``. +- :func:`reset_auth_provider` — clears the lru_cache singleton + (used in tests + when the auth provider config changes at runtime). + +Configuration +------------- +The provider is selected via the ``AGENTKIT_AUTH_PROVIDER`` environment +variable (default: ``"local"``). When the future ``auth.provider`` +field is added to ``agentkit.yaml`` it should override the env var. + +Adding a new provider +--------------------- +1. Create ``auth/providers/.py`` with a class that satisfies + the :class:`AuthProvider` Protocol (i.e. has ``name: str`` and the + four async methods). +2. Import it here and add a branch to :func:`get_auth_provider`. +3. Add it to ``AGENTKIT_AUTH_PROVIDER`` enum in the deployment docs. +""" + +from __future__ import annotations + +import logging +import os +from functools import lru_cache +from pathlib import Path + +from .base import AuthProvider +from .exceptions import AuthProviderError, InvalidCredentials, ProviderNotImplemented +from .local import LocalAuthProvider +from .oidc_stub import StubOIDCProvider +from .user import User + +logger = logging.getLogger(__name__) + + +# Re-exports for ergonomic imports: `from agentkit.server.auth.providers import AuthProvider` +__all__ = [ + "AuthProvider", + "AuthProviderError", + "InvalidCredentials", + "LocalAuthProvider", + "ProviderNotImplemented", + "StubOIDCProvider", + "User", + "get_auth_provider", + "reset_auth_provider", +] + + +# Default provider name. Configurable via AGENTKIT_AUTH_PROVIDER env var +# (overridable when the auth.provider field is added to agentkit.yaml). +DEFAULT_AUTH_PROVIDER = "local" + + +def _resolve_provider_name() -> str: + """Return the configured provider name, defaulting to 'local'.""" + name = os.environ.get("AGENTKIT_AUTH_PROVIDER", DEFAULT_AUTH_PROVIDER).strip() + if not name: + return DEFAULT_AUTH_PROVIDER + return name + + +def _resolve_db_path() -> Path: + """Return the auth DB path (overridable for tests via env var).""" + from ..models import DEFAULT_AUTH_DB_PATH + + env = os.environ.get("AGENTKIT_AUTH_DB") + if env: + return Path(env) + return DEFAULT_AUTH_DB_PATH + + +@lru_cache(maxsize=1) +def get_auth_provider() -> AuthProvider: + """Return the configured :class:`AuthProvider` (memoized singleton). + + The ``lru_cache`` is process-local and the value is resolved on + first call. Use :func:`reset_auth_provider` to clear the cache + (e.g. when the configuration changes at runtime, or in test + fixtures that need a fresh provider with a different db_path). + + Raises: + ValueError: if the configured provider name is not recognized. + """ + name = _resolve_provider_name() + if name == "local": + return LocalAuthProvider(db_path=_resolve_db_path()) + if name == "oidc-stub": + return StubOIDCProvider() + raise ValueError( + f"unknown auth provider: {name!r}. " + f"Supported providers: 'local', 'oidc-stub'. " + f"Set AGENTKIT_AUTH_PROVIDER or update agentkit.yaml's auth.provider field." + ) + + +def reset_auth_provider() -> None: + """Clear the memoized :class:`AuthProvider` singleton. + + Use in tests (e.g. in a fixture's teardown) or in code paths that + change the auth provider configuration at runtime and need a + re-resolution. + """ + get_auth_provider.cache_clear() diff --git a/src/agentkit/server/auth/providers/base.py b/src/agentkit/server/auth/providers/base.py new file mode 100644 index 0000000..fabb7a8 --- /dev/null +++ b/src/agentkit/server/auth/providers/base.py @@ -0,0 +1,109 @@ +"""AuthProvider protocol — pluggable authentication backend contract. + +The :class:`AuthProvider` Protocol defines the minimal surface area the +auth subsystem needs from any authentication backend (Local today, OIDC +/ SAML / LDAP tomorrow). Routes and admin endpoints call only these +methods and never touch the underlying user store directly. Adding a new +IdP integration is a matter of writing a new adapter that satisfies +this Protocol. + +Design notes: + +- **No sync surface for password verification**: the only entry point is + :meth:`authenticate` which takes the plaintext password. Internally each + adapter chooses how to verify (bcrypt locally, redirect to IdP, etc.). +- **No persistence responsibility**: providers return :class:`User` objects + but do not manage sessions, refresh tokens, or session state. That is + the route + SessionService's job (see :mod:`agentkit.server.auth.session`). +- **Audit-friendly**: :attr:`name` is written to ``auth_sessions.auth_provider`` + on every login, so the source of every session is traceable. This is + what enables "list sessions by provider" admin queries and future + cross-IdP policy enforcement. +- **runtime_checkable**: the Protocol is decorated so that + ``isinstance(provider, AuthProvider)`` works at runtime, enabling + defensive checks in tests and in DI wiring. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from .user import User + + +@runtime_checkable +class AuthProvider(Protocol): + """All authentication backends must implement this surface. + + The route layer only calls the methods below. It is intentionally + unaware of whether the backing store is SQLite, an OIDC IdP, LDAP, + or anything else. + """ + + name: str + """Identifier for this provider, written to ``auth_sessions.auth_provider``. + + Convention: ``local`` for the local SQLite + bcrypt backend, + ``oidc-`` for OIDC (e.g. ``oidc-keycloak``, + ``oidc-feishu``), ``saml`` / ``ldap`` for the other planned + integrations. This value is the stable contract for audit + traceability — do not rename without a migration plan. + """ + + async def authenticate(self, *, username: str, password: str) -> User: + """Verify ``username`` + ``password`` and return the :class:`User`. + + Args: + username: The submitted username (or, for OIDC, the IdP + subject id — but that's a future adapter's concern). + password: The submitted plaintext password. + + Returns: + The matched :class:`User` on success. + + Raises: + InvalidCredentials: if the user does not exist, the password + is wrong, or the user is inactive. Callers MUST NOT + distinguish between these three cases in the error + message returned to the client (timing-attack / + username-enumeration mitigation). + """ + ... + + async def get_user_by_id(self, user_id: str) -> User | None: + """Look up a :class:`User` by primary key. + + Used by: + - Admin endpoints that need to display user info by id + - Session validation in the cold-start / whoami path + - Audit log enrichment + + Returns ``None`` if no user exists with this id (or the user + is inactive — convention: inactive users are "not found" from + the auth layer's perspective). + """ + ... + + async def sync_user_attributes(self, user_id: str) -> None: + """Refresh user attributes (department / email / title) from the source of truth. + + - :class:`LocalAuthProvider`: no-op (attributes are managed locally). + - OIDC adapter (future): pull the latest profile from the IdP and + write back to the local ``users`` table. + + Implementations that have nothing to sync should still define + this method (returning ``None``) so the contract is uniform. + """ + ... + + async def revoke_user(self, user_id: str) -> None: + """Disable a user account (e.g. on termination or lock-out). + + - :class:`LocalAuthProvider`: ``UPDATE users SET is_active = 0`` + - OIDC adapter (future): call the IdP's disable API + + The admin endpoint that calls this does NOT also need to + revoke the user's active sessions — that is the + :class:`SessionService`'s job, called separately. + """ + ... diff --git a/src/agentkit/server/auth/providers/exceptions.py b/src/agentkit/server/auth/providers/exceptions.py new file mode 100644 index 0000000..bdc0972 --- /dev/null +++ b/src/agentkit/server/auth/providers/exceptions.py @@ -0,0 +1,26 @@ +"""Exceptions raised by AuthProvider implementations. + +These are caught by the route layer and translated to HTTP responses +(401 for :class:`InvalidCredentials`, 501 for :class:`ProviderNotImplemented`). +""" + + +class AuthProviderError(Exception): + """Base class for all auth provider errors.""" + + +class InvalidCredentials(AuthProviderError): + """Raised when username / password is wrong, or the user is inactive. + + Translates to HTTP 401. The error message MUST NOT leak which of + "user not found" vs "wrong password" vs "user inactive" failed — + that is a username enumeration risk. + """ + + +class ProviderNotImplemented(AuthProviderError): + """Raised when a configured provider is not yet implemented. + + Translates to HTTP 501. Used by :class:`StubOIDCProvider` and any + future adapter that is registered but has no real implementation. + """ diff --git a/src/agentkit/server/auth/providers/local.py b/src/agentkit/server/auth/providers/local.py new file mode 100644 index 0000000..de1f8d5 --- /dev/null +++ b/src/agentkit/server/auth/providers/local.py @@ -0,0 +1,163 @@ +"""Local authentication provider — SQLite + bcrypt. + +The default :class:`AuthProvider` implementation. Authenticates users +against the local ``users`` table in the auth SQLite database using +the bcrypt cost=12 password hash (see :mod:`agentkit.server.auth.password`). + +This is a behavioral equivalent of the password-verification code that +previously lived inline in ``routes/auth.py`` — moved here so the route +layer can call a single :meth:`authenticate` method regardless of which +backend is configured. + +Future-IdP note +--------------- +When the organization moves to OIDC / SAML / LDAP, this class does not +need to be deleted. It can continue to serve as a "local emergency +account" provider, configurable side-by-side with the IdP provider via +a future composite / multi-provider setup. For now it is the only +implementation. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path + +import aiosqlite + +from ..models import DEFAULT_AUTH_DB_PATH +from ..password import verify_password +from .user import User + +logger = logging.getLogger(__name__) + + +def _resolve_db_path() -> Path: + """Resolve the auth DB path with runtime env-var priority. + + The :data:`models.DEFAULT_AUTH_DB_PATH` constant is captured at + module-import time and therefore cannot see test-time env mutations. + Re-reading ``AGENTKIT_AUTH_DB`` here keeps the provider + "test-friendly" (tests can ``monkeypatch.setenv`` before constructing + the provider) without giving up the default path when no env is set. + """ + env = os.environ.get("AGENTKIT_AUTH_DB") + if env: + return Path(env) + return DEFAULT_AUTH_DB_PATH + + +class LocalAuthProvider: + """AuthProvider backed by the local SQLite ``users`` table + bcrypt. + + Args: + db_path: Path to the auth DB. Defaults to the value of the + ``AGENTKIT_AUTH_DB`` env var, falling back to + :data:`agentkit.server.auth.models.DEFAULT_AUTH_DB_PATH`. + Each operation opens a short-lived aiosqlite connection; + the existing route layer follows the same pattern, so no + connection pooling is introduced here. If a future + deployment needs pooling, swap in a ``db_factory: Callable`` + here without changing the protocol. + """ + + name = "local" + + def __init__(self, db_path: str | Path | None = None) -> None: + self._db_path = Path(db_path) if db_path is not None else _resolve_db_path() + + @property + def db_path(self) -> Path: + return self._db_path + + async def authenticate(self, *, username: str, password: str) -> User: + """Verify the username + password against the local users table. + + Raises :class:`InvalidCredentials` on every failure mode + (unknown user, wrong password, inactive user) with the same + error message — preventing username enumeration via error + inspection. Constant-time-equivalent behavior is also + ensured by always running a real bcrypt computation + (against a dummy hash) when the user does not exist, + matching the timing of the "user exists, wrong password" path. + """ + from .exceptions import InvalidCredentials # local import to avoid cycle at module load + + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT id, username, email, password_hash, role, is_active, " + "is_terminal_authorized, is_server_terminal_authorized, " + "created_at, updated_at, last_login_at, created_by " + "FROM users WHERE username = ?", + (username,), + ) + row = await cursor.fetchone() + + if row is None or not bool(row["is_active"]): + # Run a real bcrypt verification against a valid-format dummy + # hash so the response time matches the "user exists, wrong + # password" path (~250ms). Prevents username enumeration via + # timing. The dummy hash is invalid (won't match any password) + # but has the right shape so bcrypt.checkpw doesn't short-circuit. + _DUMMY_BCRYPT_HASH = "$2b$12$abcdefghijklmnopqrstuuABCDEFGHIJKLMNOPQRSTUVWXYZ0123" + verify_password(password, _DUMMY_BCRYPT_HASH) + raise InvalidCredentials("invalid username or password") + + if not verify_password(password, row["password_hash"]): + raise InvalidCredentials("invalid username or password") + + return _row_to_user(row) + + async def get_user_by_id(self, user_id: str) -> User | None: + """Look up a user by id. Returns ``None`` if not found or inactive.""" + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT id, username, email, password_hash, role, is_active, " + "is_terminal_authorized, is_server_terminal_authorized, " + "created_at, updated_at, last_login_at, created_by " + "FROM users WHERE id = ? AND is_active = 1", + (user_id,), + ) + row = await cursor.fetchone() + return _row_to_user(row) if row else None + + async def sync_user_attributes(self, user_id: str) -> None: + """No-op: local provider has no upstream source of truth to sync from.""" + return None + + async def revoke_user(self, user_id: str) -> None: + """Disable a user account (``is_active = 0``).""" + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute( + "UPDATE users SET is_active = 0, updated_at = ? WHERE id = ?", + (_now_iso(), user_id), + ) + await db.commit() + logger.info(f"Revoked user {user_id} via LocalAuthProvider") + + +def _row_to_user(row: aiosqlite.Row) -> User: + """Convert a ``users`` row to a :class:`User` value object.""" + return User( + id=row["id"], + username=row["username"], + email=row["email"], + role=row["role"], + is_active=bool(row["is_active"]), + is_terminal_authorized=bool(row["is_terminal_authorized"]), + is_server_terminal_authorized=bool(row["is_server_terminal_authorized"]), + created_at=row["created_at"], + updated_at=row["updated_at"], + last_login_at=row["last_login_at"], + created_by=row["created_by"], + ) + + +def _now_iso() -> str: + """Return current UTC time as ISO 8601 string.""" + from datetime import datetime, timezone + + return datetime.now(timezone.utc).isoformat() diff --git a/src/agentkit/server/auth/providers/oidc_stub.py b/src/agentkit/server/auth/providers/oidc_stub.py new file mode 100644 index 0000000..1f64b28 --- /dev/null +++ b/src/agentkit/server/auth/providers/oidc_stub.py @@ -0,0 +1,83 @@ +"""Stub OIDC provider — interface contract placeholder. + +This is a deliberately unimplemented :class:`AuthProvider` whose only +job is to fail loudly when ``auth.provider: oidc-stub`` is configured +without the real OIDC integration being in place. It exists so that +the configuration switch and the dependency-injection wiring can be +exercised end-to-end before the IdP integration work is scheduled. + +Future OIDC integration checklist +--------------------------------- +When the real OIDC integration lands, replace this file with +:file:`auth/providers/oidc.py` implementing: + +- [ ] :meth:`authenticate` — accept (username, password)? OR redirect-flow? + OIDC is fundamentally a redirect flow, NOT a password POST. The + right answer is probably a separate ``/auth/oauth/{provider}/redirect`` + and ``/auth/oauth/{provider}/callback`` route pair that bypasses + the password-based login. The :class:`AuthProvider` interface is + only for backends that accept username + password; an OIDC + adapter may need to extend the protocol or expose a separate + callback-handler hook. +- [ ] :meth:`get_user_by_id` — query local cache (auto-provisioned + OIDC users live in the local ``users`` table after first login). +- [ ] :meth:`sync_user_attributes` — pull latest profile from the IdP + on each successful login (department / title / email). +- [ ] :meth:`revoke_user` — call the IdP's account-disable API (e.g. + Keycloak's ``PUT /admin/realms/{realm}/users/{id}/disable``). +- [ ] State cache for OAuth ``state`` parameter (Redis, TTL 5 min). +- [ ] First-login user auto-provisioning policy: just-in-time + creation / reject / invite-only. + +Until then, configuring ``auth.provider: oidc-stub`` and starting +the server lets a smoke test confirm: + +1. The DI factory resolves the right class. +2. The route layer calls :meth:`authenticate` and surfaces 501. +3. ``auth_sessions.auth_provider='oidc-stub'`` would be written + for any session that *did* get created (none, in this stub). +""" + +from __future__ import annotations + +from .exceptions import ProviderNotImplemented +from .user import User + + +class StubOIDCProvider: + """Unimplemented OIDC adapter. All methods raise :class:`ProviderNotImplemented`. + + The class satisfies the :class:`AuthProvider` Protocol structurally + (it has ``name`` and the four methods), but every method body is + a deliberate failure. This is on purpose: configuring + ``auth.provider: oidc-stub`` and then successfully logging in would + be a security bug. + """ + + name = "oidc-stub" + + async def authenticate(self, *, username: str, password: str) -> User: + raise ProviderNotImplemented( + "OIDC authentication is not implemented yet. " + "Use auth.provider: local in agentkit.yaml, or implement " + "auth/providers/oidc.py (see the checklist in this module's " + "docstring)." + ) + + async def get_user_by_id(self, user_id: str) -> User | None: + raise ProviderNotImplemented( + "StubOIDCProvider.get_user_by_id is not implemented. " + "See auth/providers/oidc_stub.py for the future OIDC checklist." + ) + + async def sync_user_attributes(self, user_id: str) -> None: + raise ProviderNotImplemented( + "StubOIDCProvider.sync_user_attributes is not implemented. " + "See auth/providers/oidc_stub.py for the future OIDC checklist." + ) + + async def revoke_user(self, user_id: str) -> None: + raise ProviderNotImplemented( + "StubOIDCProvider.revoke_user is not implemented. " + "See auth/providers/oidc_stub.py for the future OIDC checklist." + ) diff --git a/src/agentkit/server/auth/providers/user.py b/src/agentkit/server/auth/providers/user.py new file mode 100644 index 0000000..d5227d0 --- /dev/null +++ b/src/agentkit/server/auth/providers/user.py @@ -0,0 +1,42 @@ +"""Provider-agnostic user data model. + +This is the value object the :class:`AuthProvider` Protocol returns +from :meth:`AuthProvider.authenticate` and :meth:`AuthProvider.get_user_by_id`. +It is intentionally NOT the SQLAlchemy ``UserModel`` ORM class so that +future IdP adapters (OIDC / SAML / LDAP) can construct one without +touching a DB. + +The route layer converts ``User`` → ``UserResponse`` (the public +API payload) at the boundary; ``User`` itself stays an internal +type. +""" + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, EmailStr + + +class User(BaseModel): + """A user as known to the auth subsystem. + + Distinct from the SQLAlchemy ``UserModel`` ORM class because: + - The auth provider may not have a local row at all (e.g. an + OIDC user logged in via SSO who hasn't been provisioned yet + — that's a future concern but the data model must allow it). + - IdP-sourced attributes (department / title) are not in the + SQLAlchemy model. + """ + + model_config = ConfigDict(extra="forbid") + + id: str + username: str + email: EmailStr + role: str = "member" + is_active: bool = True + is_terminal_authorized: bool = False + is_server_terminal_authorized: bool = False + created_at: str + updated_at: str + last_login_at: str | None = None + created_by: str | None = None diff --git a/tests/unit/auth/__init__.py b/tests/unit/auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/auth/providers/__init__.py b/tests/unit/auth/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/auth/providers/test_base.py b/tests/unit/auth/providers/test_base.py new file mode 100644 index 0000000..b32f3a9 --- /dev/null +++ b/tests/unit/auth/providers/test_base.py @@ -0,0 +1,112 @@ +"""Tests for the AuthProvider protocol and the DI factory (U11).""" + +from __future__ import annotations + +import pytest + +from agentkit.server.auth.providers import ( + AuthProvider, + LocalAuthProvider, + StubOIDCProvider, + get_auth_provider, + reset_auth_provider, +) +from agentkit.server.auth.providers.base import AuthProvider as AuthProviderFromBase +from agentkit.server.auth.providers.exceptions import ( + AuthProviderError, + InvalidCredentials, + ProviderNotImplemented, +) + + +class TestProtocolConformance: + """The runtime_checkable Protocol accepts both real implementations.""" + + def test_local_passes_isinstance_check(self): + provider = LocalAuthProvider() + assert isinstance(provider, AuthProvider) + assert isinstance(provider, AuthProviderFromBase) + + def test_stub_passes_isinstance_check(self): + provider = StubOIDCProvider() + assert isinstance(provider, AuthProvider) + assert isinstance(provider, AuthProviderFromBase) + + +class TestProviderNames: + def test_local_provider_name(self): + assert LocalAuthProvider.name == "local" + + def test_stub_oidc_provider_name(self): + assert StubOIDCProvider.name == "oidc-stub" + + +class TestDiFactory: + """``get_auth_provider`` is a memoized singleton, env-driven.""" + + def setup_method(self): + """Reset cache and env before each test.""" + reset_auth_provider() + self._saved_env = {} + for key in ("AGENTKIT_AUTH_PROVIDER",): + self._saved_env[key] = __import__("os").environ.pop(key, None) + + def teardown_method(self): + for key, value in self._saved_env.items(): + if value is not None: + __import__("os").environ[key] = value + reset_auth_provider() + + def test_default_provider_is_local(self): + provider = get_auth_provider() + assert isinstance(provider, LocalAuthProvider) + assert provider.name == "local" + + def test_oidc_stub_provider(self, monkeypatch): + monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub") + reset_auth_provider() + provider = get_auth_provider() + assert isinstance(provider, StubOIDCProvider) + assert provider.name == "oidc-stub" + + def test_unknown_provider_raises_value_error(self, monkeypatch): + monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "ldap-unknown") + reset_auth_provider() + with pytest.raises(ValueError, match="unknown auth provider"): + get_auth_provider() + + def test_factory_is_memoized(self): + first = get_auth_provider() + second = get_auth_provider() + assert first is second + + def test_reset_clears_cache(self, monkeypatch): + first = get_auth_provider() + reset_auth_provider() + monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub") + second = get_auth_provider() + assert first is not second + assert isinstance(second, StubOIDCProvider) + + +class TestExceptionHierarchy: + def test_invalid_credentials_inherits_auth_provider_error(self): + assert issubclass(InvalidCredentials, AuthProviderError) + assert issubclass(InvalidCredentials, Exception) + + def test_provider_not_implemented_inherits_auth_provider_error(self): + assert issubclass(ProviderNotImplemented, AuthProviderError) + assert issubclass(ProviderNotImplemented, Exception) + + def test_invalid_credentials_can_be_raised_and_caught(self): + with pytest.raises(InvalidCredentials): + raise InvalidCredentials("test message") + + def test_invalid_credentials_caught_as_base(self): + """Route layer catches AuthProviderError to handle all provider errors.""" + with pytest.raises(AuthProviderError): + raise InvalidCredentials("test") + + def test_provider_not_implemented_caught_as_base(self): + with pytest.raises(AuthProviderError): + raise ProviderNotImplemented("test") diff --git a/tests/unit/auth/providers/test_local.py b/tests/unit/auth/providers/test_local.py new file mode 100644 index 0000000..b726913 --- /dev/null +++ b/tests/unit/auth/providers/test_local.py @@ -0,0 +1,254 @@ +"""Tests for LocalAuthProvider (U11 — concrete Local implementation).""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from pathlib import Path + +import aiosqlite +import pytest + +from agentkit.server.auth.models import init_auth_db +from agentkit.server.auth.password import hash_password +from agentkit.server.auth.providers import LocalAuthProvider +from agentkit.server.auth.providers.exceptions import InvalidCredentials +from agentkit.server.auth.providers.user import User + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def auth_db_with_users(tmp_path: Path) -> dict: + """Create a fresh auth DB with two users: one active, one inactive. + + Returns a dict with user info + the db path. + """ + db_path = tmp_path / "auth.db" + await init_auth_db(db_path) + + now_iso = datetime.now(timezone.utc).isoformat() + active_id = str(uuid.uuid4()) + inactive_id = str(uuid.uuid4()) + active_pw = "correct-horse-battery-staple" + inactive_pw = "disabled-user-pw" + + async with aiosqlite.connect(str(db_path)) as db: + await db.execute( + "INSERT INTO users " + "(id, username, email, password_hash, role, is_active, " + " is_terminal_authorized, is_server_terminal_authorized, " + " created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + active_id, + "alice", + "alice@example.com", + hash_password(active_pw), + "member", + 1, + 0, + 0, + now_iso, + now_iso, + ), + ) + await db.execute( + "INSERT INTO users " + "(id, username, email, password_hash, role, is_active, " + " is_terminal_authorized, is_server_terminal_authorized, " + " created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + inactive_id, + "bob_inactive", + "bob@example.com", + hash_password(inactive_pw), + "member", + 0, + 0, + 0, + now_iso, + now_iso, + ), + ) + await db.commit() + + return { + "db_path": db_path, + "active": { + "id": active_id, + "username": "alice", + "password": active_pw, + "email": "alice@example.com", + "role": "member", + }, + "inactive": { + "id": inactive_id, + "username": "bob_inactive", + "password": inactive_pw, + }, + } + + +@pytest.fixture +def provider(auth_db_with_users: dict) -> LocalAuthProvider: + return LocalAuthProvider(db_path=auth_db_with_users["db_path"]) + + +# --------------------------------------------------------------------------- +# authenticate +# --------------------------------------------------------------------------- + + +class TestAuthenticate: + async def test_valid_credentials_returns_user( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + user = await provider.authenticate( + username=auth_db_with_users["active"]["username"], + password=auth_db_with_users["active"]["password"], + ) + assert isinstance(user, User) + assert user.id == auth_db_with_users["active"]["id"] + assert user.username == "alice" + assert user.email == "alice@example.com" + assert user.role == "member" + assert user.is_active is True + + async def test_wrong_password_raises_invalid_credentials( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + with pytest.raises(InvalidCredentials): + await provider.authenticate( + username=auth_db_with_users["active"]["username"], + password="definitely-not-the-password", + ) + + async def test_unknown_user_raises_invalid_credentials(self, provider: LocalAuthProvider): + with pytest.raises(InvalidCredentials): + await provider.authenticate(username="nonexistent", password="anything") + + async def test_inactive_user_raises_invalid_credentials( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + with pytest.raises(InvalidCredentials): + await provider.authenticate( + username=auth_db_with_users["inactive"]["username"], + password=auth_db_with_users["inactive"]["password"], + ) + + async def test_error_message_does_not_leak_username_existence( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + """Wrong-password and unknown-user errors must have the same message.""" + try: + await provider.authenticate( + username=auth_db_with_users["active"]["username"], + password="wrong", + ) + except InvalidCredentials as e1: + wrong_pw_msg = str(e1) + try: + await provider.authenticate(username="nobody-here", password="x") + except InvalidCredentials as e2: + unknown_msg = str(e2) + assert wrong_pw_msg == unknown_msg + + +# --------------------------------------------------------------------------- +# get_user_by_id +# --------------------------------------------------------------------------- + + +class TestGetUserById: + async def test_returns_user_for_active_id( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + user = await provider.get_user_by_id(auth_db_with_users["active"]["id"]) + assert user is not None + assert user.id == auth_db_with_users["active"]["id"] + assert user.username == "alice" + + async def test_returns_none_for_inactive_user( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + user = await provider.get_user_by_id(auth_db_with_users["inactive"]["id"]) + assert user is None + + async def test_returns_none_for_unknown_id(self, provider: LocalAuthProvider): + user = await provider.get_user_by_id(str(uuid.uuid4())) + assert user is None + + +# --------------------------------------------------------------------------- +# sync_user_attributes +# --------------------------------------------------------------------------- + + +class TestSyncUserAttributes: + async def test_is_noop(self, provider: LocalAuthProvider, auth_db_with_users: dict): + """Local provider has no upstream to sync from — must succeed with no effect.""" + result = await provider.sync_user_attributes(auth_db_with_users["active"]["id"]) + assert result is None + + async def test_noop_does_not_throw_for_unknown_user(self, provider: LocalAuthProvider): + """sync_user_attributes is fire-and-forget — no existence check.""" + # Should NOT raise even though the id doesn't exist + await provider.sync_user_attributes(str(uuid.uuid4())) + + +# --------------------------------------------------------------------------- +# revoke_user +# --------------------------------------------------------------------------- + + +class TestRevokeUser: + async def test_sets_is_active_to_zero( + self, provider: LocalAuthProvider, auth_db_with_users: dict, tmp_path: Path + ): + await provider.revoke_user(auth_db_with_users["active"]["id"]) + async with aiosqlite.connect(str(auth_db_with_users["db_path"])) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT is_active FROM users WHERE id = ?", + (auth_db_with_users["active"]["id"],), + ) + row = await cursor.fetchone() + assert bool(row["is_active"]) is False + + async def test_revoked_user_can_no_longer_authenticate( + self, provider: LocalAuthProvider, auth_db_with_users: dict + ): + await provider.revoke_user(auth_db_with_users["active"]["id"]) + with pytest.raises(InvalidCredentials): + await provider.authenticate( + username=auth_db_with_users["active"]["username"], + password=auth_db_with_users["active"]["password"], + ) + + async def test_revoke_unknown_user_does_not_raise(self, provider: LocalAuthProvider): + """``UPDATE ... WHERE id = ?`` with no match is a no-op, not an error.""" + await provider.revoke_user(str(uuid.uuid4())) + + +# --------------------------------------------------------------------------- +# default db_path +# --------------------------------------------------------------------------- + + +class TestDefaultDbPath: + def test_default_db_path_uses_models_default(self, tmp_path: Path, monkeypatch): + """If no db_path is given, the provider should use the module default. + + The default may resolve to a path that does not exist on a test + machine — we only assert that the property returns a Path, not + that the file exists. + """ + monkeypatch.setenv("AGENTKIT_AUTH_DB", str(tmp_path / "default.db")) + provider = LocalAuthProvider() + assert isinstance(provider.db_path, Path) + assert str(provider.db_path) == str(tmp_path / "default.db") diff --git a/tests/unit/auth/providers/test_oidc_stub.py b/tests/unit/auth/providers/test_oidc_stub.py new file mode 100644 index 0000000..1fd70c1 --- /dev/null +++ b/tests/unit/auth/providers/test_oidc_stub.py @@ -0,0 +1,69 @@ +"""Tests for StubOIDCProvider (U11 — interface placeholder).""" + +from __future__ import annotations + +import pytest + +from agentkit.server.auth.providers import StubOIDCProvider +from agentkit.server.auth.providers.exceptions import ( + AuthProviderError, + ProviderNotImplemented, +) + + +class TestStubAuthenticate: + async def test_raises_provider_not_implemented(self): + provider = StubOIDCProvider() + with pytest.raises(ProviderNotImplemented): + await provider.authenticate(username="alice", password="hunter2") + + async def test_error_message_mentions_oidc(self): + """The error message must guide the operator to the right next step.""" + provider = StubOIDCProvider() + with pytest.raises(ProviderNotImplemented) as exc: + await provider.authenticate(username="alice", password="x") + msg = str(exc.value) + assert "OIDC" in msg + assert "not implemented" in msg.lower() + + +class TestStubGetUserById: + async def test_raises_provider_not_implemented(self): + provider = StubOIDCProvider() + with pytest.raises(ProviderNotImplemented): + await provider.get_user_by_id("any-id") + + +class TestStubSyncUserAttributes: + async def test_raises_provider_not_implemented(self): + provider = StubOIDCProvider() + with pytest.raises(ProviderNotImplemented): + await provider.sync_user_attributes("any-id") + + +class TestStubRevokeUser: + async def test_raises_provider_not_implemented(self): + provider = StubOIDCProvider() + with pytest.raises(ProviderNotImplemented): + await provider.revoke_user("any-id") + + +class TestAllErrorsAreAuthProviderError: + """Every stub method's error should be catchable as AuthProviderError.""" + + @pytest.mark.parametrize( + "method,call_args", + [ + ("authenticate", ()), + ("get_user_by_id", ("id",)), + ("sync_user_attributes", ("id",)), + ("revoke_user", ("id",)), + ], + ) + async def test_catchable_as_base(self, method: str, call_args: tuple): + provider = StubOIDCProvider() + with pytest.raises(AuthProviderError): + if method == "authenticate": + await provider.authenticate(username="u", password="p") + else: + await getattr(provider, method)(*call_args) diff --git a/tests/unit/auth/test_models.py b/tests/unit/auth/test_models.py new file mode 100644 index 0000000..3b399ff --- /dev/null +++ b/tests/unit/auth/test_models.py @@ -0,0 +1,653 @@ +"""Unit tests for auth.models (U1 — V2 schema + backfill). + +Covers: +- ``auth_sessions`` table creation (columns + indexes) +- ``auth_meta`` table creation +- ``_SCHEMA_VERSION`` constant value +- ``auth_session_row_to_dict`` field-by-field conversion +- ``_backfill_user_sessions`` one-time migration (V1 → V2) +- ``init_auth_db`` idempotency (subsequent runs are no-ops) +- ``auth_provider`` column default value +- Index presence (verified via ``PRAGMA index_list``) +""" + +from __future__ import annotations + +import json +import uuid +from datetime import datetime, timezone +from pathlib import Path + +import aiosqlite +import pytest + +from agentkit.server.auth.models import ( + AuthSessionModel, + UserSessionModel, + _SCHEMA_VERSION, + auth_session_row_to_dict, + init_auth_db, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +async def fresh_db(tmp_path: Path) -> Path: + """A brand-new auth DB on a fresh path (no data).""" + db_path = tmp_path / "auth.db" + await init_auth_db(db_path) + return db_path + + +async def _insert_user(db: aiosqlite.Connection, *, user_id: str | None = None) -> str: + """Insert a minimal user row and return its id.""" + user_id = user_id or str(uuid.uuid4()) + now_iso = datetime.now(timezone.utc).isoformat() + await db.execute( + "INSERT INTO users " + "(id, username, email, password_hash, role, is_active, " + " is_terminal_authorized, is_server_terminal_authorized, " + " created_at, updated_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + user_id, + f"user-{user_id[:8]}", + f"{user_id[:8]}@example.com", + "$2b$12$placeholder.hash.placeholder.hash.placeholder.hash", + "member", + 1, + 0, + 0, + now_iso, + now_iso, + ), + ) + return user_id + + +async def _insert_user_session( + db: aiosqlite.Connection, + *, + user_id: str, + session_id: str | None = None, + refresh_hash: str | None = None, + device_info: str = "{}", + revoked: bool = False, +) -> str: + """Insert a V1 ``user_sessions`` row and return its id.""" + session_id = session_id or str(uuid.uuid4()) + refresh_hash = refresh_hash or uuid.uuid4().hex + now_iso = datetime.now(timezone.utc).isoformat() + await db.execute( + "INSERT INTO user_sessions " + "(id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + ( + session_id, + user_id, + refresh_hash, + device_info, + now_iso, + now_iso, + now_iso if revoked else None, + ), + ) + return session_id + + +async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]: + """Return the set of index names for a table. + + Sets ``row_factory`` on the connection so we can address columns by name. + PRAGMA results in aiosqlite come back as plain tuples unless a row factory + is configured. + """ + db.row_factory = aiosqlite.Row + cursor = await db.execute(f"PRAGMA index_list({table})") + rows = await cursor.fetchall() + return {row["name"] for row in rows} + + +# --------------------------------------------------------------------------- +# _SCHEMA_VERSION +# --------------------------------------------------------------------------- + + +class TestSchemaVersion: + def test_schema_version_is_v2(self): + """The current schema version is 2 (V2 adds auth_sessions + auth_meta).""" + assert _SCHEMA_VERSION == 2 + + def test_sqlalchemy_model_table_name(self): + assert AuthSessionModel.__tablename__ == "auth_sessions" + assert UserSessionModel.__tablename__ == "user_sessions" + + +# --------------------------------------------------------------------------- +# init_auth_db: tables + indexes +# --------------------------------------------------------------------------- + + +class TestInitAuthDbTables: + async def test_creates_auth_sessions_table(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + cursor = await db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='auth_sessions'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_creates_auth_meta_table(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + cursor = await db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='auth_meta'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_creates_user_sessions_table_for_back_compat(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + cursor = await db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='user_sessions'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_records_schema_version_in_auth_meta(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'") + row = await cursor.fetchone() + assert row is not None + assert row["value"] == str(_SCHEMA_VERSION) + + +class TestAuthSessionsIndexes: + async def test_user_id_active_index(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)): + pass + async with aiosqlite.connect(str(fresh_db)) as db: + indexes = await _list_index_names(db, "auth_sessions") + assert "idx_auth_sessions_user_id_active" in indexes + + async def test_expires_at_index(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + indexes = await _list_index_names(db, "auth_sessions") + assert "idx_auth_sessions_expires_at" in indexes + + async def test_auth_provider_index(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + indexes = await _list_index_names(db, "auth_sessions") + assert "idx_auth_sessions_auth_provider" in indexes + + async def test_refresh_token_hash_unique_index(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + # SQLite stores column-level UNIQUE constraints as PRAGMA index_list + # entries with auto-generated names like sqlite_autoindex__1. + # The PRAGMA index_info on each autoindex reports the constrained columns. + cursor = await db.execute("PRAGMA index_list(auth_sessions)") + index_entries = await cursor.fetchall() + column_names: set[str] = set() + for entry in index_entries: + col_cursor = await db.execute(f"PRAGMA index_info({entry['name']})") + col_rows = await col_cursor.fetchall() + for col_row in col_rows: + column_names.add(col_row["name"]) + assert "refresh_token_hash" in column_names + + +# --------------------------------------------------------------------------- +# auth_sessions columns +# --------------------------------------------------------------------------- + + +class TestAuthSessionsColumns: + async def test_required_columns_present(self, fresh_db: Path): + async with aiosqlite.connect(str(fresh_db)) as db: + cursor = await db.execute("PRAGMA table_info(auth_sessions)") + cols = {row[1] for row in await cursor.fetchall()} + required = { + "id", + "user_id", + "refresh_token_hash", + "device_fingerprint", + "device_label", + "ip", + "user_agent", + "auth_provider", + "created_at", + "last_active_at", + "expires_at", + "revoked", + "revoked_reason", + "previous_session_id", + } + assert required.issubset(cols) + + async def test_auth_provider_default_is_local(self, fresh_db: Path): + """Insert a row without auth_provider → column should default to 'local'.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + sid = str(uuid.uuid4()) + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, " + " revoked) " + "VALUES (?, ?, ?, ?, ?, ?, 0)", + ( + sid, + user_id, + "deadbeef", + "2026-01-01T00:00:00+00:00", + "2026-01-01T00:00:00+00:00", + "2026-12-31T00:00:00+00:00", + ), + ) + await db.commit() + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT auth_provider FROM auth_sessions WHERE id=?", (sid,)) + row = await cursor.fetchone() + assert row is not None + assert row["auth_provider"] == "local" + + async def test_revoked_default_is_false(self, fresh_db: Path): + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + sid = str(uuid.uuid4()) + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + ( + sid, + user_id, + "beefdead", + "2026-01-01T00:00:00+00:00", + "2026-01-01T00:00:00+00:00", + "2026-12-31T00:00:00+00:00", + ), + ) + await db.commit() + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT revoked FROM auth_sessions WHERE id=?", (sid,)) + row = await cursor.fetchone() + assert bool(row["revoked"]) is False + + +# --------------------------------------------------------------------------- +# auth_session_row_to_dict +# --------------------------------------------------------------------------- + + +class TestAuthSessionRowToDict: + async def test_converts_all_fields(self, fresh_db: Path): + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + sid = str(uuid.uuid4()) + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, device_fingerprint, device_label, ip, " + " user_agent, auth_provider, created_at, last_active_at, expires_at, " + " revoked, revoked_reason, previous_session_id) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + sid, + user_id, + "abc123", + "fp-xyz", + "macOS Chrome 119", + "10.0.0.1", + "Mozilla/5.0", + "local", + "2026-01-01T00:00:00+00:00", + "2026-06-20T00:00:00+00:00", + "2027-01-01T00:00:00+00:00", + 1, + "user_terminated", + "old-sid", + ), + ) + await db.commit() + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,)) + row = await cursor.fetchone() + + d = auth_session_row_to_dict(row) + assert d["id"] == sid + assert d["user_id"] == user_id + assert d["device_fingerprint"] == "fp-xyz" + assert d["device_label"] == "macOS Chrome 119" + assert d["ip"] == "10.0.0.1" + assert d["user_agent"] == "Mozilla/5.0" + assert d["auth_provider"] == "local" + assert d["revoked"] is True + assert d["revoked_reason"] == "user_terminated" + assert d["previous_session_id"] == "old-sid" + + async def test_normalizes_revoked_to_bool(self, fresh_db: Path): + """DB stores 0/1; helper should return Python bool.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + sid = str(uuid.uuid4()) + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, " + " revoked) " + "VALUES (?, ?, ?, ?, ?, ?, 0)", + (sid, user_id, "x", "t", "t", "t"), + ) + await db.commit() + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,)) + row = await cursor.fetchone() + d = auth_session_row_to_dict(row) + assert isinstance(d["revoked"], bool) + assert d["revoked"] is False + + +# --------------------------------------------------------------------------- +# _backfill_user_sessions +# --------------------------------------------------------------------------- + + +class TestBackfillUserSessions: + async def test_backfills_non_revoked_v1_rows(self, fresh_db: Path): + """On a fresh V1 install, all non-revoked rows are migrated to auth_sessions.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="hash1", + device_info="{}", + ) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="hash2", + device_info=json.dumps( + { + "fingerprint": "mac-tauri", + "label": "macOS Tauri 1.0", + "ip": "192.168.1.10", + "user_agent": "Tauri/1.0", + } + ), + ) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="hash3", + revoked=True, + ) + await db.commit() + # Force a backfill by clearing the marker + dropping auth_sessions rows + # (simulating a V1 install upgrading to V2). + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + # Re-init to trigger the migration + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute("SELECT * FROM auth_sessions ORDER BY refresh_token_hash") + rows = await cursor.fetchall() + + # Only the 2 non-revoked rows should have been backfilled + assert len(rows) == 2 + hashes = {row["refresh_token_hash"] for row in rows} + assert hashes == {"hash1", "hash2"} + + async def test_backfill_preserves_original_id(self, fresh_db: Path): + """Backfilled rows reuse the V1 id so legacy clients match.""" + user_id = str(uuid.uuid4()) + original_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session( + db, + user_id=user_id, + session_id=original_id, + refresh_hash="hash-original", + ) + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT id FROM auth_sessions WHERE refresh_token_hash='hash-original'" + ) + row = await cursor.fetchone() + assert row is not None + assert row["id"] == original_id + + async def test_backfill_does_not_touch_revoked_v1_rows(self, fresh_db: Path): + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="revoked-hash", + revoked=True, + ) + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT COUNT(*) AS c FROM auth_sessions WHERE refresh_token_hash='revoked-hash'" + ) + row = await cursor.fetchone() + assert row["c"] == 0 + + async def test_backfill_copies_device_info_fields(self, fresh_db: Path): + user_id = str(uuid.uuid4()) + device_info = json.dumps( + { + "fingerprint": "win-tauri-abc", + "label": "Windows Tauri 1.0", + "ip": "10.0.0.5", + "user_agent": "Tauri/2.0", + } + ) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="hash-with-device", + device_info=device_info, + ) + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT device_fingerprint, device_label, ip, user_agent " + "FROM auth_sessions WHERE refresh_token_hash='hash-with-device'" + ) + row = await cursor.fetchone() + assert row["device_fingerprint"] == "win-tauri-abc" + assert row["device_label"] == "Windows Tauri 1.0" + assert row["ip"] == "10.0.0.5" + assert row["user_agent"] == "Tauri/2.0" + + async def test_backfill_handles_malformed_device_info(self, fresh_db: Path): + """Malformed JSON in device_info → fall back to defaults.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session( + db, + user_id=user_id, + refresh_hash="bad-json", + device_info="not-json{", + ) + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT device_fingerprint, device_label FROM auth_sessions " + "WHERE refresh_token_hash='bad-json'" + ) + row = await cursor.fetchone() + assert row is not None + assert row["device_fingerprint"] == "unknown" + assert row["device_label"] == "Unknown device" + + async def test_backfill_marks_done_in_auth_meta(self, fresh_db: Path): + """After a backfill, the auth_meta marker is set so it doesn't run again.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await _insert_user_session(db, user_id=user_id, refresh_hash="h1") + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'" + ) + row = await cursor.fetchone() + assert row is not None + assert row["value"] == "done" + + async def test_backfill_is_idempotent(self, fresh_db: Path): + """Running init twice does NOT duplicate auth_sessions rows.""" + user_id = str(uuid.uuid4()) + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + for i in range(3): + await _insert_user_session(db, user_id=user_id, refresh_hash=f"hash{i}") + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + await init_auth_db(fresh_db) # second run should be a no-op + + async with aiosqlite.connect(str(fresh_db)) as db: + cursor = await db.execute("SELECT COUNT(*) AS c FROM auth_sessions") + row = await cursor.fetchone() + assert row[0] == 3 + + async def test_fresh_install_marks_backfill_done_without_rows(self, fresh_db: Path): + """A fresh V2 install (no V1 rows) still marks the backfill as done.""" + # The fresh_db fixture already ran init_auth_db once. + # Re-running should be a no-op (idempotent). + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'" + ) + row = await cursor.fetchone() + assert row is not None + assert row["value"] == "done" + + async def test_backfill_preserves_expires_at(self, fresh_db: Path): + user_id = str(uuid.uuid4()) + original_exp = "2027-06-20T12:34:56+00:00" + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + await db.execute( + "INSERT INTO user_sessions " + "(id, user_id, refresh_token_hash, device_info, created_at, expires_at) " + "VALUES (?, ?, ?, ?, ?, ?)", + ( + str(uuid.uuid4()), + user_id, + "exp-hash", + "{}", + "2026-01-01T00:00:00+00:00", + original_exp, + ), + ) + await db.execute("DELETE FROM auth_sessions") + await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'") + await db.commit() + + await init_auth_db(fresh_db) + + async with aiosqlite.connect(str(fresh_db)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT expires_at FROM auth_sessions WHERE refresh_token_hash='exp-hash'" + ) + row = await cursor.fetchone() + assert row["expires_at"] == original_exp + + +# --------------------------------------------------------------------------- +# Active-session query pattern (covers the cap-count and list-active paths) +# --------------------------------------------------------------------------- + + +class TestActiveSessionQueries: + async def test_query_active_sessions_for_user(self, fresh_db: Path): + """The (user_id, revoked, expires_at) index supports the active-session query.""" + user_id = str(uuid.uuid4()) + future = "2027-12-31T00:00:00+00:00" + past = "2020-01-01T00:00:00+00:00" + async with aiosqlite.connect(str(fresh_db)) as db: + await _insert_user(db, user_id=user_id) + # 2 active, 1 expired, 1 revoked → 2 should match + for hash_, exp in [("a", future), ("b", future), ("c", past)]: + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, " + " revoked) VALUES (?, ?, ?, ?, ?, ?, 0)", + (str(uuid.uuid4()), user_id, hash_, "t", "t", exp), + ) + await db.execute( + "INSERT INTO auth_sessions " + "(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, " + " revoked, revoked_reason) " + "VALUES (?, ?, ?, ?, ?, ?, 1, 'user_terminated')", + (str(uuid.uuid4()), user_id, "d", "t", "t", future), + ) + await db.commit() + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT refresh_token_hash FROM auth_sessions " + "WHERE user_id = ? AND revoked = 0 AND expires_at > ?", + (user_id, "2026-06-20T00:00:00+00:00"), + ) + rows = await cursor.fetchall() + hashes = {row["refresh_token_hash"] for row in rows} + assert hashes == {"a", "b"}