124 lines
4.5 KiB
Python
124 lines
4.5 KiB
Python
import uuid
|
|
from datetime import date, datetime
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from app.models.subscription import Subscription
|
|
from app.models.user import User
|
|
|
|
|
|
class TestSubscriptionModel:
|
|
|
|
def test_subscription_table_name(self):
|
|
assert Subscription.__tablename__ == "subscriptions"
|
|
|
|
def test_subscription_has_required_fields(self):
|
|
fields = Subscription.__table__.columns.keys()
|
|
assert "id" in fields
|
|
assert "user_id" in fields
|
|
assert "plan" in fields
|
|
assert "status" in fields
|
|
assert "start_date" in fields
|
|
assert "end_date" in fields
|
|
assert "amount" in fields
|
|
assert "payment_method" in fields
|
|
assert "payment_id" in fields
|
|
assert "created_at" in fields
|
|
|
|
def test_subscription_field_types(self):
|
|
columns = Subscription.__table__.columns
|
|
assert "UUID" in str(columns["id"].type).upper()
|
|
assert "VARCHAR" in str(columns["user_id"].type).upper() or "STRING" in str(columns["user_id"].type).upper()
|
|
assert "VARCHAR" in str(columns["plan"].type).upper() or "STRING" in str(columns["plan"].type).upper()
|
|
assert "VARCHAR" in str(columns["status"].type).upper() or "STRING" in str(columns["status"].type).upper()
|
|
assert "DATE" in str(columns["start_date"].type).upper()
|
|
assert "DATE" in str(columns["end_date"].type).upper()
|
|
assert "NUMERIC" in str(columns["amount"].type).upper()
|
|
|
|
def test_subscription_relationships_defined(self):
|
|
relationships = Subscription.__mapper__.relationships
|
|
rel_keys = relationships.keys()
|
|
assert "user" in rel_keys
|
|
|
|
def test_subscription_plan_field_allows_values(self):
|
|
columns = Subscription.__table__.columns
|
|
plan_col = columns["plan"]
|
|
assert plan_col.nullable is False
|
|
|
|
def test_subscription_status_default(self):
|
|
columns = Subscription.__table__.columns
|
|
status_col = columns["status"]
|
|
assert status_col.default is not None
|
|
assert status_col.default.arg == "active"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscription_create(self, async_session, test_user):
|
|
subscription = Subscription(
|
|
id=uuid.uuid4(),
|
|
user_id=test_user.id,
|
|
plan="pro",
|
|
status="active",
|
|
start_date=date(2025, 1, 1),
|
|
end_date=date(2025, 12, 31),
|
|
amount=99.99,
|
|
payment_method="credit_card",
|
|
payment_id="pay_abc123",
|
|
)
|
|
async_session.add(subscription)
|
|
await async_session.commit()
|
|
await async_session.refresh(subscription)
|
|
|
|
assert subscription.id is not None
|
|
assert subscription.user_id == test_user.id
|
|
assert subscription.plan == "pro"
|
|
assert subscription.status == "active"
|
|
assert subscription.start_date == date(2025, 1, 1)
|
|
assert subscription.end_date == date(2025, 12, 31)
|
|
assert subscription.amount is not None
|
|
assert subscription.payment_method == "credit_card"
|
|
assert subscription.payment_id == "pay_abc123"
|
|
assert subscription.created_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscription_default_status(self, async_session, test_user):
|
|
subscription = Subscription(
|
|
user_id=test_user.id,
|
|
plan="free",
|
|
start_date=date(2025, 1, 1),
|
|
end_date=date(2025, 12, 31),
|
|
)
|
|
async_session.add(subscription)
|
|
await async_session.commit()
|
|
await async_session.refresh(subscription)
|
|
|
|
assert subscription.status == "active"
|
|
assert subscription.amount is None
|
|
assert subscription.payment_method is None
|
|
assert subscription.payment_id is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscription_query_by_user(self, async_session, test_user):
|
|
sub1 = Subscription(
|
|
user_id=test_user.id,
|
|
plan="free",
|
|
start_date=date(2025, 1, 1),
|
|
end_date=date(2025, 6, 30),
|
|
)
|
|
sub2 = Subscription(
|
|
user_id=test_user.id,
|
|
plan="pro",
|
|
start_date=date(2025, 7, 1),
|
|
end_date=date(2025, 12, 31),
|
|
)
|
|
async_session.add(sub1)
|
|
async_session.add(sub2)
|
|
await async_session.commit()
|
|
|
|
result = await async_session.execute(
|
|
select(Subscription).where(Subscription.user_id == test_user.id)
|
|
)
|
|
subscriptions = result.scalars().all()
|
|
|
|
assert len(subscriptions) == 2
|