feat(auth): U11 AuthProvider 抽象层 + auth_sessions schema
为未来对接集团 IdP(OIDC / SAML / LDAP / 飞书 / 钉钉 / 企微)留扩展点, 同时落地 auth_sessions 表(V2 替代 user_sessions)。 变更 - models.py: 新增 auth_sessions + auth_meta 表,V1→V2 数据回填 - providers/base.py: AuthProvider Protocol 接口契约 - providers/local.py: LocalAuthProvider 默认实现(封装 SQLite + bcrypt) - providers/oidc_stub.py: StubOIDCProvider 占位(NotImplementedError) - providers/__init__.py: get_auth_provider DI 工厂(lru_cache 单例) - providers/exceptions.py: AuthProviderError / InvalidCredentials / ProviderNotImplemented - providers/user.py: Provider-agnostic User 值对象 - tests/unit/auth/: 37 个测试覆盖 Protocol / DI / Local / OIDC 行为 auth_sessions.auth_provider 字段记录登录来源(local / oidc-stub / 未来 oidc-keycloak / saml / ldap),未来切 IdP 时审计可溯源。 测试: 37 passed (providers) + 62 passed (auth 全集) + ruff check clean
This commit is contained in:
parent
54955aab50
commit
2f55fc7434
|
|
@ -5,10 +5,23 @@ stored as ``String(36)`` so the same schema works on both SQLite and
|
|||
PostgreSQL without dialect-specific types.
|
||||
|
||||
Use :func:`init_auth_db` to create the tables on startup.
|
||||
|
||||
Schema versioning
|
||||
-----------------
|
||||
The :data:`_SCHEMA_VERSION` constant tracks the current auth DB schema. The
|
||||
:func:`_backfill_user_sessions` one-time migration is gated on this version
|
||||
to ensure idempotency. After a successful backfill the version is stored
|
||||
in the ``auth_meta`` table so subsequent restarts are no-ops.
|
||||
|
||||
V2 additions (2026-06-20, Centralized Auth & Token Persistence):
|
||||
- New ``auth_sessions`` table with full device/IP/audit metadata and an
|
||||
``auth_provider`` column for future IdP integration traceability.
|
||||
- New ``auth_meta`` table for storing schema version + migration state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
|
|
@ -88,10 +101,16 @@ class UserApiKeyModel(Base):
|
|||
|
||||
|
||||
class UserSessionModel(Base):
|
||||
"""Refresh-token session record.
|
||||
"""Refresh-token session record (V1, deprecated — see :class:`AuthSessionModel`).
|
||||
|
||||
Stores the SHA-256 hash of the refresh token (never the plaintext).
|
||||
``revoked_at`` is set on logout / forced revocation.
|
||||
|
||||
.. deprecated::
|
||||
Kept for one minor version (per U10 back-compat shim) so legacy
|
||||
clients holding JWTs without ``sid`` claim can still validate.
|
||||
New code should use :class:`AuthSessionModel` (table ``auth_sessions``)
|
||||
which carries full device/IP/audit metadata.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_sessions"
|
||||
|
|
@ -107,6 +126,50 @@ class UserSessionModel(Base):
|
|||
revoked_at: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
|
||||
class AuthSessionModel(Base):
|
||||
"""Server-side session record (V2, the primary session table going forward).
|
||||
|
||||
Each row corresponds to a single refresh-token issuance. The full JWT
|
||||
session id (``sid`` claim) is the row's ``id`` (UUID string).
|
||||
|
||||
V2 fields (vs V1 ``user_sessions``):
|
||||
- ``device_fingerprint`` / ``device_label``: surfaces "which device is
|
||||
this session" in the admin UI.
|
||||
- ``ip`` / ``user_agent``: audit trail.
|
||||
- ``last_active_at``: updated on every successful refresh, used to
|
||||
display "last seen" in the sessions list.
|
||||
- ``revoked`` (0/1) + ``revoked_reason``: explicit revoked-state machine
|
||||
with machine-readable reasons (``user_terminated``, ``password_changed``,
|
||||
``admin_revoked``, ``reuse_detected``, ``session_cap_eviction``).
|
||||
- ``previous_session_id``: back-pointer to the previous session id,
|
||||
written on refresh rotation, for audit trail.
|
||||
- ``auth_provider``: the AuthProvider that issued this session
|
||||
(``local`` / ``oidc-stub`` / future ``oidc-keycloak`` / ``saml`` / ``ldap``).
|
||||
Enables admin "list sessions by provider" queries and audit traceability.
|
||||
"""
|
||||
|
||||
__tablename__ = "auth_sessions"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String(36), nullable=False, index=True)
|
||||
refresh_token_hash: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False, index=True
|
||||
)
|
||||
device_fingerprint: Mapped[str] = mapped_column(String(128), nullable=False, default="unknown")
|
||||
device_label: Mapped[str] = mapped_column(String(256), nullable=False, default="Unknown device")
|
||||
ip: Mapped[str] = mapped_column(String(64), nullable=False, default="")
|
||||
user_agent: Mapped[str] = mapped_column(String(512), nullable=False, default="")
|
||||
auth_provider: Mapped[str] = mapped_column(
|
||||
String(32), nullable=False, default="local", index=True
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||
last_active_at: Mapped[str] = mapped_column(String(64), nullable=False, default=_now_iso)
|
||||
expires_at: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
revoked: Mapped[bool] = mapped_column(default=False, nullable=False, index=True)
|
||||
revoked_reason: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
previous_session_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||
|
||||
|
||||
class TerminalWhitelistUserModel(Base):
|
||||
"""Per-user terminal command whitelist.
|
||||
|
||||
|
|
@ -244,6 +307,44 @@ CREATE INDEX IF NOT EXISTS idx_user_sessions_user_id ON user_sessions(user_id);
|
|||
CREATE INDEX IF NOT EXISTS idx_user_sessions_refresh_token_hash
|
||||
ON user_sessions(refresh_token_hash);
|
||||
|
||||
-- V2: auth_sessions replaces user_sessions as the primary session table.
|
||||
-- Stores device/IP/audit metadata and auth_provider for IdP traceability.
|
||||
-- Per-row indexes are sized for the most common access patterns:
|
||||
-- * (user_id, revoked, expires_at) — cap-count, list-active, refresh-validate
|
||||
-- * expires_at — cleanup sweeps
|
||||
-- * refresh_token_hash — uniqueness + fast lookup on /auth/refresh
|
||||
-- * auth_provider — admin "list sessions by provider" queries
|
||||
CREATE TABLE IF NOT EXISTS auth_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
refresh_token_hash TEXT NOT NULL UNIQUE,
|
||||
device_fingerprint TEXT NOT NULL DEFAULT 'unknown',
|
||||
device_label TEXT NOT NULL DEFAULT 'Unknown device',
|
||||
ip TEXT NOT NULL DEFAULT '',
|
||||
user_agent TEXT NOT NULL DEFAULT '',
|
||||
auth_provider TEXT NOT NULL DEFAULT 'local',
|
||||
created_at TEXT NOT NULL,
|
||||
last_active_at TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked INTEGER NOT NULL DEFAULT 0,
|
||||
revoked_reason TEXT,
|
||||
previous_session_id TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_id_active
|
||||
ON auth_sessions(user_id, revoked, expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at
|
||||
ON auth_sessions(expires_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_sessions_auth_provider
|
||||
ON auth_sessions(auth_provider);
|
||||
|
||||
-- V2: auth_meta stores schema version + migration completion markers.
|
||||
-- Used by init_auth_db to gate one-time migrations (idempotency).
|
||||
CREATE TABLE IF NOT EXISTS auth_meta (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS terminal_whitelist_user (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
|
|
@ -307,12 +408,137 @@ CREATE INDEX IF NOT EXISTS idx_terminal_approvals_status
|
|||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema versioning + one-time migrations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Current auth DB schema version. Bump this when adding new tables/columns
|
||||
# that require data backfill or migration. The :func:`init_auth_db` function
|
||||
# uses this together with the ``auth_meta.schema_version`` row to decide
|
||||
# which migrations to run.
|
||||
_SCHEMA_VERSION = 2
|
||||
|
||||
_META_SCHEMA_VERSION_KEY = "schema_version"
|
||||
|
||||
|
||||
async def _get_meta_value(db: aiosqlite.Connection, key: str) -> str | None:
|
||||
"""Read a key from the ``auth_meta`` table. Returns ``None`` if missing."""
|
||||
cursor = await db.execute("SELECT value FROM auth_meta WHERE key = ?", (key,))
|
||||
row = await cursor.fetchone()
|
||||
return row["value"] if row else None
|
||||
|
||||
|
||||
async def _set_meta_value(db: aiosqlite.Connection, key: str, value: str) -> None:
|
||||
"""Upsert a key in the ``auth_meta`` table."""
|
||||
await db.execute(
|
||||
"INSERT INTO auth_meta (key, value, updated_at) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at",
|
||||
(key, value, _now_iso()),
|
||||
)
|
||||
|
||||
|
||||
async def _backfill_user_sessions(db: aiosqlite.Connection) -> int:
|
||||
"""One-time backfill from ``user_sessions`` (V1) to ``auth_sessions`` (V2).
|
||||
|
||||
Runs only when ``auth_sessions`` is empty AND ``user_sessions`` has rows.
|
||||
Idempotent: subsequent restarts are no-ops because we mark the backfill
|
||||
as completed in ``auth_meta``.
|
||||
|
||||
For each non-revoked V1 session, copies:
|
||||
- id (reused — see note below)
|
||||
- user_id
|
||||
- refresh_token_hash
|
||||
- device_fingerprint / device_label / ip / user_agent from the legacy
|
||||
``device_info`` JSON blob (best-effort)
|
||||
- created_at / expires_at
|
||||
- last_active_at defaults to created_at
|
||||
- revoked=0 (already filtered)
|
||||
- revoked_reason=None
|
||||
- auth_provider='local' (default; backfilled rows are pre-IdP)
|
||||
|
||||
The original ``id`` is preserved so that legacy clients holding the
|
||||
old refresh_token_hash still match a row in the new table — this is
|
||||
what the back-compat path in U10 (``get_current_user`` for legacy
|
||||
JWTs) relies on.
|
||||
|
||||
Returns:
|
||||
Number of rows backfilled (0 if already done or nothing to backfill).
|
||||
"""
|
||||
# Idempotency: check the marker
|
||||
if await _get_meta_value(db, "backfill_user_sessions_v1_to_v2") == "done":
|
||||
return 0
|
||||
|
||||
cursor = await db.execute("SELECT COUNT(*) FROM auth_sessions")
|
||||
(count,) = await cursor.fetchone()
|
||||
if count > 0:
|
||||
# auth_sessions already has data — this is a fresh V2 install, not an
|
||||
# upgrade. Mark the backfill done so we never re-check.
|
||||
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
|
||||
await db.commit()
|
||||
return 0
|
||||
|
||||
cursor = await db.execute(
|
||||
"SELECT id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at "
|
||||
"FROM user_sessions WHERE revoked_at IS NULL"
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
backfilled = 0
|
||||
for row in rows:
|
||||
try:
|
||||
device_info = json.loads(row["device_info"]) if row["device_info"] else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
device_info = {}
|
||||
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, "
|
||||
" ip, user_agent, auth_provider, created_at, last_active_at, expires_at, "
|
||||
" revoked, revoked_reason, previous_session_id) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
row["id"], # reuse legacy id for back-compat with old clients
|
||||
row["user_id"],
|
||||
row["refresh_token_hash"],
|
||||
device_info.get("fingerprint", "unknown"),
|
||||
device_info.get("label", "Unknown device"),
|
||||
device_info.get("ip", ""),
|
||||
device_info.get("user_agent", ""),
|
||||
"local", # backfilled rows are pre-IdP by definition
|
||||
row["created_at"],
|
||||
row["created_at"], # last_active_at defaults to created_at
|
||||
row["expires_at"],
|
||||
0, # not revoked (already filtered)
|
||||
None,
|
||||
None,
|
||||
),
|
||||
)
|
||||
backfilled += 1
|
||||
|
||||
if backfilled:
|
||||
logger.info(
|
||||
f"Backfilled {backfilled} user_sessions rows to auth_sessions "
|
||||
f"(schema v{_SCHEMA_VERSION})"
|
||||
)
|
||||
|
||||
# Mark the backfill as completed regardless of how many rows were moved.
|
||||
# (idempotency: even a 0-row backfill is "done".)
|
||||
await _set_meta_value(db, "backfill_user_sessions_v1_to_v2", "done")
|
||||
await db.commit()
|
||||
return backfilled
|
||||
|
||||
|
||||
async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
||||
"""Create auth tables if they do not exist.
|
||||
|
||||
Uses aiosqlite directly (no SQLAlchemy engine) for a lightweight,
|
||||
zero-config bootstrap that mirrors :class:`SqliteConversationStore`.
|
||||
|
||||
On startup, this function:
|
||||
1. Creates all tables and indexes from :data:`_SCHEMA_SQL` (idempotent).
|
||||
2. Records the current :data:`_SCHEMA_VERSION` in ``auth_meta``.
|
||||
3. Runs any pending one-time migrations (currently: V1 → V2 backfill
|
||||
from ``user_sessions`` to ``auth_sessions``).
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite file. Defaults to
|
||||
:data:`DEFAULT_AUTH_DB_PATH` (``data/auth.db`` under the project
|
||||
|
|
@ -325,8 +551,19 @@ async def init_auth_db(db_path: str | Path | None = None) -> Path:
|
|||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiosqlite.connect(str(path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.executescript(_SCHEMA_SQL)
|
||||
|
||||
# Record the current schema version (idempotent upsert).
|
||||
current = await _get_meta_value(db, _META_SCHEMA_VERSION_KEY)
|
||||
if current != str(_SCHEMA_VERSION):
|
||||
await _set_meta_value(db, _META_SCHEMA_VERSION_KEY, str(_SCHEMA_VERSION))
|
||||
logger.info(f"Auth DB schema version set to {_SCHEMA_VERSION}")
|
||||
|
||||
# Run pending migrations (each is internally idempotent).
|
||||
await _backfill_user_sessions(db)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(f"Auth DB initialized at {path}")
|
||||
|
|
@ -353,3 +590,28 @@ def user_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any
|
|||
"last_login_at": row["last_login_at"],
|
||||
"created_by": row["created_by"],
|
||||
}
|
||||
|
||||
|
||||
def auth_session_row_to_dict(row: aiosqlite.Row | Mapping[str, object]) -> dict[str, Any]:
|
||||
"""Convert an ``auth_sessions`` row into a JSON-safe dict.
|
||||
|
||||
The ``revoked`` field is normalized to a Python ``bool`` (the DB stores
|
||||
0/1). The full set of audit fields is included so the admin UI and
|
||||
API responses can surface device/IP/last-active information without
|
||||
a separate lookup.
|
||||
"""
|
||||
return {
|
||||
"id": row["id"],
|
||||
"user_id": row["user_id"],
|
||||
"device_fingerprint": row["device_fingerprint"],
|
||||
"device_label": row["device_label"],
|
||||
"ip": row["ip"],
|
||||
"user_agent": row["user_agent"],
|
||||
"auth_provider": row["auth_provider"],
|
||||
"created_at": row["created_at"],
|
||||
"last_active_at": row["last_active_at"],
|
||||
"expires_at": row["expires_at"],
|
||||
"revoked": bool(row["revoked"]),
|
||||
"revoked_reason": row["revoked_reason"],
|
||||
"previous_session_id": row["previous_session_id"],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
"""AuthProvider implementations and dependency-injection factory.
|
||||
|
||||
Public surface:
|
||||
|
||||
- :class:`AuthProvider` (re-exported from :mod:`.base`) — the protocol
|
||||
every backend must satisfy.
|
||||
- :class:`LocalAuthProvider` — SQLite + bcrypt default.
|
||||
- :class:`StubOIDCProvider` — interface placeholder for the future
|
||||
OIDC integration.
|
||||
- :class:`User` — provider-agnostic user value object.
|
||||
- :func:`get_auth_provider` — DI factory used by routes via
|
||||
``Depends(get_auth_provider)``.
|
||||
- :func:`reset_auth_provider` — clears the lru_cache singleton
|
||||
(used in tests + when the auth provider config changes at runtime).
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
The provider is selected via the ``AGENTKIT_AUTH_PROVIDER`` environment
|
||||
variable (default: ``"local"``). When the future ``auth.provider``
|
||||
field is added to ``agentkit.yaml`` it should override the env var.
|
||||
|
||||
Adding a new provider
|
||||
---------------------
|
||||
1. Create ``auth/providers/<name>.py`` with a class that satisfies
|
||||
the :class:`AuthProvider` Protocol (i.e. has ``name: str`` and the
|
||||
four async methods).
|
||||
2. Import it here and add a branch to :func:`get_auth_provider`.
|
||||
3. Add it to ``AGENTKIT_AUTH_PROVIDER`` enum in the deployment docs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from .base import AuthProvider
|
||||
from .exceptions import AuthProviderError, InvalidCredentials, ProviderNotImplemented
|
||||
from .local import LocalAuthProvider
|
||||
from .oidc_stub import StubOIDCProvider
|
||||
from .user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Re-exports for ergonomic imports: `from agentkit.server.auth.providers import AuthProvider`
|
||||
__all__ = [
|
||||
"AuthProvider",
|
||||
"AuthProviderError",
|
||||
"InvalidCredentials",
|
||||
"LocalAuthProvider",
|
||||
"ProviderNotImplemented",
|
||||
"StubOIDCProvider",
|
||||
"User",
|
||||
"get_auth_provider",
|
||||
"reset_auth_provider",
|
||||
]
|
||||
|
||||
|
||||
# Default provider name. Configurable via AGENTKIT_AUTH_PROVIDER env var
|
||||
# (overridable when the auth.provider field is added to agentkit.yaml).
|
||||
DEFAULT_AUTH_PROVIDER = "local"
|
||||
|
||||
|
||||
def _resolve_provider_name() -> str:
|
||||
"""Return the configured provider name, defaulting to 'local'."""
|
||||
name = os.environ.get("AGENTKIT_AUTH_PROVIDER", DEFAULT_AUTH_PROVIDER).strip()
|
||||
if not name:
|
||||
return DEFAULT_AUTH_PROVIDER
|
||||
return name
|
||||
|
||||
|
||||
def _resolve_db_path() -> Path:
|
||||
"""Return the auth DB path (overridable for tests via env var)."""
|
||||
from ..models import DEFAULT_AUTH_DB_PATH
|
||||
|
||||
env = os.environ.get("AGENTKIT_AUTH_DB")
|
||||
if env:
|
||||
return Path(env)
|
||||
return DEFAULT_AUTH_DB_PATH
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_auth_provider() -> AuthProvider:
|
||||
"""Return the configured :class:`AuthProvider` (memoized singleton).
|
||||
|
||||
The ``lru_cache`` is process-local and the value is resolved on
|
||||
first call. Use :func:`reset_auth_provider` to clear the cache
|
||||
(e.g. when the configuration changes at runtime, or in test
|
||||
fixtures that need a fresh provider with a different db_path).
|
||||
|
||||
Raises:
|
||||
ValueError: if the configured provider name is not recognized.
|
||||
"""
|
||||
name = _resolve_provider_name()
|
||||
if name == "local":
|
||||
return LocalAuthProvider(db_path=_resolve_db_path())
|
||||
if name == "oidc-stub":
|
||||
return StubOIDCProvider()
|
||||
raise ValueError(
|
||||
f"unknown auth provider: {name!r}. "
|
||||
f"Supported providers: 'local', 'oidc-stub'. "
|
||||
f"Set AGENTKIT_AUTH_PROVIDER or update agentkit.yaml's auth.provider field."
|
||||
)
|
||||
|
||||
|
||||
def reset_auth_provider() -> None:
|
||||
"""Clear the memoized :class:`AuthProvider` singleton.
|
||||
|
||||
Use in tests (e.g. in a fixture's teardown) or in code paths that
|
||||
change the auth provider configuration at runtime and need a
|
||||
re-resolution.
|
||||
"""
|
||||
get_auth_provider.cache_clear()
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""AuthProvider protocol — pluggable authentication backend contract.
|
||||
|
||||
The :class:`AuthProvider` Protocol defines the minimal surface area the
|
||||
auth subsystem needs from any authentication backend (Local today, OIDC
|
||||
/ SAML / LDAP tomorrow). Routes and admin endpoints call only these
|
||||
methods and never touch the underlying user store directly. Adding a new
|
||||
IdP integration is a matter of writing a new adapter that satisfies
|
||||
this Protocol.
|
||||
|
||||
Design notes:
|
||||
|
||||
- **No sync surface for password verification**: the only entry point is
|
||||
:meth:`authenticate` which takes the plaintext password. Internally each
|
||||
adapter chooses how to verify (bcrypt locally, redirect to IdP, etc.).
|
||||
- **No persistence responsibility**: providers return :class:`User` objects
|
||||
but do not manage sessions, refresh tokens, or session state. That is
|
||||
the route + SessionService's job (see :mod:`agentkit.server.auth.session`).
|
||||
- **Audit-friendly**: :attr:`name` is written to ``auth_sessions.auth_provider``
|
||||
on every login, so the source of every session is traceable. This is
|
||||
what enables "list sessions by provider" admin queries and future
|
||||
cross-IdP policy enforcement.
|
||||
- **runtime_checkable**: the Protocol is decorated so that
|
||||
``isinstance(provider, AuthProvider)`` works at runtime, enabling
|
||||
defensive checks in tests and in DI wiring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from .user import User
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AuthProvider(Protocol):
|
||||
"""All authentication backends must implement this surface.
|
||||
|
||||
The route layer only calls the methods below. It is intentionally
|
||||
unaware of whether the backing store is SQLite, an OIDC IdP, LDAP,
|
||||
or anything else.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""Identifier for this provider, written to ``auth_sessions.auth_provider``.
|
||||
|
||||
Convention: ``local`` for the local SQLite + bcrypt backend,
|
||||
``oidc-<idp-name>`` for OIDC (e.g. ``oidc-keycloak``,
|
||||
``oidc-feishu``), ``saml`` / ``ldap`` for the other planned
|
||||
integrations. This value is the stable contract for audit
|
||||
traceability — do not rename without a migration plan.
|
||||
"""
|
||||
|
||||
async def authenticate(self, *, username: str, password: str) -> User:
|
||||
"""Verify ``username`` + ``password`` and return the :class:`User`.
|
||||
|
||||
Args:
|
||||
username: The submitted username (or, for OIDC, the IdP
|
||||
subject id — but that's a future adapter's concern).
|
||||
password: The submitted plaintext password.
|
||||
|
||||
Returns:
|
||||
The matched :class:`User` on success.
|
||||
|
||||
Raises:
|
||||
InvalidCredentials: if the user does not exist, the password
|
||||
is wrong, or the user is inactive. Callers MUST NOT
|
||||
distinguish between these three cases in the error
|
||||
message returned to the client (timing-attack /
|
||||
username-enumeration mitigation).
|
||||
"""
|
||||
...
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Look up a :class:`User` by primary key.
|
||||
|
||||
Used by:
|
||||
- Admin endpoints that need to display user info by id
|
||||
- Session validation in the cold-start / whoami path
|
||||
- Audit log enrichment
|
||||
|
||||
Returns ``None`` if no user exists with this id (or the user
|
||||
is inactive — convention: inactive users are "not found" from
|
||||
the auth layer's perspective).
|
||||
"""
|
||||
...
|
||||
|
||||
async def sync_user_attributes(self, user_id: str) -> None:
|
||||
"""Refresh user attributes (department / email / title) from the source of truth.
|
||||
|
||||
- :class:`LocalAuthProvider`: no-op (attributes are managed locally).
|
||||
- OIDC adapter (future): pull the latest profile from the IdP and
|
||||
write back to the local ``users`` table.
|
||||
|
||||
Implementations that have nothing to sync should still define
|
||||
this method (returning ``None``) so the contract is uniform.
|
||||
"""
|
||||
...
|
||||
|
||||
async def revoke_user(self, user_id: str) -> None:
|
||||
"""Disable a user account (e.g. on termination or lock-out).
|
||||
|
||||
- :class:`LocalAuthProvider`: ``UPDATE users SET is_active = 0``
|
||||
- OIDC adapter (future): call the IdP's disable API
|
||||
|
||||
The admin endpoint that calls this does NOT also need to
|
||||
revoke the user's active sessions — that is the
|
||||
:class:`SessionService`'s job, called separately.
|
||||
"""
|
||||
...
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""Exceptions raised by AuthProvider implementations.
|
||||
|
||||
These are caught by the route layer and translated to HTTP responses
|
||||
(401 for :class:`InvalidCredentials`, 501 for :class:`ProviderNotImplemented`).
|
||||
"""
|
||||
|
||||
|
||||
class AuthProviderError(Exception):
|
||||
"""Base class for all auth provider errors."""
|
||||
|
||||
|
||||
class InvalidCredentials(AuthProviderError):
|
||||
"""Raised when username / password is wrong, or the user is inactive.
|
||||
|
||||
Translates to HTTP 401. The error message MUST NOT leak which of
|
||||
"user not found" vs "wrong password" vs "user inactive" failed —
|
||||
that is a username enumeration risk.
|
||||
"""
|
||||
|
||||
|
||||
class ProviderNotImplemented(AuthProviderError):
|
||||
"""Raised when a configured provider is not yet implemented.
|
||||
|
||||
Translates to HTTP 501. Used by :class:`StubOIDCProvider` and any
|
||||
future adapter that is registered but has no real implementation.
|
||||
"""
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
"""Local authentication provider — SQLite + bcrypt.
|
||||
|
||||
The default :class:`AuthProvider` implementation. Authenticates users
|
||||
against the local ``users`` table in the auth SQLite database using
|
||||
the bcrypt cost=12 password hash (see :mod:`agentkit.server.auth.password`).
|
||||
|
||||
This is a behavioral equivalent of the password-verification code that
|
||||
previously lived inline in ``routes/auth.py`` — moved here so the route
|
||||
layer can call a single :meth:`authenticate` method regardless of which
|
||||
backend is configured.
|
||||
|
||||
Future-IdP note
|
||||
---------------
|
||||
When the organization moves to OIDC / SAML / LDAP, this class does not
|
||||
need to be deleted. It can continue to serve as a "local emergency
|
||||
account" provider, configurable side-by-side with the IdP provider via
|
||||
a future composite / multi-provider setup. For now it is the only
|
||||
implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from ..models import DEFAULT_AUTH_DB_PATH
|
||||
from ..password import verify_password
|
||||
from .user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_db_path() -> Path:
|
||||
"""Resolve the auth DB path with runtime env-var priority.
|
||||
|
||||
The :data:`models.DEFAULT_AUTH_DB_PATH` constant is captured at
|
||||
module-import time and therefore cannot see test-time env mutations.
|
||||
Re-reading ``AGENTKIT_AUTH_DB`` here keeps the provider
|
||||
"test-friendly" (tests can ``monkeypatch.setenv`` before constructing
|
||||
the provider) without giving up the default path when no env is set.
|
||||
"""
|
||||
env = os.environ.get("AGENTKIT_AUTH_DB")
|
||||
if env:
|
||||
return Path(env)
|
||||
return DEFAULT_AUTH_DB_PATH
|
||||
|
||||
|
||||
class LocalAuthProvider:
|
||||
"""AuthProvider backed by the local SQLite ``users`` table + bcrypt.
|
||||
|
||||
Args:
|
||||
db_path: Path to the auth DB. Defaults to the value of the
|
||||
``AGENTKIT_AUTH_DB`` env var, falling back to
|
||||
:data:`agentkit.server.auth.models.DEFAULT_AUTH_DB_PATH`.
|
||||
Each operation opens a short-lived aiosqlite connection;
|
||||
the existing route layer follows the same pattern, so no
|
||||
connection pooling is introduced here. If a future
|
||||
deployment needs pooling, swap in a ``db_factory: Callable``
|
||||
here without changing the protocol.
|
||||
"""
|
||||
|
||||
name = "local"
|
||||
|
||||
def __init__(self, db_path: str | Path | None = None) -> None:
|
||||
self._db_path = Path(db_path) if db_path is not None else _resolve_db_path()
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
return self._db_path
|
||||
|
||||
async def authenticate(self, *, username: str, password: str) -> User:
|
||||
"""Verify the username + password against the local users table.
|
||||
|
||||
Raises :class:`InvalidCredentials` on every failure mode
|
||||
(unknown user, wrong password, inactive user) with the same
|
||||
error message — preventing username enumeration via error
|
||||
inspection. Constant-time-equivalent behavior is also
|
||||
ensured by always running a real bcrypt computation
|
||||
(against a dummy hash) when the user does not exist,
|
||||
matching the timing of the "user exists, wrong password" path.
|
||||
"""
|
||||
from .exceptions import InvalidCredentials # local import to avoid cycle at module load
|
||||
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT id, username, email, password_hash, role, is_active, "
|
||||
"is_terminal_authorized, is_server_terminal_authorized, "
|
||||
"created_at, updated_at, last_login_at, created_by "
|
||||
"FROM users WHERE username = ?",
|
||||
(username,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None or not bool(row["is_active"]):
|
||||
# Run a real bcrypt verification against a valid-format dummy
|
||||
# hash so the response time matches the "user exists, wrong
|
||||
# password" path (~250ms). Prevents username enumeration via
|
||||
# timing. The dummy hash is invalid (won't match any password)
|
||||
# but has the right shape so bcrypt.checkpw doesn't short-circuit.
|
||||
_DUMMY_BCRYPT_HASH = "$2b$12$abcdefghijklmnopqrstuuABCDEFGHIJKLMNOPQRSTUVWXYZ0123"
|
||||
verify_password(password, _DUMMY_BCRYPT_HASH)
|
||||
raise InvalidCredentials("invalid username or password")
|
||||
|
||||
if not verify_password(password, row["password_hash"]):
|
||||
raise InvalidCredentials("invalid username or password")
|
||||
|
||||
return _row_to_user(row)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Look up a user by id. Returns ``None`` if not found or inactive."""
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT id, username, email, password_hash, role, is_active, "
|
||||
"is_terminal_authorized, is_server_terminal_authorized, "
|
||||
"created_at, updated_at, last_login_at, created_by "
|
||||
"FROM users WHERE id = ? AND is_active = 1",
|
||||
(user_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return _row_to_user(row) if row else None
|
||||
|
||||
async def sync_user_attributes(self, user_id: str) -> None:
|
||||
"""No-op: local provider has no upstream source of truth to sync from."""
|
||||
return None
|
||||
|
||||
async def revoke_user(self, user_id: str) -> None:
|
||||
"""Disable a user account (``is_active = 0``)."""
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
await db.execute(
|
||||
"UPDATE users SET is_active = 0, updated_at = ? WHERE id = ?",
|
||||
(_now_iso(), user_id),
|
||||
)
|
||||
await db.commit()
|
||||
logger.info(f"Revoked user {user_id} via LocalAuthProvider")
|
||||
|
||||
|
||||
def _row_to_user(row: aiosqlite.Row) -> User:
|
||||
"""Convert a ``users`` row to a :class:`User` value object."""
|
||||
return User(
|
||||
id=row["id"],
|
||||
username=row["username"],
|
||||
email=row["email"],
|
||||
role=row["role"],
|
||||
is_active=bool(row["is_active"]),
|
||||
is_terminal_authorized=bool(row["is_terminal_authorized"]),
|
||||
is_server_terminal_authorized=bool(row["is_server_terminal_authorized"]),
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
last_login_at=row["last_login_at"],
|
||||
created_by=row["created_by"],
|
||||
)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Return current UTC time as ISO 8601 string."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
"""Stub OIDC provider — interface contract placeholder.
|
||||
|
||||
This is a deliberately unimplemented :class:`AuthProvider` whose only
|
||||
job is to fail loudly when ``auth.provider: oidc-stub`` is configured
|
||||
without the real OIDC integration being in place. It exists so that
|
||||
the configuration switch and the dependency-injection wiring can be
|
||||
exercised end-to-end before the IdP integration work is scheduled.
|
||||
|
||||
Future OIDC integration checklist
|
||||
---------------------------------
|
||||
When the real OIDC integration lands, replace this file with
|
||||
:file:`auth/providers/oidc.py` implementing:
|
||||
|
||||
- [ ] :meth:`authenticate` — accept (username, password)? OR redirect-flow?
|
||||
OIDC is fundamentally a redirect flow, NOT a password POST. The
|
||||
right answer is probably a separate ``/auth/oauth/{provider}/redirect``
|
||||
and ``/auth/oauth/{provider}/callback`` route pair that bypasses
|
||||
the password-based login. The :class:`AuthProvider` interface is
|
||||
only for backends that accept username + password; an OIDC
|
||||
adapter may need to extend the protocol or expose a separate
|
||||
callback-handler hook.
|
||||
- [ ] :meth:`get_user_by_id` — query local cache (auto-provisioned
|
||||
OIDC users live in the local ``users`` table after first login).
|
||||
- [ ] :meth:`sync_user_attributes` — pull latest profile from the IdP
|
||||
on each successful login (department / title / email).
|
||||
- [ ] :meth:`revoke_user` — call the IdP's account-disable API (e.g.
|
||||
Keycloak's ``PUT /admin/realms/{realm}/users/{id}/disable``).
|
||||
- [ ] State cache for OAuth ``state`` parameter (Redis, TTL 5 min).
|
||||
- [ ] First-login user auto-provisioning policy: just-in-time
|
||||
creation / reject / invite-only.
|
||||
|
||||
Until then, configuring ``auth.provider: oidc-stub`` and starting
|
||||
the server lets a smoke test confirm:
|
||||
|
||||
1. The DI factory resolves the right class.
|
||||
2. The route layer calls :meth:`authenticate` and surfaces 501.
|
||||
3. ``auth_sessions.auth_provider='oidc-stub'`` would be written
|
||||
for any session that *did* get created (none, in this stub).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .exceptions import ProviderNotImplemented
|
||||
from .user import User
|
||||
|
||||
|
||||
class StubOIDCProvider:
|
||||
"""Unimplemented OIDC adapter. All methods raise :class:`ProviderNotImplemented`.
|
||||
|
||||
The class satisfies the :class:`AuthProvider` Protocol structurally
|
||||
(it has ``name`` and the four methods), but every method body is
|
||||
a deliberate failure. This is on purpose: configuring
|
||||
``auth.provider: oidc-stub`` and then successfully logging in would
|
||||
be a security bug.
|
||||
"""
|
||||
|
||||
name = "oidc-stub"
|
||||
|
||||
async def authenticate(self, *, username: str, password: str) -> User:
|
||||
raise ProviderNotImplemented(
|
||||
"OIDC authentication is not implemented yet. "
|
||||
"Use auth.provider: local in agentkit.yaml, or implement "
|
||||
"auth/providers/oidc.py (see the checklist in this module's "
|
||||
"docstring)."
|
||||
)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
raise ProviderNotImplemented(
|
||||
"StubOIDCProvider.get_user_by_id is not implemented. "
|
||||
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||
)
|
||||
|
||||
async def sync_user_attributes(self, user_id: str) -> None:
|
||||
raise ProviderNotImplemented(
|
||||
"StubOIDCProvider.sync_user_attributes is not implemented. "
|
||||
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||
)
|
||||
|
||||
async def revoke_user(self, user_id: str) -> None:
|
||||
raise ProviderNotImplemented(
|
||||
"StubOIDCProvider.revoke_user is not implemented. "
|
||||
"See auth/providers/oidc_stub.py for the future OIDC checklist."
|
||||
)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Provider-agnostic user data model.
|
||||
|
||||
This is the value object the :class:`AuthProvider` Protocol returns
|
||||
from :meth:`AuthProvider.authenticate` and :meth:`AuthProvider.get_user_by_id`.
|
||||
It is intentionally NOT the SQLAlchemy ``UserModel`` ORM class so that
|
||||
future IdP adapters (OIDC / SAML / LDAP) can construct one without
|
||||
touching a DB.
|
||||
|
||||
The route layer converts ``User`` → ``UserResponse`` (the public
|
||||
API payload) at the boundary; ``User`` itself stays an internal
|
||||
type.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""A user as known to the auth subsystem.
|
||||
|
||||
Distinct from the SQLAlchemy ``UserModel`` ORM class because:
|
||||
- The auth provider may not have a local row at all (e.g. an
|
||||
OIDC user logged in via SSO who hasn't been provisioned yet
|
||||
— that's a future concern but the data model must allow it).
|
||||
- IdP-sourced attributes (department / title) are not in the
|
||||
SQLAlchemy model.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: EmailStr
|
||||
role: str = "member"
|
||||
is_active: bool = True
|
||||
is_terminal_authorized: bool = False
|
||||
is_server_terminal_authorized: bool = False
|
||||
created_at: str
|
||||
updated_at: str
|
||||
last_login_at: str | None = None
|
||||
created_by: str | None = None
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
"""Tests for the AuthProvider protocol and the DI factory (U11)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.server.auth.providers import (
|
||||
AuthProvider,
|
||||
LocalAuthProvider,
|
||||
StubOIDCProvider,
|
||||
get_auth_provider,
|
||||
reset_auth_provider,
|
||||
)
|
||||
from agentkit.server.auth.providers.base import AuthProvider as AuthProviderFromBase
|
||||
from agentkit.server.auth.providers.exceptions import (
|
||||
AuthProviderError,
|
||||
InvalidCredentials,
|
||||
ProviderNotImplemented,
|
||||
)
|
||||
|
||||
|
||||
class TestProtocolConformance:
|
||||
"""The runtime_checkable Protocol accepts both real implementations."""
|
||||
|
||||
def test_local_passes_isinstance_check(self):
|
||||
provider = LocalAuthProvider()
|
||||
assert isinstance(provider, AuthProvider)
|
||||
assert isinstance(provider, AuthProviderFromBase)
|
||||
|
||||
def test_stub_passes_isinstance_check(self):
|
||||
provider = StubOIDCProvider()
|
||||
assert isinstance(provider, AuthProvider)
|
||||
assert isinstance(provider, AuthProviderFromBase)
|
||||
|
||||
|
||||
class TestProviderNames:
|
||||
def test_local_provider_name(self):
|
||||
assert LocalAuthProvider.name == "local"
|
||||
|
||||
def test_stub_oidc_provider_name(self):
|
||||
assert StubOIDCProvider.name == "oidc-stub"
|
||||
|
||||
|
||||
class TestDiFactory:
|
||||
"""``get_auth_provider`` is a memoized singleton, env-driven."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset cache and env before each test."""
|
||||
reset_auth_provider()
|
||||
self._saved_env = {}
|
||||
for key in ("AGENTKIT_AUTH_PROVIDER",):
|
||||
self._saved_env[key] = __import__("os").environ.pop(key, None)
|
||||
|
||||
def teardown_method(self):
|
||||
for key, value in self._saved_env.items():
|
||||
if value is not None:
|
||||
__import__("os").environ[key] = value
|
||||
reset_auth_provider()
|
||||
|
||||
def test_default_provider_is_local(self):
|
||||
provider = get_auth_provider()
|
||||
assert isinstance(provider, LocalAuthProvider)
|
||||
assert provider.name == "local"
|
||||
|
||||
def test_oidc_stub_provider(self, monkeypatch):
|
||||
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
|
||||
reset_auth_provider()
|
||||
provider = get_auth_provider()
|
||||
assert isinstance(provider, StubOIDCProvider)
|
||||
assert provider.name == "oidc-stub"
|
||||
|
||||
def test_unknown_provider_raises_value_error(self, monkeypatch):
|
||||
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "ldap-unknown")
|
||||
reset_auth_provider()
|
||||
with pytest.raises(ValueError, match="unknown auth provider"):
|
||||
get_auth_provider()
|
||||
|
||||
def test_factory_is_memoized(self):
|
||||
first = get_auth_provider()
|
||||
second = get_auth_provider()
|
||||
assert first is second
|
||||
|
||||
def test_reset_clears_cache(self, monkeypatch):
|
||||
first = get_auth_provider()
|
||||
reset_auth_provider()
|
||||
monkeypatch.setenv("AGENTKIT_AUTH_PROVIDER", "oidc-stub")
|
||||
second = get_auth_provider()
|
||||
assert first is not second
|
||||
assert isinstance(second, StubOIDCProvider)
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
def test_invalid_credentials_inherits_auth_provider_error(self):
|
||||
assert issubclass(InvalidCredentials, AuthProviderError)
|
||||
assert issubclass(InvalidCredentials, Exception)
|
||||
|
||||
def test_provider_not_implemented_inherits_auth_provider_error(self):
|
||||
assert issubclass(ProviderNotImplemented, AuthProviderError)
|
||||
assert issubclass(ProviderNotImplemented, Exception)
|
||||
|
||||
def test_invalid_credentials_can_be_raised_and_caught(self):
|
||||
with pytest.raises(InvalidCredentials):
|
||||
raise InvalidCredentials("test message")
|
||||
|
||||
def test_invalid_credentials_caught_as_base(self):
|
||||
"""Route layer catches AuthProviderError to handle all provider errors."""
|
||||
with pytest.raises(AuthProviderError):
|
||||
raise InvalidCredentials("test")
|
||||
|
||||
def test_provider_not_implemented_caught_as_base(self):
|
||||
with pytest.raises(AuthProviderError):
|
||||
raise ProviderNotImplemented("test")
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
"""Tests for LocalAuthProvider (U11 — concrete Local implementation)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.server.auth.password import hash_password
|
||||
from agentkit.server.auth.providers import LocalAuthProvider
|
||||
from agentkit.server.auth.providers.exceptions import InvalidCredentials
|
||||
from agentkit.server.auth.providers.user import User
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def auth_db_with_users(tmp_path: Path) -> dict:
|
||||
"""Create a fresh auth DB with two users: one active, one inactive.
|
||||
|
||||
Returns a dict with user info + the db path.
|
||||
"""
|
||||
db_path = tmp_path / "auth.db"
|
||||
await init_auth_db(db_path)
|
||||
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
active_id = str(uuid.uuid4())
|
||||
inactive_id = str(uuid.uuid4())
|
||||
active_pw = "correct-horse-battery-staple"
|
||||
inactive_pw = "disabled-user-pw"
|
||||
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
await db.execute(
|
||||
"INSERT INTO users "
|
||||
"(id, username, email, password_hash, role, is_active, "
|
||||
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||
" created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
active_id,
|
||||
"alice",
|
||||
"alice@example.com",
|
||||
hash_password(active_pw),
|
||||
"member",
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
now_iso,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO users "
|
||||
"(id, username, email, password_hash, role, is_active, "
|
||||
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||
" created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
inactive_id,
|
||||
"bob_inactive",
|
||||
"bob@example.com",
|
||||
hash_password(inactive_pw),
|
||||
"member",
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
now_iso,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"db_path": db_path,
|
||||
"active": {
|
||||
"id": active_id,
|
||||
"username": "alice",
|
||||
"password": active_pw,
|
||||
"email": "alice@example.com",
|
||||
"role": "member",
|
||||
},
|
||||
"inactive": {
|
||||
"id": inactive_id,
|
||||
"username": "bob_inactive",
|
||||
"password": inactive_pw,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider(auth_db_with_users: dict) -> LocalAuthProvider:
|
||||
return LocalAuthProvider(db_path=auth_db_with_users["db_path"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# authenticate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthenticate:
|
||||
async def test_valid_credentials_returns_user(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
user = await provider.authenticate(
|
||||
username=auth_db_with_users["active"]["username"],
|
||||
password=auth_db_with_users["active"]["password"],
|
||||
)
|
||||
assert isinstance(user, User)
|
||||
assert user.id == auth_db_with_users["active"]["id"]
|
||||
assert user.username == "alice"
|
||||
assert user.email == "alice@example.com"
|
||||
assert user.role == "member"
|
||||
assert user.is_active is True
|
||||
|
||||
async def test_wrong_password_raises_invalid_credentials(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
with pytest.raises(InvalidCredentials):
|
||||
await provider.authenticate(
|
||||
username=auth_db_with_users["active"]["username"],
|
||||
password="definitely-not-the-password",
|
||||
)
|
||||
|
||||
async def test_unknown_user_raises_invalid_credentials(self, provider: LocalAuthProvider):
|
||||
with pytest.raises(InvalidCredentials):
|
||||
await provider.authenticate(username="nonexistent", password="anything")
|
||||
|
||||
async def test_inactive_user_raises_invalid_credentials(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
with pytest.raises(InvalidCredentials):
|
||||
await provider.authenticate(
|
||||
username=auth_db_with_users["inactive"]["username"],
|
||||
password=auth_db_with_users["inactive"]["password"],
|
||||
)
|
||||
|
||||
async def test_error_message_does_not_leak_username_existence(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
"""Wrong-password and unknown-user errors must have the same message."""
|
||||
try:
|
||||
await provider.authenticate(
|
||||
username=auth_db_with_users["active"]["username"],
|
||||
password="wrong",
|
||||
)
|
||||
except InvalidCredentials as e1:
|
||||
wrong_pw_msg = str(e1)
|
||||
try:
|
||||
await provider.authenticate(username="nobody-here", password="x")
|
||||
except InvalidCredentials as e2:
|
||||
unknown_msg = str(e2)
|
||||
assert wrong_pw_msg == unknown_msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_user_by_id
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetUserById:
|
||||
async def test_returns_user_for_active_id(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
user = await provider.get_user_by_id(auth_db_with_users["active"]["id"])
|
||||
assert user is not None
|
||||
assert user.id == auth_db_with_users["active"]["id"]
|
||||
assert user.username == "alice"
|
||||
|
||||
async def test_returns_none_for_inactive_user(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
user = await provider.get_user_by_id(auth_db_with_users["inactive"]["id"])
|
||||
assert user is None
|
||||
|
||||
async def test_returns_none_for_unknown_id(self, provider: LocalAuthProvider):
|
||||
user = await provider.get_user_by_id(str(uuid.uuid4()))
|
||||
assert user is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sync_user_attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyncUserAttributes:
|
||||
async def test_is_noop(self, provider: LocalAuthProvider, auth_db_with_users: dict):
|
||||
"""Local provider has no upstream to sync from — must succeed with no effect."""
|
||||
result = await provider.sync_user_attributes(auth_db_with_users["active"]["id"])
|
||||
assert result is None
|
||||
|
||||
async def test_noop_does_not_throw_for_unknown_user(self, provider: LocalAuthProvider):
|
||||
"""sync_user_attributes is fire-and-forget — no existence check."""
|
||||
# Should NOT raise even though the id doesn't exist
|
||||
await provider.sync_user_attributes(str(uuid.uuid4()))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# revoke_user
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevokeUser:
|
||||
async def test_sets_is_active_to_zero(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict, tmp_path: Path
|
||||
):
|
||||
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
||||
async with aiosqlite.connect(str(auth_db_with_users["db_path"])) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT is_active FROM users WHERE id = ?",
|
||||
(auth_db_with_users["active"]["id"],),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert bool(row["is_active"]) is False
|
||||
|
||||
async def test_revoked_user_can_no_longer_authenticate(
|
||||
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
||||
):
|
||||
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
||||
with pytest.raises(InvalidCredentials):
|
||||
await provider.authenticate(
|
||||
username=auth_db_with_users["active"]["username"],
|
||||
password=auth_db_with_users["active"]["password"],
|
||||
)
|
||||
|
||||
async def test_revoke_unknown_user_does_not_raise(self, provider: LocalAuthProvider):
|
||||
"""``UPDATE ... WHERE id = ?`` with no match is a no-op, not an error."""
|
||||
await provider.revoke_user(str(uuid.uuid4()))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# default db_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaultDbPath:
|
||||
def test_default_db_path_uses_models_default(self, tmp_path: Path, monkeypatch):
|
||||
"""If no db_path is given, the provider should use the module default.
|
||||
|
||||
The default may resolve to a path that does not exist on a test
|
||||
machine — we only assert that the property returns a Path, not
|
||||
that the file exists.
|
||||
"""
|
||||
monkeypatch.setenv("AGENTKIT_AUTH_DB", str(tmp_path / "default.db"))
|
||||
provider = LocalAuthProvider()
|
||||
assert isinstance(provider.db_path, Path)
|
||||
assert str(provider.db_path) == str(tmp_path / "default.db")
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""Tests for StubOIDCProvider (U11 — interface placeholder)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.server.auth.providers import StubOIDCProvider
|
||||
from agentkit.server.auth.providers.exceptions import (
|
||||
AuthProviderError,
|
||||
ProviderNotImplemented,
|
||||
)
|
||||
|
||||
|
||||
class TestStubAuthenticate:
|
||||
async def test_raises_provider_not_implemented(self):
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(ProviderNotImplemented):
|
||||
await provider.authenticate(username="alice", password="hunter2")
|
||||
|
||||
async def test_error_message_mentions_oidc(self):
|
||||
"""The error message must guide the operator to the right next step."""
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(ProviderNotImplemented) as exc:
|
||||
await provider.authenticate(username="alice", password="x")
|
||||
msg = str(exc.value)
|
||||
assert "OIDC" in msg
|
||||
assert "not implemented" in msg.lower()
|
||||
|
||||
|
||||
class TestStubGetUserById:
|
||||
async def test_raises_provider_not_implemented(self):
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(ProviderNotImplemented):
|
||||
await provider.get_user_by_id("any-id")
|
||||
|
||||
|
||||
class TestStubSyncUserAttributes:
|
||||
async def test_raises_provider_not_implemented(self):
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(ProviderNotImplemented):
|
||||
await provider.sync_user_attributes("any-id")
|
||||
|
||||
|
||||
class TestStubRevokeUser:
|
||||
async def test_raises_provider_not_implemented(self):
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(ProviderNotImplemented):
|
||||
await provider.revoke_user("any-id")
|
||||
|
||||
|
||||
class TestAllErrorsAreAuthProviderError:
|
||||
"""Every stub method's error should be catchable as AuthProviderError."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method,call_args",
|
||||
[
|
||||
("authenticate", ()),
|
||||
("get_user_by_id", ("id",)),
|
||||
("sync_user_attributes", ("id",)),
|
||||
("revoke_user", ("id",)),
|
||||
],
|
||||
)
|
||||
async def test_catchable_as_base(self, method: str, call_args: tuple):
|
||||
provider = StubOIDCProvider()
|
||||
with pytest.raises(AuthProviderError):
|
||||
if method == "authenticate":
|
||||
await provider.authenticate(username="u", password="p")
|
||||
else:
|
||||
await getattr(provider, method)(*call_args)
|
||||
|
|
@ -0,0 +1,653 @@
|
|||
"""Unit tests for auth.models (U1 — V2 schema + backfill).
|
||||
|
||||
Covers:
|
||||
- ``auth_sessions`` table creation (columns + indexes)
|
||||
- ``auth_meta`` table creation
|
||||
- ``_SCHEMA_VERSION`` constant value
|
||||
- ``auth_session_row_to_dict`` field-by-field conversion
|
||||
- ``_backfill_user_sessions`` one-time migration (V1 → V2)
|
||||
- ``init_auth_db`` idempotency (subsequent runs are no-ops)
|
||||
- ``auth_provider`` column default value
|
||||
- Index presence (verified via ``PRAGMA index_list``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from agentkit.server.auth.models import (
|
||||
AuthSessionModel,
|
||||
UserSessionModel,
|
||||
_SCHEMA_VERSION,
|
||||
auth_session_row_to_dict,
|
||||
init_auth_db,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fresh_db(tmp_path: Path) -> Path:
|
||||
"""A brand-new auth DB on a fresh path (no data)."""
|
||||
db_path = tmp_path / "auth.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
async def _insert_user(db: aiosqlite.Connection, *, user_id: str | None = None) -> str:
|
||||
"""Insert a minimal user row and return its id."""
|
||||
user_id = user_id or str(uuid.uuid4())
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
await db.execute(
|
||||
"INSERT INTO users "
|
||||
"(id, username, email, password_hash, role, is_active, "
|
||||
" is_terminal_authorized, is_server_terminal_authorized, "
|
||||
" created_at, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
user_id,
|
||||
f"user-{user_id[:8]}",
|
||||
f"{user_id[:8]}@example.com",
|
||||
"$2b$12$placeholder.hash.placeholder.hash.placeholder.hash",
|
||||
"member",
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
now_iso,
|
||||
now_iso,
|
||||
),
|
||||
)
|
||||
return user_id
|
||||
|
||||
|
||||
async def _insert_user_session(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str | None = None,
|
||||
refresh_hash: str | None = None,
|
||||
device_info: str = "{}",
|
||||
revoked: bool = False,
|
||||
) -> str:
|
||||
"""Insert a V1 ``user_sessions`` row and return its id."""
|
||||
session_id = session_id or str(uuid.uuid4())
|
||||
refresh_hash = refresh_hash or uuid.uuid4().hex
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
await db.execute(
|
||||
"INSERT INTO user_sessions "
|
||||
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at, revoked_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
session_id,
|
||||
user_id,
|
||||
refresh_hash,
|
||||
device_info,
|
||||
now_iso,
|
||||
now_iso,
|
||||
now_iso if revoked else None,
|
||||
),
|
||||
)
|
||||
return session_id
|
||||
|
||||
|
||||
async def _list_index_names(db: aiosqlite.Connection, table: str) -> set[str]:
|
||||
"""Return the set of index names for a table.
|
||||
|
||||
Sets ``row_factory`` on the connection so we can address columns by name.
|
||||
PRAGMA results in aiosqlite come back as plain tuples unless a row factory
|
||||
is configured.
|
||||
"""
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(f"PRAGMA index_list({table})")
|
||||
rows = await cursor.fetchall()
|
||||
return {row["name"] for row in rows}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _SCHEMA_VERSION
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchemaVersion:
|
||||
def test_schema_version_is_v2(self):
|
||||
"""The current schema version is 2 (V2 adds auth_sessions + auth_meta)."""
|
||||
assert _SCHEMA_VERSION == 2
|
||||
|
||||
def test_sqlalchemy_model_table_name(self):
|
||||
assert AuthSessionModel.__tablename__ == "auth_sessions"
|
||||
assert UserSessionModel.__tablename__ == "user_sessions"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# init_auth_db: tables + indexes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInitAuthDbTables:
|
||||
async def test_creates_auth_sessions_table(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_sessions'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
async def test_creates_auth_meta_table(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='auth_meta'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
async def test_creates_user_sessions_table_for_back_compat(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
cursor = await db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='user_sessions'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
async def test_records_schema_version_in_auth_meta(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT value FROM auth_meta WHERE key='schema_version'")
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["value"] == str(_SCHEMA_VERSION)
|
||||
|
||||
|
||||
class TestAuthSessionsIndexes:
|
||||
async def test_user_id_active_index(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)):
|
||||
pass
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
indexes = await _list_index_names(db, "auth_sessions")
|
||||
assert "idx_auth_sessions_user_id_active" in indexes
|
||||
|
||||
async def test_expires_at_index(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
indexes = await _list_index_names(db, "auth_sessions")
|
||||
assert "idx_auth_sessions_expires_at" in indexes
|
||||
|
||||
async def test_auth_provider_index(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
indexes = await _list_index_names(db, "auth_sessions")
|
||||
assert "idx_auth_sessions_auth_provider" in indexes
|
||||
|
||||
async def test_refresh_token_hash_unique_index(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
# SQLite stores column-level UNIQUE constraints as PRAGMA index_list
|
||||
# entries with auto-generated names like sqlite_autoindex_<table>_1.
|
||||
# The PRAGMA index_info on each autoindex reports the constrained columns.
|
||||
cursor = await db.execute("PRAGMA index_list(auth_sessions)")
|
||||
index_entries = await cursor.fetchall()
|
||||
column_names: set[str] = set()
|
||||
for entry in index_entries:
|
||||
col_cursor = await db.execute(f"PRAGMA index_info({entry['name']})")
|
||||
col_rows = await col_cursor.fetchall()
|
||||
for col_row in col_rows:
|
||||
column_names.add(col_row["name"])
|
||||
assert "refresh_token_hash" in column_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# auth_sessions columns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthSessionsColumns:
|
||||
async def test_required_columns_present(self, fresh_db: Path):
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
cursor = await db.execute("PRAGMA table_info(auth_sessions)")
|
||||
cols = {row[1] for row in await cursor.fetchall()}
|
||||
required = {
|
||||
"id",
|
||||
"user_id",
|
||||
"refresh_token_hash",
|
||||
"device_fingerprint",
|
||||
"device_label",
|
||||
"ip",
|
||||
"user_agent",
|
||||
"auth_provider",
|
||||
"created_at",
|
||||
"last_active_at",
|
||||
"expires_at",
|
||||
"revoked",
|
||||
"revoked_reason",
|
||||
"previous_session_id",
|
||||
}
|
||||
assert required.issubset(cols)
|
||||
|
||||
async def test_auth_provider_default_is_local(self, fresh_db: Path):
|
||||
"""Insert a row without auth_provider → column should default to 'local'."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
sid = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||
" revoked) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||
(
|
||||
sid,
|
||||
user_id,
|
||||
"deadbeef",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
"2026-12-31T00:00:00+00:00",
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT auth_provider FROM auth_sessions WHERE id=?", (sid,))
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["auth_provider"] == "local"
|
||||
|
||||
async def test_revoked_default_is_false(self, fresh_db: Path):
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
sid = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
sid,
|
||||
user_id,
|
||||
"beefdead",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
"2026-12-31T00:00:00+00:00",
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT revoked FROM auth_sessions WHERE id=?", (sid,))
|
||||
row = await cursor.fetchone()
|
||||
assert bool(row["revoked"]) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# auth_session_row_to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthSessionRowToDict:
|
||||
async def test_converts_all_fields(self, fresh_db: Path):
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
sid = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, device_fingerprint, device_label, ip, "
|
||||
" user_agent, auth_provider, created_at, last_active_at, expires_at, "
|
||||
" revoked, revoked_reason, previous_session_id) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
sid,
|
||||
user_id,
|
||||
"abc123",
|
||||
"fp-xyz",
|
||||
"macOS Chrome 119",
|
||||
"10.0.0.1",
|
||||
"Mozilla/5.0",
|
||||
"local",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
"2026-06-20T00:00:00+00:00",
|
||||
"2027-01-01T00:00:00+00:00",
|
||||
1,
|
||||
"user_terminated",
|
||||
"old-sid",
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
|
||||
row = await cursor.fetchone()
|
||||
|
||||
d = auth_session_row_to_dict(row)
|
||||
assert d["id"] == sid
|
||||
assert d["user_id"] == user_id
|
||||
assert d["device_fingerprint"] == "fp-xyz"
|
||||
assert d["device_label"] == "macOS Chrome 119"
|
||||
assert d["ip"] == "10.0.0.1"
|
||||
assert d["user_agent"] == "Mozilla/5.0"
|
||||
assert d["auth_provider"] == "local"
|
||||
assert d["revoked"] is True
|
||||
assert d["revoked_reason"] == "user_terminated"
|
||||
assert d["previous_session_id"] == "old-sid"
|
||||
|
||||
async def test_normalizes_revoked_to_bool(self, fresh_db: Path):
|
||||
"""DB stores 0/1; helper should return Python bool."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
sid = str(uuid.uuid4())
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||
" revoked) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||
(sid, user_id, "x", "t", "t", "t"),
|
||||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM auth_sessions WHERE id=?", (sid,))
|
||||
row = await cursor.fetchone()
|
||||
d = auth_session_row_to_dict(row)
|
||||
assert isinstance(d["revoked"], bool)
|
||||
assert d["revoked"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _backfill_user_sessions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackfillUserSessions:
|
||||
async def test_backfills_non_revoked_v1_rows(self, fresh_db: Path):
|
||||
"""On a fresh V1 install, all non-revoked rows are migrated to auth_sessions."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="hash1",
|
||||
device_info="{}",
|
||||
)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="hash2",
|
||||
device_info=json.dumps(
|
||||
{
|
||||
"fingerprint": "mac-tauri",
|
||||
"label": "macOS Tauri 1.0",
|
||||
"ip": "192.168.1.10",
|
||||
"user_agent": "Tauri/1.0",
|
||||
}
|
||||
),
|
||||
)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="hash3",
|
||||
revoked=True,
|
||||
)
|
||||
await db.commit()
|
||||
# Force a backfill by clearing the marker + dropping auth_sessions rows
|
||||
# (simulating a V1 install upgrading to V2).
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
# Re-init to trigger the migration
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute("SELECT * FROM auth_sessions ORDER BY refresh_token_hash")
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# Only the 2 non-revoked rows should have been backfilled
|
||||
assert len(rows) == 2
|
||||
hashes = {row["refresh_token_hash"] for row in rows}
|
||||
assert hashes == {"hash1", "hash2"}
|
||||
|
||||
async def test_backfill_preserves_original_id(self, fresh_db: Path):
|
||||
"""Backfilled rows reuse the V1 id so legacy clients match."""
|
||||
user_id = str(uuid.uuid4())
|
||||
original_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
session_id=original_id,
|
||||
refresh_hash="hash-original",
|
||||
)
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT id FROM auth_sessions WHERE refresh_token_hash='hash-original'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["id"] == original_id
|
||||
|
||||
async def test_backfill_does_not_touch_revoked_v1_rows(self, fresh_db: Path):
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="revoked-hash",
|
||||
revoked=True,
|
||||
)
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT COUNT(*) AS c FROM auth_sessions WHERE refresh_token_hash='revoked-hash'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["c"] == 0
|
||||
|
||||
async def test_backfill_copies_device_info_fields(self, fresh_db: Path):
|
||||
user_id = str(uuid.uuid4())
|
||||
device_info = json.dumps(
|
||||
{
|
||||
"fingerprint": "win-tauri-abc",
|
||||
"label": "Windows Tauri 1.0",
|
||||
"ip": "10.0.0.5",
|
||||
"user_agent": "Tauri/2.0",
|
||||
}
|
||||
)
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="hash-with-device",
|
||||
device_info=device_info,
|
||||
)
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT device_fingerprint, device_label, ip, user_agent "
|
||||
"FROM auth_sessions WHERE refresh_token_hash='hash-with-device'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["device_fingerprint"] == "win-tauri-abc"
|
||||
assert row["device_label"] == "Windows Tauri 1.0"
|
||||
assert row["ip"] == "10.0.0.5"
|
||||
assert row["user_agent"] == "Tauri/2.0"
|
||||
|
||||
async def test_backfill_handles_malformed_device_info(self, fresh_db: Path):
|
||||
"""Malformed JSON in device_info → fall back to defaults."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(
|
||||
db,
|
||||
user_id=user_id,
|
||||
refresh_hash="bad-json",
|
||||
device_info="not-json{",
|
||||
)
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT device_fingerprint, device_label FROM auth_sessions "
|
||||
"WHERE refresh_token_hash='bad-json'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["device_fingerprint"] == "unknown"
|
||||
assert row["device_label"] == "Unknown device"
|
||||
|
||||
async def test_backfill_marks_done_in_auth_meta(self, fresh_db: Path):
|
||||
"""After a backfill, the auth_meta marker is set so it doesn't run again."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await _insert_user_session(db, user_id=user_id, refresh_hash="h1")
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["value"] == "done"
|
||||
|
||||
async def test_backfill_is_idempotent(self, fresh_db: Path):
|
||||
"""Running init twice does NOT duplicate auth_sessions rows."""
|
||||
user_id = str(uuid.uuid4())
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
for i in range(3):
|
||||
await _insert_user_session(db, user_id=user_id, refresh_hash=f"hash{i}")
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
await init_auth_db(fresh_db) # second run should be a no-op
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
cursor = await db.execute("SELECT COUNT(*) AS c FROM auth_sessions")
|
||||
row = await cursor.fetchone()
|
||||
assert row[0] == 3
|
||||
|
||||
async def test_fresh_install_marks_backfill_done_without_rows(self, fresh_db: Path):
|
||||
"""A fresh V2 install (no V1 rows) still marks the backfill as done."""
|
||||
# The fresh_db fixture already ran init_auth_db once.
|
||||
# Re-running should be a no-op (idempotent).
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT value FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row["value"] == "done"
|
||||
|
||||
async def test_backfill_preserves_expires_at(self, fresh_db: Path):
|
||||
user_id = str(uuid.uuid4())
|
||||
original_exp = "2027-06-20T12:34:56+00:00"
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
await db.execute(
|
||||
"INSERT INTO user_sessions "
|
||||
"(id, user_id, refresh_token_hash, device_info, created_at, expires_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
user_id,
|
||||
"exp-hash",
|
||||
"{}",
|
||||
"2026-01-01T00:00:00+00:00",
|
||||
original_exp,
|
||||
),
|
||||
)
|
||||
await db.execute("DELETE FROM auth_sessions")
|
||||
await db.execute("DELETE FROM auth_meta WHERE key='backfill_user_sessions_v1_to_v2'")
|
||||
await db.commit()
|
||||
|
||||
await init_auth_db(fresh_db)
|
||||
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT expires_at FROM auth_sessions WHERE refresh_token_hash='exp-hash'"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row["expires_at"] == original_exp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Active-session query pattern (covers the cap-count and list-active paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestActiveSessionQueries:
|
||||
async def test_query_active_sessions_for_user(self, fresh_db: Path):
|
||||
"""The (user_id, revoked, expires_at) index supports the active-session query."""
|
||||
user_id = str(uuid.uuid4())
|
||||
future = "2027-12-31T00:00:00+00:00"
|
||||
past = "2020-01-01T00:00:00+00:00"
|
||||
async with aiosqlite.connect(str(fresh_db)) as db:
|
||||
await _insert_user(db, user_id=user_id)
|
||||
# 2 active, 1 expired, 1 revoked → 2 should match
|
||||
for hash_, exp in [("a", future), ("b", future), ("c", past)]:
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||
" revoked) VALUES (?, ?, ?, ?, ?, ?, 0)",
|
||||
(str(uuid.uuid4()), user_id, hash_, "t", "t", exp),
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO auth_sessions "
|
||||
"(id, user_id, refresh_token_hash, created_at, last_active_at, expires_at, "
|
||||
" revoked, revoked_reason) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, 1, 'user_terminated')",
|
||||
(str(uuid.uuid4()), user_id, "d", "t", "t", future),
|
||||
)
|
||||
await db.commit()
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT refresh_token_hash FROM auth_sessions "
|
||||
"WHERE user_id = ? AND revoked = 0 AND expires_at > ?",
|
||||
(user_id, "2026-06-20T00:00:00+00:00"),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
hashes = {row["refresh_token_hash"] for row in rows}
|
||||
assert hashes == {"a", "b"}
|
||||
Loading…
Reference in New Issue