286 lines
9.1 KiB
Python
286 lines
9.1 KiB
Python
"""Tests for Competitor model."""
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from app.models.brand import Brand
|
|
from app.models.competitor import Competitor
|
|
from tests.fixtures.auth import _to_uuid
|
|
|
|
|
|
class TestCompetitorModel:
|
|
"""Test cases for Competitor model."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_create(self, async_session, test_user):
|
|
"""Test creating a new competitor."""
|
|
# First create a brand
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Test Brand for Competitor",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
# Create competitor
|
|
competitor = Competitor(
|
|
id=uuid.uuid4(),
|
|
brand_id=brand.id,
|
|
name="Test Competitor",
|
|
aliases=["TC", "TestComp"],
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
await async_session.refresh(competitor)
|
|
|
|
assert competitor.id is not None
|
|
assert competitor.brand_id == brand.id
|
|
assert competitor.name == "Test Competitor"
|
|
assert competitor.aliases == ["TC", "TestComp"]
|
|
assert competitor.created_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_default_values(self, async_session, test_user):
|
|
"""Test competitor default values."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Default Competitor",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
# Create competitor with minimal fields
|
|
competitor = Competitor(
|
|
brand_id=brand.id,
|
|
name="Minimal Competitor",
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
await async_session.refresh(competitor)
|
|
|
|
assert competitor.aliases == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_fields(self, async_session, test_user):
|
|
"""Test competitor field validation."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Field Test",
|
|
platforms=["wenxin", "kimi"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor = Competitor(
|
|
brand_id=brand.id,
|
|
name="Field Test Competitor",
|
|
aliases=["FTC", "FieldTestComp", "Competitor3"],
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
await async_session.refresh(competitor)
|
|
|
|
assert competitor.name == "Field Test Competitor"
|
|
assert len(competitor.name) <= 50
|
|
assert competitor.aliases == ["FTC", "FieldTestComp", "Competitor3"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_query_by_id(self, async_session, test_user):
|
|
"""Test querying competitor by ID."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Query Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor_id = uuid.uuid4()
|
|
competitor = Competitor(
|
|
id=competitor_id,
|
|
brand_id=brand.id,
|
|
name="Query Test Competitor",
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
|
|
result = await async_session.execute(
|
|
select(Competitor).where(Competitor.id == competitor_id)
|
|
)
|
|
fetched_competitor = result.scalar_one()
|
|
|
|
assert fetched_competitor is not None
|
|
assert fetched_competitor.id == competitor_id
|
|
assert fetched_competitor.name == "Query Test Competitor"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_query_by_brand_id(self, async_session, test_user):
|
|
"""Test querying competitors by brand ID."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Multi Competitor Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
# Create multiple competitors for the same brand
|
|
competitor1 = Competitor(
|
|
brand_id=brand.id,
|
|
name="Competitor 1",
|
|
aliases=["C1"],
|
|
)
|
|
competitor2 = Competitor(
|
|
brand_id=brand.id,
|
|
name="Competitor 2",
|
|
aliases=["C2"],
|
|
)
|
|
async_session.add(competitor1)
|
|
async_session.add(competitor2)
|
|
await async_session.commit()
|
|
|
|
result = await async_session.execute(
|
|
select(Competitor).where(Competitor.brand_id == brand.id)
|
|
)
|
|
competitors = result.scalars().all()
|
|
|
|
assert len(competitors) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_timestamp(self, async_session, test_user):
|
|
"""Test competitor created_at timestamp."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Timestamp Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor = Competitor(
|
|
brand_id=brand.id,
|
|
name="Timestamp Test Competitor",
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
await async_session.refresh(competitor)
|
|
|
|
assert competitor.created_at is not None
|
|
assert isinstance(competitor.created_at, datetime)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_update(self, async_session, test_user):
|
|
"""Test updating competitor fields."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Update Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor = Competitor(
|
|
brand_id=brand.id,
|
|
name="Update Test Competitor",
|
|
aliases=["UTC"],
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
|
|
# Update competitor
|
|
competitor.name = "Updated Competitor Name"
|
|
competitor.aliases = ["Updated", "NewAlias"]
|
|
await async_session.commit()
|
|
await async_session.refresh(competitor)
|
|
|
|
assert competitor.name == "Updated Competitor Name"
|
|
assert competitor.aliases == ["Updated", "NewAlias"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_delete(self, async_session, test_user):
|
|
"""Test deleting a competitor."""
|
|
# Create brand first
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Delete Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor = Competitor(
|
|
brand_id=brand.id,
|
|
name="Delete Test Competitor",
|
|
)
|
|
async_session.add(competitor)
|
|
await async_session.commit()
|
|
competitor_id = competitor.id
|
|
|
|
await async_session.delete(competitor)
|
|
await async_session.commit()
|
|
|
|
result = await async_session.execute(
|
|
select(Competitor).where(Competitor.id == competitor_id)
|
|
)
|
|
deleted_competitor = result.scalar_one_or_none()
|
|
|
|
assert deleted_competitor is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_competitor_cascade_delete_with_brand(self, async_session, test_user):
|
|
"""Test that competitors are deleted when brand is deleted."""
|
|
# Create brand with competitors
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Brand for Cascade Test",
|
|
platforms=["wenxin"],
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
|
|
competitor1 = Competitor(
|
|
brand_id=brand.id,
|
|
name="Cascade Competitor 1",
|
|
)
|
|
competitor2 = Competitor(
|
|
brand_id=brand.id,
|
|
name="Cascade Competitor 2",
|
|
)
|
|
async_session.add(competitor1)
|
|
async_session.add(competitor2)
|
|
await async_session.commit()
|
|
|
|
brand_id = brand.id
|
|
|
|
# Delete brand (should cascade delete competitors)
|
|
await async_session.delete(brand)
|
|
await async_session.commit()
|
|
|
|
# Verify competitors are also deleted
|
|
result = await async_session.execute(
|
|
select(Competitor).where(Competitor.brand_id == brand_id)
|
|
)
|
|
competitors = result.scalars().all()
|
|
|
|
assert len(competitors) == 0
|