feat(admin): U5 — LLM config admin endpoints + department quotas

QuotaService: set/get/list/delete quotas, check_quota (hard reject),
is_model_allowed. JSON-serialized limit_value, upsert with ON CONFLICT.

LlmConfigService: provider CRUD + set_api_key + fallback management.
fcntl.flock file lock prevents concurrent YAML writes. Reuses
settings.py helpers (_read_yaml_config, _write_yaml_config,
_write_env_var, _mask_api_key).

11 new admin endpoints: provider CRUD, api-key, fallback CRUD,
department quotas CRUD. All guarded by _require_admin.

93 new tests (30 quota unit + 32 llm-config unit + 31 integration).
This commit is contained in:
chiguyong 2026-06-21 15:03:38 +08:00
parent ad65f7a8d7
commit 980919fc95
6 changed files with 2297 additions and 0 deletions

View File

@ -0,0 +1,437 @@
"""LlmConfigService — runtime CRUD for LLM providers/fallbacks (U5).
This module wraps the YAML read/write logic from
:mod:`agentkit.server.routes.settings` as a reusable service. Web UI
routes (``/api/v1/admin/llm/*``) and the CLI ``agentkit admin llm``
sub-app both call into :class:`LlmConfigService` rather than touching
the YAML directly, keeping the validation rules (duplicate-provider,
fallback-in-use guard, API-key masking) in one place.
The service writes back to ``agentkit.yaml`` using the existing
:func:`_write_yaml_config` helper (which preserves ``${VAR}`` env-var
references and comments via ruamel.yaml). Concurrent writes are
serialized with :func:`fcntl.flock` to prevent corruption.
The service is a module-level singleton (see
:func:`get_llm_config_service`) so tests can inject a custom instance
via :func:`set_llm_config_service`.
"""
from __future__ import annotations
import fcntl
import logging
import os
import re
from pathlib import Path
from typing import Any
import yaml
from agentkit.server.routes.settings import (
_mask_api_key,
_read_yaml_config,
_write_env_var,
_write_yaml_config,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: Regex matching ``${ENV_VAR}`` references in YAML api_key fields.
_ENV_REF_RE = re.compile(r"^\$\{([^}]+)\}$")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _provider_to_dict(name: str, pconf: dict[str, Any]) -> dict[str, Any]:
"""Convert a raw YAML provider dict to a masked response dict.
The ``api_key`` field is masked via :func:`_mask_api_key`. The
``models`` field is preserved as-is (it's already a dict in YAML).
"""
raw_key = pconf.get("api_key", "")
# Resolve ${VAR} references for masking — we want to show the
# masked *resolved* value, not the literal "${VAR}" string.
resolved_key = _resolve_env_var(raw_key)
return {
"name": name,
"type": pconf.get("type", "openai"),
"api_key": _mask_api_key(resolved_key),
"base_url": pconf.get("base_url", ""),
"models": pconf.get("models", {}) or {},
"max_tokens": pconf.get("max_tokens", 4096),
"timeout": pconf.get("timeout", 60.0),
}
def _resolve_env_var(value: str) -> str:
"""Resolve a ``${VAR}`` reference to its env value (if set).
If ``value`` is ``${OPENAI_API_KEY}`` and ``OPENAI_API_KEY`` is set
in the environment, returns the env value. Otherwise returns the
input unchanged.
"""
if not isinstance(value, str):
return ""
match = _ENV_REF_RE.match(value)
if match:
var_name = match.group(1).split(":-")[0]
return os.environ.get(var_name, value)
return value
def _env_var_name_for_provider(name: str) -> str:
"""Derive a sensible env var name from a provider name."""
return f"{name.upper().replace('-', '_')}_API_KEY"
# ---------------------------------------------------------------------------
# File-locking context manager
# ---------------------------------------------------------------------------
class _YamlLock:
"""Context manager that acquires an exclusive ``fcntl.flock`` on the
YAML file's sibling lockfile (``<config_path>.lock``).
The lockfile is created on demand and is not removed after release
(it's reused on subsequent writes). Holding the lock prevents
concurrent writers from corrupting the YAML file.
"""
def __init__(self, config_path: Path) -> None:
self._lock_path = config_path.with_suffix(config_path.suffix + ".lock")
self._fh = None
def __enter__(self) -> "_YamlLock":
self._fh = open(self._lock_path, "w")
fcntl.flock(self._fh.fileno(), fcntl.LOCK_EX)
return self
def __exit__(self, exc_type, exc, tb) -> None:
if self._fh is not None:
fcntl.flock(self._fh.fileno(), fcntl.LOCK_UN)
self._fh.close()
self._fh = None
# ---------------------------------------------------------------------------
# Service
# ---------------------------------------------------------------------------
class LlmConfigService:
"""CRUD for LLM providers, API keys, and fallback chains.
All methods read/write the YAML file at ``self._config_path``.
Writes are serialized via :class:`_YamlLock` (``fcntl.flock``) to
prevent concurrent corruption. The existing ``watch_config()``
mechanism detects the file change and rebuilds the
:class:`LLMGateway` no explicit notification is needed.
API keys are NEVER returned in full every response masks them via
:func:`_mask_api_key`. New plaintext keys are written to ``.env``
(next to the YAML) and the YAML stores a ``${ENV_VAR}`` reference.
"""
def __init__(self, config_path: Path) -> None:
self._config_path = Path(config_path)
# ------------------------------------------------------------------
# Provider CRUD
# ------------------------------------------------------------------
def list_providers(self) -> list[dict[str, Any]]:
"""Return all providers with masked API keys."""
data = self._read()
providers = data.get("llm", {}).get("providers", {}) or {}
return [_provider_to_dict(name, pconf) for name, pconf in providers.items()]
def get_provider(self, name: str) -> dict[str, Any] | None:
"""Return a single provider by name (masked key), or ``None``."""
data = self._read()
pconf = data.get("llm", {}).get("providers", {}).get(name)
if pconf is None:
return None
return _provider_to_dict(name, pconf)
def create_provider(
self,
name: str,
provider_type: str,
api_key: str,
base_url: str = "",
models: dict[str, Any] | None = None,
max_tokens: int = 4096,
timeout: float = 60.0,
) -> dict[str, Any]:
"""Add a new provider to the YAML file.
Args:
name: Provider name (must be unique within the YAML).
provider_type: Provider type (``openai`` / ``anthropic`` /
``gemini`` / ``doubao`` / ``wenxin`` / ``yuanbao``).
api_key: Plaintext API key. Stored in ``.env`` as
``{NAME}_API_KEY``; the YAML stores ``${...}`` ref.
base_url: Optional base URL override.
models: Optional models dict (e.g. ``{"gpt-4o": {}}``).
max_tokens: Default max tokens for completions.
timeout: Request timeout in seconds.
Returns:
The newly-created provider dict (masked key).
Raises:
ValueError: If a provider with the same name already exists.
"""
with _YamlLock(self._config_path):
data = self._read()
data.setdefault("llm", {}).setdefault("providers", {})
if name in data["llm"]["providers"]:
raise ValueError(f"Provider {name!r} already exists")
env_key = _env_var_name_for_provider(name)
_write_env_var(str(self._config_path), env_key, api_key)
pconf: dict[str, Any] = {
"type": provider_type,
"api_key": f"${{{env_key}}}",
"base_url": base_url,
"max_tokens": max_tokens,
"timeout": timeout,
}
if models is not None:
pconf["models"] = models
data["llm"]["providers"][name] = pconf
self._write(data)
created = self.get_provider(name)
assert created is not None # we just inserted it
return created
def update_provider(
self,
name: str,
provider_type: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
models: dict[str, Any] | None = None,
max_tokens: int | None = None,
timeout: float | None = None,
) -> dict[str, Any]:
"""Partially update a provider.
Only the provided fields are updated. If ``api_key`` is a masked
value (starts with ``****``), it is ignored the existing key
is preserved.
Raises:
ValueError: If the provider does not exist.
"""
with _YamlLock(self._config_path):
data = self._read()
providers = data.get("llm", {}).get("providers", {}) or {}
if name not in providers:
raise ValueError(f"Provider {name!r} not found")
pconf = providers[name]
if provider_type is not None:
pconf["type"] = provider_type
if base_url is not None:
pconf["base_url"] = base_url
if models is not None:
pconf["models"] = models
if max_tokens is not None:
pconf["max_tokens"] = max_tokens
if timeout is not None:
pconf["timeout"] = timeout
if api_key is not None and not api_key.startswith("****"):
# New plaintext key — write to .env, keep ${VAR} ref in YAML.
existing_key = str(pconf.get("api_key", ""))
env_match = _ENV_REF_RE.match(existing_key)
if env_match:
env_key = env_match.group(1).split(":-")[0]
else:
env_key = _env_var_name_for_provider(name)
_write_env_var(str(self._config_path), env_key, api_key)
pconf["api_key"] = f"${{{env_key}}}"
providers[name] = pconf
data.setdefault("llm", {})["providers"] = providers
self._write(data)
updated = self.get_provider(name)
assert updated is not None
return updated
def delete_provider(self, name: str) -> bool:
"""Remove a provider from the YAML file.
Returns:
``True`` if the provider was deleted.
Raises:
ValueError: If the provider does not exist, or if it is
referenced by any fallback chain (removing it would
break the chain).
"""
with _YamlLock(self._config_path):
data = self._read()
providers = data.get("llm", {}).get("providers", {}) or {}
if name not in providers:
raise ValueError(f"Provider {name!r} not found")
# Guard: refuse to delete if referenced in any fallback chain.
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
for model, chain in fallbacks.items():
if isinstance(chain, list) and name in chain:
raise ValueError(
f"Provider {name!r} is used in fallback chain for model "
f"{model!r}; remove the fallback first"
)
del providers[name]
data.setdefault("llm", {})["providers"] = providers
self._write(data)
return True
# ------------------------------------------------------------------
# API key management
# ------------------------------------------------------------------
def set_api_key(self, provider_name: str, api_key: str) -> dict[str, Any]:
"""Set the API key for a provider.
The plaintext key is written to ``.env`` (as
``{NAME}_API_KEY``) and the YAML stores a ``${...}`` reference.
If the provider does not yet exist in the YAML, a stub entry is
created with ``type=openai`` and default settings.
Returns:
The updated provider dict (masked key).
"""
with _YamlLock(self._config_path):
data = self._read()
data.setdefault("llm", {}).setdefault("providers", {})
providers = data["llm"]["providers"]
pconf = providers.get(provider_name, {})
existing_key = str(pconf.get("api_key", ""))
env_match = _ENV_REF_RE.match(existing_key)
if env_match:
env_key = env_match.group(1).split(":-")[0]
else:
env_key = _env_var_name_for_provider(provider_name)
_write_env_var(str(self._config_path), env_key, api_key)
pconf["api_key"] = f"${{{env_key}}}"
if "type" not in pconf:
pconf["type"] = "openai"
providers[provider_name] = pconf
self._write(data)
updated = self.get_provider(provider_name)
assert updated is not None
return updated
# ------------------------------------------------------------------
# Fallback chain management
# ------------------------------------------------------------------
def get_fallbacks(self) -> dict[str, list[str]]:
"""Return the fallback chains (model → list of provider/model refs)."""
data = self._read()
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
# Normalize: ensure all values are lists.
return {
str(model): list(chain) if isinstance(chain, list) else [str(chain)]
for model, chain in fallbacks.items()
}
def set_fallback(self, model: str, chain: list[str]) -> dict[str, Any]:
"""Set the fallback chain for a model.
Returns:
``{"model": model, "chain": chain}`` dict.
"""
with _YamlLock(self._config_path):
data = self._read()
data.setdefault("llm", {}).setdefault("fallbacks", {})
data["llm"]["fallbacks"][model] = list(chain)
self._write(data)
return {"model": model, "chain": list(chain)}
def delete_fallback(self, model: str) -> bool:
"""Delete the fallback chain for a model.
Returns:
``True`` if a chain was deleted, ``False`` if no chain
existed for ``model``.
"""
with _YamlLock(self._config_path):
data = self._read()
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
if model not in fallbacks:
return False
del fallbacks[model]
data.setdefault("llm", {})["fallbacks"] = fallbacks
self._write(data)
return True
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _read(self) -> dict[str, Any]:
"""Read the YAML config (returns ``{}`` if file is missing/empty)."""
if not self._config_path.exists():
return {}
return _read_yaml_config(str(self._config_path))
def _write(self, data: dict[str, Any]) -> None:
"""Write the full config dict back to the YAML file."""
_write_yaml_config(str(self._config_path), data)
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_llm_config_service)
# ---------------------------------------------------------------------------
_llm_config_service: LlmConfigService | None = None
def get_llm_config_service(config_path: Path | str | None = None) -> LlmConfigService:
"""Return the process-wide :class:`LlmConfigService`.
On first call, ``config_path`` must be provided (or the service
will raise ``ValueError``). Subsequent calls ignore ``config_path``
and return the cached singleton.
"""
global _llm_config_service
if _llm_config_service is None:
if config_path is None:
raise ValueError("config_path is required to initialize LlmConfigService on first call")
_llm_config_service = LlmConfigService(Path(config_path))
return _llm_config_service
def set_llm_config_service(service: LlmConfigService | None) -> None:
"""Inject a custom :class:`LlmConfigService` (used by tests)."""
global _llm_config_service
_llm_config_service = service
# Re-export yaml for tests that need to construct sample config files.
__all__ = [
"LlmConfigService",
"get_llm_config_service",
"set_llm_config_service",
"yaml",
]

