merge: 引入 U11 AuthProvider 抽象层到客户端持久化分支

This commit is contained in:
chiguyong 2026-06-21 01:28:23 +08:00
commit d42c45e5ad
13 changed files with 1889 additions and 1 deletions

View File

@ -5,10 +5,23 @@ stored as ``String(36)`` so the same schema works on both SQLite and
PostgreSQL without dialect-specific types. PostgreSQL without dialect-specific types.
Use :func:`init_auth_db` to create the tables on startup. Use :func:`init_auth_db` to create the tables on startup.
Schema versioning
-----------------
The :data:`_SCHEMA_VERSION` constant tracks the current auth DB schema. The
:func:`_backfill_user_sessions` one-time migration is gated on this version
to ensure idempotency. After a successful backfill the version is stored
in the ``auth_meta`` table so subsequent restarts are no-ops.
V2 additions (2026-06-20, Centralized Auth & Token Persistence):
- New ``auth_sessions`` table with full device/IP/audit metadata and an
``auth_provider`` column for future IdP integration traceability.
- New ``auth_meta`` table for storing schema version + migration state.
""" """
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os import os
from collections.abc import Mapping from collections.abc import Mapping
@ -88,10 +101,16 @@ class UserApiKeyModel(Base):
class UserSessionModel(Base): class UserSessionModel(Base):
"""Refresh-token session record. """Refresh-token session record (V1, deprecated — see :class:`AuthSessionModel`).
Stores the SHA-256 hash of the refresh token (never the plaintext). Stores the SHA-256 hash of the refresh token (never the plaintext).
``revoked_at`` is set on logout / forced revocation. ``revoked_at`` is set on logout / forced revocation.
.. deprecated::
Kept for one minor version (per U10 back-compat shim) so legacy
clients holding JWTs without ``sid`` claim can still validate.
New code should use :class:`AuthSessionModel` (table ``auth_sessions``)
which carries full device/IP/audit metadata.
""" """
__tablename__ = "user_sessions" __tablename__ = "user_sessions"
@ -107,6 +126,50 @@ class UserSessionModel(Base):
revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True) revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True)
class AuthSessionModel(Base):
"""Server-side session record (V2, the primary session table going forward).
Each row corresponds to a single refresh-token issuance. The full JWT
session id (``sid`` claim) is the row's ``id`` (UUID string).
V2 fields (vs V1 ``user_sessions``):
- ``device_fingerprint`` / ``device_label``: surfaces "which device is
this session" in the admin UI.
- ``ip`` / ``user_agent``: audit trail.
- ``last_active_at``: updated on every successful refresh, used to
display "last seen" in the sessions list.
- ``revoked`` (0/1) + ``revoked_reason``: explicit revoked-state machine
with machine-readable reasons (``user_terminated``, ``password_changed``,
``admin_revoked``, ``reuse_detected``, ``session_cap_eviction``).
- ``previous_session_id``: back-pointer to the previous session id,
written on refresh rotation, for audit trail.
- ``auth_provider``: the AuthProvider that issued this session
(``local`` / ``oidc-stub`` / future ``oidc-keycloak`` / ``saml`` / ``ldap``).
Enables admin "list sessions by provider" queries and audit traceability.
"""
__tablename__ = "auth_sessions"
id: Mapped[str] = mapped_column(String(36), primary_key=True)
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
refresh_token_hash: Mapped[str] = mapped_column(
String(64), unique=True, nullable=False, index=True
)
device_fingerprint: Mapped[str] = mapped_column(String(128), nullable=False, default="unknown")
device_label: Mapped[str] = mapped_column(String(256), nullable=False, default="Unknown device")
ip: Mapped[str] = mapped_column(String(64), nullable=False, default="")
user_agent: Mapped[str] = mapped_column(String(512), nullable=False, default="")
auth_provider: Mapped[str] = mapped_column(
String(32), nullable=False, default="local", index=True
)
created_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
last_active_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
expires_at: Mapped[str] = mapped_column(String(64), nullable=False)
revoked: Mapped[bool] = mapped_column(default=False, nullable=False, index=True)
revoked_reason: Mapped[str | None] = mapped_column(String(64), nullable=True)
previous_session_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
class TerminalWhitelistUserModel(Base): class TerminalWhitelistUserModel(Base):
"""Per-user terminal command whitelist. """Per-user terminal command whitelist.
@ -244,6 +307,44 @@ CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash
ON user_sessions(refresh_token_hash); ON user_sessions(refresh_token_hash);
-- V2: auth_sessions replaces user_sessions as the primary session table.
-- Stores device/IP/audit metadata and auth_provider for IdP traceability.
-- Per-row indexes are sized for the most common access patterns:
-- * (user_id, revoked, expires_at) cap-count, list-active, refresh-validate
-- * expires_at cleanup sweeps
-- * refresh_token_hash uniqueness + fast lookup on /auth/refresh
-- * auth_provider admin "list sessions by provider" queries
CREATE TABLE IF NOT EXISTS auth_sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
refresh_token_hash TEXT NOT NULL UNIQUE,
device_fingerprint TEXT NOT NULL DEFAULT 'unknown',
device_label TEXT NOT NULL DEFAULT 'Unknown device',
ip TEXT NOT NULL DEFAULT '',
user_agent TEXT NOT NULL DEFAULT '',
auth_provider TEXT NOT NULL DEFAULT 'local',
created_at TEXT NOT NULL,
last_active_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
revoked_reason TEXT,
previous_session_id TEXT
);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id_active
ON auth_sessions(user_id, revoked, expires_at);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at
ON auth_sessions(expires_at);
CREATE INDEX IF NOT EXISTS idx_auth_sessions_auth_provider
ON auth_sessions(auth_provider);
-- V2: auth_meta stores schema version + migration completion markers.
-- Used by init_auth_db to gate one-time migrations (idempotency).
CREATE TABLE IF NOT EXISTS auth_meta (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS terminal_whitelist_user ( CREATE TABLE IF NOT EXISTS terminal_whitelist_user (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
@ -307,12 +408,137 @@ CREATE INDEX IF NOT EXISTS idx_terminal_approvals_status
""" """
# ---------------------------------------------------------------------------
# Schema versioning + one-time migrations
# ---------------------------------------------------------------------------
# Current auth DB schema version. Bump this when adding new tables/columns
# that require data backfill or migration. The :func:`init_auth_db` function
# uses this together with the ``auth_meta.schema_version`` row to decide
# which migrations to run.
_SCHEMA_VERSION = 2
_META_SCHEMA_VERSION_KEY = "schema_version"
async def _get_meta_value(db: aiosqlite.Connection, key: str) -> str | None:
"""Read a key from the ``auth_meta`` table. Returns ``None`` if missing."""
cursor = await db.execute("SELECT value FROM auth_meta WHERE key = ?", (key,))
row = await cursor.fetchone()
return row["value"] if row else None
async def _set_meta_value(db: aiosqlite.Connection, key: str, value: str) -> None:
"""Upsert a key in the ``auth_meta`` table."""
await db.execute(
"INSERT INTO auth_meta (key, value, updated_at) VALUES (?, ?, ?) "
"ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
(key, value, _now_iso()),
)
async def _backfill_user_sessions(db: aiosqlite.Connection) -> int:
"""One-time backfill from ``user_sessions`` (V1) to ``auth_sessions`` (V2).
Runs only when ``auth_sessions`` is empty AND ``user_sessions`` has rows.
Idempotent: subsequent restarts are no-ops because we mark the backfill
as completed in ``auth_meta``.
For each non-revoked V1 session, copies:
- id (reused see note below)
- user_id
- refresh_token_hash
- device_fingerprint / device_label / ip / user_agent from the legacy
``device_info`` JSON blob (best-effort)
- created_at / expires_at
- last_active_at defaults to created_at
- revoked=0 (already filtered)
- revoked_reason=None
- auth_provider='local' (default; backfilled rows are pre-IdP)
The original ``id`` is preserved so that legacy clients holding the
old refresh_token_hash still match a row in the new table this is
what the back-compat path in U10 (``get_current_user`` for legacy
JWTs) relies on.
Returns:
Number of rows backfilled (0 if already done or nothing to backfill).
"""
# Idempotency: check the marker
if await _get_meta_value(db, "backfill_user_sessions_v1_to_v2") == "done":
return 0
cursor = await db.execute("SELECT COUNT(*) FROM auth_sessions")
(count,) = await cursor.fetchone()
if count > 0:
# auth_sessions already has data — this is a fresh V2 install, not an
# upgrade. Mark the backfill done so we never re-check.
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
await db.commit()
return 0
cursor = await db.execute(
"SELECT id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at "
"FROM user_sessions WHERE revoked_at IS NULL"
)
rows = await cursor.fetchall()
backfilled = 0
for row in rows:
try:
device_info = json.loads(row["device_info"]) if row["device_info"] else {}
except (json.JSONDecodeError, TypeError):
device_info = {}
await db.execute(
"INSERT OR IGNORE INTO auth_sessions "
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, "
" ip, user_agent, auth_provider, created_at, last_active_at, expires_at, "
" revoked, revoked_reason, previous_session_id) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
row["id"], # reuse legacy id for back-compat with old clients
row["user_id"],
row["refresh_token_hash"],
device_info.get("fingerprint", "unknown"),
device_info.get("label", "Unknown device"),
device_info.get("ip", ""),
device_info.get("user_agent", ""),
"local", # backfilled rows are pre-IdP by definition
row["created_at"],
row["created_at"], # last_active_at defaults to created_at
row["expires_at"],
0, # not revoked (already filtered)
None,
None,
),
)
backfilled += 1
if backfilled:
logger.info(
f"Backfilled {backfilled} user_sessions rows to auth_sessions "
f"(schema v{_SCHEMA_VERSION})"
)
# Mark the backfill as completed regardless of how many rows were moved.
# (idempotency: even a 0-row backfill is "done".)
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
await db.commit()
return backfilled
async def init_auth_db(db_path: str | Path | None = None) -> Path: async def init_auth_db(db_path: str | Path | None = None) -> Path:
"""Create auth tables if they do not exist. """Create auth tables if they do not exist.
Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight, Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight,
zero-config bootstrap that mirrors :class:`SqliteConversationStore`. zero-config bootstrap that mirrors :class:`SqliteConversationStore`.
On startup, this function:
1. Creates all tables and indexes from :data:`_SCHEMA_SQL` (idempotent).
2. Records the current :data:`_SCHEMA_VERSION` in ``auth_meta``.
3. Runs any pending one-time migrations (currently: V1 V2 backfill
from ``user_sessions`` to ``auth_sessions``).
Args: Args:
db_path: Path to the SQLite file. Defaults to db_path: Path to the SQLite file. Defaults to
:data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project :data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project
@ -325,8 +551,19 @@ async def init_auth_db(db_path: str | Path | None = None) -> Path:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(str(path)) as db: async with aiosqlite.connect(str(path)) as db:
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA journal_mode=WAL") await db.execute("PRAGMA journal_mode=WAL")
await db.executescript(_SCHEMA_SQL) await db.executescript(_SCHEMA_SQL)
# Record the current schema version (idempotent upsert).
current = await _get_meta_value(db, _META_SCHEMA_VERSION_KEY)
if current != str(_SCHEMA_VERSION):
await _set_meta_value(db, _META_SCHEMA_VERSION_KEY, str(_SCHEMA_VERSION))
logger.info(f"Auth DB schema version set to {_SCHEMA_VERSION}")
# Run pending migrations (each is internally idempotent).
await _backfill_user_sessions(db)
await db.commit() await db.commit()
logger.info(f"Auth DB initialized at {path}") logger.info(f"Auth DB initialized at {path}")
@ -353,3 +590,28 @@ def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any
"last_login_at": row["last_login_at"], "last_login_at": row["last_login_at"],
"created_by": row["created_by"], "created_by": row["created_by"],
} }
def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
"""Convert an ``auth_sessions`` row into a JSON-safe dict.
The ``revoked`` field is normalized to a Python ``bool`` (the DB stores
0/1). The full set of audit fields is included so the admin UI and
API responses can surface device/IP/last-active information without
a separate lookup.
"""
return {
"id": row["id"],
"user_id": row["user_id"],
"device_fingerprint": row["device_fingerprint"],
"device_label": row["device_label"],
"ip": row["ip"],
"user_agent": row["user_agent"],
"auth_provider": row["auth_provider"],
"created_at": row["created_at"],
"last_active_at": row["last_active_at"],
"expires_at": row["expires_at"],
"revoked": bool(row["revoked"]),
"revoked_reason": row["revoked_reason"],
"previous_session_id": row["previous_session_id"],
}

