"""Unit tests for LLMGateway quota enforcement (U7). Covers: - QuotaExceededError raised when token_limit exceeded - QuotaExceededError raised when cost_limit exceeded - QuotaExceededError raised when model not in whitelist - No quota set → request allowed - Multi-department: strictest-wins (one exceeds, other doesn't → rejected) - QuotaExceededError carries the right metadata - Usage recording still attaches user_id + department_id on success """ from __future__ import annotations import uuid from pathlib import Path import pytest from agentkit.llm.gateway import LLMGateway, QuotaExceededError from agentkit.llm.protocol import ( LLMProvider, LLMRequest, LLMResponse, TokenUsage, ) from agentkit.llm.providers.usage_store import InMemoryUsageStore from agentkit.server.admin.quota_service import ( get_quota_service, set_quota_service, ) from agentkit.server.auth.models import init_auth_db # --------------------------------------------------------------------------- # Test doubles # --------------------------------------------------------------------------- class FakeProvider(LLMProvider): """A minimal LLMProvider that returns a fixed response.""" def __init__(self, name: str = "fake"): self._name = name self.last_request: LLMRequest | None = None self.call_count = 0 async def chat(self, request: LLMRequest) -> LLMResponse: self.last_request = request self.call_count += 1 return LLMResponse( content=f"response from {self._name}", model=request.model, usage=TokenUsage(prompt_tokens=100, completion_tokens=50), ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def store() -> InMemoryUsageStore: return InMemoryUsageStore() @pytest.fixture def gateway(store: InMemoryUsageStore) -> LLMGateway: gw = LLMGateway(usage_store=store) gw.register_provider("openai", FakeProvider("openai")) return gw @pytest.fixture async def fresh_db(tmp_path: Path) -> Path: db_path = tmp_path / "auth.db" await init_auth_db(db_path) return db_path @pytest.fixture(autouse=True) def _reset_quota_singleton(): """Reset the QuotaService singleton before and after each test.""" set_quota_service(None) yield set_quota_service(None) def _random_dept_id() -> str: return str(uuid.uuid4()) # --------------------------------------------------------------------------- # Quota enforcement # --------------------------------------------------------------------------- class TestQuotaEnforcement: async def test_no_quota_set_allows_request(self, gateway: LLMGateway, fresh_db: Path): """When no quota is configured, the request is allowed.""" dept_id = _random_dept_id() response = await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) assert response.content == "response from openai" async def test_token_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path): """token_limit quota exceeded → QuotaExceededError.""" dept_id = _random_dept_id() svc = get_quota_service() # Set a tiny token limit (1 token) — any usage will exceed it. await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily") # Pre-populate the usage store so the daily total > 1. gateway._usage_tracker.record( agent_name="prev", model="openai/gpt-4o", usage=TokenUsage(prompt_tokens=100, completion_tokens=50), cost=0.0, latency_ms=10, user_id="u1", department_id=dept_id, ) with pytest.raises(QuotaExceededError) as exc_info: await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) err = exc_info.value assert err.department_id == dept_id assert err.quota_type == "token_limit" assert err.period == "daily" assert err.limit == 1 assert err.current == 150 # 100 prompt + 50 completion async def test_cost_limit_exceeded_raises(self, gateway: LLMGateway, fresh_db: Path): """cost_limit quota exceeded → QuotaExceededError.""" dept_id = _random_dept_id() svc = get_quota_service() # cost_limit is in cents. Set 1 cent. await svc.set_quota(fresh_db, dept_id, "cost_limit", 1, period="daily") # Pre-populate usage with $1.00 cost = 100 cents, exceeding the 1-cent limit. gateway._usage_tracker.record( agent_name="prev", model="openai/gpt-4o", usage=TokenUsage(prompt_tokens=100, completion_tokens=50), cost=1.00, # $1.00 = 100 cents latency_ms=10, user_id="u1", department_id=dept_id, ) with pytest.raises(QuotaExceededError) as exc_info: await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) err = exc_info.value assert err.quota_type == "cost_limit" assert err.period == "daily" assert err.limit == 1 # current is in cents (100 cents = $1.00). assert err.current == 100.0 async def test_model_whitelist_rejection_raises(self, gateway: LLMGateway, fresh_db: Path): """Model not in whitelist → QuotaExceededError with quota_type=model_whitelist.""" dept_id = _random_dept_id() svc = get_quota_service() # Whitelist only allows "claude" — gateway is calling "gpt-4o". await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["claude"], period="daily") with pytest.raises(QuotaExceededError) as exc_info: await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) err = exc_info.value assert err.quota_type == "model_whitelist" assert err.department_id == dept_id # For model_whitelist, current is the rejected model name. assert err.current == "openai/gpt-4o" async def test_model_whitelist_allows_listed_model(self, gateway: LLMGateway, fresh_db: Path): """Model in whitelist → request allowed.""" dept_id = _random_dept_id() svc = get_quota_service() # Whitelist uses the full resolved model identifier (provider/model). await svc.set_quota(fresh_db, dept_id, "model_whitelist", ["openai/gpt-4o"], period="daily") response = await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) assert response.content == "response from openai" async def test_multi_department_strictest_wins(self, gateway: LLMGateway, fresh_db: Path): """One department exceeds, the other doesn't → rejected (strictest wins).""" dept_ok = _random_dept_id() dept_bad = _random_dept_id() svc = get_quota_service() # dept_bad has a 1-token limit; dept_ok has a 1M-token limit. await svc.set_quota(fresh_db, dept_bad, "token_limit", 1, period="daily") await svc.set_quota(fresh_db, dept_ok, "token_limit", 1_000_000, period="daily") # Pre-populate usage for dept_bad so it exceeds. gateway._usage_tracker.record( agent_name="prev", model="openai/gpt-4o", usage=TokenUsage(prompt_tokens=100, completion_tokens=50), cost=0.0, latency_ms=10, user_id="u1", department_id=dept_bad, ) with pytest.raises(QuotaExceededError) as exc_info: await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_ok, dept_bad], db_path=fresh_db, ) # The error should reference dept_bad (the one that exceeded). assert exc_info.value.department_id == dept_bad async def test_quota_check_skipped_without_db_path(self, gateway: LLMGateway, fresh_db: Path): """When db_path is None, no quota check is performed.""" dept_id = _random_dept_id() svc = get_quota_service() await svc.set_quota(fresh_db, dept_id, "token_limit", 1, period="daily") # Even with a quota set, calling without db_path should succeed. response = await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=None, ) assert response.content == "response from openai" async def test_quota_check_skipped_without_department_ids( self, gateway: LLMGateway, fresh_db: Path ): """When department_ids is None, no quota check is performed.""" response = await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=None, db_path=fresh_db, ) assert response.content == "response from openai" async def test_usage_recorded_with_user_and_department( self, gateway: LLMGateway, store: InMemoryUsageStore, fresh_db: Path ): """After a successful call, the usage record carries user_id + department_id.""" dept_id = _random_dept_id() await gateway.chat( messages=[{"role": "user", "content": "hi"}], model="openai/gpt-4o", user_id="u1", department_ids=[dept_id], db_path=fresh_db, ) summary = store.get_usage() assert len(summary.records) == 1 rec = summary.records[0] assert rec.user_id == "u1" assert rec.department_id == dept_id assert rec.model == "gpt-4o" assert rec.total_tokens == 150 # 100 prompt + 50 completion # --------------------------------------------------------------------------- # QuotaExceededError dataclass-like behavior # --------------------------------------------------------------------------- class TestQuotaExceededError: def test_error_message_includes_metadata(self): err = QuotaExceededError( department_id="d1", quota_type="token_limit", period="daily", limit=1000, current=1500, ) msg = str(err) assert "d1" in msg assert "token_limit" in msg assert "daily" in msg assert "1000" in msg assert "1500" in msg def test_error_attributes_preserved(self): err = QuotaExceededError("d1", "cost_limit", "monthly", 5000, 6000) assert err.department_id == "d1" assert err.quota_type == "cost_limit" assert err.period == "monthly" assert err.limit == 5000 assert err.current == 6000