257 lines
8.9 KiB
Python
257 lines
8.9 KiB
Python
"""Unit tests for SessionService (U3)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from agentkit.server.auth.denylist import InMemoryRecentlyRevoked
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.session_service import (
|
|
REVOKE_REASON_PASSWORD_CHANGED,
|
|
REVOKE_REASON_REUSE_DETECTED,
|
|
REVOKE_REASON_SESSION_CAP_EVICTION,
|
|
REVOKE_REASON_USER_TERMINATED,
|
|
SessionCreate,
|
|
SessionNotFound,
|
|
SessionReuseDetected,
|
|
SessionService,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def auth_db(tmp_path: Path):
|
|
"""Initialise a fresh auth DB in a tmpdir; set AGENTKIT_AUTH_DB for the duration."""
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
prev = os.environ.get("AGENTKIT_AUTH_DB")
|
|
os.environ["AGENTKIT_AUTH_DB"] = str(db_path)
|
|
try:
|
|
yield db_path
|
|
finally:
|
|
if prev is None:
|
|
os.environ.pop("AGENTKIT_AUTH_DB", None)
|
|
else:
|
|
os.environ["AGENTKIT_AUTH_DB"] = prev
|
|
|
|
|
|
@pytest.fixture
|
|
async def user_id(auth_db: Path) -> str:
|
|
"""Insert a single user and return its id."""
|
|
user_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
async with aiosqlite.connect(str(auth_db)) as db:
|
|
await db.execute(
|
|
"INSERT INTO users (id, username, email, password_hash, role, "
|
|
"is_active, is_terminal_authorized, is_server_terminal_authorized, "
|
|
"created_at, updated_at, last_login_at, created_by) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(
|
|
user_id,
|
|
"alice",
|
|
"alice@example.com",
|
|
"x",
|
|
"member",
|
|
1,
|
|
0,
|
|
0,
|
|
now,
|
|
now,
|
|
None,
|
|
None,
|
|
),
|
|
)
|
|
await db.commit()
|
|
return user_id
|
|
|
|
|
|
@pytest.fixture
|
|
async def svc(auth_db: Path) -> SessionService:
|
|
"""A SessionService backed by the in-memory denylist for determinism."""
|
|
return SessionService(
|
|
db_path=auth_db, denylist=InMemoryRecentlyRevoked(), session_cap=3
|
|
)
|
|
|
|
|
|
def _make_create(user_id: str, refresh_token: str = "r1") -> SessionCreate:
|
|
return SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token=refresh_token,
|
|
device_fingerprint="fp",
|
|
device_label="Test device",
|
|
ip="127.0.0.1",
|
|
user_agent="pytest",
|
|
auth_provider="local",
|
|
ttl_seconds=3600,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# create
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_create_inserts_row_and_returns_info(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert info.user_id == user_id
|
|
assert info.revoked is False
|
|
# Round-trip via the read API
|
|
fetched = await svc.get(info.id)
|
|
assert fetched is not None
|
|
assert fetched.id == info.id
|
|
|
|
|
|
async def test_create_evicts_oldest_when_at_cap(svc: SessionService, user_id: str):
|
|
# Cap is 3; insert 3, then a 4th should evict the first.
|
|
a = await svc.create(_make_create(user_id, "rt-a"))
|
|
b = await svc.create(_make_create(user_id, "rt-b"))
|
|
c = await svc.create(_make_create(user_id, "rt-c"))
|
|
d = await svc.create(_make_create(user_id, "rt-d"))
|
|
|
|
assert (await svc.get(a.id)).revoked is True
|
|
assert (await svc.get(a.id)).revoked_reason == REVOKE_REASON_SESSION_CAP_EVICTION
|
|
for kept in (b, c, d):
|
|
assert (await svc.get(kept.id)).revoked is False
|
|
|
|
# Sanity: exactly session_cap remain (this test uses cap=3)
|
|
remaining = [s for s in await svc.list_for_user(user_id) if not s.revoked]
|
|
assert len(remaining) == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# rotate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_rotate_replaces_refresh_hash(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id, "rt-old"))
|
|
new = await svc.rotate(
|
|
old_refresh_token="rt-old", new_refresh_token="rt-new", new_ttl_seconds=3600
|
|
)
|
|
assert new.id == info.id
|
|
# Old hash no longer resolves
|
|
assert await svc.find_by_refresh_token("rt-old") is None
|
|
# New hash does
|
|
assert (await svc.find_by_refresh_token("rt-new")).id == info.id
|
|
|
|
|
|
async def test_rotate_adds_old_to_denylist(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
assert svc._denylist.contains(__import__("hashlib").sha256(b"rt-1").hexdigest())
|
|
|
|
|
|
async def test_rotate_unknown_token_raises(svc: SessionService, user_id: str):
|
|
with pytest.raises(SessionNotFound):
|
|
await svc.rotate("never-issued", "rt-new", new_ttl_seconds=3600)
|
|
|
|
|
|
async def test_rotate_reuse_raises_and_revokes_all(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
# A second attempt with the old token: must raise and revoke all
|
|
with pytest.raises(SessionReuseDetected):
|
|
await svc.rotate("rt-1", "rt-2b", new_ttl_seconds=3600)
|
|
# All sessions for this user are now revoked
|
|
active = await svc.list_for_user(user_id, include_revoked=True)
|
|
assert all(s.revoked for s in active)
|
|
assert all(s.revoked_reason == REVOKE_REASON_REUSE_DETECTED for s in active)
|
|
|
|
|
|
async def test_rotate_revoked_session_raises(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
await svc.revoke_by_refresh_token("rt-1")
|
|
with pytest.raises(SessionNotFound):
|
|
await svc.rotate("rt-1", "rt-2", new_ttl_seconds=3600)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# revoke
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_revoke_marks_session_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
ok = await svc.revoke(info.id)
|
|
assert ok is True
|
|
assert (await svc.get(info.id)).revoked is True
|
|
assert (await svc.get(info.id)).revoked_reason == REVOKE_REASON_USER_TERMINATED
|
|
|
|
|
|
async def test_revoke_returns_false_if_already_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert await svc.revoke(info.id) is True
|
|
assert await svc.revoke(info.id) is False
|
|
|
|
|
|
async def test_revoke_by_refresh_token(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "rt-1"))
|
|
assert await svc.revoke_by_refresh_token("rt-1") is True
|
|
assert await svc.revoke_by_refresh_token("rt-1") is False
|
|
assert await svc.revoke_by_refresh_token("rt-never") is False
|
|
|
|
|
|
async def test_revoke_all_for_user(svc: SessionService, user_id: str):
|
|
await svc.create(_make_create(user_id, "a"))
|
|
await svc.create(_make_create(user_id, "b"))
|
|
n = await svc.revoke_all_for_user(
|
|
user_id, reason=REVOKE_REASON_PASSWORD_CHANGED
|
|
)
|
|
assert n == 2
|
|
active = await svc.list_for_user(user_id)
|
|
assert active == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list / get
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def test_list_for_user_excludes_revoked_by_default(svc: SessionService, user_id: str):
|
|
a = await svc.create(_make_create(user_id, "a"))
|
|
b = await svc.create(_make_create(user_id, "b"))
|
|
await svc.revoke(a.id)
|
|
visible = await svc.list_for_user(user_id)
|
|
assert [s.id for s in visible] == [b.id]
|
|
visible_all = await svc.list_for_user(user_id, include_revoked=True)
|
|
assert {s.id for s in visible_all} == {a.id, b.id}
|
|
|
|
|
|
async def test_is_session_valid_rejects_revoked(svc: SessionService, user_id: str):
|
|
info = await svc.create(_make_create(user_id))
|
|
assert await svc.is_session_valid(info.id) is True
|
|
await svc.revoke(info.id)
|
|
assert await svc.is_session_valid(info.id) is False
|
|
|
|
|
|
async def test_is_session_valid_rejects_expired(svc: SessionService, user_id: str):
|
|
# Create with a very short TTL
|
|
create = SessionCreate(
|
|
user_id=user_id,
|
|
refresh_token="rt-x",
|
|
device_fingerprint="fp",
|
|
device_label="d",
|
|
ip="",
|
|
user_agent="",
|
|
auth_provider="local",
|
|
ttl_seconds=0, # expires immediately
|
|
)
|
|
info = await svc.create(create)
|
|
# Re-check with a future expires_at: should be expired
|
|
assert await svc.is_session_valid(info.id) is False
|
|
|
|
|
|
async def test_is_session_valid_returns_false_for_unknown(svc: SessionService):
|
|
assert await svc.is_session_valid("nonexistent-id") is False
|