"""Tests for by_day aggregation and user quota service.""" import uuid from datetime import datetime, timezone, timedelta import pytest import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.models.user import User from app.models.usage_record import UsageRecord from app.repositories.usage_repository import UsageRepository from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS from tests.fixtures.auth import _to_uuid @pytest_asyncio.fixture async def async_engine(): """Create async engine for testing with SQLite.""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """Create async session for testing.""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session await session.rollback() @pytest_asyncio.fixture async def test_user_free(async_session): """Create a free plan test user.""" user = User( id=str(uuid.uuid4()), email="free@example.com", password="hashed_password", firstName="Free User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_user_basic(async_session): """Create a basic plan test user.""" user = User( id=str(uuid.uuid4()), email="basic@example.com", password="hashed_password", firstName="Basic User", plan="basic", max_queries=50, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_user_pro(async_session): """Create a pro plan test user.""" user = User( id=str(uuid.uuid4()), email="pro@example.com", password="hashed_password", firstName="Pro User", plan="pro", max_queries=500, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user class TestByDayAggregation: """Test cases for by_day aggregation in usage summary.""" @pytest.mark.asyncio async def test_by_day_returns_data(self, async_session, test_user_free): """Test that by_day aggregation returns non-empty data when records exist.""" repo = UsageRepository(async_session) for i in range(3): await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "chatgpt", "query": f"Query {i}", "cost": 0.01, "input_tokens": 100, "output_tokens": 200, }) summary = await repo.get_summary(str(test_user_free.id), period="month") assert "by_day" in summary assert len(summary["by_day"]) > 0, "by_day should not be empty when records exist" @pytest.mark.asyncio async def test_by_day_groups_by_date(self, async_session, test_user_free): """Test that records are correctly grouped by date.""" repo = UsageRepository(async_session) for i in range(5): await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "deepseek", "query": f"Query {i}", "cost": 0.02, }) summary = await repo.get_summary(str(test_user_free.id), period="month") today = datetime.now(timezone.utc).strftime("%Y-%m-%d") assert today in summary["by_day"], f"Today's date {today} should be in by_day" assert summary["by_day"][today]["queries"] == 5 assert summary["by_day"][today]["cost"] == 0.10 @pytest.mark.asyncio async def test_by_day_aggregates_tokens(self, async_session, test_user_free): """Test that by_day correctly aggregates tokens.""" repo = UsageRepository(async_session) await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "qwen", "query": "Query 1", "input_tokens": 100, "output_tokens": 200, "cost": 0.01, }) await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "qwen", "query": "Query 2", "input_tokens": 150, "output_tokens": 300, "cost": 0.02, }) summary = await repo.get_summary(str(test_user_free.id), period="month") today = datetime.now(timezone.utc).strftime("%Y-%m-%d") assert summary["by_day"][today]["input_tokens"] == 250 assert summary["by_day"][today]["output_tokens"] == 500 assert summary["by_day"][today]["cost"] == 0.03 @pytest.mark.asyncio async def test_by_day_empty_when_no_records(self, async_session, test_user_free): """Test that by_day is empty when no records exist.""" repo = UsageRepository(async_session) summary = await repo.get_summary(str(test_user_free.id), period="month") assert summary["by_day"] == {} @pytest.mark.asyncio async def test_by_day_multiple_days(self, async_session, test_user_free): """Test by_day when records span multiple days.""" repo = UsageRepository(async_session) yesterday = datetime.now(timezone.utc) - timedelta(days=1) await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "gemini", "query": "Yesterday query", "cost": 0.05, "timestamp": yesterday, }) await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "gemini", "query": "Today query", "cost": 0.05, }) summary = await repo.get_summary(str(test_user_free.id), period="week") today = datetime.now(timezone.utc).strftime("%Y-%m-%d") yesterday_str = yesterday.strftime("%Y-%m-%d") assert today in summary["by_day"] assert yesterday_str in summary["by_day"] class TestUserQuotaService: """Test cases for UserQuotaService with plan-based monthly limits.""" def test_plan_monthly_limits_defined(self): """Test that all plan monthly limits are defined.""" assert "free" in PLAN_MONTHLY_LIMITS assert "basic" in PLAN_MONTHLY_LIMITS assert "pro" in PLAN_MONTHLY_LIMITS assert "enterprise" in PLAN_MONTHLY_LIMITS def test_free_plan_monthly_limit(self): """Test that free plan has 10 yuan monthly limit.""" assert PLAN_MONTHLY_LIMITS["free"] == 10.0 def test_basic_plan_monthly_limit(self): """Test that basic plan has 50 yuan monthly limit.""" assert PLAN_MONTHLY_LIMITS["basic"] == 50.0 def test_pro_plan_monthly_limit(self): """Test that pro plan has 200 yuan monthly limit.""" assert PLAN_MONTHLY_LIMITS["pro"] == 200.0 def test_enterprise_plan_monthly_limit(self): """Test that enterprise plan has 1000 yuan monthly limit.""" assert PLAN_MONTHLY_LIMITS["enterprise"] == 1000.0 def test_unknown_plan_defaults_to_free(self): """Test that unknown plan defaults to free plan limit.""" service = UserQuotaService() limit = service.get_monthly_limit("unknown_plan") assert limit == 10.0 def test_get_monthly_limit_free(self): """Test get_monthly_limit for free plan.""" service = UserQuotaService() assert service.get_monthly_limit("free") == 10.0 def test_get_monthly_limit_basic(self): """Test get_monthly_limit for basic plan.""" service = UserQuotaService() assert service.get_monthly_limit("basic") == 50.0 def test_get_monthly_limit_pro(self): """Test get_monthly_limit for pro plan.""" service = UserQuotaService() assert service.get_monthly_limit("pro") == 200.0 def test_get_monthly_limit_enterprise(self): """Test get_monthly_limit for enterprise plan.""" service = UserQuotaService() assert service.get_monthly_limit("enterprise") == 1000.0 @pytest.mark.asyncio async def test_check_quota_with_free_plan(self, async_session, test_user_free): """Test quota check uses free plan limit (10 yuan).""" repo = UsageRepository(async_session) for i in range(5): await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "chatgpt", "query": f"Query {i}", "cost": 1.0, }) service = UserQuotaService(session=async_session) result = await service.check_quota_with_plan( user_id=str(test_user_free.id), user_plan="free" ) assert result["limit"] == 10.0 assert result["used"] == 5.0 assert result["usage_percentage"] == 50.0 @pytest.mark.asyncio async def test_check_quota_with_basic_plan(self, async_session, test_user_basic): """Test quota check uses basic plan limit (50 yuan).""" repo = UsageRepository(async_session) for i in range(10): await repo.create({ "user_id": _to_uuid(test_user_basic.id), "engine_type": "deepseek", "query": f"Query {i}", "cost": 2.0, }) service = UserQuotaService(session=async_session) result = await service.check_quota_with_plan( user_id=str(test_user_basic.id), user_plan="basic" ) assert result["limit"] == 50.0 assert result["used"] == 20.0 assert result["usage_percentage"] == 40.0 @pytest.mark.asyncio async def test_check_quota_with_pro_plan(self, async_session, test_user_pro): """Test quota check uses pro plan limit (200 yuan).""" repo = UsageRepository(async_session) for i in range(10): await repo.create({ "user_id": _to_uuid(test_user_pro.id), "engine_type": "qwen", "query": f"Query {i}", "cost": 10.0, }) service = UserQuotaService(session=async_session) result = await service.check_quota_with_plan( user_id=str(test_user_pro.id), user_plan="pro" ) assert result["limit"] == 200.0 assert result["used"] == 100.0 assert result["usage_percentage"] == 50.0 @pytest.mark.asyncio async def test_check_quota_exceeded_status(self, async_session, test_user_free): """Test that exceeded status works correctly with plan limits.""" repo = UsageRepository(async_session) await repo.create({ "user_id": _to_uuid(test_user_free.id), "engine_type": "gemini", "query": "Expensive query", "cost": 15.0, }) service = UserQuotaService(session=async_session) result = await service.check_quota_with_plan( user_id=str(test_user_free.id), user_plan="free" ) assert result["limit"] == 10.0 assert result["used"] == 15.0 assert result["usage_percentage"] == 150.0 assert result["status"] == "exceeded" @pytest.mark.asyncio async def test_check_quota_warning_status(self, async_session, test_user_basic): """Test that warning status works correctly with plan limits.""" repo = UsageRepository(async_session) await repo.create({ "user_id": _to_uuid(test_user_basic.id), "engine_type": "kimi", "query": "Moderate query", "cost": 45.0, }) service = UserQuotaService(session=async_session) result = await service.check_quota_with_plan( user_id=str(test_user_basic.id), user_plan="basic" ) assert result["limit"] == 50.0 assert result["used"] == 45.0 assert result["usage_percentage"] == 90.0 assert result["status"] == "warning"