diff --git a/src/agentkit/server/auth/session_service.py b/src/agentkit/server/auth/session_service.py index 1ea2dd9..83046bb 100644 --- a/src/agentkit/server/auth/session_service.py +++ b/src/agentkit/server/auth/session_service.py @@ -233,6 +233,25 @@ class SessionService: row = await cursor.fetchone() return _row_to_info(row) if row else None + async def get_stored_refresh_hash(self, session_id: str) -> str | None: + """Return the stored ``refresh_token_hash`` for ``session_id``. + + Used by callers (e.g. ``/auth/whoami``) that need to verify a + presented refresh token against the session's current hash without + going through :meth:`rotate`. Returns ``None`` if the session + does not exist. + """ + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + "SELECT refresh_token_hash FROM auth_sessions WHERE id = ?", + (session_id,), + ) + row = await cursor.fetchone() + if row is None: + return None + return row["refresh_token_hash"] + # ------------------------------------------------------------------ # Write # ------------------------------------------------------------------ diff --git a/src/agentkit/server/routes/auth.py b/src/agentkit/server/routes/auth.py index d569478..04d0727 100644 --- a/src/agentkit/server/routes/auth.py +++ b/src/agentkit/server/routes/auth.py @@ -21,6 +21,7 @@ The auth DB (SQLite via aiosqlite) and JWT secret are resolved from from __future__ import annotations +import hmac import logging from datetime import datetime, timezone from pathlib import Path @@ -586,6 +587,22 @@ async def whoami(request: Request) -> WhoamiResponse: info = await svc.get(sid) if info is None or info.revoked: raise HTTPException(status_code=401, detail="session revoked or expired") + # Cold-start defense (R9): when the presented token is a refresh + # token, verify its hash matches the session's current + # ``refresh_token_hash``. If the token has been rotated (via + # ``/auth/refresh``) or revoked, the stored hash will differ and + # we reject the request. Comparison uses ``hmac.compare_digest`` + # for constant-time equality to prevent timing attacks. + if token_type == "refresh": + stored_hash = await svc.get_stored_refresh_hash(sid) + if stored_hash is None: + raise HTTPException(status_code=401, detail="session not found") + presented_hash = hash_token(token) + if not hmac.compare_digest(presented_hash, stored_hash): + raise HTTPException( + status_code=401, + detail="refresh token has been rotated or revoked", + ) session_response = SessionResponse( **auth_session_row_to_dict(_info_to_dict(info)), is_current=True, diff --git a/tests/integration/auth/test_auth_routes.py b/tests/integration/auth/test_auth_routes.py index fdea777..e255ac1 100644 --- a/tests/integration/auth/test_auth_routes.py +++ b/tests/integration/auth/test_auth_routes.py @@ -27,6 +27,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient from agentkit.server.auth.denylist import InMemoryRecentlyRevoked +from agentkit.server.auth.jwt_utils import verify_token from agentkit.server.auth.middleware import AuthMiddleware from agentkit.server.auth.models import init_auth_db from agentkit.server.auth.password import hash_password @@ -243,6 +244,220 @@ class TestWhoamiColdStart: assert "refresh_token" not in data +class TestWhoamiTokenHash: + """GET /auth/whoami refresh-token hash verification (U7 — R9). + + After refresh-token rotation via ``/auth/refresh``, the old refresh + token must NOT be usable for cold-start on ``/auth/whoami``. The + route verifies the presented token's SHA-256 hash against the + session's stored ``refresh_token_hash`` and rejects mismatches with + 401 (constant-time comparison via ``hmac.compare_digest``). + + Note: ``create_token_pair`` does not add ``jti`` to refresh tokens, + so login + refresh within the same second produce identical refresh + tokens. To test the hash-mismatch path deterministically we update + ``refresh_token_hash`` directly in the DB (simulating rotation to a + different token). + """ + + def test_whoami_with_valid_refresh_token_returns_200( + self, + auth_client: TestClient, + auth_db_with_user: dict[str, Any], + ): + """A valid (non-rotated) refresh token passes the hash check.""" + body = _login( + auth_client, + auth_db_with_user["username"], + auth_db_with_user["password"], + ) + resp = auth_client.get( + "/api/v1/auth/whoami", + headers={"Authorization": f"Bearer {body['refresh_token']}"}, + ) + assert resp.status_code == 200, resp.text + data = resp.json() + assert data["access_token"] is not None + assert data["user"]["username"] == auth_db_with_user["username"] + + async def test_whoami_with_rotated_refresh_token_returns_401( + self, + auth_client: TestClient, + auth_db_with_user: dict[str, Any], + tmp_auth_db: Path, + ): + """After rotation, the old refresh token is rejected. + + We simulate rotation by overwriting ``refresh_token_hash`` in + the DB with a different value (the hash of a hypothetical new + token). The old token's hash no longer matches and whoami + returns 401. + """ + body = _login( + auth_client, + auth_db_with_user["username"], + auth_db_with_user["password"], + ) + old_refresh = body["refresh_token"] + + # Decode the old token to get the session id. + old_payload = verify_token( + old_refresh, auth_client.app.state.jwt_secret, expected_type="refresh" + ) + sid = old_payload.get("sid") + assert sid is not None, "refresh token must carry sid" + + # Simulate rotation: replace the stored hash with a different one. + from agentkit.server.auth.denylist import hash_token as _hash_token + + new_hash = _hash_token("rotated-new-token-different-from-old") + async with aiosqlite.connect(str(tmp_auth_db)) as db: + await db.execute( + "UPDATE auth_sessions SET refresh_token_hash = ? WHERE id = ?", + (new_hash, sid), + ) + await db.commit() + + # Old refresh token must now be rejected by whoami. + resp = auth_client.get( + "/api/v1/auth/whoami", + headers={"Authorization": f"Bearer {old_refresh}"}, + ) + assert resp.status_code == 401 + detail = resp.json()["detail"].lower() + assert "rotated" in detail or "revoked" in detail + + async def test_whoami_with_new_refresh_token_after_rotation_returns_200( + self, + auth_client: TestClient, + auth_db_with_user: dict[str, Any], + tmp_auth_db: Path, + ): + """After rotation, the NEW refresh token works on whoami. + + We simulate rotation by overwriting ``refresh_token_hash`` with + the hash of a known new token, then present that new token. + """ + body = _login( + auth_client, + auth_db_with_user["username"], + auth_db_with_user["password"], + ) + old_refresh = body["refresh_token"] + old_payload = verify_token( + old_refresh, auth_client.app.state.jwt_secret, expected_type="refresh" + ) + sid = old_payload.get("sid") + assert sid is not None + + # Mint a new refresh token with a unique jti claim so it differs + # from the old token (create_token_pair doesn't add jti to + # refresh tokens, so we craft one directly). + import jwt as _jwt + + new_payload = {**old_payload, "jti": str(uuid.uuid4())} + new_refresh = _jwt.encode(new_payload, auth_client.app.state.jwt_secret, algorithm="HS256") + if isinstance(new_refresh, bytes): + new_refresh = new_refresh.decode("utf-8") + + # Update the stored hash to match the new token. + from agentkit.server.auth.denylist import hash_token as _hash_token + + new_hash = _hash_token(new_refresh) + async with aiosqlite.connect(str(tmp_auth_db)) as db: + await db.execute( + "UPDATE auth_sessions SET refresh_token_hash = ? WHERE id = ?", + (new_hash, sid), + ) + await db.commit() + + # The new refresh token must pass the hash check. + resp = auth_client.get( + "/api/v1/auth/whoami", + headers={"Authorization": f"Bearer {new_refresh}"}, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["access_token"] is not None + + def test_whoami_with_revoked_session_refresh_token_returns_401( + self, + auth_client: TestClient, + auth_db_with_user: dict[str, Any], + ): + """A refresh token whose session was revoked is rejected. + + Although the session row's ``refresh_token_hash`` is unchanged + by revocation, the session-revocation check (``info.revoked``) + fires first and returns 401. + """ + body = _login( + auth_client, + auth_db_with_user["username"], + auth_db_with_user["password"], + ) + # Revoke the session via DELETE /auth/sessions/{id}. + sessions = auth_client.get( + "/api/v1/auth/sessions", + headers={"Authorization": f"Bearer {body['access_token']}"}, + ).json() + sid = sessions[0]["id"] + del_resp = auth_client.delete( + f"/api/v1/auth/sessions/{sid}", + headers={"Authorization": f"Bearer {body['access_token']}"}, + ) + assert del_resp.status_code == 200 + + # Refresh token on the revoked session → 401. + resp = auth_client.get( + "/api/v1/auth/whoami", + headers={"Authorization": f"Bearer {body['refresh_token']}"}, + ) + assert resp.status_code == 401 + + async def test_whoami_access_token_skips_hash_check( + self, + auth_client: TestClient, + auth_db_with_user: dict[str, Any], + tmp_auth_db: Path, + ): + """Access tokens are not subject to the refresh-token hash check. + + The hash check only runs when ``token_type == "refresh"``; + access tokens bypass it (they have their own expiry + jti). + """ + body = _login( + auth_client, + auth_db_with_user["username"], + auth_db_with_user["password"], + ) + # Rotate the stored hash to a different value. + old_refresh = body["refresh_token"] + old_payload = verify_token( + old_refresh, auth_client.app.state.jwt_secret, expected_type="refresh" + ) + sid = old_payload.get("sid") + assert sid is not None + + from agentkit.server.auth.denylist import hash_token as _hash_token + + new_hash = _hash_token("some-other-token-not-the-access-token") + async with aiosqlite.connect(str(tmp_auth_db)) as db: + await db.execute( + "UPDATE auth_sessions SET refresh_token_hash = ? WHERE id = ?", + (new_hash, sid), + ) + await db.commit() + + # The original access token should still work (not yet expired). + resp = auth_client.get( + "/api/v1/auth/whoami", + headers={"Authorization": f"Bearer {body['access_token']}"}, + ) + assert resp.status_code == 200, resp.text + # Access-token call does NOT issue a new access token. + assert resp.json()["access_token"] is None + + class TestSessionsManagement: """GET /auth/sessions, DELETE /auth/sessions/{id}."""