"""品牌API测试""" import uuid import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.main import app from app.models.user import User from app.models.brand import Brand from app.api.deps import get_current_user, get_db from app.services.auth import hash_password, create_access_token from tests.fixtures.auth import _to_uuid # ==================== Fixtures ==================== @pytest_asyncio.fixture async def async_engine(): """创建测试用SQLite异步引擎""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """创建测试用异步数据库会话""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): """创建测试用户""" user = User( id=str(uuid.uuid4()), email="test@example.com", password=hash_password("Test@123456"), firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): """创建测试品牌""" brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="Test Brand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def async_client(async_session, test_user): """创建异步HTTP客户端用于API测试""" async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() @pytest.fixture def auth_headers(test_user): """创建认证请求头""" token = create_access_token(data={"sub": str(test_user.id)}) return {"Authorization": f"Bearer {token}"} # ==================== 测试类 ==================== class TestBrandsAPI: """品牌API测试""" @pytest.mark.asyncio async def test_list_brands_empty(self, async_client): """测试获取空品牌列表""" response = await async_client.get("/api/v1/brands/") assert response.status_code == 200 data = response.json() assert data["items"] == [] assert data["total"] == 0 @pytest.mark.asyncio async def test_create_brand(self, async_client, async_session, test_user): """测试创建品牌""" brand_data = { "name": "华为", "aliases": ["Huawei", "HW"], "website": "https://www.huawei.com", "industry": "technology", "platforms": ["wenxin", "kimi", "doubao"], "frequency": "daily", } response = await async_client.post("/api/v1/brands/", json=brand_data) assert response.status_code == 201 data = response.json() assert data["name"] == "华为" assert data["aliases"] == ["Huawei", "HW"] assert data["website"] == "https://www.huawei.com" assert data["industry"] == "technology" assert data["platforms"] == ["wenxin", "kimi", "doubao"] assert data["frequency"] == "daily" assert data["status"] == "active" assert data["user_id"] == str(test_user.id) assert "id" in data @pytest.mark.asyncio async def test_create_brand_minimal(self, async_client): """测试创建品牌(最小数据)""" brand_data = { "name": "minimal_brand", } response = await async_client.post("/api/v1/brands/", json=brand_data) assert response.status_code == 201 data = response.json() assert data["name"] == "minimal_brand" assert data["aliases"] == [] assert data["platforms"] == ["wenxin", "kimi"] # 默认值 @pytest.mark.asyncio async def test_list_brands(self, async_client, async_session, test_user): """测试列出多个品牌""" # 创建多个品牌 for i in range(3): brand = Brand( user_id=_to_uuid(test_user.id), name=f"Brand {i}", platforms=["wenxin"], ) async_session.add(brand) await async_session.commit() response = await async_client.get("/api/v1/brands/") assert response.status_code == 200 data = response.json() assert len(data["items"]) == 3 assert data["total"] == 3 @pytest.mark.asyncio async def test_get_brand_by_id(self, async_client, test_brand): """测试通过ID获取品牌""" response = await async_client.get(f"/api/v1/brands/{test_brand.id}/") assert response.status_code == 200 data = response.json() assert data["id"] == str(test_brand.id) assert data["name"] == "Test Brand" assert data["aliases"] == ["TestBrand", "TB"] @pytest.mark.asyncio async def test_get_brand_not_found(self, async_client): """测试获取不存在的品牌""" non_existent_id = uuid.uuid4() response = await async_client.get(f"/api/v1/brands/{non_existent_id}/") assert response.status_code == 404 assert "品牌不存在" in response.json()["detail"] @pytest.mark.asyncio async def test_update_brand(self, async_client, test_brand): """测试更新品牌""" # 注意:BrandUpdate schema 不允许更新 name 字段 update_data = { "aliases": ["Updated", "Alias"], "frequency": "daily", } response = await async_client.put( f"/api/v1/brands/{test_brand.id}/", json=update_data ) assert response.status_code == 200 data = response.json() assert data["aliases"] == ["Updated", "Alias"] assert data["frequency"] == "daily" assert data["name"] == "Test Brand" # name 保持不变 @pytest.mark.asyncio async def test_update_brand_partial(self, async_client, test_brand): """测试部分更新品牌""" update_data = { "frequency": "monthly", } response = await async_client.put( f"/api/v1/brands/{test_brand.id}/", json=update_data ) assert response.status_code == 200 data = response.json() # 只更新frequency,其他字段保持不变 assert data["frequency"] == "monthly" assert data["name"] == "Test Brand" assert data["aliases"] == ["TestBrand", "TB"] @pytest.mark.asyncio async def test_update_brand_not_found(self, async_client): """测试更新不存在的品牌""" non_existent_id = uuid.uuid4() response = await async_client.put( f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"} ) assert response.status_code == 404 @pytest.mark.asyncio async def test_delete_brand(self, async_client, test_brand): """测试删除品牌""" response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/") assert response.status_code == 204 # 验证品牌已删除 get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/") assert get_response.status_code == 404 @pytest.mark.asyncio async def test_delete_brand_not_found(self, async_client): """测试删除不存在的品牌""" non_existent_id = uuid.uuid4() response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/") assert response.status_code == 404 @pytest.mark.asyncio async def test_brand_unauthorized_access(self, async_session): """测试未授权访问品牌API(无效token)""" # 使用无效的token访问品牌API async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: # 使用无效的token headers = {"Authorization": "Bearer invalid_token"} # 尝试访问品牌列表 response = await client.get( "/api/v1/brands/", headers=headers ) # 无效token应该返回401 assert response.status_code == 401 app.dependency_overrides.clear() class TestBrandsValidation: """品牌API验证测试""" @pytest.mark.asyncio async def test_create_brand_name_too_short(self, async_client): """测试品牌名称过短""" brand_data = { "name": "A", # 最小长度为2 } response = await async_client.post("/api/v1/brands/", json=brand_data) assert response.status_code == 422 # 验证错误 @pytest.mark.asyncio async def test_create_brand_invalid_frequency(self, async_client): """测试无效的更新频率""" brand_data = { "name": "Valid Brand", "frequency": "invalid_frequency", } response = await async_client.post("/api/v1/brands/", json=brand_data) # frequency字段在BrandCreate中没有严格验证,但应该是有效值 # 这里主要测试API不会崩溃 assert response.status_code in [201, 422]