fix(review): U1 Redis quota enforcement — key construction + fail-closed + degradation recovery + async
This commit is contained in:
parent
abe2a66436
commit
00c8386939
|
|
@ -168,7 +168,7 @@ class LLMGateway:
|
||||||
result = await self._cache.get(cache_key)
|
result = await self._cache.get(cache_key)
|
||||||
if result.hit:
|
if result.hit:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
self._record_usage(
|
await self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=result.response.model,
|
model=result.response.model,
|
||||||
usage=result.response.usage,
|
usage=result.response.usage,
|
||||||
|
|
@ -197,7 +197,7 @@ class LLMGateway:
|
||||||
result = await self._cache.semantic_search(query_embedding)
|
result = await self._cache.semantic_search(query_embedding)
|
||||||
if result.hit:
|
if result.hit:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
self._record_usage(
|
await self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=result.response.model,
|
model=result.response.model,
|
||||||
usage=result.response.usage,
|
usage=result.response.usage,
|
||||||
|
|
@ -245,7 +245,7 @@ class LLMGateway:
|
||||||
if response.usage:
|
if response.usage:
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
cost = self._calculate_cost(model_name, response.usage)
|
cost = self._calculate_cost(model_name, response.usage)
|
||||||
self._record_usage(
|
await self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
usage=response.usage,
|
usage=response.usage,
|
||||||
|
|
@ -286,7 +286,7 @@ class LLMGateway:
|
||||||
cost = self._calculate_cost(response.model, response.usage)
|
cost = self._calculate_cost(response.model, response.usage)
|
||||||
|
|
||||||
# 记录使用量
|
# 记录使用量
|
||||||
self._record_usage(
|
await self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=response.model,
|
model=response.model,
|
||||||
usage=response.usage,
|
usage=response.usage,
|
||||||
|
|
@ -406,7 +406,7 @@ class LLMGateway:
|
||||||
if final_usage is None:
|
if final_usage is None:
|
||||||
final_usage = TokenUsage()
|
final_usage = TokenUsage()
|
||||||
cost = self._calculate_cost(final_model, final_usage)
|
cost = self._calculate_cost(final_model, final_usage)
|
||||||
self._record_usage(
|
await self._record_usage(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=final_model,
|
model=final_model,
|
||||||
usage=final_usage,
|
usage=final_usage,
|
||||||
|
|
@ -512,7 +512,7 @@ class LLMGateway:
|
||||||
# Quota enforcement helpers (U7)
|
# Quota enforcement helpers (U7)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _record_usage(
|
async def _record_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
@ -522,16 +522,15 @@ class LLMGateway:
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
department_ids: list[str] | None,
|
department_ids: list[str] | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record a usage event, attaching user_id and (first) department_id.
|
"""Record a usage event via the async store interface (KTD-6).
|
||||||
|
|
||||||
We attach only the first department_id to the record because
|
Attaches ``user_id`` and the first ``department_id`` to the
|
||||||
usage attribution is per-department. If a user belongs to
|
record. Multi-department attribution is handled by the caller
|
||||||
multiple departments, the caller is responsible for choosing
|
(see U2 — when a user belongs to multiple departments, each
|
||||||
which department to bill — the gateway just records what it's
|
department gets its own record).
|
||||||
told.
|
|
||||||
"""
|
"""
|
||||||
dept_id = department_ids[0] if department_ids else None
|
dept_id = department_ids[0] if department_ids else None
|
||||||
self._usage_tracker.record(
|
await self._usage_tracker.record_async(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
model=model,
|
model=model,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
|
|
@ -551,6 +550,10 @@ class LLMGateway:
|
||||||
|
|
||||||
Strictest-wins: if ANY department fails ANY check, raises
|
Strictest-wins: if ANY department fails ANY check, raises
|
||||||
:class:`QuotaExceededError` and the request is rejected.
|
:class:`QuotaExceededError` and the request is rejected.
|
||||||
|
|
||||||
|
Fail-closed (KTD-1): if the usage store is unavailable (Redis
|
||||||
|
degraded), raises :class:`UsageStoreUnavailableError`. The
|
||||||
|
caller must translate this to HTTP 503.
|
||||||
"""
|
"""
|
||||||
# Lazy import to avoid circular dependency (admin → ... → gateway).
|
# Lazy import to avoid circular dependency (admin → ... → gateway).
|
||||||
from agentkit.server.admin.quota_service import get_quota_service
|
from agentkit.server.admin.quota_service import get_quota_service
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class UsageTracker:
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
department_id: str | None = None,
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""记录一次使用"""
|
"""记录一次使用(sync — 在 async 上下文中可能阻塞事件循环)"""
|
||||||
self._store.record(
|
self._store.record(
|
||||||
agent_name,
|
agent_name,
|
||||||
model,
|
model,
|
||||||
|
|
@ -37,6 +37,44 @@ class UsageTracker:
|
||||||
department_id=department_id,
|
department_id=department_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def record_async(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
model: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
cost: float,
|
||||||
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""记录一次使用(async — 不阻塞事件循环,KTD-6)。
|
||||||
|
|
||||||
|
如果底层 store 提供了 ``record_async`` 方法则调用它;否则
|
||||||
|
回退到同步 ``record``(适用于 InMemoryUsageStore 等无 I/O 的 store)。
|
||||||
|
"""
|
||||||
|
record_async = getattr(self._store, "record_async", None)
|
||||||
|
if record_async is not None:
|
||||||
|
await record_async(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: sync store without async support.
|
||||||
|
self._store.record(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ Legacy v1 keys (still readable for backward compat):
|
||||||
agentkit:usage_records:{date} → List
|
agentkit:usage_records:{date} → List
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -24,6 +25,14 @@ from agentkit.llm.protocol import TokenUsage
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageStoreUnavailableError(Exception):
|
||||||
|
"""Raised when the usage store is unavailable for quota-critical queries.
|
||||||
|
|
||||||
|
The gateway treats this as a fail-closed condition (HTTP 503) —
|
||||||
|
refusing the request rather than allowing untracked usage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UsageRecord:
|
class UsageRecord:
|
||||||
"""使用量记录"""
|
"""使用量记录"""
|
||||||
|
|
@ -86,7 +95,7 @@ class UsageStore(Protocol):
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
department_id: str | None = None,
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record a usage event."""
|
"""Record a usage event (sync — may block in async contexts)."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
|
|
@ -97,7 +106,12 @@ class UsageStore(Protocol):
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
department_id: str | None = None,
|
department_id: str | None = None,
|
||||||
) -> UsageSummary:
|
) -> UsageSummary:
|
||||||
"""Query usage summary."""
|
"""Query usage summary.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
UsageStoreUnavailableError: When the store is degraded and
|
||||||
|
cannot answer quota-critical queries (fail-closed).
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -139,6 +153,27 @@ class InMemoryUsageStore:
|
||||||
if len(self._records) > self.MAX_RECORDS:
|
if len(self._records) > self.MAX_RECORDS:
|
||||||
self._records = self._records[-self.MAX_RECORDS :]
|
self._records = self._records[-self.MAX_RECORDS :]
|
||||||
|
|
||||||
|
async def record_async(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
model: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
cost: float,
|
||||||
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Async wrapper — InMemory store has no I/O, so just delegates to sync."""
|
||||||
|
self.record(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
|
@ -234,6 +269,19 @@ class RedisUsageStore:
|
||||||
Legacy v1 keys (still readable for backward compat):
|
Legacy v1 keys (still readable for backward compat):
|
||||||
agentkit:usage:{YYYY-MM-DD} → Hash
|
agentkit:usage:{YYYY-MM-DD} → Hash
|
||||||
agentkit:usage_records:{YYYY-MM-DD} → List
|
agentkit:usage_records:{YYYY-MM-DD} → List
|
||||||
|
|
||||||
|
Fail-closed semantics:
|
||||||
|
When Redis is unreachable and the store degrades to the in-memory
|
||||||
|
fallback, ``get_usage`` raises :class:`UsageStoreUnavailableError`.
|
||||||
|
The gateway translates this to HTTP 503, refusing the request
|
||||||
|
rather than allowing untracked usage (KTD-1). The fallback store
|
||||||
|
is used only for ``record`` (best-effort persistence) — never for
|
||||||
|
quota queries.
|
||||||
|
|
||||||
|
Degradation recovery:
|
||||||
|
A background health-check task pings Redis every 30 seconds. On
|
||||||
|
success, the degraded flag is cleared and the fallback store is
|
||||||
|
discarded (KTD-5).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
USAGE_PREFIX = "agentkit:usage:"
|
USAGE_PREFIX = "agentkit:usage:"
|
||||||
|
|
@ -242,6 +290,7 @@ class RedisUsageStore:
|
||||||
RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:"
|
RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:"
|
||||||
MAX_RECORDS_PER_DAY = 50000
|
MAX_RECORDS_PER_DAY = 50000
|
||||||
TTL_DAYS = 90 # Auto-expire after 90 days
|
TTL_DAYS = 90 # Auto-expire after 90 days
|
||||||
|
HEALTH_CHECK_INTERVAL = 30.0 # seconds
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
def __init__(self, redis_url: str = "redis://localhost:6379"):
|
||||||
self._redis_url = redis_url
|
self._redis_url = redis_url
|
||||||
|
|
@ -249,6 +298,7 @@ class RedisUsageStore:
|
||||||
self._sync_redis: Any = None
|
self._sync_redis: Any = None
|
||||||
self._fallback: InMemoryUsageStore | None = None
|
self._fallback: InMemoryUsageStore | None = None
|
||||||
self._degraded = False
|
self._degraded = False
|
||||||
|
self._health_check_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
async def _get_redis(self):
|
async def _get_redis(self):
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
|
|
@ -266,6 +316,14 @@ class RedisUsageStore:
|
||||||
return self._sync_redis
|
return self._sync_redis
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
|
# Cancel the health-check task first to avoid races.
|
||||||
|
if self._health_check_task is not None:
|
||||||
|
self._health_check_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._health_check_task
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
self._health_check_task = None
|
||||||
if self._redis is not None:
|
if self._redis is not None:
|
||||||
await self._redis.aclose()
|
await self._redis.aclose()
|
||||||
self._redis = None
|
self._redis = None
|
||||||
|
|
@ -279,6 +337,45 @@ class RedisUsageStore:
|
||||||
if self._fallback is None:
|
if self._fallback is None:
|
||||||
self._fallback = InMemoryUsageStore()
|
self._fallback = InMemoryUsageStore()
|
||||||
logger.warning("Redis usage store unreachable, degraded to in-memory")
|
logger.warning("Redis usage store unreachable, degraded to in-memory")
|
||||||
|
self._start_health_check()
|
||||||
|
|
||||||
|
def _start_health_check(self) -> None:
|
||||||
|
"""Start the background health-check task (idempotent).
|
||||||
|
|
||||||
|
The task pings Redis every ``HEALTH_CHECK_INTERVAL`` seconds. On
|
||||||
|
success, the degraded flag is cleared and the fallback store is
|
||||||
|
discarded (KTD-5).
|
||||||
|
"""
|
||||||
|
if self._health_check_task is not None and not self._health_check_task.done():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
# No running event loop (e.g. sync-only context) — skip.
|
||||||
|
logger.debug("No event loop available for health-check task")
|
||||||
|
return
|
||||||
|
self._health_check_task = loop.create_task(self._health_check_loop())
|
||||||
|
|
||||||
|
async def _health_check_loop(self) -> None:
|
||||||
|
"""Periodically ping Redis; clear degraded state on success."""
|
||||||
|
while self._degraded:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.HEALTH_CHECK_INTERVAL)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
redis_client = await self._get_redis()
|
||||||
|
await redis_client.ping()
|
||||||
|
# Redis is back — clear degraded state.
|
||||||
|
self._degraded = False
|
||||||
|
self._fallback = None
|
||||||
|
logger.info("Redis usage store recovered, cleared degraded state")
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Redis health check still failing")
|
||||||
|
# Keep looping until Redis recovers.
|
||||||
|
|
||||||
def _today_key(self) -> str:
|
def _today_key(self) -> str:
|
||||||
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
return datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
|
@ -299,6 +396,47 @@ class RedisUsageStore:
|
||||||
f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}",
|
f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_record_json(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
model: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
cost: float,
|
||||||
|
latency_ms: float,
|
||||||
|
user_id: str | None,
|
||||||
|
department_id: str | None,
|
||||||
|
) -> tuple[str, str, str, str]:
|
||||||
|
"""Build (hash_key, list_key, bucket_field, record_json) for a write."""
|
||||||
|
date_key = self._today_key()
|
||||||
|
hash_key, list_key = self._v2_keys(date_key, user_id, department_id)
|
||||||
|
bucket_field = f"{agent_name}:{model}"
|
||||||
|
rec = UsageRecord(
|
||||||
|
agent_name=agent_name,
|
||||||
|
model=model,
|
||||||
|
prompt_tokens=usage.prompt_tokens,
|
||||||
|
completion_tokens=usage.completion_tokens,
|
||||||
|
total_tokens=usage.total_tokens,
|
||||||
|
cost=cost,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
record_json = json.dumps(
|
||||||
|
{
|
||||||
|
"agent_name": rec.agent_name,
|
||||||
|
"model": rec.model,
|
||||||
|
"prompt_tokens": rec.prompt_tokens,
|
||||||
|
"completion_tokens": rec.completion_tokens,
|
||||||
|
"total_tokens": rec.total_tokens,
|
||||||
|
"cost": rec.cost,
|
||||||
|
"latency_ms": rec.latency_ms,
|
||||||
|
"timestamp": rec.timestamp,
|
||||||
|
"user_id": rec.user_id,
|
||||||
|
"department_id": rec.department_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return hash_key, list_key, bucket_field, record_json
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self,
|
||||||
agent_name: str,
|
agent_name: str,
|
||||||
|
|
@ -309,11 +447,12 @@ class RedisUsageStore:
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
department_id: str | None = None,
|
department_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Record usage — sync wrapper for async Redis.
|
"""Record usage — sync wrapper using sync Redis client.
|
||||||
|
|
||||||
Note: This is a sync method because UsageTracker.record() is sync.
|
Note: This is a sync method because UsageTracker.record() is sync.
|
||||||
For Redis, we use a sync Redis client for writes to avoid
|
For Redis, we use a sync Redis client for writes to avoid
|
||||||
needing an event loop in the caller.
|
needing an event loop in the caller. In async contexts prefer
|
||||||
|
:meth:`record_async` to avoid blocking the event loop (KTD-6).
|
||||||
"""
|
"""
|
||||||
if self._degraded and self._fallback is not None:
|
if self._degraded and self._fallback is not None:
|
||||||
self._fallback.record(
|
self._fallback.record(
|
||||||
|
|
@ -329,10 +468,9 @@ class RedisUsageStore:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r = self._get_sync_redis()
|
r = self._get_sync_redis()
|
||||||
|
hash_key, list_key, bucket_field, record_json = self._build_record_json(
|
||||||
date_key = self._today_key()
|
agent_name, model, usage, cost, latency_ms, user_id, department_id
|
||||||
hash_key, list_key = self._v2_keys(date_key, user_id, department_id)
|
)
|
||||||
bucket_field = f"{agent_name}:{model}"
|
|
||||||
|
|
||||||
# Atomic HINCRBYFLOAT for bucket aggregation
|
# Atomic HINCRBYFLOAT for bucket aggregation
|
||||||
pipe = r.pipeline()
|
pipe = r.pipeline()
|
||||||
|
|
@ -343,34 +481,7 @@ class RedisUsageStore:
|
||||||
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
|
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
|
||||||
|
|
||||||
# Append record
|
# Append record
|
||||||
rec = UsageRecord(
|
pipe.rpush(list_key, record_json)
|
||||||
agent_name=agent_name,
|
|
||||||
model=model,
|
|
||||||
prompt_tokens=usage.prompt_tokens,
|
|
||||||
completion_tokens=usage.completion_tokens,
|
|
||||||
total_tokens=usage.total_tokens,
|
|
||||||
cost=cost,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
user_id=user_id,
|
|
||||||
department_id=department_id,
|
|
||||||
)
|
|
||||||
pipe.rpush(
|
|
||||||
list_key,
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"agent_name": rec.agent_name,
|
|
||||||
"model": rec.model,
|
|
||||||
"prompt_tokens": rec.prompt_tokens,
|
|
||||||
"completion_tokens": rec.completion_tokens,
|
|
||||||
"total_tokens": rec.total_tokens,
|
|
||||||
"cost": rec.cost,
|
|
||||||
"latency_ms": rec.latency_ms,
|
|
||||||
"timestamp": rec.timestamp,
|
|
||||||
"user_id": rec.user_id,
|
|
||||||
"department_id": rec.department_id,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
)
|
|
||||||
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
||||||
|
|
||||||
# Set TTL on first write of the day
|
# Set TTL on first write of the day
|
||||||
|
|
@ -392,6 +503,65 @@ class RedisUsageStore:
|
||||||
department_id=department_id,
|
department_id=department_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def record_async(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
model: str,
|
||||||
|
usage: TokenUsage,
|
||||||
|
cost: float,
|
||||||
|
latency_ms: float,
|
||||||
|
user_id: str | None = None,
|
||||||
|
department_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Async record using ``redis.asyncio`` — does not block the event loop (KTD-6).
|
||||||
|
|
||||||
|
Preferred over :meth:`record` in async contexts (gateway, routes).
|
||||||
|
Falls back to the in-memory store on Redis failure, same as the
|
||||||
|
sync version.
|
||||||
|
"""
|
||||||
|
if self._degraded and self._fallback is not None:
|
||||||
|
await self._fallback.record_async(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = await self._get_redis()
|
||||||
|
hash_key, list_key, bucket_field, record_json = self._build_record_json(
|
||||||
|
agent_name, model, usage, cost, latency_ms, user_id, department_id
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = r.pipeline()
|
||||||
|
pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost)
|
||||||
|
pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens)
|
||||||
|
pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens)
|
||||||
|
pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens)
|
||||||
|
pipe.hincrby(hash_key, f"{bucket_field}:count", 1)
|
||||||
|
pipe.rpush(list_key, record_json)
|
||||||
|
pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1)
|
||||||
|
pipe.expire(hash_key, self.TTL_DAYS * 86400)
|
||||||
|
pipe.expire(list_key, self.TTL_DAYS * 86400)
|
||||||
|
await pipe.execute()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis async usage record failed: {e}")
|
||||||
|
self._degrade_to_fallback()
|
||||||
|
if self._fallback is not None:
|
||||||
|
await self._fallback.record_async(
|
||||||
|
agent_name,
|
||||||
|
model,
|
||||||
|
usage,
|
||||||
|
cost,
|
||||||
|
latency_ms,
|
||||||
|
user_id=user_id,
|
||||||
|
department_id=department_id,
|
||||||
|
)
|
||||||
|
|
||||||
def get_usage(
|
def get_usage(
|
||||||
self,
|
self,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
|
|
@ -405,14 +575,28 @@ class RedisUsageStore:
|
||||||
Scans v2 keys (filtered by user_id/department_id when provided)
|
Scans v2 keys (filtered by user_id/department_id when provided)
|
||||||
and legacy v1 keys (no per-user/department scoping). Records
|
and legacy v1 keys (no per-user/department scoping). Records
|
||||||
from both schemas are merged.
|
from both schemas are merged.
|
||||||
|
|
||||||
|
Key construction (U1 fix):
|
||||||
|
- Both user_id and department_id provided → direct key lookup
|
||||||
|
(exact match).
|
||||||
|
- Only user_id provided → SCAN ``...:{user_id}:*`` to aggregate
|
||||||
|
across all departments the user belongs to.
|
||||||
|
- Only department_id provided → SCAN ``...:*:{department_id}``
|
||||||
|
to aggregate across all users in that department.
|
||||||
|
- Neither provided → SCAN ``...:*`` (all records for the date).
|
||||||
|
|
||||||
|
Fail-closed (KTD-1):
|
||||||
|
When the store is degraded (Redis unreachable), raises
|
||||||
|
:class:`UsageStoreUnavailableError`. The fallback store is
|
||||||
|
used only for ``record`` — never for quota queries.
|
||||||
"""
|
"""
|
||||||
if self._degraded and self._fallback is not None:
|
# Fail-closed: when degraded, refuse quota-critical queries.
|
||||||
return self._fallback.get_usage(
|
# The in-memory fallback is only for recording (best-effort),
|
||||||
agent_name=agent_name,
|
# not for quota checks — returning an empty summary would
|
||||||
start_time=start_time,
|
# make quota checks pass (fail-open), which is a security bug.
|
||||||
end_time=end_time,
|
if self._degraded:
|
||||||
user_id=user_id,
|
raise UsageStoreUnavailableError(
|
||||||
department_id=department_id,
|
"Redis usage store is degraded — cannot answer quota-critical query"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -429,14 +613,29 @@ class RedisUsageStore:
|
||||||
end_date = end.date()
|
end_date = end.date()
|
||||||
while current <= end_date:
|
while current <= end_date:
|
||||||
date_key = current.isoformat()
|
date_key = current.isoformat()
|
||||||
# When user_id/department_id is provided, scan only the
|
# Determine the SCAN pattern based on which scope dims are provided.
|
||||||
# matching scope key. Otherwise scan all scopes for that
|
# Key format: {RECORDS_PREFIX_V2}{date}:{user_or_none}:{dept_or_none}
|
||||||
# date via SCAN.
|
if user_id is not None and department_id is not None:
|
||||||
if user_id is not None or department_id is not None:
|
# Both provided → direct key lookup (exact match).
|
||||||
list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}"
|
list_key = (
|
||||||
|
f"{self.RECORDS_PREFIX_V2}{date_key}:"
|
||||||
|
f"{self._scope_key(user_id)}:{self._scope_key(department_id)}"
|
||||||
|
)
|
||||||
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
|
all_records.extend(self._read_list(r, list_key, start, end, agent_name))
|
||||||
|
elif user_id is not None:
|
||||||
|
# Only user_id → aggregate across all departments.
|
||||||
|
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:*"
|
||||||
|
for key in r.scan_iter(match=pattern, count=200):
|
||||||
|
all_records.extend(self._read_list(r, key, start, end, agent_name))
|
||||||
|
elif department_id is not None:
|
||||||
|
# Only department_id → aggregate across all users.
|
||||||
|
pattern = (
|
||||||
|
f"{self.RECORDS_PREFIX_V2}{date_key}:*:{self._scope_key(department_id)}"
|
||||||
|
)
|
||||||
|
for key in r.scan_iter(match=pattern, count=200):
|
||||||
|
all_records.extend(self._read_list(r, key, start, end, agent_name))
|
||||||
else:
|
else:
|
||||||
# Scan all v2 list keys for this date.
|
# Neither provided → scan all v2 list keys for this date.
|
||||||
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*"
|
pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*"
|
||||||
for key in r.scan_iter(match=pattern, count=200):
|
for key in r.scan_iter(match=pattern, count=200):
|
||||||
all_records.extend(self._read_list(r, key, start, end, agent_name))
|
all_records.extend(self._read_list(r, key, start, end, agent_name))
|
||||||
|
|
@ -460,17 +659,13 @@ class RedisUsageStore:
|
||||||
return UsageSummary()
|
return UsageSummary()
|
||||||
|
|
||||||
return InMemoryUsageStore._aggregate(all_records)
|
return InMemoryUsageStore._aggregate(all_records)
|
||||||
|
except UsageStoreUnavailableError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Redis usage query failed: {e}")
|
logger.warning(f"Redis usage query failed: {e}")
|
||||||
if self._fallback is not None:
|
# Degrade for future calls, but fail-closed for this one.
|
||||||
return self._fallback.get_usage(
|
self._degrade_to_fallback()
|
||||||
agent_name=agent_name,
|
raise UsageStoreUnavailableError(f"Redis usage query failed: {e}") from e
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
user_id=user_id,
|
|
||||||
department_id=department_id,
|
|
||||||
)
|
|
||||||
return UsageSummary()
|
|
||||||
|
|
||||||
def get_usage_by_user(
|
def get_usage_by_user(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from fastapi import Request
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH
|
from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH
|
||||||
|
|
||||||
|
|
@ -134,6 +134,11 @@ async def get_department_context(request: Request) -> DepartmentContext:
|
||||||
If ``request.state.current_user`` is missing entirely (e.g. the
|
If ``request.state.current_user`` is missing entirely (e.g. the
|
||||||
auth middleware was not installed), returns an empty context
|
auth middleware was not installed), returns an empty context
|
||||||
equivalent to the unauthenticated case.
|
equivalent to the unauthenticated case.
|
||||||
|
|
||||||
|
Fail-closed (KTD-1): if the DB lookup fails for a regular user,
|
||||||
|
raises ``HTTPException(503)``. Returning an empty list would make
|
||||||
|
quota enforcement skip the check (fail-open), which is a security
|
||||||
|
bug. Admins and API-key clients are unaffected (no DB lookup).
|
||||||
"""
|
"""
|
||||||
current_user: dict[str, Any] | None = getattr(request.state, "current_user", None)
|
current_user: dict[str, Any] | None = getattr(request.state, "current_user", None)
|
||||||
if current_user is None:
|
if current_user is None:
|
||||||
|
|
@ -155,15 +160,23 @@ async def get_department_context(request: Request) -> DepartmentContext:
|
||||||
return DepartmentContext(user_id=None, department_ids=[], is_admin=False)
|
return DepartmentContext(user_id=None, department_ids=[], is_admin=False)
|
||||||
|
|
||||||
# Regular user: look up their active department ids.
|
# Regular user: look up their active department ids.
|
||||||
|
# Fail-closed on DB errors — returning an empty list would bypass
|
||||||
|
# quota enforcement (fail-open), which is a security vulnerability.
|
||||||
db_path = _resolve_db_path(request)
|
db_path = _resolve_db_path(request)
|
||||||
try:
|
try:
|
||||||
department_ids = await _fetch_user_department_ids(db_path, user_id)
|
department_ids = await _fetch_user_department_ids(db_path, user_id)
|
||||||
except Exception: # noqa: BLE001 — never block a request on DB errors
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to fetch department ids for user %s — falling back to empty list",
|
"Failed to fetch department ids for user %s — fail-closed (503)",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
department_ids = []
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail={
|
||||||
|
"error": "department_lookup_failed",
|
||||||
|
"detail": "Cannot verify department membership — refusing request",
|
||||||
|
},
|
||||||
|
) from None
|
||||||
|
|
||||||
return DepartmentContext(
|
return DepartmentContext(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||||
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
||||||
from agentkit.llm.gateway import QuotaExceededError
|
from agentkit.llm.gateway import QuotaExceededError
|
||||||
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||||||
|
from agentkit.llm.providers.usage_store import UsageStoreUnavailableError
|
||||||
from agentkit.server.admin.context import get_department_context
|
from agentkit.server.admin.context import get_department_context
|
||||||
|
|
||||||
router = APIRouter(prefix="/llm", tags=["llm-gateway"])
|
router = APIRouter(prefix="/llm", tags=["llm-gateway"])
|
||||||
|
|
@ -109,6 +110,11 @@ async def chat(
|
||||||
)
|
)
|
||||||
except QuotaExceededError as e:
|
except QuotaExceededError as e:
|
||||||
raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e
|
raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e
|
||||||
|
except UsageStoreUnavailableError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail={"error": "usage_store_unavailable", "detail": str(e)},
|
||||||
|
) from e
|
||||||
except ModelNotFoundError as e:
|
except ModelNotFoundError as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||||
except LLMProviderError as e:
|
except LLMProviderError as e:
|
||||||
|
|
@ -150,6 +156,12 @@ async def chat_stream(
|
||||||
error_payload = _quota_error_payload(e)
|
error_payload = _quota_error_payload(e)
|
||||||
error_payload["error"] = "quota_exceeded"
|
error_payload["error"] = "quota_exceeded"
|
||||||
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
||||||
|
except UsageStoreUnavailableError as e:
|
||||||
|
error_payload = {
|
||||||
|
"error": "usage_store_unavailable",
|
||||||
|
"detail": str(e),
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
||||||
except ModelNotFoundError as e:
|
except ModelNotFoundError as e:
|
||||||
error_payload = {"error": "model_not_found", "detail": str(e)}
|
error_payload = {"error": "model_not_found", "detail": str(e)}
|
||||||
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from agentkit.llm.protocol import TokenUsage
|
from agentkit.llm.protocol import TokenUsage
|
||||||
from agentkit.llm.providers.usage_store import (
|
from agentkit.llm.providers.usage_store import (
|
||||||
InMemoryUsageStore,
|
InMemoryUsageStore,
|
||||||
|
|
@ -10,6 +12,7 @@ from agentkit.llm.providers.usage_store import (
|
||||||
UsageBucket,
|
UsageBucket,
|
||||||
UsageSummary,
|
UsageSummary,
|
||||||
UsageStore,
|
UsageStore,
|
||||||
|
UsageStoreUnavailableError,
|
||||||
create_usage_store,
|
create_usage_store,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -315,20 +318,40 @@ class TestRedisUsageStoreMocked:
|
||||||
summary = store._fallback.get_usage()
|
summary = store._fallback.get_usage()
|
||||||
assert len(summary.records) == 1
|
assert len(summary.records) == 1
|
||||||
|
|
||||||
def test_get_usage_degraded_uses_fallback(self):
|
def test_get_usage_degraded_raises_unavailable(self):
|
||||||
|
"""Fail-closed (KTD-1): degraded get_usage raises UsageStoreUnavailableError."""
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
|
store.get_usage()
|
||||||
|
|
||||||
|
def test_get_usage_degraded_raises_even_with_fallback_data(self):
|
||||||
|
"""Even if fallback has data, get_usage must fail-closed (no quota bypass)."""
|
||||||
store = self._make_store()
|
store = self._make_store()
|
||||||
store._degrade_to_fallback()
|
store._degrade_to_fallback()
|
||||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
|
store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
|
||||||
summary = store.get_usage()
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
assert len(summary.records) == 1
|
store.get_usage()
|
||||||
|
|
||||||
def test_get_usage_degraded_no_fallback_returns_empty(self):
|
def test_get_usage_degraded_with_user_filter_raises(self):
|
||||||
|
"""Fail-closed applies even when filtering by user_id."""
|
||||||
store = self._make_store()
|
store = self._make_store()
|
||||||
store._degraded = True
|
store._degrade_to_fallback()
|
||||||
# No fallback set — should return empty
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
summary = store.get_usage()
|
store.get_usage(user_id="u1")
|
||||||
assert summary.total_tokens == 0
|
|
||||||
|
def test_get_usage_by_user_degraded_raises(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
|
store.get_usage_by_user("u1")
|
||||||
|
|
||||||
|
def test_get_usage_by_department_degraded_raises(self):
|
||||||
|
store = self._make_store()
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
|
store.get_usage_by_department("d1")
|
||||||
|
|
||||||
def test_today_key_format(self):
|
def test_today_key_format(self):
|
||||||
store = self._make_store()
|
store = self._make_store()
|
||||||
|
|
@ -369,65 +392,20 @@ class TestRedisUsageStoreMocked:
|
||||||
assert summary.records[0].user_id == "u1"
|
assert summary.records[0].user_id == "u1"
|
||||||
assert summary.records[0].department_id == "d1"
|
assert summary.records[0].department_id == "d1"
|
||||||
|
|
||||||
def test_get_usage_degraded_with_user_filter(self):
|
def test_record_async_degraded_uses_fallback(self):
|
||||||
store = self._make_store()
|
"""Async record in degraded state uses fallback (recording is allowed)."""
|
||||||
store._degrade_to_fallback()
|
import asyncio
|
||||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
|
||||||
store._fallback.record(
|
|
||||||
"agent1",
|
|
||||||
"gpt-4",
|
|
||||||
usage,
|
|
||||||
cost=0.05,
|
|
||||||
latency_ms=200,
|
|
||||||
user_id="u1",
|
|
||||||
department_id="d1",
|
|
||||||
)
|
|
||||||
store._fallback.record(
|
|
||||||
"agent1",
|
|
||||||
"gpt-4",
|
|
||||||
usage,
|
|
||||||
cost=0.05,
|
|
||||||
latency_ms=200,
|
|
||||||
user_id="u2",
|
|
||||||
department_id="d2",
|
|
||||||
)
|
|
||||||
summary = store.get_usage(user_id="u1")
|
|
||||||
assert len(summary.records) == 1
|
|
||||||
assert summary.records[0].user_id == "u1"
|
|
||||||
|
|
||||||
def test_get_usage_by_user_degraded(self):
|
|
||||||
store = self._make_store()
|
store = self._make_store()
|
||||||
store._degrade_to_fallback()
|
store._degrade_to_fallback()
|
||||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
store._fallback.record(
|
|
||||||
"agent1",
|
|
||||||
"gpt-4",
|
|
||||||
usage,
|
|
||||||
cost=0.05,
|
|
||||||
latency_ms=200,
|
|
||||||
user_id="u1",
|
|
||||||
department_id="d1",
|
|
||||||
)
|
|
||||||
summary = store.get_usage_by_user("u1")
|
|
||||||
assert len(summary.records) == 1
|
|
||||||
assert summary.records[0].user_id == "u1"
|
|
||||||
|
|
||||||
def test_get_usage_by_department_degraded(self):
|
async def _run():
|
||||||
store = self._make_store()
|
await store.record_async("agent1", "gpt-4", usage, cost=0.05, latency_ms=200)
|
||||||
store._degrade_to_fallback()
|
|
||||||
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
asyncio.run(_run())
|
||||||
store._fallback.record(
|
summary = store._fallback.get_usage()
|
||||||
"agent1",
|
|
||||||
"gpt-4",
|
|
||||||
usage,
|
|
||||||
cost=0.05,
|
|
||||||
latency_ms=200,
|
|
||||||
user_id="u1",
|
|
||||||
department_id="d1",
|
|
||||||
)
|
|
||||||
summary = store.get_usage_by_department("d1")
|
|
||||||
assert len(summary.records) == 1
|
assert len(summary.records) == 1
|
||||||
assert summary.records[0].department_id == "d1"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -448,3 +426,347 @@ class TestCreateUsageStore:
|
||||||
store = create_usage_store(backend="redis")
|
store = create_usage_store(backend="redis")
|
||||||
# May be InMemory if redis package unavailable
|
# May be InMemory if redis package unavailable
|
||||||
assert isinstance(store, (InMemoryUsageStore, RedisUsageStore))
|
assert isinstance(store, (InMemoryUsageStore, RedisUsageStore))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U1: Key construction fix — SCAN patterns for partial scope queries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisUsageStoreKeyConstruction:
|
||||||
|
"""U1: Verify get_usage constructs correct SCAN patterns when only
|
||||||
|
user_id OR only department_id is provided (not both).
|
||||||
|
|
||||||
|
Previously, ``get_usage(department_id=X)`` constructed key
|
||||||
|
``...:none:X`` which only matched records with no user — missing
|
||||||
|
all records from actual users in that department. The fix uses
|
||||||
|
SCAN with pattern ``...:*:X`` to aggregate across all users.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _make_store_with_mock_redis(self, mock_redis):
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
store._sync_redis = mock_redis
|
||||||
|
return store
|
||||||
|
|
||||||
|
def test_get_usage_by_department_scans_all_users(self):
|
||||||
|
"""get_usage(department_id=X) should SCAN ...:*:X (all users)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
# Single-day range so scan_iter is called once per scope.
|
||||||
|
keys_for_d1 = [
|
||||||
|
"agentkit:usage_records:v2:2026-06-21:u1:d1",
|
||||||
|
"agentkit:usage_records:v2:2026-06-21:u2:d1",
|
||||||
|
"agentkit:usage_records:v2:2026-06-21:none:d1",
|
||||||
|
]
|
||||||
|
mock_redis.scan_iter.return_value = iter(keys_for_d1)
|
||||||
|
# lrange returns records for each v2 key + 1 legacy v1 key.
|
||||||
|
record_u1 = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":100,'
|
||||||
|
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u1","department_id":"d1"}'
|
||||||
|
)
|
||||||
|
record_u2 = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":200,'
|
||||||
|
'"completion_tokens":100,"total_tokens":300,"cost":0.10,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u2","department_id":"d1"}'
|
||||||
|
)
|
||||||
|
record_none = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":50,'
|
||||||
|
'"completion_tokens":25,"total_tokens":75,"cost":0.02,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":null,"department_id":"d1"}'
|
||||||
|
)
|
||||||
|
mock_redis.lrange.side_effect = [
|
||||||
|
[record_u1], # for key u1:d1
|
||||||
|
[record_u2], # for key u2:d1
|
||||||
|
[record_none], # for key none:d1
|
||||||
|
[], # legacy v1 key
|
||||||
|
]
|
||||||
|
|
||||||
|
store = self._make_store_with_mock_redis(mock_redis)
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
# Use end_time within same day (records are at 10:00 UTC).
|
||||||
|
end = day + timedelta(hours=12)
|
||||||
|
summary = store.get_usage(department_id="d1", start_time=day, end_time=end)
|
||||||
|
|
||||||
|
# Should aggregate records from all users in d1.
|
||||||
|
assert summary.total_tokens == 525 # 150 + 300 + 75
|
||||||
|
# Verify SCAN was called with the correct pattern.
|
||||||
|
scan_call = mock_redis.scan_iter.call_args
|
||||||
|
pattern = scan_call.kwargs.get("match") or scan_call.args[0]
|
||||||
|
assert "*:d1" in pattern
|
||||||
|
|
||||||
|
def test_get_usage_by_user_scans_all_departments(self):
|
||||||
|
"""get_usage(user_id=X) should SCAN ...:X:* (all departments)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.scan_iter.return_value = iter(
|
||||||
|
[
|
||||||
|
"agentkit:usage_records:v2:2026-06-21:u1:d1",
|
||||||
|
"agentkit:usage_records:v2:2026-06-21:u1:d2",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
record_d1 = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":100,'
|
||||||
|
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u1","department_id":"d1"}'
|
||||||
|
)
|
||||||
|
record_d2 = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":200,'
|
||||||
|
'"completion_tokens":100,"total_tokens":300,"cost":0.10,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u1","department_id":"d2"}'
|
||||||
|
)
|
||||||
|
mock_redis.lrange.side_effect = [
|
||||||
|
[record_d1],
|
||||||
|
[record_d2],
|
||||||
|
[], # legacy v1 key
|
||||||
|
]
|
||||||
|
|
||||||
|
store = self._make_store_with_mock_redis(mock_redis)
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
end = day + timedelta(hours=12)
|
||||||
|
summary = store.get_usage(user_id="u1", start_time=day, end_time=end)
|
||||||
|
|
||||||
|
assert summary.total_tokens == 450 # 150 + 300
|
||||||
|
scan_call = mock_redis.scan_iter.call_args
|
||||||
|
pattern = scan_call.kwargs.get("match") or scan_call.args[0]
|
||||||
|
assert "u1:*" in pattern
|
||||||
|
|
||||||
|
def test_get_usage_both_user_and_dept_uses_direct_key(self):
|
||||||
|
"""get_usage(user_id=X, department_id=Y) → direct key lookup (no SCAN)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
record = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":100,'
|
||||||
|
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u1","department_id":"d1"}'
|
||||||
|
)
|
||||||
|
mock_redis.lrange.side_effect = [
|
||||||
|
[record], # direct v2 key
|
||||||
|
[], # legacy v1 key
|
||||||
|
]
|
||||||
|
|
||||||
|
store = self._make_store_with_mock_redis(mock_redis)
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
end = day + timedelta(hours=12)
|
||||||
|
summary = store.get_usage(
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
start_time=day,
|
||||||
|
end_time=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
# SCAN should NOT be called (direct key lookup instead).
|
||||||
|
mock_redis.scan_iter.assert_not_called()
|
||||||
|
|
||||||
|
def test_get_usage_no_filter_scans_all(self):
|
||||||
|
"""get_usage() with no filters → SCAN ...:* (all records)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.scan_iter.return_value = iter(["agentkit:usage_records:v2:2026-06-21:u1:d1"])
|
||||||
|
record = (
|
||||||
|
'{"agent_name":"a","model":"m","prompt_tokens":100,'
|
||||||
|
'"completion_tokens":50,"total_tokens":150,"cost":0.05,'
|
||||||
|
'"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",'
|
||||||
|
'"user_id":"u1","department_id":"d1"}'
|
||||||
|
)
|
||||||
|
mock_redis.lrange.side_effect = [
|
||||||
|
[record],
|
||||||
|
[], # legacy v1 key
|
||||||
|
]
|
||||||
|
|
||||||
|
store = self._make_store_with_mock_redis(mock_redis)
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
end = day + timedelta(hours=12)
|
||||||
|
summary = store.get_usage(start_time=day, end_time=end)
|
||||||
|
|
||||||
|
assert summary.total_tokens == 150
|
||||||
|
mock_redis.scan_iter.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_usage_empty_redis_returns_empty_summary(self):
|
||||||
|
"""No records in Redis → empty UsageSummary (not an error)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.scan_iter.return_value = iter([])
|
||||||
|
mock_redis.lrange.return_value = []
|
||||||
|
|
||||||
|
store = self._make_store_with_mock_redis(mock_redis)
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
end = day + timedelta(hours=12)
|
||||||
|
summary = store.get_usage(department_id="d1", start_time=day, end_time=end)
|
||||||
|
|
||||||
|
assert summary.total_tokens == 0
|
||||||
|
assert len(summary.records) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U1: Degradation recovery — health check clears degraded state
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisUsageStoreDegradationRecovery:
|
||||||
|
"""U1 (KTD-5): Redis recovery clears degraded state via health check."""
|
||||||
|
|
||||||
|
def test_health_check_clears_degraded_on_recovery(self):
|
||||||
|
"""When Redis recovers, _degraded is cleared and _fallback discarded."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
assert store._degraded
|
||||||
|
assert store._fallback is not None
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
# Mock the async Redis client's ping to succeed.
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(return_value=True)
|
||||||
|
with patch.object(store, "_get_redis", return_value=mock_redis):
|
||||||
|
# Manually trigger one iteration of the health check loop.
|
||||||
|
# We set a very short interval and run one cycle.
|
||||||
|
store.HEALTH_CHECK_INTERVAL = 0.01
|
||||||
|
task = asyncio.create_task(store._health_check_loop())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
assert not store._degraded
|
||||||
|
assert store._fallback is None
|
||||||
|
|
||||||
|
def test_health_check_keeps_degraded_on_failure(self):
|
||||||
|
"""When Redis is still down, degraded state persists."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.ping = AsyncMock(side_effect=ConnectionError("still down"))
|
||||||
|
with patch.object(store, "_get_redis", return_value=mock_redis):
|
||||||
|
store.HEALTH_CHECK_INTERVAL = 0.01
|
||||||
|
task = asyncio.create_task(store._health_check_loop())
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
assert store._degraded # Still degraded
|
||||||
|
|
||||||
|
def test_aclose_cancels_health_check_task(self):
|
||||||
|
"""aclose() cancels the health check task."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
store._degrade_to_fallback()
|
||||||
|
assert store._health_check_task is not None
|
||||||
|
await store.aclose()
|
||||||
|
assert store._health_check_task is None
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U1: Fail-closed — get_usage raises when Redis query fails
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisUsageStoreFailClosed:
|
||||||
|
"""U1 (KTD-1): Redis query failure → UsageStoreUnavailableError (not empty)."""
|
||||||
|
|
||||||
|
def test_get_usage_raises_on_redis_connection_failure(self):
|
||||||
|
"""When Redis connection fails during get_usage, raise (fail-closed)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.scan_iter.side_effect = ConnectionError("Redis gone")
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
store._sync_redis = mock_redis
|
||||||
|
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
|
store.get_usage(department_id="d1", start_time=day, end_time=day + timedelta(days=1))
|
||||||
|
|
||||||
|
# Should also degrade for future calls.
|
||||||
|
assert store._degraded
|
||||||
|
|
||||||
|
def test_get_usage_fail_closed_not_empty_summary(self):
|
||||||
|
"""Critical: must NOT return empty summary on Redis failure (that's fail-open)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.scan_iter.side_effect = ConnectionError("Redis gone")
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
store._sync_redis = mock_redis
|
||||||
|
|
||||||
|
day = datetime(2026, 6, 21, tzinfo=timezone.utc)
|
||||||
|
# Must raise, not return UsageSummary().
|
||||||
|
with pytest.raises(UsageStoreUnavailableError):
|
||||||
|
store.get_usage(start_time=day, end_time=day + timedelta(days=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U1: Async record — record_async uses redis.asyncio
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisUsageStoreAsyncRecord:
|
||||||
|
"""U1 (KTD-6): record_async uses async Redis client, doesn't block event loop."""
|
||||||
|
|
||||||
|
def test_record_async_writes_to_redis(self):
|
||||||
|
"""record_async should write to Redis via async client."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
store = RedisUsageStore(redis_url="redis://localhost:6379")
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
mock_pipeline.execute = AsyncMock(return_value=[True])
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
mock_redis.pipeline.return_value = mock_pipeline
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
with patch.object(store, "_get_redis", return_value=mock_redis):
|
||||||
|
usage = TokenUsage(prompt_tokens=100, completion_tokens=50)
|
||||||
|
await store.record_async(
|
||||||
|
"agent1",
|
||||||
|
"gpt-4",
|
||||||
|
usage,
|
||||||
|
cost=0.05,
|
||||||
|
latency_ms=200,
|
||||||
|
user_id="u1",
|
||||||
|
department_id="d1",
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
# Verify pipeline was used.
|
||||||
|
mock_redis.pipeline.assert_called_once()
|
||||||
|
mock_pipeline.hincrbyfloat.assert_called_once()
|
||||||
|
mock_pipeline.rpush.assert_called_once()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue