geo/backend/tests/test_repositories/test_usage_repository.py

373 lines
12 KiB
Python

"""Tests for UsageRepository."""
import uuid
from datetime import datetime, timezone, timedelta
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.user import User
from app.models.usage_record import UsageRecord
from app.repositories.usage_repository import UsageRepository
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
async def async_engine():
"""Create async engine for testing with SQLite."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""Create async session for testing."""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def test_user(async_session):
"""Create a test user."""
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password="hashed_password",
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
class TestUsageRepository:
"""Test cases for UsageRepository."""
@pytest.mark.asyncio
async def test_create(self, async_session, test_user):
"""Test creating a usage record."""
repo = UsageRepository(async_session)
data = {
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Test query",
"input_tokens": 100,
"output_tokens": 200,
"cost": 0.015,
"extra_data": {"model": "gpt-4"},
}
record = await repo.create(data)
assert record.id is not None
assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "chatgpt"
assert record.query == "Test query"
assert record.input_tokens == 100
assert record.output_tokens == 200
assert record.cost == 0.015
assert record.extra_data == {"model": "gpt-4"}
@pytest.mark.asyncio
async def test_create_minimal(self, async_session, test_user):
"""Test creating a usage record with minimal data."""
repo = UsageRepository(async_session)
data = {
"user_id": _to_uuid(test_user.id),
"engine_type": "deepseek",
"query": "Minimal query",
}
record = await repo.create(data)
assert record.id is not None
assert record.user_id == _to_uuid(test_user.id)
assert record.engine_type == "deepseek"
assert record.query == "Minimal query"
assert record.input_tokens == 0
assert record.output_tokens == 0
assert record.cost == 0.0
assert record.extra_data == {}
@pytest.mark.asyncio
async def test_get_summary(self, async_session, test_user):
"""Test getting usage summary."""
repo = UsageRepository(async_session)
for i in range(3):
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": f"Query {i}",
"input_tokens": 100,
"output_tokens": 200,
"cost": 0.01,
})
summary = await repo.get_summary(str(test_user.id), period="month")
assert summary["period"] == "month"
assert summary["total_queries"] == 3
assert summary["total_input_tokens"] == 300
assert summary["total_output_tokens"] == 600
assert summary["total_cost"] == 0.03
assert "chatgpt" in summary["by_engine"]
assert summary["by_engine"]["chatgpt"]["queries"] == 3
@pytest.mark.asyncio
async def test_get_summary_by_engine(self, async_session, test_user):
"""Test getting usage summary grouped by engine."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "ChatGPT query",
"cost": 0.02,
})
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "deepseek",
"query": "DeepSeek query",
"cost": 0.01,
})
summary = await repo.get_summary(str(test_user.id), period="month")
assert len(summary["by_engine"]) == 2
assert summary["by_engine"]["chatgpt"]["queries"] == 1
assert summary["by_engine"]["chatgpt"]["cost"] == 0.02
assert summary["by_engine"]["deepseek"]["queries"] == 1
assert summary["by_engine"]["deepseek"]["cost"] == 0.01
@pytest.mark.asyncio
async def test_get_summary_by_day(self, async_session, test_user):
"""Test getting usage summary grouped by day."""
repo = UsageRepository(async_session)
for i in range(2):
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "qwen",
"query": f"Query {i}",
"cost": 0.01,
})
summary = await repo.get_summary(str(test_user.id), period="month")
assert len(summary["by_day"]) >= 1
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
assert today in summary["by_day"]
assert summary["by_day"][today]["queries"] == 2
@pytest.mark.asyncio
async def test_get_summary_with_brand_filter(self, async_session, test_user):
"""Test getting usage summary filtered by brand."""
repo = UsageRepository(async_session)
brand_id = uuid.uuid4()
await repo.create({
"user_id": _to_uuid(test_user.id),
"brand_id": brand_id,
"engine_type": "kimi",
"query": "Brand query",
"cost": 0.05,
})
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "kimi",
"query": "No brand query",
"cost": 0.03,
})
summary = await repo.get_summary(
str(test_user.id),
period="month",
brand_id=str(brand_id),
)
assert summary["total_queries"] == 1
assert summary["total_cost"] == 0.05
@pytest.mark.asyncio
async def test_check_quota(self, async_session, test_user):
"""Test checking quota usage."""
repo = UsageRepository(async_session)
for i in range(5):
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "gemini",
"query": f"Query {i}",
"cost": 1.0,
})
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
assert result["used"] == 5.0
assert result["limit"] == 100.0
assert result["usage_percentage"] == 5.0
assert result["status"] == "ok"
@pytest.mark.asyncio
async def test_check_quota_warning(self, async_session, test_user):
"""Test quota warning status."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Expensive query",
"cost": 85.0,
})
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
assert result["status"] == "warning"
assert result["usage_percentage"] == 85.0
@pytest.mark.asyncio
async def test_check_quota_exceeded(self, async_session, test_user):
"""Test quota exceeded status."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "chatgpt",
"query": "Very expensive query",
"cost": 120.0,
})
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
assert result["status"] == "exceeded"
assert result["usage_percentage"] == 120.0
@pytest.mark.asyncio
async def test_get_by_id(self, async_session, test_user):
"""Test getting a usage record by ID."""
repo = UsageRepository(async_session)
created = await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "wenxin",
"query": "Get by ID test",
"cost": 0.5,
})
fetched = await repo.get_by_id(created.id)
assert fetched is not None
assert fetched.id == created.id
assert fetched.engine_type == "wenxin"
@pytest.mark.asyncio
async def test_get_by_user(self, async_session, test_user):
"""Test getting usage records by user."""
repo = UsageRepository(async_session)
for i in range(3):
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "doubao",
"query": f"User query {i}",
"cost": 0.1,
})
records = await repo.get_by_user(str(test_user.id))
assert len(records) == 3
@pytest.mark.asyncio
async def test_get_by_user_and_engine(self, async_session, test_user):
"""Test getting usage records by user and engine."""
repo = UsageRepository(async_session)
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "xinghuo",
"query": "Xinghuo query",
"cost": 0.1,
})
await repo.create({
"user_id": _to_uuid(test_user.id),
"engine_type": "perplexity",
"query": "Perplexity query",
"cost": 0.2,
})
records = await repo.get_by_user_and_engine(
str(test_user.id),
"xinghuo",
)
assert len(records) == 1
assert records[0].engine_type == "xinghuo"
@pytest.mark.asyncio
async def test_empty_summary(self, async_session, test_user):
"""Test getting summary when no records exist."""
repo = UsageRepository(async_session)
summary = await repo.get_summary(str(test_user.id), period="month")
assert summary["total_queries"] == 0
assert summary["total_input_tokens"] == 0
assert summary["total_output_tokens"] == 0
assert summary["total_cost"] == 0.0
assert summary["by_engine"] == {}
assert summary["by_day"] == {}
@pytest.mark.asyncio
async def test_empty_quota_check(self, async_session, test_user):
"""Test checking quota when no records exist."""
repo = UsageRepository(async_session)
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
assert result["used"] == 0.0
assert result["usage_percentage"] == 0.0
assert result["status"] == "ok"
@pytest.mark.asyncio
async def test_uuid_handling(self, async_session, test_user):
"""Test that UUID handling works correctly."""
repo = UsageRepository(async_session)
data = {
"user_id": _to_uuid(test_user.id),
"engine_type": "yuanbao",
"query": "UUID test",
"cost": 0.5,
}
record = await repo.create(data)
summary = await repo.get_summary(test_user.id, period="month")
assert summary["total_queries"] == 1
assert record.user_id == _to_uuid(test_user.id)