merge: 引入 U11 AuthProvider 抽象层到客户端持久化分支
This commit is contained in:
commit
d42c45e5ad
|
|
@ -5,10 +5,23 @@ stored as ``String(36)`` so the same schema works on both SQLite and
|
||||||
PostgreSQL without dialect-specific types.
|
PostgreSQL without dialect-specific types.
|
||||||
|
|
||||||
Use :func:`init_auth_db` to create the tables on startup.
|
Use :func:`init_auth_db` to create the tables on startup.
|
||||||
|
|
||||||
|
Schema versioning
|
||||||
|
-----------------
|
||||||
|
The :data:`_SCHEMA_VERSION` constant tracks the current auth DB schema. The
|
||||||
|
:func:`_backfill_user_sessions` one-time migration is gated on this version
|
||||||
|
to ensure idempotency. After a successful backfill the version is stored
|
||||||
|
in the ``auth_meta`` table so subsequent restarts are no-ops.
|
||||||
|
|
||||||
|
V2 additions (2026-06-20, Centralized Auth & Token Persistence):
|
||||||
|
- New ``auth_sessions`` table with full device/IP/audit metadata and an
|
||||||
|
``auth_provider`` column for future IdP integration traceability.
|
||||||
|
- New ``auth_meta`` table for storing schema version + migration state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
@ -88,10 +101,16 @@ class UserApiKeyModel(Base):
|
||||||
|
|
||||||
|
|
||||||
class UserSessionModel(Base):
|
class UserSessionModel(Base):
|
||||||
"""Refresh-token session record.
|
"""Refresh-token session record (V1, deprecated — see :class:`AuthSessionModel`).
|
||||||
|
|
||||||
Stores the SHA-256 hash of the refresh token (never the plaintext).
|
Stores the SHA-256 hash of the refresh token (never the plaintext).
|
||||||
``revoked_at`` is set on logout / forced revocation.
|
``revoked_at`` is set on logout / forced revocation.
|
||||||
|
|
||||||
|
.. deprecated::
|
||||||
|
Kept for one minor version (per U10 back-compat shim) so legacy
|
||||||
|
clients holding JWTs without ``sid`` claim can still validate.
|
||||||
|
New code should use :class:`AuthSessionModel` (table ``auth_sessions``)
|
||||||
|
which carries full device/IP/audit metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "user_sessions"
|
__tablename__ = "user_sessions"
|
||||||
|
|
@ -107,6 +126,50 @@ class UserSessionModel(Base):
|
||||||
revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSessionModel(Base):
|
||||||
|
"""Server-side session record (V2, the primary session table going forward).
|
||||||
|
|
||||||
|
Each row corresponds to a single refresh-token issuance. The full JWT
|
||||||
|
session id (``sid`` claim) is the row's ``id`` (UUID string).
|
||||||
|
|
||||||
|
V2 fields (vs V1 ``user_sessions``):
|
||||||
|
- ``device_fingerprint`` / ``device_label``: surfaces "which device is
|
||||||
|
this session" in the admin UI.
|
||||||
|
- ``ip`` / ``user_agent``: audit trail.
|
||||||
|
- ``last_active_at``: updated on every successful refresh, used to
|
||||||
|
display "last seen" in the sessions list.
|
||||||
|
- ``revoked`` (0/1) + ``revoked_reason``: explicit revoked-state machine
|
||||||
|
with machine-readable reasons (``user_terminated``, ``password_changed``,
|
||||||
|
``admin_revoked``, ``reuse_detected``, ``session_cap_eviction``).
|
||||||
|
- ``previous_session_id``: back-pointer to the previous session id,
|
||||||
|
written on refresh rotation, for audit trail.
|
||||||
|
- ``auth_provider``: the AuthProvider that issued this session
|
||||||
|
(``local`` / ``oidc-stub`` / future ``oidc-keycloak`` / ``saml`` / ``ldap``).
|
||||||
|
Enables admin "list sessions by provider" queries and audit traceability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "auth_sessions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||||
|
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
|
||||||
|
refresh_token_hash: Mapped[str] = mapped_column(
|
||||||
|
String(64), unique=True, nullable=False, index=True
|
||||||
|
)
|
||||||
|
device_fingerprint: Mapped[str] = mapped_column(String(128), nullable=False, default="unknown")
|
||||||
|
device_label: Mapped[str] = mapped_column(String(256), nullable=False, default="Unknown device")
|
||||||
|
ip: Mapped[str] = mapped_column(String(64), nullable=False, default="")
|
||||||
|
user_agent: Mapped[str] = mapped_column(String(512), nullable=False, default="")
|
||||||
|
auth_provider: Mapped[str] = mapped_column(
|
||||||
|
String(32), nullable=False, default="local", index=True
|
||||||
|
)
|
||||||
|
created_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||||
|
last_active_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||||
|
expires_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
revoked: Mapped[bool] = mapped_column(default=False, nullable=False, index=True)
|
||||||
|
revoked_reason: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
previous_session_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
class TerminalWhitelistUserModel(Base):
|
class TerminalWhitelistUserModel(Base):
|
||||||
"""Per-user terminal command whitelist.
|
"""Per-user terminal command whitelist.
|
||||||
|
|
||||||
|
|
@ -244,6 +307,44 @@ CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash
|
CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash
|
||||||
ON user_sessions(refresh_token_hash);
|
ON user_sessions(refresh_token_hash);
|
||||||
|
|
||||||
|
-- V2: auth_sessions replaces user_sessions as the primary session table.
|
||||||
|
-- Stores device/IP/audit metadata and auth_provider for IdP traceability.
|
||||||
|
-- Per-row indexes are sized for the most common access patterns:
|
||||||
|
-- * (user_id, revoked, expires_at) — cap-count, list-active, refresh-validate
|
||||||
|
-- * expires_at — cleanup sweeps
|
||||||
|
-- * refresh_token_hash — uniqueness + fast lookup on /auth/refresh
|
||||||
|
-- * auth_provider — admin "list sessions by provider" queries
|
||||||
|
CREATE TABLE IF NOT EXISTS auth_sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
refresh_token_hash TEXT NOT NULL UNIQUE,
|
||||||
|
device_fingerprint TEXT NOT NULL DEFAULT 'unknown',
|
||||||
|
device_label TEXT NOT NULL DEFAULT 'Unknown device',
|
||||||
|
ip TEXT NOT NULL DEFAULT '',
|
||||||
|
user_agent TEXT NOT NULL DEFAULT '',
|
||||||
|
auth_provider TEXT NOT NULL DEFAULT 'local',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
last_active_at TEXT NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
revoked INTEGER NOT NULL DEFAULT 0,
|
||||||
|
revoked_reason TEXT,
|
||||||
|
previous_session_id TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id_active
|
||||||
|
ON auth_sessions(user_id, revoked, expires_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at
|
||||||
|
ON auth_sessions(expires_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_auth_sessions_auth_provider
|
||||||
|
ON auth_sessions(auth_provider);
|
||||||
|
|
||||||
|
-- V2: auth_meta stores schema version + migration completion markers.
|
||||||
|
-- Used by init_auth_db to gate one-time migrations (idempotency).
|
||||||
|
CREATE TABLE IF NOT EXISTS auth_meta (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS terminal_whitelist_user (
|
CREATE TABLE IF NOT EXISTS terminal_whitelist_user (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
|
|
@ -307,12 +408,137 @@ CREATE INDEX IF NOT EXISTS idx_terminal_approvals_status
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema versioning + one-time migrations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Current auth DB schema version. Bump this when adding new tables/columns
|
||||||
|
# that require data backfill or migration. The :func:`init_auth_db` function
|
||||||
|
# uses this together with the ``auth_meta.schema_version`` row to decide
|
||||||
|
# which migrations to run.
|
||||||
|
_SCHEMA_VERSION = 2
|
||||||
|
|
||||||
|
_META_SCHEMA_VERSION_KEY = "schema_version"
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_meta_value(db: aiosqlite.Connection, key: str) -> str | None:
|
||||||
|
"""Read a key from the ``auth_meta`` table. Returns ``None`` if missing."""
|
||||||
|
cursor = await db.execute("SELECT value FROM auth_meta WHERE key = ?", (key,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row["value"] if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def _set_meta_value(db: aiosqlite.Connection, key: str, value: str) -> None:
|
||||||
|
"""Upsert a key in the ``auth_meta`` table."""
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_meta (key, value, updated_at) VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
|
||||||
|
(key, value, _now_iso()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _backfill_user_sessions(db: aiosqlite.Connection) -> int:
|
||||||
|
"""One-time backfill from ``user_sessions`` (V1) to ``auth_sessions`` (V2).
|
||||||
|
|
||||||
|
Runs only when ``auth_sessions`` is empty AND ``user_sessions`` has rows.
|
||||||
|
Idempotent: subsequent restarts are no-ops because we mark the backfill
|
||||||
|
as completed in ``auth_meta``.
|
||||||
|
|
||||||
|
For each non-revoked V1 session, copies:
|
||||||
|
- id (reused — see note below)
|
||||||
|
- user_id
|
||||||
|
- refresh_token_hash
|
||||||
|
- device_fingerprint / device_label / ip / user_agent from the legacy
|
||||||
|
``device_info`` JSON blob (best-effort)
|
||||||
|
- created_at / expires_at
|
||||||
|
- last_active_at defaults to created_at
|
||||||
|
- revoked=0 (already filtered)
|
||||||
|
- revoked_reason=None
|
||||||
|
- auth_provider='local' (default; backfilled rows are pre-IdP)
|
||||||
|
|
||||||
|
The original ``id`` is preserved so that legacy clients holding the
|
||||||
|
old refresh_token_hash still match a row in the new table — this is
|
||||||
|
what the back-compat path in U10 (``get_current_user`` for legacy
|
||||||
|
JWTs) relies on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows backfilled (0 if already done or nothing to backfill).
|
||||||
|
"""
|
||||||
|
# Idempotency: check the marker
|
||||||
|
if await _get_meta_value(db, "backfill_user_sessions_v1_to_v2") == "done":
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cursor = await db.execute("SELECT COUNT(*) FROM auth_sessions")
|
||||||
|
(count,) = await cursor.fetchone()
|
||||||
|
if count > 0:
|
||||||
|
# auth_sessions already has data — this is a fresh V2 install, not an
|
||||||
|
# upgrade. Mark the backfill done so we never re-check.
|
||||||
|
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
|
||||||
|
await db.commit()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at "
|
||||||
|
"FROM user_sessions WHERE revoked_at IS NULL"
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
backfilled = 0
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
device_info = json.loads(row["device_info"]) if row["device_info"] else {}
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
device_info = {}
|
||||||
|
|
||||||
|
await db.execute(
|
||||||
|
"INSERT OR IGNORE INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, "
|
||||||
|
" ip, user_agent, auth_provider, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked, revoked_reason, previous_session_id) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
row["id"], # reuse legacy id for back-compat with old clients
|
||||||
|
row["user_id"],
|
||||||
|
row["refresh_token_hash"],
|
||||||
|
device_info.get("fingerprint", "unknown"),
|
||||||
|
device_info.get("label", "Unknown device"),
|
||||||
|
device_info.get("ip", ""),
|
||||||
|
device_info.get("user_agent", ""),
|
||||||
|
"local", # backfilled rows are pre-IdP by definition
|
||||||
|
row["created_at"],
|
||||||
|
row["created_at"], # last_active_at defaults to created_at
|
||||||
|
row["expires_at"],
|
||||||
|
0, # not revoked (already filtered)
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
backfilled += 1
|
||||||
|
|
||||||
|
if backfilled:
|
||||||
|
logger.info(
|
||||||
|
f"Backfilled {backfilled} user_sessions rows to auth_sessions "
|
||||||
|
f"(schema v{_SCHEMA_VERSION})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark the backfill as completed regardless of how many rows were moved.
|
||||||
|
# (idempotency: even a 0-row backfill is "done".)
|
||||||
|
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
|
||||||
|
await db.commit()
|
||||||
|
return backfilled
|
||||||
|
|
||||||
|
|
||||||
async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
||||||
"""Create auth tables if they do not exist.
|
"""Create auth tables if they do not exist.
|
||||||
|
|
||||||
Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight,
|
Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight,
|
||||||
zero-config bootstrap that mirrors :class:`SqliteConversationStore`.
|
zero-config bootstrap that mirrors :class:`SqliteConversationStore`.
|
||||||
|
|
||||||
|
On startup, this function:
|
||||||
|
1. Creates all tables and indexes from :data:`_SCHEMA_SQL` (idempotent).
|
||||||
|
2. Records the current :data:`_SCHEMA_VERSION` in ``auth_meta``.
|
||||||
|
3. Runs any pending one-time migrations (currently: V1 → V2 backfill
|
||||||
|
from ``user_sessions`` to ``auth_sessions``).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_path: Path to the SQLite file. Defaults to
|
db_path: Path to the SQLite file. Defaults to
|
||||||
:data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project
|
:data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project
|
||||||
|
|
@ -325,8 +551,19 @@ async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
async with aiosqlite.connect(str(path)) as db:
|
async with aiosqlite.connect(str(path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
await db.execute("PRAGMA journal_mode=WAL")
|
await db.execute("PRAGMA journal_mode=WAL")
|
||||||
await db.executescript(_SCHEMA_SQL)
|
await db.executescript(_SCHEMA_SQL)
|
||||||
|
|
||||||
|
# Record the current schema version (idempotent upsert).
|
||||||
|
current = await _get_meta_value(db, _META_SCHEMA_VERSION_KEY)
|
||||||
|
if current != str(_SCHEMA_VERSION):
|
||||||
|
await _set_meta_value(db, _META_SCHEMA_VERSION_KEY, str(_SCHEMA_VERSION))
|
||||||
|
logger.info(f"Auth DB schema version set to {_SCHEMA_VERSION}")
|
||||||
|
|
||||||
|
# Run pending migrations (each is internally idempotent).
|
||||||
|
await _backfill_user_sessions(db)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
logger.info(f"Auth DB initialized at {path}")
|
logger.info(f"Auth DB initialized at {path}")
|
||||||
|
|
@ -353,3 +590,28 @@ def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any
|
||||||
"last_login_at": row["last_login_at"],
|
"last_login_at": row["last_login_at"],
|
||||||
"created_by": row["created_by"],
|
"created_by": row["created_by"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
||||||
|
"""Convert an ``auth_sessions`` row into a JSON-safe dict.
|
||||||
|
|
||||||
|
The ``revoked`` field is normalized to a Python ``bool`` (the DB stores
|
||||||
|
0/1). The full set of audit fields is included so the admin UI and
|
||||||
|
API responses can surface device/IP/last-active information without
|
||||||
|
a separate lookup.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": row["id"],
|
||||||
|
"user_id": row["user_id"],
|
||||||
|
"device_fingerprint": row["device_fingerprint"],
|
||||||
|
"device_label": row["device_label"],
|
||||||
|
"ip": row["ip"],
|
||||||
|
"user_agent": row["user_agent"],
|
||||||
|
"auth_provider": row["auth_provider"],
|
||||||
|
"created_at": row["created_at"],
|
||||||
|
"last_active_at": row["last_active_at"],
|
||||||
|
"expires_at": row["expires_at"],
|
||||||
|
"revoked": bool(row["revoked"]),
|
||||||
|
"revoked_reason": row["revoked_reason"],
|
||||||
|
"previous_session_id": row["previous_session_id"],
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""AuthProvider implementations and dependency-injection factory.
|
||||||
|
|
||||||
|
Public surface:
|
||||||
|
|
||||||
|
- :class:`AuthProvider` (re-exported from :mod:`.base`) — the protocol
|
||||||
|
every backend must satisfy.
|
||||||
|
- :class:`LocalAuthProvider` — SQLite + bcrypt default.
|
||||||
|
- :class:`StubOIDCProvider` — interface placeholder for the future
|
||||||
|
OIDC integration.
|
||||||
|
- :class:`User` — provider-agnostic user value object.
|
||||||
|
- :func:`get_auth_provider` — DI factory used by routes via
|
||||||
|
``Depends(get_auth_provider)``.
|
||||||
|
- :func:`reset_auth_provider` — clears the lru_cache singleton
|
||||||
|
(used in tests + when the auth provider config changes at runtime).
|
||||||
|
|
||||||
|
Configuration
|
||||||
|
-------------
|
||||||
|
The provider is selected via the ``AGENTKIT_AUTH_PROVIDER`` environment
|
||||||
|
variable (default: ``"local"``). When the future ``auth.provider``
|
||||||
|
field is added to ``agentkit.yaml`` it should override the env var.
|
||||||
|
|
||||||
|
Adding a new provider
|
||||||
|
---------------------
|
||||||
|
1. Create ``auth/providers/<name>.py`` with a class that satisfies
|
||||||
|
the :class:`AuthProvider` Protocol (i.e. has ``name: str`` and the
|
||||||
|
four async methods).
|
||||||
|
2. Import it here and add a branch to :func:`get_auth_provider`.
|
||||||
|
3. Add it to ``AGENTKIT_AUTH_PROVIDER`` enum in the deployment docs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .base import AuthProvider
|
||||||
|
from .exceptions import AuthProviderError, InvalidCredentials, ProviderNotImplemented
|
||||||
|
from .local import LocalAuthProvider
|
||||||
|
from .oidc_stub import StubOIDCProvider
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Re-exports for ergonomic imports: `from agentkit.server.auth.providers import AuthProvider`
|
||||||
|
__all__ = [
|
||||||
|
"AuthProvider",
|
||||||
|
"AuthProviderError",
|
||||||
|
"InvalidCredentials",
|
||||||
|
"LocalAuthProvider",
|
||||||
|
"ProviderNotImplemented",
|
||||||
|
"StubOIDCProvider",
|
||||||
|
"User",
|
||||||
|
"get_auth_provider",
|
||||||
|
"reset_auth_provider",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Default provider name. Configurable via AGENTKIT_AUTH_PROVIDER env var
|
||||||
|
# (overridable when the auth.provider field is added to agentkit.yaml).
|
||||||
|
DEFAULT_AUTH_PROVIDER = "local"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_provider_name() -> str:
|
||||||
|
"""Return the configured provider name, defaulting to 'local'."""
|
||||||
|
name = os.environ.get("AGENTKIT_AUTH_PROVIDER", DEFAULT_AUTH_PROVIDER).strip()
|
||||||
|
if not name:
|
||||||
|
return DEFAULT_AUTH_PROVIDER
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_db_path() -> Path:
|
||||||
|
"""Return the auth DB path (overridable for tests via env var)."""
|
||||||
|
from ..models import DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
env = os.environ.get("AGENTKIT_AUTH_DB")
|
||||||
|
if env:
|
||||||
|
return Path(env)
|
||||||
|
return DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_auth_provider() -> AuthProvider:
|
||||||
|
"""Return the configured :class:`AuthProvider` (memoized singleton).
|
||||||
|
|
||||||
|
The ``lru_cache`` is process-local and the value is resolved on
|
||||||
|
first call. Use :func:`reset_auth_provider` to clear the cache
|
||||||
|
(e.g. when the configuration changes at runtime, or in test
|
||||||
|
fixtures that need a fresh provider with a different db_path).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the configured provider name is not recognized.
|
||||||
|
"""
|
||||||
|
name = _resolve_provider_name()
|
||||||
|
if name == "local":
|
||||||
|
return LocalAuthProvider(db_path=_resolve_db_path())
|
||||||
|
if name == "oidc-stub":
|
||||||
|
return StubOIDCProvider()
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown auth provider: {name!r}. "
|
||||||
|
f"Supported providers: 'local', 'oidc-stub'. "
|
||||||
|
f"Set AGENTKIT_AUTH_PROVIDER or update agentkit.yaml's auth.provider field."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_auth_provider() -> None:
|
||||||
|
"""Clear the memoized :class:`AuthProvider` singleton.
|
||||||
|
|
||||||
|
Use in tests (e.g. in a fixture's teardown) or in code paths that
|
||||||
|
change the auth provider configuration at runtime and need a
|
||||||
|
re-resolution.
|
||||||
|
"""
|
||||||
|
get_auth_provider.cache_clear()
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""AuthProvider protocol — pluggable authentication backend contract.
|
||||||
|
|
||||||
|
The :class:`AuthProvider` Protocol defines the minimal surface area the
|
||||||
|
auth subsystem needs from any authentication backend (Local today, OIDC
|
||||||
|
/ SAML / LDAP tomorrow). Routes and admin endpoints call only these
|
||||||
|
methods and never touch the underlying user store directly. Adding a new
|
||||||
|
IdP integration is a matter of writing a new adapter that satisfies
|
||||||
|
this Protocol.
|
||||||
|
|
||||||
|
Design notes:
|
||||||
|
|
||||||
|
- **No sync surface for password verification**: the only entry point is
|
||||||
|
:meth:`authenticate` which takes the plaintext password. Internally each
|
||||||
|
adapter chooses how to verify (bcrypt locally, redirect to IdP, etc.).
|
||||||
|
- **No persistence responsibility**: providers return :class:`User` objects
|
||||||
|
but do not manage sessions, refresh tokens, or session state. That is
|
||||||
|
the route + SessionService's job (see :mod:`agentkit.server.auth.session`).
|
||||||
|
- **Audit-friendly**: :attr:`name` is written to ``auth_sessions.auth_provider``
|
||||||
|
on every login, so the source of every session is traceable. This is
|
||||||
|
what enables "list sessions by provider" admin queries and future
|
||||||
|
cross-IdP policy enforcement.
|
||||||
|
- **runtime_checkable**: the Protocol is decorated so that
|
||||||
|
``isinstance(provider, AuthProvider)`` works at runtime, enabling
|
||||||
|
defensive checks in tests and in DI wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AuthProvider(Protocol):
|
||||||
|
"""All authentication backends must implement this surface.
|
||||||
|
|
||||||
|
The route layer only calls the methods below. It is intentionally
|
||||||
|
unaware of whether the backing store is SQLite, an OIDC IdP, LDAP,
|
||||||
|
or anything else.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""Identifier for this provider, written to ``auth_sessions.auth_provider``.
|
||||||
|
|
||||||
|
Convention: ``local`` for the local SQLite + bcrypt backend,
|
||||||
|
``oidc-<idp-name>`` for OIDC (e.g. ``oidc-keycloak``,
|
||||||
|
``oidc-feishu``), ``saml`` / ``ldap`` for the other planned
|
||||||
|
integrations. This value is the stable contract for audit
|
||||||
|
traceability — do not rename without a migration plan.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def authenticate(self, *, username: str, password: str) -> User:
|
||||||
|
"""Verify ``username`` + ``password`` and return the :class:`User`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: The submitted username (or, for OIDC, the IdP
|
||||||
|
subject id — but that's a future adapter's concern).
|
||||||
|
password: The submitted plaintext password.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The matched :class:`User` on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidCredentials: if the user does not exist, the password
|
||||||
|
is wrong, or the user is inactive. Callers MUST NOT
|
||||||
|
distinguish between these three cases in the error
|
||||||
|
message returned to the client (timing-attack /
|
||||||
|
username-enumeration mitigation).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
|
"""Look up a :class:`User` by primary key.
|
||||||
|
|
||||||
|
Used by:
|
||||||
|
- Admin endpoints that need to display user info by id
|
||||||
|
- Session validation in the cold-start / whoami path
|
||||||
|
- Audit log enrichment
|
||||||
|
|
||||||
|
Returns ``None`` if no user exists with this id (or the user
|
||||||
|
is inactive — convention: inactive users are "not found" from
|
||||||
|
the auth layer's perspective).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def sync_user_attributes(self, user_id: str) -> None:
|
||||||
|
"""Refresh user attributes (department / email / title) from the source of truth.
|
||||||
|
|
||||||
|
- :class:`LocalAuthProvider`: no-op (attributes are managed locally).
|
||||||
|
- OIDC adapter (future): pull the latest profile from the IdP and
|
||||||
|
write back to the local ``users`` table.
|
||||||
|
|
||||||
|
Implementations that have nothing to sync should still define
|
||||||
|
this method (returning ``None``) so the contract is uniform.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def revoke_user(self, user_id: str) -> None:
|
||||||
|
"""Disable a user account (e.g. on termination or lock-out).
|
||||||
|
|
||||||
|
- :class:`LocalAuthProvider`: ``UPDATE users SET is_active = 0``
|
||||||
|
- OIDC adapter (future): call the IdP's disable API
|
||||||
|
|
||||||
|
The admin endpoint that calls this does NOT also need to
|
||||||
|
revoke the user's active sessions — that is the
|
||||||
|
:class:`SessionService`'s job, called separately.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Exceptions raised by AuthProvider implementations.
|
||||||
|
|
||||||
|
These are caught by the route layer and translated to HTTP responses
|
||||||
|
(401 for :class:`InvalidCredentials`, 501 for :class:`ProviderNotImplemented`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AuthProviderError(Exception):
|
||||||
|
"""Base class for all auth provider errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCredentials(AuthProviderError):
|
||||||
|
"""Raised when username / password is wrong, or the user is inactive.
|
||||||
|
|
||||||
|
Translates to HTTP 401. The error message MUST NOT leak which of
|
||||||
|
"user not found" vs "wrong password" vs "user inactive" failed —
|
||||||
|
that is a username enumeration risk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderNotImplemented(AuthProviderError):
|
||||||
|
"""Raised when a configured provider is not yet implemented.
|
||||||
|
|
||||||
|
Translates to HTTP 501. Used by :class:`StubOIDCProvider` and any
|
||||||
|
future adapter that is registered but has no real implementation.
|
||||||
|
"""
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""Local authentication provider — SQLite + bcrypt.
|
||||||
|
|
||||||
|
The default :class:`AuthProvider` implementation. Authenticates users
|
||||||
|
against the local ``users`` table in the auth SQLite database using
|
||||||
|
the bcrypt cost=12 password hash (see :mod:`agentkit.server.auth.password`).
|
||||||
|
|
||||||
|
This is a behavioral equivalent of the password-verification code that
|
||||||
|
previously lived inline in ``routes/auth.py`` — moved here so the route
|
||||||
|
layer can call a single :meth:`authenticate` method regardless of which
|
||||||
|
backend is configured.
|
||||||
|
|
||||||
|
Future-IdP note
|
||||||
|
---------------
|
||||||
|
When the organization moves to OIDC / SAML / LDAP, this class does not
|
||||||
|
need to be deleted. It can continue to serve as a "local emergency
|
||||||
|
account" provider, configurable side-by-side with the IdP provider via
|
||||||
|
a future composite / multi-provider setup. For now it is the only
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from ..models import DEFAULT_AUTH_DB_PATH
|
||||||
|
from ..password import verify_password
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_db_path() -> Path:
|
||||||
|
"""Resolve the auth DB path with runtime env-var priority.
|
||||||
|
|
||||||
|
The :data:`models.DEFAULT_AUTH_DB_PATH` constant is captured at
|
||||||
|
module-import time and therefore cannot see test-time env mutations.
|
||||||
|
Re-reading ``AGENTKIT_AUTH_DB`` here keeps the provider
|
||||||
|
"test-friendly" (tests can ``monkeypatch.setenv`` before constructing
|
||||||
|
the provider) without giving up the default path when no env is set.
|
||||||
|
"""
|
||||||
|
env = os.environ.get("AGENTKIT_AUTH_DB")
|
||||||
|
if env:
|
||||||
|
return Path(env)
|
||||||
|
return DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAuthProvider:
|
||||||
|
"""AuthProvider backed by the local SQLite ``users`` table + bcrypt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the auth DB. Defaults to the value of the
|
||||||
|
``AGENTKIT_AUTH_DB`` env var, falling back to
|
||||||
|
:data:`agentkit.server.auth.models.DEFAULT_AUTH_DB_PATH`.
|
||||||
|
Each operation opens a short-lived aiosqlite connection;
|
||||||
|
the existing route layer follows the same pattern, so no
|
||||||
|
connection pooling is introduced here. If a future
|
||||||
|
deployment needs pooling, swap in a ``db_factory: Callable``
|
||||||
|
here without changing the protocol.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "local"
|
||||||
|
|
||||||
|
def __init__(self, db_path: str | Path | None = None) -> None:
|
||||||
|
self._db_path = Path(db_path) if db_path is not None else _resolve_db_path()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_path(self) -> Path:
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
|
async def authenticate(self, *, username: str, password: str) -> User:
|
||||||
|
"""Verify the username + password against the local users table.
|
||||||
|
|
||||||
|
Raises :class:`InvalidCredentials` on every failure mode
|
||||||
|
(unknown user, wrong password, inactive user) with the same
|
||||||
|
error message — preventing username enumeration via error
|
||||||
|
inspection. Constant-time-equivalent behavior is also
|
||||||
|
ensured by always running a real bcrypt computation
|
||||||
|
(against a dummy hash) when the user does not exist,
|
||||||
|
matching the timing of the "user exists, wrong password" path.
|
||||||
|
"""
|
||||||
|
from .exceptions import InvalidCredentials # local import to avoid cycle at module load
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, username, email, password_hash, role, is_active, "
|
||||||
|
"is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
"created_at, updated_at, last_login_at, created_by "
|
||||||
|
"FROM users WHERE username = ?",
|
||||||
|
(username,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
if row is None or not bool(row["is_active"]):
|
||||||
|
# Run a real bcrypt verification against a valid-format dummy
|
||||||
|
# hash so the response time matches the "user exists, wrong
|
||||||
|
# password" path (~250ms). Prevents username enumeration via
|
||||||
|
# timing. The dummy hash is invalid (won't match any password)
|
||||||
|
# but has the right shape so bcrypt.checkpw doesn't short-circuit.
|
||||||
|
_DUMMY_BCRYPT_HASH = "$2b$12$abcdefghijklmnopqrstuuABCDEFGHIJKLMNOPQRSTUVWXYZ0123"
|
||||||
|
verify_password(password, _DUMMY_BCRYPT_HASH)
|
||||||
|
raise InvalidCredentials("invalid username or password")
|
||||||
|
|
||||||
|
if not verify_password(password, row["password_hash"]):
|
||||||
|
raise InvalidCredentials("invalid username or password")
|
||||||
|
|
||||||
|
return _row_to_user(row)
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
|
"""Look up a user by id. Returns ``None`` if not found or inactive."""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id, username, email, password_hash, role, is_active, "
|
||||||
|
"is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
"created_at, updated_at, last_login_at, created_by "
|
||||||
|
"FROM users WHERE id = ? AND is_active = 1",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return _row_to_user(row) if row else None
|
||||||
|
|
||||||
|
async def sync_user_attributes(self, user_id: str) -> None:
|
||||||
|
"""No-op: local provider has no upstream source of truth to sync from."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def revoke_user(self, user_id: str) -> None:
|
||||||
|
"""Disable a user account (``is_active = 0``)."""
|
||||||
|
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE users SET is_active = 0, updated_at = ? WHERE id = ?",
|
||||||
|
(_now_iso(), user_id),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
logger.info(f"Revoked user {user_id} via LocalAuthProvider")
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_user(row: aiosqlite.Row) -> User:
|
||||||
|
"""Convert a ``users`` row to a :class:`User` value object."""
|
||||||
|
return User(
|
||||||
|
id=row["id"],
|
||||||
|
username=row["username"],
|
||||||
|
email=row["email"],
|
||||||
|
role=row["role"],
|
||||||
|
is_active=bool(row["is_active"]),
|
||||||
|
is_terminal_authorized=bool(row["is_terminal_authorized"]),
|
||||||
|
is_server_terminal_authorized=bool(row["is_server_terminal_authorized"]),
|
||||||
|
created_at=row["created_at"],
|
||||||
|
updated_at=row["updated_at"],
|
||||||
|
last_login_at=row["last_login_at"],
|
||||||
|
created_by=row["created_by"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Return current UTC time as ISO 8601 string."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Stub OIDC provider — interface contract placeholder.
|
||||||
|
|
||||||
|
This is a deliberately unimplemented :class:`AuthProvider` whose only
|
||||||
|
job is to fail loudly when ``auth.provider: oidc-stub`` is configured
|
||||||
|
without the real OIDC integration being in place. It exists so that
|
||||||
|
the configuration switch and the dependency-injection wiring can be
|
||||||
|
exercised end-to-end before the IdP integration work is scheduled.
|
||||||
|
|
||||||
|
Future OIDC integration checklist
|
||||||
|
---------------------------------
|
||||||
|
When the real OIDC integration lands, replace this file with
|
||||||
|
:file:`auth/providers/oidc.py` implementing:
|
||||||
|
|
||||||
|
- [ ] :meth:`authenticate` — accept (username, password)? OR redirect-flow?
|
||||||
|
OIDC is fundamentally a redirect flow, NOT a password POST. The
|
||||||
|
right answer is probably a separate ``/auth/oauth/{provider}/redirect``
|
||||||
|
and ``/auth/oauth/{provider}/callback`` route pair that bypasses
|
||||||
|
the password-based login. The :class:`AuthProvider` interface is
|
||||||
|
only for backends that accept username + password; an OIDC
|
||||||
|
adapter may need to extend the protocol or expose a separate
|
||||||
|
callback-handler hook.
|
||||||
|
- [ ] :meth:`get_user_by_id` — query local cache (auto-provisioned
|
||||||
|
OIDC users live in the local ``users`` table after first login).
|
||||||
|
- [ ] :meth:`sync_user_attributes` — pull latest profile from the IdP
|
||||||
|
on each successful login (department / title / email).
|
||||||
|
- [ ] :meth:`revoke_user` — call the IdP's account-disable API (e.g.
|
||||||
|
Keycloak's ``PUT /admin/realms/{realm}/users/{id}/disable``).
|
||||||
|
- [ ] State cache for OAuth ``state`` parameter (Redis, TTL 5 min).
|
||||||
|
- [ ] First-login user auto-provisioning policy: just-in-time
|
||||||
|
creation / reject / invite-only.
|
||||||
|
|
||||||
|
Until then, configuring ``auth.provider: oidc-stub`` and starting
|
||||||
|
the server lets a smoke test confirm:
|
||||||
|
|
||||||
|
1. The DI factory resolves the right class.
|
||||||
|
2. The route layer calls :meth:`authenticate` and surfaces 501.
|
||||||
|
3. ``auth_sessions.auth_provider='oidc-stub'`` would be written
|
||||||
|
for any session that *did* get created (none, in this stub).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .exceptions import ProviderNotImplemented
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class StubOIDCProvider:
|
||||||
|
"""Unimplemented OIDC adapter. All methods raise :class:`ProviderNotImplemented`.
|
||||||
|
|
||||||
|
The class satisfies the :class:`AuthProvider` Protocol structurally
|
||||||
|
(it has ``name`` and the four methods), but every method body is
|
||||||
|
a deliberate failure. This is on purpose: configuring
|
||||||
|
``auth.provider: oidc-stub`` and then successfully logging in would
|
||||||
|
be a security bug.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "oidc-stub"
|
||||||
|
|
||||||
|
async def authenticate(self, *, username: str, password: str) -> User:
|
||||||
|
raise ProviderNotImplemented(
|
||||||
|
"OIDC authentication is not implemented yet. "
|
||||||
|
"Use auth.provider: local in agentkit.yaml, or implement "
|
||||||
|
"auth/providers/oidc.py (see the checklist in this module's "
|
||||||
|
"docstring)."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
|
raise ProviderNotImplemented(
|
||||||
|
"StubOIDCProvider.get_user_by_id is not implemented. "
|
||||||
|
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def sync_user_attributes(self, user_id: str) -> None:
|
||||||
|
raise ProviderNotImplemented(
|
||||||
|
"StubOIDCProvider.sync_user_attributes is not implemented. "
|
||||||
|
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke_user(self, user_id: str) -> None:
|
||||||
|
raise ProviderNotImplemented(
|
||||||
|
"StubOIDCProvider.revoke_user is not implemented. "
|
||||||
|
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
"""Provider-agnostic user data model.
|
||||||
|
|
||||||
|
This is the value object the :class:`AuthProvider` Protocol returns
|
||||||
|
from :meth:`AuthProvider.authenticate` and :meth:`AuthProvider.get_user_by_id`.
|
||||||
|
It is intentionally NOT the SQLAlchemy ``UserModel`` ORM class so that
|
||||||
|
future IdP adapters (OIDC / SAML / LDAP) can construct one without
|
||||||
|
touching a DB.
|
||||||
|
|
||||||
|
The route layer converts ``User`` → ``UserResponse`` (the public
|
||||||
|
API payload) at the boundary; ``User`` itself stays an internal
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, EmailStr
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
"""A user as known to the auth subsystem.
|
||||||
|
|
||||||
|
Distinct from the SQLAlchemy ``UserModel`` ORM class because:
|
||||||
|
- The auth provider may not have a local row at all (e.g. an
|
||||||
|
OIDC user logged in via SSO who hasn't been provisioned yet
|
||||||
|
— that's a future concern but the data model must allow it).
|
||||||
|
- IdP-sourced attributes (department / title) are not in the
|
||||||
|
SQLAlchemy model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
email: EmailStr
|
||||||
|
role: str = "member"
|
||||||
|
is_active: bool = True
|
||||||
|
is_terminal_authorized: bool = False
|
||||||
|
is_server_terminal_authorized: bool = False
|
||||||
|
created_at: str
|
||||||
|
updated_at: str
|
||||||
|
last_login_at: str | None = None
|
||||||
|
created_by: str | None = None
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
"""Tests for the AuthProvider protocol and the DI factory (U11)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.auth.providers import (
|
||||||
|
AuthProvider,
|
||||||
|
LocalAuthProvider,
|
||||||
|
StubOIDCProvider,
|
||||||
|
get_auth_provider,
|
||||||
|
reset_auth_provider,
|
||||||
|
)
|
||||||
|
from agentkit.server.auth.providers.base import AuthProvider as AuthProviderFromBase
|
||||||
|
from agentkit.server.auth.providers.exceptions import (
|
||||||
|
AuthProviderError,
|
||||||
|
InvalidCredentials,
|
||||||
|
ProviderNotImplemented,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProtocolConformance:
|
||||||
|
"""The runtime_checkable Protocol accepts both real implementations."""
|
||||||
|
|
||||||
|
def test_local_passes_isinstance_check(self):
|
||||||
|
provider = LocalAuthProvider()
|
||||||
|
assert isinstance(provider, AuthProvider)
|
||||||
|
assert isinstance(provider, AuthProviderFromBase)
|
||||||
|
|
||||||
|
def test_stub_passes_isinstance_check(self):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
assert isinstance(provider, AuthProvider)
|
||||||
|
assert isinstance(provider, AuthProviderFromBase)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderNames:
|
||||||
|
def test_local_provider_name(self):
|
||||||
|
assert LocalAuthProvider.name == "local"
|
||||||
|
|
||||||
|
def test_stub_oidc_provider_name(self):
|
||||||
|
assert StubOIDCProvider.name == "oidc-stub"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiFactory:
|
||||||
|
"""``get_auth_provider`` is a memoized singleton, env-driven."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Reset cache and env before each test."""
|
||||||
|
reset_auth_provider()
|
||||||
|
self._saved_env = {}
|
||||||
|
for key in ("AGENTKIT_AUTH_PROVIDER",):
|
||||||
|
self._saved_env[key] = __import__("os").environ.pop(key, None)
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
for key, value in self._saved_env.items():
|
||||||
|
if value is not None:
|
||||||
|
__import__("os").environ[key] = value
|
||||||
|
reset_auth_provider()
|
||||||
|
|
||||||
|
def test_default_provider_is_local(self):
|
||||||
|
provider = get_auth_provider()
|
||||||
|
assert isinstance(provider, LocalAuthProvider)
|
||||||
|
assert provider.name == "local"
|
||||||
|
|
||||||
|
def test_oidc_stub_provider(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
|
||||||
|
reset_auth_provider()
|
||||||
|
provider = get_auth_provider()
|
||||||
|
assert isinstance(provider, StubOIDCProvider)
|
||||||
|
assert provider.name == "oidc-stub"
|
||||||
|
|
||||||
|
def test_unknown_provider_raises_value_error(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "ldap-unknown")
|
||||||
|
reset_auth_provider()
|
||||||
|
with pytest.raises(ValueError, match="unknown auth provider"):
|
||||||
|
get_auth_provider()
|
||||||
|
|
||||||
|
def test_factory_is_memoized(self):
|
||||||
|
first = get_auth_provider()
|
||||||
|
second = get_auth_provider()
|
||||||
|
assert first is second
|
||||||
|
|
||||||
|
def test_reset_clears_cache(self, monkeypatch):
|
||||||
|
first = get_auth_provider()
|
||||||
|
reset_auth_provider()
|
||||||
|
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
|
||||||
|
second = get_auth_provider()
|
||||||
|
assert first is not second
|
||||||
|
assert isinstance(second, StubOIDCProvider)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionHierarchy:
|
||||||
|
def test_invalid_credentials_inherits_auth_provider_error(self):
|
||||||
|
assert issubclass(InvalidCredentials, AuthProviderError)
|
||||||
|
assert issubclass(InvalidCredentials, Exception)
|
||||||
|
|
||||||
|
def test_provider_not_implemented_inherits_auth_provider_error(self):
|
||||||
|
assert issubclass(ProviderNotImplemented, AuthProviderError)
|
||||||
|
assert issubclass(ProviderNotImplemented, Exception)
|
||||||
|
|
||||||
|
def test_invalid_credentials_can_be_raised_and_caught(self):
|
||||||
|
with pytest.raises(InvalidCredentials):
|
||||||
|
raise InvalidCredentials("test message")
|
||||||
|
|
||||||
|
def test_invalid_credentials_caught_as_base(self):
|
||||||
|
"""Route layer catches AuthProviderError to handle all provider errors."""
|
||||||
|
with pytest.raises(AuthProviderError):
|
||||||
|
raise InvalidCredentials("test")
|
||||||
|
|
||||||
|
def test_provider_not_implemented_caught_as_base(self):
|
||||||
|
with pytest.raises(AuthProviderError):
|
||||||
|
raise ProviderNotImplemented("test")
|
||||||
|
|
@ -0,0 +1,254 @@
|
||||||
|
"""Tests for LocalAuthProvider (U11 — concrete Local implementation)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.server.auth.password import hash_password
|
||||||
|
from agentkit.server.auth.providers import LocalAuthProvider
|
||||||
|
from agentkit.server.auth.providers.exceptions import InvalidCredentials
|
||||||
|
from agentkit.server.auth.providers.user import User
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def auth_db_with_users(tmp_path: Path) -> dict:
|
||||||
|
"""Create a fresh auth DB with two users: one active, one inactive.
|
||||||
|
|
||||||
|
Returns a dict with user info + the db path.
|
||||||
|
"""
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
|
||||||
|
now_iso = datetime.now(timezone.utc).isoformat()
|
||||||
|
active_id = str(uuid.uuid4())
|
||||||
|
inactive_id = str(uuid.uuid4())
|
||||||
|
active_pw = "correct-horse-battery-staple"
|
||||||
|
inactive_pw = "disabled-user-pw"
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO users "
|
||||||
|
"(id, username, email, password_hash, role, is_active, "
|
||||||
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
" created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
active_id,
|
||||||
|
"alice",
|
||||||
|
"alice@example.com",
|
||||||
|
hash_password(active_pw),
|
||||||
|
"member",
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO users "
|
||||||
|
"(id, username, email, password_hash, role, is_active, "
|
||||||
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
" created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
inactive_id,
|
||||||
|
"bob_inactive",
|
||||||
|
"bob@example.com",
|
||||||
|
hash_password(inactive_pw),
|
||||||
|
"member",
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"db_path": db_path,
|
||||||
|
"active": {
|
||||||
|
"id": active_id,
|
||||||
|
"username": "alice",
|
||||||
|
"password": active_pw,
|
||||||
|
"email": "alice@example.com",
|
||||||
|
"role": "member",
|
||||||
|
},
|
||||||
|
"inactive": {
|
||||||
|
"id": inactive_id,
|
||||||
|
"username": "bob_inactive",
|
||||||
|
"password": inactive_pw,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider(auth_db_with_users: dict) -> LocalAuthProvider:
|
||||||
|
return LocalAuthProvider(db_path=auth_db_with_users["db_path"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# authenticate
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthenticate:
|
||||||
|
async def test_valid_credentials_returns_user(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
user = await provider.authenticate(
|
||||||
|
username=auth_db_with_users["active"]["username"],
|
||||||
|
password=auth_db_with_users["active"]["password"],
|
||||||
|
)
|
||||||
|
assert isinstance(user, User)
|
||||||
|
assert user.id == auth_db_with_users["active"]["id"]
|
||||||
|
assert user.username == "alice"
|
||||||
|
assert user.email == "alice@example.com"
|
||||||
|
assert user.role == "member"
|
||||||
|
assert user.is_active is True
|
||||||
|
|
||||||
|
async def test_wrong_password_raises_invalid_credentials(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
with pytest.raises(InvalidCredentials):
|
||||||
|
await provider.authenticate(
|
||||||
|
username=auth_db_with_users["active"]["username"],
|
||||||
|
password="definitely-not-the-password",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_unknown_user_raises_invalid_credentials(self, provider: LocalAuthProvider):
|
||||||
|
with pytest.raises(InvalidCredentials):
|
||||||
|
await provider.authenticate(username="nonexistent", password="anything")
|
||||||
|
|
||||||
|
async def test_inactive_user_raises_invalid_credentials(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
with pytest.raises(InvalidCredentials):
|
||||||
|
await provider.authenticate(
|
||||||
|
username=auth_db_with_users["inactive"]["username"],
|
||||||
|
password=auth_db_with_users["inactive"]["password"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_error_message_does_not_leak_username_existence(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
"""Wrong-password and unknown-user errors must have the same message."""
|
||||||
|
try:
|
||||||
|
await provider.authenticate(
|
||||||
|
username=auth_db_with_users["active"]["username"],
|
||||||
|
password="wrong",
|
||||||
|
)
|
||||||
|
except InvalidCredentials as e1:
|
||||||
|
wrong_pw_msg = str(e1)
|
||||||
|
try:
|
||||||
|
await provider.authenticate(username="nobody-here", password="x")
|
||||||
|
except InvalidCredentials as e2:
|
||||||
|
unknown_msg = str(e2)
|
||||||
|
assert wrong_pw_msg == unknown_msg
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_user_by_id
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUserById:
|
||||||
|
async def test_returns_user_for_active_id(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
user = await provider.get_user_by_id(auth_db_with_users["active"]["id"])
|
||||||
|
assert user is not None
|
||||||
|
assert user.id == auth_db_with_users["active"]["id"]
|
||||||
|
assert user.username == "alice"
|
||||||
|
|
||||||
|
async def test_returns_none_for_inactive_user(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
user = await provider.get_user_by_id(auth_db_with_users["inactive"]["id"])
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
async def test_returns_none_for_unknown_id(self, provider: LocalAuthProvider):
|
||||||
|
user = await provider.get_user_by_id(str(uuid.uuid4()))
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# sync_user_attributes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSyncUserAttributes:
|
||||||
|
async def test_is_noop(self, provider: LocalAuthProvider, auth_db_with_users: dict):
|
||||||
|
"""Local provider has no upstream to sync from — must succeed with no effect."""
|
||||||
|
result = await provider.sync_user_attributes(auth_db_with_users["active"]["id"])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_noop_does_not_throw_for_unknown_user(self, provider: LocalAuthProvider):
|
||||||
|
"""sync_user_attributes is fire-and-forget — no existence check."""
|
||||||
|
# Should NOT raise even though the id doesn't exist
|
||||||
|
await provider.sync_user_attributes(str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# revoke_user
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevokeUser:
|
||||||
|
async def test_sets_is_active_to_zero(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict, tmp_path: Path
|
||||||
|
):
|
||||||
|
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
||||||
|
async with aiosqlite.connect(str(auth_db_with_users["db_path"])) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT is_active FROM users WHERE id = ?",
|
||||||
|
(auth_db_with_users["active"]["id"],),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert bool(row["is_active"]) is False
|
||||||
|
|
||||||
|
async def test_revoked_user_can_no_longer_authenticate(
|
||||||
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||||
|
):
|
||||||
|
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
||||||
|
with pytest.raises(InvalidCredentials):
|
||||||
|
await provider.authenticate(
|
||||||
|
username=auth_db_with_users["active"]["username"],
|
||||||
|
password=auth_db_with_users["active"]["password"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_revoke_unknown_user_does_not_raise(self, provider: LocalAuthProvider):
|
||||||
|
"""``UPDATE ... WHERE id = ?`` with no match is a no-op, not an error."""
|
||||||
|
await provider.revoke_user(str(uuid.uuid4()))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# default db_path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultDbPath:
|
||||||
|
def test_default_db_path_uses_models_default(self, tmp_path: Path, monkeypatch):
|
||||||
|
"""If no db_path is given, the provider should use the module default.
|
||||||
|
|
||||||
|
The default may resolve to a path that does not exist on a test
|
||||||
|
machine — we only assert that the property returns a Path, not
|
||||||
|
that the file exists.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("AGENTKIT_AUTH_DB", str(tmp_path / "default.db"))
|
||||||
|
provider = LocalAuthProvider()
|
||||||
|
assert isinstance(provider.db_path, Path)
|
||||||
|
assert str(provider.db_path) == str(tmp_path / "default.db")
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""Tests for StubOIDCProvider (U11 — interface placeholder)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.auth.providers import StubOIDCProvider
|
||||||
|
from agentkit.server.auth.providers.exceptions import (
|
||||||
|
AuthProviderError,
|
||||||
|
ProviderNotImplemented,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStubAuthenticate:
|
||||||
|
async def test_raises_provider_not_implemented(self):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(ProviderNotImplemented):
|
||||||
|
await provider.authenticate(username="alice", password="hunter2")
|
||||||
|
|
||||||
|
async def test_error_message_mentions_oidc(self):
|
||||||
|
"""The error message must guide the operator to the right next step."""
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(ProviderNotImplemented) as exc:
|
||||||
|
await provider.authenticate(username="alice", password="x")
|
||||||
|
msg = str(exc.value)
|
||||||
|
assert "OIDC" in msg
|
||||||
|
assert "not implemented" in msg.lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestStubGetUserById:
|
||||||
|
async def test_raises_provider_not_implemented(self):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(ProviderNotImplemented):
|
||||||
|
await provider.get_user_by_id("any-id")
|
||||||
|
|
||||||
|
|
||||||
|
class TestStubSyncUserAttributes:
|
||||||
|
async def test_raises_provider_not_implemented(self):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(ProviderNotImplemented):
|
||||||
|
await provider.sync_user_attributes("any-id")
|
||||||
|
|
||||||
|
|
||||||
|
class TestStubRevokeUser:
|
||||||
|
async def test_raises_provider_not_implemented(self):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(ProviderNotImplemented):
|
||||||
|
await provider.revoke_user("any-id")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllErrorsAreAuthProviderError:
|
||||||
|
"""Every stub method's error should be catchable as AuthProviderError."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method,call_args",
|
||||||
|
[
|
||||||
|
("authenticate", ()),
|
||||||
|
("get_user_by_id", ("id",)),
|
||||||
|
("sync_user_attributes", ("id",)),
|
||||||
|
("revoke_user", ("id",)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_catchable_as_base(self, method: str, call_args: tuple):
|
||||||
|
provider = StubOIDCProvider()
|
||||||
|
with pytest.raises(AuthProviderError):
|
||||||
|
if method == "authenticate":
|
||||||
|
await provider.authenticate(username="u", password="p")
|
||||||
|
else:
|
||||||
|
await getattr(provider, method)(*call_args)
|
||||||
|
|
@ -0,0 +1,653 @@
|
||||||
|
"""Unit tests for auth.models (U1 — V2 schema + backfill).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- ``auth_sessions`` table creation (columns + indexes)
|
||||||
|
- ``auth_meta`` table creation
|
||||||
|
- ``_SCHEMA_VERSION`` constant value
|
||||||
|
- ``auth_session_row_to_dict`` field-by-field conversion
|
||||||
|
- ``_backfill_user_sessions`` one-time migration (V1 → V2)
|
||||||
|
- ``init_auth_db`` idempotency (subsequent runs are no-ops)
|
||||||
|
- ``auth_provider`` column default value
|
||||||
|
- Index presence (verified via ``PRAGMA index_list``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.auth.models import (
|
||||||
|
AuthSessionModel,
|
||||||
|
UserSessionModel,
|
||||||
|
_SCHEMA_VERSION,
|
||||||
|
auth_session_row_to_dict,
|
||||||
|
init_auth_db,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fresh_db(tmp_path: Path) -> Path:
|
||||||
|
"""A brand-new auth DB on a fresh path (no data)."""
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
async def _insert_user(db: aiosqlite.Connection, *, user_id: str | None = None) -> str:
|
||||||
|
"""Insert a minimal user row and return its id."""
|
||||||
|
user_id = user_id or str(uuid.uuid4())
|
||||||
|
now_iso = datetime.now(timezone.utc).isoformat()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO users "
|
||||||
|
"(id, username, email, password_hash, role, is_active, "
|
||||||
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||||
|
" created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
user_id,
|
||||||
|
f"user-{user_id[:8]}",
|
||||||
|
f"{user_id[:8]}@example.com",
|
||||||
|
"$2b$12$placeholder.hash.placeholder.hash.placeholder.hash",
|
||||||
|
"member",
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _insert_user_session(
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
refresh_hash: str | None = None,
|
||||||
|
device_info: str = "{}",
|
||||||
|
revoked: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Insert a V1 ``user_sessions`` row and return its id."""
|
||||||
|
session_id = session_id or str(uuid.uuid4())
|
||||||
|
refresh_hash = refresh_hash or uuid.uuid4().hex
|
||||||
|
now_iso = datetime.now(timezone.utc).isoformat()
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO user_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
refresh_hash,
|
||||||
|
device_info,
|
||||||
|
now_iso,
|
||||||
|
now_iso,
|
||||||
|
now_iso if revoked else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
||||||
|
"""Return the set of index names for a table.
|
||||||
|
|
||||||
|
Sets ``row_factory`` on the connection so we can address columns by name.
|
||||||
|
PRAGMA results in aiosqlite come back as plain tuples unless a row factory
|
||||||
|
is configured.
|
||||||
|
"""
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(f"PRAGMA index_list({table})")
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return {row["name"] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _SCHEMA_VERSION
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchemaVersion:
|
||||||
|
def test_schema_version_is_v2(self):
|
||||||
|
"""The current schema version is 2 (V2 adds auth_sessions + auth_meta)."""
|
||||||
|
assert _SCHEMA_VERSION == 2
|
||||||
|
|
||||||
|
def test_sqlalchemy_model_table_name(self):
|
||||||
|
assert AuthSessionModel.__tablename__ == "auth_sessions"
|
||||||
|
assert UserSessionModel.__tablename__ == "user_sessions"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# init_auth_db: tables + indexes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestInitAuthDbTables:
|
||||||
|
async def test_creates_auth_sessions_table(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_sessions'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
async def test_creates_auth_meta_table(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_meta'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
async def test_creates_user_sessions_table_for_back_compat(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='user_sessions'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
async def test_records_schema_version_in_auth_meta(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["value"] == str(_SCHEMA_VERSION)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthSessionsIndexes:
|
||||||
|
async def test_user_id_active_index(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)):
|
||||||
|
pass
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
indexes = await _list_index_names(db, "auth_sessions")
|
||||||
|
assert "idx_auth_sessions_user_id_active" in indexes
|
||||||
|
|
||||||
|
async def test_expires_at_index(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
indexes = await _list_index_names(db, "auth_sessions")
|
||||||
|
assert "idx_auth_sessions_expires_at" in indexes
|
||||||
|
|
||||||
|
async def test_auth_provider_index(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
indexes = await _list_index_names(db, "auth_sessions")
|
||||||
|
assert "idx_auth_sessions_auth_provider" in indexes
|
||||||
|
|
||||||
|
async def test_refresh_token_hash_unique_index(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
# SQLite stores column-level UNIQUE constraints as PRAGMA index_list
|
||||||
|
# entries with auto-generated names like sqlite_autoindex_<table>_1.
|
||||||
|
# The PRAGMA index_info on each autoindex reports the constrained columns.
|
||||||
|
cursor = await db.execute("PRAGMA index_list(auth_sessions)")
|
||||||
|
index_entries = await cursor.fetchall()
|
||||||
|
column_names: set[str] = set()
|
||||||
|
for entry in index_entries:
|
||||||
|
col_cursor = await db.execute(f"PRAGMA index_info({entry['name']})")
|
||||||
|
col_rows = await col_cursor.fetchall()
|
||||||
|
for col_row in col_rows:
|
||||||
|
column_names.add(col_row["name"])
|
||||||
|
assert "refresh_token_hash" in column_names
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# auth_sessions columns
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthSessionsColumns:
|
||||||
|
async def test_required_columns_present(self, fresh_db: Path):
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
cursor = await db.execute("PRAGMA table_info(auth_sessions)")
|
||||||
|
cols = {row[1] for row in await cursor.fetchall()}
|
||||||
|
required = {
|
||||||
|
"id",
|
||||||
|
"user_id",
|
||||||
|
"refresh_token_hash",
|
||||||
|
"device_fingerprint",
|
||||||
|
"device_label",
|
||||||
|
"ip",
|
||||||
|
"user_agent",
|
||||||
|
"auth_provider",
|
||||||
|
"created_at",
|
||||||
|
"last_active_at",
|
||||||
|
"expires_at",
|
||||||
|
"revoked",
|
||||||
|
"revoked_reason",
|
||||||
|
"previous_session_id",
|
||||||
|
}
|
||||||
|
assert required.issubset(cols)
|
||||||
|
|
||||||
|
async def test_auth_provider_default_is_local(self, fresh_db: Path):
|
||||||
|
"""Insert a row without auth_provider → column should default to 'local'."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
sid = str(uuid.uuid4())
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||||
|
(
|
||||||
|
sid,
|
||||||
|
user_id,
|
||||||
|
"deadbeef",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
"2026-12-31T00:00:00+00:00",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT auth_provider FROM auth_sessions WHERE id=?", (sid,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["auth_provider"] == "local"
|
||||||
|
|
||||||
|
async def test_revoked_default_is_false(self, fresh_db: Path):
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
sid = str(uuid.uuid4())
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
sid,
|
||||||
|
user_id,
|
||||||
|
"beefdead",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
"2026-12-31T00:00:00+00:00",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT revoked FROM auth_sessions WHERE id=?", (sid,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert bool(row["revoked"]) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# auth_session_row_to_dict
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthSessionRowToDict:
|
||||||
|
async def test_converts_all_fields(self, fresh_db: Path):
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
sid = str(uuid.uuid4())
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, ip, "
|
||||||
|
" user_agent, auth_provider, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked, revoked_reason, previous_session_id) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
sid,
|
||||||
|
user_id,
|
||||||
|
"abc123",
|
||||||
|
"fp-xyz",
|
||||||
|
"macOS Chrome 119",
|
||||||
|
"10.0.0.1",
|
||||||
|
"Mozilla/5.0",
|
||||||
|
"local",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
"2026-06-20T00:00:00+00:00",
|
||||||
|
"2027-01-01T00:00:00+00:00",
|
||||||
|
1,
|
||||||
|
"user_terminated",
|
||||||
|
"old-sid",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
|
||||||
|
d = auth_session_row_to_dict(row)
|
||||||
|
assert d["id"] == sid
|
||||||
|
assert d["user_id"] == user_id
|
||||||
|
assert d["device_fingerprint"] == "fp-xyz"
|
||||||
|
assert d["device_label"] == "macOS Chrome 119"
|
||||||
|
assert d["ip"] == "10.0.0.1"
|
||||||
|
assert d["user_agent"] == "Mozilla/5.0"
|
||||||
|
assert d["auth_provider"] == "local"
|
||||||
|
assert d["revoked"] is True
|
||||||
|
assert d["revoked_reason"] == "user_terminated"
|
||||||
|
assert d["previous_session_id"] == "old-sid"
|
||||||
|
|
||||||
|
async def test_normalizes_revoked_to_bool(self, fresh_db: Path):
|
||||||
|
"""DB stores 0/1; helper should return Python bool."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
sid = str(uuid.uuid4())
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||||
|
(sid, user_id, "x", "t", "t", "t"),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
d = auth_session_row_to_dict(row)
|
||||||
|
assert isinstance(d["revoked"], bool)
|
||||||
|
assert d["revoked"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _backfill_user_sessions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackfillUserSessions:
|
||||||
|
async def test_backfills_non_revoked_v1_rows(self, fresh_db: Path):
|
||||||
|
"""On a fresh V1 install, all non-revoked rows are migrated to auth_sessions."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="hash1",
|
||||||
|
device_info="{}",
|
||||||
|
)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="hash2",
|
||||||
|
device_info=json.dumps(
|
||||||
|
{
|
||||||
|
"fingerprint": "mac-tauri",
|
||||||
|
"label": "macOS Tauri 1.0",
|
||||||
|
"ip": "192.168.1.10",
|
||||||
|
"user_agent": "Tauri/1.0",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="hash3",
|
||||||
|
revoked=True,
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
# Force a backfill by clearing the marker + dropping auth_sessions rows
|
||||||
|
# (simulating a V1 install upgrading to V2).
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Re-init to trigger the migration
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute("SELECT * FROM auth_sessions ORDER BY refresh_token_hash")
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
# Only the 2 non-revoked rows should have been backfilled
|
||||||
|
assert len(rows) == 2
|
||||||
|
hashes = {row["refresh_token_hash"] for row in rows}
|
||||||
|
assert hashes == {"hash1", "hash2"}
|
||||||
|
|
||||||
|
async def test_backfill_preserves_original_id(self, fresh_db: Path):
|
||||||
|
"""Backfilled rows reuse the V1 id so legacy clients match."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
original_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=original_id,
|
||||||
|
refresh_hash="hash-original",
|
||||||
|
)
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT id FROM auth_sessions WHERE refresh_token_hash='hash-original'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["id"] == original_id
|
||||||
|
|
||||||
|
async def test_backfill_does_not_touch_revoked_v1_rows(self, fresh_db: Path):
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="revoked-hash",
|
||||||
|
revoked=True,
|
||||||
|
)
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT COUNT(*) AS c FROM auth_sessions WHERE refresh_token_hash='revoked-hash'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row["c"] == 0
|
||||||
|
|
||||||
|
async def test_backfill_copies_device_info_fields(self, fresh_db: Path):
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
device_info = json.dumps(
|
||||||
|
{
|
||||||
|
"fingerprint": "win-tauri-abc",
|
||||||
|
"label": "Windows Tauri 1.0",
|
||||||
|
"ip": "10.0.0.5",
|
||||||
|
"user_agent": "Tauri/2.0",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="hash-with-device",
|
||||||
|
device_info=device_info,
|
||||||
|
)
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT device_fingerprint, device_label, ip, user_agent "
|
||||||
|
"FROM auth_sessions WHERE refresh_token_hash='hash-with-device'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row["device_fingerprint"] == "win-tauri-abc"
|
||||||
|
assert row["device_label"] == "Windows Tauri 1.0"
|
||||||
|
assert row["ip"] == "10.0.0.5"
|
||||||
|
assert row["user_agent"] == "Tauri/2.0"
|
||||||
|
|
||||||
|
async def test_backfill_handles_malformed_device_info(self, fresh_db: Path):
|
||||||
|
"""Malformed JSON in device_info → fall back to defaults."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(
|
||||||
|
db,
|
||||||
|
user_id=user_id,
|
||||||
|
refresh_hash="bad-json",
|
||||||
|
device_info="not-json{",
|
||||||
|
)
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT device_fingerprint, device_label FROM auth_sessions "
|
||||||
|
"WHERE refresh_token_hash='bad-json'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["device_fingerprint"] == "unknown"
|
||||||
|
assert row["device_label"] == "Unknown device"
|
||||||
|
|
||||||
|
async def test_backfill_marks_done_in_auth_meta(self, fresh_db: Path):
|
||||||
|
"""After a backfill, the auth_meta marker is set so it doesn't run again."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await _insert_user_session(db, user_id=user_id, refresh_hash="h1")
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["value"] == "done"
|
||||||
|
|
||||||
|
async def test_backfill_is_idempotent(self, fresh_db: Path):
|
||||||
|
"""Running init twice does NOT duplicate auth_sessions rows."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
for i in range(3):
|
||||||
|
await _insert_user_session(db, user_id=user_id, refresh_hash=f"hash{i}")
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
await init_auth_db(fresh_db) # second run should be a no-op
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
cursor = await db.execute("SELECT COUNT(*) AS c FROM auth_sessions")
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row[0] == 3
|
||||||
|
|
||||||
|
async def test_fresh_install_marks_backfill_done_without_rows(self, fresh_db: Path):
|
||||||
|
"""A fresh V2 install (no V1 rows) still marks the backfill as done."""
|
||||||
|
# The fresh_db fixture already ran init_auth_db once.
|
||||||
|
# Re-running should be a no-op (idempotent).
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row["value"] == "done"
|
||||||
|
|
||||||
|
async def test_backfill_preserves_expires_at(self, fresh_db: Path):
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
original_exp = "2027-06-20T12:34:56+00:00"
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO user_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
(
|
||||||
|
str(uuid.uuid4()),
|
||||||
|
user_id,
|
||||||
|
"exp-hash",
|
||||||
|
"{}",
|
||||||
|
"2026-01-01T00:00:00+00:00",
|
||||||
|
original_exp,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await db.execute("DELETE FROM auth_sessions")
|
||||||
|
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await init_auth_db(fresh_db)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT expires_at FROM auth_sessions WHERE refresh_token_hash='exp-hash'"
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
assert row["expires_at"] == original_exp
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Active-session query pattern (covers the cap-count and list-active paths)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestActiveSessionQueries:
|
||||||
|
async def test_query_active_sessions_for_user(self, fresh_db: Path):
|
||||||
|
"""The (user_id, revoked, expires_at) index supports the active-session query."""
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
future = "2027-12-31T00:00:00+00:00"
|
||||||
|
past = "2020-01-01T00:00:00+00:00"
|
||||||
|
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||||
|
await _insert_user(db, user_id=user_id)
|
||||||
|
# 2 active, 1 expired, 1 revoked → 2 should match
|
||||||
|
for hash_, exp in [("a", future), ("b", future), ("c", past)]:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked) VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||||
|
(str(uuid.uuid4()), user_id, hash_, "t", "t", exp),
|
||||||
|
)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO auth_sessions "
|
||||||
|
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||||
|
" revoked, revoked_reason) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?, 1, 'user_terminated')",
|
||||||
|
(str(uuid.uuid4()), user_id, "d", "t", "t", future),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT refresh_token_hash FROM auth_sessions "
|
||||||
|
"WHERE user_id = ? AND revoked = 0 AND expires_at > ?",
|
||||||
|
(user_id, "2026-06-20T00:00:00+00:00"),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
hashes = {row["refresh_token_hash"] for row in rows}
|
||||||
|
assert hashes == {"a", "b"}
|
||||||
Loading…
Reference in New Issue