176 lines
5.7 KiB
Python
176 lines
5.7 KiB
Python
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from sqlalchemy import select, func, and_, case
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.usage_record import UsageRecord
|
|
|
|
|
|
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
|
|
_QUOTA_WARNING_PCT = 80.0
|
|
_QUOTA_EXCEEDED_PCT = 100.0
|
|
|
|
|
|
def _compute_cutoff(period: str, now: datetime) -> datetime:
|
|
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
|
|
if days == 0:
|
|
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
return now - timedelta(days=days)
|
|
|
|
|
|
def _quota_status(usage_pct: float) -> str:
|
|
if usage_pct >= _QUOTA_EXCEEDED_PCT:
|
|
return "exceeded"
|
|
if usage_pct >= _QUOTA_WARNING_PCT:
|
|
return "warning"
|
|
return "ok"
|
|
|
|
|
|
class UsageRepository:
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, data: dict) -> UsageRecord:
|
|
record = UsageRecord(
|
|
user_id=data["user_id"],
|
|
brand_id=data.get("brand_id"),
|
|
engine_type=data["engine_type"],
|
|
query=data["query"],
|
|
input_tokens=data.get("input_tokens", 0),
|
|
output_tokens=data.get("output_tokens", 0),
|
|
cost=data.get("cost", 0.0),
|
|
extra_data=data.get("extra_data", {}),
|
|
timestamp=data.get("timestamp", datetime.now(timezone.utc)),
|
|
)
|
|
self.session.add(record)
|
|
await self.session.commit()
|
|
await self.session.refresh(record)
|
|
return record
|
|
|
|
async def get_summary(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
period: str = "month",
|
|
brand_id: str | uuid.UUID | None = None,
|
|
) -> dict:
|
|
now = datetime.now(timezone.utc)
|
|
cutoff = _compute_cutoff(period, now)
|
|
|
|
if isinstance(user_id, str):
|
|
user_id = uuid.UUID(user_id)
|
|
if brand_id and isinstance(brand_id, str):
|
|
brand_id = uuid.UUID(brand_id)
|
|
|
|
conditions = [
|
|
UsageRecord.user_id == user_id,
|
|
UsageRecord.timestamp >= cutoff,
|
|
]
|
|
if brand_id:
|
|
conditions.append(UsageRecord.brand_id == brand_id)
|
|
|
|
result = await self.session.execute(
|
|
select(UsageRecord).where(and_(*conditions))
|
|
)
|
|
records = list(result.scalars().all())
|
|
|
|
total_queries = len(records)
|
|
total_input_tokens = sum(r.input_tokens for r in records)
|
|
total_output_tokens = sum(r.output_tokens for r in records)
|
|
total_cost = round(sum(r.cost for r in records), 4)
|
|
|
|
by_engine: dict[str, dict] = {}
|
|
for r in records:
|
|
bucket = by_engine.setdefault(
|
|
r.engine_type,
|
|
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
|
)
|
|
bucket["queries"] += 1
|
|
bucket["input_tokens"] += r.input_tokens
|
|
bucket["output_tokens"] += r.output_tokens
|
|
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
|
|
|
|
by_day: dict[str, dict] = {}
|
|
for r in records:
|
|
day_key = r.timestamp.strftime("%Y-%m-%d")
|
|
bucket = by_day.setdefault(
|
|
day_key,
|
|
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
|
|
)
|
|
bucket["queries"] += 1
|
|
bucket["input_tokens"] += r.input_tokens
|
|
bucket["output_tokens"] += r.output_tokens
|
|
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
|
|
|
|
return {
|
|
"period": period,
|
|
"start_date": cutoff.isoformat(),
|
|
"end_date": now.isoformat(),
|
|
"total_queries": total_queries,
|
|
"total_input_tokens": total_input_tokens,
|
|
"total_output_tokens": total_output_tokens,
|
|
"total_cost": total_cost,
|
|
"by_engine": by_engine,
|
|
"by_day": by_day,
|
|
}
|
|
|
|
async def check_quota(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
monthly_limit: float = 100.0,
|
|
) -> dict:
|
|
summary = await self.get_summary(user_id=user_id, period="month")
|
|
usage_pct = (summary["total_cost"] / monthly_limit * 100) if monthly_limit > 0 else 0
|
|
return {
|
|
"used": summary["total_cost"],
|
|
"limit": monthly_limit,
|
|
"usage_percentage": round(usage_pct, 1),
|
|
"status": _quota_status(usage_pct),
|
|
}
|
|
|
|
async def get_by_id(self, record_id: uuid.UUID) -> UsageRecord | None:
|
|
result = await self.session.execute(
|
|
select(UsageRecord).where(UsageRecord.id == record_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_user(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
) -> list[UsageRecord]:
|
|
if isinstance(user_id, str):
|
|
user_id = uuid.UUID(user_id)
|
|
|
|
result = await self.session.execute(
|
|
select(UsageRecord)
|
|
.where(UsageRecord.user_id == user_id)
|
|
.order_by(UsageRecord.timestamp.desc())
|
|
.limit(limit)
|
|
.offset(offset)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_by_user_and_engine(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
engine_type: str,
|
|
limit: int = 100,
|
|
) -> list[UsageRecord]:
|
|
if isinstance(user_id, str):
|
|
user_id = uuid.UUID(user_id)
|
|
|
|
result = await self.session.execute(
|
|
select(UsageRecord)
|
|
.where(
|
|
and_(
|
|
UsageRecord.user_id == user_id,
|
|
UsageRecord.engine_type == engine_type,
|
|
)
|
|
)
|
|
.order_by(UsageRecord.timestamp.desc())
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|