geo/backend/app/repositories/usage_repository.py

176 lines
5.7 KiB
Python

import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select, func, and_, case
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.usage_record import UsageRecord
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
_QUOTA_WARNING_PCT = 80.0
_QUOTA_EXCEEDED_PCT = 100.0
def _compute_cutoff(period: str, now: datetime) -> datetime:
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
if days == 0:
return now.replace(hour=0, minute=0, second=0, microsecond=0)
return now - timedelta(days=days)
def _quota_status(usage_pct: float) -> str:
if usage_pct >= _QUOTA_EXCEEDED_PCT:
return "exceeded"
if usage_pct >= _QUOTA_WARNING_PCT:
return "warning"
return "ok"
class UsageRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, data: dict) -> UsageRecord:
record = UsageRecord(
user_id=data["user_id"],
brand_id=data.get("brand_id"),
engine_type=data["engine_type"],
query=data["query"],
input_tokens=data.get("input_tokens", 0),
output_tokens=data.get("output_tokens", 0),
cost=data.get("cost", 0.0),
extra_data=data.get("extra_data", {}),
timestamp=data.get("timestamp", datetime.now(timezone.utc)),
)
self.session.add(record)
await self.session.commit()
await self.session.refresh(record)
return record
async def get_summary(
self,
user_id: str | uuid.UUID,
period: str = "month",
brand_id: str | uuid.UUID | None = None,
) -> dict:
now = datetime.now(timezone.utc)
cutoff = _compute_cutoff(period, now)
if isinstance(user_id, str):
user_id = uuid.UUID(user_id)
if brand_id and isinstance(brand_id, str):
brand_id = uuid.UUID(brand_id)
conditions = [
UsageRecord.user_id == user_id,
UsageRecord.timestamp >= cutoff,
]
if brand_id:
conditions.append(UsageRecord.brand_id == brand_id)
result = await self.session.execute(
select(UsageRecord).where(and_(*conditions))
)
records = list(result.scalars().all())
total_queries = len(records)
total_input_tokens = sum(r.input_tokens for r in records)
total_output_tokens = sum(r.output_tokens for r in records)
total_cost = round(sum(r.cost for r in records), 4)
by_engine: dict[str, dict] = {}
for r in records:
bucket = by_engine.setdefault(
r.engine_type,
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
)
bucket["queries"] += 1
bucket["input_tokens"] += r.input_tokens
bucket["output_tokens"] += r.output_tokens
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
by_day: dict[str, dict] = {}
for r in records:
day_key = r.timestamp.strftime("%Y-%m-%d")
bucket = by_day.setdefault(
day_key,
{"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0}
)
bucket["queries"] += 1
bucket["input_tokens"] += r.input_tokens
bucket["output_tokens"] += r.output_tokens
bucket["cost"] = round(bucket["cost"] + r.cost, 4)
return {
"period": period,
"start_date": cutoff.isoformat(),
"end_date": now.isoformat(),
"total_queries": total_queries,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_cost": total_cost,
"by_engine": by_engine,
"by_day": by_day,
}
async def check_quota(
self,
user_id: str | uuid.UUID,
monthly_limit: float = 100.0,
) -> dict:
summary = await self.get_summary(user_id=user_id, period="month")
usage_pct = (summary["total_cost"] / monthly_limit * 100) if monthly_limit > 0 else 0
return {
"used": summary["total_cost"],
"limit": monthly_limit,
"usage_percentage": round(usage_pct, 1),
"status": _quota_status(usage_pct),
}
async def get_by_id(self, record_id: uuid.UUID) -> UsageRecord | None:
result = await self.session.execute(
select(UsageRecord).where(UsageRecord.id == record_id)
)
return result.scalar_one_or_none()
async def get_by_user(
self,
user_id: str | uuid.UUID,
limit: int = 100,
offset: int = 0,
) -> list[UsageRecord]:
if isinstance(user_id, str):
user_id = uuid.UUID(user_id)
result = await self.session.execute(
select(UsageRecord)
.where(UsageRecord.user_id == user_id)
.order_by(UsageRecord.timestamp.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
async def get_by_user_and_engine(
self,
user_id: str | uuid.UUID,
engine_type: str,
limit: int = 100,
) -> list[UsageRecord]:
if isinstance(user_id, str):
user_id = uuid.UUID(user_id)
result = await self.session.execute(
select(UsageRecord)
.where(
and_(
UsageRecord.user_id == user_id,
UsageRecord.engine_type == engine_type,
)
)
.order_by(UsageRecord.timestamp.desc())
.limit(limit)
)
return list(result.scalars().all())