geo/backend/tests/test_api/test_brands_api.py

323 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""品牌API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestBrandsAPI:
"""品牌API测试"""
@pytest.mark.asyncio
async def test_list_brands_empty(self, async_client):
"""测试获取空品牌列表"""
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["total"] == 0
@pytest.mark.asyncio
async def test_create_brand(self, async_client, async_session, test_user):
"""测试创建品牌"""
brand_data = {
"name": "华为",
"aliases": ["Huawei", "HW"],
"website": "https://www.huawei.com",
"industry": "technology",
"platforms": ["wenxin", "kimi", "doubao"],
"frequency": "daily",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "华为"
assert data["aliases"] == ["Huawei", "HW"]
assert data["website"] == "https://www.huawei.com"
assert data["industry"] == "technology"
assert data["platforms"] == ["wenxin", "kimi", "doubao"]
assert data["frequency"] == "daily"
assert data["status"] == "active"
assert data["user_id"] == str(test_user.id)
assert "id" in data
@pytest.mark.asyncio
async def test_create_brand_minimal(self, async_client):
"""测试创建品牌(最小数据)"""
brand_data = {
"name": "minimal_brand",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "minimal_brand"
assert data["aliases"] == []
assert data["platforms"] == ["wenxin", "kimi"] # 默认值
@pytest.mark.asyncio
async def test_list_brands(self, async_client, async_session, test_user):
"""测试列出多个品牌"""
# 创建多个品牌
for i in range(3):
brand = Brand(
user_id=_to_uuid(test_user.id),
name=f"Brand {i}",
platforms=["wenxin"],
)
async_session.add(brand)
await async_session.commit()
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["total"] == 3
@pytest.mark.asyncio
async def test_get_brand_by_id(self, async_client, test_brand):
"""测试通过ID获取品牌"""
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_brand.id)
assert data["name"] == "Test Brand"
assert data["aliases"] == ["TestBrand", "TB"]
@pytest.mark.asyncio
async def test_get_brand_not_found(self, async_client):
"""测试获取不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.get(f"/api/v1/brands/{non_existent_id}/")
assert response.status_code == 404
assert "品牌不存在" in response.json()["detail"]
@pytest.mark.asyncio
async def test_update_brand(self, async_client, test_brand):
"""测试更新品牌"""
# 注意BrandUpdate schema 不允许更新 name 字段
update_data = {
"aliases": ["Updated", "Alias"],
"frequency": "daily",
}
response = await async_client.put(
f"/api/v1/brands/{test_brand.id}/", json=update_data
)
assert response.status_code == 200
data = response.json()
assert data["aliases"] == ["Updated", "Alias"]
assert data["frequency"] == "daily"
assert data["name"] == "Test Brand" # name 保持不变
@pytest.mark.asyncio
async def test_update_brand_partial(self, async_client, test_brand):
"""测试部分更新品牌"""
update_data = {
"frequency": "monthly",
}
response = await async_client.put(
f"/api/v1/brands/{test_brand.id}/", json=update_data
)
assert response.status_code == 200
data = response.json()
# 只更新frequency其他字段保持不变
assert data["frequency"] == "monthly"
assert data["name"] == "Test Brand"
assert data["aliases"] == ["TestBrand", "TB"]
@pytest.mark.asyncio
async def test_update_brand_not_found(self, async_client):
"""测试更新不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.put(
f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"}
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_delete_brand(self, async_client, test_brand):
"""测试删除品牌"""
response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 204
# 验证品牌已删除
get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_delete_brand_not_found(self, async_client):
"""测试删除不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_brand_unauthorized_access(self, async_session):
"""测试未授权访问品牌API无效token"""
# 使用无效的token访问品牌API
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 使用无效的token
headers = {"Authorization": "Bearer invalid_token"}
# 尝试访问品牌列表
response = await client.get(
"/api/v1/brands/",
headers=headers
)
# 无效token应该返回401
assert response.status_code == 401
app.dependency_overrides.clear()
class TestBrandsValidation:
"""品牌API验证测试"""
@pytest.mark.asyncio
async def test_create_brand_name_too_short(self, async_client):
"""测试品牌名称过短"""
brand_data = {
"name": "A", # 最小长度为2
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 422 # 验证错误
@pytest.mark.asyncio
async def test_create_brand_invalid_frequency(self, async_client):
"""测试无效的更新频率"""
brand_data = {
"name": "Valid Brand",
"frequency": "invalid_frequency",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
# frequency字段在BrandCreate中没有严格验证但应该是有效值
# 这里主要测试API不会崩溃
assert response.status_code in [201, 422]