fischer-agentkit/tests/unit/llm/test_quota_enforcement.py

322 lines
12 KiB
Python

"""Unit tests for LLMGateway quota enforcement (U7).
Covers:
- QuotaExceededError raised when token_limit exceeded
- QuotaExceededError raised when cost_limit exceeded
- QuotaExceededError raised when model not in whitelist
- No quota set → request allowed
- Multi-department: strictest-wins (one exceeds, other doesn't → rejected)
- QuotaExceededError carries the right metadata
- Usage recording still attaches user_id + department_id on success
"""
from __future__ import annotations
import uuid
from pathlib import Path
import pytest
from agentkit.llm.gateway import LLMGateway, QuotaExceededError
from agentkit.llm.protocol import (
LLMProvider,
LLMRequest,
LLMResponse,
TokenUsage,
)
from agentkit.llm.providers.usage_store import InMemoryUsageStore
from agentkit.server.admin.quota_service import (
get_quota_service,
set_quota_service,
)
from agentkit.server.auth.models import init_auth_db
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class FakeProvider(LLMProvider):
"""A minimal LLMProvider that returns a fixed response."""
def __init__(self, name: str = "fake"):
self._name = name
self.last_request: LLMRequest | None = None
self.call_count = 0
async def chat(self, request: LLMRequest) -> LLMResponse:
self.last_request = request
self.call_count += 1
return LLMResponse(
content=f"response from {self._name}",
model=request.model,
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def store() -> InMemoryUsageStore:
return InMemoryUsageStore()
@pytest.fixture
def gateway(store: InMemoryUsageStore) -> LLMGateway:
gw = LLMGateway(usage_store=store)
gw.register_provider("openai", FakeProvider("openai"))
return gw
@pytest.fixture
async def fresh_db(tmp_path: Path) -> Path:
db_path = tmp_path / "auth.db"
await init_auth_db(db_path)
return db_path
@pytest.fixture(autouse=True)
def _reset_quota_singleton():
"""Reset the QuotaService singleton before and after each test."""
set_quota_service(None)
yield
set_quota_service(None)
def _random_dept_id() -> str:
return str(uuid.uuid4())
# ---------------------------------------------------------------------------
# Quota enforcement
# ---------------------------------------------------------------------------
class TestQuotaEnforcement:
async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path):
"""When no quota is configured, the request is allowed."""
dept_id = _random_dept_id()
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
"""token_limit quota exceeded → QuotaExceededError."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Set a tiny token limit (1 token) — any usage will exceed it.
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
# Pre-populate the usage store so the daily total > 1.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=0.0,
latency_ms=10,
user_id="u1",
department_id=dept_id,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.department_id == dept_id
assert err.quota_type == "token_limit"
assert err.period == "daily"
assert err.limit == 1
assert err.current == 150 # 100 prompt + 50 completion
async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path):
"""cost_limit quota exceeded → QuotaExceededError."""
dept_id = _random_dept_id()
svc = get_quota_service()
# cost_limit is in cents. Set 1 cent.
await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily")
# Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=1.00, # $1.00 = 100 cents
latency_ms=10,
user_id="u1",
department_id=dept_id,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.quota_type == "cost_limit"
assert err.period == "daily"
assert err.limit == 1
# current is in cents (100 cents = $1.00).
assert err.current == 100.0
async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path):
"""Model not in whitelist → QuotaExceededError with quota_type=model_whitelist."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Whitelist only allows "claude" — gateway is calling "gpt-4o".
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily")
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
err = exc_info.value
assert err.quota_type == "model_whitelist"
assert err.department_id == dept_id
# For model_whitelist, current is the rejected model name.
assert err.current == "openai/gpt-4o"
async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path):
"""Model in whitelist → request allowed."""
dept_id = _random_dept_id()
svc = get_quota_service()
# Whitelist uses the full resolved model identifier (provider/model).
await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily")
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path):
"""One department exceeds, the other doesn't → rejected (strictest wins)."""
dept_ok = _random_dept_id()
dept_bad = _random_dept_id()
svc = get_quota_service()
# dept_bad has a 1-token limit; dept_ok has a 1M-token limit.
await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily")
await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily")
# Pre-populate usage for dept_bad so it exceeds.
gateway._usage_tracker.record(
agent_name="prev",
model="openai/gpt-4o",
usage=TokenUsage(prompt_tokens=100, completion_tokens=50),
cost=0.0,
latency_ms=10,
user_id="u1",
department_id=dept_bad,
)
with pytest.raises(QuotaExceededError) as exc_info:
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_ok, dept_bad],
db_path=fresh_db,
)
# The error should reference dept_bad (the one that exceeded).
assert exc_info.value.department_id == dept_bad
async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path):
"""When db_path is None, no quota check is performed."""
dept_id = _random_dept_id()
svc = get_quota_service()
await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily")
# Even with a quota set, calling without db_path should succeed.
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=None,
)
assert response.content == "response from openai"
async def test_quota_check_skipped_without_department_ids(
self, gateway: LLMGateway, fresh_db: Path
):
"""When department_ids is None, no quota check is performed."""
response = await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=None,
db_path=fresh_db,
)
assert response.content == "response from openai"
async def test_usage_recorded_with_user_and_department(
self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path
):
"""After a successful call, the usage record carries user_id + department_id."""
dept_id = _random_dept_id()
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model="openai/gpt-4o",
user_id="u1",
department_ids=[dept_id],
db_path=fresh_db,
)
summary = store.get_usage()
assert len(summary.records) == 1
rec = summary.records[0]
assert rec.user_id == "u1"
assert rec.department_id == dept_id
assert rec.model == "gpt-4o"
assert rec.total_tokens == 150 # 100 prompt + 50 completion
# ---------------------------------------------------------------------------
# QuotaExceededError dataclass-like behavior
# ---------------------------------------------------------------------------
class TestQuotaExceededError:
def test_error_message_includes_metadata(self):
err = QuotaExceededError(
department_id="d1",
quota_type="token_limit",
period="daily",
limit=1000,
current=1500,
)
msg = str(err)
assert "d1" in msg
assert "token_limit" in msg
assert "daily" in msg
assert "1000" in msg
assert "1500" in msg
def test_error_attributes_preserved(self):
err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000)
assert err.department_id == "d1"
assert err.quota_type == "cost_limit"
assert err.period == "monthly"
assert err.limit == 5000
assert err.current == 6000