View File

@ -0,0 +1,349 @@
"""QuotaService — per-department LLM quota CRUD + enforcement (U5/U7).
This module is the single owner of the ``department_quotas`` table.
Web UI routes (``/api/v1/admin/departments/{id}/quotas``) and the CLI
``agentkit admin llm set-quota`` sub-app both call into
:class:`QuotaService` rather than touching the table directly, keeping
the validation rules (quota_type/period enums, JSON serialization of
``model_whitelist``) in one place.
Quota semantics
---------------
- ``token_limit`` (int): max tokens per ``period``. ``check_quota``
compares ``current_usage`` (tokens used) against the limit.
- ``cost_limit`` (int, in cents): max spend per ``period``.
``check_quota`` compares ``current_usage`` (cost in cents) against
the limit.
- ``model_whitelist`` (list[str]): allowed model names. Not a quota
check it's a model-access check (use :meth:`is_model_allowed`).
The service is a module-level singleton (see :func:`get_quota_service`)
so tests can inject a custom instance via :func:`set_quota_service`.
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import aiosqlite
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: Valid ``quota_type`` values.
QUOTA_TYPES: frozenset[str] = frozenset({"token_limit", "cost_limit", "model_whitelist"})
#: Valid ``period`` values.
PERIODS: frozenset[str] = frozenset({"daily", "monthly"})
def _now_iso() -> str:
"""Return current UTC time as ISO 8601 string."""
return datetime.now(timezone.utc).isoformat()
def _new_id() -> str:
"""Return a new UUID4 string."""
return str(uuid.uuid4())
def _serialize_limit_value(quota_type: str, limit_value: Any) -> str:
"""Serialize ``limit_value`` to a JSON string for storage.
For ``token_limit`` / ``cost_limit``: ``limit_value`` is an int,
stored as ``json.dumps(int)`` (e.g. ``"1000"``).
For ``model_whitelist``: ``limit_value`` is a ``list[str]``, stored
as ``json.dumps(list)`` (e.g. ``'["gpt-4o", "claude"]'``).
"""
if quota_type == "model_whitelist":
if not isinstance(limit_value, list):
raise ValueError(
f"model_whitelist limit_value must be a list, got {type(limit_value).__name__}"
)
return json.dumps([str(m) for m in limit_value])
if not isinstance(limit_value, int):
raise ValueError(
f"{quota_type} limit_value must be an int, got {type(limit_value).__name__}"
)
return json.dumps(limit_value)
def _deserialize_limit_value(quota_type: str, raw: str) -> Any:
"""Deserialize ``limit_value`` from JSON string back to native type."""
try:
value = json.loads(raw)
except json.JSONDecodeError as exc:
logger.warning("Failed to deserialize limit_value %r: %s", raw, exc)
return None
if quota_type == "model_whitelist":
return list(value) if isinstance(value, list) else []
return int(value) if isinstance(value, (int, float)) else None
def _validate_quota_type(quota_type: str) -> None:
if quota_type not in QUOTA_TYPES:
raise ValueError(f"Invalid quota_type {quota_type!r}; must be one of {sorted(QUOTA_TYPES)}")
def _validate_period(period: str) -> None:
if period not in PERIODS:
raise ValueError(f"Invalid period {period!r}; must be one of {sorted(PERIODS)}")
class QuotaService:
"""CRUD + enforcement for per-department LLM quotas.
All methods are async and take ``db_path: Path`` as the first
argument (after ``self``). Each method opens its own short-lived
:class:`aiosqlite.Connection` there is no shared connection state,
which keeps the service safe to call from any async context.
"""
# ------------------------------------------------------------------
# Quota CRUD
# ------------------------------------------------------------------
async def set_quota(
self,
db_path: Path,
department_id: str,
quota_type: str,
limit_value: Any,
period: str = "daily",
) -> dict[str, Any]:
"""Upsert a quota for a department.
Args:
db_path: Path to the auth SQLite DB.
department_id: Department id.
quota_type: One of ``token_limit`` / ``cost_limit`` /
``model_whitelist``.
limit_value: ``int`` for token/cost limits, ``list[str]``
for model_whitelist.
period: ``daily`` or ``monthly`` (default ``daily``).
Returns:
The upserted quota as a dict (with deserialized
``limit_value``).
"""
_validate_quota_type(quota_type)
_validate_period(period)
serialized = _serialize_limit_value(quota_type, limit_value)
now = _now_iso()
quota_id = _new_id()
async with aiosqlite.connect(str(db_path)) as db:
# Upsert: if (department_id, quota_type, period) exists,
# update limit_value + updated_at; otherwise insert.
await db.execute(
"INSERT INTO department_quotas "
"(id, department_id, quota_type, limit_value, period, updated_at) "
"VALUES (?, ?, ?, ?, ?, ?) "
"ON CONFLICT(department_id, quota_type, period) DO UPDATE SET "
"limit_value = excluded.limit_value, "
"updated_at = excluded.updated_at",
(quota_id, department_id, quota_type, serialized, period, now),
)
await db.commit()
return {
"id": quota_id,
"department_id": department_id,
"quota_type": quota_type,
"limit_value": _deserialize_limit_value(quota_type, serialized),
"period": period,
"updated_at": now,
}
async def get_quota(
self,
db_path: Path,
department_id: str,
quota_type: str,
period: str = "daily",
) -> dict[str, Any] | None:
"""Return a single quota, or ``None`` if not set."""
_validate_quota_type(quota_type)
_validate_period(period)
async with aiosqlite.connect(str(db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM department_quotas "
"WHERE department_id = ? AND quota_type = ? AND period = ?",
(department_id, quota_type, period),
)
row = await cursor.fetchone()
if row is None:
return None
return self._row_to_dict(row)
async def list_department_quotas(
self,
db_path: Path,
department_id: str,
) -> list[dict[str, Any]]:
"""List all quotas for a department (all types/periods)."""
async with aiosqlite.connect(str(db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM department_quotas WHERE department_id = ? "
"ORDER BY quota_type ASC, period ASC",
(department_id,),
)
rows = await cursor.fetchall()
return [self._row_to_dict(row) for row in rows]
async def delete_quota(
self,
db_path: Path,
department_id: str,
quota_type: str,
period: str = "daily",
) -> bool:
"""Delete a quota. Returns ``True`` if a row was deleted."""
_validate_quota_type(quota_type)
_validate_period(period)
async with aiosqlite.connect(str(db_path)) as db:
cursor = await db.execute(
"DELETE FROM department_quotas "
"WHERE department_id = ? AND quota_type = ? AND period = ?",
(department_id, quota_type, period),
)
await db.commit()
return cursor.rowcount > 0
# ------------------------------------------------------------------
# Quota enforcement
# ------------------------------------------------------------------
async def check_quota(
self,
db_path: Path,
department_id: str,
quota_type: str,
period: str,
current_usage: int | float,
) -> tuple[bool, str]:
"""Check whether ``current_usage`` is within the configured quota.
Args:
db_path: Path to the auth SQLite DB.
department_id: Department id.
quota_type: ``token_limit`` or ``cost_limit``.
``model_whitelist`` is not a quota check use
:meth:`is_model_allowed` instead.
period: ``daily`` or ``monthly``.
current_usage: Current usage value (tokens for
``token_limit``, cents for ``cost_limit``).
Returns:
``(allowed, reason)`` tuple. ``allowed`` is ``True`` if the
usage is within the limit (or no quota is set). ``reason``
is ``"ok"`` when allowed, or a human-readable denial reason.
"""
_validate_quota_type(quota_type)
_validate_period(period)
if quota_type == "model_whitelist":
return (
False,
"model_whitelist is not a quota check; use is_model_allowed()",
)
quota = await self.get_quota(db_path, department_id, quota_type, period)
if quota is None:
return (True, "ok")
limit = quota["limit_value"]
if not isinstance(limit, int):
return (True, "ok") # malformed limit; allow
if current_usage >= limit:
return (
False,
f"{quota_type} ({period}) exceeded: usage={current_usage}, limit={limit}",
)
return (True, "ok")
async def is_model_allowed(
self,
db_path: Path,
department_id: str,
model: str,
) -> tuple[bool, str]:
"""Check whether ``model`` is in the department's whitelist.
If no ``model_whitelist`` quota is set for the department (any
period), all models are allowed. Otherwise the model must appear
in the whitelist.
Returns:
``(allowed, reason)`` tuple.
"""
async with aiosqlite.connect(str(db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
"SELECT * FROM department_quotas "
"WHERE department_id = ? AND quota_type = 'model_whitelist' "
"ORDER BY period ASC LIMIT 1",
(department_id,),
)
row = await cursor.fetchone()
if row is None:
return (True, "ok")
whitelist = _deserialize_limit_value("model_whitelist", row["limit_value"])
if not isinstance(whitelist, list):
return (True, "ok")
if model in whitelist:
return (True, "ok")
return (
False,
f"model {model!r} not in department {department_id!r} whitelist",
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _row_to_dict(row: aiosqlite.Row | Any) -> dict[str, Any]:
"""Convert a ``department_quotas`` row to a JSON-safe dict."""
quota_type = row["quota_type"]
return {
"id": row["id"],
"department_id": row["department_id"],
"quota_type": quota_type,
"limit_value": _deserialize_limit_value(quota_type, row["limit_value"]),
"period": row["period"],
"updated_at": row["updated_at"],
}
# ---------------------------------------------------------------------------
# Module-level singleton (overridable in tests via set_quota_service)
# ---------------------------------------------------------------------------
_quota_service: QuotaService | None = None
def get_quota_service() -> QuotaService:
"""Return the process-wide :class:`QuotaService` (lazy singleton)."""
global _quota_service
if _quota_service is None:
_quota_service = QuotaService()
return _quota_service
def set_quota_service(service: QuotaService | None) -> None:
"""Inject a custom :class:`QuotaService` (used by tests)."""
global _quota_service
_quota_service = service

View File

@ -23,6 +23,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, ConfigDict
from agentkit.server.admin.department_service import get_department_service
from agentkit.server.admin.llm_config_service import (
LlmConfigService,
get_llm_config_service,
)
from agentkit.server.admin.quota_service import get_quota_service
from agentkit.server.admin.user_service import get_user_service
from agentkit.server.auth.dependencies import require_authenticated
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
@ -550,3 +555,302 @@ async def list_user_departments(
db_path = await _ensure_db(request)
svc = get_user_service()
return await svc.list_user_departments(db_path, user_id)
# ---------------------------------------------------------------------------
# LLM config endpoints (U5) — provider CRUD + fallback chains
# ---------------------------------------------------------------------------
def _get_llm_config_service(request: Request) -> LlmConfigService:
"""Resolve the :class:`LlmConfigService` from ``app.state`` or the
module singleton.
The service needs the YAML config path, which is read from
``app.state.server_config._config_path`` (same source as the
existing ``GET/PUT /settings/llm`` endpoints). Falls back to the
module singleton (which tests can pre-populate via
:func:`set_llm_config_service`).
"""
server_config = getattr(request.app.state, "server_config", None)
if server_config is not None and getattr(server_config, "_config_path", None):
try:
return get_llm_config_service(server_config._config_path)
except ValueError:
# Singleton was reset between calls — re-initialize with the
# resolved path.
from agentkit.server.admin.llm_config_service import set_llm_config_service
svc = LlmConfigService(Path(server_config._config_path))
set_llm_config_service(svc)
return svc
# Fall back to the existing singleton (tests inject it directly).
return get_llm_config_service()
class ProviderCreateRequest(BaseModel):
"""Body for ``POST /admin/llm/providers``."""
model_config = ConfigDict(extra="forbid")
name: str
type: str = "openai"
api_key: str
base_url: str = ""
models: dict[str, Any] | None = None
max_tokens: int = 4096
timeout: float = 60.0
class ProviderUpdateRequest(BaseModel):
"""Body for ``PATCH /admin/llm/providers/{name}``."""
model_config = ConfigDict(extra="forbid")
type: str | None = None
api_key: str | None = None
base_url: str | None = None
models: dict[str, Any] | None = None
max_tokens: int | None = None
timeout: float | None = None
class ApiKeySetRequest(BaseModel):
"""Body for ``POST /admin/llm/providers/{name}/api-key``."""
model_config = ConfigDict(extra="forbid")
api_key: str
class FallbackSetRequest(BaseModel):
"""Body for ``PUT /admin/llm/fallbacks/{model}``."""
model_config = ConfigDict(extra="forbid")
chain: list[str]
@admin_router.get("/llm/providers")
async def list_llm_providers(
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> list[dict[str, Any]]:
"""List all LLM providers (API keys masked)."""
svc = _get_llm_config_service(request)
return svc.list_providers()
@admin_router.post("/llm/providers", status_code=201)
async def create_llm_provider(
payload: ProviderCreateRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Create a new LLM provider.
Returns 201 with the provider dict (masked key) on success, 409 if
a provider with the same name already exists.
"""
svc = _get_llm_config_service(request)
try:
return svc.create_provider(
name=payload.name,
provider_type=payload.type,
api_key=payload.api_key,
base_url=payload.base_url,
models=payload.models,
max_tokens=payload.max_tokens,
timeout=payload.timeout,
)
except ValueError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
@admin_router.patch("/llm/providers/{name}")
async def update_llm_provider(
name: str,
payload: ProviderUpdateRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Partially update an LLM provider.
Returns 200 with the updated provider dict (masked key), 404 if the
provider does not exist.
"""
svc = _get_llm_config_service(request)
try:
return svc.update_provider(
name=name,
provider_type=payload.type,
api_key=payload.api_key,
base_url=payload.base_url,
models=payload.models,
max_tokens=payload.max_tokens,
timeout=payload.timeout,
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
@admin_router.delete("/llm/providers/{name}")
async def delete_llm_provider(
name: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Delete an LLM provider.
Returns 200 ``{deleted: true}`` on success, 404 if the provider
does not exist, 400 if the provider is used in a fallback chain.
"""
svc = _get_llm_config_service(request)
try:
deleted = svc.delete_provider(name)
except ValueError as exc:
msg = str(exc)
if "not found" in msg:
raise HTTPException(status_code=404, detail=msg) from exc
raise HTTPException(status_code=400, detail=msg) from exc
if not deleted:
raise HTTPException(status_code=404, detail="Provider not found")
return {"deleted": True}
@admin_router.post("/llm/providers/{name}/api-key")
async def set_llm_provider_api_key(
name: str,
payload: ApiKeySetRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Set the API key for a provider.
The plaintext key is written to ``.env`` (not the YAML file); the
YAML stores a ``${ENV_VAR}`` reference. Returns 200 with the
updated provider dict (masked key).
"""
svc = _get_llm_config_service(request)
return svc.set_api_key(name, payload.api_key)
@admin_router.get("/llm/fallbacks")
async def get_llm_fallbacks(
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, list[str]]:
"""Return all fallback chains (model → list of provider/model refs)."""
svc = _get_llm_config_service(request)
return svc.get_fallbacks()
@admin_router.put("/llm/fallbacks/{model}")
async def set_llm_fallback(
model: str,
payload: FallbackSetRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Set the fallback chain for a model."""
svc = _get_llm_config_service(request)
return svc.set_fallback(model, payload.chain)
@admin_router.delete("/llm/fallbacks/{model}")
async def delete_llm_fallback(
model: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Delete the fallback chain for a model.
Returns 200 ``{deleted: true}`` if a chain was deleted, 404 if no
chain existed for ``model``.
"""
svc = _get_llm_config_service(request)
deleted = svc.delete_fallback(model)
if not deleted:
raise HTTPException(status_code=404, detail="Fallback chain not found")
return {"deleted": True}
# ---------------------------------------------------------------------------
# Department quota endpoints (U5) — per-department LLM quotas
# ---------------------------------------------------------------------------
class QuotaSetRequest(BaseModel):
"""Body for ``PUT /admin/departments/{id}/quotas``."""
model_config = ConfigDict(extra="forbid")
quota_type: str
limit_value: int | list[str]
period: str = "daily"
@admin_router.get("/departments/{department_id}/quotas")
async def list_department_quotas(
department_id: str,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> list[dict[str, Any]]:
"""List all quotas for a department."""
db_path = await _ensure_db(request)
svc = get_quota_service()
return await svc.list_department_quotas(db_path, department_id)
@admin_router.put("/departments/{department_id}/quotas")
async def set_department_quota(
department_id: str,
payload: QuotaSetRequest,
request: Request,
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Set (upsert) a quota for a department.
Returns 200 with the upserted quota dict. Returns 400 if
``quota_type`` or ``period`` is invalid, or if ``limit_value``
doesn't match the quota type (int for token/cost limits, list for
model_whitelist).
"""
db_path = await _ensure_db(request)
svc = get_quota_service()
try:
return await svc.set_quota(
db_path,
department_id,
payload.quota_type,
payload.limit_value,
payload.period,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@admin_router.delete("/departments/{department_id}/quotas")
async def delete_department_quota(
department_id: str,
request: Request,
quota_type: str,
period: str = "daily",
admin: dict[str, Any] = Depends(_require_admin),
) -> dict[str, Any]:
"""Delete a quota for a department.
Query params: ``quota_type`` (required), ``period`` (default
``daily``). Returns 200 ``{deleted: true}`` if a quota was deleted,
404 if no quota matched, 400 if ``quota_type`` or ``period`` is
invalid.
"""
db_path = await _ensure_db(request)
svc = get_quota_service()
try:
deleted = await svc.delete_quota(db_path, department_id, quota_type, period)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if not deleted:
raise HTTPException(status_code=404, detail="Quota not found")
return {"deleted": True}

View File

@ -0,0 +1,464 @@
"""Integration tests for the LLM config admin routes (U5).
Uses FastAPI TestClient with a test app that mounts only the
``admin_router`` from ``routes.admin``. The ``_require_admin`` dependency
is overridden via ``app.dependency_overrides`` so the tests don't need
real JWTs they can simulate admin and non-admin callers directly.
The :class:`LlmConfigService` is injected via
:func:`set_llm_config_service` so the tests can point it at a temp
YAML file.
"""
from __future__ import annotations
import uuid
from pathlib import Path
from typing import Any
import pytest
import yaml
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from agentkit.server.admin.llm_config_service import (
LlmConfigService,
set_llm_config_service,
)
from agentkit.server.auth.models import init_auth_db
from agentkit.server.routes import admin as admin_routes_module
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _sample_config() -> dict[str, Any]:
"""A minimal agentkit.yaml-style config for testing."""
return {
"server": {"host": "0.0.0.0", "port": 8001},
"llm": {
"providers": {
"openai": {
"type": "openai",
"api_key": "sk-test-12345678",
"base_url": "https://api.openai.com/v1",
"models": {"gpt-4o": {}},
"max_tokens": 4096,
"timeout": 120.0,
},
},
"model_aliases": {"gpt4": "openai/gpt-4o"},
"fallbacks": {},
},
}
@pytest.fixture
def config_path(tmp_path: Path) -> Path:
"""Create a temporary agentkit.yaml config file."""
path = tmp_path / "agentkit.yaml"
with open(path, "w", encoding="utf-8") as f:
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
return path
@pytest.fixture
async def tmp_auth_db(tmp_path: Path) -> Path:
db_path = tmp_path / "admin_llm.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture(autouse=True)
def _reset_llm_config_singleton():
"""Reset the LlmConfigService singleton before and after each test."""
set_llm_config_service(None)
yield
set_llm_config_service(None)
@pytest.fixture
def admin_app(config_path: Path, tmp_auth_db: Path) -> FastAPI:
"""A minimal FastAPI app with only the admin router mounted.
The ``_require_admin`` dependency is overridden to return a fake
admin user. The :class:`LlmConfigService` singleton is pre-populated
so the routes use the temp YAML file.
"""
# Inject the LlmConfigService singleton pointing at the temp YAML.
set_llm_config_service(LlmConfigService(config_path))
app = FastAPI()
app.state.auth_db_path = str(tmp_auth_db)
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
# Default: allow admin access.
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
return app
@pytest.fixture
def admin_client(admin_app: FastAPI) -> TestClient:
return TestClient(admin_app)
def _make_admin_user() -> dict[str, Any]:
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
def _raise_forbidden() -> dict[str, Any]:
"""Dependency override that simulates a non-admin (403) response."""
raise HTTPException(status_code=403, detail="Admin permission required")
# ---------------------------------------------------------------------------
# Provider CRUD
# ---------------------------------------------------------------------------
class TestListProviders:
def test_list_returns_existing_providers(self, admin_client: TestClient):
resp = admin_client.get("/api/v1/admin/llm/providers")
assert resp.status_code == 200
providers = resp.json()
assert len(providers) == 1
assert providers[0]["name"] == "openai"
def test_list_masks_api_keys(self, admin_client: TestClient):
resp = admin_client.get("/api/v1/admin/llm/providers")
providers = resp.json()
assert providers[0]["api_key"].startswith("****")
# Should not contain the plaintext key.
assert "sk-test-12345678" not in providers[0]["api_key"]
class TestCreateProvider:
def test_create_returns_201(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/llm/providers",
json={
"name": "anthropic",
"type": "anthropic",
"api_key": "sk-ant-test-abcd1234",
"base_url": "https://api.anthropic.com",
"models": {"claude-sonnet-4-20250514": {}},
"max_tokens": 8192,
"timeout": 90.0,
},
)
assert resp.status_code == 201
body = resp.json()
assert body["name"] == "anthropic"
assert body["type"] == "anthropic"
# API key should be masked.
assert body["api_key"].startswith("****")
def test_create_duplicate_returns_409(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/llm/providers",
json={
"name": "openai",
"type": "openai",
"api_key": "sk-duplicate",
},
)
assert resp.status_code == 409
def test_non_admin_returns_403(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post(
"/api/v1/admin/llm/providers",
json={"name": "x", "type": "openai", "api_key": "sk-x"},
)
assert resp.status_code == 403
class TestUpdateProvider:
def test_update_partial_returns_200(self, admin_client: TestClient):
resp = admin_client.patch(
"/api/v1/admin/llm/providers/openai",
json={"max_tokens": 2048},
)
assert resp.status_code == 200
body = resp.json()
assert body["max_tokens"] == 2048
# Other fields should be preserved.
assert body["base_url"] == "https://api.openai.com/v1"
def test_update_unknown_returns_404(self, admin_client: TestClient):
resp = admin_client.patch(
"/api/v1/admin/llm/providers/nonexistent",
json={"max_tokens": 2048},
)
assert resp.status_code == 404
def test_update_with_masked_api_key_preserves_existing(
self, admin_client: TestClient, config_path: Path
):
resp = admin_client.patch(
"/api/v1/admin/llm/providers/openai",
json={"api_key": "****5678"},
)
assert resp.status_code == 200
# Verify the YAML file still has the original key.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
class TestDeleteProvider:
def test_delete_returns_200(self, admin_client: TestClient):
resp = admin_client.delete("/api/v1/admin/llm/providers/openai")
assert resp.status_code == 200
assert resp.json() == {"deleted": True}
def test_delete_unknown_returns_404(self, admin_client: TestClient):
resp = admin_client.delete("/api/v1/admin/llm/providers/nonexistent")
assert resp.status_code == 404
def test_delete_provider_used_in_fallback_returns_400(self, admin_client: TestClient):
# Add a second provider and a fallback chain that references it.
admin_client.post(
"/api/v1/admin/llm/providers",
json={
"name": "anthropic",
"type": "anthropic",
"api_key": "sk-ant-test",
},
)
admin_client.put(
"/api/v1/admin/llm/fallbacks/gpt-4o",
json={"chain": ["anthropic"]},
)
resp = admin_client.delete("/api/v1/admin/llm/providers/anthropic")
assert resp.status_code == 400
assert "fallback" in resp.json()["detail"].lower()
# ---------------------------------------------------------------------------
# API key management
# ---------------------------------------------------------------------------
class TestSetApiKey:
def test_set_api_key_returns_200(self, admin_client: TestClient):
resp = admin_client.post(
"/api/v1/admin/llm/providers/openai/api-key",
json={"api_key": "sk-brand-new-key-1234"},
)
assert resp.status_code == 200
body = resp.json()
assert body["api_key"].startswith("****")
def test_set_api_key_writes_to_env_file(self, admin_client: TestClient, config_path: Path):
admin_client.post(
"/api/v1/admin/llm/providers/openai/api-key",
json={"api_key": "sk-brand-new-key-1234"},
)
env_path = config_path.parent / ".env"
assert env_path.exists()
with open(env_path, encoding="utf-8") as f:
content = f.read()
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
# ---------------------------------------------------------------------------
# Fallback chain management
# ---------------------------------------------------------------------------
class TestFallbackRoutes:
def test_get_fallbacks_returns_empty_when_none(self, admin_client: TestClient):
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
assert resp.status_code == 200
assert resp.json() == {}
def test_set_fallback_returns_200(self, admin_client: TestClient):
resp = admin_client.put(
"/api/v1/admin/llm/fallbacks/gpt-4o",
json={"chain": ["openai", "anthropic"]},
)
assert resp.status_code == 200
body = resp.json()
assert body["model"] == "gpt-4o"
assert body["chain"] == ["openai", "anthropic"]
def test_get_fallbacks_returns_set_chain(self, admin_client: TestClient):
admin_client.put(
"/api/v1/admin/llm/fallbacks/gpt-4o",
json={"chain": ["openai", "anthropic"]},
)
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
assert resp.status_code == 200
assert resp.json() == {"gpt-4o": ["openai", "anthropic"]}
def test_delete_fallback_returns_200(self, admin_client: TestClient):
admin_client.put(
"/api/v1/admin/llm/fallbacks/gpt-4o",
json={"chain": ["openai"]},
)
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/gpt-4o")
assert resp.status_code == 200
assert resp.json() == {"deleted": True}
def test_delete_unknown_fallback_returns_404(self, admin_client: TestClient):
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/nonexistent")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# Department quota endpoints
# ---------------------------------------------------------------------------
class TestDepartmentQuotaRoutes:
def test_set_quota_returns_200(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={
"quota_type": "token_limit",
"limit_value": 1000,
"period": "daily",
},
)
assert resp.status_code == 200
body = resp.json()
assert body["quota_type"] == "token_limit"
assert body["limit_value"] == 1000
assert body["period"] == "daily"
def test_set_quota_invalid_type_returns_400(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={"quota_type": "invalid", "limit_value": 1000},
)
assert resp.status_code == 400
def test_set_quota_model_whitelist_accepts_list(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={
"quota_type": "model_whitelist",
"limit_value": ["gpt-4o", "claude"],
"period": "daily",
},
)
assert resp.status_code == 200
body = resp.json()
assert body["limit_value"] == ["gpt-4o", "claude"]
def test_list_quotas_returns_empty_for_new_department(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
assert resp.status_code == 200
assert resp.json() == []
def test_list_quotas_returns_set_quotas(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={"quota_type": "token_limit", "limit_value": 1000},
)
admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={
"quota_type": "cost_limit",
"limit_value": 5000,
"period": "monthly",
},
)
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
assert resp.status_code == 200
body = resp.json()
assert len(body) == 2
types = {q["quota_type"] for q in body}
assert types == {"token_limit", "cost_limit"}
def test_delete_quota_returns_200(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
admin_client.put(
f"/api/v1/admin/departments/{dept_id}/quotas",
json={"quota_type": "token_limit", "limit_value": 1000},
)
resp = admin_client.delete(
f"/api/v1/admin/departments/{dept_id}/quotas",
params={"quota_type": "token_limit", "period": "daily"},
)
assert resp.status_code == 200
assert resp.json() == {"deleted": True}
def test_delete_unknown_quota_returns_404(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.delete(
f"/api/v1/admin/departments/{dept_id}/quotas",
params={"quota_type": "token_limit", "period": "daily"},
)
assert resp.status_code == 404
def test_delete_quota_invalid_period_returns_400(self, admin_client: TestClient):
dept_id = str(uuid.uuid4())
resp = admin_client.delete(
f"/api/v1/admin/departments/{dept_id}/quotas",
params={"quota_type": "token_limit", "period": "weekly"},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# Non-admin access
# ---------------------------------------------------------------------------
class TestNonAdminAccess:
"""All LLM config endpoints must return 403 for non-admin users."""
def test_non_admin_cannot_list_providers(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.get("/api/v1/admin/llm/providers")
assert resp.status_code == 403
def test_non_admin_cannot_create_provider(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post(
"/api/v1/admin/llm/providers",
json={"name": "x", "type": "openai", "api_key": "sk-x"},
)
assert resp.status_code == 403
def test_non_admin_cannot_delete_provider(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.delete("/api/v1/admin/llm/providers/openai")
assert resp.status_code == 403
def test_non_admin_cannot_set_api_key(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.post(
"/api/v1/admin/llm/providers/openai/api-key",
json={"api_key": "sk-x"},
)
assert resp.status_code == 403
def test_non_admin_cannot_get_fallbacks(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.get("/api/v1/admin/llm/fallbacks")
assert resp.status_code == 403
def test_non_admin_cannot_set_quota(self, admin_app: FastAPI):
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
client = TestClient(admin_app)
resp = client.put(
f"/api/v1/admin/departments/{uuid.uuid4()}/quotas",
json={"quota_type": "token_limit", "limit_value": 1000},
)
assert resp.status_code == 403

View File

@ -0,0 +1,410 @@
"""Unit tests for LlmConfigService (U5 — LLM provider/fallback CRUD).
Covers:
- Happy path: list get create update delete providers
- API key management: set_api_key writes to .env, masks in responses
- Fallback chain: get set delete
- Edge case: create duplicate provider ValueError
- Edge case: update non-existent provider ValueError
- Edge case: delete provider used in fallback ValueError
- Edge case: delete non-existent provider ValueError
- File locking: concurrent writes don't corrupt (sequential simulation)
- Singleton helpers (get/set_llm_config_service)
"""
from __future__ import annotations
import threading
from pathlib import Path
from typing import Any
import pytest
import yaml
from agentkit.server.admin.llm_config_service import (
LlmConfigService,
get_llm_config_service,
set_llm_config_service,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _sample_config() -> dict[str, Any]:
"""A minimal agentkit.yaml-style config for testing."""
return {
"server": {"host": "0.0.0.0", "port": 8001},
"llm": {
"providers": {
"openai": {
"type": "openai",
"api_key": "sk-test-12345678",
"base_url": "https://api.openai.com/v1",
"models": {"gpt-4o": {}},
"max_tokens": 4096,
"timeout": 120.0,
},
},
"model_aliases": {"gpt4": "openai/gpt-4o"},
"fallbacks": {},
},
}
@pytest.fixture
def config_path(tmp_path: Path) -> Path:
"""Create a temporary agentkit.yaml config file."""
path = tmp_path / "agentkit.yaml"
with open(path, "w", encoding="utf-8") as f:
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
return path
@pytest.fixture
def service(config_path: Path) -> LlmConfigService:
return LlmConfigService(config_path)
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Reset the module singleton before and after each test."""
set_llm_config_service(None)
yield
set_llm_config_service(None)
# ---------------------------------------------------------------------------
# Provider CRUD happy path
# ---------------------------------------------------------------------------
class TestProviderCrudHappyPath:
def test_list_returns_existing_providers(self, service: LlmConfigService):
providers = service.list_providers()
assert len(providers) == 1
assert providers[0]["name"] == "openai"
assert providers[0]["type"] == "openai"
def test_list_masks_api_keys(self, service: LlmConfigService):
providers = service.list_providers()
# API key should be masked: ****xxxx
assert providers[0]["api_key"].startswith("****")
# Should show only last 4 chars
assert providers[0]["api_key"] == "****5678"
def test_get_returns_provider_by_name(self, service: LlmConfigService):
provider = service.get_provider("openai")
assert provider is not None
assert provider["name"] == "openai"
assert provider["type"] == "openai"
assert provider["base_url"] == "https://api.openai.com/v1"
def test_get_returns_none_for_unknown(self, service: LlmConfigService):
assert service.get_provider("nonexistent") is None
def test_create_adds_provider_to_yaml(self, service: LlmConfigService, config_path: Path):
created = service.create_provider(
name="anthropic",
provider_type="anthropic",
api_key="sk-ant-test-abcd1234",
base_url="https://api.anthropic.com",
models={"claude-sonnet-4-20250514": {}},
max_tokens=8192,
timeout=90.0,
)
assert created["name"] == "anthropic"
assert created["type"] == "anthropic"
# API key should be masked in the response.
assert created["api_key"].startswith("****")
# Verify the YAML file was updated.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert "anthropic" in saved["llm"]["providers"]
# The YAML should store a ${VAR} reference, not the plaintext key.
assert saved["llm"]["providers"]["anthropic"]["api_key"] == "${ANTHROPIC_API_KEY}"
def test_create_writes_api_key_to_env_file(self, service: LlmConfigService, config_path: Path):
service.create_provider(
name="anthropic",
provider_type="anthropic",
api_key="sk-ant-test-abcd1234",
)
env_path = config_path.parent / ".env"
assert env_path.exists()
with open(env_path, encoding="utf-8") as f:
content = f.read()
assert "ANTHROPIC_API_KEY=sk-ant-test-abcd1234" in content
def test_update_partial_only_changes_provided_fields(
self, service: LlmConfigService, config_path: Path
):
updated = service.update_provider("openai", max_tokens=2048)
assert updated["max_tokens"] == 2048
# Other fields should be preserved.
assert updated["base_url"] == "https://api.openai.com/v1"
assert updated["type"] == "openai"
def test_update_with_masked_api_key_preserves_existing(
self, service: LlmConfigService, config_path: Path
):
"""When user sends back a masked key (****xxxx), the real key should
be preserved."""
service.update_provider("openai", api_key="****5678")
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
# The original plaintext key should be preserved.
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
def test_update_with_new_api_key_writes_to_env(
self, service: LlmConfigService, config_path: Path
):
service.update_provider("openai", api_key="sk-new-key-9999")
env_path = config_path.parent / ".env"
with open(env_path, encoding="utf-8") as f:
content = f.read()
assert "OPENAI_API_KEY=sk-new-key-9999" in content
# YAML should now have a ${VAR} reference.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
def test_delete_removes_provider(self, service: LlmConfigService, config_path: Path):
deleted = service.delete_provider("openai")
assert deleted is True
# Verify the YAML file was updated.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert "openai" not in saved["llm"]["providers"]
# list_providers should return empty.
assert service.list_providers() == []
# ---------------------------------------------------------------------------
# Provider CRUD edge cases
# ---------------------------------------------------------------------------
class TestProviderCrudEdgeCases:
def test_create_duplicate_raises_value_error(self, service: LlmConfigService):
with pytest.raises(ValueError, match="already exists"):
service.create_provider(
name="openai",
provider_type="openai",
api_key="sk-duplicate",
)
def test_update_nonexistent_raises_value_error(self, service: LlmConfigService):
with pytest.raises(ValueError, match="not found"):
service.update_provider("nonexistent", max_tokens=2048)
def test_delete_nonexistent_raises_value_error(self, service: LlmConfigService):
with pytest.raises(ValueError, match="not found"):
service.delete_provider("nonexistent")
def test_delete_provider_used_in_fallback_raises_value_error(self, service: LlmConfigService):
# Add a second provider and a fallback chain that references it.
service.create_provider(
name="anthropic",
provider_type="anthropic",
api_key="sk-ant-test",
)
service.set_fallback("gpt-4o", ["anthropic"])
with pytest.raises(ValueError, match="used in fallback"):
service.delete_provider("anthropic")
# ---------------------------------------------------------------------------
# API key management
# ---------------------------------------------------------------------------
class TestApiKeyManagement:
def test_set_api_key_writes_to_env_file(self, service: LlmConfigService, config_path: Path):
service.set_api_key("openai", "sk-brand-new-key-1234")
env_path = config_path.parent / ".env"
assert env_path.exists()
with open(env_path, encoding="utf-8") as f:
content = f.read()
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
def test_set_api_key_updates_yaml_reference(self, service: LlmConfigService, config_path: Path):
service.set_api_key("openai", "sk-brand-new-key-1234")
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
def test_set_api_key_returns_masked_provider(self, service: LlmConfigService):
result = service.set_api_key("openai", "sk-brand-new-key-1234")
assert result["api_key"].startswith("****")
# Should not contain the plaintext key.
assert "sk-brand-new-key-1234" not in result["api_key"]
def test_set_api_key_for_new_provider_creates_stub(
self, service: LlmConfigService, config_path: Path
):
"""Setting an API key for a provider that doesn't yet exist in the
YAML should create a stub entry."""
result = service.set_api_key("gemini", "AIza-test-key-1234")
assert result["name"] == "gemini"
assert result["type"] == "openai" # default type for stub
assert result["api_key"].startswith("****")
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert "gemini" in saved["llm"]["providers"]
assert saved["llm"]["providers"]["gemini"]["api_key"] == "${GEMINI_API_KEY}"
# ---------------------------------------------------------------------------
# Fallback chain management
# ---------------------------------------------------------------------------
class TestFallbackManagement:
def test_get_fallbacks_returns_empty_when_none_configured(self, service: LlmConfigService):
fallbacks = service.get_fallbacks()
assert fallbacks == {}
def test_set_fallback_adds_chain(self, service: LlmConfigService, config_path: Path):
result = service.set_fallback("gpt-4o", ["openai", "anthropic"])
assert result["model"] == "gpt-4o"
assert result["chain"] == ["openai", "anthropic"]
# Verify it was written to YAML.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
assert saved["llm"]["fallbacks"]["gpt-4o"] == ["openai", "anthropic"]
def test_get_fallbacks_returns_set_chain(self, service: LlmConfigService):
service.set_fallback("gpt-4o", ["openai", "anthropic"])
fallbacks = service.get_fallbacks()
assert fallbacks == {"gpt-4o": ["openai", "anthropic"]}
def test_set_fallback_overwrites_existing(self, service: LlmConfigService):
service.set_fallback("gpt-4o", ["openai"])
service.set_fallback("gpt-4o", ["anthropic", "gemini"])
fallbacks = service.get_fallbacks()
assert fallbacks["gpt-4o"] == ["anthropic", "gemini"]
def test_delete_fallback_removes_chain(self, service: LlmConfigService):
service.set_fallback("gpt-4o", ["openai"])
deleted = service.delete_fallback("gpt-4o")
assert deleted is True
assert service.get_fallbacks() == {}
def test_delete_fallback_returns_false_for_unset(self, service: LlmConfigService):
deleted = service.delete_fallback("nonexistent")
assert deleted is False
# ---------------------------------------------------------------------------
# File locking
# ---------------------------------------------------------------------------
class TestFileLocking:
def test_concurrent_writes_do_not_corrupt(self, service: LlmConfigService, config_path: Path):
"""Simulate concurrent writes from multiple threads.
Each thread creates a different provider. The file lock should
serialize the writes so the final YAML is valid and contains
all providers.
"""
errors: list[Exception] = []
def _create_provider(name: str, api_key: str) -> None:
try:
# Each thread needs its own service instance to avoid
# sharing state, but they all point to the same file.
svc = LlmConfigService(config_path)
svc.create_provider(name=name, provider_type="openai", api_key=api_key)
except Exception as exc:
errors.append(exc)
threads = [
threading.Thread(
target=_create_provider,
args=(f"provider_{i}", f"sk-key-{i:04d}"),
)
for i in range(5)
]
for t in threads:
t.start()
for t in threads:
t.join()
# No errors should have occurred.
assert errors == [], f"Concurrent writes failed: {errors}"
# The YAML file should be valid and contain all 5 new providers
# plus the original "openai" provider.
with open(config_path, encoding="utf-8") as f:
saved = yaml.safe_load(f)
providers = saved["llm"]["providers"]
for i in range(5):
assert f"provider_{i}" in providers
assert "openai" in providers
def test_lockfile_is_created(self, service: LlmConfigService, config_path: Path):
"""The lockfile should be created on the first write."""
service.create_provider(
name="anthropic",
provider_type="anthropic",
api_key="sk-ant-test",
)
lockfile = config_path.with_suffix(config_path.suffix + ".lock")
assert lockfile.exists()
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
class TestSingletonHelpers:
def test_get_llm_config_service_requires_path_on_first_call(self):
with pytest.raises(ValueError, match="config_path is required"):
get_llm_config_service()
def test_get_llm_config_service_returns_singleton(self, config_path: Path):
svc1 = get_llm_config_service(config_path)
svc2 = get_llm_config_service() # subsequent call ignores path
assert svc1 is svc2
def test_set_llm_config_service_overrides_singleton(self, config_path: Path):
custom = LlmConfigService(config_path)
set_llm_config_service(custom)
assert get_llm_config_service() is custom
def test_set_llm_config_service_none_resets_singleton(self, config_path: Path):
svc = get_llm_config_service(config_path)
set_llm_config_service(None)
# Next call should raise (no path provided after reset).
with pytest.raises(ValueError, match="config_path is required"):
get_llm_config_service()
# The previously-created instance should still be usable.
assert svc.list_providers() is not None
# ---------------------------------------------------------------------------
# Env var resolution
# ---------------------------------------------------------------------------
class TestEnvVarResolution:
def test_list_providers_masks_env_var_referenced_keys(
self, config_path: Path, monkeypatch: pytest.MonkeyPatch
):
"""When the YAML stores ``${VAR}``, the masked response should
show the masked *resolved* env value, not the literal ``${VAR}``."""
# Rewrite the config to use a ${VAR} reference.
with open(config_path, "w", encoding="utf-8") as f:
cfg = _sample_config()
cfg["llm"]["providers"]["openai"]["api_key"] = "${TEST_OPENAI_KEY}"
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
# Set the env var so it can be resolved.
monkeypatch.setenv("TEST_OPENAI_KEY", "sk-resolved-1234")
svc = LlmConfigService(config_path)
providers = svc.list_providers()
assert providers[0]["api_key"] == "****1234"

View File

@ -0,0 +1,333 @@
"""Unit tests for QuotaService (U5 — per-department LLM quotas).
Covers:
- Happy path: set get list delete quotas
- Quota check: token_limit (allowed / denied)
- Quota check: cost_limit (allowed / denied)
- Quota check: no quota set always allowed
- model_whitelist serialization (list JSON list)
- Upsert: set same quota twice updates the value
- Validation: invalid quota_type / period raises ValueError
- is_model_allowed: whitelist enforcement
- Singleton helpers (get/set_quota_service)
"""
from __future__ import annotations
import uuid
from pathlib import Path
import pytest
from agentkit.server.admin.quota_service import (
QuotaService,
get_quota_service,
set_quota_service,
)
from agentkit.server.auth.models import init_auth_db
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
async def fresh_db(tmp_path: Path) -> Path:
"""A brand-new auth DB on a fresh path (no data)."""
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture
def service() -> QuotaService:
return QuotaService()
def _random_dept_id() -> str:
return str(uuid.uuid4())
# ---------------------------------------------------------------------------
# Quota CRUD happy path
# ---------------------------------------------------------------------------
class TestQuotaCrudHappyPath:
async def test_set_token_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
quota = await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
assert quota["department_id"] == dept_id
assert quota["quota_type"] == "token_limit"
assert quota["limit_value"] == 1000
assert quota["period"] == "daily"
assert "id" in quota
assert "updated_at" in quota
async def test_set_cost_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
quota = await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
assert quota["quota_type"] == "cost_limit"
assert quota["limit_value"] == 5000
assert quota["period"] == "monthly"
async def test_set_model_whitelist_serializes_list(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
whitelist = ["gpt-4o", "claude-sonnet-4-20250514", "gemini-pro"]
quota = await service.set_quota(
fresh_db, dept_id, "model_whitelist", whitelist, period="daily"
)
assert quota["quota_type"] == "model_whitelist"
# The deserialized value should match the input list.
assert quota["limit_value"] == whitelist
async def test_get_returns_set_quota(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
assert fetched is not None
assert fetched["limit_value"] == 1000
assert fetched["quota_type"] == "token_limit"
async def test_get_returns_none_for_unset_quota(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
assert fetched is None
async def test_list_returns_all_quotas_for_department(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
quotas = await service.list_department_quotas(fresh_db, dept_id)
assert len(quotas) == 3
types = {q["quota_type"] for q in quotas}
assert types == {"token_limit", "cost_limit", "model_whitelist"}
async def test_list_returns_empty_for_department_with_no_quotas(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
quotas = await service.list_department_quotas(fresh_db, dept_id)
assert quotas == []
async def test_delete_removes_quota(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
assert deleted is True
# Confirm it's gone.
assert await service.get_quota(fresh_db, dept_id, "token_limit") is None
async def test_delete_returns_false_for_unset_quota(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
assert deleted is False
# ---------------------------------------------------------------------------
# Upsert behavior
# ---------------------------------------------------------------------------
class TestQuotaUpsert:
async def test_set_same_quota_twice_updates_value(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
updated = await service.set_quota(fresh_db, dept_id, "token_limit", 2000)
assert updated["limit_value"] == 2000
# Only one row should exist.
quotas = await service.list_department_quotas(fresh_db, dept_id)
assert len(quotas) == 1
assert quotas[0]["limit_value"] == 2000
async def test_set_same_quota_with_different_period_creates_new_row(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
await service.set_quota(fresh_db, dept_id, "token_limit", 30000, period="monthly")
quotas = await service.list_department_quotas(fresh_db, dept_id)
assert len(quotas) == 2
periods = {q["period"] for q in quotas}
assert periods == {"daily", "monthly"}
async def test_upsert_model_whitelist_replaces_list(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o-mini"])
fetched = await service.get_quota(fresh_db, dept_id, "model_whitelist")
assert fetched is not None
assert fetched["limit_value"] == ["gpt-4o-mini"]
# ---------------------------------------------------------------------------
# Quota enforcement
# ---------------------------------------------------------------------------
class TestQuotaCheck:
async def test_token_limit_allowed_when_under_limit(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
allowed, reason = await service.check_quota(
fresh_db, dept_id, "token_limit", "daily", current_usage=500
)
assert allowed is True
assert reason == "ok"
async def test_token_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
allowed, reason = await service.check_quota(
fresh_db, dept_id, "token_limit", "daily", current_usage=1000
)
assert allowed is False
assert "exceeded" in reason
async def test_token_limit_denied_over_limit(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
allowed, reason = await service.check_quota(
fresh_db, dept_id, "token_limit", "daily", current_usage=1500
)
assert allowed is False
assert "exceeded" in reason
async def test_cost_limit_allowed_when_under_limit(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
allowed, reason = await service.check_quota(
fresh_db, dept_id, "cost_limit", "monthly", current_usage=2500
)
assert allowed is True
assert reason == "ok"
async def test_cost_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
allowed, _ = await service.check_quota(
fresh_db, dept_id, "cost_limit", "monthly", current_usage=5000
)
assert allowed is False
async def test_no_quota_set_always_allowed(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
allowed, reason = await service.check_quota(
fresh_db, dept_id, "token_limit", "daily", current_usage=999999
)
assert allowed is True
assert reason == "ok"
async def test_model_whitelist_not_a_quota_check(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
allowed, reason = await service.check_quota(
fresh_db, dept_id, "model_whitelist", "daily", current_usage=0
)
assert allowed is False
assert "is_model_allowed" in reason
# ---------------------------------------------------------------------------
# Model whitelist enforcement
# ---------------------------------------------------------------------------
class TestModelWhitelist:
async def test_is_model_allowed_when_in_whitelist(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gpt-4o")
assert allowed is True
assert reason == "ok"
async def test_is_model_denied_when_not_in_whitelist(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gemini-pro")
assert allowed is False
assert "not in" in reason
async def test_is_model_allowed_when_no_whitelist(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
# No whitelist set — all models allowed.
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "any-model")
assert allowed is True
assert reason == "ok"
# ---------------------------------------------------------------------------
# Validation
# ---------------------------------------------------------------------------
class TestQuotaValidation:
async def test_set_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="Invalid quota_type"):
await service.set_quota(fresh_db, dept_id, "invalid_type", 1000)
async def test_set_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="Invalid period"):
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="weekly")
async def test_set_token_limit_with_list_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="must be an int"):
await service.set_quota(fresh_db, dept_id, "token_limit", ["gpt-4o"])
async def test_set_model_whitelist_with_int_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="must be a list"):
await service.set_quota(fresh_db, dept_id, "model_whitelist", 1000)
async def test_get_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="Invalid quota_type"):
await service.get_quota(fresh_db, dept_id, "invalid_type")
async def test_delete_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="Invalid period"):
await service.delete_quota(fresh_db, dept_id, "token_limit", period="weekly")
async def test_check_quota_invalid_quota_type_raises(
self, service: QuotaService, fresh_db: Path
):
dept_id = _random_dept_id()
with pytest.raises(ValueError, match="Invalid quota_type"):
await service.check_quota(fresh_db, dept_id, "invalid_type", "daily", current_usage=0)
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
class TestSingletonHelpers:
def test_get_quota_service_returns_singleton(self):
# Save the original singleton so we don't disturb other tests.
original = get_quota_service()
try:
custom = QuotaService()
set_quota_service(custom)
assert get_quota_service() is custom
# Clearing falls back to a new lazy instance.
set_quota_service(None)
new_one = get_quota_service()
assert new_one is not custom
finally:
set_quota_service(original)