fischer-agentkit/tests/unit/auth/test_session_service.py

257 lines
8.9 KiB
Python

"""Unit tests for SessionService (U3)."""
from __future__ import annotations
import asyncio
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
import aiosqlite
import pytest
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
from agentkit.server.auth.models import init_auth_db
from agentkit.server.auth.session_service import (
REVOKE_REASON_PASSWORD_CHANGED,
REVOKE_REASON_REUSE_DETECTED,
REVOKE_REASON_SESSION_CAP_EVICTION,
REVOKE_REASON_USER_TERMINATED,
SessionCreate,
SessionNotFound,
SessionReuseDetected,
SessionService,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def auth_db(tmp_path: Path):
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
prev = os.environ.get("AGENTKIT_AUTH_DB")
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
try:
yield db_path
finally:
if prev is None:
os.environ.pop("AGENTKIT_AUTH_DB", None)
else:
os.environ["AGENTKIT_AUTH_DB"] = prev
@pytest.fixture
async def user_id(auth_db: Path) -> str:
"""Insert a single user and return its id."""
user_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
async with aiosqlite.connect(str(auth_db)) as db:
await db.execute(
"INSERT INTO users (id, username, email, password_hash, role, "
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
"created_at, updated_at, last_login_at, created_by) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
user_id,
"alice",
"alice@example.com",
"x",
"member",
1,
0,
0,
now,
now,
None,
None,
),
)
await db.commit()
return user_id
@pytest.fixture
async def svc(auth_db: Path) -> SessionService:
"""A SessionService backed by the in-memory denylist for determinism."""
return SessionService(
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
)
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
return SessionCreate(
user_id=user_id,
refresh_token=refresh_token,
device_fingerprint="fp",
device_label="Test device",
ip="127.0.0.1",
user_agent="pytest",
auth_provider="local",
ttl_seconds=3600,
)
# ---------------------------------------------------------------------------
# create
# ---------------------------------------------------------------------------
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert info.user_id == user_id
assert info.revoked is False
# Round-trip via the read API
fetched = await svc.get(info.id)
assert fetched is not None
assert fetched.id == info.id
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
# Cap is 3; insert 3, then a 4th should evict the first.
a = await svc.create(_make_create(user_id, "rt-a"))
b = await svc.create(_make_create(user_id, "rt-b"))
c = await svc.create(_make_create(user_id, "rt-c"))
d = await svc.create(_make_create(user_id, "rt-d"))
assert (await svc.get(a.id)).revoked is True
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
for kept in (b, c, d):
assert (await svc.get(kept.id)).revoked is False
# Sanity: exactly session_cap remain (this test uses cap=3)
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
assert len(remaining) == 3
# ---------------------------------------------------------------------------
# rotate
# ---------------------------------------------------------------------------
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id, "rt-old"))
new = await svc.rotate(
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
)
assert new.id == info.id
# Old hash no longer resolves
assert await svc.find_by_refresh_token("rt-old") is None
# New hash does
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
with pytest.raises(SessionNotFound):
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# A second attempt with the old token: must raise and revoke all
with pytest.raises(SessionReuseDetected):
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
# All sessions for this user are now revoked
active = await svc.list_for_user(user_id, include_revoked=True)
assert all(s.revoked for s in active)
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
await svc.revoke_by_refresh_token("rt-1")
with pytest.raises(SessionNotFound):
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
# ---------------------------------------------------------------------------
# revoke
# ---------------------------------------------------------------------------
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
ok = await svc.revoke(info.id)
assert ok is True
assert (await svc.get(info.id)).revoked is True
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.revoke(info.id) is True
assert await svc.revoke(info.id) is False
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "rt-1"))
assert await svc.revoke_by_refresh_token("rt-1") is True
assert await svc.revoke_by_refresh_token("rt-1") is False
assert await svc.revoke_by_refresh_token("rt-never") is False
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
await svc.create(_make_create(user_id, "a"))
await svc.create(_make_create(user_id, "b"))
n = await svc.revoke_all_for_user(
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
)
assert n == 2
active = await svc.list_for_user(user_id)
assert active == []
# ---------------------------------------------------------------------------
# list / get
# ---------------------------------------------------------------------------
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
a = await svc.create(_make_create(user_id, "a"))
b = await svc.create(_make_create(user_id, "b"))
await svc.revoke(a.id)
visible = await svc.list_for_user(user_id)
assert [s.id for s in visible] == [b.id]
visible_all = await svc.list_for_user(user_id, include_revoked=True)
assert {s.id for s in visible_all} == {a.id, b.id}
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
info = await svc.create(_make_create(user_id))
assert await svc.is_session_valid(info.id) is True
await svc.revoke(info.id)
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
# Create with a very short TTL
create = SessionCreate(
user_id=user_id,
refresh_token="rt-x",
device_fingerprint="fp",
device_label="d",
ip="",
user_agent="",
auth_provider="local",
ttl_seconds=0, # expires immediately
)
info = await svc.create(create)
# Re-check with a future expires_at: should be expired
assert await svc.is_session_valid(info.id) is False
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
assert await svc.is_session_valid("nonexistent-id") is False