"""Tests for UsageRepository.""" import uuid from datetime import datetime, timezone, timedelta import pytest import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.models.user import User from app.models.usage_record import UsageRecord from app.repositories.usage_repository import UsageRepository @pytest_asyncio.fixture async def async_engine(): """Create async engine for testing with SQLite.""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """Create async session for testing.""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session await session.rollback() @pytest_asyncio.fixture async def test_user(async_session): """Create a test user.""" user = User( id=uuid.uuid4(), email="test@example.com", password_hash="hashed_password", name="Test User", plan="free", max_queries=5, is_active=True, email_verified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user class TestUsageRepository: """Test cases for UsageRepository.""" @pytest.mark.asyncio async def test_create(self, async_session, test_user): """Test creating a usage record.""" repo = UsageRepository(async_session) data = { "user_id": test_user.id, "engine_type": "chatgpt", "query": "Test query", "input_tokens": 100, "output_tokens": 200, "cost": 0.015, "extra_data": {"model": "gpt-4"}, } record = await repo.create(data) assert record.id is not None assert record.user_id == test_user.id assert record.engine_type == "chatgpt" assert record.query == "Test query" assert record.input_tokens == 100 assert record.output_tokens == 200 assert record.cost == 0.015 assert record.extra_data == {"model": "gpt-4"} @pytest.mark.asyncio async def test_create_minimal(self, async_session, test_user): """Test creating a usage record with minimal data.""" repo = UsageRepository(async_session) data = { "user_id": test_user.id, "engine_type": "deepseek", "query": "Minimal query", } record = await repo.create(data) assert record.id is not None assert record.user_id == test_user.id assert record.engine_type == "deepseek" assert record.query == "Minimal query" assert record.input_tokens == 0 assert record.output_tokens == 0 assert record.cost == 0.0 assert record.extra_data == {} @pytest.mark.asyncio async def test_get_summary(self, async_session, test_user): """Test getting usage summary.""" repo = UsageRepository(async_session) for i in range(3): await repo.create({ "user_id": test_user.id, "engine_type": "chatgpt", "query": f"Query {i}", "input_tokens": 100, "output_tokens": 200, "cost": 0.01, }) summary = await repo.get_summary(str(test_user.id), period="month") assert summary["period"] == "month" assert summary["total_queries"] == 3 assert summary["total_input_tokens"] == 300 assert summary["total_output_tokens"] == 600 assert summary["total_cost"] == 0.03 assert "chatgpt" in summary["by_engine"] assert summary["by_engine"]["chatgpt"]["queries"] == 3 @pytest.mark.asyncio async def test_get_summary_by_engine(self, async_session, test_user): """Test getting usage summary grouped by engine.""" repo = UsageRepository(async_session) await repo.create({ "user_id": test_user.id, "engine_type": "chatgpt", "query": "ChatGPT query", "cost": 0.02, }) await repo.create({ "user_id": test_user.id, "engine_type": "deepseek", "query": "DeepSeek query", "cost": 0.01, }) summary = await repo.get_summary(str(test_user.id), period="month") assert len(summary["by_engine"]) == 2 assert summary["by_engine"]["chatgpt"]["queries"] == 1 assert summary["by_engine"]["chatgpt"]["cost"] == 0.02 assert summary["by_engine"]["deepseek"]["queries"] == 1 assert summary["by_engine"]["deepseek"]["cost"] == 0.01 @pytest.mark.asyncio async def test_get_summary_by_day(self, async_session, test_user): """Test getting usage summary grouped by day.""" repo = UsageRepository(async_session) for i in range(2): await repo.create({ "user_id": test_user.id, "engine_type": "qwen", "query": f"Query {i}", "cost": 0.01, }) summary = await repo.get_summary(str(test_user.id), period="month") assert len(summary["by_day"]) >= 1 today = datetime.now(timezone.utc).strftime("%Y-%m-%d") assert today in summary["by_day"] assert summary["by_day"][today]["queries"] == 2 @pytest.mark.asyncio async def test_get_summary_with_brand_filter(self, async_session, test_user): """Test getting usage summary filtered by brand.""" repo = UsageRepository(async_session) brand_id = uuid.uuid4() await repo.create({ "user_id": test_user.id, "brand_id": brand_id, "engine_type": "kimi", "query": "Brand query", "cost": 0.05, }) await repo.create({ "user_id": test_user.id, "engine_type": "kimi", "query": "No brand query", "cost": 0.03, }) summary = await repo.get_summary( str(test_user.id), period="month", brand_id=str(brand_id), ) assert summary["total_queries"] == 1 assert summary["total_cost"] == 0.05 @pytest.mark.asyncio async def test_check_quota(self, async_session, test_user): """Test checking quota usage.""" repo = UsageRepository(async_session) for i in range(5): await repo.create({ "user_id": test_user.id, "engine_type": "gemini", "query": f"Query {i}", "cost": 1.0, }) result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) assert result["used"] == 5.0 assert result["limit"] == 100.0 assert result["usage_percentage"] == 5.0 assert result["status"] == "ok" @pytest.mark.asyncio async def test_check_quota_warning(self, async_session, test_user): """Test quota warning status.""" repo = UsageRepository(async_session) await repo.create({ "user_id": test_user.id, "engine_type": "chatgpt", "query": "Expensive query", "cost": 85.0, }) result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) assert result["status"] == "warning" assert result["usage_percentage"] == 85.0 @pytest.mark.asyncio async def test_check_quota_exceeded(self, async_session, test_user): """Test quota exceeded status.""" repo = UsageRepository(async_session) await repo.create({ "user_id": test_user.id, "engine_type": "chatgpt", "query": "Very expensive query", "cost": 120.0, }) result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) assert result["status"] == "exceeded" assert result["usage_percentage"] == 120.0 @pytest.mark.asyncio async def test_get_by_id(self, async_session, test_user): """Test getting a usage record by ID.""" repo = UsageRepository(async_session) created = await repo.create({ "user_id": test_user.id, "engine_type": "wenxin", "query": "Get by ID test", "cost": 0.5, }) fetched = await repo.get_by_id(created.id) assert fetched is not None assert fetched.id == created.id assert fetched.engine_type == "wenxin" @pytest.mark.asyncio async def test_get_by_user(self, async_session, test_user): """Test getting usage records by user.""" repo = UsageRepository(async_session) for i in range(3): await repo.create({ "user_id": test_user.id, "engine_type": "doubao", "query": f"User query {i}", "cost": 0.1, }) records = await repo.get_by_user(str(test_user.id)) assert len(records) == 3 @pytest.mark.asyncio async def test_get_by_user_and_engine(self, async_session, test_user): """Test getting usage records by user and engine.""" repo = UsageRepository(async_session) await repo.create({ "user_id": test_user.id, "engine_type": "xinghuo", "query": "Xinghuo query", "cost": 0.1, }) await repo.create({ "user_id": test_user.id, "engine_type": "perplexity", "query": "Perplexity query", "cost": 0.2, }) records = await repo.get_by_user_and_engine( str(test_user.id), "xinghuo", ) assert len(records) == 1 assert records[0].engine_type == "xinghuo" @pytest.mark.asyncio async def test_empty_summary(self, async_session, test_user): """Test getting summary when no records exist.""" repo = UsageRepository(async_session) summary = await repo.get_summary(str(test_user.id), period="month") assert summary["total_queries"] == 0 assert summary["total_input_tokens"] == 0 assert summary["total_output_tokens"] == 0 assert summary["total_cost"] == 0.0 assert summary["by_engine"] == {} assert summary["by_day"] == {} @pytest.mark.asyncio async def test_empty_quota_check(self, async_session, test_user): """Test checking quota when no records exist.""" repo = UsageRepository(async_session) result = await repo.check_quota(str(test_user.id), monthly_limit=100.0) assert result["used"] == 0.0 assert result["usage_percentage"] == 0.0 assert result["status"] == "ok" @pytest.mark.asyncio async def test_uuid_handling(self, async_session, test_user): """Test that UUID handling works correctly.""" repo = UsageRepository(async_session) data = { "user_id": test_user.id, "engine_type": "yuanbao", "query": "UUID test", "cost": 0.5, } record = await repo.create(data) summary = await repo.get_summary(test_user.id, period="month") assert summary["total_queries"] == 1 assert record.user_id == test_user.id