feat(admin): U5 — LLM config admin endpoints + department quotas
QuotaService: set/get/list/delete quotas, check_quota (hard reject), is_model_allowed. JSON-serialized limit_value, upsert with ON CONFLICT. LlmConfigService: provider CRUD + set_api_key + fallback management. fcntl.flock file lock prevents concurrent YAML writes. Reuses settings.py helpers (_read_yaml_config, _write_yaml_config, _write_env_var, _mask_api_key). 11 new admin endpoints: provider CRUD, api-key, fallback CRUD, department quotas CRUD. All guarded by _require_admin. 93 new tests (30 quota unit + 32 llm-config unit + 31 integration).
This commit is contained in:
parent
ad65f7a8d7
commit
980919fc95
|
|
@ -0,0 +1,437 @@
|
||||||
|
"""LlmConfigService — runtime CRUD for LLM providers/fallbacks (U5).
|
||||||
|
|
||||||
|
This module wraps the YAML read/write logic from
|
||||||
|
:mod:`agentkit.server.routes.settings` as a reusable service. Web UI
|
||||||
|
routes (``/api/v1/admin/llm/*``) and the CLI ``agentkit admin llm``
|
||||||
|
sub-app both call into :class:`LlmConfigService` rather than touching
|
||||||
|
the YAML directly, keeping the validation rules (duplicate-provider,
|
||||||
|
fallback-in-use guard, API-key masking) in one place.
|
||||||
|
|
||||||
|
The service writes back to ``agentkit.yaml`` using the existing
|
||||||
|
:func:`_write_yaml_config` helper (which preserves ``${VAR}`` env-var
|
||||||
|
references and comments via ruamel.yaml). Concurrent writes are
|
||||||
|
serialized with :func:`fcntl.flock` to prevent corruption.
|
||||||
|
|
||||||
|
The service is a module-level singleton (see
|
||||||
|
:func:`get_llm_config_service`) so tests can inject a custom instance
|
||||||
|
via :func:`set_llm_config_service`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import fcntl
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.server.routes.settings import (
|
||||||
|
_mask_api_key,
|
||||||
|
_read_yaml_config,
|
||||||
|
_write_env_var,
|
||||||
|
_write_yaml_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#: Regex matching ``${ENV_VAR}`` references in YAML api_key fields.
|
||||||
|
_ENV_REF_RE = re.compile(r"^\$\{([^}]+)\}$")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_to_dict(name: str, pconf: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Convert a raw YAML provider dict to a masked response dict.
|
||||||
|
|
||||||
|
The ``api_key`` field is masked via :func:`_mask_api_key`. The
|
||||||
|
``models`` field is preserved as-is (it's already a dict in YAML).
|
||||||
|
"""
|
||||||
|
raw_key = pconf.get("api_key", "")
|
||||||
|
# Resolve ${VAR} references for masking — we want to show the
|
||||||
|
# masked *resolved* value, not the literal "${VAR}" string.
|
||||||
|
resolved_key = _resolve_env_var(raw_key)
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"type": pconf.get("type", "openai"),
|
||||||
|
"api_key": _mask_api_key(resolved_key),
|
||||||
|
"base_url": pconf.get("base_url", ""),
|
||||||
|
"models": pconf.get("models", {}) or {},
|
||||||
|
"max_tokens": pconf.get("max_tokens", 4096),
|
||||||
|
"timeout": pconf.get("timeout", 60.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_env_var(value: str) -> str:
|
||||||
|
"""Resolve a ``${VAR}`` reference to its env value (if set).
|
||||||
|
|
||||||
|
If ``value`` is ``${OPENAI_API_KEY}`` and ``OPENAI_API_KEY`` is set
|
||||||
|
in the environment, returns the env value. Otherwise returns the
|
||||||
|
input unchanged.
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return ""
|
||||||
|
match = _ENV_REF_RE.match(value)
|
||||||
|
if match:
|
||||||
|
var_name = match.group(1).split(":-")[0]
|
||||||
|
return os.environ.get(var_name, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _env_var_name_for_provider(name: str) -> str:
|
||||||
|
"""Derive a sensible env var name from a provider name."""
|
||||||
|
return f"{name.upper().replace('-', '_')}_API_KEY"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File-locking context manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _YamlLock:
|
||||||
|
"""Context manager that acquires an exclusive ``fcntl.flock`` on the
|
||||||
|
YAML file's sibling lockfile (``<config_path>.lock``).
|
||||||
|
|
||||||
|
The lockfile is created on demand and is not removed after release
|
||||||
|
(it's reused on subsequent writes). Holding the lock prevents
|
||||||
|
concurrent writers from corrupting the YAML file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Path) -> None:
|
||||||
|
self._lock_path = config_path.with_suffix(config_path.suffix + ".lock")
|
||||||
|
self._fh = None
|
||||||
|
|
||||||
|
def __enter__(self) -> "_YamlLock":
|
||||||
|
self._fh = open(self._lock_path, "w")
|
||||||
|
fcntl.flock(self._fh.fileno(), fcntl.LOCK_EX)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb) -> None:
|
||||||
|
if self._fh is not None:
|
||||||
|
fcntl.flock(self._fh.fileno(), fcntl.LOCK_UN)
|
||||||
|
self._fh.close()
|
||||||
|
self._fh = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Service
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class LlmConfigService:
|
||||||
|
"""CRUD for LLM providers, API keys, and fallback chains.
|
||||||
|
|
||||||
|
All methods read/write the YAML file at ``self._config_path``.
|
||||||
|
Writes are serialized via :class:`_YamlLock` (``fcntl.flock``) to
|
||||||
|
prevent concurrent corruption. The existing ``watch_config()``
|
||||||
|
mechanism detects the file change and rebuilds the
|
||||||
|
:class:`LLMGateway` — no explicit notification is needed.
|
||||||
|
|
||||||
|
API keys are NEVER returned in full — every response masks them via
|
||||||
|
:func:`_mask_api_key`. New plaintext keys are written to ``.env``
|
||||||
|
(next to the YAML) and the YAML stores a ``${ENV_VAR}`` reference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Path) -> None:
|
||||||
|
self._config_path = Path(config_path)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_providers(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return all providers with masked API keys."""
|
||||||
|
data = self._read()
|
||||||
|
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||||
|
return [_provider_to_dict(name, pconf) for name, pconf in providers.items()]
|
||||||
|
|
||||||
|
def get_provider(self, name: str) -> dict[str, Any] | None:
|
||||||
|
"""Return a single provider by name (masked key), or ``None``."""
|
||||||
|
data = self._read()
|
||||||
|
pconf = data.get("llm", {}).get("providers", {}).get(name)
|
||||||
|
if pconf is None:
|
||||||
|
return None
|
||||||
|
return _provider_to_dict(name, pconf)
|
||||||
|
|
||||||
|
def create_provider(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
provider_type: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str = "",
|
||||||
|
models: dict[str, Any] | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
timeout: float = 60.0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Add a new provider to the YAML file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Provider name (must be unique within the YAML).
|
||||||
|
provider_type: Provider type (``openai`` / ``anthropic`` /
|
||||||
|
``gemini`` / ``doubao`` / ``wenxin`` / ``yuanbao``).
|
||||||
|
api_key: Plaintext API key. Stored in ``.env`` as
|
||||||
|
``{NAME}_API_KEY``; the YAML stores ``${...}`` ref.
|
||||||
|
base_url: Optional base URL override.
|
||||||
|
models: Optional models dict (e.g. ``{"gpt-4o": {}}``).
|
||||||
|
max_tokens: Default max tokens for completions.
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The newly-created provider dict (masked key).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a provider with the same name already exists.
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
data.setdefault("llm", {}).setdefault("providers", {})
|
||||||
|
if name in data["llm"]["providers"]:
|
||||||
|
raise ValueError(f"Provider {name!r} already exists")
|
||||||
|
|
||||||
|
env_key = _env_var_name_for_provider(name)
|
||||||
|
_write_env_var(str(self._config_path), env_key, api_key)
|
||||||
|
|
||||||
|
pconf: dict[str, Any] = {
|
||||||
|
"type": provider_type,
|
||||||
|
"api_key": f"${{{env_key}}}",
|
||||||
|
"base_url": base_url,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"timeout": timeout,
|
||||||
|
}
|
||||||
|
if models is not None:
|
||||||
|
pconf["models"] = models
|
||||||
|
data["llm"]["providers"][name] = pconf
|
||||||
|
self._write(data)
|
||||||
|
|
||||||
|
created = self.get_provider(name)
|
||||||
|
assert created is not None # we just inserted it
|
||||||
|
return created
|
||||||
|
|
||||||
|
def update_provider(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
provider_type: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None,
|
||||||
|
models: dict[str, Any] | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Partially update a provider.
|
||||||
|
|
||||||
|
Only the provided fields are updated. If ``api_key`` is a masked
|
||||||
|
value (starts with ``****``), it is ignored — the existing key
|
||||||
|
is preserved.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider does not exist.
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||||
|
if name not in providers:
|
||||||
|
raise ValueError(f"Provider {name!r} not found")
|
||||||
|
|
||||||
|
pconf = providers[name]
|
||||||
|
if provider_type is not None:
|
||||||
|
pconf["type"] = provider_type
|
||||||
|
if base_url is not None:
|
||||||
|
pconf["base_url"] = base_url
|
||||||
|
if models is not None:
|
||||||
|
pconf["models"] = models
|
||||||
|
if max_tokens is not None:
|
||||||
|
pconf["max_tokens"] = max_tokens
|
||||||
|
if timeout is not None:
|
||||||
|
pconf["timeout"] = timeout
|
||||||
|
if api_key is not None and not api_key.startswith("****"):
|
||||||
|
# New plaintext key — write to .env, keep ${VAR} ref in YAML.
|
||||||
|
existing_key = str(pconf.get("api_key", ""))
|
||||||
|
env_match = _ENV_REF_RE.match(existing_key)
|
||||||
|
if env_match:
|
||||||
|
env_key = env_match.group(1).split(":-")[0]
|
||||||
|
else:
|
||||||
|
env_key = _env_var_name_for_provider(name)
|
||||||
|
_write_env_var(str(self._config_path), env_key, api_key)
|
||||||
|
pconf["api_key"] = f"${{{env_key}}}"
|
||||||
|
|
||||||
|
providers[name] = pconf
|
||||||
|
data.setdefault("llm", {})["providers"] = providers
|
||||||
|
self._write(data)
|
||||||
|
|
||||||
|
updated = self.get_provider(name)
|
||||||
|
assert updated is not None
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def delete_provider(self, name: str) -> bool:
|
||||||
|
"""Remove a provider from the YAML file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if the provider was deleted.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provider does not exist, or if it is
|
||||||
|
referenced by any fallback chain (removing it would
|
||||||
|
break the chain).
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||||
|
if name not in providers:
|
||||||
|
raise ValueError(f"Provider {name!r} not found")
|
||||||
|
|
||||||
|
# Guard: refuse to delete if referenced in any fallback chain.
|
||||||
|
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||||
|
for model, chain in fallbacks.items():
|
||||||
|
if isinstance(chain, list) and name in chain:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider {name!r} is used in fallback chain for model "
|
||||||
|
f"{model!r}; remove the fallback first"
|
||||||
|
)
|
||||||
|
|
||||||
|
del providers[name]
|
||||||
|
data.setdefault("llm", {})["providers"] = providers
|
||||||
|
self._write(data)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# API key management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def set_api_key(self, provider_name: str, api_key: str) -> dict[str, Any]:
|
||||||
|
"""Set the API key for a provider.
|
||||||
|
|
||||||
|
The plaintext key is written to ``.env`` (as
|
||||||
|
``{NAME}_API_KEY``) and the YAML stores a ``${...}`` reference.
|
||||||
|
If the provider does not yet exist in the YAML, a stub entry is
|
||||||
|
created with ``type=openai`` and default settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated provider dict (masked key).
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
data.setdefault("llm", {}).setdefault("providers", {})
|
||||||
|
providers = data["llm"]["providers"]
|
||||||
|
pconf = providers.get(provider_name, {})
|
||||||
|
existing_key = str(pconf.get("api_key", ""))
|
||||||
|
env_match = _ENV_REF_RE.match(existing_key)
|
||||||
|
if env_match:
|
||||||
|
env_key = env_match.group(1).split(":-")[0]
|
||||||
|
else:
|
||||||
|
env_key = _env_var_name_for_provider(provider_name)
|
||||||
|
_write_env_var(str(self._config_path), env_key, api_key)
|
||||||
|
pconf["api_key"] = f"${{{env_key}}}"
|
||||||
|
if "type" not in pconf:
|
||||||
|
pconf["type"] = "openai"
|
||||||
|
providers[provider_name] = pconf
|
||||||
|
self._write(data)
|
||||||
|
|
||||||
|
updated = self.get_provider(provider_name)
|
||||||
|
assert updated is not None
|
||||||
|
return updated
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Fallback chain management
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_fallbacks(self) -> dict[str, list[str]]:
|
||||||
|
"""Return the fallback chains (model → list of provider/model refs)."""
|
||||||
|
data = self._read()
|
||||||
|
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||||
|
# Normalize: ensure all values are lists.
|
||||||
|
return {
|
||||||
|
str(model): list(chain) if isinstance(chain, list) else [str(chain)]
|
||||||
|
for model, chain in fallbacks.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_fallback(self, model: str, chain: list[str]) -> dict[str, Any]:
|
||||||
|
"""Set the fallback chain for a model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``{"model": model, "chain": chain}`` dict.
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
data.setdefault("llm", {}).setdefault("fallbacks", {})
|
||||||
|
data["llm"]["fallbacks"][model] = list(chain)
|
||||||
|
self._write(data)
|
||||||
|
return {"model": model, "chain": list(chain)}
|
||||||
|
|
||||||
|
def delete_fallback(self, model: str) -> bool:
|
||||||
|
"""Delete the fallback chain for a model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if a chain was deleted, ``False`` if no chain
|
||||||
|
existed for ``model``.
|
||||||
|
"""
|
||||||
|
with _YamlLock(self._config_path):
|
||||||
|
data = self._read()
|
||||||
|
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||||
|
if model not in fallbacks:
|
||||||
|
return False
|
||||||
|
del fallbacks[model]
|
||||||
|
data.setdefault("llm", {})["fallbacks"] = fallbacks
|
||||||
|
self._write(data)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _read(self) -> dict[str, Any]:
|
||||||
|
"""Read the YAML config (returns ``{}`` if file is missing/empty)."""
|
||||||
|
if not self._config_path.exists():
|
||||||
|
return {}
|
||||||
|
return _read_yaml_config(str(self._config_path))
|
||||||
|
|
||||||
|
def _write(self, data: dict[str, Any]) -> None:
|
||||||
|
"""Write the full config dict back to the YAML file."""
|
||||||
|
_write_yaml_config(str(self._config_path), data)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_llm_config_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_llm_config_service: LlmConfigService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_config_service(config_path: Path | str | None = None) -> LlmConfigService:
|
||||||
|
"""Return the process-wide :class:`LlmConfigService`.
|
||||||
|
|
||||||
|
On first call, ``config_path`` must be provided (or the service
|
||||||
|
will raise ``ValueError``). Subsequent calls ignore ``config_path``
|
||||||
|
and return the cached singleton.
|
||||||
|
"""
|
||||||
|
global _llm_config_service
|
||||||
|
if _llm_config_service is None:
|
||||||
|
if config_path is None:
|
||||||
|
raise ValueError("config_path is required to initialize LlmConfigService on first call")
|
||||||
|
_llm_config_service = LlmConfigService(Path(config_path))
|
||||||
|
return _llm_config_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_llm_config_service(service: LlmConfigService | None) -> None:
|
||||||
|
"""Inject a custom :class:`LlmConfigService` (used by tests)."""
|
||||||
|
global _llm_config_service
|
||||||
|
_llm_config_service = service
|
||||||
|
|
||||||
|
|
||||||
|
# Re-export yaml for tests that need to construct sample config files.
|
||||||
|
__all__ = [
|
||||||
|
"LlmConfigService",
|
||||||
|
"get_llm_config_service",
|
||||||
|
"set_llm_config_service",
|
||||||
|
"yaml",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,349 @@
|
||||||
|
"""QuotaService — per-department LLM quota CRUD + enforcement (U5/U7).
|
||||||
|
|
||||||
|
This module is the single owner of the ``department_quotas`` table.
|
||||||
|
Web UI routes (``/api/v1/admin/departments/{id}/quotas``) and the CLI
|
||||||
|
``agentkit admin llm set-quota`` sub-app both call into
|
||||||
|
:class:`QuotaService` rather than touching the table directly, keeping
|
||||||
|
the validation rules (quota_type/period enums, JSON serialization of
|
||||||
|
``model_whitelist``) in one place.
|
||||||
|
|
||||||
|
Quota semantics
|
||||||
|
---------------
|
||||||
|
- ``token_limit`` (int): max tokens per ``period``. ``check_quota``
|
||||||
|
compares ``current_usage`` (tokens used) against the limit.
|
||||||
|
- ``cost_limit`` (int, in cents): max spend per ``period``.
|
||||||
|
``check_quota`` compares ``current_usage`` (cost in cents) against
|
||||||
|
the limit.
|
||||||
|
- ``model_whitelist`` (list[str]): allowed model names. Not a quota
|
||||||
|
check — it's a model-access check (use :meth:`is_model_allowed`).
|
||||||
|
|
||||||
|
The service is a module-level singleton (see :func:`get_quota_service`)
|
||||||
|
so tests can inject a custom instance via :func:`set_quota_service`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constants
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#: Valid ``quota_type`` values.
|
||||||
|
QUOTA_TYPES: frozenset[str] = frozenset({"token_limit", "cost_limit", "model_whitelist"})
|
||||||
|
|
||||||
|
#: Valid ``period`` values.
|
||||||
|
PERIODS: frozenset[str] = frozenset({"daily", "monthly"})
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
"""Return current UTC time as ISO 8601 string."""
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _new_id() -> str:
|
||||||
|
"""Return a new UUID4 string."""
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_limit_value(quota_type: str, limit_value: Any) -> str:
|
||||||
|
"""Serialize ``limit_value`` to a JSON string for storage.
|
||||||
|
|
||||||
|
For ``token_limit`` / ``cost_limit``: ``limit_value`` is an int,
|
||||||
|
stored as ``json.dumps(int)`` (e.g. ``"1000"``).
|
||||||
|
For ``model_whitelist``: ``limit_value`` is a ``list[str]``, stored
|
||||||
|
as ``json.dumps(list)`` (e.g. ``'["gpt-4o", "claude"]'``).
|
||||||
|
"""
|
||||||
|
if quota_type == "model_whitelist":
|
||||||
|
if not isinstance(limit_value, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"model_whitelist limit_value must be a list, got {type(limit_value).__name__}"
|
||||||
|
)
|
||||||
|
return json.dumps([str(m) for m in limit_value])
|
||||||
|
if not isinstance(limit_value, int):
|
||||||
|
raise ValueError(
|
||||||
|
f"{quota_type} limit_value must be an int, got {type(limit_value).__name__}"
|
||||||
|
)
|
||||||
|
return json.dumps(limit_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_limit_value(quota_type: str, raw: str) -> Any:
|
||||||
|
"""Deserialize ``limit_value`` from JSON string back to native type."""
|
||||||
|
try:
|
||||||
|
value = json.loads(raw)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.warning("Failed to deserialize limit_value %r: %s", raw, exc)
|
||||||
|
return None
|
||||||
|
if quota_type == "model_whitelist":
|
||||||
|
return list(value) if isinstance(value, list) else []
|
||||||
|
return int(value) if isinstance(value, (int, float)) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_quota_type(quota_type: str) -> None:
|
||||||
|
if quota_type not in QUOTA_TYPES:
|
||||||
|
raise ValueError(f"Invalid quota_type {quota_type!r}; must be one of {sorted(QUOTA_TYPES)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_period(period: str) -> None:
|
||||||
|
if period not in PERIODS:
|
||||||
|
raise ValueError(f"Invalid period {period!r}; must be one of {sorted(PERIODS)}")
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaService:
|
||||||
|
"""CRUD + enforcement for per-department LLM quotas.
|
||||||
|
|
||||||
|
All methods are async and take ``db_path: Path`` as the first
|
||||||
|
argument (after ``self``). Each method opens its own short-lived
|
||||||
|
:class:`aiosqlite.Connection` — there is no shared connection state,
|
||||||
|
which keeps the service safe to call from any async context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Quota CRUD
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def set_quota(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
limit_value: Any,
|
||||||
|
period: str = "daily",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Upsert a quota for a department.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the auth SQLite DB.
|
||||||
|
department_id: Department id.
|
||||||
|
quota_type: One of ``token_limit`` / ``cost_limit`` /
|
||||||
|
``model_whitelist``.
|
||||||
|
limit_value: ``int`` for token/cost limits, ``list[str]``
|
||||||
|
for model_whitelist.
|
||||||
|
period: ``daily`` or ``monthly`` (default ``daily``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The upserted quota as a dict (with deserialized
|
||||||
|
``limit_value``).
|
||||||
|
"""
|
||||||
|
_validate_quota_type(quota_type)
|
||||||
|
_validate_period(period)
|
||||||
|
serialized = _serialize_limit_value(quota_type, limit_value)
|
||||||
|
now = _now_iso()
|
||||||
|
quota_id = _new_id()
|
||||||
|
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
# Upsert: if (department_id, quota_type, period) exists,
|
||||||
|
# update limit_value + updated_at; otherwise insert.
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO department_quotas "
|
||||||
|
"(id, department_id, quota_type, limit_value, period, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, ?, ?, ?) "
|
||||||
|
"ON CONFLICT(department_id, quota_type, period) DO UPDATE SET "
|
||||||
|
"limit_value = excluded.limit_value, "
|
||||||
|
"updated_at = excluded.updated_at",
|
||||||
|
(quota_id, department_id, quota_type, serialized, period, now),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": quota_id,
|
||||||
|
"department_id": department_id,
|
||||||
|
"quota_type": quota_type,
|
||||||
|
"limit_value": _deserialize_limit_value(quota_type, serialized),
|
||||||
|
"period": period,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_quota(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
period: str = "daily",
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return a single quota, or ``None`` if not set."""
|
||||||
|
_validate_quota_type(quota_type)
|
||||||
|
_validate_period(period)
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM department_quotas "
|
||||||
|
"WHERE department_id = ? AND quota_type = ? AND period = ?",
|
||||||
|
(department_id, quota_type, period),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
|
async def list_department_quotas(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List all quotas for a department (all types/periods)."""
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM department_quotas WHERE department_id = ? "
|
||||||
|
"ORDER BY quota_type ASC, period ASC",
|
||||||
|
(department_id,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [self._row_to_dict(row) for row in rows]
|
||||||
|
|
||||||
|
async def delete_quota(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
period: str = "daily",
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a quota. Returns ``True`` if a row was deleted."""
|
||||||
|
_validate_quota_type(quota_type)
|
||||||
|
_validate_period(period)
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
"DELETE FROM department_quotas "
|
||||||
|
"WHERE department_id = ? AND quota_type = ? AND period = ?",
|
||||||
|
(department_id, quota_type, period),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Quota enforcement
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def check_quota(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
quota_type: str,
|
||||||
|
period: str,
|
||||||
|
current_usage: int | float,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Check whether ``current_usage`` is within the configured quota.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the auth SQLite DB.
|
||||||
|
department_id: Department id.
|
||||||
|
quota_type: ``token_limit`` or ``cost_limit``.
|
||||||
|
``model_whitelist`` is not a quota check — use
|
||||||
|
:meth:`is_model_allowed` instead.
|
||||||
|
period: ``daily`` or ``monthly``.
|
||||||
|
current_usage: Current usage value (tokens for
|
||||||
|
``token_limit``, cents for ``cost_limit``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(allowed, reason)`` tuple. ``allowed`` is ``True`` if the
|
||||||
|
usage is within the limit (or no quota is set). ``reason``
|
||||||
|
is ``"ok"`` when allowed, or a human-readable denial reason.
|
||||||
|
"""
|
||||||
|
_validate_quota_type(quota_type)
|
||||||
|
_validate_period(period)
|
||||||
|
if quota_type == "model_whitelist":
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"model_whitelist is not a quota check; use is_model_allowed()",
|
||||||
|
)
|
||||||
|
|
||||||
|
quota = await self.get_quota(db_path, department_id, quota_type, period)
|
||||||
|
if quota is None:
|
||||||
|
return (True, "ok")
|
||||||
|
|
||||||
|
limit = quota["limit_value"]
|
||||||
|
if not isinstance(limit, int):
|
||||||
|
return (True, "ok") # malformed limit; allow
|
||||||
|
|
||||||
|
if current_usage >= limit:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"{quota_type} ({period}) exceeded: usage={current_usage}, limit={limit}",
|
||||||
|
)
|
||||||
|
return (True, "ok")
|
||||||
|
|
||||||
|
async def is_model_allowed(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
department_id: str,
|
||||||
|
model: str,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Check whether ``model`` is in the department's whitelist.
|
||||||
|
|
||||||
|
If no ``model_whitelist`` quota is set for the department (any
|
||||||
|
period), all models are allowed. Otherwise the model must appear
|
||||||
|
in the whitelist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(allowed, reason)`` tuple.
|
||||||
|
"""
|
||||||
|
async with aiosqlite.connect(str(db_path)) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
cursor = await db.execute(
|
||||||
|
"SELECT * FROM department_quotas "
|
||||||
|
"WHERE department_id = ? AND quota_type = 'model_whitelist' "
|
||||||
|
"ORDER BY period ASC LIMIT 1",
|
||||||
|
(department_id,),
|
||||||
|
)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return (True, "ok")
|
||||||
|
whitelist = _deserialize_limit_value("model_whitelist", row["limit_value"])
|
||||||
|
if not isinstance(whitelist, list):
|
||||||
|
return (True, "ok")
|
||||||
|
if model in whitelist:
|
||||||
|
return (True, "ok")
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
f"model {model!r} not in department {department_id!r} whitelist",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _row_to_dict(row: aiosqlite.Row | Any) -> dict[str, Any]:
|
||||||
|
"""Convert a ``department_quotas`` row to a JSON-safe dict."""
|
||||||
|
quota_type = row["quota_type"]
|
||||||
|
return {
|
||||||
|
"id": row["id"],
|
||||||
|
"department_id": row["department_id"],
|
||||||
|
"quota_type": quota_type,
|
||||||
|
"limit_value": _deserialize_limit_value(quota_type, row["limit_value"]),
|
||||||
|
"period": row["period"],
|
||||||
|
"updated_at": row["updated_at"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton (overridable in tests via set_quota_service)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_quota_service: QuotaService | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_quota_service() -> QuotaService:
|
||||||
|
"""Return the process-wide :class:`QuotaService` (lazy singleton)."""
|
||||||
|
global _quota_service
|
||||||
|
if _quota_service is None:
|
||||||
|
_quota_service = QuotaService()
|
||||||
|
return _quota_service
|
||||||
|
|
||||||
|
|
||||||
|
def set_quota_service(service: QuotaService | None) -> None:
|
||||||
|
"""Inject a custom :class:`QuotaService` (used by tests)."""
|
||||||
|
global _quota_service
|
||||||
|
_quota_service = service
|
||||||
|
|
@ -23,6 +23,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from agentkit.server.admin.department_service import get_department_service
|
from agentkit.server.admin.department_service import get_department_service
|
||||||
|
from agentkit.server.admin.llm_config_service import (
|
||||||
|
LlmConfigService,
|
||||||
|
get_llm_config_service,
|
||||||
|
)
|
||||||
|
from agentkit.server.admin.quota_service import get_quota_service
|
||||||
from agentkit.server.admin.user_service import get_user_service
|
from agentkit.server.admin.user_service import get_user_service
|
||||||
from agentkit.server.auth.dependencies import require_authenticated
|
from agentkit.server.auth.dependencies import require_authenticated
|
||||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
||||||
|
|
@ -550,3 +555,302 @@ async def list_user_departments(
|
||||||
db_path = await _ensure_db(request)
|
db_path = await _ensure_db(request)
|
||||||
svc = get_user_service()
|
svc = get_user_service()
|
||||||
return await svc.list_user_departments(db_path, user_id)
|
return await svc.list_user_departments(db_path, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# LLM config endpoints (U5) — provider CRUD + fallback chains
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_config_service(request: Request) -> LlmConfigService:
|
||||||
|
"""Resolve the :class:`LlmConfigService` from ``app.state`` or the
|
||||||
|
module singleton.
|
||||||
|
|
||||||
|
The service needs the YAML config path, which is read from
|
||||||
|
``app.state.server_config._config_path`` (same source as the
|
||||||
|
existing ``GET/PUT /settings/llm`` endpoints). Falls back to the
|
||||||
|
module singleton (which tests can pre-populate via
|
||||||
|
:func:`set_llm_config_service`).
|
||||||
|
"""
|
||||||
|
server_config = getattr(request.app.state, "server_config", None)
|
||||||
|
if server_config is not None and getattr(server_config, "_config_path", None):
|
||||||
|
try:
|
||||||
|
return get_llm_config_service(server_config._config_path)
|
||||||
|
except ValueError:
|
||||||
|
# Singleton was reset between calls — re-initialize with the
|
||||||
|
# resolved path.
|
||||||
|
from agentkit.server.admin.llm_config_service import set_llm_config_service
|
||||||
|
|
||||||
|
svc = LlmConfigService(Path(server_config._config_path))
|
||||||
|
set_llm_config_service(svc)
|
||||||
|
return svc
|
||||||
|
# Fall back to the existing singleton (tests inject it directly).
|
||||||
|
return get_llm_config_service()
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCreateRequest(BaseModel):
|
||||||
|
"""Body for ``POST /admin/llm/providers``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: str = "openai"
|
||||||
|
api_key: str
|
||||||
|
base_url: str = ""
|
||||||
|
models: dict[str, Any] | None = None
|
||||||
|
max_tokens: int = 4096
|
||||||
|
timeout: float = 60.0
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderUpdateRequest(BaseModel):
|
||||||
|
"""Body for ``PATCH /admin/llm/providers/{name}``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
type: str | None = None
|
||||||
|
api_key: str | None = None
|
||||||
|
base_url: str | None = None
|
||||||
|
models: dict[str, Any] | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
timeout: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeySetRequest(BaseModel):
|
||||||
|
"""Body for ``POST /admin/llm/providers/{name}/api-key``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class FallbackSetRequest(BaseModel):
|
||||||
|
"""Body for ``PUT /admin/llm/fallbacks/{model}``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
chain: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/llm/providers")
|
||||||
|
async def list_llm_providers(
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List all LLM providers (API keys masked)."""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
return svc.list_providers()
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/llm/providers", status_code=201)
|
||||||
|
async def create_llm_provider(
|
||||||
|
payload: ProviderCreateRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create a new LLM provider.
|
||||||
|
|
||||||
|
Returns 201 with the provider dict (masked key) on success, 409 if
|
||||||
|
a provider with the same name already exists.
|
||||||
|
"""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
try:
|
||||||
|
return svc.create_provider(
|
||||||
|
name=payload.name,
|
||||||
|
provider_type=payload.type,
|
||||||
|
api_key=payload.api_key,
|
||||||
|
base_url=payload.base_url,
|
||||||
|
models=payload.models,
|
||||||
|
max_tokens=payload.max_tokens,
|
||||||
|
timeout=payload.timeout,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.patch("/llm/providers/{name}")
|
||||||
|
async def update_llm_provider(
|
||||||
|
name: str,
|
||||||
|
payload: ProviderUpdateRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Partially update an LLM provider.
|
||||||
|
|
||||||
|
Returns 200 with the updated provider dict (masked key), 404 if the
|
||||||
|
provider does not exist.
|
||||||
|
"""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
try:
|
||||||
|
return svc.update_provider(
|
||||||
|
name=name,
|
||||||
|
provider_type=payload.type,
|
||||||
|
api_key=payload.api_key,
|
||||||
|
base_url=payload.base_url,
|
||||||
|
models=payload.models,
|
||||||
|
max_tokens=payload.max_tokens,
|
||||||
|
timeout=payload.timeout,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/llm/providers/{name}")
|
||||||
|
async def delete_llm_provider(
|
||||||
|
name: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Delete an LLM provider.
|
||||||
|
|
||||||
|
Returns 200 ``{deleted: true}`` on success, 404 if the provider
|
||||||
|
does not exist, 400 if the provider is used in a fallback chain.
|
||||||
|
"""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
try:
|
||||||
|
deleted = svc.delete_provider(name)
|
||||||
|
except ValueError as exc:
|
||||||
|
msg = str(exc)
|
||||||
|
if "not found" in msg:
|
||||||
|
raise HTTPException(status_code=404, detail=msg) from exc
|
||||||
|
raise HTTPException(status_code=400, detail=msg) from exc
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Provider not found")
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/llm/providers/{name}/api-key")
|
||||||
|
async def set_llm_provider_api_key(
|
||||||
|
name: str,
|
||||||
|
payload: ApiKeySetRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Set the API key for a provider.
|
||||||
|
|
||||||
|
The plaintext key is written to ``.env`` (not the YAML file); the
|
||||||
|
YAML stores a ``${ENV_VAR}`` reference. Returns 200 with the
|
||||||
|
updated provider dict (masked key).
|
||||||
|
"""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
return svc.set_api_key(name, payload.api_key)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/llm/fallbacks")
|
||||||
|
async def get_llm_fallbacks(
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, list[str]]:
|
||||||
|
"""Return all fallback chains (model → list of provider/model refs)."""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
return svc.get_fallbacks()
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.put("/llm/fallbacks/{model}")
|
||||||
|
async def set_llm_fallback(
|
||||||
|
model: str,
|
||||||
|
payload: FallbackSetRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Set the fallback chain for a model."""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
return svc.set_fallback(model, payload.chain)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/llm/fallbacks/{model}")
|
||||||
|
async def delete_llm_fallback(
|
||||||
|
model: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Delete the fallback chain for a model.
|
||||||
|
|
||||||
|
Returns 200 ``{deleted: true}`` if a chain was deleted, 404 if no
|
||||||
|
chain existed for ``model``.
|
||||||
|
"""
|
||||||
|
svc = _get_llm_config_service(request)
|
||||||
|
deleted = svc.delete_fallback(model)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Fallback chain not found")
|
||||||
|
return {"deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Department quota endpoints (U5) — per-department LLM quotas
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaSetRequest(BaseModel):
|
||||||
|
"""Body for ``PUT /admin/departments/{id}/quotas``."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
quota_type: str
|
||||||
|
limit_value: int | list[str]
|
||||||
|
period: str = "daily"
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/departments/{department_id}/quotas")
|
||||||
|
async def list_department_quotas(
|
||||||
|
department_id: str,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List all quotas for a department."""
|
||||||
|
db_path = await _ensure_db(request)
|
||||||
|
svc = get_quota_service()
|
||||||
|
return await svc.list_department_quotas(db_path, department_id)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.put("/departments/{department_id}/quotas")
|
||||||
|
async def set_department_quota(
|
||||||
|
department_id: str,
|
||||||
|
payload: QuotaSetRequest,
|
||||||
|
request: Request,
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Set (upsert) a quota for a department.
|
||||||
|
|
||||||
|
Returns 200 with the upserted quota dict. Returns 400 if
|
||||||
|
``quota_type`` or ``period`` is invalid, or if ``limit_value``
|
||||||
|
doesn't match the quota type (int for token/cost limits, list for
|
||||||
|
model_whitelist).
|
||||||
|
"""
|
||||||
|
db_path = await _ensure_db(request)
|
||||||
|
svc = get_quota_service()
|
||||||
|
try:
|
||||||
|
return await svc.set_quota(
|
||||||
|
db_path,
|
||||||
|
department_id,
|
||||||
|
payload.quota_type,
|
||||||
|
payload.limit_value,
|
||||||
|
payload.period,
|
||||||
|
)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/departments/{department_id}/quotas")
|
||||||
|
async def delete_department_quota(
|
||||||
|
department_id: str,
|
||||||
|
request: Request,
|
||||||
|
quota_type: str,
|
||||||
|
period: str = "daily",
|
||||||
|
admin: dict[str, Any] = Depends(_require_admin),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Delete a quota for a department.
|
||||||
|
|
||||||
|
Query params: ``quota_type`` (required), ``period`` (default
|
||||||
|
``daily``). Returns 200 ``{deleted: true}`` if a quota was deleted,
|
||||||
|
404 if no quota matched, 400 if ``quota_type`` or ``period`` is
|
||||||
|
invalid.
|
||||||
|
"""
|
||||||
|
db_path = await _ensure_db(request)
|
||||||
|
svc = get_quota_service()
|
||||||
|
try:
|
||||||
|
deleted = await svc.delete_quota(db_path, department_id, quota_type, period)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Quota not found")
|
||||||
|
return {"deleted": True}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,464 @@
|
||||||
|
"""Integration tests for the LLM config admin routes (U5).
|
||||||
|
|
||||||
|
Uses FastAPI TestClient with a test app that mounts only the
|
||||||
|
``admin_router`` from ``routes.admin``. The ``_require_admin`` dependency
|
||||||
|
is overridden via ``app.dependency_overrides`` so the tests don't need
|
||||||
|
real JWTs — they can simulate admin and non-admin callers directly.
|
||||||
|
|
||||||
|
The :class:`LlmConfigService` is injected via
|
||||||
|
:func:`set_llm_config_service` so the tests can point it at a temp
|
||||||
|
YAML file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.server.admin.llm_config_service import (
|
||||||
|
LlmConfigService,
|
||||||
|
set_llm_config_service,
|
||||||
|
)
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
from agentkit.server.routes import admin as admin_routes_module
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_config() -> dict[str, Any]:
|
||||||
|
"""A minimal agentkit.yaml-style config for testing."""
|
||||||
|
return {
|
||||||
|
"server": {"host": "0.0.0.0", "port": 8001},
|
||||||
|
"llm": {
|
||||||
|
"providers": {
|
||||||
|
"openai": {
|
||||||
|
"type": "openai",
|
||||||
|
"api_key": "sk-test-12345678",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"models": {"gpt-4o": {}},
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"timeout": 120.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
||||||
|
"fallbacks": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_path(tmp_path: Path) -> Path:
|
||||||
|
"""Create a temporary agentkit.yaml config file."""
|
||||||
|
path = tmp_path / "agentkit.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||||
|
db_path = tmp_path / "admin_llm.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_llm_config_singleton():
|
||||||
|
"""Reset the LlmConfigService singleton before and after each test."""
|
||||||
|
set_llm_config_service(None)
|
||||||
|
yield
|
||||||
|
set_llm_config_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_app(config_path: Path, tmp_auth_db: Path) -> FastAPI:
|
||||||
|
"""A minimal FastAPI app with only the admin router mounted.
|
||||||
|
|
||||||
|
The ``_require_admin`` dependency is overridden to return a fake
|
||||||
|
admin user. The :class:`LlmConfigService` singleton is pre-populated
|
||||||
|
so the routes use the temp YAML file.
|
||||||
|
"""
|
||||||
|
# Inject the LlmConfigService singleton pointing at the temp YAML.
|
||||||
|
set_llm_config_service(LlmConfigService(config_path))
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.state.auth_db_path = str(tmp_auth_db)
|
||||||
|
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# Default: allow admin access.
|
||||||
|
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_client(admin_app: FastAPI) -> TestClient:
|
||||||
|
return TestClient(admin_app)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_admin_user() -> dict[str, Any]:
|
||||||
|
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_forbidden() -> dict[str, Any]:
|
||||||
|
"""Dependency override that simulates a non-admin (403) response."""
|
||||||
|
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provider CRUD
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestListProviders:
|
||||||
|
def test_list_returns_existing_providers(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
providers = resp.json()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0]["name"] == "openai"
|
||||||
|
|
||||||
|
def test_list_masks_api_keys(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||||
|
providers = resp.json()
|
||||||
|
assert providers[0]["api_key"].startswith("****")
|
||||||
|
# Should not contain the plaintext key.
|
||||||
|
assert "sk-test-12345678" not in providers[0]["api_key"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateProvider:
|
||||||
|
def test_create_returns_201(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/llm/providers",
|
||||||
|
json={
|
||||||
|
"name": "anthropic",
|
||||||
|
"type": "anthropic",
|
||||||
|
"api_key": "sk-ant-test-abcd1234",
|
||||||
|
"base_url": "https://api.anthropic.com",
|
||||||
|
"models": {"claude-sonnet-4-20250514": {}},
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"timeout": 90.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
body = resp.json()
|
||||||
|
assert body["name"] == "anthropic"
|
||||||
|
assert body["type"] == "anthropic"
|
||||||
|
# API key should be masked.
|
||||||
|
assert body["api_key"].startswith("****")
|
||||||
|
|
||||||
|
def test_create_duplicate_returns_409(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/llm/providers",
|
||||||
|
json={
|
||||||
|
"name": "openai",
|
||||||
|
"type": "openai",
|
||||||
|
"api_key": "sk-duplicate",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_non_admin_returns_403(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/admin/llm/providers",
|
||||||
|
json={"name": "x", "type": "openai", "api_key": "sk-x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateProvider:
|
||||||
|
def test_update_partial_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.patch(
|
||||||
|
"/api/v1/admin/llm/providers/openai",
|
||||||
|
json={"max_tokens": 2048},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["max_tokens"] == 2048
|
||||||
|
# Other fields should be preserved.
|
||||||
|
assert body["base_url"] == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
def test_update_unknown_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.patch(
|
||||||
|
"/api/v1/admin/llm/providers/nonexistent",
|
||||||
|
json={"max_tokens": 2048},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_update_with_masked_api_key_preserves_existing(
|
||||||
|
self, admin_client: TestClient, config_path: Path
|
||||||
|
):
|
||||||
|
resp = admin_client.patch(
|
||||||
|
"/api/v1/admin/llm/providers/openai",
|
||||||
|
json={"api_key": "****5678"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# Verify the YAML file still has the original key.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteProvider:
|
||||||
|
def test_delete_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.delete("/api/v1/admin/llm/providers/openai")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"deleted": True}
|
||||||
|
|
||||||
|
def test_delete_unknown_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.delete("/api/v1/admin/llm/providers/nonexistent")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_provider_used_in_fallback_returns_400(self, admin_client: TestClient):
|
||||||
|
# Add a second provider and a fallback chain that references it.
|
||||||
|
admin_client.post(
|
||||||
|
"/api/v1/admin/llm/providers",
|
||||||
|
json={
|
||||||
|
"name": "anthropic",
|
||||||
|
"type": "anthropic",
|
||||||
|
"api_key": "sk-ant-test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
admin_client.put(
|
||||||
|
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||||
|
json={"chain": ["anthropic"]},
|
||||||
|
)
|
||||||
|
resp = admin_client.delete("/api/v1/admin/llm/providers/anthropic")
|
||||||
|
assert resp.status_code == 400
|
||||||
|
assert "fallback" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# API key management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetApiKey:
|
||||||
|
def test_set_api_key_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.post(
|
||||||
|
"/api/v1/admin/llm/providers/openai/api-key",
|
||||||
|
json={"api_key": "sk-brand-new-key-1234"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["api_key"].startswith("****")
|
||||||
|
|
||||||
|
def test_set_api_key_writes_to_env_file(self, admin_client: TestClient, config_path: Path):
|
||||||
|
admin_client.post(
|
||||||
|
"/api/v1/admin/llm/providers/openai/api-key",
|
||||||
|
json={"api_key": "sk-brand-new-key-1234"},
|
||||||
|
)
|
||||||
|
env_path = config_path.parent / ".env"
|
||||||
|
assert env_path.exists()
|
||||||
|
with open(env_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fallback chain management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackRoutes:
|
||||||
|
def test_get_fallbacks_returns_empty_when_none(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {}
|
||||||
|
|
||||||
|
def test_set_fallback_returns_200(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.put(
|
||||||
|
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||||
|
json={"chain": ["openai", "anthropic"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["model"] == "gpt-4o"
|
||||||
|
assert body["chain"] == ["openai", "anthropic"]
|
||||||
|
|
||||||
|
def test_get_fallbacks_returns_set_chain(self, admin_client: TestClient):
|
||||||
|
admin_client.put(
|
||||||
|
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||||
|
json={"chain": ["openai", "anthropic"]},
|
||||||
|
)
|
||||||
|
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"gpt-4o": ["openai", "anthropic"]}
|
||||||
|
|
||||||
|
def test_delete_fallback_returns_200(self, admin_client: TestClient):
|
||||||
|
admin_client.put(
|
||||||
|
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||||
|
json={"chain": ["openai"]},
|
||||||
|
)
|
||||||
|
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/gpt-4o")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"deleted": True}
|
||||||
|
|
||||||
|
def test_delete_unknown_fallback_returns_404(self, admin_client: TestClient):
|
||||||
|
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/nonexistent")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Department quota endpoints
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDepartmentQuotaRoutes:
|
||||||
|
def test_set_quota_returns_200(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={
|
||||||
|
"quota_type": "token_limit",
|
||||||
|
"limit_value": 1000,
|
||||||
|
"period": "daily",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["quota_type"] == "token_limit"
|
||||||
|
assert body["limit_value"] == 1000
|
||||||
|
assert body["period"] == "daily"
|
||||||
|
|
||||||
|
def test_set_quota_invalid_type_returns_400(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={"quota_type": "invalid", "limit_value": 1000},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_set_quota_model_whitelist_accepts_list(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={
|
||||||
|
"quota_type": "model_whitelist",
|
||||||
|
"limit_value": ["gpt-4o", "claude"],
|
||||||
|
"period": "daily",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["limit_value"] == ["gpt-4o", "claude"]
|
||||||
|
|
||||||
|
def test_list_quotas_returns_empty_for_new_department(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
def test_list_quotas_returns_set_quotas(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||||
|
)
|
||||||
|
admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={
|
||||||
|
"quota_type": "cost_limit",
|
||||||
|
"limit_value": 5000,
|
||||||
|
"period": "monthly",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert len(body) == 2
|
||||||
|
types = {q["quota_type"] for q in body}
|
||||||
|
assert types == {"token_limit", "cost_limit"}
|
||||||
|
|
||||||
|
def test_delete_quota_returns_200(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
admin_client.put(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||||
|
)
|
||||||
|
resp = admin_client.delete(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
params={"quota_type": "token_limit", "period": "daily"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"deleted": True}
|
||||||
|
|
||||||
|
def test_delete_unknown_quota_returns_404(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.delete(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
params={"quota_type": "token_limit", "period": "daily"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_quota_invalid_period_returns_400(self, admin_client: TestClient):
|
||||||
|
dept_id = str(uuid.uuid4())
|
||||||
|
resp = admin_client.delete(
|
||||||
|
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||||
|
params={"quota_type": "token_limit", "period": "weekly"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Non-admin access
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNonAdminAccess:
|
||||||
|
"""All LLM config endpoints must return 403 for non-admin users."""
|
||||||
|
|
||||||
|
def test_non_admin_cannot_list_providers(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.get("/api/v1/admin/llm/providers")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_create_provider(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/admin/llm/providers",
|
||||||
|
json={"name": "x", "type": "openai", "api_key": "sk-x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_delete_provider(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.delete("/api/v1/admin/llm/providers/openai")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_set_api_key(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/admin/llm/providers/openai/api-key",
|
||||||
|
json={"api_key": "sk-x"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_get_fallbacks(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.get("/api/v1/admin/llm/fallbacks")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_non_admin_cannot_set_quota(self, admin_app: FastAPI):
|
||||||
|
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||||
|
client = TestClient(admin_app)
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/admin/departments/{uuid.uuid4()}/quotas",
|
||||||
|
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
@ -0,0 +1,410 @@
|
||||||
|
"""Unit tests for LlmConfigService (U5 — LLM provider/fallback CRUD).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Happy path: list → get → create → update → delete providers
|
||||||
|
- API key management: set_api_key writes to .env, masks in responses
|
||||||
|
- Fallback chain: get → set → delete
|
||||||
|
- Edge case: create duplicate provider → ValueError
|
||||||
|
- Edge case: update non-existent provider → ValueError
|
||||||
|
- Edge case: delete provider used in fallback → ValueError
|
||||||
|
- Edge case: delete non-existent provider → ValueError
|
||||||
|
- File locking: concurrent writes don't corrupt (sequential simulation)
|
||||||
|
- Singleton helpers (get/set_llm_config_service)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.server.admin.llm_config_service import (
|
||||||
|
LlmConfigService,
|
||||||
|
get_llm_config_service,
|
||||||
|
set_llm_config_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_config() -> dict[str, Any]:
|
||||||
|
"""A minimal agentkit.yaml-style config for testing."""
|
||||||
|
return {
|
||||||
|
"server": {"host": "0.0.0.0", "port": 8001},
|
||||||
|
"llm": {
|
||||||
|
"providers": {
|
||||||
|
"openai": {
|
||||||
|
"type": "openai",
|
||||||
|
"api_key": "sk-test-12345678",
|
||||||
|
"base_url": "https://api.openai.com/v1",
|
||||||
|
"models": {"gpt-4o": {}},
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"timeout": 120.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
||||||
|
"fallbacks": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_path(tmp_path: Path) -> Path:
|
||||||
|
"""Create a temporary agentkit.yaml config file."""
|
||||||
|
path = tmp_path / "agentkit.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(config_path: Path) -> LlmConfigService:
|
||||||
|
return LlmConfigService(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singleton():
|
||||||
|
"""Reset the module singleton before and after each test."""
|
||||||
|
set_llm_config_service(None)
|
||||||
|
yield
|
||||||
|
set_llm_config_service(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provider CRUD happy path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderCrudHappyPath:
|
||||||
|
def test_list_returns_existing_providers(self, service: LlmConfigService):
|
||||||
|
providers = service.list_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0]["name"] == "openai"
|
||||||
|
assert providers[0]["type"] == "openai"
|
||||||
|
|
||||||
|
def test_list_masks_api_keys(self, service: LlmConfigService):
|
||||||
|
providers = service.list_providers()
|
||||||
|
# API key should be masked: ****xxxx
|
||||||
|
assert providers[0]["api_key"].startswith("****")
|
||||||
|
# Should show only last 4 chars
|
||||||
|
assert providers[0]["api_key"] == "****5678"
|
||||||
|
|
||||||
|
def test_get_returns_provider_by_name(self, service: LlmConfigService):
|
||||||
|
provider = service.get_provider("openai")
|
||||||
|
assert provider is not None
|
||||||
|
assert provider["name"] == "openai"
|
||||||
|
assert provider["type"] == "openai"
|
||||||
|
assert provider["base_url"] == "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
def test_get_returns_none_for_unknown(self, service: LlmConfigService):
|
||||||
|
assert service.get_provider("nonexistent") is None
|
||||||
|
|
||||||
|
def test_create_adds_provider_to_yaml(self, service: LlmConfigService, config_path: Path):
|
||||||
|
created = service.create_provider(
|
||||||
|
name="anthropic",
|
||||||
|
provider_type="anthropic",
|
||||||
|
api_key="sk-ant-test-abcd1234",
|
||||||
|
base_url="https://api.anthropic.com",
|
||||||
|
models={"claude-sonnet-4-20250514": {}},
|
||||||
|
max_tokens=8192,
|
||||||
|
timeout=90.0,
|
||||||
|
)
|
||||||
|
assert created["name"] == "anthropic"
|
||||||
|
assert created["type"] == "anthropic"
|
||||||
|
# API key should be masked in the response.
|
||||||
|
assert created["api_key"].startswith("****")
|
||||||
|
# Verify the YAML file was updated.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert "anthropic" in saved["llm"]["providers"]
|
||||||
|
# The YAML should store a ${VAR} reference, not the plaintext key.
|
||||||
|
assert saved["llm"]["providers"]["anthropic"]["api_key"] == "${ANTHROPIC_API_KEY}"
|
||||||
|
|
||||||
|
def test_create_writes_api_key_to_env_file(self, service: LlmConfigService, config_path: Path):
|
||||||
|
service.create_provider(
|
||||||
|
name="anthropic",
|
||||||
|
provider_type="anthropic",
|
||||||
|
api_key="sk-ant-test-abcd1234",
|
||||||
|
)
|
||||||
|
env_path = config_path.parent / ".env"
|
||||||
|
assert env_path.exists()
|
||||||
|
with open(env_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "ANTHROPIC_API_KEY=sk-ant-test-abcd1234" in content
|
||||||
|
|
||||||
|
def test_update_partial_only_changes_provided_fields(
|
||||||
|
self, service: LlmConfigService, config_path: Path
|
||||||
|
):
|
||||||
|
updated = service.update_provider("openai", max_tokens=2048)
|
||||||
|
assert updated["max_tokens"] == 2048
|
||||||
|
# Other fields should be preserved.
|
||||||
|
assert updated["base_url"] == "https://api.openai.com/v1"
|
||||||
|
assert updated["type"] == "openai"
|
||||||
|
|
||||||
|
def test_update_with_masked_api_key_preserves_existing(
|
||||||
|
self, service: LlmConfigService, config_path: Path
|
||||||
|
):
|
||||||
|
"""When user sends back a masked key (****xxxx), the real key should
|
||||||
|
be preserved."""
|
||||||
|
service.update_provider("openai", api_key="****5678")
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
# The original plaintext key should be preserved.
|
||||||
|
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
|
||||||
|
|
||||||
|
def test_update_with_new_api_key_writes_to_env(
|
||||||
|
self, service: LlmConfigService, config_path: Path
|
||||||
|
):
|
||||||
|
service.update_provider("openai", api_key="sk-new-key-9999")
|
||||||
|
env_path = config_path.parent / ".env"
|
||||||
|
with open(env_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "OPENAI_API_KEY=sk-new-key-9999" in content
|
||||||
|
# YAML should now have a ${VAR} reference.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
|
||||||
|
|
||||||
|
def test_delete_removes_provider(self, service: LlmConfigService, config_path: Path):
|
||||||
|
deleted = service.delete_provider("openai")
|
||||||
|
assert deleted is True
|
||||||
|
# Verify the YAML file was updated.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert "openai" not in saved["llm"]["providers"]
|
||||||
|
# list_providers should return empty.
|
||||||
|
assert service.list_providers() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Provider CRUD edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderCrudEdgeCases:
|
||||||
|
def test_create_duplicate_raises_value_error(self, service: LlmConfigService):
|
||||||
|
with pytest.raises(ValueError, match="already exists"):
|
||||||
|
service.create_provider(
|
||||||
|
name="openai",
|
||||||
|
provider_type="openai",
|
||||||
|
api_key="sk-duplicate",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_update_nonexistent_raises_value_error(self, service: LlmConfigService):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
service.update_provider("nonexistent", max_tokens=2048)
|
||||||
|
|
||||||
|
def test_delete_nonexistent_raises_value_error(self, service: LlmConfigService):
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
service.delete_provider("nonexistent")
|
||||||
|
|
||||||
|
def test_delete_provider_used_in_fallback_raises_value_error(self, service: LlmConfigService):
|
||||||
|
# Add a second provider and a fallback chain that references it.
|
||||||
|
service.create_provider(
|
||||||
|
name="anthropic",
|
||||||
|
provider_type="anthropic",
|
||||||
|
api_key="sk-ant-test",
|
||||||
|
)
|
||||||
|
service.set_fallback("gpt-4o", ["anthropic"])
|
||||||
|
with pytest.raises(ValueError, match="used in fallback"):
|
||||||
|
service.delete_provider("anthropic")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# API key management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestApiKeyManagement:
|
||||||
|
def test_set_api_key_writes_to_env_file(self, service: LlmConfigService, config_path: Path):
|
||||||
|
service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||||
|
env_path = config_path.parent / ".env"
|
||||||
|
assert env_path.exists()
|
||||||
|
with open(env_path, encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
|
||||||
|
|
||||||
|
def test_set_api_key_updates_yaml_reference(self, service: LlmConfigService, config_path: Path):
|
||||||
|
service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
|
||||||
|
|
||||||
|
def test_set_api_key_returns_masked_provider(self, service: LlmConfigService):
|
||||||
|
result = service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||||
|
assert result["api_key"].startswith("****")
|
||||||
|
# Should not contain the plaintext key.
|
||||||
|
assert "sk-brand-new-key-1234" not in result["api_key"]
|
||||||
|
|
||||||
|
def test_set_api_key_for_new_provider_creates_stub(
|
||||||
|
self, service: LlmConfigService, config_path: Path
|
||||||
|
):
|
||||||
|
"""Setting an API key for a provider that doesn't yet exist in the
|
||||||
|
YAML should create a stub entry."""
|
||||||
|
result = service.set_api_key("gemini", "AIza-test-key-1234")
|
||||||
|
assert result["name"] == "gemini"
|
||||||
|
assert result["type"] == "openai" # default type for stub
|
||||||
|
assert result["api_key"].startswith("****")
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert "gemini" in saved["llm"]["providers"]
|
||||||
|
assert saved["llm"]["providers"]["gemini"]["api_key"] == "${GEMINI_API_KEY}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fallback chain management
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackManagement:
|
||||||
|
def test_get_fallbacks_returns_empty_when_none_configured(self, service: LlmConfigService):
|
||||||
|
fallbacks = service.get_fallbacks()
|
||||||
|
assert fallbacks == {}
|
||||||
|
|
||||||
|
def test_set_fallback_adds_chain(self, service: LlmConfigService, config_path: Path):
|
||||||
|
result = service.set_fallback("gpt-4o", ["openai", "anthropic"])
|
||||||
|
assert result["model"] == "gpt-4o"
|
||||||
|
assert result["chain"] == ["openai", "anthropic"]
|
||||||
|
# Verify it was written to YAML.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
assert saved["llm"]["fallbacks"]["gpt-4o"] == ["openai", "anthropic"]
|
||||||
|
|
||||||
|
def test_get_fallbacks_returns_set_chain(self, service: LlmConfigService):
|
||||||
|
service.set_fallback("gpt-4o", ["openai", "anthropic"])
|
||||||
|
fallbacks = service.get_fallbacks()
|
||||||
|
assert fallbacks == {"gpt-4o": ["openai", "anthropic"]}
|
||||||
|
|
||||||
|
def test_set_fallback_overwrites_existing(self, service: LlmConfigService):
|
||||||
|
service.set_fallback("gpt-4o", ["openai"])
|
||||||
|
service.set_fallback("gpt-4o", ["anthropic", "gemini"])
|
||||||
|
fallbacks = service.get_fallbacks()
|
||||||
|
assert fallbacks["gpt-4o"] == ["anthropic", "gemini"]
|
||||||
|
|
||||||
|
def test_delete_fallback_removes_chain(self, service: LlmConfigService):
|
||||||
|
service.set_fallback("gpt-4o", ["openai"])
|
||||||
|
deleted = service.delete_fallback("gpt-4o")
|
||||||
|
assert deleted is True
|
||||||
|
assert service.get_fallbacks() == {}
|
||||||
|
|
||||||
|
def test_delete_fallback_returns_false_for_unset(self, service: LlmConfigService):
|
||||||
|
deleted = service.delete_fallback("nonexistent")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File locking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileLocking:
|
||||||
|
def test_concurrent_writes_do_not_corrupt(self, service: LlmConfigService, config_path: Path):
|
||||||
|
"""Simulate concurrent writes from multiple threads.
|
||||||
|
|
||||||
|
Each thread creates a different provider. The file lock should
|
||||||
|
serialize the writes so the final YAML is valid and contains
|
||||||
|
all providers.
|
||||||
|
"""
|
||||||
|
errors: list[Exception] = []
|
||||||
|
|
||||||
|
def _create_provider(name: str, api_key: str) -> None:
|
||||||
|
try:
|
||||||
|
# Each thread needs its own service instance to avoid
|
||||||
|
# sharing state, but they all point to the same file.
|
||||||
|
svc = LlmConfigService(config_path)
|
||||||
|
svc.create_provider(name=name, provider_type="openai", api_key=api_key)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(exc)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(
|
||||||
|
target=_create_provider,
|
||||||
|
args=(f"provider_{i}", f"sk-key-{i:04d}"),
|
||||||
|
)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# No errors should have occurred.
|
||||||
|
assert errors == [], f"Concurrent writes failed: {errors}"
|
||||||
|
|
||||||
|
# The YAML file should be valid and contain all 5 new providers
|
||||||
|
# plus the original "openai" provider.
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
saved = yaml.safe_load(f)
|
||||||
|
providers = saved["llm"]["providers"]
|
||||||
|
for i in range(5):
|
||||||
|
assert f"provider_{i}" in providers
|
||||||
|
assert "openai" in providers
|
||||||
|
|
||||||
|
def test_lockfile_is_created(self, service: LlmConfigService, config_path: Path):
|
||||||
|
"""The lockfile should be created on the first write."""
|
||||||
|
service.create_provider(
|
||||||
|
name="anthropic",
|
||||||
|
provider_type="anthropic",
|
||||||
|
api_key="sk-ant-test",
|
||||||
|
)
|
||||||
|
lockfile = config_path.with_suffix(config_path.suffix + ".lock")
|
||||||
|
assert lockfile.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Singleton helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonHelpers:
|
||||||
|
def test_get_llm_config_service_requires_path_on_first_call(self):
|
||||||
|
with pytest.raises(ValueError, match="config_path is required"):
|
||||||
|
get_llm_config_service()
|
||||||
|
|
||||||
|
def test_get_llm_config_service_returns_singleton(self, config_path: Path):
|
||||||
|
svc1 = get_llm_config_service(config_path)
|
||||||
|
svc2 = get_llm_config_service() # subsequent call ignores path
|
||||||
|
assert svc1 is svc2
|
||||||
|
|
||||||
|
def test_set_llm_config_service_overrides_singleton(self, config_path: Path):
|
||||||
|
custom = LlmConfigService(config_path)
|
||||||
|
set_llm_config_service(custom)
|
||||||
|
assert get_llm_config_service() is custom
|
||||||
|
|
||||||
|
def test_set_llm_config_service_none_resets_singleton(self, config_path: Path):
|
||||||
|
svc = get_llm_config_service(config_path)
|
||||||
|
set_llm_config_service(None)
|
||||||
|
# Next call should raise (no path provided after reset).
|
||||||
|
with pytest.raises(ValueError, match="config_path is required"):
|
||||||
|
get_llm_config_service()
|
||||||
|
# The previously-created instance should still be usable.
|
||||||
|
assert svc.list_providers() is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Env var resolution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvVarResolution:
|
||||||
|
def test_list_providers_masks_env_var_referenced_keys(
|
||||||
|
self, config_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
"""When the YAML stores ``${VAR}``, the masked response should
|
||||||
|
show the masked *resolved* env value, not the literal ``${VAR}``."""
|
||||||
|
# Rewrite the config to use a ${VAR} reference.
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
cfg = _sample_config()
|
||||||
|
cfg["llm"]["providers"]["openai"]["api_key"] = "${TEST_OPENAI_KEY}"
|
||||||
|
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
|
||||||
|
# Set the env var so it can be resolved.
|
||||||
|
monkeypatch.setenv("TEST_OPENAI_KEY", "sk-resolved-1234")
|
||||||
|
svc = LlmConfigService(config_path)
|
||||||
|
providers = svc.list_providers()
|
||||||
|
assert providers[0]["api_key"] == "****1234"
|
||||||
|
|
@ -0,0 +1,333 @@
|
||||||
|
"""Unit tests for QuotaService (U5 — per-department LLM quotas).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Happy path: set → get → list → delete quotas
|
||||||
|
- Quota check: token_limit (allowed / denied)
|
||||||
|
- Quota check: cost_limit (allowed / denied)
|
||||||
|
- Quota check: no quota set → always allowed
|
||||||
|
- model_whitelist serialization (list → JSON → list)
|
||||||
|
- Upsert: set same quota twice updates the value
|
||||||
|
- Validation: invalid quota_type / period raises ValueError
|
||||||
|
- is_model_allowed: whitelist enforcement
|
||||||
|
- Singleton helpers (get/set_quota_service)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.admin.quota_service import (
|
||||||
|
QuotaService,
|
||||||
|
get_quota_service,
|
||||||
|
set_quota_service,
|
||||||
|
)
|
||||||
|
from agentkit.server.auth.models import init_auth_db
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def fresh_db(tmp_path: Path) -> Path:
|
||||||
|
"""A brand-new auth DB on a fresh path (no data)."""
|
||||||
|
db_path = tmp_path / "auth.db"
|
||||||
|
await init_auth_db(db_path)
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service() -> QuotaService:
|
||||||
|
return QuotaService()
|
||||||
|
|
||||||
|
|
||||||
|
def _random_dept_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quota CRUD happy path
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaCrudHappyPath:
|
||||||
|
async def test_set_token_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
quota = await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||||
|
assert quota["department_id"] == dept_id
|
||||||
|
assert quota["quota_type"] == "token_limit"
|
||||||
|
assert quota["limit_value"] == 1000
|
||||||
|
assert quota["period"] == "daily"
|
||||||
|
assert "id" in quota
|
||||||
|
assert "updated_at" in quota
|
||||||
|
|
||||||
|
async def test_set_cost_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
quota = await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||||
|
assert quota["quota_type"] == "cost_limit"
|
||||||
|
assert quota["limit_value"] == 5000
|
||||||
|
assert quota["period"] == "monthly"
|
||||||
|
|
||||||
|
async def test_set_model_whitelist_serializes_list(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
whitelist = ["gpt-4o", "claude-sonnet-4-20250514", "gemini-pro"]
|
||||||
|
quota = await service.set_quota(
|
||||||
|
fresh_db, dept_id, "model_whitelist", whitelist, period="daily"
|
||||||
|
)
|
||||||
|
assert quota["quota_type"] == "model_whitelist"
|
||||||
|
# The deserialized value should match the input list.
|
||||||
|
assert quota["limit_value"] == whitelist
|
||||||
|
|
||||||
|
async def test_get_returns_set_quota(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched["limit_value"] == 1000
|
||||||
|
assert fetched["quota_type"] == "token_limit"
|
||||||
|
|
||||||
|
async def test_get_returns_none_for_unset_quota(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
|
||||||
|
assert fetched is None
|
||||||
|
|
||||||
|
async def test_list_returns_all_quotas_for_department(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||||
|
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
|
||||||
|
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||||
|
assert len(quotas) == 3
|
||||||
|
types = {q["quota_type"] for q in quotas}
|
||||||
|
assert types == {"token_limit", "cost_limit", "model_whitelist"}
|
||||||
|
|
||||||
|
async def test_list_returns_empty_for_department_with_no_quotas(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||||
|
assert quotas == []
|
||||||
|
|
||||||
|
async def test_delete_removes_quota(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
|
||||||
|
assert deleted is True
|
||||||
|
# Confirm it's gone.
|
||||||
|
assert await service.get_quota(fresh_db, dept_id, "token_limit") is None
|
||||||
|
|
||||||
|
async def test_delete_returns_false_for_unset_quota(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
|
||||||
|
assert deleted is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Upsert behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaUpsert:
|
||||||
|
async def test_set_same_quota_twice_updates_value(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
updated = await service.set_quota(fresh_db, dept_id, "token_limit", 2000)
|
||||||
|
assert updated["limit_value"] == 2000
|
||||||
|
# Only one row should exist.
|
||||||
|
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||||
|
assert len(quotas) == 1
|
||||||
|
assert quotas[0]["limit_value"] == 2000
|
||||||
|
|
||||||
|
async def test_set_same_quota_with_different_period_creates_new_row(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 30000, period="monthly")
|
||||||
|
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||||
|
assert len(quotas) == 2
|
||||||
|
periods = {q["period"] for q in quotas}
|
||||||
|
assert periods == {"daily", "monthly"}
|
||||||
|
|
||||||
|
async def test_upsert_model_whitelist_replaces_list(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o-mini"])
|
||||||
|
fetched = await service.get_quota(fresh_db, dept_id, "model_whitelist")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched["limit_value"] == ["gpt-4o-mini"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Quota enforcement
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaCheck:
|
||||||
|
async def test_token_limit_allowed_when_under_limit(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "token_limit", "daily", current_usage=500
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
assert reason == "ok"
|
||||||
|
|
||||||
|
async def test_token_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "token_limit", "daily", current_usage=1000
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
assert "exceeded" in reason
|
||||||
|
|
||||||
|
async def test_token_limit_denied_over_limit(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "token_limit", "daily", current_usage=1500
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
assert "exceeded" in reason
|
||||||
|
|
||||||
|
async def test_cost_limit_allowed_when_under_limit(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "cost_limit", "monthly", current_usage=2500
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
assert reason == "ok"
|
||||||
|
|
||||||
|
async def test_cost_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||||
|
allowed, _ = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "cost_limit", "monthly", current_usage=5000
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
async def test_no_quota_set_always_allowed(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "token_limit", "daily", current_usage=999999
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
assert reason == "ok"
|
||||||
|
|
||||||
|
async def test_model_whitelist_not_a_quota_check(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
|
||||||
|
allowed, reason = await service.check_quota(
|
||||||
|
fresh_db, dept_id, "model_whitelist", "daily", current_usage=0
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
assert "is_model_allowed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model whitelist enforcement
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelWhitelist:
|
||||||
|
async def test_is_model_allowed_when_in_whitelist(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||||
|
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gpt-4o")
|
||||||
|
assert allowed is True
|
||||||
|
assert reason == "ok"
|
||||||
|
|
||||||
|
async def test_is_model_denied_when_not_in_whitelist(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||||
|
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gemini-pro")
|
||||||
|
assert allowed is False
|
||||||
|
assert "not in" in reason
|
||||||
|
|
||||||
|
async def test_is_model_allowed_when_no_whitelist(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
# No whitelist set — all models allowed.
|
||||||
|
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "any-model")
|
||||||
|
assert allowed is True
|
||||||
|
assert reason == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaValidation:
|
||||||
|
async def test_set_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||||
|
await service.set_quota(fresh_db, dept_id, "invalid_type", 1000)
|
||||||
|
|
||||||
|
async def test_set_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="Invalid period"):
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="weekly")
|
||||||
|
|
||||||
|
async def test_set_token_limit_with_list_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="must be an int"):
|
||||||
|
await service.set_quota(fresh_db, dept_id, "token_limit", ["gpt-4o"])
|
||||||
|
|
||||||
|
async def test_set_model_whitelist_with_int_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="must be a list"):
|
||||||
|
await service.set_quota(fresh_db, dept_id, "model_whitelist", 1000)
|
||||||
|
|
||||||
|
async def test_get_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||||
|
await service.get_quota(fresh_db, dept_id, "invalid_type")
|
||||||
|
|
||||||
|
async def test_delete_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="Invalid period"):
|
||||||
|
await service.delete_quota(fresh_db, dept_id, "token_limit", period="weekly")
|
||||||
|
|
||||||
|
async def test_check_quota_invalid_quota_type_raises(
|
||||||
|
self, service: QuotaService, fresh_db: Path
|
||||||
|
):
|
||||||
|
dept_id = _random_dept_id()
|
||||||
|
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||||
|
await service.check_quota(fresh_db, dept_id, "invalid_type", "daily", current_usage=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Singleton helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonHelpers:
|
||||||
|
def test_get_quota_service_returns_singleton(self):
|
||||||
|
# Save the original singleton so we don't disturb other tests.
|
||||||
|
original = get_quota_service()
|
||||||
|
try:
|
||||||
|
custom = QuotaService()
|
||||||
|
set_quota_service(custom)
|
||||||
|
assert get_quota_service() is custom
|
||||||
|
# Clearing falls back to a new lazy instance.
|
||||||
|
set_quota_service(None)
|
||||||
|
new_one = get_quota_service()
|
||||||
|
assert new_one is not custom
|
||||||
|
finally:
|
||||||
|
set_quota_service(original)
|
||||||
Loading…
Reference in New Issue