From 09feca330740893be0b8fac41a49846c25ed4f2a Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 21 Jun 2026 17:23:20 +0800 Subject: [PATCH] =?UTF-8?q?feat(admin):=20U7=20=E2=80=94=20usage=20dashboa?= =?UTF-8?q?rd=20+=20quota=20enforcement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit UsageRecord extended with user_id + department_id (backward compatible). UsageStore Protocol extended: record() accepts user_id/department_id, get_usage() accepts filters, new get_usage_by_user/department methods. RedisUsageStore uses versioned keys (v2) for new records. LLMGateway.chat()/chat_stream() accept user_id, department_ids, db_path. Quota check before provider call: model whitelist + token limit + cost limit (daily). Multi-department uses strictest-wins (any exceed → reject). QuotaExceededError → 429 at route layer. UsageService: summary, timeseries, by-model, top-users, export (CSV/JSON). 5 new admin endpoints under /admin/usage/*. llm_gateway.py routes pass DepartmentContext + db_path to gateway, catch QuotaExceededError → 429 (JSON for /chat, SSE error for /stream). 84 new tests. 441 admin+usage tests pass, no regressions. --- src/agentkit/llm/gateway.py | 193 ++++++++++- src/agentkit/llm/providers/tracker.py | 27 +- src/agentkit/llm/providers/usage_store.py | 313 +++++++++++++---- src/agentkit/server/admin/usage_service.py | 298 ++++++++++++++++ src/agentkit/server/routes/admin.py | 192 +++++++++++ src/agentkit/server/routes/llm_gateway.py | 45 ++- tests/integration/admin/test_usage_routes.py | 341 +++++++++++++++++++ tests/unit/admin/test_usage_service.py | 330 ++++++++++++++++++ tests/unit/llm/test_quota_enforcement.py | 321 +++++++++++++++++ tests/unit/llm/test_usage_store.py | 235 ++++++++++++- 10 files changed, 2215 insertions(+), 80 deletions(-) create mode 100644 src/agentkit/server/admin/usage_service.py create mode 100644 tests/integration/admin/test_usage_routes.py create mode 100644 tests/unit/admin/test_usage_service.py create mode 100644 tests/unit/llm/test_quota_enforcement.py diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 6bed7ab..8d93d00 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -3,6 +3,8 @@ import asyncio import logging import time +from datetime import datetime, timezone +from pathlib import Path from typing import Any from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError @@ -15,6 +17,32 @@ from agentkit.telemetry.metrics import llm_token_histogram logger = logging.getLogger(__name__) +class QuotaExceededError(Exception): + """Raised when a department's LLM quota is exceeded. + + Carries enough metadata for the API layer to return a structured + 429 response (department_id, quota_type, period, limit, current). + """ + + def __init__( + self, + department_id: str, + quota_type: str, + period: str, + limit: Any, + current: Any, + ) -> None: + self.department_id = department_id + self.quota_type = quota_type + self.period = period + self.limit = limit + self.current = current + super().__init__( + f"Quota exceeded for department {department_id}: " + f"{quota_type} {period} (limit={limit}, current={current})" + ) + + class LLMGateway: """LLM 网关 - Provider 注册、模型别名解析、Fallback、Usage 追踪、Cache""" @@ -83,6 +111,9 @@ class LLMGateway: tools: list[dict] | None = None, tool_choice: str = "auto", timeout: float | None = None, + user_id: str | None = None, + department_ids: list[str] | None = None, + db_path: Path | str | None = None, **kwargs, ) -> LLMResponse: """发送 chat 请求,自动解析别名和 Fallback""" @@ -91,6 +122,12 @@ class LLMGateway: if not self._providers: raise LLMProviderError("", "No provider registered") + # ── Quota enforcement ── + # Only enforce when department_ids + db_path are provided + # (other call sites pass None — no quota check). + if department_ids and db_path: + await self._enforce_quota(db_path, department_ids, resolved_model) + # Telemetry: start LLM span _span_cm = None _span = None @@ -131,12 +168,14 @@ class LLMGateway: result = await self._cache.get(cache_key) if result.hit: latency_ms = (time.monotonic() - start) * 1000 - self._usage_tracker.record( + self._record_usage( agent_name=agent_name, model=result.response.model, usage=result.response.usage, cost=0.0, latency_ms=latency_ms, + user_id=user_id, + department_ids=department_ids, ) if _span is not None: _span.set_attribute("gen_ai.cache.hit", True) @@ -158,12 +197,14 @@ class LLMGateway: result = await self._cache.semantic_search(query_embedding) if result.hit: latency_ms = (time.monotonic() - start) * 1000 - self._usage_tracker.record( + self._record_usage( agent_name=agent_name, model=result.response.model, usage=result.response.usage, cost=0.0, latency_ms=latency_ms, + user_id=user_id, + department_ids=department_ids, ) if _span is not None: _span.set_attribute("gen_ai.cache.hit", True) @@ -204,12 +245,14 @@ class LLMGateway: if response.usage: latency_ms = (time.monotonic() - start) * 1000 cost = self._calculate_cost(model_name, response.usage) - self._usage_tracker.record( + self._record_usage( agent_name=agent_name, model=model_name, usage=response.usage, cost=cost, latency_ms=latency_ms, + user_id=user_id, + department_ids=department_ids, ) logger.warning( f"Model '{model_name}' returned empty content with no tool_calls, " @@ -243,12 +286,14 @@ class LLMGateway: cost = self._calculate_cost(response.model, response.usage) # 记录使用量 - self._usage_tracker.record( + self._record_usage( agent_name=agent_name, model=response.model, usage=response.usage, cost=cost, latency_ms=latency_ms, + user_id=user_id, + department_ids=department_ids, ) # Telemetry: record token usage and end span @@ -278,6 +323,9 @@ class LLMGateway: tools: list[dict] | None = None, tool_choice: str = "auto", timeout: float | None = None, + user_id: str | None = None, + department_ids: list[str] | None = None, + db_path: Path | str | None = None, **kwargs, ): """Stream chat response with fallback support. @@ -293,6 +341,10 @@ class LLMGateway: if not self._providers: raise LLMProviderError("", "No provider registered") + # ── Quota enforcement ── + if department_ids and db_path: + await self._enforce_quota(db_path, department_ids, resolved_model) + models_to_try = self._get_models_to_try(resolved_model) last_error: Exception | None = None @@ -354,12 +406,14 @@ class LLMGateway: if final_usage is None: final_usage = TokenUsage() cost = self._calculate_cost(final_model, final_usage) - self._usage_tracker.record( + self._record_usage( agent_name=agent_name, model=final_model, usage=final_usage, cost=cost, latency_ms=latency_ms, + user_id=user_id, + department_ids=department_ids, ) # Empty stream detection: if no content was produced, @@ -453,3 +507,132 @@ class LLMGateway: start_time=start_time, end_time=end_time, ) + + # ------------------------------------------------------------------ + # Quota enforcement helpers (U7) + # ------------------------------------------------------------------ + + def _record_usage( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + user_id: str | None, + department_ids: list[str] | None, + ) -> None: + """Record a usage event, attaching user_id and (first) department_id. + + We attach only the first department_id to the record because + usage attribution is per-department. If a user belongs to + multiple departments, the caller is responsible for choosing + which department to bill — the gateway just records what it's + told. + """ + dept_id = department_ids[0] if department_ids else None + self._usage_tracker.record( + agent_name=agent_name, + model=model, + usage=usage, + cost=cost, + latency_ms=latency_ms, + user_id=user_id, + department_id=dept_id, + ) + + async def _enforce_quota( + self, + db_path: Path | str, + department_ids: list[str], + resolved_model: str, + ) -> None: + """Run all quota checks for the given departments. + + Strictest-wins: if ANY department fails ANY check, raises + :class:`QuotaExceededError` and the request is rejected. + """ + # Lazy import to avoid circular dependency (admin → ... → gateway). + from agentkit.server.admin.quota_service import get_quota_service + + quota_service = get_quota_service() + db = Path(db_path) + + for dept_id in department_ids: + # 1. Model whitelist + allowed, _reason = await quota_service.is_model_allowed(db, dept_id, resolved_model) + if not allowed: + raise QuotaExceededError( + department_id=dept_id, + quota_type="model_whitelist", + period="", + limit="", + current=resolved_model, + ) + + # 2. Token limit (daily) + current_tokens = await self._get_current_usage_for_quota(dept_id, "daily") + allowed, _reason = await quota_service.check_quota( + db, dept_id, "token_limit", "daily", current_tokens + ) + if not allowed: + quota = await quota_service.get_quota(db, dept_id, "token_limit", "daily") + limit = quota["limit_value"] if quota else None + raise QuotaExceededError( + department_id=dept_id, + quota_type="token_limit", + period="daily", + limit=limit, + current=current_tokens, + ) + + # 3. Cost limit (daily) + current_cost = await self._get_current_cost_for_quota(dept_id, "daily") + allowed, _reason = await quota_service.check_quota( + db, dept_id, "cost_limit", "daily", current_cost + ) + if not allowed: + quota = await quota_service.get_quota(db, dept_id, "cost_limit", "daily") + limit = quota["limit_value"] if quota else None + raise QuotaExceededError( + department_id=dept_id, + quota_type="cost_limit", + period="daily", + limit=limit, + current=current_cost, + ) + + async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int: + """Return total tokens used by ``department_id`` in the current period. + + ``period`` is ``"daily"`` or ``"monthly"``. For ``"daily"`` the + window is since 00:00 UTC today; for ``"monthly"`` since the + first of the current month. + """ + now = datetime.now(timezone.utc) + if period == "monthly": + start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + else: + start = now.replace(hour=0, minute=0, second=0, microsecond=0) + summary = self._usage_tracker.get_usage( + department_id=department_id, start_time=start, end_time=now + ) + return int(summary.total_tokens) + + async def _get_current_cost_for_quota(self, department_id: str, period: str) -> float: + """Return total cost (in cents) for ``department_id`` in the current period. + + ``period`` is ``"daily"`` or ``"monthly"``. Quota cost_limit is + stored in cents, so we convert the float USD cost from the usage + store to cents (×100) for comparison. + """ + now = datetime.now(timezone.utc) + if period == "monthly": + start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + else: + start = now.replace(hour=0, minute=0, second=0, microsecond=0) + summary = self._usage_tracker.get_usage( + department_id=department_id, start_time=start, end_time=now + ) + # cost_limit is stored in cents; convert from USD to cents. + return float(summary.total_cost) * 100.0 diff --git a/src/agentkit/llm/providers/tracker.py b/src/agentkit/llm/providers/tracker.py index fe9d056..a962a87 100644 --- a/src/agentkit/llm/providers/tracker.py +++ b/src/agentkit/llm/providers/tracker.py @@ -23,15 +23,38 @@ class UsageTracker: usage: TokenUsage, cost: float, latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, ) -> None: """记录一次使用""" - self._store.record(agent_name, model, usage, cost, latency_ms) + self._store.record( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + user_id: str | None = None, + department_id: str | None = None, ) -> UsageSummary: """查询使用量汇总""" - return self._store.get_usage(agent_name, start_time, end_time) + return self._store.get_usage( + agent_name=agent_name, + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + + @property + def store(self) -> UsageStore: + """Expose the underlying store for service-layer queries.""" + return self._store diff --git a/src/agentkit/llm/providers/usage_store.py b/src/agentkit/llm/providers/usage_store.py index 97916c8..67ee4b2 100644 --- a/src/agentkit/llm/providers/usage_store.py +++ b/src/agentkit/llm/providers/usage_store.py @@ -5,8 +5,12 @@ backends. Replaces the in-memory list in UsageTracker with a pluggable store that survives restarts and supports multi-instance deployment. Key schema (Redis): - agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)} - agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM + agentkit:usage:v2:{date}:{user_id}:{department_id} → Hash: {agent_name:model → JSON(UsageBucket)} + agentkit:usage_records:v2:{date}:{user_id}:{department_id} → List: JSON(UsageRecord) with LTRIM + +Legacy v1 keys (still readable for backward compat): + agentkit:usage:{date} → Hash + agentkit:usage_records:{date} → List """ import json @@ -32,6 +36,8 @@ class UsageRecord: cost: float latency_ms: float timestamp: str = "" # ISO 8601 string for JSON serialization + user_id: str | None = None + department_id: str | None = None def __post_init__(self): if not self.timestamp: @@ -57,6 +63,8 @@ class UsageSummary: total_cost: float = 0.0 by_model: dict[str, dict[str, int | float]] = field(default_factory=dict) records: list[UsageRecord] = field(default_factory=list) + by_user: dict[str, dict[str, int | float]] = field(default_factory=dict) + by_department: dict[str, dict[str, int | float]] = field(default_factory=dict) # --------------------------------------------------------------------------- @@ -75,6 +83,8 @@ class UsageStore(Protocol): usage: TokenUsage, cost: float, latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, ) -> None: """Record a usage event.""" ... @@ -84,6 +94,8 @@ class UsageStore(Protocol): agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + user_id: str | None = None, + department_id: str | None = None, ) -> UsageSummary: """Query usage summary.""" ... @@ -109,6 +121,8 @@ class InMemoryUsageStore: usage: TokenUsage, cost: float, latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, ) -> None: rec = UsageRecord( agent_name=agent_name, @@ -118,16 +132,20 @@ class InMemoryUsageStore: total_tokens=usage.total_tokens, cost=cost, latency_ms=latency_ms, + user_id=user_id, + department_id=department_id, ) self._records.append(rec) if len(self._records) > self.MAX_RECORDS: - self._records = self._records[-self.MAX_RECORDS:] + self._records = self._records[-self.MAX_RECORDS :] def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + user_id: str | None = None, + department_id: str | None = None, ) -> UsageSummary: filtered = self._records @@ -137,26 +155,65 @@ class InMemoryUsageStore: filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time] if end_time is not None: filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time] + if user_id is not None: + filtered = [r for r in filtered if r.user_id == user_id] + if department_id is not None: + filtered = [r for r in filtered if r.department_id == department_id] if not filtered: return UsageSummary() - total_tokens = sum(r.total_tokens for r in filtered) - total_cost = sum(r.cost for r in filtered) + return self._aggregate(filtered) + + def get_usage_by_user( + self, + user_id: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Aggregate usage for a specific user.""" + return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time) + + def get_usage_by_department( + self, + department_id: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Aggregate usage for a specific department.""" + return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time) + + @staticmethod + def _aggregate(records: list[UsageRecord]) -> UsageSummary: + """Build a :class:`UsageSummary` from a list of records.""" + total_tokens = sum(r.total_tokens for r in records) + total_cost = sum(r.cost for r in records) by_model: dict[str, dict[str, int | float]] = {} - for r in filtered: - if r.model not in by_model: - by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} - by_model[r.model]["total_tokens"] += r.total_tokens - by_model[r.model]["total_cost"] += r.cost - by_model[r.model]["count"] += 1 + by_user: dict[str, dict[str, int | float]] = {} + by_department: dict[str, dict[str, int | float]] = {} + + def _bump(bucket_map: dict[str, dict[str, int | float]], key: str, r: UsageRecord) -> None: + if key not in bucket_map: + bucket_map[key] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} + bucket_map[key]["total_tokens"] += r.total_tokens + bucket_map[key]["total_cost"] += r.cost + bucket_map[key]["count"] += 1 + + for r in records: + _bump(by_model, r.model, r) + if r.user_id is not None: + _bump(by_user, r.user_id, r) + if r.department_id is not None: + _bump(by_department, r.department_id, r) return UsageSummary( total_tokens=total_tokens, total_cost=total_cost, by_model=by_model, - records=filtered, + by_user=by_user, + by_department=by_department, + records=records, ) @@ -168,13 +225,21 @@ class InMemoryUsageStore: class RedisUsageStore: """Redis-backed usage store using Hash per date for O(1) writes. - Key schema: - agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)} - agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM + Key schema (v2 — includes user_id/department_id in key): + agentkit:usage:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'} + → Hash: {agent:model → JSON(UsageBucket)} + agentkit:usage_records:v2:{YYYY-MM-DD}:{user_id or 'none'}:{department_id or 'none'} + → List: JSON(UsageRecord) with LTRIM + + Legacy v1 keys (still readable for backward compat): + agentkit:usage:{YYYY-MM-DD} → Hash + agentkit:usage_records:{YYYY-MM-DD} → List """ USAGE_PREFIX = "agentkit:usage:" RECORDS_PREFIX = "agentkit:usage_records:" + USAGE_PREFIX_V2 = "agentkit:usage:v2:" + RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:" MAX_RECORDS_PER_DAY = 50000 TTL_DAYS = 90 # Auto-expire after 90 days @@ -188,6 +253,7 @@ class RedisUsageStore: async def _get_redis(self): if self._redis is None: import redis.asyncio as aioredis + self._redis = aioredis.from_url(self._redis_url, decode_responses=True) return self._redis @@ -195,9 +261,8 @@ class RedisUsageStore: """Get or create a persistent sync Redis client (connection pool backed).""" if self._sync_redis is None: import redis as sync_redis - self._sync_redis = sync_redis.from_url( - self._redis_url, decode_responses=True - ) + + self._sync_redis = sync_redis.from_url(self._redis_url, decode_responses=True) return self._sync_redis async def aclose(self) -> None: @@ -218,6 +283,22 @@ class RedisUsageStore: def _today_key(self) -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%d") + @staticmethod + def _scope_key(part: str | None) -> str: + """Normalize a user_id/department_id for use in a Redis key.""" + return part if part else "none" + + def _v2_keys( + self, date_key: str, user_id: str | None, department_id: str | None + ) -> tuple[str, str]: + """Return (hash_key, list_key) for v2 schema.""" + u = self._scope_key(user_id) + d = self._scope_key(department_id) + return ( + f"{self.USAGE_PREFIX_V2}{date_key}:{u}:{d}", + f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}", + ) + def record( self, agent_name: str, @@ -225,6 +306,8 @@ class RedisUsageStore: usage: TokenUsage, cost: float, latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, ) -> None: """Record usage — sync wrapper for async Redis. @@ -233,15 +316,22 @@ class RedisUsageStore: needing an event loop in the caller. """ if self._degraded and self._fallback is not None: - self._fallback.record(agent_name, model, usage, cost, latency_ms) + self._fallback.record( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) return try: r = self._get_sync_redis() date_key = self._today_key() - hash_key = f"{self.USAGE_PREFIX}{date_key}" - list_key = f"{self.RECORDS_PREFIX}{date_key}" + hash_key, list_key = self._v2_keys(date_key, user_id, department_id) bucket_field = f"{agent_name}:{model}" # Atomic HINCRBYFLOAT for bucket aggregation @@ -261,17 +351,26 @@ class RedisUsageStore: total_tokens=usage.total_tokens, cost=cost, latency_ms=latency_ms, + user_id=user_id, + department_id=department_id, + ) + pipe.rpush( + list_key, + json.dumps( + { + "agent_name": rec.agent_name, + "model": rec.model, + "prompt_tokens": rec.prompt_tokens, + "completion_tokens": rec.completion_tokens, + "total_tokens": rec.total_tokens, + "cost": rec.cost, + "latency_ms": rec.latency_ms, + "timestamp": rec.timestamp, + "user_id": rec.user_id, + "department_id": rec.department_id, + } + ), ) - pipe.rpush(list_key, json.dumps({ - "agent_name": rec.agent_name, - "model": rec.model, - "prompt_tokens": rec.prompt_tokens, - "completion_tokens": rec.completion_tokens, - "total_tokens": rec.total_tokens, - "cost": rec.cost, - "latency_ms": rec.latency_ms, - "timestamp": rec.timestamp, - })) pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) # Set TTL on first write of the day @@ -283,17 +382,38 @@ class RedisUsageStore: logger.warning(f"Redis usage record failed: {e}") self._degrade_to_fallback() if self._fallback is not None: - self._fallback.record(agent_name, model, usage, cost, latency_ms) + self._fallback.record( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, + user_id: str | None = None, + department_id: str | None = None, ) -> UsageSummary: - """Query usage summary from Redis.""" + """Query usage summary from Redis. + + Scans v2 keys (filtered by user_id/department_id when provided) + and legacy v1 keys (no per-user/department scoping). Records + from both schemas are merged. + """ if self._degraded and self._fallback is not None: - return self._fallback.get_usage(agent_name, start_time, end_time) + return self._fallback.get_usage( + agent_name=agent_name, + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) try: r = self._get_sync_redis() @@ -303,47 +423,115 @@ class RedisUsageStore: end = end_time or datetime.now(timezone.utc) all_records: list[UsageRecord] = [] - # Scan date keys in range + + # Scan v2 keys. current = start.date() end_date = end.date() while current <= end_date: - list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}" - raw_records = r.lrange(list_key, 0, -1) - for raw in raw_records: - data = json.loads(raw) - rec = UsageRecord(**data) - rec_ts = datetime.fromisoformat(rec.timestamp) - if rec_ts >= start and rec_ts <= end: - if agent_name is None or rec.agent_name == agent_name: - all_records.append(rec) + date_key = current.isoformat() + # When user_id/department_id is provided, scan only the + # matching scope key. Otherwise scan all scopes for that + # date via SCAN. + if user_id is not None or department_id is not None: + list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}" + all_records.extend(self._read_list(r, list_key, start, end, agent_name)) + else: + # Scan all v2 list keys for this date. + pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*" + for key in r.scan_iter(match=pattern, count=200): + all_records.extend(self._read_list(r, key, start, end, agent_name)) current = current + timedelta(days=1) + # Also scan legacy v1 keys (no user/department scoping). + current = start.date() + while current <= end_date: + list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}" + all_records.extend(self._read_list(r, list_key, start, end, agent_name)) + current = current + timedelta(days=1) + + # Apply user_id/department_id filters to records from legacy + # v1 keys (which don't carry these fields — they'll be None). + if user_id is not None: + all_records = [r for r in all_records if r.user_id == user_id] + if department_id is not None: + all_records = [r for r in all_records if r.department_id == department_id] + if not all_records: return UsageSummary() - total_tokens = sum(r.total_tokens for r in all_records) - total_cost = sum(r.cost for r in all_records) - - by_model: dict[str, dict[str, int | float]] = {} - for r in all_records: - if r.model not in by_model: - by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} - by_model[r.model]["total_tokens"] += r.total_tokens - by_model[r.model]["total_cost"] += r.cost - by_model[r.model]["count"] += 1 - - return UsageSummary( - total_tokens=total_tokens, - total_cost=total_cost, - by_model=by_model, - records=all_records, - ) + return InMemoryUsageStore._aggregate(all_records) except Exception as e: logger.warning(f"Redis usage query failed: {e}") if self._fallback is not None: - return self._fallback.get_usage(agent_name, start_time, end_time) + return self._fallback.get_usage( + agent_name=agent_name, + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) return UsageSummary() + def get_usage_by_user( + self, + user_id: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Aggregate usage for a specific user.""" + return self.get_usage(user_id=user_id, start_time=start_time, end_time=end_time) + + def get_usage_by_department( + self, + department_id: str, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> UsageSummary: + """Aggregate usage for a specific department.""" + return self.get_usage(department_id=department_id, start_time=start_time, end_time=end_time) + + @staticmethod + def _read_list( + r: Any, + list_key: str, + start: datetime, + end: datetime, + agent_name: str | None, + ) -> list[UsageRecord]: + """Read all records from a Redis list, filtered by time range and agent.""" + out: list[UsageRecord] = [] + raw_records = r.lrange(list_key, 0, -1) + for raw in raw_records: + try: + data = json.loads(raw) + except json.JSONDecodeError: + continue + # Build record, tolerating legacy records without user_id/department_id. + rec = UsageRecord( + agent_name=data["agent_name"], + model=data["model"], + prompt_tokens=data["prompt_tokens"], + completion_tokens=data["completion_tokens"], + total_tokens=data["total_tokens"], + cost=data["cost"], + latency_ms=data["latency_ms"], + timestamp=data.get("timestamp", ""), + user_id=data.get("user_id"), + department_id=data.get("department_id"), + ) + if not rec.timestamp: + continue + try: + rec_ts = datetime.fromisoformat(rec.timestamp) + except ValueError: + continue + if rec_ts < start or rec_ts > end: + continue + if agent_name is not None and rec.agent_name != agent_name: + continue + out.append(rec) + return out + # --------------------------------------------------------------------------- # Factory @@ -366,6 +554,7 @@ def create_usage_store( if backend in ("auto", "redis"): try: import redis # noqa: F401 + return RedisUsageStore(redis_url=redis_url) except ImportError: logger.warning("redis package not available, falling back to in-memory usage store") diff --git a/src/agentkit/server/admin/usage_service.py b/src/agentkit/server/admin/usage_service.py new file mode 100644 index 0000000..c868a82 --- /dev/null +++ b/src/agentkit/server/admin/usage_service.py @@ -0,0 +1,298 @@ +"""UsageService — read-side aggregations for the usage dashboard (U7). + +This module provides read-only aggregations over a :class:`UsageStore` +for the admin usage dashboard. It is intentionally a thin layer — the +store already produces :class:`UsageSummary` aggregations, and this +service just shapes them for the dashboard endpoints (timeseries, +top-N, CSV/JSON export). + +The service is a module-level singleton (see :func:`get_usage_service`) +so tests can inject a custom instance via :func:`set_usage_service`. +""" + +from __future__ import annotations + +import csv +import io +import json +import logging +from datetime import datetime +from typing import Any + +from agentkit.llm.providers.usage_store import UsageStore, UsageSummary + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _bucket_start(ts: datetime, interval: str) -> datetime: + """Return the start of the time bucket containing ``ts``.""" + if interval == "hour": + return ts.replace(minute=0, second=0, microsecond=0) + # Default: day + return ts.replace(hour=0, minute=0, second=0, microsecond=0) + + +# --------------------------------------------------------------------------- +# Service +# --------------------------------------------------------------------------- + + +class UsageService: + """Read-side aggregations for the usage dashboard.""" + + async def get_usage_summary( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> dict[str, Any]: + """Return a flat usage summary dict. + + Shape:: + + { + "total_tokens": int, + "total_cost": float, + "total_requests": int, + "by_model": {model: {total_tokens, total_cost, count}, ...}, + "by_user": {user_id: {...}, ...}, + "by_department": {department_id: {...}, ...}, + } + """ + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + return self._summary_to_dict(summary) + + async def get_usage_timeseries( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + interval: str = "day", + ) -> list[dict[str, Any]]: + """Return a time-bucketed series. + + Each item has shape ``{timestamp, tokens, cost, requests}``. + Buckets with no activity are omitted (callers can fill gaps). + """ + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + buckets: dict[datetime, dict[str, Any]] = {} + for rec in summary.records: + try: + ts = datetime.fromisoformat(rec.timestamp) + except ValueError: + continue + bucket = _bucket_start(ts, interval) + if bucket not in buckets: + buckets[bucket] = {"tokens": 0, "cost": 0.0, "requests": 0} + buckets[bucket]["tokens"] += rec.total_tokens + buckets[bucket]["cost"] += rec.cost + buckets[bucket]["requests"] += 1 + return [ + { + "timestamp": bucket.isoformat(), + "tokens": data["tokens"], + "cost": data["cost"], + "requests": data["requests"], + } + for bucket, data in sorted(buckets.items()) + ] + + async def get_usage_by_model( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[dict[str, Any]]: + """Return a per-model breakdown.""" + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + return [ + { + "model": model, + "tokens": data["total_tokens"], + "cost": data["total_cost"], + "requests": data["count"], + } + for model, data in sorted(summary.by_model.items()) + ] + + async def get_top_users( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int = 10, + ) -> list[dict[str, Any]]: + """Return the top-N users by total token usage.""" + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + rows = [ + { + "user_id": uid, + "tokens": data["total_tokens"], + "cost": data["total_cost"], + "requests": data["count"], + } + for uid, data in summary.by_user.items() + ] + rows.sort(key=lambda r: r["tokens"], reverse=True) + return rows[:limit] + + async def get_top_departments( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + limit: int = 10, + ) -> list[dict[str, Any]]: + """Return the top-N departments by total token usage.""" + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + rows = [ + { + "department_id": did, + "tokens": data["total_tokens"], + "cost": data["total_cost"], + "requests": data["count"], + } + for did, data in summary.by_department.items() + ] + rows.sort(key=lambda r: r["tokens"], reverse=True) + return rows[:limit] + + async def export_usage( + self, + usage_store: UsageStore, + department_id: str | None = None, + user_id: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + format: str = "csv", + ) -> str: + """Export raw usage records as CSV or JSON. + + ``format`` is ``"csv"`` (default) or ``"json"``. + """ + summary = usage_store.get_usage( + start_time=start_time, + end_time=end_time, + user_id=user_id, + department_id=department_id, + ) + records = [ + { + "timestamp": rec.timestamp, + "agent_name": rec.agent_name, + "model": rec.model, + "prompt_tokens": rec.prompt_tokens, + "completion_tokens": rec.completion_tokens, + "total_tokens": rec.total_tokens, + "cost": rec.cost, + "latency_ms": rec.latency_ms, + "user_id": rec.user_id or "", + "department_id": rec.department_id or "", + } + for rec in summary.records + ] + if format == "json": + return json.dumps(records, ensure_ascii=False, indent=2) + # Default: CSV + out = io.StringIO() + if records: + writer = csv.DictWriter(out, fieldnames=list(records[0].keys())) + writer.writeheader() + writer.writerows(records) + else: + # Empty CSV with just headers + writer = csv.DictWriter( + out, + fieldnames=[ + "timestamp", + "agent_name", + "model", + "prompt_tokens", + "completion_tokens", + "total_tokens", + "cost", + "latency_ms", + "user_id", + "department_id", + ], + ) + writer.writeheader() + return out.getvalue() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _summary_to_dict(summary: UsageSummary) -> dict[str, Any]: + """Convert a :class:`UsageSummary` to a flat dict response.""" + return { + "total_tokens": summary.total_tokens, + "total_cost": summary.total_cost, + "total_requests": len(summary.records), + "by_model": dict(summary.by_model), + "by_user": dict(summary.by_user), + "by_department": dict(summary.by_department), + } + + +# --------------------------------------------------------------------------- +# Module-level singleton (overridable in tests via set_usage_service) +# --------------------------------------------------------------------------- + + +_usage_service: UsageService | None = None + + +def get_usage_service() -> UsageService: + """Return the process-wide :class:`UsageService` (lazy singleton).""" + global _usage_service + if _usage_service is None: + _usage_service = UsageService() + return _usage_service + + +def set_usage_service(service: UsageService | None) -> None: + """Inject a custom :class:`UsageService` (used by tests).""" + global _usage_service + _usage_service = service diff --git a/src/agentkit/server/routes/admin.py b/src/agentkit/server/routes/admin.py index 1a0d739..bc940be 100644 --- a/src/agentkit/server/routes/admin.py +++ b/src/agentkit/server/routes/admin.py @@ -16,10 +16,12 @@ import time (keeps the module self-contained and test-friendly). from __future__ import annotations import logging +from datetime import datetime from pathlib import Path from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import PlainTextResponse from pydantic import BaseModel, ConfigDict from agentkit.server.admin.department_service import get_department_service @@ -30,6 +32,7 @@ from agentkit.server.admin.llm_config_service import ( ) from agentkit.server.admin.quota_service import get_quota_service from agentkit.server.admin.skill_service import get_skill_service +from agentkit.server.admin.usage_service import get_usage_service from agentkit.server.admin.user_service import get_user_service from agentkit.server.auth.dependencies import require_authenticated from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db @@ -1138,3 +1141,192 @@ async def rebuild_kb_source( return svc.rebuild_index(source_id) except ValueError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc + + +# --------------------------------------------------------------------------- +# Usage dashboard endpoints (U7) — usage aggregations + export +# --------------------------------------------------------------------------- + + +def _get_usage_store(request: Request) -> Any: + """Return the live :class:`UsageStore` from ``app.state.llm_gateway``. + + Raises HTTPException(500) if the gateway or usage store is missing — + usage endpoints cannot function without it. + """ + gateway = getattr(request.app.state, "llm_gateway", None) + if gateway is None: + raise HTTPException( + status_code=500, + detail="LLM gateway not initialized on app.state", + ) + try: + return gateway._usage_tracker.store # type: ignore[attr-defined] + except AttributeError as exc: + raise HTTPException( + status_code=500, + detail="Usage store not available on LLM gateway", + ) from exc + + +def _parse_iso(value: str | None) -> datetime | None: + """Parse an ISO 8601 string into a timezone-aware datetime.""" + if value is None or value == "": + return None + try: + dt = datetime.fromisoformat(value) + except ValueError: + # Try the trailing-Z form. + if value.endswith("Z"): + try: + from datetime import timezone + + dt = datetime.fromisoformat(value[:-1]).replace(tzinfo=timezone.utc) + except ValueError: + raise HTTPException( + status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}" + ) + else: + raise HTTPException(status_code=400, detail=f"Invalid ISO 8601 timestamp: {value!r}") + return dt + + +@admin_router.get("/usage/summary") +async def get_usage_summary( + request: Request, + department_id: str | None = None, + user_id: str | None = None, + start: str | None = None, + end: str | None = None, + admin: dict[str, Any] = Depends(_require_admin), +) -> dict[str, Any]: + """Return an aggregated usage summary. + + Query params: ``department_id``, ``user_id``, ``start``, ``end`` + (ISO 8601). Admins see all data; non-admin callers are blocked by + ``_require_admin`` (403). + """ + store = _get_usage_store(request) + svc = get_usage_service() + return await svc.get_usage_summary( + store, + department_id=department_id, + user_id=user_id, + start_time=_parse_iso(start), + end_time=_parse_iso(end), + ) + + +@admin_router.get("/usage/timeseries") +async def get_usage_timeseries( + request: Request, + department_id: str | None = None, + user_id: str | None = None, + start: str | None = None, + end: str | None = None, + interval: str = "day", + admin: dict[str, Any] = Depends(_require_admin), +) -> list[dict[str, Any]]: + """Return a time-bucketed usage series. + + Query params: ``department_id``, ``user_id``, ``start``, ``end`` + (ISO 8601), ``interval`` (``day`` or ``hour``, default ``day``). + """ + if interval not in ("day", "hour"): + raise HTTPException(status_code=400, detail="interval must be 'day' or 'hour'") + store = _get_usage_store(request) + svc = get_usage_service() + return await svc.get_usage_timeseries( + store, + department_id=department_id, + user_id=user_id, + start_time=_parse_iso(start), + end_time=_parse_iso(end), + interval=interval, + ) + + +@admin_router.get("/usage/by-model") +async def get_usage_by_model( + request: Request, + department_id: str | None = None, + user_id: str | None = None, + start: str | None = None, + end: str | None = None, + admin: dict[str, Any] = Depends(_require_admin), +) -> list[dict[str, Any]]: + """Return a per-model usage breakdown.""" + store = _get_usage_store(request) + svc = get_usage_service() + return await svc.get_usage_by_model( + store, + department_id=department_id, + user_id=user_id, + start_time=_parse_iso(start), + end_time=_parse_iso(end), + ) + + +@admin_router.get("/usage/top-users") +async def get_top_users( + request: Request, + department_id: str | None = None, + user_id: str | None = None, + start: str | None = None, + end: str | None = None, + limit: int = 10, + admin: dict[str, Any] = Depends(_require_admin), +) -> list[dict[str, Any]]: + """Return the top-N users by total token usage. + + Query params: ``department_id``, ``user_id``, ``start``, ``end``, + ``limit`` (default 10, max 100). + """ + if limit < 1: + limit = 1 + if limit > 100: + limit = 100 + store = _get_usage_store(request) + svc = get_usage_service() + return await svc.get_top_users( + store, + department_id=department_id, + user_id=user_id, + start_time=_parse_iso(start), + end_time=_parse_iso(end), + limit=limit, + ) + + +@admin_router.get("/usage/export") +async def export_usage( + request: Request, + department_id: str | None = None, + user_id: str | None = None, + start: str | None = None, + end: str | None = None, + format: str = "csv", + admin: dict[str, Any] = Depends(_require_admin), +) -> Any: + """Export raw usage records as CSV or JSON. + + Query params: ``department_id``, ``user_id``, ``start``, ``end``, + ``format`` (``csv`` or ``json``, default ``csv``). + + Returns ``text/csv`` for CSV or ``application/json`` for JSON. + """ + if format not in ("csv", "json"): + raise HTTPException(status_code=400, detail="format must be 'csv' or 'json'") + store = _get_usage_store(request) + svc = get_usage_service() + body = await svc.export_usage( + store, + department_id=department_id, + user_id=user_id, + start_time=_parse_iso(start), + end_time=_parse_iso(end), + format=format, + ) + if format == "csv": + return PlainTextResponse(content=body, media_type="text/csv") + return PlainTextResponse(content=body, media_type="application/json") diff --git a/src/agentkit/server/routes/llm_gateway.py b/src/agentkit/server/routes/llm_gateway.py index 79f4c0e..1da87dc 100644 --- a/src/agentkit/server/routes/llm_gateway.py +++ b/src/agentkit/server/routes/llm_gateway.py @@ -7,12 +7,14 @@ Supports both non-streaming (`POST /api/v1/llm/chat`) and SSE streaming import json from typing import Any -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, ConfigDict, Field from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError +from agentkit.llm.gateway import QuotaExceededError from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall +from agentkit.server.admin.context import get_department_context router = APIRouter(prefix="/llm", tags=["llm-gateway"]) @@ -66,14 +68,32 @@ def _serialize_chunk(chunk: StreamChunk) -> dict[str, Any]: return payload +def _quota_error_payload(exc: QuotaExceededError) -> dict[str, Any]: + """Build a structured 429 error body from a QuotaExceededError.""" + return { + "error": "quota_exceeded", + "department_id": exc.department_id, + "quota_type": exc.quota_type, + "period": exc.period, + "limit": exc.limit, + "current": exc.current, + } + + @router.post("/chat") -async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]: +async def chat( + request: Request, + body: LLMChatRequest, + ctx: Any = Depends(get_department_context), +) -> dict[str, Any]: """Non-streaming LLM chat proxy. Forwards the request to the configured LLMGateway and returns the - serialized LLMResponse. + serialized LLMResponse. Quota-exceeded errors from the gateway are + translated to HTTP 429. """ gateway = request.app.state.llm_gateway + db_path = getattr(request.app.state, "auth_db_path", None) try: response = await gateway.chat( messages=body.messages, @@ -83,7 +103,12 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]: timeout=body.timeout, temperature=body.temperature, max_tokens=body.max_tokens, + user_id=ctx.user_id, + department_ids=ctx.department_ids if ctx.department_ids else None, + db_path=db_path, ) + except QuotaExceededError as e: + raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e except ModelNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e except LLMProviderError as e: @@ -92,7 +117,11 @@ async def chat(request: Request, body: LLMChatRequest) -> dict[str, Any]: @router.post("/chat/stream") -async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingResponse: +async def chat_stream( + request: Request, + body: LLMChatRequest, + ctx: Any = Depends(get_department_context), +) -> StreamingResponse: """SSE streaming LLM chat proxy. Each StreamChunk is serialized as `data: {json}\\n\\n`. The stream @@ -101,6 +130,7 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon async def event_generator(): gateway = request.app.state.llm_gateway + db_path = getattr(request.app.state, "auth_db_path", None) try: async for chunk in gateway.chat_stream( messages=body.messages, @@ -110,9 +140,16 @@ async def chat_stream(request: Request, body: LLMChatRequest) -> StreamingRespon timeout=body.timeout, temperature=body.temperature, max_tokens=body.max_tokens, + user_id=ctx.user_id, + department_ids=ctx.department_ids if ctx.department_ids else None, + db_path=db_path, ): payload = _serialize_chunk(chunk) yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + except QuotaExceededError as e: + error_payload = _quota_error_payload(e) + error_payload["error"] = "quota_exceeded" + yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" except ModelNotFoundError as e: error_payload = {"error": "model_not_found", "detail": str(e)} yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" diff --git a/tests/integration/admin/test_usage_routes.py b/tests/integration/admin/test_usage_routes.py new file mode 100644 index 0000000..13974e4 --- /dev/null +++ b/tests/integration/admin/test_usage_routes.py @@ -0,0 +1,341 @@ +"""Integration tests for the admin usage dashboard routes (U7). + +Uses FastAPI TestClient with a test app that mounts only the +``admin_router`` from ``routes.admin``. The ``_require_admin`` +dependency is overridden via ``app.dependency_overrides`` so the tests +don't need real JWTs — they can simulate admin and non-admin callers +directly. + +The LLM gateway is replaced with a stub that exposes a +``_usage_tracker.store`` attribute pointing at an +:class:`InMemoryUsageStore` pre-populated with test records. +""" + +from __future__ import annotations + +import csv +import io +import json +from pathlib import Path +from typing import Any + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from agentkit.llm.protocol import TokenUsage +from agentkit.llm.providers.usage_store import InMemoryUsageStore +from agentkit.server.admin.usage_service import set_usage_service +from agentkit.server.auth.models import init_auth_db +from agentkit.server.routes import admin as admin_routes_module + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +class _StubTracker: + """Minimal stub matching the UsageTracker surface used by routes.""" + + def __init__(self, store: InMemoryUsageStore) -> None: + self.store = store + + +class _StubGateway: + """Minimal stub matching the LLMGateway surface used by routes.""" + + def __init__(self, store: InMemoryUsageStore) -> None: + self._usage_tracker = _StubTracker(store) + + +@pytest.fixture +def store() -> InMemoryUsageStore: + return InMemoryUsageStore() + + +@pytest.fixture +def populated_store() -> InMemoryUsageStore: + """Pre-populated store with a mix of records across users/depts/models.""" + s = InMemoryUsageStore() + s.record( + "agent1", + "gpt-4o", + TokenUsage(prompt_tokens=60, completion_tokens=40), + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + s.record( + "agent1", + "claude", + TokenUsage(prompt_tokens=120, completion_tokens=80), + cost=0.10, + latency_ms=300, + user_id="u1", + department_id="d1", + ) + s.record( + "agent2", + "gpt-4o", + TokenUsage(prompt_tokens=30, completion_tokens=20), + cost=0.02, + latency_ms=100, + user_id="u2", + department_id="d2", + ) + s.record( + "agent3", + "gpt-4o", + TokenUsage(prompt_tokens=300, completion_tokens=200), + cost=0.50, + latency_ms=400, + user_id="u3", + department_id="d1", + ) + return s + + +@pytest.fixture(autouse=True) +def _reset_singletons(): + set_usage_service(None) + yield + set_usage_service(None) + + +@pytest.fixture +async def tmp_auth_db(tmp_path: Path) -> Path: + db_path = tmp_path / "usage_routes.db" + await init_auth_db(db_path) + return db_path + + +def _make_admin_app(store: InMemoryUsageStore, tmp_auth_db: Path) -> FastAPI: + """Build a FastAPI app with admin router + stub gateway.""" + app = FastAPI() + app.state.auth_db_path = str(tmp_auth_db) + app.state.llm_gateway = _StubGateway(store) + app.include_router(admin_routes_module.admin_router, prefix="/api/v1") + # Default: allow admin access. + app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user() + return app + + +def _make_admin_user() -> dict[str, Any]: + return {"user_id": "admin-1", "username": "admin", "role": "admin"} + + +def _raise_forbidden() -> dict[str, Any]: + raise HTTPException(status_code=403, detail="Admin permission required") + + +# --------------------------------------------------------------------------- +# /admin/usage/summary +# --------------------------------------------------------------------------- + + +class TestUsageSummaryRoute: + def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/summary") + assert resp.status_code == 200 + body = resp.json() + assert body["total_tokens"] == 850 + assert abs(body["total_cost"] - 0.67) < 1e-6 + assert body["total_requests"] == 4 + assert "gpt-4o" in body["by_model"] + assert "u1" in body["by_user"] + assert "d1" in body["by_department"] + + def test_with_department_filter(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/summary", params={"department_id": "d2"}) + assert resp.status_code == 200 + body = resp.json() + assert body["total_tokens"] == 50 + assert body["total_requests"] == 1 + + def test_empty_store_returns_200_with_zeros(self, store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/summary") + assert resp.status_code == 200 + body = resp.json() + assert body["total_tokens"] == 0 + assert body["total_cost"] == 0.0 + assert body["total_requests"] == 0 + + def test_non_admin_returns_403(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/summary") + assert resp.status_code == 403 + + +# --------------------------------------------------------------------------- +# /admin/usage/timeseries +# --------------------------------------------------------------------------- + + +class TestUsageTimeseriesRoute: + def test_returns_200_with_data(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/timeseries") + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + assert len(body) >= 1 + assert "timestamp" in body[0] + assert "tokens" in body[0] + assert body[0]["tokens"] == 850 + + def test_invalid_interval_returns_400( + self, populated_store: InMemoryUsageStore, tmp_auth_db: Path + ): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/timeseries", params={"interval": "week"}) + assert resp.status_code == 400 + + def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/timeseries") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# /admin/usage/by-model +# --------------------------------------------------------------------------- + + +class TestUsageByModelRoute: + def test_returns_200_with_breakdown( + self, populated_store: InMemoryUsageStore, tmp_auth_db: Path + ): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/by-model") + assert resp.status_code == 200 + body = resp.json() + assert isinstance(body, list) + models = {row["model"] for row in body} + assert models == {"gpt-4o", "claude"} + + def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/by-model") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# /admin/usage/top-users +# --------------------------------------------------------------------------- + + +class TestTopUsersRoute: + def test_returns_200_sorted_by_tokens( + self, populated_store: InMemoryUsageStore, tmp_auth_db: Path + ): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/top-users") + assert resp.status_code == 200 + body = resp.json() + assert len(body) == 3 + # u3 (500), u1 (300), u2 (50) + assert body[0]["user_id"] == "u3" + assert body[0]["tokens"] == 500 + assert body[1]["user_id"] == "u1" + assert body[2]["user_id"] == "u2" + + def test_limit_param_respected(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/top-users", params={"limit": 2}) + assert resp.status_code == 200 + body = resp.json() + assert len(body) == 2 + + def test_empty_store_returns_200_empty_list(self, store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/top-users") + assert resp.status_code == 200 + assert resp.json() == [] + + +# --------------------------------------------------------------------------- +# /admin/usage/export +# --------------------------------------------------------------------------- + + +class TestUsageExportRoute: + def test_csv_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"}) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/csv") + reader = csv.DictReader(io.StringIO(resp.text)) + rows = list(reader) + assert len(rows) == 4 + assert "timestamp" in rows[0] + assert "user_id" in rows[0] + assert "department_id" in rows[0] + + def test_json_export_returns_200(self, populated_store: InMemoryUsageStore, tmp_auth_db: Path): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/export", params={"format": "json"}) + assert resp.status_code == 200 + # PlainTextResponse returns text/plain or application/json depending on media_type. + body = json.loads(resp.text) + assert isinstance(body, list) + assert len(body) == 4 + + def test_invalid_format_returns_400( + self, populated_store: InMemoryUsageStore, tmp_auth_db: Path + ): + app = _make_admin_app(populated_store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/export", params={"format": "xml"}) + assert resp.status_code == 400 + + def test_empty_store_csv_returns_header_only( + self, store: InMemoryUsageStore, tmp_auth_db: Path + ): + app = _make_admin_app(store, tmp_auth_db) + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/export", params={"format": "csv"}) + assert resp.status_code == 200 + reader = csv.DictReader(io.StringIO(resp.text)) + rows = list(reader) + assert rows == [] + # Header should still be present. + assert "timestamp" in resp.text + + +# --------------------------------------------------------------------------- +# Missing gateway +# --------------------------------------------------------------------------- + + +class TestMissingGateway: + def test_summary_returns_500_without_gateway(self, tmp_auth_db: Path): + app = FastAPI() + app.state.auth_db_path = str(tmp_auth_db) + # No llm_gateway on app.state. + app.include_router(admin_routes_module.admin_router, prefix="/api/v1") + app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user() + client = TestClient(app) + resp = client.get("/api/v1/admin/usage/summary") + assert resp.status_code == 500 diff --git a/tests/unit/admin/test_usage_service.py b/tests/unit/admin/test_usage_service.py new file mode 100644 index 0000000..bc32882 --- /dev/null +++ b/tests/unit/admin/test_usage_service.py @@ -0,0 +1,330 @@ +"""Unit tests for UsageService (U7 — usage dashboard aggregations).""" + +from __future__ import annotations + +import csv +import io +import json + +import pytest + +from agentkit.llm.protocol import TokenUsage +from agentkit.llm.providers.usage_store import InMemoryUsageStore +from agentkit.server.admin.usage_service import ( + UsageService, + get_usage_service, + set_usage_service, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def store() -> InMemoryUsageStore: + return InMemoryUsageStore() + + +@pytest.fixture +def service() -> UsageService: + return UsageService() + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Reset the UsageService singleton before and after each test.""" + set_usage_service(None) + yield + set_usage_service(None) + + +def _populate_store(store: InMemoryUsageStore) -> None: + """Populate ``store`` with a mix of records for testing.""" + # User u1 in dept d1, gpt-4o, 100 tokens, $0.05 + store.record( + "agent1", + "gpt-4o", + TokenUsage(prompt_tokens=60, completion_tokens=40), + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + # User u1 in dept d1, claude, 200 tokens, $0.10 + store.record( + "agent1", + "claude", + TokenUsage(prompt_tokens=120, completion_tokens=80), + cost=0.10, + latency_ms=300, + user_id="u1", + department_id="d1", + ) + # User u2 in dept d2, gpt-4o, 50 tokens, $0.02 + store.record( + "agent2", + "gpt-4o", + TokenUsage(prompt_tokens=30, completion_tokens=20), + cost=0.02, + latency_ms=100, + user_id="u2", + department_id="d2", + ) + # User u3 in dept d1, gpt-4o, 500 tokens, $0.50 (top user) + store.record( + "agent3", + "gpt-4o", + TokenUsage(prompt_tokens=300, completion_tokens=200), + cost=0.50, + latency_ms=400, + user_id="u3", + department_id="d1", + ) + + +# --------------------------------------------------------------------------- +# get_usage_summary +# --------------------------------------------------------------------------- + + +class TestGetUsageSummary: + async def test_summary_aggregates_all(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_usage_summary(store) + assert result["total_tokens"] == 850 + assert abs(result["total_cost"] - 0.67) < 1e-6 + assert result["total_requests"] == 4 + # by_model: gpt-4o (3 records, 650 tokens), claude (1, 200) + assert "gpt-4o" in result["by_model"] + assert "claude" in result["by_model"] + assert result["by_model"]["gpt-4o"]["count"] == 3 + assert result["by_model"]["gpt-4o"]["total_tokens"] == 650 + # by_user: u1 (2 records, 300 tokens), u2 (1, 50), u3 (1, 500) + assert result["by_user"]["u1"]["total_tokens"] == 300 + assert result["by_user"]["u2"]["total_tokens"] == 50 + assert result["by_user"]["u3"]["total_tokens"] == 500 + # by_department: d1 (3 records, 800 tokens), d2 (1, 50) + assert result["by_department"]["d1"]["total_tokens"] == 800 + assert result["by_department"]["d2"]["total_tokens"] == 50 + + async def test_summary_with_department_filter( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + result = await service.get_usage_summary(store, department_id="d1") + assert result["total_tokens"] == 800 + assert result["total_requests"] == 3 + # Only u1 and u3 are in d1. + assert "u1" in result["by_user"] + assert "u3" in result["by_user"] + assert "u2" not in result["by_user"] + + async def test_summary_with_user_filter(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_usage_summary(store, user_id="u2") + assert result["total_tokens"] == 50 + assert result["total_requests"] == 1 + assert "u2" in result["by_user"] + + async def test_summary_with_empty_store(self, service: UsageService, store: InMemoryUsageStore): + result = await service.get_usage_summary(store) + assert result["total_tokens"] == 0 + assert result["total_cost"] == 0.0 + assert result["total_requests"] == 0 + assert result["by_model"] == {} + assert result["by_user"] == {} + assert result["by_department"] == {} + + +# --------------------------------------------------------------------------- +# get_usage_timeseries +# --------------------------------------------------------------------------- + + +class TestGetUsageTimeseries: + async def test_timeseries_day_buckets(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_usage_timeseries(store, interval="day") + # All records are within the same day (today), so we expect one bucket. + assert len(result) >= 1 + bucket = result[0] + assert "timestamp" in bucket + assert bucket["tokens"] == 850 + assert abs(bucket["cost"] - 0.67) < 1e-6 + assert bucket["requests"] == 4 + + async def test_timeseries_hour_buckets(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_usage_timeseries(store, interval="hour") + # All records are within the same hour (now), so we expect one bucket. + assert len(result) >= 1 + assert result[0]["tokens"] == 850 + + async def test_timeseries_with_department_filter( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + result = await service.get_usage_timeseries(store, department_id="d2", interval="day") + assert len(result) >= 1 + assert result[0]["tokens"] == 50 + + async def test_timeseries_empty_store(self, service: UsageService, store: InMemoryUsageStore): + result = await service.get_usage_timeseries(store, interval="day") + assert result == [] + + +# --------------------------------------------------------------------------- +# get_usage_by_model +# --------------------------------------------------------------------------- + + +class TestGetUsageByModel: + async def test_by_model_breakdown(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_usage_by_model(store) + # Sorted by model name: claude, gpt-4o + assert len(result) == 2 + models = {row["model"] for row in result} + assert models == {"gpt-4o", "claude"} + gpt_row = next(r for r in result if r["model"] == "gpt-4o") + assert gpt_row["tokens"] == 650 + assert gpt_row["requests"] == 3 + claude_row = next(r for r in result if r["model"] == "claude") + assert claude_row["tokens"] == 200 + assert claude_row["requests"] == 1 + + async def test_by_model_with_department_filter( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + result = await service.get_usage_by_model(store, department_id="d2") + # d2 only has gpt-4o (50 tokens, 1 request) + assert len(result) == 1 + assert result[0]["model"] == "gpt-4o" + assert result[0]["tokens"] == 50 + + async def test_by_model_empty_store(self, service: UsageService, store: InMemoryUsageStore): + result = await service.get_usage_by_model(store) + assert result == [] + + +# --------------------------------------------------------------------------- +# get_top_users +# --------------------------------------------------------------------------- + + +class TestGetTopUsers: + async def test_top_users_sorted_by_tokens( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + result = await service.get_top_users(store, limit=10) + # u3 (500), u1 (300), u2 (50) + assert len(result) == 3 + assert result[0]["user_id"] == "u3" + assert result[0]["tokens"] == 500 + assert result[1]["user_id"] == "u1" + assert result[1]["tokens"] == 300 + assert result[2]["user_id"] == "u2" + assert result[2]["tokens"] == 50 + + async def test_top_users_respects_limit(self, service: UsageService, store: InMemoryUsageStore): + _populate_store(store) + result = await service.get_top_users(store, limit=2) + assert len(result) == 2 + assert result[0]["user_id"] == "u3" + assert result[1]["user_id"] == "u1" + + async def test_top_users_with_department_filter( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + result = await service.get_top_users(store, department_id="d1", limit=10) + # d1 has u1 and u3 + assert len(result) == 2 + assert result[0]["user_id"] == "u3" + assert result[1]["user_id"] == "u1" + + async def test_top_users_empty_store(self, service: UsageService, store: InMemoryUsageStore): + result = await service.get_top_users(store, limit=10) + assert result == [] + + +# --------------------------------------------------------------------------- +# export_usage +# --------------------------------------------------------------------------- + + +class TestExportUsage: + async def test_export_csv_has_header_and_rows( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + body = await service.export_usage(store, format="csv") + reader = csv.DictReader(io.StringIO(body)) + rows = list(reader) + assert len(rows) == 4 + # Verify headers + assert "timestamp" in rows[0] + assert "agent_name" in rows[0] + assert "model" in rows[0] + assert "user_id" in rows[0] + assert "department_id" in rows[0] + # Verify a known record + gpt_rows = [r for r in rows if r["model"] == "gpt-4o"] + assert len(gpt_rows) == 3 + + async def test_export_json_returns_valid_json( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + body = await service.export_usage(store, format="json") + data = json.loads(body) + assert isinstance(data, list) + assert len(data) == 4 + assert "timestamp" in data[0] + assert "user_id" in data[0] + assert "department_id" in data[0] + + async def test_export_csv_empty_store_returns_header_only( + self, service: UsageService, store: InMemoryUsageStore + ): + body = await service.export_usage(store, format="csv") + reader = csv.DictReader(io.StringIO(body)) + rows = list(reader) + assert rows == [] + # Header should still be present. + assert "timestamp" in body + + async def test_export_with_department_filter( + self, service: UsageService, store: InMemoryUsageStore + ): + _populate_store(store) + body = await service.export_usage(store, department_id="d2", format="csv") + reader = csv.DictReader(io.StringIO(body)) + rows = list(reader) + assert len(rows) == 1 + assert rows[0]["department_id"] == "d2" + + +# --------------------------------------------------------------------------- +# Singleton helpers +# --------------------------------------------------------------------------- + + +class TestSingletonHelpers: + def test_get_usage_service_returns_singleton(self): + first = get_usage_service() + second = get_usage_service() + assert first is second + + def test_set_usage_service_overrides(self): + custom = UsageService() + set_usage_service(custom) + assert get_usage_service() is custom + # Clearing falls back to a new lazy instance. + set_usage_service(None) + new_one = get_usage_service() + assert new_one is not custom diff --git a/tests/unit/llm/test_quota_enforcement.py b/tests/unit/llm/test_quota_enforcement.py new file mode 100644 index 0000000..5769a14 --- /dev/null +++ b/tests/unit/llm/test_quota_enforcement.py @@ -0,0 +1,321 @@ +"""Unit tests for LLMGateway quota enforcement (U7). + +Covers: +- QuotaExceededError raised when token_limit exceeded +- QuotaExceededError raised when cost_limit exceeded +- QuotaExceededError raised when model not in whitelist +- No quota set → request allowed +- Multi-department: strictest-wins (one exceeds, other doesn't → rejected) +- QuotaExceededError carries the right metadata +- Usage recording still attaches user_id + department_id on success +""" + +from __future__ import annotations + +import uuid +from pathlib import Path + +import pytest + +from agentkit.llm.gateway import LLMGateway, QuotaExceededError +from agentkit.llm.protocol import ( + LLMProvider, + LLMRequest, + LLMResponse, + TokenUsage, +) +from agentkit.llm.providers.usage_store import InMemoryUsageStore +from agentkit.server.admin.quota_service import ( + get_quota_service, + set_quota_service, +) +from agentkit.server.auth.models import init_auth_db + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class FakeProvider(LLMProvider): + """A minimal LLMProvider that returns a fixed response.""" + + def __init__(self, name: str = "fake"): + self._name = name + self.last_request: LLMRequest | None = None + self.call_count = 0 + + async def chat(self, request: LLMRequest) -> LLMResponse: + self.last_request = request + self.call_count += 1 + return LLMResponse( + content=f"response from {self._name}", + model=request.model, + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def store() -> InMemoryUsageStore: + return InMemoryUsageStore() + + +@pytest.fixture +def gateway(store: InMemoryUsageStore) -> LLMGateway: + gw = LLMGateway(usage_store=store) + gw.register_provider("openai", FakeProvider("openai")) + return gw + + +@pytest.fixture +async def fresh_db(tmp_path: Path) -> Path: + db_path = tmp_path / "auth.db" + await init_auth_db(db_path) + return db_path + + +@pytest.fixture(autouse=True) +def _reset_quota_singleton(): + """Reset the QuotaService singleton before and after each test.""" + set_quota_service(None) + yield + set_quota_service(None) + + +def _random_dept_id() -> str: + return str(uuid.uuid4()) + + +# --------------------------------------------------------------------------- +# Quota enforcement +# --------------------------------------------------------------------------- + + +class TestQuotaEnforcement: + async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path): + """When no quota is configured, the request is allowed.""" + dept_id = _random_dept_id() + response = await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + assert response.content == "response from openai" + + async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path): + """token_limit quota exceeded → QuotaExceededError.""" + dept_id = _random_dept_id() + svc = get_quota_service() + # Set a tiny token limit (1 token) — any usage will exceed it. + await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily") + + # Pre-populate the usage store so the daily total > 1. + gateway._usage_tracker.record( + agent_name="prev", + model="openai/gpt-4o", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=0.0, + latency_ms=10, + user_id="u1", + department_id=dept_id, + ) + + with pytest.raises(QuotaExceededError) as exc_info: + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + err = exc_info.value + assert err.department_id == dept_id + assert err.quota_type == "token_limit" + assert err.period == "daily" + assert err.limit == 1 + assert err.current == 150 # 100 prompt + 50 completion + + async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path): + """cost_limit quota exceeded → QuotaExceededError.""" + dept_id = _random_dept_id() + svc = get_quota_service() + # cost_limit is in cents. Set 1 cent. + await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily") + + # Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit. + gateway._usage_tracker.record( + agent_name="prev", + model="openai/gpt-4o", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=1.00, # $1.00 = 100 cents + latency_ms=10, + user_id="u1", + department_id=dept_id, + ) + + with pytest.raises(QuotaExceededError) as exc_info: + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + err = exc_info.value + assert err.quota_type == "cost_limit" + assert err.period == "daily" + assert err.limit == 1 + # current is in cents (100 cents = $1.00). + assert err.current == 100.0 + + async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path): + """Model not in whitelist → QuotaExceededError with quota_type=model_whitelist.""" + dept_id = _random_dept_id() + svc = get_quota_service() + # Whitelist only allows "claude" — gateway is calling "gpt-4o". + await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily") + + with pytest.raises(QuotaExceededError) as exc_info: + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + err = exc_info.value + assert err.quota_type == "model_whitelist" + assert err.department_id == dept_id + # For model_whitelist, current is the rejected model name. + assert err.current == "openai/gpt-4o" + + async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path): + """Model in whitelist → request allowed.""" + dept_id = _random_dept_id() + svc = get_quota_service() + # Whitelist uses the full resolved model identifier (provider/model). + await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily") + response = await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + assert response.content == "response from openai" + + async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path): + """One department exceeds, the other doesn't → rejected (strictest wins).""" + dept_ok = _random_dept_id() + dept_bad = _random_dept_id() + svc = get_quota_service() + # dept_bad has a 1-token limit; dept_ok has a 1M-token limit. + await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily") + await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily") + + # Pre-populate usage for dept_bad so it exceeds. + gateway._usage_tracker.record( + agent_name="prev", + model="openai/gpt-4o", + usage=TokenUsage(prompt_tokens=100, completion_tokens=50), + cost=0.0, + latency_ms=10, + user_id="u1", + department_id=dept_bad, + ) + + with pytest.raises(QuotaExceededError) as exc_info: + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_ok, dept_bad], + db_path=fresh_db, + ) + # The error should reference dept_bad (the one that exceeded). + assert exc_info.value.department_id == dept_bad + + async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path): + """When db_path is None, no quota check is performed.""" + dept_id = _random_dept_id() + svc = get_quota_service() + await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily") + # Even with a quota set, calling without db_path should succeed. + response = await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=None, + ) + assert response.content == "response from openai" + + async def test_quota_check_skipped_without_department_ids( + self, gateway: LLMGateway, fresh_db: Path + ): + """When department_ids is None, no quota check is performed.""" + response = await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=None, + db_path=fresh_db, + ) + assert response.content == "response from openai" + + async def test_usage_recorded_with_user_and_department( + self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path + ): + """After a successful call, the usage record carries user_id + department_id.""" + dept_id = _random_dept_id() + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model="openai/gpt-4o", + user_id="u1", + department_ids=[dept_id], + db_path=fresh_db, + ) + summary = store.get_usage() + assert len(summary.records) == 1 + rec = summary.records[0] + assert rec.user_id == "u1" + assert rec.department_id == dept_id + assert rec.model == "gpt-4o" + assert rec.total_tokens == 150 # 100 prompt + 50 completion + + +# --------------------------------------------------------------------------- +# QuotaExceededError dataclass-like behavior +# --------------------------------------------------------------------------- + + +class TestQuotaExceededError: + def test_error_message_includes_metadata(self): + err = QuotaExceededError( + department_id="d1", + quota_type="token_limit", + period="daily", + limit=1000, + current=1500, + ) + msg = str(err) + assert "d1" in msg + assert "token_limit" in msg + assert "daily" in msg + assert "1000" in msg + assert "1500" in msg + + def test_error_attributes_preserved(self): + err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000) + assert err.department_id == "d1" + assert err.quota_type == "cost_limit" + assert err.period == "monthly" + assert err.limit == 5000 + assert err.current == 6000 diff --git a/tests/unit/llm/test_usage_store.py b/tests/unit/llm/test_usage_store.py index d9ce3d4..93e59ba 100644 --- a/tests/unit/llm/test_usage_store.py +++ b/tests/unit/llm/test_usage_store.py @@ -1,6 +1,5 @@ """Unit tests for UsageStore (U4 — UsageStore Persistence).""" -import pytest from datetime import datetime, timedelta, timezone from agentkit.llm.protocol import TokenUsage @@ -114,6 +113,128 @@ class TestInMemoryUsageStore: # Should be parseable as ISO 8601 datetime.fromisoformat(rec.timestamp) + def test_record_with_user_and_department(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + rec = store.get_usage().records[0] + assert rec.user_id == "u1" + assert rec.department_id == "d1" + + def test_record_defaults_user_department_to_none(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) + rec = store.get_usage().records[0] + assert rec.user_id is None + assert rec.department_id is None + + def test_get_usage_filters_by_user(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u1") + store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, user_id="u2") + summary = store.get_usage(user_id="u1") + assert len(summary.records) == 1 + assert summary.records[0].user_id == "u1" + + def test_get_usage_filters_by_department(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d1") + store.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200, department_id="d2") + summary = store.get_usage(department_id="d1") + assert len(summary.records) == 1 + assert summary.records[0].department_id == "d1" + + def test_get_usage_by_user(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u2", + department_id="d2", + ) + summary = store.get_usage_by_user("u1") + assert len(summary.records) == 1 + assert summary.records[0].user_id == "u1" + assert summary.total_tokens == 150 + + def test_get_usage_by_department(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u2", + department_id="d2", + ) + summary = store.get_usage_by_department("d1") + assert len(summary.records) == 1 + assert summary.records[0].department_id == "d1" + assert summary.total_tokens == 150 + + def test_summary_includes_by_user_and_by_department(self): + store = InMemoryUsageStore() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + summary = store.get_usage() + assert "u1" in summary.by_user + assert summary.by_user["u1"]["count"] == 2 + assert summary.by_user["u1"]["total_tokens"] == 300 + assert "d1" in summary.by_department + assert summary.by_department["d1"]["count"] == 2 + # --------------------------------------------------------------------------- # UsageRecord / UsageBucket / UsageSummary dataclasses @@ -123,17 +244,25 @@ class TestInMemoryUsageStore: class TestDataclasses: def test_usage_record_auto_timestamp(self): rec = UsageRecord( - agent_name="a", model="m", - prompt_tokens=1, completion_tokens=1, - total_tokens=2, cost=0.01, latency_ms=100, + agent_name="a", + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost=0.01, + latency_ms=100, ) assert rec.timestamp != "" def test_usage_record_explicit_timestamp(self): rec = UsageRecord( - agent_name="a", model="m", - prompt_tokens=1, completion_tokens=1, - total_tokens=2, cost=0.01, latency_ms=100, + agent_name="a", + model="m", + prompt_tokens=1, + completion_tokens=1, + total_tokens=2, + cost=0.01, + latency_ms=100, timestamp="2026-01-01T00:00:00+00:00", ) assert rec.timestamp == "2026-01-01T00:00:00+00:00" @@ -208,6 +337,98 @@ class TestRedisUsageStoreMocked: assert len(key) == 10 assert key[4] == "-" + def test_v2_keys_with_user_and_department(self): + store = self._make_store() + hash_key, list_key = store._v2_keys("2026-06-21", "u1", "d1") + assert hash_key == "agentkit:usage:v2:2026-06-21:u1:d1" + assert list_key == "agentkit:usage_records:v2:2026-06-21:u1:d1" + + def test_v2_keys_with_none_user_and_department(self): + store = self._make_store() + hash_key, list_key = store._v2_keys("2026-06-21", None, None) + # None values are normalized to "none" in the key. + assert hash_key == "agentkit:usage:v2:2026-06-21:none:none" + assert list_key == "agentkit:usage_records:v2:2026-06-21:none:none" + + def test_record_degraded_with_user_and_department(self): + store = self._make_store() + store._degrade_to_fallback() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + # Should be in fallback with user/department attached. + summary = store._fallback.get_usage() + assert len(summary.records) == 1 + assert summary.records[0].user_id == "u1" + assert summary.records[0].department_id == "d1" + + def test_get_usage_degraded_with_user_filter(self): + store = self._make_store() + store._degrade_to_fallback() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store._fallback.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + store._fallback.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u2", + department_id="d2", + ) + summary = store.get_usage(user_id="u1") + assert len(summary.records) == 1 + assert summary.records[0].user_id == "u1" + + def test_get_usage_by_user_degraded(self): + store = self._make_store() + store._degrade_to_fallback() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store._fallback.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + summary = store.get_usage_by_user("u1") + assert len(summary.records) == 1 + assert summary.records[0].user_id == "u1" + + def test_get_usage_by_department_degraded(self): + store = self._make_store() + store._degrade_to_fallback() + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + store._fallback.record( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + summary = store.get_usage_by_department("d1") + assert len(summary.records) == 1 + assert summary.records[0].department_id == "d1" + # --------------------------------------------------------------------------- # Factory