336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""Unit tests for SessionService (U3)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.session_service import (
|
|
REVOKE_REASON_PASSWORD_CHANGED,
|
|
REVOKE_REASON_REUSE_DETECTED,
|
|
REVOKE_REASON_SESSION_CAP_EVICTION,
|
|
REVOKE_REASON_USER_TERMINATED,
|
|
SessionCreate,
|
|
SessionNotFound,
|
|
SessionReuseDetected,
|
|
SessionService,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def auth_db(tmp_path: Path):
|
|
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
prev = os.environ.get("AGENTKIT_AUTH_DB")
|
|
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
|
|
try:
|
|
yield db_path
|
|
finally:
|
|
if prev is None:
|
|
os.environ.pop("AGENTKIT_AUTH_DB", None)
|
|
else:
|
|
os.environ["AGENTKIT_AUTH_DB"] = prev
|
|
|
|
|
|
@pytest.fixture
|
|
async def user_id(auth_db: Path) -> str:
|
|
"""Insert a single user and return its id."""
|
|
user_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
async with aiosqlite.connect(str(auth_db)) as db:
|
|
await db.execute(
|
|
"INSERT INTO users (id, username, email, password_hash, role, "
|
|
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
|
|
"created_at, updated_at, last_login_at, created_by) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(
|
|
user_id,
|
|
"alice",
|
|
"alice@example.com",
|
|
"x",
|
|
"member",
|
|
1,
|
|
0,
|
|
0,
|
|
now,
|
|
now,
|
|
None,
|
|
None,
|
|
),
|
|
)
|
|
await db.commit()
|
|
return user_id
|
|
|
|
|
|
@pytest.fixture
|
|
async def svc(auth_db: Path) -> SessionService:
|
|
"""A SessionService backed by the in-memory denylist for determinism."""
|
|
return SessionService(
|
|
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
|
|
)
|
|
|
|
|
|
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
|
|
return SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token=refresh_token,
|
|
device_fingerprint="fp",
|
|
device_label="Test device",
|
|
ip="127.0.0.1",
|
|
user_agent="pytest",
|
|
auth_provider="local",
|
|
ttl_seconds=3600,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# create
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert info.user_id == user_id
|
|
assert info.revoked is False
|
|
# Round-trip via the read API
|
|
fetched = await svc.get(info.id)
|
|
assert fetched is not None
|
|
assert fetched.id == info.id
|
|
|
|
|
|
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
|
|
# Cap is 3; insert 3, then a 4th should evict the first.
|
|
a = await svc.create(_make_create(user_id, "rt-a"))
|
|
b = await svc.create(_make_create(user_id, "rt-b"))
|
|
c = await svc.create(_make_create(user_id, "rt-c"))
|
|
d = await svc.create(_make_create(user_id, "rt-d"))
|
|
|
|
assert (await svc.get(a.id)).revoked is True
|
|
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
|
|
for kept in (b, c, d):
|
|
assert (await svc.get(kept.id)).revoked is False
|
|
|
|
# Sanity: exactly session_cap remain (this test uses cap=3)
|
|
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
|
|
assert len(remaining) == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# rotate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id, "rt-old"))
|
|
new = await svc.rotate(
|
|
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
|
|
)
|
|
assert new.id == info.id
|
|
# Old hash no longer resolves
|
|
assert await svc.find_by_refresh_token("rt-old") is None
|
|
# New hash does
|
|
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
|
|
|
|
|
|
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
|
|
|
|
|
|
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
|
|
with pytest.raises(SessionNotFound):
|
|
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
|
|
|
|
|
|
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
# A second attempt with the old token: must raise and revoke all
|
|
with pytest.raises(SessionReuseDetected):
|
|
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
|
|
# All sessions for this user are now revoked
|
|
active = await svc.list_for_user(user_id, include_revoked=True)
|
|
assert all(s.revoked for s in active)
|
|
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
|
|
|
|
|
|
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.revoke_by_refresh_token("rt-1")
|
|
with pytest.raises(SessionNotFound):
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# revoke
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
ok = await svc.revoke(info.id)
|
|
assert ok is True
|
|
assert (await svc.get(info.id)).revoked is True
|
|
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
|
|
|
|
|
|
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert await svc.revoke(info.id) is True
|
|
assert await svc.revoke(info.id) is False
|
|
|
|
|
|
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
assert await svc.revoke_by_refresh_token("rt-1") is True
|
|
assert await svc.revoke_by_refresh_token("rt-1") is False
|
|
assert await svc.revoke_by_refresh_token("rt-never") is False
|
|
|
|
|
|
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "a"))
|
|
await svc.create(_make_create(user_id, "b"))
|
|
n = await svc.revoke_all_for_user(
|
|
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
|
|
)
|
|
assert n == 2
|
|
active = await svc.list_for_user(user_id)
|
|
assert active == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list / get
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
|
|
a = await svc.create(_make_create(user_id, "a"))
|
|
b = await svc.create(_make_create(user_id, "b"))
|
|
await svc.revoke(a.id)
|
|
visible = await svc.list_for_user(user_id)
|
|
assert [s.id for s in visible] == [b.id]
|
|
visible_all = await svc.list_for_user(user_id, include_revoked=True)
|
|
assert {s.id for s in visible_all} == {a.id, b.id}
|
|
|
|
|
|
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert await svc.is_session_valid(info.id) is True
|
|
await svc.revoke(info.id)
|
|
assert await svc.is_session_valid(info.id) is False
|
|
|
|
|
|
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
|
|
# Create with a very short TTL
|
|
create = SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token="rt-x",
|
|
device_fingerprint="fp",
|
|
device_label="d",
|
|
ip="",
|
|
user_agent="",
|
|
auth_provider="local",
|
|
ttl_seconds=0, # expires immediately
|
|
)
|
|
info = await svc.create(create)
|
|
# Re-check with a future expires_at: should be expired
|
|
assert await svc.is_session_valid(info.id) is False
|
|
|
|
|
|
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
|
|
assert await svc.is_session_valid("nonexistent-id") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_active_by_provider (U2 fix: expired sessions must be filtered)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_list_active_by_provider_excludes_expired(svc: SessionService, user_id: str):
|
|
"""Expired sessions must NOT appear in list_active_by_provider (U2 fix).
|
|
|
|
The docstring promises "non-revoked, non-expired" — before the U2 fix
|
|
the SQL only checked ``revoked = 0`` and ignored ``expires_at``.
|
|
"""
|
|
# An expired session (TTL=0 → expires immediately).
|
|
expired_create = SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token="rt-expired",
|
|
device_fingerprint="fp",
|
|
device_label="expired-device",
|
|
ip="",
|
|
user_agent="",
|
|
auth_provider="local",
|
|
ttl_seconds=0,
|
|
)
|
|
expired_info = await svc.create(expired_create)
|
|
|
|
# A live session.
|
|
live_info = await svc.create(_make_create(user_id, "rt-live"))
|
|
|
|
# list_active_by_provider must return ONLY the live one.
|
|
active = await svc.list_active_by_provider("local")
|
|
active_ids = {s.id for s in active}
|
|
assert live_info.id in active_ids
|
|
assert expired_info.id not in active_ids
|
|
|
|
|
|
async def test_list_active_by_provider_excludes_revoked(svc: SessionService, user_id: str):
|
|
"""Revoked sessions must also be excluded (regression guard)."""
|
|
a = await svc.create(_make_create(user_id, "rt-a"))
|
|
b = await svc.create(_make_create(user_id, "rt-b"))
|
|
await svc.revoke(a.id)
|
|
active = await svc.list_active_by_provider("local")
|
|
active_ids = {s.id for s in active}
|
|
assert b.id in active_ids
|
|
assert a.id not in active_ids
|
|
|
|
|
|
async def test_list_active_by_provider_filters_by_provider(
|
|
svc: SessionService, user_id: str
|
|
):
|
|
"""Only sessions matching the requested auth_provider are returned."""
|
|
# SessionCreate is a frozen dataclass — build each with its provider.
|
|
local_create = SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token="rt-local",
|
|
device_fingerprint="fp",
|
|
device_label="d",
|
|
ip="",
|
|
user_agent="",
|
|
auth_provider="local",
|
|
ttl_seconds=3600,
|
|
)
|
|
await svc.create(local_create)
|
|
|
|
oidc_create = SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token="rt-oidc",
|
|
device_fingerprint="fp",
|
|
device_label="d",
|
|
ip="",
|
|
user_agent="",
|
|
auth_provider="oidc",
|
|
ttl_seconds=3600,
|
|
)
|
|
await svc.create(oidc_create)
|
|
|
|
local_only = await svc.list_active_by_provider("local")
|
|
assert all(s.auth_provider == "local" for s in local_only)
|
|
assert len(local_only) == 1
|