geo/backend/tests/test_repositories/test_usage_quota_integratio...

378 lines
12 KiB
Python

"""Tests for by_day aggregation and user quota service."""
import uuid
from datetime import datetime, timezone, timedelta
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.user import User
from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository
from app.services.user_quota_service import UserQuotaService, PLAN_MONTHLY_LIMITS
@pytest_asyncio.fixture
async def async_engine():
"""Create async engine for testing with SQLite."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""Create async session for testing."""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def test_user_free(async_session):
"""Create a free plan test user."""
user = User(
id=uuid.uuid4(),
email="free@example.com",
password_hash="hashed_password",
name="Free User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_user_basic(async_session):
"""Create a basic plan test user."""
user = User(
id=uuid.uuid4(),
email="basic@example.com",
password_hash="hashed_password",
name="Basic User",
plan="basic",
max_queries=50,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_user_pro(async_session):
"""Create a pro plan test user."""
user = User(
id=uuid.uuid4(),
email="pro@example.com",
password_hash="hashed_password",
name="Pro User",
plan="pro",
max_queries=500,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
class TestByDayAggregation:
"""Test cases for by_day aggregation in usage summary."""
@pytest.mark.asyncio
async def test_by_day_returns_data(self, async_session, test_user_free):
"""Test that by_day aggregation returns non-empty data when records exist."""
repo = UsageRepository(async_session)
for i in range(3):
await repo.create({
"user_id": test_user_free.id,
"engine_type": "chatgpt",
"query": f"Query {i}",
"cost": 0.01,
"input_tokens": 100,
"output_tokens": 200,
})
summary = await repo.get_summary(str(test_user_free.id), period="month")
assert "by_day" in summary
assert len(summary["by_day"]) > 0, "by_day should not be empty when records exist"
@pytest.mark.asyncio
async def test_by_day_groups_by_date(self, async_session, test_user_free):
"""Test that records are correctly grouped by date."""
repo = UsageRepository(async_session)
for i in range(5):
await repo.create({
"user_id": test_user_free.id,
"engine_type": "deepseek",
"query": f"Query {i}",
"cost": 0.02,
})
summary = await repo.get_summary(str(test_user_free.id), period="month")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in summary["by_day"], f"Today's date {today} should be in by_day"
assert summary["by_day"][today]["queries"] == 5
assert summary["by_day"][today]["cost"] == 0.10
@pytest.mark.asyncio
async def test_by_day_aggregates_tokens(self, async_session, test_user_free):
"""Test that by_day correctly aggregates tokens."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user_free.id,
"engine_type": "qwen",
"query": "Query 1",
"input_tokens": 100,
"output_tokens": 200,
"cost": 0.01,
})
await repo.create({
"user_id": test_user_free.id,
"engine_type": "qwen",
"query": "Query 2",
"input_tokens": 150,
"output_tokens": 300,
"cost": 0.02,
})
summary = await repo.get_summary(str(test_user_free.id), period="month")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert summary["by_day"][today]["input_tokens"] == 250
assert summary["by_day"][today]["output_tokens"] == 500
assert summary["by_day"][today]["cost"] == 0.03
@pytest.mark.asyncio
async def test_by_day_empty_when_no_records(self, async_session, test_user_free):
"""Test that by_day is empty when no records exist."""
repo = UsageRepository(async_session)
summary = await repo.get_summary(str(test_user_free.id), period="month")
assert summary["by_day"] == {}
@pytest.mark.asyncio
async def test_by_day_multiple_days(self, async_session, test_user_free):
"""Test by_day when records span multiple days."""
repo = UsageRepository(async_session)
yesterday = datetime.now(timezone.utc) - timedelta(days=1)
await repo.create({
"user_id": test_user_free.id,
"engine_type": "gemini",
"query": "Yesterday query",
"cost": 0.05,
"timestamp": yesterday,
})
await repo.create({
"user_id": test_user_free.id,
"engine_type": "gemini",
"query": "Today query",
"cost": 0.05,
})
summary = await repo.get_summary(str(test_user_free.id), period="week")
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
yesterday_str = yesterday.strftime("%Y-%m-%d")
assert today in summary["by_day"]
assert yesterday_str in summary["by_day"]
class TestUserQuotaService:
"""Test cases for UserQuotaService with plan-based monthly limits."""
def test_plan_monthly_limits_defined(self):
"""Test that all plan monthly limits are defined."""
assert "free" in PLAN_MONTHLY_LIMITS
assert "basic" in PLAN_MONTHLY_LIMITS
assert "pro" in PLAN_MONTHLY_LIMITS
assert "enterprise" in PLAN_MONTHLY_LIMITS
def test_free_plan_monthly_limit(self):
"""Test that free plan has 10 yuan monthly limit."""
assert PLAN_MONTHLY_LIMITS["free"] == 10.0
def test_basic_plan_monthly_limit(self):
"""Test that basic plan has 50 yuan monthly limit."""
assert PLAN_MONTHLY_LIMITS["basic"] == 50.0
def test_pro_plan_monthly_limit(self):
"""Test that pro plan has 200 yuan monthly limit."""
assert PLAN_MONTHLY_LIMITS["pro"] == 200.0
def test_enterprise_plan_monthly_limit(self):
"""Test that enterprise plan has 1000 yuan monthly limit."""
assert PLAN_MONTHLY_LIMITS["enterprise"] == 1000.0
def test_unknown_plan_defaults_to_free(self):
"""Test that unknown plan defaults to free plan limit."""
service = UserQuotaService()
limit = service.get_monthly_limit("unknown_plan")
assert limit == 10.0
def test_get_monthly_limit_free(self):
"""Test get_monthly_limit for free plan."""
service = UserQuotaService()
assert service.get_monthly_limit("free") == 10.0
def test_get_monthly_limit_basic(self):
"""Test get_monthly_limit for basic plan."""
service = UserQuotaService()
assert service.get_monthly_limit("basic") == 50.0
def test_get_monthly_limit_pro(self):
"""Test get_monthly_limit for pro plan."""
service = UserQuotaService()
assert service.get_monthly_limit("pro") == 200.0
def test_get_monthly_limit_enterprise(self):
"""Test get_monthly_limit for enterprise plan."""
service = UserQuotaService()
assert service.get_monthly_limit("enterprise") == 1000.0
@pytest.mark.asyncio
async def test_check_quota_with_free_plan(self, async_session, test_user_free):
"""Test quota check uses free plan limit (10 yuan)."""
repo = UsageRepository(async_session)
for i in range(5):
await repo.create({
"user_id": test_user_free.id,
"engine_type": "chatgpt",
"query": f"Query {i}",
"cost": 1.0,
})
service = UserQuotaService(session=async_session)
result = await service.check_quota_with_plan(
user_id=str(test_user_free.id),
user_plan="free"
)
assert result["limit"] == 10.0
assert result["used"] == 5.0
assert result["usage_percentage"] == 50.0
@pytest.mark.asyncio
async def test_check_quota_with_basic_plan(self, async_session, test_user_basic):
"""Test quota check uses basic plan limit (50 yuan)."""
repo = UsageRepository(async_session)
for i in range(10):
await repo.create({
"user_id": test_user_basic.id,
"engine_type": "deepseek",
"query": f"Query {i}",
"cost": 2.0,
})
service = UserQuotaService(session=async_session)
result = await service.check_quota_with_plan(
user_id=str(test_user_basic.id),
user_plan="basic"
)
assert result["limit"] == 50.0
assert result["used"] == 20.0
assert result["usage_percentage"] == 40.0
@pytest.mark.asyncio
async def test_check_quota_with_pro_plan(self, async_session, test_user_pro):
"""Test quota check uses pro plan limit (200 yuan)."""
repo = UsageRepository(async_session)
for i in range(10):
await repo.create({
"user_id": test_user_pro.id,
"engine_type": "qwen",
"query": f"Query {i}",
"cost": 10.0,
})
service = UserQuotaService(session=async_session)
result = await service.check_quota_with_plan(
user_id=str(test_user_pro.id),
user_plan="pro"
)
assert result["limit"] == 200.0
assert result["used"] == 100.0
assert result["usage_percentage"] == 50.0
@pytest.mark.asyncio
async def test_check_quota_exceeded_status(self, async_session, test_user_free):
"""Test that exceeded status works correctly with plan limits."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user_free.id,
"engine_type": "gemini",
"query": "Expensive query",
"cost": 15.0,
})
service = UserQuotaService(session=async_session)
result = await service.check_quota_with_plan(
user_id=str(test_user_free.id),
user_plan="free"
)
assert result["limit"] == 10.0
assert result["used"] == 15.0
assert result["usage_percentage"] == 150.0
assert result["status"] == "exceeded"
@pytest.mark.asyncio
async def test_check_quota_warning_status(self, async_session, test_user_basic):
"""Test that warning status works correctly with plan limits."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": test_user_basic.id,
"engine_type": "kimi",
"query": "Moderate query",
"cost": 45.0,
})
service = UserQuotaService(session=async_session)
result = await service.check_quota_with_plan(
user_id=str(test_user_basic.id),
user_plan="basic"
)
assert result["limit"] == 50.0
assert result["used"] == 45.0
assert result["usage_percentage"] == 90.0
assert result["status"] == "warning"