feat(admin): U5 — LLM config admin endpoints + department quotas
QuotaService: set/get/list/delete quotas, check_quota (hard reject), is_model_allowed. JSON-serialized limit_value, upsert with ON CONFLICT. LlmConfigService: provider CRUD + set_api_key + fallback management. fcntl.flock file lock prevents concurrent YAML writes. Reuses settings.py helpers (_read_yaml_config, _write_yaml_config, _write_env_var, _mask_api_key). 11 new admin endpoints: provider CRUD, api-key, fallback CRUD, department quotas CRUD. All guarded by _require_admin. 93 new tests (30 quota unit + 32 llm-config unit + 31 integration).
This commit is contained in:
parent
ad65f7a8d7
commit
980919fc95
|
|
@ -0,0 +1,437 @@
|
|||
"""LlmConfigService — runtime CRUD for LLM providers/fallbacks (U5).
|
||||
|
||||
This module wraps the YAML read/write logic from
|
||||
:mod:`agentkit.server.routes.settings` as a reusable service. Web UI
|
||||
routes (``/api/v1/admin/llm/*``) and the CLI ``agentkit admin llm``
|
||||
sub-app both call into :class:`LlmConfigService` rather than touching
|
||||
the YAML directly, keeping the validation rules (duplicate-provider,
|
||||
fallback-in-use guard, API-key masking) in one place.
|
||||
|
||||
The service writes back to ``agentkit.yaml`` using the existing
|
||||
:func:`_write_yaml_config` helper (which preserves ``${VAR}`` env-var
|
||||
references and comments via ruamel.yaml). Concurrent writes are
|
||||
serialized with :func:`fcntl.flock` to prevent corruption.
|
||||
|
||||
The service is a module-level singleton (see
|
||||
:func:`get_llm_config_service`) so tests can inject a custom instance
|
||||
via :func:`set_llm_config_service`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from agentkit.server.routes.settings import (
|
||||
_mask_api_key,
|
||||
_read_yaml_config,
|
||||
_write_env_var,
|
||||
_write_yaml_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Regex matching ``${ENV_VAR}`` references in YAML api_key fields.
|
||||
_ENV_REF_RE = re.compile(r"^\$\{([^}]+)\}$")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _provider_to_dict(name: str, pconf: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert a raw YAML provider dict to a masked response dict.
|
||||
|
||||
The ``api_key`` field is masked via :func:`_mask_api_key`. The
|
||||
``models`` field is preserved as-is (it's already a dict in YAML).
|
||||
"""
|
||||
raw_key = pconf.get("api_key", "")
|
||||
# Resolve ${VAR} references for masking — we want to show the
|
||||
# masked *resolved* value, not the literal "${VAR}" string.
|
||||
resolved_key = _resolve_env_var(raw_key)
|
||||
return {
|
||||
"name": name,
|
||||
"type": pconf.get("type", "openai"),
|
||||
"api_key": _mask_api_key(resolved_key),
|
||||
"base_url": pconf.get("base_url", ""),
|
||||
"models": pconf.get("models", {}) or {},
|
||||
"max_tokens": pconf.get("max_tokens", 4096),
|
||||
"timeout": pconf.get("timeout", 60.0),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_env_var(value: str) -> str:
|
||||
"""Resolve a ``${VAR}`` reference to its env value (if set).
|
||||
|
||||
If ``value`` is ``${OPENAI_API_KEY}`` and ``OPENAI_API_KEY`` is set
|
||||
in the environment, returns the env value. Otherwise returns the
|
||||
input unchanged.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return ""
|
||||
match = _ENV_REF_RE.match(value)
|
||||
if match:
|
||||
var_name = match.group(1).split(":-")[0]
|
||||
return os.environ.get(var_name, value)
|
||||
return value
|
||||
|
||||
|
||||
def _env_var_name_for_provider(name: str) -> str:
|
||||
"""Derive a sensible env var name from a provider name."""
|
||||
return f"{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File-locking context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _YamlLock:
|
||||
"""Context manager that acquires an exclusive ``fcntl.flock`` on the
|
||||
YAML file's sibling lockfile (``<config_path>.lock``).
|
||||
|
||||
The lockfile is created on demand and is not removed after release
|
||||
(it's reused on subsequent writes). Holding the lock prevents
|
||||
concurrent writers from corrupting the YAML file.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Path) -> None:
|
||||
self._lock_path = config_path.with_suffix(config_path.suffix + ".lock")
|
||||
self._fh = None
|
||||
|
||||
def __enter__(self) -> "_YamlLock":
|
||||
self._fh = open(self._lock_path, "w")
|
||||
fcntl.flock(self._fh.fileno(), fcntl.LOCK_EX)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if self._fh is not None:
|
||||
fcntl.flock(self._fh.fileno(), fcntl.LOCK_UN)
|
||||
self._fh.close()
|
||||
self._fh = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LlmConfigService:
|
||||
"""CRUD for LLM providers, API keys, and fallback chains.
|
||||
|
||||
All methods read/write the YAML file at ``self._config_path``.
|
||||
Writes are serialized via :class:`_YamlLock` (``fcntl.flock``) to
|
||||
prevent concurrent corruption. The existing ``watch_config()``
|
||||
mechanism detects the file change and rebuilds the
|
||||
:class:`LLMGateway` — no explicit notification is needed.
|
||||
|
||||
API keys are NEVER returned in full — every response masks them via
|
||||
:func:`_mask_api_key`. New plaintext keys are written to ``.env``
|
||||
(next to the YAML) and the YAML stores a ``${ENV_VAR}`` reference.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: Path) -> None:
|
||||
self._config_path = Path(config_path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_providers(self) -> list[dict[str, Any]]:
|
||||
"""Return all providers with masked API keys."""
|
||||
data = self._read()
|
||||
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||
return [_provider_to_dict(name, pconf) for name, pconf in providers.items()]
|
||||
|
||||
def get_provider(self, name: str) -> dict[str, Any] | None:
|
||||
"""Return a single provider by name (masked key), or ``None``."""
|
||||
data = self._read()
|
||||
pconf = data.get("llm", {}).get("providers", {}).get(name)
|
||||
if pconf is None:
|
||||
return None
|
||||
return _provider_to_dict(name, pconf)
|
||||
|
||||
def create_provider(
|
||||
self,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str,
|
||||
base_url: str = "",
|
||||
models: dict[str, Any] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
timeout: float = 60.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a new provider to the YAML file.
|
||||
|
||||
Args:
|
||||
name: Provider name (must be unique within the YAML).
|
||||
provider_type: Provider type (``openai`` / ``anthropic`` /
|
||||
``gemini`` / ``doubao`` / ``wenxin`` / ``yuanbao``).
|
||||
api_key: Plaintext API key. Stored in ``.env`` as
|
||||
``{NAME}_API_KEY``; the YAML stores ``${...}`` ref.
|
||||
base_url: Optional base URL override.
|
||||
models: Optional models dict (e.g. ``{"gpt-4o": {}}``).
|
||||
max_tokens: Default max tokens for completions.
|
||||
timeout: Request timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The newly-created provider dict (masked key).
|
||||
|
||||
Raises:
|
||||
ValueError: If a provider with the same name already exists.
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
data.setdefault("llm", {}).setdefault("providers", {})
|
||||
if name in data["llm"]["providers"]:
|
||||
raise ValueError(f"Provider {name!r} already exists")
|
||||
|
||||
env_key = _env_var_name_for_provider(name)
|
||||
_write_env_var(str(self._config_path), env_key, api_key)
|
||||
|
||||
pconf: dict[str, Any] = {
|
||||
"type": provider_type,
|
||||
"api_key": f"${{{env_key}}}",
|
||||
"base_url": base_url,
|
||||
"max_tokens": max_tokens,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if models is not None:
|
||||
pconf["models"] = models
|
||||
data["llm"]["providers"][name] = pconf
|
||||
self._write(data)
|
||||
|
||||
created = self.get_provider(name)
|
||||
assert created is not None # we just inserted it
|
||||
return created
|
||||
|
||||
def update_provider(
|
||||
self,
|
||||
name: str,
|
||||
provider_type: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
models: dict[str, Any] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Partially update a provider.
|
||||
|
||||
Only the provided fields are updated. If ``api_key`` is a masked
|
||||
value (starts with ``****``), it is ignored — the existing key
|
||||
is preserved.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider does not exist.
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||
if name not in providers:
|
||||
raise ValueError(f"Provider {name!r} not found")
|
||||
|
||||
pconf = providers[name]
|
||||
if provider_type is not None:
|
||||
pconf["type"] = provider_type
|
||||
if base_url is not None:
|
||||
pconf["base_url"] = base_url
|
||||
if models is not None:
|
||||
pconf["models"] = models
|
||||
if max_tokens is not None:
|
||||
pconf["max_tokens"] = max_tokens
|
||||
if timeout is not None:
|
||||
pconf["timeout"] = timeout
|
||||
if api_key is not None and not api_key.startswith("****"):
|
||||
# New plaintext key — write to .env, keep ${VAR} ref in YAML.
|
||||
existing_key = str(pconf.get("api_key", ""))
|
||||
env_match = _ENV_REF_RE.match(existing_key)
|
||||
if env_match:
|
||||
env_key = env_match.group(1).split(":-")[0]
|
||||
else:
|
||||
env_key = _env_var_name_for_provider(name)
|
||||
_write_env_var(str(self._config_path), env_key, api_key)
|
||||
pconf["api_key"] = f"${{{env_key}}}"
|
||||
|
||||
providers[name] = pconf
|
||||
data.setdefault("llm", {})["providers"] = providers
|
||||
self._write(data)
|
||||
|
||||
updated = self.get_provider(name)
|
||||
assert updated is not None
|
||||
return updated
|
||||
|
||||
def delete_provider(self, name: str) -> bool:
|
||||
"""Remove a provider from the YAML file.
|
||||
|
||||
Returns:
|
||||
``True`` if the provider was deleted.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider does not exist, or if it is
|
||||
referenced by any fallback chain (removing it would
|
||||
break the chain).
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
providers = data.get("llm", {}).get("providers", {}) or {}
|
||||
if name not in providers:
|
||||
raise ValueError(f"Provider {name!r} not found")
|
||||
|
||||
# Guard: refuse to delete if referenced in any fallback chain.
|
||||
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||
for model, chain in fallbacks.items():
|
||||
if isinstance(chain, list) and name in chain:
|
||||
raise ValueError(
|
||||
f"Provider {name!r} is used in fallback chain for model "
|
||||
f"{model!r}; remove the fallback first"
|
||||
)
|
||||
|
||||
del providers[name]
|
||||
data.setdefault("llm", {})["providers"] = providers
|
||||
self._write(data)
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# API key management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def set_api_key(self, provider_name: str, api_key: str) -> dict[str, Any]:
|
||||
"""Set the API key for a provider.
|
||||
|
||||
The plaintext key is written to ``.env`` (as
|
||||
``{NAME}_API_KEY``) and the YAML stores a ``${...}`` reference.
|
||||
If the provider does not yet exist in the YAML, a stub entry is
|
||||
created with ``type=openai`` and default settings.
|
||||
|
||||
Returns:
|
||||
The updated provider dict (masked key).
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
data.setdefault("llm", {}).setdefault("providers", {})
|
||||
providers = data["llm"]["providers"]
|
||||
pconf = providers.get(provider_name, {})
|
||||
existing_key = str(pconf.get("api_key", ""))
|
||||
env_match = _ENV_REF_RE.match(existing_key)
|
||||
if env_match:
|
||||
env_key = env_match.group(1).split(":-")[0]
|
||||
else:
|
||||
env_key = _env_var_name_for_provider(provider_name)
|
||||
_write_env_var(str(self._config_path), env_key, api_key)
|
||||
pconf["api_key"] = f"${{{env_key}}}"
|
||||
if "type" not in pconf:
|
||||
pconf["type"] = "openai"
|
||||
providers[provider_name] = pconf
|
||||
self._write(data)
|
||||
|
||||
updated = self.get_provider(provider_name)
|
||||
assert updated is not None
|
||||
return updated
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Fallback chain management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def get_fallbacks(self) -> dict[str, list[str]]:
|
||||
"""Return the fallback chains (model → list of provider/model refs)."""
|
||||
data = self._read()
|
||||
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||
# Normalize: ensure all values are lists.
|
||||
return {
|
||||
str(model): list(chain) if isinstance(chain, list) else [str(chain)]
|
||||
for model, chain in fallbacks.items()
|
||||
}
|
||||
|
||||
def set_fallback(self, model: str, chain: list[str]) -> dict[str, Any]:
|
||||
"""Set the fallback chain for a model.
|
||||
|
||||
Returns:
|
||||
``{"model": model, "chain": chain}`` dict.
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
data.setdefault("llm", {}).setdefault("fallbacks", {})
|
||||
data["llm"]["fallbacks"][model] = list(chain)
|
||||
self._write(data)
|
||||
return {"model": model, "chain": list(chain)}
|
||||
|
||||
def delete_fallback(self, model: str) -> bool:
|
||||
"""Delete the fallback chain for a model.
|
||||
|
||||
Returns:
|
||||
``True`` if a chain was deleted, ``False`` if no chain
|
||||
existed for ``model``.
|
||||
"""
|
||||
with _YamlLock(self._config_path):
|
||||
data = self._read()
|
||||
fallbacks = data.get("llm", {}).get("fallbacks", {}) or {}
|
||||
if model not in fallbacks:
|
||||
return False
|
||||
del fallbacks[model]
|
||||
data.setdefault("llm", {})["fallbacks"] = fallbacks
|
||||
self._write(data)
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _read(self) -> dict[str, Any]:
|
||||
"""Read the YAML config (returns ``{}`` if file is missing/empty)."""
|
||||
if not self._config_path.exists():
|
||||
return {}
|
||||
return _read_yaml_config(str(self._config_path))
|
||||
|
||||
def _write(self, data: dict[str, Any]) -> None:
|
||||
"""Write the full config dict back to the YAML file."""
|
||||
_write_yaml_config(str(self._config_path), data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton (overridable in tests via set_llm_config_service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_llm_config_service: LlmConfigService | None = None
|
||||
|
||||
|
||||
def get_llm_config_service(config_path: Path | str | None = None) -> LlmConfigService:
|
||||
"""Return the process-wide :class:`LlmConfigService`.
|
||||
|
||||
On first call, ``config_path`` must be provided (or the service
|
||||
will raise ``ValueError``). Subsequent calls ignore ``config_path``
|
||||
and return the cached singleton.
|
||||
"""
|
||||
global _llm_config_service
|
||||
if _llm_config_service is None:
|
||||
if config_path is None:
|
||||
raise ValueError("config_path is required to initialize LlmConfigService on first call")
|
||||
_llm_config_service = LlmConfigService(Path(config_path))
|
||||
return _llm_config_service
|
||||
|
||||
|
||||
def set_llm_config_service(service: LlmConfigService | None) -> None:
|
||||
"""Inject a custom :class:`LlmConfigService` (used by tests)."""
|
||||
global _llm_config_service
|
||||
_llm_config_service = service
|
||||
|
||||
|
||||
# Re-export yaml for tests that need to construct sample config files.
|
||||
__all__ = [
|
||||
"LlmConfigService",
|
||||
"get_llm_config_service",
|
||||
"set_llm_config_service",
|
||||
"yaml",
|
||||
]
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
"""QuotaService — per-department LLM quota CRUD + enforcement (U5/U7).
|
||||
|
||||
This module is the single owner of the ``department_quotas`` table.
|
||||
Web UI routes (``/api/v1/admin/departments/{id}/quotas``) and the CLI
|
||||
``agentkit admin llm set-quota`` sub-app both call into
|
||||
:class:`QuotaService` rather than touching the table directly, keeping
|
||||
the validation rules (quota_type/period enums, JSON serialization of
|
||||
``model_whitelist``) in one place.
|
||||
|
||||
Quota semantics
|
||||
---------------
|
||||
- ``token_limit`` (int): max tokens per ``period``. ``check_quota``
|
||||
compares ``current_usage`` (tokens used) against the limit.
|
||||
- ``cost_limit`` (int, in cents): max spend per ``period``.
|
||||
``check_quota`` compares ``current_usage`` (cost in cents) against
|
||||
the limit.
|
||||
- ``model_whitelist`` (list[str]): allowed model names. Not a quota
|
||||
check — it's a model-access check (use :meth:`is_model_allowed`).
|
||||
|
||||
The service is a module-level singleton (see :func:`get_quota_service`)
|
||||
so tests can inject a custom instance via :func:`set_quota_service`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Valid ``quota_type`` values.
|
||||
QUOTA_TYPES: frozenset[str] = frozenset({"token_limit", "cost_limit", "model_whitelist"})
|
||||
|
||||
#: Valid ``period`` values.
|
||||
PERIODS: frozenset[str] = frozenset({"daily", "monthly"})
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
"""Return current UTC time as ISO 8601 string."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _new_id() -> str:
|
||||
"""Return a new UUID4 string."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _serialize_limit_value(quota_type: str, limit_value: Any) -> str:
|
||||
"""Serialize ``limit_value`` to a JSON string for storage.
|
||||
|
||||
For ``token_limit`` / ``cost_limit``: ``limit_value`` is an int,
|
||||
stored as ``json.dumps(int)`` (e.g. ``"1000"``).
|
||||
For ``model_whitelist``: ``limit_value`` is a ``list[str]``, stored
|
||||
as ``json.dumps(list)`` (e.g. ``'["gpt-4o", "claude"]'``).
|
||||
"""
|
||||
if quota_type == "model_whitelist":
|
||||
if not isinstance(limit_value, list):
|
||||
raise ValueError(
|
||||
f"model_whitelist limit_value must be a list, got {type(limit_value).__name__}"
|
||||
)
|
||||
return json.dumps([str(m) for m in limit_value])
|
||||
if not isinstance(limit_value, int):
|
||||
raise ValueError(
|
||||
f"{quota_type} limit_value must be an int, got {type(limit_value).__name__}"
|
||||
)
|
||||
return json.dumps(limit_value)
|
||||
|
||||
|
||||
def _deserialize_limit_value(quota_type: str, raw: str) -> Any:
|
||||
"""Deserialize ``limit_value`` from JSON string back to native type."""
|
||||
try:
|
||||
value = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning("Failed to deserialize limit_value %r: %s", raw, exc)
|
||||
return None
|
||||
if quota_type == "model_whitelist":
|
||||
return list(value) if isinstance(value, list) else []
|
||||
return int(value) if isinstance(value, (int, float)) else None
|
||||
|
||||
|
||||
def _validate_quota_type(quota_type: str) -> None:
|
||||
if quota_type not in QUOTA_TYPES:
|
||||
raise ValueError(f"Invalid quota_type {quota_type!r}; must be one of {sorted(QUOTA_TYPES)}")
|
||||
|
||||
|
||||
def _validate_period(period: str) -> None:
|
||||
if period not in PERIODS:
|
||||
raise ValueError(f"Invalid period {period!r}; must be one of {sorted(PERIODS)}")
|
||||
|
||||
|
||||
class QuotaService:
|
||||
"""CRUD + enforcement for per-department LLM quotas.
|
||||
|
||||
All methods are async and take ``db_path: Path`` as the first
|
||||
argument (after ``self``). Each method opens its own short-lived
|
||||
:class:`aiosqlite.Connection` — there is no shared connection state,
|
||||
which keeps the service safe to call from any async context.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quota CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def set_quota(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
quota_type: str,
|
||||
limit_value: Any,
|
||||
period: str = "daily",
|
||||
) -> dict[str, Any]:
|
||||
"""Upsert a quota for a department.
|
||||
|
||||
Args:
|
||||
db_path: Path to the auth SQLite DB.
|
||||
department_id: Department id.
|
||||
quota_type: One of ``token_limit`` / ``cost_limit`` /
|
||||
``model_whitelist``.
|
||||
limit_value: ``int`` for token/cost limits, ``list[str]``
|
||||
for model_whitelist.
|
||||
period: ``daily`` or ``monthly`` (default ``daily``).
|
||||
|
||||
Returns:
|
||||
The upserted quota as a dict (with deserialized
|
||||
``limit_value``).
|
||||
"""
|
||||
_validate_quota_type(quota_type)
|
||||
_validate_period(period)
|
||||
serialized = _serialize_limit_value(quota_type, limit_value)
|
||||
now = _now_iso()
|
||||
quota_id = _new_id()
|
||||
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
# Upsert: if (department_id, quota_type, period) exists,
|
||||
# update limit_value + updated_at; otherwise insert.
|
||||
await db.execute(
|
||||
"INSERT INTO department_quotas "
|
||||
"(id, department_id, quota_type, limit_value, period, updated_at) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?) "
|
||||
"ON CONFLICT(department_id, quota_type, period) DO UPDATE SET "
|
||||
"limit_value = excluded.limit_value, "
|
||||
"updated_at = excluded.updated_at",
|
||||
(quota_id, department_id, quota_type, serialized, period, now),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"id": quota_id,
|
||||
"department_id": department_id,
|
||||
"quota_type": quota_type,
|
||||
"limit_value": _deserialize_limit_value(quota_type, serialized),
|
||||
"period": period,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
async def get_quota(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
quota_type: str,
|
||||
period: str = "daily",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return a single quota, or ``None`` if not set."""
|
||||
_validate_quota_type(quota_type)
|
||||
_validate_period(period)
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT * FROM department_quotas "
|
||||
"WHERE department_id = ? AND quota_type = ? AND period = ?",
|
||||
(department_id, quota_type, period),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_department_quotas(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all quotas for a department (all types/periods)."""
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT * FROM department_quotas WHERE department_id = ? "
|
||||
"ORDER BY quota_type ASC, period ASC",
|
||||
(department_id,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_dict(row) for row in rows]
|
||||
|
||||
async def delete_quota(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
quota_type: str,
|
||||
period: str = "daily",
|
||||
) -> bool:
|
||||
"""Delete a quota. Returns ``True`` if a row was deleted."""
|
||||
_validate_quota_type(quota_type)
|
||||
_validate_period(period)
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
"DELETE FROM department_quotas "
|
||||
"WHERE department_id = ? AND quota_type = ? AND period = ?",
|
||||
(department_id, quota_type, period),
|
||||
)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Quota enforcement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def check_quota(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
quota_type: str,
|
||||
period: str,
|
||||
current_usage: int | float,
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether ``current_usage`` is within the configured quota.
|
||||
|
||||
Args:
|
||||
db_path: Path to the auth SQLite DB.
|
||||
department_id: Department id.
|
||||
quota_type: ``token_limit`` or ``cost_limit``.
|
||||
``model_whitelist`` is not a quota check — use
|
||||
:meth:`is_model_allowed` instead.
|
||||
period: ``daily`` or ``monthly``.
|
||||
current_usage: Current usage value (tokens for
|
||||
``token_limit``, cents for ``cost_limit``).
|
||||
|
||||
Returns:
|
||||
``(allowed, reason)`` tuple. ``allowed`` is ``True`` if the
|
||||
usage is within the limit (or no quota is set). ``reason``
|
||||
is ``"ok"`` when allowed, or a human-readable denial reason.
|
||||
"""
|
||||
_validate_quota_type(quota_type)
|
||||
_validate_period(period)
|
||||
if quota_type == "model_whitelist":
|
||||
return (
|
||||
False,
|
||||
"model_whitelist is not a quota check; use is_model_allowed()",
|
||||
)
|
||||
|
||||
quota = await self.get_quota(db_path, department_id, quota_type, period)
|
||||
if quota is None:
|
||||
return (True, "ok")
|
||||
|
||||
limit = quota["limit_value"]
|
||||
if not isinstance(limit, int):
|
||||
return (True, "ok") # malformed limit; allow
|
||||
|
||||
if current_usage >= limit:
|
||||
return (
|
||||
False,
|
||||
f"{quota_type} ({period}) exceeded: usage={current_usage}, limit={limit}",
|
||||
)
|
||||
return (True, "ok")
|
||||
|
||||
async def is_model_allowed(
|
||||
self,
|
||||
db_path: Path,
|
||||
department_id: str,
|
||||
model: str,
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether ``model`` is in the department's whitelist.
|
||||
|
||||
If no ``model_whitelist`` quota is set for the department (any
|
||||
period), all models are allowed. Otherwise the model must appear
|
||||
in the whitelist.
|
||||
|
||||
Returns:
|
||||
``(allowed, reason)`` tuple.
|
||||
"""
|
||||
async with aiosqlite.connect(str(db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
"SELECT * FROM department_quotas "
|
||||
"WHERE department_id = ? AND quota_type = 'model_whitelist' "
|
||||
"ORDER BY period ASC LIMIT 1",
|
||||
(department_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return (True, "ok")
|
||||
whitelist = _deserialize_limit_value("model_whitelist", row["limit_value"])
|
||||
if not isinstance(whitelist, list):
|
||||
return (True, "ok")
|
||||
if model in whitelist:
|
||||
return (True, "ok")
|
||||
return (
|
||||
False,
|
||||
f"model {model!r} not in department {department_id!r} whitelist",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: aiosqlite.Row | Any) -> dict[str, Any]:
|
||||
"""Convert a ``department_quotas`` row to a JSON-safe dict."""
|
||||
quota_type = row["quota_type"]
|
||||
return {
|
||||
"id": row["id"],
|
||||
"department_id": row["department_id"],
|
||||
"quota_type": quota_type,
|
||||
"limit_value": _deserialize_limit_value(quota_type, row["limit_value"]),
|
||||
"period": row["period"],
|
||||
"updated_at": row["updated_at"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton (overridable in tests via set_quota_service)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_quota_service: QuotaService | None = None
|
||||
|
||||
|
||||
def get_quota_service() -> QuotaService:
|
||||
"""Return the process-wide :class:`QuotaService` (lazy singleton)."""
|
||||
global _quota_service
|
||||
if _quota_service is None:
|
||||
_quota_service = QuotaService()
|
||||
return _quota_service
|
||||
|
||||
|
||||
def set_quota_service(service: QuotaService | None) -> None:
|
||||
"""Inject a custom :class:`QuotaService` (used by tests)."""
|
||||
global _quota_service
|
||||
_quota_service = service
|
||||
|
|
@ -23,6 +23,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from agentkit.server.admin.department_service import get_department_service
|
||||
from agentkit.server.admin.llm_config_service import (
|
||||
LlmConfigService,
|
||||
get_llm_config_service,
|
||||
)
|
||||
from agentkit.server.admin.quota_service import get_quota_service
|
||||
from agentkit.server.admin.user_service import get_user_service
|
||||
from agentkit.server.auth.dependencies import require_authenticated
|
||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH, init_auth_db
|
||||
|
|
@ -550,3 +555,302 @@ async def list_user_departments(
|
|||
db_path = await _ensure_db(request)
|
||||
svc = get_user_service()
|
||||
return await svc.list_user_departments(db_path, user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM config endpoints (U5) — provider CRUD + fallback chains
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_llm_config_service(request: Request) -> LlmConfigService:
|
||||
"""Resolve the :class:`LlmConfigService` from ``app.state`` or the
|
||||
module singleton.
|
||||
|
||||
The service needs the YAML config path, which is read from
|
||||
``app.state.server_config._config_path`` (same source as the
|
||||
existing ``GET/PUT /settings/llm`` endpoints). Falls back to the
|
||||
module singleton (which tests can pre-populate via
|
||||
:func:`set_llm_config_service`).
|
||||
"""
|
||||
server_config = getattr(request.app.state, "server_config", None)
|
||||
if server_config is not None and getattr(server_config, "_config_path", None):
|
||||
try:
|
||||
return get_llm_config_service(server_config._config_path)
|
||||
except ValueError:
|
||||
# Singleton was reset between calls — re-initialize with the
|
||||
# resolved path.
|
||||
from agentkit.server.admin.llm_config_service import set_llm_config_service
|
||||
|
||||
svc = LlmConfigService(Path(server_config._config_path))
|
||||
set_llm_config_service(svc)
|
||||
return svc
|
||||
# Fall back to the existing singleton (tests inject it directly).
|
||||
return get_llm_config_service()
|
||||
|
||||
|
||||
class ProviderCreateRequest(BaseModel):
|
||||
"""Body for ``POST /admin/llm/providers``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
type: str = "openai"
|
||||
api_key: str
|
||||
base_url: str = ""
|
||||
models: dict[str, Any] | None = None
|
||||
max_tokens: int = 4096
|
||||
timeout: float = 60.0
|
||||
|
||||
|
||||
class ProviderUpdateRequest(BaseModel):
|
||||
"""Body for ``PATCH /admin/llm/providers/{name}``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
models: dict[str, Any] | None = None
|
||||
max_tokens: int | None = None
|
||||
timeout: float | None = None
|
||||
|
||||
|
||||
class ApiKeySetRequest(BaseModel):
|
||||
"""Body for ``POST /admin/llm/providers/{name}/api-key``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
api_key: str
|
||||
|
||||
|
||||
class FallbackSetRequest(BaseModel):
|
||||
"""Body for ``PUT /admin/llm/fallbacks/{model}``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chain: list[str]
|
||||
|
||||
|
||||
@admin_router.get("/llm/providers")
|
||||
async def list_llm_providers(
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all LLM providers (API keys masked)."""
|
||||
svc = _get_llm_config_service(request)
|
||||
return svc.list_providers()
|
||||
|
||||
|
||||
@admin_router.post("/llm/providers", status_code=201)
|
||||
async def create_llm_provider(
|
||||
payload: ProviderCreateRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new LLM provider.
|
||||
|
||||
Returns 201 with the provider dict (masked key) on success, 409 if
|
||||
a provider with the same name already exists.
|
||||
"""
|
||||
svc = _get_llm_config_service(request)
|
||||
try:
|
||||
return svc.create_provider(
|
||||
name=payload.name,
|
||||
provider_type=payload.type,
|
||||
api_key=payload.api_key,
|
||||
base_url=payload.base_url,
|
||||
models=payload.models,
|
||||
max_tokens=payload.max_tokens,
|
||||
timeout=payload.timeout,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.patch("/llm/providers/{name}")
|
||||
async def update_llm_provider(
|
||||
name: str,
|
||||
payload: ProviderUpdateRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Partially update an LLM provider.
|
||||
|
||||
Returns 200 with the updated provider dict (masked key), 404 if the
|
||||
provider does not exist.
|
||||
"""
|
||||
svc = _get_llm_config_service(request)
|
||||
try:
|
||||
return svc.update_provider(
|
||||
name=name,
|
||||
provider_type=payload.type,
|
||||
api_key=payload.api_key,
|
||||
base_url=payload.base_url,
|
||||
models=payload.models,
|
||||
max_tokens=payload.max_tokens,
|
||||
timeout=payload.timeout,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.delete("/llm/providers/{name}")
|
||||
async def delete_llm_provider(
|
||||
name: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Delete an LLM provider.
|
||||
|
||||
Returns 200 ``{deleted: true}`` on success, 404 if the provider
|
||||
does not exist, 400 if the provider is used in a fallback chain.
|
||||
"""
|
||||
svc = _get_llm_config_service(request)
|
||||
try:
|
||||
deleted = svc.delete_provider(name)
|
||||
except ValueError as exc:
|
||||
msg = str(exc)
|
||||
if "not found" in msg:
|
||||
raise HTTPException(status_code=404, detail=msg) from exc
|
||||
raise HTTPException(status_code=400, detail=msg) from exc
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@admin_router.post("/llm/providers/{name}/api-key")
|
||||
async def set_llm_provider_api_key(
|
||||
name: str,
|
||||
payload: ApiKeySetRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Set the API key for a provider.
|
||||
|
||||
The plaintext key is written to ``.env`` (not the YAML file); the
|
||||
YAML stores a ``${ENV_VAR}`` reference. Returns 200 with the
|
||||
updated provider dict (masked key).
|
||||
"""
|
||||
svc = _get_llm_config_service(request)
|
||||
return svc.set_api_key(name, payload.api_key)
|
||||
|
||||
|
||||
@admin_router.get("/llm/fallbacks")
|
||||
async def get_llm_fallbacks(
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, list[str]]:
|
||||
"""Return all fallback chains (model → list of provider/model refs)."""
|
||||
svc = _get_llm_config_service(request)
|
||||
return svc.get_fallbacks()
|
||||
|
||||
|
||||
@admin_router.put("/llm/fallbacks/{model}")
|
||||
async def set_llm_fallback(
|
||||
model: str,
|
||||
payload: FallbackSetRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Set the fallback chain for a model."""
|
||||
svc = _get_llm_config_service(request)
|
||||
return svc.set_fallback(model, payload.chain)
|
||||
|
||||
|
||||
@admin_router.delete("/llm/fallbacks/{model}")
|
||||
async def delete_llm_fallback(
|
||||
model: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Delete the fallback chain for a model.
|
||||
|
||||
Returns 200 ``{deleted: true}`` if a chain was deleted, 404 if no
|
||||
chain existed for ``model``.
|
||||
"""
|
||||
svc = _get_llm_config_service(request)
|
||||
deleted = svc.delete_fallback(model)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Fallback chain not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Department quota endpoints (U5) — per-department LLM quotas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class QuotaSetRequest(BaseModel):
|
||||
"""Body for ``PUT /admin/departments/{id}/quotas``."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
quota_type: str
|
||||
limit_value: int | list[str]
|
||||
period: str = "daily"
|
||||
|
||||
|
||||
@admin_router.get("/departments/{department_id}/quotas")
|
||||
async def list_department_quotas(
|
||||
department_id: str,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List all quotas for a department."""
|
||||
db_path = await _ensure_db(request)
|
||||
svc = get_quota_service()
|
||||
return await svc.list_department_quotas(db_path, department_id)
|
||||
|
||||
|
||||
@admin_router.put("/departments/{department_id}/quotas")
|
||||
async def set_department_quota(
|
||||
department_id: str,
|
||||
payload: QuotaSetRequest,
|
||||
request: Request,
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Set (upsert) a quota for a department.
|
||||
|
||||
Returns 200 with the upserted quota dict. Returns 400 if
|
||||
``quota_type`` or ``period`` is invalid, or if ``limit_value``
|
||||
doesn't match the quota type (int for token/cost limits, list for
|
||||
model_whitelist).
|
||||
"""
|
||||
db_path = await _ensure_db(request)
|
||||
svc = get_quota_service()
|
||||
try:
|
||||
return await svc.set_quota(
|
||||
db_path,
|
||||
department_id,
|
||||
payload.quota_type,
|
||||
payload.limit_value,
|
||||
payload.period,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@admin_router.delete("/departments/{department_id}/quotas")
|
||||
async def delete_department_quota(
|
||||
department_id: str,
|
||||
request: Request,
|
||||
quota_type: str,
|
||||
period: str = "daily",
|
||||
admin: dict[str, Any] = Depends(_require_admin),
|
||||
) -> dict[str, Any]:
|
||||
"""Delete a quota for a department.
|
||||
|
||||
Query params: ``quota_type`` (required), ``period`` (default
|
||||
``daily``). Returns 200 ``{deleted: true}`` if a quota was deleted,
|
||||
404 if no quota matched, 400 if ``quota_type`` or ``period`` is
|
||||
invalid.
|
||||
"""
|
||||
db_path = await _ensure_db(request)
|
||||
svc = get_quota_service()
|
||||
try:
|
||||
deleted = await svc.delete_quota(db_path, department_id, quota_type, period)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Quota not found")
|
||||
return {"deleted": True}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,464 @@
|
|||
"""Integration tests for the LLM config admin routes (U5).
|
||||
|
||||
Uses FastAPI TestClient with a test app that mounts only the
|
||||
``admin_router`` from ``routes.admin``. The ``_require_admin`` dependency
|
||||
is overridden via ``app.dependency_overrides`` so the tests don't need
|
||||
real JWTs — they can simulate admin and non-admin callers directly.
|
||||
|
||||
The :class:`LlmConfigService` is injected via
|
||||
:func:`set_llm_config_service` so the tests can point it at a temp
|
||||
YAML file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.server.admin.llm_config_service import (
|
||||
LlmConfigService,
|
||||
set_llm_config_service,
|
||||
)
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
from agentkit.server.routes import admin as admin_routes_module
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sample_config() -> dict[str, Any]:
|
||||
"""A minimal agentkit.yaml-style config for testing."""
|
||||
return {
|
||||
"server": {"host": "0.0.0.0", "port": 8001},
|
||||
"llm": {
|
||||
"providers": {
|
||||
"openai": {
|
||||
"type": "openai",
|
||||
"api_key": "sk-test-12345678",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"models": {"gpt-4o": {}},
|
||||
"max_tokens": 4096,
|
||||
"timeout": 120.0,
|
||||
},
|
||||
},
|
||||
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
||||
"fallbacks": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_path(tmp_path: Path) -> Path:
|
||||
"""Create a temporary agentkit.yaml config file."""
|
||||
path = tmp_path / "agentkit.yaml"
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def tmp_auth_db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "admin_llm.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_llm_config_singleton():
|
||||
"""Reset the LlmConfigService singleton before and after each test."""
|
||||
set_llm_config_service(None)
|
||||
yield
|
||||
set_llm_config_service(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_app(config_path: Path, tmp_auth_db: Path) -> FastAPI:
|
||||
"""A minimal FastAPI app with only the admin router mounted.
|
||||
|
||||
The ``_require_admin`` dependency is overridden to return a fake
|
||||
admin user. The :class:`LlmConfigService` singleton is pre-populated
|
||||
so the routes use the temp YAML file.
|
||||
"""
|
||||
# Inject the LlmConfigService singleton pointing at the temp YAML.
|
||||
set_llm_config_service(LlmConfigService(config_path))
|
||||
|
||||
app = FastAPI()
|
||||
app.state.auth_db_path = str(tmp_auth_db)
|
||||
app.include_router(admin_routes_module.admin_router, prefix="/api/v1")
|
||||
|
||||
# Default: allow admin access.
|
||||
app.dependency_overrides[admin_routes_module._require_admin] = lambda: _make_admin_user()
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(admin_app: FastAPI) -> TestClient:
|
||||
return TestClient(admin_app)
|
||||
|
||||
|
||||
def _make_admin_user() -> dict[str, Any]:
|
||||
return {"user_id": "admin-1", "username": "admin", "role": "admin"}
|
||||
|
||||
|
||||
def _raise_forbidden() -> dict[str, Any]:
|
||||
"""Dependency override that simulates a non-admin (403) response."""
|
||||
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListProviders:
|
||||
def test_list_returns_existing_providers(self, admin_client: TestClient):
|
||||
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||
assert resp.status_code == 200
|
||||
providers = resp.json()
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["name"] == "openai"
|
||||
|
||||
def test_list_masks_api_keys(self, admin_client: TestClient):
|
||||
resp = admin_client.get("/api/v1/admin/llm/providers")
|
||||
providers = resp.json()
|
||||
assert providers[0]["api_key"].startswith("****")
|
||||
# Should not contain the plaintext key.
|
||||
assert "sk-test-12345678" not in providers[0]["api_key"]
|
||||
|
||||
|
||||
class TestCreateProvider:
|
||||
def test_create_returns_201(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={
|
||||
"name": "anthropic",
|
||||
"type": "anthropic",
|
||||
"api_key": "sk-ant-test-abcd1234",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"models": {"claude-sonnet-4-20250514": {}},
|
||||
"max_tokens": 8192,
|
||||
"timeout": 90.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["name"] == "anthropic"
|
||||
assert body["type"] == "anthropic"
|
||||
# API key should be masked.
|
||||
assert body["api_key"].startswith("****")
|
||||
|
||||
def test_create_duplicate_returns_409(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={
|
||||
"name": "openai",
|
||||
"type": "openai",
|
||||
"api_key": "sk-duplicate",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_non_admin_returns_403(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={"name": "x", "type": "openai", "api_key": "sk-x"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
class TestUpdateProvider:
|
||||
def test_update_partial_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.patch(
|
||||
"/api/v1/admin/llm/providers/openai",
|
||||
json={"max_tokens": 2048},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["max_tokens"] == 2048
|
||||
# Other fields should be preserved.
|
||||
assert body["base_url"] == "https://api.openai.com/v1"
|
||||
|
||||
def test_update_unknown_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.patch(
|
||||
"/api/v1/admin/llm/providers/nonexistent",
|
||||
json={"max_tokens": 2048},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_with_masked_api_key_preserves_existing(
|
||||
self, admin_client: TestClient, config_path: Path
|
||||
):
|
||||
resp = admin_client.patch(
|
||||
"/api/v1/admin/llm/providers/openai",
|
||||
json={"api_key": "****5678"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Verify the YAML file still has the original key.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
|
||||
|
||||
|
||||
class TestDeleteProvider:
|
||||
def test_delete_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.delete("/api/v1/admin/llm/providers/openai")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
def test_delete_unknown_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.delete("/api/v1/admin/llm/providers/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_provider_used_in_fallback_returns_400(self, admin_client: TestClient):
|
||||
# Add a second provider and a fallback chain that references it.
|
||||
admin_client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={
|
||||
"name": "anthropic",
|
||||
"type": "anthropic",
|
||||
"api_key": "sk-ant-test",
|
||||
},
|
||||
)
|
||||
admin_client.put(
|
||||
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||
json={"chain": ["anthropic"]},
|
||||
)
|
||||
resp = admin_client.delete("/api/v1/admin/llm/providers/anthropic")
|
||||
assert resp.status_code == 400
|
||||
assert "fallback" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API key management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetApiKey:
|
||||
def test_set_api_key_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.post(
|
||||
"/api/v1/admin/llm/providers/openai/api-key",
|
||||
json={"api_key": "sk-brand-new-key-1234"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["api_key"].startswith("****")
|
||||
|
||||
def test_set_api_key_writes_to_env_file(self, admin_client: TestClient, config_path: Path):
|
||||
admin_client.post(
|
||||
"/api/v1/admin/llm/providers/openai/api-key",
|
||||
json={"api_key": "sk-brand-new-key-1234"},
|
||||
)
|
||||
env_path = config_path.parent / ".env"
|
||||
assert env_path.exists()
|
||||
with open(env_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback chain management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallbackRoutes:
|
||||
def test_get_fallbacks_returns_empty_when_none(self, admin_client: TestClient):
|
||||
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {}
|
||||
|
||||
def test_set_fallback_returns_200(self, admin_client: TestClient):
|
||||
resp = admin_client.put(
|
||||
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||
json={"chain": ["openai", "anthropic"]},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["model"] == "gpt-4o"
|
||||
assert body["chain"] == ["openai", "anthropic"]
|
||||
|
||||
def test_get_fallbacks_returns_set_chain(self, admin_client: TestClient):
|
||||
admin_client.put(
|
||||
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||
json={"chain": ["openai", "anthropic"]},
|
||||
)
|
||||
resp = admin_client.get("/api/v1/admin/llm/fallbacks")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"gpt-4o": ["openai", "anthropic"]}
|
||||
|
||||
def test_delete_fallback_returns_200(self, admin_client: TestClient):
|
||||
admin_client.put(
|
||||
"/api/v1/admin/llm/fallbacks/gpt-4o",
|
||||
json={"chain": ["openai"]},
|
||||
)
|
||||
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/gpt-4o")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
def test_delete_unknown_fallback_returns_404(self, admin_client: TestClient):
|
||||
resp = admin_client.delete("/api/v1/admin/llm/fallbacks/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Department quota endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDepartmentQuotaRoutes:
|
||||
def test_set_quota_returns_200(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={
|
||||
"quota_type": "token_limit",
|
||||
"limit_value": 1000,
|
||||
"period": "daily",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["quota_type"] == "token_limit"
|
||||
assert body["limit_value"] == 1000
|
||||
assert body["period"] == "daily"
|
||||
|
||||
def test_set_quota_invalid_type_returns_400(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={"quota_type": "invalid", "limit_value": 1000},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_set_quota_model_whitelist_accepts_list(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={
|
||||
"quota_type": "model_whitelist",
|
||||
"limit_value": ["gpt-4o", "claude"],
|
||||
"period": "daily",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["limit_value"] == ["gpt-4o", "claude"]
|
||||
|
||||
def test_list_quotas_returns_empty_for_new_department(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
def test_list_quotas_returns_set_quotas(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||
)
|
||||
admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={
|
||||
"quota_type": "cost_limit",
|
||||
"limit_value": 5000,
|
||||
"period": "monthly",
|
||||
},
|
||||
)
|
||||
resp = admin_client.get(f"/api/v1/admin/departments/{dept_id}/quotas")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body) == 2
|
||||
types = {q["quota_type"] for q in body}
|
||||
assert types == {"token_limit", "cost_limit"}
|
||||
|
||||
def test_delete_quota_returns_200(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
admin_client.put(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||
)
|
||||
resp = admin_client.delete(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
params={"quota_type": "token_limit", "period": "daily"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": True}
|
||||
|
||||
def test_delete_unknown_quota_returns_404(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.delete(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
params={"quota_type": "token_limit", "period": "daily"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_quota_invalid_period_returns_400(self, admin_client: TestClient):
|
||||
dept_id = str(uuid.uuid4())
|
||||
resp = admin_client.delete(
|
||||
f"/api/v1/admin/departments/{dept_id}/quotas",
|
||||
params={"quota_type": "token_limit", "period": "weekly"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-admin access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNonAdminAccess:
|
||||
"""All LLM config endpoints must return 403 for non-admin users."""
|
||||
|
||||
def test_non_admin_cannot_list_providers(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.get("/api/v1/admin/llm/providers")
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_create_provider(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post(
|
||||
"/api/v1/admin/llm/providers",
|
||||
json={"name": "x", "type": "openai", "api_key": "sk-x"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_delete_provider(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.delete("/api/v1/admin/llm/providers/openai")
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_set_api_key(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.post(
|
||||
"/api/v1/admin/llm/providers/openai/api-key",
|
||||
json={"api_key": "sk-x"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_get_fallbacks(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.get("/api/v1/admin/llm/fallbacks")
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_non_admin_cannot_set_quota(self, admin_app: FastAPI):
|
||||
admin_app.dependency_overrides[admin_routes_module._require_admin] = _raise_forbidden
|
||||
client = TestClient(admin_app)
|
||||
resp = client.put(
|
||||
f"/api/v1/admin/departments/{uuid.uuid4()}/quotas",
|
||||
json={"quota_type": "token_limit", "limit_value": 1000},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
"""Unit tests for LlmConfigService (U5 — LLM provider/fallback CRUD).
|
||||
|
||||
Covers:
|
||||
- Happy path: list → get → create → update → delete providers
|
||||
- API key management: set_api_key writes to .env, masks in responses
|
||||
- Fallback chain: get → set → delete
|
||||
- Edge case: create duplicate provider → ValueError
|
||||
- Edge case: update non-existent provider → ValueError
|
||||
- Edge case: delete provider used in fallback → ValueError
|
||||
- Edge case: delete non-existent provider → ValueError
|
||||
- File locking: concurrent writes don't corrupt (sequential simulation)
|
||||
- Singleton helpers (get/set_llm_config_service)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from agentkit.server.admin.llm_config_service import (
|
||||
LlmConfigService,
|
||||
get_llm_config_service,
|
||||
set_llm_config_service,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sample_config() -> dict[str, Any]:
|
||||
"""A minimal agentkit.yaml-style config for testing."""
|
||||
return {
|
||||
"server": {"host": "0.0.0.0", "port": 8001},
|
||||
"llm": {
|
||||
"providers": {
|
||||
"openai": {
|
||||
"type": "openai",
|
||||
"api_key": "sk-test-12345678",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"models": {"gpt-4o": {}},
|
||||
"max_tokens": 4096,
|
||||
"timeout": 120.0,
|
||||
},
|
||||
},
|
||||
"model_aliases": {"gpt4": "openai/gpt-4o"},
|
||||
"fallbacks": {},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_path(tmp_path: Path) -> Path:
|
||||
"""Create a temporary agentkit.yaml config file."""
|
||||
path = tmp_path / "agentkit.yaml"
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(_sample_config(), f, default_flow_style=False, allow_unicode=True)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(config_path: Path) -> LlmConfigService:
|
||||
return LlmConfigService(config_path)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton():
|
||||
"""Reset the module singleton before and after each test."""
|
||||
set_llm_config_service(None)
|
||||
yield
|
||||
set_llm_config_service(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider CRUD happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProviderCrudHappyPath:
|
||||
def test_list_returns_existing_providers(self, service: LlmConfigService):
|
||||
providers = service.list_providers()
|
||||
assert len(providers) == 1
|
||||
assert providers[0]["name"] == "openai"
|
||||
assert providers[0]["type"] == "openai"
|
||||
|
||||
def test_list_masks_api_keys(self, service: LlmConfigService):
|
||||
providers = service.list_providers()
|
||||
# API key should be masked: ****xxxx
|
||||
assert providers[0]["api_key"].startswith("****")
|
||||
# Should show only last 4 chars
|
||||
assert providers[0]["api_key"] == "****5678"
|
||||
|
||||
def test_get_returns_provider_by_name(self, service: LlmConfigService):
|
||||
provider = service.get_provider("openai")
|
||||
assert provider is not None
|
||||
assert provider["name"] == "openai"
|
||||
assert provider["type"] == "openai"
|
||||
assert provider["base_url"] == "https://api.openai.com/v1"
|
||||
|
||||
def test_get_returns_none_for_unknown(self, service: LlmConfigService):
|
||||
assert service.get_provider("nonexistent") is None
|
||||
|
||||
def test_create_adds_provider_to_yaml(self, service: LlmConfigService, config_path: Path):
|
||||
created = service.create_provider(
|
||||
name="anthropic",
|
||||
provider_type="anthropic",
|
||||
api_key="sk-ant-test-abcd1234",
|
||||
base_url="https://api.anthropic.com",
|
||||
models={"claude-sonnet-4-20250514": {}},
|
||||
max_tokens=8192,
|
||||
timeout=90.0,
|
||||
)
|
||||
assert created["name"] == "anthropic"
|
||||
assert created["type"] == "anthropic"
|
||||
# API key should be masked in the response.
|
||||
assert created["api_key"].startswith("****")
|
||||
# Verify the YAML file was updated.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert "anthropic" in saved["llm"]["providers"]
|
||||
# The YAML should store a ${VAR} reference, not the plaintext key.
|
||||
assert saved["llm"]["providers"]["anthropic"]["api_key"] == "${ANTHROPIC_API_KEY}"
|
||||
|
||||
def test_create_writes_api_key_to_env_file(self, service: LlmConfigService, config_path: Path):
|
||||
service.create_provider(
|
||||
name="anthropic",
|
||||
provider_type="anthropic",
|
||||
api_key="sk-ant-test-abcd1234",
|
||||
)
|
||||
env_path = config_path.parent / ".env"
|
||||
assert env_path.exists()
|
||||
with open(env_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "ANTHROPIC_API_KEY=sk-ant-test-abcd1234" in content
|
||||
|
||||
def test_update_partial_only_changes_provided_fields(
|
||||
self, service: LlmConfigService, config_path: Path
|
||||
):
|
||||
updated = service.update_provider("openai", max_tokens=2048)
|
||||
assert updated["max_tokens"] == 2048
|
||||
# Other fields should be preserved.
|
||||
assert updated["base_url"] == "https://api.openai.com/v1"
|
||||
assert updated["type"] == "openai"
|
||||
|
||||
def test_update_with_masked_api_key_preserves_existing(
|
||||
self, service: LlmConfigService, config_path: Path
|
||||
):
|
||||
"""When user sends back a masked key (****xxxx), the real key should
|
||||
be preserved."""
|
||||
service.update_provider("openai", api_key="****5678")
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
# The original plaintext key should be preserved.
|
||||
assert saved["llm"]["providers"]["openai"]["api_key"] == "sk-test-12345678"
|
||||
|
||||
def test_update_with_new_api_key_writes_to_env(
|
||||
self, service: LlmConfigService, config_path: Path
|
||||
):
|
||||
service.update_provider("openai", api_key="sk-new-key-9999")
|
||||
env_path = config_path.parent / ".env"
|
||||
with open(env_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "OPENAI_API_KEY=sk-new-key-9999" in content
|
||||
# YAML should now have a ${VAR} reference.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
|
||||
|
||||
def test_delete_removes_provider(self, service: LlmConfigService, config_path: Path):
|
||||
deleted = service.delete_provider("openai")
|
||||
assert deleted is True
|
||||
# Verify the YAML file was updated.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert "openai" not in saved["llm"]["providers"]
|
||||
# list_providers should return empty.
|
||||
assert service.list_providers() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider CRUD edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProviderCrudEdgeCases:
|
||||
def test_create_duplicate_raises_value_error(self, service: LlmConfigService):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.create_provider(
|
||||
name="openai",
|
||||
provider_type="openai",
|
||||
api_key="sk-duplicate",
|
||||
)
|
||||
|
||||
def test_update_nonexistent_raises_value_error(self, service: LlmConfigService):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_provider("nonexistent", max_tokens=2048)
|
||||
|
||||
def test_delete_nonexistent_raises_value_error(self, service: LlmConfigService):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.delete_provider("nonexistent")
|
||||
|
||||
def test_delete_provider_used_in_fallback_raises_value_error(self, service: LlmConfigService):
|
||||
# Add a second provider and a fallback chain that references it.
|
||||
service.create_provider(
|
||||
name="anthropic",
|
||||
provider_type="anthropic",
|
||||
api_key="sk-ant-test",
|
||||
)
|
||||
service.set_fallback("gpt-4o", ["anthropic"])
|
||||
with pytest.raises(ValueError, match="used in fallback"):
|
||||
service.delete_provider("anthropic")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# API key management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApiKeyManagement:
|
||||
def test_set_api_key_writes_to_env_file(self, service: LlmConfigService, config_path: Path):
|
||||
service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||
env_path = config_path.parent / ".env"
|
||||
assert env_path.exists()
|
||||
with open(env_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
assert "OPENAI_API_KEY=sk-brand-new-key-1234" in content
|
||||
|
||||
def test_set_api_key_updates_yaml_reference(self, service: LlmConfigService, config_path: Path):
|
||||
service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert saved["llm"]["providers"]["openai"]["api_key"] == "${OPENAI_API_KEY}"
|
||||
|
||||
def test_set_api_key_returns_masked_provider(self, service: LlmConfigService):
|
||||
result = service.set_api_key("openai", "sk-brand-new-key-1234")
|
||||
assert result["api_key"].startswith("****")
|
||||
# Should not contain the plaintext key.
|
||||
assert "sk-brand-new-key-1234" not in result["api_key"]
|
||||
|
||||
def test_set_api_key_for_new_provider_creates_stub(
|
||||
self, service: LlmConfigService, config_path: Path
|
||||
):
|
||||
"""Setting an API key for a provider that doesn't yet exist in the
|
||||
YAML should create a stub entry."""
|
||||
result = service.set_api_key("gemini", "AIza-test-key-1234")
|
||||
assert result["name"] == "gemini"
|
||||
assert result["type"] == "openai" # default type for stub
|
||||
assert result["api_key"].startswith("****")
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert "gemini" in saved["llm"]["providers"]
|
||||
assert saved["llm"]["providers"]["gemini"]["api_key"] == "${GEMINI_API_KEY}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback chain management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFallbackManagement:
|
||||
def test_get_fallbacks_returns_empty_when_none_configured(self, service: LlmConfigService):
|
||||
fallbacks = service.get_fallbacks()
|
||||
assert fallbacks == {}
|
||||
|
||||
def test_set_fallback_adds_chain(self, service: LlmConfigService, config_path: Path):
|
||||
result = service.set_fallback("gpt-4o", ["openai", "anthropic"])
|
||||
assert result["model"] == "gpt-4o"
|
||||
assert result["chain"] == ["openai", "anthropic"]
|
||||
# Verify it was written to YAML.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
assert saved["llm"]["fallbacks"]["gpt-4o"] == ["openai", "anthropic"]
|
||||
|
||||
def test_get_fallbacks_returns_set_chain(self, service: LlmConfigService):
|
||||
service.set_fallback("gpt-4o", ["openai", "anthropic"])
|
||||
fallbacks = service.get_fallbacks()
|
||||
assert fallbacks == {"gpt-4o": ["openai", "anthropic"]}
|
||||
|
||||
def test_set_fallback_overwrites_existing(self, service: LlmConfigService):
|
||||
service.set_fallback("gpt-4o", ["openai"])
|
||||
service.set_fallback("gpt-4o", ["anthropic", "gemini"])
|
||||
fallbacks = service.get_fallbacks()
|
||||
assert fallbacks["gpt-4o"] == ["anthropic", "gemini"]
|
||||
|
||||
def test_delete_fallback_removes_chain(self, service: LlmConfigService):
|
||||
service.set_fallback("gpt-4o", ["openai"])
|
||||
deleted = service.delete_fallback("gpt-4o")
|
||||
assert deleted is True
|
||||
assert service.get_fallbacks() == {}
|
||||
|
||||
def test_delete_fallback_returns_false_for_unset(self, service: LlmConfigService):
|
||||
deleted = service.delete_fallback("nonexistent")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File locking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileLocking:
|
||||
def test_concurrent_writes_do_not_corrupt(self, service: LlmConfigService, config_path: Path):
|
||||
"""Simulate concurrent writes from multiple threads.
|
||||
|
||||
Each thread creates a different provider. The file lock should
|
||||
serialize the writes so the final YAML is valid and contains
|
||||
all providers.
|
||||
"""
|
||||
errors: list[Exception] = []
|
||||
|
||||
def _create_provider(name: str, api_key: str) -> None:
|
||||
try:
|
||||
# Each thread needs its own service instance to avoid
|
||||
# sharing state, but they all point to the same file.
|
||||
svc = LlmConfigService(config_path)
|
||||
svc.create_provider(name=name, provider_type="openai", api_key=api_key)
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=_create_provider,
|
||||
args=(f"provider_{i}", f"sk-key-{i:04d}"),
|
||||
)
|
||||
for i in range(5)
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# No errors should have occurred.
|
||||
assert errors == [], f"Concurrent writes failed: {errors}"
|
||||
|
||||
# The YAML file should be valid and contain all 5 new providers
|
||||
# plus the original "openai" provider.
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
saved = yaml.safe_load(f)
|
||||
providers = saved["llm"]["providers"]
|
||||
for i in range(5):
|
||||
assert f"provider_{i}" in providers
|
||||
assert "openai" in providers
|
||||
|
||||
def test_lockfile_is_created(self, service: LlmConfigService, config_path: Path):
|
||||
"""The lockfile should be created on the first write."""
|
||||
service.create_provider(
|
||||
name="anthropic",
|
||||
provider_type="anthropic",
|
||||
api_key="sk-ant-test",
|
||||
)
|
||||
lockfile = config_path.with_suffix(config_path.suffix + ".lock")
|
||||
assert lockfile.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingletonHelpers:
|
||||
def test_get_llm_config_service_requires_path_on_first_call(self):
|
||||
with pytest.raises(ValueError, match="config_path is required"):
|
||||
get_llm_config_service()
|
||||
|
||||
def test_get_llm_config_service_returns_singleton(self, config_path: Path):
|
||||
svc1 = get_llm_config_service(config_path)
|
||||
svc2 = get_llm_config_service() # subsequent call ignores path
|
||||
assert svc1 is svc2
|
||||
|
||||
def test_set_llm_config_service_overrides_singleton(self, config_path: Path):
|
||||
custom = LlmConfigService(config_path)
|
||||
set_llm_config_service(custom)
|
||||
assert get_llm_config_service() is custom
|
||||
|
||||
def test_set_llm_config_service_none_resets_singleton(self, config_path: Path):
|
||||
svc = get_llm_config_service(config_path)
|
||||
set_llm_config_service(None)
|
||||
# Next call should raise (no path provided after reset).
|
||||
with pytest.raises(ValueError, match="config_path is required"):
|
||||
get_llm_config_service()
|
||||
# The previously-created instance should still be usable.
|
||||
assert svc.list_providers() is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env var resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarResolution:
|
||||
def test_list_providers_masks_env_var_referenced_keys(
|
||||
self, config_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""When the YAML stores ``${VAR}``, the masked response should
|
||||
show the masked *resolved* env value, not the literal ``${VAR}``."""
|
||||
# Rewrite the config to use a ${VAR} reference.
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
cfg = _sample_config()
|
||||
cfg["llm"]["providers"]["openai"]["api_key"] = "${TEST_OPENAI_KEY}"
|
||||
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True)
|
||||
# Set the env var so it can be resolved.
|
||||
monkeypatch.setenv("TEST_OPENAI_KEY", "sk-resolved-1234")
|
||||
svc = LlmConfigService(config_path)
|
||||
providers = svc.list_providers()
|
||||
assert providers[0]["api_key"] == "****1234"
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""Unit tests for QuotaService (U5 — per-department LLM quotas).
|
||||
|
||||
Covers:
|
||||
- Happy path: set → get → list → delete quotas
|
||||
- Quota check: token_limit (allowed / denied)
|
||||
- Quota check: cost_limit (allowed / denied)
|
||||
- Quota check: no quota set → always allowed
|
||||
- model_whitelist serialization (list → JSON → list)
|
||||
- Upsert: set same quota twice updates the value
|
||||
- Validation: invalid quota_type / period raises ValueError
|
||||
- is_model_allowed: whitelist enforcement
|
||||
- Singleton helpers (get/set_quota_service)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.server.admin.quota_service import (
|
||||
QuotaService,
|
||||
get_quota_service,
|
||||
set_quota_service,
|
||||
)
|
||||
from agentkit.server.auth.models import init_auth_db
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def fresh_db(tmp_path: Path) -> Path:
|
||||
"""A brand-new auth DB on a fresh path (no data)."""
|
||||
db_path = tmp_path / "auth.db"
|
||||
await init_auth_db(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> QuotaService:
|
||||
return QuotaService()
|
||||
|
||||
|
||||
def _random_dept_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quota CRUD happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuotaCrudHappyPath:
|
||||
async def test_set_token_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
quota = await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||
assert quota["department_id"] == dept_id
|
||||
assert quota["quota_type"] == "token_limit"
|
||||
assert quota["limit_value"] == 1000
|
||||
assert quota["period"] == "daily"
|
||||
assert "id" in quota
|
||||
assert "updated_at" in quota
|
||||
|
||||
async def test_set_cost_limit_returns_quota_dict(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
quota = await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||
assert quota["quota_type"] == "cost_limit"
|
||||
assert quota["limit_value"] == 5000
|
||||
assert quota["period"] == "monthly"
|
||||
|
||||
async def test_set_model_whitelist_serializes_list(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
whitelist = ["gpt-4o", "claude-sonnet-4-20250514", "gemini-pro"]
|
||||
quota = await service.set_quota(
|
||||
fresh_db, dept_id, "model_whitelist", whitelist, period="daily"
|
||||
)
|
||||
assert quota["quota_type"] == "model_whitelist"
|
||||
# The deserialized value should match the input list.
|
||||
assert quota["limit_value"] == whitelist
|
||||
|
||||
async def test_get_returns_set_quota(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
|
||||
assert fetched is not None
|
||||
assert fetched["limit_value"] == 1000
|
||||
assert fetched["quota_type"] == "token_limit"
|
||||
|
||||
async def test_get_returns_none_for_unset_quota(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
fetched = await service.get_quota(fresh_db, dept_id, "token_limit")
|
||||
assert fetched is None
|
||||
|
||||
async def test_list_returns_all_quotas_for_department(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
|
||||
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||
assert len(quotas) == 3
|
||||
types = {q["quota_type"] for q in quotas}
|
||||
assert types == {"token_limit", "cost_limit", "model_whitelist"}
|
||||
|
||||
async def test_list_returns_empty_for_department_with_no_quotas(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||
assert quotas == []
|
||||
|
||||
async def test_delete_removes_quota(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
|
||||
assert deleted is True
|
||||
# Confirm it's gone.
|
||||
assert await service.get_quota(fresh_db, dept_id, "token_limit") is None
|
||||
|
||||
async def test_delete_returns_false_for_unset_quota(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
deleted = await service.delete_quota(fresh_db, dept_id, "token_limit")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upsert behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuotaUpsert:
|
||||
async def test_set_same_quota_twice_updates_value(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
updated = await service.set_quota(fresh_db, dept_id, "token_limit", 2000)
|
||||
assert updated["limit_value"] == 2000
|
||||
# Only one row should exist.
|
||||
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||
assert len(quotas) == 1
|
||||
assert quotas[0]["limit_value"] == 2000
|
||||
|
||||
async def test_set_same_quota_with_different_period_creates_new_row(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="daily")
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 30000, period="monthly")
|
||||
quotas = await service.list_department_quotas(fresh_db, dept_id)
|
||||
assert len(quotas) == 2
|
||||
periods = {q["period"] for q in quotas}
|
||||
assert periods == {"daily", "monthly"}
|
||||
|
||||
async def test_upsert_model_whitelist_replaces_list(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o-mini"])
|
||||
fetched = await service.get_quota(fresh_db, dept_id, "model_whitelist")
|
||||
assert fetched is not None
|
||||
assert fetched["limit_value"] == ["gpt-4o-mini"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quota enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuotaCheck:
|
||||
async def test_token_limit_allowed_when_under_limit(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "token_limit", "daily", current_usage=500
|
||||
)
|
||||
assert allowed is True
|
||||
assert reason == "ok"
|
||||
|
||||
async def test_token_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "token_limit", "daily", current_usage=1000
|
||||
)
|
||||
assert allowed is False
|
||||
assert "exceeded" in reason
|
||||
|
||||
async def test_token_limit_denied_over_limit(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000)
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "token_limit", "daily", current_usage=1500
|
||||
)
|
||||
assert allowed is False
|
||||
assert "exceeded" in reason
|
||||
|
||||
async def test_cost_limit_allowed_when_under_limit(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "cost_limit", "monthly", current_usage=2500
|
||||
)
|
||||
assert allowed is True
|
||||
assert reason == "ok"
|
||||
|
||||
async def test_cost_limit_denied_at_limit(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "cost_limit", 5000, period="monthly")
|
||||
allowed, _ = await service.check_quota(
|
||||
fresh_db, dept_id, "cost_limit", "monthly", current_usage=5000
|
||||
)
|
||||
assert allowed is False
|
||||
|
||||
async def test_no_quota_set_always_allowed(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "token_limit", "daily", current_usage=999999
|
||||
)
|
||||
assert allowed is True
|
||||
assert reason == "ok"
|
||||
|
||||
async def test_model_whitelist_not_a_quota_check(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o"], period="daily")
|
||||
allowed, reason = await service.check_quota(
|
||||
fresh_db, dept_id, "model_whitelist", "daily", current_usage=0
|
||||
)
|
||||
assert allowed is False
|
||||
assert "is_model_allowed" in reason
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model whitelist enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModelWhitelist:
|
||||
async def test_is_model_allowed_when_in_whitelist(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gpt-4o")
|
||||
assert allowed is True
|
||||
assert reason == "ok"
|
||||
|
||||
async def test_is_model_denied_when_not_in_whitelist(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", ["gpt-4o", "claude"])
|
||||
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "gemini-pro")
|
||||
assert allowed is False
|
||||
assert "not in" in reason
|
||||
|
||||
async def test_is_model_allowed_when_no_whitelist(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
# No whitelist set — all models allowed.
|
||||
allowed, reason = await service.is_model_allowed(fresh_db, dept_id, "any-model")
|
||||
assert allowed is True
|
||||
assert reason == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQuotaValidation:
|
||||
async def test_set_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||
await service.set_quota(fresh_db, dept_id, "invalid_type", 1000)
|
||||
|
||||
async def test_set_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="Invalid period"):
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", 1000, period="weekly")
|
||||
|
||||
async def test_set_token_limit_with_list_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="must be an int"):
|
||||
await service.set_quota(fresh_db, dept_id, "token_limit", ["gpt-4o"])
|
||||
|
||||
async def test_set_model_whitelist_with_int_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="must be a list"):
|
||||
await service.set_quota(fresh_db, dept_id, "model_whitelist", 1000)
|
||||
|
||||
async def test_get_invalid_quota_type_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||
await service.get_quota(fresh_db, dept_id, "invalid_type")
|
||||
|
||||
async def test_delete_invalid_period_raises(self, service: QuotaService, fresh_db: Path):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="Invalid period"):
|
||||
await service.delete_quota(fresh_db, dept_id, "token_limit", period="weekly")
|
||||
|
||||
async def test_check_quota_invalid_quota_type_raises(
|
||||
self, service: QuotaService, fresh_db: Path
|
||||
):
|
||||
dept_id = _random_dept_id()
|
||||
with pytest.raises(ValueError, match="Invalid quota_type"):
|
||||
await service.check_quota(fresh_db, dept_id, "invalid_type", "daily", current_usage=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingletonHelpers:
|
||||
def test_get_quota_service_returns_singleton(self):
|
||||
# Save the original singleton so we don't disturb other tests.
|
||||
original = get_quota_service()
|
||||
try:
|
||||
custom = QuotaService()
|
||||
set_quota_service(custom)
|
||||
assert get_quota_service() is custom
|
||||
# Clearing falls back to a new lazy instance.
|
||||
set_quota_service(None)
|
||||
new_one = get_quota_service()
|
||||
assert new_one is not custom
|
||||
finally:
|
||||
set_quota_service(original)
|
||||
Loading…
Reference in New Issue