geo/backend/tests/test_infrastructure/test_security.py

669 lines
26 KiB
Python

"""Security tests: SQL injection, XSS protection, and authentication security."""
import uuid
from datetime import datetime, timedelta
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from jose import jwt
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.models.query import Query
from app.models.competitor import Competitor
from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, create_refresh_token, hash_password
from app.config import settings
from tests.fixtures.auth import _to_uuid
# Only the tables needed for security tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = (
User.__table__,
Brand.__table__,
Query.__table__,
Competitor.__table__,
Suggestion.__table__,
)
# ─────────────────────── Fixtures ───────────────────────
@pytest_asyncio.fixture
async def async_engine():
"""Create async engine for testing with SQLite.
Only creates the specific tables needed by security tests,
avoiding PostgreSQL-only types (JSONB) that fail on SQLite.
"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: Base.metadata.create_all(
sync_conn, tables=[t for t in _TEST_TABLES]
)
)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""Create async session for testing."""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""Create a test user with properly hashed password."""
user = User(
id=str(uuid.uuid4()),
email="security_test@example.com",
password=hash_password("SecurePass123!"),
firstName="Security Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def second_user(async_session):
"""Create a second test user for cross-user isolation tests."""
user = User(
id=str(uuid.uuid4()),
email="second_user@example.com",
password=hash_password("SecondPass456!"),
firstName="Second User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""Create async client for API testing with dependency overrides."""
session = async_session
async def override_get_db():
yield session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def client_no_override(async_session):
"""Create async client WITHOUT overriding get_current_user.
This allows testing real JWT authentication flow.
get_db is still overridden to use test database.
"""
session = async_session
async def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
# Intentionally NOT overriding get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""Create authentication headers for test_user."""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def second_auth_headers(second_user):
"""Create authentication headers for second_user."""
token = create_access_token(data={"sub": str(second_user.id)})
return {"Authorization": f"Bearer {token}"}
# ═══════════════════════════════════════════════════════════
# SQL Injection Protection Tests
# ═══════════════════════════════════════════════════════════
class TestSQLInjection:
"""Verify that SQL injection attack vectors are properly mitigated."""
@pytest.mark.asyncio
async def test_login_sql_injection_rejected(self, client_no_override):
"""Login endpoint should reject SQL injection payloads.
Email field uses EmailStr validation, so non-email payloads
should return 422 (validation error). The ORM layer uses
parameterized queries, preventing SQL injection even if
payloads pass validation. 429 (rate-limited) is also
acceptable — it means the security layer is working.
"""
payloads = [
"' OR '1'='1",
"'; DROP TABLE users; --",
"admin'--",
"1' UNION SELECT * FROM users--",
"' OR 1=1 --",
'" OR ""=""',
"1; SELECT * FROM users WHERE '1' = '1'",
]
for payload in payloads:
response = await client_no_override.post(
"/api/v1/auth/login",
json={"email": payload, "password": payload},
)
# 422: EmailStr validation rejects non-email strings
# 401: If valid email format but auth fails
# 429: Rate-limited (also a correct security behavior)
assert response.status_code in (401, 422, 429), (
f"SQL injection payload '{payload}' returned {response.status_code}, "
f"expected 401/422/429"
)
@pytest.mark.asyncio
async def test_register_sql_injection_rejected(self, client_no_override):
"""Registration endpoint should reject SQL injection payloads."""
payloads = [
"' OR '1'='1",
"admin'--; DROP TABLE users;--",
"1' UNION SELECT * FROM users--",
]
for payload in payloads:
response = await client_no_override.post(
"/api/v1/auth/register",
json={"email": payload, "password": "password123", "name": "test"},
)
# 429 is acceptable — rate-limited is a valid security response
assert response.status_code in (400, 422, 429), (
f"SQL injection payload '{payload}' returned {response.status_code}"
)
@pytest.mark.asyncio
async def test_query_path_param_injection(self, async_client, auth_headers):
"""Path parameters should not be vulnerable to SQL injection.
UUID-type path parameters will fail validation for non-UUID inputs.
"""
injection_payloads = [
"1' OR '1'='1",
"1; DROP TABLE queries; --",
"' UNION SELECT * FROM users--",
]
for payload in injection_payloads:
response = await async_client.get(
f"/api/v1/queries/{payload}",
headers=auth_headers,
)
# Non-UUID path params should return 422 (validation error)
assert response.status_code in (404, 422), (
f"Path param injection '{payload}' returned {response.status_code}"
)
@pytest.mark.asyncio
async def test_brand_path_param_injection(self, async_client, auth_headers):
"""Brand path parameters should reject SQL injection."""
injection_payloads = [
"1' OR '1'='1",
"1; DROP TABLE brands; --",
]
for payload in injection_payloads:
response = await async_client.get(
f"/api/v1/brands/{payload}/",
headers=auth_headers,
)
assert response.status_code in (404, 422), (
f"Brand path param injection '{payload}' returned {response.status_code}"
)
@pytest.mark.asyncio
async def test_forgot_password_sql_injection_rejected(self, client_no_override):
"""Forgot-password endpoint should reject SQL injection payloads."""
payloads = [
"' OR '1'='1",
"admin'--",
]
for payload in payloads:
response = await client_no_override.post(
"/api/v1/auth/forgot-password",
json={"email": payload},
)
assert response.status_code in (200, 422, 429), (
f"SQL injection payload '{payload}' returned {response.status_code}"
)
# Even if 200 (generic response), should not leak user existence
if response.status_code == 200:
data = response.json()
assert "message" in data
@pytest.mark.asyncio
async def test_reset_password_sql_injection_rejected(self, client_no_override):
"""Reset-password endpoint should handle injection payloads safely."""
response = await client_no_override.post(
"/api/v1/auth/reset-password",
json={
"token": "' OR '1'='1",
"new_password": "newpassword123",
},
)
# Token should be rejected (invalid or expired), not cause SQL error
assert response.status_code in (400, 422, 429), (
f"Reset password injection returned {response.status_code}"
)
# ═══════════════════════════════════════════════════════════
# XSS Protection Tests
# ═══════════════════════════════════════════════════════════
class TestXSSProtection:
"""Verify that XSS attack vectors are properly mitigated.
For a pure JSON API, XSS payloads stored as-is in the database is
expected behavior — JSON responses are not rendered as HTML by
browsers. The real XSS protections are:
1. Content-Type: application/json (prevents HTML rendering)
2. Security headers (X-XSS-Protection, X-Content-Type-Options, etc.)
3. Frontend escaping when rendering data in HTML
"""
@pytest.mark.asyncio
async def test_api_returns_json_content_type(self, async_client, auth_headers):
"""All API responses must have Content-Type: application/json.
This is the primary XSS defense for JSON APIs — browsers will
not execute scripts in JSON responses.
"""
endpoints = [
("/api/v1/brands/", "GET"),
("/api/v1/queries/", "GET"),
]
for path, method in endpoints:
response = await async_client.get(path, headers=auth_headers)
if response.status_code == 200:
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type, (
f"Response for {path} has Content-Type '{content_type}', "
f"expected 'application/json'"
)
@pytest.mark.asyncio
async def test_brand_name_xss_not_executable(self, async_client, auth_headers):
"""XSS payloads in brand name should be stored as plain text.
In a JSON API, script tags are stored as text and not executed
because the response Content-Type is application/json.
The key assertion is that the response is valid JSON (not HTML).
"""
xss_payloads = [
"<script>alert('xss')</script>",
"<img src=x onerror=alert(1)>",
"javascript:alert(1)",
"<svg onload=alert(1)>",
"<iframe src='javascript:alert(1)'>",
]
for payload in xss_payloads:
response = await async_client.post(
"/api/v1/brands/",
json={"name": payload, "platforms": ["kimi"]},
headers=auth_headers,
)
if response.status_code in (200, 201):
# Verify the response is valid JSON (not HTML)
data = response.json()
assert isinstance(data, dict), "Response should be a JSON object"
# Verify Content-Type is application/json
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type, (
f"Content-Type should be application/json, got '{content_type}'"
)
# Verify the XSS payload is stored as plain text
# (it's the frontend's responsibility to escape when rendering)
name = data.get("name", "")
assert name == payload, (
f"Brand name should store XSS payload as-is (plain text), "
f"got '{name}' instead of '{payload}'"
)
@pytest.mark.asyncio
async def test_brand_update_xss_as_plain_text(self, async_client, async_session, test_user, auth_headers):
"""XSS payloads in brand aliases should be stored as plain text."""
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Safe Brand",
platforms=["wenxin"],
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
xss_payloads = [
"<script>alert('xss')</script>",
"<img src=x onerror=alert(1)>",
]
for payload in xss_payloads:
response = await async_client.put(
f"/api/v1/brands/{brand.id}/",
json={"aliases": [payload]},
headers=auth_headers,
)
if response.status_code == 200:
data = response.json()
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type
# XSS payloads stored as plain text in JSON
assert payload in data.get("aliases", [])
@pytest.mark.asyncio
async def test_query_keyword_xss_as_plain_text(self, async_client, auth_headers):
"""XSS payloads in query keyword should be stored as plain text."""
xss_payload = "<script>alert('xss')</script>"
response = await async_client.post(
"/api/v1/queries/",
json={
"keyword": xss_payload,
"target_brand": "Test Brand",
"platforms": ["wenxin"],
},
headers=auth_headers,
)
if response.status_code in (200, 201):
data = response.json()
content_type = response.headers.get("content-type", "")
assert "application/json" in content_type
# XSS payload stored as plain text
assert data.get("keyword") == xss_payload
@pytest.mark.asyncio
async def test_security_headers_present(self, async_client):
"""Verify that security response headers are set on all responses."""
response = await async_client.get("/health")
assert response.status_code == 200
# Check essential security headers
assert response.headers.get("x-content-type-options") == "nosniff", (
"X-Content-Type-Options header missing or incorrect"
)
assert response.headers.get("x-frame-options") == "DENY", (
"X-Frame-Options header missing or incorrect"
)
assert response.headers.get("x-xss-protection") == "1; mode=block", (
"X-XSS-Protection header missing or incorrect"
)
assert response.headers.get("referrer-policy") == "strict-origin-when-cross-origin", (
"Referrer-Policy header missing or incorrect"
)
@pytest.mark.asyncio
async def test_security_headers_on_api_endpoints(self, async_client, auth_headers):
"""Security headers should be present on API endpoints too."""
response = await async_client.get("/api/v1/brands/", headers=auth_headers)
if response.status_code == 200:
assert response.headers.get("x-content-type-options") == "nosniff"
assert response.headers.get("x-frame-options") == "DENY"
assert response.headers.get("x-xss-protection") == "1; mode=block"
# ═══════════════════════════════════════════════════════════
# Authentication Security Tests
# ═══════════════════════════════════════════════════════════
class TestAuthSecurity:
"""Verify authentication security mechanisms."""
@pytest.mark.asyncio
async def test_expired_token_rejected(self, client_no_override):
"""Expired JWT tokens should be rejected with 401."""
# Create a token that expired 1 hour ago
expired_payload = {
"sub": str(uuid.uuid4()),
"exp": datetime.utcnow() - timedelta(hours=1),
"type": "access",
}
expired_token = jwt.encode(expired_payload, settings.JWT_SECRET, algorithm="HS256")
response = await client_no_override.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {expired_token}"},
)
assert response.status_code == 401, (
f"Expired token should be rejected, got {response.status_code}"
)
@pytest.mark.asyncio
async def test_invalid_token_rejected(self, client_no_override):
"""Invalid JWT tokens should be rejected with 401."""
invalid_tokens = [
"invalid.token.here",
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.invalid.payload",
"null",
"undefined",
]
for token in invalid_tokens:
response = await client_no_override.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 401, (
f"Invalid token '{token[:20]}...' should be rejected, got {response.status_code}"
)
@pytest.mark.asyncio
async def test_invalid_refresh_token(self, client_no_override):
"""Invalid refresh tokens should return 401."""
invalid_tokens = [
"invalid.refresh.token",
"eyJhbGciOiJIUzI1NiJ9.invalid.payload",
]
for token in invalid_tokens:
response = await client_no_override.post(
"/api/v1/auth/refresh",
json={"refresh_token": token},
)
assert response.status_code == 401, (
f"Invalid refresh token should be rejected, got {response.status_code}"
)
@pytest.mark.asyncio
async def test_access_token_used_as_refresh_rejected(self, client_no_override, test_user):
"""Access tokens should not be accepted as refresh tokens."""
access_token = create_access_token(data={"sub": str(test_user.id)})
response = await client_no_override.post(
"/api/v1/auth/refresh",
json={"refresh_token": access_token},
)
assert response.status_code == 401, (
"Access token used as refresh token should be rejected"
)
@pytest.mark.asyncio
async def test_refresh_token_used_as_access_rejected(self, client_no_override, test_user):
"""Refresh tokens should not be accepted as access tokens."""
refresh_token = create_refresh_token(data={"sub": str(test_user.id)})
response = await client_no_override.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {refresh_token}"},
)
assert response.status_code == 401, (
"Refresh token used as access token should be rejected"
)
@pytest.mark.asyncio
async def test_missing_authorization_header(self, client_no_override):
"""Requests without Authorization header should return 401."""
response = await client_no_override.get("/api/v1/auth/me")
assert response.status_code in (401, 403), (
f"Missing auth should return 401/403, got {response.status_code}"
)
@pytest.mark.asyncio
async def test_cross_user_data_isolation(
self, async_session, test_user, second_user, auth_headers
):
"""User A should not be able to access User B's brand data."""
# Create a brand for second_user
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(second_user.id),
name="Second User's Brand",
platforms=["wenxin"],
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
# Do NOT override get_current_user — let JWT auth work naturally
test_user_token = create_access_token(data={"sub": str(test_user.id)})
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# test_user tries to access second_user's brand
response = await client.get(
f"/api/v1/brands/{brand.id}/",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 404, (
f"User should not access another user's brand, got {response.status_code}"
)
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_cross_user_query_isolation(
self, async_session, test_user, second_user
):
"""User A should not be able to access User B's query data."""
query = Query(
id=uuid.uuid4(),
user_id=second_user.id,
keyword="Second User Query",
target_brand="Second Brand",
platforms=["wenxin"],
status="active",
)
async_session.add(query)
await async_session.commit()
await async_session.refresh(query)
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
test_user_token = create_access_token(data={"sub": str(test_user.id)})
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get(
f"/api/v1/queries/{query.id}",
headers={"Authorization": f"Bearer {test_user_token}"},
)
assert response.status_code == 404, (
f"User should not access another user's query, got {response.status_code}"
)
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_nonexistent_user_token_rejected(self, client_no_override):
"""Token with nonexistent user_id should be rejected."""
fake_user_id = str(uuid.uuid4())
token = create_access_token(data={"sub": fake_user_id})
response = await client_no_override.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 401, (
f"Token for nonexistent user should be rejected, got {response.status_code}"
)
@pytest.mark.asyncio
async def test_login_error_message_no_user_enumeration(self, client_no_override):
"""Login error messages should not reveal whether email exists.
Both non-existent email and wrong password should return
the same error status. Rate limiting (429) is also acceptable.
"""
# Non-existent email
response1 = await client_no_override.post(
"/api/v1/auth/login",
json={"email": "nonexistent@example.com", "password": "password123"},
)
# Existing email with wrong password
response2 = await client_no_override.post(
"/api/v1/auth/login",
json={"email": "security_test@example.com", "password": "WrongPassword999!"},
)
# Both should return the same status code (either 401 or 429 if rate-limited)
# The key point: no information leakage about whether the user exists
assert response1.status_code in (401, 429), (
f"Non-existent email login returned {response1.status_code}, expected 401/429"
)
assert response2.status_code in (401, 429), (
f"Wrong password login returned {response2.status_code}, expected 401/429"
)
# When both return 401 (not rate-limited), verify same error structure
if response1.status_code == 401 and response2.status_code == 401:
# Both should have consistent error responses (no user enumeration)
assert response1.status_code == response2.status_code