371 lines
14 KiB
Python
371 lines
14 KiB
Python
"""Unit tests for the permission model and RBAC dependencies (U5).
|
|
|
|
Covers:
|
|
- Permission enum values
|
|
- ROLE_PERMISSIONS mapping (member / operator / admin)
|
|
- has_permission() for each role
|
|
- require_permission() dependency (success, 403, dev mode)
|
|
- require_terminal_authorized() (success, 403 on role, 403 on flag)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from fastapi import Depends, FastAPI, Request
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.server.auth.dependencies import (
|
|
require_authenticated,
|
|
require_permission,
|
|
require_terminal_authorized,
|
|
)
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.password import hash_password
|
|
from agentkit.server.auth.permissions import (
|
|
Permission,
|
|
get_role_permissions,
|
|
has_permission,
|
|
is_dev_mode,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Permission model tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPermissionModel:
|
|
"""Permission enum and role mapping."""
|
|
|
|
def test_permission_values_are_strings(self):
|
|
"""All Permission members have string values (used in JWT + audit logs)."""
|
|
for perm in Permission:
|
|
assert isinstance(perm.value, str)
|
|
assert perm.value == perm.name # value matches name for simplicity
|
|
|
|
def test_member_role_permissions(self):
|
|
"""member has chat + KB query + workflow, but no terminal/admin."""
|
|
perms = get_role_permissions("member")
|
|
assert Permission.CHAT in perms
|
|
assert Permission.KB_QUERY in perms
|
|
assert Permission.WORKFLOW_EXECUTE in perms
|
|
assert Permission.KB_WRITE not in perms
|
|
assert Permission.TERMINAL_LOCAL_USE not in perms
|
|
assert Permission.TERMINAL_SERVER_USE not in perms
|
|
assert Permission.USER_MANAGE not in perms
|
|
assert Permission.SYSTEM_CONFIG not in perms
|
|
|
|
def test_operator_role_permissions(self):
|
|
"""operator has member perms + KB write + local terminal + whitelist manage."""
|
|
perms = get_role_permissions("operator")
|
|
assert Permission.CHAT in perms
|
|
assert Permission.KB_QUERY in perms
|
|
assert Permission.KB_WRITE in perms
|
|
assert Permission.WORKFLOW_EXECUTE in perms
|
|
assert Permission.TERMINAL_LOCAL_USE in perms
|
|
assert Permission.TERMINAL_WHITELIST_MANAGE in perms
|
|
# operator does NOT have server terminal or admin perms
|
|
assert Permission.TERMINAL_SERVER_USE not in perms
|
|
assert Permission.USER_MANAGE not in perms
|
|
assert Permission.SYSTEM_CONFIG not in perms
|
|
|
|
def test_admin_role_permissions(self):
|
|
"""admin has all permissions."""
|
|
perms = get_role_permissions("admin")
|
|
for perm in Permission:
|
|
assert perm in perms, f"admin should have {perm.value}"
|
|
|
|
def test_unknown_role_returns_empty(self):
|
|
"""Unknown role → empty permission set (no permissions)."""
|
|
assert get_role_permissions("superuser") == frozenset()
|
|
assert get_role_permissions("") == frozenset()
|
|
|
|
def test_none_role_returns_empty(self):
|
|
"""None role (dev mode) → empty permission set."""
|
|
assert get_role_permissions(None) == frozenset()
|
|
|
|
def test_has_permission_member(self):
|
|
"""member has CHAT but not TERMINAL_LOCAL_USE."""
|
|
user = {"role": "member", "user_id": "u1", "username": "alice"}
|
|
assert has_permission(user, Permission.CHAT) is True
|
|
assert has_permission(user, Permission.TERMINAL_LOCAL_USE) is False
|
|
|
|
def test_has_permission_admin(self):
|
|
"""admin has all permissions."""
|
|
user = {"role": "admin", "user_id": "u1", "username": "admin"}
|
|
for perm in Permission:
|
|
assert has_permission(user, perm) is True
|
|
|
|
def test_has_permission_none_user(self):
|
|
"""None user (dev mode) → no permissions via has_permission."""
|
|
assert has_permission(None, Permission.CHAT) is False
|
|
|
|
def test_is_dev_mode(self):
|
|
"""is_dev_mode returns True when user is None."""
|
|
assert is_dev_mode(None) is True
|
|
assert is_dev_mode({"role": "member"}) is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# require_permission dependency tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_protected_app() -> FastAPI:
|
|
"""App with endpoints protected by various permission dependencies."""
|
|
app = FastAPI()
|
|
|
|
@app.get("/chat")
|
|
async def chat_endpoint(_user=Depends(require_permission(Permission.CHAT))):
|
|
return {"ok": True}
|
|
|
|
@app.get("/terminal")
|
|
async def terminal_endpoint(_user=Depends(require_permission(Permission.TERMINAL_LOCAL_USE))):
|
|
return {"ok": True}
|
|
|
|
@app.get("/admin")
|
|
async def admin_endpoint(_user=Depends(require_permission(Permission.USER_MANAGE))):
|
|
return {"ok": True}
|
|
|
|
@app.get("/any-auth")
|
|
async def any_auth_endpoint(user=Depends(require_authenticated)):
|
|
return {"user": user}
|
|
|
|
return app
|
|
|
|
|
|
def _set_user(app: FastAPI, user: dict[str, Any] | None) -> None:
|
|
"""Install middleware that sets request.state.current_user."""
|
|
|
|
@app.middleware("http")
|
|
async def set_user_middleware(request: Request, call_next):
|
|
if user is None:
|
|
# Simulate dev mode: don't set current_user at all
|
|
return await call_next(request)
|
|
request.state.current_user = user
|
|
return await call_next(request)
|
|
|
|
|
|
class TestRequirePermission:
|
|
"""require_permission FastAPI dependency."""
|
|
|
|
def test_member_can_chat(self):
|
|
"""member role → 200 on /chat."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "member", "user_id": "u1", "username": "alice"})
|
|
client = TestClient(app)
|
|
resp = client.get("/chat")
|
|
assert resp.status_code == 200
|
|
|
|
def test_member_cannot_use_terminal(self):
|
|
"""member role → 403 on /terminal (no TERMINAL_LOCAL_USE)."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "member", "user_id": "u1", "username": "alice"})
|
|
client = TestClient(app)
|
|
resp = client.get("/terminal")
|
|
assert resp.status_code == 403
|
|
assert "TERMINAL_LOCAL_USE" in resp.json()["detail"]
|
|
|
|
def test_operator_can_use_terminal(self):
|
|
"""operator role → 200 on /terminal."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "operator", "user_id": "u1", "username": "bob"})
|
|
client = TestClient(app)
|
|
resp = client.get("/terminal")
|
|
assert resp.status_code == 200
|
|
|
|
def test_member_cannot_access_admin(self):
|
|
"""member role → 403 on /admin (no USER_MANAGE)."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "member", "user_id": "u1", "username": "alice"})
|
|
client = TestClient(app)
|
|
resp = client.get("/admin")
|
|
assert resp.status_code == 403
|
|
|
|
def test_admin_can_access_admin(self):
|
|
"""admin role → 200 on /admin."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "admin", "user_id": "u1", "username": "root"})
|
|
client = TestClient(app)
|
|
resp = client.get("/admin")
|
|
assert resp.status_code == 200
|
|
|
|
def test_dev_mode_allows_low_risk(self):
|
|
"""Dev mode (no user) → 200 on /chat (low-risk allowed)."""
|
|
app = _make_protected_app()
|
|
_set_user(app, None)
|
|
client = TestClient(app)
|
|
resp = client.get("/chat")
|
|
assert resp.status_code == 200
|
|
|
|
def test_dev_mode_blocks_high_risk(self):
|
|
"""Dev mode (no user) → 401 on /terminal (high-risk requires auth)."""
|
|
app = _make_protected_app()
|
|
_set_user(app, None)
|
|
client = TestClient(app)
|
|
resp = client.get("/terminal")
|
|
assert resp.status_code == 401
|
|
|
|
def test_dev_mode_blocks_admin(self):
|
|
"""Dev mode (no user) → 401 on /admin (high-risk requires auth)."""
|
|
app = _make_protected_app()
|
|
_set_user(app, None)
|
|
client = TestClient(app)
|
|
resp = client.get("/admin")
|
|
assert resp.status_code == 401
|
|
|
|
def test_require_authenticated_blocks_dev_mode(self):
|
|
"""require_authenticated → 401 in dev mode."""
|
|
app = _make_protected_app()
|
|
_set_user(app, None)
|
|
client = TestClient(app)
|
|
resp = client.get("/any-auth")
|
|
assert resp.status_code == 401
|
|
|
|
def test_require_authenticated_passes_authenticated(self):
|
|
"""require_authenticated → 200 when user is set."""
|
|
app = _make_protected_app()
|
|
_set_user(app, {"role": "member", "user_id": "u1", "username": "alice"})
|
|
client = TestClient(app)
|
|
resp = client.get("/any-auth")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["user"]["username"] == "alice"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# require_terminal_authorized tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def tmp_auth_db_with_users(tmp_path: Path) -> Path:
|
|
"""Create an auth DB with users having different terminal authorizations."""
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
now_iso = datetime.now(timezone.utc).isoformat()
|
|
|
|
# operator with terminal authorized
|
|
operator_id = str(uuid.uuid4())
|
|
# operator without terminal authorized
|
|
operator_no_term_id = str(uuid.uuid4())
|
|
# member (no terminal permission in role)
|
|
member_id = str(uuid.uuid4())
|
|
|
|
async with aiosqlite.connect(str(db_path)) as db:
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(operator_id, "op1", "op1@x.com", hash_password("p"), "operator", 1, 1, 0, now_iso, now_iso),
|
|
)
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(operator_no_term_id, "op2", "op2@x.com", hash_password("p"), "operator", 1, 0, 0, now_iso, now_iso),
|
|
)
|
|
await db.execute(
|
|
"INSERT INTO users "
|
|
"(id, username, email, password_hash, role, is_active, "
|
|
" is_terminal_authorized, is_server_terminal_authorized, "
|
|
" created_at, updated_at) "
|
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
|
(member_id, "m1", "m1@x.com", hash_password("p"), "member", 1, 1, 0, now_iso, now_iso),
|
|
)
|
|
await db.commit()
|
|
|
|
return db_path
|
|
|
|
|
|
class TestRequireTerminalAuthorized:
|
|
"""require_terminal_authorized dependency."""
|
|
|
|
def test_operator_with_flag_can_access(self, tmp_auth_db_with_users: Path):
|
|
"""operator with is_terminal_authorized=True → 200."""
|
|
app = FastAPI()
|
|
app.state.auth_db_path = str(tmp_auth_db_with_users)
|
|
|
|
@app.get("/term")
|
|
async def term_endpoint(_user=Depends(require_terminal_authorized)):
|
|
return {"ok": True}
|
|
|
|
@app.middleware("http")
|
|
async def set_user(request: Request, call_next):
|
|
request.state.current_user = {
|
|
"user_id": _get_user_id(tmp_auth_db_with_users, "op1"),
|
|
"username": "op1",
|
|
"role": "operator",
|
|
}
|
|
return await call_next(request)
|
|
|
|
client = TestClient(app)
|
|
resp = client.get("/term")
|
|
assert resp.status_code == 200
|
|
|
|
def test_operator_without_flag_blocked(self, tmp_auth_db_with_users: Path):
|
|
"""operator with is_terminal_authorized=False → 403."""
|
|
app = FastAPI()
|
|
app.state.auth_db_path = str(tmp_auth_db_with_users)
|
|
|
|
@app.get("/term")
|
|
async def term_endpoint(_user=Depends(require_terminal_authorized)):
|
|
return {"ok": True}
|
|
|
|
@app.middleware("http")
|
|
async def set_user(request: Request, call_next):
|
|
request.state.current_user = {
|
|
"user_id": _get_user_id(tmp_auth_db_with_users, "op2"),
|
|
"username": "op2",
|
|
"role": "operator",
|
|
}
|
|
return await call_next(request)
|
|
|
|
client = TestClient(app)
|
|
resp = client.get("/term")
|
|
assert resp.status_code == 403
|
|
assert "Terminal access not authorized" in resp.json()["detail"]
|
|
|
|
def test_member_blocked_by_role(self, tmp_auth_db_with_users: Path):
|
|
"""member (no TERMINAL_LOCAL_USE permission) → 403."""
|
|
app = FastAPI()
|
|
app.state.auth_db_path = str(tmp_auth_db_with_users)
|
|
|
|
@app.get("/term")
|
|
async def term_endpoint(_user=Depends(require_terminal_authorized)):
|
|
return {"ok": True}
|
|
|
|
@app.middleware("http")
|
|
async def set_user(request: Request, call_next):
|
|
request.state.current_user = {
|
|
"user_id": _get_user_id(tmp_auth_db_with_users, "m1"),
|
|
"username": "m1",
|
|
"role": "member",
|
|
}
|
|
return await call_next(request)
|
|
|
|
client = TestClient(app)
|
|
resp = client.get("/term")
|
|
assert resp.status_code == 403
|
|
assert "TERMINAL_LOCAL_USE" in resp.json()["detail"]
|
|
|
|
|
|
def _get_user_id(db_path: Path, username: str) -> str:
|
|
"""Look up a user's ID from the test DB (synchronous helper)."""
|
|
import sqlite3
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
cursor = conn.execute("SELECT id FROM users WHERE username = ?", (username,))
|
|
row = cursor.fetchone()
|
|
conn.close()
|
|
return row[0] if row else ""
|