geo/backend/app/services/usage_tracker.py

202 lines
6.3 KiB
Python

from __future__ import annotations
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.repositories.usage_repository import UsageRepository
from app.models.usage_record import UsageRecord as UsageRecordModel
_QUOTA_WARNING_PCT = 80.0
_QUOTA_EXCEEDED_PCT = 100.0
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
@dataclass
class UsageRecord:
id: str
user_id: str
brand_id: str
engine_type: str
query: str
input_tokens: int
output_tokens: int
cost: float
timestamp: datetime
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class UsageSummary:
period: str
start_date: str
end_date: str
total_queries: int
total_input_tokens: int
total_output_tokens: int
total_cost: float
by_engine: dict[str, dict[str, Any]]
by_day: dict[str, dict[str, Any]]
def _aggregate_by_engine(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
result: dict[str, dict[str, Any]] = {}
for r in records:
bucket = result.setdefault(r.engine_type, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
bucket["queries"] += 1
bucket["input_tokens"] += r.input_tokens
bucket["output_tokens"] += r.output_tokens
bucket["cost"] += r.cost
return result
def _aggregate_by_day(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
result: dict[str, dict[str, Any]] = {}
for r in records:
day_key = r.timestamp.strftime("%Y-%m-%d")
bucket = result.setdefault(day_key, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
bucket["queries"] += 1
bucket["input_tokens"] += r.input_tokens
bucket["output_tokens"] += r.output_tokens
bucket["cost"] += r.cost
return result
def _compute_cutoff(period: str, now: datetime) -> datetime:
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
if days == 0:
return now.replace(hour=0, minute=0, second=0, microsecond=0)
return now - timedelta(days=days)
def _quota_status(usage_pct: float) -> str:
if usage_pct >= _QUOTA_EXCEEDED_PCT:
return "exceeded"
if usage_pct >= _QUOTA_WARNING_PCT:
return "warning"
return "ok"
class UsageTracker:
def __init__(self, session: AsyncSession | None = None) -> None:
self._records: list[UsageRecord] = []
self._session = session
self._repository = UsageRepository(session) if session else None
def record(
self,
user_id: str,
brand_id: str,
engine_type: str,
query: str,
input_tokens: int,
output_tokens: int,
cost: float,
metadata: dict | None = None,
) -> UsageRecord:
rec = UsageRecord(
id=f"usage_{len(self._records) + 1}",
user_id=user_id,
brand_id=brand_id,
engine_type=engine_type,
query=query,
input_tokens=input_tokens,
output_tokens=output_tokens,
cost=cost,
timestamp=datetime.now(UTC),
metadata=metadata or {},
)
self._records.append(rec)
return rec
def get_summary(
self,
user_id: str | None = None,
period: str = "month",
brand_id: str | None = None,
) -> UsageSummary:
now = datetime.now(UTC)
filtered = list(self._records)
if user_id:
filtered = [r for r in filtered if r.user_id == user_id]
if brand_id:
filtered = [r for r in filtered if r.brand_id == brand_id]
cutoff = _compute_cutoff(period, now)
filtered = [r for r in filtered if r.timestamp >= cutoff]
return UsageSummary(
period=period,
start_date=cutoff.isoformat(),
end_date=now.isoformat(),
total_queries=len(filtered),
total_input_tokens=sum(r.input_tokens for r in filtered),
total_output_tokens=sum(r.output_tokens for r in filtered),
total_cost=round(sum(r.cost for r in filtered), 4),
by_engine=_aggregate_by_engine(filtered),
by_day=_aggregate_by_day(filtered),
)
def check_quota(self, user_id: str, monthly_limit: float = 100.0) -> dict:
summary = self.get_summary(user_id=user_id, period="month")
usage_pct = (summary.total_cost / monthly_limit * 100) if monthly_limit > 0 else 0
return {
"used": summary.total_cost,
"limit": monthly_limit,
"usage_percentage": round(usage_pct, 1),
"status": _quota_status(usage_pct),
}
async def record_async(
self,
user_id: str | uuid.UUID,
brand_id: str | uuid.UUID | None,
engine_type: str,
query: str,
input_tokens: int,
output_tokens: int,
cost: float,
extra_data: dict | None = None,
) -> UsageRecordModel:
if not self._repository:
raise RuntimeError("UsageTracker not initialized with AsyncSession")
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
brand_uuid = uuid.UUID(brand_id) if (brand_id and isinstance(brand_id, str)) else brand_id
data = {
"user_id": user_uuid,
"brand_id": brand_uuid,
"engine_type": engine_type,
"query": query,
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": cost,
"extra_data": extra_data or {},
}
return await self._repository.create(data)
async def get_summary_async(
self,
user_id: str | uuid.UUID,
period: str = "month",
brand_id: str | uuid.UUID | None = None,
) -> dict:
if not self._repository:
raise RuntimeError("UsageTracker not initialized with AsyncSession")
return await self._repository.get_summary(user_id, period, brand_id)
async def check_quota_async(
self,
user_id: str | uuid.UUID,
monthly_limit: float = 100.0,
) -> dict:
if not self._repository:
raise RuntimeError("UsageTracker not initialized with AsyncSession")
return await self._repository.check_quota(user_id, monthly_limit)