diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 8d93d00..fdf6648 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -168,7 +168,7 @@ class LLMGateway: result = await self._cache.get(cache_key) if result.hit: latency_ms = (time.monotonic() - start) * 1000 - self._record_usage( + await self._record_usage( agent_name=agent_name, model=result.response.model, usage=result.response.usage, @@ -197,7 +197,7 @@ class LLMGateway: result = await self._cache.semantic_search(query_embedding) if result.hit: latency_ms = (time.monotonic() - start) * 1000 - self._record_usage( + await self._record_usage( agent_name=agent_name, model=result.response.model, usage=result.response.usage, @@ -245,7 +245,7 @@ class LLMGateway: if response.usage: latency_ms = (time.monotonic() - start) * 1000 cost = self._calculate_cost(model_name, response.usage) - self._record_usage( + await self._record_usage( agent_name=agent_name, model=model_name, usage=response.usage, @@ -286,7 +286,7 @@ class LLMGateway: cost = self._calculate_cost(response.model, response.usage) # 记录使用量 - self._record_usage( + await self._record_usage( agent_name=agent_name, model=response.model, usage=response.usage, @@ -406,7 +406,7 @@ class LLMGateway: if final_usage is None: final_usage = TokenUsage() cost = self._calculate_cost(final_model, final_usage) - self._record_usage( + await self._record_usage( agent_name=agent_name, model=final_model, usage=final_usage, @@ -512,7 +512,7 @@ class LLMGateway: # Quota enforcement helpers (U7) # ------------------------------------------------------------------ - def _record_usage( + async def _record_usage( self, agent_name: str, model: str, @@ -522,16 +522,15 @@ class LLMGateway: user_id: str | None, department_ids: list[str] | None, ) -> None: - """Record a usage event, attaching user_id and (first) department_id. + """Record a usage event via the async store interface (KTD-6). - We attach only the first department_id to the record because - usage attribution is per-department. If a user belongs to - multiple departments, the caller is responsible for choosing - which department to bill — the gateway just records what it's - told. + Attaches ``user_id`` and the first ``department_id`` to the + record. Multi-department attribution is handled by the caller + (see U2 — when a user belongs to multiple departments, each + department gets its own record). """ dept_id = department_ids[0] if department_ids else None - self._usage_tracker.record( + await self._usage_tracker.record_async( agent_name=agent_name, model=model, usage=usage, @@ -551,6 +550,10 @@ class LLMGateway: Strictest-wins: if ANY department fails ANY check, raises :class:`QuotaExceededError` and the request is rejected. + + Fail-closed (KTD-1): if the usage store is unavailable (Redis + degraded), raises :class:`UsageStoreUnavailableError`. The + caller must translate this to HTTP 503. """ # Lazy import to avoid circular dependency (admin → ... → gateway). from agentkit.server.admin.quota_service import get_quota_service diff --git a/src/agentkit/llm/providers/tracker.py b/src/agentkit/llm/providers/tracker.py index a962a87..7ea4eff 100644 --- a/src/agentkit/llm/providers/tracker.py +++ b/src/agentkit/llm/providers/tracker.py @@ -26,7 +26,7 @@ class UsageTracker: user_id: str | None = None, department_id: str | None = None, ) -> None: - """记录一次使用""" + """记录一次使用(sync — 在 async 上下文中可能阻塞事件循环)""" self._store.record( agent_name, model, @@ -37,6 +37,44 @@ class UsageTracker: department_id=department_id, ) + async def record_async( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, + ) -> None: + """记录一次使用(async — 不阻塞事件循环,KTD-6)。 + + 如果底层 store 提供了 ``record_async`` 方法则调用它;否则 + 回退到同步 ``record``(适用于 InMemoryUsageStore 等无 I/O 的 store)。 + """ + record_async = getattr(self._store, "record_async", None) + if record_async is not None: + await record_async( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) + else: + # Fallback: sync store without async support. + self._store.record( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) + def get_usage( self, agent_name: str | None = None, diff --git a/src/agentkit/llm/providers/usage_store.py b/src/agentkit/llm/providers/usage_store.py index 67ee4b2..f7cf77d 100644 --- a/src/agentkit/llm/providers/usage_store.py +++ b/src/agentkit/llm/providers/usage_store.py @@ -13,6 +13,7 @@ Legacy v1 keys (still readable for backward compat): agentkit:usage_records:{date} → List """ +import asyncio import json import logging from dataclasses import dataclass, field @@ -24,6 +25,14 @@ from agentkit.llm.protocol import TokenUsage logger = logging.getLogger(__name__) +class UsageStoreUnavailableError(Exception): + """Raised when the usage store is unavailable for quota-critical queries. + + The gateway treats this as a fail-closed condition (HTTP 503) — + refusing the request rather than allowing untracked usage. + """ + + @dataclass class UsageRecord: """使用量记录""" @@ -86,7 +95,7 @@ class UsageStore(Protocol): user_id: str | None = None, department_id: str | None = None, ) -> None: - """Record a usage event.""" + """Record a usage event (sync — may block in async contexts).""" ... def get_usage( @@ -97,7 +106,12 @@ class UsageStore(Protocol): user_id: str | None = None, department_id: str | None = None, ) -> UsageSummary: - """Query usage summary.""" + """Query usage summary. + + Raises: + UsageStoreUnavailableError: When the store is degraded and + cannot answer quota-critical queries (fail-closed). + """ ... @@ -139,6 +153,27 @@ class InMemoryUsageStore: if len(self._records) > self.MAX_RECORDS: self._records = self._records[-self.MAX_RECORDS :] + async def record_async( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, + ) -> None: + """Async wrapper — InMemory store has no I/O, so just delegates to sync.""" + self.record( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) + def get_usage( self, agent_name: str | None = None, @@ -234,6 +269,19 @@ class RedisUsageStore: Legacy v1 keys (still readable for backward compat): agentkit:usage:{YYYY-MM-DD} → Hash agentkit:usage_records:{YYYY-MM-DD} → List + + Fail-closed semantics: + When Redis is unreachable and the store degrades to the in-memory + fallback, ``get_usage`` raises :class:`UsageStoreUnavailableError`. + The gateway translates this to HTTP 503, refusing the request + rather than allowing untracked usage (KTD-1). The fallback store + is used only for ``record`` (best-effort persistence) — never for + quota queries. + + Degradation recovery: + A background health-check task pings Redis every 30 seconds. On + success, the degraded flag is cleared and the fallback store is + discarded (KTD-5). """ USAGE_PREFIX = "agentkit:usage:" @@ -242,6 +290,7 @@ class RedisUsageStore: RECORDS_PREFIX_V2 = "agentkit:usage_records:v2:" MAX_RECORDS_PER_DAY = 50000 TTL_DAYS = 90 # Auto-expire after 90 days + HEALTH_CHECK_INTERVAL = 30.0 # seconds def __init__(self, redis_url: str = "redis://localhost:6379"): self._redis_url = redis_url @@ -249,6 +298,7 @@ class RedisUsageStore: self._sync_redis: Any = None self._fallback: InMemoryUsageStore | None = None self._degraded = False + self._health_check_task: asyncio.Task[None] | None = None async def _get_redis(self): if self._redis is None: @@ -266,6 +316,14 @@ class RedisUsageStore: return self._sync_redis async def aclose(self) -> None: + # Cancel the health-check task first to avoid races. + if self._health_check_task is not None: + self._health_check_task.cancel() + try: + await self._health_check_task + except (asyncio.CancelledError, Exception): + pass + self._health_check_task = None if self._redis is not None: await self._redis.aclose() self._redis = None @@ -279,6 +337,45 @@ class RedisUsageStore: if self._fallback is None: self._fallback = InMemoryUsageStore() logger.warning("Redis usage store unreachable, degraded to in-memory") + self._start_health_check() + + def _start_health_check(self) -> None: + """Start the background health-check task (idempotent). + + The task pings Redis every ``HEALTH_CHECK_INTERVAL`` seconds. On + success, the degraded flag is cleared and the fallback store is + discarded (KTD-5). + """ + if self._health_check_task is not None and not self._health_check_task.done(): + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # No running event loop (e.g. sync-only context) — skip. + logger.debug("No event loop available for health-check task") + return + self._health_check_task = loop.create_task(self._health_check_loop()) + + async def _health_check_loop(self) -> None: + """Periodically ping Redis; clear degraded state on success.""" + while self._degraded: + try: + await asyncio.sleep(self.HEALTH_CHECK_INTERVAL) + except asyncio.CancelledError: + return + try: + redis_client = await self._get_redis() + await redis_client.ping() + # Redis is back — clear degraded state. + self._degraded = False + self._fallback = None + logger.info("Redis usage store recovered, cleared degraded state") + return + except asyncio.CancelledError: + return + except Exception: + logger.debug("Redis health check still failing") + # Keep looping until Redis recovers. def _today_key(self) -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%d") @@ -299,6 +396,47 @@ class RedisUsageStore: f"{self.RECORDS_PREFIX_V2}{date_key}:{u}:{d}", ) + def _build_record_json( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + user_id: str | None, + department_id: str | None, + ) -> tuple[str, str, str, str]: + """Build (hash_key, list_key, bucket_field, record_json) for a write.""" + date_key = self._today_key() + hash_key, list_key = self._v2_keys(date_key, user_id, department_id) + bucket_field = f"{agent_name}:{model}" + rec = UsageRecord( + agent_name=agent_name, + model=model, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + total_tokens=usage.total_tokens, + cost=cost, + latency_ms=latency_ms, + user_id=user_id, + department_id=department_id, + ) + record_json = json.dumps( + { + "agent_name": rec.agent_name, + "model": rec.model, + "prompt_tokens": rec.prompt_tokens, + "completion_tokens": rec.completion_tokens, + "total_tokens": rec.total_tokens, + "cost": rec.cost, + "latency_ms": rec.latency_ms, + "timestamp": rec.timestamp, + "user_id": rec.user_id, + "department_id": rec.department_id, + } + ) + return hash_key, list_key, bucket_field, record_json + def record( self, agent_name: str, @@ -309,11 +447,12 @@ class RedisUsageStore: user_id: str | None = None, department_id: str | None = None, ) -> None: - """Record usage — sync wrapper for async Redis. + """Record usage — sync wrapper using sync Redis client. Note: This is a sync method because UsageTracker.record() is sync. For Redis, we use a sync Redis client for writes to avoid - needing an event loop in the caller. + needing an event loop in the caller. In async contexts prefer + :meth:`record_async` to avoid blocking the event loop (KTD-6). """ if self._degraded and self._fallback is not None: self._fallback.record( @@ -329,10 +468,9 @@ class RedisUsageStore: try: r = self._get_sync_redis() - - date_key = self._today_key() - hash_key, list_key = self._v2_keys(date_key, user_id, department_id) - bucket_field = f"{agent_name}:{model}" + hash_key, list_key, bucket_field, record_json = self._build_record_json( + agent_name, model, usage, cost, latency_ms, user_id, department_id + ) # Atomic HINCRBYFLOAT for bucket aggregation pipe = r.pipeline() @@ -343,34 +481,7 @@ class RedisUsageStore: pipe.hincrby(hash_key, f"{bucket_field}:count", 1) # Append record - rec = UsageRecord( - agent_name=agent_name, - model=model, - prompt_tokens=usage.prompt_tokens, - completion_tokens=usage.completion_tokens, - total_tokens=usage.total_tokens, - cost=cost, - latency_ms=latency_ms, - user_id=user_id, - department_id=department_id, - ) - pipe.rpush( - list_key, - json.dumps( - { - "agent_name": rec.agent_name, - "model": rec.model, - "prompt_tokens": rec.prompt_tokens, - "completion_tokens": rec.completion_tokens, - "total_tokens": rec.total_tokens, - "cost": rec.cost, - "latency_ms": rec.latency_ms, - "timestamp": rec.timestamp, - "user_id": rec.user_id, - "department_id": rec.department_id, - } - ), - ) + pipe.rpush(list_key, record_json) pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) # Set TTL on first write of the day @@ -392,6 +503,65 @@ class RedisUsageStore: department_id=department_id, ) + async def record_async( + self, + agent_name: str, + model: str, + usage: TokenUsage, + cost: float, + latency_ms: float, + user_id: str | None = None, + department_id: str | None = None, + ) -> None: + """Async record using ``redis.asyncio`` — does not block the event loop (KTD-6). + + Preferred over :meth:`record` in async contexts (gateway, routes). + Falls back to the in-memory store on Redis failure, same as the + sync version. + """ + if self._degraded and self._fallback is not None: + await self._fallback.record_async( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) + return + + try: + r = await self._get_redis() + hash_key, list_key, bucket_field, record_json = self._build_record_json( + agent_name, model, usage, cost, latency_ms, user_id, department_id + ) + + pipe = r.pipeline() + pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost) + pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens) + pipe.hincrby(hash_key, f"{bucket_field}:count", 1) + pipe.rpush(list_key, record_json) + pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) + pipe.expire(hash_key, self.TTL_DAYS * 86400) + pipe.expire(list_key, self.TTL_DAYS * 86400) + await pipe.execute() + except Exception as e: + logger.warning(f"Redis async usage record failed: {e}") + self._degrade_to_fallback() + if self._fallback is not None: + await self._fallback.record_async( + agent_name, + model, + usage, + cost, + latency_ms, + user_id=user_id, + department_id=department_id, + ) + def get_usage( self, agent_name: str | None = None, @@ -405,14 +575,28 @@ class RedisUsageStore: Scans v2 keys (filtered by user_id/department_id when provided) and legacy v1 keys (no per-user/department scoping). Records from both schemas are merged. + + Key construction (U1 fix): + - Both user_id and department_id provided → direct key lookup + (exact match). + - Only user_id provided → SCAN ``...:{user_id}:*`` to aggregate + across all departments the user belongs to. + - Only department_id provided → SCAN ``...:*:{department_id}`` + to aggregate across all users in that department. + - Neither provided → SCAN ``...:*`` (all records for the date). + + Fail-closed (KTD-1): + When the store is degraded (Redis unreachable), raises + :class:`UsageStoreUnavailableError`. The fallback store is + used only for ``record`` — never for quota queries. """ - if self._degraded and self._fallback is not None: - return self._fallback.get_usage( - agent_name=agent_name, - start_time=start_time, - end_time=end_time, - user_id=user_id, - department_id=department_id, + # Fail-closed: when degraded, refuse quota-critical queries. + # The in-memory fallback is only for recording (best-effort), + # not for quota checks — returning an empty summary would + # make quota checks pass (fail-open), which is a security bug. + if self._degraded: + raise UsageStoreUnavailableError( + "Redis usage store is degraded — cannot answer quota-critical query" ) try: @@ -429,14 +613,29 @@ class RedisUsageStore: end_date = end.date() while current <= end_date: date_key = current.isoformat() - # When user_id/department_id is provided, scan only the - # matching scope key. Otherwise scan all scopes for that - # date via SCAN. - if user_id is not None or department_id is not None: - list_key = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:{self._scope_key(department_id)}" + # Determine the SCAN pattern based on which scope dims are provided. + # Key format: {RECORDS_PREFIX_V2}{date}:{user_or_none}:{dept_or_none} + if user_id is not None and department_id is not None: + # Both provided → direct key lookup (exact match). + list_key = ( + f"{self.RECORDS_PREFIX_V2}{date_key}:" + f"{self._scope_key(user_id)}:{self._scope_key(department_id)}" + ) all_records.extend(self._read_list(r, list_key, start, end, agent_name)) + elif user_id is not None: + # Only user_id → aggregate across all departments. + pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:{self._scope_key(user_id)}:*" + for key in r.scan_iter(match=pattern, count=200): + all_records.extend(self._read_list(r, key, start, end, agent_name)) + elif department_id is not None: + # Only department_id → aggregate across all users. + pattern = ( + f"{self.RECORDS_PREFIX_V2}{date_key}:*:{self._scope_key(department_id)}" + ) + for key in r.scan_iter(match=pattern, count=200): + all_records.extend(self._read_list(r, key, start, end, agent_name)) else: - # Scan all v2 list keys for this date. + # Neither provided → scan all v2 list keys for this date. pattern = f"{self.RECORDS_PREFIX_V2}{date_key}:*" for key in r.scan_iter(match=pattern, count=200): all_records.extend(self._read_list(r, key, start, end, agent_name)) @@ -460,17 +659,13 @@ class RedisUsageStore: return UsageSummary() return InMemoryUsageStore._aggregate(all_records) + except UsageStoreUnavailableError: + raise except Exception as e: logger.warning(f"Redis usage query failed: {e}") - if self._fallback is not None: - return self._fallback.get_usage( - agent_name=agent_name, - start_time=start_time, - end_time=end_time, - user_id=user_id, - department_id=department_id, - ) - return UsageSummary() + # Degrade for future calls, but fail-closed for this one. + self._degrade_to_fallback() + raise UsageStoreUnavailableError(f"Redis usage query failed: {e}") from e def get_usage_by_user( self, diff --git a/src/agentkit/server/admin/context.py b/src/agentkit/server/admin/context.py index ee97f94..eee9890 100644 --- a/src/agentkit/server/admin/context.py +++ b/src/agentkit/server/admin/context.py @@ -30,7 +30,7 @@ from pathlib import Path from typing import Any import aiosqlite -from fastapi import Request +from fastapi import HTTPException, Request from agentkit.server.auth.models import DEFAULT_AUTH_DB_PATH @@ -134,6 +134,11 @@ async def get_department_context(request: Request) -> DepartmentContext: If ``request.state.current_user`` is missing entirely (e.g. the auth middleware was not installed), returns an empty context equivalent to the unauthenticated case. + + Fail-closed (KTD-1): if the DB lookup fails for a regular user, + raises ``HTTPException(503)``. Returning an empty list would make + quota enforcement skip the check (fail-open), which is a security + bug. Admins and API-key clients are unaffected (no DB lookup). """ current_user: dict[str, Any] | None = getattr(request.state, "current_user", None) if current_user is None: @@ -155,15 +160,23 @@ async def get_department_context(request: Request) -> DepartmentContext: return DepartmentContext(user_id=None, department_ids=[], is_admin=False) # Regular user: look up their active department ids. + # Fail-closed on DB errors — returning an empty list would bypass + # quota enforcement (fail-open), which is a security vulnerability. db_path = _resolve_db_path(request) try: department_ids = await _fetch_user_department_ids(db_path, user_id) - except Exception: # noqa: BLE001 — never block a request on DB errors + except Exception: logger.exception( - "Failed to fetch department ids for user %s — falling back to empty list", + "Failed to fetch department ids for user %s — fail-closed (503)", user_id, ) - department_ids = [] + raise HTTPException( + status_code=503, + detail={ + "error": "department_lookup_failed", + "detail": "Cannot verify department membership — refusing request", + }, + ) from None return DepartmentContext( user_id=user_id, diff --git a/src/agentkit/server/routes/llm_gateway.py b/src/agentkit/server/routes/llm_gateway.py index 1da87dc..702ae1e 100644 --- a/src/agentkit/server/routes/llm_gateway.py +++ b/src/agentkit/server/routes/llm_gateway.py @@ -14,6 +14,7 @@ from pydantic import BaseModel, ConfigDict, Field from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.gateway import QuotaExceededError from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall +from agentkit.llm.providers.usage_store import UsageStoreUnavailableError from agentkit.server.admin.context import get_department_context router = APIRouter(prefix="/llm", tags=["llm-gateway"]) @@ -109,6 +110,11 @@ async def chat( ) except QuotaExceededError as e: raise HTTPException(status_code=429, detail=_quota_error_payload(e)) from e + except UsageStoreUnavailableError as e: + raise HTTPException( + status_code=503, + detail={"error": "usage_store_unavailable", "detail": str(e)}, + ) from e except ModelNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e except LLMProviderError as e: @@ -150,6 +156,12 @@ async def chat_stream( error_payload = _quota_error_payload(e) error_payload["error"] = "quota_exceeded" yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" + except UsageStoreUnavailableError as e: + error_payload = { + "error": "usage_store_unavailable", + "detail": str(e), + } + yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" except ModelNotFoundError as e: error_payload = {"error": "model_not_found", "detail": str(e)} yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n" diff --git a/tests/unit/llm/test_usage_store.py b/tests/unit/llm/test_usage_store.py index 93e59ba..39597a1 100644 --- a/tests/unit/llm/test_usage_store.py +++ b/tests/unit/llm/test_usage_store.py @@ -2,6 +2,8 @@ from datetime import datetime, timedelta, timezone +import pytest + from agentkit.llm.protocol import TokenUsage from agentkit.llm.providers.usage_store import ( InMemoryUsageStore, @@ -10,6 +12,7 @@ from agentkit.llm.providers.usage_store import ( UsageBucket, UsageSummary, UsageStore, + UsageStoreUnavailableError, create_usage_store, ) @@ -315,20 +318,40 @@ class TestRedisUsageStoreMocked: summary = store._fallback.get_usage() assert len(summary.records) == 1 - def test_get_usage_degraded_uses_fallback(self): + def test_get_usage_degraded_raises_unavailable(self): + """Fail-closed (KTD-1): degraded get_usage raises UsageStoreUnavailableError.""" + store = self._make_store() + store._degrade_to_fallback() + with pytest.raises(UsageStoreUnavailableError): + store.get_usage() + + def test_get_usage_degraded_raises_even_with_fallback_data(self): + """Even if fallback has data, get_usage must fail-closed (no quota bypass).""" store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) store._fallback.record("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) - summary = store.get_usage() - assert len(summary.records) == 1 + with pytest.raises(UsageStoreUnavailableError): + store.get_usage() - def test_get_usage_degraded_no_fallback_returns_empty(self): + def test_get_usage_degraded_with_user_filter_raises(self): + """Fail-closed applies even when filtering by user_id.""" store = self._make_store() - store._degraded = True - # No fallback set — should return empty - summary = store.get_usage() - assert summary.total_tokens == 0 + store._degrade_to_fallback() + with pytest.raises(UsageStoreUnavailableError): + store.get_usage(user_id="u1") + + def test_get_usage_by_user_degraded_raises(self): + store = self._make_store() + store._degrade_to_fallback() + with pytest.raises(UsageStoreUnavailableError): + store.get_usage_by_user("u1") + + def test_get_usage_by_department_degraded_raises(self): + store = self._make_store() + store._degrade_to_fallback() + with pytest.raises(UsageStoreUnavailableError): + store.get_usage_by_department("d1") def test_today_key_format(self): store = self._make_store() @@ -369,65 +392,20 @@ class TestRedisUsageStoreMocked: assert summary.records[0].user_id == "u1" assert summary.records[0].department_id == "d1" - def test_get_usage_degraded_with_user_filter(self): - store = self._make_store() - store._degrade_to_fallback() - usage = TokenUsage(prompt_tokens=100, completion_tokens=50) - store._fallback.record( - "agent1", - "gpt-4", - usage, - cost=0.05, - latency_ms=200, - user_id="u1", - department_id="d1", - ) - store._fallback.record( - "agent1", - "gpt-4", - usage, - cost=0.05, - latency_ms=200, - user_id="u2", - department_id="d2", - ) - summary = store.get_usage(user_id="u1") - assert len(summary.records) == 1 - assert summary.records[0].user_id == "u1" + def test_record_async_degraded_uses_fallback(self): + """Async record in degraded state uses fallback (recording is allowed).""" + import asyncio - def test_get_usage_by_user_degraded(self): store = self._make_store() store._degrade_to_fallback() usage = TokenUsage(prompt_tokens=100, completion_tokens=50) - store._fallback.record( - "agent1", - "gpt-4", - usage, - cost=0.05, - latency_ms=200, - user_id="u1", - department_id="d1", - ) - summary = store.get_usage_by_user("u1") - assert len(summary.records) == 1 - assert summary.records[0].user_id == "u1" - def test_get_usage_by_department_degraded(self): - store = self._make_store() - store._degrade_to_fallback() - usage = TokenUsage(prompt_tokens=100, completion_tokens=50) - store._fallback.record( - "agent1", - "gpt-4", - usage, - cost=0.05, - latency_ms=200, - user_id="u1", - department_id="d1", - ) - summary = store.get_usage_by_department("d1") + async def _run(): + await store.record_async("agent1", "gpt-4", usage, cost=0.05, latency_ms=200) + + asyncio.run(_run()) + summary = store._fallback.get_usage() assert len(summary.records) == 1 - assert summary.records[0].department_id == "d1" # --------------------------------------------------------------------------- @@ -448,3 +426,347 @@ class TestCreateUsageStore: store = create_usage_store(backend="redis") # May be InMemory if redis package unavailable assert isinstance(store, (InMemoryUsageStore, RedisUsageStore)) + + +# --------------------------------------------------------------------------- +# U1: Key construction fix — SCAN patterns for partial scope queries +# --------------------------------------------------------------------------- + + +class TestRedisUsageStoreKeyConstruction: + """U1: Verify get_usage constructs correct SCAN patterns when only + user_id OR only department_id is provided (not both). + + Previously, ``get_usage(department_id=X)`` constructed key + ``...:none:X`` which only matched records with no user — missing + all records from actual users in that department. The fix uses + SCAN with pattern ``...:*:X`` to aggregate across all users. + """ + + def _make_store_with_mock_redis(self, mock_redis): + store = RedisUsageStore(redis_url="redis://localhost:6379") + store._sync_redis = mock_redis + return store + + def test_get_usage_by_department_scans_all_users(self): + """get_usage(department_id=X) should SCAN ...:*:X (all users).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + # Single-day range so scan_iter is called once per scope. + keys_for_d1 = [ + "agentkit:usage_records:v2:2026-06-21:u1:d1", + "agentkit:usage_records:v2:2026-06-21:u2:d1", + "agentkit:usage_records:v2:2026-06-21:none:d1", + ] + mock_redis.scan_iter.return_value = iter(keys_for_d1) + # lrange returns records for each v2 key + 1 legacy v1 key. + record_u1 = ( + '{"agent_name":"a","model":"m","prompt_tokens":100,' + '"completion_tokens":50,"total_tokens":150,"cost":0.05,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u1","department_id":"d1"}' + ) + record_u2 = ( + '{"agent_name":"a","model":"m","prompt_tokens":200,' + '"completion_tokens":100,"total_tokens":300,"cost":0.10,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u2","department_id":"d1"}' + ) + record_none = ( + '{"agent_name":"a","model":"m","prompt_tokens":50,' + '"completion_tokens":25,"total_tokens":75,"cost":0.02,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":null,"department_id":"d1"}' + ) + mock_redis.lrange.side_effect = [ + [record_u1], # for key u1:d1 + [record_u2], # for key u2:d1 + [record_none], # for key none:d1 + [], # legacy v1 key + ] + + store = self._make_store_with_mock_redis(mock_redis) + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + # Use end_time within same day (records are at 10:00 UTC). + end = day + timedelta(hours=12) + summary = store.get_usage(department_id="d1", start_time=day, end_time=end) + + # Should aggregate records from all users in d1. + assert summary.total_tokens == 525 # 150 + 300 + 75 + # Verify SCAN was called with the correct pattern. + scan_call = mock_redis.scan_iter.call_args + pattern = scan_call.kwargs.get("match") or scan_call.args[0] + assert "*:d1" in pattern + + def test_get_usage_by_user_scans_all_departments(self): + """get_usage(user_id=X) should SCAN ...:X:* (all departments).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + mock_redis.scan_iter.return_value = iter( + [ + "agentkit:usage_records:v2:2026-06-21:u1:d1", + "agentkit:usage_records:v2:2026-06-21:u1:d2", + ] + ) + record_d1 = ( + '{"agent_name":"a","model":"m","prompt_tokens":100,' + '"completion_tokens":50,"total_tokens":150,"cost":0.05,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u1","department_id":"d1"}' + ) + record_d2 = ( + '{"agent_name":"a","model":"m","prompt_tokens":200,' + '"completion_tokens":100,"total_tokens":300,"cost":0.10,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u1","department_id":"d2"}' + ) + mock_redis.lrange.side_effect = [ + [record_d1], + [record_d2], + [], # legacy v1 key + ] + + store = self._make_store_with_mock_redis(mock_redis) + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + end = day + timedelta(hours=12) + summary = store.get_usage(user_id="u1", start_time=day, end_time=end) + + assert summary.total_tokens == 450 # 150 + 300 + scan_call = mock_redis.scan_iter.call_args + pattern = scan_call.kwargs.get("match") or scan_call.args[0] + assert "u1:*" in pattern + + def test_get_usage_both_user_and_dept_uses_direct_key(self): + """get_usage(user_id=X, department_id=Y) → direct key lookup (no SCAN).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + record = ( + '{"agent_name":"a","model":"m","prompt_tokens":100,' + '"completion_tokens":50,"total_tokens":150,"cost":0.05,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u1","department_id":"d1"}' + ) + mock_redis.lrange.side_effect = [ + [record], # direct v2 key + [], # legacy v1 key + ] + + store = self._make_store_with_mock_redis(mock_redis) + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + end = day + timedelta(hours=12) + summary = store.get_usage( + user_id="u1", + department_id="d1", + start_time=day, + end_time=end, + ) + + assert summary.total_tokens == 150 + # SCAN should NOT be called (direct key lookup instead). + mock_redis.scan_iter.assert_not_called() + + def test_get_usage_no_filter_scans_all(self): + """get_usage() with no filters → SCAN ...:* (all records).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + mock_redis.scan_iter.return_value = iter(["agentkit:usage_records:v2:2026-06-21:u1:d1"]) + record = ( + '{"agent_name":"a","model":"m","prompt_tokens":100,' + '"completion_tokens":50,"total_tokens":150,"cost":0.05,' + '"latency_ms":200,"timestamp":"2026-06-21T10:00:00+00:00",' + '"user_id":"u1","department_id":"d1"}' + ) + mock_redis.lrange.side_effect = [ + [record], + [], # legacy v1 key + ] + + store = self._make_store_with_mock_redis(mock_redis) + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + end = day + timedelta(hours=12) + summary = store.get_usage(start_time=day, end_time=end) + + assert summary.total_tokens == 150 + mock_redis.scan_iter.assert_called_once() + + def test_get_usage_empty_redis_returns_empty_summary(self): + """No records in Redis → empty UsageSummary (not an error).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + mock_redis.scan_iter.return_value = iter([]) + mock_redis.lrange.return_value = [] + + store = self._make_store_with_mock_redis(mock_redis) + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + end = day + timedelta(hours=12) + summary = store.get_usage(department_id="d1", start_time=day, end_time=end) + + assert summary.total_tokens == 0 + assert len(summary.records) == 0 + + +# --------------------------------------------------------------------------- +# U1: Degradation recovery — health check clears degraded state +# --------------------------------------------------------------------------- + + +class TestRedisUsageStoreDegradationRecovery: + """U1 (KTD-5): Redis recovery clears degraded state via health check.""" + + def test_health_check_clears_degraded_on_recovery(self): + """When Redis recovers, _degraded is cleared and _fallback discarded.""" + import asyncio + from unittest.mock import AsyncMock, patch + + store = RedisUsageStore(redis_url="redis://localhost:6379") + store._degrade_to_fallback() + assert store._degraded + assert store._fallback is not None + + async def _run(): + # Mock the async Redis client's ping to succeed. + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + with patch.object(store, "_get_redis", return_value=mock_redis): + # Manually trigger one iteration of the health check loop. + # We set a very short interval and run one cycle. + store.HEALTH_CHECK_INTERVAL = 0.01 + task = asyncio.create_task(store._health_check_loop()) + await asyncio.sleep(0.05) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + asyncio.run(_run()) + + assert not store._degraded + assert store._fallback is None + + def test_health_check_keeps_degraded_on_failure(self): + """When Redis is still down, degraded state persists.""" + import asyncio + from unittest.mock import AsyncMock, patch + + store = RedisUsageStore(redis_url="redis://localhost:6379") + store._degrade_to_fallback() + + async def _run(): + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(side_effect=ConnectionError("still down")) + with patch.object(store, "_get_redis", return_value=mock_redis): + store.HEALTH_CHECK_INTERVAL = 0.01 + task = asyncio.create_task(store._health_check_loop()) + await asyncio.sleep(0.05) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + asyncio.run(_run()) + + assert store._degraded # Still degraded + + def test_aclose_cancels_health_check_task(self): + """aclose() cancels the health check task.""" + import asyncio + + store = RedisUsageStore(redis_url="redis://localhost:6379") + + async def _run(): + store._degrade_to_fallback() + assert store._health_check_task is not None + await store.aclose() + assert store._health_check_task is None + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# U1: Fail-closed — get_usage raises when Redis query fails +# --------------------------------------------------------------------------- + + +class TestRedisUsageStoreFailClosed: + """U1 (KTD-1): Redis query failure → UsageStoreUnavailableError (not empty).""" + + def test_get_usage_raises_on_redis_connection_failure(self): + """When Redis connection fails during get_usage, raise (fail-closed).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + mock_redis.scan_iter.side_effect = ConnectionError("Redis gone") + + store = RedisUsageStore(redis_url="redis://localhost:6379") + store._sync_redis = mock_redis + + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + with pytest.raises(UsageStoreUnavailableError): + store.get_usage(department_id="d1", start_time=day, end_time=day + timedelta(days=1)) + + # Should also degrade for future calls. + assert store._degraded + + def test_get_usage_fail_closed_not_empty_summary(self): + """Critical: must NOT return empty summary on Redis failure (that's fail-open).""" + from unittest.mock import MagicMock + + mock_redis = MagicMock() + mock_redis.scan_iter.side_effect = ConnectionError("Redis gone") + + store = RedisUsageStore(redis_url="redis://localhost:6379") + store._sync_redis = mock_redis + + day = datetime(2026, 6, 21, tzinfo=timezone.utc) + # Must raise, not return UsageSummary(). + with pytest.raises(UsageStoreUnavailableError): + store.get_usage(start_time=day, end_time=day + timedelta(days=1)) + + +# --------------------------------------------------------------------------- +# U1: Async record — record_async uses redis.asyncio +# --------------------------------------------------------------------------- + + +class TestRedisUsageStoreAsyncRecord: + """U1 (KTD-6): record_async uses async Redis client, doesn't block event loop.""" + + def test_record_async_writes_to_redis(self): + """record_async should write to Redis via async client.""" + import asyncio + from unittest.mock import AsyncMock, MagicMock + + store = RedisUsageStore(redis_url="redis://localhost:6379") + mock_pipeline = MagicMock() + mock_pipeline.execute = AsyncMock(return_value=[True]) + mock_redis = MagicMock() + mock_redis.pipeline.return_value = mock_pipeline + + async def _run(): + from unittest.mock import patch + + with patch.object(store, "_get_redis", return_value=mock_redis): + usage = TokenUsage(prompt_tokens=100, completion_tokens=50) + await store.record_async( + "agent1", + "gpt-4", + usage, + cost=0.05, + latency_ms=200, + user_id="u1", + department_id="d1", + ) + + asyncio.run(_run()) + + # Verify pipeline was used. + mock_redis.pipeline.assert_called_once() + mock_pipeline.hincrbyfloat.assert_called_once() + mock_pipeline.rpush.assert_called_once()