fix(review): U2 quota semantics — monthly quota + multi-dept attribution + TOCTOU docs

This commit is contained in:
chiguyong 2026-06-22 16:30:22 +08:00
parent 278d76b381
commit cd371e4155
1 changed files with 90 additions and 42 deletions

View File

@ -524,21 +524,45 @@ class LLMGateway:
) -> None:
"""Record a usage event via the async store interface (KTD-6).
Attaches ``user_id`` and the first ``department_id`` to the
record. Multi-department attribution is handled by the caller
(see U2 when a user belongs to multiple departments, each
department gets its own record).
Multi-department attribution (U2): when a user belongs to
multiple departments, a separate :class:`UsageRecord` is created
for each department. This ensures ``get_usage(dept_id)`` returns
the correct total for every department the user belongs to,
matching the quota check scope (which checks all departments).
TOCTOU (KTD-2): This method is called *after* the LLM response
is received. Between ``_enforce_quota`` (before the call) and
this recording, concurrent requests may push usage over the
limit. This race window is accepted; post-hoc reconciliation
(periodic scans for over-limit users) handles violations.
"""
dept_id = department_ids[0] if department_ids else None
await self._usage_tracker.record_async(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=dept_id,
)
if not department_ids:
# API key users (no departments) — record once with dept=None.
await self._usage_tracker.record_async(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=None,
)
return
# Record one entry per department so each department's aggregate
# includes this usage. The cost is attributed in full to each
# department (not split) — this matches how quota checks work
# (each department is checked against the full usage).
for dept_id in department_ids:
await self._usage_tracker.record_async(
agent_name=agent_name,
model=model,
usage=usage,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=dept_id,
)
async def _enforce_quota(
self,
@ -551,9 +575,18 @@ class LLMGateway:
Strictest-wins: if ANY department fails ANY check, raises
:class:`QuotaExceededError` and the request is rejected.
Both daily and monthly periods are checked (U2): for each
department, ``token_limit`` and ``cost_limit`` are evaluated
against both ``daily`` and ``monthly`` windows.
Fail-closed (KTD-1): if the usage store is unavailable (Redis
degraded), raises :class:`UsageStoreUnavailableError`. The
caller must translate this to HTTP 503.
TOCTOU (KTD-2): quota is checked *before* the LLM call, and
usage is recorded *after*. Concurrent requests in this window
may exceed the limit. This race is accepted; see
:meth:`_record_usage` for the reconciliation strategy.
"""
# Lazy import to avoid circular dependency (admin → ... → gateway).
from agentkit.server.admin.quota_service import get_quota_service
@ -573,37 +606,52 @@ class LLMGateway:
current=resolved_model,
)
# 2. Token limit (daily)
current_tokens = await self._get_current_usage_for_quota(dept_id, "daily")
allowed, _reason = await quota_service.check_quota(
db, dept_id, "token_limit", "daily", current_tokens
# 2. Token + cost limits (daily AND monthly)
await self._check_quota_period(
quota_service, db, dept_id, "daily", "token_limit"
)
await self._check_quota_period(
quota_service, db, dept_id, "daily", "cost_limit"
)
await self._check_quota_period(
quota_service, db, dept_id, "monthly", "token_limit"
)
await self._check_quota_period(
quota_service, db, dept_id, "monthly", "cost_limit"
)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, "token_limit", "daily")
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type="token_limit",
period="daily",
limit=limit,
current=current_tokens,
)
# 3. Cost limit (daily)
current_cost = await self._get_current_cost_for_quota(dept_id, "daily")
allowed, _reason = await quota_service.check_quota(
db, dept_id, "cost_limit", "daily", current_cost
async def _check_quota_period(
self,
quota_service: Any,
db: Path,
dept_id: str,
period: str,
quota_type: str,
) -> None:
"""Check a single quota (token_limit or cost_limit) for a period.
Raises :class:`QuotaExceededError` if the current usage exceeds
the configured limit. ``period`` is ``"daily"`` or ``"monthly"``;
``quota_type`` is ``"token_limit"`` or ``"cost_limit"``.
"""
if quota_type == "token_limit":
current = await self._get_current_usage_for_quota(dept_id, period)
else:
current = await self._get_current_cost_for_quota(dept_id, period)
allowed, _reason = await quota_service.check_quota(
db, dept_id, quota_type, period, current
)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type=quota_type,
period=period,
limit=limit,
current=current,
)
if not allowed:
quota = await quota_service.get_quota(db, dept_id, "cost_limit", "daily")
limit = quota["limit_value"] if quota else None
raise QuotaExceededError(
department_id=dept_id,
quota_type="cost_limit",
period="daily",
limit=limit,
current=current_cost,
)
async def _get_current_usage_for_quota(self, department_id: str, period: str) -> int:
"""Return total tokens used by ``department_id`` in the current period.