feat(admin): U7 — usage dashboard + quota enforcement

UsageRecord extended with user_id + department_id (backward compatible).
UsageStore Protocol extended: record() accepts user_id/department_id,
get_usage() accepts filters, new get_usage_by_user/department methods.
RedisUsageStore uses versioned keys (v2) for new records.

LLMGateway.chat()/chat_stream() accept user_id, department_ids, db_path.
Quota check before provider call: model whitelist + token limit + cost
limit (daily). Multi-department uses strictest-wins (any exceed → reject).
QuotaExceededError → 429 at route layer.

UsageService: summary, timeseries, by-model, top-users, export (CSV/JSON).
5 new admin endpoints under /admin/usage/*.

llm_gateway.py routes pass DepartmentContext + db_path to gateway,
catch QuotaExceededError → 429 (JSON for /chat, SSE error for /stream).

84 new tests. 441 admin+usage tests pass, no regressions.
This commit is contained in:
chiguyong 2026-06-21 17:23:20 +08:00
parent fd7f6816b8
commit 09feca3307
10 changed files with 2215 additions and 80 deletions

View File

@ -3,6 +3,8 @@
import asyncio
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
@ -15,6 +17,32 @@ from agentkit.telemetry.metrics import llm_token_histogram
logger = logging.getLogger(__name__)
class QuotaExceededError(Exception):
"""Raised when a department's LLM quota is exceeded.
Carries enough metadata for the API layer to return a structured
429 response (department_id, quota_type, period, limit, current).
"""
def __init__(
self,
department_id: str,
quota_type: str,
period: str,
limit: Any,
current: Any,
) -> None:
self.department_id = department_id
self.quota_type = quota_type
self.period = period
self.limit = limit
self.current = current
super().__init__(
f"Quota exceeded for department {department_id}: "
f"{quota_type} {period} (limit={limit}, current={current})"
)
class LLMGateway:
"""LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache"""
@ -83,6 +111,9 @@ class LLMGateway:
tools: list[dict] | None = None,
tool_choice: str = "auto",
timeout: float | None = None,
user_id: str | None = None,
department_ids: list[str] | None = None,
db_path: Path | str | None = None,
**kwargs,
) -> LLMResponse:
"""发送 chat 请求,自动解析别名和 Fallback"""
@ -91,6 +122,12 @@ class LLMGateway:
if not self._providers:
raise LLMProviderError("", "No provider registered")
# ── Quota enforcement ──
# Only enforce when department_ids + db_path are provided
# (other call sites pass None — no quota check).
if department_ids and db_path:
await self._enforce_quota(db_path, department_ids, resolved_model)
# Telemetry: start LLM span
_span_cm = None
_span = None
@ -131,12 +168,14 @@ class LLMGateway:
result = await self._cache.get(cache_key)
if result.hit:
latency_ms = (time.monotonic() - start) * 1000
self._usage_tracker.record(
self._record_usage(
agent_name=agent_name,
model=result.response.model,
usage=result.response.usage,
cost=0.0,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
if _span is not None:
_span.set_attribute("gen_ai.cache.hit", True)
@ -158,12 +197,14 @@ class LLMGateway:
result = await self._cache.semantic_search(query_embedding)
if result.hit:
latency_ms = (time.monotonic() - start) * 1000
self._usage_tracker.record(
self._record_usage(
agent_name=agent_name,
model=result.response.model,
usage=result.response.usage,
cost=0.0,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
if _span is not None:
_span.set_attribute("gen_ai.cache.hit", True)
@ -204,12 +245,14 @@ class LLMGateway:
if response.usage:
latency_ms = (time.monotonic() - start) * 1000
cost = self._calculate_cost(model_name, response.usage)
self._usage_tracker.record(
self._record_usage(
agent_name=agent_name,
model=model_name,
usage=response.usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
logger.warning(
f"Model '{model_name}' returned empty content with no tool_calls, "
@ -243,12 +286,14 @@ class LLMGateway:
cost = self._calculate_cost(response.model, response.usage)
# 记录使用量
self._usage_tracker.record(
self._record_usage(
agent_name=agent_name,
model=response.model,
usage=response.usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
# Telemetry: record token usage and end span
@ -278,6 +323,9 @@ class LLMGateway:
tools: list[dict] | None = None,
tool_choice: str = "auto",
timeout: float | None = None,
user_id: str | None = None,
department_ids: list[str] | None = None,
db_path: Path | str | None = None,
**kwargs,
):
"""Stream chat response with fallback support.
@ -293,6 +341,10 @@ class LLMGateway:
if not self._providers:
raise LLMProviderError("", "No provider registered")
# ── Quota enforcement ──
if department_ids and db_path:
await self._enforce_quota(db_path, department_ids, resolved_model)
models_to_try = self._get_models_to_try(resolved_model)
last_error: Exception | None = None
@ -354,12 +406,14 @@ class LLMGateway:
if final_usage is None:
final_usage = TokenUsage()
cost = self._calculate_cost(final_model, final_usage)
self._usage_tracker.record(
self._record_usage(
agent_name=agent_name,
model=final_model,
usage=final_usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
# Empty stream detection: if no content was produced,
@ -453,3 +507,132 @@ class LLMGateway:
start_time=start_time,
end_time=end_time,
)
# ------------------------------------------------------------------
# Quota enforcement helpers (U7)
# ------------------------------------------------------------------
def _record_usage(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None,
department_ids: list[str] | None,
) -> None:
"""Record a usage event, attaching user_id and (first) department_id.
We attach only the first department_id to the record because
usage attribution is per-department. If a user belongs to
multiple departments, the caller is responsible for choosing
which department to bill the gateway just records what it's
told.
"""
dept_id = department_ids[0] if department_ids else None
self._usage_tracker.record(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=dept_id,
)
async def _enforce_quota(
self,
db_path: Path | str,
department_ids: list[str],
resolved_model: str,
) -> None:
"""Run all quota checks for the given departments.
Strictest-wins: if ANY department fails ANY check, raises
:class:`QuotaExceededError` and the request is rejected.
"""
# Lazy import to avoid circular dependency (admin → ... → gateway).
from agentkit.server.admin.quota_service import get_quota_service
quota_service = get_quota_service()
db = Path(db_path)
for dept_id in department_ids:
# 1. Model whitelist
allowed, _reason = await quota_service.is_model_allowed(db, dept_id, resolved_model)
if not allowed:
raise QuotaExceededError(
department_id=dept_id,
quota_type="model_whitelist",
period="",
limit="",
current=resolved_model,
)
# 2. Token limit (daily)
current_tokens = await self._get_current_usage_for_quota(dept_id, "daily")
allowed, _reason = await quota_service.check_quota(
db, dept_id, "token_limit", "daily", current_tokens
)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, "token_limit", "daily")
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type="token_limit",
period="daily",
limit=limit,
current=current_tokens,
)
# 3. Cost limit (daily)
current_cost = await self._get_current_cost_for_quota(dept_id, "daily")
allowed, _reason = await quota_service.check_quota(
db, dept_id, "cost_limit", "daily", current_cost
)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, "cost_limit", "daily")
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type="cost_limit",
period="daily",
limit=limit,
current=current_cost,
)
async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int:
"""Return total tokens used by ``department_id`` in the current period.
``period`` is ``"daily"`` or ``"monthly"``. For ``"daily"`` the
window is since 00:00 UTC today; for ``"monthly"`` since the
first of the current month.
"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
summary = self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
return int(summary.total_tokens)
async def _get_current_cost_for_quota(self, department_id: str, period: str) -> float:
"""Return total cost (in cents) for ``department_id`` in the current period.
``period`` is ``"daily"`` or ``"monthly"``. Quota cost_limit is
stored in cents, so we convert the float USD cost from the usage
store to cents (×100) for comparison.
"""
now = datetime.now(timezone.utc)
if period == "monthly":
start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
else:
start = now.replace(hour=0, minute=0, second=0, microsecond=0)
summary = self._usage_tracker.get_usage(
department_id=department_id, start_time=start, end_time=now
)
# cost_limit is stored in cents; convert from USD to cents.
return float(summary.total_cost) * 100.0

View File

@ -23,15 +23,38 @@ class UsageTracker:
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""记录一次使用"""
self._store.record(agent_name, model, usage, cost, latency_ms)
self._store.record(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_id: str | None = None,
department_id: str | None = None,
) -> UsageSummary:
"""查询使用量汇总"""
return self._store.get_usage(agent_name, start_time, end_time)
return self._store.get_usage(
agent_name=agent_name,
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
@property
def store(self) -> UsageStore:
"""Expose the underlying store for service-layer queries."""
return self._store

View File

@ -5,8 +5,12 @@ backends. Replaces the in-memory list in UsageTracker with a pluggable
store that survives restarts and supports multi-instance deployment.
Key schema (Redis):
agentkit:usage:{date} Hash: {agent_name:model JSON(UsageBucket)}
agentkit:usage_records:{date} List: JSON(UsageRecord) with LTRIM
agentkit:usage:v2:{date}:{user_id}:{department_id} Hash: {agent_name:model JSON(UsageBucket)}
agentkit:usage_records:v2:{date}:{user_id}:{department_id} List: JSON(UsageRecord) with LTRIM
Legacy v1 keys (still readable for backward compat):
agentkit:usage:{date} Hash
agentkit:usage_records:{date} List
"""
import json
@ -32,6 +36,8 @@ class UsageRecord:
cost: float
latency_ms: float
timestamp: str = "" # ISO 8601 string for JSON serialization
user_id: str | None = None
department_id: str | None = None
def __post_init__(self):
if not self.timestamp:
@ -57,6 +63,8 @@ class UsageSummary:
total_cost: float = 0.0
by_model: dict[str, dict[str, int | float]] = field(default_factory=dict)
records: list[UsageRecord] = field(default_factory=list)
by_user: dict[str, dict[str, int | float]] = field(default_factory=dict)
by_department: dict[str, dict[str, int | float]] = field(default_factory=dict)
# ---------------------------------------------------------------------------
@ -75,6 +83,8 @@ class UsageStore(Protocol):
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""Record a usage event."""
...
@ -84,6 +94,8 @@ class UsageStore(Protocol):
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_id: str | None = None,
department_id: str | None = None,
) -> UsageSummary:
"""Query usage summary."""
...
@ -109,6 +121,8 @@ class InMemoryUsageStore:
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
rec = UsageRecord(
agent_name=agent_name,
@ -118,16 +132,20 @@ class InMemoryUsageStore:
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=department_id,
)
self._records.append(rec)
if len(self._records) > self.MAX_RECORDS:
self._records = self._records[-self.MAX_RECORDS:]
self._records = self._records[-self.MAX_RECORDS :]
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_id: str | None = None,
department_id: str | None = None,
) -> UsageSummary:
filtered = self._records
@ -137,26 +155,65 @@ class InMemoryUsageStore:
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time]
if end_time is not None:
filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time]
if user_id is not None:
filtered = [r for r in filtered if r.user_id == user_id]
if department_id is not None:
filtered = [r for r in filtered if r.department_id == department_id]
if not filtered:
return UsageSummary()
total_tokens = sum(r.total_tokens for r in filtered)
total_cost = sum(r.cost for r in filtered)
return self._aggregate(filtered)
def get_usage_by_user(
self,
user_id: str,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Aggregate usage for a specific user."""
return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time)
def get_usage_by_department(
self,
department_id: str,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Aggregate usage for a specific department."""
return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time)
@staticmethod
def _aggregate(records: list[UsageRecord]) -> UsageSummary:
"""Build a :class:`UsageSummary` from a list of records."""
total_tokens = sum(r.total_tokens for r in records)
total_cost = sum(r.cost for r in records)
by_model: dict[str, dict[str, int | float]] = {}
for r in filtered:
if r.model not in by_model:
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
by_model[r.model]["total_tokens"] += r.total_tokens
by_model[r.model]["total_cost"] += r.cost
by_model[r.model]["count"] += 1
by_user: dict[str, dict[str, int | float]] = {}
by_department: dict[str, dict[str, int | float]] = {}
def _bump(bucket_map: dict[str, dict[str, int | float]], key: str, r: UsageRecord) -> None:
if key not in bucket_map:
bucket_map[key] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
bucket_map[key]["total_tokens"] += r.total_tokens
bucket_map[key]["total_cost"] += r.cost
bucket_map[key]["count"] += 1
for r in records:
_bump(by_model, r.model, r)
if r.user_id is not None:
_bump(by_user, r.user_id, r)
if r.department_id is not None:
_bump(by_department, r.department_id, r)
return UsageSummary(
total_tokens=total_tokens,
total_cost=total_cost,
by_model=by_model,
records=filtered,
by_user=by_user,
by_department=by_department,
records=records,
)
@ -168,13 +225,21 @@ class InMemoryUsageStore:
class RedisUsageStore:
"""Redis-backed usage store using Hash per date for O(1) writes.
Key schema:
agentkit:usage:{YYYY-MM-DD} Hash: {agent:model JSON(UsageBucket)}
agentkit:usage_records:{YYYY-MM-DD} List: JSON(UsageRecord) with LTRIM
Key schema (v2 includes user_id/department_id in key):
agentkit:usage:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'}
Hash: {agent:model JSON(UsageBucket)}
agentkit:usage_records:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'}
List: JSON(UsageRecord) with LTRIM
Legacy v1 keys (still readable for backward compat):
agentkit:usage:{YYYY-MM-DD} Hash
agentkit:usage_records:{YYYY-MM-DD} List
"""
USAGE_PREFIX = "agentkit:usage:"
RECORDS_PREFIX = "agentkit:usage_records:"
USAGE_PREFIX_V2 = "agentkit:usage:v2:"
RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:"
MAX_RECORDS_PER_DAY = 50000
TTL_DAYS = 90 # Auto-expire after 90 days
@ -188,6 +253,7 @@ class RedisUsageStore:
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
@ -195,9 +261,8 @@ class RedisUsageStore:
"""Get or create a persistent sync Redis client (connection pool backed)."""
if self._sync_redis is None:
import redis as sync_redis
self._sync_redis = sync_redis.from_url(
self._redis_url, decode_responses=True
)
self._sync_redis = sync_redis.from_url(self._redis_url, decode_responses=True)
return self._sync_redis
async def aclose(self) -> None:
@ -218,6 +283,22 @@ class RedisUsageStore:
def _today_key(self) -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
@staticmethod
def _scope_key(part: str | None) -> str:
"""Normalize a user_id/department_id for use in a Redis key."""
return part if part else "none"
def _v2_keys(
self, date_key: str, user_id: str | None, department_id: str | None
) -> tuple[str, str]:
"""Return (hash_key, list_key) for v2 schema."""
u = self._scope_key(user_id)
d = self._scope_key(department_id)
return (
f"{self.USAGE_PREFIX_V2}{date_key}:{u}:{d}",
f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}",
)
def record(
self,
agent_name: str,
@ -225,6 +306,8 @@ class RedisUsageStore:
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""Record usage — sync wrapper for async Redis.
@ -233,15 +316,22 @@ class RedisUsageStore:
needing an event loop in the caller.
"""
if self._degraded and self._fallback is not None:
self._fallback.record(agent_name, model, usage, cost, latency_ms)
self._fallback.record(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
return
try:
r = self._get_sync_redis()
date_key = self._today_key()
hash_key = f"{self.USAGE_PREFIX}{date_key}"
list_key = f"{self.RECORDS_PREFIX}{date_key}"
hash_key, list_key = self._v2_keys(date_key, user_id, department_id)
bucket_field = f"{agent_name}:{model}"
# Atomic HINCRBYFLOAT for bucket aggregation
@ -261,17 +351,26 @@ class RedisUsageStore:
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=department_id,
)
pipe.rpush(
list_key,
json.dumps(
{
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"timestamp": rec.timestamp,
"user_id": rec.user_id,
"department_id": rec.department_id,
}
),
)
pipe.rpush(list_key, json.dumps({
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"timestamp": rec.timestamp,
}))
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
# Set TTL on first write of the day
@ -283,17 +382,38 @@ class RedisUsageStore:
logger.warning(f"Redis usage record failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
self._fallback.record(agent_name, model, usage, cost, latency_ms)
self._fallback.record(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
def get_usage(
self,
agent_name: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_id: str | None = None,
department_id: str | None = None,
) -> UsageSummary:
"""Query usage summary from Redis."""
"""Query usage summary from Redis.
Scans v2 keys (filtered by user_id/department_id when provided)
and legacy v1 keys (no per-user/department scoping). Records
from both schemas are merged.
"""
if self._degraded and self._fallback is not None:
return self._fallback.get_usage(agent_name, start_time, end_time)
return self._fallback.get_usage(
agent_name=agent_name,
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
try:
r = self._get_sync_redis()
@ -303,47 +423,115 @@ class RedisUsageStore:
end = end_time or datetime.now(timezone.utc)
all_records: list[UsageRecord] = []
# Scan date keys in range
# Scan v2 keys.
current = start.date()
end_date = end.date()
while current <= end_date:
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
raw_records = r.lrange(list_key, 0, -1)
for raw in raw_records:
data = json.loads(raw)
rec = UsageRecord(**data)
rec_ts = datetime.fromisoformat(rec.timestamp)
if rec_ts >= start and rec_ts <= end:
if agent_name is None or rec.agent_name == agent_name:
all_records.append(rec)
date_key = current.isoformat()
# When user_id/department_id is provided, scan only the
# matching scope key. Otherwise scan all scopes for that
# date via SCAN.
if user_id is not None or department_id is not None:
list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}"
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
else:
# Scan all v2 list keys for this date.
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*"
for key in r.scan_iter(match=pattern, count=200):
all_records.extend(self._read_list(r, key, start, end, agent_name))
current = current + timedelta(days=1)
# Also scan legacy v1 keys (no user/department scoping).
current = start.date()
while current <= end_date:
list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}"
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
current = current + timedelta(days=1)
# Apply user_id/department_id filters to records from legacy
# v1 keys (which don't carry these fields — they'll be None).
if user_id is not None:
all_records = [r for r in all_records if r.user_id == user_id]
if department_id is not None:
all_records = [r for r in all_records if r.department_id == department_id]
if not all_records:
return UsageSummary()
total_tokens = sum(r.total_tokens for r in all_records)
total_cost = sum(r.cost for r in all_records)
by_model: dict[str, dict[str, int | float]] = {}
for r in all_records:
if r.model not in by_model:
by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0}
by_model[r.model]["total_tokens"] += r.total_tokens
by_model[r.model]["total_cost"] += r.cost
by_model[r.model]["count"] += 1
return UsageSummary(
total_tokens=total_tokens,
total_cost=total_cost,
by_model=by_model,
records=all_records,
)
return InMemoryUsageStore._aggregate(all_records)
except Exception as e:
logger.warning(f"Redis usage query failed: {e}")
if self._fallback is not None:
return self._fallback.get_usage(agent_name, start_time, end_time)
return self._fallback.get_usage(
agent_name=agent_name,
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
return UsageSummary()
def get_usage_by_user(
self,
user_id: str,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Aggregate usage for a specific user."""
return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time)
def get_usage_by_department(
self,
department_id: str,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> UsageSummary:
"""Aggregate usage for a specific department."""
return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time)
@staticmethod
def _read_list(
r: Any,
list_key: str,
start: datetime,
end: datetime,
agent_name: str | None,
) -> list[UsageRecord]:
"""Read all records from a Redis list, filtered by time range and agent."""
out: list[UsageRecord] = []
raw_records = r.lrange(list_key, 0, -1)
for raw in raw_records:
try:
data = json.loads(raw)
except json.JSONDecodeError:
continue
# Build record, tolerating legacy records without user_id/department_id.
rec = UsageRecord(
agent_name=data["agent_name"],
model=data["model"],
prompt_tokens=data["prompt_tokens"],
completion_tokens=data["completion_tokens"],
total_tokens=data["total_tokens"],
cost=data["cost"],
latency_ms=data["latency_ms"],
timestamp=data.get("timestamp", ""),
user_id=data.get("user_id"),
department_id=data.get("department_id"),
)
if not rec.timestamp:
continue
try:
rec_ts = datetime.fromisoformat(rec.timestamp)
except ValueError:
continue
if rec_ts < start or rec_ts > end:
continue
if agent_name is not None and rec.agent_name != agent_name:
continue
out.append(rec)
return out
# ---------------------------------------------------------------------------
# Factory
@ -366,6 +554,7 @@ def create_usage_store(
if backend in ("auto", "redis"):
try:
import redis # noqa: F401
return RedisUsageStore(redis_url=redis_url)
except ImportError:
logger.warning("redis package not available, falling back to in-memory usage store")

View File

@ -0,0 +1,298 @@
"""UsageService — read-side aggregations for the usage dashboard (U7).
This module provides read-only aggregations over a :class:`UsageStore`
for the admin usage dashboard. It is intentionally a thin layer the
store already produces :class:`UsageSummary` aggregations, and this
service just shapes them for the dashboard endpoints (timeseries,
top-N, CSV/JSON export).
The service is a module-level singleton (see :func:`get_usage_service`)
so tests can inject a custom instance via :func:`set_usage_service`.
"""
from __future__ import annotations
import csv
import io
import json
import logging
from datetime import datetime
from typing import Any
from agentkit.llm.providers.usage_store import UsageStore, UsageSummary
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _bucket_start(ts: datetime, interval: str) -> datetime:
"""Return the start of the time bucket containing ``ts``."""
if interval == "hour":
return ts.replace(minute=0, second=0, microsecond=0)
# Default: day
return ts.replace(hour=0, minute=0, second=0, microsecond=0)
# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class UsageService:
"""Read-side aggregations for the usage dashboard."""
async def get_usage_summary(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> dict[str, Any]:
"""Return a flat usage summary dict.
Shape::
{
"total_tokens": int,
"total_cost": float,
"total_requests": int,
"by_model": {model: {total_tokens, total_cost, count}, ...},
"by_user": {user_id: {...}, ...},
"by_department": {department_id: {...}, ...},
}
"""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
return self._summary_to_dict(summary)
async def get_usage_timeseries(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
interval: str = "day",
) -> list[dict[str, Any]]:
"""Return a time-bucketed series.
Each item has shape ``{timestamp, tokens, cost, requests}``.
Buckets with no activity are omitted (callers can fill gaps).
"""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
buckets: dict[datetime, dict[str, Any]] = {}
for rec in summary.records:
try:
ts = datetime.fromisoformat(rec.timestamp)
except ValueError:
continue
bucket = _bucket_start(ts, interval)
if bucket not in buckets:
buckets[bucket] = {"tokens": 0, "cost": 0.0, "requests": 0}
buckets[bucket]["tokens"] += rec.total_tokens
buckets[bucket]["cost"] += rec.cost
buckets[bucket]["requests"] += 1
return [
{
"timestamp": bucket.isoformat(),
"tokens": data["tokens"],
"cost": data["cost"],
"requests": data["requests"],
}
for bucket, data in sorted(buckets.items())
]
async def get_usage_by_model(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[dict[str, Any]]:
"""Return a per-model breakdown."""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
return [
{
"model": model,
"tokens": data["total_tokens"],
"cost": data["total_cost"],
"requests": data["count"],
}
for model, data in sorted(summary.by_model.items())
]
async def get_top_users(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Return the top-N users by total token usage."""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
rows = [
{
"user_id": uid,
"tokens": data["total_tokens"],
"cost": data["total_cost"],
"requests": data["count"],
}
for uid, data in summary.by_user.items()
]
rows.sort(key=lambda r: r["tokens"], reverse=True)
return rows[:limit]
async def get_top_departments(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Return the top-N departments by total token usage."""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
rows = [
{
"department_id": did,
"tokens": data["total_tokens"],
"cost": data["total_cost"],
"requests": data["count"],
}
for did, data in summary.by_department.items()
]
rows.sort(key=lambda r: r["tokens"], reverse=True)
return rows[:limit]
async def export_usage(
self,
usage_store: UsageStore,
department_id: str | None = None,
user_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
format: str = "csv",
) -> str:
"""Export raw usage records as CSV or JSON.
``format`` is ``"csv"`` (default) or ``"json"``.
"""
summary = usage_store.get_usage(
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
records = [
{
"timestamp": rec.timestamp,
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"user_id": rec.user_id or "",
"department_id": rec.department_id or "",
}
for rec in summary.records
]
if format == "json":
return json.dumps(records, ensure_ascii=False, indent=2)
# Default: CSV
out = io.StringIO()
if records:
writer = csv.DictWriter(out, fieldnames=list(records[0].keys()))
writer.writeheader()
writer.writerows(records)
else:
# Empty CSV with just headers
writer = csv.DictWriter(
out,
fieldnames=[
"timestamp",
"agent_name",
"model",
"prompt_tokens",
"completion_tokens",
"total_tokens",
"cost",
"latency_ms",
"user_id",
"department_id",
],
)
writer.writeheader()
return out.getvalue()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _summary_to_dict(summary: UsageSummary) -> dict[str, Any]:
"""Convert a :class:`UsageSummary` to a flat dict response."""
return {
"total_tokens": summary.total_tokens,
"total_cost": summary.total_cost,
"total_requests": len(summary.records),
"by_model": dict(summary.by_model),
"by_user": dict(summary.by_user),
"by_department": dict(summary.by_department),
}
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_usage_service)
# ---------------------------------------------------------------------------
_usage_service: UsageService | None = None
def get_usage_service() -> UsageService:
"""Return the process-wide :class:`UsageService` (lazy singleton)."""
global _usage_service
if _usage_service is None:
_usage_service = UsageService()
return _usage_service
def set_usage_service(service: UsageService | None) -> None:
"""Inject a custom :class:`UsageService` (used by tests)."""
global _usage_service
_usage_service = service

View File

@ -16,10 +16,12 @@ import time (keeps the module self-contained and test-friendly).
from __future__ import annotations
import logging
from datetime import datetime
from pathlib import Path
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel, ConfigDict
from agentkit.server.admin.department_service import get_department_service
@ -30,6 +32,7 @@ from agentkit.server.admin.llm_config_service import (
)
from agentkit.server.admin.quota_service import get_quota_service
from agentkit.server.admin.skill_service import get_skill_service
from agentkit.server.admin.usage_service import get_usage_service
from agentkit.server.admin.user_service import get_user_service
from agentkit.server.auth.dependencies import require_authenticated
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
@ -1138,3 +1141,192 @@ async def rebuild_kb_source(
return svc.rebuild_index(source_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
# ---------------------------------------------------------------------------
# Usage dashboard endpoints (U7) — usage aggregations + export
# ---------------------------------------------------------------------------
def _get_usage_store(request: Request) -> Any:
"""Return the live :class:`UsageStore` from ``app.state.llm_gateway``.
Raises HTTPException(500) if the gateway or usage store is missing
usage endpoints cannot function without it.
"""
gateway = getattr(request.app.state, "llm_gateway", None)
if gateway is None:
raise HTTPException(
status_code=500,
detail="LLM gateway not initialized on app.state",
)
try:
return gateway._usage_tracker.store # type: ignore[attr-defined]
except AttributeError as exc:
raise HTTPException(
status_code=500,
detail="Usage store not available on LLM gateway",
) from exc
def _parse_iso(value: str | None) -> datetime | None:
"""Parse an ISO 8601 string into a timezone-aware datetime."""
if value is None or value == "":
return None
try:
dt = datetime.fromisoformat(value)
except ValueError:
# Try the trailing-Z form.
if value.endswith("Z"):
try:
from datetime import timezone
dt = datetime.fromisoformat(value[:-1]).replace(tzinfo=timezone.utc)
except ValueError:
raise HTTPException(
status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}"
)
else:
raise HTTPException(status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}")
return dt
@admin_router.get("/usage/summary")
async def get_usage_summary(
request: Request,
department_id: str | None = None,
user_id: str | None = None,
start: str | None = None,
end: str | None = None,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Return an aggregated usage summary.
Query params: ``department_id``, ``user_id``, ``start``, ``end``
(ISO 8601). Admins see all data; non-admin callers are blocked by
``_require_admin`` (403).
"""
store = _get_usage_store(request)
svc = get_usage_service()
return await svc.get_usage_summary(
store,
department_id=department_id,
user_id=user_id,
start_time=_parse_iso(start),
end_time=_parse_iso(end),
)
@admin_router.get("/usage/timeseries")
async def get_usage_timeseries(
request: Request,
department_id: str | None = None,
user_id: str | None = None,
start: str | None = None,
end: str | None = None,
interval: str = "day",
admin: dict[str, Any] = Depends(_require_admin),
) -> list[dict[str, Any]]:
"""Return a time-bucketed usage series.
Query params: ``department_id``, ``user_id``, ``start``, ``end``
(ISO 8601), ``interval`` (``day`` or ``hour``, default ``day``).
"""
if interval not in ("day", "hour"):
raise HTTPException(status_code=400, detail="interval must be 'day' or 'hour'")
store = _get_usage_store(request)
svc = get_usage_service()
return await svc.get_usage_timeseries(
store,
department_id=department_id,
user_id=user_id,
start_time=_parse_iso(start),
end_time=_parse_iso(end),
interval=interval,
)
@admin_router.get("/usage/by-model")
async def get_usage_by_model(
request: Request,
department_id: str | None = None,
user_id: str | None = None,
start: str | None = None,
end: str | None = None,
admin: dict[str, Any] = Depends(_require_admin),
) -> list[dict[str, Any]]:
"""Return a per-model usage breakdown."""
store = _get_usage_store(request)
svc = get_usage_service()
return await svc.get_usage_by_model(
store,
department_id=department_id,
user_id=user_id,
start_time=_parse_iso(start),
end_time=_parse_iso(end),
)
@admin_router.get("/usage/top-users")
async def get_top_users(
request: Request,
department_id: str | None = None,
user_id: str | None = None,
start: str | None = None,
end: str | None = None,
limit: int = 10,
admin: dict[str, Any] = Depends(_require_admin),
) -> list[dict[str, Any]]:
"""Return the top-N users by total token usage.
Query params: ``department_id``, ``user_id``, ``start``, ``end``,
``limit`` (default 10, max 100).
"""
if limit < 1:
limit = 1
if limit > 100:
limit = 100
store = _get_usage_store(request)
svc = get_usage_service()
return await svc.get_top_users(
store,
department_id=department_id,
user_id=user_id,
start_time=_parse_iso(start),
end_time=_parse_iso(end),
limit=limit,
)
@admin_router.get("/usage/export")
async def export_usage(
request: Request,
department_id: str | None = None,
user_id: str | None = None,
start: str | None = None,
end: str | None = None,
format: str = "csv",
admin: dict[str, Any] = Depends(_require_admin),
) -> Any:
"""Export raw usage records as CSV or JSON.
Query params: ``department_id``, ``user_id``, ``start``, ``end``,
``format`` (``csv`` or ``json``, default ``csv``).
Returns ``text/csv`` for CSV or ``application/json`` for JSON.
"""
if format not in ("csv", "json"):
raise HTTPException(status_code=400, detail="format must be 'csv' or 'json'")
store = _get_usage_store(request)
svc = get_usage_service()
body = await svc.export_usage(
store,
department_id=department_id,
user_id=user_id,
start_time=_parse_iso(start),
end_time=_parse_iso(end),
format=format,
)
if format == "csv":
return PlainTextResponse(content=body, media_type="text/csv")
return PlainTextResponse(content=body, media_type="application/json")

View File

@ -7,12 +7,14 @@ Supports both non-streaming (`POST /api/v1/llm/chat`) and SSE streaming
import json
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, Field
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.gateway import QuotaExceededError
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
from agentkit.server.admin.context import get_department_context
router = APIRouter(prefix="/llm", tags=["llm-gateway"])
@ -66,14 +68,32 @@ def _serialize_chunk(chunk: StreamChunk) -> dict[str, Any]:
return payload
def _quota_error_payload(exc: QuotaExceededError) -> dict[str, Any]:
"""Build a structured 429 error body from a QuotaExceededError."""
return {
"error": "quota_exceeded",
"department_id": exc.department_id,
"quota_type": exc.quota_type,
"period": exc.period,
"limit": exc.limit,
"current": exc.current,
}
@router.post("/chat")
async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
async def chat(
request: Request,
body: LLMChatRequest,
ctx: Any = Depends(get_department_context),
) -> dict[str, Any]:
"""Non-streaming LLM chat proxy.
Forwards the request to the configured LLMGateway and returns the
serialized LLMResponse.
serialized LLMResponse. Quota-exceeded errors from the gateway are
translated to HTTP 429.
"""
gateway = request.app.state.llm_gateway
db_path = getattr(request.app.state, "auth_db_path", None)
try:
response = await gateway.chat(
messages=body.messages,
@ -83,7 +103,12 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
timeout=body.timeout,
temperature=body.temperature,
max_tokens=body.max_tokens,
user_id=ctx.user_id,
department_ids=ctx.department_ids if ctx.department_ids else None,
db_path=db_path,
)
except QuotaExceededError as e:
raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e
except ModelNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
except LLMProviderError as e:
@ -92,7 +117,11 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]:
@router.post("/chat/stream")
async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingResponse:
async def chat_stream(
request: Request,
body: LLMChatRequest,
ctx: Any = Depends(get_department_context),
) -> StreamingResponse:
"""SSE streaming LLM chat proxy.
Each StreamChunk is serialized as `data: {json}\\n\\n`. The stream
@ -101,6 +130,7 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon
async def event_generator():
gateway = request.app.state.llm_gateway
db_path = getattr(request.app.state, "auth_db_path", None)
try:
async for chunk in gateway.chat_stream(
messages=body.messages,
@ -110,9 +140,16 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon
timeout=body.timeout,
temperature=body.temperature,
max_tokens=body.max_tokens,
user_id=ctx.user_id,
department_ids=ctx.department_ids if ctx.department_ids else None,
db_path=db_path,
):
payload = _serialize_chunk(chunk)
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
except QuotaExceededError as e:
error_payload = _quota_error_payload(e)
error_payload["error"] = "quota_exceeded"
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
except ModelNotFoundError as e:
error_payload = {"error": "model_not_found", "detail": str(e)}
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"

View File

@ -0,0 +1,341 @@
"""Integration tests for the admin usage dashboard routes (U7).
Uses FastAPI TestClient with a test app that mounts only the
``admin_router`` from ``routes.admin``. The ``_require_admin``
dependency is overridden via ``app.dependency_overrides`` so the tests
don't need real JWTs — they can simulate admin and non-admin callers
directly.
The LLM gateway is replaced with a stub that exposes a
``_usage_tracker.store`` attribute pointing at an
:class:`InMemoryUsageStore` pre-populated with test records.
"""
from __future__ import annotations
import csv
import io
import json
from pathlib import Path
from typing import Any
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from agentkit.llm.protocol import TokenUsage
from agentkit.llm.providers.usage_store import InMemoryUsageStore
from agentkit.server.admin.usage_service import set_usage_service
from agentkit.server.auth.models import init_auth_db
from agentkit.server.routes import admin as admin_routes_module
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
class _StubTracker:
"""Minimal stub matching the UsageTracker surface used by routes."""
def __init__(self, store: InMemoryUsageStore) -> None:
self.store = store
class _StubGateway:
"""Minimal stub matching the LLMGateway surface used by routes."""
def __init__(self, store: InMemoryUsageStore) -> None:
self._usage_tracker = _StubTracker(store)
@pytest.fixture
def store() -> InMemoryUsageStore:
return InMemoryUsageStore()
@pytest.fixture
def populated_store() -> InMemoryUsageStore:
"""Pre-populated store with a mix of records across users/depts/models."""
s = InMemoryUsageStore()
s.record(
"agent1",
"gpt-4o",
TokenUsage(prompt_tokens=60, completion_tokens=40),
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
s.record(
"agent1",
"claude",
TokenUsage(prompt_tokens=120, completion_tokens=80),
cost=0.10,
latency_ms=300,
user_id="u1",
department_id="d1",
)
s.record(
"agent2",
"gpt-4o",
TokenUsage(prompt_tokens=30, completion_tokens=20),
cost=0.02,
latency_ms=100,
user_id="u2",
department_id="d2",
)
s.record(
"agent3",
"gpt-4o",
TokenUsage(prompt_tokens=300, completion_tokens=200),
cost=0.50,
latency_ms=400,
user_id="u3",
department_id="d1",
)
return s
@pytest.fixture(autouse=True)
def _reset_singletons():
set_usage_service(None)
yield
set_usage_service(None)
@pytest.fixture
async def tmp_auth_db(tmp_path: Path) -> Path:
db_path = tmp_path / "usage_routes.db"
await init_auth_db(db_path)
return db_path
def _make_admin_app(store: InMemoryUsageStore, tmp_auth_db: Path) -> FastAPI:
"""Build a FastAPI app with admin router + stub gateway."""
app = FastAPI()
app.state.auth_db_path = str(tmp_auth_db)
app.state.llm_gateway = _StubGateway(store)
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
# Default: allow admin access.
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
return app
def _make_admin_user() -> dict[str, Any]:
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
def _raise_forbidden() -> dict[str, Any]:
raise HTTPException(status_code=403, detail="Admin permission required")
# ---------------------------------------------------------------------------
# /admin/usage/summary
# ---------------------------------------------------------------------------
class TestUsageSummaryRoute:
def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/summary")
assert resp.status_code == 200
body = resp.json()
assert body["total_tokens"] == 850
assert abs(body["total_cost"] - 0.67) < 1e-6
assert body["total_requests"] == 4
assert "gpt-4o" in body["by_model"]
assert "u1" in body["by_user"]
assert "d1" in body["by_department"]
def test_with_department_filter(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/summary", params={"department_id": "d2"})
assert resp.status_code == 200
body = resp.json()
assert body["total_tokens"] == 50
assert body["total_requests"] == 1
def test_empty_store_returns_200_with_zeros(self, store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/summary")
assert resp.status_code == 200
body = resp.json()
assert body["total_tokens"] == 0
assert body["total_cost"] == 0.0
assert body["total_requests"] == 0
def test_non_admin_returns_403(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/summary")
assert resp.status_code == 403
# ---------------------------------------------------------------------------
# /admin/usage/timeseries
# ---------------------------------------------------------------------------
class TestUsageTimeseriesRoute:
def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/timeseries")
assert resp.status_code == 200
body = resp.json()
assert isinstance(body, list)
assert len(body) >= 1
assert "timestamp" in body[0]
assert "tokens" in body[0]
assert body[0]["tokens"] == 850
def test_invalid_interval_returns_400(
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/timeseries", params={"interval": "week"})
assert resp.status_code == 400
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/timeseries")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# /admin/usage/by-model
# ---------------------------------------------------------------------------
class TestUsageByModelRoute:
def test_returns_200_with_breakdown(
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/by-model")
assert resp.status_code == 200
body = resp.json()
assert isinstance(body, list)
models = {row["model"] for row in body}
assert models == {"gpt-4o", "claude"}
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/by-model")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# /admin/usage/top-users
# ---------------------------------------------------------------------------
class TestTopUsersRoute:
def test_returns_200_sorted_by_tokens(
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/top-users")
assert resp.status_code == 200
body = resp.json()
assert len(body) == 3
# u3 (500), u1 (300), u2 (50)
assert body[0]["user_id"] == "u3"
assert body[0]["tokens"] == 500
assert body[1]["user_id"] == "u1"
assert body[2]["user_id"] == "u2"
def test_limit_param_respected(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/top-users", params={"limit": 2})
assert resp.status_code == 200
body = resp.json()
assert len(body) == 2
def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/top-users")
assert resp.status_code == 200
assert resp.json() == []
# ---------------------------------------------------------------------------
# /admin/usage/export
# ---------------------------------------------------------------------------
class TestUsageExportRoute:
def test_csv_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"})
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/csv")
reader = csv.DictReader(io.StringIO(resp.text))
rows = list(reader)
assert len(rows) == 4
assert "timestamp" in rows[0]
assert "user_id" in rows[0]
assert "department_id" in rows[0]
def test_json_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/export", params={"format": "json"})
assert resp.status_code == 200
# PlainTextResponse returns text/plain or application/json depending on media_type.
body = json.loads(resp.text)
assert isinstance(body, list)
assert len(body) == 4
def test_invalid_format_returns_400(
self, populated_store: InMemoryUsageStore, tmp_auth_db: Path
):
app = _make_admin_app(populated_store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/export", params={"format": "xml"})
assert resp.status_code == 400
def test_empty_store_csv_returns_header_only(
self, store: InMemoryUsageStore, tmp_auth_db: Path
):
app = _make_admin_app(store, tmp_auth_db)
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"})
assert resp.status_code == 200
reader = csv.DictReader(io.StringIO(resp.text))
rows = list(reader)
assert rows == []
# Header should still be present.
assert "timestamp" in resp.text
# ---------------------------------------------------------------------------
# Missing gateway
# ---------------------------------------------------------------------------
class TestMissingGateway:
def test_summary_returns_500_without_gateway(self, tmp_auth_db: Path):
app = FastAPI()
app.state.auth_db_path = str(tmp_auth_db)
# No llm_gateway on app.state.
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
client = TestClient(app)
resp = client.get("/api/v1/admin/usage/summary")
assert resp.status_code == 500

View File

@ -0,0 +1,330 @@
"""Unit tests for UsageService (U7 — usage dashboard aggregations)."""
from __future__ import annotations
import csv
import io
import json
import pytest
from agentkit.llm.protocol import TokenUsage
from agentkit.llm.providers.usage_store import InMemoryUsageStore
from agentkit.server.admin.usage_service import (
UsageService,
get_usage_service,
set_usage_service,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def store() -> InMemoryUsageStore:
return InMemoryUsageStore()
@pytest.fixture
def service() -> UsageService:
return UsageService()
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Reset the UsageService singleton before and after each test."""
set_usage_service(None)
yield
set_usage_service(None)
def _populate_store(store: InMemoryUsageStore) -> None:
"""Populate ``store`` with a mix of records for testing."""
# User u1 in dept d1, gpt-4o, 100 tokens, $0.05
store.record(
"agent1",
"gpt-4o",
TokenUsage(prompt_tokens=60, completion_tokens=40),
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
# User u1 in dept d1, claude, 200 tokens, $0.10
store.record(
"agent1",
"claude",
TokenUsage(prompt_tokens=120, completion_tokens=80),
cost=0.10,
latency_ms=300,
user_id="u1",
department_id="d1",
)
# User u2 in dept d2, gpt-4o, 50 tokens, $0.02
store.record(
"agent2",
"gpt-4o",
TokenUsage(prompt_tokens=30, completion_tokens=20),
cost=0.02,
latency_ms=100,
user_id="u2",
department_id="d2",
)
# User u3 in dept d1, gpt-4o, 500 tokens, $0.50 (top user)
store.record(
"agent3",
"gpt-4o",
TokenUsage(prompt_tokens=300, completion_tokens=200),
cost=0.50,
latency_ms=400,
user_id="u3",
department_id="d1",
)
# ---------------------------------------------------------------------------
# get_usage_summary
# ---------------------------------------------------------------------------
class TestGetUsageSummary:
async def test_summary_aggregates_all(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_usage_summary(store)
assert result["total_tokens"] == 850
assert abs(result["total_cost"] - 0.67) < 1e-6
assert result["total_requests"] == 4
# by_model: gpt-4o (3 records, 650 tokens), claude (1, 200)
assert "gpt-4o" in result["by_model"]
assert "claude" in result["by_model"]
assert result["by_model"]["gpt-4o"]["count"] == 3
assert result["by_model"]["gpt-4o"]["total_tokens"] == 650
# by_user: u1 (2 records, 300 tokens), u2 (1, 50), u3 (1, 500)
assert result["by_user"]["u1"]["total_tokens"] == 300
assert result["by_user"]["u2"]["total_tokens"] == 50
assert result["by_user"]["u3"]["total_tokens"] == 500
# by_department: d1 (3 records, 800 tokens), d2 (1, 50)
assert result["by_department"]["d1"]["total_tokens"] == 800
assert result["by_department"]["d2"]["total_tokens"] == 50
async def test_summary_with_department_filter(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
result = await service.get_usage_summary(store, department_id="d1")
assert result["total_tokens"] == 800
assert result["total_requests"] == 3
# Only u1 and u3 are in d1.
assert "u1" in result["by_user"]
assert "u3" in result["by_user"]
assert "u2" not in result["by_user"]
async def test_summary_with_user_filter(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_usage_summary(store, user_id="u2")
assert result["total_tokens"] == 50
assert result["total_requests"] == 1
assert "u2" in result["by_user"]
async def test_summary_with_empty_store(self, service: UsageService, store: InMemoryUsageStore):
result = await service.get_usage_summary(store)
assert result["total_tokens"] == 0
assert result["total_cost"] == 0.0
assert result["total_requests"] == 0
assert result["by_model"] == {}
assert result["by_user"] == {}
assert result["by_department"] == {}
# ---------------------------------------------------------------------------
# get_usage_timeseries
# ---------------------------------------------------------------------------
class TestGetUsageTimeseries:
async def test_timeseries_day_buckets(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_usage_timeseries(store, interval="day")
# All records are within the same day (today), so we expect one bucket.
assert len(result) >= 1
bucket = result[0]
assert "timestamp" in bucket
assert bucket["tokens"] == 850
assert abs(bucket["cost"] - 0.67) < 1e-6
assert bucket["requests"] == 4
async def test_timeseries_hour_buckets(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_usage_timeseries(store, interval="hour")
# All records are within the same hour (now), so we expect one bucket.
assert len(result) >= 1
assert result[0]["tokens"] == 850
async def test_timeseries_with_department_filter(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
result = await service.get_usage_timeseries(store, department_id="d2", interval="day")
assert len(result) >= 1
assert result[0]["tokens"] == 50
async def test_timeseries_empty_store(self, service: UsageService, store: InMemoryUsageStore):
result = await service.get_usage_timeseries(store, interval="day")
assert result == []
# ---------------------------------------------------------------------------
# get_usage_by_model
# ---------------------------------------------------------------------------
class TestGetUsageByModel:
async def test_by_model_breakdown(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_usage_by_model(store)
# Sorted by model name: claude, gpt-4o
assert len(result) == 2
models = {row["model"] for row in result}
assert models == {"gpt-4o", "claude"}
gpt_row = next(r for r in result if r["model"] == "gpt-4o")
assert gpt_row["tokens"] == 650
assert gpt_row["requests"] == 3
claude_row = next(r for r in result if r["model"] == "claude")
assert claude_row["tokens"] == 200
assert claude_row["requests"] == 1
async def test_by_model_with_department_filter(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
result = await service.get_usage_by_model(store, department_id="d2")
# d2 only has gpt-4o (50 tokens, 1 request)
assert len(result) == 1
assert result[0]["model"] == "gpt-4o"
assert result[0]["tokens"] == 50
async def test_by_model_empty_store(self, service: UsageService, store: InMemoryUsageStore):
result = await service.get_usage_by_model(store)
assert result == []
# ---------------------------------------------------------------------------
# get_top_users
# ---------------------------------------------------------------------------
class TestGetTopUsers:
async def test_top_users_sorted_by_tokens(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
result = await service.get_top_users(store, limit=10)
# u3 (500), u1 (300), u2 (50)
assert len(result) == 3
assert result[0]["user_id"] == "u3"
assert result[0]["tokens"] == 500
assert result[1]["user_id"] == "u1"
assert result[1]["tokens"] == 300
assert result[2]["user_id"] == "u2"
assert result[2]["tokens"] == 50
async def test_top_users_respects_limit(self, service: UsageService, store: InMemoryUsageStore):
_populate_store(store)
result = await service.get_top_users(store, limit=2)
assert len(result) == 2
assert result[0]["user_id"] == "u3"
assert result[1]["user_id"] == "u1"
async def test_top_users_with_department_filter(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
result = await service.get_top_users(store, department_id="d1", limit=10)
# d1 has u1 and u3
assert len(result) == 2
assert result[0]["user_id"] == "u3"
assert result[1]["user_id"] == "u1"
async def test_top_users_empty_store(self, service: UsageService, store: InMemoryUsageStore):
result = await service.get_top_users(store, limit=10)
assert result == []
# ---------------------------------------------------------------------------
# export_usage
# ---------------------------------------------------------------------------
class TestExportUsage:
async def test_export_csv_has_header_and_rows(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
body = await service.export_usage(store, format="csv")
reader = csv.DictReader(io.StringIO(body))
rows = list(reader)
assert len(rows) == 4
# Verify headers
assert "timestamp" in rows[0]
assert "agent_name" in rows[0]
assert "model" in rows[0]
assert "user_id" in rows[0]
assert "department_id" in rows[0]
# Verify a known record
gpt_rows = [r for r in rows if r["model"] == "gpt-4o"]
assert len(gpt_rows) == 3
async def test_export_json_returns_valid_json(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
body = await service.export_usage(store, format="json")
data = json.loads(body)
assert isinstance(data, list)
assert len(data) == 4
assert "timestamp" in data[0]
assert "user_id" in data[0]
assert "department_id" in data[0]
async def test_export_csv_empty_store_returns_header_only(
self, service: UsageService, store: InMemoryUsageStore
):
body = await service.export_usage(store, format="csv")
reader = csv.DictReader(io.StringIO(body))
rows = list(reader)
assert rows == []
# Header should still be present.
assert "timestamp" in body
async def test_export_with_department_filter(
self, service: UsageService, store: InMemoryUsageStore
):
_populate_store(store)
body = await service.export_usage(store, department_id="d2", format="csv")
reader = csv.DictReader(io.StringIO(body))
rows = list(reader)
assert len(rows) == 1
assert rows[0]["department_id"] == "d2"
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
class TestSingletonHelpers:
def test_get_usage_service_returns_singleton(self):
first = get_usage_service()
second = get_usage_service()
assert first is second
def test_set_usage_service_overrides(self):
custom = UsageService()
set_usage_service(custom)
assert get_usage_service() is custom
# Clearing falls back to a new lazy instance.
set_usage_service(None)
new_one = get_usage_service()
assert new_one is not custom

View File

@ -0,0 +1,321 @@
"""Unit tests for LLMGateway quota enforcement (U7).
Covers:
- QuotaExceededError raised when token_limit exceeded
- QuotaExceededError raised when cost_limit exceeded
- QuotaExceededError raised when model not in whitelist
- No quota set request allowed
- Multi-department: strictest-wins (one exceeds, other doesn't → rejected)
- QuotaExceededError carries the right metadata
- Usage recording still attaches user_id + department_id on success
"""
from __future__ import annotations
import uuid
from pathlib import Path
import pytest
from agentkit.llm.gateway import LLMGateway, QuotaExceededError
from agentkit.llm.protocol import (
LLMProvider,
LLMRequest,
LLMResponse,
TokenUsage,
)
from agentkit.llm.providers.usage_store import InMemoryUsageStore
from agentkit.server.admin.quota_service import (
get_quota_service,
set_quota_service,
)
from agentkit.server.auth.models import init_auth_db
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class FakeProvider(LLMProvider):
"""A minimal LLMProvider that returns a fixed response."""
def __init__(self, name: str = "fake"):
self._name = name
self.last_request: LLMRequest | None = None
self.call_count = 0
async def chat(self, request: LLMRequest) -> LLMResponse:
self.last_request = request
self.call_count += 1
return LLMResponse(
content=f"response from {self._name}",
model=request.model,
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def store() -> InMemoryUsageStore:
return InMemoryUsageStore()
@pytest.fixture
def gateway(store: InMemoryUsageStore) -> LLMGateway:
gw = LLMGateway(usage_store=store)
gw.register_provider("openai", FakeProvider("openai"))
return gw
@pytest.fixture
async def fresh_db(tmp_path: Path) -> Path:
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture(autouse=True)
def _reset_quota_singleton():
"""Reset the QuotaService singleton before and after each test."""
set_quota_service(None)
yield
set_quota_service(None)
def _random_dept_id() -> str:
return str(uuid.uuid4())
# ---------------------------------------------------------------------------
# Quota enforcement
# ---------------------------------------------------------------------------
class TestQuotaEnforcement:
async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path):
"""When no quota is configured, the request is allowed."""
dept_id = _random_dept_id()
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
"""token_limit quota exceeded → QuotaExceededError."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Set a tiny token limit (1 token) — any usage will exceed it.
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
# Pre-populate the usage store so the daily total > 1.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=0.0,
latency_ms=10,
user_id="u1",
department_id=dept_id,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.department_id == dept_id
assert err.quota_type == "token_limit"
assert err.period == "daily"
assert err.limit == 1
assert err.current == 150 # 100 prompt + 50 completion
async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
"""cost_limit quota exceeded → QuotaExceededError."""
dept_id = _random_dept_id()
svc = get_quota_service()
# cost_limit is in cents. Set 1 cent.
await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily")
# Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=1.00, # $1.00 = 100 cents
latency_ms=10,
user_id="u1",
department_id=dept_id,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.quota_type == "cost_limit"
assert err.period == "daily"
assert err.limit == 1
# current is in cents (100 cents = $1.00).
assert err.current == 100.0
async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path):
"""Model not in whitelist → QuotaExceededError with quota_type=model_whitelist."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Whitelist only allows "claude" — gateway is calling "gpt-4o".
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily")
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.quota_type == "model_whitelist"
assert err.department_id == dept_id
# For model_whitelist, current is the rejected model name.
assert err.current == "openai/gpt-4o"
async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path):
"""Model in whitelist → request allowed."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Whitelist uses the full resolved model identifier (provider/model).
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily")
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path):
"""One department exceeds, the other doesn't → rejected (strictest wins)."""
dept_ok = _random_dept_id()
dept_bad = _random_dept_id()
svc = get_quota_service()
# dept_bad has a 1-token limit; dept_ok has a 1M-token limit.
await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily")
await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily")
# Pre-populate usage for dept_bad so it exceeds.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=0.0,
latency_ms=10,
user_id="u1",
department_id=dept_bad,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_ok, dept_bad],
db_path=fresh_db,
)
# The error should reference dept_bad (the one that exceeded).
assert exc_info.value.department_id == dept_bad
async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path):
"""When db_path is None, no quota check is performed."""
dept_id = _random_dept_id()
svc = get_quota_service()
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
# Even with a quota set, calling without db_path should succeed.
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=None,
)
assert response.content == "response from openai"
async def test_quota_check_skipped_without_department_ids(
self, gateway: LLMGateway, fresh_db: Path
):
"""When department_ids is None, no quota check is performed."""
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=None,
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_usage_recorded_with_user_and_department(
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
):
"""After a successful call, the usage record carries user_id + department_id."""
dept_id = _random_dept_id()
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
summary = store.get_usage()
assert len(summary.records) == 1
rec = summary.records[0]
assert rec.user_id == "u1"
assert rec.department_id == dept_id
assert rec.model == "gpt-4o"
assert rec.total_tokens == 150 # 100 prompt + 50 completion
# ---------------------------------------------------------------------------
# QuotaExceededError dataclass-like behavior
# ---------------------------------------------------------------------------
class TestQuotaExceededError:
def test_error_message_includes_metadata(self):
err = QuotaExceededError(
department_id="d1",
quota_type="token_limit",
period="daily",
limit=1000,
current=1500,
)
msg = str(err)
assert "d1" in msg
assert "token_limit" in msg
assert "daily" in msg
assert "1000" in msg
assert "1500" in msg
def test_error_attributes_preserved(self):
err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000)
assert err.department_id == "d1"
assert err.quota_type == "cost_limit"
assert err.period == "monthly"
assert err.limit == 5000
assert err.current == 6000

View File

@ -1,6 +1,5 @@
"""Unit tests for UsageStore (U4 — UsageStore Persistence)."""
import pytest
from datetime import datetime, timedelta, timezone
from agentkit.llm.protocol import TokenUsage
@ -114,6 +113,128 @@ class TestInMemoryUsageStore:
# Should be parseable as ISO 8601
datetime.fromisoformat(rec.timestamp)
def test_record_with_user_and_department(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
rec = store.get_usage().records[0]
assert rec.user_id == "u1"
assert rec.department_id == "d1"
def test_record_defaults_user_department_to_none(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
rec = store.get_usage().records[0]
assert rec.user_id is None
assert rec.department_id is None
def test_get_usage_filters_by_user(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1")
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2")
summary = store.get_usage(user_id="u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
def test_get_usage_filters_by_department(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d1")
store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d2")
summary = store.get_usage(department_id="d1")
assert len(summary.records) == 1
assert summary.records[0].department_id == "d1"
def test_get_usage_by_user(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u2",
department_id="d2",
)
summary = store.get_usage_by_user("u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
assert summary.total_tokens == 150
def test_get_usage_by_department(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u2",
department_id="d2",
)
summary = store.get_usage_by_department("d1")
assert len(summary.records) == 1
assert summary.records[0].department_id == "d1"
assert summary.total_tokens == 150
def test_summary_includes_by_user_and_by_department(self):
store = InMemoryUsageStore()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
summary = store.get_usage()
assert "u1" in summary.by_user
assert summary.by_user["u1"]["count"] == 2
assert summary.by_user["u1"]["total_tokens"] == 300
assert "d1" in summary.by_department
assert summary.by_department["d1"]["count"] == 2
# ---------------------------------------------------------------------------
# UsageRecord / UsageBucket / UsageSummary dataclasses
@ -123,17 +244,25 @@ class TestInMemoryUsageStore:
class TestDataclasses:
def test_usage_record_auto_timestamp(self):
rec = UsageRecord(
agent_name="a", model="m",
prompt_tokens=1, completion_tokens=1,
total_tokens=2, cost=0.01, latency_ms=100,
agent_name="a",
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost=0.01,
latency_ms=100,
)
assert rec.timestamp != ""
def test_usage_record_explicit_timestamp(self):
rec = UsageRecord(
agent_name="a", model="m",
prompt_tokens=1, completion_tokens=1,
total_tokens=2, cost=0.01, latency_ms=100,
agent_name="a",
model="m",
prompt_tokens=1,
completion_tokens=1,
total_tokens=2,
cost=0.01,
latency_ms=100,
timestamp="2026-01-01T00:00:00+00:00",
)
assert rec.timestamp == "2026-01-01T00:00:00+00:00"
@ -208,6 +337,98 @@ class TestRedisUsageStoreMocked:
assert len(key) == 10
assert key[4] == "-"
def test_v2_keys_with_user_and_department(self):
store = self._make_store()
hash_key, list_key = store._v2_keys("2026-06-21", "u1", "d1")
assert hash_key == "agentkit:usage:v2:2026-06-21:u1:d1"
assert list_key == "agentkit:usage_records:v2:2026-06-21:u1:d1"
def test_v2_keys_with_none_user_and_department(self):
store = self._make_store()
hash_key, list_key = store._v2_keys("2026-06-21", None, None)
# None values are normalized to "none" in the key.
assert hash_key == "agentkit:usage:v2:2026-06-21:none:none"
assert list_key == "agentkit:usage_records:v2:2026-06-21:none:none"
def test_record_degraded_with_user_and_department(self):
store = self._make_store()
store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
# Should be in fallback with user/department attached.
summary = store._fallback.get_usage()
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
assert summary.records[0].department_id == "d1"
def test_get_usage_degraded_with_user_filter(self):
store = self._make_store()
store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u2",
department_id="d2",
)
summary = store.get_usage(user_id="u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
def test_get_usage_by_user_degraded(self):
store = self._make_store()
store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
summary = store.get_usage_by_user("u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
def test_get_usage_by_department_degraded(self):
store = self._make_store()
store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
summary = store.get_usage_by_department("d1")
assert len(summary.records) == 1
assert summary.records[0].department_id == "d1"
# ---------------------------------------------------------------------------
# Factory