feat(admin): U7 — usage dashboard + quota enforcement
UsageRecord extended with user_id + department_id (backward compatible). UsageStore Protocol extended: record() accepts user_id/department_id, get_usage() accepts filters, new get_usage_by_user/department methods. RedisUsageStore uses versioned keys (v2) for new records. LLMGateway.chat()/chat_stream() accept user_id, department_ids, db_path. Quota check before provider call: model whitelist + token limit + cost limit (daily). Multi-department uses strictest-wins (any exceed → reject). QuotaExceededError → 429 at route layer. UsageService: summary, timeseries, by-model, top-users, export (CSV/JSON). 5 new admin endpoints under /admin/usage/*. llm_gateway.py routes pass DepartmentContext + db_path to gateway, catch QuotaExceededError → 429 (JSON for /chat, SSE error for /stream). 84 new tests. 441 admin+usage tests pass, no regressions.
This commit is contained in:
parent
fd7f6816b8
commit
09feca3307
|
|
@ -3,6 +3,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||||
|
|
@ -15,6 +17,32 @@ from agentkit.telemetry.metrics import llm_token_histogram
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceededError(Exception):
|
||||||
|
"""Raised when a department's LLM quota is exceeded.
|
||||||
|
|
||||||
|
Carries enough metadata for the API layer to return a structured
|
||||||
|
429 response (department_id, quota_type, period, limit, current).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
period: str,
|
||||||
|
limit: Any,
|
||||||
|
current: Any,
|
||||||
|
) -> None:
|
||||||
|
self.department_id = department_id
|
||||||
|
self.quota_type = quota_type
|
||||||
|
self.period = period
|
||||||
|
self.limit = limit
|
||||||
|
self.current = current
|
||||||
|
super().__init__(
|
||||||
|
f"Quota exceeded for department {department_id}: "
|
||||||
|
f"{quota_type} {period} (limit={limit}, current={current})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMGateway:
|
class LLMGateway:
|
||||||
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
|
||||||
|
|
||||||
|
|
@ -83,6 +111,9 @@ class LLMGateway:
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_ids: list[str] | None = None,
|
||||||
|
db_path: Path | str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""发送 chat 请求,自动解析别名和 Fallback"""
|
"""发送 chat 请求,自动解析别名和 Fallback"""
|
||||||
|
|
@ -91,6 +122,12 @@ class LLMGateway:
|
||||||
if not self._providers:
|
if not self._providers:
|
||||||
raise LLMProviderError("", "No provider registered")
|
raise LLMProviderError("", "No provider registered")
|
||||||
|
|
||||||
|
# ── Quota enforcement ──
|
||||||
|
# Only enforce when department_ids + db_path are provided
|
||||||
|
# (other call sites pass None — no quota check).
|
||||||
|
if department_ids and db_path:
|
||||||
|
await self._enforce_quota(db_path, department_ids, resolved_model)
|
||||||
|
|
||||||
# Telemetry: start LLM span
|
# Telemetry: start LLM span
|
||||||
_span_cm = None
|
_span_cm = None
|
||||||
_span = None
|
_span = None
|
||||||
|
|
@ -131,12 +168,14 @@ class LLMGateway:
|
||||||
result = await self._cache.get(cache_key)
|
result = await self._cache.get(cache_key)
|
||||||
if result.hit:
|
if result.hit:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
self._usage_tracker.record(
|
self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=result.response.model,
|
model=result.response.model,
|
||||||
usage=result.response.usage,
|
usage=result.response.usage,
|
||||||
cost=0.0,
|
cost=0.0,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_ids=department_ids,
|
||||||
)
|
)
|
||||||
if _span is not None:
|
if _span is not None:
|
||||||
_span.set_attribute("gen_ai.cache.hit", True)
|
_span.set_attribute("gen_ai.cache.hit", True)
|
||||||
|
|
@ -158,12 +197,14 @@ class LLMGateway:
|
||||||
result = await self._cache.semantic_search(query_embedding)
|
result = await self._cache.semantic_search(query_embedding)
|
||||||
if result.hit:
|
if result.hit:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
self._usage_tracker.record(
|
self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=result.response.model,
|
model=result.response.model,
|
||||||
usage=result.response.usage,
|
usage=result.response.usage,
|
||||||
cost=0.0,
|
cost=0.0,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_ids=department_ids,
|
||||||
)
|
)
|
||||||
if _span is not None:
|
if _span is not None:
|
||||||
_span.set_attribute("gen_ai.cache.hit", True)
|
_span.set_attribute("gen_ai.cache.hit", True)
|
||||||
|
|
@ -204,12 +245,14 @@ class LLMGateway:
|
||||||
if response.usage:
|
if response.usage:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
cost = self._calculate_cost(model_name, response.usage)
|
cost = self._calculate_cost(model_name, response.usage)
|
||||||
self._usage_tracker.record(
|
self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
usage=response.usage,
|
usage=response.usage,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_ids=department_ids,
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Model '{model_name}' returned empty content with no tool_calls, "
|
f"Model '{model_name}' returned empty content with no tool_calls, "
|
||||||
|
|
@ -243,12 +286,14 @@ class LLMGateway:
|
||||||
cost = self._calculate_cost(response.model, response.usage)
|
cost = self._calculate_cost(response.model, response.usage)
|
||||||
|
|
||||||
# 记录使用量
|
# 记录使用量
|
||||||
self._usage_tracker.record(
|
self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=response.model,
|
model=response.model,
|
||||||
usage=response.usage,
|
usage=response.usage,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_ids=department_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Telemetry: record token usage and end span
|
# Telemetry: record token usage and end span
|
||||||
|
|
@ -278,6 +323,9 @@ class LLMGateway:
|
||||||
tools: list[dict] | None = None,
|
tools: list[dict] | None = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_ids: list[str] | None = None,
|
||||||
|
db_path: Path | str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Stream chat response with fallback support.
|
"""Stream chat response with fallback support.
|
||||||
|
|
@ -293,6 +341,10 @@ class LLMGateway:
|
||||||
if not self._providers:
|
if not self._providers:
|
||||||
raise LLMProviderError("", "No provider registered")
|
raise LLMProviderError("", "No provider registered")
|
||||||
|
|
||||||
|
# ── Quota enforcement ──
|
||||||
|
if department_ids and db_path:
|
||||||
|
await self._enforce_quota(db_path, department_ids, resolved_model)
|
||||||
|
|
||||||
models_to_try = self._get_models_to_try(resolved_model)
|
models_to_try = self._get_models_to_try(resolved_model)
|
||||||
last_error: Exception | None = None
|
last_error: Exception | None = None
|
||||||
|
|
||||||
|
|
@ -354,12 +406,14 @@ class LLMGateway:
|
||||||
if final_usage is None:
|
if final_usage is None:
|
||||||
final_usage = TokenUsage()
|
final_usage = TokenUsage()
|
||||||
cost = self._calculate_cost(final_model, final_usage)
|
cost = self._calculate_cost(final_model, final_usage)
|
||||||
self._usage_tracker.record(
|
self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=final_model,
|
model=final_model,
|
||||||
usage=final_usage,
|
usage=final_usage,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_ids=department_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Empty stream detection: if no content was produced,
|
# Empty stream detection: if no content was produced,
|
||||||
|
|
@ -453,3 +507,132 @@ class LLMGateway:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Quota enforcement helpers (U7)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _record_usage(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
model: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
cost: float,
|
||||||
|
latency_ms: float,
|
||||||
|
user_id: str | None,
|
||||||
|
department_ids: list[str] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Record a usage event, attaching user_id and (first) department_id.
|
||||||
|
|
||||||
|
We attach only the first department_id to the record because
|
||||||
|
usage attribution is per-department. If a user belongs to
|
||||||
|
multiple departments, the caller is responsible for choosing
|
||||||
|
which department to bill — the gateway just records what it's
|
||||||
|
told.
|
||||||
|
"""
|
||||||
|
dept_id = department_ids[0] if department_ids else None
|
||||||
|
self._usage_tracker.record(
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=model,
|
||||||
|
usage=usage,
|
||||||
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=dept_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _enforce_quota(
|
||||||
|
self,
|
||||||
|
db_path: Path | str,
|
||||||
|
department_ids: list[str],
|
||||||
|
resolved_model: str,
|
||||||
|
) -> None:
|
||||||
|
"""Run all quota checks for the given departments.
|
||||||
|
|
||||||
|
Strictest-wins: if ANY department fails ANY check, raises
|
||||||
|
:class:`QuotaExceededError` and the request is rejected.
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid circular dependency (admin → ... → gateway).
|
||||||
|
from agentkit.server.admin.quota_service import get_quota_service
|
||||||
|
|
||||||
|
quota_service = get_quota_service()
|
||||||
|
db = Path(db_path)
|
||||||
|
|
||||||
|
for dept_id in department_ids:
|
||||||
|
# 1. Model whitelist
|
||||||
|
allowed, _reason = await quota_service.is_model_allowed(db, dept_id, resolved_model)
|
||||||
|
if not allowed:
|
||||||
|
raise QuotaExceededError(
|
||||||
|
department_id=dept_id,
|
||||||
|
quota_type="model_whitelist",
|
||||||
|
period="",
|
||||||
|
limit="",
|
||||||
|
current=resolved_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Token limit (daily)
|
||||||
|
current_tokens = await self._get_current_usage_for_quota(dept_id, "daily")
|
||||||
|
allowed, _reason = await quota_service.check_quota(
|
||||||
|
db, dept_id, "token_limit", "daily", current_tokens
|
||||||
|
)
|
||||||
|
if not allowed:
|
||||||
|
quota = await quota_service.get_quota(db, dept_id, "token_limit", "daily")
|
||||||
|
limit = quota["limit_value"] if quota else None
|
||||||
|
raise QuotaExceededError(
|
||||||
|
department_id=dept_id,
|
||||||
|
quota_type="token_limit",
|
||||||
|
period="daily",
|
||||||
|
limit=limit,
|
||||||
|
current=current_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Cost limit (daily)
|
||||||
|
current_cost = await self._get_current_cost_for_quota(dept_id, "daily")
|
||||||
|
allowed, _reason = await quota_service.check_quota(
|
||||||
|
db, dept_id, "cost_limit", "daily", current_cost
|
||||||
|
)
|
||||||
|
if not allowed:
|
||||||
|
quota = await quota_service.get_quota(db, dept_id, "cost_limit", "daily")
|
||||||
|
limit = quota["limit_value"] if quota else None
|
||||||
|
raise QuotaExceededError(
|
||||||
|
department_id=dept_id,
|
||||||
|
quota_type="cost_limit",
|
||||||
|
period="daily",
|
||||||
|
limit=limit,
|
||||||
|
current=current_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int:
|
||||||
|
"""Return total tokens used by ``department_id`` in the current period.
|
||||||
|
|
||||||
|
``period`` is ``"daily"`` or ``"monthly"``. For ``"daily"`` the
|
||||||
|
window is since 00:00 UTC today; for ``"monthly"`` since the
|
||||||
|
first of the current month.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if period == "monthly":
|
||||||
|
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
else:
|
||||||
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
summary = self._usage_tracker.get_usage(
|
||||||
|
department_id=department_id, start_time=start, end_time=now
|
||||||
|
)
|
||||||
|
return int(summary.total_tokens)
|
||||||
|
|
||||||
|
async def _get_current_cost_for_quota(self, department_id: str, period: str) -> float:
|
||||||
|
"""Return total cost (in cents) for ``department_id`` in the current period.
|
||||||
|
|
||||||
|
``period`` is ``"daily"`` or ``"monthly"``. Quota cost_limit is
|
||||||
|
stored in cents, so we convert the float USD cost from the usage
|
||||||
|
store to cents (×100) for comparison.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if period == "monthly":
|
||||||
|
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
else:
|
||||||
|
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
summary = self._usage_tracker.get_usage(
|
||||||
|
department_id=department_id, start_time=start, end_time=now
|
||||||
|
)
|
||||||
|
# cost_limit is stored in cents; convert from USD to cents.
|
||||||
|
return float(summary.total_cost) * 100.0
|
||||||
|
|
|
||||||
|
|
@ -23,15 +23,38 @@ class UsageTracker:
|
||||||
usage: TokenUsage,
|
usage: TokenUsage,
|
||||||
cost: float,
|
cost: float,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录一次使用"""
|
"""记录一次使用"""
|
||||||
self._store.record(agent_name, model, usage, cost, latency_ms)
|
self._store.record(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
start_time: datetime | None = None,
|
start_time: datetime | None = None,
|
||||||
end_time: datetime | None = None,
|
end_time: datetime | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> UsageSummary:
|
) -> UsageSummary:
|
||||||
"""查询使用量汇总"""
|
"""查询使用量汇总"""
|
||||||
return self._store.get_usage(agent_name, start_time, end_time)
|
return self._store.get_usage(
|
||||||
|
agent_name=agent_name,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def store(self) -> UsageStore:
|
||||||
|
"""Expose the underlying store for service-layer queries."""
|
||||||
|
return self._store
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,12 @@ backends. Replaces the in-memory list in UsageTracker with a pluggable
|
||||||
store that survives restarts and supports multi-instance deployment.
|
store that survives restarts and supports multi-instance deployment.
|
||||||
|
|
||||||
Key schema (Redis):
|
Key schema (Redis):
|
||||||
agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)}
|
agentkit:usage:v2:{date}:{user_id}:{department_id} → Hash: {agent_name:model → JSON(UsageBucket)}
|
||||||
agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM
|
agentkit:usage_records:v2:{date}:{user_id}:{department_id} → List: JSON(UsageRecord) with LTRIM
|
||||||
|
|
||||||
|
Legacy v1 keys (still readable for backward compat):
|
||||||
|
agentkit:usage:{date} → Hash
|
||||||
|
agentkit:usage_records:{date} → List
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -32,6 +36,8 @@ class UsageRecord:
|
||||||
cost: float
|
cost: float
|
||||||
latency_ms: float
|
latency_ms: float
|
||||||
timestamp: str = "" # ISO 8601 string for JSON serialization
|
timestamp: str = "" # ISO 8601 string for JSON serialization
|
||||||
|
user_id: str | None = None
|
||||||
|
department_id: str | None = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.timestamp:
|
if not self.timestamp:
|
||||||
|
|
@ -57,6 +63,8 @@ class UsageSummary:
|
||||||
total_cost: float = 0.0
|
total_cost: float = 0.0
|
||||||
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
||||||
records: list[UsageRecord] = field(default_factory=list)
|
records: list[UsageRecord] = field(default_factory=list)
|
||||||
|
by_user: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
||||||
|
by_department: dict[str, dict[str, int | float]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -75,6 +83,8 @@ class UsageStore(Protocol):
|
||||||
usage: TokenUsage,
|
usage: TokenUsage,
|
||||||
cost: float,
|
cost: float,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record a usage event."""
|
"""Record a usage event."""
|
||||||
...
|
...
|
||||||
|
|
@ -84,6 +94,8 @@ class UsageStore(Protocol):
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
start_time: datetime | None = None,
|
start_time: datetime | None = None,
|
||||||
end_time: datetime | None = None,
|
end_time: datetime | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> UsageSummary:
|
) -> UsageSummary:
|
||||||
"""Query usage summary."""
|
"""Query usage summary."""
|
||||||
...
|
...
|
||||||
|
|
@ -109,6 +121,8 @@ class InMemoryUsageStore:
|
||||||
usage: TokenUsage,
|
usage: TokenUsage,
|
||||||
cost: float,
|
cost: float,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
rec = UsageRecord(
|
rec = UsageRecord(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
|
|
@ -118,16 +132,20 @@ class InMemoryUsageStore:
|
||||||
total_tokens=usage.total_tokens,
|
total_tokens=usage.total_tokens,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
)
|
)
|
||||||
self._records.append(rec)
|
self._records.append(rec)
|
||||||
if len(self._records) > self.MAX_RECORDS:
|
if len(self._records) > self.MAX_RECORDS:
|
||||||
self._records = self._records[-self.MAX_RECORDS:]
|
self._records = self._records[-self.MAX_RECORDS :]
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
start_time: datetime | None = None,
|
start_time: datetime | None = None,
|
||||||
end_time: datetime | None = None,
|
end_time: datetime | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> UsageSummary:
|
) -> UsageSummary:
|
||||||
filtered = self._records
|
filtered = self._records
|
||||||
|
|
||||||
|
|
@ -137,26 +155,65 @@ class InMemoryUsageStore:
|
||||||
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
|
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
|
||||||
if end_time is not None:
|
if end_time is not None:
|
||||||
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
|
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
|
||||||
|
if user_id is not None:
|
||||||
|
filtered = [r for r in filtered if r.user_id == user_id]
|
||||||
|
if department_id is not None:
|
||||||
|
filtered = [r for r in filtered if r.department_id == department_id]
|
||||||
|
|
||||||
if not filtered:
|
if not filtered:
|
||||||
return UsageSummary()
|
return UsageSummary()
|
||||||
|
|
||||||
total_tokens = sum(r.total_tokens for r in filtered)
|
return self._aggregate(filtered)
|
||||||
total_cost = sum(r.cost for r in filtered)
|
|
||||||
|
def get_usage_by_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> UsageSummary:
|
||||||
|
"""Aggregate usage for a specific user."""
|
||||||
|
return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time)
|
||||||
|
|
||||||
|
def get_usage_by_department(
|
||||||
|
self,
|
||||||
|
department_id: str,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> UsageSummary:
|
||||||
|
"""Aggregate usage for a specific department."""
|
||||||
|
return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _aggregate(records: list[UsageRecord]) -> UsageSummary:
|
||||||
|
"""Build a :class:`UsageSummary` from a list of records."""
|
||||||
|
total_tokens = sum(r.total_tokens for r in records)
|
||||||
|
total_cost = sum(r.cost for r in records)
|
||||||
|
|
||||||
by_model: dict[str, dict[str, int | float]] = {}
|
by_model: dict[str, dict[str, int | float]] = {}
|
||||||
for r in filtered:
|
by_user: dict[str, dict[str, int | float]] = {}
|
||||||
if r.model not in by_model:
|
by_department: dict[str, dict[str, int | float]] = {}
|
||||||
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
|
||||||
by_model[r.model]["total_tokens"] += r.total_tokens
|
def _bump(bucket_map: dict[str, dict[str, int | float]], key: str, r: UsageRecord) -> None:
|
||||||
by_model[r.model]["total_cost"] += r.cost
|
if key not in bucket_map:
|
||||||
by_model[r.model]["count"] += 1
|
bucket_map[key] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
||||||
|
bucket_map[key]["total_tokens"] += r.total_tokens
|
||||||
|
bucket_map[key]["total_cost"] += r.cost
|
||||||
|
bucket_map[key]["count"] += 1
|
||||||
|
|
||||||
|
for r in records:
|
||||||
|
_bump(by_model, r.model, r)
|
||||||
|
if r.user_id is not None:
|
||||||
|
_bump(by_user, r.user_id, r)
|
||||||
|
if r.department_id is not None:
|
||||||
|
_bump(by_department, r.department_id, r)
|
||||||
|
|
||||||
return UsageSummary(
|
return UsageSummary(
|
||||||
total_tokens=total_tokens,
|
total_tokens=total_tokens,
|
||||||
total_cost=total_cost,
|
total_cost=total_cost,
|
||||||
by_model=by_model,
|
by_model=by_model,
|
||||||
records=filtered,
|
by_user=by_user,
|
||||||
|
by_department=by_department,
|
||||||
|
records=records,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -168,13 +225,21 @@ class InMemoryUsageStore:
|
||||||
class RedisUsageStore:
|
class RedisUsageStore:
|
||||||
"""Redis-backed usage store using Hash per date for O(1) writes.
|
"""Redis-backed usage store using Hash per date for O(1) writes.
|
||||||
|
|
||||||
Key schema:
|
Key schema (v2 — includes user_id/department_id in key):
|
||||||
agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)}
|
agentkit:usage:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'}
|
||||||
agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM
|
→ Hash: {agent:model → JSON(UsageBucket)}
|
||||||
|
agentkit:usage_records:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'}
|
||||||
|
→ List: JSON(UsageRecord) with LTRIM
|
||||||
|
|
||||||
|
Legacy v1 keys (still readable for backward compat):
|
||||||
|
agentkit:usage:{YYYY-MM-DD} → Hash
|
||||||
|
agentkit:usage_records:{YYYY-MM-DD} → List
|
||||||
"""
|
"""
|
||||||
|
|
||||||
USAGE_PREFIX = "agentkit:usage:"
|
USAGE_PREFIX = "agentkit:usage:"
|
||||||
RECORDS_PREFIX = "agentkit:usage_records:"
|
RECORDS_PREFIX = "agentkit:usage_records:"
|
||||||
|
USAGE_PREFIX_V2 = "agentkit:usage:v2:"
|
||||||
|
RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:"
|
||||||
MAX_RECORDS_PER_DAY = 50000
|
MAX_RECORDS_PER_DAY = 50000
|
||||||
TTL_DAYS = 90 # Auto-expire after 90 days
|
TTL_DAYS = 90 # Auto-expire after 90 days
|
||||||
|
|
||||||
|
|
@ -188,6 +253,7 @@ class RedisUsageStore:
|
||||||
async def _get_redis(self):
|
async def _get_redis(self):
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
|
|
@ -195,9 +261,8 @@ class RedisUsageStore:
|
||||||
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
"""Get or create a persistent sync Redis client (connection pool backed)."""
|
||||||
if self._sync_redis is None:
|
if self._sync_redis is None:
|
||||||
import redis as sync_redis
|
import redis as sync_redis
|
||||||
self._sync_redis = sync_redis.from_url(
|
|
||||||
self._redis_url, decode_responses=True
|
self._sync_redis = sync_redis.from_url(self._redis_url, decode_responses=True)
|
||||||
)
|
|
||||||
return self._sync_redis
|
return self._sync_redis
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
|
|
@ -218,6 +283,22 @@ class RedisUsageStore:
|
||||||
def _today_key(self) -> str:
|
def _today_key(self) -> str:
|
||||||
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _scope_key(part: str | None) -> str:
|
||||||
|
"""Normalize a user_id/department_id for use in a Redis key."""
|
||||||
|
return part if part else "none"
|
||||||
|
|
||||||
|
def _v2_keys(
|
||||||
|
self, date_key: str, user_id: str | None, department_id: str | None
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""Return (hash_key, list_key) for v2 schema."""
|
||||||
|
u = self._scope_key(user_id)
|
||||||
|
d = self._scope_key(department_id)
|
||||||
|
return (
|
||||||
|
f"{self.USAGE_PREFIX_V2}{date_key}:{u}:{d}",
|
||||||
|
f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}",
|
||||||
|
)
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
|
|
@ -225,6 +306,8 @@ class RedisUsageStore:
|
||||||
usage: TokenUsage,
|
usage: TokenUsage,
|
||||||
cost: float,
|
cost: float,
|
||||||
latency_ms: float,
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record usage — sync wrapper for async Redis.
|
"""Record usage — sync wrapper for async Redis.
|
||||||
|
|
||||||
|
|
@ -233,15 +316,22 @@ class RedisUsageStore:
|
||||||
needing an event loop in the caller.
|
needing an event loop in the caller.
|
||||||
"""
|
"""
|
||||||
if self._degraded and self._fallback is not None:
|
if self._degraded and self._fallback is not None:
|
||||||
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
self._fallback.record(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = self._get_sync_redis()
|
r = self._get_sync_redis()
|
||||||
|
|
||||||
date_key = self._today_key()
|
date_key = self._today_key()
|
||||||
hash_key = f"{self.USAGE_PREFIX}{date_key}"
|
hash_key, list_key = self._v2_keys(date_key, user_id, department_id)
|
||||||
list_key = f"{self.RECORDS_PREFIX}{date_key}"
|
|
||||||
bucket_field = f"{agent_name}:{model}"
|
bucket_field = f"{agent_name}:{model}"
|
||||||
|
|
||||||
# Atomic HINCRBYFLOAT for bucket aggregation
|
# Atomic HINCRBYFLOAT for bucket aggregation
|
||||||
|
|
@ -261,17 +351,26 @@ class RedisUsageStore:
|
||||||
total_tokens=usage.total_tokens,
|
total_tokens=usage.total_tokens,
|
||||||
cost=cost,
|
cost=cost,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
pipe.rpush(
|
||||||
|
list_key,
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"agent_name": rec.agent_name,
|
||||||
|
"model": rec.model,
|
||||||
|
"prompt_tokens": rec.prompt_tokens,
|
||||||
|
"completion_tokens": rec.completion_tokens,
|
||||||
|
"total_tokens": rec.total_tokens,
|
||||||
|
"cost": rec.cost,
|
||||||
|
"latency_ms": rec.latency_ms,
|
||||||
|
"timestamp": rec.timestamp,
|
||||||
|
"user_id": rec.user_id,
|
||||||
|
"department_id": rec.department_id,
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
pipe.rpush(list_key, json.dumps({
|
|
||||||
"agent_name": rec.agent_name,
|
|
||||||
"model": rec.model,
|
|
||||||
"prompt_tokens": rec.prompt_tokens,
|
|
||||||
"completion_tokens": rec.completion_tokens,
|
|
||||||
"total_tokens": rec.total_tokens,
|
|
||||||
"cost": rec.cost,
|
|
||||||
"latency_ms": rec.latency_ms,
|
|
||||||
"timestamp": rec.timestamp,
|
|
||||||
}))
|
|
||||||
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
||||||
|
|
||||||
# Set TTL on first write of the day
|
# Set TTL on first write of the day
|
||||||
|
|
@ -283,17 +382,38 @@ class RedisUsageStore:
|
||||||
logger.warning(f"Redis usage record failed: {e}")
|
logger.warning(f"Redis usage record failed: {e}")
|
||||||
self._degrade_to_fallback()
|
self._degrade_to_fallback()
|
||||||
if self._fallback is not None:
|
if self._fallback is not None:
|
||||||
self._fallback.record(agent_name, model, usage, cost, latency_ms)
|
self._fallback.record(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
start_time: datetime | None = None,
|
start_time: datetime | None = None,
|
||||||
end_time: datetime | None = None,
|
end_time: datetime | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
) -> UsageSummary:
|
) -> UsageSummary:
|
||||||
"""Query usage summary from Redis."""
|
"""Query usage summary from Redis.
|
||||||
|
|
||||||
|
Scans v2 keys (filtered by user_id/department_id when provided)
|
||||||
|
and legacy v1 keys (no per-user/department scoping). Records
|
||||||
|
from both schemas are merged.
|
||||||
|
"""
|
||||||
if self._degraded and self._fallback is not None:
|
if self._degraded and self._fallback is not None:
|
||||||
return self._fallback.get_usage(agent_name, start_time, end_time)
|
return self._fallback.get_usage(
|
||||||
|
agent_name=agent_name,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = self._get_sync_redis()
|
r = self._get_sync_redis()
|
||||||
|
|
@ -303,47 +423,115 @@ class RedisUsageStore:
|
||||||
end = end_time or datetime.now(timezone.utc)
|
end = end_time or datetime.now(timezone.utc)
|
||||||
|
|
||||||
all_records: list[UsageRecord] = []
|
all_records: list[UsageRecord] = []
|
||||||
# Scan date keys in range
|
|
||||||
|
# Scan v2 keys.
|
||||||
current = start.date()
|
current = start.date()
|
||||||
end_date = end.date()
|
end_date = end.date()
|
||||||
while current <= end_date:
|
while current <= end_date:
|
||||||
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
|
date_key = current.isoformat()
|
||||||
raw_records = r.lrange(list_key, 0, -1)
|
# When user_id/department_id is provided, scan only the
|
||||||
for raw in raw_records:
|
# matching scope key. Otherwise scan all scopes for that
|
||||||
data = json.loads(raw)
|
# date via SCAN.
|
||||||
rec = UsageRecord(**data)
|
if user_id is not None or department_id is not None:
|
||||||
rec_ts = datetime.fromisoformat(rec.timestamp)
|
list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}"
|
||||||
if rec_ts >= start and rec_ts <= end:
|
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
|
||||||
if agent_name is None or rec.agent_name == agent_name:
|
else:
|
||||||
all_records.append(rec)
|
# Scan all v2 list keys for this date.
|
||||||
|
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*"
|
||||||
|
for key in r.scan_iter(match=pattern, count=200):
|
||||||
|
all_records.extend(self._read_list(r, key, start, end, agent_name))
|
||||||
current = current + timedelta(days=1)
|
current = current + timedelta(days=1)
|
||||||
|
|
||||||
|
# Also scan legacy v1 keys (no user/department scoping).
|
||||||
|
current = start.date()
|
||||||
|
while current <= end_date:
|
||||||
|
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
|
||||||
|
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
|
||||||
|
current = current + timedelta(days=1)
|
||||||
|
|
||||||
|
# Apply user_id/department_id filters to records from legacy
|
||||||
|
# v1 keys (which don't carry these fields — they'll be None).
|
||||||
|
if user_id is not None:
|
||||||
|
all_records = [r for r in all_records if r.user_id == user_id]
|
||||||
|
if department_id is not None:
|
||||||
|
all_records = [r for r in all_records if r.department_id == department_id]
|
||||||
|
|
||||||
if not all_records:
|
if not all_records:
|
||||||
return UsageSummary()
|
return UsageSummary()
|
||||||
|
|
||||||
total_tokens = sum(r.total_tokens for r in all_records)
|
return InMemoryUsageStore._aggregate(all_records)
|
||||||
total_cost = sum(r.cost for r in all_records)
|
|
||||||
|
|
||||||
by_model: dict[str, dict[str, int | float]] = {}
|
|
||||||
for r in all_records:
|
|
||||||
if r.model not in by_model:
|
|
||||||
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
|
|
||||||
by_model[r.model]["total_tokens"] += r.total_tokens
|
|
||||||
by_model[r.model]["total_cost"] += r.cost
|
|
||||||
by_model[r.model]["count"] += 1
|
|
||||||
|
|
||||||
return UsageSummary(
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
total_cost=total_cost,
|
|
||||||
by_model=by_model,
|
|
||||||
records=all_records,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Redis usage query failed: {e}")
|
logger.warning(f"Redis usage query failed: {e}")
|
||||||
if self._fallback is not None:
|
if self._fallback is not None:
|
||||||
return self._fallback.get_usage(agent_name, start_time, end_time)
|
return self._fallback.get_usage(
|
||||||
|
agent_name=agent_name,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
return UsageSummary()
|
return UsageSummary()
|
||||||
|
|
||||||
|
def get_usage_by_user(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> UsageSummary:
|
||||||
|
"""Aggregate usage for a specific user."""
|
||||||
|
return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time)
|
||||||
|
|
||||||
|
def get_usage_by_department(
|
||||||
|
self,
|
||||||
|
department_id: str,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> UsageSummary:
|
||||||
|
"""Aggregate usage for a specific department."""
|
||||||
|
return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_list(
|
||||||
|
r: Any,
|
||||||
|
list_key: str,
|
||||||
|
start: datetime,
|
||||||
|
end: datetime,
|
||||||
|
agent_name: str | None,
|
||||||
|
) -> list[UsageRecord]:
|
||||||
|
"""Read all records from a Redis list, filtered by time range and agent."""
|
||||||
|
out: list[UsageRecord] = []
|
||||||
|
raw_records = r.lrange(list_key, 0, -1)
|
||||||
|
for raw in raw_records:
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
# Build record, tolerating legacy records without user_id/department_id.
|
||||||
|
rec = UsageRecord(
|
||||||
|
agent_name=data["agent_name"],
|
||||||
|
model=data["model"],
|
||||||
|
prompt_tokens=data["prompt_tokens"],
|
||||||
|
completion_tokens=data["completion_tokens"],
|
||||||
|
total_tokens=data["total_tokens"],
|
||||||
|
cost=data["cost"],
|
||||||
|
latency_ms=data["latency_ms"],
|
||||||
|
timestamp=data.get("timestamp", ""),
|
||||||
|
user_id=data.get("user_id"),
|
||||||
|
department_id=data.get("department_id"),
|
||||||
|
)
|
||||||
|
if not rec.timestamp:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
rec_ts = datetime.fromisoformat(rec.timestamp)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if rec_ts < start or rec_ts > end:
|
||||||
|
continue
|
||||||
|
if agent_name is not None and rec.agent_name != agent_name:
|
||||||
|
continue
|
||||||
|
out.append(rec)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Factory
|
# Factory
|
||||||
|
|
@ -366,6 +554,7 @@ def create_usage_store(
|
||||||
if backend in ("auto", "redis"):
|
if backend in ("auto", "redis"):
|
||||||
try:
|
try:
|
||||||
import redis # noqa: F401
|
import redis # noqa: F401
|
||||||
|
|
||||||
return RedisUsageStore(redis_url=redis_url)
|
return RedisUsageStore(redis_url=redis_url)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("redis package not available, falling back to in-memory usage store")
|
logger.warning("redis package not available, falling back to in-memory usage store")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,298 @@
|
||||||
|
"""UsageService — read-side aggregations for the usage dashboard (U7).
|
||||||
|
|
||||||
|
This module provides read-only aggregations over a :class:`UsageStore`
|
||||||
|
for the admin usage dashboard. It is intentionally a thin layer — the
|
||||||
|
store already produces :class:`UsageSummary` aggregations, and this
|
||||||
|
service just shapes them for the dashboard endpoints (timeseries,
|
||||||
|
top-N, CSV/JSON export).
|
||||||
|
|
||||||
|
The service is a module-level singleton (see :func:`get_usage_service`)
|
||||||
|
so tests can inject a custom instance via :func:`set_usage_service`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.llm.providers.usage_store import UsageStore, UsageSummary
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _bucket_start(ts: datetime, interval: str) -> datetime:
|
||||||
|
"""Return the start of the time bucket containing ``ts``."""
|
||||||
|
if interval == "hour":
|
||||||
|
return ts.replace(minute=0, second=0, microsecond=0)
|
||||||
|
# Default: day
|
||||||
|
return ts.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Service
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class UsageService:
|
||||||
|
"""Read-side aggregations for the usage dashboard."""
|
||||||
|
|
||||||
|
async def get_usage_summary(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return a flat usage summary dict.
|
||||||
|
|
||||||
|
Shape::
|
||||||
|
|
||||||
|
{
|
||||||
|
"total_tokens": int,
|
||||||
|
"total_cost": float,
|
||||||
|
"total_requests": int,
|
||||||
|
"by_model": {model: {total_tokens, total_cost, count}, ...},
|
||||||
|
"by_user": {user_id: {...}, ...},
|
||||||
|
"by_department": {department_id: {...}, ...},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
return self._summary_to_dict(summary)
|
||||||
|
|
||||||
|
async def get_usage_timeseries(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
interval: str = "day",
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return a time-bucketed series.
|
||||||
|
|
||||||
|
Each item has shape ``{timestamp, tokens, cost, requests}``.
|
||||||
|
Buckets with no activity are omitted (callers can fill gaps).
|
||||||
|
"""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
buckets: dict[datetime, dict[str, Any]] = {}
|
||||||
|
for rec in summary.records:
|
||||||
|
try:
|
||||||
|
ts = datetime.fromisoformat(rec.timestamp)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
bucket = _bucket_start(ts, interval)
|
||||||
|
if bucket not in buckets:
|
||||||
|
buckets[bucket] = {"tokens": 0, "cost": 0.0, "requests": 0}
|
||||||
|
buckets[bucket]["tokens"] += rec.total_tokens
|
||||||
|
buckets[bucket]["cost"] += rec.cost
|
||||||
|
buckets[bucket]["requests"] += 1
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"timestamp": bucket.isoformat(),
|
||||||
|
"tokens": data["tokens"],
|
||||||
|
"cost": data["cost"],
|
||||||
|
"requests": data["requests"],
|
||||||
|
}
|
||||||
|
for bucket, data in sorted(buckets.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_usage_by_model(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return a per-model breakdown."""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"model": model,
|
||||||
|
"tokens": data["total_tokens"],
|
||||||
|
"cost": data["total_cost"],
|
||||||
|
"requests": data["count"],
|
||||||
|
}
|
||||||
|
for model, data in sorted(summary.by_model.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
async def get_top_users(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return the top-N users by total token usage."""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"user_id": uid,
|
||||||
|
"tokens": data["total_tokens"],
|
||||||
|
"cost": data["total_cost"],
|
||||||
|
"requests": data["count"],
|
||||||
|
}
|
||||||
|
for uid, data in summary.by_user.items()
|
||||||
|
]
|
||||||
|
rows.sort(key=lambda r: r["tokens"], reverse=True)
|
||||||
|
return rows[:limit]
|
||||||
|
|
||||||
|
async def get_top_departments(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return the top-N departments by total token usage."""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"department_id": did,
|
||||||
|
"tokens": data["total_tokens"],
|
||||||
|
"cost": data["total_cost"],
|
||||||
|
"requests": data["count"],
|
||||||
|
}
|
||||||
|
for did, data in summary.by_department.items()
|
||||||
|
]
|
||||||
|
rows.sort(key=lambda r: r["tokens"], reverse=True)
|
||||||
|
return rows[:limit]
|
||||||
|
|
||||||
|
async def export_usage(
|
||||||
|
self,
|
||||||
|
usage_store: UsageStore,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
format: str = "csv",
|
||||||
|
) -> str:
|
||||||
|
"""Export raw usage records as CSV or JSON.
|
||||||
|
|
||||||
|
``format`` is ``"csv"`` (default) or ``"json"``.
|
||||||
|
"""
|
||||||
|
summary = usage_store.get_usage(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"timestamp": rec.timestamp,
|
||||||
|
"agent_name": rec.agent_name,
|
||||||
|
"model": rec.model,
|
||||||
|
"prompt_tokens": rec.prompt_tokens,
|
||||||
|
"completion_tokens": rec.completion_tokens,
|
||||||
|
"total_tokens": rec.total_tokens,
|
||||||
|
"cost": rec.cost,
|
||||||
|
"latency_ms": rec.latency_ms,
|
||||||
|
"user_id": rec.user_id or "",
|
||||||
|
"department_id": rec.department_id or "",
|
||||||
|
}
|
||||||
|
for rec in summary.records
|
||||||
|
]
|
||||||
|
if format == "json":
|
||||||
|
return json.dumps(records, ensure_ascii=False, indent=2)
|
||||||
|
# Default: CSV
|
||||||
|
out = io.StringIO()
|
||||||
|
if records:
|
||||||
|
writer = csv.DictWriter(out, fieldnames=list(records[0].keys()))
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(records)
|
||||||
|
else:
|
||||||
|
# Empty CSV with just headers
|
||||||
|
writer = csv.DictWriter(
|
||||||
|
out,
|
||||||
|
fieldnames=[
|
||||||
|
"timestamp",
|
||||||
|
"agent_name",
|
||||||
|
"model",
|
||||||
|
"prompt_tokens",
|
||||||
|
"completion_tokens",
|
||||||
|
"total_tokens",
|
||||||
|
"cost",
|
||||||
|
"latency_ms",
|
||||||
|
"user_id",
|
||||||
|
"department_id",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
writer.writeheader()
|
||||||
|
return out.getvalue()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _summary_to_dict(summary: UsageSummary) -> dict[str, Any]:
|
||||||
|
"""Convert a :class:`UsageSummary` to a flat dict response."""
|
||||||
|
return {
|
||||||
|
"total_tokens": summary.total_tokens,
|
||||||
|
"total_cost": summary.total_cost,
|
||||||
|
"total_requests": len(summary.records),
|
||||||
|
"by_model": dict(summary.by_model),
|
||||||
|
"by_user": dict(summary.by_user),
|
||||||
|
"by_department": dict(summary.by_department),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_usage_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_usage_service: UsageService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_service() -> UsageService:
|
||||||
|
"""Return the process-wide :class:`UsageService` (lazy singleton)."""
|
||||||
|
global _usage_service
|
||||||
|
if _usage_service is None:
|
||||||
|
_usage_service = UsageService()
|
||||||
|
return _usage_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_usage_service(service: UsageService | None) -> None:
|
||||||
|
"""Inject a custom :class:`UsageService` (used by tests)."""
|
||||||
|
global _usage_service
|
||||||
|
_usage_service = service
|
||||||
|
|
@ -16,10 +16,12 @@ import time (keeps the module self-contained and test-friendly).
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from agentkit.server.admin.department_service import get_department_service
|
from agentkit.server.admin.department_service import get_department_service
|
||||||
|
|
@ -30,6 +32,7 @@ from agentkit.server.admin.llm_config_service import (
|
||||||
)
|
)
|
||||||
from agentkit.server.admin.quota_service import get_quota_service
|
from agentkit.server.admin.quota_service import get_quota_service
|
||||||
from agentkit.server.admin.skill_service import get_skill_service
|
from agentkit.server.admin.skill_service import get_skill_service
|
||||||
|
from agentkit.server.admin.usage_service import get_usage_service
|
||||||
from agentkit.server.admin.user_service import get_user_service
|
from agentkit.server.admin.user_service import get_user_service
|
||||||
from agentkit.server.auth.dependencies import require_authenticated
|
from agentkit.server.auth.dependencies import require_authenticated
|
||||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
||||||
|
|
@ -1138,3 +1141,192 @@ async def rebuild_kb_source(
|
||||||
return svc.rebuild_index(source_id)
|
return svc.rebuild_index(source_id)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Usage dashboard endpoints (U7) — usage aggregations + export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_usage_store(request: Request) -> Any:
|
||||||
|
"""Return the live :class:`UsageStore` from ``app.state.llm_gateway``.
|
||||||
|
|
||||||
|
Raises HTTPException(500) if the gateway or usage store is missing —
|
||||||
|
usage endpoints cannot function without it.
|
||||||
|
"""
|
||||||
|
gateway = getattr(request.app.state, "llm_gateway", None)
|
||||||
|
if gateway is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="LLM gateway not initialized on app.state",
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
return gateway._usage_tracker.store # type: ignore[attr-defined]
|
||||||
|
except AttributeError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Usage store not available on LLM gateway",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso(value: str | None) -> datetime | None:
|
||||||
|
"""Parse an ISO 8601 string into a timezone-aware datetime."""
|
||||||
|
if value is None or value == "":
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(value)
|
||||||
|
except ValueError:
|
||||||
|
# Try the trailing-Z form.
|
||||||
|
if value.endswith("Z"):
|
||||||
|
try:
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
dt = datetime.fromisoformat(value[:-1]).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}")
|
||||||
|
return dt
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/usage/summary")
|
||||||
|
async def get_usage_summary(
|
||||||
|
request: Request,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return an aggregated usage summary.
|
||||||
|
|
||||||
|
Query params: ``department_id``, ``user_id``, ``start``, ``end``
|
||||||
|
(ISO 8601). Admins see all data; non-admin callers are blocked by
|
||||||
|
``_require_admin`` (403).
|
||||||
|
"""
|
||||||
|
store = _get_usage_store(request)
|
||||||
|
svc = get_usage_service()
|
||||||
|
return await svc.get_usage_summary(
|
||||||
|
store,
|
||||||
|
department_id=department_id,
|
||||||
|
user_id=user_id,
|
||||||
|
start_time=_parse_iso(start),
|
||||||
|
end_time=_parse_iso(end),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/usage/timeseries")
|
||||||
|
async def get_usage_timeseries(
|
||||||
|
request: Request,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
interval: str = "day",
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return a time-bucketed usage series.
|
||||||
|
|
||||||
|
Query params: ``department_id``, ``user_id``, ``start``, ``end``
|
||||||
|
(ISO 8601), ``interval`` (``day`` or ``hour``, default ``day``).
|
||||||
|
"""
|
||||||
|
if interval not in ("day", "hour"):
|
||||||
|
raise HTTPException(status_code=400, detail="interval must be 'day' or 'hour'")
|
||||||
|
store = _get_usage_store(request)
|
||||||
|
svc = get_usage_service()
|
||||||
|
return await svc.get_usage_timeseries(
|
||||||
|
store,
|
||||||
|
department_id=department_id,
|
||||||
|
user_id=user_id,
|
||||||
|
start_time=_parse_iso(start),
|
||||||
|
end_time=_parse_iso(end),
|
||||||
|
interval=interval,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/usage/by-model")
|
||||||
|
async def get_usage_by_model(
|
||||||
|
request: Request,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return a per-model usage breakdown."""
|
||||||
|
store = _get_usage_store(request)
|
||||||
|
svc = get_usage_service()
|
||||||
|
return await svc.get_usage_by_model(
|
||||||
|
store,
|
||||||
|
department_id=department_id,
|
||||||
|
user_id=user_id,
|
||||||
|
start_time=_parse_iso(start),
|
||||||
|
end_time=_parse_iso(end),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/usage/top-users")
|
||||||
|
async def get_top_users(
|
||||||
|
request: Request,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
limit: int = 10,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return the top-N users by total token usage.
|
||||||
|
|
||||||
|
Query params: ``department_id``, ``user_id``, ``start``, ``end``,
|
||||||
|
``limit`` (default 10, max 100).
|
||||||
|
"""
|
||||||
|
if limit < 1:
|
||||||
|
limit = 1
|
||||||
|
if limit > 100:
|
||||||
|
limit = 100
|
||||||
|
store = _get_usage_store(request)
|
||||||
|
svc = get_usage_service()
|
||||||
|
return await svc.get_top_users(
|
||||||
|
store,
|
||||||
|
department_id=department_id,
|
||||||
|
user_id=user_id,
|
||||||
|
start_time=_parse_iso(start),
|
||||||
|
end_time=_parse_iso(end),
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/usage/export")
|
||||||
|
async def export_usage(
|
||||||
|
request: Request,
|
||||||
|
department_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
start: str | None = None,
|
||||||
|
end: str | None = None,
|
||||||
|
format: str = "csv",
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> Any:
|
||||||
|
"""Export raw usage records as CSV or JSON.
|
||||||
|
|
||||||
|
Query params: ``department_id``, ``user_id``, ``start``, ``end``,
|
||||||
|
``format`` (``csv`` or ``json``, default ``csv``).
|
||||||
|
|
||||||
|
Returns ``text/csv`` for CSV or ``application/json`` for JSON.
|
||||||
|
"""
|
||||||
|
if format not in ("csv", "json"):
|
||||||
|
raise HTTPException(status_code=400, detail="format must be 'csv' or 'json'")
|
||||||
|
store = _get_usage_store(request)
|
||||||
|
svc = get_usage_service()
|
||||||
|
body = await svc.export_usage(
|
||||||
|
store,
|
||||||
|
department_id=department_id,
|
||||||
|
user_id=user_id,
|
||||||
|
start_time=_parse_iso(start),
|
||||||
|
end_time=_parse_iso(end),
|
||||||
|
format=format,
|
||||||
|
)
|
||||||
|
if format == "csv":
|
||||||
|
return PlainTextResponse(content=body, media_type="text/csv")
|
||||||
|
return PlainTextResponse(content=body, media_type="application/json")
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,14 @@ Supports both non-streaming (`POST /api/v1/llm/chat`) and SSE streaming
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||||
|
from agentkit.llm.gateway import QuotaExceededError
|
||||||
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||||||
|
from agentkit.server.admin.context import get_department_context
|
||||||
|
|
||||||
router = APIRouter(prefix="/llm", tags=["llm-gateway"])
|
router = APIRouter(prefix="/llm", tags=["llm-gateway"])
|
||||||
|
|
||||||
|
|
@ -66,14 +68,32 @@ def _serialize_chunk(chunk: StreamChunk) -> dict[str, Any]:
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _quota_error_payload(exc: QuotaExceededError) -> dict[str, Any]:
|
||||||
|
"""Build a structured 429 error body from a QuotaExceededError."""
|
||||||
|
return {
|
||||||
|
"error": "quota_exceeded",
|
||||||
|
"department_id": exc.department_id,
|
||||||
|
"quota_type": exc.quota_type,
|
||||||
|
"period": exc.period,
|
||||||
|
"limit": exc.limit,
|
||||||
|
"current": exc.current,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
|
async def chat(
|
||||||
|
request: Request,
|
||||||
|
body: LLMChatRequest,
|
||||||
|
ctx: Any = Depends(get_department_context),
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Non-streaming LLM chat proxy.
|
"""Non-streaming LLM chat proxy.
|
||||||
|
|
||||||
Forwards the request to the configured LLMGateway and returns the
|
Forwards the request to the configured LLMGateway and returns the
|
||||||
serialized LLMResponse.
|
serialized LLMResponse. Quota-exceeded errors from the gateway are
|
||||||
|
translated to HTTP 429.
|
||||||
"""
|
"""
|
||||||
gateway = request.app.state.llm_gateway
|
gateway = request.app.state.llm_gateway
|
||||||
|
db_path = getattr(request.app.state, "auth_db_path", None)
|
||||||
try:
|
try:
|
||||||
response = await gateway.chat(
|
response = await gateway.chat(
|
||||||
messages=body.messages,
|
messages=body.messages,
|
||||||
|
|
@ -83,7 +103,12 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
|
||||||
timeout=body.timeout,
|
timeout=body.timeout,
|
||||||
temperature=body.temperature,
|
temperature=body.temperature,
|
||||||
max_tokens=body.max_tokens,
|
max_tokens=body.max_tokens,
|
||||||
|
user_id=ctx.user_id,
|
||||||
|
department_ids=ctx.department_ids if ctx.department_ids else None,
|
||||||
|
db_path=db_path,
|
||||||
)
|
)
|
||||||
|
except QuotaExceededError as e:
|
||||||
|
raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e
|
||||||
except ModelNotFoundError as e:
|
except ModelNotFoundError as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||||
except LLMProviderError as e:
|
except LLMProviderError as e:
|
||||||
|
|
@ -92,7 +117,11 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat/stream")
|
@router.post("/chat/stream")
|
||||||
async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingResponse:
|
async def chat_stream(
|
||||||
|
request: Request,
|
||||||
|
body: LLMChatRequest,
|
||||||
|
ctx: Any = Depends(get_department_context),
|
||||||
|
) -> StreamingResponse:
|
||||||
"""SSE streaming LLM chat proxy.
|
"""SSE streaming LLM chat proxy.
|
||||||
|
|
||||||
Each StreamChunk is serialized as `data: {json}\\n\\n`. The stream
|
Each StreamChunk is serialized as `data: {json}\\n\\n`. The stream
|
||||||
|
|
@ -101,6 +130,7 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
gateway = request.app.state.llm_gateway
|
gateway = request.app.state.llm_gateway
|
||||||
|
db_path = getattr(request.app.state, "auth_db_path", None)
|
||||||
try:
|
try:
|
||||||
async for chunk in gateway.chat_stream(
|
async for chunk in gateway.chat_stream(
|
||||||
messages=body.messages,
|
messages=body.messages,
|
||||||
|
|
@ -110,9 +140,16 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon
|
||||||
timeout=body.timeout,
|
timeout=body.timeout,
|
||||||
temperature=body.temperature,
|
temperature=body.temperature,
|
||||||
max_tokens=body.max_tokens,
|
max_tokens=body.max_tokens,
|
||||||
|
user_id=ctx.user_id,
|
||||||
|
department_ids=ctx.department_ids if ctx.department_ids else None,
|
||||||
|
db_path=db_path,
|
||||||
):
|
):
|
||||||
payload = _serialize_chunk(chunk)
|
payload = _serialize_chunk(chunk)
|
||||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
except QuotaExceededError as e:
|
||||||
|
error_payload = _quota_error_payload(e)
|
||||||
|
error_payload["error"] = "quota_exceeded"
|
||||||
|
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
||||||
except ModelNotFoundError as e:
|
except ModelNotFoundError as e:
|
||||||
error_payload = {"error": "model_not_found", "detail": str(e)}
|
error_payload = {"error": "model_not_found", "detail": str(e)}
|
||||||
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,341 @@
|
||||||
|
"""Integration tests for the admin usage dashboard routes (U7).
|
||||||
|
|
||||||
|
Uses FastAPI TestClient with a test app that mounts only the
|
||||||
|
``admin_router`` from ``routes.admin``. The ``_require_admin``
|
||||||
|
dependency is overridden via ``app.dependency_overrides`` so the tests
|
||||||
|
don't need real JWTs — they can simulate admin and non-admin callers
|
||||||
|
directly.
|
||||||
|
|
||||||
|
The LLM gateway is replaced with a stub that exposes a
|
||||||
|
``_usage_tracker.store`` attribute pointing at an
|
||||||
|
:class:`InMemoryUsageStore` pre-populated with test records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
||||||
|
from agentkit.server.admin.usage_service import set_usage_service
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.server.routes import admin as admin_routes_module
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StubTracker:
|
||||||
|
"""Minimal stub matching the UsageTracker surface used by routes."""
|
||||||
|
|
||||||
|
def __init__(self, store: InMemoryUsageStore) -> None:
|
||||||
|
self.store = store
|
||||||
|
|
||||||
|
|
||||||
|
class _StubGateway:
|
||||||
|
"""Minimal stub matching the LLMGateway surface used by routes."""
|
||||||
|
|
||||||
|
def __init__(self, store: InMemoryUsageStore) -> None:
|
||||||
|
self._usage_tracker = _StubTracker(store)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store() -> InMemoryUsageStore:
|
||||||
|
return InMemoryUsageStore()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def populated_store() -> InMemoryUsageStore:
|
||||||
|
"""Pre-populated store with a mix of records across users/depts/models."""
|
||||||
|
s = InMemoryUsageStore()
|
||||||
|
s.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=60, completion_tokens=40),
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
s.record(
|
||||||
|
"agent1",
|
||||||
|
"claude",
|
||||||
|
TokenUsage(prompt_tokens=120, completion_tokens=80),
|
||||||
|
cost=0.10,
|
||||||
|
latency_ms=300,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
s.record(
|
||||||
|
"agent2",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=30, completion_tokens=20),
|
||||||
|
cost=0.02,
|
||||||
|
latency_ms=100,
|
||||||
|
user_id="u2",
|
||||||
|
department_id="d2",
|
||||||
|
)
|
||||||
|
s.record(
|
||||||
|
"agent3",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=300, completion_tokens=200),
|
||||||
|
cost=0.50,
|
||||||
|
latency_ms=400,
|
||||||
|
user_id="u3",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singletons():
|
||||||
|
set_usage_service(None)
|
||||||
|
yield
|
||||||
|
set_usage_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||||
|
db_path = tmp_path / "usage_routes.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
def _make_admin_app(store: InMemoryUsageStore, tmp_auth_db: Path) -> FastAPI:
|
||||||
|
"""Build a FastAPI app with admin router + stub gateway."""
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.auth_db_path = str(tmp_auth_db)
|
||||||
|
app.state.llm_gateway = _StubGateway(store)
|
||||||
|
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||||
|
# Default: allow admin access.
|
||||||
|
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_admin_user() -> dict[str, Any]:
|
||||||
|
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_forbidden() -> dict[str, Any]:
|
||||||
|
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /admin/usage/summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageSummaryRoute:
|
||||||
|
def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/summary")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["total_tokens"] == 850
|
||||||
|
assert abs(body["total_cost"] - 0.67) < 1e-6
|
||||||
|
assert body["total_requests"] == 4
|
||||||
|
assert "gpt-4o" in body["by_model"]
|
||||||
|
assert "u1" in body["by_user"]
|
||||||
|
assert "d1" in body["by_department"]
|
||||||
|
|
||||||
|
def test_with_department_filter(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/summary", params={"department_id": "d2"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["total_tokens"] == 50
|
||||||
|
assert body["total_requests"] == 1
|
||||||
|
|
||||||
|
def test_empty_store_returns_200_with_zeros(self, store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/summary")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["total_tokens"] == 0
|
||||||
|
assert body["total_cost"] == 0.0
|
||||||
|
assert body["total_requests"] == 0
|
||||||
|
|
||||||
|
def test_non_admin_returns_403(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/summary")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /admin/usage/timeseries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageTimeseriesRoute:
|
||||||
|
def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/timeseries")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert isinstance(body, list)
|
||||||
|
assert len(body) >= 1
|
||||||
|
assert "timestamp" in body[0]
|
||||||
|
assert "tokens" in body[0]
|
||||||
|
assert body[0]["tokens"] == 850
|
||||||
|
|
||||||
|
def test_invalid_interval_returns_400(
|
||||||
|
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
|
||||||
|
):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/timeseries", params={"interval": "week"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/timeseries")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /admin/usage/by-model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageByModelRoute:
|
||||||
|
def test_returns_200_with_breakdown(
|
||||||
|
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
|
||||||
|
):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/by-model")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert isinstance(body, list)
|
||||||
|
models = {row["model"] for row in body}
|
||||||
|
assert models == {"gpt-4o", "claude"}
|
||||||
|
|
||||||
|
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/by-model")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /admin/usage/top-users
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopUsersRoute:
|
||||||
|
def test_returns_200_sorted_by_tokens(
|
||||||
|
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
|
||||||
|
):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/top-users")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body) == 3
|
||||||
|
# u3 (500), u1 (300), u2 (50)
|
||||||
|
assert body[0]["user_id"] == "u3"
|
||||||
|
assert body[0]["tokens"] == 500
|
||||||
|
assert body[1]["user_id"] == "u1"
|
||||||
|
assert body[2]["user_id"] == "u2"
|
||||||
|
|
||||||
|
def test_limit_param_respected(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/top-users", params={"limit": 2})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body) == 2
|
||||||
|
|
||||||
|
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/top-users")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# /admin/usage/export
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUsageExportRoute:
|
||||||
|
def test_csv_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.headers["content-type"].startswith("text/csv")
|
||||||
|
reader = csv.DictReader(io.StringIO(resp.text))
|
||||||
|
rows = list(reader)
|
||||||
|
assert len(rows) == 4
|
||||||
|
assert "timestamp" in rows[0]
|
||||||
|
assert "user_id" in rows[0]
|
||||||
|
assert "department_id" in rows[0]
|
||||||
|
|
||||||
|
def test_json_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/export", params={"format": "json"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# PlainTextResponse returns text/plain or application/json depending on media_type.
|
||||||
|
body = json.loads(resp.text)
|
||||||
|
assert isinstance(body, list)
|
||||||
|
assert len(body) == 4
|
||||||
|
|
||||||
|
def test_invalid_format_returns_400(
|
||||||
|
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
|
||||||
|
):
|
||||||
|
app = _make_admin_app(populated_store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/export", params={"format": "xml"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_empty_store_csv_returns_header_only(
|
||||||
|
self, store: InMemoryUsageStore, tmp_auth_db: Path
|
||||||
|
):
|
||||||
|
app = _make_admin_app(store, tmp_auth_db)
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
reader = csv.DictReader(io.StringIO(resp.text))
|
||||||
|
rows = list(reader)
|
||||||
|
assert rows == []
|
||||||
|
# Header should still be present.
|
||||||
|
assert "timestamp" in resp.text
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Missing gateway
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMissingGateway:
|
||||||
|
def test_summary_returns_500_without_gateway(self, tmp_auth_db: Path):
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.auth_db_path = str(tmp_auth_db)
|
||||||
|
# No llm_gateway on app.state.
|
||||||
|
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||||
|
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||||
|
client = TestClient(app)
|
||||||
|
resp = client.get("/api/v1/admin/usage/summary")
|
||||||
|
assert resp.status_code == 500
|
||||||
|
|
@ -0,0 +1,330 @@
|
||||||
|
"""Unit tests for UsageService (U7 — usage dashboard aggregations)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
||||||
|
from agentkit.server.admin.usage_service import (
|
||||||
|
UsageService,
|
||||||
|
get_usage_service,
|
||||||
|
set_usage_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store() -> InMemoryUsageStore:
|
||||||
|
return InMemoryUsageStore()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service() -> UsageService:
|
||||||
|
return UsageService()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singleton():
|
||||||
|
"""Reset the UsageService singleton before and after each test."""
|
||||||
|
set_usage_service(None)
|
||||||
|
yield
|
||||||
|
set_usage_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_store(store: InMemoryUsageStore) -> None:
|
||||||
|
"""Populate ``store`` with a mix of records for testing."""
|
||||||
|
# User u1 in dept d1, gpt-4o, 100 tokens, $0.05
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=60, completion_tokens=40),
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
# User u1 in dept d1, claude, 200 tokens, $0.10
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"claude",
|
||||||
|
TokenUsage(prompt_tokens=120, completion_tokens=80),
|
||||||
|
cost=0.10,
|
||||||
|
latency_ms=300,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
# User u2 in dept d2, gpt-4o, 50 tokens, $0.02
|
||||||
|
store.record(
|
||||||
|
"agent2",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=30, completion_tokens=20),
|
||||||
|
cost=0.02,
|
||||||
|
latency_ms=100,
|
||||||
|
user_id="u2",
|
||||||
|
department_id="d2",
|
||||||
|
)
|
||||||
|
# User u3 in dept d1, gpt-4o, 500 tokens, $0.50 (top user)
|
||||||
|
store.record(
|
||||||
|
"agent3",
|
||||||
|
"gpt-4o",
|
||||||
|
TokenUsage(prompt_tokens=300, completion_tokens=200),
|
||||||
|
cost=0.50,
|
||||||
|
latency_ms=400,
|
||||||
|
user_id="u3",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_usage_summary
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUsageSummary:
|
||||||
|
async def test_summary_aggregates_all(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_summary(store)
|
||||||
|
assert result["total_tokens"] == 850
|
||||||
|
assert abs(result["total_cost"] - 0.67) < 1e-6
|
||||||
|
assert result["total_requests"] == 4
|
||||||
|
# by_model: gpt-4o (3 records, 650 tokens), claude (1, 200)
|
||||||
|
assert "gpt-4o" in result["by_model"]
|
||||||
|
assert "claude" in result["by_model"]
|
||||||
|
assert result["by_model"]["gpt-4o"]["count"] == 3
|
||||||
|
assert result["by_model"]["gpt-4o"]["total_tokens"] == 650
|
||||||
|
# by_user: u1 (2 records, 300 tokens), u2 (1, 50), u3 (1, 500)
|
||||||
|
assert result["by_user"]["u1"]["total_tokens"] == 300
|
||||||
|
assert result["by_user"]["u2"]["total_tokens"] == 50
|
||||||
|
assert result["by_user"]["u3"]["total_tokens"] == 500
|
||||||
|
# by_department: d1 (3 records, 800 tokens), d2 (1, 50)
|
||||||
|
assert result["by_department"]["d1"]["total_tokens"] == 800
|
||||||
|
assert result["by_department"]["d2"]["total_tokens"] == 50
|
||||||
|
|
||||||
|
async def test_summary_with_department_filter(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_summary(store, department_id="d1")
|
||||||
|
assert result["total_tokens"] == 800
|
||||||
|
assert result["total_requests"] == 3
|
||||||
|
# Only u1 and u3 are in d1.
|
||||||
|
assert "u1" in result["by_user"]
|
||||||
|
assert "u3" in result["by_user"]
|
||||||
|
assert "u2" not in result["by_user"]
|
||||||
|
|
||||||
|
async def test_summary_with_user_filter(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_summary(store, user_id="u2")
|
||||||
|
assert result["total_tokens"] == 50
|
||||||
|
assert result["total_requests"] == 1
|
||||||
|
assert "u2" in result["by_user"]
|
||||||
|
|
||||||
|
async def test_summary_with_empty_store(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
result = await service.get_usage_summary(store)
|
||||||
|
assert result["total_tokens"] == 0
|
||||||
|
assert result["total_cost"] == 0.0
|
||||||
|
assert result["total_requests"] == 0
|
||||||
|
assert result["by_model"] == {}
|
||||||
|
assert result["by_user"] == {}
|
||||||
|
assert result["by_department"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_usage_timeseries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUsageTimeseries:
|
||||||
|
async def test_timeseries_day_buckets(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_timeseries(store, interval="day")
|
||||||
|
# All records are within the same day (today), so we expect one bucket.
|
||||||
|
assert len(result) >= 1
|
||||||
|
bucket = result[0]
|
||||||
|
assert "timestamp" in bucket
|
||||||
|
assert bucket["tokens"] == 850
|
||||||
|
assert abs(bucket["cost"] - 0.67) < 1e-6
|
||||||
|
assert bucket["requests"] == 4
|
||||||
|
|
||||||
|
async def test_timeseries_hour_buckets(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_timeseries(store, interval="hour")
|
||||||
|
# All records are within the same hour (now), so we expect one bucket.
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert result[0]["tokens"] == 850
|
||||||
|
|
||||||
|
async def test_timeseries_with_department_filter(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_timeseries(store, department_id="d2", interval="day")
|
||||||
|
assert len(result) >= 1
|
||||||
|
assert result[0]["tokens"] == 50
|
||||||
|
|
||||||
|
async def test_timeseries_empty_store(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
result = await service.get_usage_timeseries(store, interval="day")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_usage_by_model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetUsageByModel:
|
||||||
|
async def test_by_model_breakdown(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_by_model(store)
|
||||||
|
# Sorted by model name: claude, gpt-4o
|
||||||
|
assert len(result) == 2
|
||||||
|
models = {row["model"] for row in result}
|
||||||
|
assert models == {"gpt-4o", "claude"}
|
||||||
|
gpt_row = next(r for r in result if r["model"] == "gpt-4o")
|
||||||
|
assert gpt_row["tokens"] == 650
|
||||||
|
assert gpt_row["requests"] == 3
|
||||||
|
claude_row = next(r for r in result if r["model"] == "claude")
|
||||||
|
assert claude_row["tokens"] == 200
|
||||||
|
assert claude_row["requests"] == 1
|
||||||
|
|
||||||
|
async def test_by_model_with_department_filter(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_usage_by_model(store, department_id="d2")
|
||||||
|
# d2 only has gpt-4o (50 tokens, 1 request)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["model"] == "gpt-4o"
|
||||||
|
assert result[0]["tokens"] == 50
|
||||||
|
|
||||||
|
async def test_by_model_empty_store(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
result = await service.get_usage_by_model(store)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_top_users
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetTopUsers:
|
||||||
|
async def test_top_users_sorted_by_tokens(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_top_users(store, limit=10)
|
||||||
|
# u3 (500), u1 (300), u2 (50)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["user_id"] == "u3"
|
||||||
|
assert result[0]["tokens"] == 500
|
||||||
|
assert result[1]["user_id"] == "u1"
|
||||||
|
assert result[1]["tokens"] == 300
|
||||||
|
assert result[2]["user_id"] == "u2"
|
||||||
|
assert result[2]["tokens"] == 50
|
||||||
|
|
||||||
|
async def test_top_users_respects_limit(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_top_users(store, limit=2)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["user_id"] == "u3"
|
||||||
|
assert result[1]["user_id"] == "u1"
|
||||||
|
|
||||||
|
async def test_top_users_with_department_filter(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
result = await service.get_top_users(store, department_id="d1", limit=10)
|
||||||
|
# d1 has u1 and u3
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["user_id"] == "u3"
|
||||||
|
assert result[1]["user_id"] == "u1"
|
||||||
|
|
||||||
|
async def test_top_users_empty_store(self, service: UsageService, store: InMemoryUsageStore):
|
||||||
|
result = await service.get_top_users(store, limit=10)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# export_usage
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExportUsage:
|
||||||
|
async def test_export_csv_has_header_and_rows(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
body = await service.export_usage(store, format="csv")
|
||||||
|
reader = csv.DictReader(io.StringIO(body))
|
||||||
|
rows = list(reader)
|
||||||
|
assert len(rows) == 4
|
||||||
|
# Verify headers
|
||||||
|
assert "timestamp" in rows[0]
|
||||||
|
assert "agent_name" in rows[0]
|
||||||
|
assert "model" in rows[0]
|
||||||
|
assert "user_id" in rows[0]
|
||||||
|
assert "department_id" in rows[0]
|
||||||
|
# Verify a known record
|
||||||
|
gpt_rows = [r for r in rows if r["model"] == "gpt-4o"]
|
||||||
|
assert len(gpt_rows) == 3
|
||||||
|
|
||||||
|
async def test_export_json_returns_valid_json(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
body = await service.export_usage(store, format="json")
|
||||||
|
data = json.loads(body)
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 4
|
||||||
|
assert "timestamp" in data[0]
|
||||||
|
assert "user_id" in data[0]
|
||||||
|
assert "department_id" in data[0]
|
||||||
|
|
||||||
|
async def test_export_csv_empty_store_returns_header_only(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
body = await service.export_usage(store, format="csv")
|
||||||
|
reader = csv.DictReader(io.StringIO(body))
|
||||||
|
rows = list(reader)
|
||||||
|
assert rows == []
|
||||||
|
# Header should still be present.
|
||||||
|
assert "timestamp" in body
|
||||||
|
|
||||||
|
async def test_export_with_department_filter(
|
||||||
|
self, service: UsageService, store: InMemoryUsageStore
|
||||||
|
):
|
||||||
|
_populate_store(store)
|
||||||
|
body = await service.export_usage(store, department_id="d2", format="csv")
|
||||||
|
reader = csv.DictReader(io.StringIO(body))
|
||||||
|
rows = list(reader)
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["department_id"] == "d2"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Singleton helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonHelpers:
|
||||||
|
def test_get_usage_service_returns_singleton(self):
|
||||||
|
first = get_usage_service()
|
||||||
|
second = get_usage_service()
|
||||||
|
assert first is second
|
||||||
|
|
||||||
|
def test_set_usage_service_overrides(self):
|
||||||
|
custom = UsageService()
|
||||||
|
set_usage_service(custom)
|
||||||
|
assert get_usage_service() is custom
|
||||||
|
# Clearing falls back to a new lazy instance.
|
||||||
|
set_usage_service(None)
|
||||||
|
new_one = get_usage_service()
|
||||||
|
assert new_one is not custom
|
||||||
|
|
@ -0,0 +1,321 @@
|
||||||
|
"""Unit tests for LLMGateway quota enforcement (U7).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- QuotaExceededError raised when token_limit exceeded
|
||||||
|
- QuotaExceededError raised when cost_limit exceeded
|
||||||
|
- QuotaExceededError raised when model not in whitelist
|
||||||
|
- No quota set → request allowed
|
||||||
|
- Multi-department: strictest-wins (one exceeds, other doesn't → rejected)
|
||||||
|
- QuotaExceededError carries the right metadata
|
||||||
|
- Usage recording still attaches user_id + department_id on success
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.llm.gateway import LLMGateway, QuotaExceededError
|
||||||
|
from agentkit.llm.protocol import (
|
||||||
|
LLMProvider,
|
||||||
|
LLMRequest,
|
||||||
|
LLMResponse,
|
||||||
|
TokenUsage,
|
||||||
|
)
|
||||||
|
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
||||||
|
from agentkit.server.admin.quota_service import (
|
||||||
|
get_quota_service,
|
||||||
|
set_quota_service,
|
||||||
|
)
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test doubles
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class FakeProvider(LLMProvider):
|
||||||
|
"""A minimal LLMProvider that returns a fixed response."""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "fake"):
|
||||||
|
self._name = name
|
||||||
|
self.last_request: LLMRequest | None = None
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
self.last_request = request
|
||||||
|
self.call_count += 1
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"response from {self._name}",
|
||||||
|
model=request.model,
|
||||||
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store() -> InMemoryUsageStore:
|
||||||
|
return InMemoryUsageStore()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gateway(store: InMemoryUsageStore) -> LLMGateway:
|
||||||
|
gw = LLMGateway(usage_store=store)
|
||||||
|
gw.register_provider("openai", FakeProvider("openai"))
|
||||||
|
return gw
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fresh_db(tmp_path: Path) -> Path:
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_quota_singleton():
|
||||||
|
"""Reset the QuotaService singleton before and after each test."""
|
||||||
|
set_quota_service(None)
|
||||||
|
yield
|
||||||
|
set_quota_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_dept_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quota enforcement
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaEnforcement:
|
||||||
|
async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""When no quota is configured, the request is allowed."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
response = await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
assert response.content == "response from openai"
|
||||||
|
|
||||||
|
async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""token_limit quota exceeded → QuotaExceededError."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
# Set a tiny token limit (1 token) — any usage will exceed it.
|
||||||
|
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
|
||||||
|
|
||||||
|
# Pre-populate the usage store so the daily total > 1.
|
||||||
|
gateway._usage_tracker.record(
|
||||||
|
agent_name="prev",
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||||
|
cost=0.0,
|
||||||
|
latency_ms=10,
|
||||||
|
user_id="u1",
|
||||||
|
department_id=dept_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(QuotaExceededError) as exc_info:
|
||||||
|
await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
err = exc_info.value
|
||||||
|
assert err.department_id == dept_id
|
||||||
|
assert err.quota_type == "token_limit"
|
||||||
|
assert err.period == "daily"
|
||||||
|
assert err.limit == 1
|
||||||
|
assert err.current == 150 # 100 prompt + 50 completion
|
||||||
|
|
||||||
|
async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""cost_limit quota exceeded → QuotaExceededError."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
# cost_limit is in cents. Set 1 cent.
|
||||||
|
await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily")
|
||||||
|
|
||||||
|
# Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit.
|
||||||
|
gateway._usage_tracker.record(
|
||||||
|
agent_name="prev",
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||||
|
cost=1.00, # $1.00 = 100 cents
|
||||||
|
latency_ms=10,
|
||||||
|
user_id="u1",
|
||||||
|
department_id=dept_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(QuotaExceededError) as exc_info:
|
||||||
|
await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
err = exc_info.value
|
||||||
|
assert err.quota_type == "cost_limit"
|
||||||
|
assert err.period == "daily"
|
||||||
|
assert err.limit == 1
|
||||||
|
# current is in cents (100 cents = $1.00).
|
||||||
|
assert err.current == 100.0
|
||||||
|
|
||||||
|
async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""Model not in whitelist → QuotaExceededError with quota_type=model_whitelist."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
# Whitelist only allows "claude" — gateway is calling "gpt-4o".
|
||||||
|
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily")
|
||||||
|
|
||||||
|
with pytest.raises(QuotaExceededError) as exc_info:
|
||||||
|
await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
err = exc_info.value
|
||||||
|
assert err.quota_type == "model_whitelist"
|
||||||
|
assert err.department_id == dept_id
|
||||||
|
# For model_whitelist, current is the rejected model name.
|
||||||
|
assert err.current == "openai/gpt-4o"
|
||||||
|
|
||||||
|
async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""Model in whitelist → request allowed."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
# Whitelist uses the full resolved model identifier (provider/model).
|
||||||
|
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily")
|
||||||
|
response = await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
assert response.content == "response from openai"
|
||||||
|
|
||||||
|
async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""One department exceeds, the other doesn't → rejected (strictest wins)."""
|
||||||
|
dept_ok = _random_dept_id()
|
||||||
|
dept_bad = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
# dept_bad has a 1-token limit; dept_ok has a 1M-token limit.
|
||||||
|
await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily")
|
||||||
|
await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily")
|
||||||
|
|
||||||
|
# Pre-populate usage for dept_bad so it exceeds.
|
||||||
|
gateway._usage_tracker.record(
|
||||||
|
agent_name="prev",
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
||||||
|
cost=0.0,
|
||||||
|
latency_ms=10,
|
||||||
|
user_id="u1",
|
||||||
|
department_id=dept_bad,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(QuotaExceededError) as exc_info:
|
||||||
|
await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_ok, dept_bad],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
# The error should reference dept_bad (the one that exceeded).
|
||||||
|
assert exc_info.value.department_id == dept_bad
|
||||||
|
|
||||||
|
async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path):
|
||||||
|
"""When db_path is None, no quota check is performed."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
svc = get_quota_service()
|
||||||
|
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
|
||||||
|
# Even with a quota set, calling without db_path should succeed.
|
||||||
|
response = await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=None,
|
||||||
|
)
|
||||||
|
assert response.content == "response from openai"
|
||||||
|
|
||||||
|
async def test_quota_check_skipped_without_department_ids(
|
||||||
|
self, gateway: LLMGateway, fresh_db: Path
|
||||||
|
):
|
||||||
|
"""When department_ids is None, no quota check is performed."""
|
||||||
|
response = await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=None,
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
assert response.content == "response from openai"
|
||||||
|
|
||||||
|
async def test_usage_recorded_with_user_and_department(
|
||||||
|
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
||||||
|
):
|
||||||
|
"""After a successful call, the usage record carries user_id + department_id."""
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="openai/gpt-4o",
|
||||||
|
user_id="u1",
|
||||||
|
department_ids=[dept_id],
|
||||||
|
db_path=fresh_db,
|
||||||
|
)
|
||||||
|
summary = store.get_usage()
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
rec = summary.records[0]
|
||||||
|
assert rec.user_id == "u1"
|
||||||
|
assert rec.department_id == dept_id
|
||||||
|
assert rec.model == "gpt-4o"
|
||||||
|
assert rec.total_tokens == 150 # 100 prompt + 50 completion
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# QuotaExceededError dataclass-like behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaExceededError:
|
||||||
|
def test_error_message_includes_metadata(self):
|
||||||
|
err = QuotaExceededError(
|
||||||
|
department_id="d1",
|
||||||
|
quota_type="token_limit",
|
||||||
|
period="daily",
|
||||||
|
limit=1000,
|
||||||
|
current=1500,
|
||||||
|
)
|
||||||
|
msg = str(err)
|
||||||
|
assert "d1" in msg
|
||||||
|
assert "token_limit" in msg
|
||||||
|
assert "daily" in msg
|
||||||
|
assert "1000" in msg
|
||||||
|
assert "1500" in msg
|
||||||
|
|
||||||
|
def test_error_attributes_preserved(self):
|
||||||
|
err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000)
|
||||||
|
assert err.department_id == "d1"
|
||||||
|
assert err.quota_type == "cost_limit"
|
||||||
|
assert err.period == "monthly"
|
||||||
|
assert err.limit == 5000
|
||||||
|
assert err.current == 6000
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Unit tests for UsageStore (U4 — UsageStore Persistence)."""
|
"""Unit tests for UsageStore (U4 — UsageStore Persistence)."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from agentkit.llm.protocol import TokenUsage
|
from agentkit.llm.protocol import TokenUsage
|
||||||
|
|
@ -114,6 +113,128 @@ class TestInMemoryUsageStore:
|
||||||
# Should be parseable as ISO 8601
|
# Should be parseable as ISO 8601
|
||||||
datetime.fromisoformat(rec.timestamp)
|
datetime.fromisoformat(rec.timestamp)
|
||||||
|
|
||||||
|
def test_record_with_user_and_department(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
rec = store.get_usage().records[0]
|
||||||
|
assert rec.user_id == "u1"
|
||||||
|
assert rec.department_id == "d1"
|
||||||
|
|
||||||
|
def test_record_defaults_user_department_to_none(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
|
||||||
|
rec = store.get_usage().records[0]
|
||||||
|
assert rec.user_id is None
|
||||||
|
assert rec.department_id is None
|
||||||
|
|
||||||
|
def test_get_usage_filters_by_user(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1")
|
||||||
|
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2")
|
||||||
|
summary = store.get_usage(user_id="u1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].user_id == "u1"
|
||||||
|
|
||||||
|
def test_get_usage_filters_by_department(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d1")
|
||||||
|
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d2")
|
||||||
|
summary = store.get_usage(department_id="d1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].department_id == "d1"
|
||||||
|
|
||||||
|
def test_get_usage_by_user(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u2",
|
||||||
|
department_id="d2",
|
||||||
|
)
|
||||||
|
summary = store.get_usage_by_user("u1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].user_id == "u1"
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
|
||||||
|
def test_get_usage_by_department(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u2",
|
||||||
|
department_id="d2",
|
||||||
|
)
|
||||||
|
summary = store.get_usage_by_department("d1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].department_id == "d1"
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
|
||||||
|
def test_summary_includes_by_user_and_by_department(self):
|
||||||
|
store = InMemoryUsageStore()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
summary = store.get_usage()
|
||||||
|
assert "u1" in summary.by_user
|
||||||
|
assert summary.by_user["u1"]["count"] == 2
|
||||||
|
assert summary.by_user["u1"]["total_tokens"] == 300
|
||||||
|
assert "d1" in summary.by_department
|
||||||
|
assert summary.by_department["d1"]["count"] == 2
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# UsageRecord / UsageBucket / UsageSummary dataclasses
|
# UsageRecord / UsageBucket / UsageSummary dataclasses
|
||||||
|
|
@ -123,17 +244,25 @@ class TestInMemoryUsageStore:
|
||||||
class TestDataclasses:
|
class TestDataclasses:
|
||||||
def test_usage_record_auto_timestamp(self):
|
def test_usage_record_auto_timestamp(self):
|
||||||
rec = UsageRecord(
|
rec = UsageRecord(
|
||||||
agent_name="a", model="m",
|
agent_name="a",
|
||||||
prompt_tokens=1, completion_tokens=1,
|
model="m",
|
||||||
total_tokens=2, cost=0.01, latency_ms=100,
|
prompt_tokens=1,
|
||||||
|
completion_tokens=1,
|
||||||
|
total_tokens=2,
|
||||||
|
cost=0.01,
|
||||||
|
latency_ms=100,
|
||||||
)
|
)
|
||||||
assert rec.timestamp != ""
|
assert rec.timestamp != ""
|
||||||
|
|
||||||
def test_usage_record_explicit_timestamp(self):
|
def test_usage_record_explicit_timestamp(self):
|
||||||
rec = UsageRecord(
|
rec = UsageRecord(
|
||||||
agent_name="a", model="m",
|
agent_name="a",
|
||||||
prompt_tokens=1, completion_tokens=1,
|
model="m",
|
||||||
total_tokens=2, cost=0.01, latency_ms=100,
|
prompt_tokens=1,
|
||||||
|
completion_tokens=1,
|
||||||
|
total_tokens=2,
|
||||||
|
cost=0.01,
|
||||||
|
latency_ms=100,
|
||||||
timestamp="2026-01-01T00:00:00+00:00",
|
timestamp="2026-01-01T00:00:00+00:00",
|
||||||
)
|
)
|
||||||
assert rec.timestamp == "2026-01-01T00:00:00+00:00"
|
assert rec.timestamp == "2026-01-01T00:00:00+00:00"
|
||||||
|
|
@ -208,6 +337,98 @@ class TestRedisUsageStoreMocked:
|
||||||
assert len(key) == 10
|
assert len(key) == 10
|
||||||
assert key[4] == "-"
|
assert key[4] == "-"
|
||||||
|
|
||||||
|
def test_v2_keys_with_user_and_department(self):
|
||||||
|
store = self._make_store()
|
||||||
|
hash_key, list_key = store._v2_keys("2026-06-21", "u1", "d1")
|
||||||
|
assert hash_key == "agentkit:usage:v2:2026-06-21:u1:d1"
|
||||||
|
assert list_key == "agentkit:usage_records:v2:2026-06-21:u1:d1"
|
||||||
|
|
||||||
|
def test_v2_keys_with_none_user_and_department(self):
|
||||||
|
store = self._make_store()
|
||||||
|
hash_key, list_key = store._v2_keys("2026-06-21", None, None)
|
||||||
|
# None values are normalized to "none" in the key.
|
||||||
|
assert hash_key == "agentkit:usage:v2:2026-06-21:none:none"
|
||||||
|
assert list_key == "agentkit:usage_records:v2:2026-06-21:none:none"
|
||||||
|
|
||||||
|
def test_record_degraded_with_user_and_department(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
# Should be in fallback with user/department attached.
|
||||||
|
summary = store._fallback.get_usage()
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].user_id == "u1"
|
||||||
|
assert summary.records[0].department_id == "d1"
|
||||||
|
|
||||||
|
def test_get_usage_degraded_with_user_filter(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store._fallback.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
store._fallback.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u2",
|
||||||
|
department_id="d2",
|
||||||
|
)
|
||||||
|
summary = store.get_usage(user_id="u1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].user_id == "u1"
|
||||||
|
|
||||||
|
def test_get_usage_by_user_degraded(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store._fallback.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
summary = store.get_usage_by_user("u1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].user_id == "u1"
|
||||||
|
|
||||||
|
def test_get_usage_by_department_degraded(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
store._fallback.record(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
summary = store.get_usage_by_department("d1")
|
||||||
|
assert len(summary.records) == 1
|
||||||
|
assert summary.records[0].department_id == "d1"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Factory
|
# Factory
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue