geo/backend/tests/test_api/test_brands.py

210 lines
6.7 KiB
Python

"""Tests for brands API."""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import create_access_token
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
async def async_engine():
"""Create async engine for testing with SQLite."""
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.pool import StaticPool
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""Create async session for testing."""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""Create a test user."""
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""Create a test brand."""
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""Create async client for API testing."""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""Create authentication headers."""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
class TestBrandsAPI:
"""Test cases for brands API."""
@pytest.mark.asyncio
async def test_get_brands_empty(self, async_client):
"""Test getting empty brand list."""
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["total"] == 0
@pytest.mark.asyncio
async def test_create_brand(self, async_client, async_session, test_user):
"""Test creating a new brand."""
brand_data = {
"name": "New Brand",
"aliases": ["NewBrand", "NB"],
"website": "https://newbrand.com",
"industry": "technology",
"platforms": ["wenxin", "kimi"],
"frequency": "weekly",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "New Brand"
assert data["aliases"] == ["NewBrand", "NB"]
assert data["website"] == "https://newbrand.com"
assert data["industry"] == "technology"
assert data["platforms"] == ["wenxin", "kimi"]
assert data["frequency"] == "weekly"
assert data["status"] == "active"
assert data["user_id"] == str(test_user.id)
@pytest.mark.asyncio
async def test_get_brand_by_id(self, async_client, test_brand):
"""Test getting brand by ID."""
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_brand.id)
assert data["name"] == "Test Brand"
assert data["aliases"] == ["TestBrand", "TB"]
@pytest.mark.asyncio
async def test_get_brand_not_found(self, async_client):
"""Test getting non-existent brand."""
non_existent_id = uuid.uuid4()
response = await async_client.get(f"/api/v1/brands/{non_existent_id}/")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_brand(self, async_client, test_brand):
"""Test updating a brand."""
update_data = {
"aliases": ["Updated", "Alias"],
"frequency": "daily",
}
response = await async_client.put(
f"/api/v1/brands/{test_brand.id}/", json=update_data
)
assert response.status_code == 200
data = response.json()
assert data["aliases"] == ["Updated", "Alias"]
assert data["frequency"] == "daily"
assert data["name"] == "Test Brand" # Unchanged
@pytest.mark.asyncio
async def test_delete_brand(self, async_client, test_brand):
"""Test deleting a brand."""
response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 204
# Verify brand is deleted
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_list_brands(self, async_client, async_session, test_user):
"""Test listing multiple brands."""
# Create multiple brands
for i in range(3):
brand = Brand(
user_id=_to_uuid(test_user.id),
name=f"Brand {i}",
platforms=["wenxin"],
)
async_session.add(brand)
await async_session.commit()
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["total"] == 3