"""Unit tests for the JWT auth module (U2). Covers: - password hashing / verification - JWT token pair creation / verification (success, expired, invalid) - AuthMiddleware (whitelist, JWT, API key, dev mode, 401) - /auth/login (correct + wrong password) - /auth/refresh (valid refresh token) - /auth/me (returns user info) """ from __future__ import annotations import uuid from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any import aiosqlite import jwt import pytest from fastapi import FastAPI, Request from fastapi.testclient import TestClient from agentkit.server.auth.denylist import InMemoryRecentlyRevoked from agentkit.server.auth.jwt_utils import ( create_token_pair, verify_token, ) from agentkit.server.auth.middleware import AuthMiddleware from agentkit.server.auth.models import init_auth_db from agentkit.server.auth.password import hash_password, verify_password from agentkit.server.auth.session_service import SessionService, set_session_service from agentkit.server.routes import auth as auth_routes # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def jwt_secret() -> str: """A fixed JWT secret for deterministic tests.""" return "test-jwt-secret-for-unit-tests-do-not-use-in-prod" @pytest.fixture async def tmp_auth_db(tmp_path: Path) -> Path: """Create a fresh auth DB in a temp directory and return its path.""" db_path = tmp_path / "auth.db" await init_auth_db(db_path) return db_path @pytest.fixture async def auth_db_with_user(tmp_auth_db: Path) -> dict[str, Any]: """Insert a test user into the auth DB and return user fields + plaintext password.""" user_id = str(uuid.uuid4()) username = "testuser" email = "testuser@example.com" password = "correct-horse-battery-staple" password_hash = hash_password(password) now_iso = datetime.now(timezone.utc).isoformat() async with aiosqlite.connect(str(tmp_auth_db)) as db: await db.execute( "INSERT INTO users " "(id, username, email, password_hash, role, is_active, " " is_terminal_authorized, is_server_terminal_authorized, " " created_at, updated_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", ( user_id, username, email, password_hash, "member", 1, 0, 0, now_iso, now_iso, ), ) await db.commit() return { "id": user_id, "username": username, "email": email, "password": password, "password_hash": password_hash, "role": "member", "db_path": tmp_auth_db, } @pytest.fixture def auth_app(jwt_secret: str, auth_db_with_user: dict[str, Any]) -> FastAPI: """A FastAPI app with auth routes + AuthMiddleware, wired to the test DB + secret.""" app = FastAPI() app.state.jwt_secret = jwt_secret app.state.auth_db_path = str(auth_db_with_user["db_path"]) app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret) app.include_router(auth_routes.router, prefix="/api/v1") return app @pytest.fixture def auth_client(auth_app: FastAPI) -> TestClient: """TestClient for the auth-only app (no auth middleware). Also overrides the global :class:`SessionService` singleton with a per-test instance bound to the test's auth DB path. This ensures the login/refresh routes write to the test database instead of the project-default auth DB. """ # Bind a fresh SessionService to the test DB and inject it. test_svc = SessionService( db_path=auth_app.state.auth_db_path, denylist=InMemoryRecentlyRevoked(), ) set_session_service(test_svc) try: yield TestClient(auth_app) finally: # Reset the singleton so the next test gets a fresh one. set_session_service(None) # --------------------------------------------------------------------------- # Password tests # --------------------------------------------------------------------------- class TestPassword: """bcrypt hash_password / verify_password.""" def test_hash_and_verify_correct_password(self): """hash_password then verify_password with the same password → True.""" password = "my-secret-password-123" hashed = hash_password(password) assert hashed != password assert hashed.startswith("$2b$12$") assert verify_password(password, hashed) is True def test_verify_wrong_password_returns_false(self): """verify_password with a different password → False.""" hashed = hash_password("correct-password") assert verify_password("wrong-password", hashed) is False def test_hash_is_salt_randomized(self): """Same password hashed twice → different hashes (salt randomization).""" password = "same-password" h1 = hash_password(password) h2 = hash_password(password) assert h1 != h2 # Both should still verify against the original password assert verify_password(password, h1) is True assert verify_password(password, h2) is True def test_verify_malformed_hash_returns_false(self): """verify_password with a malformed hash → False (no exception).""" assert verify_password("any-password", "not-a-valid-bcrypt-hash") is False # --------------------------------------------------------------------------- # JWT tests # --------------------------------------------------------------------------- class TestJWT: """create_token_pair / verify_token.""" def test_create_token_pair_returns_valid_tokens(self, jwt_secret: str): """create_token_pair returns two non-empty JWT strings.""" pair = create_token_pair( user_id="user-123", username="alice", role="member", secret=jwt_secret, ) assert pair.access_token assert pair.refresh_token assert pair.access_token != pair.refresh_token assert pair.access_expires_at > datetime.now(timezone.utc) assert pair.refresh_expires_at > pair.access_expires_at def test_verify_access_token_succeeds(self, jwt_secret: str): """verify_token on a fresh access token returns the payload.""" pair = create_token_pair( user_id="user-123", username="alice", role="member", secret=jwt_secret, ) payload = verify_token(pair.access_token, jwt_secret) assert payload["sub"] == "user-123" assert payload["username"] == "alice" assert payload["role"] == "member" assert payload["type"] == "access" def test_verify_refresh_token_succeeds(self, jwt_secret: str): """verify_token on a refresh token returns the payload with type=refresh.""" pair = create_token_pair( user_id="user-123", username="alice", role="member", secret=jwt_secret, ) payload = verify_token(pair.refresh_token, jwt_secret) assert payload["type"] == "refresh" def test_verify_expired_token_raises(self, jwt_secret: str): """verify_token on an expired token raises jwt.ExpiredSignatureError.""" past = datetime.now(timezone.utc) - timedelta(hours=1) pair = create_token_pair( user_id="user-123", username="alice", role="member", secret=jwt_secret, now=past, ) with pytest.raises(jwt.ExpiredSignatureError): verify_token(pair.access_token, jwt_secret) def test_verify_invalid_token_raises(self, jwt_secret: str): """verify_token on a malformed token raises jwt.InvalidTokenError.""" with pytest.raises(jwt.InvalidTokenError): verify_token("not.a.valid.jwt", jwt_secret) def test_verify_wrong_secret_raises(self, jwt_secret: str): """verify_token with the wrong secret raises jwt.InvalidSignatureError.""" pair = create_token_pair( user_id="user-123", username="alice", role="member", secret=jwt_secret, ) with pytest.raises(jwt.InvalidTokenError): verify_token(pair.access_token, "a-different-secret") def test_empty_secret_raises(self): """create_token_pair with an empty secret raises ValueError.""" with pytest.raises(ValueError): create_token_pair( user_id="user-123", username="alice", role="member", secret="", ) # --------------------------------------------------------------------------- # AuthMiddleware tests # --------------------------------------------------------------------------- def _make_minimal_app() -> FastAPI: """Minimal FastAPI app with a public + protected endpoint.""" app = FastAPI() @app.get("/api/v1/health") async def health(): return {"status": "ok"} @app.get("/api/v1/auth/login") async def login_page(): return {"page": "login"} @app.get("/api/v1/protected") async def protected(request: Request): user = getattr(request.state, "current_user", None) return {"user": user} return app class TestAuthMiddleware: """AuthMiddleware dispatch logic.""" def test_whitelist_path_passes_without_auth(self): """Whitelisted paths (e.g. /api/v1/health) pass without any auth.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret="s", api_key="k") client = TestClient(app) resp = client.get("/api/v1/health") assert resp.status_code == 200 assert resp.json() == {"status": "ok"} def test_login_path_whitelisted(self): """/api/v1/auth/login is whitelisted so users can log in.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret="s", api_key="k") client = TestClient(app) resp = client.get("/api/v1/auth/login") assert resp.status_code == 200 def test_jwt_auth_success_sets_current_user(self, jwt_secret: str): """Valid JWT → 200 and request.state.current_user is populated.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret) client = TestClient(app) pair = create_token_pair( user_id="user-1", username="alice", role="member", secret=jwt_secret, ) resp = client.get( "/api/v1/protected", headers={"Authorization": f"Bearer {pair.access_token}"}, ) assert resp.status_code == 200 body = resp.json() assert body["user"]["user_id"] == "user-1" assert body["user"]["username"] == "alice" assert body["user"]["role"] == "member" def test_jwt_with_wrong_secret_returns_401(self, jwt_secret: str): """JWT signed with a different secret → 401.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret) client = TestClient(app) pair = create_token_pair( user_id="user-1", username="alice", role="member", secret="a-different-secret", ) resp = client.get( "/api/v1/protected", headers={"Authorization": f"Bearer {pair.access_token}"}, ) assert resp.status_code == 401 def test_refresh_token_rejected_for_request_auth(self, jwt_secret: str): """A refresh token (type=refresh) must not authenticate requests.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret, api_key="fallback-key") client = TestClient(app) pair = create_token_pair( user_id="user-1", username="alice", role="member", secret=jwt_secret, ) resp = client.get( "/api/v1/protected", headers={"Authorization": f"Bearer {pair.refresh_token}"}, ) # Should fall through to API key check (no X-API-Key) → 401 assert resp.status_code == 401 def test_api_key_auth_success(self): """Valid X-API-Key → 200.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, api_key="my-global-key") client = TestClient(app) resp = client.get( "/api/v1/protected", headers={"X-API-Key": "my-global-key"}, ) assert resp.status_code == 200 def test_api_key_client_keys_success(self): """API key matching a client_keys entry → 200.""" app = _make_minimal_app() app.add_middleware( AuthMiddleware, client_keys={"client-a": "key-for-a"}, ) client = TestClient(app) resp = client.get( "/api/v1/protected", headers={"X-API-Key": "key-for-a"}, ) assert resp.status_code == 200 def test_api_key_wrong_returns_401(self): """Wrong X-API-Key → 401.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, api_key="my-global-key") client = TestClient(app) resp = client.get( "/api/v1/protected", headers={"X-API-Key": "wrong-key"}, ) assert resp.status_code == 401 def test_dev_mode_passes_through(self): """No JWT secret, no API key, no client keys → dev mode, all pass.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware) client = TestClient(app) resp = client.get("/api/v1/protected") assert resp.status_code == 200 def test_no_auth_returns_401(self): """Auth configured but no credentials provided → 401.""" app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret="s", api_key="k") client = TestClient(app) resp = client.get("/api/v1/protected") assert resp.status_code == 401 body = resp.json() assert body["error"] == "Unauthorized" def test_jwt_via_query_param_on_ws_path(self, jwt_secret: str): """WebSocket paths accept ?token= as a fallback for clients that cannot set the Authorization header (e.g. browser WebSocket API). """ app = _make_minimal_app() @app.get("/api/v1/ws/echo") async def ws_echo(request: Request): user = getattr(request.state, "current_user", None) return {"user": user} app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret) client = TestClient(app) pair = create_token_pair( user_id="user-1", username="alice", role="member", secret=jwt_secret, ) resp = client.get(f"/api/v1/ws/echo?token={pair.access_token}") assert resp.status_code == 200 body = resp.json() assert body["user"]["user_id"] == "user-1" assert body["user"]["username"] == "alice" def test_jwt_via_query_param_rejected_on_non_ws_path(self, jwt_secret: str): """The ?token= query parameter is only honored on /ws paths — using it on a regular REST endpoint should not authenticate. """ app = _make_minimal_app() app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret, api_key="fallback") client = TestClient(app) pair = create_token_pair( user_id="user-1", username="alice", role="member", secret=jwt_secret, ) # /api/v1/protected is NOT a /ws path → token query param ignored resp = client.get(f"/api/v1/protected?token={pair.access_token}") # No Authorization header, no X-API-Key → 401 (query param not honored) assert resp.status_code == 401 def test_jwt_via_query_param_invalid_token_returns_401(self, jwt_secret: str): """An invalid ?token= value on a /ws path → 401 (falls through).""" app = _make_minimal_app() @app.get("/api/v1/ws/echo") async def ws_echo(request: Request): user = getattr(request.state, "current_user", None) return {"user": user} app.add_middleware(AuthMiddleware, jwt_secret=jwt_secret, api_key="fallback") client = TestClient(app) resp = client.get("/api/v1/ws/echo?token=not.a.valid.jwt") assert resp.status_code == 401 # --------------------------------------------------------------------------- # Auth routes tests # --------------------------------------------------------------------------- class TestLoginRoute: """POST /api/v1/auth/login.""" def test_login_correct_password_returns_token( self, auth_client: TestClient, auth_db_with_user: dict[str, Any], ): """Login with correct password → 200 + TokenResponse.""" resp = auth_client.post( "/api/v1/auth/login", json={ "username": auth_db_with_user["username"], "password": auth_db_with_user["password"], }, ) assert resp.status_code == 200, resp.text body = resp.json() assert body["access_token"] assert body["refresh_token"] assert body["token_type"] == "bearer" assert body["expires_in"] == 900 assert body["user"]["username"] == auth_db_with_user["username"] assert body["user"]["email"] == auth_db_with_user["email"] assert body["user"]["role"] == "member" assert body["user"]["is_active"] is True def test_login_wrong_password_returns_401( self, auth_client: TestClient, auth_db_with_user: dict[str, Any], ): """Login with wrong password → 401.""" resp = auth_client.post( "/api/v1/auth/login", json={ "username": auth_db_with_user["username"], "password": "totally-wrong-password", }, ) assert resp.status_code == 401 assert "invalid username or password" in resp.json()["detail"].lower() def test_login_unknown_user_returns_401( self, auth_client: TestClient, ): """Login with unknown username → 401.""" resp = auth_client.post( "/api/v1/auth/login", json={"username": "ghost", "password": "anything"}, ) assert resp.status_code == 401 class TestRefreshRoute: """POST /api/v1/auth/refresh.""" def test_refresh_valid_token_returns_new_access_token( self, auth_client: TestClient, auth_db_with_user: dict[str, Any], ): """Refresh with a valid refresh token → 200 + new TokenResponse.""" # First log in to get a refresh token login_resp = auth_client.post( "/api/v1/auth/login", json={ "username": auth_db_with_user["username"], "password": auth_db_with_user["password"], }, ) assert login_resp.status_code == 200 refresh_token = login_resp.json()["refresh_token"] # Now refresh resp = auth_client.post( "/api/v1/auth/refresh", json={"refresh_token": refresh_token}, ) assert resp.status_code == 200, resp.text body = resp.json() assert body["access_token"] assert body["refresh_token"] assert body["user"]["username"] == auth_db_with_user["username"] def test_refresh_invalid_token_returns_401( self, auth_client: TestClient, ): """Refresh with an invalid token → 401.""" resp = auth_client.post( "/api/v1/auth/refresh", json={"refresh_token": "not.a.valid.jwt"}, ) assert resp.status_code == 401 def test_refresh_revoked_token_returns_401( self, auth_client: TestClient, auth_db_with_user: dict[str, Any], ): """Refresh with a revoked token → 401.""" # Login login_resp = auth_client.post( "/api/v1/auth/login", json={ "username": auth_db_with_user["username"], "password": auth_db_with_user["password"], }, ) refresh_token = login_resp.json()["refresh_token"] # Logout (revokes the refresh token) logout_resp = auth_client.post( "/api/v1/auth/logout", json={"refresh_token": refresh_token}, ) assert logout_resp.status_code == 200 assert logout_resp.json()["revoked"] is True # Refresh should now fail resp = auth_client.post( "/api/v1/auth/refresh", json={"refresh_token": refresh_token}, ) assert resp.status_code == 401 class TestMeRoute: """GET /api/v1/auth/me.""" def test_me_returns_user_info( self, auth_client: TestClient, auth_db_with_user: dict[str, Any], ): """Authenticated /me → 200 + UserResponse.""" # Login to get an access token login_resp = auth_client.post( "/api/v1/auth/login", json={ "username": auth_db_with_user["username"], "password": auth_db_with_user["password"], }, ) access_token = login_resp.json()["access_token"] # Call /me with the access token resp = auth_client.get( "/api/v1/auth/me", headers={"Authorization": f"Bearer {access_token}"}, ) assert resp.status_code == 200, resp.text body = resp.json() assert body["id"] == auth_db_with_user["id"] assert body["username"] == auth_db_with_user["username"] assert body["email"] == auth_db_with_user["email"] assert body["role"] == "member" assert body["is_active"] is True def test_me_without_auth_returns_401( self, auth_client: TestClient, ): """/me without authentication → 401.""" resp = auth_client.get("/api/v1/auth/me") assert resp.status_code == 401 # --------------------------------------------------------------------------- # init_auth_db test # --------------------------------------------------------------------------- class TestInitAuthDb: """init_auth_db creates the expected tables.""" async def test_init_creates_tables(self, tmp_path: Path): """init_auth_db creates users, user_api_keys, user_sessions tables.""" db_path = tmp_path / "auth.db" await init_auth_db(db_path) assert db_path.exists() # Verify tables exist by querying sqlite_master async with aiosqlite.connect(str(db_path)) as db: cursor = await db.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" ) rows = await cursor.fetchall() names = [r[0] for r in rows] assert "users" in names assert "user_api_keys" in names assert "user_sessions" in names async def test_init_is_idempotent(self, tmp_path: Path): """init_auth_db can be called twice without error.""" db_path = tmp_path / "auth.db" await init_auth_db(db_path) await init_auth_db(db_path) assert db_path.exists()