322 lines
10 KiB
Python
322 lines
10 KiB
Python
"""品牌API测试"""
|
||
import uuid
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from httpx import AsyncClient, ASGITransport
|
||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.main import app
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.api.deps import get_current_user, get_db
|
||
from app.services.auth import hash_password, create_access_token
|
||
|
||
|
||
# ==================== Fixtures ====================
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_engine():
|
||
"""创建测试用SQLite异步引擎"""
|
||
engine = create_async_engine(
|
||
"sqlite+aiosqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
yield engine
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_session(async_engine):
|
||
"""创建测试用异步数据库会话"""
|
||
async_session_maker = async_sessionmaker(
|
||
async_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False,
|
||
)
|
||
async with async_session_maker() as session:
|
||
yield session
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def test_user(async_session):
|
||
"""创建测试用户"""
|
||
user = User(
|
||
id=uuid.uuid4(),
|
||
email="test@example.com",
|
||
password_hash=hash_password("Test@123456"),
|
||
name="Test User",
|
||
plan="free",
|
||
max_queries=5,
|
||
is_active=True,
|
||
email_verified=True,
|
||
)
|
||
async_session.add(user)
|
||
await async_session.commit()
|
||
await async_session.refresh(user)
|
||
return user
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def test_brand(async_session, test_user):
|
||
"""创建测试品牌"""
|
||
brand = Brand(
|
||
id=uuid.uuid4(),
|
||
user_id=test_user.id,
|
||
name="Test Brand",
|
||
aliases=["TestBrand", "TB"],
|
||
website="https://testbrand.com",
|
||
industry="technology",
|
||
platforms=["wenxin", "kimi"],
|
||
frequency="weekly",
|
||
status="active",
|
||
)
|
||
async_session.add(brand)
|
||
await async_session.commit()
|
||
await async_session.refresh(brand)
|
||
return brand
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_client(async_session, test_user):
|
||
"""创建异步HTTP客户端用于API测试"""
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
async def override_get_current_user():
|
||
return test_user
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
yield client
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
@pytest.fixture
|
||
def auth_headers(test_user):
|
||
"""创建认证请求头"""
|
||
token = create_access_token(data={"sub": str(test_user.id)})
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
|
||
# ==================== 测试类 ====================
|
||
|
||
class TestBrandsAPI:
|
||
"""品牌API测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_brands_empty(self, async_client):
|
||
"""测试获取空品牌列表"""
|
||
response = await async_client.get("/api/v1/brands/")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["items"] == []
|
||
assert data["total"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_brand(self, async_client, async_session, test_user):
|
||
"""测试创建品牌"""
|
||
brand_data = {
|
||
"name": "华为",
|
||
"aliases": ["Huawei", "HW"],
|
||
"website": "https://www.huawei.com",
|
||
"industry": "technology",
|
||
"platforms": ["wenxin", "kimi", "doubao"],
|
||
"frequency": "daily",
|
||
}
|
||
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||
|
||
assert response.status_code == 201
|
||
data = response.json()
|
||
assert data["name"] == "华为"
|
||
assert data["aliases"] == ["Huawei", "HW"]
|
||
assert data["website"] == "https://www.huawei.com"
|
||
assert data["industry"] == "technology"
|
||
assert data["platforms"] == ["wenxin", "kimi", "doubao"]
|
||
assert data["frequency"] == "daily"
|
||
assert data["status"] == "active"
|
||
assert data["user_id"] == str(test_user.id)
|
||
assert "id" in data
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_brand_minimal(self, async_client):
|
||
"""测试创建品牌(最小数据)"""
|
||
brand_data = {
|
||
"name": "minimal_brand",
|
||
}
|
||
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||
|
||
assert response.status_code == 201
|
||
data = response.json()
|
||
assert data["name"] == "minimal_brand"
|
||
assert data["aliases"] == []
|
||
assert data["platforms"] == ["wenxin", "kimi"] # 默认值
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_brands(self, async_client, async_session, test_user):
|
||
"""测试列出多个品牌"""
|
||
# 创建多个品牌
|
||
for i in range(3):
|
||
brand = Brand(
|
||
user_id=test_user.id,
|
||
name=f"Brand {i}",
|
||
platforms=["wenxin"],
|
||
)
|
||
async_session.add(brand)
|
||
await async_session.commit()
|
||
|
||
response = await async_client.get("/api/v1/brands/")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert len(data["items"]) == 3
|
||
assert data["total"] == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_brand_by_id(self, async_client, test_brand):
|
||
"""测试通过ID获取品牌"""
|
||
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == str(test_brand.id)
|
||
assert data["name"] == "Test Brand"
|
||
assert data["aliases"] == ["TestBrand", "TB"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_brand_not_found(self, async_client):
|
||
"""测试获取不存在的品牌"""
|
||
non_existent_id = uuid.uuid4()
|
||
response = await async_client.get(f"/api/v1/brands/{non_existent_id}/")
|
||
|
||
assert response.status_code == 404
|
||
assert "品牌不存在" in response.json()["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_brand(self, async_client, test_brand):
|
||
"""测试更新品牌"""
|
||
# 注意:BrandUpdate schema 不允许更新 name 字段
|
||
update_data = {
|
||
"aliases": ["Updated", "Alias"],
|
||
"frequency": "daily",
|
||
}
|
||
response = await async_client.put(
|
||
f"/api/v1/brands/{test_brand.id}/", json=update_data
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["aliases"] == ["Updated", "Alias"]
|
||
assert data["frequency"] == "daily"
|
||
assert data["name"] == "Test Brand" # name 保持不变
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_brand_partial(self, async_client, test_brand):
|
||
"""测试部分更新品牌"""
|
||
update_data = {
|
||
"frequency": "monthly",
|
||
}
|
||
response = await async_client.put(
|
||
f"/api/v1/brands/{test_brand.id}/", json=update_data
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
# 只更新frequency,其他字段保持不变
|
||
assert data["frequency"] == "monthly"
|
||
assert data["name"] == "Test Brand"
|
||
assert data["aliases"] == ["TestBrand", "TB"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_brand_not_found(self, async_client):
|
||
"""测试更新不存在的品牌"""
|
||
non_existent_id = uuid.uuid4()
|
||
response = await async_client.put(
|
||
f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"}
|
||
)
|
||
|
||
assert response.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_brand(self, async_client, test_brand):
|
||
"""测试删除品牌"""
|
||
response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/")
|
||
|
||
assert response.status_code == 204
|
||
|
||
# 验证品牌已删除
|
||
get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
|
||
assert get_response.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_brand_not_found(self, async_client):
|
||
"""测试删除不存在的品牌"""
|
||
non_existent_id = uuid.uuid4()
|
||
response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/")
|
||
|
||
assert response.status_code == 404
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_brand_unauthorized_access(self, async_session):
|
||
"""测试未授权访问品牌API(无效token)"""
|
||
# 使用无效的token访问品牌API
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
# 使用无效的token
|
||
headers = {"Authorization": "Bearer invalid_token"}
|
||
|
||
# 尝试访问品牌列表
|
||
response = await client.get(
|
||
"/api/v1/brands/",
|
||
headers=headers
|
||
)
|
||
# 无效token应该返回401
|
||
assert response.status_code == 401
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
class TestBrandsValidation:
|
||
"""品牌API验证测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_brand_name_too_short(self, async_client):
|
||
"""测试品牌名称过短"""
|
||
brand_data = {
|
||
"name": "A", # 最小长度为2
|
||
}
|
||
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||
|
||
assert response.status_code == 422 # 验证错误
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_brand_invalid_frequency(self, async_client):
|
||
"""测试无效的更新频率"""
|
||
brand_data = {
|
||
"name": "Valid Brand",
|
||
"frequency": "invalid_frequency",
|
||
}
|
||
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||
|
||
# frequency字段在BrandCreate中没有严格验证,但应该是有效值
|
||
# 这里主要测试API不会崩溃
|
||
assert response.status_code in [201, 422]
|