View File

@ -0,0 +1,115 @@
"""AuthProvider implementations and dependency-injection factory.
Public surface:
- :class:`AuthProvider` (re-exported from :mod:`.base`) the protocol
every backend must satisfy.
- :class:`LocalAuthProvider` SQLite + bcrypt default.
- :class:`StubOIDCProvider` interface placeholder for the future
OIDC integration.
- :class:`User` provider-agnostic user value object.
- :func:`get_auth_provider` DI factory used by routes via
``Depends(get_auth_provider)``.
- :func:`reset_auth_provider` clears the lru_cache singleton
(used in tests + when the auth provider config changes at runtime).
Configuration
-------------
The provider is selected via the ``AGENTKIT_AUTH_PROVIDER`` environment
variable (default: ``"local"``). When the future ``auth.provider``
field is added to ``agentkit.yaml`` it should override the env var.
Adding a new provider
---------------------
1. Create ``auth/providers/<name>.py`` with a class that satisfies
the :class:`AuthProvider` Protocol (i.e. has ``name: str`` and the
four async methods).
2. Import it here and add a branch to :func:`get_auth_provider`.
3. Add it to ``AGENTKIT_AUTH_PROVIDER`` enum in the deployment docs.
"""
from __future__ import annotations
import logging
import os
from functools import lru_cache
from pathlib import Path
from .base import AuthProvider
from .exceptions import AuthProviderError, InvalidCredentials, ProviderNotImplemented
from .local import LocalAuthProvider
from .oidc_stub import StubOIDCProvider
from .user import User
logger = logging.getLogger(__name__)
# Re-exports for ergonomic imports: `from agentkit.server.auth.providers import AuthProvider`
__all__ = [
"AuthProvider",
"AuthProviderError",
"InvalidCredentials",
"LocalAuthProvider",
"ProviderNotImplemented",
"StubOIDCProvider",
"User",
"get_auth_provider",
"reset_auth_provider",
]
# Default provider name. Configurable via AGENTKIT_AUTH_PROVIDER env var
# (overridable when the auth.provider field is added to agentkit.yaml).
DEFAULT_AUTH_PROVIDER = "local"
def _resolve_provider_name() -> str:
"""Return the configured provider name, defaulting to 'local'."""
name = os.environ.get("AGENTKIT_AUTH_PROVIDER", DEFAULT_AUTH_PROVIDER).strip()
if not name:
return DEFAULT_AUTH_PROVIDER
return name
def _resolve_db_path() -> Path:
"""Return the auth DB path (overridable for tests via env var)."""
from ..models import DEFAULT_AUTH_DB_PATH
env = os.environ.get("AGENTKIT_AUTH_DB")
if env:
return Path(env)
return DEFAULT_AUTH_DB_PATH
@lru_cache(maxsize=1)
def get_auth_provider() -> AuthProvider:
"""Return the configured :class:`AuthProvider` (memoized singleton).
The ``lru_cache`` is process-local and the value is resolved on
first call. Use :func:`reset_auth_provider` to clear the cache
(e.g. when the configuration changes at runtime, or in test
fixtures that need a fresh provider with a different db_path).
Raises:
ValueError: if the configured provider name is not recognized.
"""
name = _resolve_provider_name()
if name == "local":
return LocalAuthProvider(db_path=_resolve_db_path())
if name == "oidc-stub":
return StubOIDCProvider()
raise ValueError(
f"unknown auth provider: {name!r}. "
f"Supported providers: 'local', 'oidc-stub'. "
f"Set AGENTKIT_AUTH_PROVIDER or update agentkit.yaml's auth.provider field."
)
def reset_auth_provider() -> None:
"""Clear the memoized :class:`AuthProvider` singleton.
Use in tests (e.g. in a fixture's teardown) or in code paths that
change the auth provider configuration at runtime and need a
re-resolution.
"""
get_auth_provider.cache_clear()

