fix(review): U1 Redis quota enforcement — key construction + fail-closed + degradation recovery + async

This commit is contained in:
chiguyong 2026-06-22 16:22:33 +08:00
parent abe2a66436
commit 00c8386939
6 changed files with 720 additions and 137 deletions

View File

@ -168,7 +168,7 @@ class LLMGateway:
result = await self._cache.get(cache_key) result = await self._cache.get(cache_key)
if result.hit: if result.hit:
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
self._record_usage( await self._record_usage(
agent_name=agent_name, agent_name=agent_name,
model=result.response.model, model=result.response.model,
usage=result.response.usage, usage=result.response.usage,
@ -197,7 +197,7 @@ class LLMGateway:
result = await self._cache.semantic_search(query_embedding) result = await self._cache.semantic_search(query_embedding)
if result.hit: if result.hit:
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
self._record_usage( await self._record_usage(
agent_name=agent_name, agent_name=agent_name,
model=result.response.model, model=result.response.model,
usage=result.response.usage, usage=result.response.usage,
@ -245,7 +245,7 @@ class LLMGateway:
if response.usage: if response.usage:
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
cost = self._calculate_cost(model_name, response.usage) cost = self._calculate_cost(model_name, response.usage)
self._record_usage( await self._record_usage(
agent_name=agent_name, agent_name=agent_name,
model=model_name, model=model_name,
usage=response.usage, usage=response.usage,
@ -286,7 +286,7 @@ class LLMGateway:
cost = self._calculate_cost(response.model, response.usage) cost = self._calculate_cost(response.model, response.usage)
# 记录使用量 # 记录使用量
self._record_usage( await self._record_usage(
agent_name=agent_name, agent_name=agent_name,
model=response.model, model=response.model,
usage=response.usage, usage=response.usage,
@ -406,7 +406,7 @@ class LLMGateway:
if final_usage is None: if final_usage is None:
final_usage = TokenUsage() final_usage = TokenUsage()
cost = self._calculate_cost(final_model, final_usage) cost = self._calculate_cost(final_model, final_usage)
self._record_usage( await self._record_usage(
agent_name=agent_name, agent_name=agent_name,
model=final_model, model=final_model,
usage=final_usage, usage=final_usage,
@ -512,7 +512,7 @@ class LLMGateway:
# Quota enforcement helpers (U7) # Quota enforcement helpers (U7)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _record_usage( async def _record_usage(
self, self,
agent_name: str, agent_name: str,
model: str, model: str,
@ -522,16 +522,15 @@ class LLMGateway:
user_id: str | None, user_id: str | None,
department_ids: list[str] | None, department_ids: list[str] | None,
) -> None: ) -> None:
"""Record a usage event, attaching user_id and (first) department_id. """Record a usage event via the async store interface (KTD-6).
We attach only the first department_id to the record because Attaches ``user_id`` and the first ``department_id`` to the
usage attribution is per-department. If a user belongs to record. Multi-department attribution is handled by the caller
multiple departments, the caller is responsible for choosing (see U2 when a user belongs to multiple departments, each
which department to bill the gateway just records what it's department gets its own record).
told.
""" """
dept_id = department_ids[0] if department_ids else None dept_id = department_ids[0] if department_ids else None
self._usage_tracker.record( await self._usage_tracker.record_async(
agent_name=agent_name, agent_name=agent_name,
model=model, model=model,
usage=usage, usage=usage,
@ -551,6 +550,10 @@ class LLMGateway:
Strictest-wins: if ANY department fails ANY check, raises Strictest-wins: if ANY department fails ANY check, raises
:class:`QuotaExceededError` and the request is rejected. :class:`QuotaExceededError` and the request is rejected.
Fail-closed (KTD-1): if the usage store is unavailable (Redis
degraded), raises :class:`UsageStoreUnavailableError`. The
caller must translate this to HTTP 503.
""" """
# Lazy import to avoid circular dependency (admin → ... → gateway). # Lazy import to avoid circular dependency (admin → ... → gateway).
from agentkit.server.admin.quota_service import get_quota_service from agentkit.server.admin.quota_service import get_quota_service

View File

@ -26,7 +26,7 @@ class UsageTracker:
user_id: str | None = None, user_id: str | None = None,
department_id: str | None = None, department_id: str | None = None,
) -> None: ) -> None:
"""记录一次使用""" """记录一次使用sync — 在 async 上下文中可能阻塞事件循环)"""
self._store.record( self._store.record(
agent_name, agent_name,
model, model,
@ -37,6 +37,44 @@ class UsageTracker:
department_id=department_id, department_id=department_id,
) )
async def record_async(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""记录一次使用async — 不阻塞事件循环KTD-6
如果底层 store 提供了 ``record_async`` 方法则调用它否则
回退到同步 ``record``适用于 InMemoryUsageStore 等无 I/O store
"""
record_async = getattr(self._store, "record_async", None)
if record_async is not None:
await record_async(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
else:
# Fallback: sync store without async support.
self._store.record(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
def get_usage( def get_usage(
self, self,
agent_name: str | None = None, agent_name: str | None = None,

View File

@ -13,6 +13,7 @@ Legacy v1 keys (still readable for backward compat):
agentkit:usage_records:{date} List agentkit:usage_records:{date} List
""" """
import asyncio
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -24,6 +25,14 @@ from agentkit.llm.protocol import TokenUsage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UsageStoreUnavailableError(Exception):
"""Raised when the usage store is unavailable for quota-critical queries.
The gateway treats this as a fail-closed condition (HTTP 503)
refusing the request rather than allowing untracked usage.
"""
@dataclass @dataclass
class UsageRecord: class UsageRecord:
"""使用量记录""" """使用量记录"""
@ -86,7 +95,7 @@ class UsageStore(Protocol):
user_id: str | None = None, user_id: str | None = None,
department_id: str | None = None, department_id: str | None = None,
) -> None: ) -> None:
"""Record a usage event.""" """Record a usage event (sync — may block in async contexts)."""
... ...
def get_usage( def get_usage(
@ -97,7 +106,12 @@ class UsageStore(Protocol):
user_id: str | None = None, user_id: str | None = None,
department_id: str | None = None, department_id: str | None = None,
) -> UsageSummary: ) -> UsageSummary:
"""Query usage summary.""" """Query usage summary.
Raises:
UsageStoreUnavailableError: When the store is degraded and
cannot answer quota-critical queries (fail-closed).
"""
... ...
@ -139,6 +153,27 @@ class InMemoryUsageStore:
if len(self._records) > self.MAX_RECORDS: if len(self._records) > self.MAX_RECORDS:
self._records = self._records[-self.MAX_RECORDS :] self._records = self._records[-self.MAX_RECORDS :]
async def record_async(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""Async wrapper — InMemory store has no I/O, so just delegates to sync."""
self.record(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
def get_usage( def get_usage(
self, self,
agent_name: str | None = None, agent_name: str | None = None,
@ -234,6 +269,19 @@ class RedisUsageStore:
Legacy v1 keys (still readable for backward compat): Legacy v1 keys (still readable for backward compat):
agentkit:usage:{YYYY-MM-DD} Hash agentkit:usage:{YYYY-MM-DD} Hash
agentkit:usage_records:{YYYY-MM-DD} List agentkit:usage_records:{YYYY-MM-DD} List
Fail-closed semantics:
When Redis is unreachable and the store degrades to the in-memory
fallback, ``get_usage`` raises :class:`UsageStoreUnavailableError`.
The gateway translates this to HTTP 503, refusing the request
rather than allowing untracked usage (KTD-1). The fallback store
is used only for ``record`` (best-effort persistence) never for
quota queries.
Degradation recovery:
A background health-check task pings Redis every 30 seconds. On
success, the degraded flag is cleared and the fallback store is
discarded (KTD-5).
""" """
USAGE_PREFIX = "agentkit:usage:" USAGE_PREFIX = "agentkit:usage:"
@ -242,6 +290,7 @@ class RedisUsageStore:
RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:" RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:"
MAX_RECORDS_PER_DAY = 50000 MAX_RECORDS_PER_DAY = 50000
TTL_DAYS = 90 # Auto-expire after 90 days TTL_DAYS = 90 # Auto-expire after 90 days
HEALTH_CHECK_INTERVAL = 30.0 # seconds
def __init__(self, redis_url: str = "redis://localhost:6379"): def __init__(self, redis_url: str = "redis://localhost:6379"):
self._redis_url = redis_url self._redis_url = redis_url
@ -249,6 +298,7 @@ class RedisUsageStore:
self._sync_redis: Any = None self._sync_redis: Any = None
self._fallback: InMemoryUsageStore | None = None self._fallback: InMemoryUsageStore | None = None
self._degraded = False self._degraded = False
self._health_check_task: asyncio.Task[None] | None = None
async def _get_redis(self): async def _get_redis(self):
if self._redis is None: if self._redis is None:
@ -266,6 +316,14 @@ class RedisUsageStore:
return self._sync_redis return self._sync_redis
async def aclose(self) -> None: async def aclose(self) -> None:
# Cancel the health-check task first to avoid races.
if self._health_check_task is not None:
self._health_check_task.cancel()
try:
await self._health_check_task
except (asyncio.CancelledError, Exception):
pass
self._health_check_task = None
if self._redis is not None: if self._redis is not None:
await self._redis.aclose() await self._redis.aclose()
self._redis = None self._redis = None
@ -279,6 +337,45 @@ class RedisUsageStore:
if self._fallback is None: if self._fallback is None:
self._fallback = InMemoryUsageStore() self._fallback = InMemoryUsageStore()
logger.warning("Redis usage store unreachable, degraded to in-memory") logger.warning("Redis usage store unreachable, degraded to in-memory")
self._start_health_check()
def _start_health_check(self) -> None:
"""Start the background health-check task (idempotent).
The task pings Redis every ``HEALTH_CHECK_INTERVAL`` seconds. On
success, the degraded flag is cleared and the fallback store is
discarded (KTD-5).
"""
if self._health_check_task is not None and not self._health_check_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop (e.g. sync-only context) — skip.
logger.debug("No event loop available for health-check task")
return
self._health_check_task = loop.create_task(self._health_check_loop())
async def _health_check_loop(self) -> None:
"""Periodically ping Redis; clear degraded state on success."""
while self._degraded:
try:
await asyncio.sleep(self.HEALTH_CHECK_INTERVAL)
except asyncio.CancelledError:
return
try:
redis_client = await self._get_redis()
await redis_client.ping()
# Redis is back — clear degraded state.
self._degraded = False
self._fallback = None
logger.info("Redis usage store recovered, cleared degraded state")
return
except asyncio.CancelledError:
return
except Exception:
logger.debug("Redis health check still failing")
# Keep looping until Redis recovers.
def _today_key(self) -> str: def _today_key(self) -> str:
return datetime.now(timezone.utc).strftime("%Y-%m-%d") return datetime.now(timezone.utc).strftime("%Y-%m-%d")
@ -299,6 +396,47 @@ class RedisUsageStore:
f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}", f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}",
) )
def _build_record_json(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None,
department_id: str | None,
) -> tuple[str, str, str, str]:
"""Build (hash_key, list_key, bucket_field, record_json) for a write."""
date_key = self._today_key()
hash_key, list_key = self._v2_keys(date_key, user_id, department_id)
bucket_field = f"{agent_name}:{model}"
rec = UsageRecord(
agent_name=agent_name,
model=model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=department_id,
)
record_json = json.dumps(
{
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"timestamp": rec.timestamp,
"user_id": rec.user_id,
"department_id": rec.department_id,
}
)
return hash_key, list_key, bucket_field, record_json
def record( def record(
self, self,
agent_name: str, agent_name: str,
@ -309,11 +447,12 @@ class RedisUsageStore:
user_id: str | None = None, user_id: str | None = None,
department_id: str | None = None, department_id: str | None = None,
) -> None: ) -> None:
"""Record usage — sync wrapper for async Redis. """Record usage — sync wrapper using sync Redis client.
Note: This is a sync method because UsageTracker.record() is sync. Note: This is a sync method because UsageTracker.record() is sync.
For Redis, we use a sync Redis client for writes to avoid For Redis, we use a sync Redis client for writes to avoid
needing an event loop in the caller. needing an event loop in the caller. In async contexts prefer
:meth:`record_async` to avoid blocking the event loop (KTD-6).
""" """
if self._degraded and self._fallback is not None: if self._degraded and self._fallback is not None:
self._fallback.record( self._fallback.record(
@ -329,10 +468,9 @@ class RedisUsageStore:
try: try:
r = self._get_sync_redis() r = self._get_sync_redis()
hash_key, list_key, bucket_field, record_json = self._build_record_json(
date_key = self._today_key() agent_name, model, usage, cost, latency_ms, user_id, department_id
hash_key, list_key = self._v2_keys(date_key, user_id, department_id) )
bucket_field = f"{agent_name}:{model}"
# Atomic HINCRBYFLOAT for bucket aggregation # Atomic HINCRBYFLOAT for bucket aggregation
pipe = r.pipeline() pipe = r.pipeline()
@ -343,34 +481,7 @@ class RedisUsageStore:
pipe.hincrby(hash_key, f"{bucket_field}:count", 1) pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
# Append record # Append record
rec = UsageRecord( pipe.rpush(list_key, record_json)
agent_name=agent_name,
model=model,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
total_tokens=usage.total_tokens,
cost=cost,
latency_ms=latency_ms,
user_id=user_id,
department_id=department_id,
)
pipe.rpush(
list_key,
json.dumps(
{
"agent_name": rec.agent_name,
"model": rec.model,
"prompt_tokens": rec.prompt_tokens,
"completion_tokens": rec.completion_tokens,
"total_tokens": rec.total_tokens,
"cost": rec.cost,
"latency_ms": rec.latency_ms,
"timestamp": rec.timestamp,
"user_id": rec.user_id,
"department_id": rec.department_id,
}
),
)
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
# Set TTL on first write of the day # Set TTL on first write of the day
@ -392,6 +503,65 @@ class RedisUsageStore:
department_id=department_id, department_id=department_id,
) )
async def record_async(
self,
agent_name: str,
model: str,
usage: TokenUsage,
cost: float,
latency_ms: float,
user_id: str | None = None,
department_id: str | None = None,
) -> None:
"""Async record using ``redis.asyncio`` — does not block the event loop (KTD-6).
Preferred over :meth:`record` in async contexts (gateway, routes).
Falls back to the in-memory store on Redis failure, same as the
sync version.
"""
if self._degraded and self._fallback is not None:
await self._fallback.record_async(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
return
try:
r = await self._get_redis()
hash_key, list_key, bucket_field, record_json = self._build_record_json(
agent_name, model, usage, cost, latency_ms, user_id, department_id
)
pipe = r.pipeline()
pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost)
pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens)
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
pipe.rpush(list_key, record_json)
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
pipe.expire(hash_key, self.TTL_DAYS * 86400)
pipe.expire(list_key, self.TTL_DAYS * 86400)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis async usage record failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
await self._fallback.record_async(
agent_name,
model,
usage,
cost,
latency_ms,
user_id=user_id,
department_id=department_id,
)
def get_usage( def get_usage(
self, self,
agent_name: str | None = None, agent_name: str | None = None,
@ -405,14 +575,28 @@ class RedisUsageStore:
Scans v2 keys (filtered by user_id/department_id when provided) Scans v2 keys (filtered by user_id/department_id when provided)
and legacy v1 keys (no per-user/department scoping). Records and legacy v1 keys (no per-user/department scoping). Records
from both schemas are merged. from both schemas are merged.
Key construction (U1 fix):
- Both user_id and department_id provided direct key lookup
(exact match).
- Only user_id provided SCAN ``...:{user_id}:*`` to aggregate
across all departments the user belongs to.
- Only department_id provided SCAN ``...:*:{department_id}``
to aggregate across all users in that department.
- Neither provided SCAN ``...:*`` (all records for the date).
Fail-closed (KTD-1):
When the store is degraded (Redis unreachable), raises
:class:`UsageStoreUnavailableError`. The fallback store is
used only for ``record`` never for quota queries.
""" """
if self._degraded and self._fallback is not None: # Fail-closed: when degraded, refuse quota-critical queries.
return self._fallback.get_usage( # The in-memory fallback is only for recording (best-effort),
agent_name=agent_name, # not for quota checks — returning an empty summary would
start_time=start_time, # make quota checks pass (fail-open), which is a security bug.
end_time=end_time, if self._degraded:
user_id=user_id, raise UsageStoreUnavailableError(
department_id=department_id, "Redis usage store is degraded — cannot answer quota-critical query"
) )
try: try:
@ -429,14 +613,29 @@ class RedisUsageStore:
end_date = end.date() end_date = end.date()
while current <= end_date: while current <= end_date:
date_key = current.isoformat() date_key = current.isoformat()
# When user_id/department_id is provided, scan only the # Determine the SCAN pattern based on which scope dims are provided.
# matching scope key. Otherwise scan all scopes for that # Key format: {RECORDS_PREFIX_V2}{date}:{user_or_none}:{dept_or_none}
# date via SCAN. if user_id is not None and department_id is not None:
if user_id is not None or department_id is not None: # Both provided → direct key lookup (exact match).
list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}" list_key = (
f"{self.RECORDS_PREFIX_V2}{date_key}:"
f"{self._scope_key(user_id)}:{self._scope_key(department_id)}"
)
all_records.extend(self._read_list(r, list_key, start, end, agent_name)) all_records.extend(self._read_list(r, list_key, start, end, agent_name))
elif user_id is not None:
# Only user_id → aggregate across all departments.
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:*"
for key in r.scan_iter(match=pattern, count=200):
all_records.extend(self._read_list(r, key, start, end, agent_name))
elif department_id is not None:
# Only department_id → aggregate across all users.
pattern = (
f"{self.RECORDS_PREFIX_V2}{date_key}:*:{self._scope_key(department_id)}"
)
for key in r.scan_iter(match=pattern, count=200):
all_records.extend(self._read_list(r, key, start, end, agent_name))
else: else:
# Scan all v2 list keys for this date. # Neither provided → scan all v2 list keys for this date.
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*" pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*"
for key in r.scan_iter(match=pattern, count=200): for key in r.scan_iter(match=pattern, count=200):
all_records.extend(self._read_list(r, key, start, end, agent_name)) all_records.extend(self._read_list(r, key, start, end, agent_name))
@ -460,17 +659,13 @@ class RedisUsageStore:
return UsageSummary() return UsageSummary()
return InMemoryUsageStore._aggregate(all_records) return InMemoryUsageStore._aggregate(all_records)
except UsageStoreUnavailableError:
raise
except Exception as e: except Exception as e:
logger.warning(f"Redis usage query failed: {e}") logger.warning(f"Redis usage query failed: {e}")
if self._fallback is not None: # Degrade for future calls, but fail-closed for this one.
return self._fallback.get_usage( self._degrade_to_fallback()
agent_name=agent_name, raise UsageStoreUnavailableError(f"Redis usage query failed: {e}") from e
start_time=start_time,
end_time=end_time,
user_id=user_id,
department_id=department_id,
)
return UsageSummary()
def get_usage_by_user( def get_usage_by_user(
self, self,

View File

@ -30,7 +30,7 @@ from pathlib import Path
from typing import Any from typing import Any
import aiosqlite import aiosqlite
from fastapi import Request from fastapi import HTTPException, Request
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH
@ -134,6 +134,11 @@ async def get_department_context(request: Request) -> DepartmentContext:
If ``request.state.current_user`` is missing entirely (e.g. the If ``request.state.current_user`` is missing entirely (e.g. the
auth middleware was not installed), returns an empty context auth middleware was not installed), returns an empty context
equivalent to the unauthenticated case. equivalent to the unauthenticated case.
Fail-closed (KTD-1): if the DB lookup fails for a regular user,
raises ``HTTPException(503)``. Returning an empty list would make
quota enforcement skip the check (fail-open), which is a security
bug. Admins and API-key clients are unaffected (no DB lookup).
""" """
current_user: dict[str, Any] | None = getattr(request.state, "current_user", None) current_user: dict[str, Any] | None = getattr(request.state, "current_user", None)
if current_user is None: if current_user is None:
@ -155,15 +160,23 @@ async def get_department_context(request: Request) -> DepartmentContext:
return DepartmentContext(user_id=None, department_ids=[], is_admin=False) return DepartmentContext(user_id=None, department_ids=[], is_admin=False)
# Regular user: look up their active department ids. # Regular user: look up their active department ids.
# Fail-closed on DB errors — returning an empty list would bypass
# quota enforcement (fail-open), which is a security vulnerability.
db_path = _resolve_db_path(request) db_path = _resolve_db_path(request)
try: try:
department_ids = await _fetch_user_department_ids(db_path, user_id) department_ids = await _fetch_user_department_ids(db_path, user_id)
except Exception: # noqa: BLE001 — never block a request on DB errors except Exception:
logger.exception( logger.exception(
"Failed to fetch department ids for user %s — falling back to empty list", "Failed to fetch department ids for user %s — fail-closed (503)",
user_id, user_id,
) )
department_ids = [] raise HTTPException(
status_code=503,
detail={
"error": "department_lookup_failed",
"detail": "Cannot verify department membership — refusing request",
},
) from None
return DepartmentContext( return DepartmentContext(
user_id=user_id, user_id=user_id,

View File

@ -14,6 +14,7 @@ from pydantic import BaseModel, ConfigDict, Field
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.gateway import QuotaExceededError from agentkit.llm.gateway import QuotaExceededError
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
from agentkit.llm.providers.usage_store import UsageStoreUnavailableError
from agentkit.server.admin.context import get_department_context from agentkit.server.admin.context import get_department_context
router = APIRouter(prefix="/llm", tags=["llm-gateway"]) router = APIRouter(prefix="/llm", tags=["llm-gateway"])
@ -109,6 +110,11 @@ async def chat(
) )
except QuotaExceededError as e: except QuotaExceededError as e:
raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e
except UsageStoreUnavailableError as e:
raise HTTPException(
status_code=503,
detail={"error": "usage_store_unavailable", "detail": str(e)},
) from e
except ModelNotFoundError as e: except ModelNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e raise HTTPException(status_code=404, detail=str(e)) from e
except LLMProviderError as e: except LLMProviderError as e:
@ -150,6 +156,12 @@ async def chat_stream(
error_payload = _quota_error_payload(e) error_payload = _quota_error_payload(e)
error_payload["error"] = "quota_exceeded" error_payload["error"] = "quota_exceeded"
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
except UsageStoreUnavailableError as e:
error_payload = {
"error": "usage_store_unavailable",
"detail": str(e),
}
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
except ModelNotFoundError as e: except ModelNotFoundError as e:
error_payload = {"error": "model_not_found", "detail": str(e)} error_payload = {"error": "model_not_found", "detail": str(e)}
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"

View File

@ -2,6 +2,8 @@
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import pytest
from agentkit.llm.protocol import TokenUsage from agentkit.llm.protocol import TokenUsage
from agentkit.llm.providers.usage_store import ( from agentkit.llm.providers.usage_store import (
InMemoryUsageStore, InMemoryUsageStore,
@ -10,6 +12,7 @@ from agentkit.llm.providers.usage_store import (
UsageBucket, UsageBucket,
UsageSummary, UsageSummary,
UsageStore, UsageStore,
UsageStoreUnavailableError,
create_usage_store, create_usage_store,
) )
@ -315,20 +318,40 @@ class TestRedisUsageStoreMocked:
summary = store._fallback.get_usage() summary = store._fallback.get_usage()
assert len(summary.records) == 1 assert len(summary.records) == 1
def test_get_usage_degraded_uses_fallback(self): def test_get_usage_degraded_raises_unavailable(self):
"""Fail-closed (KTD-1): degraded get_usage raises UsageStoreUnavailableError."""
store = self._make_store()
store._degrade_to_fallback()
with pytest.raises(UsageStoreUnavailableError):
store.get_usage()
def test_get_usage_degraded_raises_even_with_fallback_data(self):
"""Even if fallback has data, get_usage must fail-closed (no quota bypass)."""
store = self._make_store() store = self._make_store()
store._degrade_to_fallback() store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50) usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
summary = store.get_usage() with pytest.raises(UsageStoreUnavailableError):
assert len(summary.records) == 1 store.get_usage()
def test_get_usage_degraded_no_fallback_returns_empty(self): def test_get_usage_degraded_with_user_filter_raises(self):
"""Fail-closed applies even when filtering by user_id."""
store = self._make_store() store = self._make_store()
store._degraded = True store._degrade_to_fallback()
# No fallback set — should return empty with pytest.raises(UsageStoreUnavailableError):
summary = store.get_usage() store.get_usage(user_id="u1")
assert summary.total_tokens == 0
def test_get_usage_by_user_degraded_raises(self):
store = self._make_store()
store._degrade_to_fallback()
with pytest.raises(UsageStoreUnavailableError):
store.get_usage_by_user("u1")
def test_get_usage_by_department_degraded_raises(self):
store = self._make_store()
store._degrade_to_fallback()
with pytest.raises(UsageStoreUnavailableError):
store.get_usage_by_department("d1")
def test_today_key_format(self): def test_today_key_format(self):
store = self._make_store() store = self._make_store()
@ -369,65 +392,20 @@ class TestRedisUsageStoreMocked:
assert summary.records[0].user_id == "u1" assert summary.records[0].user_id == "u1"
assert summary.records[0].department_id == "d1" assert summary.records[0].department_id == "d1"
def test_get_usage_degraded_with_user_filter(self): def test_record_async_degraded_uses_fallback(self):
store = self._make_store() """Async record in degraded state uses fallback (recording is allowed)."""
store._degrade_to_fallback() import asyncio
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u2",
department_id="d2",
)
summary = store.get_usage(user_id="u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
def test_get_usage_by_user_degraded(self):
store = self._make_store() store = self._make_store()
store._degrade_to_fallback() store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50) usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
store._fallback.record(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
summary = store.get_usage_by_user("u1")
assert len(summary.records) == 1
assert summary.records[0].user_id == "u1"
def test_get_usage_by_department_degraded(self): async def _run():
store = self._make_store() await store.record_async("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
store._degrade_to_fallback()
usage = TokenUsage(prompt_tokens=100, completion_tokens=50) asyncio.run(_run())
store._fallback.record( summary = store._fallback.get_usage()
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
summary = store.get_usage_by_department("d1")
assert len(summary.records) == 1 assert len(summary.records) == 1
assert summary.records[0].department_id == "d1"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -448,3 +426,347 @@ class TestCreateUsageStore:
store = create_usage_store(backend="redis") store = create_usage_store(backend="redis")
# May be InMemory if redis package unavailable # May be InMemory if redis package unavailable
assert isinstance(store, (InMemoryUsageStore, RedisUsageStore)) assert isinstance(store, (InMemoryUsageStore, RedisUsageStore))
# ---------------------------------------------------------------------------
# U1: Key construction fix — SCAN patterns for partial scope queries
# ---------------------------------------------------------------------------
class TestRedisUsageStoreKeyConstruction:
"""U1: Verify get_usage constructs correct SCAN patterns when only
user_id OR only department_id is provided (not both).
Previously, ``get_usage(department_id=X)`` constructed key
``...:none:X`` which only matched records with no user missing
all records from actual users in that department. The fix uses
SCAN with pattern ``...:*:X`` to aggregate across all users.
"""
def _make_store_with_mock_redis(self, mock_redis):
store = RedisUsageStore(redis_url="redis://localhost:6379")
store._sync_redis = mock_redis
return store
def test_get_usage_by_department_scans_all_users(self):
"""get_usage(department_id=X) should SCAN ...:*:X (all users)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
# Single-day range so scan_iter is called once per scope.
keys_for_d1 = [
"agentkit:usage_records:v2:2026-06-21:u1:d1",
"agentkit:usage_records:v2:2026-06-21:u2:d1",
"agentkit:usage_records:v2:2026-06-21:none:d1",
]
mock_redis.scan_iter.return_value = iter(keys_for_d1)
# lrange returns records for each v2 key + 1 legacy v1 key.
record_u1 = (
'{"agent_name":"a","model":"m","prompt_tokens":100,'
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u1","department_id":"d1"}'
)
record_u2 = (
'{"agent_name":"a","model":"m","prompt_tokens":200,'
'"completion_tokens":100,"total_tokens":300,"cost":0.10,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u2","department_id":"d1"}'
)
record_none = (
'{"agent_name":"a","model":"m","prompt_tokens":50,'
'"completion_tokens":25,"total_tokens":75,"cost":0.02,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":null,"department_id":"d1"}'
)
mock_redis.lrange.side_effect = [
[record_u1], # for key u1:d1
[record_u2], # for key u2:d1
[record_none], # for key none:d1
[], # legacy v1 key
]
store = self._make_store_with_mock_redis(mock_redis)
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
# Use end_time within same day (records are at 10:00 UTC).
end = day + timedelta(hours=12)
summary = store.get_usage(department_id="d1", start_time=day, end_time=end)
# Should aggregate records from all users in d1.
assert summary.total_tokens == 525 # 150 + 300 + 75
# Verify SCAN was called with the correct pattern.
scan_call = mock_redis.scan_iter.call_args
pattern = scan_call.kwargs.get("match") or scan_call.args[0]
assert "*:d1" in pattern
def test_get_usage_by_user_scans_all_departments(self):
"""get_usage(user_id=X) should SCAN ...:X:* (all departments)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
mock_redis.scan_iter.return_value = iter(
[
"agentkit:usage_records:v2:2026-06-21:u1:d1",
"agentkit:usage_records:v2:2026-06-21:u1:d2",
]
)
record_d1 = (
'{"agent_name":"a","model":"m","prompt_tokens":100,'
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u1","department_id":"d1"}'
)
record_d2 = (
'{"agent_name":"a","model":"m","prompt_tokens":200,'
'"completion_tokens":100,"total_tokens":300,"cost":0.10,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u1","department_id":"d2"}'
)
mock_redis.lrange.side_effect = [
[record_d1],
[record_d2],
[], # legacy v1 key
]
store = self._make_store_with_mock_redis(mock_redis)
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
end = day + timedelta(hours=12)
summary = store.get_usage(user_id="u1", start_time=day, end_time=end)
assert summary.total_tokens == 450 # 150 + 300
scan_call = mock_redis.scan_iter.call_args
pattern = scan_call.kwargs.get("match") or scan_call.args[0]
assert "u1:*" in pattern
def test_get_usage_both_user_and_dept_uses_direct_key(self):
"""get_usage(user_id=X, department_id=Y) → direct key lookup (no SCAN)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
record = (
'{"agent_name":"a","model":"m","prompt_tokens":100,'
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u1","department_id":"d1"}'
)
mock_redis.lrange.side_effect = [
[record], # direct v2 key
[], # legacy v1 key
]
store = self._make_store_with_mock_redis(mock_redis)
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
end = day + timedelta(hours=12)
summary = store.get_usage(
user_id="u1",
department_id="d1",
start_time=day,
end_time=end,
)
assert summary.total_tokens == 150
# SCAN should NOT be called (direct key lookup instead).
mock_redis.scan_iter.assert_not_called()
def test_get_usage_no_filter_scans_all(self):
"""get_usage() with no filters → SCAN ...:* (all records)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
mock_redis.scan_iter.return_value = iter(["agentkit:usage_records:v2:2026-06-21:u1:d1"])
record = (
'{"agent_name":"a","model":"m","prompt_tokens":100,'
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
'"user_id":"u1","department_id":"d1"}'
)
mock_redis.lrange.side_effect = [
[record],
[], # legacy v1 key
]
store = self._make_store_with_mock_redis(mock_redis)
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
end = day + timedelta(hours=12)
summary = store.get_usage(start_time=day, end_time=end)
assert summary.total_tokens == 150
mock_redis.scan_iter.assert_called_once()
def test_get_usage_empty_redis_returns_empty_summary(self):
"""No records in Redis → empty UsageSummary (not an error)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
mock_redis.scan_iter.return_value = iter([])
mock_redis.lrange.return_value = []
store = self._make_store_with_mock_redis(mock_redis)
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
end = day + timedelta(hours=12)
summary = store.get_usage(department_id="d1", start_time=day, end_time=end)
assert summary.total_tokens == 0
assert len(summary.records) == 0
# ---------------------------------------------------------------------------
# U1: Degradation recovery — health check clears degraded state
# ---------------------------------------------------------------------------
class TestRedisUsageStoreDegradationRecovery:
"""U1 (KTD-5): Redis recovery clears degraded state via health check."""
def test_health_check_clears_degraded_on_recovery(self):
"""When Redis recovers, _degraded is cleared and _fallback discarded."""
import asyncio
from unittest.mock import AsyncMock, patch
store = RedisUsageStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
assert store._degraded
assert store._fallback is not None
async def _run():
# Mock the async Redis client's ping to succeed.
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
with patch.object(store, "_get_redis", return_value=mock_redis):
# Manually trigger one iteration of the health check loop.
# We set a very short interval and run one cycle.
store.HEALTH_CHECK_INTERVAL = 0.01
task = asyncio.create_task(store._health_check_loop())
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
asyncio.run(_run())
assert not store._degraded
assert store._fallback is None
def test_health_check_keeps_degraded_on_failure(self):
"""When Redis is still down, degraded state persists."""
import asyncio
from unittest.mock import AsyncMock, patch
store = RedisUsageStore(redis_url="redis://localhost:6379")
store._degrade_to_fallback()
async def _run():
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(side_effect=ConnectionError("still down"))
with patch.object(store, "_get_redis", return_value=mock_redis):
store.HEALTH_CHECK_INTERVAL = 0.01
task = asyncio.create_task(store._health_check_loop())
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
asyncio.run(_run())
assert store._degraded # Still degraded
def test_aclose_cancels_health_check_task(self):
"""aclose() cancels the health check task."""
import asyncio
store = RedisUsageStore(redis_url="redis://localhost:6379")
async def _run():
store._degrade_to_fallback()
assert store._health_check_task is not None
await store.aclose()
assert store._health_check_task is None
asyncio.run(_run())
# ---------------------------------------------------------------------------
# U1: Fail-closed — get_usage raises when Redis query fails
# ---------------------------------------------------------------------------
class TestRedisUsageStoreFailClosed:
"""U1 (KTD-1): Redis query failure → UsageStoreUnavailableError (not empty)."""
def test_get_usage_raises_on_redis_connection_failure(self):
"""When Redis connection fails during get_usage, raise (fail-closed)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
mock_redis.scan_iter.side_effect = ConnectionError("Redis gone")
store = RedisUsageStore(redis_url="redis://localhost:6379")
store._sync_redis = mock_redis
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
with pytest.raises(UsageStoreUnavailableError):
store.get_usage(department_id="d1", start_time=day, end_time=day + timedelta(days=1))
# Should also degrade for future calls.
assert store._degraded
def test_get_usage_fail_closed_not_empty_summary(self):
"""Critical: must NOT return empty summary on Redis failure (that's fail-open)."""
from unittest.mock import MagicMock
mock_redis = MagicMock()
mock_redis.scan_iter.side_effect = ConnectionError("Redis gone")
store = RedisUsageStore(redis_url="redis://localhost:6379")
store._sync_redis = mock_redis
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
# Must raise, not return UsageSummary().
with pytest.raises(UsageStoreUnavailableError):
store.get_usage(start_time=day, end_time=day + timedelta(days=1))
# ---------------------------------------------------------------------------
# U1: Async record — record_async uses redis.asyncio
# ---------------------------------------------------------------------------
class TestRedisUsageStoreAsyncRecord:
"""U1 (KTD-6): record_async uses async Redis client, doesn't block event loop."""
def test_record_async_writes_to_redis(self):
"""record_async should write to Redis via async client."""
import asyncio
from unittest.mock import AsyncMock, MagicMock
store = RedisUsageStore(redis_url="redis://localhost:6379")
mock_pipeline = MagicMock()
mock_pipeline.execute = AsyncMock(return_value=[True])
mock_redis = MagicMock()
mock_redis.pipeline.return_value = mock_pipeline
async def _run():
from unittest.mock import patch
with patch.object(store, "_get_redis", return_value=mock_redis):
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
await store.record_async(
"agent1",
"gpt-4",
usage,
cost=0.05,
latency_ms=200,
user_id="u1",
department_id="d1",
)
asyncio.run(_run())
# Verify pipeline was used.
mock_redis.pipeline.assert_called_once()
mock_pipeline.hincrbyfloat.assert_called_once()
mock_pipeline.rpush.assert_called_once()