fischer-agentkit/tests/unit/auth/test_session_service.py

336 lines
12 KiB
Python

"""Unit tests for SessionService (U3)."""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
import aiosqlite
import pytest
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
from agentkit.server.auth.models import init_auth_db
from agentkit.server.auth.session_service import (
REVOKE_REASON_PASSWORD_CHANGED,
REVOKE_REASON_REUSE_DETECTED,
REVOKE_REASON_SESSION_CAP_EVICTION,
REVOKE_REASON_USER_TERMINATED,
SessionCreate,
SessionNotFound,
SessionReuseDetected,
SessionService,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def auth_db(tmp_path: Path):
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
prev = os.environ.get("AGENTKIT_AUTH_DB")
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
try:
yield db_path
finally:
if prev is None:
os.environ.pop("AGENTKIT_AUTH_DB", None)
else:
os.environ["AGENTKIT_AUTH_DB"] = prev
@pytest.fixture
async def user_id(auth_db: Path) -> str:
"""Insert a single user and return its id."""
user_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
async with aiosqlite.connect(str(auth_db)) as db:
await db.execute(
"INSERT INTO users (id, username, email, password_hash, role, "
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
"created_at, updated_at, last_login_at, created_by) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
user_id,
"alice",
"alice@example.com",
"x",
"member",
1,
0,
0,
now,
now,
None,
None,
),
)
await db.commit()
return user_id
@pytest.fixture
async def svc(auth_db: Path) -> SessionService:
"""A SessionService backed by the in-memory denylist for determinism."""
return SessionService(
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
)
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
return SessionCreate(
user_id=user_id,
refresh_token=refresh_token,
device_fingerprint="fp",
device_label="Test device",
ip="127.0.0.1",
user_agent="pytest",
auth_provider="local",
ttl_seconds=3600,
)
# ---------------------------------------------------------------------------
# create
# ---------------------------------------------------------------------------
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert info.user_id == user_id
assert info.revoked is False
# Round-trip via the read API
fetched = await svc.get(info.id)
assert fetched is not None
assert fetched.id == info.id
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
# Cap is 3; insert 3, then a 4th should evict the first.
a = await svc.create(_make_create(user_id, "rt-a"))
b = await svc.create(_make_create(user_id, "rt-b"))
c = await svc.create(_make_create(user_id, "rt-c"))
d = await svc.create(_make_create(user_id, "rt-d"))
assert (await svc.get(a.id)).revoked is True
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
for kept in (b, c, d):
assert (await svc.get(kept.id)).revoked is False
# Sanity: exactly session_cap remain (this test uses cap=3)
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
assert len(remaining) == 3
# ---------------------------------------------------------------------------
# rotate
# ---------------------------------------------------------------------------
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id, "rt-old"))
new = await svc.rotate(
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
)
assert new.id == info.id
# Old hash no longer resolves
assert await svc.find_by_refresh_token("rt-old") is None
# New hash does
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
with pytest.raises(SessionNotFound):
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# A second attempt with the old token: must raise and revoke all
with pytest.raises(SessionReuseDetected):
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
# All sessions for this user are now revoked
active = await svc.list_for_user(user_id, include_revoked=True)
assert all(s.revoked for s in active)
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.revoke_by_refresh_token("rt-1")
with pytest.raises(SessionNotFound):
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# ---------------------------------------------------------------------------
# revoke
# ---------------------------------------------------------------------------
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
ok = await svc.revoke(info.id)
assert ok is True
assert (await svc.get(info.id)).revoked is True
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.revoke(info.id) is True
assert await svc.revoke(info.id) is False
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
assert await svc.revoke_by_refresh_token("rt-1") is True
assert await svc.revoke_by_refresh_token("rt-1") is False
assert await svc.revoke_by_refresh_token("rt-never") is False
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "a"))
await svc.create(_make_create(user_id, "b"))
n = await svc.revoke_all_for_user(
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
)
assert n == 2
active = await svc.list_for_user(user_id)
assert active == []
# ---------------------------------------------------------------------------
# list / get
# ---------------------------------------------------------------------------
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
a = await svc.create(_make_create(user_id, "a"))
b = await svc.create(_make_create(user_id, "b"))
await svc.revoke(a.id)
visible = await svc.list_for_user(user_id)
assert [s.id for s in visible] == [b.id]
visible_all = await svc.list_for_user(user_id, include_revoked=True)
assert {s.id for s in visible_all} == {a.id, b.id}
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.is_session_valid(info.id) is True
await svc.revoke(info.id)
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
# Create with a very short TTL
create = SessionCreate(
user_id=user_id,
refresh_token="rt-x",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=0, # expires immediately
)
info = await svc.create(create)
# Re-check with a future expires_at: should be expired
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
assert await svc.is_session_valid("nonexistent-id") is False
# ---------------------------------------------------------------------------
# list_active_by_provider (U2 fix: expired sessions must be filtered)
# ---------------------------------------------------------------------------
async def test_list_active_by_provider_excludes_expired(svc: SessionService, user_id: str):
"""Expired sessions must NOT appear in list_active_by_provider (U2 fix).
The docstring promises "non-revoked, non-expired" — before the U2 fix
the SQL only checked ``revoked = 0`` and ignored ``expires_at``.
"""
# An expired session (TTL=0 → expires immediately).
expired_create = SessionCreate(
user_id=user_id,
refresh_token="rt-expired",
device_fingerprint="fp",
device_label="expired-device",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=0,
)
expired_info = await svc.create(expired_create)
# A live session.
live_info = await svc.create(_make_create(user_id, "rt-live"))
# list_active_by_provider must return ONLY the live one.
active = await svc.list_active_by_provider("local")
active_ids = {s.id for s in active}
assert live_info.id in active_ids
assert expired_info.id not in active_ids
async def test_list_active_by_provider_excludes_revoked(svc: SessionService, user_id: str):
"""Revoked sessions must also be excluded (regression guard)."""
a = await svc.create(_make_create(user_id, "rt-a"))
b = await svc.create(_make_create(user_id, "rt-b"))
await svc.revoke(a.id)
active = await svc.list_active_by_provider("local")
active_ids = {s.id for s in active}
assert b.id in active_ids
assert a.id not in active_ids
async def test_list_active_by_provider_filters_by_provider(
svc: SessionService, user_id: str
):
"""Only sessions matching the requested auth_provider are returned."""
# SessionCreate is a frozen dataclass — build each with its provider.
local_create = SessionCreate(
user_id=user_id,
refresh_token="rt-local",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=3600,
)
await svc.create(local_create)
oidc_create = SessionCreate(
user_id=user_id,
refresh_token="rt-oidc",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="oidc",
ttl_seconds=3600,
)
await svc.create(oidc_create)
local_only = await svc.list_active_by_provider("local")
assert all(s.auth_provider == "local" for s in local_only)
assert len(local_only) == 1