View File

@ -0,0 +1,109 @@
"""AuthProvider protocol — pluggable authentication backend contract.
The :class:`AuthProvider` Protocol defines the minimal surface area the
auth subsystem needs from any authentication backend (Local today, OIDC
/ SAML / LDAP tomorrow). Routes and admin endpoints call only these
methods and never touch the underlying user store directly. Adding a new
IdP integration is a matter of writing a new adapter that satisfies
this Protocol.
Design notes:
- **No sync surface for password verification**: the only entry point is
:meth:`authenticate` which takes the plaintext password. Internally each
adapter chooses how to verify (bcrypt locally, redirect to IdP, etc.).
- **No persistence responsibility**: providers return :class:`User` objects
but do not manage sessions, refresh tokens, or session state. That is
the route + SessionService's job (see :mod:`agentkit.server.auth.session`).
- **Audit-friendly**: :attr:`name` is written to ``auth_sessions.auth_provider``
on every login, so the source of every session is traceable. This is
what enables "list sessions by provider" admin queries and future
cross-IdP policy enforcement.
- **runtime_checkable**: the Protocol is decorated so that
``isinstance(provider, AuthProvider)`` works at runtime, enabling
defensive checks in tests and in DI wiring.
"""
from __future__ import annotations
from typing import Protocol, runtime_checkable
from .user import User
@runtime_checkable
class AuthProvider(Protocol):
"""All authentication backends must implement this surface.
The route layer only calls the methods below. It is intentionally
unaware of whether the backing store is SQLite, an OIDC IdP, LDAP,
or anything else.
"""
name: str
"""Identifier for this provider, written to ``auth_sessions.auth_provider``.
Convention: ``local`` for the local SQLite + bcrypt backend,
``oidc-<idp-name>`` for OIDC (e.g. ``oidc-keycloak``,
``oidc-feishu``), ``saml`` / ``ldap`` for the other planned
integrations. This value is the stable contract for audit
traceability do not rename without a migration plan.
"""
async def authenticate(self, *, username: str, password: str) -> User:
"""Verify ``username`` + ``password`` and return the :class:`User`.
Args:
username: The submitted username (or, for OIDC, the IdP
subject id but that's a future adapter's concern).
password: The submitted plaintext password.
Returns:
The matched :class:`User` on success.
Raises:
InvalidCredentials: if the user does not exist, the password
is wrong, or the user is inactive. Callers MUST NOT
distinguish between these three cases in the error
message returned to the client (timing-attack /
username-enumeration mitigation).
"""
...
async def get_user_by_id(self, user_id: str) -> User | None:
"""Look up a :class:`User` by primary key.
Used by:
- Admin endpoints that need to display user info by id
- Session validation in the cold-start / whoami path
- Audit log enrichment
Returns ``None`` if no user exists with this id (or the user
is inactive convention: inactive users are "not found" from
the auth layer's perspective).
"""
...
async def sync_user_attributes(self, user_id: str) -> None:
"""Refresh user attributes (department / email / title) from the source of truth.
- :class:`LocalAuthProvider`: no-op (attributes are managed locally).
- OIDC adapter (future): pull the latest profile from the IdP and
write back to the local ``users`` table.
Implementations that have nothing to sync should still define
this method (returning ``None``) so the contract is uniform.
"""
...
async def revoke_user(self, user_id: str) -> None:
"""Disable a user account (e.g. on termination or lock-out).
- :class:`LocalAuthProvider`: ``UPDATE users SET is_active = 0``
- OIDC adapter (future): call the IdP's disable API
The admin endpoint that calls this does NOT also need to
revoke the user's active sessions — that is the
:class:`SessionService`'s job, called separately.
"""
...

View File

@ -0,0 +1,26 @@
"""Exceptions raised by AuthProvider implementations.
These are caught by the route layer and translated to HTTP responses
(401 for :class:`InvalidCredentials`, 501 for :class:`ProviderNotImplemented`).
"""
class AuthProviderError(Exception):
"""Base class for all auth provider errors."""
class InvalidCredentials(AuthProviderError):
"""Raised when username / password is wrong, or the user is inactive.
Translates to HTTP 401. The error message MUST NOT leak which of
"user not found" vs "wrong password" vs "user inactive" failed
that is a username enumeration risk.
"""
class ProviderNotImplemented(AuthProviderError):
"""Raised when a configured provider is not yet implemented.
Translates to HTTP 501. Used by :class:`StubOIDCProvider` and any
future adapter that is registered but has no real implementation.
"""

View File

