"""Usage Store — Persistent usage tracking with Redis Hash backend. Provides UsageStore Protocol with InMemoryUsageStore and RedisUsageStore backends. Replaces the in-memory list in UsageTracker with a pluggable store that survives restarts and supports multi-instance deployment. Key schema (Redis): agentkit:usage:{date} → Hash: {agent_name:model → JSON(UsageBucket)} agentkit:usage_records:{date} → List: JSON(UsageRecord) with LTRIM """ import json import logging from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Any, Protocol, runtime_checkable from agentkit.llm.protocol import TokenUsage logger = logging.getLogger(__name__) @dataclass class UsageRecord: """使用量记录""" agent_name: str model: str prompt_tokens: int completion_tokens: int total_tokens: int cost: float latency_ms: float timestamp: str = "" # ISO 8601 string for JSON serialization def __post_init__(self): if not self.timestamp: self.timestamp = datetime.now(timezone.utc).isoformat() @dataclass class UsageBucket: """Aggregated usage for an agent+model pair on a given date.""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 cost: float = 0.0 count: int = 0 @dataclass class UsageSummary: """使用量汇总""" total_tokens: int = 0 total_cost: float = 0.0 by_model: dict[str, dict[str, int | float]] = field(default_factory=dict) records: list[UsageRecord] = field(default_factory=list) # --------------------------------------------------------------------------- # UsageStore Protocol # --------------------------------------------------------------------------- @runtime_checkable class UsageStore(Protocol): """Persistent usage store interface.""" def record( self, agent_name: str, model: str, usage: TokenUsage, cost: float, latency_ms: float, ) -> None: """Record a usage event.""" ... def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> UsageSummary: """Query usage summary.""" ... # --------------------------------------------------------------------------- # InMemoryUsageStore # --------------------------------------------------------------------------- class InMemoryUsageStore: """In-memory usage store (drop-in replacement for old UsageTracker).""" MAX_RECORDS = 10000 def __init__(self): self._records: list[UsageRecord] = [] def record( self, agent_name: str, model: str, usage: TokenUsage, cost: float, latency_ms: float, ) -> None: rec = UsageRecord( agent_name=agent_name, model=model, prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, total_tokens=usage.total_tokens, cost=cost, latency_ms=latency_ms, ) self._records.append(rec) if len(self._records) > self.MAX_RECORDS: self._records = self._records[-self.MAX_RECORDS:] def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> UsageSummary: filtered = self._records if agent_name is not None: filtered = [r for r in filtered if r.agent_name == agent_name] if start_time is not None: filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) >= start_time] if end_time is not None: filtered = [r for r in filtered if datetime.fromisoformat(r.timestamp) <= end_time] if not filtered: return UsageSummary() total_tokens = sum(r.total_tokens for r in filtered) total_cost = sum(r.cost for r in filtered) by_model: dict[str, dict[str, int | float]] = {} for r in filtered: if r.model not in by_model: by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} by_model[r.model]["total_tokens"] += r.total_tokens by_model[r.model]["total_cost"] += r.cost by_model[r.model]["count"] += 1 return UsageSummary( total_tokens=total_tokens, total_cost=total_cost, by_model=by_model, records=filtered, ) # --------------------------------------------------------------------------- # RedisUsageStore # --------------------------------------------------------------------------- class RedisUsageStore: """Redis-backed usage store using Hash per date for O(1) writes. Key schema: agentkit:usage:{YYYY-MM-DD} → Hash: {agent:model → JSON(UsageBucket)} agentkit:usage_records:{YYYY-MM-DD} → List: JSON(UsageRecord) with LTRIM """ USAGE_PREFIX = "agentkit:usage:" RECORDS_PREFIX = "agentkit:usage_records:" MAX_RECORDS_PER_DAY = 50000 TTL_DAYS = 90 # Auto-expire after 90 days def __init__(self, redis_url: str = "redis://localhost:6379"): self._redis_url = redis_url self._redis: Any = None self._sync_redis: Any = None self._fallback: InMemoryUsageStore | None = None self._degraded = False async def _get_redis(self): if self._redis is None: import redis.asyncio as aioredis self._redis = aioredis.from_url(self._redis_url, decode_responses=True) return self._redis def _get_sync_redis(self): """Get or create a persistent sync Redis client (connection pool backed).""" if self._sync_redis is None: import redis as sync_redis self._sync_redis = sync_redis.from_url( self._redis_url, decode_responses=True ) return self._sync_redis async def aclose(self) -> None: if self._redis is not None: await self._redis.aclose() self._redis = None if self._sync_redis is not None: self._sync_redis.close() self._sync_redis = None def _degrade_to_fallback(self) -> None: if not self._degraded: self._degraded = True if self._fallback is None: self._fallback = InMemoryUsageStore() logger.warning("Redis usage store unreachable, degraded to in-memory") def _today_key(self) -> str: return datetime.now(timezone.utc).strftime("%Y-%m-%d") def record( self, agent_name: str, model: str, usage: TokenUsage, cost: float, latency_ms: float, ) -> None: """Record usage — sync wrapper for async Redis. Note: This is a sync method because UsageTracker.record() is sync. For Redis, we use a sync Redis client for writes to avoid needing an event loop in the caller. """ if self._degraded and self._fallback is not None: self._fallback.record(agent_name, model, usage, cost, latency_ms) return try: r = self._get_sync_redis() date_key = self._today_key() hash_key = f"{self.USAGE_PREFIX}{date_key}" list_key = f"{self.RECORDS_PREFIX}{date_key}" bucket_field = f"{agent_name}:{model}" # Atomic HINCRBYFLOAT for bucket aggregation pipe = r.pipeline() pipe.hincrbyfloat(hash_key, f"{bucket_field}:cost", cost) pipe.hincrby(hash_key, f"{bucket_field}:prompt_tokens", usage.prompt_tokens) pipe.hincrby(hash_key, f"{bucket_field}:completion_tokens", usage.completion_tokens) pipe.hincrby(hash_key, f"{bucket_field}:total_tokens", usage.total_tokens) pipe.hincrby(hash_key, f"{bucket_field}:count", 1) # Append record rec = UsageRecord( agent_name=agent_name, model=model, prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, total_tokens=usage.total_tokens, cost=cost, latency_ms=latency_ms, ) pipe.rpush(list_key, json.dumps({ "agent_name": rec.agent_name, "model": rec.model, "prompt_tokens": rec.prompt_tokens, "completion_tokens": rec.completion_tokens, "total_tokens": rec.total_tokens, "cost": rec.cost, "latency_ms": rec.latency_ms, "timestamp": rec.timestamp, })) pipe.ltrim(list_key, -self.MAX_RECORDS_PER_DAY, -1) # Set TTL on first write of the day pipe.expire(hash_key, self.TTL_DAYS * 86400) pipe.expire(list_key, self.TTL_DAYS * 86400) pipe.execute() except Exception as e: logger.warning(f"Redis usage record failed: {e}") self._degrade_to_fallback() if self._fallback is not None: self._fallback.record(agent_name, model, usage, cost, latency_ms) def get_usage( self, agent_name: str | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> UsageSummary: """Query usage summary from Redis.""" if self._degraded and self._fallback is not None: return self._fallback.get_usage(agent_name, start_time, end_time) try: r = self._get_sync_redis() # Determine date range to scan start = start_time or datetime(2020, 1, 1, tzinfo=timezone.utc) end = end_time or datetime.now(timezone.utc) all_records: list[UsageRecord] = [] # Scan date keys in range current = start.date() end_date = end.date() while current <= end_date: list_key = f"{self.RECORDS_PREFIX}{current.isoformat()}" raw_records = r.lrange(list_key, 0, -1) for raw in raw_records: data = json.loads(raw) rec = UsageRecord(**data) rec_ts = datetime.fromisoformat(rec.timestamp) if rec_ts >= start and rec_ts <= end: if agent_name is None or rec.agent_name == agent_name: all_records.append(rec) current = current + timedelta(days=1) if not all_records: return UsageSummary() total_tokens = sum(r.total_tokens for r in all_records) total_cost = sum(r.cost for r in all_records) by_model: dict[str, dict[str, int | float]] = {} for r in all_records: if r.model not in by_model: by_model[r.model] = {"total_tokens": 0, "total_cost": 0.0, "count": 0} by_model[r.model]["total_tokens"] += r.total_tokens by_model[r.model]["total_cost"] += r.cost by_model[r.model]["count"] += 1 return UsageSummary( total_tokens=total_tokens, total_cost=total_cost, by_model=by_model, records=all_records, ) except Exception as e: logger.warning(f"Redis usage query failed: {e}") if self._fallback is not None: return self._fallback.get_usage(agent_name, start_time, end_time) return UsageSummary() # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- def create_usage_store( backend: str = "auto", redis_url: str = "redis://localhost:6379", ) -> UsageStore: """Create a usage store backend. Args: backend: "auto" (try Redis, fallback to memory), "redis", "memory". redis_url: Redis connection URL. Returns: A UsageStore instance. """ if backend in ("auto", "redis"): try: import redis # noqa: F401 return RedisUsageStore(redis_url=redis_url) except ImportError: logger.warning("redis package not available, falling back to in-memory usage store") return InMemoryUsageStore() return InMemoryUsageStore()