322 lines
12 KiB
Python
322 lines
12 KiB
Python
"""Unit tests for LLMGateway quota enforcement (U7).
|
|
|
|
Covers:
|
|
- QuotaExceededError raised when token_limit exceeded
|
|
- QuotaExceededError raised when cost_limit exceeded
|
|
- QuotaExceededError raised when model not in whitelist
|
|
- No quota set → request allowed
|
|
- Multi-department: strictest-wins (one exceeds, other doesn't → rejected)
|
|
- QuotaExceededError carries the right metadata
|
|
- Usage recording still attaches user_id + department_id on success
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from agentkit.llm.gateway import LLMGateway, QuotaExceededError
|
|
from agentkit.llm.protocol import (
|
|
LLMProvider,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
TokenUsage,
|
|
)
|
|
from agentkit.llm.providers.usage_store import InMemoryUsageStore
|
|
from agentkit.server.admin.quota_service import (
|
|
get_quota_service,
|
|
set_quota_service,
|
|
)
|
|
from agentkit.server.auth.models import init_auth_db
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test doubles
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FakeProvider(LLMProvider):
|
|
"""A minimal LLMProvider that returns a fixed response."""
|
|
|
|
def __init__(self, name: str = "fake"):
|
|
self._name = name
|
|
self.last_request: LLMRequest | None = None
|
|
self.call_count = 0
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
self.last_request = request
|
|
self.call_count += 1
|
|
return LLMResponse(
|
|
content=f"response from {self._name}",
|
|
model=request.model,
|
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def store() -> InMemoryUsageStore:
|
|
return InMemoryUsageStore()
|
|
|
|
|
|
@pytest.fixture
|
|
def gateway(store: InMemoryUsageStore) -> LLMGateway:
|
|
gw = LLMGateway(usage_store=store)
|
|
gw.register_provider("openai", FakeProvider("openai"))
|
|
return gw
|
|
|
|
|
|
@pytest.fixture
|
|
async def fresh_db(tmp_path: Path) -> Path:
|
|
db_path = tmp_path / "auth.db"
|
|
await init_auth_db(db_path)
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_quota_singleton():
|
|
"""Reset the QuotaService singleton before and after each test."""
|
|
set_quota_service(None)
|
|
yield
|
|
set_quota_service(None)
|
|
|
|
|
|
def _random_dept_id() -> str:
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Quota enforcement
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuotaEnforcement:
|
|
async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""When no quota is configured, the request is allowed."""
|
|
dept_id = _random_dept_id()
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
assert response.content == "response from openai"
|
|
|
|
async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""token_limit quota exceeded → QuotaExceededError."""
|
|
dept_id = _random_dept_id()
|
|
svc = get_quota_service()
|
|
# Set a tiny token limit (1 token) — any usage will exceed it.
|
|
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
|
|
|
|
# Pre-populate the usage store so the daily total > 1.
|
|
gateway._usage_tracker.record(
|
|
agent_name="prev",
|
|
model="openai/gpt-4o",
|
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
|
cost=0.0,
|
|
latency_ms=10,
|
|
user_id="u1",
|
|
department_id=dept_id,
|
|
)
|
|
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
err = exc_info.value
|
|
assert err.department_id == dept_id
|
|
assert err.quota_type == "token_limit"
|
|
assert err.period == "daily"
|
|
assert err.limit == 1
|
|
assert err.current == 150 # 100 prompt + 50 completion
|
|
|
|
async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""cost_limit quota exceeded → QuotaExceededError."""
|
|
dept_id = _random_dept_id()
|
|
svc = get_quota_service()
|
|
# cost_limit is in cents. Set 1 cent.
|
|
await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily")
|
|
|
|
# Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit.
|
|
gateway._usage_tracker.record(
|
|
agent_name="prev",
|
|
model="openai/gpt-4o",
|
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
|
cost=1.00, # $1.00 = 100 cents
|
|
latency_ms=10,
|
|
user_id="u1",
|
|
department_id=dept_id,
|
|
)
|
|
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
err = exc_info.value
|
|
assert err.quota_type == "cost_limit"
|
|
assert err.period == "daily"
|
|
assert err.limit == 1
|
|
# current is in cents (100 cents = $1.00).
|
|
assert err.current == 100.0
|
|
|
|
async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""Model not in whitelist → QuotaExceededError with quota_type=model_whitelist."""
|
|
dept_id = _random_dept_id()
|
|
svc = get_quota_service()
|
|
# Whitelist only allows "claude" — gateway is calling "gpt-4o".
|
|
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily")
|
|
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
err = exc_info.value
|
|
assert err.quota_type == "model_whitelist"
|
|
assert err.department_id == dept_id
|
|
# For model_whitelist, current is the rejected model name.
|
|
assert err.current == "openai/gpt-4o"
|
|
|
|
async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""Model in whitelist → request allowed."""
|
|
dept_id = _random_dept_id()
|
|
svc = get_quota_service()
|
|
# Whitelist uses the full resolved model identifier (provider/model).
|
|
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily")
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
assert response.content == "response from openai"
|
|
|
|
async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""One department exceeds, the other doesn't → rejected (strictest wins)."""
|
|
dept_ok = _random_dept_id()
|
|
dept_bad = _random_dept_id()
|
|
svc = get_quota_service()
|
|
# dept_bad has a 1-token limit; dept_ok has a 1M-token limit.
|
|
await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily")
|
|
await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily")
|
|
|
|
# Pre-populate usage for dept_bad so it exceeds.
|
|
gateway._usage_tracker.record(
|
|
agent_name="prev",
|
|
model="openai/gpt-4o",
|
|
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
|
|
cost=0.0,
|
|
latency_ms=10,
|
|
user_id="u1",
|
|
department_id=dept_bad,
|
|
)
|
|
|
|
with pytest.raises(QuotaExceededError) as exc_info:
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_ok, dept_bad],
|
|
db_path=fresh_db,
|
|
)
|
|
# The error should reference dept_bad (the one that exceeded).
|
|
assert exc_info.value.department_id == dept_bad
|
|
|
|
async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path):
|
|
"""When db_path is None, no quota check is performed."""
|
|
dept_id = _random_dept_id()
|
|
svc = get_quota_service()
|
|
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
|
|
# Even with a quota set, calling without db_path should succeed.
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=None,
|
|
)
|
|
assert response.content == "response from openai"
|
|
|
|
async def test_quota_check_skipped_without_department_ids(
|
|
self, gateway: LLMGateway, fresh_db: Path
|
|
):
|
|
"""When department_ids is None, no quota check is performed."""
|
|
response = await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=None,
|
|
db_path=fresh_db,
|
|
)
|
|
assert response.content == "response from openai"
|
|
|
|
async def test_usage_recorded_with_user_and_department(
|
|
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
|
|
):
|
|
"""After a successful call, the usage record carries user_id + department_id."""
|
|
dept_id = _random_dept_id()
|
|
await gateway.chat(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
model="openai/gpt-4o",
|
|
user_id="u1",
|
|
department_ids=[dept_id],
|
|
db_path=fresh_db,
|
|
)
|
|
summary = store.get_usage()
|
|
assert len(summary.records) == 1
|
|
rec = summary.records[0]
|
|
assert rec.user_id == "u1"
|
|
assert rec.department_id == dept_id
|
|
assert rec.model == "gpt-4o"
|
|
assert rec.total_tokens == 150 # 100 prompt + 50 completion
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# QuotaExceededError dataclass-like behavior
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestQuotaExceededError:
|
|
def test_error_message_includes_metadata(self):
|
|
err = QuotaExceededError(
|
|
department_id="d1",
|
|
quota_type="token_limit",
|
|
period="daily",
|
|
limit=1000,
|
|
current=1500,
|
|
)
|
|
msg = str(err)
|
|
assert "d1" in msg
|
|
assert "token_limit" in msg
|
|
assert "daily" in msg
|
|
assert "1000" in msg
|
|
assert "1500" in msg
|
|
|
|
def test_error_attributes_preserved(self):
|
|
err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000)
|
|
assert err.department_id == "d1"
|
|
assert err.quota_type == "cost_limit"
|
|
assert err.period == "monthly"
|
|
assert err.limit == 5000
|
|
assert err.current == 6000
|