255 lines
9.2 KiB
Python
255 lines
9.2 KiB
Python
"""Tests for LocalAuthProvider (U11 — concrete Local implementation)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.password import hash_password
|
|
from agentkit.server.auth.providers import LocalAuthProvider
|
|
from agentkit.server.auth.providers.exceptions import InvalidCredentials
|
|
from agentkit.server.auth.providers.user import User
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def auth_db_with_users(tmp_path: Path) -> dict:
|
|
"""Create a fresh auth DB with two users: one active, one inactive.
|
|
|
|
Returns a dict with user info + the db path.
|
|
"""
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
|
|
now_iso = datetime.now(timezone.utc).isoformat()
|
|
active_id = str(uuid.uuid4())
|
|
inactive_id = str(uuid.uuid4())
|
|
active_pw = "correct-horse-battery-staple"
|
|
inactive_pw = "disabled-user-pw"
|
|
|
|
async with aiosqlite.connect(str(db_path)) as db:
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(
|
|
active_id,
|
|
"alice",
|
|
"alice@example.com",
|
|
hash_password(active_pw),
|
|
"member",
|
|
1,
|
|
0,
|
|
0,
|
|
now_iso,
|
|
now_iso,
|
|
),
|
|
)
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(
|
|
inactive_id,
|
|
"bob_inactive",
|
|
"bob@example.com",
|
|
hash_password(inactive_pw),
|
|
"member",
|
|
0,
|
|
0,
|
|
0,
|
|
now_iso,
|
|
now_iso,
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
return {
|
|
"db_path": db_path,
|
|
"active": {
|
|
"id": active_id,
|
|
"username": "alice",
|
|
"password": active_pw,
|
|
"email": "alice@example.com",
|
|
"role": "member",
|
|
},
|
|
"inactive": {
|
|
"id": inactive_id,
|
|
"username": "bob_inactive",
|
|
"password": inactive_pw,
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def provider(auth_db_with_users: dict) -> LocalAuthProvider:
|
|
return LocalAuthProvider(db_path=auth_db_with_users["db_path"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# authenticate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAuthenticate:
|
|
async def test_valid_credentials_returns_user(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
user = await provider.authenticate(
|
|
username=auth_db_with_users["active"]["username"],
|
|
password=auth_db_with_users["active"]["password"],
|
|
)
|
|
assert isinstance(user, User)
|
|
assert user.id == auth_db_with_users["active"]["id"]
|
|
assert user.username == "alice"
|
|
assert user.email == "alice@example.com"
|
|
assert user.role == "member"
|
|
assert user.is_active is True
|
|
|
|
async def test_wrong_password_raises_invalid_credentials(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
with pytest.raises(InvalidCredentials):
|
|
await provider.authenticate(
|
|
username=auth_db_with_users["active"]["username"],
|
|
password="definitely-not-the-password",
|
|
)
|
|
|
|
async def test_unknown_user_raises_invalid_credentials(self, provider: LocalAuthProvider):
|
|
with pytest.raises(InvalidCredentials):
|
|
await provider.authenticate(username="nonexistent", password="anything")
|
|
|
|
async def test_inactive_user_raises_invalid_credentials(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
with pytest.raises(InvalidCredentials):
|
|
await provider.authenticate(
|
|
username=auth_db_with_users["inactive"]["username"],
|
|
password=auth_db_with_users["inactive"]["password"],
|
|
)
|
|
|
|
async def test_error_message_does_not_leak_username_existence(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
"""Wrong-password and unknown-user errors must have the same message."""
|
|
try:
|
|
await provider.authenticate(
|
|
username=auth_db_with_users["active"]["username"],
|
|
password="wrong",
|
|
)
|
|
except InvalidCredentials as e1:
|
|
wrong_pw_msg = str(e1)
|
|
try:
|
|
await provider.authenticate(username="nobody-here", password="x")
|
|
except InvalidCredentials as e2:
|
|
unknown_msg = str(e2)
|
|
assert wrong_pw_msg == unknown_msg
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_user_by_id
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetUserById:
|
|
async def test_returns_user_for_active_id(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
user = await provider.get_user_by_id(auth_db_with_users["active"]["id"])
|
|
assert user is not None
|
|
assert user.id == auth_db_with_users["active"]["id"]
|
|
assert user.username == "alice"
|
|
|
|
async def test_returns_none_for_inactive_user(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
user = await provider.get_user_by_id(auth_db_with_users["inactive"]["id"])
|
|
assert user is None
|
|
|
|
async def test_returns_none_for_unknown_id(self, provider: LocalAuthProvider):
|
|
user = await provider.get_user_by_id(str(uuid.uuid4()))
|
|
assert user is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# sync_user_attributes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSyncUserAttributes:
|
|
async def test_is_noop(self, provider: LocalAuthProvider, auth_db_with_users: dict):
|
|
"""Local provider has no upstream to sync from — must succeed with no effect."""
|
|
result = await provider.sync_user_attributes(auth_db_with_users["active"]["id"])
|
|
assert result is None
|
|
|
|
async def test_noop_does_not_throw_for_unknown_user(self, provider: LocalAuthProvider):
|
|
"""sync_user_attributes is fire-and-forget — no existence check."""
|
|
# Should NOT raise even though the id doesn't exist
|
|
await provider.sync_user_attributes(str(uuid.uuid4()))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# revoke_user
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRevokeUser:
|
|
async def test_sets_is_active_to_zero(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict, tmp_path: Path
|
|
):
|
|
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
|
async with aiosqlite.connect(str(auth_db_with_users["db_path"])) as db:
|
|
db.row_factory = aiosqlite.Row
|
|
cursor = await db.execute(
|
|
"SELECT is_active FROM users WHERE id = ?",
|
|
(auth_db_with_users["active"]["id"],),
|
|
)
|
|
row = await cursor.fetchone()
|
|
assert bool(row["is_active"]) is False
|
|
|
|
async def test_revoked_user_can_no_longer_authenticate(
|
|
self, provider: LocalAuthProvider, auth_db_with_users: dict
|
|
):
|
|
await provider.revoke_user(auth_db_with_users["active"]["id"])
|
|
with pytest.raises(InvalidCredentials):
|
|
await provider.authenticate(
|
|
username=auth_db_with_users["active"]["username"],
|
|
password=auth_db_with_users["active"]["password"],
|
|
)
|
|
|
|
async def test_revoke_unknown_user_does_not_raise(self, provider: LocalAuthProvider):
|
|
"""``UPDATE ... WHERE id = ?`` with no match is a no-op, not an error."""
|
|
await provider.revoke_user(str(uuid.uuid4()))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# default db_path
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDefaultDbPath:
|
|
def test_default_db_path_uses_models_default(self, tmp_path: Path, monkeypatch):
|
|
"""If no db_path is given, the provider should use the module default.
|
|
|
|
The default may resolve to a path that does not exist on a test
|
|
machine — we only assert that the property returns a Path, not
|
|
that the file exists.
|
|
"""
|
|
monkeypatch.setenv("AGENTKIT_AUTH_DB", str(tmp_path / "default.db"))
|
|
provider = LocalAuthProvider()
|
|
assert isinstance(provider.db_path, Path)
|
|
assert str(provider.db_path) == str(tmp_path / "default.db")
|