660 lines
23 KiB
Python
660 lines
23 KiB
Python
"""Tests for U8: server-side terminal + approval mechanism.
|
|
|
|
Covers:
|
|
- Approval management API (list, approve, reject, expire)
|
|
- Global whitelist CRUD
|
|
- Global whitelist integration with safety check
|
|
- Approval creation and lifecycle
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from agentkit.server.auth.models import init_auth_db
|
|
from agentkit.server.auth.terminal_security import check_command_safety_v2
|
|
from agentkit.server.routes import terminal_server, terminal_whitelist
|
|
|
|
|
|
def _future_iso(minutes: int = 5) -> str:
|
|
"""Return an ISO timestamp `minutes` in the future."""
|
|
return (datetime.now(timezone.utc) + timedelta(minutes=minutes)).isoformat()
|
|
|
|
|
|
def _past_iso(minutes: int = 10) -> str:
|
|
"""Return an ISO timestamp `minutes` in the past."""
|
|
return (datetime.now(timezone.utc) - timedelta(minutes=minutes)).isoformat()
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
async def server_app(tmp_path: Path) -> FastAPI:
|
|
"""App with terminal_server routes + dev-admin middleware."""
|
|
db_path = tmp_path / "test_auth.db"
|
|
await init_auth_db(db_path)
|
|
|
|
app = FastAPI()
|
|
app.state.auth_db_path = str(db_path)
|
|
app.state.allow_dev_terminal = True
|
|
app.include_router(terminal_server.router, prefix="/api/v1")
|
|
app.include_router(terminal_whitelist.router, prefix="/api/v1")
|
|
|
|
@app.middleware("http")
|
|
async def _set_dev_admin_user(request, call_next):
|
|
request.state.current_user = {
|
|
"user_id": "dev-admin-id",
|
|
"username": "dev-admin",
|
|
"role": "admin",
|
|
"dev_mode": False,
|
|
}
|
|
return await call_next(request)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
async def server_client(server_app: FastAPI):
|
|
transport = ASGITransport(app=server_app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
|
|
@pytest.fixture
|
|
async def server_db_path(server_app: FastAPI) -> Path:
|
|
return Path(server_app.state.auth_db_path)
|
|
|
|
|
|
# ── Approval management API tests ─────────────────────────────────────
|
|
|
|
|
|
class TestApprovalListAPI:
|
|
"""Test GET /api/v1/terminal/approvals."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_empty_approvals(self, server_client: AsyncClient):
|
|
resp = await server_client.get("/api/v1/terminal/approvals")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["approvals"] == []
|
|
assert data["total"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_approvals_after_creation(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
# Create an approval directly in the DB
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, reason, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
"appr-1",
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"rm -rf /tmp/test",
|
|
"Dangerous command",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
resp = await server_client.get("/api/v1/terminal/approvals")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["command"] == "rm -rf /tmp/test"
|
|
assert data["approvals"][0]["status"] == "pending"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_approvals_filter_by_status(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
# Create approvals with different statuses
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
for i, status in enumerate(["pending", "approved", "rejected"]):
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
f"appr-{i}",
|
|
"user-1",
|
|
"alice",
|
|
f"session-{i}",
|
|
f"cmd-{i}",
|
|
status,
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
# Filter by pending
|
|
resp = await server_client.get("/api/v1/terminal/approvals?status=pending")
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["status"] == "pending"
|
|
|
|
# Filter by approved
|
|
resp = await server_client.get("/api/v1/terminal/approvals?status=approved")
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["status"] == "approved"
|
|
|
|
|
|
class TestApprovalApproveAPI:
|
|
"""Test POST /api/v1/terminal/approvals/{id}/approve."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_approve_pending_request(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
# Create a pending approval
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
"appr-approve-1",
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"rm -rf /tmp",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/approvals/appr-approve-1/approve",
|
|
json={"note": "Approved for testing"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "approved"
|
|
assert data["reviewer_username"] == "dev-admin"
|
|
assert data["review_note"] == "Approved for testing"
|
|
assert data["reviewed_at"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_approve_nonexistent_returns_404(self, server_client: AsyncClient):
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/approvals/nonexistent/approve"
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_approve_already_reviewed_returns_404(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
# Create an already-approved approval
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at, reviewed_at)
|
|
VALUES (?, ?, ?, ?, ?, 'approved', ?, ?, ?)""",
|
|
(
|
|
"appr-done",
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"ls",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
"2026-01-01T00:01:00Z",
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/approvals/appr-done/approve"
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestApprovalRejectAPI:
|
|
"""Test POST /api/v1/terminal/approvals/{id}/reject."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_pending_request(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
"appr-reject-1",
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"mkfs /dev/sda",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/approvals/appr-reject-1/reject",
|
|
json={"note": "Too dangerous"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["status"] == "rejected"
|
|
assert data["review_note"] == "Too dangerous"
|
|
|
|
|
|
class TestApprovalExpiration:
|
|
"""Test that stale approvals are auto-expired."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_expired_approvals_auto_marked(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
# Create an approval that has already expired
|
|
past = _past_iso()
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
"appr-expired",
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"rm -rf /",
|
|
"2026-01-01T00:00:00Z",
|
|
past, # Already expired
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
# Listing triggers auto-expiration
|
|
resp = await server_client.get("/api/v1/terminal/approvals")
|
|
assert resp.status_code == 200
|
|
|
|
# The expired approval should now have status "expired"
|
|
resp = await server_client.get(
|
|
"/api/v1/terminal/approvals?status=expired"
|
|
)
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["status"] == "expired"
|
|
|
|
# It should no longer be pending
|
|
resp = await server_client.get(
|
|
"/api/v1/terminal/approvals?status=pending"
|
|
)
|
|
data = resp.json()
|
|
assert data["total"] == 0
|
|
|
|
|
|
# ── Global whitelist CRUD tests ───────────────────────────────────────
|
|
|
|
|
|
class TestGlobalWhitelistAPI:
|
|
"""Test the global whitelist management endpoints."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_empty_global_whitelist(self, server_client: AsyncClient):
|
|
resp = await server_client.get("/api/v1/terminal/whitelist/global")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["entries"] == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_global_whitelist_entry(self, server_client: AsyncClient):
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "docker"},
|
|
)
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert data["command_pattern"] == "docker"
|
|
assert "id" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_duplicate_global_whitelist_returns_409(
|
|
self, server_client: AsyncClient
|
|
):
|
|
await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "docker"},
|
|
)
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "docker"},
|
|
)
|
|
assert resp.status_code == 409
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_global_whitelist_after_add(self, server_client: AsyncClient):
|
|
await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "docker"},
|
|
)
|
|
await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "kubectl"},
|
|
)
|
|
|
|
resp = await server_client.get("/api/v1/terminal/whitelist/global")
|
|
assert resp.status_code == 200
|
|
entries = resp.json()["entries"]
|
|
assert len(entries) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_global_whitelist_entry(self, server_client: AsyncClient):
|
|
add_resp = await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "docker"},
|
|
)
|
|
entry_id = add_resp.json()["id"]
|
|
|
|
del_resp = await server_client.delete(
|
|
f"/api/v1/terminal/whitelist/global/{entry_id}"
|
|
)
|
|
assert del_resp.status_code == 204
|
|
|
|
# Verify deleted
|
|
list_resp = await server_client.get("/api/v1/terminal/whitelist/global")
|
|
assert list_resp.json()["entries"] == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_nonexistent_global_returns_404(self, server_client: AsyncClient):
|
|
resp = await server_client.delete(
|
|
"/api/v1/terminal/whitelist/global/nonexistent"
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ── Integration: global whitelist + safety check ──────────────────────
|
|
|
|
|
|
class TestGlobalWhitelistIntegration:
|
|
"""Integration: global whitelist allows commands in safety check."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_global_whitelist_allows_dangerous_command(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""Add 'rm' to global whitelist, then verify safety check passes."""
|
|
# Add "rm" to global whitelist via API
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "rm"},
|
|
)
|
|
assert resp.status_code == 201
|
|
|
|
# Check command safety — rm should now be allowed via global whitelist
|
|
decision = await check_command_safety_v2(
|
|
"rm -rf /tmp/test",
|
|
session_id="test-session",
|
|
session_whitelist=set(),
|
|
user_id="user-1",
|
|
db_path=server_db_path,
|
|
)
|
|
assert decision.safe is True
|
|
assert decision.matched_layer == "global_whitelist"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_global_whitelist_priority_over_user(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""Global whitelist is checked before user whitelist."""
|
|
# Add to global whitelist
|
|
await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "rm"},
|
|
)
|
|
|
|
# Even without user whitelist, global whitelist allows it
|
|
decision = await check_command_safety_v2(
|
|
"rm -rf /tmp/test",
|
|
session_id="test-session",
|
|
session_whitelist=set(),
|
|
user_id="user-without-whitelist",
|
|
db_path=server_db_path,
|
|
)
|
|
assert decision.safe is True
|
|
assert decision.matched_layer == "global_whitelist"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocklist_overrides_global_whitelist(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""Blocklist takes priority over global whitelist."""
|
|
# Add "rm" to global whitelist
|
|
await server_client.post(
|
|
"/api/v1/terminal/whitelist/global",
|
|
json={"command_pattern": "rm"},
|
|
)
|
|
|
|
# Add "rm" to blocklist via the whitelist management API
|
|
resp = await server_client.post(
|
|
"/api/v1/terminal/blocklist",
|
|
json={"command_pattern": "rm", "reason": "rm is globally blocked"},
|
|
)
|
|
assert resp.status_code == 201
|
|
|
|
# Blocklist should take priority
|
|
decision = await check_command_safety_v2(
|
|
"rm -rf /tmp/test",
|
|
session_id="test-session",
|
|
session_whitelist=set(),
|
|
user_id="user-1",
|
|
db_path=server_db_path,
|
|
)
|
|
assert decision.safe is False
|
|
assert decision.decision == "blocked"
|
|
|
|
|
|
# ── Approval lifecycle integration ────────────────────────────────────
|
|
|
|
|
|
class TestApprovalLifecycle:
|
|
"""Test the full approval lifecycle via the API."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_approval_lifecycle(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""Create → list pending → approve → list approved."""
|
|
# 1. Create a pending approval directly
|
|
approval_id = "lifecycle-1"
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, reason, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
approval_id,
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"rm -rf /tmp/test",
|
|
"Cleanup temp directory",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
# 2. List pending — should see it
|
|
resp = await server_client.get("/api/v1/terminal/approvals?status=pending")
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["id"] == approval_id
|
|
|
|
# 3. Approve it
|
|
resp = await server_client.post(
|
|
f"/api/v1/terminal/approvals/{approval_id}/approve",
|
|
json={"note": "OK to proceed"},
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "approved"
|
|
|
|
# 4. List pending — should be empty now
|
|
resp = await server_client.get("/api/v1/terminal/approvals?status=pending")
|
|
assert resp.json()["total"] == 0
|
|
|
|
# 5. List approved — should see it
|
|
resp = await server_client.get("/api/v1/terminal/approvals?status=approved")
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["approvals"][0]["id"] == approval_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_then_cannot_approve(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""Once rejected, cannot approve the same request."""
|
|
approval_id = "reject-then-approve"
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
approval_id,
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"mkfs /dev/sda",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
# Reject it
|
|
resp = await server_client.post(
|
|
f"/api/v1/terminal/approvals/{approval_id}/reject"
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "rejected"
|
|
|
|
# Try to approve — should fail (already reviewed)
|
|
resp = await server_client.post(
|
|
f"/api/v1/terminal/approvals/{approval_id}/approve"
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
# ── Approval future notification test ─────────────────────────────────
|
|
|
|
|
|
class TestApprovalFutureNotification:
|
|
"""Test that approving/rejecting via API notifies waiting futures."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_approve_notifies_future(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""When admin approves via API, the waiting future is resolved."""
|
|
from agentkit.server.routes.terminal_server import _pending_approvals
|
|
|
|
approval_id = "future-test-1"
|
|
|
|
# Create approval in DB
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
approval_id,
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"rm -rf /tmp",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
# Register a future (simulating a waiting WebSocket)
|
|
loop = asyncio.get_running_loop()
|
|
future: asyncio.Future[bool] = loop.create_future()
|
|
_pending_approvals[approval_id] = future
|
|
|
|
# Approve via API
|
|
resp = await server_client.post(
|
|
f"/api/v1/terminal/approvals/{approval_id}/approve"
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
# The future should be resolved with True
|
|
result = await asyncio.wait_for(future, timeout=1.0)
|
|
assert result is True
|
|
|
|
# Cleanup
|
|
_pending_approvals.pop(approval_id, None)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reject_notifies_future(
|
|
self, server_client: AsyncClient, server_db_path: Path
|
|
):
|
|
"""When admin rejects via API, the waiting future is resolved with False."""
|
|
from agentkit.server.routes.terminal_server import _pending_approvals
|
|
|
|
approval_id = "future-test-2"
|
|
|
|
async with aiosqlite.connect(str(server_db_path)) as db:
|
|
await db.execute(
|
|
"""INSERT INTO terminal_approvals
|
|
(id, user_id, username, session_id, command, status,
|
|
created_at, expires_at)
|
|
VALUES (?, ?, ?, ?, ?, 'pending', ?, ?)""",
|
|
(
|
|
approval_id,
|
|
"user-1",
|
|
"alice",
|
|
"session-1",
|
|
"mkfs /dev/sda",
|
|
"2026-01-01T00:00:00Z",
|
|
_future_iso(),
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
loop = asyncio.get_running_loop()
|
|
future: asyncio.Future[bool] = loop.create_future()
|
|
_pending_approvals[approval_id] = future
|
|
|
|
resp = await server_client.post(
|
|
f"/api/v1/terminal/approvals/{approval_id}/reject"
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
result = await asyncio.wait_for(future, timeout=1.0)
|
|
assert result is False
|
|
|
|
_pending_approvals.pop(approval_id, None)
|