@ -0,0 +1,163 @@
"""Local authentication provider — SQLite + bcrypt.
The default :class:`AuthProvider` implementation. Authenticates users
against the local ``users`` table in the auth SQLite database using
the bcrypt cost=12 password hash (see :mod:`agentkit.server.auth.password`).
This is a behavioral equivalent of the password-verification code that
previously lived inline in ``routes/auth.py`` moved here so the route
layer can call a single :meth:`authenticate` method regardless of which
backend is configured.
Future-IdP note
---------------
When the organization moves to OIDC / SAML / LDAP, this class does not
need to be deleted. It can continue to serve as a "local emergency
account" provider, configurable side-by-side with the IdP provider via
a future composite / multi-provider setup. For now it is the only
implementation.
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import aiosqlite
from ..models import DEFAULT_AUTH_DB_PATH
from ..password import verify_password
from .user import User
logger = logging.getLogger(__name__)
def _resolve_db_path() -> Path:
"""Resolve the auth DB path with runtime env-var priority.
The :data:`models.DEFAULT_AUTH_DB_PATH` constant is captured at
module-import time and therefore cannot see test-time env mutations.
Re-reading ``AGENTKIT_AUTH_DB`` here keeps the provider
"test-friendly" (tests can ``monkeypatch.setenv`` before constructing
the provider) without giving up the default path when no env is set.
"""
env = os.environ.get("AGENTKIT_AUTH_DB")
if env:
return Path(env)
return DEFAULT_AUTH_DB_PATH
class LocalAuthProvider:
"""AuthProvider backed by the local SQLite ``users`` table + bcrypt.
Args:
db_path: Path to the auth DB. Defaults to the value of the
``AGENTKIT_AUTH_DB`` env var, falling back to
:data:`agentkit.server.auth.models.DEFAULT_AUTH_DB_PATH`.
Each operation opens a short-lived aiosqlite connection;
the existing route layer follows the same pattern, so no
connection pooling is introduced here. If a future
deployment needs pooling, swap in a ``db_factory: Callable``
here without changing the protocol.
"""
name = "local"
def __init__(self, db_path: str | Path | None = None) -> None:
self._db_path = Path(db_path) if db_path is not None else _resolve_db_path()
@property
def db_path(self) -> Path:
return self._db_path
async def authenticate(self, *, username: str, password: str) -> User:
"""Verify the username + password against the local users table.
Raises :class:`InvalidCredentials` on every failure mode
(unknown user, wrong password, inactive user) with the same
error message preventing username enumeration via error
inspection. Constant-time-equivalent behavior is also
ensured by always running a real bcrypt computation
(against a dummy hash) when the user does not exist,
matching the timing of the "user exists, wrong password" path.
"""
from .exceptions import InvalidCredentials # local import to avoid cycle at module load
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT id, username, email, password_hash, role, is_active, "
"is_terminal_authorized, is_server_terminal_authorized, "
"created_at, updated_at, last_login_at, created_by "
"FROM users WHERE username = ?",
(username,),
)
row = await cursor.fetchone()
if row is None or not bool(row["is_active"]):
# Run a real bcrypt verification against a valid-format dummy
# hash so the response time matches the "user exists, wrong
# password" path (~250ms). Prevents username enumeration via
# timing. The dummy hash is invalid (won't match any password)
# but has the right shape so bcrypt.checkpw doesn't short-circuit.
_DUMMY_BCRYPT_HASH = "$2b$12$abcdefghijklmnopqrstuuABCDEFGHIJKLMNOPQRSTUVWXYZ0123"
verify_password(password, _DUMMY_BCRYPT_HASH)
raise InvalidCredentials("invalid username or password")
if not verify_password(password, row["password_hash"]):
raise InvalidCredentials("invalid username or password")
return _row_to_user(row)
async def get_user_by_id(self, user_id: str) -> User | None:
"""Look up a user by id. Returns ``None`` if not found or inactive."""
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT id, username, email, password_hash, role, is_active, "
"is_terminal_authorized, is_server_terminal_authorized, "
"created_at, updated_at, last_login_at, created_by "
"FROM users WHERE id = ? AND is_active = 1",
(user_id,),
)
row = await cursor.fetchone()
return _row_to_user(row) if row else None
async def sync_user_attributes(self, user_id: str) -> None:
"""No-op: local provider has no upstream source of truth to sync from."""
return None
async def revoke_user(self, user_id: str) -> None:
"""Disable a user account (``is_active = 0``)."""
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute(
"UPDATE users SET is_active = 0, updated_at = ? WHERE id = ?",
(_now_iso(), user_id),
)
await db.commit()
logger.info(f"Revoked user {user_id} via LocalAuthProvider")
def _row_to_user(row: aiosqlite.Row) -> User:
"""Convert a ``users`` row to a :class:`User` value object."""
return User(
id=row["id"],
username=row["username"],
email=row["email"],
role=row["role"],
is_active=bool(row["is_active"]),
is_terminal_authorized=bool(row["is_terminal_authorized"]),
is_server_terminal_authorized=bool(row["is_server_terminal_authorized"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
last_login_at=row["last_login_at"],
created_by=row["created_by"],
)
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat()

View File

@ -0,0 +1,83 @@
"""Stub OIDC provider — interface contract placeholder.
This is a deliberately unimplemented :class:`AuthProvider` whose only
job is to fail loudly when ``auth.provider: oidc-stub`` is configured
without the real OIDC integration being in place. It exists so that
the configuration switch and the dependency-injection wiring can be
exercised end-to-end before the IdP integration work is scheduled.
Future OIDC integration checklist
---------------------------------
When the real OIDC integration lands, replace this file with
:file:`auth/providers/oidc.py` implementing:
- [ ] :meth:`authenticate` accept (username, password)? OR redirect-flow?
OIDC is fundamentally a redirect flow, NOT a password POST. The
right answer is probably a separate ``/auth/oauth/{provider}/redirect``
and ``/auth/oauth/{provider}/callback`` route pair that bypasses
the password-based login. The :class:`AuthProvider` interface is
only for backends that accept username + password; an OIDC
adapter may need to extend the protocol or expose a separate
callback-handler hook.
- [ ] :meth:`get_user_by_id` query local cache (auto-provisioned
OIDC users live in the local ``users`` table after first login).
- [ ] :meth:`sync_user_attributes` pull latest profile from the IdP
on each successful login (department / title / email).
- [ ] :meth:`revoke_user` call the IdP's account-disable API (e.g.
Keycloak's ``PUT /admin/realms/{realm}/users/{id}/disable``).
- [ ] State cache for OAuth ``state`` parameter (Redis, TTL 5 min).
- [ ] First-login user auto-provisioning policy: just-in-time
creation / reject / invite-only.
Until then, configuring ``auth.provider: oidc-stub`` and starting
the server lets a smoke test confirm:
1. The DI factory resolves the right class.
2. The route layer calls :meth:`authenticate` and surfaces 501.
3. ``auth_sessions.auth_provider='oidc-stub'`` would be written
for any session that *did* get created (none, in this stub).
"""
from __future__ import annotations
from .exceptions import ProviderNotImplemented
from .user import User
class StubOIDCProvider:
"""Unimplemented OIDC adapter. All methods raise :class:`ProviderNotImplemented`.
The class satisfies the :class:`AuthProvider` Protocol structurally
(it has ``name`` and the four methods), but every method body is
a deliberate failure. This is on purpose: configuring
``auth.provider: oidc-stub`` and then successfully logging in would
be a security bug.
"""
name = "oidc-stub"
async def authenticate(self, *, username: str, password: str) -> User:
raise ProviderNotImplemented(
"OIDC authentication is not implemented yet. "
"Use auth.provider: local in agentkit.yaml, or implement "
"auth/providers/oidc.py (see the checklist in this module's "
"docstring)."
)
async def get_user_by_id(self, user_id: str) -> User | None:
raise ProviderNotImplemented(
"StubOIDCProvider.get_user_by_id is not implemented. "
"See auth/providers/oidc_stub.py for the future OIDC checklist."
)
async def sync_user_attributes(self, user_id: str) -> None:
raise ProviderNotImplemented(
"StubOIDCProvider.sync_user_attributes is not implemented. "
"See auth/providers/oidc_stub.py for the future OIDC checklist."
)
async def revoke_user(self, user_id: str) -> None:
raise ProviderNotImplemented(
"StubOIDCProvider.revoke_user is not implemented. "
"See auth/providers/oidc_stub.py for the future OIDC checklist."
)

View File

@ -0,0 +1,42 @@
"""Provider-agnostic user data model.
This is the value object the :class:`AuthProvider` Protocol returns
from :meth:`AuthProvider.authenticate` and :meth:`AuthProvider.get_user_by_id`.
It is intentionally NOT the SQLAlchemy ``UserModel`` ORM class so that
future IdP adapters (OIDC / SAML / LDAP) can construct one without
touching a DB.
The route layer converts ``User`` ``UserResponse`` (the public
API payload) at the boundary; ``User`` itself stays an internal
type.
"""
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, EmailStr
class User(BaseModel):
"""A user as known to the auth subsystem.
Distinct from the SQLAlchemy ``UserModel`` ORM class because:
- The auth provider may not have a local row at all (e.g. an
OIDC user logged in via SSO who hasn't been provisioned yet
that's a future concern but the data model must allow it).
- IdP-sourced attributes (department / title) are not in the
SQLAlchemy model.
"""
model_config = ConfigDict(extra="forbid")
id: str
username: str
email: EmailStr
role: str = "member"
is_active: bool = True
is_terminal_authorized: bool = False
is_server_terminal_authorized: bool = False
created_at: str
updated_at: str
last_login_at: str | None = None
created_by: str | None = None

View File

View File

View File

@ -0,0 +1,112 @@
"""Tests for the AuthProvider protocol and the DI factory (U11)."""
from __future__ import annotations
import pytest
from agentkit.server.auth.providers import (
AuthProvider,
LocalAuthProvider,
StubOIDCProvider,
get_auth_provider,
reset_auth_provider,
)
from agentkit.server.auth.providers.base import AuthProvider as AuthProviderFromBase
from agentkit.server.auth.providers.exceptions import (
AuthProviderError,
InvalidCredentials,
ProviderNotImplemented,
)
class TestProtocolConformance:
"""The runtime_checkable Protocol accepts both real implementations."""
def test_local_passes_isinstance_check(self):
provider = LocalAuthProvider()
assert isinstance(provider, AuthProvider)
assert isinstance(provider, AuthProviderFromBase)
def test_stub_passes_isinstance_check(self):
provider = StubOIDCProvider()
assert isinstance(provider, AuthProvider)
assert isinstance(provider, AuthProviderFromBase)
class TestProviderNames:
def test_local_provider_name(self):
assert LocalAuthProvider.name == "local"
def test_stub_oidc_provider_name(self):
assert StubOIDCProvider.name == "oidc-stub"
class TestDiFactory:
"""``get_auth_provider`` is a memoized singleton, env-driven."""
def setup_method(self):
"""Reset cache and env before each test."""
reset_auth_provider()
self._saved_env = {}
for key in ("AGENTKIT_AUTH_PROVIDER",):
self._saved_env[key] = __import__("os").environ.pop(key, None)
def teardown_method(self):
for key, value in self._saved_env.items():
if value is not None:
__import__("os").environ[key] = value
reset_auth_provider()
def test_default_provider_is_local(self):
provider = get_auth_provider()
assert isinstance(provider, LocalAuthProvider)
assert provider.name == "local"
def test_oidc_stub_provider(self, monkeypatch):
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
reset_auth_provider()
provider = get_auth_provider()
assert isinstance(provider, StubOIDCProvider)
assert provider.name == "oidc-stub"
def test_unknown_provider_raises_value_error(self, monkeypatch):
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "ldap-unknown")
reset_auth_provider()
with pytest.raises(ValueError, match="unknown auth provider"):
get_auth_provider()
def test_factory_is_memoized(self):
first = get_auth_provider()
second = get_auth_provider()
assert first is second
def test_reset_clears_cache(self, monkeypatch):
first = get_auth_provider()
reset_auth_provider()
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
second = get_auth_provider()
assert first is not second
assert isinstance(second, StubOIDCProvider)
class TestExceptionHierarchy:
def test_invalid_credentials_inherits_auth_provider_error(self):
assert issubclass(InvalidCredentials, AuthProviderError)
assert issubclass(InvalidCredentials, Exception)
def test_provider_not_implemented_inherits_auth_provider_error(self):
assert issubclass(ProviderNotImplemented, AuthProviderError)
assert issubclass(ProviderNotImplemented, Exception)
def test_invalid_credentials_can_be_raised_and_caught(self):
with pytest.raises(InvalidCredentials):
raise InvalidCredentials("test message")
def test_invalid_credentials_caught_as_base(self):
"""Route layer catches AuthProviderError to handle all provider errors."""
with pytest.raises(AuthProviderError):
raise InvalidCredentials("test")
def test_provider_not_implemented_caught_as_base(self):
with pytest.raises(AuthProviderError):
raise ProviderNotImplemented("test")

View File

@ -0,0 +1,254 @@
"""Tests for LocalAuthProvider (U11 — concrete Local implementation)."""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from pathlib import Path
import aiosqlite
import pytest
from agentkit.server.auth.models import init_auth_db
from agentkit.server.auth.password import hash_password
from agentkit.server.auth.providers import LocalAuthProvider
from agentkit.server.auth.providers.exceptions import InvalidCredentials
from agentkit.server.auth.providers.user import User
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def auth_db_with_users(tmp_path: Path) -> dict:
"""Create a fresh auth DB with two users: one active, one inactive.
Returns a dict with user info + the db path.
"""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
now_iso = datetime.now(timezone.utc).isoformat()
active_id = str(uuid.uuid4())
inactive_id = str(uuid.uuid4())
active_pw = "correct-horse-battery-staple"
inactive_pw = "disabled-user-pw"
async with aiosqlite.connect(str(db_path)) as db:
await db.execute(
"INSERT INTO users "
"(id, username, email, password_hash, role, is_active, "
" is_terminal_authorized, is_server_terminal_authorized, "
" created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
active_id,
"alice",
"alice@example.com",
hash_password(active_pw),
"member",
1,
0,
0,
now_iso,
now_iso,
),
)
await db.execute(
"INSERT INTO users "
"(id, username, email, password_hash, role, is_active, "
" is_terminal_authorized, is_server_terminal_authorized, "
" created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
inactive_id,
"bob_inactive",
"bob@example.com",
hash_password(inactive_pw),
"member",
0,
0,
0,
now_iso,
now_iso,
),
)
await db.commit()
return {
"db_path": db_path,
"active": {
"id": active_id,
"username": "alice",
"password": active_pw,
"email": "alice@example.com",
"role": "member",
},
"inactive": {
"id": inactive_id,
"username": "bob_inactive",
"password": inactive_pw,
},
}
@pytest.fixture
def provider(auth_db_with_users: dict) -> LocalAuthProvider:
return LocalAuthProvider(db_path=auth_db_with_users["db_path"])
# ---------------------------------------------------------------------------
# authenticate
# ---------------------------------------------------------------------------
class TestAuthenticate:
async def test_valid_credentials_returns_user(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
user = await provider.authenticate(
username=auth_db_with_users["active"]["username"],
password=auth_db_with_users["active"]["password"],
)
assert isinstance(user, User)
assert user.id == auth_db_with_users["active"]["id"]
assert user.username == "alice"
assert user.email == "alice@example.com"
assert user.role == "member"
assert user.is_active is True
async def test_wrong_password_raises_invalid_credentials(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
with pytest.raises(InvalidCredentials):
await provider.authenticate(
username=auth_db_with_users["active"]["username"],
password="definitely-not-the-password",
)
async def test_unknown_user_raises_invalid_credentials(self, provider: LocalAuthProvider):
with pytest.raises(InvalidCredentials):
await provider.authenticate(username="nonexistent", password="anything")
async def test_inactive_user_raises_invalid_credentials(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
with pytest.raises(InvalidCredentials):
await provider.authenticate(
username=auth_db_with_users["inactive"]["username"],
password=auth_db_with_users["inactive"]["password"],
)
async def test_error_message_does_not_leak_username_existence(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
"""Wrong-password and unknown-user errors must have the same message."""
try:
await provider.authenticate(
username=auth_db_with_users["active"]["username"],
password="wrong",
)
except InvalidCredentials as e1:
wrong_pw_msg = str(e1)
try:
await provider.authenticate(username="nobody-here", password="x")
except InvalidCredentials as e2:
unknown_msg = str(e2)
assert wrong_pw_msg == unknown_msg
# ---------------------------------------------------------------------------
# get_user_by_id
# ---------------------------------------------------------------------------
class TestGetUserById:
async def test_returns_user_for_active_id(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
user = await provider.get_user_by_id(auth_db_with_users["active"]["id"])
assert user is not None
assert user.id == auth_db_with_users["active"]["id"]
assert user.username == "alice"
async def test_returns_none_for_inactive_user(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
user = await provider.get_user_by_id(auth_db_with_users["inactive"]["id"])
assert user is None
async def test_returns_none_for_unknown_id(self, provider: LocalAuthProvider):
user = await provider.get_user_by_id(str(uuid.uuid4()))
assert user is None
# ---------------------------------------------------------------------------
# sync_user_attributes
# ---------------------------------------------------------------------------
class TestSyncUserAttributes:
async def test_is_noop(self, provider: LocalAuthProvider, auth_db_with_users: dict):
"""Local provider has no upstream to sync from — must succeed with no effect."""
result = await provider.sync_user_attributes(auth_db_with_users["active"]["id"])
assert result is None
async def test_noop_does_not_throw_for_unknown_user(self, provider: LocalAuthProvider):
"""sync_user_attributes is fire-and-forget — no existence check."""
# Should NOT raise even though the id doesn't exist
await provider.sync_user_attributes(str(uuid.uuid4()))
# ---------------------------------------------------------------------------
# revoke_user
# ---------------------------------------------------------------------------
class TestRevokeUser:
async def test_sets_is_active_to_zero(
self, provider: LocalAuthProvider, auth_db_with_users: dict, tmp_path: Path
):
await provider.revoke_user(auth_db_with_users["active"]["id"])
async with aiosqlite.connect(str(auth_db_with_users["db_path"])) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT is_active FROM users WHERE id = ?",
(auth_db_with_users["active"]["id"],),
)
row = await cursor.fetchone()
assert bool(row["is_active"]) is False
async def test_revoked_user_can_no_longer_authenticate(
self, provider: LocalAuthProvider, auth_db_with_users: dict
):
await provider.revoke_user(auth_db_with_users["active"]["id"])
with pytest.raises(InvalidCredentials):
await provider.authenticate(
username=auth_db_with_users["active"]["username"],
password=auth_db_with_users["active"]["password"],
)
async def test_revoke_unknown_user_does_not_raise(self, provider: LocalAuthProvider):
"""``UPDATE ... WHERE id = ?`` with no match is a no-op, not an error."""
await provider.revoke_user(str(uuid.uuid4()))
# ---------------------------------------------------------------------------
# default db_path
# ---------------------------------------------------------------------------
class TestDefaultDbPath:
def test_default_db_path_uses_models_default(self, tmp_path: Path, monkeypatch):
"""If no db_path is given, the provider should use the module default.
The default may resolve to a path that does not exist on a test
machine we only assert that the property returns a Path, not
that the file exists.
"""
monkeypatch.setenv("AGENTKIT_AUTH_DB", str(tmp_path / "default.db"))
provider = LocalAuthProvider()
assert isinstance(provider.db_path, Path)
assert str(provider.db_path) == str(tmp_path / "default.db")

View File

@ -0,0 +1,69 @@
"""Tests for StubOIDCProvider (U11 — interface placeholder)."""
from __future__ import annotations
import pytest
from agentkit.server.auth.providers import StubOIDCProvider
from agentkit.server.auth.providers.exceptions import (
AuthProviderError,
ProviderNotImplemented,
)
class TestStubAuthenticate:
async def test_raises_provider_not_implemented(self):
provider = StubOIDCProvider()
with pytest.raises(ProviderNotImplemented):
await provider.authenticate(username="alice", password="hunter2")
async def test_error_message_mentions_oidc(self):
"""The error message must guide the operator to the right next step."""
provider = StubOIDCProvider()
with pytest.raises(ProviderNotImplemented) as exc:
await provider.authenticate(username="alice", password="x")
msg = str(exc.value)
assert "OIDC" in msg
assert "not implemented" in msg.lower()
class TestStubGetUserById:
async def test_raises_provider_not_implemented(self):
provider = StubOIDCProvider()
with pytest.raises(ProviderNotImplemented):
await provider.get_user_by_id("any-id")
class TestStubSyncUserAttributes:
async def test_raises_provider_not_implemented(self):
provider = StubOIDCProvider()
with pytest.raises(ProviderNotImplemented):
await provider.sync_user_attributes("any-id")
class TestStubRevokeUser:
async def test_raises_provider_not_implemented(self):
provider = StubOIDCProvider()
with pytest.raises(ProviderNotImplemented):
await provider.revoke_user("any-id")
class TestAllErrorsAreAuthProviderError:
"""Every stub method's error should be catchable as AuthProviderError."""
@pytest.mark.parametrize(
"method,call_args",
[
("authenticate", ()),
("get_user_by_id", ("id",)),
("sync_user_attributes", ("id",)),
("revoke_user", ("id",)),
],
)
async def test_catchable_as_base(self, method: str, call_args: tuple):
provider = StubOIDCProvider()
with pytest.raises(AuthProviderError):
if method == "authenticate":
await provider.authenticate(username="u", password="p")
else:
await getattr(provider, method)(*call_args)

View File

@ -0,0 +1,653 @@
"""Unit tests for auth.models (U1 — V2 schema + backfill).
Covers:
- ``auth_sessions`` table creation (columns + indexes)
- ``auth_meta`` table creation
- ``_SCHEMA_VERSION`` constant value
- ``auth_session_row_to_dict`` field-by-field conversion
- ``_backfill_user_sessions`` one-time migration (V1 V2)
- ``init_auth_db`` idempotency (subsequent runs are no-ops)
- ``auth_provider`` column default value
- Index presence (verified via ``PRAGMA index_list``)
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from pathlib import Path
import aiosqlite
import pytest
from agentkit.server.auth.models import (
AuthSessionModel,
UserSessionModel,
_SCHEMA_VERSION,
auth_session_row_to_dict,
init_auth_db,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def fresh_db(tmp_path: Path) -> Path:
"""A brand-new auth DB on a fresh path (no data)."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
return db_path
async def _insert_user(db: aiosqlite.Connection, *, user_id: str | None = None) -> str:
"""Insert a minimal user row and return its id."""
user_id = user_id or str(uuid.uuid4())
now_iso = datetime.now(timezone.utc).isoformat()
await db.execute(
"INSERT INTO users "
"(id, username, email, password_hash, role, is_active, "
" is_terminal_authorized, is_server_terminal_authorized, "
" created_at, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
user_id,
f"user-{user_id[:8]}",
f"{user_id[:8]}@example.com",
"$2b$12$placeholder.hash.placeholder.hash.placeholder.hash",
"member",
1,
0,
0,
now_iso,
now_iso,
),
)
return user_id
async def _insert_user_session(
db: aiosqlite.Connection,
*,
user_id: str,
session_id: str | None = None,
refresh_hash: str | None = None,
device_info: str = "{}",
revoked: bool = False,
) -> str:
"""Insert a V1 ``user_sessions`` row and return its id."""
session_id = session_id or str(uuid.uuid4())
refresh_hash = refresh_hash or uuid.uuid4().hex
now_iso = datetime.now(timezone.utc).isoformat()
await db.execute(
"INSERT INTO user_sessions "
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at) "
"VALUES (?, ?, ?, ?, ?, ?, ?)",
(
session_id,
user_id,
refresh_hash,
device_info,
now_iso,
now_iso,
now_iso if revoked else None,
),
)
return session_id
async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
"""Return the set of index names for a table.
Sets ``row_factory`` on the connection so we can address columns by name.
PRAGMA results in aiosqlite come back as plain tuples unless a row factory
is configured.
"""
db.row_factory = aiosqlite.Row
cursor = await db.execute(f"PRAGMA index_list({table})")
rows = await cursor.fetchall()
return {row["name"] for row in rows}
# ---------------------------------------------------------------------------
# _SCHEMA_VERSION
# ---------------------------------------------------------------------------
class TestSchemaVersion:
def test_schema_version_is_v2(self):
"""The current schema version is 2 (V2 adds auth_sessions + auth_meta)."""
assert _SCHEMA_VERSION == 2
def test_sqlalchemy_model_table_name(self):
assert AuthSessionModel.__tablename__ == "auth_sessions"
assert UserSessionModel.__tablename__ == "user_sessions"
# ---------------------------------------------------------------------------
# init_auth_db: tables + indexes
# ---------------------------------------------------------------------------
class TestInitAuthDbTables:
async def test_creates_auth_sessions_table(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
cursor = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_sessions'"
)
row = await cursor.fetchone()
assert row is not None
async def test_creates_auth_meta_table(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
cursor = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_meta'"
)
row = await cursor.fetchone()
assert row is not None
async def test_creates_user_sessions_table_for_back_compat(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
cursor = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='user_sessions'"
)
row = await cursor.fetchone()
assert row is not None
async def test_records_schema_version_in_auth_meta(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
row = await cursor.fetchone()
assert row is not None
assert row["value"] == str(_SCHEMA_VERSION)
class TestAuthSessionsIndexes:
async def test_user_id_active_index(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)):
pass
async with aiosqlite.connect(str(fresh_db)) as db:
indexes = await _list_index_names(db, "auth_sessions")
assert "idx_auth_sessions_user_id_active" in indexes
async def test_expires_at_index(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
indexes = await _list_index_names(db, "auth_sessions")
assert "idx_auth_sessions_expires_at" in indexes
async def test_auth_provider_index(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
indexes = await _list_index_names(db, "auth_sessions")
assert "idx_auth_sessions_auth_provider" in indexes
async def test_refresh_token_hash_unique_index(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
# SQLite stores column-level UNIQUE constraints as PRAGMA index_list
# entries with auto-generated names like sqlite_autoindex_<table>_1.
# The PRAGMA index_info on each autoindex reports the constrained columns.
cursor = await db.execute("PRAGMA index_list(auth_sessions)")
index_entries = await cursor.fetchall()
column_names: set[str] = set()
for entry in index_entries:
col_cursor = await db.execute(f"PRAGMA index_info({entry['name']})")
col_rows = await col_cursor.fetchall()
for col_row in col_rows:
column_names.add(col_row["name"])
assert "refresh_token_hash" in column_names
# ---------------------------------------------------------------------------
# auth_sessions columns
# ---------------------------------------------------------------------------
class TestAuthSessionsColumns:
async def test_required_columns_present(self, fresh_db: Path):
async with aiosqlite.connect(str(fresh_db)) as db:
cursor = await db.execute("PRAGMA table_info(auth_sessions)")
cols = {row[1] for row in await cursor.fetchall()}
required = {
"id",
"user_id",
"refresh_token_hash",
"device_fingerprint",
"device_label",
"ip",
"user_agent",
"auth_provider",
"created_at",
"last_active_at",
"expires_at",
"revoked",
"revoked_reason",
"previous_session_id",
}
assert required.issubset(cols)
async def test_auth_provider_default_is_local(self, fresh_db: Path):
"""Insert a row without auth_provider → column should default to 'local'."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
sid = str(uuid.uuid4())
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
" revoked) "
"VALUES (?, ?, ?, ?, ?, ?, 0)",
(
sid,
user_id,
"deadbeef",
"2026-01-01T00:00:00+00:00",
"2026-01-01T00:00:00+00:00",
"2026-12-31T00:00:00+00:00",
),
)
await db.commit()
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT auth_provider FROM auth_sessions WHERE id=?", (sid,))
row = await cursor.fetchone()
assert row is not None
assert row["auth_provider"] == "local"
async def test_revoked_default_is_false(self, fresh_db: Path):
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
sid = str(uuid.uuid4())
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(
sid,
user_id,
"beefdead",
"2026-01-01T00:00:00+00:00",
"2026-01-01T00:00:00+00:00",
"2026-12-31T00:00:00+00:00",
),
)
await db.commit()
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT revoked FROM auth_sessions WHERE id=?", (sid,))
row = await cursor.fetchone()
assert bool(row["revoked"]) is False
# ---------------------------------------------------------------------------
# auth_session_row_to_dict
# ---------------------------------------------------------------------------
class TestAuthSessionRowToDict:
async def test_converts_all_fields(self, fresh_db: Path):
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
sid = str(uuid.uuid4())
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, ip, "
" user_agent, auth_provider, created_at, last_active_at, expires_at, "
" revoked, revoked_reason, previous_session_id) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
sid,
user_id,
"abc123",
"fp-xyz",
"macOS Chrome 119",
"10.0.0.1",
"Mozilla/5.0",
"local",
"2026-01-01T00:00:00+00:00",
"2026-06-20T00:00:00+00:00",
"2027-01-01T00:00:00+00:00",
1,
"user_terminated",
"old-sid",
),
)
await db.commit()
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
row = await cursor.fetchone()
d = auth_session_row_to_dict(row)
assert d["id"] == sid
assert d["user_id"] == user_id
assert d["device_fingerprint"] == "fp-xyz"
assert d["device_label"] == "macOS Chrome 119"
assert d["ip"] == "10.0.0.1"
assert d["user_agent"] == "Mozilla/5.0"
assert d["auth_provider"] == "local"
assert d["revoked"] is True
assert d["revoked_reason"] == "user_terminated"
assert d["previous_session_id"] == "old-sid"
async def test_normalizes_revoked_to_bool(self, fresh_db: Path):
"""DB stores 0/1; helper should return Python bool."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
sid = str(uuid.uuid4())
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
" revoked) "
"VALUES (?, ?, ?, ?, ?, ?, 0)",
(sid, user_id, "x", "t", "t", "t"),
)
await db.commit()
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
row = await cursor.fetchone()
d = auth_session_row_to_dict(row)
assert isinstance(d["revoked"], bool)
assert d["revoked"] is False
# ---------------------------------------------------------------------------
# _backfill_user_sessions
# ---------------------------------------------------------------------------
class TestBackfillUserSessions:
async def test_backfills_non_revoked_v1_rows(self, fresh_db: Path):
"""On a fresh V1 install, all non-revoked rows are migrated to auth_sessions."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="hash1",
device_info="{}",
)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="hash2",
device_info=json.dumps(
{
"fingerprint": "mac-tauri",
"label": "macOS Tauri 1.0",
"ip": "192.168.1.10",
"user_agent": "Tauri/1.0",
}
),
)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="hash3",
revoked=True,
)
await db.commit()
# Force a backfill by clearing the marker + dropping auth_sessions rows
# (simulating a V1 install upgrading to V2).
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
# Re-init to trigger the migration
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute("SELECT * FROM auth_sessions ORDER BY refresh_token_hash")
rows = await cursor.fetchall()
# Only the 2 non-revoked rows should have been backfilled
assert len(rows) == 2
hashes = {row["refresh_token_hash"] for row in rows}
assert hashes == {"hash1", "hash2"}
async def test_backfill_preserves_original_id(self, fresh_db: Path):
"""Backfilled rows reuse the V1 id so legacy clients match."""
user_id = str(uuid.uuid4())
original_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(
db,
user_id=user_id,
session_id=original_id,
refresh_hash="hash-original",
)
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT id FROM auth_sessions WHERE refresh_token_hash='hash-original'"
)
row = await cursor.fetchone()
assert row is not None
assert row["id"] == original_id
async def test_backfill_does_not_touch_revoked_v1_rows(self, fresh_db: Path):
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="revoked-hash",
revoked=True,
)
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT COUNT(*) AS c FROM auth_sessions WHERE refresh_token_hash='revoked-hash'"
)
row = await cursor.fetchone()
assert row["c"] == 0
async def test_backfill_copies_device_info_fields(self, fresh_db: Path):
user_id = str(uuid.uuid4())
device_info = json.dumps(
{
"fingerprint": "win-tauri-abc",
"label": "Windows Tauri 1.0",
"ip": "10.0.0.5",
"user_agent": "Tauri/2.0",
}
)
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="hash-with-device",
device_info=device_info,
)
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT device_fingerprint, device_label, ip, user_agent "
"FROM auth_sessions WHERE refresh_token_hash='hash-with-device'"
)
row = await cursor.fetchone()
assert row["device_fingerprint"] == "win-tauri-abc"
assert row["device_label"] == "Windows Tauri 1.0"
assert row["ip"] == "10.0.0.5"
assert row["user_agent"] == "Tauri/2.0"
async def test_backfill_handles_malformed_device_info(self, fresh_db: Path):
"""Malformed JSON in device_info → fall back to defaults."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(
db,
user_id=user_id,
refresh_hash="bad-json",
device_info="not-json{",
)
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT device_fingerprint, device_label FROM auth_sessions "
"WHERE refresh_token_hash='bad-json'"
)
row = await cursor.fetchone()
assert row is not None
assert row["device_fingerprint"] == "unknown"
assert row["device_label"] == "Unknown device"
async def test_backfill_marks_done_in_auth_meta(self, fresh_db: Path):
"""After a backfill, the auth_meta marker is set so it doesn't run again."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await _insert_user_session(db, user_id=user_id, refresh_hash="h1")
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
)
row = await cursor.fetchone()
assert row is not None
assert row["value"] == "done"
async def test_backfill_is_idempotent(self, fresh_db: Path):
"""Running init twice does NOT duplicate auth_sessions rows."""
user_id = str(uuid.uuid4())
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
for i in range(3):
await _insert_user_session(db, user_id=user_id, refresh_hash=f"hash{i}")
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
await init_auth_db(fresh_db) # second run should be a no-op
async with aiosqlite.connect(str(fresh_db)) as db:
cursor = await db.execute("SELECT COUNT(*) AS c FROM auth_sessions")
row = await cursor.fetchone()
assert row[0] == 3
async def test_fresh_install_marks_backfill_done_without_rows(self, fresh_db: Path):
"""A fresh V2 install (no V1 rows) still marks the backfill as done."""
# The fresh_db fixture already ran init_auth_db once.
# Re-running should be a no-op (idempotent).
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
)
row = await cursor.fetchone()
assert row is not None
assert row["value"] == "done"
async def test_backfill_preserves_expires_at(self, fresh_db: Path):
user_id = str(uuid.uuid4())
original_exp = "2027-06-20T12:34:56+00:00"
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
await db.execute(
"INSERT INTO user_sessions "
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(
str(uuid.uuid4()),
user_id,
"exp-hash",
"{}",
"2026-01-01T00:00:00+00:00",
original_exp,
),
)
await db.execute("DELETE FROM auth_sessions")
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
await db.commit()
await init_auth_db(fresh_db)
async with aiosqlite.connect(str(fresh_db)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT expires_at FROM auth_sessions WHERE refresh_token_hash='exp-hash'"
)
row = await cursor.fetchone()
assert row["expires_at"] == original_exp
# ---------------------------------------------------------------------------
# Active-session query pattern (covers the cap-count and list-active paths)
# ---------------------------------------------------------------------------
class TestActiveSessionQueries:
async def test_query_active_sessions_for_user(self, fresh_db: Path):
"""The (user_id, revoked, expires_at) index supports the active-session query."""
user_id = str(uuid.uuid4())
future = "2027-12-31T00:00:00+00:00"
past = "2020-01-01T00:00:00+00:00"
async with aiosqlite.connect(str(fresh_db)) as db:
await _insert_user(db, user_id=user_id)
# 2 active, 1 expired, 1 revoked → 2 should match
for hash_, exp in [("a", future), ("b", future), ("c", past)]:
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
" revoked) VALUES (?, ?, ?, ?, ?, ?, 0)",
(str(uuid.uuid4()), user_id, hash_, "t", "t", exp),
)
await db.execute(
"INSERT INTO auth_sessions "
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
" revoked, revoked_reason) "
"VALUES (?, ?, ?, ?, ?, ?, 1, 'user_terminated')",
(str(uuid.uuid4()), user_id, "d", "t", "t", future),
)
await db.commit()
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT refresh_token_hash FROM auth_sessions "
"WHERE user_id = ? AND revoked = 0 AND expires_at > ?",
(user_id, "2026-06-20T00:00:00+00:00"),
)
rows = await cursor.fetchall()
hashes = {row["refresh_token_hash"] for row in rows}
assert hashes == {"a", "b"}