372 lines
12 KiB
Python
372 lines
12 KiB
Python
"""Tests for UsageRepository."""
|
|
import uuid
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.models.user import User
|
|
from app.models.usage_record import UsageRecord
|
|
from app.repositories.usage_repository import UsageRepository
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""Create async engine for testing with SQLite."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""Create async session for testing."""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
"""Create a test user."""
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="test@example.com",
|
|
password_hash="hashed_password",
|
|
name="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
is_active=True,
|
|
email_verified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
class TestUsageRepository:
|
|
"""Test cases for UsageRepository."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create(self, async_session, test_user):
|
|
"""Test creating a usage record."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
data = {
|
|
"user_id": test_user.id,
|
|
"engine_type": "chatgpt",
|
|
"query": "Test query",
|
|
"input_tokens": 100,
|
|
"output_tokens": 200,
|
|
"cost": 0.015,
|
|
"extra_data": {"model": "gpt-4"},
|
|
}
|
|
|
|
record = await repo.create(data)
|
|
|
|
assert record.id is not None
|
|
assert record.user_id == test_user.id
|
|
assert record.engine_type == "chatgpt"
|
|
assert record.query == "Test query"
|
|
assert record.input_tokens == 100
|
|
assert record.output_tokens == 200
|
|
assert record.cost == 0.015
|
|
assert record.extra_data == {"model": "gpt-4"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_minimal(self, async_session, test_user):
|
|
"""Test creating a usage record with minimal data."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
data = {
|
|
"user_id": test_user.id,
|
|
"engine_type": "deepseek",
|
|
"query": "Minimal query",
|
|
}
|
|
|
|
record = await repo.create(data)
|
|
|
|
assert record.id is not None
|
|
assert record.user_id == test_user.id
|
|
assert record.engine_type == "deepseek"
|
|
assert record.query == "Minimal query"
|
|
assert record.input_tokens == 0
|
|
assert record.output_tokens == 0
|
|
assert record.cost == 0.0
|
|
assert record.extra_data == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_summary(self, async_session, test_user):
|
|
"""Test getting usage summary."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(3):
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "chatgpt",
|
|
"query": f"Query {i}",
|
|
"input_tokens": 100,
|
|
"output_tokens": 200,
|
|
"cost": 0.01,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user.id), period="month")
|
|
|
|
assert summary["period"] == "month"
|
|
assert summary["total_queries"] == 3
|
|
assert summary["total_input_tokens"] == 300
|
|
assert summary["total_output_tokens"] == 600
|
|
assert summary["total_cost"] == 0.03
|
|
assert "chatgpt" in summary["by_engine"]
|
|
assert summary["by_engine"]["chatgpt"]["queries"] == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_summary_by_engine(self, async_session, test_user):
|
|
"""Test getting usage summary grouped by engine."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "chatgpt",
|
|
"query": "ChatGPT query",
|
|
"cost": 0.02,
|
|
})
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "deepseek",
|
|
"query": "DeepSeek query",
|
|
"cost": 0.01,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user.id), period="month")
|
|
|
|
assert len(summary["by_engine"]) == 2
|
|
assert summary["by_engine"]["chatgpt"]["queries"] == 1
|
|
assert summary["by_engine"]["chatgpt"]["cost"] == 0.02
|
|
assert summary["by_engine"]["deepseek"]["queries"] == 1
|
|
assert summary["by_engine"]["deepseek"]["cost"] == 0.01
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_summary_by_day(self, async_session, test_user):
|
|
"""Test getting usage summary grouped by day."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(2):
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "qwen",
|
|
"query": f"Query {i}",
|
|
"cost": 0.01,
|
|
})
|
|
|
|
summary = await repo.get_summary(str(test_user.id), period="month")
|
|
|
|
assert len(summary["by_day"]) >= 1
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
assert today in summary["by_day"]
|
|
assert summary["by_day"][today]["queries"] == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_summary_with_brand_filter(self, async_session, test_user):
|
|
"""Test getting usage summary filtered by brand."""
|
|
repo = UsageRepository(async_session)
|
|
brand_id = uuid.uuid4()
|
|
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"brand_id": brand_id,
|
|
"engine_type": "kimi",
|
|
"query": "Brand query",
|
|
"cost": 0.05,
|
|
})
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "kimi",
|
|
"query": "No brand query",
|
|
"cost": 0.03,
|
|
})
|
|
|
|
summary = await repo.get_summary(
|
|
str(test_user.id),
|
|
period="month",
|
|
brand_id=str(brand_id),
|
|
)
|
|
|
|
assert summary["total_queries"] == 1
|
|
assert summary["total_cost"] == 0.05
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota(self, async_session, test_user):
|
|
"""Test checking quota usage."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(5):
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "gemini",
|
|
"query": f"Query {i}",
|
|
"cost": 1.0,
|
|
})
|
|
|
|
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
|
|
|
assert result["used"] == 5.0
|
|
assert result["limit"] == 100.0
|
|
assert result["usage_percentage"] == 5.0
|
|
assert result["status"] == "ok"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_warning(self, async_session, test_user):
|
|
"""Test quota warning status."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "chatgpt",
|
|
"query": "Expensive query",
|
|
"cost": 85.0,
|
|
})
|
|
|
|
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
|
|
|
assert result["status"] == "warning"
|
|
assert result["usage_percentage"] == 85.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_quota_exceeded(self, async_session, test_user):
|
|
"""Test quota exceeded status."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "chatgpt",
|
|
"query": "Very expensive query",
|
|
"cost": 120.0,
|
|
})
|
|
|
|
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
|
|
|
assert result["status"] == "exceeded"
|
|
assert result["usage_percentage"] == 120.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id(self, async_session, test_user):
|
|
"""Test getting a usage record by ID."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
created = await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "wenxin",
|
|
"query": "Get by ID test",
|
|
"cost": 0.5,
|
|
})
|
|
|
|
fetched = await repo.get_by_id(created.id)
|
|
|
|
assert fetched is not None
|
|
assert fetched.id == created.id
|
|
assert fetched.engine_type == "wenxin"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_user(self, async_session, test_user):
|
|
"""Test getting usage records by user."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
for i in range(3):
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "doubao",
|
|
"query": f"User query {i}",
|
|
"cost": 0.1,
|
|
})
|
|
|
|
records = await repo.get_by_user(str(test_user.id))
|
|
|
|
assert len(records) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_user_and_engine(self, async_session, test_user):
|
|
"""Test getting usage records by user and engine."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "xinghuo",
|
|
"query": "Xinghuo query",
|
|
"cost": 0.1,
|
|
})
|
|
await repo.create({
|
|
"user_id": test_user.id,
|
|
"engine_type": "perplexity",
|
|
"query": "Perplexity query",
|
|
"cost": 0.2,
|
|
})
|
|
|
|
records = await repo.get_by_user_and_engine(
|
|
str(test_user.id),
|
|
"xinghuo",
|
|
)
|
|
|
|
assert len(records) == 1
|
|
assert records[0].engine_type == "xinghuo"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_summary(self, async_session, test_user):
|
|
"""Test getting summary when no records exist."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
summary = await repo.get_summary(str(test_user.id), period="month")
|
|
|
|
assert summary["total_queries"] == 0
|
|
assert summary["total_input_tokens"] == 0
|
|
assert summary["total_output_tokens"] == 0
|
|
assert summary["total_cost"] == 0.0
|
|
assert summary["by_engine"] == {}
|
|
assert summary["by_day"] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_quota_check(self, async_session, test_user):
|
|
"""Test checking quota when no records exist."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
result = await repo.check_quota(str(test_user.id), monthly_limit=100.0)
|
|
|
|
assert result["used"] == 0.0
|
|
assert result["usage_percentage"] == 0.0
|
|
assert result["status"] == "ok"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_uuid_handling(self, async_session, test_user):
|
|
"""Test that UUID handling works correctly."""
|
|
repo = UsageRepository(async_session)
|
|
|
|
data = {
|
|
"user_id": test_user.id,
|
|
"engine_type": "yuanbao",
|
|
"query": "UUID test",
|
|
"cost": 0.5,
|
|
}
|
|
|
|
record = await repo.create(data)
|
|
summary = await repo.get_summary(test_user.id, period="month")
|
|
|
|
assert summary["total_queries"] == 1
|
|
assert record.user_id == test_user.id
|