202 lines
6.3 KiB
Python
202 lines
6.3 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.repositories.usage_repository import UsageRepository
|
|
from app.models.usage_record import UsageRecord as UsageRecordModel
|
|
|
|
_QUOTA_WARNING_PCT = 80.0
|
|
_QUOTA_EXCEEDED_PCT = 100.0
|
|
|
|
_PERIOD_CUTOFF_DAYS = {"day": 0, "week": 7, "month": 30}
|
|
|
|
|
|
@dataclass
|
|
class UsageRecord:
|
|
id: str
|
|
user_id: str
|
|
brand_id: str
|
|
engine_type: str
|
|
query: str
|
|
input_tokens: int
|
|
output_tokens: int
|
|
cost: float
|
|
timestamp: datetime
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class UsageSummary:
|
|
period: str
|
|
start_date: str
|
|
end_date: str
|
|
total_queries: int
|
|
total_input_tokens: int
|
|
total_output_tokens: int
|
|
total_cost: float
|
|
by_engine: dict[str, dict[str, Any]]
|
|
by_day: dict[str, dict[str, Any]]
|
|
|
|
|
|
def _aggregate_by_engine(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
|
|
result: dict[str, dict[str, Any]] = {}
|
|
for r in records:
|
|
bucket = result.setdefault(r.engine_type, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
|
|
bucket["queries"] += 1
|
|
bucket["input_tokens"] += r.input_tokens
|
|
bucket["output_tokens"] += r.output_tokens
|
|
bucket["cost"] += r.cost
|
|
return result
|
|
|
|
|
|
def _aggregate_by_day(records: list[UsageRecord]) -> dict[str, dict[str, Any]]:
|
|
result: dict[str, dict[str, Any]] = {}
|
|
for r in records:
|
|
day_key = r.timestamp.strftime("%Y-%m-%d")
|
|
bucket = result.setdefault(day_key, {"queries": 0, "input_tokens": 0, "output_tokens": 0, "cost": 0.0})
|
|
bucket["queries"] += 1
|
|
bucket["input_tokens"] += r.input_tokens
|
|
bucket["output_tokens"] += r.output_tokens
|
|
bucket["cost"] += r.cost
|
|
return result
|
|
|
|
|
|
def _compute_cutoff(period: str, now: datetime) -> datetime:
|
|
days = _PERIOD_CUTOFF_DAYS.get(period, 30)
|
|
if days == 0:
|
|
return now.replace(hour=0, minute=0, second=0, microsecond=0)
|
|
return now - timedelta(days=days)
|
|
|
|
|
|
def _quota_status(usage_pct: float) -> str:
|
|
if usage_pct >= _QUOTA_EXCEEDED_PCT:
|
|
return "exceeded"
|
|
if usage_pct >= _QUOTA_WARNING_PCT:
|
|
return "warning"
|
|
return "ok"
|
|
|
|
|
|
class UsageTracker:
|
|
def __init__(self, session: AsyncSession | None = None) -> None:
|
|
self._records: list[UsageRecord] = []
|
|
self._session = session
|
|
self._repository = UsageRepository(session) if session else None
|
|
|
|
def record(
|
|
self,
|
|
user_id: str,
|
|
brand_id: str,
|
|
engine_type: str,
|
|
query: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
cost: float,
|
|
metadata: dict | None = None,
|
|
) -> UsageRecord:
|
|
rec = UsageRecord(
|
|
id=f"usage_{len(self._records) + 1}",
|
|
user_id=user_id,
|
|
brand_id=brand_id,
|
|
engine_type=engine_type,
|
|
query=query,
|
|
input_tokens=input_tokens,
|
|
output_tokens=output_tokens,
|
|
cost=cost,
|
|
timestamp=datetime.now(UTC),
|
|
metadata=metadata or {},
|
|
)
|
|
self._records.append(rec)
|
|
return rec
|
|
|
|
def get_summary(
|
|
self,
|
|
user_id: str | None = None,
|
|
period: str = "month",
|
|
brand_id: str | None = None,
|
|
) -> UsageSummary:
|
|
now = datetime.now(UTC)
|
|
filtered = list(self._records)
|
|
|
|
if user_id:
|
|
filtered = [r for r in filtered if r.user_id == user_id]
|
|
if brand_id:
|
|
filtered = [r for r in filtered if r.brand_id == brand_id]
|
|
|
|
cutoff = _compute_cutoff(period, now)
|
|
filtered = [r for r in filtered if r.timestamp >= cutoff]
|
|
|
|
return UsageSummary(
|
|
period=period,
|
|
start_date=cutoff.isoformat(),
|
|
end_date=now.isoformat(),
|
|
total_queries=len(filtered),
|
|
total_input_tokens=sum(r.input_tokens for r in filtered),
|
|
total_output_tokens=sum(r.output_tokens for r in filtered),
|
|
total_cost=round(sum(r.cost for r in filtered), 4),
|
|
by_engine=_aggregate_by_engine(filtered),
|
|
by_day=_aggregate_by_day(filtered),
|
|
)
|
|
|
|
def check_quota(self, user_id: str, monthly_limit: float = 100.0) -> dict:
|
|
summary = self.get_summary(user_id=user_id, period="month")
|
|
usage_pct = (summary.total_cost / monthly_limit * 100) if monthly_limit > 0 else 0
|
|
return {
|
|
"used": summary.total_cost,
|
|
"limit": monthly_limit,
|
|
"usage_percentage": round(usage_pct, 1),
|
|
"status": _quota_status(usage_pct),
|
|
}
|
|
|
|
async def record_async(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
brand_id: str | uuid.UUID | None,
|
|
engine_type: str,
|
|
query: str,
|
|
input_tokens: int,
|
|
output_tokens: int,
|
|
cost: float,
|
|
extra_data: dict | None = None,
|
|
) -> UsageRecordModel:
|
|
if not self._repository:
|
|
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
|
|
|
user_uuid = uuid.UUID(user_id) if isinstance(user_id, str) else user_id
|
|
brand_uuid = uuid.UUID(brand_id) if (brand_id and isinstance(brand_id, str)) else brand_id
|
|
|
|
data = {
|
|
"user_id": user_uuid,
|
|
"brand_id": brand_uuid,
|
|
"engine_type": engine_type,
|
|
"query": query,
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
"cost": cost,
|
|
"extra_data": extra_data or {},
|
|
}
|
|
return await self._repository.create(data)
|
|
|
|
async def get_summary_async(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
period: str = "month",
|
|
brand_id: str | uuid.UUID | None = None,
|
|
) -> dict:
|
|
if not self._repository:
|
|
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
|
return await self._repository.get_summary(user_id, period, brand_id)
|
|
|
|
async def check_quota_async(
|
|
self,
|
|
user_id: str | uuid.UUID,
|
|
monthly_limit: float = 100.0,
|
|
) -> dict:
|
|
if not self._repository:
|
|
raise RuntimeError("UsageTracker not initialized with AsyncSession")
|
|
return await self._repository.check_quota(user_id, monthly_limit)
|