378 lines
12 KiB
Python
378 lines
12 KiB
Python
"""Tests for by_day aggregation and user quota service."""
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.models.user import User
|
|
from app.models.usage_record import UsageRecord
|
|
from app.repositories.usage_repository import UsageRepository
|
|
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""Create async engine for testing with SQLite."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""Create async session for testing."""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user_free(async_session):
|
|
"""Create a free plan test user."""
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="free@example.com",
|
|
password_hash="hashed_password",
|
|
name="Free User",
|
|
plan="free",
|
|
max_queries=5,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user_basic(async_session):
|
|
"""Create a basic plan test user."""
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="basic@example.com",
|
|
password_hash="hashed_password",
|
|
name="Basic User",
|
|
plan="basic",
|
|
max_queries=50,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user_pro(async_session):
|
|
"""Create a pro plan test user."""
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="pro@example.com",
|
|
password_hash="hashed_password",
|
|
name="Pro User",
|
|
plan="pro",
|
|
max_queries=500,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
class TestByDayAggregation:
|
|
"""Test cases for by_day aggregation in usage summary."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_by_day_returns_data(self, async_session, test_user_free):
|
|
"""Test that by_day aggregation returns non-empty data when records exist."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(3):
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "chatgpt",
|
|
"query": f"Query {i}",
|
|
"cost": 0.01,
|
|
"input_tokens": 100,
|
|
"output_tokens": 200,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
|
|
|
assert "by_day" in summary
|
|
assert len(summary["by_day"]) > 0, "by_day should not be empty when records exist"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_by_day_groups_by_date(self, async_session, test_user_free):
|
|
"""Test that records are correctly grouped by date."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(5):
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "deepseek",
|
|
"query": f"Query {i}",
|
|
"cost": 0.02,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
|
|
assert today in summary["by_day"], f"Today's date {today} should be in by_day"
|
|
assert summary["by_day"][today]["queries"] == 5
|
|
assert summary["by_day"][today]["cost"] == 0.10
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_by_day_aggregates_tokens(self, async_session, test_user_free):
|
|
"""Test that by_day correctly aggregates tokens."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "qwen",
|
|
"query": "Query 1",
|
|
"input_tokens": 100,
|
|
"output_tokens": 200,
|
|
"cost": 0.01,
|
|
})
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "qwen",
|
|
"query": "Query 2",
|
|
"input_tokens": 150,
|
|
"output_tokens": 300,
|
|
"cost": 0.02,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
|
|
assert summary["by_day"][today]["input_tokens"] == 250
|
|
assert summary["by_day"][today]["output_tokens"] == 500
|
|
assert summary["by_day"][today]["cost"] == 0.03
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_by_day_empty_when_no_records(self, async_session, test_user_free):
|
|
"""Test that by_day is empty when no records exist."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
summary = await repo.get_summary(str(test_user_free.id), period="month")
|
|
|
|
assert summary["by_day"] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_by_day_multiple_days(self, async_session, test_user_free):
|
|
"""Test by_day when records span multiple days."""
|
|
repo = UsageRepository(async_session)
|
|
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
|
|
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "gemini",
|
|
"query": "Yesterday query",
|
|
"cost": 0.05,
|
|
"timestamp": yesterday,
|
|
})
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "gemini",
|
|
"query": "Today query",
|
|
"cost": 0.05,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user_free.id), period="week")
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
yesterday_str = yesterday.strftime("%Y-%m-%d")
|
|
|
|
assert today in summary["by_day"]
|
|
assert yesterday_str in summary["by_day"]
|
|
|
|
|
|
class TestUserQuotaService:
|
|
"""Test cases for UserQuotaService with plan-based monthly limits."""
|
|
|
|
def test_plan_monthly_limits_defined(self):
|
|
"""Test that all plan monthly limits are defined."""
|
|
assert "free" in PLAN_MONTHLY_LIMITS
|
|
assert "basic" in PLAN_MONTHLY_LIMITS
|
|
assert "pro" in PLAN_MONTHLY_LIMITS
|
|
assert "enterprise" in PLAN_MONTHLY_LIMITS
|
|
|
|
def test_free_plan_monthly_limit(self):
|
|
"""Test that free plan has 10 yuan monthly limit."""
|
|
assert PLAN_MONTHLY_LIMITS["free"] == 10.0
|
|
|
|
def test_basic_plan_monthly_limit(self):
|
|
"""Test that basic plan has 50 yuan monthly limit."""
|
|
assert PLAN_MONTHLY_LIMITS["basic"] == 50.0
|
|
|
|
def test_pro_plan_monthly_limit(self):
|
|
"""Test that pro plan has 200 yuan monthly limit."""
|
|
assert PLAN_MONTHLY_LIMITS["pro"] == 200.0
|
|
|
|
def test_enterprise_plan_monthly_limit(self):
|
|
"""Test that enterprise plan has 1000 yuan monthly limit."""
|
|
assert PLAN_MONTHLY_LIMITS["enterprise"] == 1000.0
|
|
|
|
def test_unknown_plan_defaults_to_free(self):
|
|
"""Test that unknown plan defaults to free plan limit."""
|
|
service = UserQuotaService()
|
|
limit = service.get_monthly_limit("unknown_plan")
|
|
assert limit == 10.0
|
|
|
|
def test_get_monthly_limit_free(self):
|
|
"""Test get_monthly_limit for free plan."""
|
|
service = UserQuotaService()
|
|
assert service.get_monthly_limit("free") == 10.0
|
|
|
|
def test_get_monthly_limit_basic(self):
|
|
"""Test get_monthly_limit for basic plan."""
|
|
service = UserQuotaService()
|
|
assert service.get_monthly_limit("basic") == 50.0
|
|
|
|
def test_get_monthly_limit_pro(self):
|
|
"""Test get_monthly_limit for pro plan."""
|
|
service = UserQuotaService()
|
|
assert service.get_monthly_limit("pro") == 200.0
|
|
|
|
def test_get_monthly_limit_enterprise(self):
|
|
"""Test get_monthly_limit for enterprise plan."""
|
|
service = UserQuotaService()
|
|
assert service.get_monthly_limit("enterprise") == 1000.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_with_free_plan(self, async_session, test_user_free):
|
|
"""Test quota check uses free plan limit (10 yuan)."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(5):
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "chatgpt",
|
|
"query": f"Query {i}",
|
|
"cost": 1.0,
|
|
})
|
|
|
|
service = UserQuotaService(session=async_session)
|
|
result = await service.check_quota_with_plan(
|
|
user_id=str(test_user_free.id),
|
|
user_plan="free"
|
|
)
|
|
|
|
assert result["limit"] == 10.0
|
|
assert result["used"] == 5.0
|
|
assert result["usage_percentage"] == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_with_basic_plan(self, async_session, test_user_basic):
|
|
"""Test quota check uses basic plan limit (50 yuan)."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(10):
|
|
await repo.create({
|
|
"user_id": test_user_basic.id,
|
|
"engine_type": "deepseek",
|
|
"query": f"Query {i}",
|
|
"cost": 2.0,
|
|
})
|
|
|
|
service = UserQuotaService(session=async_session)
|
|
result = await service.check_quota_with_plan(
|
|
user_id=str(test_user_basic.id),
|
|
user_plan="basic"
|
|
)
|
|
|
|
assert result["limit"] == 50.0
|
|
assert result["used"] == 20.0
|
|
assert result["usage_percentage"] == 40.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_with_pro_plan(self, async_session, test_user_pro):
|
|
"""Test quota check uses pro plan limit (200 yuan)."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(10):
|
|
await repo.create({
|
|
"user_id": test_user_pro.id,
|
|
"engine_type": "qwen",
|
|
"query": f"Query {i}",
|
|
"cost": 10.0,
|
|
})
|
|
|
|
service = UserQuotaService(session=async_session)
|
|
result = await service.check_quota_with_plan(
|
|
user_id=str(test_user_pro.id),
|
|
user_plan="pro"
|
|
)
|
|
|
|
assert result["limit"] == 200.0
|
|
assert result["used"] == 100.0
|
|
assert result["usage_percentage"] == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_exceeded_status(self, async_session, test_user_free):
|
|
"""Test that exceeded status works correctly with plan limits."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user_free.id,
|
|
"engine_type": "gemini",
|
|
"query": "Expensive query",
|
|
"cost": 15.0,
|
|
})
|
|
|
|
service = UserQuotaService(session=async_session)
|
|
result = await service.check_quota_with_plan(
|
|
user_id=str(test_user_free.id),
|
|
user_plan="free"
|
|
)
|
|
|
|
assert result["limit"] == 10.0
|
|
assert result["used"] == 15.0
|
|
assert result["usage_percentage"] == 150.0
|
|
assert result["status"] == "exceeded"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_warning_status(self, async_session, test_user_basic):
|
|
"""Test that warning status works correctly with plan limits."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user_basic.id,
|
|
"engine_type": "kimi",
|
|
"query": "Moderate query",
|
|
"cost": 45.0,
|
|
})
|
|
|
|
service = UserQuotaService(session=async_session)
|
|
result = await service.check_quota_with_plan(
|
|
user_id=str(test_user_basic.id),
|
|
user_plan="basic"
|
|
)
|
|
|
|
assert result["limit"] == 50.0
|
|
assert result["used"] == 45.0
|
|
assert result["usage_percentage"] == 90.0
|
|
assert result["status"] == "warning"
|