geo/backend/tests/test_infrastructure/test_performance.py

381 lines
14 KiB
Python

"""Performance tests: concurrent access, response time, and rate limiting."""
import asyncio
import time
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.query import Query
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.suggestion import Suggestion
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token, hash_password
# Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility)
_TEST_TABLES = (
User.__table__,
Query.__table__,
Brand.__table__,
Competitor.__table__,
Suggestion.__table__,
)
# ─────────────────────── Fixtures ───────────────────────
@pytest_asyncio.fixture
async def async_engine():
"""Create async engine for testing with SQLite.
Only creates the specific tables needed by performance tests,
avoiding PostgreSQL-only types (JSONB) that fail on SQLite.
"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(
lambda sync_conn: Base.metadata.create_all(
sync_conn, tables=[t for t in _TEST_TABLES]
)
)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""Create async session for testing."""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""Create a test user with properly hashed password."""
user = User(
id=uuid.uuid4(),
email="perf_test@example.com",
password_hash=hash_password("PerfTest123!"),
name="Performance Test User",
plan="free",
max_queries=50,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""Create async client for API testing with dependency overrides."""
session = async_session
async def override_get_db():
yield session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def client_no_override(async_session):
"""Create async client WITHOUT overriding get_current_user (for real auth flow)."""
session = async_session
async def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""Create authentication headers."""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ═══════════════════════════════════════════════════════════
# API Response Time Tests
# ═══════════════════════════════════════════════════════════
class TestAPIPerformance:
"""Test API response time and concurrency behavior."""
@pytest.mark.asyncio
async def test_health_check_fast(self, async_client):
"""Health check endpoint should respond quickly (< 100ms)."""
start = time.time()
response = await async_client.get("/health")
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 0.1, f"Health check took {elapsed:.3f}s, expected < 0.1s"
@pytest.mark.asyncio
async def test_brand_list_performance(self, async_client, async_session, test_user, auth_headers):
"""Brand list API should respond within 500ms."""
# Create several brands for a more realistic test
for i in range(10):
brand = Brand(
user_id=test_user.id,
name=f"Brand {i}",
platforms=["wenxin"],
status="active",
)
async_session.add(brand)
await async_session.commit()
start = time.time()
response = await async_client.get("/api/v1/brands/", headers=auth_headers)
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 0.5, f"Brand list took {elapsed:.3f}s, expected < 0.5s"
@pytest.mark.asyncio
async def test_query_list_performance(self, async_client, async_session, test_user, auth_headers):
"""Query list API should respond within 500ms."""
# Create several queries for a more realistic test
for i in range(10):
query = Query(
user_id=test_user.id,
keyword=f"query keyword {i}",
target_brand=f"Brand {i}",
platforms=["wenxin"],
status="active",
)
async_session.add(query)
await async_session.commit()
start = time.time()
response = await async_client.get("/api/v1/queries/", headers=auth_headers)
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 0.5, f"Query list took {elapsed:.3f}s, expected < 0.5s"
@pytest.mark.asyncio
async def test_me_endpoint_performance(self, async_client, auth_headers):
"""Current user endpoint should respond within 200ms."""
start = time.time()
response = await async_client.get("/api/v1/auth/me", headers=auth_headers)
elapsed = time.time() - start
assert response.status_code == 200
assert elapsed < 0.2, f"/auth/me took {elapsed:.3f}s, expected < 0.2s"
# ═══════════════════════════════════════════════════════════
# Concurrency Tests
# ═══════════════════════════════════════════════════════════
class TestConcurrency:
"""Test concurrent access behavior."""
@pytest.mark.asyncio
async def test_concurrent_login_no_crash(self, client_no_override, test_user):
"""50 concurrent login requests should not cause system crash.
Note: Rate limiting will kick in after 5 attempts from same IP,
so most requests will get 429. The key point is no 500 errors.
"""
tasks = []
for _ in range(50):
task = client_no_override.post(
"/api/v1/auth/login",
json={
"email": "perf_test@example.com",
"password": "PerfTest123!",
},
)
tasks.append(task)
responses = await asyncio.gather(*tasks, return_exceptions=True)
for i, resp in enumerate(responses):
if isinstance(resp, Exception):
# Network/transport errors are acceptable under heavy load
continue
# Should NOT get 500 — rate limiting (429) or auth errors (401/422) are fine
assert resp.status_code in (200, 401, 422, 429), (
f"Concurrent login request {i} returned {resp.status_code}, "
f"expected 200/401/422/429"
)
@pytest.mark.asyncio
async def test_concurrent_brand_reads(self, async_client, async_session, test_user, auth_headers):
"""Concurrent brand list reads should all succeed."""
# Pre-create data
for i in range(5):
brand = Brand(
user_id=test_user.id,
name=f"Concurrent Brand {i}",
platforms=["wenxin"],
status="active",
)
async_session.add(brand)
await async_session.commit()
# 20 concurrent read requests
tasks = [
async_client.get("/api/v1/brands/", headers=auth_headers)
for _ in range(20)
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
success_count = 0
for resp in responses:
if isinstance(resp, Exception):
continue
if resp.status_code == 200:
success_count += 1
# At least some should succeed
assert success_count > 0, "No concurrent brand reads succeeded"
@pytest.mark.asyncio
async def test_concurrent_query_reads(self, async_client, async_session, test_user, auth_headers):
"""Concurrent query list reads should all succeed."""
for i in range(5):
query = Query(
user_id=test_user.id,
keyword=f"concurrent query {i}",
target_brand=f"Brand {i}",
platforms=["wenxin"],
status="active",
)
async_session.add(query)
await async_session.commit()
tasks = [
async_client.get("/api/v1/queries/", headers=auth_headers)
for _ in range(20)
]
responses = await asyncio.gather(*tasks, return_exceptions=True)
success_count = 0
for resp in responses:
if isinstance(resp, Exception):
continue
if resp.status_code == 200:
success_count += 1
assert success_count > 0, "No concurrent query reads succeeded"
# ═══════════════════════════════════════════════════════════
# Rate Limiting Tests
# ═══════════════════════════════════════════════════════════
class TestRateLimiting:
"""Test rate limiting enforcement."""
@pytest.mark.asyncio
async def test_login_rate_limit_enforcement(self, client_no_override):
"""Login endpoint should enforce rate limiting (5 req/min/IP).
After 5 rapid login attempts, subsequent requests should get 429.
"""
# Note: RateLimitMiddleware state is shared across the app instance.
# Since other tests may have already sent requests, we need to
# send enough to trigger the limit. The auth_strict rule allows
# 5 requests per 60 seconds per IP.
responses = []
for _ in range(8):
response = await client_no_override.post(
"/api/v1/auth/login",
json={
"email": "ratelimit@example.com",
"password": "somepassword123",
},
)
responses.append(response)
# At least one of the later requests should be rate-limited (429)
status_codes = [r.status_code for r in responses]
# After 5 requests, additional ones should get 429
rate_limited = [code for code in status_codes if code == 429]
assert len(rate_limited) > 0, (
f"No rate limiting detected. Status codes: {status_codes}. "
f"Expected at least one 429 after 8 rapid login attempts."
)
@pytest.mark.asyncio
async def test_global_rate_limit_high_threshold(self, async_client, auth_headers):
"""Global rate limit (100 req/min) should allow normal usage patterns.
Sending 10 rapid requests should all succeed (well within limit).
"""
responses = []
for _ in range(10):
response = await async_client.get("/health")
responses.append(response)
# All should succeed — well under global limit
success_count = sum(1 for r in responses if r.status_code == 200)
assert success_count == 10, (
f"Expected all 10 health checks to succeed, got {success_count}/10"
)
@pytest.mark.asyncio
async def test_rate_limit_429_response_format(self, client_no_override):
"""Rate-limited responses should have proper 429 format."""
# Exhaust login rate limit
for _ in range(8):
await client_no_override.post(
"/api/v1/auth/login",
json={"email": "rl@example.com", "password": "password123"},
)
# This one should be rate-limited
response = await client_no_override.post(
"/api/v1/auth/login",
json={"email": "rl@example.com", "password": "password123"},
)
if response.status_code == 429:
data = response.json()
# Should have a detail message
assert "detail" in data or "message" in data, (
"429 response should include a detail or message field"
)